0

I am currently trying to speedup a python scientific code with the help of PyO3. Until now, numba's @jit has been doing wonders, but some routines are impossible to parallelize. I decided to turn to Rust and Rayon which provide the tools to efficiently write a parallel version of said routine. However, using PyO3 to replace the jited sequential function by the parallel Rust code has slowed down the simulation.

As this rust function is run repeatedly in the python code, I was wondering if the thread pool used by Rayon was spawned each time, adding a large overhead to the function.

If that is the case, would there be a way to keep the rust runtime alive with the pool between each call from the python code?

Here is the Rust function code:

pub fn gather_n(positions: &[f64], n_cell: usize) -> Array1<f64> {
    let mut density = positions
        .par_iter()
        .fold(
            || Array1::zeros(n_cell),
            |mut cells, x| {
                let i = x.floor() as usize;
                if i == n_cell {
                    cells[n_cell - 1] += 1.;
                } else {
                    let d = x - i as f64;
                    cells[i] += 1. - d;
                    cells[i + 1] += d;
                }
                cells
            },
        )
        .reduce(
            || Array1::zeros(n_cell),
            |mut a, b| {
                a += &b;
                a
            },
        );
    density[0] *= 2.;
    density[n_cell - 1] *= 2.;
    density
}

and the PyO3 wrapper:

#[pyfunction]
fn gather_n<'py>(py: Python<'py>,x: &PyArray1<f64>, n:usize) -> PyResult<&'py PyArray1<f64>> {
    let x_slice = unsafe{x.as_slice()?};
    Ok(gather::gather_n(x_slice,n).into_pyarray(py))
}
Quettle
  • 23
  • 4
  • well, difficult to tell without the actual code. – Netwave Sep 15 '22 at 07:04
  • How big is `n` / `n_cell` typically? – Finomnis Sep 15 '22 at 12:04
  • What is `Array1`? Where does it come from? – Finomnis Sep 15 '22 at 12:07
  • @Finomnis A few hundreds to a few thousands. The more the better. `Array1` is a 1D f64 array from the `ndarray`crate. This crate is the rust version of numpy, and is actually used by pyo3 to interfaces numpy's array with the rust code. – Quettle Sep 15 '22 at 15:19
  • I'm not entirely sure what I should tell you, honestly ... the code looks fine to me. Did you benchmark the speedup vs non-parallel? How good does it scale? That's a good indicator of where the overhead may be. – Finomnis Sep 15 '22 at 15:26

0 Answers0