I was trying to implement a 8x8 dct method in compute. The current implementation looks like this:
void dct_8(in mat2x4 data, out vec4 res0246, out vec4 res1357) {
res0246 = dct8_matrix_even*(data[0] + data[1].wzyx);
res1357 = dct8_matrix_odd*(data[0] - data[1].wzyx);
}
void dct_8x8(inout float data[8][8]){
float temp_buffer[8][8];
// horizontal dct
for(int i = 0; i<8; ++i){
// compiler please optimize....
mat2x4 temp;
temp[0].x = data[i][0];
temp[0].y = data[i][1];
temp[0].z = data[i][2];
temp[0].w = data[i][3];
temp[1].x = data[i][4];
temp[1].y = data[i][5];
temp[1].z = data[i][6];
temp[1].w = data[i][7];
vec4 res0246;
vec4 res1357;
dct_8(temp, res0246, res1357);
temp_buffer[i][0] = res0246.x;
temp_buffer[i][1] = res1357.x;
temp_buffer[i][2] = res0246.y;
temp_buffer[i][3] = res1357.y;
temp_buffer[i][4] = res0246.z;
temp_buffer[i][5] = res1357.z;
temp_buffer[i][6] = res0246.w;
temp_buffer[i][7] = res1357.w;
}
// vertical dct
for(int i = 0; i<8; ++i){
// compiler please optimize....
mat2x4 temp;
temp[0].x = temp_buffer[0][i];
temp[0].y = temp_buffer[1][i];
temp[0].z = temp_buffer[2][i];
temp[0].w = temp_buffer[3][i];
temp[1].x = temp_buffer[4][i];
temp[1].y = temp_buffer[5][i];
temp[1].z = temp_buffer[6][i];
temp[1].w = temp_buffer[7][i];
vec4 res0246;
vec4 res1357;
dct_8(temp, res0246, res1357);
data[0][i] = res0246.x;
data[1][i] = res1357.x;
data[2][i] = res0246.y;
data[3][i] = res1357.y;
data[4][i] = res0246.z;
data[5][i] = res1357.z;
data[6][i] = res0246.w;
data[7][i] = res1357.w;
}
}
The horizontal one's storing into buffer operation's swizzle is possible to simplify into returning a mat2x4, then interpreting that mat2x4 as mat4x2 and then transpose it (and then store the item in contiguous 8 float).
Is the compiler able to optimize this? If not, is it possible to explicitly write this swizzle operation?
(If it can be explicit, is this implementation better?)
void transpose_8x8(inout mat4x4 data[2][2])
{
mat4x4 temp;
data[0][0] = transpose(data[0][0]);
temp = transpose(data[0][1]);
data[0][1] = transpose(data[1][0]);
data[1][0] = temp;
data[1][1] = transpose(data[1][1]);
}
void dct8x8_vertical(in mat4x4 data[2][2], out mat4x4 result[2][2]){
for (int i = 0; i < 2; ++i) {
for(int j = 0; j<8; ++j){
// compiler please optimize....
mat2x4 temp = mat2x4(data[j/4][0][j%4], data[j/4][1][j%4]);
mat2x4 res;
dct_8(temp, res);
result[j/4][0][j%4] = res[0];
result[j/4][1][j%4] = res[1];
}
}
}
void dct_8x8(inout mat4x4 data[2][2])
{
mat4x4 temp_buffer[2][2];
// vertical dct
dct8x8_vertical(data, temp_buffer);
// horizontal dct
transpose_8x8(temp_buffer);
dct8x8_vertical(temp_buffer, data);
transpose_8x8(data);
}