I'm struggling to implement an efficient Radix Sort in DirectX11. I need to quickly sort no more than 256K items. It doesn't need to be in-place. Using CUDA / OpenCL is not an option.
I have it almost working, but it has a few problems:
- The histogram generation is quite fast, but still takes much longer than quoted figures online
- On subsequent sort passes, the order of keys whose lower bits are identical, is not preserved, due to the InterlockedAdd on the histogram buffer in cp_sort (see below)
- cp_sort is really slow, due to global memory access on that same InterlockedAdd
I've been trying to understand how I can fix this, based on algorithms online, but I can't seem to understand them.
Here are my 3 kernels. (it's for a billboarded particle system, so the term 'quad' just refers to an item to be sorted)
// 8 bits per radix
#define NUM_RADIXES 4
#define RADIX_SIZE 256
// buffers
StructuredBuffer<Quad> g_particleQuadBuffer : register( t20 );
RWStructuredBuffer<uint> g_indexBuffer : register( u0 );
RWStructuredBuffer<uint> g_histogram : register( u1 );
RWStructuredBuffer<uint> g_indexOutBuffer : register( u2 );
// quad buffer counter
cbuffer BufferCounter : register( b12 )
{
uint numQuads;
uint pad[ 3 ];
}
// on-chip memory for fast histogram calculation
#define SHARED_MEM_PADDING 8 // to try and reduce bank conflicts
groupshared uint g_localHistogram[ NUM_RADIXES * RADIX_SIZE * SHARED_MEM_PADDING ];
// convert a float to a sortable int, assuming all inputs are negative
uint floatToUInt( float input )
{
return 0xffffffff - asuint( input );
}
// initialise the indices, and build the histograms (dispatched with numQuads / ( NUM_RADIXES * RADIX_SIZE ))
[numthreads( NUM_RADIXES * RADIX_SIZE, 1, 1 )]
void cp_histogram( uint3 groupID : SV_GroupID, uint groupIndex : SV_GroupIndex )
{
// initialise local histogram
g_localHistogram[ groupIndex * SHARED_MEM_PADDING ] = 0;
GroupMemoryBarrierWithGroupSync();
// check within range
uint quadIndex = ( groupID.x * NUM_RADIXES * RADIX_SIZE ) + groupIndex;
if ( quadIndex < numQuads )
{
// initialise index
g_indexBuffer[ quadIndex ] = quadIndex;
// floating point to sortable uint
uint value = floatToUInt( g_particleQuadBuffer[ quadIndex ].v[ 0 ].projected.z );
// build 8-bit histograms
uint value0 = ( value ) & 0xff;
uint value1 = ( value >> 8 ) & 0xff;
uint value2 = ( value >> 16 ) & 0xff;
uint value3 = ( value >> 24 );
InterlockedAdd( g_localHistogram[ ( value0 ) * SHARED_MEM_PADDING ], 1 );
InterlockedAdd( g_localHistogram[ ( value1 + 256 ) * SHARED_MEM_PADDING ], 1 );
InterlockedAdd( g_localHistogram[ ( value2 + 512 ) * SHARED_MEM_PADDING ], 1 );
InterlockedAdd( g_localHistogram[ ( value3 + 768 ) * SHARED_MEM_PADDING ], 1 );
}
// write back to histogram
GroupMemoryBarrierWithGroupSync();
InterlockedAdd( g_histogram[ groupIndex ], g_localHistogram[ groupIndex * SHARED_MEM_PADDING ] );
}
// build the offsets based on histograms (dispatched with 1)
// NOTE: I know this could be more efficient, but from my profiling, its time is negligible compared to the other 2 stages, and I can optimise this separately using a parallel prefix sum if I need to
[numthreads( NUM_RADIXES, 1, 1 )]
void cp_offsets( uint groupIndex : SV_GroupIndex )
{
uint sum = 0;
uint base = ( groupIndex * RADIX_SIZE );
for ( uint i = 0; i < RADIX_SIZE; i++ )
{
uint tempSum = g_histogram[ base + i ] + sum;
g_histogram[ base + i ] = sum;
sum = tempSum;
}
}
// move the data (dispatched with numQuads / ( NUM_RADIXES * RADIX_SIZE ))
uint currentRadix;
[numthreads( NUM_RADIXES * RADIX_SIZE, 1, 1 )]
void cp_sort( uint3 groupID : SV_GroupID, uint groupIndex : SV_GroupIndex )
{
// check within range
uint quadIndex = ( groupID.x * NUM_RADIXES * RADIX_SIZE ) + groupIndex;
if ( quadIndex < numQuads )
{
uint fi = g_indexBuffer[ quadIndex ];
uint depth = floatToUInt( g_particleQuadBuffer[ fi ].v[ 0 ].projected.z );
uint radix = currentRadix;
uint pos = ( depth >> ( radix * 8 ) ) & 0xff;
uint writePosition;
InterlockedAdd( g_histogram[ radix * RADIX_SIZE + pos ], 1, writePosition );
g_indexOutBuffer[ writePosition ] = fi;
}
}
Can anyone offer any help on how to fix/optimise this? I'd love to understand what some of the more complex algorithms for GPU Radix sorting are actually doing!
Thanks!