Skip to main content

Mojo function

flatten_leading

flatten_leading[dtype: DType, layout: TensorLayout, //](tensor: TileTensor[dtype, layout, tensor.origin, address_space=tensor.address_space, linear_idx_type=tensor.linear_idx_type, element_size=tensor.element_size]) -> TileTensor[dtype, Layout[RuntimeInt[DType.int64], layout._shape_types[(add layout.rank._mlir_value, -1)], #kgen.variadic.concat(RuntimeInt[layout._shape_types[(add layout.rank._mlir_value, -1)].DTYPE if layout._shape_types[(add layout.rank._mlir_value, -1)].is_static_value.__bool__().__invert__()._mlir_value else DType.int] if layout._shape_types[(add layout.rank._mlir_value, -1)].is_static_value.__bool__().__invert__() if layout._shape_types[(add layout.rank._mlir_value, -1)].is_static_value.__bool__().__invert__()._mlir_value else False else ComptimeInt[(layout._shape_types[(add layout.rank._mlir_value, -1)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1])], tensor.origin, address_space=tensor.address_space, linear_idx_type=tensor.linear_idx_type, element_size=tensor.element_size]

Merge the first two dimensions of a rank-3 TileTensor: (A, B, C) -> (A*B, C).

Returns a new TileTensor sharing the same pointer with row-major strides computed from the merged shape. Zero-cost operation.

Common use case: converting 3D batched tensors (num_experts, N, K) to 2D (num_experts*N, K) for TMA descriptor creation in MoE kernels.

Parameters:

  • โ€‹dtype (DType): Element type (inferred from tensor).
  • โ€‹layout (TensorLayout): Layout type (inferred from tensor).

Args:

  • โ€‹tensor (TileTensor): A rank-3 TileTensor.

Returns:

TileTensor: A rank-2 TileTensor where dim[0] = old dim[0] * dim[1].

Was this page helpful?