Skip to main content

atomr_accel_cuda/kernel/blas/
l1.rs

1//! Typed L1 ops: axpy, dot, nrm2, scal, asum, iamax, iamin, copy,
2//! swap, rot.
3//!
4//! Each op routes through the cuBLAS *Ex entry point so we can ship
5//! the same code for f32, f64, f16, and bf16 (and later fp8) without
6//! the per-dtype `cublasS*`/`cublasD*` ladder. The wrappers live in
7//! [`crate::sys::cublas`].
8
9use std::sync::Arc;
10
11use cudarc::driver::{DevicePtr, DevicePtrMut};
12use tokio::sync::oneshot;
13
14use crate::dtype::{AxpyDotNrm2Supported, CudaDtype};
15use crate::error::GpuError;
16use crate::gpu_ref::GpuRef;
17use crate::kernel::dispatch::{BlasDispatchCtx, BlasL1Dispatch};
18use crate::kernel::envelope;
19use crate::sys::cublas as syscublas;
20
21const LIB: &str = "cublas";
22
23// ─────────────────────── AXPY ───────────────────────
24
25pub struct AxpyRequest<T: AxpyDotNrm2Supported> {
26    pub n: i32,
27    pub alpha: T::Scalar,
28    pub x: GpuRef<T>,
29    pub incx: i32,
30    pub y: GpuRef<T>,
31    pub incy: i32,
32    pub reply: oneshot::Sender<Result<(), GpuError>>,
33}
34
35fn dispatch_axpy<T>(req: AxpyRequest<T>, ctx: &BlasDispatchCtx<'_>)
36where
37    T: AxpyDotNrm2Supported,
38{
39    let AxpyRequest {
40        n,
41        alpha,
42        x,
43        incx,
44        y,
45        incy,
46        reply,
47    } = req;
48
49    let (x_slice, y_slice) = match envelope::access_all_2(&x, &y) {
50        Ok(t) => t,
51        Err(e) => {
52            let _ = reply.send(Err(e));
53            return;
54        }
55    };
56
57    let mut y_owned = match Arc::try_unwrap(y_slice) {
58        Ok(s) => s,
59        Err(_arc) => {
60            let _ = reply.send(Err(GpuError::Unrecoverable(
61                "AXPY target buffer Y has more than one live reference".into(),
62            )));
63            return;
64        }
65    };
66
67    y.record_write(ctx.stream);
68
69    let cublas = ctx.cublas.clone();
70    let stream = ctx.stream.clone();
71    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
72        let res = {
73            // Scope the SyncOnDrop guards so they release the
74            // borrow on x_slice/y_owned before we move them into
75            // the keep-alive tuple.
76            let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
77            let (y_ptr, _y_rec) = y_owned.device_ptr_mut(&stream);
78            // SAFETY: handle valid; pointers/length come from a
79            // generation-checked GpuRef.
80            unsafe {
81                syscublas::axpy_ex(
82                    *cublas.handle(),
83                    n,
84                    (&alpha) as *const T::Scalar as *const _,
85                    scalar_data_type::<T>(),
86                    x_ptr,
87                    T::cuda_data_type(),
88                    incx,
89                    y_ptr,
90                    T::cuda_data_type(),
91                    incy,
92                    scalar_data_type::<T>(),
93                )
94            }
95        };
96        match res {
97            Ok(()) => Ok((cublas, x_slice, y_owned)),
98            Err(e) => Err(e),
99        }
100    });
101}
102
103impl BlasL1Dispatch for AxpyRequest<f32> {
104    fn dtype_name(&self) -> &'static str {
105        <f32 as atomr_accel::AccelDtype>::NAME
106    }
107    fn op_name(&self) -> &'static str {
108        "axpy"
109    }
110    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
111        dispatch_axpy::<f32>(*self, ctx);
112    }
113}
114
115impl BlasL1Dispatch for AxpyRequest<f64> {
116    fn dtype_name(&self) -> &'static str {
117        <f64 as atomr_accel::AccelDtype>::NAME
118    }
119    fn op_name(&self) -> &'static str {
120        "axpy"
121    }
122    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
123        dispatch_axpy::<f64>(*self, ctx);
124    }
125}
126
127#[cfg(feature = "f16")]
128impl BlasL1Dispatch for AxpyRequest<half::f16> {
129    fn dtype_name(&self) -> &'static str {
130        <half::f16 as atomr_accel::AccelDtype>::NAME
131    }
132    fn op_name(&self) -> &'static str {
133        "axpy"
134    }
135    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
136        dispatch_axpy::<half::f16>(*self, ctx);
137    }
138}
139
140#[cfg(feature = "f16")]
141impl BlasL1Dispatch for AxpyRequest<half::bf16> {
142    fn dtype_name(&self) -> &'static str {
143        <half::bf16 as atomr_accel::AccelDtype>::NAME
144    }
145    fn op_name(&self) -> &'static str {
146        "axpy"
147    }
148    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
149        dispatch_axpy::<half::bf16>(*self, ctx);
150    }
151}
152
153// ─────────────────────── SCAL ───────────────────────
154
155pub struct ScalRequest<T: AxpyDotNrm2Supported> {
156    pub n: i32,
157    pub alpha: T::Scalar,
158    pub x: GpuRef<T>,
159    pub incx: i32,
160    pub reply: oneshot::Sender<Result<(), GpuError>>,
161}
162
163fn dispatch_scal<T>(req: ScalRequest<T>, ctx: &BlasDispatchCtx<'_>)
164where
165    T: AxpyDotNrm2Supported,
166{
167    let ScalRequest {
168        n,
169        alpha,
170        x,
171        incx,
172        reply,
173    } = req;
174    let x_slice = match x.access() {
175        Ok(s) => s.clone(),
176        Err(e) => {
177            let _ = reply.send(Err(e));
178            return;
179        }
180    };
181    let mut x_owned = match Arc::try_unwrap(x_slice) {
182        Ok(s) => s,
183        Err(_) => {
184            let _ = reply.send(Err(GpuError::Unrecoverable(
185                "SCAL target buffer X has more than one live reference".into(),
186            )));
187            return;
188        }
189    };
190    x.record_write(ctx.stream);
191    let cublas = ctx.cublas.clone();
192    let stream = ctx.stream.clone();
193    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
194        let res = {
195            let (x_ptr, _x_rec) = x_owned.device_ptr_mut(&stream);
196            unsafe {
197                syscublas::scal_ex(
198                    *cublas.handle(),
199                    n,
200                    (&alpha) as *const T::Scalar as *const _,
201                    scalar_data_type::<T>(),
202                    x_ptr,
203                    T::cuda_data_type(),
204                    incx,
205                    scalar_data_type::<T>(),
206                )
207            }
208        };
209        match res {
210            Ok(()) => Ok((cublas, x_owned)),
211            Err(e) => Err(e),
212        }
213    });
214}
215
216impl BlasL1Dispatch for ScalRequest<f32> {
217    fn dtype_name(&self) -> &'static str {
218        <f32 as atomr_accel::AccelDtype>::NAME
219    }
220    fn op_name(&self) -> &'static str {
221        "scal"
222    }
223    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
224        dispatch_scal::<f32>(*self, ctx);
225    }
226}
227
228impl BlasL1Dispatch for ScalRequest<f64> {
229    fn dtype_name(&self) -> &'static str {
230        <f64 as atomr_accel::AccelDtype>::NAME
231    }
232    fn op_name(&self) -> &'static str {
233        "scal"
234    }
235    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
236        dispatch_scal::<f64>(*self, ctx);
237    }
238}
239
240#[cfg(feature = "f16")]
241impl BlasL1Dispatch for ScalRequest<half::f16> {
242    fn dtype_name(&self) -> &'static str {
243        <half::f16 as atomr_accel::AccelDtype>::NAME
244    }
245    fn op_name(&self) -> &'static str {
246        "scal"
247    }
248    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
249        dispatch_scal::<half::f16>(*self, ctx);
250    }
251}
252
253#[cfg(feature = "f16")]
254impl BlasL1Dispatch for ScalRequest<half::bf16> {
255    fn dtype_name(&self) -> &'static str {
256        <half::bf16 as atomr_accel::AccelDtype>::NAME
257    }
258    fn op_name(&self) -> &'static str {
259        "scal"
260    }
261    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
262        dispatch_scal::<half::bf16>(*self, ctx);
263    }
264}
265
266// ─────────────────────── NRM2 ───────────────────────
267
268/// Compute `||x||_2`. The result is written to a host-side
269/// `Box<MaybeUninit<T::Scalar>>` and forwarded back through `reply`.
270pub struct Nrm2Request<T: AxpyDotNrm2Supported> {
271    pub n: i32,
272    pub x: GpuRef<T>,
273    pub incx: i32,
274    pub reply: oneshot::Sender<Result<T::Scalar, GpuError>>,
275}
276
277fn dispatch_nrm2<T>(req: Nrm2Request<T>, ctx: &BlasDispatchCtx<'_>)
278where
279    T: AxpyDotNrm2Supported,
280    T::Scalar: Default,
281{
282    let Nrm2Request { n, x, incx, reply } = req;
283    let x_slice = match x.access() {
284        Ok(s) => s.clone(),
285        Err(e) => {
286            let _ = reply.send(Err(e));
287            return;
288        }
289    };
290    let cublas = ctx.cublas.clone();
291    let stream = ctx.stream.clone();
292    let stream_for_kernel = ctx.stream.clone();
293    let completion = ctx.completion.clone();
294    // We need a host-side scalar. nrm2 with `CUBLAS_POINTER_MODE_HOST`
295    // (the default cuBLAS state) writes to host memory but blocks
296    // until the stream finishes — that defeats the actor's
297    // never-block contract. So we allocate the result on the host
298    // and let cuBLAS sync inline; the caller doesn't await the
299    // completion future, just the reply.
300    //
301    // For Phase 1 we keep the simple path: enqueue + completion.
302    // The kernel writes to a host scalar held in a Box that we
303    // keep alive past completion.
304    let mut result_box = Box::new(T::Scalar::default());
305    let result_ptr = (&mut *result_box) as *mut T::Scalar as *mut core::ffi::c_void;
306
307    let scalar_dt = scalar_data_type::<T>();
308    let exec_dt = T::cuda_data_type();
309
310    let final_reply = reply;
311    // We drive a manual variant of run_kernel so the success arm can
312    // forward the host-side scalar back to the caller.
313    let (inner_tx, inner_rx) = oneshot::channel::<Result<(), GpuError>>();
314    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), inner_tx, move || {
315        let res = {
316            let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
317            // SAFETY: result_ptr is a valid host pointer to
318            // T::Scalar; the stream callback fires after the
319            // kernel has populated it.
320            unsafe {
321                syscublas::nrm2_ex(
322                    *cublas.handle(),
323                    n,
324                    x_ptr,
325                    T::cuda_data_type(),
326                    incx,
327                    result_ptr,
328                    scalar_dt,
329                    exec_dt,
330                )
331            }
332        };
333        match res {
334            Ok(()) => Ok((cublas, x_slice)),
335            Err(e) => Err(e),
336        }
337    });
338    let _ = stream_for_kernel; // silence unused while the pattern matches the other ops
339    let _ = completion;
340    tokio::spawn(async move {
341        match inner_rx.await {
342            Ok(Ok(())) => {
343                let _ = final_reply.send(Ok(*result_box));
344            }
345            Ok(Err(e)) => {
346                let _ = final_reply.send(Err(e));
347            }
348            Err(_) => {
349                let _ = final_reply.send(Err(GpuError::Timeout));
350            }
351        }
352    });
353}
354
355impl BlasL1Dispatch for Nrm2Request<f32> {
356    fn dtype_name(&self) -> &'static str {
357        <f32 as atomr_accel::AccelDtype>::NAME
358    }
359    fn op_name(&self) -> &'static str {
360        "nrm2"
361    }
362    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
363        dispatch_nrm2::<f32>(*self, ctx);
364    }
365}
366
367impl BlasL1Dispatch for Nrm2Request<f64> {
368    fn dtype_name(&self) -> &'static str {
369        <f64 as atomr_accel::AccelDtype>::NAME
370    }
371    fn op_name(&self) -> &'static str {
372        "nrm2"
373    }
374    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
375        dispatch_nrm2::<f64>(*self, ctx);
376    }
377}
378
379// ─────────────────────── DOT ───────────────────────
380
381pub struct DotRequest<T: AxpyDotNrm2Supported> {
382    pub n: i32,
383    pub x: GpuRef<T>,
384    pub incx: i32,
385    pub y: GpuRef<T>,
386    pub incy: i32,
387    pub reply: oneshot::Sender<Result<T::Scalar, GpuError>>,
388}
389
390fn dispatch_dot<T>(req: DotRequest<T>, ctx: &BlasDispatchCtx<'_>)
391where
392    T: AxpyDotNrm2Supported,
393    T::Scalar: Default,
394{
395    let DotRequest {
396        n,
397        x,
398        incx,
399        y,
400        incy,
401        reply,
402    } = req;
403    let (x_slice, y_slice) = match envelope::access_all_2(&x, &y) {
404        Ok(t) => t,
405        Err(e) => {
406            let _ = reply.send(Err(e));
407            return;
408        }
409    };
410
411    let cublas = ctx.cublas.clone();
412    let stream = ctx.stream.clone();
413    let mut result_box = Box::new(T::Scalar::default());
414    let result_ptr = (&mut *result_box) as *mut T::Scalar as *mut core::ffi::c_void;
415    let scalar_dt = scalar_data_type::<T>();
416    let exec_dt = T::cuda_data_type();
417
418    let final_reply = reply;
419    let (inner_tx, inner_rx) = oneshot::channel::<Result<(), GpuError>>();
420    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), inner_tx, move || {
421        let res = {
422            let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
423            let (y_ptr, _y_rec) = (*y_slice).device_ptr(&stream);
424            unsafe {
425                syscublas::dot_ex(
426                    *cublas.handle(),
427                    n,
428                    x_ptr,
429                    T::cuda_data_type(),
430                    incx,
431                    y_ptr,
432                    T::cuda_data_type(),
433                    incy,
434                    result_ptr,
435                    scalar_dt,
436                    exec_dt,
437                )
438            }
439        };
440        match res {
441            Ok(()) => Ok((cublas, x_slice, y_slice)),
442            Err(e) => Err(e),
443        }
444    });
445    tokio::spawn(async move {
446        match inner_rx.await {
447            Ok(Ok(())) => {
448                let _ = final_reply.send(Ok(*result_box));
449            }
450            Ok(Err(e)) => {
451                let _ = final_reply.send(Err(e));
452            }
453            Err(_) => {
454                let _ = final_reply.send(Err(GpuError::Timeout));
455            }
456        }
457    });
458}
459
460impl BlasL1Dispatch for DotRequest<f32> {
461    fn dtype_name(&self) -> &'static str {
462        <f32 as atomr_accel::AccelDtype>::NAME
463    }
464    fn op_name(&self) -> &'static str {
465        "dot"
466    }
467    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
468        dispatch_dot::<f32>(*self, ctx);
469    }
470}
471
472impl BlasL1Dispatch for DotRequest<f64> {
473    fn dtype_name(&self) -> &'static str {
474        <f64 as atomr_accel::AccelDtype>::NAME
475    }
476    fn op_name(&self) -> &'static str {
477        "dot"
478    }
479    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
480        dispatch_dot::<f64>(*self, ctx);
481    }
482}
483
484// ─────────────────────── ASUM ───────────────────────
485
486pub struct AsumRequest<T: AxpyDotNrm2Supported> {
487    pub n: i32,
488    pub x: GpuRef<T>,
489    pub incx: i32,
490    pub reply: oneshot::Sender<Result<T::Scalar, GpuError>>,
491}
492
493fn dispatch_asum<T>(req: AsumRequest<T>, ctx: &BlasDispatchCtx<'_>)
494where
495    T: AxpyDotNrm2Supported,
496    T::Scalar: Default,
497{
498    let AsumRequest { n, x, incx, reply } = req;
499    let x_slice = match x.access() {
500        Ok(s) => s.clone(),
501        Err(e) => {
502            let _ = reply.send(Err(e));
503            return;
504        }
505    };
506    let cublas = ctx.cublas.clone();
507    let stream = ctx.stream.clone();
508    let mut result_box = Box::new(T::Scalar::default());
509    let result_ptr = (&mut *result_box) as *mut T::Scalar as *mut core::ffi::c_void;
510    let scalar_dt = scalar_data_type::<T>();
511    let exec_dt = T::cuda_data_type();
512    let final_reply = reply;
513    let (inner_tx, inner_rx) = oneshot::channel::<Result<(), GpuError>>();
514    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), inner_tx, move || {
515        let res = {
516            let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
517            unsafe {
518                syscublas::asum_ex(
519                    *cublas.handle(),
520                    n,
521                    x_ptr,
522                    T::cuda_data_type(),
523                    incx,
524                    result_ptr,
525                    scalar_dt,
526                    exec_dt,
527                )
528            }
529        };
530        match res {
531            Ok(()) => Ok((cublas, x_slice)),
532            Err(e) => Err(e),
533        }
534    });
535    tokio::spawn(async move {
536        match inner_rx.await {
537            Ok(Ok(())) => {
538                let _ = final_reply.send(Ok(*result_box));
539            }
540            Ok(Err(e)) => {
541                let _ = final_reply.send(Err(e));
542            }
543            Err(_) => {
544                let _ = final_reply.send(Err(GpuError::Timeout));
545            }
546        }
547    });
548}
549
550impl BlasL1Dispatch for AsumRequest<f32> {
551    fn dtype_name(&self) -> &'static str {
552        <f32 as atomr_accel::AccelDtype>::NAME
553    }
554    fn op_name(&self) -> &'static str {
555        "asum"
556    }
557    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
558        dispatch_asum::<f32>(*self, ctx);
559    }
560}
561
562impl BlasL1Dispatch for AsumRequest<f64> {
563    fn dtype_name(&self) -> &'static str {
564        <f64 as atomr_accel::AccelDtype>::NAME
565    }
566    fn op_name(&self) -> &'static str {
567        "asum"
568    }
569    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
570        dispatch_asum::<f64>(*self, ctx);
571    }
572}
573
574// ─────────────────────── IAMAX / IAMIN ───────────────────────
575
576pub struct IamaxRequest<T: AxpyDotNrm2Supported> {
577    pub n: i32,
578    pub x: GpuRef<T>,
579    pub incx: i32,
580    pub reply: oneshot::Sender<Result<i32, GpuError>>,
581}
582
583pub struct IaminRequest<T: AxpyDotNrm2Supported> {
584    pub n: i32,
585    pub x: GpuRef<T>,
586    pub incx: i32,
587    pub reply: oneshot::Sender<Result<i32, GpuError>>,
588}
589
590fn dispatch_iamax_impl<T>(req: IamaxRequest<T>, ctx: &BlasDispatchCtx<'_>, find_min: bool)
591where
592    T: AxpyDotNrm2Supported,
593{
594    let IamaxRequest { n, x, incx, reply } = req;
595    let x_slice = match x.access() {
596        Ok(s) => s.clone(),
597        Err(e) => {
598            let _ = reply.send(Err(e));
599            return;
600        }
601    };
602    let cublas = ctx.cublas.clone();
603    let stream = ctx.stream.clone();
604    let mut result_box = Box::new(0i32);
605    let result_ptr = (&mut *result_box) as *mut i32;
606    let final_reply = reply;
607    let (inner_tx, inner_rx) = oneshot::channel::<Result<(), GpuError>>();
608    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), inner_tx, move || {
609        let res = {
610            let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
611            if find_min {
612                unsafe {
613                    syscublas::iamin_ex(
614                        *cublas.handle(),
615                        n,
616                        x_ptr,
617                        T::cuda_data_type(),
618                        incx,
619                        result_ptr,
620                    )
621                }
622            } else {
623                unsafe {
624                    syscublas::iamax_ex(
625                        *cublas.handle(),
626                        n,
627                        x_ptr,
628                        T::cuda_data_type(),
629                        incx,
630                        result_ptr,
631                    )
632                }
633            }
634        };
635        match res {
636            Ok(()) => Ok((cublas, x_slice)),
637            Err(e) => Err(e),
638        }
639    });
640    tokio::spawn(async move {
641        match inner_rx.await {
642            Ok(Ok(())) => {
643                let _ = final_reply.send(Ok(*result_box));
644            }
645            Ok(Err(e)) => {
646                let _ = final_reply.send(Err(e));
647            }
648            Err(_) => {
649                let _ = final_reply.send(Err(GpuError::Timeout));
650            }
651        }
652    });
653}
654
655impl BlasL1Dispatch for IamaxRequest<f32> {
656    fn dtype_name(&self) -> &'static str {
657        <f32 as atomr_accel::AccelDtype>::NAME
658    }
659    fn op_name(&self) -> &'static str {
660        "iamax"
661    }
662    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
663        dispatch_iamax_impl::<f32>(*self, ctx, false);
664    }
665}
666
667impl BlasL1Dispatch for IamaxRequest<f64> {
668    fn dtype_name(&self) -> &'static str {
669        <f64 as atomr_accel::AccelDtype>::NAME
670    }
671    fn op_name(&self) -> &'static str {
672        "iamax"
673    }
674    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
675        dispatch_iamax_impl::<f64>(*self, ctx, false);
676    }
677}
678
679impl BlasL1Dispatch for IaminRequest<f32> {
680    fn dtype_name(&self) -> &'static str {
681        <f32 as atomr_accel::AccelDtype>::NAME
682    }
683    fn op_name(&self) -> &'static str {
684        "iamin"
685    }
686    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
687        // IaminRequest<T> has the same field shape as IamaxRequest<T>;
688        // collapse via a helper.
689        let IaminRequest { n, x, incx, reply } = *self;
690        let req = IamaxRequest::<f32> { n, x, incx, reply };
691        dispatch_iamax_impl::<f32>(req, ctx, true);
692    }
693}
694
695impl BlasL1Dispatch for IaminRequest<f64> {
696    fn dtype_name(&self) -> &'static str {
697        <f64 as atomr_accel::AccelDtype>::NAME
698    }
699    fn op_name(&self) -> &'static str {
700        "iamin"
701    }
702    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
703        let IaminRequest { n, x, incx, reply } = *self;
704        let req = IamaxRequest::<f64> { n, x, incx, reply };
705        dispatch_iamax_impl::<f64>(req, ctx, true);
706    }
707}
708
709// ─────────────────────── COPY ───────────────────────
710
711pub struct CopyRequest<T: AxpyDotNrm2Supported> {
712    pub n: i32,
713    pub x: GpuRef<T>,
714    pub incx: i32,
715    pub y: GpuRef<T>,
716    pub incy: i32,
717    pub reply: oneshot::Sender<Result<(), GpuError>>,
718}
719
720fn dispatch_copy<T>(req: CopyRequest<T>, ctx: &BlasDispatchCtx<'_>)
721where
722    T: AxpyDotNrm2Supported,
723{
724    let CopyRequest {
725        n,
726        x,
727        incx,
728        y,
729        incy,
730        reply,
731    } = req;
732    let (x_slice, y_slice) = match envelope::access_all_2(&x, &y) {
733        Ok(t) => t,
734        Err(e) => {
735            let _ = reply.send(Err(e));
736            return;
737        }
738    };
739    let mut y_owned = match Arc::try_unwrap(y_slice) {
740        Ok(s) => s,
741        Err(_) => {
742            let _ = reply.send(Err(GpuError::Unrecoverable(
743                "COPY target buffer Y has more than one live reference".into(),
744            )));
745            return;
746        }
747    };
748    y.record_write(ctx.stream);
749    let cublas = ctx.cublas.clone();
750    let stream = ctx.stream.clone();
751    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
752        let res = {
753            let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
754            let (y_ptr, _y_rec) = y_owned.device_ptr_mut(&stream);
755            unsafe {
756                syscublas::copy_ex(
757                    *cublas.handle(),
758                    n,
759                    x_ptr,
760                    T::cuda_data_type(),
761                    incx,
762                    y_ptr,
763                    T::cuda_data_type(),
764                    incy,
765                )
766            }
767        };
768        match res {
769            Ok(()) => Ok((cublas, x_slice, y_owned)),
770            Err(e) => Err(e),
771        }
772    });
773}
774
775impl BlasL1Dispatch for CopyRequest<f32> {
776    fn dtype_name(&self) -> &'static str {
777        <f32 as atomr_accel::AccelDtype>::NAME
778    }
779    fn op_name(&self) -> &'static str {
780        "copy"
781    }
782    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
783        dispatch_copy::<f32>(*self, ctx);
784    }
785}
786
787impl BlasL1Dispatch for CopyRequest<f64> {
788    fn dtype_name(&self) -> &'static str {
789        <f64 as atomr_accel::AccelDtype>::NAME
790    }
791    fn op_name(&self) -> &'static str {
792        "copy"
793    }
794    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
795        dispatch_copy::<f64>(*self, ctx);
796    }
797}
798
799// ─────────────────────── SWAP ───────────────────────
800
801pub struct SwapRequest<T: AxpyDotNrm2Supported> {
802    pub n: i32,
803    pub x: GpuRef<T>,
804    pub incx: i32,
805    pub y: GpuRef<T>,
806    pub incy: i32,
807    pub reply: oneshot::Sender<Result<(), GpuError>>,
808}
809
810fn dispatch_swap<T>(req: SwapRequest<T>, ctx: &BlasDispatchCtx<'_>)
811where
812    T: AxpyDotNrm2Supported,
813{
814    let SwapRequest {
815        n,
816        x,
817        incx,
818        y,
819        incy,
820        reply,
821    } = req;
822    let x_slice = match x.access() {
823        Ok(s) => s.clone(),
824        Err(e) => {
825            let _ = reply.send(Err(e));
826            return;
827        }
828    };
829    let y_slice = match y.access() {
830        Ok(s) => s.clone(),
831        Err(e) => {
832            let _ = reply.send(Err(e));
833            return;
834        }
835    };
836    let mut x_owned = match Arc::try_unwrap(x_slice) {
837        Ok(s) => s,
838        Err(_) => {
839            let _ = reply.send(Err(GpuError::Unrecoverable(
840                "SWAP buffer X has more than one live reference".into(),
841            )));
842            return;
843        }
844    };
845    let mut y_owned = match Arc::try_unwrap(y_slice) {
846        Ok(s) => s,
847        Err(_) => {
848            let _ = reply.send(Err(GpuError::Unrecoverable(
849                "SWAP buffer Y has more than one live reference".into(),
850            )));
851            return;
852        }
853    };
854    x.record_write(ctx.stream);
855    y.record_write(ctx.stream);
856    let cublas = ctx.cublas.clone();
857    let stream = ctx.stream.clone();
858    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
859        let res = {
860            let (x_ptr, _x_rec) = x_owned.device_ptr_mut(&stream);
861            let (y_ptr, _y_rec) = y_owned.device_ptr_mut(&stream);
862            unsafe {
863                syscublas::swap_ex(
864                    *cublas.handle(),
865                    n,
866                    x_ptr,
867                    T::cuda_data_type(),
868                    incx,
869                    y_ptr,
870                    T::cuda_data_type(),
871                    incy,
872                )
873            }
874        };
875        match res {
876            Ok(()) => Ok((cublas, x_owned, y_owned)),
877            Err(e) => Err(e),
878        }
879    });
880}
881
882impl BlasL1Dispatch for SwapRequest<f32> {
883    fn dtype_name(&self) -> &'static str {
884        <f32 as atomr_accel::AccelDtype>::NAME
885    }
886    fn op_name(&self) -> &'static str {
887        "swap"
888    }
889    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
890        dispatch_swap::<f32>(*self, ctx);
891    }
892}
893
894impl BlasL1Dispatch for SwapRequest<f64> {
895    fn dtype_name(&self) -> &'static str {
896        <f64 as atomr_accel::AccelDtype>::NAME
897    }
898    fn op_name(&self) -> &'static str {
899        "swap"
900    }
901    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
902        dispatch_swap::<f64>(*self, ctx);
903    }
904}
905
906// ─────────────────────── ROT ───────────────────────
907
908/// Givens rotation: applies a 2D rotation `(x_i, y_i) := (c·x_i +
909/// s·y_i, -s·x_i + c·y_i)` in-place across two vectors.
910///
911/// `c` and `s` are passed through `T::Scalar` to match the cuBLAS-Ex
912/// pointer-mode-host convention.
913pub struct RotRequest<T: AxpyDotNrm2Supported> {
914    pub n: i32,
915    pub x: GpuRef<T>,
916    pub incx: i32,
917    pub y: GpuRef<T>,
918    pub incy: i32,
919    pub c: T::Scalar,
920    pub s: T::Scalar,
921    pub reply: oneshot::Sender<Result<(), GpuError>>,
922}
923
924fn dispatch_rot<T>(req: RotRequest<T>, ctx: &BlasDispatchCtx<'_>)
925where
926    T: AxpyDotNrm2Supported,
927{
928    let RotRequest {
929        n,
930        x,
931        incx,
932        y,
933        incy,
934        c,
935        s,
936        reply,
937    } = req;
938    let x_slice = match x.access() {
939        Ok(s) => s.clone(),
940        Err(e) => {
941            let _ = reply.send(Err(e));
942            return;
943        }
944    };
945    let y_slice = match y.access() {
946        Ok(s) => s.clone(),
947        Err(e) => {
948            let _ = reply.send(Err(e));
949            return;
950        }
951    };
952    let mut x_owned = match Arc::try_unwrap(x_slice) {
953        Ok(s) => s,
954        Err(_) => {
955            let _ = reply.send(Err(GpuError::Unrecoverable(
956                "ROT buffer X has more than one live reference".into(),
957            )));
958            return;
959        }
960    };
961    let mut y_owned = match Arc::try_unwrap(y_slice) {
962        Ok(s) => s,
963        Err(_) => {
964            let _ = reply.send(Err(GpuError::Unrecoverable(
965                "ROT buffer Y has more than one live reference".into(),
966            )));
967            return;
968        }
969    };
970    x.record_write(ctx.stream);
971    y.record_write(ctx.stream);
972    let cublas = ctx.cublas.clone();
973    let stream = ctx.stream.clone();
974    let scalar_dt = scalar_data_type::<T>();
975    let exec_dt = T::cuda_data_type();
976    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
977        let res = {
978            let (x_ptr, _x_rec) = x_owned.device_ptr_mut(&stream);
979            let (y_ptr, _y_rec) = y_owned.device_ptr_mut(&stream);
980            unsafe {
981                syscublas::rot_ex(
982                    *cublas.handle(),
983                    n,
984                    x_ptr,
985                    T::cuda_data_type(),
986                    incx,
987                    y_ptr,
988                    T::cuda_data_type(),
989                    incy,
990                    (&c) as *const T::Scalar as *const _,
991                    (&s) as *const T::Scalar as *const _,
992                    scalar_dt,
993                    exec_dt,
994                )
995            }
996        };
997        match res {
998            Ok(()) => Ok((cublas, x_owned, y_owned, c, s)),
999            Err(e) => Err(e),
1000        }
1001    });
1002}
1003
1004impl BlasL1Dispatch for RotRequest<f32> {
1005    fn dtype_name(&self) -> &'static str {
1006        <f32 as atomr_accel::AccelDtype>::NAME
1007    }
1008    fn op_name(&self) -> &'static str {
1009        "rot"
1010    }
1011    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
1012        dispatch_rot::<f32>(*self, ctx);
1013    }
1014}
1015
1016impl BlasL1Dispatch for RotRequest<f64> {
1017    fn dtype_name(&self) -> &'static str {
1018        <f64 as atomr_accel::AccelDtype>::NAME
1019    }
1020    fn op_name(&self) -> &'static str {
1021        "rot"
1022    }
1023    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
1024        dispatch_rot::<f64>(*self, ctx);
1025    }
1026}
1027
1028// ─────────────────────── helpers ───────────────────────
1029
1030/// Map `T::Scalar` to its `cudaDataType_t`. f32→`CUDA_R_32F`,
1031/// f64→`CUDA_R_64F`. Used as the alpha/result-precision argument to
1032/// nrm2/dot/asum/axpy/scal-Ex.
1033fn scalar_data_type<T: CudaDtype>() -> cudarc::cublas::sys::cudaDataType_t {
1034    use core::any::TypeId;
1035    if TypeId::of::<T::Scalar>() == TypeId::of::<f32>() {
1036        cudarc::cublas::sys::cudaDataType_t::CUDA_R_32F
1037    } else if TypeId::of::<T::Scalar>() == TypeId::of::<f64>() {
1038        cudarc::cublas::sys::cudaDataType_t::CUDA_R_64F
1039    } else {
1040        // Should be unreachable: every `T::Scalar` we ship is
1041        // f32 or f64.
1042        panic!(
1043            "Unrecoverable: scalar type for {} is not f32/f64",
1044            <T as atomr_accel::AccelDtype>::NAME
1045        );
1046    }
1047}
1048
1049#[cfg(test)]
1050mod tests {
1051    use super::super::gemm::tests_helpers::gpu_ref_stub;
1052    use super::*;
1053    use tokio::sync::oneshot;
1054
1055    #[test]
1056    fn axpy_request_round_trip() {
1057        let (tx, _rx) = oneshot::channel();
1058        let req = AxpyRequest::<f32> {
1059            n: 8,
1060            alpha: 1.0,
1061            x: gpu_ref_stub::<f32>(),
1062            incx: 1,
1063            y: gpu_ref_stub::<f32>(),
1064            incy: 1,
1065            reply: tx,
1066        };
1067        let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1068        assert_eq!(boxed.op_name(), "axpy");
1069        assert_eq!(boxed.dtype_name(), "f32");
1070        Box::leak(boxed);
1071
1072        let (tx, _rx) = oneshot::channel();
1073        let req = AxpyRequest::<f64> {
1074            n: 8,
1075            alpha: 1.0,
1076            x: gpu_ref_stub::<f64>(),
1077            incx: 1,
1078            y: gpu_ref_stub::<f64>(),
1079            incy: 1,
1080            reply: tx,
1081        };
1082        let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1083        assert_eq!(boxed.dtype_name(), "f64");
1084        Box::leak(boxed);
1085    }
1086
1087    #[test]
1088    fn scal_request_round_trip() {
1089        let (tx, _rx) = oneshot::channel();
1090        let req = ScalRequest::<f32> {
1091            n: 4,
1092            alpha: 2.0,
1093            x: gpu_ref_stub::<f32>(),
1094            incx: 1,
1095            reply: tx,
1096        };
1097        let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1098        assert_eq!(boxed.op_name(), "scal");
1099        Box::leak(boxed);
1100    }
1101
1102    #[test]
1103    fn dot_nrm2_asum_iamax_request_round_trip() {
1104        let (tx, _rx) = oneshot::channel();
1105        let req = DotRequest::<f32> {
1106            n: 4,
1107            x: gpu_ref_stub::<f32>(),
1108            incx: 1,
1109            y: gpu_ref_stub::<f32>(),
1110            incy: 1,
1111            reply: tx,
1112        };
1113        let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1114        assert_eq!(boxed.op_name(), "dot");
1115        Box::leak(boxed);
1116
1117        let (tx, _rx) = oneshot::channel();
1118        let req = Nrm2Request::<f32> {
1119            n: 4,
1120            x: gpu_ref_stub::<f32>(),
1121            incx: 1,
1122            reply: tx,
1123        };
1124        let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1125        assert_eq!(boxed.op_name(), "nrm2");
1126        Box::leak(boxed);
1127
1128        let (tx, _rx) = oneshot::channel();
1129        let req = IamaxRequest::<f32> {
1130            n: 4,
1131            x: gpu_ref_stub::<f32>(),
1132            incx: 1,
1133            reply: tx,
1134        };
1135        let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1136        assert_eq!(boxed.op_name(), "iamax");
1137        Box::leak(boxed);
1138
1139        let (tx, _rx) = oneshot::channel();
1140        let req = IaminRequest::<f32> {
1141            n: 4,
1142            x: gpu_ref_stub::<f32>(),
1143            incx: 1,
1144            reply: tx,
1145        };
1146        let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1147        assert_eq!(boxed.op_name(), "iamin");
1148        Box::leak(boxed);
1149    }
1150}