Trying to implement Compare and swap element

Hello Everyone!, I am trying to contribute to Issue #6702. I thought this can be achieved by simply updating compare_and_swap function by adding additional arguments to it. But, I got the following error

LoweringError: Failed in cuda mode pipeline (step: native lowering)
No definition for lowering <class 'numba.cuda.stubs.atomic.compare_and_swap_element'>(array(int32, 1d, C), int32, int32, int32) -> int32

File "<ipython-input-3-afad3a55cf02>", line 4:
def atomic_compare_and_swap_elem(res, old, idx,ary, fill_val):
    <source elided>
    if gid < res.size:
        out = cuda.atomic.compare_and_swap_element(res[gid:], idx[gid],fill_val, ary[idx[gid]])
        ^

During: lowering "out = call $22.3($22.9, $22.12, fill_val, $22.18, func=$22.3, args=[Var($22.9, <ipython-input-3-afad3a55cf02>:4), Var($22.12, <ipython-input-3-afad3a55cf02>:4), Var(fill_val, <ipython-input-3-afad3a55cf02>:2), Var($22.18, <ipython-input-3-afad3a55cf02>:4)], kws=(), vararg=None)" at <ipython-input-3-afad3a55cf02> (4)

The way I tried to update the function by cloning compare_and_swap implementation. I included original compare_and_swap to make it easier for reference.

First, In Cudadecl.py following are the changes

@register
class Cuda_atomic_compare_and_swap(AbstractTemplate):
    key = cuda.atomic.compare_and_swap

    def generic(self, args, kws):
        assert not kws
        ary, old, val = args
        dty = ary.dtype

        if dty in integer_numba_types and ary.ndim == 1:
            return signature(dty, ary, dty, dty)

@register
class Cuda_atomic_compare_and_swap_element(AbstractTemplate):
    key = cuda.atomic.compare_and_swap_element

    def generic(self, args, kws):
        assert not kws
        ary, idx, old, val = args
        dty = ary.dtype

        if dty in integer_numba_types and ary.ndim == 1:
            return signature(dty, ary,idx, dty, dty)

 def resolve_compare_and_swap(self, mod):
        return types.Function(Cuda_atomic_compare_and_swap)

 def resolve_compare_and_swap_element(self, mod):
        return types.Function(Cuda_atomic_compare_and_swap_element)

Following in Cudaimpl.py

@lower(stubs.atomic.compare_and_swap, types.Array, types.Any, types.Any)
def ptx_atomic_cas_tuple(context, builder, sig, args):
    aryty, oldty, valty = sig.args
    ary, old, val = args
    dtype = aryty.dtype

    lary = context.make_array(aryty)(context, builder, ary)
    zero = context.get_constant(types.intp, 0)
    ptr = cgutils.get_item_pointer(context, builder, aryty, lary, (zero,))

    if aryty.dtype in (cuda.cudadecl.integer_numba_types):
        lmod = builder.module
        bitwidth = aryty.dtype.bitwidth
        return nvvmutils.atomic_cmpxchg(builder, lmod, bitwidth, ptr, old, val)
    else:
        raise TypeError('Unimplemented atomic compare_and_swap '
                        'with %s array' % dtype)

@lower(stubs.atomic.compare_and_swap_element, types.Array, types.intp, types.Any, types.Any)
def ptx_atomic_cas_element_tuple(context, builder, sig, args):
    """
    Need to update correct implementation
    """
    aryty, idxty, oldty, valty = sig.args
    ary, idx, old, val = args
    dtype = aryty.dtype

    lary = context.make_array(aryty)(context, builder, ary)
    zero = context.get_constant(types.intp, 0)
    ptr = cgutils.get_item_pointer(context, builder, aryty, lary, (zero,))

    if aryty.dtype in (cuda.cudadecl.integer_numba_types):
        lmod = builder.module
        bitwidth = aryty.dtype.bitwidth
        return nvvmutils.atomic_cmpxchg(builder, lmod, bitwidth, ptr+idx, old, val)
    else:
        raise TypeError('Unimplemented atomic compare_and_swap_element ')

following in kernelapi.py

        def compare_and_swap(self, array, old, val):
            with caslock:
                index = (0,) * array.ndim
                loaded = array[index]
                if loaded == old:
                    array[index] = val
                return loaded

    def compare_and_swap_element(self, array,idx, old, val):
        with caslock:
            index = (idx,) * array.ndim
            loaded = array[index]
            if loaded == old:
                array[index] = val
            return loaded

There’s a few changes needed here. First of all, your test code is omitted, so some test code needs to be added:

diff --git a/numba/cuda/tests/cudapy/test_atomics.py b/numba/cuda/tests/cudapy/test_atomics.py
index edc7d5148..0435c372e 100644
--- a/numba/cuda/tests/cudapy/test_atomics.py
+++ b/numba/cuda/tests/cudapy/test_atomics.py
@@ -456,6 +456,14 @@ def atomic_compare_and_swap(res, old, ary, fill_val):
         old[gid] = out
 
 
+def atomic_compare_and_swap_element(res, old, ary, fill_val):
+    gid = cuda.grid(1)
+    if gid < res.size:
+        out = cuda.atomic.compare_and_swap_element(res, gid,
+                                                   fill_val, ary[gid])
+        old[gid] = out
+
+
 class TestCudaAtomics(CUDATestCase):
     def setUp(self):
         np.random.seed(0)
@@ -1277,23 +1285,55 @@ class TestCudaAtomics(CUDATestCase):
         np.testing.assert_array_equal(expect_res, res)
         np.testing.assert_array_equal(expect_out, out)
 
+    def check_compare_and_swap_element(self, n, fill, unfill, dtype):
+        res = [fill] * (n // 2) + [unfill] * (n // 2)
+        np.random.shuffle(res)
+        res = np.asarray(res, dtype=dtype)
+        out = np.zeros_like(res)
+        ary = np.random.randint(1, 10, size=res.size).astype(res.dtype)
+
+        fill_mask = res == fill
+        unfill_mask = res == unfill
+
+        expect_res = np.zeros_like(res)
+        expect_res[fill_mask] = ary[fill_mask]
+        expect_res[unfill_mask] = unfill
+
+        expect_out = np.zeros_like(out)
+        expect_out[fill_mask] = res[fill_mask]
+        expect_out[unfill_mask] = unfill
+
+        cuda_func = cuda.jit(atomic_compare_and_swap_element)
+        cuda_func[10, 10](res, out, ary, fill)
+
+        np.testing.assert_array_equal(expect_res, res)
+        np.testing.assert_array_equal(expect_out, out)
+
     def test_atomic_compare_and_swap(self):
         self.check_compare_and_swap(n=100, fill=-99, unfill=-1, dtype=np.int32)
+        self.check_compare_and_swap_element(n=100, fill=-99, unfill=-1,
+                                            dtype=np.int32)
 
     def test_atomic_compare_and_swap2(self):
         self.check_compare_and_swap(n=100, fill=-45, unfill=-1, dtype=np.int64)
+        self.check_compare_and_swap_element(n=100, fill=-45, unfill=-1,
+                                            dtype=np.int64)
 
     def test_atomic_compare_and_swap3(self):
         rfill = np.random.randint(50, 500, dtype=np.uint32)
         runfill = np.random.randint(1, 25, dtype=np.uint32)
         self.check_compare_and_swap(n=100, fill=rfill, unfill=runfill,
                                     dtype=np.uint32)
+        self.check_compare_and_swap_element(n=100, fill=rfill, unfill=runfill,
+                                            dtype=np.uint32)
 
     def test_atomic_compare_and_swap4(self):
         rfill = np.random.randint(50, 500, dtype=np.uint64)
         runfill = np.random.randint(1, 25, dtype=np.uint64)
         self.check_compare_and_swap(n=100, fill=rfill, unfill=runfill,
                                     dtype=np.uint64)
+        self.check_compare_and_swap_element(n=100, fill=rfill, unfill=runfill,
+                                            dtype=np.uint64)
 
     # Tests that the atomic add, min, and max operations return the old value -
     # in the simulator, they did not (see Issue #5458). The max and min have

This is based on the atomic_compare_and_swap test - there’s no need for an additional index, as the indexing in atomic_compare_and_swap is done using gid, but the fact that compare_and_swap doesn’t accept indices is worked around by passing res[gid:] to it - for atomic_compare_and_swap_element function we can just pass res as the array and gid as the index.

A stub is missing from your code above - I think you have it, but didn’t paste it above otherwise you wouldn’t have got as far as getting a LoweringError:

diff --git a/numba/cuda/stubs.py b/numba/cuda/stubs.py
index ba4d185d2..7ee15796f 100644
--- a/numba/cuda/stubs.py
+++ b/numba/cuda/stubs.py
@@ -552,3 +552,12 @@ class atomic(Stub):
 
         Returns the current value as if it is loaded atomically.
         """
+
+    class compare_and_swap_element(Stub):
+        """compare_and_swap_element(ary, idx, old, val)
+
+        Conditionally assign ``val`` to the first element of an 1D array ``ary``
+        if the current value matches ``old``.
+
+        Returns the current value as if it is loaded atomically.
+        """

Your typing implementation needs to type the index type as types.intp to match its typing in your lowering function. You also need to add a resolve_compare_and_swap_element function in the CudaAtomicTemplate class. The complete changes I have applied to cudadecl.py are:

diff --git a/numba/cuda/cudadecl.py b/numba/cuda/cudadecl.py
index d9e2b9da1..7a2d2c704 100644
--- a/numba/cuda/cudadecl.py
+++ b/numba/cuda/cudadecl.py
@@ -16,6 +16,24 @@ register_global = registry.register_global
 register_number_classes(register_global)
 
 
+@register
+class Cuda_atomic_compare_and_swap_element(AbstractTemplate):
+    key = cuda.atomic.compare_and_swap_element
+
+    def generic(self, args, kws):
+        assert not kws
+        ary, idx, old, val = args
+        dty = ary.dtype
+
+        if (dty in integer_numba_types and idx in integer_numba_types and
+                ary.ndim == 1):
+            return signature(dty, ary, types.intp, dty, dty)
+
+
+def resolve_compare_and_swap_element(self, mod):
+    return types.Function(Cuda_atomic_compare_and_swap_element)
+
+
 class GridFunction(CallableTemplate):
     def generic(self):
         def typer(ndim):
@@ -452,6 +470,9 @@ class CudaAtomicTemplate(AttributeTemplate):
     def resolve_compare_and_swap(self, mod):
         return types.Function(Cuda_atomic_compare_and_swap)
 
+    def resolve_compare_and_swap_element(self, mod):
+        return types.Function(Cuda_atomic_compare_and_swap_element)
+
 
 @register_attr
 class CudaModuleTemplate(AttributeTemplate):

Next, your lowering implementation doesn’t work because it’s not possible to just add an integer to the pointer (from your code above):

        return nvvmutils.atomic_cmpxchg(builder, lmod, bitwidth, ptr+idx, old, val)

This is because code in lowering functions needs to construct the LLVM IR for the operation they’re implementing, not directly implement it in Python. The _atomic_dispatcher decorator implements the indexing logic used by most of the other atomic functions, but as the compare_and_swap_element function doesn’t have the same signature as these other functions, it doesn’t apply directly. Instead we can make the lowering function work by duplicating most of the logic from it instead. This makes our lowering function look like:

@lower(stubs.atomic.compare_and_swap_element, types.Array, types.intp,
       types.Any, types.Any)
def ptx_atomic_cas_element_tuple(context, builder, sig, args):
    aryty, indty, oldty, valty = sig.args
    ary, inds, old, val = args
    dtype = aryty.dtype

    indty, indices = _normalize_indices(context, builder, indty, inds)

    if dtype != valty:
        raise TypeError("expect %s but got %s" % (dtype, valty))

    if aryty.ndim != len(indty):
        raise TypeError("indexing %d-D array with %d-D index" %
                        (aryty.ndim, len(indty)))

    lary = context.make_array(aryty)(context, builder, ary)
    ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices,
                                   wraparound=True)

    if dtype in (cuda.cudadecl.integer_numba_types):
        lmod = builder.module
        bitwidth = dtype.bitwidth
        return nvvmutils.atomic_cmpxchg(builder, lmod, bitwidth, ptr,
                                        old, val)
    else:
        raise TypeError('Unimplemented atomic compare_and_swap_element ')

This gives enough to make compare_and_swap_element work well enough to pass the additional tests added above. You can find these changes in the grm-cas-element branch in commit d3d7c43.

This is not quite PR-ready. Some additional work needed:

  • Rename the function to cas_element, or something else appropriate - compare_and_swap_element seems a bit long to me.
  • Split out the tests of compare_and_swap_element from the tests of compare_and_swap, do something about the duplication / copy-pasting of code between the tests.
  • Ensure that it works in the simulator (I haven’t looked at the simulator implementation).
  • Do something about the fact that most of the indexing logic is copied from _atomic_dispatcher function - perhaps the logic can be shared between _atomic_dispatcher and ptx_atomic_cas_element_tuple, or maybe _atomic_dispatcher can be adapted to decorate ptx_atomic_cas_element_tuple - I haven’t given this thought yet so would require some judgment.
  • Add documentation.
  • Fix the docstring in the stub.
  • In the interest of time, I pasted some of your changes into files at more-or-less random points - some functions could be placed more appropriately.
1 Like