Hello,
I was hoping to include the function scipy.linalg.ordqz into a numba njit function. This function calls two LAPACK functions, and I found this issue on github (Issue 5301, sorry I’m not allowed to include links) showing how to do this, so I was initially optimistic. The first function I tried to implement was zgges
, which is the complex generalized Schur decomposition. I’m on windows, numba 0.55.1, python 3.10.4. The windows thing turns out to be relevant, quelle surprise!
Anyway, here is my attempt to implement the function:
from numba import njit
import numpy as np
from numba.extending import get_cython_function_address
import ctypes
# Datatype pointers to give to the cython LAPACK functions
_PTR = ctypes.POINTER
_dbl = ctypes.c_double
_int = ctypes.c_int
_ptr_dbl = _PTR(_dbl)
_ptr_int = _PTR(_int)
# zgges is the complex QZ-decomposition
zgges_addr = get_cython_function_address('scipy.linalg.cython_lapack', 'zgges')
zgges_functype = ctypes.CFUNCTYPE(None,
_ptr_int, # JOBVSL
_ptr_int, # JOBVSR
_ptr_int, # SORT
_ptr_int, # SELCTG
_ptr_int, # N
_ptr_dbl, # A, complex
_ptr_int, # LDA
_ptr_dbl, # B, complex
_ptr_int, # LDB
_ptr_int, # SDIM
_ptr_dbl, # ALPHA, complex
_ptr_dbl, # BETA, complex
_ptr_dbl, # VSL, complex
_ptr_int, # LDVSL
_ptr_dbl, # VSR, complex
_ptr_int, # LDVSR
_ptr_dbl, # WORK, complex
_ptr_int, # LWORK
_ptr_dbl, # RWORK
_ptr_int, # BWORK
_ptr_int) # INFO
zgges_fn = zgges_functype(zgges_addr)
@njit
def numba_zgges(x, y):
_M, _N = x.shape
A = x
B = y
JOBVSL = np.array([ord('V')], dtype=np.int32)
JOBVSR = np.array([ord('V')], dtype=np.int32)
SORT = np.array([ord('N')], dtype=np.int32)
SELCTG = np.empty(1, dtype=np.int32)
N = np.array(_N, dtype=np.int32)
LDA = np.array(_N, dtype=np.int32)
LDB = np.array(_N, dtype=np.int32)
SDIM = np.array(0, dtype=np.int32) # out
ALPHA = np.empty(_N, dtype=np.complex128) # out
BETA = np.empty(_N, dtype=np.complex128) # out
LDVSL = np.array(_N, dtype=np.int32)
VSL = np.empty((_N, _N), dtype=np.complex128) # out
LDVSR = np.array(_N, dtype=np.int32)
VSR = np.empty((_N, _N), dtype=np.complex128) # out
WORK = np.empty((1,), dtype=np.complex128) #out
LWORK = np.array(-1, dtype=np.int32)
RWORK = np.empty(_N, dtype=np.float64)
BWORK = np.empty(_N, dtype=np.int32)
INFO = np.empty(1, dtype=np.int32)
zgges_fn(JOBVSL.ctypes,
JOBVSR.ctypes,
SORT.ctypes,
SELCTG.ctypes,
N.ctypes,
A.view(np.float64).ctypes,
LDA.ctypes,
B.view(np.float64).ctypes,
LDB.ctypes,
SDIM.ctypes,
ALPHA.view(np.float64).ctypes,
BETA.view(np.float64).ctypes,
VSL.view(np.float64).ctypes,
LDVSL.ctypes,
VSR.view(np.float64).ctypes,
LDVSR.ctypes,
WORK.view(np.float64).ctypes,
LWORK.ctypes,
RWORK.ctypes,
BWORK.ctypes,
INFO.ctypes)
print("Calculated workspace size as", WORK[0])
WS_SIZE = np.int32(WORK[0].real)
LWORK = np.array(WS_SIZE, np.int32)
WORK = np.empty(WS_SIZE, dtype=np.complex128)
zgges_fn(JOBVSL.ctypes,
JOBVSR.ctypes,
SORT.ctypes,
SELCTG.ctypes,
N.ctypes,
A.view(np.float64).ctypes,
LDA.ctypes,
B.view(np.float64).ctypes,
LDB.ctypes,
SDIM.ctypes,
ALPHA.view(np.float64).ctypes,
BETA.view(np.float64).ctypes,
VSL.view(np.float64).ctypes,
LDVSL.ctypes,
VSR.view(np.float64).ctypes,
LDVSR.ctypes,
WORK.view(np.float64).ctypes,
LWORK.ctypes,
RWORK.ctypes,
BWORK.ctypes,
INFO.ctypes)
# The LAPACK function also returns SDIM, WORK, BWORK, but I don't need them here.
return A, B, ALPHA, BETA, VSL.T, VSR.T, INFO
This function actually works great on small matrices. Here I reproduce the relevant test case from scipy:
n = 5
A = np.random.random([n, n])
B = np.random.random([n, n])
AA, BB, a, b, Q, Z, info = numba_zgges(np.asfortranarray(A, dtype='D'),
np.asfortranarray(B, dtype='D'))
aa = Q @ AA @ Z.conj().T
assert np.allclose(aa.real, A)
assert np.allclose(aa.imag, 0)
bb = Q @ BB @ Z.conj().T
assert np.allclose(bb.real, B)
assert np.allclose(bb.imag, 0)
assert np.allclose(Q @ Q.conj().T, np.eye(n))
assert np.allclose(Z @ Z.conj().T, np.eye(n))
assert np.all(np.diag(BB) >= 0)
When n <= 5
all the tests pass, but for larger matrices I get a memory heap corruption error, code 0xc0000374. I also get this error when I loop the test over and over, or when I try to pass the results of the function on to another jitted function.
I did find another stack overflow question (again apologies because I can’t add links, here’s the partial address – 64957662/cpython-memory-heap-corruption-issue) that seemed similar, involving a memory heap corruption when repeatedly calling a C function. In that case it was evidently related to de-referencing, a word which I repeat here as if I had any idea what it means.
On the other hand, the example linked on the Numba github that implements zgees
does happily loop for arbitrarily large matrices – I tested this as well just to be sure it wasn’t a more general problem. In addition, there is a note in the Scipy code about zgges
being janky on windows. So it might be some idiosyncratic to this function?
Anyway, I’m not sure this is numba related per se, but since I got the example from the numba github here I am. I’m hoping someone more experienced with these sorts of Cython issues can spot something obvious I missed right away.
Cheers.