2

I currently need to intersect two sorted vectors.

Here are my implementations:

#include <immintrin.h>
#include <stdint.h>

#include <algorithm>
#include <bit>
#include <chrono>
#include <cstring>
#include <iostream>
#include <string>
#include <vector>

using namespace std;

// Type your code here, or load an example.
void square(vector<uint64_t> &res, vector<uint64_t> &a, vector<uint64_t> &b) {
  set_intersection(a.begin(), a.end(), b.begin(), b.end(), back_inserter(res));
}

void super(vector<uint64_t> &res, vector<uint64_t> &a, vector<uint64_t> &b) {
  uint32_t offset = 0;

  for (auto it = a.begin(), it_end = a.end(); it != it_end; ++it) {
    auto r = _mm512_set1_epi64((uint64_t)*it);

    for (auto it2 = b.begin(), it2_end = b.end(); it2 != it2_end; it2 += 8) {
      auto distance = it2 - b.begin();
      auto *ptr = b.data() + distance;

      auto s = _mm512_loadu_epi64((const void *)ptr);

      auto m = _mm512_cmpeq_epu64_mask(r, s);

      auto count = popcount(m);

      auto *ptr2 = &(*(res.begin() + offset));

      _mm512_mask_compressstoreu_epi64((void *)ptr2, m, s);

      offset += count;
    }
  }

  // res.resize(offset);
}

int main() {
  vector<uint64_t> a, b;

  for (uint32_t x = 0; x < (64 * 1024); ++x) {
    a.push_back(x);
    b.push_back(x + 21);
  }
  vector<uint64_t> res(66000);

  for (uint32_t t = 0; t < 10; ++t) {
    auto c = std::chrono::high_resolution_clock::now();
    super(res, a, b);
    auto c2 = std::chrono::high_resolution_clock::now();
    auto d = chrono::duration_cast<chrono::nanoseconds>(c2 - c);

    cout << d.count() << endl;

    for (uint32_t x = 0, xs = 4; x < xs; ++x) {
      cout << res[x] << endl;
    }

    res.clear();
  }

  cout << "-------" << endl;

  for (uint32_t t = 0; t < 10; ++t) {
    auto c = std::chrono::high_resolution_clock::now();
    square(res, a, b);
    auto c2 = std::chrono::high_resolution_clock::now();
    auto d = chrono::duration_cast<chrono::nanoseconds>(c2 - c);

    cout << d.count() << endl;
    for (uint32_t x = 0, xs = 4; x < xs; ++x) {
      cout << res[x] << endl;
    }
    res.clear();
  }
}

I created a godbolt compiler example to ease the testings by the community: Godbolt's test box

My concern is the abysmal performance of the SIMD version, even using AVX instructions (huge latencies).

The result is correct though.

Could someone have a cue to make it fast?

Edit: I made my vectors a multiple of 8 for this test case, so don't bother with the checks.

Peter Cordes
  • 328,167
  • 45
  • 605
  • 847
Kroma
  • 1,109
  • 9
  • 18
  • 2
    Since your arrays are sorted, looping over the whole array2 for every element of array1 seems very sub-optimal (O(N^2)), and probably slower than scalar looping with one pass over both arrays. Also, if your arrays contain `[1, 1, 3, 4]` / `[1,1, 2, 4, 4]`, won't this output `[1,1,1,1, 4,4]`? i.e. both ones in B get output for each one in A. – Peter Cordes Jun 27 '21 at 19:58
  • 1
    Do you have some information about the typical distribution of the input? Are long consecutive sequences a common use case? – Marc Glisse Jun 27 '21 at 19:58
  • Also, you don't need compressstore (if you have padding in your output array to allow over-run): you can just `storeu(ptr2, r)` since `r` already has 8 copies of the one element that can match. Since the compare is only true on an exact match, it doesn't matter where the elements you store actually come from, just that there are (at least) popcnt contiguous elements of that value in the output. By storing 8 and incrementing the pointer by popcnt, you can overwrite any unneeded ones later. (And BTW, latency isn't a big factor in this, just throughput.) – Peter Cordes Jun 27 '21 at 20:03
  • 1. If your sequences are sorted, you don't need a full scan in the internal loop. You don't even have to start it from the beginning of `b`. 2. You're not pushing results into `res`, and its memory allocation is not seen in the code. 3. I don't think you need `_mm512_mask_compressstoreu_epi64`, which is expensive, since all matching elements will be consecutive because the elements are sorted. – Andrey Semashev Jun 27 '21 at 20:04
  • 1
    My advice would be to start from the code of set_intersection and see how you can tweak it for your case. Maybe you want to compare (equality) the next 8 elements of both vectors, count the number of leading "true" in the result and advance by that much. Or maybe when you have a smaller current element in the second vector you want to compare (inequality) the current element of the first vector to the next 8 elements of the second to see by how much you can advance. Etc. – Marc Glisse Jun 27 '21 at 20:09
  • BTW, one useful building block might be [`vpconflictq`](https://www.felixcloutier.com/x86/vpconflictd:vpconflictq) _mm512_conflict_epi64 to compare each element of one vector against every other element, and generate a bitmap in each element. Unfortunately it's not between two separate vectors, but with 4x uint64_t values in each half of a `__m512i`, that's 4x4 = 16 useful compares, vs. 8 from one broadcasted value against another. And possibly in a more useful order? What you really want is TGL `VP2INTERSECTQ` https://en.wikipedia.org/wiki/AVX-512#VP2INTERSECT for 8x8 useful compares. – Peter Cordes Jun 27 '21 at 20:09
  • Possibly scalar compare to keep your pointers in sync (or vector vs. broadcast compare for greater-than -> tzcnt), and use that to feed `vp2intersectq` (_mm512_2intersect_epi64) if available. Or to feed `vpconflictq` on 4x 64-bit from each array (blend both ways) -> `vptestmq` with a mask that selects the appropriate matches -> `vpcompressq`. – Peter Cordes Jun 27 '21 at 20:23
  • Are your sets unique, or can you have repeats of the same value in one input? – Peter Cordes Jun 27 '21 at 20:23
  • @PeterCordes: no, the sets are made of unique values :) – Kroma Jun 28 '21 at 15:06
  • @MarcGlisse: the distribution is random, but always sorted. It is just for the example that I didn't bother to make it more "real life" :) – Kroma Jun 28 '21 at 15:07
  • Ok, then it's weird to use compressstore to store 0 or 1 elements. You might as well just `*ptr2 = *it` or `_mm_storeu_si64(ptr2, _mm512_castsi512_si128(r))` once in the outer loop, and increment the pointer by 1 or not in the inner search loop, according to whether you ever found a match. If your arrays are tiny enough that branch mispredicts cost more than the redundant work you're doing scanning the rest of the 2nd set after finding the one and only match, that would keep the branching independent of the data. You don't need any store or popcnt in the loop, just mask OR. – Peter Cordes Jun 28 '21 at 15:17
  • Of course, it's also probably not optimal to do a brute-force linear search in a sorted list, especially when the values you're looking for arrive in sorted order. – Peter Cordes Jun 28 '21 at 15:19

0 Answers0