I am trying to write a common interface for different types of matrices that provides a way to mutably iterate their rows and modify them. I have the following matrix types:
struct NdArrayMatrix {
matrix: Array2<f32>,
}
struct ByteMatrix<'a> {
data: &'a mut [u8],
rows: usize,
cols: usize,
}
Where the first one is just a RAM-stored matrix, and the second is memory mapped, using the MMap library, but for convenience, I omit those details. First, I made a trait to be able to modify both of them using the same interface:
trait ReadWrite
{
fn rw_read(&self, i: usize, j: usize) -> f32;
fn rw_write(&mut self, i: usize, j: usize, val: f32);
}
Then, I've created a trait that produces a rayon::iter::IndexedParallelItertor
from both of these:
trait Sliceable<'a>
{
type Output: IndexedParallelIterator;
fn rows_par_iter(&'a mut self ) -> Self::Output;
}
Up to this point works everything fine. But when I want to use these in a generic context, such as:
fn<'a, T> slice_and_write(matrix: T)
where T: Sliceable<'a>
{
T.rows_par_iter()
.map(|mut row| {
row.rw_write(...);
})
...
}
I run into problems. It is obvious, that row, in this case, doesn't implement ReadWrite
so no surprise there. So what I tried to do, is to create an iterator trait based on IndexedParallelItertor
:
trait RwIterator: IndexedParallelIterator {
type Item: ReadWrite;
}
and modify Sliceable
:
trait Sliceable<'a>
{
type Output: RwIterator;
fn rows_par_iter(&'a mut self ) -> Self::Output;
}
Running this I get the error:
| row.rw_write(...);
| ^^^^^^^^ method not found in `<<T as Sliceable<'a>>::Output as ParallelIterator>::Item`
Which is, again, fairly obvious. I suspect that the map
function does only require the trait bound ParallelIterator, hence can't take advantage of the trait RwIterator
.
My question is: Is there any way around this problem, or an alternate way for doing this?
EDIT: Here is a minimal reproducible code example, only using one of the matrix structures.
use ndarray::Array2;
use rayon::prelude::*;
use ndarray::Axis;
use ndarray::parallel::Parallel;
use ndarray::Dim;
use ndarray::iter::AxisIterMut;
use rayon::iter::ParallelIterator;
use ndarray::ViewRepr;
use ndarray::ArrayBase;
struct NdArrayMatrix {
matrix: Array2<f32>,
}
impl NdArrayMatrix {
pub fn new() -> Self {
let matrix = Array2::zeros((10, 10));
Self {
matrix,
}
}
}
trait ReadWrite
{
fn rw_read(&self, i: usize, j: usize) -> f32;
fn rw_write(&mut self, i: usize, j: usize, val: f32);
}
impl ReadWrite for NdArrayMatrix {
fn rw_read(&self, i: usize, j: usize) -> f32 {
self.matrix[[i, j]]
}
fn rw_write(&mut self, i: usize, j: usize, val: f32) {
self.matrix[[i, j]] = val;
}
}
impl ReadWrite for ArrayBase<ViewRepr<&mut f32>, Dim<[usize; 1]>> {
fn rw_read(&self, i: usize, j: usize) -> f32 {
self[j]
}
fn rw_write(&mut self, i: usize, j: usize, val: f32) {
self[j] = val;
}
}
trait RwIterator: IndexedParallelIterator {
type Item: ReadWrite;
}
impl<'a> RwIterator for Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>> {
type Item = ArrayBase<ViewRepr<&'a mut f32>, Dim<[usize; 1]>>;
}
trait Sliceable<'a>
{
type Output: RwIterator;
fn rows_par_iter(&'a mut self ) -> Self::Output;
}
impl<'a> Sliceable<'a> for NdArrayMatrix {
type Output = Parallel<AxisIterMut<'a, f32, Dim<[usize; 1]>>>;
fn rows_par_iter(&'a mut self) -> Self::Output {
self.matrix
.axis_iter_mut(Axis(0))
.into_par_iter()
}
}
fn main() {
let mut matrix: NdArrayMatrix = NdArrayMatrix::new();
test(matrix);
}
fn test<'a, T> (matrix: T)
where T: Sliceable<'a> + ReadWrite
{
matrix.rows_par_iter()
.map(|mut row| {
row.rw_write(0, 0, 0.0);
}).count();
}