Runtime specialization

I am experimenting with ways to include runtime information in function specialization.

One archetypal example is the shape of an array: it is not included in the types.Array (for good reasons!) but opting-in to specialize on a subpart of shape can provide large performance increase (e.g. LLVM can reason about and vectorize inner loops if the number of iterations is compile-time constant or some algorithms can be optimized for specific shapes like 2x2, 3x3 or 4x4 matrices).

Consider those two examples (what the loop_array_* functions do is not important, it’s only to provide a benchmark):

import numpy as np
import numba as nb
from numba import types

class DimensionalType(types.Type):

    def __init__(self, num_dimensions):
        self.num_dimensions = num_dimensions
        super().__init__(name='Dimension[%dD]'%self.num_dimensions)

    @property
    def key(self):
        return self.num_dimensions

@nb.njit(fastmath=True)
def loop_array_base(arr):
    _t = arr.dtype.type
    N, M = arr.shape
    S = _t(0.0)
    for i in range(N):
        s = _t(0.0)
        for j in range(M):
            el = a[i, j]
            s += el*el
        if (i%2 == 0):
            S += s
        else:
            S -= s
    return np.sqrt(np.abs(S))


@nb.generated_jit(nopython=True)
def get_ndims(x):
    num_dimensions = x.instance_type.num_dimensions
    return lambda x: num_dimensions

@nb.njit(fastmath=True, inline="always")
def _loop_array(tp, arr):
    _t = arr.dtype.type
    N = len(arr)
    M = get_ndims(tp)
    S = _t(0.0)
    for i in range(N):
        s = _t(0.0)
        for j in range(M):
            el = a[i, j]
            s += el*el
        if (i%2 == 0):
            S += s
        else:
            S -= s
    return np.sqrt(np.abs(S))

mytp = DimensionalType(3)
@nb.njit(fastmath=True)
def loop_array_spec(arr):
    tp = mytp
    return _loop_array(tp, arr)

And speed tests:

a = np.random.rand(50000, 3)

loop_array_base(a)
loop_array_spec(a)
%timeit loop_array_base(a)
%timeit loop_array_spec(a)
164 µs ± 61.9 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
46.9 µs ± 123 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

As you can see, LLVM can better use SIMD instructions if it knows about (and specializes over) the inner shape of the array.

The problem is that the current approach requires the parameterization of DimensionalType in python. What I would love to write instead is:

@nb.njit(fastmath=True)
def loop_array_spec(arr):
    tp = DimensionalType(arr.shape[-1])
    return _loop_array(tp, arr)

I know I can probably wrap my array into a specialize type instead. But I would have to create full types with boxing/unboxing and all the bookkeeping and boilerplate for every single case that requires it (specializing on shape is not the only thing one might want to do).

My ultimate objective is to come up with a specialize decorator that would force specialization based on specific arguments (or attributes of those) in a general way without having to alter the type system.

hi @gdonval, interesting question. I won’t claim to have an answer, but I’ll offer some ideas, hoping to learn from others’ replies.

My first thought is that if you want to get compile-time specialization, then you need to provide that information at compile-time. Not telling the compiler the shape of arr, and expecting it can optimize based on the unknown number, is probably hard.

A few ideas to get you closer to what you want.

  • I’m curious why you need loop_array_spec to be jitted. Probably your real example is more complex, but loop_array_spec does not need to be jitted.

  • I’m happy to treat pure python as a kind of macro system for jitted code, and I usually write things like:

    @lru_cache()
    def make_loop_array(dims):
    
        @nb.njit(fastmath=True)
        def loop_array_spec(arr):
            tp = DimensionalType(dims)
            return _loop_array(tp, arr)
    
     return loop_array_spec
    
    

    The downside is that you need to track all your versions manually

  • Similar to the above, there might be a way to use literal dispatching to do the tracking for you. I’m speculating here, I’m not sure exactly how to do it , but something like:

    
        @nb.generated_njit(fastmath=True)
        def loop_array_spec(arr, shape):
            dim = shape._literal_value
    
            def impl(arr, shape)
                tp = DimensionalType(dim)
                return _loop_array(tp, arr)
    
            return impl
    
    
  • I’m guessing you don’t have a huge number of possible array sizes, because if you had, specialization would spend a lot of time in compilation and possibly erase the benefits of faster runtime. so another idea is

    @nb.njit(fastmath=True)
    def loop_array_spec(arr):
        if arr.shape[-1] = 1:
            tp = DimensionalType(1)
        elif arr.shape[-1] = 2:
            tp = DimensionalType(2)
        return _loop_array(tp, arr)
    

    Note that if you don’t want to write the above by hand (maybe because the exact values change over time) you can generate the function programatically as text, and then exec it. Just be aware than large text functions can take long time to compile (Tips or tricks for speeding up compilation time on first call of large Numba-jitted NumPy-containing functions?).

I hope this helps, I’m curious to see what other people come up with.

Luk

You can do something like this (untested, may have issues!):

import numpy as np
from numba import jit, njit
from numba.core import extending, compiler, types
from numba.core.datamodel.models import ArrayModel
from numba.core.dispatcher import Dispatcher
from numba.core.registry import dispatcher_registry, cpu_target
from numba.np import numpy_support

# this is a wrapper class to differentiate from np.ndarray, it simply holds
# a reference to the array arg
class ShapeArray(object):
    def __init__(self, arr):
        assert isinstance(arr, np.ndarray)
        self.arr = arr

# This is the type for ShapeArray, it's basically the same as types.Array and
# inherits from it so as to get access to all the things that work already for
# Array, the only real difference is that this type holds the *shape* of the
# array it's wrapping in the `.shape` attr
class ShapeArrayType(types.Array):
    def __init__(self, sarr):
        val = sarr.arr
        self.shape = val.shape
        try:
            dtype = numpy_support.from_dtype(val.dtype)
        except NotImplementedError:
            raise ValueError("Unsupported array dtype: %s" % (val.dtype,))
        layout = numpy_support.map_layout(val)
        readonly = not val.flags.writeable
        super().__init__(dtype, val.ndim, layout, readonly=readonly)
        nm = self.name
        self.name = 'ShapeArray[%s](%s)' % (str(self.shape), nm)

    @property
    def key(self):
        # use the shape in the key, want to dispatch based on it!
        return super().key, self.shape

# register ShapeArray with `typeof`
@extending.typeof_impl.register(ShapeArray)
def typeof_shapearray(val, c):
    return ShapeArrayType(val)

# tell the backend to use the same datamodel as np.ndarray for ShapeArray
extending.register_model(ShapeArrayType)(ArrayModel)

# This is the most important part, creating a custom dispatcher. This dispatcher
# is the same as the standard cpu dispatcher, but it overrides
# `_compile_for_args` and in there wraps any `ndarray` argument instance in a
# ShapeArray class such that specialisation is possible.
class ShapeSpecialiseDispatcher(Dispatcher):
    targetdescr = cpu_target
    def __init__(self, py_func, locals={}, targetoptions={},
                 impl_kind='direct', pipeline_class=compiler.Compiler):
        super(ShapeSpecialiseDispatcher, self).__init__(py_func,
                                           locals=locals,
                                           targetoptions=targetoptions,
                                           impl_kind=impl_kind,
                                           pipeline_class=pipeline_class)

    def _compile_for_args(self, *args, **kws):
        # this is the important bit, wrap numpy arrays as ShapeArray
        nargs = []
        for x in args:
            if isinstance(x, np.ndarray):
                nargs.append(ShapeArray(x))
            else:
                nargs.append(x)

        return Dispatcher._compile_for_args(self, *nargs, **kws)

# tell the dispatcher registry about this shape specialising dispatcher
dispatcher_registry['shape_specialiser'] = ShapeSpecialiseDispatcher

# create a jit decorator, it's like numba.jit but takes the kwarg
# "shape_specialise" and sets the `_target` based on that (this influences
# which dispatcher to use).
def shape_specialising_jit(*args, **kwargs):
    _specialise = kwargs.pop('shape_specialise', False)
    _target = 'shape_specialiser' if _specialise else 'cpu'
    return jit(*args, _target=_target, **kwargs)

# Example of use:

def bar(x):
    pass

# this demo's the compile time constant shape (or not!) depending on the
# shape_specialise option
@extending.overload(bar)
def ol_bar(x):
    print("Demo specialised", x)
    shape_str = str(getattr(x, 'shape', 'Not present'))
    print("Compile time shape: %s" % shape_str)
    def impl(x):
        return "Runtime const str shape: " + shape_str
    return impl

# a function using standard JIT, works fine
@njit
def baz(x):
    return x.sum()

# demo...
for specialise in (False, True):
    print(("specialise=%s" % specialise).center(80, '-'))
    @shape_specialising_jit(shape_specialise=specialise)
    def foo(x):
        a = len(x)
        b = bar(x)
        c = baz(x)
        return a, b, c

    print(foo(np.ones((5, 4, 3, 2, 1))))
    print("")
    foo.inspect_types()

gives me this:

--------------------------------specialise=False--------------------------------
Demo specialized array(float64, 5d, C)
Compile time shape: Not present
(5, 'Runtime const str shape: Not present', 120.0)

foo (array(float64, 5d, C),)
--------------------------------------------------------------------------------
# File: di6.py
# --- LINE 107 --- 

@shape_specialising_jit(shape_specialise=specialise)

# --- LINE 108 --- 

def foo(x):

    # --- LINE 109 --- 
    # label 0
    #   x = arg(0, name=x)  :: array(float64, 5d, C)
    #   $2load_global.0 = global(len: <built-in function len>)  :: Function(<built-in function len>)
    #   $6call_function.2 = call $2load_global.0(x, func=$2load_global.0, args=[Var(x, di6.py:109)], kws=(), vararg=None)  :: (array(float64, 5d, C),) -> int64
    #   del $2load_global.0
    #   a = $6call_function.2  :: int64
    #   del $6call_function.2

    a = len(x)

    # --- LINE 110 --- 
    #   $10load_global.3 = global(bar: <function bar at 0x7f940dbced40>)  :: Function(<function bar at 0x7f940dbced40>)
    #   $14call_function.5 = call $10load_global.3(x, func=$10load_global.3, args=[Var(x, di6.py:109)], kws=(), vararg=None)  :: (array(float64, 5d, C),) -> unicode_type
    #   del $10load_global.3
    #   b = $14call_function.5  :: unicode_type
    #   del $14call_function.5

    b = bar(x)

    # --- LINE 111 --- 
    #   $18load_global.6 = global(baz: CPUDispatcher(<function baz at 0x7f940db44200>))  :: type(CPUDispatcher(<function baz at 0x7f940db44200>))
    #   $22call_function.8 = call $18load_global.6(x, func=$18load_global.6, args=[Var(x, di6.py:109)], kws=(), vararg=None)  :: (array(float64, 5d, C),) -> float64
    #   del x
    #   del $18load_global.6
    #   c = $22call_function.8  :: float64
    #   del $22call_function.8

    c = baz(x)

    # --- LINE 112 --- 
    #   $32build_tuple.12 = build_tuple(items=[Var(a, di6.py:109), Var(b, di6.py:110), Var(c, di6.py:111)])  :: Tuple(int64, unicode_type, float64)
    #   del c
    #   del b
    #   del a
    #   $34return_value.13 = cast(value=$32build_tuple.12)  :: Tuple(int64, unicode_type, float64)
    #   del $32build_tuple.12
    #   return $34return_value.13

    return a, b, c


================================================================================
--------------------------------specialise=True---------------------------------
Demo specialized ShapeArray[(5, 4, 3, 2, 1)](array(float64, 5d, C))
Compile time shape: (5, 4, 3, 2, 1)
(5, 'Runtime const str shape: (5, 4, 3, 2, 1)', 120.0)

foo (ShapeArray[(5, 4, 3, 2, 1)](array(float64, 5d, C)),)
--------------------------------------------------------------------------------
# File: di6.py
# --- LINE 107 --- 

@shape_specialising_jit(shape_specialise=specialise)

# --- LINE 108 --- 

def foo(x):

    # --- LINE 109 --- 
    # label 0
    #   x = arg(0, name=x)  :: ShapeArray[(5, 4, 3, 2, 1)](array(float64, 5d, C))
    #   $2load_global.0 = global(len: <built-in function len>)  :: Function(<built-in function len>)
    #   $6call_function.2 = call $2load_global.0(x, func=$2load_global.0, args=[Var(x, di6.py:109)], kws=(), vararg=None)  :: (ShapeArray[(5, 4, 3, 2, 1)](array(float64, 5d, C)),) -> int64
    #   del $2load_global.0
    #   a = $6call_function.2  :: int64
    #   del $6call_function.2

    a = len(x)

    # --- LINE 110 --- 
    #   $10load_global.3 = global(bar: <function bar at 0x7f940dbced40>)  :: Function(<function bar at 0x7f940dbced40>)
    #   $14call_function.5 = call $10load_global.3(x, func=$10load_global.3, args=[Var(x, di6.py:109)], kws=(), vararg=None)  :: (ShapeArray[(5, 4, 3, 2, 1)](array(float64, 5d, C)),) -> unicode_type
    #   del $10load_global.3
    #   b = $14call_function.5  :: unicode_type
    #   del $14call_function.5

    b = bar(x)

    # --- LINE 111 --- 
    #   $18load_global.6 = global(baz: CPUDispatcher(<function baz at 0x7f940db44200>))  :: type(CPUDispatcher(<function baz at 0x7f940db44200>))
    #   $22call_function.8 = call $18load_global.6(x, func=$18load_global.6, args=[Var(x, di6.py:109)], kws=(), vararg=None)  :: (ShapeArray[(5, 4, 3, 2, 1)](array(float64, 5d, C)),) -> float64
    #   del x
    #   del $18load_global.6
    #   c = $22call_function.8  :: float64
    #   del $22call_function.8

    c = baz(x)

    # --- LINE 112 --- 
    #   $32build_tuple.12 = build_tuple(items=[Var(a, di6.py:109), Var(b, di6.py:110), Var(c, di6.py:111)])  :: Tuple(int64, unicode_type, float64)
    #   del c
    #   del b
    #   del a
    #   $34return_value.13 = cast(value=$32build_tuple.12)  :: Tuple(int64, unicode_type, float64)
    #   del $32build_tuple.12
    #   return $34return_value.13

    return a, b, c


================================================================================