Skip to main content

atomr_accel_cuda/kernel/blas/
scaling.rs

1//! fp8 scaling-factor helpers.
2//!
3//! Hopper+ fp8 cuBLAS calls (`cublasGemmEx` with `CUDA_R_8F_E4M3` /
4//! `CUDA_R_8F_E5M2` operands) take a per-tensor or per-row scaling
5//! factor that brings the input/output values into the representable
6//! fp8 range. This module factors out the small bookkeeping helpers
7//! used by both cuBLAS and cuBLASLt fp8 paths so they don't have to
8//! be duplicated.
9//!
10//! The full fp8 path lights up under the `cublas-fp8` cargo feature
11//! (currently scaffolded — Phase 1 cuBLAS slice ships the helper
12//! types, the wired call site lives in cuBLASLt's own module).
13
14#![allow(dead_code)]
15
16/// Per-tensor scaling factor: a single multiplicative scalar.
17///
18/// `a_scale` is computed by the caller (typically `max(abs(A)) /
19/// fp8_max`) and passed to `cublasGemmEx` via `alpha = a_scale *
20/// b_scale * gemm_alpha`.
21#[derive(Debug, Clone, Copy, Default)]
22pub struct PerTensorScale {
23    pub scale: f32,
24}
25
26/// Per-row scaling factor: a vector of `m` scalars, one per row of
27/// the matrix. Stored device-side; the caller passes a `GpuRef<f32>`
28/// when the cuBLASLt descriptor accepts row-wise amax.
29#[derive(Debug, Clone)]
30pub struct PerRowScale {
31    pub rows: i32,
32    pub scale_buf: crate::gpu_ref::GpuRef<f32>,
33}
34
35#[cfg(test)]
36mod tests {
37    use super::*;
38
39    #[test]
40    fn per_tensor_scale_default_is_one() {
41        let s = PerTensorScale::default();
42        assert_eq!(s.scale, 0.0);
43    }
44}