Documentation Index
Fetch the complete documentation index at: https://mintlify.com/NVIDIA/cutlass/llms.txt
Use this file to discover all available pages before exploring further.
Batched GEMM Example
This example demonstrates how to use CUTLASS to compute batched GEMM operations in two different ways:- Strided batched GEMM: Matrices separated by a fixed stride in memory
- Array GEMM: Arbitrary pointers to each matrix in the batch
Overview
Batched GEMM operations compute multiple independent matrix multiplications:Key Concepts
- Strided batched GEMM: Efficient when matrices are laid out with uniform spacing
- Array GEMM: Flexible approach for arbitrary memory layouts
- Batch stride: Distance in memory between consecutive matrices
- Performance optimization: Amortize kernel launch overhead across multiple operations
Memory Layout
Consider a batch of 2 matrices with dimensions M=6, N=3, K=2:Matrix C Layout (M=6, N=3, batch=2)
(batch_idx, row_idx, column_idx) denotes each element.
The batch stride is: batch_stride_C = ldc * N
Implementation
#include "cutlass/gemm/device/gemm_batched.h"
cudaError_t cutlass_strided_batched_sgemm(
int m,
int n,
int k,
float alpha,
float const *A,
int lda,
long long int batch_stride_A,
float const *B,
int ldb,
long long int batch_stride_B,
float *C,
int ldc,
long long int batch_stride_C,
float beta,
int batch_count) {
using Gemm = cutlass::gemm::device::GemmBatched<
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor
>;
Gemm gemm_op;
cutlass::Status status = gemm_op({
{m, n, k},
{A, lda},
batch_stride_A,
{B, ldb},
batch_stride_B,
{C, ldc},
batch_stride_C,
{C, ldc},
batch_stride_C,
{alpha, beta},
batch_count
});
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}
return cudaSuccess;
}
#include "cutlass/gemm/device/gemm_array.h"
cudaError_t cutlass_array_sgemm(
int m,
int n,
int k,
float alpha,
float const * const *A, // Array of pointers
int lda,
float const * const *B, // Array of pointers
int ldb,
float * const *C, // Array of pointers
int ldc,
float beta,
int batch_count) {
using Gemm = cutlass::gemm::device::GemmArray<
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor
>;
Gemm gemm_op;
cutlass::Status status = gemm_op({
{m, n, k},
A, lda,
B, ldb,
C, ldc,
C, ldc,
{alpha, beta},
batch_count
});
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}
return cudaSuccess;
}
int m = 520, n = 219, k = 129;
int batch_count = 17;
int lda = m;
int ldb = k * batch_count;
int ldc = m;
// Stride between consecutive matrices in the batch
long long int batch_stride_A = static_cast<long long int>(lda) * static_cast<long long int>(k);
long long int batch_stride_B = static_cast<long long int>(k);
long long int batch_stride_C = static_cast<long long int>(ldc) * static_cast<long long int>(n);
// Allocate host memory for pointers
std::vector<float*> host_ptr_A(batch_count);
std::vector<float*> host_ptr_B(batch_count);
std::vector<float*> host_ptr_C(batch_count);
// Matrices can be in any order - not required to be uniformly spaced
std::vector<size_t> permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5};
for (size_t b_idx = 0; b_idx < batch_count; b_idx++) {
host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A;
host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B;
host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C;
}
// Allocate device memory for pointer arrays
float const **ptr_A;
float const **ptr_B;
float **ptr_C;
cudaMalloc(&ptr_A, batch_count * sizeof(float*));
cudaMalloc(&ptr_B, batch_count * sizeof(float*));
cudaMalloc(&ptr_C, batch_count * sizeof(float*));
// Copy pointers to device
cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
// Launch array GEMM
cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count);
int main() {
cudaError_t result = cudaSuccess;
// Run both strided batched GEMM and array GEMM
for (bool use_array : {false, true}) {
result = run_batched_gemm(use_array);
if (result == cudaSuccess) {
std::cout << "Passed." << std::endl;
} else {
break;
}
}
return result == cudaSuccess ? 0 : -1;
}
Building and Running
Build the example
Run the example
Source Code Location
The complete source code for this example is available at:examples/05_batched_gemm/batched_gemm.cu
What This Example Demonstrates
- Two batching modes: Both strided and array-based batched GEMM
- Flexible memory layouts: How to handle both regular and irregular memory patterns
- Pointer management: Setting up device pointer arrays for array GEMM
- Correctness verification: Reference implementation for validating results
Performance Considerations
-
Strided batched GEMM is typically faster when matrices are uniformly spaced because:
- Simpler addressing logic
- Better memory access patterns
- Less pointer indirection
-
Array GEMM provides flexibility when:
- Matrices are scattered in memory
- Each batch item comes from different allocations
- You need arbitrary ordering of operations
Key Takeaways
- Use
GemmBatchedfor strided batched operations with uniform spacing - Use
GemmArrayfor arbitrary pointer arrays with irregular layouts - Batch operations amortize kernel launch overhead across multiple GEMMs
- Both approaches share the same underlying optimizations for individual matrix multiplications
Next Steps
- Learn about Basic GEMM for single matrix multiplication
- Explore Fused Operations to combine GEMM with activation functions
- Check out Convolution for batched convolution operations