I'm hoping everyone is familiar with the standard "naive" method of multiplying two (n x n
square for simplicity) matrices. In C
this is:
for(int i = 0; i < n; ++i)
for(int j = 0; j < n; ++j)
for(int k = 0; k < n; ++k)
C[i*n + j] += A[i*n + k] * B[k*n + j];
The above method computes the dot (inner) product of a row of A
with a column of B
and is easy to implement in OpenCL as follows:
__kernel void matmul_ocl(
__global const float *A,
__global const float *B,
__global float *C,
const int n
)
{
const int row = get_global_id(1); // row
const int col = get_global_id(0); // col
for(int i = 0; i < n; i++)
C[row*n + col] += A[row*n + i]*B[i*n + col];
}
Interchanging the two inner-most loops of the original C
implementation results in a method that computes outer products, i.e., it computes rank-1 updates of the rows of the C
matrix:
for(int i = 0; i < n; ++i)
for(int k = 0; k < n; ++k)
for(int j = 0; j < n; ++j)
C[i*n + j] += A[i*n + k] * B[k*n + j];
Does anybody know how to properly implement the above outer-product method in OpenCL? I have two of my attempts pasted below but I just can't seem to nail it
Attempt 1
__kernel void matmul_ocl(
__global const float *A,
__global const float *B,
__global float *C,
const int n
)
{
const int row = get_global_id(1); // row
const int col = get_global_id(0); // col
__local float r;
r = A[row*n + col];
barrier(CLK_LOCAL_MEM_FENCE);
for(int i = 0; i < n; ++i)
C[row*n + i] += r * B[col*n + i];
}
Attempt 2
#define TS 1
__kernel void matmul_ocl(
__global const float *A,
__global const float *B,
__global float *C,
int n)
{
// Thread coordinates
const int row = get_local_id(1); // row
const int col = get_local_id(0); // col
// Group tile coordinates
const int by = get_group_id(1); // row
const int bx = get_group_id(0); // col
A += TS*by + TS*bx*n + n*row + (col);
B += TS*by*n + n*row + (col);
C += TS*bx*n + n*(row) + col;
__global const float *Blast = B + n;
float c[2] = {0.0f,0.0f};
float* cptr = &c[0];
__local float bs[2];
do
{
bs[0] = B[0];
bs[1] = B[n];
barrier(CLK_LOCAL_MEM_FENCE);
*cptr += A[0] * bs[0];
*cptr++ += A[0] * bs[1];
B++;
barrier(CLK_LOCAL_MEM_FENCE);
} while( B < Blast );
C[0] += c[0];
C[1] += c[1];
}