1

I'm working with different CUDA kernels (gemm3, gemm4, and gemm5) for matrix multiplication:

gemm3: baseline of shared memory GEMM
gemm4: less thread blocks in x dimension
gemm5: less blocks in both x and y dimension

After profiling, I noticed that the number of shared memory store bank conflicts decreases progressively across these kernels, even disappearing in gemm5.
As the answer here described, although threadIdx.x in gemm3 is intended to be conflict-free during storage as the final subscript, profiling results actually demonstrate that gemm3 experiences the highest number of conflicts. Also I'm having difficulty understanding the idea of decreasing bank conflicts in gemm4 and gemm5.

I'm hoping to learn more about this topic.

__global__ void gemm3(const float* A, const float* B, float* C, int M, int N, int K){
    int xid = threadIdx.x  + blockIdx.x * blockDim.x;
    int yid = threadIdx.y  + blockIdx.y * blockDim.y;
    __shared__ float smem_A[BLOCK_SIZE][BLOCK_SIZE];
    __shared__ float smem_B[BLOCK_SIZE][BLOCK_SIZE];
    float sum = 0.0f;
    for (int kid = 0;kid < K / BLOCK_SIZE;kid++){
        smem_A[threadIdx.y][threadIdx.x] = A[yid * K + kid * BLOCK_SIZE + threadIdx.x];
        smem_B[threadIdx.y][threadIdx.x] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + xid];
        __syncthreads();
        for (int sub_k = 0; sub_k <BLOCK_SIZE;sub_k++){
            sum += smem_A[threadIdx.y][sub_k] * smem_B[sub_k][threadIdx.x] ;
        } 
        __syncthreads();
    }    
    C[yid * N + xid] = sum;
}
__global__ void gemm4(const float* A, const float* B, float* C, int M, int N, int K){
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;
    const int bx = blockIdx.x;
    const int by = blockIdx.y;
    int xid = tx + bx * blockDim.x;
    int yid = ty + by * blockDim.y;
    const int block_offset = BLOCK_SIZE * 4;
    __shared__ float smem_A[BLOCK_SIZE][BLOCK_SIZE];
    __shared__ float smem_B[BLOCK_SIZE][BLOCK_SIZE * 4];
    float sum[4] = {0,0,0,0};
    for (int kid = 0;kid < K / BLOCK_SIZE;kid++){
        smem_A[threadIdx.y][threadIdx.x] = A[yid * K + kid * BLOCK_SIZE + threadIdx.x];
        smem_B[threadIdx.y][threadIdx.x * 4]     = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4];
        smem_B[threadIdx.y][threadIdx.x * 4 + 1] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4 + 1];
        smem_B[threadIdx.y][threadIdx.x * 4 + 2] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4 + 2];
        smem_B[threadIdx.y][threadIdx.x * 4 + 3] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4 + 3];
        __syncthreads();
        for (int sub_k = 0; sub_k <BLOCK_SIZE;sub_k++){
            sum[0] = fma(smem_A[threadIdx.y][sub_k],  smem_B[sub_k][threadIdx.x*4], sum[0]);
            sum[1] = fma(smem_A[threadIdx.y][sub_k],  smem_B[sub_k][threadIdx.x*4+1], sum[1]);
            sum[2] = fma(smem_A[threadIdx.y][sub_k],  smem_B[sub_k][threadIdx.x*4+2], sum[2]);
            sum[3] = fma(smem_A[threadIdx.y][sub_k],  smem_B[sub_k][threadIdx.x*4+3], sum[3]);
        } 
        __syncthreads();
    }    
    C[yid * N + bx * block_offset+ threadIdx.x * 4] = sum[0];
    C[yid * N + bx * block_offset+ threadIdx.x * 4 + 1] = sum[1];
    C[yid * N + bx * block_offset+ threadIdx.x * 4 + 2] = sum[2];
    C[yid * N + bx * block_offset+ threadIdx.x * 4 + 3] = sum[3];
}
__global__ void gemm5(const float* A, const float* B, float* C, int M, int N, int K){
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;
    const int bx = blockIdx.x;
    const int by = blockIdx.y;
    int xid = tx + bx * blockDim.x;
    int yid = ty + by * blockDim.y;
    const int block_offset = BLOCK_SIZE * 4;
    __shared__ float smem_A[BLOCK_SIZE * 4][BLOCK_SIZE];
    __shared__ float smem_B[BLOCK_SIZE][BLOCK_SIZE * 4];
    float sum[4][4]= {0.f};
    for (int kid = 0;kid < K / BLOCK_SIZE;kid++){
        smem_A[threadIdx.y * 4][threadIdx.x] = A[(by * block_offset + threadIdx.y * 4)  * K + kid * BLOCK_SIZE + threadIdx.x];
        smem_A[threadIdx.y * 4 + 1][threadIdx.x] = A[(by * block_offset + threadIdx.y * 4 + 1)  * K + kid * BLOCK_SIZE + threadIdx.x];
        smem_A[threadIdx.y * 4 + 2][threadIdx.x] = A[(by * block_offset + threadIdx.y * 4 + 2)  * K + kid * BLOCK_SIZE + threadIdx.x];
        smem_A[threadIdx.y * 4 + 3][threadIdx.x] = A[(by * block_offset + threadIdx.y * 4 + 3)  * K + kid * BLOCK_SIZE + threadIdx.x];

        smem_B[threadIdx.y][threadIdx.x * 4]     = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4];
        smem_B[threadIdx.y][threadIdx.x * 4 + 1] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4 + 1];
        smem_B[threadIdx.y][threadIdx.x * 4 + 2] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4 + 2];
        smem_B[threadIdx.y][threadIdx.x * 4 + 3] = B[(kid * BLOCK_SIZE + threadIdx.y) * N + bx * block_offset+ threadIdx.x * 4 + 3];
        __syncthreads();
        for (int sub_k = 0; sub_k <BLOCK_SIZE;sub_k++){
            for (int i = 0; i < 4; i++){
                for (int j = 0; j < 4; j++){
                    sum[i][j] = fma(smem_A[threadIdx.y * 4 + i][sub_k],  smem_B[sub_k][threadIdx.x * 4 + j], sum[i][j]);
                }
            }
        } 
        __syncthreads();
    }
    for (int i = 0; i < 4; i++){
        for (int j = 0; j < 4; j++){
            C[(by * block_offset + threadIdx.y * 4 + i) * N + bx * block_offset + threadIdx.x * 4 + j] = sum[i][j];
        }
    }
}

profiling results:

gemm3(...), Context 1, Stream 19
----------------------------------------------------------------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum      0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum  160811
----------------------------------------------------------------------

gemm4(...),  Context 1, Stream 20
----------------------------------------------------------------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum      0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum  57492
----------------------------------------------------------------------

gemm5(...), Context 1, Stream 21
----------------------------------------------------------------------
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum      0
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum      0
----------------------------------------------------------------------

  • 2
    For a 32x32 threadblock, at least, your first kernel (`gemm3`) doesn't have any bank-conflicted accesses. [This](https://forums.developer.nvidia.com/t/shared-memory-bank-conflicts-and-nsight-metric/115731/15) may be of interest: "There is not currently a hardware counter that only counts bank conflicts. Other arbitration conflicts that result in a replayed wavefront are included. Summing L1 Wavefronts Shared Excessive on the Source View page is the best method to only count bank conflicts. " – – Robert Crovella Aug 20 '23 at 18:32

0 Answers0