IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

random_normal

def random_normal[dtype: DType, rank: Int, //, output_fn: def[width: SIMDSize, _rank: Int](idx: IndexList[_rank], val: SIMD[dtype, width]) capturing -> None, target: StringSlice[StaticConstantOrigin]](shape: IndexList[rank], mean: Float32, stddev: Float32, seed_ptr: UnsafePointer[UInt64, ImmutAnyOrigin], ctx: DeviceContext)

Call output_fn with values from a normal distribution, matching PyTorch CUDA's torch.randn element-to-counter mapping.

For element i, mirrors PyTorch's per-thread Philox state:

thread_id     = i mod GRID_BLOCK
within_thread = i div GRID_BLOCK   (0..3)

where GRID_BLOCK = 256 * min(num_SMs * blocks_per_sm, ceil(numel/256)).

A single Philox step at counter (0, 0, thread_id, 0) produces 4 normals via :func:std.random.NormalRandom.step_normal_4; the within_thread index selects which lane to write to output[i].

Bit-exact for numel <= 4 * GRID_BLOCK_max (β‰ˆ 1.2M elements on B200).

Parameters:

  • ​dtype (DType): The data type to generate.
  • ​rank (Int): The rank of the underlying buffer.
  • ​output_fn (def[width: SIMDSize, _rank: Int](idx: IndexList[_rank], val: SIMD[dtype, width]) capturing -> None): The function which stores the generated values.
  • ​target (StringSlice[StaticConstantOrigin]): The target to run on.

Args:

  • ​shape (IndexList[rank]): The shape of the output being stored into by output_fn.
  • ​mean (Float32): The mean of the normal distribution.
  • ​stddev (Float32): The standard deviation of the normal distribution.
  • ​seed_ptr (UnsafePointer[UInt64, ImmutAnyOrigin]): Pointer to a single uint64 in device memory containing the Philox seed.
  • ​ctx (DeviceContext): The device context.