Extending Numba for CUDA inside Cppyy

Hello,
I am a beginner in Numba and CUDA. I am working on a project Cppyy where I am trying to add support of passing C++ functions written inside cppdef to numba cuda kernel(decorated by @cuda.jit ).
For example, something like this should work:

from numba import cuda
import numpy as np
import cppyy
import cppyy.numba_ext
import math
cppyy.cppdef('''
             int foo(){
            return 42;
             }
             ''')

@cuda.jit()
def abs_kernel(x, out):
    pos = cuda.grid(1)
    if pos < x.size:
        out[pos] = math.fabs(x[pos])
    cppyy.gbl.foo()
 
n = 100000
# Example array with negative values
x = np.arange(-n, n).astype(np.float32)  
out = np.empty_like(x)
print("Before operation:", out[:10])
threads_per_block = 128
blocks_per_grid = (out.size + (threads_per_block - 1)) // threads_per_block
abs_kernel[blocks_per_grid, threads_per_block](x, out)
print("After operation:", out[:10])

foo() should be called inside the cuda JITed kernel as shown above but right now it doesn’t work out of the box. To add the support of the following:
Step 1: To follow the logic in lower_external_call in cppyy/numba_ext.py. That will be called to lower the C++ call to LLVM IR. As-is, it re-uses the call infrastructure to external call for the CPU context. Instead, it should check the type of the context and if it is GPU, and then use the CUDA driver logic to dispatch the function. (This should include changes to get_pointer function to get the pointer to the PTX, i.e. device, compiled function.) launch_kernel logic in the driver.py file can be used for the above.

@nb_iutils.lower_builtin(ol, *args)
        def lower_external_call(context, builder, sig, args,
                ty=nb_types.ExternalFunctionPointer(extsig, ol.get_pointer), pyval=self._func, is_method=self._is_method):
            ptrty = context.get_function_pointer_type(ty)
            ptrval = context.add_dynamic_addr(
            builder, ty.get_pointer(pyval), info=str(pyval))
            fptr = builder.bitcast(ptrval, ptrty)
            return context.call_function_pointer(builder, fptr, args)

        return ol.sig

I am not sure on how to check the type of the context if it is GPU or not in the above code. Can anyone explain me how to proceed with this task? Any type of help would be appreciable :slight_smile:

I don’t know why I am not able to add links to my above post but these are the references:

Hi @gmarkall,
Maybe you could check the above issue and discuss what can be done.

@chococandy63 ,
I think that the functionality that you want already exists. It’s in numba if you are willing to put the C++ code in a file. Otherwise, you need to use pynvjitlink.
You can look at Calling foreign functions from Python kernels — Numba 0+untagged.871.g53e976f.dirty documentation.
Also, take a look at Pass NULL pointer for cffi.FFI().from_buffer(empty array) · Issue #9654 · numba/numba · GitHub for an issue I wrote that uses the link feature of the cuda.jit decorator.

Hope this helps.
-Ed

1 Like

@ed-o-saurus,
How do I check the context(GPU or CPU) in lowering function(lower_external_call)?

@nb_iutils.lower_builtin(ol, *args)
        def lower_external_call(context, builder, sig, args,
                ty=nb_types.ExternalFunctionPointer(extsig, ol.get_pointer), pyval=self._func, is_method=self._is_method):
            ptrty = context.get_function_pointer_type(ty)
            ptrval = context.add_dynamic_addr(
            builder, ty.get_pointer(pyval), info=str(pyval))
            fptr = builder.bitcast(ptrval, ptrty)
            return context.call_function_pointer(builder, fptr, args)

        return ol.sig

Right now CppFunctionNumbaType can only handle cpu context not gpu and we want to enable its handling so for that changes needs to be done in lower_external_call to check if the context is gpu or cpu, and based on this information lowering can happen accordingly.

Maybe need to do few changes in get_pointer function
Here’s the get_pointer function code:https://github.com/wlav/cppyy/blob/master/python/cppyy/numba_ext.py#L241
and lower_external_call: https://github.com/wlav/cppyy/blob/master/python/cppyy/numba_ext.py#L224,
CppFunctionNumbaType: https://github.com/wlav/cppyy/blob/master/python/cppyy/numba_ext.py#L175

@ed-o-saurus mentioned this already, but I put together a small example using pynvjitlink to embed the CUDA C++ code in the Python source:

from numba import cuda, int32
from pynvjitlink import patch

patch.patch_numba_linker()

cu_functions = cuda.CUSource('''
extern "C" __device__ int foo(int* return_value){
  *return_value = 42;
  return 0;
}
''')

foo = cuda.declare_device('foo', int32())


@cuda.jit(link=[cu_functions])
def kernel():
    print(foo())


kernel[1, 1]()
cuda.synchronize()

I think extending cppyy’s numba_ext to support the CUDA target will also require some source-code processing to handle the differences between the usual C++ ABI and that used by Numba internally, along with qualifying functions with the __device__ keyword. There are maybe other things I haven’t thought about yet.

To answer your original question about how to determine the current target, I think you can do:

from numba.core.target_extension import current_target
target = current_target()

and target will then be either "cuda" or "cpu" depending on the current target.

1 Like

@gmarkall Thanks for the example but I saw in the comments here https://github.com/numba/numba/pull/9470#issuecomment-2186254416 that I can only use pynvjitlink with CUDA 12.0 or later. I am on CUDA 11.2 so, I guess I will only have to go with the approach mentioned on numba docs(Calling FFI from python kernels: https://numba.readthedocs.io/en/stable/cuda/cuda_ffi.html)

Well, talking about the extending numba inside cppyy to support CUDA. I’m able to determine the target now. I have to make some changes to get_pointer(https://github.com/wlav/cppyy/blob/master/python/cppyy/numba_ext.py#L241) to get the pointer to the PTX, i.e. device, compiled function.

def get_pointer(self, func, context='cpu'):
        if func is None: func = self._func
        ol = func.__overload__(numba_arg_convertor(self.sig.args))
        address = cppyy.addressof(ol)
        if not address:
            raise RuntimeError("unresolved address for %s" % str(ol))
        return address

This get_pointergets called by lower_external_call function here: https://github.com/wlav/cppyy/blob/master/python/cppyy/numba_ext.py#L224

Could you give me some hints/resources/files on how this can be done?

Numba has a built-in (albeit slightly limited) interface to NVRTC which you might be able to use to compile CUDA C++ to PTX, which you can then add to the link - here’s a simple example:

from numba.cuda.cudadrv import nvrtc

code = '''
extern "C" __device__ int foo(int* return_value){
  *return_value = 42;
  return 0;
}
'''

# Returns a tuple of the PTX code and the compiler output messages
# Arguments are (<source code>, 
#                <source filename>, 
#                <compute capability as a tuple>)
ptx, log = nvrtc.compile(code, "test.cu", (8, 9))

print(ptx)

which prints (slightly abridged):

//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-34177558
// Cuda compilation tools, release 12.5, V12.5.40
// Based on NVVM 7.0.1
//

.version 8.5
.target sm_89
.address_size 64

	// .globl	foo

.visible .func  (.param .b32 func_retval0) foo(
	.param .b64 foo_param_0
)
{
	.reg .b32 	%r<3>;
	.reg .b64 	%rd<2>;

	ld.param.u64 	%rd1, [foo_param_0];
	mov.u32 	%r1, 42;
	st.u32 	[%rd1], %r1;
	mov.u32 	%r2, 0;
	st.param.b32 	[func_retval0+0], %r2;
	ret;
}

There are a couple of points to be aware of:

  • Arbitrary function pointers are not supported in CUDA, because of the way register allocation works (kernels can use different numbers of registers, but their register usage needs to be declared and is static for the kernel, so a device function that uses some unknown number of registers is not possible)
  • Therefore, in cppyy you will need an alternative mechanism to function pointers to call the C++ functions. The existing built-in CUDA C/C++ support does this with the declare_device() function to insert Numba typing and lowering that implements the call using the name of the C/C++ device function - you may be able to do the same in the cppyy implementation.

If you need to be able to have more flexibility in the way that NVRTC is called (e.g. using arbitrary flags, include paths, etc) then you can use the CUDA Python bindings to access it more directly:

1 Like

@gmarkall I don’t quite understand why you have mentioned "test.cu" source file here ptx, log = nvrtc.compile(code, "test.cu", (8, 9))?
The CUDA code is already defined inside the foo() function starting with extern ? Can you tell me why do we need another .cu file and what would be the content of that file ?

from numba.cuda.cudadrv import nvrtc

code = '''
extern "C" __device__ int foo(int* return_value){
  *return_value = 42;
  return 0;
}
'''

# Returns a tuple of the PTX code and the compiler output messages
# Arguments are (<source code>, 
#                <source filename>, 
#                <compute capability as a tuple>)
ptx, log = nvrtc.compile(code, "test.cu", (8, 9))

print(ptx)

I was thinking of a way through which we can create a Cppyy helper function like cudadef for defining CUDA code inside python files(without needing to create external .cu files ).
Something like this:

import cppyy
import cppyy.numba_ext
cppyy.cudadef('''
__global__ void MatrixMul(float* A, float* B, float* out) {
    // kernel logic for matrix multiplication
}
''')

@cuda.jit
def run_cuda_mul(A, B, out):
    # Allocate memory for input and output arrays on GPU
    # Define grid and block dimensions
    # Launch the kernel
    MatrixMul[griddim, blockdim](d_A, d_B, d_out)

But right now I haven’t figured out how this can be done yet. Do you have any thoughts on this ?

test.cu isn’t another source file, it’s the name of the file containing the code we’re passing in - it could be anything, but in a typical use case it would match the name of a file read from disk (though we’re not reading from disk here). The utility of it is so that compiler error messages can refer to the file in feedback to the user.

For example, if I introduce an error in the source (declaring a kernel with a non-void return type, which is not allowed):

from numba.cuda.cudadrv import nvrtc

code = '''
extern "C" __device__ int foo(int* return_value){
  *return_value = 42;
  return 0;
}

__global__ int k();
'''

ptx, log = nvrtc.compile(code, "test.cu", (8, 9))

then we get:

Traceback (most recent call last):
  File "/home/gmarkall/numbadev/issues/discourse-2644/nvrtc_bad_example.py", line 12, in <module>
    ptx, log = nvrtc.compile(code, "test.cu", (8, 9))
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gmarkall/numbadev/numba/numba/cuda/cudadrv/nvrtc.py", line 252, in compile
    raise NvrtcError(msg)
numba.cuda.cudadrv.error.NvrtcError: NVRTC Compilation failure whilst compiling test.cu:

test.cu(7): error: a __global__ function must have a void return type
  __global__ int k();
  ^

1 error detected in the compilation of "test.cu".

The name “test.cu” passed into the compile function appears in the error message to refer to the line causing the error. If you’re passing in code that’s generated at runtime, you can just give a name like "<generated-code>" or anything you think would help the user to understand an issue if they see a compiler error message.

pynvjitlink does it by patching Numba to include the CUSource class and to modify the linker - once this functionality is upstreamed in Numba (this is on the to-do list but not yet started) you would be able to use it as the basis of your cppyy.cudadef implementation. In the meantime, this functionality is not CUDA 12-specific, so you could copy the necessary parts of the implementation from pynvjitlink:

The other thing your cudadef method will have to do is generate a stub that adapts the C++ function to match the Numba ABI (which is documented here) and call cuda.declare_device with the function name and argument types - this will need to be done for each device function.

For kernels, like in your example, you will have to generate a wrapper function to match the ABI with which Numba will call it. This isn’t documented (or expected to be stable, though we can discuss making it documented and stable) but you can follow Numba’s logic for generating wrappers in target.py in the generate_kernel_wrapper() function.

A couple of other questions on your example:

  • The run_cuda_mul() function is decorated with @cuda.jit, which makes a function a kernel. However, from the code inside the function, I think your intention is that this function runs on the host, and only launches a kernel on its final line?
  • There’s nothing putting MatrixMul into scope in the Python file - should it have been assigned the result of the cppyy.cudadef call?

i.e., should it look like:

import cppyy
import cppyy.numba_ext

MatrixMul = cppyy.cudadef('''
__global__ void MatrixMul(float* A, float* B, float* out) {
    // kernel logic for matrix multiplication
}
''')

def run_cuda_mul(A, B, out):
    # Allocate memory for input and output arrays on GPU
    # Define grid and block dimensions
    # Launch the kernel
    MatrixMul[griddim, blockdim](d_A, d_B, d_out)

?

@gmarkall When I try to run this code:

from numba.cuda.cudadrv import nvrtc

code = '''
extern "C" __device__ int foo(int* return_value){
  *return_value = 42;
  return 0;
}
'''

# Returns a tuple of the PTX code and the compiler output messages
# Arguments are (<source code>, 
#                <source filename>, 
#                <compute capability as a tuple>)
ptx, log = nvrtc.compile(code, "test.cu", (8, 9))

print(ptx)

I get the following errors:

Traceback (most recent call last):
  File "foo.py", line 14, in <module>
    ptx, log = nvrtc.compile(code, "test.cu", (8, 9))
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/cudadrv/nvrtc.py", line 244, in compile
    compile_error = nvrtc.compile_program(program, options)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/cudadrv/nvrtc.py", line 177, in compile_program
    self.nvrtcCompileProgram(program.handle, len(options), c_options)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/cudadrv/nvrtc.py", line 132, in checked_call
    raise NvrtcError(msg)
numba.cuda.cudadrv.error.NvrtcError: Failed to call nvrtcCompileProgram: NVRTC_ERROR_INVALID_OPTION

My system info:
Ubuntu 20.04.6 LTS x86_64, CUDA 11.2, python 3.8.10

It’s because (8, 9) refers to a compute capability of 8.9, which isn’t supported by CUDA 11.2. I picked it because my device has compute capability 8.9, but you should pick the CC that matches your device - you can see it with cuda.detect(), e.g.:

$ python -c "from numba import cuda; cuda.detect()"
Found 1 CUDA devices
id 0    b'NVIDIA RTX 3500 Ada Generation Laptop GPU'                              [SUPPORTED]
                      Compute Capability: 8.9
                           PCI Device ID: 0
                              PCI Bus ID: 1
                                    UUID: GPU-4e38c271-caf3-156e-a3a9-5e39d39ea974
                                Watchdog: Enabled
             FP32/FP64 Performance Ratio: 64
Summary:
	1/1 devices are supported

Just substitute the tuple for what you see in the Compute Capability row - e.g. if it’s 7.5, then (7, 5) would be appropriate.

You can also automate getting an appropriate tuple for your device with cuda.current_context().device.compute_capability - e.g. on my system:

$ python -c "from numba import cuda; print(cuda.current_context().device.compute_capability)"
(8, 9)
1 Like

Thanks, it works now! For me it shows (7, 5)
@gmarkall You provided above the code to compile CUDA into ptx using nvrtc but how do I find the pointer to the function(here, foo) from the ptx code ?
I tried something like this in order to get the pointer to the function:

from numba.cuda.cudadrv import devices, driver, nvvm, runtime

code ='''
__device__ float foo(float a) {
    return fabs(a);
}
'''
ptx, log = nvrtc.compile(code, "test.cu", (7, 5))
print(ptx)

The ptx for this is generated as follows:

//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-29618528
// Cuda compilation tools, release 11.2, V11.2.152
// Based on NVVM 7.0.1
//

.version 7.2
.target sm_75
.address_size 64

        // .globl       _Z3foof

.visible .func  (.param .b32 func_retval0) _Z3foof(
        .param .b32 _Z3foof_param_0
)
{
        .reg .f32       %f<3>;


        ld.param.f32    %f1, [_Z3foof_param_0];
        abs.f32         %f2, %f1;
        st.param.f32    [func_retval0+0], %f2;
        ret;

}

If I add these lines after printing ptx in the above code:

ctx = devices.get_context()
module = ctx.create_module_ptx(ptx)
foo = module.get_function("_Z3foof")

I get the following error as stated:

Traceback (most recent call last):
  File "12.py", line 53, in <module>
    foo = module.get_function("_Z3foof")
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/cudadrv/driver.py", line 2414, in get_function
    driver.cuModuleGetFunction(byref(handle), self.handle,
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/cudadrv/driver.py", line 327, in safe_cuda_api_call
    self._check_ctypes_error(fname, retcode)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/cudadrv/driver.py", line 395, in _check_ctypes_error
    raise CudaAPIError(retcode, msg)
numba.cuda.cudadrv.driver.CudaAPIError: [500] Call to cuModuleGetFunction results in CUDA_ERROR_NOT_FOUND

I am also curious about how I got _Z3foof function name in the ptx whereas you got the same name foo. Not sure why this error might be coming, do you have any ideas? I might be missing something and why it changed from foo to _Z3foof?

Also, instead of using the above method of compiling cuda using nvrtc and then storing ptx if I have to try loading the external .ptx file into my cppyy file then how I can achieve that?
I tried searching for some APIs from numba/cuda/cudadrv/driver.py like load_module_image (https://github.com/numba/numba/blob/main/numba/cuda/cudadrv/driver.py#L1531) which takes context and image as parameters. This numba abstraction is calling cuModuleLoadDataEx under the hood which I thought might be something that would be useful for my case. But I am still not sure how to use them on my side. Could you guide me please?

TL;DR
so here are the tasks that I have to perform:

  • To load external .ptx file into my cppyy python file.
  • To find the pointer to the function from the ptx file that was loaded.

@gmarkall Reply to your above questions on cudadef implementation:

Currently, cudadef definition exists in cppyy which uses Cling for incremental compilation during runtime. This is the implementation: https://github.com/chococandy63/cppyy/blob/cuda/python/cppyy/__init__.py#L211-L221
So, here are few points:

  • cudadefconsist of CUDA code that is to be compiled by Cling(you might be aware of Cling which is a C++ interpreter, https://github.com/root-project/cling), it calls IncrementalCUDADeviceCompiler under the hood.
  • So, right now the implementation in cppyy is such that it exposes the CUDA symbols to Cling through PCH(pre-compiled headers). For information on this: https://cppyy.readthedocs.io/en/latest/cuda.html.

For example:
With current support, if I try to use cudadef as shown below:

from numba import cuda
import numpy as np
import cppyy
import cppyy.numba_ext
import math
import numpy as np
import os
os.environ['CLING_ENABLE_CUDA'] = '1'
os.environ['CLING_CUDA_PATH'] = '/usr/local/cuda'
cppyy.add_include_path('/usr/local/cuda/include')
cppyy.add_library_path('/usr/local/cuda/lib64')
cppyy.include('iostream')
cppyy.include("cuda.h")
cppyy.include("cuda_runtime.h")
cppyy.load_library("cudart")
cppyy.cudadef('''
      __device__ int foo(){
            return 42;
            }
             ''')
@cuda.jit()
def abs_kernel(x, out):
    a=cppyy.gbl.foo()
    pos = cuda.grid(1)
    if pos < x.size:
        out[pos] = math.fabs(x[pos]) + a

n = 100000
x = np.arange(-n, n).astype(np.float32)  # Example array with negative values
out = np.empty_like(x)
print("Before operation:", out[:10])
threads_per_block = 128
blocks_per_grid = (out.size + (threads_per_block - 1)) // threads_per_block
abs_kernel[blocks_per_grid, threads_per_block](x, out)
print("After operation:", out[:10]) 

I get an error that says:

Before operation: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
13.py:74: NumbaPendingDeprecationWarning: Code using Numba extension API maybe depending on 'old_style' error-capturing, which is deprecated and will be replaced by 'new_style' in a future release. See details at https://numba.readthedocs.io/en/latest/reference/deprecation.html#deprecation-of-old-style-numba-captured-errors
Exception origin:
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/typing/context.py", line 348, in resolve_module_constants
    attrval = getattr(typ.pymod, attr)

  a=cppyy.gbl.foo()
Traceback (most recent call last):
  File "13.py", line 85, in <module>
    abs_kernel[blocks_per_grid, threads_per_block](x, out)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/dispatcher.py", line 539, in __call__
    return self.dispatcher.call(args, self.griddim, self.blockdim,
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/dispatcher.py", line 673, in call
    kernel = _dispatcher.Dispatcher._cuda_call(self, *args)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/dispatcher.py", line 681, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/dispatcher.py", line 924, in compile
    kernel = _Kernel(self.py_func, argtypes, **self.targetoptions)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/dispatcher.py", line 83, in __init__
    cres = compile_cuda(self.py_func, types.void, self.argtypes,
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/compiler.py", line 194, in compile_cuda
    cres = compiler.compile_extra(typingctx=typingctx,
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler.py", line 770, in compile_extra
    return pipeline.compile_extra(func)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler.py", line 461, in compile_extra
    return self._compile_bytecode()
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler.py", line 529, in _compile_bytecode
    return self._compile_core()
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler.py", line 508, in _compile_core
    raise e
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler.py", line 495, in _compile_core
    pm.run(self.state)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 368, in run
    raise patched_exception
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 356, in run
    self._runPass(idx, pass_inst, state)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 311, in _runPass
    mutated |= check(pss.run_pass, internal_state)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 273, in check
    mangled = func(compiler_state)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/typed_passes.py", line 110, in run_pass
    typemap, return_type, calltypes, errs = type_inference_stage(
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/typed_passes.py", line 91, in type_inference_stage
    errs = infer.propagate(raise_errors=raise_errors)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/typeinfer.py", line 1086, in propagate
    raise errors[0]
numba.core.errors.TypingError: Failed in cuda mode pipeline (step: nopython frontend)
Internal error at resolving type of attribute "foo" of "$4load_attr.1".
<namespace cppyy.gbl at 0x2b1c260> has no attribute 'foo'. Full details:
  type object '' has no attribute 'foo'
  'foo' is not a known C++ class
  'foo' is not a known C++ template
  'foo' is not a known C++ enum
During: typing of get attribute at 13.py (74)
Enable logging at debug level for details.

File "13.py", line 74:
def abs_kernel(x, out):
    a=cppyy.gbl.foo()
    ^

There are two Cling interpreters. Both cppdef and cudadef send their code to both, if CUDA has been enabled at compile time of the backend. The reason why the attribute resolution seems to fail, I think, is because meta might be only looking into the non-CUDA interpreter. No idea, how this cling issues can be solved(if you have thoughts then feel free to share) but right now main focus for me is to get the numba flow right. Once the numba side is functional then we can move to Cling side.

The reason why we want to use cling for cudadef handling:

  • Incremental compilation
  • Runtime template instantiations
  • To use common CUDA libraries such as cuBLAS

The run_cuda_mul() function is decorated with @cuda.jit, which makes a function a kernel.
However, from the code inside the function, I think your intention is that this function runs on the host, and only launches a kernel on its final line?

Oh, sorry for the typo. I meant to use @numba.njit there not @cuda.jit.

There’s nothing putting MatrixMul into scope in the Python file - should it have been assigned the result of the cppyy.cudadef call?

Scope is being handed by cppyy internal mechanism so simply defining the foo() function inside cudadef or cppdef, it gets registered to the global namespace(gbl) and using cppyy.gbl.function_name we can access them anywhere in our python file.

Just replying to the first post for now - will take a look at the second one afterwards.

get_function() only works for kernels, not for device functions - so the following will work, which wraps foo() in a kernel:

from numba import cuda
from numba.cuda.cudadrv import nvrtc

code = '''
__device__ float foo(float a) {
    return fabs(a);
}

extern "C"
__global__ void foo_kernel(float *result, float a) {
    *result = foo(a);
}
'''
ptx, log = nvrtc.compile(code, "test.cu", (7, 5))
print(ptx)

ctx = cuda.current_context()
module = ctx.create_module_ptx(ptx)
foo_kernel = module.get_function("foo_kernel")

It’s not really useful to know the address of device functions, because there is no safe way to call them by address anyway - the call needs to be resolvable when the caller is compiled, because registers are allocated for the entire kernel and its callees all at once. Calling an arbitrary device function through a function pointer will in the general case result in register values being corrupted because there is no mechanism to preserve register usage that wasn’t known about at compile time.

The C++ compiler “mangles” the names of compiled functions to disambiguate symbols for multiple overloads of the same function - See e.g. C++ Name Mangling. I didn’t get the mangled name because I added the extern "C" linkage. I usually do this because I find it more convenient, but you can also use the mangled name as long as you do something to compute what it will be.

If I’m understanding the question correctly, it’s fairly similar to the example above, just with the PTX loaded in from a file - let me know if I’ve misunderstood the question, but here’s what I’d do:

from numba import cuda
from numba.cuda.cudadrv.driver import Linker


with open("foo_kernel.ptx", 'rb') as f:
    ptx = f.read()

# Note we create the context before creating the linker, because the linker
# does nothing to ensure there is a valid context (we are using internal
# Numba APIs, so we have to be a bit careful)
ctx = cuda.current_context()

linker = Linker.new()
linker.add_ptx(ptx)
cubin = linker.complete()

module = ctx.create_module_image(cubin)
cufunc = module.get_function("foo_kernel")

where foo_kernel.ptx contains

// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-34177558
// Cuda compilation tools, release 12.5, V12.5.40
// Based on NVVM 7.0.1
//

.version 8.5
.target sm_75
.address_size 64

	// .globl	_Z3foof

.visible .func  (.param .b32 func_retval0) _Z3foof(
	.param .b32 _Z3foof_param_0
)
{
	.reg .f32 	%f<3>;


	ld.param.f32 	%f1, [_Z3foof_param_0];
	abs.f32 	%f2, %f1;
	st.param.f32 	[func_retval0+0], %f2;
	ret;

}
	// .globl	foo_kernel
.visible .entry foo_kernel(
	.param .u64 foo_kernel_param_0,
	.param .f32 foo_kernel_param_1
)
{
	.reg .f32 	%f<3>;
	.reg .b64 	%rd<3>;


	ld.param.u64 	%rd1, [foo_kernel_param_0];
	ld.param.f32 	%f1, [foo_kernel_param_1];
	cvta.to.global.u64 	%rd2, %rd1;
	abs.f32 	%f2, %f1;
	st.global.f32 	[%rd2], %f2;
	ret;

}

Note that this approach will only work for kernels. If you want to compile C++ device functions to be called from Numba kernels, then you will need to ensure that the PTX (or C++) is linked by Numba, e.g. through the link kwarg of the @cuda.jit decorator, so that the compilation process has visibility of all the code at compile time in order to do the register allocation correctly.

1 Like

I’m not really familiar with Cling so I find it a bit hard to be helpful with any suggestions for it - I have used the clangInterpreter a bit when I was experimenting with JITting C++ host and CUDA code (jitipy/example.py at main · gmarkall/jitipy · GitHub) but I never found the time to get deep into that experiment.

At the moment it’s not possible to launch a CUDA kernel from a @numba.njit function directly. A kernel launch seems to work if it’s inside an object-mode context manager though. For example:

from numba import njit, cuda, objmode


@cuda.jit
def kernel():
    print("Hello from CUDA")


@njit
def kernel_launcher():
    print("Hello from njit")
    with objmode():
        kernel[1, 1]()
    print("Hello again from njit")


kernel_launcher()
cuda.synchronize()

which prints

Hello from njit
Hello again from njit
Hello from CUDA

Note that the CUDA print appears later due to the asynchronous nature of kernel launches.

1 Like

@gmarkall

Do you mean something like this? I tried an experiment:

code = '''
__device__ float foo(float a) {
    return fabs(a);
}
extern "C"
__global__ void foo_kernel(float *result, float a) {
    *result = foo(a);
}
'''
ptx, log = nvrtc.compile(code, "test.cu", (7, 5)) # 7.5 is the compute capability of my GPU
print(ptx)

ctx = cuda.current_context()
module = ctx.create_module_ptx(ptx)
cu_func = module.get_function("foo_kernel") #this brings the pointer to the ptx function

# Array and Memory Allocation
array = (c_int * 100)()
memory = ctx.memalloc(sizeof(array))
host_to_device(memory, array, sizeof(array))

# Pointer and Stream Initialization
ptr = memory.device_ctypes_pointer
stream = 0

if _driver.USE_NV_BINDING:
    ptr = c_void_p(int(ptr))
    stream =  _driver.binding.CUstream(stream)
    print(stream)

#Kernel Launch
launch_kernel(cu_func.handle,  # Kernel
                1,   1, 1,        # gx, gy, gz
                100, 1, 1,        # bx, by, bz
                0,                # dynamic shared mem
                stream,           # stream
                [ptr])            # arguments

#Data Transfer Back to Host 
device_to_host(array, memory,sizeof(array) )

If I run this, I get an error that states that something is not right with kernel launch:

//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-29618528
// Cuda compilation tools, release 11.2, V11.2.152
// Based on NVVM 7.0.1
//

.version 7.2
.target sm_75
.address_size 64

        // .globl       _Z3foof

.visible .func  (.param .b32 func_retval0) _Z3foof(
        .param .b32 _Z3foof_param_0
)
{
        .reg .f32       %f<3>;


        ld.param.f32    %f1, [_Z3foof_param_0];
        abs.f32         %f2, %f1;
        st.param.f32    [func_retval0+0], %f2;
        ret;

}
        // .globl       foo_kernel
.visible .entry foo_kernel(
        .param .u64 foo_kernel_param_0,
        .param .f32 foo_kernel_param_1
)
{
        .reg .f32       %f<3>;
        .reg .b64       %rd<3>;


        ld.param.u64    %rd1, [foo_kernel_param_0];
        ld.param.f32    %f1, [foo_kernel_param_1];
        cvta.to.global.u64      %rd2, %rd1;
        abs.f32         %f2, %f1;
        st.global.f32   [%rd2], %f2;
        ret;

}

Traceback (most recent call last):
  File "12.py", line 81, in <module>
    launch_kernel(cu_func.handle,  # Kernel
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/cudadrv/driver.py", line 2563, in launch_kernel
    driver.cuLaunchKernel(cufunc_handle,
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/cudadrv/driver.py", line 327, in safe_cuda_api_call
    self._check_ctypes_error(fname, retcode)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/cudadrv/driver.py", line 395, in _check_ctypes_error
    raise CudaAPIError(retcode, msg)
numba.cuda.cudadrv.driver.CudaAPIError: [1] Call to cuLaunchKernel results in CUDA_ERROR_INVALID_VALUE

??

Also I don’t understand what you meant here:

Note that this approach will only work for kernels. If you want to compile C++ device functions to be called from Numba kernels, then you will need to ensure that the PTX (or C++) is linked by Numba, e.g. through the link kwarg of the @cuda.jit decorator, so that the compilation process has visibility of all the code at compile time in order to do the register allocation correctly.

Do you mean that I can successfully call the above foo_kernel inside cuda kernel(decorated with @cuda.jit) but I won’t be able to do the same with foo which is a device function(for that I will need to add @cuda.jit(link=[foo]) )? Could you please give the example for this
So if I try something like this:

code = '''
__device__ float foo(float a) {
    return fabs(a);
}
extern "C"
__global__ void foo_kernel(float *result, float a) {
    *result = foo(a);
}
'''
ptx, log = nvrtc.compile(code, "test.cu", (7, 5)) 
print(ptx)
ctx = cuda.current_context()
module = ctx.create_module_ptx(ptx)
cu_func = module.get_function("foo_kernel") 
@cuda.jit()
def abs_kernel(x, out):
    pos = cuda.grid(1)
    if pos < x.size:
        out[pos] = cppyy.gbl.foo_kernel(out[pos],x[pos])
   
n = 100000
x = np.arange(-n, n).astype(np.float32)  # Example array with negative values
out = np.empty_like(x)
print("Before operation:", out[:10])
threads_per_block = 128
blocks_per_grid = (out.size + (threads_per_block - 1)) // threads_per_block
abs_kernel[blocks_per_grid, threads_per_block](x, out)
print("After operation:", out[:10]) 

then I get this error:

Before operation: [1.2783930e-02 4.5559016e-41 1.2783930e-02 4.5559016e-41 8.7816775e-36
 0.0000000e+00 8.7816775e-36 0.0000000e+00           nan           nan]
12.py:94: NumbaPendingDeprecationWarning: Code using Numba extension API maybe depending on 'old_style' error-capturing, which is deprecated and will be replaced by 'new_style' in a future release. See details at https://numba.readthedocs.io/en/latest/reference/deprecation.html#deprecation-of-old-style-numba-captured-errors
Exception origin:
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/typing/context.py", line 348, in resolve_module_constants
    attrval = getattr(typ.pymod, attr)

  out[pos] = cppyy.gbl.foo_kernel(out[pos],x[pos])
Traceback (most recent call last):
  File "12.py", line 102, in <module>
    abs_kernel[blocks_per_grid, threads_per_block](x, out)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/dispatcher.py", line 539, in __call__
    return self.dispatcher.call(args, self.griddim, self.blockdim,
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/dispatcher.py", line 673, in call
    kernel = _dispatcher.Dispatcher._cuda_call(self, *args)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/dispatcher.py", line 681, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/dispatcher.py", line 924, in compile
    kernel = _Kernel(self.py_func, argtypes, **self.targetoptions)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/dispatcher.py", line 83, in __init__
    cres = compile_cuda(self.py_func, types.void, self.argtypes,
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/cuda/compiler.py", line 194, in compile_cuda
    cres = compiler.compile_extra(typingctx=typingctx,
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler.py", line 770, in compile_extra
    return pipeline.compile_extra(func)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler.py", line 461, in compile_extra
    return self._compile_bytecode()
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler.py", line 529, in _compile_bytecode
    return self._compile_core()
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler.py", line 508, in _compile_core
    raise e
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler.py", line 495, in _compile_core
    pm.run(self.state)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 368, in run
    raise patched_exception
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 356, in run
    self._runPass(idx, pass_inst, state)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 311, in _runPass
    mutated |= check(pss.run_pass, internal_state)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/compiler_machinery.py", line 273, in check
    mangled = func(compiler_state)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/typed_passes.py", line 110, in run_pass
    typemap, return_type, calltypes, errs = type_inference_stage(
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/typed_passes.py", line 91, in type_inference_stage
    errs = infer.propagate(raise_errors=raise_errors)
  File "/home/trinity/compres/venv/lib/python3.8/site-packages/numba/core/typeinfer.py", line 1086, in propagate
    raise errors[0]
numba.core.errors.TypingError: Failed in cuda mode pipeline (step: nopython frontend)
Internal error at resolving type of attribute "foo_kernel" of "$24load_attr.1".
<namespace cppyy.gbl at 0x2551190> has no attribute 'foo_kernel'. Full details:
  type object '' has no attribute 'foo_kernel'
  'foo_kernel' is not a known C++ class
  'foo_kernel' is not a known C++ template
  'foo_kernel' is not a known C++ enum
During: typing of get attribute at 12.py (94)
Enable logging at debug level for details.

File "12.py", line 94:
def abs_kernel(x, out):
    <source elided>
    if pos < x.size:
        out[pos] = cppyy.gbl.foo_kernel(out[pos],x[pos])

Your example was basically correct, except it was one parameter short for the kernel - the argument list only contained ptr, but the kernel takes a pointer to a float and a float. If I modify the example to pass the two parameters and print out the result:

from ctypes import c_float, sizeof
from numba import cuda
from numba.cuda.cudadrv import driver, nvrtc

code = '''
__device__ float foo(float a) {
    return fabs(a);
}
extern "C"
__global__ void foo_kernel(float *result, float a) {
    *result = foo(a);
}
'''
ptx, log = nvrtc.compile(code, "test.cu", (7, 5))
print(ptx)

ctx = cuda.current_context()
module = ctx.create_module_ptx(ptx)
cu_func = module.get_function("foo_kernel")

# Array and Memory Allocation
array = (c_float * 100)()
memory = ctx.memalloc(sizeof(array))
driver.host_to_device(memory, array, sizeof(array))

# Pointer and Stream Initialization
ptr = memory.device_ctypes_pointer
arg = c_float(-2.0)
stream = 0

# Kernel Launch
driver.launch_kernel(cu_func.handle,  # Kernel
                     1, 1, 1,         # gx, gy, gz
                     1, 1, 1,         # bx, by, bz
                     0,               # dynamic shared mem
                     stream,          # stream
                     [ptr, arg])           # arguments

# Data Transfer Back to Host
driver.device_to_host(array, memory, sizeof(array))
print(array[0])

it gives:

$ python compile_and_launch.py 
//
// Generated by NVIDIA NVVM Compiler
//

... <output snipped > ...

2.0

(will respond to the latter part in a separate post)

1 Like

I think the two cases here are:

  • Kernels: you won’t be able to call foo_kernel from inside a @cuda.jit-decorated kernel, because a kernel can’t call another kernel.
  • Device functions: you’re won’t be able to call foo just by getting a pointer to it, because there is no ABI for calling an arbitrary function through a function pointer in CUDA.

You’re correct that you will need to add link=[<something>] to be able to call an external function from a Numba kernel.

The example for doing this is (from Extending Numba for CUDA inside Cppyy - #5 by gmarkall above):

from numba import cuda, int32
from pynvjitlink import patch

patch.patch_numba_linker()

cu_functions = cuda.CUSource('''
extern "C" __device__ int foo(int* return_value){
  *return_value = 42;
  return 0;
}
''')

foo = cuda.declare_device('foo', int32())


# Instead of cu_functions, you can also give a filename on-disk
# and Numba will load it
@cuda.jit(link=[cu_functions])
def kernel():
    print(foo())


kernel[1, 1]()
cuda.synchronize()

If you had PTX source, e.g. from compilation with cppyy, then you’d change the CUSource class for a PTXSource class, and put the PTX code inside it instead (or pass a PTX filename if you’re passing the location of a file on disk).

The immediate cause of the exception appears to be something missing in the cppyy implementation to help Numba resolve the type of cppyy.gbl.foo_kernel, but I’m not sure what (though even if it could resolve the type, calling the kernel from inside the kernel wouldn’t be valid).

1 Like

@gmarkall

The immediate cause of the exception appears to be something missing in the cppyy implementation to help Numba resolve the type of cppyy.gbl.foo_kernel, but I’m not sure what (though even if it could resolve the type, calling the kernel from inside the kernel wouldn’t be valid).

Well, can you mention few pointers on how this can be solved though? Like I was thinking maybe changing the lower_external_call and get_pointer might help with this numba cuda flow ?

This is the current definition of lower_external_call for handling cpu context:

 @nb_iutils.lower_builtin(ol, *args)
            def lower_external_call(context, builder, sig, args,
                    ty=nb_types.ExternalFunctionPointer(extsig, ol.get_pointer), pyval=self._func, is_method=self._is_method):
                    ptrty = context.get_function_pointer_type(ty)
                    ptrval = context.add_dynamic_addr(
                    builder, ty.get_pointer(pyval), info=str(pyval))
                    fptr = builder.bitcast(ptrval, ptrty)
                    return context.call_function_pointer(builder, fptr, args)

But we want another case for handling gpu context for lower_external_call, we might need a completely new definition of it. Like we might need a return type similar to ptrty = context.get_function_pointer_type(ty) and reimplementation of get_pointer which is passed to the ExternalFunctionPointer(not sure what else might be needed). Example:

if cuda.current_context() is not None:
           @nb_iutils.lower_builtin(ol, *args)
           def lower_external_call(context, builder, sig, args,
               ty=nb_types.ExternalFunctionPointer(extsig, ol.get_pointer), pyval=self._func, is_method=self._is_method):
                 #handle the gpu memory allocation, etc

else:
      # the above definition of lower_external_call for handling cpu context

Could you give me few suggestions on how I can figure out this problem? Which parts to look maybe with respect to handle gpu/numba flow ?
This is the numba extension inside Cppyy: https://github.com/wlav/cppyy/blob/master/python/cppyy/numba_ext.py#L224-L249

@gmarkall
In this thread Extending Numba for CUDA inside Cppyy - #7 by gmarkall

Therefore, in cppyy you will need an alternative mechanism to function pointers to call the C++ functions. The existing built-in CUDA C/C++ support does this with the declare_device() function to insert Numba typing and lowering that implements the call using the name of the C/C++ device function - you may be able to do the same in the cppyy implementation.

you mentioned above that the cppyy needs alternative mechanisms to function pointers to call the C++ functions, just to be clear, you mean normal C++ functions and not the device functions(that we are talking about above), right?
So for whatever the case might be(device or normal C++ functions) there should be something like declare_device in cppyy/numba_ext.py that handles the call to external functions inside numba cuda kernel ? So based on declare_device definition maybe the code below can help us with the task of calling the device functions inside the numba kernel.(Not sure what changes might be required though).

def declare_device_function_template(name, restype, argtypes):
    from .descriptor import cuda_target
    typingctx = cuda_target.typing_context
    targetctx = cuda_target.target_context
    sig = typing.signature(restype, *argtypes)
    extfn = ExternFunction(name, sig)

    class device_function_template(ConcreteTemplate):
        key = extfn
        cases = [sig]

    fndesc = funcdesc.ExternalFunctionDescriptor(
        name=name, restype=restype, argtypes=argtypes)
    typingctx.insert_user_function(extfn, device_function_template)
    targetctx.insert_user_function(extfn, fndesc)

    return device_function_template

??