Skip to main content

Mojo function

fp8_index_naive

fp8_index_naive[dtype: DType, //, num_heads: Int, depth: Int](output: TileTensor[DType.float32, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, q.LayoutType, q.origin, address_space=q.address_space, linear_idx_type=q.linear_idx_type, element_size=q.element_size], q_s: TileTensor[DType.float32, q_s.LayoutType, q_s.origin, address_space=q_s.address_space, linear_idx_type=q_s.linear_idx_type, element_size=q_s.element_size], k: TileTensor[dtype, k.LayoutType, k.origin, address_space=k.address_space, linear_idx_type=k.linear_idx_type, element_size=k.element_size], k_s: TileTensor[DType.float32, k_s.LayoutType, k_s.origin, address_space=k_s.address_space, linear_idx_type=k_s.linear_idx_type, element_size=k_s.element_size], valid_length: TileTensor[DType.uint32, valid_length.LayoutType, valid_length.origin, address_space=valid_length.address_space, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], cache_row_offsets: TileTensor[DType.uint32, cache_row_offsets.LayoutType, cache_row_offsets.origin, address_space=cache_row_offsets.address_space, linear_idx_type=cache_row_offsets.linear_idx_type, element_size=cache_row_offsets.element_size], batch_size: Int, max_seq_len: Int, max_num_keys: Int, ctx: DeviceContext)

Was this page helpful?