How can I rewrite the indexing so the function compiles?

Hi!

How can I rewrite the indexing in func_jit below, so that it compiles?

MRE
import numpy as np
from numba import jit, f8

n = 3
a, b = np.arange(n**2).reshape(n,n).astype(float), np.arange(n**2).reshape(n,n).astype(float)*10

@jit(f8[:,:](f8[:,:], f8[:,:]) , nopython=True, parallel=True)
def func_jit( a, b ):

    dist = a - b
    condition = dist < 0.0
    dist[condition] = dist[condition] + 100.0
    # .... rest of function
    return dist

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Cell In[235], line 19
     13     # .... rest of function
     14     return dist
     18 @jit(f8[:,:](f8[:,:], f8[:,:]) , nopython=True, parallel=True)
---> 19 def func_jit( a, b ):
     21     dist = a - b
     22     condition = dist < 0.0

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/decorators.py:219, in _jit.<locals>.wrapper(func)
    217     with typeinfer.register_dispatcher(disp):
    218         for sig in sigs:
--> 219             disp.compile(sig)
    220         disp.disable_compile()
    221 return disp

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/dispatcher.py:965, in Dispatcher.compile(self, sig)
    963 with ev.trigger_event("numba:compile", data=ev_details):
    964     try:
--> 965         cres = self._compiler.compile(args, return_type)
    966     except errors.ForceLiteralArg as e:
    967         def folded(args, kws):

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/dispatcher.py:129, in _FunctionCompiler.compile(self, args, return_type)
    127     return retval
    128 else:
--> 129     raise retval

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/dispatcher.py:139, in _FunctionCompiler._compile_cached(self, args, return_type)
    136     pass
    138 try:
--> 139     retval = self._compile_core(args, return_type)
    140 except errors.TypingError as e:
    141     self._failed_cache[key] = e

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/dispatcher.py:152, in _FunctionCompiler._compile_core(self, args, return_type)
    149 flags = self._customize_flags(flags)
    151 impl = self._get_implementation(args, {})
--> 152 cres = compiler.compile_extra(self.targetdescr.typing_context,
    153                               self.targetdescr.target_context,
    154                               impl,
    155                               args=args, return_type=return_type,
    156                               flags=flags, locals=self.locals,
    157                               pipeline_class=self.pipeline_class)
    158 # Check typing error if object mode is used
    159 if cres.typing_error is not None and not flags.enable_pyobject:

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler.py:716, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    692 """Compiler entry point
    693 
    694 Parameter
   (...)
    712     compiler pipeline
    713 """
    714 pipeline = pipeline_class(typingctx, targetctx, library,
    715                           args, return_type, flags, locals)
--> 716 return pipeline.compile_extra(func)

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler.py:452, in CompilerBase.compile_extra(self, func)
    450 self.state.lifted = ()
    451 self.state.lifted_from = None
--> 452 return self._compile_bytecode()

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler.py:520, in CompilerBase._compile_bytecode(self)
    516 """
    517 Populate and run pipeline for bytecode input
    518 """
    519 assert self.state.func_ir is None
--> 520 return self._compile_core()

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler.py:499, in CompilerBase._compile_core(self)
    497         self.state.status.fail_reason = e
    498         if is_final_pipeline:
--> 499             raise e
    500 else:
    501     raise CompilerError("All available pipelines exhausted")

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler.py:486, in CompilerBase._compile_core(self)
    484 res = None
    485 try:
--> 486     pm.run(self.state)
    487     if self.state.cr is not None:
    488         break

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler_machinery.py:368, in PassManager.run(self, state)
    365 msg = "Failed in %s mode pipeline (step: %s)" % \
    366     (self.pipeline_name, pass_desc)
    367 patched_exception = self._patch_error(msg, e)
--> 368 raise patched_exception

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/typed_passes.py:105, in BaseTypeInference.run_pass(self, state)
     99 """
    100 Type inference and legalization
    101 """
    102 with fallback_context(state, 'Function "%s" failed type inference'
    103                       % (state.func_id.func_name,)):
    104     # Type inference
--> 105     typemap, return_type, calltypes, errs = type_inference_stage(
    106         state.typingctx,
    107         state.targetctx,
    108         state.func_ir,
    109         state.args,
    110         state.return_type,
    111         state.locals,
    112         raise_errors=self._raise_errors)
    113     state.typemap = typemap
    114     # save errors in case of partial typing

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/typed_passes.py:83, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
     81     infer.build_constraint()
     82     # return errors in case of partial typing
---> 83     errs = infer.propagate(raise_errors=raise_errors)
     84     typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
     86 # Output all Numba warnings

File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/typeinfer.py:1086, in TypeInferer.propagate(self, raise_errors)
   1083 force_lit_args = [e for e in errors
   1084                   if isinstance(e, ForceLiteralArg)]
   1085 if not force_lit_args:
-> 1086     raise errors[0]
   1087 else:
   1088     raise reduce(operator.or_, force_lit_args)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:
 
 >>> getitem(array(float64, 2d, C), array(bool, 2d, C))
 
There are 22 candidate implementations:
      - Of which 20 did not match due to:
      Overload of function 'getitem': File: <numerous>: Line N/A.
        With argument(s): '(array(float64, 2d, C), array(bool, 2d, C))':
       No match.
      - Of which 2 did not match due to:
      Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 166.
        With argument(s): '(array(float64, 2d, C), array(bool, 2d, C))':
       Rejected as the implementation raised a specific error:
         NumbaTypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
  raised from /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/typing/arraydecl.py:72

During: typing of intrinsic-call at /tmp/ipykernel_726/2647517645.py (23)

File "../../tmp/ipykernel_726/2647517645.py", line 23:
<source missing, REPL/exec in use?>

After reading here I think this happens because the index-array condition is not 1d.

Trying numba for the first time on a numpy- & scipy-based calculation.

Hi @ofk123

I think you are right about the origin of the error. Fortunately, there are many ways to make this work. It is just a compromise between readability, brevity and efficiency.
For example, you can always write it with explicit loops. This is also more efficient in your case, because you only have to go through the arrays once.

Here you find two ways to implement this. Hope this helps.
import numpy as np
import numba as nb

n = 1_000
a = np.arange(n**2, dtype=float).reshape(n, n)
b = np.arange(n**2, dtype=float).reshape(n, n)*10

# @nb.njit
def func_jit1(a, b):
    dist = a - b
    condition = dist < 0.0
    dist[condition] = dist[condition] + 100.0
    return dist

@nb.njit(parallel=False)
def func_jit2(a, b):
    dist = np.empty_like(a)
    for i in nb.prange(a.shape[0]):
        for j in range(a.shape[1]):
            dist[i, j] = a[i, j] - b[i, j]
            if dist[i, j] < 0:
                dist[i, j] += 100
    return dist

@nb.njit(parallel=False)
def func_jit3(a, b):
    dist = np.empty_like(a)
    for i in nb.prange(a.shape[0]):
        dist[i] = (a[i] - b[i]) + (a[i] - b[i] < 0)*100
    return dist

np.allclose(func_jit1(a, b), func_jit2(a, b))
np.allclose(func_jit1(a, b), func_jit3(a, b))

%timeit func_jit1(a, b)
%timeit func_jit2(a, b)
%timeit func_jit3(a, b)
1 Like

Thanks alot @sschaer. This works. On my end func_jit2 seems to be fastest.

Can I ask a follow up on why you use nb.prange only on one of the loops?