Skip to main content

atomr_accel_cuda/kernel/
envelope.rs

1//! `kernel::envelope` — shared kernel-actor body factored out of
2//! `BlasActor::enqueue_sgemm`.
3//!
4//! Every library actor (`BlasActor`, `CudnnActor`, `FftActor`,
5//! `RngActor`, …) follows the same pattern:
6//!
7//! 1. Validate every input [`GpuRef`] via [`GpuRef::access`] and turn
8//!    the strong [`Arc<CudaSlice<T>>`] into a temporary owner that
9//!    keeps the buffer alive past kernel completion.
10//! 2. Synchronously enqueue the kernel onto the actor's stream. The
11//!    enqueue body is library-specific and provided as a closure.
12//! 3. Spawn an async task that awaits the configured
13//!    [`CompletionStrategy`], delivers the reply on a `oneshot::Sender`,
14//!    and only then drops the temporary owners (so that the kernel
15//!    can't outlive its inputs).
16//!
17//! The envelope handles step 3 uniformly. Pre-launch errors are
18//! reported synchronously through the same `oneshot`. Post-launch
19//! errors arrive through the completion future.
20//!
21//! # Single-writer enforcement
22//!
23//! cudarc's library APIs typically take `&mut Dst` for the write
24//! target. cudarc 0.19 satisfies this for `CudaSlice<T>`. Since a
25//! `GpuRef<T>` wraps `Arc<CudaSlice<T>>`, callers that want write
26//! access to a buffer must hold the unique reference to that
27//! `GpuRef` (so `Arc::try_unwrap` succeeds inside the actor). Each
28//! library actor enforces this contract explicitly — the envelope
29//! does not, because some libraries (cuBLAS gemm with non-zero beta)
30//! read-modify-write the output while others (cuDNN forward conv)
31//! write to a freshly allocated output.
32//!
33//! # Observability hooks (Phase 0.7)
34//!
35//! [`KernelEnvelope`] is an opt-in builder that wraps the same
36//! pipeline with two observability surfaces:
37//!
38//! * A `KernelTrace` callback that fires four lifecycle events
39//!   (`before_enqueue`, `after_enqueue`, `before_complete`,
40//!   `after_complete`). The trait is **always compiled** — when no
41//!   trace is set, the envelope skips the calls entirely.
42//! * An optional NVTX range label. When the `nvtx` cargo feature is
43//!   on, the synchronous-enqueue body is wrapped in a
44//!   `cudarc::nvtx::safe::scoped_range` guard. When the feature is
45//!   off, the field is unused and adds no runtime cost.
46//!
47//! Existing callers that use the free [`run_kernel`] function continue
48//! to behave byte-for-byte identically: that path constructs a default
49//! (trace-less, nvtx-less) envelope.
50
51use std::sync::Arc;
52use std::time::{Duration, Instant};
53
54use cudarc::driver::CudaSlice;
55use futures_util::FutureExt;
56use tokio::sync::oneshot;
57use tracing::warn;
58
59use crate::completion::CompletionStrategy;
60use crate::error::GpuError;
61use crate::gpu_ref::GpuRef;
62
63/// Validate two input `GpuRef`s and return owning `Arc`s of their
64/// underlying slices. Fails fast (synchronously) with `GpuRefStale` if
65/// either is invalid.
66pub fn access_all_2<A, B>(
67    a: &GpuRef<A>,
68    b: &GpuRef<B>,
69) -> Result<(Arc<CudaSlice<A>>, Arc<CudaSlice<B>>), GpuError> {
70    let a_s = a.access()?.clone();
71    let b_s = b.access()?.clone();
72    Ok((a_s, b_s))
73}
74
75/// Validate three input `GpuRef`s and return owning `Arc`s of their
76/// underlying slices.
77pub fn access_all_3<A, B, C>(
78    a: &GpuRef<A>,
79    b: &GpuRef<B>,
80    c: &GpuRef<C>,
81) -> Result<(Arc<CudaSlice<A>>, Arc<CudaSlice<B>>, Arc<CudaSlice<C>>), GpuError> {
82    let a_s = a.access()?.clone();
83    let b_s = b.access()?.clone();
84    let c_s = c.access()?.clone();
85    Ok((a_s, b_s, c_s))
86}
87
88/// Validate four input `GpuRef`s and return owning `Arc`s. Used by
89/// cuDNN convolution which takes (input, filter, output, workspace).
90pub fn access_all_4<A, B, C, D>(
91    a: &GpuRef<A>,
92    b: &GpuRef<B>,
93    c: &GpuRef<C>,
94    d: &GpuRef<D>,
95) -> Result<
96    (
97        Arc<CudaSlice<A>>,
98        Arc<CudaSlice<B>>,
99        Arc<CudaSlice<C>>,
100        Arc<CudaSlice<D>>,
101    ),
102    GpuError,
103> {
104    let a_s = a.access()?.clone();
105    let b_s = b.access()?.clone();
106    let c_s = c.access()?.clone();
107    let d_s = d.access()?.clone();
108    Ok((a_s, b_s, c_s, d_s))
109}
110
111// ---------------------------------------------------------------------------
112// Observability hooks (Phase 0.7).
113// ---------------------------------------------------------------------------
114
115/// Per-launch metadata passed to every [`KernelTrace`] callback.
116///
117/// `dtype` is `Option<&'static str>` because some kernel actors do not
118/// have a single-dtype identity (e.g. memcpy, NCCL group calls). When
119/// Phase 0.1 lands a real `atomr_accel::DType`, this field will be
120/// promoted to `Option<atomr_accel::DType>` without changing the trait
121/// shape.
122#[derive(Debug, Clone, Copy)]
123pub struct KernelInfo<'a> {
124    /// Op identifier (e.g. `"sgemm"`, `"conv2d_forward"`).
125    pub op_name: &'a str,
126    /// Library tag (e.g. `"cublas"`, `"cudnn"`, `"nccl"`).
127    pub library: &'a str,
128    /// Stream identity. The raw CUstream pointer cast to `u64`; opaque
129    /// to consumers but stable for the lifetime of the stream.
130    pub stream_id: u64,
131    /// Element dtype, if the op is single-dtype. `None` otherwise.
132    pub dtype: Option<&'a str>,
133}
134
135/// Lifecycle hook receiver. All four methods have empty default
136/// bodies, so a custom trace can override only the events it cares
137/// about.
138///
139/// The trait is always compiled; no feature flag is required. When no
140/// trace is attached to a [`KernelEnvelope`], the envelope skips every
141/// call and adds zero runtime cost beyond a single `Option::is_some`
142/// check (which the optimizer typically folds away).
143pub trait KernelTrace: Send + Sync + 'static {
144    /// Fires immediately before the synchronous enqueue closure runs.
145    fn before_enqueue(&self, info: &KernelInfo<'_>) {
146        let _ = info;
147    }
148
149    /// Fires immediately after the synchronous enqueue closure
150    /// returns. `result` is `Ok(())` on success or `Err(&GpuError)` if
151    /// the enqueue body failed.
152    fn after_enqueue(&self, info: &KernelInfo<'_>, result: Result<(), &GpuError>) {
153        let _ = (info, result);
154    }
155
156    /// Fires just before the completion future is awaited (i.e. after
157    /// a successful enqueue, on the spawned Tokio task).
158    fn before_complete(&self, info: &KernelInfo<'_>) {
159        let _ = info;
160    }
161
162    /// Fires after the completion future resolves. `latency` is the
163    /// wall-clock duration between `before_complete` and the resolved
164    /// completion (i.e. host-observed completion latency, not GPU
165    /// time).
166    fn after_complete(
167        &self,
168        info: &KernelInfo<'_>,
169        result: Result<(), &GpuError>,
170        latency: Duration,
171    ) {
172        let _ = (info, result, latency);
173    }
174}
175
176/// Builder/configuration for a single `run_kernel` invocation.
177///
178/// Existing actors that call the free [`run_kernel`] function are
179/// unaffected — that path still constructs a default envelope
180/// internally. Actors that want observability migrate to
181/// `KernelEnvelope::new(lib).with_trace(..).with_nvtx(..).run_kernel(..)`.
182#[derive(Clone)]
183pub struct KernelEnvelope {
184    lib_tag: &'static str,
185    op_name: &'static str,
186    dtype: Option<&'static str>,
187    trace: Option<Arc<dyn KernelTrace>>,
188    /// NVTX range label. When `Some(..)` and the `nvtx` feature is
189    /// enabled, the envelope wraps the synchronous enqueue body in a
190    /// `cudarc::nvtx::safe::scoped_range` guard. When the `nvtx`
191    /// feature is disabled, the field is read but otherwise unused
192    /// (zero runtime cost).
193    nvtx_range_name: Option<&'static str>,
194}
195
196impl KernelEnvelope {
197    /// Construct a trace-less, NVTX-less envelope tagged with the
198    /// given library. `op_name` defaults to `lib_tag` and can be
199    /// refined via [`Self::with_op_name`].
200    pub fn new(lib_tag: &'static str) -> Self {
201        Self {
202            lib_tag,
203            op_name: lib_tag,
204            dtype: None,
205            trace: None,
206            nvtx_range_name: None,
207        }
208    }
209
210    /// Override the op identifier surfaced to the trace callback (e.g.
211    /// `"sgemm"`, `"conv2d_forward"`).
212    pub fn with_op_name(mut self, op_name: &'static str) -> Self {
213        self.op_name = op_name;
214        self
215    }
216
217    /// Tag the envelope with a dtype name (e.g. `"f32"`, `"f16"`).
218    /// Surfaced to trace callbacks as `KernelInfo::dtype`.
219    pub fn with_dtype(mut self, dtype: &'static str) -> Self {
220        self.dtype = Some(dtype);
221        self
222    }
223
224    /// Attach a `KernelTrace` callback. Cloning `Arc<dyn KernelTrace>`
225    /// is cheap; the same object can be shared across many envelopes.
226    pub fn with_trace(mut self, trace: Arc<dyn KernelTrace>) -> Self {
227        self.trace = Some(trace);
228        self
229    }
230
231    /// Attach an NVTX range label. No-op unless the `nvtx` cargo
232    /// feature is enabled.
233    pub fn with_nvtx(mut self, name: &'static str) -> Self {
234        self.nvtx_range_name = Some(name);
235        self
236    }
237
238    fn info<'a>(&'a self, stream_id: u64) -> KernelInfo<'a> {
239        KernelInfo {
240            op_name: self.op_name,
241            library: self.lib_tag,
242            stream_id,
243            dtype: self.dtype,
244        }
245    }
246
247    /// Builder-style equivalent of the free [`run_kernel`] function
248    /// with observability hooks layered in.
249    ///
250    /// Behaviour without a trace and without an NVTX range is
251    /// byte-for-byte identical to [`run_kernel`].
252    pub fn run_kernel<O, KA, F>(
253        self,
254        stream: &Arc<cudarc::driver::CudaStream>,
255        completion: &Arc<dyn CompletionStrategy>,
256        output: O,
257        reply: oneshot::Sender<Result<O, GpuError>>,
258        enqueue: F,
259    ) where
260        O: Send + 'static,
261        KA: Send + 'static,
262        F: FnOnce() -> Result<KA, GpuError>,
263    {
264        let stream_id = stream.cu_stream() as usize as u64;
265        let info = self.info(stream_id);
266
267        if let Some(t) = self.trace.as_deref() {
268            t.before_enqueue(&info);
269        }
270
271        // NVTX range (if feature on and label set) wraps the enqueue
272        // closure. The guard drops at the end of this block.
273        let enqueue_result = {
274            #[cfg(feature = "nvtx")]
275            let _nvtx_guard = self.nvtx_range_name.map(cudarc::nvtx::safe::scoped_range);
276            #[cfg(not(feature = "nvtx"))]
277            let _ = self.nvtx_range_name;
278
279            enqueue()
280        };
281
282        let keep_alive = match enqueue_result {
283            Ok(ka) => {
284                if let Some(t) = self.trace.as_deref() {
285                    t.after_enqueue(&info, Ok(()));
286                }
287                ka
288            }
289            Err(e) => {
290                let annotated = annotate_error(e, self.lib_tag);
291                if let Some(t) = self.trace.as_deref() {
292                    t.after_enqueue(&info, Err(&annotated));
293                }
294                let _ = reply.send(Err(annotated));
295                return;
296            }
297        };
298
299        let fut = completion.await_completion(stream).boxed();
300        let lib_tag = self.lib_tag;
301        let op_name = self.op_name;
302        let dtype = self.dtype;
303        let trace = self.trace.clone();
304        tokio::spawn(async move {
305            // Re-build the info struct on the spawned task so the
306            // closure doesn't have to capture a self-borrowing
307            // reference.
308            let info = KernelInfo {
309                op_name,
310                library: lib_tag,
311                stream_id,
312                dtype,
313            };
314            if let Some(t) = trace.as_deref() {
315                t.before_complete(&info);
316            }
317            let started = Instant::now();
318            let result = fut.await;
319            let latency = started.elapsed();
320            match result {
321                Ok(()) => {
322                    if let Some(t) = trace.as_deref() {
323                        t.after_complete(&info, Ok(()), latency);
324                    }
325                    let _ = reply.send(Ok(output));
326                }
327                Err(e) => {
328                    warn!(lib = lib_tag, error = %e, "kernel completion failed");
329                    if let Some(t) = trace.as_deref() {
330                        t.after_complete(&info, Err(&e), latency);
331                    }
332                    let _ = reply.send(Err(e));
333                }
334            }
335            // Held until completion resolved; safe to drop now.
336            drop(keep_alive);
337        });
338    }
339}
340
341impl std::fmt::Debug for KernelEnvelope {
342    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343        f.debug_struct("KernelEnvelope")
344            .field("lib_tag", &self.lib_tag)
345            .field("op_name", &self.op_name)
346            .field("dtype", &self.dtype)
347            .field("nvtx_range_name", &self.nvtx_range_name)
348            .field("trace", &self.trace.as_ref().map(|_| "<dyn KernelTrace>"))
349            .finish()
350    }
351}
352
353/// Run the synchronous-enqueue + async-completion-await pipeline.
354///
355/// `enqueue` runs immediately on the calling actor's task. On success
356/// it returns the **keep-alive tuple** — anything that must outlive
357/// the kernel (input `Arc<CudaSlice<T>>`s, the unwrapped write
358/// target, descriptor handles, etc.). The envelope spawns a Tokio
359/// task that awaits [`CompletionStrategy::await_completion`] for
360/// `stream`, replies via `reply`, and drops the keep-alive only
361/// after completion.
362///
363/// `lib_tag` populates the `lib` field of any error annotation. The
364/// completion future emits its own typed errors; on failure `output`
365/// is discarded.
366///
367/// This is the trace-less, NVTX-less compatibility entry point used by
368/// every actor that hasn't migrated to [`KernelEnvelope`]. Behaviour
369/// is byte-for-byte identical to the pre-Phase-0.7 implementation.
370pub fn run_kernel<O, KA, F>(
371    lib_tag: &'static str,
372    stream: &Arc<cudarc::driver::CudaStream>,
373    completion: &Arc<dyn CompletionStrategy>,
374    output: O,
375    reply: oneshot::Sender<Result<O, GpuError>>,
376    enqueue: F,
377) where
378    O: Send + 'static,
379    KA: Send + 'static,
380    F: FnOnce() -> Result<KA, GpuError>,
381{
382    let keep_alive = match enqueue() {
383        Ok(ka) => ka,
384        Err(e) => {
385            let _ = reply.send(Err(annotate_error(e, lib_tag)));
386            return;
387        }
388    };
389
390    let fut = completion.await_completion(stream).boxed();
391    tokio::spawn(async move {
392        let result = fut.await;
393        match result {
394            Ok(()) => {
395                let _ = reply.send(Ok(output));
396            }
397            Err(e) => {
398                warn!(lib = lib_tag, error = %e, "kernel completion failed");
399                let _ = reply.send(Err(e));
400            }
401        }
402        // Held until completion resolved; safe to drop now.
403        drop(keep_alive);
404    });
405}
406
407/// Tag a generic error with a library name iff it doesn't already
408/// carry a more specific classification. Pre-existing typed variants
409/// (`ContextPoisoned`, `OutOfMemory`, `GpuRefStale`, `Unrecoverable`,
410/// `Timeout`) and pre-tagged `LibraryError` pass through unchanged.
411fn annotate_error(e: GpuError, lib_tag: &'static str) -> GpuError {
412    match e {
413        GpuError::Driver(msg) => GpuError::LibraryError { lib: lib_tag, msg },
414        // Already-tagged or library-agnostic errors pass through.
415        other => other,
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use std::sync::atomic::{AtomicU32, Ordering};
423    use std::sync::Mutex;
424
425    #[test]
426    fn annotate_error_tags_driver_failures() {
427        let e = annotate_error(GpuError::Driver("oops".into()), "cudnn");
428        match e {
429            GpuError::LibraryError { lib, msg } => {
430                assert_eq!(lib, "cudnn");
431                assert_eq!(msg, "oops");
432            }
433            other => panic!("expected LibraryError, got {other:?}"),
434        }
435    }
436
437    #[test]
438    fn annotate_error_passes_through_typed_variants() {
439        let e = annotate_error(GpuError::OutOfMemory("alloc".into()), "cudnn");
440        assert!(matches!(e, GpuError::OutOfMemory(_)));
441        let e = annotate_error(GpuError::GpuRefStale("stale"), "cudnn");
442        assert!(matches!(e, GpuError::GpuRefStale(_)));
443    }
444
445    /// Smoke test the `enqueue` failure short-circuit on the synchronous
446    /// path without needing a real stream — the completion future is
447    /// never invoked. We verify the reply carries the failure.
448    #[test]
449    fn pre_enqueue_error_bypasses_completion() {
450        let (tx, rx) = oneshot::channel::<Result<u32, GpuError>>();
451        // We can't construct an Arc<CudaStream> on a host without a
452        // GPU, so we exercise the annotate_error path inline rather
453        // than running the full envelope. The full envelope is
454        // covered by GPU integration tests.
455        let mut bumped = AtomicU32::new(0);
456        let enqueue = || -> Result<(), GpuError> {
457            bumped.fetch_add(1, Ordering::Relaxed);
458            Err(GpuError::OutOfMemory("forced".into()))
459        };
460        let res = enqueue();
461        assert!(matches!(res, Err(GpuError::OutOfMemory(_))));
462        assert_eq!(*bumped.get_mut(), 1);
463        // reply isn't actually sent in this stripped-down test.
464        drop(tx);
465        drop(rx);
466    }
467
468    /// In-memory mock `KernelTrace` that records every event for
469    /// later inspection.
470    #[derive(Default)]
471    struct RecordingTrace {
472        events: Mutex<Vec<&'static str>>,
473        last_dtype: Mutex<Option<String>>,
474        last_op: Mutex<Option<String>>,
475        last_lib: Mutex<Option<String>>,
476        enqueue_ok: AtomicU32,
477        enqueue_err: AtomicU32,
478    }
479
480    impl KernelTrace for RecordingTrace {
481        fn before_enqueue(&self, info: &KernelInfo<'_>) {
482            self.events.lock().unwrap().push("before_enqueue");
483            *self.last_op.lock().unwrap() = Some(info.op_name.to_string());
484            *self.last_lib.lock().unwrap() = Some(info.library.to_string());
485            *self.last_dtype.lock().unwrap() = info.dtype.map(str::to_string);
486        }
487
488        fn after_enqueue(&self, _info: &KernelInfo<'_>, result: Result<(), &GpuError>) {
489            self.events.lock().unwrap().push("after_enqueue");
490            match result {
491                Ok(()) => {
492                    self.enqueue_ok.fetch_add(1, Ordering::Relaxed);
493                }
494                Err(_) => {
495                    self.enqueue_err.fetch_add(1, Ordering::Relaxed);
496                }
497            }
498        }
499
500        fn before_complete(&self, _info: &KernelInfo<'_>) {
501            self.events.lock().unwrap().push("before_complete");
502        }
503
504        fn after_complete(
505            &self,
506            _info: &KernelInfo<'_>,
507            _result: Result<(), &GpuError>,
508            _latency: Duration,
509        ) {
510            self.events.lock().unwrap().push("after_complete");
511        }
512    }
513
514    /// Internal helper used by the trace tests. Mirrors the body of
515    /// `KernelEnvelope::run_kernel` minus the actual stream / Tokio
516    /// spawn, so we can exercise the trace hooks on a host without a
517    /// GPU.
518    fn drive_envelope_trace<F>(
519        env: &KernelEnvelope,
520        enqueue: F,
521    ) -> (Result<(), GpuError>, Result<(), GpuError>)
522    where
523        F: FnOnce() -> Result<(), GpuError>,
524    {
525        // Synthesise a stream id without touching cudarc.
526        let info = env.info(0xDEAD_BEEF);
527        if let Some(t) = env.trace.as_deref() {
528            t.before_enqueue(&info);
529        }
530        let enqueue_result = enqueue();
531        let enqueue_report = match &enqueue_result {
532            Ok(()) => Ok(()),
533            Err(e) => Err(annotate_error_clone(e, env.lib_tag)),
534        };
535        if let Some(t) = env.trace.as_deref() {
536            match &enqueue_report {
537                Ok(()) => t.after_enqueue(&info, Ok(())),
538                Err(e) => t.after_enqueue(&info, Err(e)),
539            }
540        }
541        // Pretend the completion future resolved synchronously with
542        // success; that's the path Phase 9 will exercise on real
543        // streams. We only care that the trace fires in the right
544        // order here.
545        if enqueue_report.is_ok() {
546            if let Some(t) = env.trace.as_deref() {
547                t.before_complete(&info);
548                t.after_complete(&info, Ok(()), Duration::from_micros(1));
549            }
550        }
551        (enqueue_result, enqueue_report)
552    }
553
554    /// Cheap clone of the error for trace inspection in tests.
555    fn annotate_error_clone(e: &GpuError, lib_tag: &'static str) -> GpuError {
556        match e {
557            GpuError::Driver(msg) => GpuError::LibraryError {
558                lib: lib_tag,
559                msg: msg.clone(),
560            },
561            GpuError::OutOfMemory(msg) => GpuError::OutOfMemory(msg.clone()),
562            GpuError::ContextPoisoned(msg) => GpuError::ContextPoisoned(msg.clone()),
563            GpuError::Unrecoverable(msg) => GpuError::Unrecoverable(msg.clone()),
564            GpuError::GpuRefStale(s) => GpuError::GpuRefStale(s),
565            GpuError::LibraryError { lib, msg } => GpuError::LibraryError {
566                lib,
567                msg: msg.clone(),
568            },
569            // Other variants don't appear on the trace path in these
570            // tests; fall back to a generic library error so the
571            // helper stays compile-clean across the GpuError surface.
572            other => GpuError::LibraryError {
573                lib: lib_tag,
574                msg: other.to_string(),
575            },
576        }
577    }
578
579    #[test]
580    fn envelope_default_is_traceless_and_nvtxless() {
581        let env = KernelEnvelope::new("cublas");
582        assert!(env.trace.is_none());
583        assert!(env.nvtx_range_name.is_none());
584        assert_eq!(env.lib_tag, "cublas");
585        assert_eq!(env.op_name, "cublas");
586        assert!(env.dtype.is_none());
587    }
588
589    #[test]
590    fn envelope_builder_sets_metadata() {
591        let trace = Arc::new(RecordingTrace::default()) as Arc<dyn KernelTrace>;
592        let env = KernelEnvelope::new("cublas")
593            .with_op_name("sgemm")
594            .with_dtype("f32")
595            .with_trace(trace)
596            .with_nvtx("blas/sgemm");
597        assert_eq!(env.op_name, "sgemm");
598        assert_eq!(env.dtype, Some("f32"));
599        assert_eq!(env.nvtx_range_name, Some("blas/sgemm"));
600        assert!(env.trace.is_some());
601    }
602
603    #[test]
604    fn trace_hooks_fire_in_order_on_success() {
605        let trace = Arc::new(RecordingTrace::default());
606        let env = KernelEnvelope::new("cublas")
607            .with_op_name("sgemm")
608            .with_dtype("f32")
609            .with_trace(trace.clone() as Arc<dyn KernelTrace>);
610
611        let (enqueue_res, _) = drive_envelope_trace(&env, || Ok(()));
612        assert!(enqueue_res.is_ok());
613        let events = trace.events.lock().unwrap().clone();
614        assert_eq!(
615            events,
616            vec![
617                "before_enqueue",
618                "after_enqueue",
619                "before_complete",
620                "after_complete",
621            ]
622        );
623        assert_eq!(trace.enqueue_ok.load(Ordering::Relaxed), 1);
624        assert_eq!(trace.enqueue_err.load(Ordering::Relaxed), 0);
625        assert_eq!(trace.last_op.lock().unwrap().as_deref(), Some("sgemm"));
626        assert_eq!(trace.last_lib.lock().unwrap().as_deref(), Some("cublas"));
627        assert_eq!(trace.last_dtype.lock().unwrap().as_deref(), Some("f32"));
628    }
629
630    #[test]
631    fn trace_hooks_skip_completion_on_enqueue_error() {
632        let trace = Arc::new(RecordingTrace::default());
633        let env = KernelEnvelope::new("cudnn")
634            .with_op_name("conv2d_forward")
635            .with_trace(trace.clone() as Arc<dyn KernelTrace>);
636
637        let (enqueue_res, report) =
638            drive_envelope_trace(&env, || Err(GpuError::Driver("forced".into())));
639        assert!(enqueue_res.is_err());
640        // Driver errors get annotated to LibraryError.
641        match report {
642            Err(GpuError::LibraryError { lib, msg }) => {
643                assert_eq!(lib, "cudnn");
644                assert_eq!(msg, "forced");
645            }
646            other => panic!("expected LibraryError, got {other:?}"),
647        }
648        let events = trace.events.lock().unwrap().clone();
649        assert_eq!(events, vec!["before_enqueue", "after_enqueue"]);
650        assert_eq!(trace.enqueue_ok.load(Ordering::Relaxed), 0);
651        assert_eq!(trace.enqueue_err.load(Ordering::Relaxed), 1);
652    }
653
654    #[test]
655    fn envelope_without_trace_is_silent() {
656        let env = KernelEnvelope::new("cufft");
657        let (res, _) = drive_envelope_trace(&env, || Ok(()));
658        assert!(res.is_ok());
659        // No trace attached, so nothing to record. The point of this
660        // test is to make sure the trace-less path compiles and runs
661        // through `drive_envelope_trace` without panicking.
662    }
663}