Get type of structref for function signature

Basically I have a structref, and a function takes it as an argument and I need the function signature.

I used to use a jitclass for this Position class for my chess engine, but now I’m trying to use a structref instead because jitclasses take too long to JIT compile.

@structref.register
class PositionStructType(types.StructRef):
    def preprocess_fields(self, fields):
        return tuple((name, types.unliteral(typ)) for name, typ in fields)


class PositionStruct(structref.StructRefProxy):
    def __new__(cls, board,
                white_pieces,
                black_pieces,
                king_positions,
                castle_ability_bits,
                ep_square, side, hash_key):

        return structref.StructRefProxy.__new__(cls, board, white_pieces, black_pieces,
            king_positions, castle_ability_bits, ep_square, side, hash_key)

    @property
    def board(self):
        return PositionStruct_get_board(self)

    @property
    def white_pieces(self):
        return PositionStruct_get_white_pieces(self)

    @property
    def black_pieces(self):
        return PositionStruct_get_black_pieces(self)

    @property
    def king_positions(self):
        return PositionStruct_get_king_positions(self)

    @property
    def castle_ability_bits(self):
        return PositionStruct_get_castle_ability_bits(self)

    @property
    def ep_square(self):
        return PositionStruct_get_ep_square(self)

    @property
    def side(self):
        return PositionStruct_get_side(self)

    @property
    def hash_key(self):
        return PositionStruct_get_hash_key(self)


@njit
def PositionStruct_get_board(self):
    return self.board


@njit
def PositionStruct_get_white_pieces(self):
    return self.white_pieces


@njit
def PositionStruct_get_black_pieces(self):
    return self.black_pieces


@njit
def PositionStruct_get_king_positions(self):
    return self.king_positions


@njit
def PositionStruct_get_castle_ability_bits(self):
    return self.castle_ability_bits


@njit
def PositionStruct_get_ep_square(self):
    return self.ep_square


@njit
def PositionStruct_get_side(self):
    return self.side


@njit
def PositionStruct_get_hash_key(self):
    return self.hash_key


structref.define_proxy(PositionStruct, PositionStructType, ["board", "white_pieces", "black_pieces",
                                                            "king_positions", "castle_ability_bits",
                                                            "ep_square", "side", "hash_key"])


@njit
def init_position():
    position = PositionStruct(np.zeros(120, dtype=np.int8),
                              [nb.int64(1) for _ in range(0)], [nb.int64(1) for _ in range(0)],
                              np.zeros(2, dtype=np.uint8), 0, 0, 0, 0)

    return position

That above is my structref, and I have a function that takes the position as an argument and needs to return an unsigned 64-bit integer.

@nb.njit(nb.uint64(Position.class_type.instance_type), cache=True)

I used to use this for the function signature, however with the structref this no longer works.

Can you leave the signature definition out of the njit call?

I want to return a uint64, and in the future I might try to do ahead of time compilation. I guess I could leave it out, but it would be better if I can use it.

Can you post a minimal complete running example without the signature?

This is the function, which returns a uint64 hash key. Without the function signature it doesn’t run correctly.

@nb.njit(nb.uint64(Position.class_type.instance_type), cache=True)
def compute_hash(position):
    code = 0

    for i in range(64):
        pos = STANDARD_TO_MAILBOX[i]
        if position.board[pos] > BLACK_KING:
            continue

        code ^= PIECE_HASH_KEYS[position.board[pos]][i]

    if position.ep_square:
        code ^= EP_HASH_KEYS[MAILBOX_TO_STANDARD[position.ep_square]]

    code ^= CASTLE_HASH_KEYS[position.castle_ability_bits]

    if position.side:  # side 1 is black, 0 is white
        code ^= SIDE_HASH_KEY

    return code

These are the constants used in the code

PIECE_HASH_KEYS = np.random.randint(1, 2**64 - 1, size=(12, 64), dtype=np.uint64)
EP_HASH_KEYS = np.random.randint(1, 2**64 - 1, size=64, dtype=np.uint64)
CASTLE_HASH_KEYS = np.random.randint(1, 2 ** 64 - 1, size=16, dtype=np.uint64)
SIDE_HASH_KEY = np.random.randint(1, 2 ** 64 - 1, dtype=np.uint64)

I tested, without the function signature it does work. But if I wanted to ahead of time compile it in the future then I guess I would need it, but I’m not sure if it’s possible to get the type of a structref.

Can you call it once without the signature and then print the signature of the jitted function?

How do I print the signature of the function?