Skip to main content

Mojo function

warp_specialize_gemm_with_multicasting

warp_specialize_gemm_with_multicasting[c_type: DType, a_type: DType, b_type: DType, //, *, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], grid_shape: OptionalReg[IndexList[2]] = None, use_tma_store: Bool = False, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, schedule: MatmulSchedule = MatmulSchedule.NONE, hilbert_swizzle: Bool = False, splits: Int = 0, raster_order: RasterOrder = RasterOrder.AlongM, swapAB: Bool = False](c_device: TileTensor[c_type, c_device.LayoutType, c_device.origin, address_space=c_device.address_space, linear_idx_type=c_device.linear_idx_type, element_size=c_device.element_size], a_device: TileTensor[a_type, a_device.LayoutType, a_device.origin, address_space=a_device.address_space, linear_idx_type=a_device.linear_idx_type, element_size=a_device.element_size], b_device: TileTensor[b_type, b_device.LayoutType, b_device.origin, address_space=b_device.address_space, linear_idx_type=b_device.linear_idx_type, element_size=b_device.element_size], ctx: DeviceContext)

Unified dispatcher for all matmul kernel variants.

Was this page helpful?