Skip to main content

atomr_accel_cuda/kernel/blas/
l2.rs

1//! Typed L2 ops: gemv, ger.
2//!
3//! cudarc 0.19's safe `Gemv<T>` covers f32/f64. For `ger` we drop to
4//! the local sys-level [`crate::sys::cublas::sger`] /
5//! [`crate::sys::cublas::dger`] wrappers since cudarc has no safe
6//! `Ger<T>` trait.
7
8use std::sync::Arc;
9
10use cudarc::cublas::sys::cublasOperation_t;
11use cudarc::cublas::{Gemv, GemvConfig};
12use cudarc::driver::{DevicePtr, DevicePtrMut};
13use tokio::sync::oneshot;
14
15use crate::dtype::{GemvSupported, GerSupported};
16use crate::error::GpuError;
17use crate::gpu_ref::GpuRef;
18use crate::kernel::dispatch::{BlasDispatchCtx, BlasL2Dispatch};
19use crate::kernel::envelope;
20use crate::sys::cublas as syscublas;
21
22const LIB: &str = "cublas";
23
24// ─────────────────────── GEMV ───────────────────────
25
26pub struct GemvRequest<T: GemvSupported> {
27    pub trans: cublasOperation_t,
28    pub m: i32,
29    pub n: i32,
30    pub alpha: T,
31    pub beta: T,
32    pub a: GpuRef<T>,
33    pub lda: i32,
34    pub x: GpuRef<T>,
35    pub incx: i32,
36    pub y: GpuRef<T>,
37    pub incy: i32,
38    pub reply: oneshot::Sender<Result<(), GpuError>>,
39}
40
41fn dispatch_gemv<T>(req: GemvRequest<T>, ctx: &BlasDispatchCtx<'_>)
42where
43    T: GemvSupported + Copy,
44    cudarc::cublas::CudaBlas: Gemv<T>,
45{
46    let GemvRequest {
47        trans,
48        m,
49        n,
50        alpha,
51        beta,
52        a,
53        lda,
54        x,
55        incx,
56        y,
57        incy,
58        reply,
59    } = req;
60
61    let (a_slice, x_slice, y_slice) = match envelope::access_all_3(&a, &x, &y) {
62        Ok(t) => t,
63        Err(e) => {
64            let _ = reply.send(Err(e));
65            return;
66        }
67    };
68
69    let mut y_owned = match Arc::try_unwrap(y_slice) {
70        Ok(s) => s,
71        Err(_) => {
72            let _ = reply.send(Err(GpuError::Unrecoverable(
73                "GEMV target buffer Y has more than one live reference".into(),
74            )));
75            return;
76        }
77    };
78
79    y.record_write(ctx.stream);
80
81    let cfg = GemvConfig::<T> {
82        trans,
83        m,
84        n,
85        alpha,
86        lda,
87        incx,
88        beta,
89        incy,
90    };
91
92    let cublas = ctx.cublas.clone();
93    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
94        let res = unsafe { cublas.gemv(cfg, &*a_slice, &*x_slice, &mut y_owned) };
95        match res {
96            Ok(()) => Ok((cublas, a_slice, x_slice, y_owned)),
97            Err(e) => Err(GpuError::LibraryError {
98                lib: LIB,
99                msg: format!("gemv enqueue: {e}"),
100            }),
101        }
102    });
103}
104
105impl BlasL2Dispatch for GemvRequest<f32> {
106    fn dtype_name(&self) -> &'static str {
107        <f32 as atomr_accel::AccelDtype>::NAME
108    }
109    fn op_name(&self) -> &'static str {
110        "gemv"
111    }
112    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
113        dispatch_gemv::<f32>(*self, ctx);
114    }
115}
116
117impl BlasL2Dispatch for GemvRequest<f64> {
118    fn dtype_name(&self) -> &'static str {
119        <f64 as atomr_accel::AccelDtype>::NAME
120    }
121    fn op_name(&self) -> &'static str {
122        "gemv"
123    }
124    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
125        dispatch_gemv::<f64>(*self, ctx);
126    }
127}
128
129// ─────────────────────── GER ───────────────────────
130
131/// Rank-1 update: `A := α·x·y^T + A`.
132pub struct GerRequest<T: GerSupported> {
133    pub m: i32,
134    pub n: i32,
135    pub alpha: T,
136    pub x: GpuRef<T>,
137    pub incx: i32,
138    pub y: GpuRef<T>,
139    pub incy: i32,
140    pub a: GpuRef<T>,
141    pub lda: i32,
142    pub reply: oneshot::Sender<Result<(), GpuError>>,
143}
144
145trait GerCall: GerSupported {
146    /// Call the right `cublasSger_v2` / `cublasDger_v2` wrapper.
147    ///
148    /// # Safety
149    /// All pointers must be valid for the encoded sizes.
150    unsafe fn call(
151        handle: cudarc::cublas::sys::cublasHandle_t,
152        m: i32,
153        n: i32,
154        alpha: *const Self,
155        x: cudarc::driver::sys::CUdeviceptr,
156        incx: i32,
157        y: cudarc::driver::sys::CUdeviceptr,
158        incy: i32,
159        a: cudarc::driver::sys::CUdeviceptr,
160        lda: i32,
161    ) -> Result<(), GpuError>;
162}
163
164impl GerCall for f32 {
165    unsafe fn call(
166        handle: cudarc::cublas::sys::cublasHandle_t,
167        m: i32,
168        n: i32,
169        alpha: *const Self,
170        x: cudarc::driver::sys::CUdeviceptr,
171        incx: i32,
172        y: cudarc::driver::sys::CUdeviceptr,
173        incy: i32,
174        a: cudarc::driver::sys::CUdeviceptr,
175        lda: i32,
176    ) -> Result<(), GpuError> {
177        syscublas::sger(handle, m, n, alpha, x, incx, y, incy, a, lda)
178    }
179}
180
181impl GerCall for f64 {
182    unsafe fn call(
183        handle: cudarc::cublas::sys::cublasHandle_t,
184        m: i32,
185        n: i32,
186        alpha: *const Self,
187        x: cudarc::driver::sys::CUdeviceptr,
188        incx: i32,
189        y: cudarc::driver::sys::CUdeviceptr,
190        incy: i32,
191        a: cudarc::driver::sys::CUdeviceptr,
192        lda: i32,
193    ) -> Result<(), GpuError> {
194        syscublas::dger(handle, m, n, alpha, x, incx, y, incy, a, lda)
195    }
196}
197
198fn dispatch_ger<T>(req: GerRequest<T>, ctx: &BlasDispatchCtx<'_>)
199where
200    T: GerSupported + GerCall + Copy,
201{
202    let GerRequest {
203        m,
204        n,
205        alpha,
206        x,
207        incx,
208        y,
209        incy,
210        a,
211        lda,
212        reply,
213    } = req;
214    let (a_slice, x_slice, y_slice) = match envelope::access_all_3(&a, &x, &y) {
215        Ok(t) => t,
216        Err(e) => {
217            let _ = reply.send(Err(e));
218            return;
219        }
220    };
221    let mut a_owned = match Arc::try_unwrap(a_slice) {
222        Ok(s) => s,
223        Err(_) => {
224            let _ = reply.send(Err(GpuError::Unrecoverable(
225                "GER target matrix A has more than one live reference".into(),
226            )));
227            return;
228        }
229    };
230    a.record_write(ctx.stream);
231    let cublas = ctx.cublas.clone();
232    let stream = ctx.stream.clone();
233    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
234        let res = {
235            let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
236            let (y_ptr, _y_rec) = (*y_slice).device_ptr(&stream);
237            let (a_ptr, _a_rec) = a_owned.device_ptr_mut(&stream);
238            unsafe {
239                T::call(
240                    *cublas.handle(),
241                    m,
242                    n,
243                    (&alpha) as *const T,
244                    x_ptr,
245                    incx,
246                    y_ptr,
247                    incy,
248                    a_ptr,
249                    lda,
250                )
251            }
252        };
253        match res {
254            Ok(()) => Ok((cublas, x_slice, y_slice, a_owned)),
255            Err(e) => Err(e),
256        }
257    });
258}
259
260impl BlasL2Dispatch for GerRequest<f32> {
261    fn dtype_name(&self) -> &'static str {
262        <f32 as atomr_accel::AccelDtype>::NAME
263    }
264    fn op_name(&self) -> &'static str {
265        "ger"
266    }
267    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
268        dispatch_ger::<f32>(*self, ctx);
269    }
270}
271
272impl BlasL2Dispatch for GerRequest<f64> {
273    fn dtype_name(&self) -> &'static str {
274        <f64 as atomr_accel::AccelDtype>::NAME
275    }
276    fn op_name(&self) -> &'static str {
277        "ger"
278    }
279    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
280        dispatch_ger::<f64>(*self, ctx);
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::super::gemm::tests_helpers::gpu_ref_stub;
287    use super::*;
288    use tokio::sync::oneshot;
289
290    #[test]
291    fn gemv_request_round_trip() {
292        let (tx, _rx) = oneshot::channel();
293        let req = GemvRequest::<f32> {
294            trans: cublasOperation_t::CUBLAS_OP_N,
295            m: 4,
296            n: 4,
297            alpha: 1.0,
298            beta: 0.0,
299            a: gpu_ref_stub::<f32>(),
300            lda: 4,
301            x: gpu_ref_stub::<f32>(),
302            incx: 1,
303            y: gpu_ref_stub::<f32>(),
304            incy: 1,
305            reply: tx,
306        };
307        let boxed: Box<dyn BlasL2Dispatch> = Box::new(req);
308        assert_eq!(boxed.op_name(), "gemv");
309        assert_eq!(boxed.dtype_name(), "f32");
310        Box::leak(boxed);
311
312        let (tx, _rx) = oneshot::channel();
313        let req = GemvRequest::<f64> {
314            trans: cublasOperation_t::CUBLAS_OP_N,
315            m: 4,
316            n: 4,
317            alpha: 1.0,
318            beta: 0.0,
319            a: gpu_ref_stub::<f64>(),
320            lda: 4,
321            x: gpu_ref_stub::<f64>(),
322            incx: 1,
323            y: gpu_ref_stub::<f64>(),
324            incy: 1,
325            reply: tx,
326        };
327        let boxed: Box<dyn BlasL2Dispatch> = Box::new(req);
328        assert_eq!(boxed.dtype_name(), "f64");
329        Box::leak(boxed);
330    }
331
332    #[test]
333    fn ger_request_round_trip() {
334        let (tx, _rx) = oneshot::channel();
335        let req = GerRequest::<f32> {
336            m: 4,
337            n: 4,
338            alpha: 1.0,
339            x: gpu_ref_stub::<f32>(),
340            incx: 1,
341            y: gpu_ref_stub::<f32>(),
342            incy: 1,
343            a: gpu_ref_stub::<f32>(),
344            lda: 4,
345            reply: tx,
346        };
347        let boxed: Box<dyn BlasL2Dispatch> = Box::new(req);
348        assert_eq!(boxed.op_name(), "ger");
349        Box::leak(boxed);
350    }
351}