IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MXFP4TokenFormat

struct MXFP4TokenFormat[fp4_dtype: DType, scales_dtype: DType, output_layout: TensorLayout, scales_layout: TensorLayout, //, _hid_dim: Int, _top_k: Int, _alignment: Int = 0]

Fields​

  • ​output_tokens (MXFP4TokenFormat[_hid_dim, _top_k, _alignment].TensorType):
  • ​output_scales (MXFP4TokenFormat[_hid_dim, _top_k, _alignment].ScalesTensorType):

Implemented traits​

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TokenFormat, TrivialRegisterPassable

comptime members​

alignment​

comptime alignment = _alignment if _alignment.__bool__() else get_device_alignment()

device_type​

comptime device_type = MXFP4TokenFormat[_hid_dim, _top_k, _alignment]

dispatch_smem_size​

comptime dispatch_smem_size = 0

dispatch_wait_tile_shape​

comptime dispatch_wait_tile_shape = Tuple(128, 1)

group_size​

comptime group_size = MXFP4_SF_VECTOR_SIZE

hid_dim​

comptime hid_dim = _hid_dim

ScalesTensorType​

comptime ScalesTensorType = TileTensor[scales_dtype, scales_layout, MutUntrackedOrigin]

TensorType​

comptime TensorType = TileTensor[fp4_dtype, output_layout, MutUntrackedOrigin]

top_k​

comptime top_k = _top_k

Methods​

__init__​

def __init__(output_tokens: TileTensor[fp4_dtype, output_layout, address_space=output_tokens.address_space, linear_idx_type=output_tokens.linear_idx_type, element_size=output_tokens.element_size], output_scales: TileTensor[scales_dtype, scales_layout, address_space=output_scales.address_space, linear_idx_type=output_scales.linear_idx_type, element_size=output_scales.element_size]) -> Self

get_type_name​

static def get_type_name() -> String

Returns:

String

fp4_quant_size​

static def fp4_quant_size() -> Int

Returns:

Int

scales_size​

static def scales_size() -> Int

Returns:

Int

token_size​

static def token_size() -> Int

Returns:

Int

scales_offset​

static def scales_offset() -> Int

Returns:

Int

copy_token_to_send_buf​

static def copy_token_to_send_buf[src_type: DType, block_size: Int, buf_addr_space: AddressSpace = AddressSpace.GENERIC](buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], src_p: UnsafePointer[Scalar[src_type], address_space=src_p.address_space], input_scale: Float32)

copy_msg_to_output_tensor​

def copy_msg_to_output_tensor[buf_addr_space: AddressSpace = AddressSpace.GENERIC](self, buf_p: UnsafePointer[UInt8, address_space=buf_addr_space], token_index: Int)