Skip to main content

Mojo function

tpool_patch_merger_kernel

tpool_patch_merger_kernel[dtype: DType, XLayout: TensorLayout, x_origin: ImmutOrigin, OutLayout: TensorLayout, out_origin: MutOrigin, GridThwLayout: TensorLayout, grid_thw_origin: ImmutOrigin, vec_width: Int, num_threads: Int](x_tile: TileTensor[dtype, XLayout, x_origin], out_tile: TileTensor[dtype, OutLayout, out_origin], grid_thws: TileTensor[DType.int64, GridThwLayout, grid_thw_origin], kH: Int, kW: Int, D: Int, n_vids: Int)

Temporal pooling patch merger kernel.

Averages x across the temporal dimension for each video, rearranging spatially according to the (kH, kW) merge kernel. Each video's output occupies H_i * W_i contiguous rows in the flat output tensor.

Grid mapping: block_idx.z = video index block_idx.y = patch index within the video (max_pat upper bound) block_idx.x = tile index along D thread_idx.x = lane within D tile

Args: