2

I would like to have the tensor indexed a certain way.

Suppose my data, tensor X shaped (1, 3, 16, 9) is

tensor([[[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
      [ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
      [ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
      [ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
      [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
      [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
      [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
      [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
      [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
      [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
      [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
      [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
      [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.],
      [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.],
      [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.],
      [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]],

     [[ 0.,  0.,  0.,  0., 17., 18.,  0., 21., 22.],
      [ 0.,  0.,  0., 17., 18., 19., 21., 22., 23.],
      [ 0.,  0.,  0., 18., 19., 20., 22., 23., 24.],
      [ 0.,  0.,  0., 19., 20.,  0., 23., 24.,  0.],
      [ 0., 17., 18.,  0., 21., 22.,  0., 25., 26.],
      [17., 18., 19., 21., 22., 23., 25., 26., 27.],
      [18., 19., 20., 22., 23., 24., 26., 27., 28.],
      [19., 20.,  0., 23., 24.,  0., 27., 28.,  0.],
      [ 0., 21., 22.,  0., 25., 26.,  0., 29., 30.],
      [21., 22., 23., 25., 26., 27., 29., 30., 31.],
      [22., 23., 24., 26., 27., 28., 30., 31., 32.],
      [23., 24.,  0., 27., 28.,  0., 31., 32.,  0.],
      [ 0., 25., 26.,  0., 29., 30.,  0.,  0.,  0.],
      [25., 26., 27., 29., 30., 31.,  0.,  0.,  0.],
      [26., 27., 28., 30., 31., 32.,  0.,  0.,  0.],
      [27., 28.,  0., 31., 32.,  0.,  0.,  0.,  0.]],

     [[ 0.,  0.,  0.,  0., 33., 34.,  0., 37., 38.],
      [ 0.,  0.,  0., 33., 34., 35., 37., 38., 39.],
      [ 0.,  0.,  0., 34., 35., 36., 38., 39., 40.],
      [ 0.,  0.,  0., 35., 36.,  0., 39., 40.,  0.],
      [ 0., 33., 34.,  0., 37., 38.,  0., 41., 42.],
      [33., 34., 35., 37., 38., 39., 41., 42., 43.],
      [34., 35., 36., 38., 39., 40., 42., 43., 44.],
      [35., 36.,  0., 39., 40.,  0., 43., 44.,  0.],
      [ 0., 37., 38.,  0., 41., 42.,  0., 45., 46.],
      [37., 38., 39., 41., 42., 43., 45., 46., 47.],
      [38., 39., 40., 42., 43., 44., 46., 47., 48.],
      [39., 40.,  0., 43., 44.,  0., 47., 48.,  0.],
      [ 0., 41., 42.,  0., 45., 46.,  0.,  0.,  0.],
      [41., 42., 43., 45., 46., 47.,  0.,  0.,  0.],
      [42., 43., 44., 46., 47., 48.,  0.,  0.,  0.],
      [43., 44.,  0., 47., 48.,  0.,  0.,  0.,  0.]]]]

I would like to have those rows where (row_index % n) == i (say n = 4 and i = 0 to 3) is saved in another tensor Y.

For example, for the data X[0][0]:

[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
 [ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
 [ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
 [ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
 [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
 [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
 [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
 [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
 [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
 [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
 [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
 [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
 [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.],
 [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.],
 [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.],      
 [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]]

I would like to have a tensor containing the following data, which is basically collection of the rows where row_index % 4 == 0 (here i = 0):

[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
 [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
 [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
 [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.]]

Similarly, where i = 1, row_index % 4 == i will look like:

[[ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
 [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
 [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
 [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.]]

when i = 2, row_index % 4 == i:

[[ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
 [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
 [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
 [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.]]

when i = 3, row_index % 4 == i:

[[ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
 [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
 [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
 [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]]

I have tried hard coding it and it doesn't seem practical when the data becomes larger and the size becomes dynamic and I assume that there would be a better way to come about it.

temp0 = data[0][0][0][:] 
temp1 = data[0][0][4][:]
temp2 = data[0][0][8][:]
temp3 = data[0][0][12][:]
temp = torch.stack([temp0,temp1,temp2,temp3],dim = 0)

Also, it would be great if the result can come back in one tensor like :

tensor Y = ([[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
              [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
              [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
              [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.]], 

             [[ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
              [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
              [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
              [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.]], 
   
             [[ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
              [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
              [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
              [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.]], 

             [[ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
              [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
              [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
              [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]]])

2 Answers 2

1

You can achieve this by first constructing a tensor containing the selected rows, then using torch.gather to assemble the final tensor.

Assuming we two lists I and N containing the values of i and n respectively:

I = [0, 1, 2, 3]
N = [4, 4, 4, 4]

First we construct the index tensor:

>>> index = torch.stack([(torch.arange(16) % n == i).nonzero() for i, n in zip(I, N)])
tensor([[[ 0],
         [ 4],
         [ 8],
         [12]],

        [[ 1],
         [ 5],
         [ 9],
         [13]],

        [[ 2],
         [ 6],
         [10],
         [14]],

        [[ 3],
         [ 7],
         [11],
         [15]]])

Then some expanding and reshaping is required:

>>> index_ = index[None].flatten(1,2).expand(X.size(0), -1, X.size(-1))
tensor([[[ 0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 4,  4,  4,  4,  4,  4,  4,  4,  4],
         [ 8,  8,  8,  8,  8,  8,  8,  8,  8],
         [12, 12, 12, 12, 12, 12, 12, 12, 12],
         [ 1,  1,  1,  1,  1,  1,  1,  1,  1],
         [ 5,  5,  5,  5,  5,  5,  5,  5,  5],
         [ 9,  9,  9,  9,  9,  9,  9,  9,  9],
         [13, 13, 13, 13, 13, 13, 13, 13, 13],
         [ 2,  2,  2,  2,  2,  2,  2,  2,  2],
         [ 6,  6,  6,  6,  6,  6,  6,  6,  6],
         [10, 10, 10, 10, 10, 10, 10, 10, 10],
         [14, 14, 14, 14, 14, 14, 14, 14, 14],
         [ 3,  3,  3,  3,  3,  3,  3,  3,  3],
         [ 7,  7,  7,  7,  7,  7,  7,  7,  7],
         [11, 11, 11, 11, 11, 11, 11, 11, 11],
         [15, 15, 15, 15, 15, 15, 15, 15, 15]]])

As a rule of thumb, we want index_ to have the same number of dimensions as X.

Now we can apply torch.gather and reshape to the final form:

>>> X.gather(1, index_).reshape(len(X), *index.shape[:2], -1)
tensor([[[[ 0.,  0.,  0.,  0.,  1.,  2.,  0.,  5.,  6.],
          [ 0.,  1.,  2.,  0.,  5.,  6.,  0.,  9., 10.],
          [ 0.,  5.,  6.,  0.,  9., 10.,  0., 13., 14.],
          [ 0.,  9., 10.,  0., 13., 14.,  0.,  0.,  0.]],

         [[ 0.,  0.,  0.,  1.,  2.,  3.,  5.,  6.,  7.],
          [ 1.,  2.,  3.,  5.,  6.,  7.,  9., 10., 11.],
          [ 5.,  6.,  7.,  9., 10., 11., 13., 14., 15.],
          [ 9., 10., 11., 13., 14., 15.,  0.,  0.,  0.]],

         [[ 0.,  0.,  0.,  2.,  3.,  4.,  6.,  7.,  8.],
          [ 2.,  3.,  4.,  6.,  7.,  8., 10., 11., 12.],
          [ 6.,  7.,  8., 10., 11., 12., 14., 15., 16.],
          [10., 11., 12., 14., 15., 16.,  0.,  0.,  0.]],

         [[ 0.,  0.,  0.,  3.,  4.,  0.,  7.,  8.,  0.],
          [ 3.,  4.,  0.,  7.,  8.,  0., 11., 12.,  0.],
          [ 7.,  8.,  0., 11., 12.,  0., 15., 16.,  0.],
          [11., 12.,  0., 15., 16.,  0.,  0.,  0.,  0.]]]])

This method can be extended to batch tensors:

>>> index = torch.stack([(torch.arange(16) % n == i).nonzero() for i, n in zip(I, N)])
>>> index_  = index[None,None].flatten(2,3).expand(X.size(0), X.size(1), -1, X.size(-1))

>>> X.gather(2, index_).reshape(*X.shape[:2], *index.shape[:2], -1)
Sign up to request clarification or add additional context in comments.

3 Comments

I get this error when I run "X.gather(1, index_).reshape(len(X), *index.shape[:2], -1)" : RuntimeError: Index tensor must have the same number of dimensions as input tensor. Also is there a good article or course to look up, because I can't seem to visualize the data manipulation functions and still struggling to comprehend your code Thank you
What is the shape of your input tensor X? I will edit my question if details on the broadcasting/reshaping approach.
Thank you for getting back. The batch extension works and it seems expanding the index_ did the trick for the first part. "index_ = index_.view(index_.size(0),1,index_.size(1),index_.size(2)) index_ = index_.expand(-1,X.size(1),-1,-1)". Thank you once again
1

First, to get each patrition you can try this:

import torch

data = torch.tensor([[[[0., 0., 0., 0., 1., 2., 0., 5., 6.],
                       [0., 0., 0., 1., 2., 3., 5., 6., 7.],
                       [0., 0., 0., 2., 3., 4., 6., 7., 8.],
                       [0., 0., 0., 3., 4., 0., 7., 8., 0.],
                       [0., 1., 2., 0., 5., 6., 0., 9., 10.],
                       [1., 2., 3., 5., 6., 7., 9., 10., 11.],
                       [2., 3., 4., 6., 7., 8., 10., 11., 12.],
                       [3., 4., 0., 7., 8., 0., 11., 12., 0.],
                       [0., 5., 6., 0., 9., 10., 0., 13., 14.],
                       [5., 6., 7., 9., 10., 11., 13., 14., 15.],
                       [6., 7., 8., 10., 11., 12., 14., 15., 16.],
                       [7., 8., 0., 11., 12., 0., 15., 16., 0.],
                       [0., 9., 10., 0., 13., 14., 0., 0., 0.],
                       [9., 10., 11., 13., 14., 15., 0., 0., 0.],
                       [10., 11., 12., 14., 15., 16., 0., 0., 0.],
                       [11., 12., 0., 15., 16., 0., 0., 0., 0.]],

                      [[0., 0., 0., 0., 17., 18., 0., 21., 22.],
                       [0., 0., 0., 17., 18., 19., 21., 22., 23.],
                       [0., 0., 0., 18., 19., 20., 22., 23., 24.],
                       [0., 0., 0., 19., 20., 0., 23., 24., 0.],
                       [0., 17., 18., 0., 21., 22., 0., 25., 26.],
                       [17., 18., 19., 21., 22., 23., 25., 26., 27.],
                       [18., 19., 20., 22., 23., 24., 26., 27., 28.],
                       [19., 20., 0., 23., 24., 0., 27., 28., 0.],
                       [0., 21., 22., 0., 25., 26., 0., 29., 30.],
                       [21., 22., 23., 25., 26., 27., 29., 30., 31.],
                       [22., 23., 24., 26., 27., 28., 30., 31., 32.],
                       [23., 24., 0., 27., 28., 0., 31., 32., 0.],
                       [0., 25., 26., 0., 29., 30., 0., 0., 0.],
                       [25., 26., 27., 29., 30., 31., 0., 0., 0.],
                       [26., 27., 28., 30., 31., 32., 0., 0., 0.],
                       [27., 28., 0., 31., 32., 0., 0., 0., 0.]],

                      [[0., 0., 0., 0., 33., 34., 0., 37., 38.],
                       [0., 0., 0., 33., 34., 35., 37., 38., 39.],
                       [0., 0., 0., 34., 35., 36., 38., 39., 40.],
                       [0., 0., 0., 35., 36., 0., 39., 40., 0.],
                       [0., 33., 34., 0., 37., 38., 0., 41., 42.],
                       [33., 34., 35., 37., 38., 39., 41., 42., 43.],
                       [34., 35., 36., 38., 39., 40., 42., 43., 44.],
                       [35., 36., 0., 39., 40., 0., 43., 44., 0.],
                       [0., 37., 38., 0., 41., 42., 0., 45., 46.],
                       [37., 38., 39., 41., 42., 43., 45., 46., 47.],
                       [38., 39., 40., 42., 43., 44., 46., 47., 48.],
                       [39., 40., 0., 43., 44., 0., 47., 48., 0.],
                       [0., 41., 42., 0., 45., 46., 0., 0., 0.],
                       [41., 42., 43., 45., 46., 47., 0., 0., 0.],
                       [42., 43., 44., 46., 47., 48., 0., 0., 0.],
                       [43., 44., 0., 47., 48., 0., 0., 0., 0.]]]])

print(data.shape)

n, i = 4, 0
indices = [index for index in range(data.shape[2]) if index % n == i]
print(data[0, 0, indices])

For the combination of those tensors you can try using:

n = 4
result = []
for i in range(n):
    indices = [index for index in range(data.shape[2]) if index % n == i]
    result.append(data[0, 0, indices])

final = torch.stack(result, dim=0)

1 Comment

Great, let me know if that works for you. If it does, please mark this question as answered :)

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.