Mojo module
mla_decode_utils
comptime values
QOTMATile
comptime QOTMATile[dtype: DType, BM: Int, BK: Int, swizzle_mode: TensorMapSwizzle] = TMATensorTile[dtype, 2, IndexList(VariadicList(BM, BK), Tuple()), _default_desc_shape[2, dtype, IndexList(VariadicList(BM, BK), Tuple()), swizzle_mode]()]
Parameters
- dtype (
DType): - BM (
Int): - BK (
Int): - swizzle_mode (
TensorMapSwizzle):
ScalesTMATile
comptime ScalesTMATile[BN: Int] = TMATensorTile[DType.float32, 2, IndexList(VariadicList(1, BN), Tuple())]
Parameters
- BN (
Int):
Structs
-
DecodeCConsumer: -
DecodeCProducer: -
DecodeKVConsumer: -
DecodeKVProducer: -
DecodeOConsumer: -
DecodeOProducer: -
DecodeOutConsumer: -
DecodeOutProducer: -
DecodePConsumer: -
DecodePConsumerN: -
DecodePProducer: -
DecodePProducerN: -
DecodeSConsumer: -
DecodeSConsumerN: -
DecodeSM100MiscMBars: -
DecodeSM100PVSS: -
DecodeSM100PVSS_FP8: -
DecodeSM100QKTSS: -
DecodeSM100QKTSS_Content_FP8: -
DecodeSM100QKTSS_FP8: -
DecodeSM100QKTSS_Rope_BF16: -
DecodeSProducer: -
DecodeSProducerN: -
KVCvt2MmaConsumer: -
KVCvt2MmaProducer: -
KVLoad2CvtConsumer: -
KVLoad2CvtProducer: -
KVPipelineGeneric: KVPipeline hasnum_kv_stages * num_qk_stagesstages.num_kv_stagesrefers to how manyKandVtiles we pipeline for performing theS = Q@K'andO += P@VMMAs. Each of these MMAs is broken up intonum_qk_stagespipelined MMAs. We setstep=Falsefor all but the last MMA that completes the operation. An alternative implementation would separate the two, and potentially allow for more overall stages at the cost of slightly more bookkeeping. -
MLA_Decode_Pack: -
MLA_SM100_Decode_Common: -
MLA_SM100_Decode_Config: -
OffsetPosition: -
OutPipeline: OutPipeline hasnum_out_stagesstages.num_out_stagesrefers to how many output stages we pipeline for performing the output store.
Functions
-
build_mma_ss_ws: -
bulk_mma_ws: -
clamped_index_coordinate: -
cvt_fp8x8_from_2xu32_to_bf16x8_packed_u32x4: -
e8m0_to_bf16_broadcast: Convert an e8m0 scale byte to a bf16 value broadcast into both halves of a uint32. -
hmul2_bf16x8_by_scalar: Multiply 8 packed bf16 values (in 4 uint32 registers) by a bf16x2 scalar broadcast. -
ld_shared_v4_u32: -
num_matrix_view_rows_decode: TileTensor overload ofnum_matrix_view_rows_decode. -
st_shared_v4_b32_at_bf16_elem_off: -
st_shared_v4_b32_at_fp8_elem_off: -
tma_tile_qo: -
tma_tile_scales: Create a TMA descriptor for per-token float32 scales. -
write_bf16x2_row_to_smem_chunked: Chunked write with optional scaling. Reduces register pressure. -
write_fp8_row_to_smem_chunked: Write float32 data to SMEM as FP8 with swizzle for SWIZZLE_64B.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!