You can do something like this (untested, may have issues!):
import numpy as np
from numba import jit, njit
from numba.core import extending, compiler, types
from numba.core.datamodel.models import ArrayModel
from numba.core.dispatcher import Dispatcher
from numba.core.registry import dispatcher_registry, cpu_target
from numba.np import numpy_support
# this is a wrapper class to differentiate from np.ndarray, it simply holds
# a reference to the array arg
class ShapeArray(object):
def __init__(self, arr):
assert isinstance(arr, np.ndarray)
self.arr = arr
# This is the type for ShapeArray, it's basically the same as types.Array and
# inherits from it so as to get access to all the things that work already for
# Array, the only real difference is that this type holds the *shape* of the
# array it's wrapping in the `.shape` attr
class ShapeArrayType(types.Array):
def __init__(self, sarr):
val = sarr.arr
self.shape = val.shape
try:
dtype = numpy_support.from_dtype(val.dtype)
except NotImplementedError:
raise ValueError("Unsupported array dtype: %s" % (val.dtype,))
layout = numpy_support.map_layout(val)
readonly = not val.flags.writeable
super().__init__(dtype, val.ndim, layout, readonly=readonly)
nm = self.name
self.name = 'ShapeArray[%s](%s)' % (str(self.shape), nm)
@property
def key(self):
# use the shape in the key, want to dispatch based on it!
return super().key, self.shape
# register ShapeArray with `typeof`
@extending.typeof_impl.register(ShapeArray)
def typeof_shapearray(val, c):
return ShapeArrayType(val)
# tell the backend to use the same datamodel as np.ndarray for ShapeArray
extending.register_model(ShapeArrayType)(ArrayModel)
# This is the most important part, creating a custom dispatcher. This dispatcher
# is the same as the standard cpu dispatcher, but it overrides
# `_compile_for_args` and in there wraps any `ndarray` argument instance in a
# ShapeArray class such that specialisation is possible.
class ShapeSpecialiseDispatcher(Dispatcher):
targetdescr = cpu_target
def __init__(self, py_func, locals={}, targetoptions={},
impl_kind='direct', pipeline_class=compiler.Compiler):
super(ShapeSpecialiseDispatcher, self).__init__(py_func,
locals=locals,
targetoptions=targetoptions,
impl_kind=impl_kind,
pipeline_class=pipeline_class)
def _compile_for_args(self, *args, **kws):
# this is the important bit, wrap numpy arrays as ShapeArray
nargs = []
for x in args:
if isinstance(x, np.ndarray):
nargs.append(ShapeArray(x))
else:
nargs.append(x)
return Dispatcher._compile_for_args(self, *nargs, **kws)
# tell the dispatcher registry about this shape specialising dispatcher
dispatcher_registry['shape_specialiser'] = ShapeSpecialiseDispatcher
# create a jit decorator, it's like numba.jit but takes the kwarg
# "shape_specialise" and sets the `_target` based on that (this influences
# which dispatcher to use).
def shape_specialising_jit(*args, **kwargs):
_specialise = kwargs.pop('shape_specialise', False)
_target = 'shape_specialiser' if _specialise else 'cpu'
return jit(*args, _target=_target, **kwargs)
# Example of use:
def bar(x):
pass
# this demo's the compile time constant shape (or not!) depending on the
# shape_specialise option
@extending.overload(bar)
def ol_bar(x):
print("Demo specialised", x)
shape_str = str(getattr(x, 'shape', 'Not present'))
print("Compile time shape: %s" % shape_str)
def impl(x):
return "Runtime const str shape: " + shape_str
return impl
# a function using standard JIT, works fine
@njit
def baz(x):
return x.sum()
# demo...
for specialise in (False, True):
print(("specialise=%s" % specialise).center(80, '-'))
@shape_specialising_jit(shape_specialise=specialise)
def foo(x):
a = len(x)
b = bar(x)
c = baz(x)
return a, b, c
print(foo(np.ones((5, 4, 3, 2, 1))))
print("")
foo.inspect_types()
gives me this:
--------------------------------specialise=False--------------------------------
Demo specialized array(float64, 5d, C)
Compile time shape: Not present
(5, 'Runtime const str shape: Not present', 120.0)
foo (array(float64, 5d, C),)
--------------------------------------------------------------------------------
# File: di6.py
# --- LINE 107 ---
@shape_specialising_jit(shape_specialise=specialise)
# --- LINE 108 ---
def foo(x):
# --- LINE 109 ---
# label 0
# x = arg(0, name=x) :: array(float64, 5d, C)
# $2load_global.0 = global(len: <built-in function len>) :: Function(<built-in function len>)
# $6call_function.2 = call $2load_global.0(x, func=$2load_global.0, args=[Var(x, di6.py:109)], kws=(), vararg=None) :: (array(float64, 5d, C),) -> int64
# del $2load_global.0
# a = $6call_function.2 :: int64
# del $6call_function.2
a = len(x)
# --- LINE 110 ---
# $10load_global.3 = global(bar: <function bar at 0x7f940dbced40>) :: Function(<function bar at 0x7f940dbced40>)
# $14call_function.5 = call $10load_global.3(x, func=$10load_global.3, args=[Var(x, di6.py:109)], kws=(), vararg=None) :: (array(float64, 5d, C),) -> unicode_type
# del $10load_global.3
# b = $14call_function.5 :: unicode_type
# del $14call_function.5
b = bar(x)
# --- LINE 111 ---
# $18load_global.6 = global(baz: CPUDispatcher(<function baz at 0x7f940db44200>)) :: type(CPUDispatcher(<function baz at 0x7f940db44200>))
# $22call_function.8 = call $18load_global.6(x, func=$18load_global.6, args=[Var(x, di6.py:109)], kws=(), vararg=None) :: (array(float64, 5d, C),) -> float64
# del x
# del $18load_global.6
# c = $22call_function.8 :: float64
# del $22call_function.8
c = baz(x)
# --- LINE 112 ---
# $32build_tuple.12 = build_tuple(items=[Var(a, di6.py:109), Var(b, di6.py:110), Var(c, di6.py:111)]) :: Tuple(int64, unicode_type, float64)
# del c
# del b
# del a
# $34return_value.13 = cast(value=$32build_tuple.12) :: Tuple(int64, unicode_type, float64)
# del $32build_tuple.12
# return $34return_value.13
return a, b, c
================================================================================
--------------------------------specialise=True---------------------------------
Demo specialized ShapeArray[(5, 4, 3, 2, 1)](array(float64, 5d, C))
Compile time shape: (5, 4, 3, 2, 1)
(5, 'Runtime const str shape: (5, 4, 3, 2, 1)', 120.0)
foo (ShapeArray[(5, 4, 3, 2, 1)](array(float64, 5d, C)),)
--------------------------------------------------------------------------------
# File: di6.py
# --- LINE 107 ---
@shape_specialising_jit(shape_specialise=specialise)
# --- LINE 108 ---
def foo(x):
# --- LINE 109 ---
# label 0
# x = arg(0, name=x) :: ShapeArray[(5, 4, 3, 2, 1)](array(float64, 5d, C))
# $2load_global.0 = global(len: <built-in function len>) :: Function(<built-in function len>)
# $6call_function.2 = call $2load_global.0(x, func=$2load_global.0, args=[Var(x, di6.py:109)], kws=(), vararg=None) :: (ShapeArray[(5, 4, 3, 2, 1)](array(float64, 5d, C)),) -> int64
# del $2load_global.0
# a = $6call_function.2 :: int64
# del $6call_function.2
a = len(x)
# --- LINE 110 ---
# $10load_global.3 = global(bar: <function bar at 0x7f940dbced40>) :: Function(<function bar at 0x7f940dbced40>)
# $14call_function.5 = call $10load_global.3(x, func=$10load_global.3, args=[Var(x, di6.py:109)], kws=(), vararg=None) :: (ShapeArray[(5, 4, 3, 2, 1)](array(float64, 5d, C)),) -> unicode_type
# del $10load_global.3
# b = $14call_function.5 :: unicode_type
# del $14call_function.5
b = bar(x)
# --- LINE 111 ---
# $18load_global.6 = global(baz: CPUDispatcher(<function baz at 0x7f940db44200>)) :: type(CPUDispatcher(<function baz at 0x7f940db44200>))
# $22call_function.8 = call $18load_global.6(x, func=$18load_global.6, args=[Var(x, di6.py:109)], kws=(), vararg=None) :: (ShapeArray[(5, 4, 3, 2, 1)](array(float64, 5d, C)),) -> float64
# del x
# del $18load_global.6
# c = $22call_function.8 :: float64
# del $22call_function.8
c = baz(x)
# --- LINE 112 ---
# $32build_tuple.12 = build_tuple(items=[Var(a, di6.py:109), Var(b, di6.py:110), Var(c, di6.py:111)]) :: Tuple(int64, unicode_type, float64)
# del c
# del b
# del a
# $34return_value.13 = cast(value=$32build_tuple.12) :: Tuple(int64, unicode_type, float64)
# del $32build_tuple.12
# return $34return_value.13
return a, b, c
================================================================================