Structured array field index to field name

Is there a better way to do this? I’m trying to convert a index into the fields of a structured array into the corresponding field name.

My workaround is to explicitly pull the field names into a separate tuple, but I was hoping there was something more elegant.

from numba import njit, from_dtype
import numpy as np
x = np.array([('Rex', 9, 81.0), ('Fido', 3, 27.0)], dtype=[('name', 'U10'), ('age', 'i4'), ('weight', 'f4')])

numba_type = from_dtype(x.dtype)
print("x.dtype.names", x.dtype.names)  # print the field names using dtype
print("numba_type.members", [x[0] for x in numba_type.members]) # print the field names using Record

names = tuple(x.dtype.names)

@njit
def get_name(i):
    # this works in both plain python and jitted
    return names[i]

    # both of these methods work in plain python, neither works jitted
    # return x.dtype.names[i]
    # return numba_type.members[i][0]

print(get_name(1))

Hi @nelson2005,

I don’t think there’s a way to deal with something like x.dtype.names[i] in Numba yet, and certainly not the numba_type.members[i][0]. You can however do something like this to get a resolution at compile time:

from numba import njit, from_dtype, literally, types
import numpy as np
x = np.array([('Rex', 9, 81.0), ('Fido', 3, 27.0)], dtype=[('name', 'U10'), ('age', 'i4'), ('weight', 'f4')])

numba_type = from_dtype(x.dtype)
print("x.dtype.names", x.dtype.names)  # print the field names using dtype
print("numba_type.members", [x[0] for x in numba_type.members]) # print the field names using Record

names = tuple(x.dtype.names)

from numba.extending import overload

# going to overload this function
def resolve(idx, dt):
    pass

@overload(resolve)
def ol_resolve(idx, dt):
    # this forces `idx` to be a literal value
    if not isinstance(idx, types.IntegerLiteral):
        return lambda idx, dt: literally(idx)
    else:
        # this fishes out the field index into name and impl closes over it
        index = idx.literal_value
        name = [x for x in dt.dtype.fields][index]
        def impl(idx, dt):
            return name
        return impl

@njit
def get_name(i, x_inst):
    a = names[i]
    b = resolve(i, x_inst.dtype)
    return a, b

print(get_name(1, x))

this uses the numba.extending.overload API.

Hope this helps?

hmmm, I need to noodle over this one a little… does this mean that the ‘i’ must be literally typed in the code? Or could I write something like

[get_name(i, x)) for i in range(10)]

Depends what you are doing, if the indexing is into something heterogeneous in type, generally yes, if not, generally no. This is another way to do the above, but relies on a baked in tuple again.

from numba import njit, from_dtype, literally, types, literal_unroll
import numpy as np
x = np.array([('Rex', 9, 81.0), ('Fido', 3, 27.0)], dtype=[('name', 'U10'), ('age', 'i4'), ('weight', 'f4')])

numba_type = from_dtype(x.dtype)
print("x.dtype.names", x.dtype.names)  # print the field names using dtype
print("numba_type.members", [x[0] for x in numba_type.members]) # print the field names using Record

names = tuple(x.dtype.names)

from numba.extending import overload

#going to overload this function
def resolve(idx, dt):
    pass

@overload(resolve)
def ol_resolve(idx, dt):
    names = tuple(dt.dtype.fields) # this is just going to get baked in
    def impl(idx, dt):
        return names[idx]
    return impl

@njit
def foo():
    print([resolve(i, x) for i in range(3)])

foo()