I'm trying to optimize the function 'pw' in the following code using only NumPy functions (or perhaps list comprehensions).
from time import time
import numpy as np
def pw(x, udata):
"""
Creates the step function
| 1, if d0 <= x < d1
| 2, if d1 <= x < d2
pw(x,data) = ...
| N, if d(N-1) <= x < dN
| 0, otherwise
where di is the ith element in data.
INPUT: x -- interval which the step function is defined over
data -- an ordered set of data (without repetitions)
OUTPUT: pw_func -- an array of size x.shape[0]
"""
vals = np.arange(1,udata.shape[0]+1).reshape(udata.shape[0],1)
pw_func = np.sum(np.where(np.greater_equal(x,udata)*np.less(x,np.roll(udata,-1)),vals,0),axis=0)
return pw_func
N = 50000
x = np.linspace(0,10,N)
data = [1,3,4,5,5,7]
udata = np.unique(data)
ti = time()
pw(x,udata)
tf = time()
print(tf - ti)
import cProfile
cProfile.run('pw(x,udata)')
The cProfile.run is telling me that most of the overhead is coming from np.where (about 1 ms) but I'd like to create faster code if possible. It seems that performing the operations row-wise versus column-wise makes some difference, unless I'm mistaken, but I think I've accounted for it. I know that sometimes list comprehensions can be faster but I couldn't figure out a faster way than what I'm doing using it.
Searchsorted seems to yield better performance but that 1 ms still remains on my computer:
(modified)
def pw(xx, uu):
"""
Creates the step function
| 1, if d0 <= x < d1
| 2, if d1 <= x < d2
pw(x,data) = ...
| N, if d(N-1) <= x < dN
| 0, otherwise
where di is the ith element in data.
INPUT: x -- interval which the step function is defined over
data -- an ordered set of data (without repetitions)
OUTPUT: pw_func -- an array of size x.shape[0]
"""
inds = np.searchsorted(uu, xx, side='right')
vals = np.arange(1,uu.shape[0]+1)
pw_func = vals[inds[inds != uu.shape[0]]]
num_mins = np.sum(xx < np.min(uu))
num_maxs = np.sum(xx > np.max(uu))
pw_func = np.concatenate((np.zeros(num_mins), pw_func, np.zeros(xx.shape[0]-pw_func.shape[0]-num_mins)))
return pw_func
This answer using piecewise seems pretty close, but that's on a scalar x0 and x1. How would I do it on arrays? And would it be more efficient?
Understandably, x may be pretty big but I'm trying to put it through a stress test.
I am still learning though so some hints or tricks that can help me out would be great.
EDIT
There seems to be a mistake in the second function since the resulting array from the second function doesn't match the first one (which I'm confident that it works):
N1 = pw1(x,udata.reshape(udata.shape[0],1)).shape[0]
N2 = np.sum(pw1(x,udata.reshape(udata.shape[0],1)) == pw2(x,udata))
print(N1 - N2)
yields
15000
data points that are not the same. So it seems that I don't know how to use 'searchsorted'.
EDIT 2
Actually I fixed it:
pw_func = vals[inds[inds != uu.shape[0]]]
was changed to
pw_func = vals[inds[inds[(inds != uu.shape[0])*(inds != 0)]-1]]
so at least the resulting arrays match. But the question still remains on whether there's a more efficient way of going about doing this.
EDIT 3
Thanks Tin Lai for pointing out the mistake. This one should work
pw_func = vals[inds[(inds != uu.shape[0])*(inds != 0)]-1]
Maybe a more readable way of presenting it would be
non_endpts = (inds != uu.shape[0])*(inds != 0) # only consider the points in between the min/max data values
shift_inds = inds[non_endpts]-1 # searchsorted side='right' includes the left end point and not right end point so a shift is needed
pw_func = vals[shift_inds]
I think I got lost in all those brackets! I guess that's the importance of readability.

