Hello Everyone!, I am trying to contribute to Issue #6702. I thought this can be achieved by simply updating compare_and_swap function by adding additional arguments to it. But, I got the following error
LoweringError: Failed in cuda mode pipeline (step: native lowering)
No definition for lowering <class 'numba.cuda.stubs.atomic.compare_and_swap_element'>(array(int32, 1d, C), int32, int32, int32) -> int32
File "<ipython-input-3-afad3a55cf02>", line 4:
def atomic_compare_and_swap_elem(res, old, idx,ary, fill_val):
<source elided>
if gid < res.size:
out = cuda.atomic.compare_and_swap_element(res[gid:], idx[gid],fill_val, ary[idx[gid]])
^
During: lowering "out = call $22.3($22.9, $22.12, fill_val, $22.18, func=$22.3, args=[Var($22.9, <ipython-input-3-afad3a55cf02>:4), Var($22.12, <ipython-input-3-afad3a55cf02>:4), Var(fill_val, <ipython-input-3-afad3a55cf02>:2), Var($22.18, <ipython-input-3-afad3a55cf02>:4)], kws=(), vararg=None)" at <ipython-input-3-afad3a55cf02> (4)
The way I tried to update the function by cloning compare_and_swap
implementation. I included original compare_and_swap to make it easier for reference.
First, In Cudadecl.py
following are the changes
@register
class Cuda_atomic_compare_and_swap(AbstractTemplate):
key = cuda.atomic.compare_and_swap
def generic(self, args, kws):
assert not kws
ary, old, val = args
dty = ary.dtype
if dty in integer_numba_types and ary.ndim == 1:
return signature(dty, ary, dty, dty)
@register
class Cuda_atomic_compare_and_swap_element(AbstractTemplate):
key = cuda.atomic.compare_and_swap_element
def generic(self, args, kws):
assert not kws
ary, idx, old, val = args
dty = ary.dtype
if dty in integer_numba_types and ary.ndim == 1:
return signature(dty, ary,idx, dty, dty)
def resolve_compare_and_swap(self, mod):
return types.Function(Cuda_atomic_compare_and_swap)
def resolve_compare_and_swap_element(self, mod):
return types.Function(Cuda_atomic_compare_and_swap_element)
Following in Cudaimpl.py
@lower(stubs.atomic.compare_and_swap, types.Array, types.Any, types.Any)
def ptx_atomic_cas_tuple(context, builder, sig, args):
aryty, oldty, valty = sig.args
ary, old, val = args
dtype = aryty.dtype
lary = context.make_array(aryty)(context, builder, ary)
zero = context.get_constant(types.intp, 0)
ptr = cgutils.get_item_pointer(context, builder, aryty, lary, (zero,))
if aryty.dtype in (cuda.cudadecl.integer_numba_types):
lmod = builder.module
bitwidth = aryty.dtype.bitwidth
return nvvmutils.atomic_cmpxchg(builder, lmod, bitwidth, ptr, old, val)
else:
raise TypeError('Unimplemented atomic compare_and_swap '
'with %s array' % dtype)
@lower(stubs.atomic.compare_and_swap_element, types.Array, types.intp, types.Any, types.Any)
def ptx_atomic_cas_element_tuple(context, builder, sig, args):
"""
Need to update correct implementation
"""
aryty, idxty, oldty, valty = sig.args
ary, idx, old, val = args
dtype = aryty.dtype
lary = context.make_array(aryty)(context, builder, ary)
zero = context.get_constant(types.intp, 0)
ptr = cgutils.get_item_pointer(context, builder, aryty, lary, (zero,))
if aryty.dtype in (cuda.cudadecl.integer_numba_types):
lmod = builder.module
bitwidth = aryty.dtype.bitwidth
return nvvmutils.atomic_cmpxchg(builder, lmod, bitwidth, ptr+idx, old, val)
else:
raise TypeError('Unimplemented atomic compare_and_swap_element ')
following in kernelapi.py
def compare_and_swap(self, array, old, val):
with caslock:
index = (0,) * array.ndim
loaded = array[index]
if loaded == old:
array[index] = val
return loaded
def compare_and_swap_element(self, array,idx, old, val):
with caslock:
index = (idx,) * array.ndim
loaded = array[index]
if loaded == old:
array[index] = val
return loaded