Passing config variables to functions so they behave as compile time constants

In numba, I want to pass the config variable to a function as a compile-time constant. Specifically what I want to do is

    @njit
    def physics(config):
        flagA = config.flagA
        flagB = config.flagB
        aNumbaList = List()
        for i in range(100):
            if flagA:
                aNumbaList.append(i)
            else:
                aNumbaList.append(i/10)
        return aNumbaList

If the config variables are compile-time constants, this would have passed, but it is not, so it’s giving me errors saying that there are two candidates

There are 2 candidate implementations:
                 - Of which 2 did not match due to:
                 ...
                 ...

I looked at one of numba meeting minutes and found that there was a way to do this Numba Meeting: 2024-03-05 (can’t post link)
I tried that, but it is still raising the same error.

Here is the code with the error message:

.. code:: ipython3

    from numba import jit, types, njit
    from numba.extending import overload
    from numba.typed import List
    import functools

.. code:: ipython3

    class Config():
        def __init__(self):
            self._flagA = True
            self._flagB = False
    
        @property
        def flagA(self):
            return self._flagA
    
        @property
        def flagB(self):
            return self._flagB

.. code:: ipython3

    @functools.cache
    def obj2strkeydict(obj, config_name):
    
        # unpack object to freevars and close over them
        tmp_a = obj.flagA
        tmp_b = obj.flagB
        assert isinstance(config_name, str)
        tmp_force_heterogeneous = config_name
    
        @njit
        def configurator():
            d = {'flagA': tmp_a,
                 'flagB': tmp_b,
                 'config_name': tmp_force_heterogeneous}
            return d
    
        # return a configuration function that returns a string-key-dict
        # representation of the configuration object.
        return configurator

.. code:: ipython3

    @njit
    def physics(cfig_func):
        config = cfig_func()
        flagA = config['flagA']
        flagB = config['flagB']
        aNumbaList = List()
        for i in range(100):
            if flagA:
                aNumbaList.append(i)
            else:
                aNumbaList.append(i/10)
        return aNumbaList

.. code:: ipython3

    def demo():
        configuration1 = Config()
        jit_config1 = obj2strkeydict(configuration1, 'config1')
        physics(jit_config1)

.. code:: ipython3

    demo()


::


    ---------------------------------------------------------------------------

    TypingError                               Traceback (most recent call last)

    Cell In[83], line 1
    ----> 1 demo()


    Cell In[82], line 4, in demo()
          2 configuration1 = Config(True, False)
          3 jit_config1 = obj2strkeydict(configuration1, 'config1')
    ----> 4 physics(jit_config1)


    File ~/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
        464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
        465                f"by the following argument(s):\n{args_str}\n")
        466         e.patch_message(msg)
    --> 468     error_rewrite(e, 'typing')
        469 except errors.UnsupportedError as e:
        470     # Something unsupported is present in the user code, add help info
        471     error_rewrite(e, 'unsupported_error')


    File ~/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
        407     raise e
        408 else:
    --> 409     raise e.with_traceback(None)


    TypingError: Failed in nopython mode pipeline (step: nopython frontend)
    - Resolution failure for literal arguments:
    No implementation of function Function(<function impl_append at 0x7fd87d253920>) found for signature:
    
     >>> impl_append(ListType[int64], float64)
    
    There are 2 candidate implementations:
          - Of which 2 did not match due to:
          Overload in function 'impl_append': File: numba/typed/listobject.py: Line 592.
            With argument(s): '(ListType[int64], float64)':
           Rejected as the implementation raised a specific error:
             TypingError: Failed in nopython mode pipeline (step: nopython frontend)
           No implementation of function Function(<intrinsic _cast>) found for signature:
    
            >>> _cast(float64, class(int64))
    
           There are 2 candidate implementations:
                 - Of which 2 did not match due to:
                 Intrinsic in function '_cast': File: numba/typed/typedobjectutils.py: Line 22.
                   With argument(s): '(float64, class(int64))':
                  Rejected as the implementation raised a specific error:
                    TypingError: cannot safely cast float64 to int64. Please cast explicitly.
             raised from /home/sam/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/typed/typedobjectutils.py:75
           
           During: resolving callee type: Function(<intrinsic _cast>)
           During: typing of call at /home/sam/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/typed/listobject.py (600)
           
           
           File "../anaconda3/envs/tardis/lib/python3.11/site-packages/numba/typed/listobject.py", line 600:
               def impl(l, item):
                   casteditem = _cast(item, itemty)
                   ^
    
      raised from /home/sam/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/core/typeinfer.py:1086
    
    - Resolution failure for non-literal arguments:
    None
    
    During: resolving callee type: BoundFunction((<class 'numba.core.types.containers.ListType'>, 'append') for ListType[int64])
    During: typing of call at /tmp/ipykernel_9889/739598600.py (11)
    
    
    File "../../../tmp/ipykernel_9889/739598600.py", line 11:
    <source missing, REPL/exec in use?>

Any help or any reference to a related material would really help me.
Thank You.

Hi @Sumit112192

The problem is that numba.typed.List is a homogeneous container, i.e. its elements must all have the same type. However, in the following loop, i is an integer and i/10 is a float. What about using integer division instead?

for i in range(100):
    if flagA:
        aNumbaList.append(i)
    else:
        aNumbaList.append(i/10)
1 Like

@sschaer
I agree. But if the config variables are treated as compile-time constants, there won’t be any else statement here since the flag is always True. So, it will never see that I am appending a float.

flag1 = True
@numba.njit
def check1():
    anotherNumbaList = List()
    for i in range(100):
        if flag1:
            anotherNumbaList.append(i)
        else:
            anotherNumbaList.append(i/10)
    return anotherNumbaList
check1()

The above code works perfectly fine because flag1 is a global variable treated as a compile-time constant by numba. Is there a way to make the config variables also a compile-time constant?

I don’t think there is a way to propagate literals through dictionaries. There is no such type as a literal dict. Numba knows about literal lists, so maybe that could work. But honestly, even if there was such a solution, relying on branch pruning to make typing work sounds like a very fragile solution to me.

1 Like

I agree that it’s a bad idea. But for now, if this would work, it can enhance the speed of the code till the time we come up with a new efficient solution.

Could you use overload or generated_jit to get the effect you’re looking for based on the typing of the config variable(s)?

1 Like

generated_jit has been deprecated, I believe.

Hey @Sumit112192 ,

I’m not sure if it works for your purpose but you can access initial values of dictionaries as literals using “numba.overload”.

from numba import njit, types, literally
from numba.typed import Dict, List
from numba.extending import overload

@njit
def get_config_1():
    return {'flagA': True, 'flagB': False}

@njit
def get_config_2():
    return {'flagA': False, 'flagB': True}

# Define the physics function to be overloaded
def physics(config):
    pass

# Create the overload implementation
@overload(physics)
def physics_overload(config):
    iv = config.initial_value
    
    if iv is None:
        return lambda config: literally(config)  # Force literal dispatch

    # Check for specific initial values
    if iv == {'flagA': True, 'flagB': False}:
        def physics_impl_1(config):
            aNumbaList = List()
            for i in range(10):
                aNumbaList.append(i)
            return aNumbaList

        return physics_impl_1

    elif iv == {'flagA': False, 'flagB': True}:
        def physics_impl_2(config):
            aNumbaList = List()
            for i in range(10):
                aNumbaList.append(i / 10.0)
            return aNumbaList

        return physics_impl_2


# Wrap the call to physics in an njit function
@njit
def call_physics(config):
    return physics(config)

config = get_config_1()
result = call_physics(config)
print(result)
config = get_config_2()
result = call_physics(config)
print(result)

# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ...]
# [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, ...]

Can you look at this question?

It is, and yet widely used.

1 Like

Deprecated in the sense that it has become obsolete. The latest version doesn’t even have that.

1 Like

Thanks for the heads-up, I wasn’t aware. I guess I missed the deletion alert in the release notes.

Interestingly, dispatcher.py still seems to have the compiler

EDIT: @DannyWeitekamp you might be interested in this if you haven’t picked up on it… I think at least some versions of CRE used generated_jit

@nelson2005 Yeah the generated_jit deprecation has been a big annoyance for me. I’ve haven’t had time to make the conversion, and I really don’t like how crufty the @overload solution is—too many separate function definitions just to accomplish what should be logically one self-contained thing. I still don’t fully appreciate the technical details of why generated_jit needed to go, it definitely served a purpose that as far as I can tell has not been fully replicated.

I’m almost certain that I’ve solved the issue of doing multiple dispatch on a grouped config object with literals…

@Sumit112192 The other approach that comes to mind is defining a structref that explicitly uses literal types for some of the members. That approach, has some annoying boilerplate code, but would probably be a lot cleaner than an approach that used @overload.

2 Likes

Thanks for the follow-up. I will look into it.