0

I'm trying to use bindings generated for cuBLAS using bindgen. Here's what my code looks like:

mod tests {
    use super::*;

    #[test]
    pub fn alpha () {
        let mut handle: cublasHandle_t;
        let mut stat: cublasStatus_t;
        let mut cudaStat: cudaError_t;

       ... some stuff

        unsafe {
            cudaStat = cudaMalloc(a.as_mut_ptr() as *mut *mut c_void, a.len() as u64);
            cudaStat = cudaMalloc(b.as_mut_ptr() as *mut *mut c_void, b.len() as u64);
            cudaStat = cudaMalloc(c.as_mut_ptr() as *mut *mut c_void, c.len() as u64);

            stat = cublasCreate_v2(handle as *mut *mut cublasContext);
        }

        ...some stuff
    }
}

I get this error:

error: expected expression, found keyword `mut`
  --> src/lib.rs:44:37
   |
44 |             stat = cublasCreate_v2(handle as *mut *mut cublasContext);
   |                                     ^^^ expected expression

error: could not compile `cublas-rs` due to previous error

NOTE: cublasHandle_t is a typedef for *mut cublasContext.

I've tried doing just &handle, *mut handle, etc but no dice.

cublasHandle_t is only supposed to be initialized by cublasCreate_v2.

Here's what things look like in bindings.rs:

// cublasContext struct we want to pass to cublasCreate_v2
#[repr(C)]
#[derive(Debug, Copy, Clone)]
pub struct cublasContext {
    _unused: [u8; 0],
}

// alternative typedef used by cublas
pub type cublasHandle_t = *mut cublasContext;

// function to create a cublas handle
extern "C" {
    pub fn cublasCreate_v2(handle: *mut cublasHandle_t) -> cublasStatus_t;
}

I've tried initializing it like this:

let mut handle: cublasHandle_t = *mut cublasContext { _unused: [] }; // no luck
let mut handle: cublasHandle_t = cublasContext { _unused: [] } as *mut cublasContext; // no

How do I call a function like this?

  • Look closely at what `cublasCreate_v2` takes as an argument, and what type `handle` is. They aren't compatible, and forcing it via a cast won't work. – Colonel Thirty Two Sep 02 '22 at 01:09
  • @ColonelThirtyTwo right, I was trying stuff. I don't know how to get *mut cublasHandle_t. If its not compatible, then what is? – Rylan Yancey Sep 02 '22 at 01:22
  • You say you're using bindgen. Is `bindings.rs` generated with bindgen? Did you add those comments after the fact? – PitaJ Sep 02 '22 at 16:09

1 Answers1

0

Actually, you probably want to just set handle to a null pointer to start with, since you say cublasCreate_v2 is supposed to create the handle.

mod tests {
    use super::*;

    pub fn alpha () {
        // initialize with null pointer
        let mut handle: cublasHandle_t = std::ptr::null_mut();
        let mut stat: cublasStatus_t;

        // ... some stuff

        unsafe {
            cudaStat = cudaMalloc(a.as_mut_ptr() as *mut *mut c_void, a.len() as u64);
            cudaStat = cudaMalloc(b.as_mut_ptr() as *mut *mut c_void, b.len() as u64);
            cudaStat = cudaMalloc(c.as_mut_ptr() as *mut *mut c_void, c.len() as u64);

            // pass pointer to the pointer, using `&mut x`
            stat = cublasCreate_v2(&mut handle);
        }

        // ...some stuff
    }
}

You need to create the context as a variable first, before creating a pointer to it. To create a pointer to a value, you use &mut value, similar to &value in C.

mod tests {
    use super::*;
    
    pub fn alpha () {
        let mut context = cublasContext { _unused: [] }; // create context
        // get pointer to context, `&mut x` can be assigned to `*mut x`
        let mut handle: cublasHandle_t = &mut context;
        let mut stat: cublasStatus_t;
    
        // ... some stuff
    
        unsafe {
            cudaStat = cudaMalloc(a.as_mut_ptr() as *mut *mut c_void, a.len() as u64);
            cudaStat = cudaMalloc(b.as_mut_ptr() as *mut *mut c_void, b.len() as u64);
            cudaStat = cudaMalloc(c.as_mut_ptr() as *mut *mut c_void, c.len() as u64);
    
            // pass double-pointer, using `&mut x` again
            stat = cublasCreate_v2(&mut handle);
        }
    
        // ...some stuff
    }
}

You might want to add #[derive(Default)] to cublasContext, so you can do cublasContext::default() instead of needing to set up _unused.

PitaJ
  • 12,969
  • 6
  • 36
  • 55