atomr_accel_cuda/kernel/blas_lt/
mod.rs1use std::sync::Arc;
12
13use async_trait::async_trait;
14use atomr_core::actor::{Actor, Context, Props};
15pub use cudarc::cublaslt::Activation;
16use cudarc::cublaslt::{CudaBlasLT, MatmulConfig};
17use tokio::sync::oneshot;
18
19use crate::completion::CompletionStrategy;
20use crate::device::DeviceState;
21use crate::error::GpuError;
22use crate::gpu_ref::GpuRef;
23use crate::kernel::dispatch::{BlasLtDispatch, BlasLtDispatchCtx};
24use crate::stream::StreamAllocator;
25
26pub mod epilogue;
27pub mod heuristic;
28pub mod matmul;
29pub mod scaling;
30pub mod workspace;
31
32pub use epilogue::Epilogue;
33pub use heuristic::{HeuristicCacheRef, HeuristicEntry, HeuristicKey, DEFAULT_HEURISTIC_CAPACITY};
34pub use matmul::MatmulRequest;
35pub use scaling::ScaleSet;
36pub use workspace::{WorkspaceLease, WorkspacePool};
37
38const LIB: &str = "cublaslt";
39
40pub enum BlasLtMsg {
42 Matmul(Box<dyn BlasLtDispatch>),
46
47 #[deprecated(
51 since = "0.2.0",
52 note = "use BlasLtMsg::Matmul(Box::new(MatmulRequest::<f32> { … }))"
53 )]
54 MatmulF32 {
55 cfg: MatmulConfig,
56 a: GpuRef<f32>,
57 b: GpuRef<f32>,
58 c: GpuRef<f32>,
59 bias: Option<GpuRef<f32>>,
60 activation: Option<Activation>,
61 reply: oneshot::Sender<Result<(), GpuError>>,
62 },
63}
64
65impl BlasLtMsg {
66 pub fn matmul<T>(req: MatmulRequest<T>) -> Self
69 where
70 T: crate::dtype::GemmSupported,
71 MatmulRequest<T>: BlasLtDispatch,
72 {
73 Self::Matmul(Box::new(req))
74 }
75}
76
77pub struct BlasLtActor {
78 inner: BlasLtInner,
79}
80
81enum BlasLtInner {
82 Real {
83 blas_lt: Arc<CudaBlasLT>,
84 stream: Arc<cudarc::driver::CudaStream>,
85 completion: Arc<dyn CompletionStrategy>,
86 #[allow(dead_code)]
87 state: Arc<DeviceState>,
88 workspace_pool: WorkspacePool,
89 heuristic_cache: HeuristicCacheRef,
90 sm_arch: u32,
91 },
92 Mock,
93}
94
95impl BlasLtActor {
96 pub fn props(
97 stream: Arc<cudarc::driver::CudaStream>,
98 _allocator: Arc<dyn StreamAllocator>,
99 completion: Arc<dyn CompletionStrategy>,
100 state: Arc<DeviceState>,
101 ) -> Props<Self> {
102 Props::create(move || {
103 let blas_lt = match CudaBlasLT::new(stream.clone()) {
104 Ok(b) => b,
105 Err(e) => panic!("ContextPoisoned: CudaBlasLT::new failed: {e}"),
106 };
107 let sm_arch = stream
114 .context()
115 .attribute(
116 cudarc::driver::sys::CUdevice_attribute::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
117 )
118 .ok()
119 .map(|m| m as u32 * 10)
120 .unwrap_or(0);
121 BlasLtActor {
122 inner: BlasLtInner::Real {
123 blas_lt: Arc::new(blas_lt),
124 stream: stream.clone(),
125 completion: completion.clone(),
126 state: state.clone(),
127 workspace_pool: WorkspacePool::new(),
128 heuristic_cache: HeuristicCacheRef::default_size(),
129 sm_arch,
130 },
131 }
132 })
133 }
134
135 pub fn mock_props() -> Props<Self> {
136 Props::create(|| BlasLtActor {
137 inner: BlasLtInner::Mock,
138 })
139 }
140}
141
142#[async_trait]
143impl Actor for BlasLtActor {
144 type Msg = BlasLtMsg;
145
146 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: BlasLtMsg) {
147 match &self.inner {
148 BlasLtInner::Mock => match msg {
149 BlasLtMsg::Matmul(req) => {
150 drop(req);
156 }
157 #[allow(deprecated)]
158 BlasLtMsg::MatmulF32 { reply, .. } => {
159 let _ = reply.send(Err(GpuError::Unrecoverable(
160 "BlasLtActor in mock mode".into(),
161 )));
162 }
163 },
164 BlasLtInner::Real {
165 blas_lt,
166 stream,
167 completion,
168 workspace_pool,
169 heuristic_cache,
170 sm_arch,
171 ..
172 } => match msg {
173 BlasLtMsg::Matmul(req) => {
174 let dctx = BlasLtDispatchCtx {
175 blas_lt: blas_lt.clone(),
176 stream,
177 completion,
178 workspace: workspace_pool,
179 heuristic: heuristic_cache.clone(),
180 sm_arch: *sm_arch,
181 };
182 req.dispatch(&dctx);
183 }
184 #[allow(deprecated)]
185 BlasLtMsg::MatmulF32 {
186 cfg,
187 a,
188 b,
189 c,
190 bias,
191 activation,
192 reply,
193 } => {
194 enqueue_matmul_f32_legacy(
195 blas_lt.clone(),
196 stream,
197 completion,
198 cfg,
199 a,
200 b,
201 c,
202 bias,
203 activation,
204 reply,
205 );
206 }
207 },
208 }
209 }
210}
211
212fn enqueue_matmul_f32_legacy(
215 blas_lt: Arc<CudaBlasLT>,
216 stream: &Arc<cudarc::driver::CudaStream>,
217 completion: &Arc<dyn CompletionStrategy>,
218 cfg: MatmulConfig,
219 a: GpuRef<f32>,
220 b: GpuRef<f32>,
221 c: GpuRef<f32>,
222 bias: Option<GpuRef<f32>>,
223 activation: Option<Activation>,
224 reply: oneshot::Sender<Result<(), GpuError>>,
225) {
226 use crate::kernel::envelope;
227 use cudarc::cublaslt::Matmul;
228
229 let (a_slice, b_slice, c_slice) = match envelope::access_all_3(&a, &b, &c) {
230 Ok(t) => t,
231 Err(e) => {
232 let _ = reply.send(Err(e));
233 return;
234 }
235 };
236 let bias_slice = match bias.as_ref() {
237 None => None,
238 Some(g) => match g.access() {
239 Ok(s) => Some(s.clone()),
240 Err(e) => {
241 let _ = reply.send(Err(e));
242 return;
243 }
244 },
245 };
246 let mut c_owned = match Arc::try_unwrap(c_slice) {
247 Ok(s) => s,
248 Err(_) => {
249 let _ = reply.send(Err(GpuError::Unrecoverable(
250 "BlasLt C has multiple live references".into(),
251 )));
252 return;
253 }
254 };
255 c.record_write(stream);
256 envelope::run_kernel(LIB, stream, completion, (), reply, move || {
257 let bias_ref = bias_slice.as_ref().map(|s| &**s);
258 let act_ref = activation.as_ref();
259 let res =
261 unsafe { blas_lt.matmul(cfg, &*a_slice, &*b_slice, &mut c_owned, bias_ref, act_ref) };
262 match res {
263 Ok(()) => Ok((a_slice, b_slice, c_owned, bias_slice, blas_lt)),
264 Err(e) => Err(GpuError::LibraryError {
265 lib: LIB,
266 msg: format!("matmul: {e}"),
267 }),
268 }
269 });
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275
276 #[test]
277 fn blas_lt_msg_matmul_constructor() {
278 let (tx, _rx) = oneshot::channel::<Result<(), GpuError>>();
281 let _f: fn(MatmulRequest<f32>) -> BlasLtMsg = BlasLtMsg::matmul::<f32>;
285 drop(tx);
286 }
287}