Mojo function
tpool_patch_merger_kernel
tpool_patch_merger_kernel[dtype: DType, XLayout: TensorLayout, x_origin: ImmutOrigin, OutLayout: TensorLayout, out_origin: MutOrigin, GridThwLayout: TensorLayout, grid_thw_origin: ImmutOrigin, vec_width: Int, num_threads: Int](x_tile: TileTensor[dtype, XLayout, x_origin], out_tile: TileTensor[dtype, OutLayout, out_origin], grid_thws: TileTensor[DType.int64, GridThwLayout, grid_thw_origin], kH: Int, kW: Int, D: Int, n_vids: Int)
Temporal pooling patch merger kernel.
Averages x across the temporal dimension for each video, rearranging spatially according to the (kH, kW) merge kernel. Each video's output occupies H_i * W_i contiguous rows in the flat output tensor.
Grid mapping: block_idx.z = video index block_idx.y = patch index within the video (max_pat upper bound) block_idx.x = tile index along D thread_idx.x = lane within D tile
Args:
- x_tile (
TileTensor): Input tensor [n_tokens, D]. - out_tile (
TileTensor): Contiguous output tensor [total_output_patches, D]. - grid_thws (
TileTensor): Grid dimensions tensor [n_vids, 3] with (T, H, W) per video. - kH (
Int): Merge kernel height. - kW (
Int): Merge kernel width. - D (
Int): Hidden dimension. - n_vids (
Int): Number of videos.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!