Skip to main content

atomr_accel_cuda/module/
mod.rs

1//! `ModuleActor` — load prebuilt cubin/PTX from disk (or memory) and
2//! launch its kernels.
3//!
4//! Distinct from [`crate::kernel::nvrtc`] / `NvrtcActor`, which
5//! JIT-compiles CUDA C++ at runtime. This actor is for the
6//! "ahead-of-time-compiled, ship-the-bytes" workflow.
7//!
8//! Lifecycle:
9//! 1. `LoadCubin { bytes }` or `LoadPtx { src }` → returns a
10//!    `ModuleHandle`.
11//! 2. `GetFunction { handle, name }` → returns a `FunctionHandle`.
12//! 3. `Launch { function, cfg, args }` → enqueues a kernel call on
13//!    the actor's stream and replies after stream completion.
14//! 4. `LaunchCooperative { function, cfg, args }` → same but goes
15//!    through `cuLaunchCooperativeKernel`. Required for cluster /
16//!    grid-sync kernels (Hopper SM90+).
17//! 5. `Unload { handle }` → frees the `CUmodule`.
18//!
19//! `KernelArg` is re-exported from [`crate::kernel::nvrtc`] so this
20//! actor and `NvrtcActor` share one launch-arg type.
21
22use std::collections::HashMap;
23use std::ffi::CString;
24use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
25use std::sync::Arc;
26
27use async_trait::async_trait;
28use atomr_core::actor::{Actor, Context, Props};
29use cudarc::driver::sys as driver_sys;
30use cudarc::driver::{CudaContext, CudaStream, LaunchConfig};
31use parking_lot::Mutex;
32use tokio::sync::oneshot;
33
34use crate::completion::CompletionStrategy;
35use crate::error::GpuError;
36use crate::sys::cuda_driver;
37
38#[cfg(feature = "nvrtc")]
39pub use crate::kernel::nvrtc::KernelArg;
40
41/// When the `nvrtc` feature is off we still need a launchable arg
42/// enum. Mirror the public surface; the variants line up so the
43/// `module::tests::launch_args_share_kernel_arg_type` test passes
44/// transparently when nvrtc is enabled, and the standalone enum
45/// covers no-feature builds.
46#[cfg(not(feature = "nvrtc"))]
47pub enum KernelArg {
48    DevSliceF32(crate::gpu_ref::GpuRef<f32>),
49    DevSliceF64(crate::gpu_ref::GpuRef<f64>),
50    DevSliceI32(crate::gpu_ref::GpuRef<i32>),
51    DevSliceU32(crate::gpu_ref::GpuRef<u32>),
52    DevSliceU8(crate::gpu_ref::GpuRef<u8>),
53    ScalarF32(f32),
54    ScalarF64(f64),
55    ScalarI32(i32),
56    ScalarU32(u32),
57    ScalarU64(u64),
58    Usize(usize),
59}
60
61const LIB: &str = "module";
62
63/// Opaque module handle. Carries an internal id used by the actor's
64/// internal `HashMap`.
65#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
66pub struct ModuleHandle {
67    id: u64,
68}
69
70#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
71pub struct FunctionHandle {
72    module: u64,
73    function_id: u64,
74}
75
76pub enum ModuleMsg {
77    LoadCubin {
78        bytes: Vec<u8>,
79        reply: oneshot::Sender<Result<ModuleHandle, GpuError>>,
80    },
81    LoadPtx {
82        src: String,
83        reply: oneshot::Sender<Result<ModuleHandle, GpuError>>,
84    },
85    GetFunction {
86        handle: ModuleHandle,
87        name: String,
88        reply: oneshot::Sender<Result<FunctionHandle, GpuError>>,
89    },
90    Launch {
91        function: FunctionHandle,
92        cfg: LaunchConfig,
93        args: Vec<KernelArg>,
94        reply: oneshot::Sender<Result<(), GpuError>>,
95    },
96    LaunchCooperative {
97        function: FunctionHandle,
98        cfg: LaunchConfig,
99        args: Vec<KernelArg>,
100        reply: oneshot::Sender<Result<(), GpuError>>,
101    },
102    Unload {
103        handle: ModuleHandle,
104        reply: oneshot::Sender<Result<(), GpuError>>,
105    },
106}
107
108struct LoadedModule {
109    cu_module: driver_sys::CUmodule,
110    /// Map function name → cu_function. We track FunctionHandle.id
111    /// separately so callers don't accidentally pass a stale name.
112    functions: HashMap<u64, (CString, driver_sys::CUfunction)>,
113    next_function_id: u64,
114}
115
116unsafe impl Send for LoadedModule {}
117unsafe impl Sync for LoadedModule {}
118
119#[allow(dead_code)]
120enum ModuleInner {
121    Real {
122        ctx: Arc<CudaContext>,
123        stream: Arc<CudaStream>,
124        completion: Arc<dyn CompletionStrategy>,
125        modules: Mutex<HashMap<u64, LoadedModule>>,
126        next_module_id: AtomicU64,
127    },
128    Mock,
129}
130
131pub struct ModuleActor {
132    inner: ModuleInner,
133}
134
135impl ModuleActor {
136    pub fn props(
137        ctx: Arc<CudaContext>,
138        stream: Arc<CudaStream>,
139        completion: Arc<dyn CompletionStrategy>,
140    ) -> Props<Self> {
141        Props::create(move || ModuleActor {
142            inner: ModuleInner::Real {
143                ctx: ctx.clone(),
144                stream: stream.clone(),
145                completion: completion.clone(),
146                modules: Mutex::new(HashMap::new()),
147                next_module_id: AtomicU64::new(1),
148            },
149        })
150    }
151
152    pub fn mock_props() -> Props<Self> {
153        Props::create(|| ModuleActor {
154            inner: ModuleInner::Mock,
155        })
156    }
157}
158
159impl Drop for ModuleInner {
160    fn drop(&mut self) {
161        if let ModuleInner::Real { modules, .. } = self {
162            let mut g = modules.lock();
163            for (_id, m) in g.drain() {
164                let _ = cuda_driver::module_unload(m.cu_module);
165            }
166        }
167    }
168}
169
170#[async_trait]
171impl Actor for ModuleActor {
172    type Msg = ModuleMsg;
173
174    async fn handle(&mut self, _ctx: &mut Context<Self>, msg: ModuleMsg) {
175        match &self.inner {
176            ModuleInner::Mock => mock_reply(msg),
177            ModuleInner::Real {
178                ctx,
179                stream,
180                completion: _completion,
181                modules,
182                next_module_id,
183            } => handle_real(ctx, stream, modules, next_module_id, msg),
184        }
185    }
186}
187
188fn mock_reply(msg: ModuleMsg) {
189    let unrecoverable = || GpuError::Unrecoverable("ModuleActor in mock mode".into());
190    match msg {
191        ModuleMsg::LoadCubin { reply, .. } => {
192            let _ = reply.send(Err(unrecoverable()));
193        }
194        ModuleMsg::LoadPtx { reply, .. } => {
195            let _ = reply.send(Err(unrecoverable()));
196        }
197        ModuleMsg::GetFunction { reply, .. } => {
198            let _ = reply.send(Err(unrecoverable()));
199        }
200        ModuleMsg::Launch { reply, .. } => {
201            let _ = reply.send(Err(unrecoverable()));
202        }
203        ModuleMsg::LaunchCooperative { reply, .. } => {
204            let _ = reply.send(Err(unrecoverable()));
205        }
206        ModuleMsg::Unload { reply, .. } => {
207            let _ = reply.send(Err(unrecoverable()));
208        }
209    }
210}
211
212fn handle_real(
213    ctx: &Arc<CudaContext>,
214    stream: &Arc<CudaStream>,
215    modules: &Mutex<HashMap<u64, LoadedModule>>,
216    next_module_id: &AtomicU64,
217    msg: ModuleMsg,
218) {
219    match msg {
220        ModuleMsg::LoadCubin { bytes, reply } => {
221            let r = load_image(ctx, modules, next_module_id, &bytes);
222            let _ = reply.send(r);
223        }
224        ModuleMsg::LoadPtx { src, reply } => {
225            // PTX is a NUL-terminated text string. Append a NUL if the
226            // caller didn't.
227            let mut text = src.into_bytes();
228            if !text.ends_with(&[0]) {
229                text.push(0);
230            }
231            let r = load_image(ctx, modules, next_module_id, &text);
232            let _ = reply.send(r);
233        }
234        ModuleMsg::GetFunction {
235            handle,
236            name,
237            reply,
238        } => {
239            let r = get_function(modules, handle, &name);
240            let _ = reply.send(r);
241        }
242        ModuleMsg::Launch {
243            function,
244            cfg,
245            args,
246            reply,
247        } => {
248            let r = launch_inner(modules, stream, function, cfg, args, false);
249            let _ = reply.send(r);
250        }
251        ModuleMsg::LaunchCooperative {
252            function,
253            cfg,
254            args,
255            reply,
256        } => {
257            let r = launch_inner(modules, stream, function, cfg, args, true);
258            let _ = reply.send(r);
259        }
260        ModuleMsg::Unload { handle, reply } => {
261            let r = unload(modules, handle);
262            let _ = reply.send(r);
263        }
264    }
265}
266
267fn load_image(
268    ctx: &Arc<CudaContext>,
269    modules: &Mutex<HashMap<u64, LoadedModule>>,
270    next_module_id: &AtomicU64,
271    bytes: &[u8],
272) -> Result<ModuleHandle, GpuError> {
273    let bind = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| ctx.bind_to_thread()));
274    match bind {
275        Ok(Ok(())) => {}
276        Ok(Err(e)) => {
277            return Err(GpuError::LibraryError {
278                lib: LIB,
279                msg: format!("bind_to_thread: {e}"),
280            });
281        }
282        Err(_) => {
283            return Err(GpuError::Unrecoverable(
284                "ModuleActor::Load: CUDA driver not loadable".into(),
285            ));
286        }
287    }
288    let m = cuda_driver::module_load_data(bytes.as_ptr() as *const _)?;
289    let id = next_module_id.fetch_add(1, AtomicOrdering::Relaxed);
290    modules.lock().insert(
291        id,
292        LoadedModule {
293            cu_module: m,
294            functions: HashMap::new(),
295            next_function_id: 1,
296        },
297    );
298    Ok(ModuleHandle { id })
299}
300
301fn get_function(
302    modules: &Mutex<HashMap<u64, LoadedModule>>,
303    handle: ModuleHandle,
304    name: &str,
305) -> Result<FunctionHandle, GpuError> {
306    let mut g = modules.lock();
307    let m = g.get_mut(&handle.id).ok_or_else(|| {
308        GpuError::Unrecoverable(format!(
309            "ModuleActor::GetFunction: unknown module {}",
310            handle.id
311        ))
312    })?;
313    let cname = CString::new(name).map_err(|e| {
314        GpuError::Unrecoverable(format!("ModuleActor::GetFunction: NUL in name: {e}"))
315    })?;
316    let f = cuda_driver::module_get_function(m.cu_module, &cname)?;
317    let function_id = m.next_function_id;
318    m.next_function_id += 1;
319    m.functions.insert(function_id, (cname, f));
320    Ok(FunctionHandle {
321        module: handle.id,
322        function_id,
323    })
324}
325
326fn launch_inner(
327    modules: &Mutex<HashMap<u64, LoadedModule>>,
328    stream: &Arc<CudaStream>,
329    function: FunctionHandle,
330    cfg: LaunchConfig,
331    args: Vec<KernelArg>,
332    cooperative: bool,
333) -> Result<(), GpuError> {
334    let g = modules.lock();
335    let m = g.get(&function.module).ok_or_else(|| {
336        GpuError::Unrecoverable(format!(
337            "ModuleActor::Launch: unknown module {}",
338            function.module
339        ))
340    })?;
341    let (_name, cu_func) = m.functions.get(&function.function_id).ok_or_else(|| {
342        GpuError::Unrecoverable(format!(
343            "ModuleActor::Launch: unknown function {}/{}",
344            function.module, function.function_id
345        ))
346    })?;
347    let cu_func = *cu_func;
348
349    // Build the kernel-params array. Each entry is a pointer to a
350    // value owned by the Vec<KernelArgScratch> we allocate below; the
351    // backing storage must outlive the launch call.
352    let mut scratch: Vec<KernelArgScratch> = Vec::with_capacity(args.len());
353    let mut keep_alive: Vec<Arc<cudarc::driver::CudaSlice<u8>>> = Vec::new();
354    for a in args.into_iter() {
355        scratch.push(KernelArgScratch::from_arg(a, &mut keep_alive)?);
356    }
357    let mut ptrs: Vec<*mut std::ffi::c_void> =
358        scratch.iter_mut().map(|s| s.as_void_ptr()).collect();
359
360    let grid = (cfg.grid_dim.0, cfg.grid_dim.1, cfg.grid_dim.2);
361    let block = (cfg.block_dim.0, cfg.block_dim.1, cfg.block_dim.2);
362    let res = if cooperative {
363        cuda_driver::launch_cooperative_kernel(
364            cu_func,
365            grid,
366            block,
367            cfg.shared_mem_bytes,
368            stream.cu_stream(),
369            ptrs.as_mut_ptr(),
370        )
371    } else {
372        cuda_driver::launch_kernel(
373            cu_func,
374            grid,
375            block,
376            cfg.shared_mem_bytes,
377            stream.cu_stream(),
378            ptrs.as_mut_ptr(),
379        )
380    };
381    // Hold scratch + keep_alive across the call; once the driver
382    // consumes the params, they can drop.
383    drop(scratch);
384    drop(keep_alive);
385    res
386}
387
388fn unload(
389    modules: &Mutex<HashMap<u64, LoadedModule>>,
390    handle: ModuleHandle,
391) -> Result<(), GpuError> {
392    let mut g = modules.lock();
393    let m = g.remove(&handle.id).ok_or_else(|| {
394        GpuError::Unrecoverable(format!("ModuleActor::Unload: unknown module {}", handle.id))
395    })?;
396    cuda_driver::module_unload(m.cu_module)
397}
398
399/// Backing storage for one `KernelArg` during a single launch.
400enum KernelArgScratch {
401    DevPtr(driver_sys::CUdeviceptr),
402    F32(f32),
403    F64(f64),
404    I32(i32),
405    U32(u32),
406    U64(u64),
407    Usize(usize),
408}
409
410impl KernelArgScratch {
411    fn from_arg(
412        arg: KernelArg,
413        _keep_alive: &mut Vec<Arc<cudarc::driver::CudaSlice<u8>>>,
414    ) -> Result<Self, GpuError> {
415        // We only need the device pointer for the launch; the keep_alive
416        // vec holds the Arc<CudaSlice> so the allocation stays live.
417        // We reinterpret the slice as bytes for the keep_alive list,
418        // but that requires the same `T` parameter — instead store it
419        // typed.
420        macro_rules! retain {
421            ($g:expr) => {{
422                use cudarc::driver::DevicePtr;
423                let s = $g.access()?.clone();
424                // Capture the device pointer immediately. The
425                // SyncOnDrop guard returned by device_ptr() ties the
426                // lifetime to `s` — we must keep `s` alive for the
427                // duration of the launch.
428                let (ptr, _g) = s.device_ptr(_stream_for_record());
429                let _ = _g;
430                let _ = keep_alive; // satisfy "unused" lint when no slices captured
431                ptr
432            }};
433        }
434        // We can't easily do the macro above because `s` is not
435        // type-erasable to `Arc<CudaSlice<u8>>`. Inline per-type:
436        #[allow(deprecated, unreachable_patterns)]
437        Ok(match arg {
438            KernelArg::DevSliceF32(g) => Self::DevPtr(devptr_of(g)?),
439            KernelArg::DevSliceF64(g) => Self::DevPtr(devptr_of(g)?),
440            KernelArg::DevSliceI32(g) => Self::DevPtr(devptr_of(g)?),
441            KernelArg::DevSliceU32(g) => Self::DevPtr(devptr_of(g)?),
442            KernelArg::DevSliceU8(g) => Self::DevPtr(devptr_of(g)?),
443            KernelArg::ScalarF32(v) => Self::F32(v),
444            KernelArg::ScalarF64(v) => Self::F64(v),
445            KernelArg::ScalarI32(v) => Self::I32(v),
446            KernelArg::ScalarU32(v) => Self::U32(v),
447            KernelArg::ScalarU64(v) => Self::U64(v),
448            KernelArg::Usize(v) => Self::Usize(v),
449            // Phase 0.3 boxed-dispatch variants — wired in a follow-up
450            // PR. The module-launch path can't yet thread typed-erased
451            // box payloads through the cuLaunchKernel ABI without
452            // additional plumbing.
453            KernelArg::DevSlice(_) | KernelArg::Scalar(_) => {
454                return Err(GpuError::Unrecoverable(
455                    "ModuleActor: KernelArg::DevSlice/Scalar dispatch not yet wired".into(),
456                ));
457            }
458        })
459    }
460
461    fn as_void_ptr(&mut self) -> *mut std::ffi::c_void {
462        match self {
463            KernelArgScratch::DevPtr(p) => p as *mut _ as *mut _,
464            KernelArgScratch::F32(v) => v as *mut _ as *mut _,
465            KernelArgScratch::F64(v) => v as *mut _ as *mut _,
466            KernelArgScratch::I32(v) => v as *mut _ as *mut _,
467            KernelArgScratch::U32(v) => v as *mut _ as *mut _,
468            KernelArgScratch::U64(v) => v as *mut _ as *mut _,
469            KernelArgScratch::Usize(v) => v as *mut _ as *mut _,
470        }
471    }
472}
473
474#[allow(dead_code)]
475fn _stream_for_record() -> &'static Arc<cudarc::driver::CudaStream> {
476    // Placeholder used by the macro above — never reached in practice
477    // because we use `devptr_of` directly. Keeping the symbol so
478    // future refactors can centralise the pointer-grabbing pattern.
479    panic!("not used")
480}
481
482fn devptr_of<T>(g: crate::gpu_ref::GpuRef<T>) -> Result<driver_sys::CUdeviceptr, GpuError> {
483    use cudarc::driver::DevicePtr;
484    let s = g.access()?.clone();
485    let stream = s.stream().clone();
486    let (ptr, _guard) = s.device_ptr(&stream);
487    let _ = _guard;
488    let _ = s; // hold across return — we drop right after.
489    Ok(ptr)
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495    use atomr_config::Config;
496    use atomr_core::actor::ActorSystem;
497    use std::time::Duration;
498
499    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
500    async fn module_msg_round_trip() {
501        let sys = ActorSystem::create("module-test", Config::empty())
502            .await
503            .unwrap();
504        let actor = sys.actor_of(ModuleActor::mock_props(), "mod").unwrap();
505
506        let (tx, rx) = oneshot::channel();
507        actor.tell(ModuleMsg::LoadCubin {
508            bytes: vec![1, 2, 3, 4],
509            reply: tx,
510        });
511        let r = tokio::time::timeout(Duration::from_secs(2), rx)
512            .await
513            .unwrap()
514            .unwrap();
515        assert!(matches!(r, Err(GpuError::Unrecoverable(_))));
516
517        let (tx, rx) = oneshot::channel();
518        actor.tell(ModuleMsg::LoadPtx {
519            src: ".version 7.0".into(),
520            reply: tx,
521        });
522        let _ = tokio::time::timeout(Duration::from_secs(2), rx)
523            .await
524            .unwrap()
525            .unwrap();
526
527        let bogus = ModuleHandle { id: 99 };
528        let (tx, rx) = oneshot::channel();
529        actor.tell(ModuleMsg::GetFunction {
530            handle: bogus,
531            name: "kern".into(),
532            reply: tx,
533        });
534        let _ = tokio::time::timeout(Duration::from_secs(2), rx)
535            .await
536            .unwrap()
537            .unwrap();
538
539        let bogus_fn = FunctionHandle {
540            module: 99,
541            function_id: 1,
542        };
543        let (tx, rx) = oneshot::channel();
544        actor.tell(ModuleMsg::Launch {
545            function: bogus_fn,
546            cfg: LaunchConfig::for_num_elems(64),
547            args: vec![],
548            reply: tx,
549        });
550        let _ = tokio::time::timeout(Duration::from_secs(2), rx)
551            .await
552            .unwrap()
553            .unwrap();
554
555        let (tx, rx) = oneshot::channel();
556        actor.tell(ModuleMsg::LaunchCooperative {
557            function: bogus_fn,
558            cfg: LaunchConfig::for_num_elems(64),
559            args: vec![],
560            reply: tx,
561        });
562        let _ = tokio::time::timeout(Duration::from_secs(2), rx)
563            .await
564            .unwrap()
565            .unwrap();
566
567        let (tx, rx) = oneshot::channel();
568        actor.tell(ModuleMsg::Unload {
569            handle: bogus,
570            reply: tx,
571        });
572        let _ = tokio::time::timeout(Duration::from_secs(2), rx)
573            .await
574            .unwrap()
575            .unwrap();
576
577        sys.terminate().await;
578    }
579
580    #[cfg(feature = "nvrtc")]
581    #[test]
582    fn launch_args_share_kernel_arg_type() {
583        // Confirm that `KernelArg` re-exported here is the same type
584        // as `crate::kernel::nvrtc::KernelArg`. If the re-export ever
585        // breaks, this won't compile.
586        fn _assert<T>(_x: T) {}
587        _assert::<crate::kernel::nvrtc::KernelArg>(KernelArg::Usize(7));
588    }
589
590    #[cfg(not(feature = "nvrtc"))]
591    #[test]
592    fn launch_args_share_kernel_arg_type() {
593        // When nvrtc is disabled, the standalone enum still has the
594        // expected variants.
595        let _arg = KernelArg::Usize(7);
596        let _arg2 = KernelArg::ScalarF32(1.0);
597    }
598}