First class functions and structured arrays

Hi all,

I’m looking to take advantage of Numba’s support for first class functions (Types and signatures — Numba 0.52.0.dev0+274.g626b40e-py3.7-linux-x86_64.egg documentation). Unfortunately I am running into an unimplemented code path.

First, let’s consider how things work with standard NumPy numeric types:

import numba as nb
import numpy as np

@nb.vectorize
def fun(x):
    return x + 1

@nb.jit
def apply_fun(x, fun):
    return fun(x)

x_num = np.ones(3, dtype=np.uint8)
print(fun(x_num))
# no problem: [2 2 2]
print(apply_fun(x_num, nb.njit(lambda elem: fun(elem))))
# no problem: [2 2 2]

Now let’s do the same but with structured arrays.

np_dtype = np.dtype([("field_1", np.uint8), ("field_2", np.uint8)], align=True)
nb_dtype = nb.from_dtype(np_dtype)

# @nb.vectorize([nb_dtype(nb_dtype)])
@nb.vectorize
def fun(x):
    out = np.empty(1, dtype=nb_dtype)
    out.field_1[:] = x.field_1 - 1
    out.field_2[:] = (x.field_1 * x.field_2) << 1
    return out[0]

x_struct = np.ones(3, dtype=np_dtype)
print(fun(x_struct))
# no problem: [(0, 2) (0, 2) (0, 2)]

print(apply_fun(x_struct, nb.njit(lambda elem: fun(elem))))
# fails! -> NotImplementedError: unsupported type for input operand: Record(field_1[type=uint8;offset=0],field_2[type=uint8;offset=1];2;True)

Note that I also tried this out specifying the type signature explicitly in the type signature. Things fail in a slightly different code path but omitting numba/numba/np/ufunc/dufunc.py at main · numba/numba · GitHub leads to the same error.

Should this be characterized as a bug or are there fundamental reasons for the NotImplementedError? I’m happy to attempt to provide a fix if it’s the former, but would appreciate any pointers from someone more familiar with the code base before attempting such a thing.

Complete stack trace:

Traceback (most recent call last):
  File "/Users/adi/Documents/numba-vectorization-experiments/x.py", line 35, in <module>
    print(apply_fun(x_struct, nb.njit(lambda elem: fun(elem))))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-bug/lib/python3.12/site-packages/numba/core/dispatcher.py", line 442, in _compile_for_args
    raise e
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/dispatcher.py", line 375, in _compile_for_args
    return_val = self.compile(tuple(argtypes))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/dispatcher.py", line 905, in compile
    cres = self._compiler.compile(args, return_type)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/dispatcher.py", line 80, in compile
    status, retval = self._compile_cached(args, return_type)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/dispatcher.py", line 94, in _compile_cached
    retval = self._compile_core(args, return_type)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/dispatcher.py", line 107, in _compile_core
    cres = compiler.compile_extra(self.targetdescr.typing_context,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 744, in compile_extra
    return pipeline.compile_extra(func)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 438, in compile_extra
    return self._compile_bytecode()
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 506, in _compile_bytecode
    return self._compile_core()
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 481, in _compile_core
    raise e
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 472, in _compile_core
    pm.run(self.state)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 364, in run
    raise e
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 356, in run
    self._runPass(idx, pass_inst, state)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 311, in _runPass
    mutated |= check(pss.run_pass, internal_state)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 273, in check
    mangled = func(compiler_state)
              ^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/typed_passes.py", line 112, in run_pass
    typemap, return_type, calltypes, errs = type_inference_stage(
                                            ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/typed_passes.py", line 93, in type_inference_stage
    errs = infer.propagate(raise_errors=raise_errors)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/typeinfer.py", line 1083, in propagate
    errors = self.constraints.propagate(self)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/typeinfer.py", line 182, in propagate
    raise e
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/typeinfer.py", line 160, in propagate
    constraint(typeinfer)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/typeinfer.py", line 583, in __call__
    self.resolve(typeinfer, typevars, fnty)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/typeinfer.py", line 606, in resolve
    sig = typeinfer.resolve_call(fnty, pos_args, kw_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/typeinfer.py", line 1577, in resolve_call
    return self.context.resolve_function_type(fnty, pos_args, kw_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/typing/context.py", line 196, in resolve_function_type
    res = self._resolve_user_function_type(func, args, kws)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/typing/context.py", line 248, in _resolve_user_function_type
    return func.get_call_type(self, args, kws)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/types/functions.py", line 541, in get_call_type
    self.dispatcher.get_call_template(args, kws)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/dispatcher.py", line 318, in get_call_template
    self.compile(tuple(args))
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/dispatcher.py", line 905, in compile
    cres = self._compiler.compile(args, return_type)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/dispatcher.py", line 80, in compile
    status, retval = self._compile_cached(args, return_type)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/dispatcher.py", line 94, in _compile_cached
    retval = self._compile_core(args, return_type)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/dispatcher.py", line 107, in _compile_core
    cres = compiler.compile_extra(self.targetdescr.typing_context,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 744, in compile_extra
    return pipeline.compile_extra(func)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 438, in compile_extra
    return self._compile_bytecode()
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 506, in _compile_bytecode
    return self._compile_core()
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 481, in _compile_core
    raise e
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 472, in _compile_core
    pm.run(self.state)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 364, in run
    raise e
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 356, in run
    self._runPass(idx, pass_inst, state)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 311, in _runPass
    mutated |= check(pss.run_pass, internal_state)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 273, in check
    mangled = func(compiler_state)
              ^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/typed_passes.py", line 468, in run_pass
    lower.lower()
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 187, in lower
    self.lower_normal_function(self.fndesc)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 226, in lower_normal_function
    entry_block_tail = self.lower_function_body()
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 256, in lower_function_body
    self.lower_block(block)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 270, in lower_block
    self.lower_inst(inst)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 448, in lower_inst
    val = self.lower_assign(ty, inst)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 660, in lower_assign
    return self.lower_expr(ty, value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 1407, in lower_expr
    res = self.context.special_ops[expr.op](self, expr)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/np/ufunc/array_exprs.py", line 406, in _lower_array_expr
    cres = context.compile_subroutine(builder, impl, inner_sig, flags=flags,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/base.py", line 866, in compile_subroutine
    cres = self._compile_subroutine_no_cache(builder, impl, sig,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/base.py", line 837, in _compile_subroutine_no_cache
    cres = compiler.compile_internal(self.typing_context, self,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 818, in compile_internal
    return pipeline.compile_extra(func)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 438, in compile_extra
    return self._compile_bytecode()
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 506, in _compile_bytecode
    return self._compile_core()
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 481, in _compile_core
    raise e
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler.py", line 472, in _compile_core
    pm.run(self.state)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 364, in run
    raise e
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 356, in run
    self._runPass(idx, pass_inst, state)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 311, in _runPass
    mutated |= check(pss.run_pass, internal_state)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 273, in check
    mangled = func(compiler_state)
              ^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/typed_passes.py", line 468, in run_pass
    lower.lower()
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 187, in lower
    self.lower_normal_function(self.fndesc)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 226, in lower_normal_function
    entry_block_tail = self.lower_function_body()
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 256, in lower_function_body
    self.lower_block(block)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 270, in lower_block
    self.lower_inst(inst)
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 448, in lower_inst
    val = self.lower_assign(ty, inst)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 660, in lower_assign
    return self.lower_expr(ty, value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 1196, in lower_expr
    res = self.lower_call(resty, expr)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 925, in lower_call
    res = self._lower_call_normal(fnty, expr, signature)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/lowering.py", line 1167, in _lower_call_normal
    res = impl(self.builder, argvals, self.loc)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/base.py", line 1190, in __call__
    res = self._imp(self._context, builder, self._sig, args, loc=loc)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/core/base.py", line 1220, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/np/ufunc/ufunc_base.py", line 15, in __call__
    return self.make_ufunc_kernel_fn(context, builder, sig, args,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/np/npyimpl.py", line 490, in numpy_ufunc_kernel
    arguments = [_prepare_argument(context, builder, arg, tyarg)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/adi/micromamba/envs/numba-vectorization/lib/python3.12/site-packages/numba/np/npyimpl.py", line 326, in _prepare_argument
    raise NotImplementedError('unsupported type for {0}: {1}'.format(where,
NotImplementedError: unsupported type for input operand: Record(field_1[type=uint8;offset=0],field_2[type=uint8;offset=1];2;True)
1 Like

To me at first glance this is sitting somewhere between “bug” and “feature request” - could you open an issue on the issue tracker please?

Sure. Created First class function and structured arrays `NotImplementedError` · Issue #9901 · numba/numba · GitHub.

1 Like

Thanks for opening the issue! It should be discussed at the triage meeting today.

2 Likes