In my code I have to handle "unmasking" of websocket packets, which essentially means XOR'ing unaligned data of arbitrary length. Thanks to SO (Websocket data unmasking / multi byte xor) I already have found out how to (hopefully) speed this up using SSE2/AVX2 extensions, but looking at it now, it seems to me that my handling of unaligned data is totally sub-optimal. Is there any way to optimize my code or at least make it simpler with same performance, or is my code already the best performing?
Here's the important part of the code (for the question I'm assuming that data will always be at least enough to run the AVX2 cycle once, but at the same time it will mostly run only a few times at most):
// circular shift left for uint32
int cshiftl_u32(uint32_t num, uint8_t shift) {
return (num << shift) | (num >> (32 - shift));
}
// circular shift right for uint32
int cshiftr_u32(uint32_t num, uint8_t shift) {
return (num >> shift) | (num << (32 - shift));
}
void optimized_xor_32( uint32_t mask, uint8_t *ds, uint8_t *de ) {
if (ds == de) return; // zero data len -> nothing to do
uint8_t maskOffset = 0;
// process single bytes till 4 byte alignment ( <= 3 )
for (; ds < de && ( (uint64_t)ds & (uint64_t)3 ); ds++) {
*ds ^= *((uint8_t *)(&mask) + maskOffset);
maskOffset = (maskOffset + 1) & (uint8_t)3;
}
if (ds == de) return; // done, return
if (maskOffset != 0) { // circular left-shift mask around so it works for other instructions
mask = cshiftl_u32(mask, maskOffset);
maskOffset = 0;
}
// process 4 byte block till 8 byte alignment ( <= 1 )
uint8_t *de32 = (uint8_t *)((uint64_t)de & ~((uint64_t)31));
if ( ds < de32 && ( (uint64_t)de & (uint64_t)7 ) ) {
*(uint32_t *)ds ^= mask; // mask is uint32_t
if (++ds == de) return;
}
// process 8 byte block till 16 byte alignment ( <= 1 )
uint64_t mask64 = mask | (mask << 4);
uint8_t *de64 = (uint8_t *)((uint64_t)de & ~((uint64_t)63));
if ( ds < de64 && ( (uint64_t)ds & (uint64_t)15 ) ) {
*(uint64_t *)ds ^= mask64;
if (++ds == de) return; // done, return
}
// process 16 byte block till 32 byte alignment ( <= 1) (if supported)
#ifdef CPU_SSE2
__m128i v128, v128_mask;
v128_mask = _mm_set1_epi32(mask);
uint8_t *de128 = (uint8_t *)((uint64_t)de & ~((uint64_t)127));
if ( ds < de128 && ( (uint64_t)ds & (uint64_t)31 ) ) {
v128 = _mm_load_si128((__m128i *)ds);
v128 = _mm_xor_si128(v128, v128_mask);
_mm_store_si128((__m128i *)ds, v128);
if (++ds == de) return; // done, return
}
#endif
#ifdef CPU_AVX2 // process 32 byte blocks (if supported -> haswell upwards)
__m256i v256, v256_mask;
v256_mask = _mm256_set1_epi32(mask);
uint8_t *de256 = (uint8_t *)((uint64_t)de & ~((uint64_t)255));
for (; ds < de256; ds+=32) {
v256 = _mm256_load_si256((__m256i *)ds);
v256 = _mm256_xor_si256(v256, v256_mask);
_mm256_store_si256((__m256i *)ds, v256);
}
if (ds == de) return; // done, return
#endif
#ifdef CPU_SSE2 // process remaining 16 byte blocks (if supported)
for (; ds < de128; ds+=16) {
v128 = _mm_load_si128((__m128i *)ds);
v128 = _mm_xor_si128(v128, v128_mask);
_mm_store_si128((__m128i *)ds, v128);
}
if (ds == de) return; // done, return
#endif
// process remaining 8 byte blocks
// this should always be supported, so remaining can be assumed to be executed <= 1 times
for (; ds < de64; ds += 8) {
*(uint64_t *)ds ^= mask64;
}
if (ds == de) return; // done, return
// process remaining 4 byte blocks ( <= 1)
if (ds < de32) {
*(uint32_t *)ds ^= mask;
if (++ds == de) return; // done, return
}
// process remaining bytes ( <= 3)
for (; ds < de; ds ++) {
*ds ^= *((uint8_t *)(&mask) + maskOffset);
maskOffset = (maskOffset + 1) & (uint8_t)3;
}
}
P.S.: Please ignore the use of #ifdef instead of cpuid or the like for cpu flag detection.