Get type of structref for function signature

Basically I have a structref, and a function takes it as an argument and I need the function signature.

I used to use a jitclass for this Position class for my chess engine, but now I’m trying to use a structref instead because jitclasses take too long to JIT compile.

@structref.register
class PositionStructType(types.StructRef):
    def preprocess_fields(self, fields):
        return tuple((name, types.unliteral(typ)) for name, typ in fields)


class PositionStruct(structref.StructRefProxy):
    def __new__(cls, board,
                white_pieces,
                black_pieces,
                king_positions,
                castle_ability_bits,
                ep_square, side, hash_key):

        return structref.StructRefProxy.__new__(cls, board, white_pieces, black_pieces,
            king_positions, castle_ability_bits, ep_square, side, hash_key)

    @property
    def board(self):
        return PositionStruct_get_board(self)

    @property
    def white_pieces(self):
        return PositionStruct_get_white_pieces(self)

    @property
    def black_pieces(self):
        return PositionStruct_get_black_pieces(self)

    @property
    def king_positions(self):
        return PositionStruct_get_king_positions(self)

    @property
    def castle_ability_bits(self):
        return PositionStruct_get_castle_ability_bits(self)

    @property
    def ep_square(self):
        return PositionStruct_get_ep_square(self)

    @property
    def side(self):
        return PositionStruct_get_side(self)

    @property
    def hash_key(self):
        return PositionStruct_get_hash_key(self)


@njit
def PositionStruct_get_board(self):
    return self.board


@njit
def PositionStruct_get_white_pieces(self):
    return self.white_pieces


@njit
def PositionStruct_get_black_pieces(self):
    return self.black_pieces


@njit
def PositionStruct_get_king_positions(self):
    return self.king_positions


@njit
def PositionStruct_get_castle_ability_bits(self):
    return self.castle_ability_bits


@njit
def PositionStruct_get_ep_square(self):
    return self.ep_square


@njit
def PositionStruct_get_side(self):
    return self.side


@njit
def PositionStruct_get_hash_key(self):
    return self.hash_key


structref.define_proxy(PositionStruct, PositionStructType, ["board", "white_pieces", "black_pieces",
                                                            "king_positions", "castle_ability_bits",
                                                            "ep_square", "side", "hash_key"])


@njit
def init_position():
    position = PositionStruct(np.zeros(120, dtype=np.int8),
                              [nb.int64(1) for _ in range(0)], [nb.int64(1) for _ in range(0)],
                              np.zeros(2, dtype=np.uint8), 0, 0, 0, 0)

    return position

That above is my structref, and I have a function that takes the position as an argument and needs to return an unsigned 64-bit integer.

@nb.njit(nb.uint64(Position.class_type.instance_type), cache=True)

I used to use this for the function signature, however with the structref this no longer works.

Can you leave the signature definition out of the njit call?

I want to return a uint64, and in the future I might try to do ahead of time compilation. I guess I could leave it out, but it would be better if I can use it.

Can you post a minimal complete running example without the signature?

This is the function, which returns a uint64 hash key. Without the function signature it doesn’t run correctly.

@nb.njit(nb.uint64(Position.class_type.instance_type), cache=True)
def compute_hash(position):
    code = 0

    for i in range(64):
        pos = STANDARD_TO_MAILBOX[i]
        if position.board[pos] > BLACK_KING:
            continue

        code ^= PIECE_HASH_KEYS[position.board[pos]][i]

    if position.ep_square:
        code ^= EP_HASH_KEYS[MAILBOX_TO_STANDARD[position.ep_square]]

    code ^= CASTLE_HASH_KEYS[position.castle_ability_bits]

    if position.side:  # side 1 is black, 0 is white
        code ^= SIDE_HASH_KEY

    return code

These are the constants used in the code

PIECE_HASH_KEYS = np.random.randint(1, 2**64 - 1, size=(12, 64), dtype=np.uint64)
EP_HASH_KEYS = np.random.randint(1, 2**64 - 1, size=64, dtype=np.uint64)
CASTLE_HASH_KEYS = np.random.randint(1, 2 ** 64 - 1, size=16, dtype=np.uint64)
SIDE_HASH_KEY = np.random.randint(1, 2 ** 64 - 1, dtype=np.uint64)

I tested, without the function signature it does work. But if I wanted to ahead of time compile it in the future then I guess I would need it, but I’m not sure if it’s possible to get the type of a structref.

Can you call it once without the signature and then print the signature of the jitted function?

How do I print the signature of the function?

For reference, here is the way to get the type signature of a structref. You first create an instance of it and then inspect its type using typeof(). In Numba, structured references (structref.StructRef) have their types registered in the types module.

Basically, in your example, before defining the init_position() function, you need

from numba import typeof,njit,types

_position_type = typeof(PositionStruct(np.zeros(120, dtype=np.int8),
                              [nb.int64(1) for _ in range(0)], [nb.int64(1) for _ in range(0)],
                              np.zeros(2, dtype=np.uint8), 0, 0, 0, 0))

@njit(_position_type(types.void))
def init_position():
    position = PositionStruct(np.zeros(120, dtype=np.int8),
                              [nb.int64(1) for _ in range(0)], [nb.int64(1) for _ in range(0)],
                              np.zeros(2, dtype=np.uint8), 0, 0, 0, 0)

    return position

Below, there’s a complete example where I needed to do something just like you needed with your Position struct.

Steps to obtain the type signature:

  1. Create an instance of NeuronRandNetParamtruct.
  2. Use numba.typeof() to obtain the corresponding Numba type.
  3. Use it in your function signature.

Solution:

1. Obtain the type signature

from numba import typeof, njit

# defining the structref
@structref.register
class NeuronRandNetParamtructType(types.StructRef):
    def preprocess_fields(self, fields):
        return tuple((name, types.unliteral(typ)) for name, typ in fields)

class NeuronRandNetParamtruct(structref.StructRefProxy):
    def __new__(cls,Iext,mu,theta,Gamma,I,P_poisson,tauTinv,uT):
        # Overriding the __new__ method is optional, doing so
        # allows Python code to use keyword arguments,
        # or add other customized behavior.
        # The default __new__ takes `*args`.
        # IMPORTANT: Users should not override __init__.
        return structref.StructRefProxy.__new__(cls,Iext,mu,theta,Gamma,I,P_poisson,tauTinv,uT)
    pass

structref.define_proxy(NeuronRandNetParamtruct, NeuronRandNetParamtructType, ['Iext','mu','theta','Gamma','I','P_poisson','tauTinv','uT'])

# Create an instance of NeuronRandNetParamtruct
par_instance = NeuronRandNetParamtruct(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)

# Get its type
_neuronRandNetParamStruct_signature = typeof(par_instance)

2. Use the obtained type in function signature

@njit(types.Tuple((types.float64, types.float64, types.float64, types.float64))(
    types.float64, types.float64, types.float64, _neuronRandNetParamStruct_signature
))
def GLNetEIRand_adaptthresh_iter(V, X, synapticInput, par):
    par.theta = par.theta * (1.0 - par.tauTinv + par.uT * X)
    V = (par.mu * V + par.I + par.Iext + synapticInput) * (1.0 - X)
    X = bool2float(random.random() < (PHI(V, par.theta, par.Gamma) * (1.0 - par.P_poisson) + par.P_poisson))
    return V, X, par.theta, par.P_poisson

Now _neuronRandNetParamStruct_signature contains the correct Numba type for your NeuronRandNetParamtruct structure, allowing you to use it in the function signature.

“In most cases you don’t have to specify types. Numba will determine them the first time you call the function. If you are interested in the types you can call my_funct.nopython_signatures to see what was detected. –
max9111
CommentedFeb 15, 2021 at 12:22”
python 3.x - how to define multiple signatures for a function in numba - Stack Overflow