0

The question: Is it possible to write a function in rust that will accept float or complex ndarrays as inputs?

I'm new to rust and I come from python/numpy land, where float arrays and complex arrays play very nicely together. So when I write a function, I don't worry if one or several of the inputs are complex.

So, I want to write a function something like this:

use ndarray::{ArrayD, ArrayViewD};
use num_complex::Complex64;
use f64;

fn example(a: f64, x: ArrayViewD<'_, Complex64>, y: ArrayViewD<'_, f64>) -> ArrayD<Complex64> {
    &x * a + &y
}

But make the inputs generic. I'm guessing is should look something like this:

fn example<T: ???>(a: T, x: ArrayViewD<'_, T>, y: ArrayViewD<'_, T>) -> ArrayD<T> {
    &x * a + &y
}

But I'm not sure what traits to require.

I suppose the brute force method would be to coerce everything to complex? Could work, but requires a bunch of conversions, double the memory, and generally just does not feel right.

My example above is a modification of a PyO3 example: https://github.com/PyO3/rust-numpy/blob/main/examples/simple-extension/src/lib.rs


Edit

I'm still missing something. If I try:

fn example<T: Num>(a: T, x: ArrayViewD<'_, T>, y: ArrayViewD<'_, T>) -> ArrayD<T> {     
&x * a + y 
}

I get:

error[E0369]: cannot multiply `&ArrayBase<ViewRepr<&T>, Dim<IxDynImpl>>` by `T`
  --> src\main.rs:26:8
   |
26 |     &x * a + &y
   |     -- ^ - T
   |     |
   |     &ArrayBase<ViewRepr<&T>, Dim<IxDynImpl>>
   |
help: consider further restricting this bound
   |
25 | fn example<T: Num + std::ops::Mul<Output = T>>(a: T, x: ArrayViewD<'_, T>, y: ArrayViewD<'_, T>) -> ArrayD<T> {
   |                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^

error[E0308]: mismatched types
  --> src\main.rs:26:14
   |
25 | fn example<T: Num>(a: T, x: ArrayViewD<'_, T>, y: ArrayViewD<'_, T>) -> ArrayD<T> {
   |            - this type parameter
26 |     &x * a + &y
   |              ^^ expected type parameter `T`, found reference
   |
   = note: expected type parameter `T`
                   found reference `&ArrayBase<ViewRepr<&T>, Dim<IxDynImpl>>`
 

So it want's me to add std::ops::Mul<Output = T>. But if I add it, I get a similar error, and it want's me to add it again (recursively?).

I think the issues is that my inputs a, x, and y can all be different types, which is not the same as the output type.

CrispyDyne
  • 231
  • 3
  • 10
  • Does this answer your question? [How can I create an is\_prime function that is generic over various integer types?](https://stackoverflow.com/questions/26810793/how-can-i-create-an-is-prime-function-that-is-generic-over-various-integer-types) – E_net4 May 12 '21 at 17:08
  • You would use the traits in the `num_traits` crate, such as [`Num`](https://docs.rs/num/0.4.0/num/traits/trait.Num.html). This problem is not specific to `ndarray`. – E_net4 May 12 '21 at 17:09
  • That is certainly helpful, but I am still missing something. Edit above shows my failed attempt with ```Num``` trait. – CrispyDyne May 12 '21 at 17:56
  • They are acomplishing it in this linear algebra library, but I can't wrap my head around it: https://github.com/rust-ndarray/ndarray-linalg/blob/master/ndarray-linalg/tests/eig.rs – CrispyDyne May 12 '21 at 17:58

0 Answers0