For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
NVFP4TokenFormat
struct NVFP4TokenFormat[fp4_dtype: DType, scales_dtype: DType, output_layout: TensorLayout, scales_offset_layout: TensorLayout, //, _hid_dim: Int, _top_k: Int, _alignment: Int = 0]
Fieldsβ
- βscales_tma_op (
NVFP4TokenFormat[_hid_dim, _top_k, _alignment].ScalesTMATensorTileType): - βoutput_tokens (
NVFP4TokenFormat[_hid_dim, _top_k, _alignment].TensorType): - βoutput_scales_offset (
NVFP4TokenFormat[_hid_dim, _top_k, _alignment].ScalesOffsetTensorType):
Implemented traitsβ
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
TokenFormat
comptime membersβ
alignmentβ
comptime alignment = _alignment if _alignment.__bool__() else get_device_alignment()
device_typeβ
comptime device_type = NVFP4TokenFormat[_hid_dim, _top_k, _alignment]
dispatch_smem_sizeβ
comptime dispatch_smem_size = ((32 * ((align_up(Int[Int](Coord[4, DType.int64](NVFP4TokenFormat[_hid_dim, _top_k, _alignment].tma_tile_shape).product()), 128) * size_of[scales_dtype]()) + align_up(((_hid_dim // 2) // (load_from_mem NVFP4TokenFormat[_hid_dim, _top_k, _alignment].dispatch_wait_tile_shape.__getitem_param__[1]())), 16))) + align_up((32 * size_of[SharedMemBarrier]()), 8))
dispatch_wait_tile_shapeβ
comptime dispatch_wait_tile_shape = Tuple(128, 2)
group_sizeβ
comptime group_size = NVFP4_SF_VECTOR_SIZE
hid_dimβ
comptime hid_dim = _hid_dim
ScalesOffsetTensorTypeβ
comptime ScalesOffsetTensorType = TileTensor[DType.uint32, scales_offset_layout, MutExternalOrigin]
ScalesTMATensorTileTypeβ
comptime ScalesTMATensorTileType = TMATensorTile[scales_dtype, 4, NVFP4TokenFormat[_hid_dim, _top_k, _alignment].tma_tile_shape, _default_desc_shape[4, scales_dtype, NVFP4TokenFormat[_hid_dim, _top_k, _alignment].tma_tile_shape, TensorMapSwizzle.SWIZZLE_NONE]()]
TensorTypeβ
comptime TensorType = TileTensor[fp4_dtype, output_layout, MutExternalOrigin]
tma_tile_shapeβ
comptime tma_tile_shape = Index[Int, Int, Int, Int](1, (((_hid_dim // 16) // 4) // (load_from_mem NVFP4TokenFormat[_hid_dim, _top_k, _alignment].dispatch_wait_tile_shape.__getitem_param__[1]())), 1, (4 * (load_from_mem SF_ATOM_M.__getitem_param__[1]())))
top_kβ
comptime top_k = _top_k
Methodsβ
__init__β
__init__(out self, output_tokens: TileTensor[fp4_dtype, output_layout, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size], output_scales: TileTensor[scales_dtype, address_space=output_scales.address_space, linear_idx_type=output_scales.linear_idx_type, element_size=output_scales.element_size], output_scales_offset: TileTensor[DType.uint32, scales_offset_layout, address_space=output_scales_offset.address_space, linear_idx_type=output_scales_offset.linear_idx_type, element_size=output_scales_offset.element_size], ctx: DeviceContext)
get_type_nameβ
fp4_quant_sizeβ
scales_sizeβ
token_sizeβ
scales_offsetβ
pad_expert_offsetsβ
pad_expert_offsets[n_groups: Int](self, row_offsets: UnsafePointer[UInt32, address_space=row_offsets.address_space])
The mojo NVFP4 grouped matmul doesn't require padding for each group's FP4 quants. However, it requires each group's scales to be aligned to the SF_MN_GROUP_SIZE=128. This function updates the output_scales_offset tensor to satisfy this requirement.
For example, if the row_offsets tensor is [0, 100, 300, 400], this function will update the output_scales_offset tensor to [0, 1, 1]. The formula is: For group i, its first scales block index is row_offsets[i] // SF_MN_GROUP_SIZE + output_scales_offset[i]. Group 0, 1 and 2 have 100, 200, 100 tokens respectively, so the number of scales blocks are 1, 2, 1 respectively. The scales blocks for group 1 start at 100 // 128 + output_scales_offset[1] = 1, and the scales blocks for group 2 start at 300 // 128 + output_scales_offset[2] = 3.
copy_token_to_send_bufβ
static copy_token_to_send_buf[src_type: DType, block_size: Int, buf_addr_space: AddressSpace = AddressSpace.GENERIC](buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], src_p: UnsafePointer[Scalar[src_type], address_space=src_p.address_space], input_scale: Float32)
copy_msg_to_output_tensorβ
copy_msg_to_output_tensor[buf_addr_space: AddressSpace = AddressSpace.GENERIC](self, buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], token_index: Int)
NVFP4 format directly uses tile based copy.
init_smem_resourcesβ
init_smem_resources(self)
copy_msg_tile_to_output_tensorβ
copy_msg_tile_to_output_tensor[extract_topk_info_func: def(UnsafePointer[UInt8, MutExternalOrigin], Int) -> None, recv_buf_ptr_func: def(Int) -> UnsafePointer[UInt8, MutExternalOrigin], //, n_warps: Int, shared_expert_offset: Int = 0](self, expert_id: Int, expert_start_pos: Int, tile_id: Int, tile_end: Int, extract_topk_info_functor: extract_topk_info_func, recv_buf_ptr_functor: recv_buf_ptr_func)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!