Skip to main content

Mojo function

tpool_patch_merger_kernel

tpool_patch_merger_kernel[dtype: DType, XLayout: TensorLayout, x_origin: ImmutOrigin, OutLayout: TensorLayout, out_origin: MutOrigin, GridThwLayout: TensorLayout, grid_thw_origin: ImmutOrigin, vec_width: Int, num_threads: Int](x_tile: TileTensor[dtype, XLayout, x_origin], out_tile: TileTensor[dtype, OutLayout, out_origin], grid_thws: TileTensor[DType.int64, GridThwLayout, grid_thw_origin], kH: Int, kW: Int, D: Int, n_vids: Int)

Temporal pooling patch merger kernel.

Averages x across the temporal dimension for each video, rearranging spatially according to the (kH, kW) merge kernel. Each video's output occupies H_i * W_i contiguous rows in the flat output tensor.

Grid mapping: block_idx.z = video index block_idx.y = patch index within the video (max_pat upper bound) block_idx.x = tile index along D thread_idx.x = lane within D tile

Args:

  • x_tile (TileTensor): Input tensor [n_tokens, D].
  • out_tile (TileTensor): Contiguous output tensor [total_output_patches, D].
  • grid_thws (TileTensor): Grid dimensions tensor [n_vids, 3] with (T, H, W) per video.
  • kH (Int): Merge kernel height.
  • kW (Int): Merge kernel width.
  • D (Int): Hidden dimension.
  • n_vids (Int): Number of videos.

Was this page helpful?