I'm trying to find the fastest way to to get the functionality of numpy's 'where' statement on a 2D numpy array; namely, retrieving the indices where a condition is met. It is simply much slower than other languages I have used (e.g., IDL, Matlab).
I have cythonized a function that marches through the array in nested for-loops. There is almost an order of magnitude increase in speed, but I would like to increase performance even more, if possible.
TEST.py:
from cython_where import *
import time
import numpy as np
data = np.zeros((2600,5200))
data[100:200,100:200] = 10
t0 = time.time()
inds,ct = cython_where(data,'EQ',10)
print time.time() - t0
t1 = time.time()
tmp = np.where(data == 10)
print time.time() - t1
My cython_where.pyx program:
from __future__ import division
import numpy as np
cimport numpy as np
cimport cython
DTYPE1 = np.float
ctypedef np.float_t DTYPE1_t
DTYPE2 = np.int
ctypedef np.int_t DTYPE2_t
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
def cython_where(np.ndarray[DTYPE1_t, ndim=2] data, oper, DTYPE1_t val):
assert data.dtype == DTYPE1
cdef int xmax = data.shape[0]
cdef int ymax = data.shape[1]
cdef unsigned int x, y
cdef int count = 0
cdef np.ndarray[DTYPE2_t, ndim=1] xind = np.zeros(100000,dtype=int)
cdef np.ndarray[DTYPE2_t, ndim=1] yind = np.zeros(100000,dtype=int)
if(oper == 'EQ' or oper == 'eq'): #I didn't want to include GT, GE, LT, LE here
for x in xrange(xmax):
for y in xrange(ymax):
if(data[x,y] == val):
xind[count] = x
yind[count] = y
count += 1
return tuple([xind[0:count],yind[0:count]]),count
Output of TEST.py:
cython_test]$ python TEST.py
0.0139019489288
0.0982608795166
I've also tried numpy's argwhere, which is about as fast as where. I'm pretty new to numpy and cython, so if you have any other ideas to really increase performance, I'm all ears!
np.nonzero(which thiswhereuses), usesnp.count_nonzeroto allocate the result arrays. So it ends up looping through the array twice, but the count iteration is quite fastPyArray_Nonzeroinhttps://github.com/numpy/numpy/blob/c0e48cfbbdef9cca954b0c4edd0052e1ec8a30aa/numpy/core/src/multiarray/item_selection.cis the source code fornp.nonzero.