Numba as a backend to Aesara

Hello,

The PyMC group has forked Theano and started a new project called Aesara. I’m leading the development of Aesara, and I’ve made it a top priority that we replace the old C code generation and compilation backend with Numba.

In the following, I’ll briefly introduce our project and describe how it relates to Numba, then I’ll describe what we’ve already done to support Numba in Aesara and one of the more important things we’re working on at the moment.

Background

Anyone who’s familiar with the old Theano might already have a good idea of what a Numba “backend” means, but, for those who aren’t familiar: Aesara build graphs of symbolic NumPy operations, performs domain-specific optimizations, and either converts those to functions that execute the underlying NumPy functions in pure Python, or converts them to C (i.e. creates and exposes custom Python extensions and orchestrates the use of those).

This latter step–the one that creates computable functions from symbolic graphs–is almost exactly what Numba does with its object and no-python modes (albeit with LLVM instead of C).

The main difference is that Numba is much better for all this, and–from what I’ve seen–Aesara could easily be much better off investing its time and effort on “low-level” things in Numba instead of its existing framework.

At a “high-level”, our intentions with Aesara are to make it an approachable Python platform for domain-specific optimizations, especially ones that exists within the spaces between symbolic math and efficient numerical computation. For more background, see the experiments in symbolic-pymc, the Python DSL work in the GitHub organization pythological, and this summary paper explaining the two.

With that in mind, I believe it’s easier to see how Aesara and Numba could fit together.

I can provide more details upon request, but that’s the gist of it. We have some old documentation that should explain more, but it’s still basically just the old Theano documentation.

Current Efforts

As of now, we have a simple dispatch-based foundation for converting Aesara graphs to Numba JITed functions: aesara.link.numba.dispatch.

It’s essentially a dispatch from Aesara Ops to numba.njited functions that take the same types of inputs as the Ops–except that these inputs are non-symbolic NumPy arrays. In the most basic cases, the JITed functions are just wrappers for NumPy functions. These conversions are very easy to make, because Aesara is largely a symbolic version of NumPy/SciPy.

There have been some challenges, though, and they’re all things of which the Numba community is well aware. For example, an attempt at my slice boxing feature request is in there, as well as a hack for some vectorize varargs-related issues that construct functions from source (using exec).

Work in Progress

There’s one thing I’m working on that’s particularly important, and I would like to get some low-level guidance on how to implement it in Numba.

It’s an Aesara Op called Scan, which represents a somewhat generalized loop. It’s one of the more complicated and least performant parts of Aesara, but it’s also a critical Op, so I’m working on fixing it ASAP, and a Numba-backed implementation could be a simple solution to the bulk of its performance problems.

Scan can be simplified–and somewhat generalized–as follows:

import numpy as np


def scan(carry, *in_sequences):
    global inner_fn
    out_sequences = []
    for in_seq in np.nditer(in_sequences):
        carry, out_seq = inner_fn(*(carry + tuple(in_seq)))
        out_sequences.append(out_seq)
    return np.stack(out_sequences)

Here’s an example of how it’s used:

import numba


@numba.njit
def inner_fn(x_tm1, a_t, b_t):
    x_t = x_tm1 * a_t + b_t
    return (x_t,), x_t


res = scan((np.array(-1.0),), np.arange(10), np.arange(10) * 1e-1)

print(res)
[0.00000e+00 1.00000e-01 4.00000e-01 1.50000e+00 6.40000e+00 3.25000e+01
 1.95600e+02 1.36990e+03 1.09600e+04 9.86409e+04]

As you can guess, there are some issues with njiting this naive version of scan; however, I think I understand the errors, so I’ll spare you that.

I did notice that the numba.stencil feature could cover a lot of the same ground as our scan, so I’m currently looking at its source for inspiration.

Epilogue

At this point, the code in StencilFunc seems to imply that the exec-based approach I’ve been using in Aesara might not be as off-base as I initially thought. (If I’m mistaken, please tell me!) In the meantime, I’ll attempt to mix the source-generation + exec approach with some other things I pick up from the StencilFunc implementation and report any blocking issues here.

Given the extent to which I would like Aesara to use Numba, I am personally interested in becoming involved in Numba development at whatever level the issues we run into ultimately need to be resolved. In other words, if there are more efficient and direct routes that make better use of Numba and improve it (e.g. through bug fixes, new features, etc.), then those are the paths I would like this work to take.

Otherwise, I just wanted to introduce our project and its new Numba-related efforts. If anyone is interested and wants to know more, don’t hesitate to ask. Better yet, if anyone has general advice related to how we should interact with Numba via Aesara, or would like to point us toward any other relevant information or similar endeavors, please do.

(Sorry about the lack of relevant links; there’s a limit set on new Discourse users.)

1 Like

hi @brandonwillard , as a regular Numba user and contributor I will watch this project with a lot of interest. I build jitted function with exec a lot, so it will be nice to see how you do it. I feel like there must be a better way than what I do, I just haven’t found it.
The possibility of having access to numba-compatible functions built out of symbolic graphs is also extremely interesting, for any optimization library, not just PyMC.

cheers,
Luk

1 Like

@brandonwillard looking forward to see your progress on this, sounds very interesting as a use case for Numba. For you and others, there’s both this forum and numba/numba - Gitter for support and the issue tracker for issues, patches to Numba are also welcomed. Further, Numba has weekly meetings Weekly Public Meeting every Tuesday for 2021 for discussing more involved topics and research, feel free to suggest items for discussion.

As to the Numba backend itself, I’ve taken a look at the code in the backend for JAX. Assuming that you are prepared to relax the use of some of the perhaps more idiomatic python constructs, then just from a quick skim I’d say Numba has a reasonable chance of being able to compile a lot of the dispatched functions (as you’ve noted).

Do you also have plans to @jit compile the dispatched graph too so as to gain the benefits of IPO/loop fusion (you’ll have to excuse me probably not using the terminology common to Aesara)?

1 Like

Do you also have plans to @jit compile the dispatched graph too so as to gain the benefits of IPO/loop fusion (you’ll have to excuse me probably not using the terminology common to Aesara)?

@stuartarchibald, If I’m understanding you correctly, yes, we jit compile everything, so the result is really just one Numba jited function–superficially wrapped in one of Aesara’s Function objects, for the time being.

If you’re interested, here’s the current Numba PR: Don't try to infer support shape of multivariate RVs by default by ricardoV94 · Pull Request #388 · pymc-devs/pytensor · GitHub.

There are quite a few more Ops covered in there (with an entertaining collection of Numba hacks), and we’re quickly closing in on complete parity with our JAX backend. Actually, it’s better than that, since there are no “symbolic” shape restrictions in the Numba case.