Mojo function
load_lds_fragment
load_lds_fragment[smem_layout: TensorLayout, reg_layout: TensorLayout, //, MMA_K: Int, swizzle: Optional[Swizzle] = Optional()](smem_tile: TileTensor[smem_tile.dtype, smem_layout, smem_tile.origin, address_space=AddressSpace.SHARED], reg_tile: TileTensor[smem_tile.dtype, reg_layout, reg_tile.origin, address_space=AddressSpace.LOCAL])
Load MMA fragments from SMEM to registers using hardware access pattern.
Dimensions are derived from the tile layouts: - num_mmas = reg rows, MMA_M = smem rows / num_mmas - lds_frag_width = MMA_M * MMA_K / WARP_SIZE - lds_row_stride: MMA_K (BF16 dense), smem stride (FP8 or strided) - num_iterations = reg flat elements / lds_frag_width
Parameters:
- smem_layout (
TensorLayout): Inferred layout of the SMEM source tile. - reg_layout (
TensorLayout): Inferred layout of the register destination tile. - MMA_K (
Int): MMA K dimension (hardware instruction width). - swizzle (
Optional): Optional element-space swizzle.
Args:
- smem_tile (
TileTensor): Source [num_mmas * MMA_M, K] in SHARED. - reg_tile (
TileTensor): Destination [num_mmas, K_frags * frag_width] in LOCAL.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!