Handling deafult NoneType arguments

Hi everyone,

I have a problem with the following function that implement a Gaussian KDE:

@njit
def kde_function(points, dataset, weights=None, bw_method=None, in_log=False):
    
    dataset = np.atleast_2d(dataset)    
    d_dataset, n_dataset = dataset.shape
    
    points = np.atleast_2d(points)
    d_points, n_points = points.shape
        
    if d_points != d_dataset:
        if d_points == 1 and n_points == d_dataset:
            points = points.T
            n_points = points.shape[1]
        else:
            msg = "points have dimension " + str(d_points) + ", dataset has dimension " + str(d_dataset)
            raise ValueError(msg)
        
    if weights is not None:
        if weights.ndim != 1:
            raise ValueError("`weights` input should be one-dimensional.")
        if len(weights) != n_dataset:
              raise ValueError("`weights` input should be of length n_dataset")
        weights = weights / np.sum(weights)       
    else:
        weights = np.full(n_dataset, 1.0 / n_dataset, dtype=dataset.dtype)
    
   # other stuff here

where the “weights>” variable is optional and I set iis default value to None.
However, when I call the function without passing such variable, the compilation of the function executes this piece of code

 if weights is not None:
        if weights.ndim != 1:
            raise ValueError("`weights` input should be one-dimensional.")
        if len(weights) != n_dataset:
              raise ValueError("`weights` input should be of length n_dataset")
        weights = weights / np.sum(weights)

and of course fails with the error Unknown attribute 'ndim' of type none.

Do you know a workaround for this problem?

Thanks in advanced!

You can try providing optional type in the signature:

@njit(nb.void(nb.f8[:], nb.f8[:], nb.types.Optional(nb.f8[:])))
def kde_function(points, dataset, weights):

    dataset = np.atleast_2d(dataset)
    d_dataset, n_dataset = dataset.shape

    points = np.atleast_2d(points)
    d_points, n_points = points.shape

    if d_points != d_dataset:
        if d_points == 1 and n_points == d_dataset:
            points = points.T
            n_points = points.shape[1]
        else:
            msg = "points have dimension " + str(d_points) + ", dataset has dimension " + str(d_dataset)
            raise ValueError(msg)

    if weights is not None:
        if weights.ndim != 1:
            raise ValueError("`weights` input should be one-dimensional.")
        if len(weights) != n_dataset:
            raise ValueError("`weights` input should be of length n_dataset")
        weights = weights / np.sum(weights)
    else:
        weights = np.full(n_dataset, 1.0 / n_dataset, dtype=dataset.dtype)

# other stuff here

points = np.array([1.2, 2.3, 3.4])
dataset = np.array([3.2, 4.3, 5.4])
weights = np.array([0.1, 0.2, 0.3])

print(kde_function(points, dataset, None))
print(kde_function(points, dataset, weights))

Alternatively you can use overload, for instance

def kde_function_(points, dataset, weights=None):
    return

@nb.extending.overload(kde_function_)
def ol_kde_function(points, dataset, weights=None):
    if isinstance(weights, nb.types.NoneType):
        def f(points, dataset, weights=None):
            return 1

    else:
        def f(points, dataset, weights=None):
            return 2
    return f

@nb.njit
def kde_function(points, dataset, weights=None):
    return kde_function_(points, dataset, weights)

assert kde_function(1, 2) == 1
assert kde_function(1, 2, 3) == 2

Thank you so much for the answer!
It solves the issue.

Just need an additional hint for another issue I have.

I’m using njit function, not the decorator. How to pass the signature to it?

Hey @mtagliazucchi ,

An optional/omitted type in the function signature should actually work out of the box.

import numpy as np
from numba import njit

@njit
def func(arr=None):
    return -1 if arr is None else arr.ndim

print(func(np.array([1])))
print(func(None))
print(func())
# 1
# -1
# -1

Your problem could be related to infering the types of the local variable in these two lines:

weights = weights / np.sum(weights)
...
weights = np.full(n_dataset, 1.0 / n_dataset, dtype=dataset.dtype)

Can you change the local variable names?

weights_new = weights / np.sum(weights)
...
weights_new = np.full(n_dataset, 1.0 / n_dataset, dtype=dataset.dtype)

You’re completely right!
This really solves the problem.

Thank you