Skip to main content

Mojo function

cluster_mask_base

cluster_mask_base[cluster_shape: IndexList[3], axis: Int]() -> SIMD[uint16, 1]

Computes the base mask for a cluster. Base mask in an axis masks the first cta in cluster and all ctas along the same axis. Example for cluster shape (4, 4, 1), note that cta rank is contiguous along the first cluster axis.

 x o o o                       x x x x
x o o o o o o o
x o o o o o o o
x o o o o o o o
 x o o o                       x x x x
x o o o o o o o
x o o o o o o o
x o o o o o o o

base mask in axis 0 base mask in axis 1

Parameters:

  • cluster_shape (IndexList[3]): The shape of the cluster.
  • axis (Int): The axis to compute the base mask for.

Returns:

The base mask for the cluster.

Was this page helpful?