IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SubTileLoaderLDS_HK_st_8x32

struct SubTileLoaderLDS_HK_st_8x32[dtype: DType, BN: Int, depth: Int, BK: Int, num_threads: Int]

DRAM→LDS DMA for HK kittens' st_8x32_s SMEM layout (V operand).

Mirrors the group-level cooperative load() in ~/HipKittens/include/ops/group/memory/tile/global_to_shared.cuh:

  • Each thread (laneid 0..63 across all num_threads / 64 warps) writes bytes_per_thread = 16 bytes per iteration directly to LDS at the natural byte offset lane_byte_offset = thread_id * 16 + iter * num_threads * 16.
  • thread_id is warp_id * WARP_SIZE + lane_id, so the LDS bytes cover successive subtiles in HK's row-major-by-block-col ordering (subtile_id = subtile_row * subtiles_per_row + subtile_col, with subtile shape 8×BK BF16).
  • Each lane reads its source position in DRAM via the swizzle ↔ position bijection: subtile_lane_byte_offset → (row, col) within the 8×BK subtile, which then unpacks back into a (global_row, global_col) DRAM byte address. For HK st_8x32 BF16 the swizzle is the identity, so the global position is just the natural subtile-local position.
  • Writes go via rocdl.raw.ptr.buffer.load.lds with the same _alias_scope_attr SubTileLoaderLDS uses, so consumer-side ds_read_tr* LDS reads tagged noalias_scopes=_alias_scope_attr can skip s_waitcnt vmcnt(0) (LLVM PR #74537's SIInsertWaitcnts vmcnt-relaxation handshake), provided the kernel maintains an explicit s_waitcnt vmcnt(0) + s_barrier fence at DMA/compute boundaries.

The layout is hard-coded to HK's BF16 st_8x32_s:

  • subtile_rows = 8
  • subtile_cols = BK (32 for HKMHAExact)
  • No swizzle (st_8x32 BF16 returns offset unchanged in HK)

For the K operand (HK uses st_32x32_s with two-XOR swizzle), see SubTileLoaderLDS + swizzle/swizzle2 plumbing instead.

Parameters

  • dtype (DType): Element data type (must be BF16 — HK's st_8x32_s specialization assumes 2-byte elements; FP32 would use a different shape).
  • BN (Int): KV block height in elements (= 64 for HKMHAExact).
  • depth (Int): V tile column span in elements (= D for the model's head_dim; 64, 128, or 256 for HKMHAExact).
  • BK (Int): Subtile column span in elements (= 32 for HK st_8x32_s).
  • num_threads (Int): Total threads in the cooperative load (= 8 warps × 64 lanes = 512 for HKMHAExact). Used to compute bytes_per_iter.

Fields

  • bc (AMDBufferResource): The 128-bit buffer resource descriptor for DRAM access.

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

Methods

__init__

__init__(gmem_tile: TileTensor[dtype, address_space=gmem_tile.address_space, linear_idx_type=gmem_tile.linear_idx_type, element_size=gmem_tile.element_size]) -> Self

Create a loader from a DRAM tile.

Args:

load

load(self, v_smem_slot: TileTensor[dtype, address_space=AddressSpace.SHARED], v_gmem_tile: TileTensor[dtype, address_space=v_gmem_tile.address_space, linear_idx_type=v_gmem_tile.linear_idx_type, element_size=v_gmem_tile.element_size], warp_id_uniform: Int, lane_id_local: Int)

Cooperatively DMA one V tile from DRAM into LDS.

Args: