Mojo struct
QRegisterBuffer
struct QRegisterBuffer[dtype: DType, mma_shape: IndexList[3], k_group_size: Int, WM: Int, WN: Int, BN: Int, BK: Int, depth: Int, thread_layout: Layout]
Fields
- reg_tile (
LayoutTensor[dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]):
Implemented traits
AnyType,
RegisterBuffer,
RegisterMMABuffer,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = True
mma_dtype
alias mma_dtype = dtype
MMA_K
alias MMA_K = mma_shape.__getitem__[3, DType.int64, Int](2)
MMA_M
alias MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)
mma_tile_layout
alias mma_tile_layout = LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], AddressSpace.LOCAL), _get_index_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], AddressSpace.LOCAL), False, align_of[dtype](), (LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0].shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), 0]()[0]
MMATileType
alias MMATileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], AddressSpace.LOCAL), _get_index_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], AddressSpace.LOCAL), False, align_of[dtype](), (LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0].shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), 0]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL]
num_k_tiles
alias num_k_tiles = ceildiv(BK, (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].MMA_K * k_group_size))
num_mmas
alias num_mmas = ceildiv(WM, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].MMA_M)
num_tiles
alias num_tiles = (depth // BK)
reg_dtype
alias reg_dtype = dtype
reg_tile_layout
alias reg_tile_layout = Layout.row_major(((QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_mmas * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles) * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].simd_width)
RegisterTileType
alias RegisterTileType = LayoutTensor[dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]
simd_width
alias simd_width = simd_width_of[dtype]()
TiledIteratorType
alias TiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_mmas * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, axis=0, layout_int_type=_get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), linear_idx_type=_get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), masked=_tile_is_masked[QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_mmas * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].simd_width]()]
Methods
__init__
__init__(out self, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])
get_dtype
get_iter
get_iter(self) -> LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_mmas * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, axis=0, layout_int_type=_get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), linear_idx_type=_get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), masked=_tile_is_masked[QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_mmas * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].simd_width]()]
Returns:
LayoutTensorIter
get_mma_tile
get_mma_tile[tile_idx: Int, k_idx: Int](self) -> LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], AddressSpace.LOCAL), _get_index_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], AddressSpace.LOCAL), False, align_of[dtype](), (LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0].shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), 0]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL]
Returns:
get_reg_tile
get_reg_tile(self) -> LayoutTensor[dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]
Returns:
zero
zero(self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!