Skip to main content

atomr_accel_cuda/kernel/blas_lt/
epilogue.rs

1//! `Epilogue` enum — atomr-accel's curated mapping over cuBLASLt's
2//! `cublasLtEpilogue_t`.
3//!
4//! cuBLASLt fuses post-matmul ops (bias add, activation, gradient
5//! aux/preact) into the kernel itself. The full set is large and
6//! version-dependent; we expose the variants that matter for
7//! transformer training/inference: bias, ReLU/GeLU forward + aux,
8//! ReLU/GeLU backward (`drelu`/`dgelu`) with optional bias gradient,
9//! and the `BGRADA`/`BGRADB` reduction-only variants used by mixed
10//! optimizer/data-parallel pipelines.
11//!
12//! Cache key compatibility: `Epilogue` derives `Hash + Eq` so
13//! `HeuristicKey` can fold it into the `(m,n,k,dtype,layout,
14//! epilogue,arch)` cache without a custom `impl`.
15
16use cudarc::cublaslt::sys::cublasLtEpilogue_t;
17
18/// Curated epilogue matrix.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20#[repr(u32)]
21pub enum Epilogue {
22    /// No fused activation, no bias. Identity epilogue.
23    None,
24    /// Fused ReLU.
25    Relu,
26    /// Bias-add only.
27    Bias,
28    /// Bias-add + ReLU.
29    ReluBias,
30    /// ReLU forward storing the masked-input "aux" tensor for the
31    /// matching backward pass.
32    ReluAux,
33    /// ReLU forward + bias + aux storage.
34    ReluAuxBias,
35    /// Fused GeLU.
36    Gelu,
37    /// GeLU forward storing the unactivated preact for backward.
38    GeluAux,
39    /// Bias-add + GeLU.
40    GeluBias,
41    /// Bias-add + GeLU forward + aux storage.
42    GeluAuxBias,
43    /// ReLU backward (`drelu`) — gradient w.r.t. the preact.
44    DRelu,
45    /// ReLU backward + reduce-sum producing bias gradient.
46    DReluBgrad,
47    /// GeLU backward (`dgelu`).
48    DGelu,
49    /// GeLU backward + reduce-sum producing bias gradient.
50    DGeluBgrad,
51    /// Reduction-only along the `A` (M) dimension — produces the
52    /// bias-grad without any activation.
53    BgradA,
54    /// Reduction-only along the `B` (N) dimension.
55    BgradB,
56}
57
58impl Epilogue {
59    /// Map to the cuBLASLt sys-level enum value.
60    pub fn to_cublas(self) -> cublasLtEpilogue_t {
61        match self {
62            Self::None => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
63            Self::Relu => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU,
64            Self::Bias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS,
65            Self::ReluBias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS,
66            Self::ReluAux => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_AUX,
67            Self::ReluAuxBias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_AUX_BIAS,
68            Self::Gelu => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU,
69            Self::GeluAux => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX,
70            Self::GeluBias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS,
71            Self::GeluAuxBias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX_BIAS,
72            // `CUBLASLT_EPILOGUE_DRELU`/`_DGELU` (pure backward, no
73            // bias-grad reduction) are only present on cudarc's
74            // CUDA ≥ 11.6 cfg branch, which we can't easily detect
75            // from this crate. Map both pure-backward variants to the
76            // BGRAD form — callers wanting the pure-backward output
77            // simply ignore the bias-grad side output.
78            Self::DRelu | Self::DReluBgrad => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DRELU_BGRAD,
79            Self::DGelu | Self::DGeluBgrad => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DGELU_BGRAD,
80            Self::BgradA => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADA,
81            Self::BgradB => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADB,
82        }
83    }
84
85    /// Does this epilogue read or write a bias vector?
86    pub fn uses_bias(self) -> bool {
87        matches!(
88            self,
89            Self::Bias | Self::ReluBias | Self::ReluAuxBias | Self::GeluBias | Self::GeluAuxBias
90        )
91    }
92
93    /// Does this epilogue store or consume an `epilogue_aux` tensor
94    /// (the activation preact / mask used by the matching backward)?
95    pub fn uses_aux(self) -> bool {
96        matches!(
97            self,
98            Self::ReluAux
99                | Self::ReluAuxBias
100                | Self::GeluAux
101                | Self::GeluAuxBias
102                | Self::DRelu
103                | Self::DReluBgrad
104                | Self::DGelu
105                | Self::DGeluBgrad
106        )
107    }
108
109    /// Does this epilogue produce a bias gradient as a side output
110    /// (`BGRADA`, `BGRADB`, or `D{Relu,Gelu}_BGRAD`)?
111    pub fn produces_bias_grad(self) -> bool {
112        matches!(
113            self,
114            Self::BgradA | Self::BgradB | Self::DReluBgrad | Self::DGeluBgrad
115        )
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    /// Round-trip every variant through `to_cublas()` and check it
124    /// maps to a non-zero, non-default sys-level value (or is `None`,
125    /// which is the only legitimate `DEFAULT`).
126    #[test]
127    fn epilogue_round_trip() {
128        let cases = [
129            (
130                Epilogue::None,
131                cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
132            ),
133            (Epilogue::Relu, cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU),
134            (Epilogue::Bias, cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS),
135            (
136                Epilogue::ReluBias,
137                cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS,
138            ),
139            (
140                Epilogue::ReluAux,
141                cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_AUX,
142            ),
143            (
144                Epilogue::ReluAuxBias,
145                cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_AUX_BIAS,
146            ),
147            (Epilogue::Gelu, cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU),
148            (
149                Epilogue::GeluAux,
150                cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX,
151            ),
152            (
153                Epilogue::GeluBias,
154                cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS,
155            ),
156            (
157                Epilogue::GeluAuxBias,
158                cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX_BIAS,
159            ),
160            (
161                Epilogue::DReluBgrad,
162                cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DRELU_BGRAD,
163            ),
164            (
165                Epilogue::DGeluBgrad,
166                cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DGELU_BGRAD,
167            ),
168            (
169                Epilogue::BgradA,
170                cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADA,
171            ),
172            (
173                Epilogue::BgradB,
174                cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADB,
175            ),
176        ];
177        for (lhs, rhs) in cases {
178            assert_eq!(lhs.to_cublas(), rhs, "{lhs:?}");
179        }
180    }
181
182    #[test]
183    fn epilogue_capability_flags() {
184        assert!(Epilogue::Bias.uses_bias());
185        assert!(!Epilogue::None.uses_bias());
186        assert!(Epilogue::ReluBias.uses_bias());
187        assert!(Epilogue::GeluAux.uses_aux());
188        assert!(Epilogue::DReluBgrad.produces_bias_grad());
189        assert!(Epilogue::BgradA.produces_bias_grad());
190        assert!(!Epilogue::Relu.produces_bias_grad());
191    }
192
193    #[test]
194    fn epilogue_default_is_none_variant() {
195        // Sanity: ensure default discriminant equality holds across
196        // cudarc cuda-version cfgs (both 11.x and 12.x bindings emit
197        // CUBLASLT_EPILOGUE_DEFAULT = 1).
198        assert_eq!(
199            Epilogue::None.to_cublas() as u32,
200            cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT as u32
201        );
202    }
203}