I'm working with different CUDA kernels (gemm3, gemm4, and gemm5) for matrix multiplication:
gemm3: baseline of shared memory GEMM
gemm4: less thread blocks in x dimension
gemm5: less blocks in both x and y dimension
After profiling, I noticed that the number of shared memory store bank conflicts decreases progressively across these kernels, even disappearing in gemm5.
As the answer here described, although threadIdx.x
in gemm3 is intended to be conflict-free during storage as the final subscript, profiling results actually demonstrate that gemm3 experiences the highest number of conflicts. Also I'm having difficulty understanding the idea of decreasing bank conflicts in gemm4 and gemm5.
I'm hoping to learn more about this topic.
__global__ void gemm3(const float* A, const float* B, float* C, int M, int N, int K){
int xid = threadIdx.x + blockIdx.x * blockDim.x;
int yid = threadIdx.y + blockIdx.y * blockDim.y;
__shared__ float smem_A[BLOCK_SIZE][BLOCK_SIZE];
__shared__ float smem_B[BLOCK_SIZE][BLOCK_SIZE];
float sum = 0.0f;
for (int kid = 0;kid < K / BLOCK_SIZE;kid++){
smem_A[threadIdx.y][threadIdx.x] = A[yid * K + kid * BLOCK_SIZE + threadIdx.x];
smem_B[threadIdx.y][threadIdx.x] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + xid];
__syncthreads();
for (int sub_k = 0; sub_k <BLOCK_SIZE;sub_k++){
sum += smem_A[threadIdx.y][sub_k] * smem_B[sub_k][threadIdx.x] ;
}
__syncthreads();
}
C[yid * N + xid] = sum;
}
__global__ void gemm4(const float* A, const float* B, float* C, int M, int N, int K){
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int bx = blockIdx.x;
const int by = blockIdx.y;
int xid = tx + bx * blockDim.x;
int yid = ty + by * blockDim.y;
const int block_offset = BLOCK_SIZE * 4;
__shared__ float smem_A[BLOCK_SIZE][BLOCK_SIZE];
__shared__ float smem_B[BLOCK_SIZE][BLOCK_SIZE * 4];
float sum[4] = {0,0,0,0};
for (int kid = 0;kid < K / BLOCK_SIZE;kid++){
smem_A[threadIdx.y][threadIdx.x] = A[yid * K + kid * BLOCK_SIZE + threadIdx.x];
smem_B[threadIdx.y][threadIdx.x * 4] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4];
smem_B[threadIdx.y][threadIdx.x * 4 + 1] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4 + 1];
smem_B[threadIdx.y][threadIdx.x * 4 + 2] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4 + 2];
smem_B[threadIdx.y][threadIdx.x * 4 + 3] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4 + 3];
__syncthreads();
for (int sub_k = 0; sub_k <BLOCK_SIZE;sub_k++){
sum[0] = fma(smem_A[threadIdx.y][sub_k], smem_B[sub_k][threadIdx.x*4], sum[0]);
sum[1] = fma(smem_A[threadIdx.y][sub_k], smem_B[sub_k][threadIdx.x*4+1], sum[1]);
sum[2] = fma(smem_A[threadIdx.y][sub_k], smem_B[sub_k][threadIdx.x*4+2], sum[2]);
sum[3] = fma(smem_A[threadIdx.y][sub_k], smem_B[sub_k][threadIdx.x*4+3], sum[3]);
}
__syncthreads();
}
C[yid * N + bx * block_offset+ threadIdx.x * 4] = sum[0];
C[yid * N + bx * block_offset+ threadIdx.x * 4 + 1] = sum[1];
C[yid * N + bx * block_offset+ threadIdx.x * 4 + 2] = sum[2];
C[yid * N + bx * block_offset+ threadIdx.x * 4 + 3] = sum[3];
}
__global__ void gemm5(const float* A, const float* B, float* C, int M, int N, int K){
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int bx = blockIdx.x;
const int by = blockIdx.y;
int xid = tx + bx * blockDim.x;
int yid = ty + by * blockDim.y;
const int block_offset = BLOCK_SIZE * 4;
__shared__ float smem_A[BLOCK_SIZE * 4][BLOCK_SIZE];
__shared__ float smem_B[BLOCK_SIZE][BLOCK_SIZE * 4];
float sum[4][4]= {0.f};
for (int kid = 0;kid < K / BLOCK_SIZE;kid++){
smem_A[threadIdx.y * 4][threadIdx.x] = A[(by * block_offset + threadIdx.y * 4) * K + kid * BLOCK_SIZE + threadIdx.x];
smem_A[threadIdx.y * 4 + 1][threadIdx.x] = A[(by * block_offset + threadIdx.y * 4 + 1) * K + kid * BLOCK_SIZE + threadIdx.x];
smem_A[threadIdx.y * 4 + 2][threadIdx.x] = A[(by * block_offset + threadIdx.y * 4 + 2) * K + kid * BLOCK_SIZE + threadIdx.x];
smem_A[threadIdx.y * 4 + 3][threadIdx.x] = A[(by * block_offset + threadIdx.y * 4 + 3) * K + kid * BLOCK_SIZE + threadIdx.x];
smem_B[threadIdx.y][threadIdx.x * 4] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4];
smem_B[threadIdx.y][threadIdx.x * 4 + 1] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4 + 1];
smem_B[threadIdx.y][threadIdx.x * 4 + 2] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4 + 2];
smem_B[threadIdx.y][threadIdx.x * 4 + 3] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4 + 3];
__syncthreads();
for (int sub_k = 0; sub_k <BLOCK_SIZE;sub_k++){
for (int i = 0; i < 4; i++){
for (int j = 0; j < 4; j++){
sum[i][j] = fma(smem_A[threadIdx.y * 4 + i][sub_k], smem_B[sub_k][threadIdx.x * 4 + j], sum[i][j]);
}
}
}
__syncthreads();
}
for (int i = 0; i < 4; i++){
for (int j = 0; j < 4; j++){
C[(by * block_offset + threadIdx.y * 4 + i) * N + bx * block_offset + threadIdx.x * 4 + j] = sum[i][j];
}
}
}
profiling results:
gemm3(...), Context 1, Stream 19
----------------------------------------------------------------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 160811
----------------------------------------------------------------------
gemm4(...), Context 1, Stream 20
----------------------------------------------------------------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 57492
----------------------------------------------------------------------
gemm5(...), Context 1, Stream 21
----------------------------------------------------------------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum 0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum 0
----------------------------------------------------------------------