I have a (working) function that uses the heapq module to build a priority queue of tuples and I would like to compile it with numba, however I get a very long and unclear error. It seems to boil down to a problem with tuple order comparison needed for the queue. The tuples have a fixed format, where the first item is a a floating point number (whose order I care about) and then a numpy array, which I need for computation but never get compared when running normally. This is intended because comparison on numpy arrays yields an array which cannot be used in conditionals and raises an exception. However, I guess numba needs a scalar yielding comparison to be at least defined for all items in the tuple, and hence the numba error.
I have a very minimal example:
@numba.njit
def f():
return 1 if (1, numpy.arange(3)) < (2, numpy.arange(3)) else 2
f()
where the numba compilation fails (without numba it works since it never needs to actually compare the arrays, as in the original code).
Here is a slightly less minimal but maybe clearer example, which shows what I am actually doing:
from heapq import heappush
import numpy
import numba
@numba.njit
def f(n):
heap = [(1, 0, numpy.random.rand(2, 3))]
for unique_id in range(n):
order = numpy.random.rand()
data = numpy.random.rand(2, 3)
heappush(heap, (order, unique_id, data))
return heap[0]
f(100)
Here order is the variable whose order I care about in the queue, unique_id is a trick to avoid that when order is the same the comparison goes on to data and throws an exception.
I tried to bypass the problem converting the numpy array to a list when in the tuple and back to array for computation, but while this compiles, the numba version is slower than the interpreted one, even thought the array is quite small (usually 2x3). Without converting I would need to rewrite the code as loops which I would prefer to avoid (but is doable).
Is there a better alternative to get this working with numba, hopefully running faster than the python interpreter?
The truth value of an array with more than one element is ambiguous. Use a.any() or a.all(), ie. it does not make sense to compare arrays the way you do). Using this as a feature is clearly not a good idea.