4

Given an 2-dimensional tensor in numpy (or in pytorch), I can partially slice along all dimensions at once as follows:

>>> import numpy as np
>>> a = np.arange(2*3).reshape(2,3)
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
>>> a[1:,1:]
array([[ 5,  6,  7],
       [ 9, 10, 11]])

How can I achieve the same slicing pattern regardless of the number of dimensions in the tensor if I do not know the number of dimensions at implementation time? (i.e. I want a[1:] if a has only one dimension, a[1:,1:] for two dimensions, a[1:,1:,1:] for three dimensions, and so on)

It would be nice if I could do it in a single line of code like the following, but this is invalid:

a[(1:,) * len(a.shape)]  # SyntaxError: invalid syntax

I am specifically interested in a solution that works for pytorch tensors (just substitute torch for numpy above and the example is the same), but I figure it is likely and best if the solution works for both numpy and pytorch.

1
  • I was pleasantly surprised that I couldn't find a duplicate Commented Jun 11, 2019 at 1:00

1 Answer 1

5

Answer: Making a tuple of slice objects does the trick:

a[(slice(1,None),) * len(a.shape)]

Explanation: slice is a builtin python class (not tied to numpy or pytorch) which provides an alternative to the subscript notation for describing slices. The answer to a different question suggests using this as a way to store slice information in python variables. The python glossary points out that

The bracket (subscript) notation uses slice objects internally.

Since the __getitem__ methods for numpy ndarrays and pytorch tensors support multi-dimensional indexing with slices, they must also support multidimensional indexing with slice objects, and so we can make a tuple of those slices that of the right length.

Btw, you can see how python uses the slice objects by creating a dummy class as follows and then do slicing on it:

class A(object):
    def __getitem__(self, ix):
        return ix

print(A()[5])  # 5
print(A()[1:])  # slice(1, None, None)
print(A()[1:,1:])  # (slice(1, None, None), slice(1, None, None))
print(A()[1:,slice(1,None)])  #  (slice(1, None, None), slice(1, None, None))


Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.