IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

NVBlockScaledTokenFormat

struct NVBlockScaledTokenFormat[quant_dtype: DType, scales_dtype: DType, output_layout: TensorLayout, scales_offset_layout: TensorLayout, //, _hid_dim: Int, _top_k: Int, _alignment: Int = Int(0)]

Fields​

  • ​scales_tma_op (NVBlockScaledTokenFormat[_hid_dim, _top_k, _alignment].ScalesTMATensorTileType):
  • ​output_tokens (NVBlockScaledTokenFormat[_hid_dim, _top_k, _alignment].TensorType):
  • ​output_scales_offset (NVBlockScaledTokenFormat[_hid_dim, _top_k, _alignment].ScalesOffsetTensorType):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, TokenFormat

comptime members​

alignment​

comptime alignment = _alignment if _alignment.__bool__() else get_device_alignment()

device_type​

comptime device_type = NVBlockScaledTokenFormat[_hid_dim, _top_k, _alignment]

dispatch_smem_size​

comptime dispatch_smem_size = (Int((add (mul align_up(Coord[Int(4), DType.int64](Index[Int, Int, Int, Int](Int(1), (((_hid_dim // NVBlockScaledTokenFormat.get_group_size()) // Int(4)) // (load_from_mem Tuple(Int(128), Int(2)).__getitem_param__[Int(1)]())), Int(1), Int((mul (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]()), 4)))).product(), Int(128)), size_of[scales_dtype](), 32), (mul align_up((NVBlockScaledTokenFormat.quant_size() // (load_from_mem Tuple(Int(128), Int(2)).__getitem_param__[Int(1)]())), Int(16)), 32))) + align_up(Int((mul size_of[SharedMemBarrier](), 32)), Int(8)))

dispatch_wait_tile_shape​

comptime dispatch_wait_tile_shape = Tuple(Int(128), Int(2))

group_size​

comptime group_size = NVBlockScaledTokenFormat.get_group_size()

hid_dim​

comptime hid_dim = _hid_dim

is_mxfp4​

comptime is_mxfp4 = (quant_dtype == DType.uint8) if (scales_dtype == DType.float8_e8m0fnu) else (scales_dtype == DType.float8_e8m0fnu)

is_mxfp8​

comptime is_mxfp8 = (quant_dtype == DType.float8_e4m3fn) if (scales_dtype == DType.float8_e8m0fnu) else (scales_dtype == DType.float8_e8m0fnu)

is_nvfp4​

comptime is_nvfp4 = (quant_dtype == DType.uint8) if (scales_dtype == DType.float8_e4m3fn) else (scales_dtype == DType.float8_e4m3fn)

ScalesOffsetTensorType​

comptime ScalesOffsetTensorType = TileTensor[DType.uint32, scales_offset_layout, MutUntrackedOrigin]

ScalesTMATensorTileType​

comptime ScalesTMATensorTileType = TMATensorTile[scales_dtype, Int(4), NVBlockScaledTokenFormat[_hid_dim, _top_k, _alignment].tma_tile_shape, _default_desc_shape[Int(4), scales_dtype, NVBlockScaledTokenFormat[_hid_dim, _top_k, _alignment].tma_tile_shape, TensorMapSwizzle.SWIZZLE_NONE]()]

TensorType​

comptime TensorType = TileTensor[quant_dtype, output_layout, MutUntrackedOrigin]

tma_tile_shape​

comptime tma_tile_shape = Index[Int, Int, Int, Int](Int(1), (((_hid_dim // NVBlockScaledTokenFormat.get_group_size()) // Int(4)) // (load_from_mem Tuple(Int(128), Int(2)).__getitem_param__[Int(1)]())), Int(1), (Int(4) * (load_from_mem SF_ATOM_M.__getitem_param__[Int(1)]())))

top_k​

comptime top_k = _top_k

Methods​

__init__​

def __init__(out self, output_tokens: TileTensor[quant_dtype, output_layout, Storage=output_tokens.Storage, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size], output_scales: TileTensor[scales_dtype, Storage=output_scales.Storage, address_space=output_scales.address_space, linear_idx_type=output_scales.linear_idx_type, element_size=output_scales.element_size], output_scales_offset: TileTensor[DType.uint32, scales_offset_layout, Storage=output_scales_offset.Storage, address_space=output_scales_offset.address_space, linear_idx_type=output_scales_offset.linear_idx_type, element_size=output_scales_offset.element_size], ctx: DeviceContext)

get_group_size​

static def get_group_size() -> Int

Returns:

Int

get_type_name​

static def get_type_name() -> String

Returns:

String

quant_size​

static def quant_size() -> Int

Returns:

Int

scales_size​

static def scales_size() -> Int

Returns:

Int

token_size​

static def token_size() -> Int

Returns:

Int

scales_offset​

static def scales_offset() -> Int

Returns:

Int

pad_expert_offsets​

def pad_expert_offsets[n_groups: Int](self, row_offsets: UnsafePointer[UInt32, address_space=row_offsets.address_space])

The mojo NVFP4 grouped matmul doesn't require padding for each group's FP4 quants. However, it requires each group's scales to be aligned to the SF_MN_GROUP_SIZE=128. This function updates the output_scales_offset tensor to satisfy this requirement.

For example, if the row_offsets tensor is [0, 100, 300, 400], this function will update the output_scales_offset tensor to [0, 1, 1]. The formula is: For group i, its first scales block index is row_offsets[i] // SF_MN_GROUP_SIZE + output_scales_offset[i]. Group 0, 1 and 2 have 100, 200, 100 tokens respectively, so the number of scales blocks are 1, 2, 1 respectively. The scales blocks for group 1 start at 100 // 128 + output_scales_offset[1] = 1, and the scales blocks for group 2 start at 300 // 128 + output_scales_offset[2] = 3.

copy_token_to_send_buf​

static def copy_token_to_send_buf[src_type: DType, block_size: Int, buf_addr_space: AddressSpace = AddressSpace.GENERIC](buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], src_p: UnsafePointer[Scalar[src_type], address_space=src_p.address_space], input_scale: Float32)

copy_msg_to_output_tensor​

def copy_msg_to_output_tensor[buf_addr_space: AddressSpace = AddressSpace.GENERIC](self, buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], token_index: Int, expert_slot: Int = Int(0), expert_start: Int = Int(0))

NVFP4 format directly uses tile based copy.

init_smem_resources​

def init_smem_resources(self)

copy_msg_tile_to_output_tensor​

def copy_msg_tile_to_output_tensor[extract_topk_info_func: def(UnsafePointer[UInt8, MutUntrackedOrigin], Int) -> None, recv_buf_ptr_func: def(Int) -> UnsafePointer[UInt8, MutUntrackedOrigin], //, n_warps: Int, shared_expert_offset: Int = Int(0)](self, expert_id: Int, expert_start_pos: Int, tile_id: Int, tile_end: Int, extract_topk_info_functor: extract_topk_info_func, recv_buf_ptr_functor: recv_buf_ptr_func)