Skip to main content

Mojo function

learnable_2d_interp_pos_emb

learnable_2d_interp_pos_emb[dtype: DType](output: TileTensor[dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], x: TileTensor[dtype, x.LayoutType, x.origin, address_space=x.address_space, linear_idx_type=x.linear_idx_type, element_size=x.element_size], weight: TileTensor[dtype, weight.LayoutType, weight.origin, address_space=weight.address_space, linear_idx_type=weight.linear_idx_type, element_size=weight.element_size], grid_thws: TileTensor[DType.int64, grid_thws.LayoutType, grid_thws.origin, address_space=grid_thws.address_space, linear_idx_type=grid_thws.linear_idx_type, element_size=grid_thws.element_size], time_weight: TileTensor[DType.float32, time_weight.LayoutType, time_weight.origin, address_space=time_weight.address_space, linear_idx_type=time_weight.linear_idx_type, element_size=time_weight.element_size], ctx: DeviceContext)

Applies learnable 2D interpolated position embedding on GPU.

For each video described by grid_thws, bicubic-interpolates weight from (H, W) to (h, w), optionally adds temporal sincos embedding, and adds the result element-wise to x.

Args:

  • output (TileTensor): (L, dim) output tensor.
  • x (TileTensor): (L, dim) input patch embeddings.
  • weight (TileTensor): (H, W, dim) learnable 2D grid.
  • grid_thws (TileTensor): (N, 3) per-video (t, h, w), int64.
  • time_weight (TileTensor): (num_frames, dim) 1D sincos temporal embedding, float32.
  • ctx (DeviceContext): Device context for GPU dispatch.

Was this page helpful?