>>> import torch
>>> A = torch.tensor([[[ 1., 2., 3.],
... [ 5., 6., 7.]],
...
... [[ 9., 10., 11.],
... [13., 14., 15.]]])
>>> B = torch.tensor([[0, 2],
... [1, 0]])
>>> A.shape
torch.Size([2, 2, 3])
>>> B.shape
torch.Size([2, 2])
>>> C = torch.zeros_like(B)
>>> for i in range(B.shape[0]):
... for j in range(B.shape[1]):
... C[i,j] = A[i,j,B[i,j]]
...
>>> C
tensor([[ 1, 7],
[10, 13]])
>>> torch.gather(A, -1, B.unsqueeze(-1))
tensor([[[ 1.],
[ 7.]],
[[10.],
[13.]]])
>>> torch.gather(A, -1, B.unsqueeze(-1)).shape
torch.Size([2, 2, 1])
>>> torch.gather(A, -1, B.unsqueeze(-1)).squeeze(-1)
tensor([[ 1., 7.],
[10., 13.]])
Hi, you can use torch.gather(A, -1, B.unsqueeze(-1)).squeeze(-1).
the first -1 between A and B.unsqueeze(-1) is indicating the dimension along which you want to pick the element.
the second -1 in B.unsqueeze(-1) is to add one dim to B to make the two tensor the same dims otherwise you get RuntimeError: Index tensor must have the same number of dimensions as input tensor.
the last -1 is to reshape the result from torch.Size([2, 2, 1]) to torch.Size([2, 2])