For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
gather
def gather[dtype: DType, indices_type: DType, //, *, axis: Int, target: StringSlice[StaticConstantOrigin] = StringSlice("cpu")](output: TileTensor[dtype, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], input: TileTensor[dtype, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], indices: TileTensor[indices_type, address_space=indices.address_space, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], *, context: DeviceContext)
Gather operation as defined in https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather.
Note that this is NOT the same as the default PyTorch gather (which is equivalent to https://github.com/onnx/onnx/blob/main/docs/Operators.md#gatherelements).
def gather[dtype: DType, indices_type: DType, InputFnType: ImplicitlyCopyable & RegisterPassable & def[width: Int, rank: Int, element_alignment: Int](IndexList[rank]) -> SIMD[dtype, width], IndicesFnType: ImplicitlyCopyable & RegisterPassable & def[width: Int, rank: Int](IndexList[rank]) -> SIMD[indices_type, width], OutputFnType: ImplicitlyCopyable & RegisterPassable & def[width: Int, rank: Int, element_alignment: Int](IndexList[rank], SIMD[dtype, width]) -> None, *, prefetch_fn: OptionalReg[def[input_rank: Int, indices_rank: Int](IndexList[input_rank], IndexList[indices_rank]) capturing -> None] = None, target: StringSlice[StaticConstantOrigin] = StringSlice("cpu")](axis: Axis, input_shape: IndexList[element_type=input_shape.element_type], indices_shape: IndexList[element_type=indices_shape.element_type], output_shape: IndexList[element_type=output_shape.element_type], *, input_fn: InputFnType, indices_fn: IndicesFnType, output_fn: OutputFnType, context: DeviceContext) where (eq InputFnType.dtype, dtype) where (eq IndicesFnType.indices_type, indices_type) where (eq OutputFnType.dtype, dtype)
Gather operation as defined in https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather.
Note that this is NOT the same as the default PyTorch gather (which is equivalent to https://github.com/onnx/onnx/blob/main/docs/Operators.md#gatherelements).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!