Mojo function
learnable_2d_interp_pos_emb
learnable_2d_interp_pos_emb[dtype: DType](output: TileTensor[dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], x: TileTensor[dtype, x.LayoutType, x.origin, address_space=x.address_space, linear_idx_type=x.linear_idx_type, element_size=x.element_size], weight: TileTensor[dtype, weight.LayoutType, weight.origin, address_space=weight.address_space, linear_idx_type=weight.linear_idx_type, element_size=weight.element_size], grid_thws: TileTensor[DType.int64, grid_thws.LayoutType, grid_thws.origin, address_space=grid_thws.address_space, linear_idx_type=grid_thws.linear_idx_type, element_size=grid_thws.element_size], time_weight: TileTensor[DType.float32, time_weight.LayoutType, time_weight.origin, address_space=time_weight.address_space, linear_idx_type=time_weight.linear_idx_type, element_size=time_weight.element_size], ctx: DeviceContext)
Applies learnable 2D interpolated position embedding on GPU.
For each video described by grid_thws, bicubic-interpolates weight
from (H, W) to (h, w), optionally adds temporal sincos embedding, and
adds the result element-wise to x.
Args:
- output (
TileTensor): (L, dim) output tensor. - x (
TileTensor): (L, dim) input patch embeddings. - weight (
TileTensor): (H, W, dim) learnable 2D grid. - grid_thws (
TileTensor): (N, 3) per-video (t, h, w), int64. - time_weight (
TileTensor): (num_frames, dim) 1D sincos temporal embedding, float32. - ctx (
DeviceContext): Device context for GPU dispatch.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!