Checking for None

Hi, I have the below code (self-containing example).

When running without njit(), the output to the console is:
0.8906
0.0

When I add @njit() as decorator, Numba chokes on the last line of the code. The line before that executes fine.
The only difference in the last line is that the w argument is now set to None. Which should be handled by the if w is not None statement. Yet I get the following error code from numba.

Unknown attribute 'astype' of type none

File "slope_debug.py", line 7:
def slope(x, y, w):
    <source elided>
    if w is not None:
        w = w.astype(np.float64)
        ^

How can the errored line even be reached? Am I making some rookie mistake?

import numpy as np
from numba import jit, njit

@njit()
def slope(x, y, w):
    if w is not None:
        w = w.astype(np.float64)
        x2 = x.reshape(-1, 1)
        w2 = w.reshape(-1, 1)
        xw = x2 * np.sqrt(w2)
        yw = y * np.sqrt(w)
        xw = xw.astype(np.float64)
        yw = yw.astype(np.float64)
        slope, p, q, r = np.linalg.lstsq(xw, yw)
        slope = slope[0]
        return slope
    else:
        return 0.0

x = np.array([3.0, 4.0, 5.0, 6.0])
y = np.array([2.3, 4.5, 3.4, 4.0])

print(slope(x=x, y=y, w=np.array([1.0, 1.0, 0.5, 0.0])))
print(slope(x=x, y=y, w=None))

I’m using numba 0.56.0.

Thanks in advance

I don’t know the specific answer to your question, but for this kind of situation I sometimes find that generated_jit can be helpful

Thanks, generated_jit was a good idea, I had not tested it.
Unfortunately, it makes the function break in a different place :wink:

Using generated_jit results in AttributeError: 'Array' object has no attribute 'astype'
at the line w = w.astype(np.float64).

Moving all of the .astype’s to outside the loop is also not an option (I have tried it earlier).

This works for me on 0.55.1/python 3.7.1/windows 10

import numpy as np
import numba

@numba.generated_jit
def slope(x, y, w):
    if isinstance(w, numba.core.types.misc.NoneType):
        def _(x, y, w):
            return 0.0
    else:
        def _(x, y, w):
            w = w.astype(np.float64)
            x2 = x.reshape(-1, 1)
            w2 = w.reshape(-1, 1)
            xw = x2 * np.sqrt(w2)
            yw = y * np.sqrt(w)
            xw = xw.astype(np.float64)
            yw = yw.astype(np.float64)
            slope, p, q, r = np.linalg.lstsq(xw, yw)
            slope = slope[0]
            return slope
    return _

That being said, you might not get much speedup from this kind of code. I’ve generally found that ‘numpy-style’ code like this already has the meat of the algorithm compiled on the numpy side so there’s not a lot of benefit to numba here.

Additionally, many numpy operations allocate memory. The memory allocation is likely highly optimized, but numba can often do the job without allocation (depending on the task)

import numpy as np
import numba

x = np.array([3.0, 4.0, 5.0, 6.0])
print(x.__array_interface__['data'])

@numba.njit
def convert(arr):
    for i in range(len(arr)):
        arr[i] = np.float64(i)

convert(x)
print(x.__array_interface__['data'])

y = x.astype(np.float64)
print(y.__array_interface__['data'])

1 Like

Thanks a lot! This worked, and I learned a few new things here (mostly generated_jit and how you constructed the nested functions).

After some more experimenting this is what I ended up with:

I found a way to take my original code from the OP and move w = w.astype(np.float64) outside of the entire function. This (unbelievably to me) solved the original error message.

Speed in my project:
with numba: 2.15 seconds
without numba: 7.65 seconds

I originally expected the same as you: not much to gain on this rather ‘pure’ numpy code.
Good to see the speedup anyway! I expect it will save me hours of compute power in production.

Oh- after thinking about it some more, I suspect it’s because ‘w’ has two different types in the OP. I guess if you use a different variable name it solves the problem as well

w1 = w.astype(np.float64)
1 Like

You’re right! This shaved of another 260ms of the total runtime :smile:. (= measured over the function that is wrapped around the slope function). 2.92s → 2.66s.