IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

tpool_patch_merger_kernel

def tpool_patch_merger_kernel[dtype: DType, XLayout: TensorLayout, x_origin: ImmutOrigin, OutLayout: TensorLayout, out_origin: MutOrigin, GridThwLayout: TensorLayout, grid_thw_origin: ImmutOrigin, vec_width: Int, num_threads: Int](x_tile: TileTensor[dtype, XLayout, x_origin], out_tile: TileTensor[dtype, OutLayout, out_origin], grid_thws: TileTensor[DType.int64, GridThwLayout, grid_thw_origin], kH: Int, kW: Int, D: Int, n_vids: Int)

Temporal pooling patch merger kernel.

Averages x across the temporal dimension for each video, rearranging spatially according to the (kH, kW) merge kernel. Each video's output occupies H_i * W_i contiguous rows in the flat output tensor.

Grid mapping: block_idx.z = video index block_idx.y = patch index within the video (max_pat upper bound) block_idx.x = tile index along D thread_idx.x = lane within D tile

Args: