atomr_accel_cuda/kernel/blas/scaling.rs
1//! fp8 scaling-factor helpers.
2//!
3//! Hopper+ fp8 cuBLAS calls (`cublasGemmEx` with `CUDA_R_8F_E4M3` /
4//! `CUDA_R_8F_E5M2` operands) take a per-tensor or per-row scaling
5//! factor that brings the input/output values into the representable
6//! fp8 range. This module factors out the small bookkeeping helpers
7//! used by both cuBLAS and cuBLASLt fp8 paths so they don't have to
8//! be duplicated.
9//!
10//! The full fp8 path lights up under the `cublas-fp8` cargo feature
11//! (currently scaffolded — Phase 1 cuBLAS slice ships the helper
12//! types, the wired call site lives in cuBLASLt's own module).
13
14#![allow(dead_code)]
15
16/// Per-tensor scaling factor: a single multiplicative scalar.
17///
18/// `a_scale` is computed by the caller (typically `max(abs(A)) /
19/// fp8_max`) and passed to `cublasGemmEx` via `alpha = a_scale *
20/// b_scale * gemm_alpha`.
21#[derive(Debug, Clone, Copy, Default)]
22pub struct PerTensorScale {
23 pub scale: f32,
24}
25
26/// Per-row scaling factor: a vector of `m` scalars, one per row of
27/// the matrix. Stored device-side; the caller passes a `GpuRef<f32>`
28/// when the cuBLASLt descriptor accepts row-wise amax.
29#[derive(Debug, Clone)]
30pub struct PerRowScale {
31 pub rows: i32,
32 pub scale_buf: crate::gpu_ref::GpuRef<f32>,
33}
34
35#[cfg(test)]
36mod tests {
37 use super::*;
38
39 #[test]
40 fn per_tensor_scale_default_is_one() {
41 let s = PerTensorScale::default();
42 assert_eq!(s.scale, 0.0);
43 }
44}