atomr_accel_cuda/kernel/blas_lt/epilogue.rs
1//! `Epilogue` enum — atomr-accel's curated mapping over cuBLASLt's
2//! `cublasLtEpilogue_t`.
3//!
4//! cuBLASLt fuses post-matmul ops (bias add, activation, gradient
5//! aux/preact) into the kernel itself. The full set is large and
6//! version-dependent; we expose the variants that matter for
7//! transformer training/inference: bias, ReLU/GeLU forward + aux,
8//! ReLU/GeLU backward (`drelu`/`dgelu`) with optional bias gradient,
9//! and the `BGRADA`/`BGRADB` reduction-only variants used by mixed
10//! optimizer/data-parallel pipelines.
11//!
12//! Cache key compatibility: `Epilogue` derives `Hash + Eq` so
13//! `HeuristicKey` can fold it into the `(m,n,k,dtype,layout,
14//! epilogue,arch)` cache without a custom `impl`.
15
16use cudarc::cublaslt::sys::cublasLtEpilogue_t;
17
18/// Curated epilogue matrix.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20#[repr(u32)]
21pub enum Epilogue {
22 /// No fused activation, no bias. Identity epilogue.
23 None,
24 /// Fused ReLU.
25 Relu,
26 /// Bias-add only.
27 Bias,
28 /// Bias-add + ReLU.
29 ReluBias,
30 /// ReLU forward storing the masked-input "aux" tensor for the
31 /// matching backward pass.
32 ReluAux,
33 /// ReLU forward + bias + aux storage.
34 ReluAuxBias,
35 /// Fused GeLU.
36 Gelu,
37 /// GeLU forward storing the unactivated preact for backward.
38 GeluAux,
39 /// Bias-add + GeLU.
40 GeluBias,
41 /// Bias-add + GeLU forward + aux storage.
42 GeluAuxBias,
43 /// ReLU backward (`drelu`) — gradient w.r.t. the preact.
44 DRelu,
45 /// ReLU backward + reduce-sum producing bias gradient.
46 DReluBgrad,
47 /// GeLU backward (`dgelu`).
48 DGelu,
49 /// GeLU backward + reduce-sum producing bias gradient.
50 DGeluBgrad,
51 /// Reduction-only along the `A` (M) dimension — produces the
52 /// bias-grad without any activation.
53 BgradA,
54 /// Reduction-only along the `B` (N) dimension.
55 BgradB,
56}
57
58impl Epilogue {
59 /// Map to the cuBLASLt sys-level enum value.
60 pub fn to_cublas(self) -> cublasLtEpilogue_t {
61 match self {
62 Self::None => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
63 Self::Relu => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU,
64 Self::Bias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS,
65 Self::ReluBias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS,
66 Self::ReluAux => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_AUX,
67 Self::ReluAuxBias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_AUX_BIAS,
68 Self::Gelu => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU,
69 Self::GeluAux => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX,
70 Self::GeluBias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS,
71 Self::GeluAuxBias => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX_BIAS,
72 // `CUBLASLT_EPILOGUE_DRELU`/`_DGELU` (pure backward, no
73 // bias-grad reduction) are only present on cudarc's
74 // CUDA ≥ 11.6 cfg branch, which we can't easily detect
75 // from this crate. Map both pure-backward variants to the
76 // BGRAD form — callers wanting the pure-backward output
77 // simply ignore the bias-grad side output.
78 Self::DRelu | Self::DReluBgrad => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DRELU_BGRAD,
79 Self::DGelu | Self::DGeluBgrad => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DGELU_BGRAD,
80 Self::BgradA => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADA,
81 Self::BgradB => cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADB,
82 }
83 }
84
85 /// Does this epilogue read or write a bias vector?
86 pub fn uses_bias(self) -> bool {
87 matches!(
88 self,
89 Self::Bias | Self::ReluBias | Self::ReluAuxBias | Self::GeluBias | Self::GeluAuxBias
90 )
91 }
92
93 /// Does this epilogue store or consume an `epilogue_aux` tensor
94 /// (the activation preact / mask used by the matching backward)?
95 pub fn uses_aux(self) -> bool {
96 matches!(
97 self,
98 Self::ReluAux
99 | Self::ReluAuxBias
100 | Self::GeluAux
101 | Self::GeluAuxBias
102 | Self::DRelu
103 | Self::DReluBgrad
104 | Self::DGelu
105 | Self::DGeluBgrad
106 )
107 }
108
109 /// Does this epilogue produce a bias gradient as a side output
110 /// (`BGRADA`, `BGRADB`, or `D{Relu,Gelu}_BGRAD`)?
111 pub fn produces_bias_grad(self) -> bool {
112 matches!(
113 self,
114 Self::BgradA | Self::BgradB | Self::DReluBgrad | Self::DGeluBgrad
115 )
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 /// Round-trip every variant through `to_cublas()` and check it
124 /// maps to a non-zero, non-default sys-level value (or is `None`,
125 /// which is the only legitimate `DEFAULT`).
126 #[test]
127 fn epilogue_round_trip() {
128 let cases = [
129 (
130 Epilogue::None,
131 cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT,
132 ),
133 (Epilogue::Relu, cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU),
134 (Epilogue::Bias, cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BIAS),
135 (
136 Epilogue::ReluBias,
137 cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_BIAS,
138 ),
139 (
140 Epilogue::ReluAux,
141 cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_AUX,
142 ),
143 (
144 Epilogue::ReluAuxBias,
145 cublasLtEpilogue_t::CUBLASLT_EPILOGUE_RELU_AUX_BIAS,
146 ),
147 (Epilogue::Gelu, cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU),
148 (
149 Epilogue::GeluAux,
150 cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX,
151 ),
152 (
153 Epilogue::GeluBias,
154 cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_BIAS,
155 ),
156 (
157 Epilogue::GeluAuxBias,
158 cublasLtEpilogue_t::CUBLASLT_EPILOGUE_GELU_AUX_BIAS,
159 ),
160 (
161 Epilogue::DReluBgrad,
162 cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DRELU_BGRAD,
163 ),
164 (
165 Epilogue::DGeluBgrad,
166 cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DGELU_BGRAD,
167 ),
168 (
169 Epilogue::BgradA,
170 cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADA,
171 ),
172 (
173 Epilogue::BgradB,
174 cublasLtEpilogue_t::CUBLASLT_EPILOGUE_BGRADB,
175 ),
176 ];
177 for (lhs, rhs) in cases {
178 assert_eq!(lhs.to_cublas(), rhs, "{lhs:?}");
179 }
180 }
181
182 #[test]
183 fn epilogue_capability_flags() {
184 assert!(Epilogue::Bias.uses_bias());
185 assert!(!Epilogue::None.uses_bias());
186 assert!(Epilogue::ReluBias.uses_bias());
187 assert!(Epilogue::GeluAux.uses_aux());
188 assert!(Epilogue::DReluBgrad.produces_bias_grad());
189 assert!(Epilogue::BgradA.produces_bias_grad());
190 assert!(!Epilogue::Relu.produces_bias_grad());
191 }
192
193 #[test]
194 fn epilogue_default_is_none_variant() {
195 // Sanity: ensure default discriminant equality holds across
196 // cudarc cuda-version cfgs (both 11.x and 12.x bindings emit
197 // CUBLASLT_EPILOGUE_DEFAULT = 1).
198 assert_eq!(
199 Epilogue::None.to_cublas() as u32,
200 cublasLtEpilogue_t::CUBLASLT_EPILOGUE_DEFAULT as u32
201 );
202 }
203}