Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

RMS Norm

Category: Normalization | Complexity: O(N) per row | Memory: 2 passes

Algorithm

RMSNorm (Zhang & Sennrich 2019) is a simplified LayerNorm used in LLaMA, Gemma, and most modern LLMs. It omits the mean-centering step:

  1. RMS: Compute root-mean-square: rms = sqrt(mean(x²) + ε)
  2. Normalize + Scale: y = (x / rms) * gamma

Compared to LayerNorm, RMSNorm saves one reduction pass (no mean computation) and one elementwise subtraction, yielding ~15% faster inference at equal accuracy.

ascend-rs Kernel Source

RMS Norm using ascend-rs buffer API (f32):

#![allow(unused)]
fn main() {
/// RMS Norm: y[i] = (x[i] / rms) * gamma[i]
/// where rms = sqrt(mean(x²) + eps)
///
/// params: [n: u32]
#[ascend_std::aiv_kernel]
pub fn rms_norm(
    input: *const f32,
    gamma: *const f32,
    output: *mut f32,
    params: *const u32,
) {
    let n = *params;
    let eps = 1.0e-5f32;

    let in_buf = ascend_std::ascend_buf_alloc(n);
    let gamma_buf = ascend_std::ascend_buf_alloc(n);
    let work = ascend_std::ascend_buf_alloc(n);
    let rwork = ascend_std::ascend_buf_alloc(n);

    // Load input and gamma
    ascend_std::ascend_buf_load_f32(in_buf, input, n);
    ascend_std::ascend_buf_load_f32(gamma_buf, gamma, n);
    ascend_std::ascend_pipe_barrier();

    // Step 1: x² → work
    ascend_std::ascend_mul_f32(work, in_buf, in_buf, n);
    ascend_std::ascend_pipe_barrier();

    // Step 2: rms = sqrt(mean(x²) + eps)
    let sq_sum = ascend_std::ascend_reduce_sum_f32(work, work, rwork, n);
    let inv_rms = 1.0 / (sq_sum / (n as f32) + eps).sqrt();

    // Step 3: output = (x * inv_rms) * gamma
    ascend_std::ascend_muls_f32(work, in_buf, inv_rms, n);
    ascend_std::ascend_pipe_barrier();
    ascend_std::ascend_mul_f32(work, work, gamma_buf, n);

    ascend_std::ascend_pipe_barrier();
    ascend_std::ascend_buf_store_f32(output, work, n);
}
}

This buffer-API kernel runs on the Ascend AIV backend. A tile-API safe::tile_rms_norm_f32 variant is additionally lowered by rustc_codegen_mlir to all 9 backends (Ascend AIV, CUDA, Apple Metal, Vulkan SPIR-V, AWS NKI, AMD AIE, Cambricon BANG, Intel Gaudi, Google TPU) — RMS Norm is one of the four “hot path” tile ops (alongside matmul, softmax, silu) that is lowered on every backend currently targeted.

Benchmark configurations

ShapeElementsBytes (f32)Notes
(1, 768)7683 KBGPT-2 hidden dim, single token
(1, 4096)4K16 KBLLaMA-7B hidden dim
(64, 768)49K192 KBTypical batch, GPT-2
(64, 4096)262K1 MBTypical batch, LLaMA
(1024, 4096)4.2M16 MBLarge batch, bandwidth-bound

All benchmarks use f32.

Results

See Leaderboard filtered to RMS Norm for the full filterable view.