Hi, is it possible to use atomic types?
Our use case is interacting with a shared memory and the synchronization is done via some atomics.
Is something like this possible with numba? I am guessing I have to resort to intrinsics?
Hi, is it possible to use atomic types?
Our use case is interacting with a shared memory and the synchronization is done via some atomics.
Is something like this possible with numba? I am guessing I have to resort to intrinsics?
Okay so, I used Claude Code and was able to get something working GitHub - hbina/numba at hbina-cpu-atomics-claude
Is this even usable? Is the LLM hallucinating the solution?
I am trying to grok the changes and some of it makes sense, some not.
The LLVM IR produced does seem to have “atomic” in it but I am not sure if it actually works.
I am certainly not able to make it work via numpy.memmap.
import multiprocessing
import tempfile
import os
import numba.cpu
from numba import njit
import numpy as np
# Create a memory-mapped array backed by a temporary file
temp_file = tempfile.NamedTemporaryFile(prefix='shared_array_', dir='/tmp', delete=False)
arr = np.memmap(temp_file.name, dtype=np.int64, mode='w+', shape=(1,))
@numba.njit(cache=True)
def fetch_add(iteration):
final = 0
for i in range(iteration):
final = numba.cpu.atomic.fetch_add(arr, 1, 1)
print("iteration:", iteration, "final:", final)
processes = []
for i in range(8):
def __run(iteration):
# Re-open the memmap in each process to ensure proper sharing
global arr
arr = np.memmap(temp_file.name, dtype=np.int64, mode='r+', shape=(1,))
fetch_add(iteration)
processes.append(multiprocessing.Process(target=__run, args=(1024*1024,)))
for p in processes:
p.start()
for p in processes:
p.join()
print("end:", arr[0])
# Clean up the temporary file
arr._mmap.close() # Close the memmap
del arr # Delete the array
os.unlink(temp_file.name) # Remove the temporary file
# Test section remains unchanged
@numba.njit(cache=False)
def atomic_fetch_add_test(arr, idx, val):
return numba.cpu.atomic.fetch_add(arr, idx, val)
# Test with uint8 array
arr = np.array([10, 20, 30], dtype=np.uint8)
old_val = atomic_fetch_add_test(arr, 1, np.uint8(5))
assert old_val == 20 # Previous value
assert arr[1] == 25 # New value
I would appreciate any feedback ![]()
I am able to get it to working.
The API currently looks like this
#!/usr/bin/env python3
"""
IPC counter shared via a real, on-disk mmap
(no NumPy — only Python’s standard library).
"""
import mmap
import multiprocessing
import struct
import tempfile
import numpy as np
import numba
# from multiprocessing import Process, Lock, set_start_method
ITERATION_COUNT = 16 # number of processes
ADD_COUNT = 1024 * 1024 # increments per process
INT64_FMT = "<q" # little-endian signed int64, 8 bytes
BYTES_NEEDED = struct.calcsize(INT64_FMT)
USE_ATOMIC = True
def create_mmap_file(file_size):
tmp = tempfile.NamedTemporaryFile(prefix="shared_counter_", delete=False)
tmp.truncate(BYTES_NEEDED) # make sure file is big enough
tmp.close()
f = open(tmp.name, "r+b")
f.write(b"\x00" * file_size)
f.flush()
return tmp.name
@numba.njit(cache=False)
def _numba_run(sync_buffer, add_buffer, idx, iterations):
total = 0
# Atomically set this process as ready
numba.atomic.store(sync_buffer, idx, 1)
# Wait for all processes to be ready using atomic loads
while True:
all_ready = True
for i in range(ITERATION_COUNT):
if numba.atomic.load(sync_buffer, i) != 1:
all_ready = False
break
if all_ready:
break
for i in range(iterations):
# Use atomic fetch_add to safely increment the counter
if USE_ATOMIC:
old_val = numba.atomic.fetch_add(add_buffer, 0, 1)
total += old_val
else:
total += add_buffer[0]
add_buffer[0] += 1
return total
def worker(sync_path: str, add_path: str, idx: int, iterations: int) -> None:
sync_file = open(sync_path, "r+b")
sync_mm = mmap.mmap(sync_file.fileno(), ITERATION_COUNT, access=mmap.ACCESS_WRITE) # 8 bytes per int64
sync_buffer = np.frombuffer(sync_mm, dtype=np.int8)
add_file = open(add_path, "r+b")
add_mm = mmap.mmap(add_file.fileno(), 8, access=mmap.ACCESS_WRITE)
add_buffer = np.frombuffer(add_mm, dtype=np.int64)
_numba_run(sync_buffer, add_buffer, idx, iterations)
def __main():
sync_path = create_mmap_file(ITERATION_COUNT * 8) # 8 bytes per int64
add_path = create_mmap_file(8)
procs = [
multiprocessing.Process(
target=worker, args=(sync_path, add_path, idx, ADD_COUNT)
)
for idx in range(ITERATION_COUNT)
]
for p in procs:
p.start()
for p in procs:
p.join()
add_file = open(add_path, "rb")
add_buffer = np.memmap(add_file, dtype=np.int64, mode="r")
final_val = add_buffer[0]
expected = ITERATION_COUNT * ADD_COUNT
print(f"Final value: {final_val}, expected: {expected}")
assert final_val == expected
if __name__ == "__main__":
__main()