Documentation Index Fetch the complete documentation index at: https://mintlify.com/NVIDIA/cutlass/llms.txt
Use this file to discover all available pages before exploring further.
Grouped GEMM enables you to compute a batch of GEMM operations where each operation can have a different problem size, unlike standard batched GEMM where all operations have the same dimensions.
Overview
Grouped GEMM is ideal for scenarios where you need to perform multiple matrix multiplications with varying dimensions, such as:
Multi-head attention with different head sizes
Variable-length sequence processing
Sparse neural network layers
Mixed expert models
What is Grouped GEMM?
Grouped GEMM differs from “Batched Array” GEMM:
Batched GEMM : All matrices have the same dimensions (M, N, K)
Grouped GEMM : Each group can have different dimensions
Each group performs: C[i] = A[i] × B[i] where A[i], B[i], and C[i] can have different sizes.
Implementation Example
Here’s a complete example using the Blackwell architecture with CuTe DSL:
Step 1: Define the Kernel
import cutlass
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
import cutlass.utils.blackwell_helpers as sm100_utils
class GroupedGemmKernel :
def __init__ (
self ,
acc_dtype ,
use_2cta_instrs ,
mma_tiler_mn ,
cluster_shape_mn ,
tensormap_update_mode = utils.TensorMapUpdateMode. SMEM ,
):
self .acc_dtype = acc_dtype
self .use_2cta_instrs = use_2cta_instrs
self .mma_tiler = ( * mma_tiler_mn, 1 )
self .cluster_shape_mn = cluster_shape_mn
self .tensormap_update_mode = tensormap_update_mode
self .cta_group = (
tcgen05.CtaGroup. TWO if use_2cta_instrs else tcgen05.CtaGroup. ONE
)
Step 2: Call the Kernel
@cute.jit
def __call__ (
self ,
initial_a : cute.Tensor,
initial_b : cute.Tensor,
initial_c : cute.Tensor,
group_count : int ,
problem_shape_mnkl : cute.Tensor,
strides_abc : cute.Tensor,
tensor_address_abc : cute.Tensor,
total_num_clusters : int ,
tensormap_cute_tensor : cute.Tensor,
max_active_clusters : int ,
stream : cuda.CUstream,
):
# Setup attributes based on input tensors
self .a_dtype = initial_a.element_type
self .b_dtype = initial_b.element_type
self .c_dtype = initial_c.element_type
# Configure TMA atoms for A, B, C
tiled_mma = sm100_utils.make_trivial_tiled_mma(
self .a_dtype,
self .a_major_mode,
self .b_major_mode,
self .acc_dtype,
self .cta_group,
self .mma_tiler[: 2 ],
)
# Setup TMA for each tensor
a_op = sm100_utils.cluster_shape_to_tma_atom_A(
self .cluster_shape_mn, tiled_mma.thr_id
)
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op, initial_a, a_smem_layout,
self .mma_tiler, tiled_mma,
self .cluster_layout_vmnk.shape,
)
# Launch kernel
self .kernel( ... ).launch(
grid = grid,
block = [ self .threads_per_cta, 1 , 1 ],
cluster = ( * self .cluster_shape_mn, 1 ),
stream = stream,
)
Step 3: Implement the Device Kernel
@cute.kernel
def kernel (
self ,
tiled_mma : cute.TiledMma,
tma_atom_a : cute.CopyAtom,
mA_mkl : cute.Tensor,
tma_atom_b : cute.CopyAtom,
mB_nkl : cute.Tensor,
tma_atom_c : cute.CopyAtom,
mC_mnl : cute.Tensor,
# ... other parameters
):
# Warp specialization
warp_idx = cute.arch.warp_idx()
# TMA warp: Load data
if warp_idx == self .tma_warp_id:
# Update tensormaps for each group
# Perform TMA loads
pass
# MMA warp: Compute
if warp_idx == self .mma_warp_id:
# Perform matrix multiply-accumulate
pass
# Epilogue warps: Store results
if warp_idx < self .mma_warp_id:
# Store results to global memory
pass
Running the Example
Prepare problem sizes
Define the dimensions for each group: problem_sizes = [
( 8192 , 1280 , 32 , 1 ), # Group 0: M=8192, N=1280, K=32, L=1
( 16 , 384 , 1536 , 1 ), # Group 1: M=16, N=384, K=1536, L=1
( 640 , 1280 , 16 , 1 ), # Group 2: M=640, N=1280, K=16, L=1
( 640 , 160 , 16 , 1 ), # Group 3: M=640, N=160, K=16, L=1
]
Run from command line
Execute the grouped GEMM example: python examples/blackwell/grouped_gemm.py \
--ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \
--mma_tiler_mn 128,64 --cluster_shape_mn 1,1 \
--problem_sizes_mnkl "(8192,1280,32,1),(16,384,1536,1),(640,1280,16,1),(640,160,16,1)" \
--num_groups 4 --tensormap_update_mode SMEM
Profile with NCU
Analyze performance: ncu python examples/blackwell/grouped_gemm.py \
--ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \
--mma_tiler_mn 128,64 --cluster_shape_mn 1,1 \
--problem_sizes_mnkl "(8192,1280,32,1),(16,384,1536,1),(640,1280,16,1),(640,160,16,1)" \
--num_groups 4 --tensormap_update_mode SMEM \
--warmup_iterations 1 --iterations 10 --skip_ref_check
Key Features
Warp Specialization
Grouped GEMM uses specialized warps for different tasks:
TMA Warp : Handles tensormap updates and data loading
MMA Warp : Performs matrix multiply-accumulate operations
Epilogue Warps : Handle result storage and post-processing
This specialization improves latency hiding and overall performance.
Tensormap Update Modes
Grouped GEMM supports two modes for updating tensormaps:
# Update tensormaps in shared memory
tensormap_update_mode = utils.TensorMapUpdateMode. SMEM
# Buffers 3 tensormaps (A, B, C) in SMEM (128B each)
# Better for workloads with frequent group changes
Performance varies by workload—profile both modes to find the optimal choice.
Persistent Tile Scheduling
The kernel uses persistent tile scheduling to:
Minimize kernel launch overhead
Improve load balancing across groups
Better utilize hardware resources
# Tile scheduler handles work distribution
tile_sched = utils.StaticPersistentGroupTileScheduler.create(
tile_sched_params,
bid, grid_dim,
cluster_tile_shape_mnk,
utils.create_initial_search_state(),
group_count,
problem_sizes_mnkl,
)
Constraints and Considerations
The following constraints apply to grouped GEMM:
Only FP16 and BF16 data types are supported for A and B
Output (C) can be FP16, BF16, or FP32
The contiguous dimension must be 16-byte aligned
Batch size (L) must be 1 for each group
All groups must have the same majorness for A, B, and C
Choosing MMA Tile Size
Select tile sizes based on your problem sizes:
Small problems : Use smaller tiles (64×64, 128×64)
Large problems : Use larger tiles (128×128, 256×128)
Mixed sizes : Choose a balanced tile size
Cluster Configuration
Cluster shape affects performance:
(1,1) : No clustering, good for small problems
(2,1) or (1,2) : Light clustering, balanced approach
(2,2) : Maximum clustering, best for large tiles
Memory Alignment
Ensure proper alignment for optimal performance:
# Check alignment
assert m * dtype_size % 16 == 0 , "Contiguous dimension must be 16-byte aligned"
Complete Working Example
Find the full implementation:
examples/python/CuTeDSL/blackwell/grouped_gemm.py
The example includes:
Full kernel implementation with warp specialization
Tensormap management for variable problem sizes
Reference implementation for correctness checking
Performance benchmarking utilities
Legacy Grouped GEMM
For pre-SM90 architectures, use the high-level Python interface:
import cutlass
import numpy as np
# Define problem sizes
problems = [
cutlass.op.GroupedGemmArguments(
A = np.random.randn( 128 , 256 ).astype(np.float16),
B = np.random.randn( 256 , 512 ).astype(np.float16),
C = np.zeros(( 128 , 512 ), dtype = np.float16),
D = np.zeros(( 128 , 512 ), dtype = np.float16),
)
for _ in range (num_groups)
]
# Create and run grouped GEMM
plan = cutlass.op.GroupedGemm( element = np.float16)
plan.run(problems)
See examples/python/deprecated/02_pytorch_extension_grouped_gemm.ipynb for details.
Next Steps
Basic GEMM Start with single GEMM operations
Custom Epilogue Add custom operations to grouped GEMM