71

With python lists, we can do:

a = [1, 2, 3]
assert a.index(2) == 1

How can a pytorch tensor find the .index() directly?

1

11 Answers 11

100

I think there is no direct translation from list.index() to a pytorch function. However, you can achieve similar results using tensor==number and then the nonzero() function. For example:

t = torch.Tensor([1, 2, 3])
print ((t == 2).nonzero(as_tuple=True)[0])

This piece of code returns

1

[torch.LongTensor of size 1x1]

Sign up to request clarification or add additional context in comments.

4 Comments

what would happen if there is no match? how would you handle that?
@CharlieParker If there is no match, it will return empty tensor tensor([], dtype=torch.int64)
How can we extend this to get batch indices? In this case, I'd want indices of values torch.Tensor([1, 2, 3]) all at once, not just 2. Is there a way without for-loops?
@KaranShah, sure, you can use pytorch.org/docs/stable/generated/torch.eq.html for multidirectional elementwise equal.
27

For multidimensional tensors you can do:

(tensor == target_value).nonzero(as_tuple=True)

The resulting tensor will be of shape number_of_matches x tensor_dimension. For example, say tensor is a 3 x 4 tensor (that means the dimension is 2), the result will be a 2D-tensor with the indexes for the matches in the rows.

tensor = torch.Tensor([[1, 2, 2, 7], [3, 1, 2, 4], [3, 1, 9, 4]])
(tensor == 2).nonzero(as_tuple=False)
>>> tensor([[0, 1],
        [0, 2],
        [1, 2]])

3 Comments

Most complete answer! general beyond just flat tensors.
what would happen if there is no match? how would you handle that?
In that case you'd get an empty tensor. The tensor shape would still adhere to the dimension of the input tensor, so in the above example, searching for a 8 would result in an (empty) tensor of shape 0 x 2.
4
x = torch.Tensor([11, 22, 33, 22])
print((x==22).nonzero().squeeze())

tensor([1, 3])

Comments

3
a = torch.tensor([1, 2, 3])
torch.where(a == 2)[0]
>>>tensor([1])

Comments

2

Based on others' answers:

t = torch.Tensor([1, 2, 3])
print((t==1).nonzero().item())

1 Comment

This is okay if the tensor only contains one occurrence of the intended number. That is because .item() method can only be called on one-element tensor, otherwise it will raise an error.
2

The answers already given are great but they don't handle when I tried it when there is no match. For that see this:

def index(tensor: Tensor, value, ith_match:int =0) -> Tensor:
    """
    Returns generalized index (i.e. location/coordinate) of the first occurence of value
    in Tensor. For flat tensors (i.e. arrays/lists) it returns the indices of the occurrences
    of the value you are looking for. Otherwise, it returns the "index" as a coordinate.
    If there are multiple occurences then you need to choose which one you want with ith_index.
    e.g. ith_index=0 gives first occurence.

    Reference: https://stackoverflow.com/a/67175757/1601580
    :return:
    """
    # bool tensor of where value occurred
    places_where_value_occurs = (tensor == value)
    # get matches as a "coordinate list" where occurence happened
    matches = (tensor == value).nonzero()  # [number_of_matches, tensor_dimension]
    if matches.size(0) == 0:  # no matches
        return -1
    else:
        # get index/coordinate of the occurence you want (e.g. 1st occurence ith_match=0)
        index = matches[ith_match]
        return index

credit to this great answer: https://stackoverflow.com/a/67175757/1601580

Comments

1

In my opinion, calling tolist() is simple and easy to understand.

t = torch.Tensor([1, 2, 3])
t.tolist().index(2) # -> 1

Comments

0

Can be done by converting to numpy as follows

import torch
x = torch.range(1,4)
print(x)
===> tensor([ 1.,  2.,  3.,  4.]) 
nx = x.numpy()
np.where(nx == 3)[0][0]
===> 2

2 Comments

Be aware that if you convert to numpy you lose the gradient graph.
what would happen if there is no match? how would you handle that?
-2

for finding index of an element in 1d tensor/array Example

mat=torch.tensor([1,8,5,3])

to find index of 5

five=5

numb_of_col=4
for o in range(numb_of_col):
   if mat[o]==five:
     print(torch.tensor([o]))

To find element index of a 2d/3d tensor covert it into 1d #ie example.view(number of elements)

Example

mat=torch.tensor([[1,2],[4,3])
#to find index of 2

five = 2
mat=mat.view(4)
numb_of_col = 4
for o in range(numb_of_col):
   if mat[o] == five:
     print(torch.tensor([o]))    

1 Comment

This is an old question and the OP is most likely not looking for a quick answer. If you want to post an answer, please take your time to explain what you code does and how does it add to the already existing answers. The code block as it is now is of very little value to the community
-3

For floating point tensors, I use this to get the index of the element in the tensor.

print((torch.abs((torch.max(your_tensor).item()-your_tensor))<0.0001).nonzero())

Here I want to get the index of max_value in the float tensor, you can also put your value like this to get the index of any elements in tensor.

print((torch.abs((YOUR_VALUE-your_tensor))<0.0001).nonzero())

1 Comment

FYI torch.isclose(your_value, your_tensor) is a better way to do this.
-4
    import torch
    x_data = variable(torch.Tensor([[1.0], [2.0], [3.0]]))
    print(x_data.data[0])
    >>tensor([1.])

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.