Skip to main content

Mojo function

load_lds_fragment

load_lds_fragment[dtype: DType, smem_layout: Layout, smem_element_layout: Layout, frag_layout: Layout, frag_element_layout: Layout, //, mma_access_layout: Layout, swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]()](smem_tile: LayoutTensor[dtype, smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=smem_element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], reg_frag: LayoutTensor[dtype, frag_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=frag_element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])

Load LDS → registers with MMA access pattern.

Why mma_access_layout differs from the global→LDS thread layout: ┌─────────────────────────────────────────────────────────────────────┐ │ Layout │ Purpose │ Constraint │ ├─────────────────────────────────────────────────────────────────────┤ │ load_thread │ Global → LDS write │ Coalesced global reads │ │ mma_access │ LDS → Registers read │ AMD WMMA hardware pattern │ └─────────────────────────────────────────────────────────────────────┘

mma_access_layout encodes how AMD's WMMA instruction expects data:

  • Lane decomposition: (lane % 16, lane // 16) = (col_group, row_group)
  • Offset computation: col_group * 32 + row_group * 8

Using RuntimeLayout ensures compile-time evaluation (no GPU heap alloc).

Layout compatibility requirements:

  • mma_access_layout must map exactly WARP_SIZE (64) threads
  • smem must have enough elements for: num_iterations * WARP_SIZE * frag_width
  • frag must store: num_iterations * frag_width elements

Was this page helpful?