1

I am new to rust. I want to write a function which later can be imported into Python as a module using the pyo3 crate.

Below is the Python implementation of the function I want to implement in Rust:

def pcompare(a, b):
    letters = []
    for i, letter in enumerate(a):
        if letter != b[i]:
            letters.append(f'{letter}{i + 1}{b[i]}')
    return letters

The first Rust implemention I wrote looks like this:

use pyo3::prelude::*;


#[pyfunction]
fn compare_strings_to_vec(a: &str, b: &str) -> PyResult<Vec<String>> {

    if a.len() != b.len() {
        panic!(
            "Reads are not the same length! 
            First string is length {} and second string is length {}.",
            a.len(), b.len());
    }

    let a_vec: Vec<char> = a.chars().collect();
    let b_vec: Vec<char> = b.chars().collect();

    let mut mismatched_chars = Vec::new();

    for (mut index,(i,j)) in a_vec.iter().zip(b_vec.iter()).enumerate() {
        if i != j {
            index += 1;
            let mutation = format!("{i}{index}{j}");
            mismatched_chars.push(mutation);
        } 

    }
    Ok(mismatched_chars)
}


#[pymodule]
fn compare_strings(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(compare_strings_to_vec, m)?)?;
    Ok(())
}

Which I builded in --release mode. The module could be imported to Python, but the performance was quite similar to the performance of the Python implementation.

My first question is: Why is the Python and Rust function similar in speed?

Now I am working on a parallelization implementation in Rust. When just printing the result variable, the function works:

use rayon::prelude::*;

fn main() {
    
    let a: Vec<char> = String::from("aaaa").chars().collect();
    let b: Vec<char> = String::from("aaab").chars().collect();
    let length = a.len();
    let index: Vec<_> = (1..=length).collect();
    
    let mut mismatched_chars: Vec<String> = Vec::new();
    
    (a, index, b).into_par_iter().for_each(|(x, i, y)| {
        if x != y {
            let mutation = format!("{}{}{}", x, i, y).to_string();
            println!("{mutation}");
            //mismatched_chars.push(mutation);
        }
    });
    
}

However, when I try to push the mutation variable to the mismatched_charsvector:

use rayon::prelude::*;

fn main() {
    
    let a: Vec<char> = String::from("aaaa").chars().collect();
    let b: Vec<char> = String::from("aaab").chars().collect();
    let length = a.len();
    let index: Vec<_> = (1..=length).collect();
    
    let mut mismatched_chars: Vec<String> = Vec::new();
    
    (a, index, b).into_par_iter().for_each(|(x, i, y)| {
        if x != y {
            let mutation = format!("{}{}{}", x, i, y).to_string();
            //println!("{mutation}");
            mismatched_chars.push(mutation);
        }
    });
    
}

I get the following error:

error[E0596]: cannot borrow `mismatched_chars` as mutable, as it is a captured variable in a `Fn` closure
  --> src/main.rs:16:13
   |
16 |             mismatched_chars.push(mutation);
   |             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot borrow as mutable

For more information about this error, try `rustc --explain E0596`.
error: could not compile `testing_compare_strings` due to previous error

I tried A LOT of different things. When I do:

use rayon::prelude::*;

fn main() {
    
    let a: Vec<char> = String::from("aaaa").chars().collect();
    let b: Vec<char> = String::from("aaab").chars().collect();
    let length = a.len();
    let index: Vec<_> = (1..=length).collect();
    
    let mut mismatched_chars: Vec<&str> = Vec::new();
    
    (a, index, b).into_par_iter().for_each(|(x, i, y)| {
        if x != y {
            let mutation = format!("{}{}{}", x, i, y).to_string();
            mismatched_chars.push(&mutation);
        }
    });
    
}

The error becomes:

error[E0596]: cannot borrow `mismatched_chars` as mutable, as it is a captured variable in a `Fn` closure
  --> src/main.rs:16:13
   |
16 |             mismatched_chars.push(&mutation);
   |             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot borrow as mutable

error[E0597]: `mutation` does not live long enough
  --> src/main.rs:16:35
   |
10 |     let mut mismatched_chars: Vec<&str> = Vec::new();
   |         -------------------- lifetime `'1` appears in the type of `mismatched_chars`
...
16 |             mismatched_chars.push(&mutation);
   |             ----------------------^^^^^^^^^-
   |             |                     |
   |             |                     borrowed value does not live long enough
   |             argument requires that `mutation` is borrowed for `'1`
17 |         }
   |         - `mutation` dropped here while still borrowed

I suspect that the solution is quite simple, but I cannot see it myself.

Chayim Friedman
  • 47,971
  • 5
  • 48
  • 77
  • "My first question is: Why is the Python and Rust function similar in speed?" probably because most of the workload is creating strings, and python has some ability to cache / intern such which Rust does not. And for the simpler cases (small / identical strings) the majority of the workload is going to be allocating the unnecessary `a_vec` and `b_vec`. – Masklinn Sep 22 '22 at 09:03
  • Please explain why `a_vec` and `b_vec` are unnecessary. – William Rosenbaum Sep 22 '22 at 09:08
  • 1
    Because `zip` works on iterators, and `String::chars` is an iterator. You can just zip the two `chars` iterators. – Masklinn Sep 22 '22 at 09:15
  • 4
    Given the simplicity of `pcompare` / `compare_strings_to_vec`, the vast majority of runtime will most likely be spent in interpreter overhead, unless the strings in question are very long (many megabytes) – user2722968 Sep 22 '22 at 09:16
  • Also you should not panic from pyo3, that's going to trigger a BaseException derivate. To raise an `Exception` you need to return an `Err`. – Masklinn Sep 22 '22 at 09:19
  • 2
    And concurrency is useful when you have a fair amount to do, here I would expect synchronisation overhead to be way about the small amount of work per iteration. Not to mention the conversion between Rust and Python types. You might actually see some gain by creating and using the Python types directly, even if they're somewhat less convenient than the Rust ones: here Rust has to decode the Python strings to Rust strings on call, then it has to convert the Vec of Rust strings to a list of python strings on output. – Masklinn Sep 22 '22 at 09:21

2 Answers2

2

You have the right idea with what you are doing, but you will want to try to use an iterator chain with filter and map to remove or convert iterator items into different values. Rayon also provides a collect method similar to regular iterators to convert items into a type T: FromIterator (such as Vec<T>).

fn compare_strings_to_vec(a: &str, b: &str) -> Vec<String> {
    // Same as with the if statement, but just a little shorter to write
    // Plus, it will print out the two values it is comparing if it errors.
    assert_eq!(a.len(), b.len(), "Reads are not the same length!");
    
    // Zip the character iterators from a and b together
    a.chars().zip(b.chars())
        // Iterate with the index of each item
        .enumerate()
        // Rayon function which turns a regular iterator into a parallel one 
        .par_bridge()
        // Filter out values where the characters are the same
        .filter(|(_, (a, b))| a != b)
        // Convert the remaining values into an error string
        .map(|(index, (a, b))| {
            format!("{}{}{}", a, index + 1, b)
        })
        // Turn the items of this iterator into a Vec (Or any other FromIterator type).
        .collect()
}

Rust Playground

Optimizing for speed

On the other hand, if you want speed we need to approach this problem from a different direction. You may have noticed, but the rayon version is quite slow since the cost of spawning a thread and using concurrency structures is orders of magnitude more than just simply comparing the bytes in the original thread. In my benchmarks, I found that even with better workload distribution, additional threads were only helpful on my machine (64GB RAM, 16 cores) when the strings were at least 1-2 million bytes long. Given that you have stated they are typically ~30,000 bytes long I think using rayon (or really any other threading for comparisons of this size) will only slow down your code.

Using criterion for benchmarking, I eventually came to this implementation. It generally gets about 2.8156 µs per run on strings of 30,000 characters with 10 different bytes. For comparison, the code posted in the original question usually gets around 61.156 µs on my system under the same conditions so this should give a ~20x speedup. It can vary a bit, but it consistently got the best results in the benchmark. I'm guessing this should be fast enough to have this step no-longer be the bottleneck in your code.

This key focus of this implementation is to do the comparisons in batches. We can take advantage of the 128bit registers on most CPUs to compare the input in 16 byte batches. Upon an inequality being found, the 16 byte section it covers is re-scanned for the exact position of the discrepancy. This gives a decent boost to performance. I initially thought that a usize would work better, but it seems that was not the case. I also attempted to use the portable_simd nightly feature to write a simd version of this code, but I was unable to match the speed of this code. I suspect this was either due to missed optimizations or a lack of experience to effectively use simd on my part.

I was worried about drops in speed due to alignment of chunks not being enforced for u128 values, but it seems to mostly be a non-issue. First of all, it is generally quite difficult to find allocators which are willing to allocate to an address which is not a multiple of the system word size. Of course, this is due to practicality rather than any actual requirement. When I manually gave it unaligned slices (unaligned for u128s), it is not significantly effected. This is why I do not attempt to enforce that the start index of the slice be aligned to align_of::<u128>().

fn compare_strings_to_vec(a: &str, b: &str) -> Vec<String> {
    let a_bytes = a.as_bytes();
    let b_bytes = b.as_bytes();
    let remainder = a_bytes.len() % size_of::<u128>();

    // Strongly suggest to the compiler we are iterating though u128
    a_bytes
        .chunks_exact(size_of::<u128>())
        .zip(b_bytes.chunks_exact(size_of::<u128>()))
        .enumerate()
        .filter(|(_, (a, b))| {
            let a_block: &[u8; 16] = (*a).try_into().unwrap();
            let b_block: &[u8; 16] = (*b).try_into().unwrap();

            u128::from_ne_bytes(*a_block) != u128::from_ne_bytes(*b_block)
        })
        .flat_map(|(word_index, (a, b))| {
            fast_path(a, b).map(move |x| word_index * size_of::<u128>() + x)
        })
        .chain(
            fast_path(
                &a_bytes[a_bytes.len() - remainder..],
                &b_bytes[b_bytes.len() - remainder..],
            )
            .map(|x| a_bytes.len() - remainder + x),
        )
        .map(|index| {
            format!(
                "{}{}{}",
                char::from(a_bytes[index]),
                index + 1,
                char::from(b_bytes[index])
            )
        })
        .collect()
}

/// Very similar to regular route, but with nothing fancy, just get the indices of the overlays
#[inline(always)]
fn fast_path<'a>(a: &'a [u8], b: &'a [u8]) -> impl 'a + Iterator<Item = usize> {
    a.iter()
        .zip(b.iter())
        .enumerate()
        .filter_map(|(x, (a, b))| (a != b).then_some(x))
}
Locke
  • 7,626
  • 2
  • 21
  • 41
  • @WilliamRosenbaum Do you know how long these strings are normally and how frequently you expect to find inconsistencies? You could likely get even greater speeds from different approaches. For large inputs, it would likely be more efficient to break input strings into blocks so each thread can do more work between dealing with other threads. It seems likely that this will be limited by how fast a single thread can create iterator items to give to worker threads after just a couple of cores are used. Additionally, if you don't expect many inconsistencies, you could compare words of memory. – Locke Sep 22 '22 at 11:00
  • They are around 30 000 characters long and I have a lot of them. The frequency in which a difference is found is approximately 10 per sequence. The sequences are stored in a csv file and I would be happy to write the resulting list to a csv file instead of storing it in memory. Thank you so much for your help! – William Rosenbaum Sep 22 '22 at 12:26
  • Is it also safe to assume that characters will only be an ascii byte? In your example you use `str::len()`, but iterate using `.chars()`. Since Rust uses UTF-8 to store strings, a character may be multiple bytes long leading to the byte length (`.len()`) differing from the character count. – Locke Sep 22 '22 at 13:20
  • The characters will be only "A", "C", "G" or "T". Read from a plain .txt file. – William Rosenbaum Sep 22 '22 at 13:46
  • 2
    @WilliamRosenbaum I had some free time so I put together a faster implementation (added to the answer). From my benchmarks it should run around 15 to 20x faster than your initial implementation. I also found that it was far more expensive to spin up a second thread than it is to simply handle the function in a single-threaded implementation on my system. I only began to see marginal improvements on my 16 core laptop at length 1,000,000 character strings so unless your data includes many longer entries, I doubt you will see a performance boost from using `rayon` in this way. – Locke Sep 22 '22 at 23:50
  • 2
    @Locke I think Rayon defaults to keeping a global threadpool so after the first call you should be able to benefit from it, without the need to spin up new threads every time. Though I still doubt the parallelism can do much good on payloads of ~30k, I feel synchronisation overhead would still dominate. – Masklinn Sep 23 '22 at 06:02
  • @Masklinn, I was thinking the same thing. I used `criterion` for my benchmark so it would have started the threadpool during the 3 second warmup period before the benchmark samples were collected. Perhaps some time could have been shaved off by manually implementing some dedicated worker threads that continuously process requests via channels. This kind of use case seems perfect for map-reduce so if jobs had been batched together we could probably see greater speedups. To truly get the best performance though, we would need to know what the data CSV looks like to optimize it with IO in mind. – Locke Sep 23 '22 at 18:05
-1

You cannot directly access the field mismatched_chars in a multithreading environment.

You can use Arc<RwLock> to access the field in multithreading.

use rayon::prelude::*;
use std::sync::{Arc, RwLock};

fn main() {
    let a: Vec<char> = String::from("aaaa").chars().collect();
    let b: Vec<char> = String::from("aaab").chars().collect();
    let length = a.len();
    let index: Vec<_> = (1..=length).collect();

    let mismatched_chars: Arc<RwLock<Vec<String>>> = Arc::new(RwLock::new(Vec::new()));

    (a, index, b).into_par_iter().for_each(|(x, i, y)| {
        if x != y {
            let mutation = format!("{}{}{}", x, i, y);
            mismatched_chars
                .write()
                .expect("could not acquire write lock")
                .push(mutation);
        }
    });

    for mismatch in mismatched_chars
        .read()
        .expect("could not acquire read lock")
        .iter()
    {
        eprintln!("{}", mismatch);
    }
}
AlexN
  • 1,613
  • 8
  • 21