Calling Numba from JAX

Hey all,

I have been recently working (WIP) on a JAX-Numba bridge (to call Numba jitted functions from JAX jitted functions), and was wondering if the proposed API can be improved in any way to be more idiomatic to numba users.

I have some more specific questions that I think would improve the code if resolved

  • I was not able to unpack the arguments passed to the “wrapped” numba function (line 160), and we have to resort to unpacking inside the numba function (e.g. line 21). Is there any way around this?
  • At the moment the input shapes are fixed and we have to recompile if the function is called with different shapes. Is there a straightforward strategy on how dynamically pass the number of dimensions of the carray?

Hi @josipd,

This is really interesting, thanks for posting.

I’m interested in having a play with this to work out where Numba’s getting stuck. Are there some instructions for getting set up or will the standard JAX instructions work ok? I presume the example here demonstrates the issues?

Thanks.

Glad to hear! The example in the module docstring does work out of the box with the current JAX version.

For the first problem, a minimal example is the following

@numba.jit
def add(x, y):
    return x + y

@numba.cfunc("float64(float64, float64)")
def foo(x, y):
  args = []
  args.append(x)
  args.append(y)
  return add(*tuple(args))

Which doesn’t work as (no overload tuple(list) or tuple(ListType) exists). Is there a workaround?

The Tuple type requires a known length, while lists are of arbitrary length. In your example, it would be trivial to write add(x, y) but I suspect you are thinking of a more complex example. Could the more complex example always be reduced to a tuple length which is known in advance?

Look at these examples:


There might some sample code on how to transform a list of known length into a tuple, but I haven’t found it.

Hope it helps,
Luk

Thanks for the reply! The length of the list is indeed known, but it’s not a constant as the cfunc is defined inside an outer scope (the code in question). One thing that comes to mind is to create functions create_tuple_of_size_n(list) for varying n, but I hope there’s a better solution.

I think you can write a create_tuple_from_list(list, len_list) as a generated jit or an overload, that will interpret len_list as Literal[int] and therefore specialize for size as part of dispatching. I don’t think you can avoid the multiple compilation (once per unique length) though.