The question: Is it possible to write a function in rust that will accept float or complex ndarrays as inputs?
I'm new to rust and I come from python/numpy land, where float arrays and complex arrays play very nicely together. So when I write a function, I don't worry if one or several of the inputs are complex.
So, I want to write a function something like this:
use ndarray::{ArrayD, ArrayViewD};
use num_complex::Complex64;
use f64;
fn example(a: f64, x: ArrayViewD<'_, Complex64>, y: ArrayViewD<'_, f64>) -> ArrayD<Complex64> {
&x * a + &y
}
But make the inputs generic. I'm guessing is should look something like this:
fn example<T: ???>(a: T, x: ArrayViewD<'_, T>, y: ArrayViewD<'_, T>) -> ArrayD<T> {
&x * a + &y
}
But I'm not sure what traits to require.
I suppose the brute force method would be to coerce everything to complex? Could work, but requires a bunch of conversions, double the memory, and generally just does not feel right.
My example above is a modification of a PyO3 example: https://github.com/PyO3/rust-numpy/blob/main/examples/simple-extension/src/lib.rs
Edit
I'm still missing something. If I try:
fn example<T: Num>(a: T, x: ArrayViewD<'_, T>, y: ArrayViewD<'_, T>) -> ArrayD<T> {
&x * a + y
}
I get:
error[E0369]: cannot multiply `&ArrayBase<ViewRepr<&T>, Dim<IxDynImpl>>` by `T`
--> src\main.rs:26:8
|
26 | &x * a + &y
| -- ^ - T
| |
| &ArrayBase<ViewRepr<&T>, Dim<IxDynImpl>>
|
help: consider further restricting this bound
|
25 | fn example<T: Num + std::ops::Mul<Output = T>>(a: T, x: ArrayViewD<'_, T>, y: ArrayViewD<'_, T>) -> ArrayD<T> {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
error[E0308]: mismatched types
--> src\main.rs:26:14
|
25 | fn example<T: Num>(a: T, x: ArrayViewD<'_, T>, y: ArrayViewD<'_, T>) -> ArrayD<T> {
| - this type parameter
26 | &x * a + &y
| ^^ expected type parameter `T`, found reference
|
= note: expected type parameter `T`
found reference `&ArrayBase<ViewRepr<&T>, Dim<IxDynImpl>>`
So it want's me to add std::ops::Mul<Output = T>
. But if I add it, I get a similar error, and it want's me to add it again (recursively?).
I think the issues is that my inputs a
, x
, and y
can all be different types, which is not the same as the output type.