Skip to main content

Mojo struct

QRegisterBuffer

struct QRegisterBuffer[dtype: DType, mma_shape: IndexList[3], k_group_size: Int, WM: Int, WN: Int, BN: Int, BK: Int, depth: Int, thread_layout: Layout]

Fields

  • reg_tile (LayoutTensor[dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)]):

Implemented traits

AnyType, RegisterBuffer, RegisterMMABuffer, UnknownDestructibility

Aliases

__del__is_trivial

alias __del__is_trivial = True

mma_dtype

alias mma_dtype = dtype

MMA_K

alias MMA_K = mma_shape.__getitem__[3, DType.int64, Int](2)

MMA_M

alias MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)

mma_tile_layout

alias mma_tile_layout = LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()).shape[0].value() // (depth // BK)), 0]()[0], MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()).shape[0].value() // (depth // BK)), 0]()[0], AddressSpace(5)), _get_index_type(LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()).shape[0].value() // (depth // BK)), 0]()[0], AddressSpace(5)), False, align_of[dtype](), (LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()).shape[0].value() // (depth // BK)), 0]()[0].shape[0].value() // ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))), 0]()[0]

MMATileType

alias MMATileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()).shape[0].value() // (depth // BK)), 0]()[0], MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()).shape[0].value() // (depth // BK)), 0]()[0], AddressSpace(5)), _get_index_type(LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()).shape[0].value() // (depth // BK)), 0]()[0], AddressSpace(5)), False, align_of[dtype](), (LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()).shape[0].value() // (depth // BK)), 0]()[0].shape[0].value() // ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5)]

num_k_tiles

alias num_k_tiles = ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))

num_mmas

alias num_mmas = ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0))

num_tiles

alias num_tiles = (depth // BK)

reg_dtype

alias reg_dtype = dtype

reg_tile_layout

alias reg_tile_layout = Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]())

RegisterTileType

alias RegisterTileType = LayoutTensor[dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)]

simd_width

alias simd_width = simd_width_of[dtype]()

TiledIteratorType

alias TiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))), simd_width_of[dtype]()]()[0], MutableAnyOrigin, address_space=AddressSpace(5), axis=0, layout_int_type=_get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), linear_idx_type=_get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), masked=_tile_is_masked[Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), (ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))), simd_width_of[dtype]()]()]

Methods

__init__

__init__(out self, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])

get_dtype

static get_dtype() -> DType

Returns:

DType

get_iter

get_iter(self) -> LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))), simd_width_of[dtype]()]()[0], MutableAnyOrigin, address_space=AddressSpace(5), axis=0, layout_int_type=_get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), linear_idx_type=_get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), masked=_tile_is_masked[Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), (ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))), simd_width_of[dtype]()]()]

Returns:

LayoutTensorIter

get_mma_tile

get_mma_tile[tile_idx: Int, k_idx: Int](self) -> LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()).shape[0].value() // (depth // BK)), 0]()[0], MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()).shape[0].value() // (depth // BK)), 0]()[0], AddressSpace(5)), _get_index_type(LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()).shape[0].value() // (depth // BK)), 0]()[0], AddressSpace(5)), False, align_of[dtype](), (LayoutTensor._compute_tile_layout[True, dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, AddressSpace(5), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), _get_index_type(Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), AddressSpace(5)), False, align_of[dtype](), (Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()).shape[0].value() // (depth // BK)), 0]()[0].shape[0].value() // ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))), 0]()[0], MutableAnyOrigin, address_space=AddressSpace(5)]

Returns:

LayoutTensor

get_reg_tile

get_reg_tile(self) -> LayoutTensor[dtype, Layout.row_major(((ceildiv(WM, mma_shape.__getitem__[3, DType.int64, Int](0)) * ceildiv(BK, (mma_shape.__getitem__[3, DType.int64, Int](2) * k_group_size))) * (depth // BK)), simd_width_of[dtype]()), MutableAnyOrigin, address_space=AddressSpace(5)]

Returns:

LayoutTensor

zero

zero(self)

Was this page helpful?