Skip to main content

Mojo module

attention

comptime values

ImmutTileTensor1D

comptime ImmutTileTensor1D[dtype: DType] = TileTensor[dtype, Layout[#kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](Layout.row_major(VariadicList(-1)).shape), [idx: __mlir_type.index] _int_to_dim(Layout.row_major(VariadicList(-1)).shape[idx].value())), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[Dim], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[VA[idx]._value_or_missing] if (VA[idx] != -31337) else RuntimeInt[DType.int64])), #kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](Layout.row_major(VariadicList(-1)).stride), [idx: __mlir_type.index] _int_to_dim(Layout.row_major(VariadicList(-1)).stride[idx].value())), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[Dim], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[VA[idx]._value_or_missing] if (VA[idx] != -31337) else RuntimeInt[DType.int64]))], ImmutAnyOrigin]

Parameters

KVTMATile

comptime KVTMATile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int] = TMATensorTile[dtype, 3, _padded_shape[3, dtype, IndexList(VariadicList(BN, 1, BK), Tuple()), swizzle_mode](), _ragged_shape[3, dtype, IndexList(VariadicList(BN, 1, BK), Tuple()), swizzle_mode]()]

Parameters

QTMATile

comptime QTMATile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BM: Int, depth: Int, group: Int, decoding: Bool, num_qk_stages: Int = 1] = TMATensorTile[dtype, 4 if decoding else 3, _padded_shape[4 if decoding else 3, dtype, q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding, num_qk_stages=num_qk_stages](), swizzle_mode](), _ragged_shape[4 if decoding else 3, dtype, q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding, num_qk_stages=num_qk_stages](), swizzle_mode]()]

Parameters

Structs

Traits

Functions

Was this page helpful?