Skip to main content

Mojo struct

BlockScaledKernelContext

struct BlockScaledKernelContext[num_clc_pipeline_stages: Int, cta_group: Int, CLUSTER_M: Int, CLUSTER_N: Int, BM: Int, MMA_N: Int, num_pipeline_stages: Int]

Per-CTA state: election flags, coordinates, multicast masks, TMEM offsets.

Fields

  • elect_one_warp (Bool):
  • elect_one_thread (Bool):
  • elect_one_cta (Bool):
  • is_first_cta_in_cluster (Bool):
  • warp_id (UInt32):
  • rank_m (UInt):
  • rank_n (UInt):
  • peer_cta_coord (Tuple[UInt, UInt, UInt]):
  • a_multicast_mask (UInt16):
  • b_multicast_mask (UInt16):
  • mma_complete_mask (Int):
  • ptr_tmem_addr (LegacyUnsafePointer[UInt32, address_space=AddressSpace.SHARED]):

Implemented traits

AnyType, Copyable, ImplicitlyDestructible, Movable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = False

__del__is_trivial

comptime __del__is_trivial = False

__moveinit__is_trivial

comptime __moveinit__is_trivial = False

SFA_NUM_COLS

comptime SFA_NUM_COLS = (BM // 32)

SFB_NUM_COLS

comptime SFB_NUM_COLS = (MMA_N // 32)

TmemAddrArray

comptime TmemAddrArray = SMemArrayType[UInt32, 1]

Methods

__init__

__init__(out self, tmem_addr_ptr: LegacyUnsafePointer[UInt32, address_space=AddressSpace.SHARED])

Initialize context; computes election flags and multicast masks.

Was this page helpful?