Use this file to discover all available pages before exploring further.
Custom epilogues allow you to fuse element-wise operations directly into the GEMM kernel, eliminating separate kernel launches and improving memory bandwidth efficiency.
This example demonstrates a GEMM with custom epilogue that computes:
Y = A × BD = (A × B) × alpha + C × beta + X × x_factor
1
Define the kernel class
Start by extending the base EFC kernel:
import cutlassimport cutlass.cute as cutefrom common_dense_gemm_efc import DenseGemmEFCclass DenseGemmAlphaBeta(DenseGemmEFC): """Implements batched GEMM with custom epilogue fusion. Computes: - Y = A * B (accumulator stored to Y) - D = (A * B) * alpha + C * beta + X * x_factor """ def __init__( self, acc_dtype, use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, ): super().__init__( acc_dtype, use_2cta_instrs, mma_tiler_mn, cluster_shape_mn, )
2
Define tensor arguments
Specify the input and output tensors:
def create_arguments( self, l, m, n, k, a_major, b_major, cd_major, ab_dtype, c_dtype, d_dtype, x_dtype, y_dtype,): # Standard A, B tensors from parent class std_args = super().create_arguments( l, m, n, k, a_major, b_major, cd_major, ab_dtype ) # Create auxiliary tensors C, X, D, Y c_tensor = cutlass_torch.matrix(l, m, n, cd_major == "m", c_dtype) x_tensor = cutlass_torch.matrix(l, m, n, cd_major == "m", x_dtype) d_tensor = cutlass_torch.matrix(l, m, n, cd_major == "m", d_dtype) y_tensor = cutlass_torch.matrix(l, m, n, cd_major == "m", y_dtype) return (*std_args, c_tensor, x_tensor, d_tensor, y_tensor)
3
Define the epilogue operation
Implement the custom fusion logic:
@cute.jitdef epilogue_operation( self, accumulator: cute.Tensor, c: cute.Tensor, x: cute.Tensor, alpha: float, beta: float, x_factor: float,) -> tuple: # Store a copy of the accumulator to Y y = accumulator # Compute D = accumulator * alpha + C * beta + X * x_factor d = accumulator * alpha + c * beta + x * x_factor return (d, y)
@cute.jitdef epilogue_multi_output(self, accumulator, c): # Output 1: Standard result d = accumulator + c # Output 2: Squared result d_squared = d * d # Output 3: Normalized result d_norm = d / cute.norm(d) return (d, d_squared, d_norm)
@cute.jitdef epilogue_complex(self, accumulator, c, x, bias, scale): # Compute: (A * B + bias) * scale + C + tanh(X) result = (accumulator + bias) * scale result = result + c result = result + cute.tanh(x) # Apply GELU activation result = cute.gelu(result) return result
@cute.jitdef epilogue_debug(self, accumulator, c): # Print for debugging (only for small tensors!) if cute.thread_idx() == 0: cute.printf("Accumulator[0] = %f\n", accumulator[0]) return accumulator + c