Controlling inlining of jitted functions (with TypedDict as argument)

I’m writing some long functions which, for legibility, I would like to break up into smaller functions. This requires a certain amount of passing variables around, which ideally would be done using a TypedDict. Additional, some parts of the larger function might change depending on external conditions. A minimal example could be something like this:

from numba import njit
import numpy as np

def build_function(ran, sign):

  @njit
  def setup():
    r = np.random.rand() * ran
    s = np.random.rand() * ran
    return {"r":r, "s":s}

  if sign == "plus":
    @njit
    def calculate(a, b, d):
      r = d['r']
      s = d['s']
      return a*r + b*s
  else:
    @njit
    def calculate(a, b, d):
      r = d['r']
      s = d['s']
      return a*r - b*s

  @njit
  def function(a, b):
    d = setup()
    return calculate(a,b,d)

  return function

# testing
function = build_function(2., "plus")

function(1.,1.)

%timeit function(1., 2.)

$   1.19 µs ± 378 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

If instead I only consider the “plus” case, I can write a more compact code:

def build_function(ran, sign):

  @njit
  def function(a, b):
    r = np.random.rand() * ran
    s = np.random.rand() * ran
    return a*r + b*s

  return function


function = build_function(2., "plus")

function(1.,1.)

%timeit function(1., 2.)

$   350 ns ± 31 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

The second case is significantly faster, which I assume comes down to not needing to allocate a TypedDict and pass it around. Intuition however suggests that, in the first case, it should be possible for the compiler to skip the allocation.

My first question is: can I somehow ensure that setup() and calculate() are inlined while compiling function()? I know that this can be done by setting inline=“always” in @njit, as described at Notes on Inlining — Numba 0+untagged.4124.gd4460fe.dirty documentation , however the page mentions that no performance gain would be obtained this way. In fact, when I added inline=“always” I did not observe any significant change in speed.

The second question is: even when inlining, would the compiler be able to recognize that there is no need to allocate a dictionary at all, therefore skipping that step–and how could I check this?

(For clarity: of course, in the specific case there is no actual need to use a TypedDict, and it is much easier to simply pass r and s. However, passing a dictionary would be preferable when writing function templates).

For nitty gritty internal implementation reasons I wouldn’t hold out hope that the compiler would figure out how to cut out TypedDict allocations. For two reasons: 1) The current implementations of TypedDict and TypedList involve calling out to C modules that are treated as opaque functions and so the compiler probably wouldn’t be able to determine that it is safe to cut these out(this may change in future versions of numba) 2) during type inference the strings in TypedDict assignments are probably going to get cast to unicode_type and not retain their Literal designation, so it’s unlikely that the compiler could ever replace d[‘a’] with a variable. You might have better luck with namedtuples however since (I think) they boil down to something like a struct in C. Structref may work too, but they are dynamically allocated so that may incur unwanted overhead. Overall TypedDict is kind of overkill here, ot seems like you don’t really need a hashtable, just some kind of object with fixed named slots.

1 Like

Thank you for the suggestion of using nampedtuples! Indeed, by doing

from numba import njit
import numpy as np
from collections import namedtuple

Numbers = namedtuple("Numbers", "r s")

def build_function(ran, sign):

  @njit
  def setup():
    r = np.random.rand() * ran
    s = np.random.rand() * ran
    return Numbers(r, s)

  if sign == "plus":
    @njit
    def calculate(a, b, d):
      r = d.r
      s = d.s
      return a*r + b*s
  else:
    @njit
    def calculate(a, b, d):
      r = d.r
      s = d.s
      return a*r - b*s

  @njit
  def function(a, b):
    d = setup()
    return calculate(a,b,d)

  return function


function = build_function(2., "plus")

function(1.,1.)

%timeit function(1., 2.)

I get something very close to the time required by the simple version of the function. I’m not sure if this is due to better inlining, or if simply namedtuples are easier for numba to throw around compared to TypedDict, but either way it does the trick. Thanks!