Support for atomic types in numba?

Hi, is it possible to use atomic types?

Our use case is interacting with a shared memory and the synchronization is done via some atomics.

Is something like this possible with numba? I am guessing I have to resort to intrinsics?

Okay so, I used Claude Code and was able to get something working GitHub - hbina/numba at hbina-cpu-atomics-claude
Is this even usable? Is the LLM hallucinating the solution?
I am trying to grok the changes and some of it makes sense, some not.
The LLVM IR produced does seem to have “atomic” in it but I am not sure if it actually works.
I am certainly not able to make it work via numpy.memmap.

import multiprocessing
import tempfile
import os

import numba.cpu
from numba import njit
import numpy as np

# Create a memory-mapped array backed by a temporary file
temp_file = tempfile.NamedTemporaryFile(prefix='shared_array_', dir='/tmp', delete=False)
arr = np.memmap(temp_file.name, dtype=np.int64, mode='w+', shape=(1,))


@numba.njit(cache=True)
def fetch_add(iteration):
    final = 0
    for i in range(iteration):
        final = numba.cpu.atomic.fetch_add(arr, 1, 1)
    print("iteration:", iteration, "final:", final)


processes = []
for i in range(8):
    def __run(iteration):
        # Re-open the memmap in each process to ensure proper sharing
        global arr
        arr = np.memmap(temp_file.name, dtype=np.int64, mode='r+', shape=(1,))
        fetch_add(iteration)


    processes.append(multiprocessing.Process(target=__run, args=(1024*1024,)))

for p in processes:
    p.start()

for p in processes:
    p.join()

print("end:", arr[0])

# Clean up the temporary file
arr._mmap.close()  # Close the memmap
del arr  # Delete the array
os.unlink(temp_file.name)  # Remove the temporary file


# Test section remains unchanged
@numba.njit(cache=False)
def atomic_fetch_add_test(arr, idx, val):
    return numba.cpu.atomic.fetch_add(arr, idx, val)


# Test with uint8 array
arr = np.array([10, 20, 30], dtype=np.uint8)
old_val = atomic_fetch_add_test(arr, 1, np.uint8(5))
assert old_val == 20  # Previous value
assert arr[1] == 25  # New value

I would appreciate any feedback :slight_smile:

I am able to get it to working.
The API currently looks like this

#!/usr/bin/env python3
"""
IPC counter shared via a real, on-disk mmap
(no NumPy — only Python’s standard library).
"""

import mmap
import multiprocessing
import struct
import tempfile

import numpy as np

import numba

# from multiprocessing import Process, Lock, set_start_method

ITERATION_COUNT = 16  # number of processes
ADD_COUNT = 1024 * 1024  # increments per process
INT64_FMT = "<q"  # little-endian signed int64, 8 bytes
BYTES_NEEDED = struct.calcsize(INT64_FMT)
USE_ATOMIC = True


def create_mmap_file(file_size):
    tmp = tempfile.NamedTemporaryFile(prefix="shared_counter_", delete=False)
    tmp.truncate(BYTES_NEEDED)  # make sure file is big enough
    tmp.close()

    f = open(tmp.name, "r+b")
    f.write(b"\x00" * file_size)
    f.flush()

    return tmp.name


@numba.njit(cache=False)
def _numba_run(sync_buffer, add_buffer, idx, iterations):
    total = 0

    # Atomically set this process as ready
    numba.atomic.store(sync_buffer, idx, 1)

    # Wait for all processes to be ready using atomic loads
    while True:
        all_ready = True
        for i in range(ITERATION_COUNT):
            if numba.atomic.load(sync_buffer, i) != 1:
                all_ready = False
                break

        if all_ready:
            break

    for i in range(iterations):
        # Use atomic fetch_add to safely increment the counter
        if USE_ATOMIC:
            old_val = numba.atomic.fetch_add(add_buffer, 0, 1)
            total += old_val
        else:
            total += add_buffer[0]
            add_buffer[0] += 1

    return total


def worker(sync_path: str, add_path: str, idx: int, iterations: int) -> None:
    sync_file = open(sync_path, "r+b")
    sync_mm = mmap.mmap(sync_file.fileno(), ITERATION_COUNT, access=mmap.ACCESS_WRITE)  # 8 bytes per int64
    sync_buffer = np.frombuffer(sync_mm, dtype=np.int8)

    add_file = open(add_path, "r+b")
    add_mm = mmap.mmap(add_file.fileno(), 8, access=mmap.ACCESS_WRITE)
    add_buffer = np.frombuffer(add_mm, dtype=np.int64)

    _numba_run(sync_buffer, add_buffer, idx, iterations)


def __main():
    sync_path = create_mmap_file(ITERATION_COUNT * 8)  # 8 bytes per int64
    add_path = create_mmap_file(8)
    procs = [
        multiprocessing.Process(
            target=worker, args=(sync_path, add_path, idx, ADD_COUNT)
        )
        for idx in range(ITERATION_COUNT)
    ]

    for p in procs:
        p.start()
    for p in procs:
        p.join()

    add_file = open(add_path, "rb")
    add_buffer = np.memmap(add_file, dtype=np.int64, mode="r")
    final_val = add_buffer[0]
    expected = ITERATION_COUNT * ADD_COUNT
    print(f"Final value: {final_val}, expected: {expected}")
    assert final_val == expected


if __name__ == "__main__":
    __main()