2

Consider this fortran module, in the file test.f90

module mymod
  use iso_c_binding, only: c_double
  implicit none
contains
  subroutine addstuff(a,b,c) bind(c,name='addstuff_wrap')
    real(c_double), intent(in) :: a, b
    real(c_double), intent(out) :: c
    c = a + b
  end subroutine
end module

Which can be compiled with gfortran test.f90 -shared -fPIC -o test.so . I can call it from python with

import ctypes as ct
mylib = ct.CDLL('test.so')
addstuff.argtypes = [ct.POINTER(ct.c_double), ct.POINTER(ct.c_double), ct.POINTER(ct.c_double)]
addstuff.restype = None

a = ct.c_double(1.0)
b = ct.c_double(2.0)
c = ct.c_double()

addstuff(ct.byref(a),ct.byref(b),ct.byref(c))
print(c.value)

This returns the right answer, 3.0. However, I want to call this from a numba jitted function

from numba import njit
@njit
def test(a, b):
    c = ct.c_double()
    addstuff(ct.byref(ct.c_double(a)), \
             ct.byref(ct.c_double(b)), \
             ct.byref(c))
    return c.value
test(1.0, 2.0)

But this doesn't work. It returns the error

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Unknown attribute 'c_double' of type Module(<module 'ctypes' from ...)

File "<ipython-input-14-f8fb94981395>", line 3:
def test(a, b):
    c = ct.c_double()
    ^

Does anyone know a work around? This is annoying because numba does claim it supports c_double types.

nicholaswogan
  • 631
  • 6
  • 13
  • Numba doesn't support byref on double values, you can only pass pointers to arrays to heap allocated arrays. I guess the easiest and most performant way would be to allocate an array on the stack and pass a pointer to it https://stackoverflow.com/a/59538114/4045774 – max9111 May 06 '21 at 07:30
  • In the answer I gave below, if `a` was an array, would it be heap or stack allocated? – nicholaswogan May 08 '21 at 21:16

1 Answers1

2

Here is the solution:

import ctypes as ct
from numba import njit

mylib = ct.CDLL('test.so')
addstuff.argtypes = [ct.c_void_p, ct.c_void_p, ct.c_void_p]
addstuff.restype = None

@njit
def test(a, b):
    aa = np.array(a,np.float64)
    bb = np.array(b,np.float64)
    c = np.array(0.0,np.float64)
    addstuff(aa.ctypes.data, \
             bb.ctypes.data, \
             c.ctypes.data)
    return c.item()
nicholaswogan
  • 631
  • 6
  • 13