1use std::collections::HashMap;
23use std::ffi::CString;
24use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
25use std::sync::Arc;
26
27use async_trait::async_trait;
28use atomr_core::actor::{Actor, Context, Props};
29use cudarc::driver::sys as driver_sys;
30use cudarc::driver::{CudaContext, CudaStream, LaunchConfig};
31use parking_lot::Mutex;
32use tokio::sync::oneshot;
33
34use crate::completion::CompletionStrategy;
35use crate::error::GpuError;
36use crate::sys::cuda_driver;
37
38#[cfg(feature = "nvrtc")]
39pub use crate::kernel::nvrtc::KernelArg;
40
41#[cfg(not(feature = "nvrtc"))]
47pub enum KernelArg {
48 DevSliceF32(crate::gpu_ref::GpuRef<f32>),
49 DevSliceF64(crate::gpu_ref::GpuRef<f64>),
50 DevSliceI32(crate::gpu_ref::GpuRef<i32>),
51 DevSliceU32(crate::gpu_ref::GpuRef<u32>),
52 DevSliceU8(crate::gpu_ref::GpuRef<u8>),
53 ScalarF32(f32),
54 ScalarF64(f64),
55 ScalarI32(i32),
56 ScalarU32(u32),
57 ScalarU64(u64),
58 Usize(usize),
59}
60
61const LIB: &str = "module";
62
63#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
66pub struct ModuleHandle {
67 id: u64,
68}
69
70#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
71pub struct FunctionHandle {
72 module: u64,
73 function_id: u64,
74}
75
76pub enum ModuleMsg {
77 LoadCubin {
78 bytes: Vec<u8>,
79 reply: oneshot::Sender<Result<ModuleHandle, GpuError>>,
80 },
81 LoadPtx {
82 src: String,
83 reply: oneshot::Sender<Result<ModuleHandle, GpuError>>,
84 },
85 GetFunction {
86 handle: ModuleHandle,
87 name: String,
88 reply: oneshot::Sender<Result<FunctionHandle, GpuError>>,
89 },
90 Launch {
91 function: FunctionHandle,
92 cfg: LaunchConfig,
93 args: Vec<KernelArg>,
94 reply: oneshot::Sender<Result<(), GpuError>>,
95 },
96 LaunchCooperative {
97 function: FunctionHandle,
98 cfg: LaunchConfig,
99 args: Vec<KernelArg>,
100 reply: oneshot::Sender<Result<(), GpuError>>,
101 },
102 Unload {
103 handle: ModuleHandle,
104 reply: oneshot::Sender<Result<(), GpuError>>,
105 },
106}
107
108struct LoadedModule {
109 cu_module: driver_sys::CUmodule,
110 functions: HashMap<u64, (CString, driver_sys::CUfunction)>,
113 next_function_id: u64,
114}
115
116unsafe impl Send for LoadedModule {}
117unsafe impl Sync for LoadedModule {}
118
119#[allow(dead_code)]
120enum ModuleInner {
121 Real {
122 ctx: Arc<CudaContext>,
123 stream: Arc<CudaStream>,
124 completion: Arc<dyn CompletionStrategy>,
125 modules: Mutex<HashMap<u64, LoadedModule>>,
126 next_module_id: AtomicU64,
127 },
128 Mock,
129}
130
131pub struct ModuleActor {
132 inner: ModuleInner,
133}
134
135impl ModuleActor {
136 pub fn props(
137 ctx: Arc<CudaContext>,
138 stream: Arc<CudaStream>,
139 completion: Arc<dyn CompletionStrategy>,
140 ) -> Props<Self> {
141 Props::create(move || ModuleActor {
142 inner: ModuleInner::Real {
143 ctx: ctx.clone(),
144 stream: stream.clone(),
145 completion: completion.clone(),
146 modules: Mutex::new(HashMap::new()),
147 next_module_id: AtomicU64::new(1),
148 },
149 })
150 }
151
152 pub fn mock_props() -> Props<Self> {
153 Props::create(|| ModuleActor {
154 inner: ModuleInner::Mock,
155 })
156 }
157}
158
159impl Drop for ModuleInner {
160 fn drop(&mut self) {
161 if let ModuleInner::Real { modules, .. } = self {
162 let mut g = modules.lock();
163 for (_id, m) in g.drain() {
164 let _ = cuda_driver::module_unload(m.cu_module);
165 }
166 }
167 }
168}
169
170#[async_trait]
171impl Actor for ModuleActor {
172 type Msg = ModuleMsg;
173
174 async fn handle(&mut self, _ctx: &mut Context<Self>, msg: ModuleMsg) {
175 match &self.inner {
176 ModuleInner::Mock => mock_reply(msg),
177 ModuleInner::Real {
178 ctx,
179 stream,
180 completion: _completion,
181 modules,
182 next_module_id,
183 } => handle_real(ctx, stream, modules, next_module_id, msg),
184 }
185 }
186}
187
188fn mock_reply(msg: ModuleMsg) {
189 let unrecoverable = || GpuError::Unrecoverable("ModuleActor in mock mode".into());
190 match msg {
191 ModuleMsg::LoadCubin { reply, .. } => {
192 let _ = reply.send(Err(unrecoverable()));
193 }
194 ModuleMsg::LoadPtx { reply, .. } => {
195 let _ = reply.send(Err(unrecoverable()));
196 }
197 ModuleMsg::GetFunction { reply, .. } => {
198 let _ = reply.send(Err(unrecoverable()));
199 }
200 ModuleMsg::Launch { reply, .. } => {
201 let _ = reply.send(Err(unrecoverable()));
202 }
203 ModuleMsg::LaunchCooperative { reply, .. } => {
204 let _ = reply.send(Err(unrecoverable()));
205 }
206 ModuleMsg::Unload { reply, .. } => {
207 let _ = reply.send(Err(unrecoverable()));
208 }
209 }
210}
211
212fn handle_real(
213 ctx: &Arc<CudaContext>,
214 stream: &Arc<CudaStream>,
215 modules: &Mutex<HashMap<u64, LoadedModule>>,
216 next_module_id: &AtomicU64,
217 msg: ModuleMsg,
218) {
219 match msg {
220 ModuleMsg::LoadCubin { bytes, reply } => {
221 let r = load_image(ctx, modules, next_module_id, &bytes);
222 let _ = reply.send(r);
223 }
224 ModuleMsg::LoadPtx { src, reply } => {
225 let mut text = src.into_bytes();
228 if !text.ends_with(&[0]) {
229 text.push(0);
230 }
231 let r = load_image(ctx, modules, next_module_id, &text);
232 let _ = reply.send(r);
233 }
234 ModuleMsg::GetFunction {
235 handle,
236 name,
237 reply,
238 } => {
239 let r = get_function(modules, handle, &name);
240 let _ = reply.send(r);
241 }
242 ModuleMsg::Launch {
243 function,
244 cfg,
245 args,
246 reply,
247 } => {
248 let r = launch_inner(modules, stream, function, cfg, args, false);
249 let _ = reply.send(r);
250 }
251 ModuleMsg::LaunchCooperative {
252 function,
253 cfg,
254 args,
255 reply,
256 } => {
257 let r = launch_inner(modules, stream, function, cfg, args, true);
258 let _ = reply.send(r);
259 }
260 ModuleMsg::Unload { handle, reply } => {
261 let r = unload(modules, handle);
262 let _ = reply.send(r);
263 }
264 }
265}
266
267fn load_image(
268 ctx: &Arc<CudaContext>,
269 modules: &Mutex<HashMap<u64, LoadedModule>>,
270 next_module_id: &AtomicU64,
271 bytes: &[u8],
272) -> Result<ModuleHandle, GpuError> {
273 let bind = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| ctx.bind_to_thread()));
274 match bind {
275 Ok(Ok(())) => {}
276 Ok(Err(e)) => {
277 return Err(GpuError::LibraryError {
278 lib: LIB,
279 msg: format!("bind_to_thread: {e}"),
280 });
281 }
282 Err(_) => {
283 return Err(GpuError::Unrecoverable(
284 "ModuleActor::Load: CUDA driver not loadable".into(),
285 ));
286 }
287 }
288 let m = cuda_driver::module_load_data(bytes.as_ptr() as *const _)?;
289 let id = next_module_id.fetch_add(1, AtomicOrdering::Relaxed);
290 modules.lock().insert(
291 id,
292 LoadedModule {
293 cu_module: m,
294 functions: HashMap::new(),
295 next_function_id: 1,
296 },
297 );
298 Ok(ModuleHandle { id })
299}
300
301fn get_function(
302 modules: &Mutex<HashMap<u64, LoadedModule>>,
303 handle: ModuleHandle,
304 name: &str,
305) -> Result<FunctionHandle, GpuError> {
306 let mut g = modules.lock();
307 let m = g.get_mut(&handle.id).ok_or_else(|| {
308 GpuError::Unrecoverable(format!(
309 "ModuleActor::GetFunction: unknown module {}",
310 handle.id
311 ))
312 })?;
313 let cname = CString::new(name).map_err(|e| {
314 GpuError::Unrecoverable(format!("ModuleActor::GetFunction: NUL in name: {e}"))
315 })?;
316 let f = cuda_driver::module_get_function(m.cu_module, &cname)?;
317 let function_id = m.next_function_id;
318 m.next_function_id += 1;
319 m.functions.insert(function_id, (cname, f));
320 Ok(FunctionHandle {
321 module: handle.id,
322 function_id,
323 })
324}
325
326fn launch_inner(
327 modules: &Mutex<HashMap<u64, LoadedModule>>,
328 stream: &Arc<CudaStream>,
329 function: FunctionHandle,
330 cfg: LaunchConfig,
331 args: Vec<KernelArg>,
332 cooperative: bool,
333) -> Result<(), GpuError> {
334 let g = modules.lock();
335 let m = g.get(&function.module).ok_or_else(|| {
336 GpuError::Unrecoverable(format!(
337 "ModuleActor::Launch: unknown module {}",
338 function.module
339 ))
340 })?;
341 let (_name, cu_func) = m.functions.get(&function.function_id).ok_or_else(|| {
342 GpuError::Unrecoverable(format!(
343 "ModuleActor::Launch: unknown function {}/{}",
344 function.module, function.function_id
345 ))
346 })?;
347 let cu_func = *cu_func;
348
349 let mut scratch: Vec<KernelArgScratch> = Vec::with_capacity(args.len());
353 let mut keep_alive: Vec<Arc<cudarc::driver::CudaSlice<u8>>> = Vec::new();
354 for a in args.into_iter() {
355 scratch.push(KernelArgScratch::from_arg(a, &mut keep_alive)?);
356 }
357 let mut ptrs: Vec<*mut std::ffi::c_void> =
358 scratch.iter_mut().map(|s| s.as_void_ptr()).collect();
359
360 let grid = (cfg.grid_dim.0, cfg.grid_dim.1, cfg.grid_dim.2);
361 let block = (cfg.block_dim.0, cfg.block_dim.1, cfg.block_dim.2);
362 let res = if cooperative {
363 cuda_driver::launch_cooperative_kernel(
364 cu_func,
365 grid,
366 block,
367 cfg.shared_mem_bytes,
368 stream.cu_stream(),
369 ptrs.as_mut_ptr(),
370 )
371 } else {
372 cuda_driver::launch_kernel(
373 cu_func,
374 grid,
375 block,
376 cfg.shared_mem_bytes,
377 stream.cu_stream(),
378 ptrs.as_mut_ptr(),
379 )
380 };
381 drop(scratch);
384 drop(keep_alive);
385 res
386}
387
388fn unload(
389 modules: &Mutex<HashMap<u64, LoadedModule>>,
390 handle: ModuleHandle,
391) -> Result<(), GpuError> {
392 let mut g = modules.lock();
393 let m = g.remove(&handle.id).ok_or_else(|| {
394 GpuError::Unrecoverable(format!("ModuleActor::Unload: unknown module {}", handle.id))
395 })?;
396 cuda_driver::module_unload(m.cu_module)
397}
398
399enum KernelArgScratch {
401 DevPtr(driver_sys::CUdeviceptr),
402 F32(f32),
403 F64(f64),
404 I32(i32),
405 U32(u32),
406 U64(u64),
407 Usize(usize),
408}
409
410impl KernelArgScratch {
411 fn from_arg(
412 arg: KernelArg,
413 _keep_alive: &mut Vec<Arc<cudarc::driver::CudaSlice<u8>>>,
414 ) -> Result<Self, GpuError> {
415 macro_rules! retain {
421 ($g:expr) => {{
422 use cudarc::driver::DevicePtr;
423 let s = $g.access()?.clone();
424 let (ptr, _g) = s.device_ptr(_stream_for_record());
429 let _ = _g;
430 let _ = keep_alive; ptr
432 }};
433 }
434 #[allow(deprecated, unreachable_patterns)]
437 Ok(match arg {
438 KernelArg::DevSliceF32(g) => Self::DevPtr(devptr_of(g)?),
439 KernelArg::DevSliceF64(g) => Self::DevPtr(devptr_of(g)?),
440 KernelArg::DevSliceI32(g) => Self::DevPtr(devptr_of(g)?),
441 KernelArg::DevSliceU32(g) => Self::DevPtr(devptr_of(g)?),
442 KernelArg::DevSliceU8(g) => Self::DevPtr(devptr_of(g)?),
443 KernelArg::ScalarF32(v) => Self::F32(v),
444 KernelArg::ScalarF64(v) => Self::F64(v),
445 KernelArg::ScalarI32(v) => Self::I32(v),
446 KernelArg::ScalarU32(v) => Self::U32(v),
447 KernelArg::ScalarU64(v) => Self::U64(v),
448 KernelArg::Usize(v) => Self::Usize(v),
449 KernelArg::DevSlice(_) | KernelArg::Scalar(_) => {
454 return Err(GpuError::Unrecoverable(
455 "ModuleActor: KernelArg::DevSlice/Scalar dispatch not yet wired".into(),
456 ));
457 }
458 })
459 }
460
461 fn as_void_ptr(&mut self) -> *mut std::ffi::c_void {
462 match self {
463 KernelArgScratch::DevPtr(p) => p as *mut _ as *mut _,
464 KernelArgScratch::F32(v) => v as *mut _ as *mut _,
465 KernelArgScratch::F64(v) => v as *mut _ as *mut _,
466 KernelArgScratch::I32(v) => v as *mut _ as *mut _,
467 KernelArgScratch::U32(v) => v as *mut _ as *mut _,
468 KernelArgScratch::U64(v) => v as *mut _ as *mut _,
469 KernelArgScratch::Usize(v) => v as *mut _ as *mut _,
470 }
471 }
472}
473
474#[allow(dead_code)]
475fn _stream_for_record() -> &'static Arc<cudarc::driver::CudaStream> {
476 panic!("not used")
480}
481
482fn devptr_of<T>(g: crate::gpu_ref::GpuRef<T>) -> Result<driver_sys::CUdeviceptr, GpuError> {
483 use cudarc::driver::DevicePtr;
484 let s = g.access()?.clone();
485 let stream = s.stream().clone();
486 let (ptr, _guard) = s.device_ptr(&stream);
487 let _ = _guard;
488 let _ = s; Ok(ptr)
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495 use atomr_config::Config;
496 use atomr_core::actor::ActorSystem;
497 use std::time::Duration;
498
499 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
500 async fn module_msg_round_trip() {
501 let sys = ActorSystem::create("module-test", Config::empty())
502 .await
503 .unwrap();
504 let actor = sys.actor_of(ModuleActor::mock_props(), "mod").unwrap();
505
506 let (tx, rx) = oneshot::channel();
507 actor.tell(ModuleMsg::LoadCubin {
508 bytes: vec![1, 2, 3, 4],
509 reply: tx,
510 });
511 let r = tokio::time::timeout(Duration::from_secs(2), rx)
512 .await
513 .unwrap()
514 .unwrap();
515 assert!(matches!(r, Err(GpuError::Unrecoverable(_))));
516
517 let (tx, rx) = oneshot::channel();
518 actor.tell(ModuleMsg::LoadPtx {
519 src: ".version 7.0".into(),
520 reply: tx,
521 });
522 let _ = tokio::time::timeout(Duration::from_secs(2), rx)
523 .await
524 .unwrap()
525 .unwrap();
526
527 let bogus = ModuleHandle { id: 99 };
528 let (tx, rx) = oneshot::channel();
529 actor.tell(ModuleMsg::GetFunction {
530 handle: bogus,
531 name: "kern".into(),
532 reply: tx,
533 });
534 let _ = tokio::time::timeout(Duration::from_secs(2), rx)
535 .await
536 .unwrap()
537 .unwrap();
538
539 let bogus_fn = FunctionHandle {
540 module: 99,
541 function_id: 1,
542 };
543 let (tx, rx) = oneshot::channel();
544 actor.tell(ModuleMsg::Launch {
545 function: bogus_fn,
546 cfg: LaunchConfig::for_num_elems(64),
547 args: vec![],
548 reply: tx,
549 });
550 let _ = tokio::time::timeout(Duration::from_secs(2), rx)
551 .await
552 .unwrap()
553 .unwrap();
554
555 let (tx, rx) = oneshot::channel();
556 actor.tell(ModuleMsg::LaunchCooperative {
557 function: bogus_fn,
558 cfg: LaunchConfig::for_num_elems(64),
559 args: vec![],
560 reply: tx,
561 });
562 let _ = tokio::time::timeout(Duration::from_secs(2), rx)
563 .await
564 .unwrap()
565 .unwrap();
566
567 let (tx, rx) = oneshot::channel();
568 actor.tell(ModuleMsg::Unload {
569 handle: bogus,
570 reply: tx,
571 });
572 let _ = tokio::time::timeout(Duration::from_secs(2), rx)
573 .await
574 .unwrap()
575 .unwrap();
576
577 sys.terminate().await;
578 }
579
580 #[cfg(feature = "nvrtc")]
581 #[test]
582 fn launch_args_share_kernel_arg_type() {
583 fn _assert<T>(_x: T) {}
587 _assert::<crate::kernel::nvrtc::KernelArg>(KernelArg::Usize(7));
588 }
589
590 #[cfg(not(feature = "nvrtc"))]
591 #[test]
592 fn launch_args_share_kernel_arg_type() {
593 let _arg = KernelArg::Usize(7);
596 let _arg2 = KernelArg::ScalarF32(1.0);
597 }
598}