1use std::sync::Arc;
9
10use cudarc::cublas::sys::cublasOperation_t;
11use cudarc::cublas::{Gemv, GemvConfig};
12use cudarc::driver::{DevicePtr, DevicePtrMut};
13use tokio::sync::oneshot;
14
15use crate::dtype::{GemvSupported, GerSupported};
16use crate::error::GpuError;
17use crate::gpu_ref::GpuRef;
18use crate::kernel::dispatch::{BlasDispatchCtx, BlasL2Dispatch};
19use crate::kernel::envelope;
20use crate::sys::cublas as syscublas;
21
22const LIB: &str = "cublas";
23
24pub struct GemvRequest<T: GemvSupported> {
27 pub trans: cublasOperation_t,
28 pub m: i32,
29 pub n: i32,
30 pub alpha: T,
31 pub beta: T,
32 pub a: GpuRef<T>,
33 pub lda: i32,
34 pub x: GpuRef<T>,
35 pub incx: i32,
36 pub y: GpuRef<T>,
37 pub incy: i32,
38 pub reply: oneshot::Sender<Result<(), GpuError>>,
39}
40
41fn dispatch_gemv<T>(req: GemvRequest<T>, ctx: &BlasDispatchCtx<'_>)
42where
43 T: GemvSupported + Copy,
44 cudarc::cublas::CudaBlas: Gemv<T>,
45{
46 let GemvRequest {
47 trans,
48 m,
49 n,
50 alpha,
51 beta,
52 a,
53 lda,
54 x,
55 incx,
56 y,
57 incy,
58 reply,
59 } = req;
60
61 let (a_slice, x_slice, y_slice) = match envelope::access_all_3(&a, &x, &y) {
62 Ok(t) => t,
63 Err(e) => {
64 let _ = reply.send(Err(e));
65 return;
66 }
67 };
68
69 let mut y_owned = match Arc::try_unwrap(y_slice) {
70 Ok(s) => s,
71 Err(_) => {
72 let _ = reply.send(Err(GpuError::Unrecoverable(
73 "GEMV target buffer Y has more than one live reference".into(),
74 )));
75 return;
76 }
77 };
78
79 y.record_write(ctx.stream);
80
81 let cfg = GemvConfig::<T> {
82 trans,
83 m,
84 n,
85 alpha,
86 lda,
87 incx,
88 beta,
89 incy,
90 };
91
92 let cublas = ctx.cublas.clone();
93 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
94 let res = unsafe { cublas.gemv(cfg, &*a_slice, &*x_slice, &mut y_owned) };
95 match res {
96 Ok(()) => Ok((cublas, a_slice, x_slice, y_owned)),
97 Err(e) => Err(GpuError::LibraryError {
98 lib: LIB,
99 msg: format!("gemv enqueue: {e}"),
100 }),
101 }
102 });
103}
104
105impl BlasL2Dispatch for GemvRequest<f32> {
106 fn dtype_name(&self) -> &'static str {
107 <f32 as atomr_accel::AccelDtype>::NAME
108 }
109 fn op_name(&self) -> &'static str {
110 "gemv"
111 }
112 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
113 dispatch_gemv::<f32>(*self, ctx);
114 }
115}
116
117impl BlasL2Dispatch for GemvRequest<f64> {
118 fn dtype_name(&self) -> &'static str {
119 <f64 as atomr_accel::AccelDtype>::NAME
120 }
121 fn op_name(&self) -> &'static str {
122 "gemv"
123 }
124 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
125 dispatch_gemv::<f64>(*self, ctx);
126 }
127}
128
129pub struct GerRequest<T: GerSupported> {
133 pub m: i32,
134 pub n: i32,
135 pub alpha: T,
136 pub x: GpuRef<T>,
137 pub incx: i32,
138 pub y: GpuRef<T>,
139 pub incy: i32,
140 pub a: GpuRef<T>,
141 pub lda: i32,
142 pub reply: oneshot::Sender<Result<(), GpuError>>,
143}
144
145trait GerCall: GerSupported {
146 unsafe fn call(
151 handle: cudarc::cublas::sys::cublasHandle_t,
152 m: i32,
153 n: i32,
154 alpha: *const Self,
155 x: cudarc::driver::sys::CUdeviceptr,
156 incx: i32,
157 y: cudarc::driver::sys::CUdeviceptr,
158 incy: i32,
159 a: cudarc::driver::sys::CUdeviceptr,
160 lda: i32,
161 ) -> Result<(), GpuError>;
162}
163
164impl GerCall for f32 {
165 unsafe fn call(
166 handle: cudarc::cublas::sys::cublasHandle_t,
167 m: i32,
168 n: i32,
169 alpha: *const Self,
170 x: cudarc::driver::sys::CUdeviceptr,
171 incx: i32,
172 y: cudarc::driver::sys::CUdeviceptr,
173 incy: i32,
174 a: cudarc::driver::sys::CUdeviceptr,
175 lda: i32,
176 ) -> Result<(), GpuError> {
177 syscublas::sger(handle, m, n, alpha, x, incx, y, incy, a, lda)
178 }
179}
180
181impl GerCall for f64 {
182 unsafe fn call(
183 handle: cudarc::cublas::sys::cublasHandle_t,
184 m: i32,
185 n: i32,
186 alpha: *const Self,
187 x: cudarc::driver::sys::CUdeviceptr,
188 incx: i32,
189 y: cudarc::driver::sys::CUdeviceptr,
190 incy: i32,
191 a: cudarc::driver::sys::CUdeviceptr,
192 lda: i32,
193 ) -> Result<(), GpuError> {
194 syscublas::dger(handle, m, n, alpha, x, incx, y, incy, a, lda)
195 }
196}
197
198fn dispatch_ger<T>(req: GerRequest<T>, ctx: &BlasDispatchCtx<'_>)
199where
200 T: GerSupported + GerCall + Copy,
201{
202 let GerRequest {
203 m,
204 n,
205 alpha,
206 x,
207 incx,
208 y,
209 incy,
210 a,
211 lda,
212 reply,
213 } = req;
214 let (a_slice, x_slice, y_slice) = match envelope::access_all_3(&a, &x, &y) {
215 Ok(t) => t,
216 Err(e) => {
217 let _ = reply.send(Err(e));
218 return;
219 }
220 };
221 let mut a_owned = match Arc::try_unwrap(a_slice) {
222 Ok(s) => s,
223 Err(_) => {
224 let _ = reply.send(Err(GpuError::Unrecoverable(
225 "GER target matrix A has more than one live reference".into(),
226 )));
227 return;
228 }
229 };
230 a.record_write(ctx.stream);
231 let cublas = ctx.cublas.clone();
232 let stream = ctx.stream.clone();
233 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
234 let res = {
235 let (x_ptr, _x_rec) = (*x_slice).device_ptr(&stream);
236 let (y_ptr, _y_rec) = (*y_slice).device_ptr(&stream);
237 let (a_ptr, _a_rec) = a_owned.device_ptr_mut(&stream);
238 unsafe {
239 T::call(
240 *cublas.handle(),
241 m,
242 n,
243 (&alpha) as *const T,
244 x_ptr,
245 incx,
246 y_ptr,
247 incy,
248 a_ptr,
249 lda,
250 )
251 }
252 };
253 match res {
254 Ok(()) => Ok((cublas, x_slice, y_slice, a_owned)),
255 Err(e) => Err(e),
256 }
257 });
258}
259
260impl BlasL2Dispatch for GerRequest<f32> {
261 fn dtype_name(&self) -> &'static str {
262 <f32 as atomr_accel::AccelDtype>::NAME
263 }
264 fn op_name(&self) -> &'static str {
265 "ger"
266 }
267 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
268 dispatch_ger::<f32>(*self, ctx);
269 }
270}
271
272impl BlasL2Dispatch for GerRequest<f64> {
273 fn dtype_name(&self) -> &'static str {
274 <f64 as atomr_accel::AccelDtype>::NAME
275 }
276 fn op_name(&self) -> &'static str {
277 "ger"
278 }
279 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
280 dispatch_ger::<f64>(*self, ctx);
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::super::gemm::tests_helpers::gpu_ref_stub;
287 use super::*;
288 use tokio::sync::oneshot;
289
290 #[test]
291 fn gemv_request_round_trip() {
292 let (tx, _rx) = oneshot::channel();
293 let req = GemvRequest::<f32> {
294 trans: cublasOperation_t::CUBLAS_OP_N,
295 m: 4,
296 n: 4,
297 alpha: 1.0,
298 beta: 0.0,
299 a: gpu_ref_stub::<f32>(),
300 lda: 4,
301 x: gpu_ref_stub::<f32>(),
302 incx: 1,
303 y: gpu_ref_stub::<f32>(),
304 incy: 1,
305 reply: tx,
306 };
307 let boxed: Box<dyn BlasL2Dispatch> = Box::new(req);
308 assert_eq!(boxed.op_name(), "gemv");
309 assert_eq!(boxed.dtype_name(), "f32");
310 Box::leak(boxed);
311
312 let (tx, _rx) = oneshot::channel();
313 let req = GemvRequest::<f64> {
314 trans: cublasOperation_t::CUBLAS_OP_N,
315 m: 4,
316 n: 4,
317 alpha: 1.0,
318 beta: 0.0,
319 a: gpu_ref_stub::<f64>(),
320 lda: 4,
321 x: gpu_ref_stub::<f64>(),
322 incx: 1,
323 y: gpu_ref_stub::<f64>(),
324 incy: 1,
325 reply: tx,
326 };
327 let boxed: Box<dyn BlasL2Dispatch> = Box::new(req);
328 assert_eq!(boxed.dtype_name(), "f64");
329 Box::leak(boxed);
330 }
331
332 #[test]
333 fn ger_request_round_trip() {
334 let (tx, _rx) = oneshot::channel();
335 let req = GerRequest::<f32> {
336 m: 4,
337 n: 4,
338 alpha: 1.0,
339 x: gpu_ref_stub::<f32>(),
340 incx: 1,
341 y: gpu_ref_stub::<f32>(),
342 incy: 1,
343 a: gpu_ref_stub::<f32>(),
344 lda: 4,
345 reply: tx,
346 };
347 let boxed: Box<dyn BlasL2Dispatch> = Box::new(req);
348 assert_eq!(boxed.op_name(), "ger");
349 Box::leak(boxed);
350 }
351}