Runtime specialization

hi @gdonval, interesting question. I won’t claim to have an answer, but I’ll offer some ideas, hoping to learn from others’ replies.

My first thought is that if you want to get compile-time specialization, then you need to provide that information at compile-time. Not telling the compiler the shape of arr, and expecting it can optimize based on the unknown number, is probably hard.

A few ideas to get you closer to what you want.

  • I’m curious why you need loop_array_spec to be jitted. Probably your real example is more complex, but loop_array_spec does not need to be jitted.

  • I’m happy to treat pure python as a kind of macro system for jitted code, and I usually write things like:

    @lru_cache()
    def make_loop_array(dims):
    
        @nb.njit(fastmath=True)
        def loop_array_spec(arr):
            tp = DimensionalType(dims)
            return _loop_array(tp, arr)
    
     return loop_array_spec
    
    

    The downside is that you need to track all your versions manually

  • Similar to the above, there might be a way to use literal dispatching to do the tracking for you. I’m speculating here, I’m not sure exactly how to do it , but something like:

    
        @nb.generated_njit(fastmath=True)
        def loop_array_spec(arr, shape):
            dim = shape._literal_value
    
            def impl(arr, shape)
                tp = DimensionalType(dim)
                return _loop_array(tp, arr)
    
            return impl
    
    
  • I’m guessing you don’t have a huge number of possible array sizes, because if you had, specialization would spend a lot of time in compilation and possibly erase the benefits of faster runtime. so another idea is

    @nb.njit(fastmath=True)
    def loop_array_spec(arr):
        if arr.shape[-1] = 1:
            tp = DimensionalType(1)
        elif arr.shape[-1] = 2:
            tp = DimensionalType(2)
        return _loop_array(tp, arr)
    

    Note that if you don’t want to write the above by hand (maybe because the exact values change over time) you can generate the function programatically as text, and then exec it. Just be aware than large text functions can take long time to compile (Tips or tricks for speeding up compilation time on first call of large Numba-jitted NumPy-containing functions?).

I hope this helps, I’m curious to see what other people come up with.

Luk