IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

blockwise_scaled_fp8_with_epilogue

def blockwise_scaled_fp8_with_epilogue[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, *, scales_granularity_mnk: IndexList[Int(3)], transpose_b: Bool = False, elementwise_lambda_fn: Optional[def[dtype: DType, width: SIMDSize, *, alignment: Int = Int(1)](IndexList[Int(2)], SIMD[dtype, width]) capturing -> None] = None](c: TileTensor[c_type, Storage=c.Storage, linear_idx_type=c.linear_idx_type, element_size=c.element_size], a: TileTensor[a_type, Storage=a.Storage, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[b_type, Storage=b.Storage, linear_idx_type=b.linear_idx_type, element_size=b.element_size], a_scales: TileTensor[a_scales_type, Storage=a_scales.Storage, linear_idx_type=a_scales.linear_idx_type, element_size=a_scales.element_size], b_scales: TileTensor[b_scales_type, Storage=b_scales.Storage, linear_idx_type=b_scales.linear_idx_type, element_size=b_scales.element_size], ctx: DeviceContext)

Our sm100 blockwise scaled fp8 matmul kernel still does not support fusion of elementwise operations. This is a temporary implementation that uses our sm100 blockwise scaled fp8 matmul kernel and dispatch a separate epilogue kernel to apply the elementwise operations. For non B200 GPUs, we use the naive blockwise scaled fp8 matmul which support normal epilogue natively. Callers must allocate c; when an elementwise_lambda_fn is supplied the matmul result is written into c and then read back by the lambda.