I am rewriting the heapq to use numba with structured arrays. In one of the functions, I compare two floating point values (two errors) stored inside a structured array.
The functions work fine when I do not use @njit
and I access the data using square brackets, i.e. arr['var']
. When I decorate the exact same functions using @njit
but this time I access the data using .dot notation, i.e. arr.var
the functions behaves differently (they no longer gives me the right result).
I printed both of the values in both versions before comparing them and there are exactly the same.
The only difference I could find was is their type. Within the numba version functions the types returned were float32
in the undecorated functions the types were np.float32
. Still this does not appear to explain the misbehavior of the functions.
Here is all that is needed to reproduce my results,
import numpy as np
import numba as nb
entryDtype = np.dtype([('error', 'f4'),
('triangle_id', 'i8')])
entryNBtype = nb.from_dtype(entryDtype)
entry0 = np.record((0.75, 1), dtype=entryNBtype)
entry1 = np.record((0.5, 0), dtype=entryNBtype)
entry2 = np.record((-0.5, 2), dtype=entryNBtype)
entry3 = np.record((0.01, 3), dtype=entryNBtype)
def push(pq, loc, item):
pq.append(item)
print(pq) # debug
_siftdown(pq, loc, 0, len(pq)-1)
def _siftdown(pq, loc, startpos, pos):
newitem = pq[pos]
while pos > startpos:
parentpos = (pos - 1) >> 1
parent = pq[parentpos]
print('BEFORE: newitem.error, pos, parent.error, parentpos', (newitem['error'], pos, parent['error'], parentpos)) # debug
if newitem['error'] < parent['error']:
pq[pos] = parent
loc[parent['triangle_id']] = pos
pos = parentpos
print('\nAFTER: newitem.error, pos, parent.error, parentpos', (newitem['error'], pos, parent['error'], parentpos)) # debug
continue
break
pq[pos] = newitem
loc[newitem['triangle_id']] = pos
print('\nafter ',pq) #debug
queue=[]
loc= {}
push(queue, loc, entry0)
push(queue, loc, entry1)
push(queue, loc, entry2)
push(queue, loc, entry3)
Here is the numba version that is not working properly,
import numpy as np
import numba as nb
entryDtype = np.dtype([('error', 'f4'),
('triangle_id', 'i8')])
entryNBtype = nb.from_dtype(entryDtype)
entry0 = np.record((0.75, 1), dtype=entryNBtype)
entry1 = np.record((0.5, 0), dtype=entryNBtype)
entry2 = np.record((-0.5, 2), dtype=entryNBtype)
entry3 = np.record((0.01, 3), dtype=entryNBtype)
@njit
def push2(pq, loc, item):
pq.append(item)
print(pq) # debug
_siftdown2(pq, loc, 0, len(pq)-1)
@njit
def _siftdown2(pq, loc, startpos, pos):
newitem = pq[pos]
while pos > startpos:
parentpos = (pos - 1) >> 1
parent = pq[parentpos]
print('BEFORE: newitem.error, pos, parent.error, parentpos', (newitem.error, pos, parent.error, parentpos)) #debug
if newitem.error < parent.error:
pq[pos] = parent
loc[parent.triangle_id] = pos
pos = parentpos
print('\nAFTER: newitem.error, pos, parent.error, parentpos', (newitem.error, pos, parent.error, parentpos)) #debug
continue
break
pq[pos] = newitem
loc[newitem.triangle_id] = pos
print('\nafter ',pq) #debug
queue2 = nb.typed.List.empty_list(entryNBtype)
loc2 = nb.typed.Dict.empty(nb.i8, nb.i8)
push2(queue2, loc2, entry0)
push2(queue2, loc2, entry1)
push2(queue2, loc2, entry2)
push2(queue2, loc2, entry3)
Why arenât these two set of function behaving the same?