Hi!
How can I rewrite the indexing in func_jit
below, so that it compiles?
MRE
import numpy as np
from numba import jit, f8
n = 3
a, b = np.arange(n**2).reshape(n,n).astype(float), np.arange(n**2).reshape(n,n).astype(float)*10
@jit(f8[:,:](f8[:,:], f8[:,:]) , nopython=True, parallel=True)
def func_jit( a, b ):
dist = a - b
condition = dist < 0.0
dist[condition] = dist[condition] + 100.0
# .... rest of function
return dist
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
Cell In[235], line 19
13 # .... rest of function
14 return dist
18 @jit(f8[:,:](f8[:,:], f8[:,:]) , nopython=True, parallel=True)
---> 19 def func_jit( a, b ):
21 dist = a - b
22 condition = dist < 0.0
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/decorators.py:219, in _jit.<locals>.wrapper(func)
217 with typeinfer.register_dispatcher(disp):
218 for sig in sigs:
--> 219 disp.compile(sig)
220 disp.disable_compile()
221 return disp
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/dispatcher.py:965, in Dispatcher.compile(self, sig)
963 with ev.trigger_event("numba:compile", data=ev_details):
964 try:
--> 965 cres = self._compiler.compile(args, return_type)
966 except errors.ForceLiteralArg as e:
967 def folded(args, kws):
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/dispatcher.py:129, in _FunctionCompiler.compile(self, args, return_type)
127 return retval
128 else:
--> 129 raise retval
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/dispatcher.py:139, in _FunctionCompiler._compile_cached(self, args, return_type)
136 pass
138 try:
--> 139 retval = self._compile_core(args, return_type)
140 except errors.TypingError as e:
141 self._failed_cache[key] = e
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/dispatcher.py:152, in _FunctionCompiler._compile_core(self, args, return_type)
149 flags = self._customize_flags(flags)
151 impl = self._get_implementation(args, {})
--> 152 cres = compiler.compile_extra(self.targetdescr.typing_context,
153 self.targetdescr.target_context,
154 impl,
155 args=args, return_type=return_type,
156 flags=flags, locals=self.locals,
157 pipeline_class=self.pipeline_class)
158 # Check typing error if object mode is used
159 if cres.typing_error is not None and not flags.enable_pyobject:
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler.py:716, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
692 """Compiler entry point
693
694 Parameter
(...)
712 compiler pipeline
713 """
714 pipeline = pipeline_class(typingctx, targetctx, library,
715 args, return_type, flags, locals)
--> 716 return pipeline.compile_extra(func)
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler.py:452, in CompilerBase.compile_extra(self, func)
450 self.state.lifted = ()
451 self.state.lifted_from = None
--> 452 return self._compile_bytecode()
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler.py:520, in CompilerBase._compile_bytecode(self)
516 """
517 Populate and run pipeline for bytecode input
518 """
519 assert self.state.func_ir is None
--> 520 return self._compile_core()
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler.py:499, in CompilerBase._compile_core(self)
497 self.state.status.fail_reason = e
498 if is_final_pipeline:
--> 499 raise e
500 else:
501 raise CompilerError("All available pipelines exhausted")
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler.py:486, in CompilerBase._compile_core(self)
484 res = None
485 try:
--> 486 pm.run(self.state)
487 if self.state.cr is not None:
488 break
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler_machinery.py:368, in PassManager.run(self, state)
365 msg = "Failed in %s mode pipeline (step: %s)" % \
366 (self.pipeline_name, pass_desc)
367 patched_exception = self._patch_error(msg, e)
--> 368 raise patched_exception
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler_machinery.py:356, in PassManager.run(self, state)
354 pass_inst = _pass_registry.get(pss).pass_inst
355 if isinstance(pass_inst, CompilerPass):
--> 356 self._runPass(idx, pass_inst, state)
357 else:
358 raise BaseException("Legacy pass in use")
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
32 @functools.wraps(func)
33 def _acquire_compile_lock(*args, **kwargs):
34 with self:
---> 35 return func(*args, **kwargs)
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
309 mutated |= check(pss.run_initialization, internal_state)
310 with SimpleTimer() as pass_time:
--> 311 mutated |= check(pss.run_pass, internal_state)
312 with SimpleTimer() as finalize_time:
313 mutated |= check(pss.run_finalizer, internal_state)
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
272 def check(func, compiler_state):
--> 273 mangled = func(compiler_state)
274 if mangled not in (True, False):
275 msg = ("CompilerPass implementations should return True/False. "
276 "CompilerPass with name '%s' did not.")
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/typed_passes.py:105, in BaseTypeInference.run_pass(self, state)
99 """
100 Type inference and legalization
101 """
102 with fallback_context(state, 'Function "%s" failed type inference'
103 % (state.func_id.func_name,)):
104 # Type inference
--> 105 typemap, return_type, calltypes, errs = type_inference_stage(
106 state.typingctx,
107 state.targetctx,
108 state.func_ir,
109 state.args,
110 state.return_type,
111 state.locals,
112 raise_errors=self._raise_errors)
113 state.typemap = typemap
114 # save errors in case of partial typing
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/typed_passes.py:83, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
81 infer.build_constraint()
82 # return errors in case of partial typing
---> 83 errs = infer.propagate(raise_errors=raise_errors)
84 typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
86 # Output all Numba warnings
File /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/typeinfer.py:1086, in TypeInferer.propagate(self, raise_errors)
1083 force_lit_args = [e for e in errors
1084 if isinstance(e, ForceLiteralArg)]
1085 if not force_lit_args:
-> 1086 raise errors[0]
1087 else:
1088 raise reduce(operator.or_, force_lit_args)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:
>>> getitem(array(float64, 2d, C), array(bool, 2d, C))
There are 22 candidate implementations:
- Of which 20 did not match due to:
Overload of function 'getitem': File: <numerous>: Line N/A.
With argument(s): '(array(float64, 2d, C), array(bool, 2d, C))':
No match.
- Of which 2 did not match due to:
Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 166.
With argument(s): '(array(float64, 2d, C), array(bool, 2d, C))':
Rejected as the implementation raised a specific error:
NumbaTypeError: unsupported array index type array(bool, 2d, C) in [array(bool, 2d, C)]
raised from /srv/conda/envs/notebook/lib/python3.10/site-packages/numba/core/typing/arraydecl.py:72
During: typing of intrinsic-call at /tmp/ipykernel_726/2647517645.py (23)
File "../../tmp/ipykernel_726/2647517645.py", line 23:
<source missing, REPL/exec in use?>
After reading here I think this happens because the index-array condition
is not 1d.
Trying numba for the first time on a numpy- & scipy-based calculation.