0

I am looking for an elegant way to flatten an array of arbitrary shape to a matrix based on a single parameter that specifies the dimension to retain. For illustration, I would like

def my_func(input, dim):
    # code to compute output
    return output

Given for example an input array of shape 2x3x4, output should be for dim=0 an array of shape 12x2; for dim=1 an array of shape 8x3; for dim=2 an array of shape 6x8. If I want to flatten the last dimension only, then this is easily accomplished by

input.reshape(-1, input.shape[-1])

But I would like to add the functionality of adding dim (elegantly, without going through all possible cases + checking with if conditions, etc.). It might be possible by first swapping dimensions, so that the dimension of interest is trailing and then applying the operation above.

Any help?

1
  • 1
    With the corresponding transposes? Commented Dec 12, 2018 at 16:12

1 Answer 1

1

We can permute axes and reshape -

# a is input array; axis is input axis/dim
np.moveaxis(a,axis,-1).reshape(-1,a.shape[axis])

Functionally, it's basically pushing the specified axis to the back and then reshaping keeping that axis length to form the second axis and merging rest of the axes to form the first axis.

Sample runs -

In [32]: a = np.random.rand(2,3,4)

In [33]: axis = 0

In [34]: np.moveaxis(a,axis,-1).reshape(-1,a.shape[axis]).shape
Out[34]: (12, 2)

In [35]: axis = 1

In [36]: np.moveaxis(a,axis,-1).reshape(-1,a.shape[axis]).shape
Out[36]: (8, 3)

In [37]: axis = 2

In [38]: np.moveaxis(a,axis,-1).reshape(-1,a.shape[axis]).shape
Out[38]: (6, 4)
Sign up to request clarification or add additional context in comments.

1 Comment

Thanks. Worked for me.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.