Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

VQ Quantize + EMA Update

Category: Quantization | Complexity: O(N·K·D) | Fusion: L2 distance + argmin + scatter-add

Algorithm

Vector quantization maps each input vector to its nearest codebook entry, then updates the codebook via exponential moving average (EMA). Used in VQ-VAE training (SOKE, Jukebox, SoundStream).

Pipeline:

  1. L2 distance: For each input vector x[i] (dim D), compute ||x[i] - c[k]||² against all K codebook entries
  2. Argmin: Find nearest codebook entry k* = argmin_k ||x[i] - c[k]||²
  3. Quantize: Output q[i] = c[k*] (the nearest codebook vector)
  4. EMA scatter-add: Accumulate x[i] into codebook slot k* for EMA update: sum[k*] += x[i], count[k*] += 1

Fusing all 4 steps into one kernel eliminates 3 intermediate buffers (distance matrix, index array, scatter workspace).

ascend-rs Kernel Source

VQ quantize kernel using ascend-rs buffer API (f32):

#![allow(unused)]
fn main() {
/// VQ Quantize: for each input vector, find nearest codebook entry (L2),
/// output the quantized vector, and scatter-add for EMA codebook update.
///
/// params: [n_vectors: u32, n_codes: u32, dim: u32]
#[ascend_std::aiv_kernel]
pub fn vq_quantize(
    input: *const f32,      // (N, D) input vectors
    codebook: *const f32,   // (K, D) codebook
    output: *mut f32,       // (N, D) quantized output
    cb_sum: *mut f32,       // (K, D) EMA numerator accumulator
    cb_count: *mut u32,     // (K,)   EMA denominator counter
    params: *const u32,
) {
    let n = *params;                      // number of input vectors
    let k = *params.wrapping_add(1);      // codebook size
    let d = *params.wrapping_add(2);      // vector dimension

    let buf_x = ascend_std::ascend_buf_alloc(d);     // current input vector
    let buf_c = ascend_std::ascend_buf_alloc(d);     // current codebook entry
    let buf_diff = ascend_std::ascend_buf_alloc(d);  // x - c
    let buf_work = ascend_std::ascend_buf_alloc(d);
    let buf_rwork = ascend_std::ascend_buf_alloc(d);

    let mut i: u32 = 0;
    while i < n {
        // Load input vector x[i]
        let x_ptr = input.wrapping_add((i * d) as usize);
        ascend_std::ascend_buf_load_f32(buf_x, x_ptr, d);
        ascend_std::ascend_pipe_barrier();

        // Find nearest codebook entry (L2 argmin)
        let mut best_k: u32 = 0;
        let mut best_dist: f32 = f32::MAX;

        let mut j: u32 = 0;
        while j < k {
            let c_ptr = codebook.wrapping_add((j * d) as usize);
            ascend_std::ascend_buf_load_f32(buf_c, c_ptr, d);
            ascend_std::ascend_pipe_barrier();

            // diff = x - c
            ascend_std::ascend_sub_f32(buf_diff, buf_x, buf_c, d);
            ascend_std::ascend_pipe_barrier();
            // diff² = diff * diff
            ascend_std::ascend_mul_f32(buf_diff, buf_diff, buf_diff, d);
            ascend_std::ascend_pipe_barrier();
            // dist = sum(diff²)
            let dist = ascend_std::ascend_reduce_sum_f32(
                buf_work, buf_diff, buf_rwork, d);

            if dist < best_dist {
                best_dist = dist;
                best_k = j;
            }
            j += 1;
        }

        // Output: quantized = codebook[best_k]
        let best_ptr = codebook.wrapping_add((best_k * d) as usize);
        ascend_std::ascend_buf_load_f32(buf_c, best_ptr, d);
        ascend_std::ascend_pipe_barrier();
        let out_ptr = output.wrapping_add((i * d) as usize);
        ascend_std::ascend_buf_store_f32(out_ptr, buf_c, d);

        // EMA scatter-add: cb_sum[best_k] += x[i], cb_count[best_k] += 1
        let sum_ptr = cb_sum.wrapping_add((best_k * d) as usize);
        let sum_buf = ascend_std::ascend_buf_alloc(d);
        ascend_std::ascend_buf_load_f32(sum_buf, sum_ptr, d);
        ascend_std::ascend_pipe_barrier();
        ascend_std::ascend_add_f32(sum_buf, sum_buf, buf_x, d);
        ascend_std::ascend_pipe_barrier();
        ascend_std::ascend_buf_store_f32(sum_ptr, sum_buf, d);

        let count_val = *cb_count.wrapping_add(best_k as usize);
        *cb_count.wrapping_add(best_k as usize) = count_val + 1;

        i += 1;
    }
}
}

This buffer-API kernel runs on the Ascend AIV backend via rustc_codegen_mlir, and avoids materializing the N×K distance matrix and K-element index array. No tile-API safe::tile_vq_quantize_f32 currently exists — tile-API lowerings on all 9 backends (Ascend AIV / CUDA / Apple Metal / Vulkan SPIR-V / AWS NKI / AMD AIE / Cambricon BANG / Intel Gaudi / Google TPU) are future work. Cross-backend VQ today uses vendor kernels (aclnnMatmul, MPS GEMM, torch.cdist) with a separate argmin pass rather than the fused Rust kernel shown above.

Benchmark configurations

Shape (N, K, D)FLOPsNotes
(256, 512, 64)16.8 MSmall codebook, low-latency inference
(1024, 512, 64)67.1 MTypical VQ-VAE batch
(1024, 1024, 128)268 MLarge codebook, high-dim embeddings
(4096, 512, 64)268 MLarge batch training

All benchmarks use f32.

Results

DeviceShape (N, K, D)Latency (μs)GFLOPSNotes
Ascend 910B(4096, 1024, 128)9411,411aclnnMatmul L2 trick
Ascend 910B(1024, 1024, 128)318,604Large codebook
Ascend 910B(4096, 512, 64)436,243Large batch
Apple M2 Max(4096, 1024, 128)6461,662MPS GEMM + CPU argmin
Apple M2 Max(8192, 512, 64)4501,193Large batch
Tesla T4(4096, 1024, 128)1,163923torch.cdist + argmin
Tesla T4(1024, 1024, 128)494544torch.cdist + argmin
Tesla T4(4096, 512, 64)624430torch.cdist + argmin

Peak: 11.4 TFLOPS on Ascend 910B (cube engine via L2 distance matmul trick). Apple M2 Max peaks at 1.7 TFLOPS via MPS. Tesla T4 peaks at 923 GFLOPS via torch.cdist.

See Leaderboard filtered to VQ Quantize for the full filterable view.