Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Softmax

Category: Activation | Complexity: O(N) per row | Memory: 2 passes over input

Algorithm

The online 2-pass softmax (Milakov & Gimelshein 2018):

Pass 1 (single traversal): Maintain running (max, sum) pair per thread. When a new maximum is found, rescale the accumulated sum:

sum_new = sum_old * exp(max_old - max_new) + exp(x - max_new)

Pass 2: Write exp(x - global_max) / global_sum per element.

This is 33% less memory traffic than the naive 3-pass algorithm (max, exp+sum, normalize).

ascend-rs Kernel Source

Softmax in ascend-rs uses the buffer API for element-wise backends and the tile API for matrix-oriented backends:

Scalar kernel (f32, benchmarked implementation):

#![allow(unused)]
fn main() {
#[ascend_std::aiv_kernel]
pub fn softmax(input: *const f32, output: *mut f32, len: *const u32) {
    let n = *len as usize;

    // Step 1: Find max for numerical stability
    let mut max_val = *input;
    let mut i = 1usize;
    loop {
        if i >= n { break; }
        let val = *input.wrapping_add(i);
        if val > max_val { max_val = val; }
        i += 1;
    }

    // Step 2: exp(x - max) and accumulate sum
    let mut sum: f32 = 0.0;
    i = 0;
    loop {
        if i >= n { break; }
        let exp_val = (*input.wrapping_add(i) - max_val).exp();
        *output.wrapping_add(i) = exp_val;
        sum += exp_val;
        i += 1;
    }

    // Step 3: Normalize
    i = 0;
    loop {
        if i >= n { break; }
        *output.wrapping_add(i) = *output.wrapping_add(i) / sum;
        i += 1;
    }
}
}

Tile API — safe entry form (lowered by rustc_codegen_mlir to all 9 backends: Ascend AIV, CUDA, Apple Metal, Vulkan SPIR-V, AWS NKI, AMD AIE, Cambricon BANG, Intel Gaudi, Google TPU):

#![allow(unused)]
fn main() {
use ascend_std::tile::{GmView, GmViewMut, safe, tile_load_view_f32, tile_store_view_f32};

#[ascend_std::aiv_kernel]
pub fn tile_softmax(
    input:  GmView<'_, 1, 1024, f32>,
    output: GmViewMut<'_, 1, 1024, f32>,
) {
    let x = tile_load_view_f32(&input);
    let y = safe::tile_softmax_f32(x);
    tile_store_view_f32(&output, y);
}
}

The kernel body is pure safe Rust — no unsafe blocks. Shape (rows, cols, dtype) is committed at the type level via const generics, so any host-side mismatch becomes a compile-time error. The #[aiv_kernel] attribute rewrites the emitted signature back to raw *const f32 / *mut f32 so the launcher/compiler toolchain (bisheng / ACL / nvcc) sees the same C ABI. #[repr(transparent)] on GmView/GmViewMut makes this rewrite free at the LLVM IR level — the two forms emit as literal symbol aliases.

Kernels compile via rustc_codegen_mlir → MLIR → target-specific code. Softmax is one of the four “hot path” tile ops (alongside matmul, rms-norm, silu) that is lowered on every backend currently targeted.

Benchmark configurations

ShapeElementsBytes (f32)Notes
(1, 1024)1K4 KBL1-resident, tests dispatch overhead
(64, 1024)64K256 KBL2-resident, typical batch
(64, 4096)256K1 MBBandwidth-bound regime

Results

See Leaderboard filtered to Softmax for the full filterable view.