For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
SubTileLoaderLDS_HK_st_8x32
struct SubTileLoaderLDS_HK_st_8x32[dtype: DType, BN: Int, depth: Int, BK: Int, num_threads: Int]
DRAM→LDS DMA for HK kittens' st_8x32_s SMEM layout (V operand).
Mirrors the group-level cooperative load() in
~/HipKittens/include/ops/group/memory/tile/global_to_shared.cuh:
- Each thread (laneid 0..63 across all
num_threads / 64warps) writesbytes_per_thread = 16bytes per iteration directly to LDS at the natural byte offsetlane_byte_offset = thread_id * 16 + iter * num_threads * 16. thread_idiswarp_id * WARP_SIZE + lane_id, so the LDS bytes cover successive subtiles in HK's row-major-by-block-col ordering (subtile_id = subtile_row * subtiles_per_row + subtile_col, with subtile shape 8×BK BF16).- Each lane reads its source position in DRAM via the swizzle ↔ position bijection: subtile_lane_byte_offset → (row, col) within the 8×BK subtile, which then unpacks back into a (global_row, global_col) DRAM byte address. For HK st_8x32 BF16 the swizzle is the identity, so the global position is just the natural subtile-local position.
- Writes go via
rocdl.raw.ptr.buffer.load.ldswith the same_alias_scope_attrSubTileLoaderLDS uses, so consumer-sideds_read_tr*LDS reads taggednoalias_scopes=_alias_scope_attrcan skips_waitcnt vmcnt(0)(LLVM PR #74537'sSIInsertWaitcntsvmcnt-relaxation handshake), provided the kernel maintains an explicits_waitcnt vmcnt(0) + s_barrierfence at DMA/compute boundaries.
The layout is hard-coded to HK's BF16 st_8x32_s:
subtile_rows = 8subtile_cols = BK(32 for HKMHAExact)- No swizzle (st_8x32 BF16 returns offset unchanged in HK)
For the K operand (HK uses st_32x32_s with two-XOR swizzle), see
SubTileLoaderLDS + swizzle/swizzle2 plumbing instead.
Parameters
- dtype (
DType): Element data type (must be BF16 — HK'sst_8x32_sspecialization assumes 2-byte elements; FP32 would use a different shape). - BN (
Int): KV block height in elements (= 64 for HKMHAExact). - depth (
Int): V tile column span in elements (= D for the model's head_dim; 64, 128, or 256 for HKMHAExact). - BK (
Int): Subtile column span in elements (= 32 for HK st_8x32_s). - num_threads (
Int): Total threads in the cooperative load (= 8 warps × 64 lanes = 512 for HKMHAExact). Used to computebytes_per_iter.
Fields
- bc (
AMDBufferResource): The 128-bit buffer resource descriptor for DRAM access.
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
Methods
__init__
__init__(gmem_tile: TileTensor[dtype, address_space=gmem_tile.address_space, linear_idx_type=gmem_tile.linear_idx_type, element_size=gmem_tile.element_size]) -> Self
Create a loader from a DRAM tile.
Args:
- gmem_tile (
TileTensor[dtype, address_space=gmem_tile.address_space, linear_idx_type=gmem_tile.linear_idx_type, element_size=gmem_tile.element_size]): The full DRAM tile from KVCacheIterator (carries a Scalar valid_rows for clamping bounds).
load
load(self, v_smem_slot: TileTensor[dtype, address_space=AddressSpace.SHARED], v_gmem_tile: TileTensor[dtype, address_space=v_gmem_tile.address_space, linear_idx_type=v_gmem_tile.linear_idx_type, element_size=v_gmem_tile.element_size], warp_id_uniform: Int, lane_id_local: Int)
Cooperatively DMA one V tile from DRAM into LDS.
Args:
- v_smem_slot (
TileTensor[dtype, address_space=AddressSpace.SHARED]): Destination V SMEM tile (must hold at leastBN * depth * size_of[dtype]bytes; the loader uses.ptras the LDS base). - v_gmem_tile (
TileTensor[dtype, address_space=v_gmem_tile.address_space, linear_idx_type=v_gmem_tile.linear_idx_type, element_size=v_gmem_tile.element_size]): TileTensor view of the BN × depth source tile in DRAM. Used to derive the DRAM-tile-to-bc-base byte offset (thescalar_offsetofbuffer_load_lds). - warp_id_uniform (
Int): Wave-uniform warp index (0..num_warps-1). Caller must pass an SGPR-class value (e.g.,readfirstlane(warp_id())). - lane_id_local (
Int): Per-lane index (0..WARP_SIZE-1) fromlane_id().
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!