Basically I have a structref, and a function takes it as an argument and I need the function signature.
I used to use a jitclass for this Position class for my chess engine, but now I’m trying to use a structref instead because jitclasses take too long to JIT compile.
@structref.register
class PositionStructType(types.StructRef):
def preprocess_fields(self, fields):
return tuple((name, types.unliteral(typ)) for name, typ in fields)
class PositionStruct(structref.StructRefProxy):
def __new__(cls, board,
white_pieces,
black_pieces,
king_positions,
castle_ability_bits,
ep_square, side, hash_key):
return structref.StructRefProxy.__new__(cls, board, white_pieces, black_pieces,
king_positions, castle_ability_bits, ep_square, side, hash_key)
@property
def board(self):
return PositionStruct_get_board(self)
@property
def white_pieces(self):
return PositionStruct_get_white_pieces(self)
@property
def black_pieces(self):
return PositionStruct_get_black_pieces(self)
@property
def king_positions(self):
return PositionStruct_get_king_positions(self)
@property
def castle_ability_bits(self):
return PositionStruct_get_castle_ability_bits(self)
@property
def ep_square(self):
return PositionStruct_get_ep_square(self)
@property
def side(self):
return PositionStruct_get_side(self)
@property
def hash_key(self):
return PositionStruct_get_hash_key(self)
@njit
def PositionStruct_get_board(self):
return self.board
@njit
def PositionStruct_get_white_pieces(self):
return self.white_pieces
@njit
def PositionStruct_get_black_pieces(self):
return self.black_pieces
@njit
def PositionStruct_get_king_positions(self):
return self.king_positions
@njit
def PositionStruct_get_castle_ability_bits(self):
return self.castle_ability_bits
@njit
def PositionStruct_get_ep_square(self):
return self.ep_square
@njit
def PositionStruct_get_side(self):
return self.side
@njit
def PositionStruct_get_hash_key(self):
return self.hash_key
structref.define_proxy(PositionStruct, PositionStructType, ["board", "white_pieces", "black_pieces",
"king_positions", "castle_ability_bits",
"ep_square", "side", "hash_key"])
@njit
def init_position():
position = PositionStruct(np.zeros(120, dtype=np.int8),
[nb.int64(1) for _ in range(0)], [nb.int64(1) for _ in range(0)],
np.zeros(2, dtype=np.uint8), 0, 0, 0, 0)
return position
That above is my structref, and I have a function that takes the position as an argument and needs to return an unsigned 64-bit integer.
@nb.njit(nb.uint64(Position.class_type.instance_type), cache=True)
I used to use this for the function signature, however with the structref this no longer works.