Here's a JAX implementation of nonmasked_edges, which takes a boolean mask and returns the same indices returned by the numpy.ma function:
import jax.numpy as jnp
def notmasked_edges(mask, axis=None):
mask = jnp.asarray(mask)
assert mask.dtype == bool
if axis is None:
mask = mask.ravel()
axis = 0
shape = list(mask.shape)
del shape[axis]
alltrue = mask.all(axis=axis).ravel()
indices = jnp.meshgrid(*(jnp.arange(n) for n in shape), indexing='ij')
indices = [jnp.ravel(ind)[~alltrue] for ind in indices]
first = indices.copy()
first.insert(axis, jnp.argmin(mask, axis=axis).ravel()[~alltrue])
last = indices.copy()
last.insert(axis, mask.shape[axis] - 1 - jnp.argmin(jnp.flip(mask, axis=axis), axis=axis).ravel()[~alltrue])
return [tuple(first), tuple(last)]
This will not be compatible with JIT, because the size of the output arrays depend on the values of the mask (rows which have no unmasked value are left out).
If you want a JIT-compatible version, you can remove the [~alltrue] indexing, and the first/last index will be returned for rows that have no unmasked value:
def notmasked_edges_v2(mask, axis=None):
mask = jnp.asarray(mask)
assert mask.dtype == bool
if axis is None:
mask = mask.ravel()
axis = 0
shape = list(mask.shape)
del shape[axis]
indices = jnp.meshgrid(*(jnp.arange(n) for n in shape), indexing='ij')
indices = [jnp.ravel(ind) for ind in indices]
first = indices.copy()
first.insert(axis, jnp.argmin(mask, axis=axis).ravel())
last = indices.copy()
last.insert(axis, mask.shape[axis] - 1 - jnp.argmin(jnp.flip(mask, axis=axis), axis=axis).ravel())
return [tuple(first), tuple(last)]
Here's an example:
import numpy as np
mask = np.array([[True, False, False, True],
[False, False, True, True],
[True, True, True, True]])
arr = np.ma.masked_array(np.ones_like(mask), mask=mask)
print(np.ma.notmasked_edges(arr, axis=1))
# [(array([0, 1]), array([1, 0])), (array([0, 1]), array([2, 1]))]
print(notmasked_edges(mask, axis=1))
# [(Array([0, 1], dtype=int32), Array([1, 0], dtype=int32)),
# (Array([0, 1], dtype=int32), Array([2, 1], dtype=int32))]
print(notmasked_edges_v2(mask, axis=1))
# [(Array([0, 1, 2], dtype=int32), Array([1, 0, 0], dtype=int32)),
# (Array([0, 1, 2], dtype=int32), Array([2, 1, 3], dtype=int32))]
np.ma.notmasked_edgesdepends on the values of the input - which isn't allowed as shapes must be static. You can get pretty close to what you want with:np.arange(my_array.shape[0]),(my_array > 0.5).argmax(axis=1). The main issue with this is that if all values are masked then it will give0.