Skip to main content

function

stencil

stencil[rank: Int, stencil_rank: Int, stencil_axis: StaticIntTuple[$1], simd_width: Int, type: DType, map_fn: fn(StaticIntTuple[$1]) capturing -> Tuple[StaticIntTuple[$1], StaticIntTuple[$1]], map_strides: fn(dim: Int) capturing -> Int, load_fn: fn[Int, DType](StaticIntTuple[$0], /) capturing -> SIMD[$1, $0], compute_init_fn: fn[Int]() capturing -> SIMD[$4, $0], compute_fn: fn[Int](StaticIntTuple[$0], SIMD[$4, $0], SIMD[$4, $0], /) capturing -> SIMD[$4, $0], compute_finalize_fn: fn[Int](StaticIntTuple[$0], SIMD[$4, $0], /) capturing -> None](shape: StaticIntTuple[rank], input_shape: StaticIntTuple[rank])

Computes stencil operation in parallel.

Computes output as a function that processes input stencils, stencils are computed as a continuous region for each output point that is determined by map_fn : map_fn(y) -> lower_bound, upper_bound. The boundary conditions for regions that fail out of the input domain are handled by load_fn.

Parameters:

  • rank (Int): Input and output domain rank.
  • stencil_rank (Int): Rank of stencil subdomain slice.
  • stencil_axis (StaticIntTuple[$1]): Stencil subdomain axes.
  • simd_width (Int): The SIMD vector width to use.
  • type (DType): The input and output data type.
  • map_fn (fn(StaticIntTuple[$1]) capturing -> Tuple[StaticIntTuple[$1], StaticIntTuple[$1]]): A function that a point in the output domain to the input co-domain.
  • map_strides (fn(dim: Int) capturing -> Int): A function that returns the stride for the dim.
  • load_fn (fn[Int, DType](StaticIntTuple[$0], /) capturing -> SIMD[$1, $0]): A function that loads a vector of simd_width from input.
  • compute_init_fn (fn[Int]() capturing -> SIMD[$4, $0]): A function that initializes vector compute over the stencil.
  • compute_fn (fn[Int](StaticIntTuple[$0], SIMD[$4, $0], SIMD[$4, $0], /) capturing -> SIMD[$4, $0]): A function the process the value computed for each point in the stencil.
  • compute_finalize_fn (fn[Int](StaticIntTuple[$0], SIMD[$4, $0], /) capturing -> None): A function that finalizes the computation of a point in the output domain given a stencil.

Args:

  • shape (StaticIntTuple[rank]): The shape of the output buffer.
  • input_shape (StaticIntTuple[rank]): The shape of the input buffer.