1

I have a square numpy.ndarray and a numpy boolean mask of the same shape. I want to find the first element in each row of the array that is not masked.

My code currently relies on numpy.ma.notmasked_edges(), which does exactly what I need. However, I now need to migrate my code to JAX, which has not implemented numpy.ma within jax.numpy.

What would be the simplest way to find the index of the first unmasked element in each row, calling only numpy functions that have been implemented in JAX (which exclude numpy.ma)?

The code I'm trying to reproduce is something like:

import numpy as np
my_array = np.random.rand(5,5)
mask = (my_array < 0.5)
my_masked_array = np.ma.masked_array(my_array, mask=mask)
np.ma.notmasked_edges(my_masked_array, axis=1)[0]

I'm sure there are many ways to do this, but I'm looking for the least unwieldy way.

2
  • 1
    I'm not sure you'd be able to do this in JAX, as the size of the output from np.ma.notmasked_edges depends on the values of the input - which isn't allowed as shapes must be static. You can get pretty close to what you want with: np.arange(my_array.shape[0]),(my_array > 0.5).argmax(axis=1). The main issue with this is that if all values are masked then it will give 0. Commented Jun 24, 2024 at 2:53
  • This should work for me, because it cannot happen in my actual code that all values are masked. I had not realized that argmax worked on boolean arrays. Commented Jun 24, 2024 at 3:12

1 Answer 1

1

Here's a JAX implementation of nonmasked_edges, which takes a boolean mask and returns the same indices returned by the numpy.ma function:

import jax.numpy as jnp

def notmasked_edges(mask, axis=None):
  mask = jnp.asarray(mask)
  assert mask.dtype == bool
  if axis is None:
    mask = mask.ravel()
    axis = 0
  shape = list(mask.shape)
  del shape[axis]
  alltrue = mask.all(axis=axis).ravel()
  indices = jnp.meshgrid(*(jnp.arange(n) for n in shape), indexing='ij')
  indices = [jnp.ravel(ind)[~alltrue] for ind in indices]

  first = indices.copy()
  first.insert(axis, jnp.argmin(mask, axis=axis).ravel()[~alltrue])

  last = indices.copy()
  last.insert(axis, mask.shape[axis] - 1 - jnp.argmin(jnp.flip(mask, axis=axis), axis=axis).ravel()[~alltrue])
  
  return [tuple(first), tuple(last)]

This will not be compatible with JIT, because the size of the output arrays depend on the values of the mask (rows which have no unmasked value are left out).

If you want a JIT-compatible version, you can remove the [~alltrue] indexing, and the first/last index will be returned for rows that have no unmasked value:

def notmasked_edges_v2(mask, axis=None):
  mask = jnp.asarray(mask)
  assert mask.dtype == bool
  if axis is None:
    mask = mask.ravel()
    axis = 0
  shape = list(mask.shape)
  del shape[axis]
  indices = jnp.meshgrid(*(jnp.arange(n) for n in shape), indexing='ij')
  indices = [jnp.ravel(ind) for ind in indices]

  first = indices.copy()
  first.insert(axis, jnp.argmin(mask, axis=axis).ravel())

  last = indices.copy()
  last.insert(axis, mask.shape[axis] - 1 - jnp.argmin(jnp.flip(mask, axis=axis), axis=axis).ravel())

  return [tuple(first), tuple(last)]

Here's an example:

import numpy as np
mask = np.array([[True, False, False, True],
                 [False, False, True, True],
                 [True, True, True, True]])

arr = np.ma.masked_array(np.ones_like(mask), mask=mask)
print(np.ma.notmasked_edges(arr, axis=1))
# [(array([0, 1]), array([1, 0])), (array([0, 1]), array([2, 1]))]

print(notmasked_edges(mask, axis=1))
# [(Array([0, 1], dtype=int32), Array([1, 0], dtype=int32)),
#  (Array([0, 1], dtype=int32), Array([2, 1], dtype=int32))]

print(notmasked_edges_v2(mask, axis=1))
# [(Array([0, 1, 2], dtype=int32), Array([1, 0, 0], dtype=int32)),
#  (Array([0, 1, 2], dtype=int32), Array([2, 1, 3], dtype=int32))]
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.