1use std::sync::Arc;
10
11use cudarc::driver::{DevicePtr, DevicePtrMut};
12use tokio::sync::oneshot;
13
14use crate::dtype::{AxpyDotNrm2Supported, CudaDtype};
15use crate::error::GpuError;
16use crate::gpu_ref::GpuRef;
17use crate::kernel::dispatch::{BlasDispatchCtx, BlasL1Dispatch};
18use crate::kernel::envelope;
19use crate::sys::cublas as syscublas;
20
21const LIB: &str = "cublas";
22
23pub struct AxpyRequest<T: AxpyDotNrm2Supported> {
26 pub n: i32,
27 pub alpha: T::Scalar,
28 pub x: GpuRef<T>,
29 pub incx: i32,
30 pub y: GpuRef<T>,
31 pub incy: i32,
32 pub reply: oneshot::Sender<Result<(), GpuError>>,
33}
34
35fn dispatch_axpy<T>(req: AxpyRequest<T>, ctx: &BlasDispatchCtx<'_>)
36where
37 T: AxpyDotNrm2Supported,
38{
39 let AxpyRequest {
40 n,
41 alpha,
42 x,
43 incx,
44 y,
45 incy,
46 reply,
47 } = req;
48
49 let (x_slice, y_slice) = match envelope::access_all_2(&x, &y) {
50 Ok(t) => t,
51 Err(e) => {
52 let _ = reply.send(Err(e));
53 return;
54 }
55 };
56
57 let mut y_owned = match Arc::try_unwrap(y_slice) {
58 Ok(s) => s,
59 Err(_arc) => {
60 let _ = reply.send(Err(GpuError::Unrecoverable(
61 "AXPY target buffer Y has more than one live reference".into(),
62 )));
63 return;
64 }
65 };
66
67 y.record_write(ctx.stream);
68
69 let cublas = ctx.cublas.clone();
70 let stream = ctx.stream.clone();
71 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
72 let res = {
73 let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
77 let (y_ptr, _y_rec) = y_owned.device_ptr_mut(&stream);
78 unsafe {
81 syscublas::axpy_ex(
82 *cublas.handle(),
83 n,
84 (&alpha) as *const T::Scalar as *const _,
85 scalar_data_type::<T>(),
86 x_ptr,
87 T::cuda_data_type(),
88 incx,
89 y_ptr,
90 T::cuda_data_type(),
91 incy,
92 scalar_data_type::<T>(),
93 )
94 }
95 };
96 match res {
97 Ok(()) => Ok((cublas, x_slice, y_owned)),
98 Err(e) => Err(e),
99 }
100 });
101}
102
103impl BlasL1Dispatch for AxpyRequest<f32> {
104 fn dtype_name(&self) -> &'static str {
105 <f32 as atomr_accel::AccelDtype>::NAME
106 }
107 fn op_name(&self) -> &'static str {
108 "axpy"
109 }
110 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
111 dispatch_axpy::<f32>(*self, ctx);
112 }
113}
114
115impl BlasL1Dispatch for AxpyRequest<f64> {
116 fn dtype_name(&self) -> &'static str {
117 <f64 as atomr_accel::AccelDtype>::NAME
118 }
119 fn op_name(&self) -> &'static str {
120 "axpy"
121 }
122 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
123 dispatch_axpy::<f64>(*self, ctx);
124 }
125}
126
127#[cfg(feature = "f16")]
128impl BlasL1Dispatch for AxpyRequest<half::f16> {
129 fn dtype_name(&self) -> &'static str {
130 <half::f16 as atomr_accel::AccelDtype>::NAME
131 }
132 fn op_name(&self) -> &'static str {
133 "axpy"
134 }
135 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
136 dispatch_axpy::<half::f16>(*self, ctx);
137 }
138}
139
140#[cfg(feature = "f16")]
141impl BlasL1Dispatch for AxpyRequest<half::bf16> {
142 fn dtype_name(&self) -> &'static str {
143 <half::bf16 as atomr_accel::AccelDtype>::NAME
144 }
145 fn op_name(&self) -> &'static str {
146 "axpy"
147 }
148 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
149 dispatch_axpy::<half::bf16>(*self, ctx);
150 }
151}
152
153pub struct ScalRequest<T: AxpyDotNrm2Supported> {
156 pub n: i32,
157 pub alpha: T::Scalar,
158 pub x: GpuRef<T>,
159 pub incx: i32,
160 pub reply: oneshot::Sender<Result<(), GpuError>>,
161}
162
163fn dispatch_scal<T>(req: ScalRequest<T>, ctx: &BlasDispatchCtx<'_>)
164where
165 T: AxpyDotNrm2Supported,
166{
167 let ScalRequest {
168 n,
169 alpha,
170 x,
171 incx,
172 reply,
173 } = req;
174 let x_slice = match x.access() {
175 Ok(s) => s.clone(),
176 Err(e) => {
177 let _ = reply.send(Err(e));
178 return;
179 }
180 };
181 let mut x_owned = match Arc::try_unwrap(x_slice) {
182 Ok(s) => s,
183 Err(_) => {
184 let _ = reply.send(Err(GpuError::Unrecoverable(
185 "SCAL target buffer X has more than one live reference".into(),
186 )));
187 return;
188 }
189 };
190 x.record_write(ctx.stream);
191 let cublas = ctx.cublas.clone();
192 let stream = ctx.stream.clone();
193 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
194 let res = {
195 let (x_ptr, _x_rec) = x_owned.device_ptr_mut(&stream);
196 unsafe {
197 syscublas::scal_ex(
198 *cublas.handle(),
199 n,
200 (&alpha) as *const T::Scalar as *const _,
201 scalar_data_type::<T>(),
202 x_ptr,
203 T::cuda_data_type(),
204 incx,
205 scalar_data_type::<T>(),
206 )
207 }
208 };
209 match res {
210 Ok(()) => Ok((cublas, x_owned)),
211 Err(e) => Err(e),
212 }
213 });
214}
215
216impl BlasL1Dispatch for ScalRequest<f32> {
217 fn dtype_name(&self) -> &'static str {
218 <f32 as atomr_accel::AccelDtype>::NAME
219 }
220 fn op_name(&self) -> &'static str {
221 "scal"
222 }
223 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
224 dispatch_scal::<f32>(*self, ctx);
225 }
226}
227
228impl BlasL1Dispatch for ScalRequest<f64> {
229 fn dtype_name(&self) -> &'static str {
230 <f64 as atomr_accel::AccelDtype>::NAME
231 }
232 fn op_name(&self) -> &'static str {
233 "scal"
234 }
235 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
236 dispatch_scal::<f64>(*self, ctx);
237 }
238}
239
240#[cfg(feature = "f16")]
241impl BlasL1Dispatch for ScalRequest<half::f16> {
242 fn dtype_name(&self) -> &'static str {
243 <half::f16 as atomr_accel::AccelDtype>::NAME
244 }
245 fn op_name(&self) -> &'static str {
246 "scal"
247 }
248 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
249 dispatch_scal::<half::f16>(*self, ctx);
250 }
251}
252
253#[cfg(feature = "f16")]
254impl BlasL1Dispatch for ScalRequest<half::bf16> {
255 fn dtype_name(&self) -> &'static str {
256 <half::bf16 as atomr_accel::AccelDtype>::NAME
257 }
258 fn op_name(&self) -> &'static str {
259 "scal"
260 }
261 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
262 dispatch_scal::<half::bf16>(*self, ctx);
263 }
264}
265
266pub struct Nrm2Request<T: AxpyDotNrm2Supported> {
271 pub n: i32,
272 pub x: GpuRef<T>,
273 pub incx: i32,
274 pub reply: oneshot::Sender<Result<T::Scalar, GpuError>>,
275}
276
277fn dispatch_nrm2<T>(req: Nrm2Request<T>, ctx: &BlasDispatchCtx<'_>)
278where
279 T: AxpyDotNrm2Supported,
280 T::Scalar: Default,
281{
282 let Nrm2Request { n, x, incx, reply } = req;
283 let x_slice = match x.access() {
284 Ok(s) => s.clone(),
285 Err(e) => {
286 let _ = reply.send(Err(e));
287 return;
288 }
289 };
290 let cublas = ctx.cublas.clone();
291 let stream = ctx.stream.clone();
292 let stream_for_kernel = ctx.stream.clone();
293 let completion = ctx.completion.clone();
294 let mut result_box = Box::new(T::Scalar::default());
305 let result_ptr = (&mut *result_box) as *mut T::Scalar as *mut core::ffi::c_void;
306
307 let scalar_dt = scalar_data_type::<T>();
308 let exec_dt = T::cuda_data_type();
309
310 let final_reply = reply;
311 let (inner_tx, inner_rx) = oneshot::channel::<Result<(), GpuError>>();
314 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), inner_tx, move || {
315 let res = {
316 let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
317 unsafe {
321 syscublas::nrm2_ex(
322 *cublas.handle(),
323 n,
324 x_ptr,
325 T::cuda_data_type(),
326 incx,
327 result_ptr,
328 scalar_dt,
329 exec_dt,
330 )
331 }
332 };
333 match res {
334 Ok(()) => Ok((cublas, x_slice)),
335 Err(e) => Err(e),
336 }
337 });
338 let _ = stream_for_kernel; let _ = completion;
340 tokio::spawn(async move {
341 match inner_rx.await {
342 Ok(Ok(())) => {
343 let _ = final_reply.send(Ok(*result_box));
344 }
345 Ok(Err(e)) => {
346 let _ = final_reply.send(Err(e));
347 }
348 Err(_) => {
349 let _ = final_reply.send(Err(GpuError::Timeout));
350 }
351 }
352 });
353}
354
355impl BlasL1Dispatch for Nrm2Request<f32> {
356 fn dtype_name(&self) -> &'static str {
357 <f32 as atomr_accel::AccelDtype>::NAME
358 }
359 fn op_name(&self) -> &'static str {
360 "nrm2"
361 }
362 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
363 dispatch_nrm2::<f32>(*self, ctx);
364 }
365}
366
367impl BlasL1Dispatch for Nrm2Request<f64> {
368 fn dtype_name(&self) -> &'static str {
369 <f64 as atomr_accel::AccelDtype>::NAME
370 }
371 fn op_name(&self) -> &'static str {
372 "nrm2"
373 }
374 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
375 dispatch_nrm2::<f64>(*self, ctx);
376 }
377}
378
379pub struct DotRequest<T: AxpyDotNrm2Supported> {
382 pub n: i32,
383 pub x: GpuRef<T>,
384 pub incx: i32,
385 pub y: GpuRef<T>,
386 pub incy: i32,
387 pub reply: oneshot::Sender<Result<T::Scalar, GpuError>>,
388}
389
390fn dispatch_dot<T>(req: DotRequest<T>, ctx: &BlasDispatchCtx<'_>)
391where
392 T: AxpyDotNrm2Supported,
393 T::Scalar: Default,
394{
395 let DotRequest {
396 n,
397 x,
398 incx,
399 y,
400 incy,
401 reply,
402 } = req;
403 let (x_slice, y_slice) = match envelope::access_all_2(&x, &y) {
404 Ok(t) => t,
405 Err(e) => {
406 let _ = reply.send(Err(e));
407 return;
408 }
409 };
410
411 let cublas = ctx.cublas.clone();
412 let stream = ctx.stream.clone();
413 let mut result_box = Box::new(T::Scalar::default());
414 let result_ptr = (&mut *result_box) as *mut T::Scalar as *mut core::ffi::c_void;
415 let scalar_dt = scalar_data_type::<T>();
416 let exec_dt = T::cuda_data_type();
417
418 let final_reply = reply;
419 let (inner_tx, inner_rx) = oneshot::channel::<Result<(), GpuError>>();
420 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), inner_tx, move || {
421 let res = {
422 let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
423 let (y_ptr, _y_rec) = (*y_slice).device_ptr(&stream);
424 unsafe {
425 syscublas::dot_ex(
426 *cublas.handle(),
427 n,
428 x_ptr,
429 T::cuda_data_type(),
430 incx,
431 y_ptr,
432 T::cuda_data_type(),
433 incy,
434 result_ptr,
435 scalar_dt,
436 exec_dt,
437 )
438 }
439 };
440 match res {
441 Ok(()) => Ok((cublas, x_slice, y_slice)),
442 Err(e) => Err(e),
443 }
444 });
445 tokio::spawn(async move {
446 match inner_rx.await {
447 Ok(Ok(())) => {
448 let _ = final_reply.send(Ok(*result_box));
449 }
450 Ok(Err(e)) => {
451 let _ = final_reply.send(Err(e));
452 }
453 Err(_) => {
454 let _ = final_reply.send(Err(GpuError::Timeout));
455 }
456 }
457 });
458}
459
460impl BlasL1Dispatch for DotRequest<f32> {
461 fn dtype_name(&self) -> &'static str {
462 <f32 as atomr_accel::AccelDtype>::NAME
463 }
464 fn op_name(&self) -> &'static str {
465 "dot"
466 }
467 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
468 dispatch_dot::<f32>(*self, ctx);
469 }
470}
471
472impl BlasL1Dispatch for DotRequest<f64> {
473 fn dtype_name(&self) -> &'static str {
474 <f64 as atomr_accel::AccelDtype>::NAME
475 }
476 fn op_name(&self) -> &'static str {
477 "dot"
478 }
479 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
480 dispatch_dot::<f64>(*self, ctx);
481 }
482}
483
484pub struct AsumRequest<T: AxpyDotNrm2Supported> {
487 pub n: i32,
488 pub x: GpuRef<T>,
489 pub incx: i32,
490 pub reply: oneshot::Sender<Result<T::Scalar, GpuError>>,
491}
492
493fn dispatch_asum<T>(req: AsumRequest<T>, ctx: &BlasDispatchCtx<'_>)
494where
495 T: AxpyDotNrm2Supported,
496 T::Scalar: Default,
497{
498 let AsumRequest { n, x, incx, reply } = req;
499 let x_slice = match x.access() {
500 Ok(s) => s.clone(),
501 Err(e) => {
502 let _ = reply.send(Err(e));
503 return;
504 }
505 };
506 let cublas = ctx.cublas.clone();
507 let stream = ctx.stream.clone();
508 let mut result_box = Box::new(T::Scalar::default());
509 let result_ptr = (&mut *result_box) as *mut T::Scalar as *mut core::ffi::c_void;
510 let scalar_dt = scalar_data_type::<T>();
511 let exec_dt = T::cuda_data_type();
512 let final_reply = reply;
513 let (inner_tx, inner_rx) = oneshot::channel::<Result<(), GpuError>>();
514 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), inner_tx, move || {
515 let res = {
516 let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
517 unsafe {
518 syscublas::asum_ex(
519 *cublas.handle(),
520 n,
521 x_ptr,
522 T::cuda_data_type(),
523 incx,
524 result_ptr,
525 scalar_dt,
526 exec_dt,
527 )
528 }
529 };
530 match res {
531 Ok(()) => Ok((cublas, x_slice)),
532 Err(e) => Err(e),
533 }
534 });
535 tokio::spawn(async move {
536 match inner_rx.await {
537 Ok(Ok(())) => {
538 let _ = final_reply.send(Ok(*result_box));
539 }
540 Ok(Err(e)) => {
541 let _ = final_reply.send(Err(e));
542 }
543 Err(_) => {
544 let _ = final_reply.send(Err(GpuError::Timeout));
545 }
546 }
547 });
548}
549
550impl BlasL1Dispatch for AsumRequest<f32> {
551 fn dtype_name(&self) -> &'static str {
552 <f32 as atomr_accel::AccelDtype>::NAME
553 }
554 fn op_name(&self) -> &'static str {
555 "asum"
556 }
557 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
558 dispatch_asum::<f32>(*self, ctx);
559 }
560}
561
562impl BlasL1Dispatch for AsumRequest<f64> {
563 fn dtype_name(&self) -> &'static str {
564 <f64 as atomr_accel::AccelDtype>::NAME
565 }
566 fn op_name(&self) -> &'static str {
567 "asum"
568 }
569 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
570 dispatch_asum::<f64>(*self, ctx);
571 }
572}
573
574pub struct IamaxRequest<T: AxpyDotNrm2Supported> {
577 pub n: i32,
578 pub x: GpuRef<T>,
579 pub incx: i32,
580 pub reply: oneshot::Sender<Result<i32, GpuError>>,
581}
582
583pub struct IaminRequest<T: AxpyDotNrm2Supported> {
584 pub n: i32,
585 pub x: GpuRef<T>,
586 pub incx: i32,
587 pub reply: oneshot::Sender<Result<i32, GpuError>>,
588}
589
590fn dispatch_iamax_impl<T>(req: IamaxRequest<T>, ctx: &BlasDispatchCtx<'_>, find_min: bool)
591where
592 T: AxpyDotNrm2Supported,
593{
594 let IamaxRequest { n, x, incx, reply } = req;
595 let x_slice = match x.access() {
596 Ok(s) => s.clone(),
597 Err(e) => {
598 let _ = reply.send(Err(e));
599 return;
600 }
601 };
602 let cublas = ctx.cublas.clone();
603 let stream = ctx.stream.clone();
604 let mut result_box = Box::new(0i32);
605 let result_ptr = (&mut *result_box) as *mut i32;
606 let final_reply = reply;
607 let (inner_tx, inner_rx) = oneshot::channel::<Result<(), GpuError>>();
608 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), inner_tx, move || {
609 let res = {
610 let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
611 if find_min {
612 unsafe {
613 syscublas::iamin_ex(
614 *cublas.handle(),
615 n,
616 x_ptr,
617 T::cuda_data_type(),
618 incx,
619 result_ptr,
620 )
621 }
622 } else {
623 unsafe {
624 syscublas::iamax_ex(
625 *cublas.handle(),
626 n,
627 x_ptr,
628 T::cuda_data_type(),
629 incx,
630 result_ptr,
631 )
632 }
633 }
634 };
635 match res {
636 Ok(()) => Ok((cublas, x_slice)),
637 Err(e) => Err(e),
638 }
639 });
640 tokio::spawn(async move {
641 match inner_rx.await {
642 Ok(Ok(())) => {
643 let _ = final_reply.send(Ok(*result_box));
644 }
645 Ok(Err(e)) => {
646 let _ = final_reply.send(Err(e));
647 }
648 Err(_) => {
649 let _ = final_reply.send(Err(GpuError::Timeout));
650 }
651 }
652 });
653}
654
655impl BlasL1Dispatch for IamaxRequest<f32> {
656 fn dtype_name(&self) -> &'static str {
657 <f32 as atomr_accel::AccelDtype>::NAME
658 }
659 fn op_name(&self) -> &'static str {
660 "iamax"
661 }
662 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
663 dispatch_iamax_impl::<f32>(*self, ctx, false);
664 }
665}
666
667impl BlasL1Dispatch for IamaxRequest<f64> {
668 fn dtype_name(&self) -> &'static str {
669 <f64 as atomr_accel::AccelDtype>::NAME
670 }
671 fn op_name(&self) -> &'static str {
672 "iamax"
673 }
674 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
675 dispatch_iamax_impl::<f64>(*self, ctx, false);
676 }
677}
678
679impl BlasL1Dispatch for IaminRequest<f32> {
680 fn dtype_name(&self) -> &'static str {
681 <f32 as atomr_accel::AccelDtype>::NAME
682 }
683 fn op_name(&self) -> &'static str {
684 "iamin"
685 }
686 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
687 let IaminRequest { n, x, incx, reply } = *self;
690 let req = IamaxRequest::<f32> { n, x, incx, reply };
691 dispatch_iamax_impl::<f32>(req, ctx, true);
692 }
693}
694
695impl BlasL1Dispatch for IaminRequest<f64> {
696 fn dtype_name(&self) -> &'static str {
697 <f64 as atomr_accel::AccelDtype>::NAME
698 }
699 fn op_name(&self) -> &'static str {
700 "iamin"
701 }
702 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
703 let IaminRequest { n, x, incx, reply } = *self;
704 let req = IamaxRequest::<f64> { n, x, incx, reply };
705 dispatch_iamax_impl::<f64>(req, ctx, true);
706 }
707}
708
709pub struct CopyRequest<T: AxpyDotNrm2Supported> {
712 pub n: i32,
713 pub x: GpuRef<T>,
714 pub incx: i32,
715 pub y: GpuRef<T>,
716 pub incy: i32,
717 pub reply: oneshot::Sender<Result<(), GpuError>>,
718}
719
720fn dispatch_copy<T>(req: CopyRequest<T>, ctx: &BlasDispatchCtx<'_>)
721where
722 T: AxpyDotNrm2Supported,
723{
724 let CopyRequest {
725 n,
726 x,
727 incx,
728 y,
729 incy,
730 reply,
731 } = req;
732 let (x_slice, y_slice) = match envelope::access_all_2(&x, &y) {
733 Ok(t) => t,
734 Err(e) => {
735 let _ = reply.send(Err(e));
736 return;
737 }
738 };
739 let mut y_owned = match Arc::try_unwrap(y_slice) {
740 Ok(s) => s,
741 Err(_) => {
742 let _ = reply.send(Err(GpuError::Unrecoverable(
743 "COPY target buffer Y has more than one live reference".into(),
744 )));
745 return;
746 }
747 };
748 y.record_write(ctx.stream);
749 let cublas = ctx.cublas.clone();
750 let stream = ctx.stream.clone();
751 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
752 let res = {
753 let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
754 let (y_ptr, _y_rec) = y_owned.device_ptr_mut(&stream);
755 unsafe {
756 syscublas::copy_ex(
757 *cublas.handle(),
758 n,
759 x_ptr,
760 T::cuda_data_type(),
761 incx,
762 y_ptr,
763 T::cuda_data_type(),
764 incy,
765 )
766 }
767 };
768 match res {
769 Ok(()) => Ok((cublas, x_slice, y_owned)),
770 Err(e) => Err(e),
771 }
772 });
773}
774
775impl BlasL1Dispatch for CopyRequest<f32> {
776 fn dtype_name(&self) -> &'static str {
777 <f32 as atomr_accel::AccelDtype>::NAME
778 }
779 fn op_name(&self) -> &'static str {
780 "copy"
781 }
782 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
783 dispatch_copy::<f32>(*self, ctx);
784 }
785}
786
787impl BlasL1Dispatch for CopyRequest<f64> {
788 fn dtype_name(&self) -> &'static str {
789 <f64 as atomr_accel::AccelDtype>::NAME
790 }
791 fn op_name(&self) -> &'static str {
792 "copy"
793 }
794 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
795 dispatch_copy::<f64>(*self, ctx);
796 }
797}
798
799pub struct SwapRequest<T: AxpyDotNrm2Supported> {
802 pub n: i32,
803 pub x: GpuRef<T>,
804 pub incx: i32,
805 pub y: GpuRef<T>,
806 pub incy: i32,
807 pub reply: oneshot::Sender<Result<(), GpuError>>,
808}
809
810fn dispatch_swap<T>(req: SwapRequest<T>, ctx: &BlasDispatchCtx<'_>)
811where
812 T: AxpyDotNrm2Supported,
813{
814 let SwapRequest {
815 n,
816 x,
817 incx,
818 y,
819 incy,
820 reply,
821 } = req;
822 let x_slice = match x.access() {
823 Ok(s) => s.clone(),
824 Err(e) => {
825 let _ = reply.send(Err(e));
826 return;
827 }
828 };
829 let y_slice = match y.access() {
830 Ok(s) => s.clone(),
831 Err(e) => {
832 let _ = reply.send(Err(e));
833 return;
834 }
835 };
836 let mut x_owned = match Arc::try_unwrap(x_slice) {
837 Ok(s) => s,
838 Err(_) => {
839 let _ = reply.send(Err(GpuError::Unrecoverable(
840 "SWAP buffer X has more than one live reference".into(),
841 )));
842 return;
843 }
844 };
845 let mut y_owned = match Arc::try_unwrap(y_slice) {
846 Ok(s) => s,
847 Err(_) => {
848 let _ = reply.send(Err(GpuError::Unrecoverable(
849 "SWAP buffer Y has more than one live reference".into(),
850 )));
851 return;
852 }
853 };
854 x.record_write(ctx.stream);
855 y.record_write(ctx.stream);
856 let cublas = ctx.cublas.clone();
857 let stream = ctx.stream.clone();
858 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
859 let res = {
860 let (x_ptr, _x_rec) = x_owned.device_ptr_mut(&stream);
861 let (y_ptr, _y_rec) = y_owned.device_ptr_mut(&stream);
862 unsafe {
863 syscublas::swap_ex(
864 *cublas.handle(),
865 n,
866 x_ptr,
867 T::cuda_data_type(),
868 incx,
869 y_ptr,
870 T::cuda_data_type(),
871 incy,
872 )
873 }
874 };
875 match res {
876 Ok(()) => Ok((cublas, x_owned, y_owned)),
877 Err(e) => Err(e),
878 }
879 });
880}
881
882impl BlasL1Dispatch for SwapRequest<f32> {
883 fn dtype_name(&self) -> &'static str {
884 <f32 as atomr_accel::AccelDtype>::NAME
885 }
886 fn op_name(&self) -> &'static str {
887 "swap"
888 }
889 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
890 dispatch_swap::<f32>(*self, ctx);
891 }
892}
893
894impl BlasL1Dispatch for SwapRequest<f64> {
895 fn dtype_name(&self) -> &'static str {
896 <f64 as atomr_accel::AccelDtype>::NAME
897 }
898 fn op_name(&self) -> &'static str {
899 "swap"
900 }
901 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
902 dispatch_swap::<f64>(*self, ctx);
903 }
904}
905
906pub struct RotRequest<T: AxpyDotNrm2Supported> {
914 pub n: i32,
915 pub x: GpuRef<T>,
916 pub incx: i32,
917 pub y: GpuRef<T>,
918 pub incy: i32,
919 pub c: T::Scalar,
920 pub s: T::Scalar,
921 pub reply: oneshot::Sender<Result<(), GpuError>>,
922}
923
924fn dispatch_rot<T>(req: RotRequest<T>, ctx: &BlasDispatchCtx<'_>)
925where
926 T: AxpyDotNrm2Supported,
927{
928 let RotRequest {
929 n,
930 x,
931 incx,
932 y,
933 incy,
934 c,
935 s,
936 reply,
937 } = req;
938 let x_slice = match x.access() {
939 Ok(s) => s.clone(),
940 Err(e) => {
941 let _ = reply.send(Err(e));
942 return;
943 }
944 };
945 let y_slice = match y.access() {
946 Ok(s) => s.clone(),
947 Err(e) => {
948 let _ = reply.send(Err(e));
949 return;
950 }
951 };
952 let mut x_owned = match Arc::try_unwrap(x_slice) {
953 Ok(s) => s,
954 Err(_) => {
955 let _ = reply.send(Err(GpuError::Unrecoverable(
956 "ROT buffer X has more than one live reference".into(),
957 )));
958 return;
959 }
960 };
961 let mut y_owned = match Arc::try_unwrap(y_slice) {
962 Ok(s) => s,
963 Err(_) => {
964 let _ = reply.send(Err(GpuError::Unrecoverable(
965 "ROT buffer Y has more than one live reference".into(),
966 )));
967 return;
968 }
969 };
970 x.record_write(ctx.stream);
971 y.record_write(ctx.stream);
972 let cublas = ctx.cublas.clone();
973 let stream = ctx.stream.clone();
974 let scalar_dt = scalar_data_type::<T>();
975 let exec_dt = T::cuda_data_type();
976 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
977 let res = {
978 let (x_ptr, _x_rec) = x_owned.device_ptr_mut(&stream);
979 let (y_ptr, _y_rec) = y_owned.device_ptr_mut(&stream);
980 unsafe {
981 syscublas::rot_ex(
982 *cublas.handle(),
983 n,
984 x_ptr,
985 T::cuda_data_type(),
986 incx,
987 y_ptr,
988 T::cuda_data_type(),
989 incy,
990 (&c) as *const T::Scalar as *const _,
991 (&s) as *const T::Scalar as *const _,
992 scalar_dt,
993 exec_dt,
994 )
995 }
996 };
997 match res {
998 Ok(()) => Ok((cublas, x_owned, y_owned, c, s)),
999 Err(e) => Err(e),
1000 }
1001 });
1002}
1003
1004impl BlasL1Dispatch for RotRequest<f32> {
1005 fn dtype_name(&self) -> &'static str {
1006 <f32 as atomr_accel::AccelDtype>::NAME
1007 }
1008 fn op_name(&self) -> &'static str {
1009 "rot"
1010 }
1011 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
1012 dispatch_rot::<f32>(*self, ctx);
1013 }
1014}
1015
1016impl BlasL1Dispatch for RotRequest<f64> {
1017 fn dtype_name(&self) -> &'static str {
1018 <f64 as atomr_accel::AccelDtype>::NAME
1019 }
1020 fn op_name(&self) -> &'static str {
1021 "rot"
1022 }
1023 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
1024 dispatch_rot::<f64>(*self, ctx);
1025 }
1026}
1027
1028fn scalar_data_type<T: CudaDtype>() -> cudarc::cublas::sys::cudaDataType_t {
1034 use core::any::TypeId;
1035 if TypeId::of::<T::Scalar>() == TypeId::of::<f32>() {
1036 cudarc::cublas::sys::cudaDataType_t::CUDA_R_32F
1037 } else if TypeId::of::<T::Scalar>() == TypeId::of::<f64>() {
1038 cudarc::cublas::sys::cudaDataType_t::CUDA_R_64F
1039 } else {
1040 panic!(
1043 "Unrecoverable: scalar type for {} is not f32/f64",
1044 <T as atomr_accel::AccelDtype>::NAME
1045 );
1046 }
1047}
1048
1049#[cfg(test)]
1050mod tests {
1051 use super::super::gemm::tests_helpers::gpu_ref_stub;
1052 use super::*;
1053 use tokio::sync::oneshot;
1054
1055 #[test]
1056 fn axpy_request_round_trip() {
1057 let (tx, _rx) = oneshot::channel();
1058 let req = AxpyRequest::<f32> {
1059 n: 8,
1060 alpha: 1.0,
1061 x: gpu_ref_stub::<f32>(),
1062 incx: 1,
1063 y: gpu_ref_stub::<f32>(),
1064 incy: 1,
1065 reply: tx,
1066 };
1067 let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1068 assert_eq!(boxed.op_name(), "axpy");
1069 assert_eq!(boxed.dtype_name(), "f32");
1070 Box::leak(boxed);
1071
1072 let (tx, _rx) = oneshot::channel();
1073 let req = AxpyRequest::<f64> {
1074 n: 8,
1075 alpha: 1.0,
1076 x: gpu_ref_stub::<f64>(),
1077 incx: 1,
1078 y: gpu_ref_stub::<f64>(),
1079 incy: 1,
1080 reply: tx,
1081 };
1082 let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1083 assert_eq!(boxed.dtype_name(), "f64");
1084 Box::leak(boxed);
1085 }
1086
1087 #[test]
1088 fn scal_request_round_trip() {
1089 let (tx, _rx) = oneshot::channel();
1090 let req = ScalRequest::<f32> {
1091 n: 4,
1092 alpha: 2.0,
1093 x: gpu_ref_stub::<f32>(),
1094 incx: 1,
1095 reply: tx,
1096 };
1097 let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1098 assert_eq!(boxed.op_name(), "scal");
1099 Box::leak(boxed);
1100 }
1101
1102 #[test]
1103 fn dot_nrm2_asum_iamax_request_round_trip() {
1104 let (tx, _rx) = oneshot::channel();
1105 let req = DotRequest::<f32> {
1106 n: 4,
1107 x: gpu_ref_stub::<f32>(),
1108 incx: 1,
1109 y: gpu_ref_stub::<f32>(),
1110 incy: 1,
1111 reply: tx,
1112 };
1113 let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1114 assert_eq!(boxed.op_name(), "dot");
1115 Box::leak(boxed);
1116
1117 let (tx, _rx) = oneshot::channel();
1118 let req = Nrm2Request::<f32> {
1119 n: 4,
1120 x: gpu_ref_stub::<f32>(),
1121 incx: 1,
1122 reply: tx,
1123 };
1124 let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1125 assert_eq!(boxed.op_name(), "nrm2");
1126 Box::leak(boxed);
1127
1128 let (tx, _rx) = oneshot::channel();
1129 let req = IamaxRequest::<f32> {
1130 n: 4,
1131 x: gpu_ref_stub::<f32>(),
1132 incx: 1,
1133 reply: tx,
1134 };
1135 let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1136 assert_eq!(boxed.op_name(), "iamax");
1137 Box::leak(boxed);
1138
1139 let (tx, _rx) = oneshot::channel();
1140 let req = IaminRequest::<f32> {
1141 n: 4,
1142 x: gpu_ref_stub::<f32>(),
1143 incx: 1,
1144 reply: tx,
1145 };
1146 let boxed: Box<dyn BlasL1Dispatch> = Box::new(req);
1147 assert_eq!(boxed.op_name(), "iamin");
1148 Box::leak(boxed);
1149 }
1150}