How to pass an "optional" variable that is known not to be None?

I have a function that takes in an optional array. If the array is None, it loads the array from elsewhere, and then does an operation (in this case,np.linalg.cholesky). However, I don’t know how to convert the numba.optional function argument to a variable that is known not to be None. Below is a minimal example:

from typing import Optional  # type hints included for clarity
import numpy as np
import numba


@numba.njit((numba.optional(numba.float64[:, :]),))
def optional_cholesky(arr: Optional[np.ndarray]) -> np.ndarray:
    if arr is None:
        arr = np.eye(2, dtype=numba.float64)
    return np.linalg.cholesky(arr)


print(optional_cholesky(None))

When I run this example, I get an error:

No implementation of function Function(<function cholesky at 0x7f7af1cd9550>) found for signature:
 
 >>> cholesky(OptionalType(array(float64, 2d, A)) i.e. the type 'array(float64, 2d, A) or None')

It appears that the cholesky function doesn’t like optional arguments. This makes sense somewhat, because passing in None shouldn’t work, but I don’t know how to inform the compiler that arr is now certainly an array, and is not None.

The example above uses a function, but the case I’m trying to implement has a lazily-evaluated variable in a jitclass. If the variable is None, the correct value is calculated, and then stored in the variable. I could work around it by using some sort of indicator value, e.g. an array of NaNs, but that seems less intuitive and efficient.

Hi @BillThePlatypus,

I don’t think that Numba supports exactly what you have in the above yet. The np.linalg.* routines don’t handle Optional, there’s also a more general problem with the need for better automatic Optional handling across Numba Need common util to handle Optional types in comparator · Issue #7480 · numba/numba · GitHub.

This “trick” might help, but might rely a bit too much on compile time type information to be useful for your use case (if you post a cut down version of your use case someone might be able to take a deeper look):

import numpy as np
import numba


@numba.njit
def optional_cholesky(arr):
    if arr is None:
        use = np.eye(2, dtype=numba.float64)
    else:
        use = arr
    return np.linalg.cholesky(use)


print(optional_cholesky(np.eye(2,)))
print(optional_cholesky(None))

the above makes use of dead branch pruning, which basically looks at types at compile time and works out if a branch is actually live for the types present and then removes it if not.

Hope this helps?

1 Like