How to specify the type signature of funtion pointers to jitclass member function

(Previous post was hidden and still not released after several days, so try to post again. Will remove that one if released.)

Basically what I’m trying to do is to have a function pointer that points to different member functions according to some argument. I can achieve this when the functions are not members of a jitclass, like this:

@nb.experimental.jitclass
class Foo:
    a: float
    f: nb.float64(nb.float64).as_type()

    def __init__(self, a):
        if a > 1e-6:
            self.f = func1
        else:
            self.f = func2

@nb.njit
def func1(x):
    return x * x

@nb.njit
def func2(x):
    return x + x

Wwhat I really want is a function pointer to member functions (which allows me to for example access class attributes), something like this:

@nb.experimental.jitclass
class Foo:
    a: float
    f: nb.float64(nb.float64).as_type()

    def __init__(self, a):
        self.a = a
        if a > 1e-6:
            self.f = self.func1
        else:
            self.f = self.func2

    def func1(self, x):
        return x * x

    def func2(self, x):
        return x + x + self.a

However this code cannot run as is. My question is, is this feature not yet supported with latest version of numba 0.58, or if it is supported, how do I make it work?

Hey @seekiu ,

Unfortunately, I can’t answer your question. I don’t know how to define a function pointer within a jitclass. Nevertheless, this is an interesting topic and I would like to add some content regarding your question if you don’t mind.
The straightforward way to generate the jitclass would be to implement the Python class with type annotations and simply convert it.
To help inferring the data types we can provide the field specs.
Unfortunately, there seems to be a problem casting the FunctionType into a BoundFunction type.
I would assume that the field spec for the BoundFunction type has to be adjusted to solve the problem.

# =============================================================================
# Try to convert Python class into jitclass using FunctionType in spec
# Error: Cannot cast BoundFunction[...] to FunctionType[...]
# =============================================================================
from collections.abc import Callable
import numba as nb
import numba.types as nbt

class Foo:
    a: float
    f: Callable[[float], float]

    def __init__(self, a):
        self.a = a
        self.f = (self.func1 if a > 1e-6 else self.func2)

    def func1(self, x: float) -> float:
        return x * x

    def func2(self, x: float) -> float:
        return x + x + self.a

foo1 = Foo(a=1.)
foo2 = Foo(a=1e-7)
print(f'foo1: {foo1.f(10.0)}')
print(f'foo2: {foo2.f(10.0)}')
# foo1: 100.0
# foo2: 20.0000001

spec = [('a', nbt.f8),
        ('f', nbt.FunctionType(nbt.f8(nbt.f8)))]
JitFoo = nb.experimental.jitclass(Foo, spec)
foo1 = JitFoo(a=1.)
foo2 = JitFoo(a=1e-7)
print(f'foo1: {foo1.f(10.0)}')
print(f'foo2: {foo2.f(10.0)}')
# TypingError: Failed in nopython mode pipeline (step: native lowering)
# Cannot cast BoundFunction[...] to FunctionType[...]

If we are not able to define the correct function type to make the conversion into a jitclass work we could do a workaround.
If the class can remain as a Python class and we just need the computation power of Numba in specific functions we could use a function factory to generate a jitted function. We can assign the jitted function to a function pointer in the Python class.

# =============================================================================
# Keep Python class and generate jitted function
# =============================================================================
from collections.abc import Callable
import numba as nb
import numba.types as nbt


def function_factory(a: float) -> Callable[[float], float]:
    @nb.njit(nbt.f8(nbt.f8))
    def func(x: float) -> float:
        return x * x if a > 1e-6 else x + x + a
    return func

class Foo:
    def __init__(self, a: float):
        self.a = a
        self.f = function_factory(a)

foo1 = Foo(a=1.)
foo2 = Foo(a=1e-7)
print(f'foo1: {foo1.f(10.0)}')
print(f'foo2: {foo2.f(10.0)}')
# foo1: 100.0
# foo2: 20.0000001

If it has to be a jitclass, we don’t need a function pointer and there is no reason not to use external functions we can follow your suggestion.

# =============================================================================
# Generate Python class and convert to jitclass using external functions
# =============================================================================
import numba as nb
import numba.types as nbt

@nb.njit(nbt.f8(nbt.f8))
def func1(x: float) -> float:
    return x * x

@nb.njit(nbt.f8(nbt.f8, nbt.f8))
def func2(x: float, a: float) -> float:
    return x * x if a > 1e-6 else x + x + a

class Foo:
    a: float

    def __init__(self, a):
        self.a = a

    def f(self, x: float):
        return (func1(x) if self.a > 1e-6
                else func2(x, self.a))

spec = [('a', nbt.f8)]
JitFoo = nb.experimental.jitclass(Foo, spec)
foo1 = JitFoo(a=1.)
foo2 = JitFoo(a=1e-7)
print(f'foo1: {foo1.f(10.0)}')
print(f'foo2: {foo2.f(10.0)}')
# foo1: 100.0
# foo2: 20.0000001

If the functions have to be part of the class and we don’t need the function pointer we could generate the Python class, define the functions as static methods and simply convert it into a jitclass. We only have to specify the field a for the conversion.
All methods of the jitclass should be automatically compiled into nopython functions.

# =============================================================================
# Generate Python class and convert to jitclass using staticmethods
# =============================================================================
import numba as nb
import numba.types as nbt

class Foo:
    a: float

    def __init__(self, a):
        self.a = a

    def f(self, x: float):
        return (self.func1(x) if self.a > 1e-6
                else self.func2(x, self.a))

    @staticmethod
    def func1(x: float) -> float:
        return x * x

    @staticmethod
    def func2(x: float, a: float) -> float:
        return x * x if a > 1e-6 else x + x + a

spec = [('a', nbt.f8)]
JitFoo = nb.experimental.jitclass(Foo, spec)
foo1 = JitFoo(a=1.)
foo2 = JitFoo(a=1e-7)
print(f'foo1: {foo1.f(10.0)}')
print(f'foo2: {foo2.f(10.0)}')
# foo1: 100.0
# foo2: 20.0000001

I am curious to see the solution to your problem.

Thanks @Oyibo ! Although your answer didn’t answer my specific question, it’s quite helpful for me to understand numba better.

Let me add a bit more context to explain why I want function pointer (more specifically pointing to a member function) in a jitclass:

  • I would like to have a unified argument for func1 and func2 via a surrogate f.
  • Since f will be called A LOT, I would like to avoid overhead as much as I can. This means that a) preferably Foo is jitted as a whole, and b) we avoid the if inside the surrogate f, which I suppose is not negligible because func1 and func2 are almost as lightweight as the if itself.

I’m totally open to other options/workarounds that satisfy the above 2 constraints. For now, jitting Foo and using function pointer seems to be the most straightforward one.

So I guess my question can be 2 folds:

  1. For the purpose of addressing my specific problem: Is there a way to achieve the 2 constraints with existing numba features?
  2. For just getting to know numba better: If I insist to have a function pointer that points to a member function, is it doable or not supported yet?

Hello @seekiu

Just a very quick note. I wouldn’t worry about that little overhead. Calling a jitclass method is most likely already a bit slower than calling a normal function. An indirect call can slow you down even more. I complicates optimizations further down the line. I would definitely leave it at the if statement and see if there is more room for improvement elsewhere. Numba and function pointers are usually not so trivial to handle. There was a discussion some time ago that you might find interesting and perhaps even useful:

As I said, I would go for the simple solution. However, depending on your actual problem (I assume you showed something simplified) you can try a branchless version:

def func(self, x):
        return (x + x + self.a, x * x)[nb.uint8(self.a > 1e-6)]

Whether a branchless version is faster depends on many things. So just try it out but with your real problem.

Hi @sschaer, thanks for the great info. I did some additional benchmarks to see how various ways of calling the underlying functions perform, and tbh the result is a little bit surprising. Here is the summary:

  • Compared to calling the njit functions directly, calling them from jitclass adds insignificant overhead (statistically about 10% slower, but the benchmark itself has about 10% variance)
  • Using if to decide which njit function to call adds roughly 100~200% overhead
  • Using function pointer (hardcoding the a in my func2) results in about 50% overhead

The tests are done on my actual problem, which is indeed slightly more complex than the example I had above but still fairly lightweight (a single direct call is at the order of 10 nano second). I don’t assume the conclusion to be universal, but based on the benchmarks, I would like to avoid using if because 200% overhead is too much. On the other hand, 50% overhead might be acceptable, though still worse than I thought, if it makes the implementation clean and nice.

Hi @seekiu

I briefly ran some tests with the code from your original post and cannot reproduce the problem you describe. In fact, I get very consistent measurements and the difference between different implementations are minor. Maybe you would like to share your full example and also specify the Numba and Python version you are using?

Hey @sschaer thank you for taking the time to test on your side. I guess the example in the original code is too trivial to see the difference between implementations. I uploaded a benchmark code here (which is still simplified but runs standalone): benchmark.ipynb · GitHub.

This benchmark is a little bit different from what I described in my earlier post. Previously I was benchmarking using 2-level loops and put the loops inside an njit function. Today I realized that I always use the inner loop as a whole, which means I can have the if outside the inner-most loop. So now I have the inner loops inside an njitted function get_col and use non-jitted loop to run the benchmark.

My environment is Mac M1, Python 3.11.5, numba 0.58.

(An interesting but off-topic observation: Having njitted 2-level loops gets me about 10~50% variance between benchmarks, but with the inner-njit-outer-python code I got very stable wall times, roughly 1% variance)

The comparison ktype == 'linear' is very expensive. You have to compare every single character when they are equal. This is a very different situation than your first example. Try to use enums instead and profile again.

import enum 

class KernelType(enum.IntEnum):
    LINEAR = enum.auto()
    RBF = enum.auto()
    
@nb.njit
def kernel(x1, x2, ktype, gamma):
    if ktype == KernelType.LINEAR:
        return kernel_linear(x1, x2, 1)
    else:
        return kernel_rbf(x1, x2, gamma)
3 Likes

Wow, that works perfectly. I thought about the string comparison to be expensive, but didn’t expect it to be this much. So indeed this can be a nice and performant solution to my specific problem.

Then there leaves the other fold of my question: If I must have function pointers, is it achievable with the current version of numba (without digging into the lower-level part)? Based on your reference above, I might assume the answer to be NO for now?

1 Like

It depends. How important is caching and short compilation time for you? If it’s not of high priority, you can use something like this without sacrificing performance:

import numba as nb 
import numpy as np 

@nb.njit
def kernel_linear(x1, x2):
    s = 0
    for i in range(x1.shape[0]):
        s += (x1[i] * x2[i])
    return s

@nb.njit
def kernel_rbf(x1, x2, gamma):
    s = 0
    for i in range(x1.shape[0]):
        s += (x1[i] - x2[i]) ** 2
    return np.exp(-gamma * s)

@nb.njit
def test(x, kernel_func, *kernel_params):
    out = np.empty(x.shape[0], x.dtype)
    for i in range(x.shape[0]):
        for j in range(x.shape[0]):
            out[i] = kernel_func(x[i], x[j], *kernel_params)
    return out

x = np.random.rand(5_000, 3)

test(x, kernel_linear)
test(x, kernel_rbf, 1.0)

I just checked, and we discussed the pros and cons of various alternatives for such problems in the thread I posted above:

Hey @sschaer , thanks for the pointer and sorry for the delayed response.

I should’ve been clearer in my previous question. My other fold of the question was meant to ask specifically about function pointers that point to member functions of jitclass.

It seems that having such function pointers require casting BoundFunction type to FunctionType (according to the error message I got). I take it that the two types are different to numba and require some extra non-trivial stuff to make it work. (If I understand correctly, the post you reference above did not talk about this explicitly, but my take-away is that such casting is not trivial)

Let me know if my explanation is clear enough or if I still miss something from that post you referenced (which btw is a great in-depth discussion and to be honest I probably missed something because of my lack of knowledge about the lower-level part of numba)

Update: I don’t have special requirement on caching and compilation time.