For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
random_normal
def random_normal[dtype: DType, rank: Int, //, output_fn: def[width: SIMDSize, _rank: Int](idx: IndexList[_rank], val: SIMD[dtype, width]) capturing -> None, target: StringSlice[StaticConstantOrigin]](shape: IndexList[rank], mean: Float32, stddev: Float32, seed_ptr: UnsafePointer[UInt64, ImmutAnyOrigin], ctx: DeviceContext)
Call output_fn with values from a normal distribution, matching PyTorch CUDA's torch.randn element-to-counter mapping.
For element i, mirrors PyTorch's per-thread Philox state:
thread_id = i mod GRID_BLOCK
within_thread = i div GRID_BLOCK (0..3)where GRID_BLOCK = 256 * min(num_SMs * blocks_per_sm, ceil(numel/256)).
A single Philox step at counter (0, 0, thread_id, 0) produces 4 normals
via :func:std.random.NormalRandom.step_normal_4; the within_thread
index selects which lane to write to output[i].
Bit-exact for numel <= 4 * GRID_BLOCK_max (β 1.2M elements on B200).
Parameters:
- βdtype (
DType): The data type to generate. - βrank (
Int): The rank of the underlying buffer. - βoutput_fn (
def[width: SIMDSize, _rank: Int](idx: IndexList[_rank], val: SIMD[dtype, width]) capturing -> None): The function which stores the generated values. - βtarget (
StringSlice[StaticConstantOrigin]): The target to run on.
Args:
- βshape (
IndexList[rank]): The shape of the output being stored into by output_fn. - βmean (
Float32): The mean of the normal distribution. - βstddev (
Float32): The standard deviation of the normal distribution. - βseed_ptr (
UnsafePointer[UInt64, ImmutAnyOrigin]): Pointer to a single uint64 in device memory containing the Philox seed. - βctx (
DeviceContext): The device context.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!