I have been recently working (WIP) on a JAX-Numba bridge (to call Numba jitted functions from JAX jitted functions), and was wondering if the proposed API can be improved in any way to be more idiomatic to numba users.
I have some more specific questions that I think would improve the code if resolved
I was not able to unpack the arguments passed to the “wrapped” numba function (line 160), and we have to resort to unpacking inside the numba function (e.g. line 21). Is there any way around this?
At the moment the input shapes are fixed and we have to recompile if the function is called with different shapes. Is there a straightforward strategy on how dynamically pass the number of dimensions of the carray?
I’m interested in having a play with this to work out where Numba’s getting stuck. Are there some instructions for getting set up or will the standard JAX instructions work ok? I presume the example here demonstrates the issues?
The Tuple type requires a known length, while lists are of arbitrary length. In your example, it would be trivial to write add(x, y) but I suspect you are thinking of a more complex example. Could the more complex example always be reduced to a tuple length which is known in advance?
Look at these examples:
There might some sample code on how to transform a list of known length into a tuple, but I haven’t found it.
Thanks for the reply! The length of the list is indeed known, but it’s not a constant as the cfunc is defined inside an outer scope (the code in question). One thing that comes to mind is to create functions create_tuple_of_size_n(list) for varying n, but I hope there’s a better solution.
I think you can write a create_tuple_from_list(list, len_list) as a generated jit or an overload, that will interpret len_list as Literal[int] and therefore specialize for size as part of dispatching. I don’t think you can avoid the multiple compilation (once per unique length) though.