I have been recently working (WIP) on a JAX-Numba bridge (to call Numba jitted functions from JAX jitted functions), and was wondering if the proposed API can be improved in any way to be more idiomatic to numba users.
I have some more specific questions that I think would improve the code if resolved
- I was not able to unpack the arguments passed to the “wrapped” numba function (line 160), and we have to resort to unpacking inside the numba function (e.g. line 21). Is there any way around this?
- At the moment the input shapes are fixed and we have to recompile if the function is called with different shapes. Is there a straightforward strategy on how dynamically pass the number of dimensions of the carray?