Mojo function
rdna_mma
rdna_mma(a_reg: TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=a_reg.linear_idx_type, element_size=a_reg.element_size], b_reg: TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=b_reg.linear_idx_type, element_size=b_reg.element_size], c_reg: TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=c_reg.linear_idx_type, element_size=c_reg.element_size])
Per-fragment WMMA loop. Derives MMA counts from operand shapes; accumulator indexing is col-major over (M, N): c_idx = m + n*num_m.
Args:
- โa_reg (
TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=a_reg.linear_idx_type, element_size=a_reg.element_size]): A operand tile [num_m_mmas, RDNA_AB_FRAG_SIZE]. - โb_reg (
TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=b_reg.linear_idx_type, element_size=b_reg.element_size]): B operand tile [num_n_mmas, RDNA_AB_FRAG_SIZE]. - โc_reg (
TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=c_reg.linear_idx_type, element_size=c_reg.element_size]): Accumulator tile [num_m_mmas * num_n_mmas, RDNA_CD_FRAG_SIZE], modified in-place.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!