For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
NVBlockScaledTokenFormat
struct NVBlockScaledTokenFormat[quant_dtype: DType, scales_dtype: DType, output_layout: TensorLayout, scales_offset_layout: TensorLayout, //, _hid_dim: Int, _top_k: Int, _alignment: Int = Int(0)]
Fieldsβ
- βscales_tma_op (
NVBlockScaledTokenFormat[_hid_dim, _top_k, _alignment].ScalesTMATensorTileType): - βoutput_tokens (
NVBlockScaledTokenFormat[_hid_dim, _top_k, _alignment].TensorType): - βoutput_scales_offset (
NVBlockScaledTokenFormat[_hid_dim, _top_k, _alignment].ScalesOffsetTensorType):
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
TokenFormat
comptime membersβ
alignmentβ
comptime alignment = _alignment if _alignment.__bool__() else get_device_alignment()
device_typeβ
comptime device_type = NVBlockScaledTokenFormat[_hid_dim, _top_k, _alignment]
dispatch_smem_sizeβ
comptime dispatch_smem_size = (Int((add (mul align_up(Coord[Int(4), DType.int64](Index[Int, Int, Int, Int](Int(1), (((_hid_dim // NVBlockScaledTokenFormat.get_group_size()) // Int(4)) // (load_from_mem Tuple(Int(128), Int(2)).__getitem_param__[Int(1)]())), Int(1), Int((mul (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]()), 4)))).product(), Int(128)), size_of[scales_dtype](), 32), (mul align_up((NVBlockScaledTokenFormat.quant_size() // (load_from_mem Tuple(Int(128), Int(2)).__getitem_param__[Int(1)]())), Int(16)), 32))) + align_up(Int((mul size_of[SharedMemBarrier](), 32)), Int(8)))
dispatch_wait_tile_shapeβ
comptime dispatch_wait_tile_shape = Tuple(Int(128), Int(2))
group_sizeβ
comptime group_size = NVBlockScaledTokenFormat.get_group_size()
hid_dimβ
comptime hid_dim = _hid_dim
is_mxfp4β
comptime is_mxfp4 = (quant_dtype == DType.uint8) if (scales_dtype == DType.float8_e8m0fnu) else (scales_dtype == DType.float8_e8m0fnu)
is_mxfp8β
comptime is_mxfp8 = (quant_dtype == DType.float8_e4m3fn) if (scales_dtype == DType.float8_e8m0fnu) else (scales_dtype == DType.float8_e8m0fnu)
is_nvfp4β
comptime is_nvfp4 = (quant_dtype == DType.uint8) if (scales_dtype == DType.float8_e4m3fn) else (scales_dtype == DType.float8_e4m3fn)
ScalesOffsetTensorTypeβ
comptime ScalesOffsetTensorType = TileTensor[DType.uint32, scales_offset_layout, MutUntrackedOrigin]
ScalesTMATensorTileTypeβ
comptime ScalesTMATensorTileType = TMATensorTile[scales_dtype, Int(4), NVBlockScaledTokenFormat[_hid_dim, _top_k, _alignment].tma_tile_shape, _default_desc_shape[Int(4), scales_dtype, NVBlockScaledTokenFormat[_hid_dim, _top_k, _alignment].tma_tile_shape, TensorMapSwizzle.SWIZZLE_NONE]()]
TensorTypeβ
comptime TensorType = TileTensor[quant_dtype, output_layout, MutUntrackedOrigin]
tma_tile_shapeβ
comptime tma_tile_shape = Index[Int, Int, Int, Int](Int(1), (((_hid_dim // NVBlockScaledTokenFormat.get_group_size()) // Int(4)) // (load_from_mem Tuple(Int(128), Int(2)).__getitem_param__[Int(1)]())), Int(1), (Int(4) * (load_from_mem SF_ATOM_M.__getitem_param__[Int(1)]())))
top_kβ
comptime top_k = _top_k
Methodsβ
__init__β
def __init__(out self, output_tokens: TileTensor[quant_dtype, output_layout, Storage=output_tokens.Storage, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size], output_scales: TileTensor[scales_dtype, Storage=output_scales.Storage, address_space=output_scales.address_space, linear_idx_type=output_scales.linear_idx_type, element_size=output_scales.element_size], output_scales_offset: TileTensor[DType.uint32, scales_offset_layout, Storage=output_scales_offset.Storage, address_space=output_scales_offset.address_space, linear_idx_type=output_scales_offset.linear_idx_type, element_size=output_scales_offset.element_size], ctx: DeviceContext)
get_group_sizeβ
get_type_nameβ
quant_sizeβ
scales_sizeβ
token_sizeβ
scales_offsetβ
pad_expert_offsetsβ
def pad_expert_offsets[n_groups: Int](self, row_offsets: UnsafePointer[UInt32, address_space=row_offsets.address_space])
The mojo NVFP4 grouped matmul doesn't require padding for each group's FP4 quants. However, it requires each group's scales to be aligned to the SF_MN_GROUP_SIZE=128. This function updates the output_scales_offset tensor to satisfy this requirement.
For example, if the row_offsets tensor is [0, 100, 300, 400], this function will update the output_scales_offset tensor to [0, 1, 1]. The formula is: For group i, its first scales block index is row_offsets[i] // SF_MN_GROUP_SIZE + output_scales_offset[i]. Group 0, 1 and 2 have 100, 200, 100 tokens respectively, so the number of scales blocks are 1, 2, 1 respectively. The scales blocks for group 1 start at 100 // 128 + output_scales_offset[1] = 1, and the scales blocks for group 2 start at 300 // 128 + output_scales_offset[2] = 3.
copy_token_to_send_bufβ
static def copy_token_to_send_buf[src_type: DType, block_size: Int, buf_addr_space: AddressSpace = AddressSpace.GENERIC](buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], src_p: UnsafePointer[Scalar[src_type], address_space=src_p.address_space], input_scale: Float32)
copy_msg_to_output_tensorβ
def copy_msg_to_output_tensor[buf_addr_space: AddressSpace = AddressSpace.GENERIC](self, buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], token_index: Int, expert_slot: Int = Int(0), expert_start: Int = Int(0))
NVFP4 format directly uses tile based copy.
init_smem_resourcesβ
def init_smem_resources(self)
copy_msg_tile_to_output_tensorβ
def copy_msg_tile_to_output_tensor[extract_topk_info_func: def(UnsafePointer[UInt8, MutUntrackedOrigin], Int) -> None, recv_buf_ptr_func: def(Int) -> UnsafePointer[UInt8, MutUntrackedOrigin], //, n_warps: Int, shared_expert_offset: Int = Int(0)](self, expert_id: Int, expert_start_pos: Int, tile_id: Int, tile_end: Int, extract_topk_info_functor: extract_topk_info_func, recv_buf_ptr_functor: recv_buf_ptr_func)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!