Mojo function
gemm_mma_cpasync
gemm_mma_cpasync[pdl_level: PDLLevel = PDLLevel(), tile_k: Int = 128, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: TileTensor[address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size], act: TileTensor[address_space=act.address_space, linear_idx_type=act.linear_idx_type, element_size=act.element_size], weight: TileTensor[address_space=weight.address_space, linear_idx_type=weight.linear_idx_type, element_size=weight.element_size], gemm_m: Int, gemm_k: Int, gemm_n: Int, batch_size: Int, ctx: DeviceContext)
Launch the batched GEMM tensor-core kernel.
C[gemm_m, gemm_n] = act[gemm_m, K] x weight[gemm_n, K]^T.
Args:
- βc (
TileTensor[address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size]): Output, shape (gemm_m, gemm_n) or (batch, gemm_m, gemm_n). - βact (
TileTensor[address_space=act.address_space, linear_idx_type=act.linear_idx_type, element_size=act.element_size]): Activation, shape (gemm_m, gemm_k) or (batch, gemm_m, gemm_k). - βweight (
TileTensor[address_space=weight.address_space, linear_idx_type=weight.linear_idx_type, element_size=weight.element_size]): Weight, shape (gemm_n, gemm_k) or (batch, gemm_n, gemm_k). - βgemm_m (
Int): Activation rows (output rows). - βgemm_k (
Int): Reduction dimension. - βgemm_n (
Int): Weight rows (output cols). - βbatch_size (
Int): Batch size; ignored for 2D inputs (treated as 1). - βctx (
DeviceContext): GPU device context.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!