Skip to main content

Mojo struct

VBufferTransposeLoads

struct VBufferTransposeLoads[dtype: DType, kv_tile_layout: TensorLayout, //, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], BN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1]

Fields

  • load_tile (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].LoadTile):
  • mma_tile (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MmaTile):
  • smem_ptr (UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]):
  • gmem_tile (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].GmemTileType):
  • reg_loader (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].RegLoaderType):
  • tile_idx (Int):
  • current_stage (Int):

Implemented traits

AnyType, ImplicitlyDestructible, KVBuffer

comptime members

depth_tile_size

comptime depth_tile_size = min(depth, 128)

GmemTileType

comptime GmemTileType = TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin]

load_width

comptime load_width = 4 if (depth == 64) else VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width

loads_per_thread_per_depth_tile

comptime loads_per_thread_per_depth_tile = ((VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].depth_tile_size * BK) // (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].load_width * num_threads))

LoadTile

comptime LoadTile = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

MMA_K

comptime MMA_K = shape[2]

MMA_M

comptime MMA_M = shape[0]

mma_tile_layout

comptime mma_tile_layout = Layout.row_major((depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_M), VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width)

MmaTile

comptime MmaTile = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

num_depth_tiles

comptime num_depth_tiles = (depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_M)

num_k_tiles

comptime num_k_tiles = ceildiv(BK, (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_K * group_size))

num_repeats

comptime num_repeats = (BK // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width)

RegLoaderType

comptime RegLoaderType = RegTileLoader[dtype, row_major[4, (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].depth_tile_size // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].load_width)](), warp_scope=True]

simd_width

comptime simd_width = simd_width_of[dtype]()

Methods

__init__

__init__(out self, gmem_tile: TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin], shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])

pad

static pad[dim: Int]() -> Int

Returns:

Int

load_from_dram

load_from_dram(mut self)

get_mma_tile

get_mma_tile(self) -> LayoutTensor[dtype, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]

Returns:

LayoutTensor

copy_to_shared

copy_to_shared[tile_id: Int = 0](self)

load_from_shared

load_from_shared[k_mma: Int](self)

Was this page helpful?