1use std::sync::Arc;
52use std::time::{Duration, Instant};
53
54use cudarc::driver::CudaSlice;
55use futures_util::FutureExt;
56use tokio::sync::oneshot;
57use tracing::warn;
58
59use crate::completion::CompletionStrategy;
60use crate::error::GpuError;
61use crate::gpu_ref::GpuRef;
62
63pub fn access_all_2<A, B>(
67 a: &GpuRef<A>,
68 b: &GpuRef<B>,
69) -> Result<(Arc<CudaSlice<A>>, Arc<CudaSlice<B>>), GpuError> {
70 let a_s = a.access()?.clone();
71 let b_s = b.access()?.clone();
72 Ok((a_s, b_s))
73}
74
75pub fn access_all_3<A, B, C>(
78 a: &GpuRef<A>,
79 b: &GpuRef<B>,
80 c: &GpuRef<C>,
81) -> Result<(Arc<CudaSlice<A>>, Arc<CudaSlice<B>>, Arc<CudaSlice<C>>), GpuError> {
82 let a_s = a.access()?.clone();
83 let b_s = b.access()?.clone();
84 let c_s = c.access()?.clone();
85 Ok((a_s, b_s, c_s))
86}
87
88pub fn access_all_4<A, B, C, D>(
91 a: &GpuRef<A>,
92 b: &GpuRef<B>,
93 c: &GpuRef<C>,
94 d: &GpuRef<D>,
95) -> Result<
96 (
97 Arc<CudaSlice<A>>,
98 Arc<CudaSlice<B>>,
99 Arc<CudaSlice<C>>,
100 Arc<CudaSlice<D>>,
101 ),
102 GpuError,
103> {
104 let a_s = a.access()?.clone();
105 let b_s = b.access()?.clone();
106 let c_s = c.access()?.clone();
107 let d_s = d.access()?.clone();
108 Ok((a_s, b_s, c_s, d_s))
109}
110
111#[derive(Debug, Clone, Copy)]
123pub struct KernelInfo<'a> {
124 pub op_name: &'a str,
126 pub library: &'a str,
128 pub stream_id: u64,
131 pub dtype: Option<&'a str>,
133}
134
135pub trait KernelTrace: Send + Sync + 'static {
144 fn before_enqueue(&self, info: &KernelInfo<'_>) {
146 let _ = info;
147 }
148
149 fn after_enqueue(&self, info: &KernelInfo<'_>, result: Result<(), &GpuError>) {
153 let _ = (info, result);
154 }
155
156 fn before_complete(&self, info: &KernelInfo<'_>) {
159 let _ = info;
160 }
161
162 fn after_complete(
167 &self,
168 info: &KernelInfo<'_>,
169 result: Result<(), &GpuError>,
170 latency: Duration,
171 ) {
172 let _ = (info, result, latency);
173 }
174}
175
176#[derive(Clone)]
183pub struct KernelEnvelope {
184 lib_tag: &'static str,
185 op_name: &'static str,
186 dtype: Option<&'static str>,
187 trace: Option<Arc<dyn KernelTrace>>,
188 nvtx_range_name: Option<&'static str>,
194}
195
196impl KernelEnvelope {
197 pub fn new(lib_tag: &'static str) -> Self {
201 Self {
202 lib_tag,
203 op_name: lib_tag,
204 dtype: None,
205 trace: None,
206 nvtx_range_name: None,
207 }
208 }
209
210 pub fn with_op_name(mut self, op_name: &'static str) -> Self {
213 self.op_name = op_name;
214 self
215 }
216
217 pub fn with_dtype(mut self, dtype: &'static str) -> Self {
220 self.dtype = Some(dtype);
221 self
222 }
223
224 pub fn with_trace(mut self, trace: Arc<dyn KernelTrace>) -> Self {
227 self.trace = Some(trace);
228 self
229 }
230
231 pub fn with_nvtx(mut self, name: &'static str) -> Self {
234 self.nvtx_range_name = Some(name);
235 self
236 }
237
238 fn info<'a>(&'a self, stream_id: u64) -> KernelInfo<'a> {
239 KernelInfo {
240 op_name: self.op_name,
241 library: self.lib_tag,
242 stream_id,
243 dtype: self.dtype,
244 }
245 }
246
247 pub fn run_kernel<O, KA, F>(
253 self,
254 stream: &Arc<cudarc::driver::CudaStream>,
255 completion: &Arc<dyn CompletionStrategy>,
256 output: O,
257 reply: oneshot::Sender<Result<O, GpuError>>,
258 enqueue: F,
259 ) where
260 O: Send + 'static,
261 KA: Send + 'static,
262 F: FnOnce() -> Result<KA, GpuError>,
263 {
264 let stream_id = stream.cu_stream() as usize as u64;
265 let info = self.info(stream_id);
266
267 if let Some(t) = self.trace.as_deref() {
268 t.before_enqueue(&info);
269 }
270
271 let enqueue_result = {
274 #[cfg(feature = "nvtx")]
275 let _nvtx_guard = self.nvtx_range_name.map(cudarc::nvtx::safe::scoped_range);
276 #[cfg(not(feature = "nvtx"))]
277 let _ = self.nvtx_range_name;
278
279 enqueue()
280 };
281
282 let keep_alive = match enqueue_result {
283 Ok(ka) => {
284 if let Some(t) = self.trace.as_deref() {
285 t.after_enqueue(&info, Ok(()));
286 }
287 ka
288 }
289 Err(e) => {
290 let annotated = annotate_error(e, self.lib_tag);
291 if let Some(t) = self.trace.as_deref() {
292 t.after_enqueue(&info, Err(&annotated));
293 }
294 let _ = reply.send(Err(annotated));
295 return;
296 }
297 };
298
299 let fut = completion.await_completion(stream).boxed();
300 let lib_tag = self.lib_tag;
301 let op_name = self.op_name;
302 let dtype = self.dtype;
303 let trace = self.trace.clone();
304 tokio::spawn(async move {
305 let info = KernelInfo {
309 op_name,
310 library: lib_tag,
311 stream_id,
312 dtype,
313 };
314 if let Some(t) = trace.as_deref() {
315 t.before_complete(&info);
316 }
317 let started = Instant::now();
318 let result = fut.await;
319 let latency = started.elapsed();
320 match result {
321 Ok(()) => {
322 if let Some(t) = trace.as_deref() {
323 t.after_complete(&info, Ok(()), latency);
324 }
325 let _ = reply.send(Ok(output));
326 }
327 Err(e) => {
328 warn!(lib = lib_tag, error = %e, "kernel completion failed");
329 if let Some(t) = trace.as_deref() {
330 t.after_complete(&info, Err(&e), latency);
331 }
332 let _ = reply.send(Err(e));
333 }
334 }
335 drop(keep_alive);
337 });
338 }
339}
340
341impl std::fmt::Debug for KernelEnvelope {
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 f.debug_struct("KernelEnvelope")
344 .field("lib_tag", &self.lib_tag)
345 .field("op_name", &self.op_name)
346 .field("dtype", &self.dtype)
347 .field("nvtx_range_name", &self.nvtx_range_name)
348 .field("trace", &self.trace.as_ref().map(|_| "<dyn KernelTrace>"))
349 .finish()
350 }
351}
352
353pub fn run_kernel<O, KA, F>(
371 lib_tag: &'static str,
372 stream: &Arc<cudarc::driver::CudaStream>,
373 completion: &Arc<dyn CompletionStrategy>,
374 output: O,
375 reply: oneshot::Sender<Result<O, GpuError>>,
376 enqueue: F,
377) where
378 O: Send + 'static,
379 KA: Send + 'static,
380 F: FnOnce() -> Result<KA, GpuError>,
381{
382 let keep_alive = match enqueue() {
383 Ok(ka) => ka,
384 Err(e) => {
385 let _ = reply.send(Err(annotate_error(e, lib_tag)));
386 return;
387 }
388 };
389
390 let fut = completion.await_completion(stream).boxed();
391 tokio::spawn(async move {
392 let result = fut.await;
393 match result {
394 Ok(()) => {
395 let _ = reply.send(Ok(output));
396 }
397 Err(e) => {
398 warn!(lib = lib_tag, error = %e, "kernel completion failed");
399 let _ = reply.send(Err(e));
400 }
401 }
402 drop(keep_alive);
404 });
405}
406
407fn annotate_error(e: GpuError, lib_tag: &'static str) -> GpuError {
412 match e {
413 GpuError::Driver(msg) => GpuError::LibraryError { lib: lib_tag, msg },
414 other => other,
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use std::sync::atomic::{AtomicU32, Ordering};
423 use std::sync::Mutex;
424
425 #[test]
426 fn annotate_error_tags_driver_failures() {
427 let e = annotate_error(GpuError::Driver("oops".into()), "cudnn");
428 match e {
429 GpuError::LibraryError { lib, msg } => {
430 assert_eq!(lib, "cudnn");
431 assert_eq!(msg, "oops");
432 }
433 other => panic!("expected LibraryError, got {other:?}"),
434 }
435 }
436
437 #[test]
438 fn annotate_error_passes_through_typed_variants() {
439 let e = annotate_error(GpuError::OutOfMemory("alloc".into()), "cudnn");
440 assert!(matches!(e, GpuError::OutOfMemory(_)));
441 let e = annotate_error(GpuError::GpuRefStale("stale"), "cudnn");
442 assert!(matches!(e, GpuError::GpuRefStale(_)));
443 }
444
445 #[test]
449 fn pre_enqueue_error_bypasses_completion() {
450 let (tx, rx) = oneshot::channel::<Result<u32, GpuError>>();
451 let mut bumped = AtomicU32::new(0);
456 let enqueue = || -> Result<(), GpuError> {
457 bumped.fetch_add(1, Ordering::Relaxed);
458 Err(GpuError::OutOfMemory("forced".into()))
459 };
460 let res = enqueue();
461 assert!(matches!(res, Err(GpuError::OutOfMemory(_))));
462 assert_eq!(*bumped.get_mut(), 1);
463 drop(tx);
465 drop(rx);
466 }
467
468 #[derive(Default)]
471 struct RecordingTrace {
472 events: Mutex<Vec<&'static str>>,
473 last_dtype: Mutex<Option<String>>,
474 last_op: Mutex<Option<String>>,
475 last_lib: Mutex<Option<String>>,
476 enqueue_ok: AtomicU32,
477 enqueue_err: AtomicU32,
478 }
479
480 impl KernelTrace for RecordingTrace {
481 fn before_enqueue(&self, info: &KernelInfo<'_>) {
482 self.events.lock().unwrap().push("before_enqueue");
483 *self.last_op.lock().unwrap() = Some(info.op_name.to_string());
484 *self.last_lib.lock().unwrap() = Some(info.library.to_string());
485 *self.last_dtype.lock().unwrap() = info.dtype.map(str::to_string);
486 }
487
488 fn after_enqueue(&self, _info: &KernelInfo<'_>, result: Result<(), &GpuError>) {
489 self.events.lock().unwrap().push("after_enqueue");
490 match result {
491 Ok(()) => {
492 self.enqueue_ok.fetch_add(1, Ordering::Relaxed);
493 }
494 Err(_) => {
495 self.enqueue_err.fetch_add(1, Ordering::Relaxed);
496 }
497 }
498 }
499
500 fn before_complete(&self, _info: &KernelInfo<'_>) {
501 self.events.lock().unwrap().push("before_complete");
502 }
503
504 fn after_complete(
505 &self,
506 _info: &KernelInfo<'_>,
507 _result: Result<(), &GpuError>,
508 _latency: Duration,
509 ) {
510 self.events.lock().unwrap().push("after_complete");
511 }
512 }
513
514 fn drive_envelope_trace<F>(
519 env: &KernelEnvelope,
520 enqueue: F,
521 ) -> (Result<(), GpuError>, Result<(), GpuError>)
522 where
523 F: FnOnce() -> Result<(), GpuError>,
524 {
525 let info = env.info(0xDEAD_BEEF);
527 if let Some(t) = env.trace.as_deref() {
528 t.before_enqueue(&info);
529 }
530 let enqueue_result = enqueue();
531 let enqueue_report = match &enqueue_result {
532 Ok(()) => Ok(()),
533 Err(e) => Err(annotate_error_clone(e, env.lib_tag)),
534 };
535 if let Some(t) = env.trace.as_deref() {
536 match &enqueue_report {
537 Ok(()) => t.after_enqueue(&info, Ok(())),
538 Err(e) => t.after_enqueue(&info, Err(e)),
539 }
540 }
541 if enqueue_report.is_ok() {
546 if let Some(t) = env.trace.as_deref() {
547 t.before_complete(&info);
548 t.after_complete(&info, Ok(()), Duration::from_micros(1));
549 }
550 }
551 (enqueue_result, enqueue_report)
552 }
553
554 fn annotate_error_clone(e: &GpuError, lib_tag: &'static str) -> GpuError {
556 match e {
557 GpuError::Driver(msg) => GpuError::LibraryError {
558 lib: lib_tag,
559 msg: msg.clone(),
560 },
561 GpuError::OutOfMemory(msg) => GpuError::OutOfMemory(msg.clone()),
562 GpuError::ContextPoisoned(msg) => GpuError::ContextPoisoned(msg.clone()),
563 GpuError::Unrecoverable(msg) => GpuError::Unrecoverable(msg.clone()),
564 GpuError::GpuRefStale(s) => GpuError::GpuRefStale(s),
565 GpuError::LibraryError { lib, msg } => GpuError::LibraryError {
566 lib,
567 msg: msg.clone(),
568 },
569 other => GpuError::LibraryError {
573 lib: lib_tag,
574 msg: other.to_string(),
575 },
576 }
577 }
578
579 #[test]
580 fn envelope_default_is_traceless_and_nvtxless() {
581 let env = KernelEnvelope::new("cublas");
582 assert!(env.trace.is_none());
583 assert!(env.nvtx_range_name.is_none());
584 assert_eq!(env.lib_tag, "cublas");
585 assert_eq!(env.op_name, "cublas");
586 assert!(env.dtype.is_none());
587 }
588
589 #[test]
590 fn envelope_builder_sets_metadata() {
591 let trace = Arc::new(RecordingTrace::default()) as Arc<dyn KernelTrace>;
592 let env = KernelEnvelope::new("cublas")
593 .with_op_name("sgemm")
594 .with_dtype("f32")
595 .with_trace(trace)
596 .with_nvtx("blas/sgemm");
597 assert_eq!(env.op_name, "sgemm");
598 assert_eq!(env.dtype, Some("f32"));
599 assert_eq!(env.nvtx_range_name, Some("blas/sgemm"));
600 assert!(env.trace.is_some());
601 }
602
603 #[test]
604 fn trace_hooks_fire_in_order_on_success() {
605 let trace = Arc::new(RecordingTrace::default());
606 let env = KernelEnvelope::new("cublas")
607 .with_op_name("sgemm")
608 .with_dtype("f32")
609 .with_trace(trace.clone() as Arc<dyn KernelTrace>);
610
611 let (enqueue_res, _) = drive_envelope_trace(&env, || Ok(()));
612 assert!(enqueue_res.is_ok());
613 let events = trace.events.lock().unwrap().clone();
614 assert_eq!(
615 events,
616 vec![
617 "before_enqueue",
618 "after_enqueue",
619 "before_complete",
620 "after_complete",
621 ]
622 );
623 assert_eq!(trace.enqueue_ok.load(Ordering::Relaxed), 1);
624 assert_eq!(trace.enqueue_err.load(Ordering::Relaxed), 0);
625 assert_eq!(trace.last_op.lock().unwrap().as_deref(), Some("sgemm"));
626 assert_eq!(trace.last_lib.lock().unwrap().as_deref(), Some("cublas"));
627 assert_eq!(trace.last_dtype.lock().unwrap().as_deref(), Some("f32"));
628 }
629
630 #[test]
631 fn trace_hooks_skip_completion_on_enqueue_error() {
632 let trace = Arc::new(RecordingTrace::default());
633 let env = KernelEnvelope::new("cudnn")
634 .with_op_name("conv2d_forward")
635 .with_trace(trace.clone() as Arc<dyn KernelTrace>);
636
637 let (enqueue_res, report) =
638 drive_envelope_trace(&env, || Err(GpuError::Driver("forced".into())));
639 assert!(enqueue_res.is_err());
640 match report {
642 Err(GpuError::LibraryError { lib, msg }) => {
643 assert_eq!(lib, "cudnn");
644 assert_eq!(msg, "forced");
645 }
646 other => panic!("expected LibraryError, got {other:?}"),
647 }
648 let events = trace.events.lock().unwrap().clone();
649 assert_eq!(events, vec!["before_enqueue", "after_enqueue"]);
650 assert_eq!(trace.enqueue_ok.load(Ordering::Relaxed), 0);
651 assert_eq!(trace.enqueue_err.load(Ordering::Relaxed), 1);
652 }
653
654 #[test]
655 fn envelope_without_trace_is_silent() {
656 let env = KernelEnvelope::new("cufft");
657 let (res, _) = drive_envelope_trace(&env, || Ok(()));
658 assert!(res.is_ok());
659 }
663}