Skip to main content

Documentation Index

Fetch the complete documentation index at: https://mintlify.com/NVIDIA/cutlass/llms.txt

Use this file to discover all available pages before exploring further.

Batched GEMM Example

This example demonstrates how to use CUTLASS to compute batched GEMM operations in two different ways:
  1. Strided batched GEMM: Matrices separated by a fixed stride in memory
  2. Array GEMM: Arbitrary pointers to each matrix in the batch

Overview

Batched GEMM operations compute multiple independent matrix multiplications:
C[i] = alpha * (A[i] x B[i]) + beta * C[i]  for i = 0 to batch_count-1
This is common in many applications including neural network training, computer graphics, and scientific computing.

Key Concepts

  • Strided batched GEMM: Efficient when matrices are laid out with uniform spacing
  • Array GEMM: Flexible approach for arbitrary memory layouts
  • Batch stride: Distance in memory between consecutive matrices
  • Performance optimization: Amortize kernel launch overhead across multiple operations

Memory Layout

Consider a batch of 2 matrices with dimensions M=6, N=3, K=2:

Matrix C Layout (M=6, N=3, batch=2)

-----------------------------------------------------------
| (0,0,0) | (0,0,1) | (0,0,2) | (1,0,0) | (1,0,1) | (1,0,2) |
-----------------------------------------------------------
| (0,1,0) | (0,1,1) | (0,1,2) | (1,1,0) | (1,1,1) | (1,1,2) |
-----------------------------------------------------------
|    ...  |   ...   |   ...   |   ...   |   ...   |   ...   |
-----------------------------------------------------------
            batch 0          |           batch 1
Where (batch_idx, row_idx, column_idx) denotes each element. The batch stride is: batch_stride_C = ldc * N

Implementation

1
Strided Batched GEMM
2
Use when your matrices are laid out with uniform spacing in memory:
3
#include "cutlass/gemm/device/gemm_batched.h"

cudaError_t cutlass_strided_batched_sgemm(
  int m, 
  int n,
  int k,
  float alpha,
  float const *A,
  int lda,
  long long int batch_stride_A,
  float const *B,
  int ldb,
  long long int batch_stride_B,
  float *C,
  int ldc,
  long long int batch_stride_C,
  float beta,
  int batch_count) {

  using Gemm = cutlass::gemm::device::GemmBatched<
    float, cutlass::layout::ColumnMajor,
    float, cutlass::layout::ColumnMajor,
    float, cutlass::layout::ColumnMajor
  >;

  Gemm gemm_op;

  cutlass::Status status = gemm_op({
    {m, n, k},
    {A, lda}, 
    batch_stride_A,
    {B, ldb}, 
    batch_stride_B,
    {C, ldc}, 
    batch_stride_C,
    {C, ldc}, 
    batch_stride_C,
    {alpha, beta},
    batch_count
  });

  if (status != cutlass::Status::kSuccess) {
    return cudaErrorUnknown;
  }

  return cudaSuccess;
}
4
Array GEMM
5
Use when matrices are scattered in memory with irregular spacing:
6
#include "cutlass/gemm/device/gemm_array.h"

cudaError_t cutlass_array_sgemm(
  int m,
  int n,
  int k,
  float alpha,
  float const * const *A,  // Array of pointers
  int lda,
  float const * const *B,  // Array of pointers
  int ldb,
  float * const *C,        // Array of pointers
  int ldc,
  float beta,
  int batch_count) {

  using Gemm = cutlass::gemm::device::GemmArray<
    float, cutlass::layout::ColumnMajor,
    float, cutlass::layout::ColumnMajor,
    float, cutlass::layout::ColumnMajor
  >;

  Gemm gemm_op;

  cutlass::Status status = gemm_op({
    {m, n, k},
    A, lda,
    B, ldb,
    C, ldc,
    C, ldc,
    {alpha, beta},
    batch_count
  });

  if (status != cutlass::Status::kSuccess) {
    return cudaErrorUnknown;
  }

  return cudaSuccess;
}
7
Calculate batch strides
8
For strided batched GEMM, calculate the stride between consecutive matrices:
9
int m = 520, n = 219, k = 129;
int batch_count = 17;

int lda = m;
int ldb = k * batch_count;
int ldc = m;

// Stride between consecutive matrices in the batch
long long int batch_stride_A = static_cast<long long int>(lda) * static_cast<long long int>(k);
long long int batch_stride_B = static_cast<long long int>(k);
long long int batch_stride_C = static_cast<long long int>(ldc) * static_cast<long long int>(n);
10
Setup array of pointers for Array GEMM
11
For array GEMM, create arrays of pointers to each matrix:
12
// Allocate host memory for pointers
std::vector<float*> host_ptr_A(batch_count);
std::vector<float*> host_ptr_B(batch_count);
std::vector<float*> host_ptr_C(batch_count);

// Matrices can be in any order - not required to be uniformly spaced
std::vector<size_t> permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5};
for (size_t b_idx = 0; b_idx < batch_count; b_idx++) {
  host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A;
  host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B;
  host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C;
}

// Allocate device memory for pointer arrays
float const **ptr_A;
float const **ptr_B;
float **ptr_C;

cudaMalloc(&ptr_A, batch_count * sizeof(float*));
cudaMalloc(&ptr_B, batch_count * sizeof(float*));
cudaMalloc(&ptr_C, batch_count * sizeof(float*));

// Copy pointers to device
cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);

// Launch array GEMM
cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count);
13
Complete example
14
Here’s a complete example running both approaches:
15
int main() {
  cudaError_t result = cudaSuccess;
  
  // Run both strided batched GEMM and array GEMM
  for (bool use_array : {false, true}) {
    result = run_batched_gemm(use_array);
    if (result == cudaSuccess) {
      std::cout << "Passed." << std::endl;
    } else {
      break;
    }
  }

  return result == cudaSuccess ? 0 : -1;
}

Building and Running

Build the example

cd /path/to/cutlass
mkdir build && cd build
cmake .. -DCUTLASS_NVCC_ARCHS='75;80;86'
make 05_batched_gemm

Run the example

./examples/05_batched_gemm/05_batched_gemm
Expected output:
Running strided batched gemm
Passed.
Running array gemm
Passed.

Source Code Location

The complete source code for this example is available at:
  • examples/05_batched_gemm/batched_gemm.cu

What This Example Demonstrates

  1. Two batching modes: Both strided and array-based batched GEMM
  2. Flexible memory layouts: How to handle both regular and irregular memory patterns
  3. Pointer management: Setting up device pointer arrays for array GEMM
  4. Correctness verification: Reference implementation for validating results

Performance Considerations

  • Strided batched GEMM is typically faster when matrices are uniformly spaced because:
    • Simpler addressing logic
    • Better memory access patterns
    • Less pointer indirection
  • Array GEMM provides flexibility when:
    • Matrices are scattered in memory
    • Each batch item comes from different allocations
    • You need arbitrary ordering of operations

Key Takeaways

  • Use GemmBatched for strided batched operations with uniform spacing
  • Use GemmArray for arbitrary pointer arrays with irregular layouts
  • Batch operations amortize kernel launch overhead across multiple GEMMs
  • Both approaches share the same underlying optimizations for individual matrix multiplications

Next Steps