atomr_accel_cuda/kernel/blas/
gemm_strided_batched.rs1use std::sync::Arc;
5
6use cudarc::cublas::sys::cublasOperation_t;
7use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
8use tokio::sync::oneshot;
9
10use crate::dtype::GemmSupported;
11use crate::error::GpuError;
12use crate::gpu_ref::GpuRef;
13use crate::kernel::dispatch::{BlasDispatchCtx, GemmStridedBatchedDispatch};
14use crate::kernel::envelope;
15
16const LIB: &str = "cublas";
17
18pub struct GemmStridedBatchedRequest<T: GemmSupported> {
22 pub a: GpuRef<T>,
23 pub b: GpuRef<T>,
24 pub c: GpuRef<T>,
25 pub m: i32,
26 pub n: i32,
27 pub k: i32,
28 pub alpha: T,
29 pub beta: T,
30 pub trans_a: cublasOperation_t,
31 pub trans_b: cublasOperation_t,
32 pub lda: i32,
33 pub ldb: i32,
34 pub ldc: i32,
35 pub stride_a: i64,
36 pub stride_b: i64,
37 pub stride_c: i64,
38 pub batch_size: i32,
39 pub reply: oneshot::Sender<Result<(), GpuError>>,
40}
41
42fn dispatch_strided_batched<T>(req: GemmStridedBatchedRequest<T>, ctx: &BlasDispatchCtx<'_>)
43where
44 T: GemmSupported + Copy,
45 cudarc::cublas::CudaBlas: Gemm<T>,
46{
47 let GemmStridedBatchedRequest {
48 a,
49 b,
50 c,
51 m,
52 n,
53 k,
54 alpha,
55 beta,
56 trans_a,
57 trans_b,
58 lda,
59 ldb,
60 ldc,
61 stride_a,
62 stride_b,
63 stride_c,
64 batch_size,
65 reply,
66 } = req;
67
68 let (a_slice, b_slice, c_slice) = match envelope::access_all_3(&a, &b, &c) {
69 Ok(t) => t,
70 Err(e) => {
71 let _ = reply.send(Err(e));
72 return;
73 }
74 };
75
76 let cfg = StridedBatchedConfig::<T> {
77 gemm: GemmConfig::<T> {
78 transa: trans_a,
79 transb: trans_b,
80 m,
81 n,
82 k,
83 alpha,
84 lda,
85 ldb,
86 beta,
87 ldc,
88 },
89 batch_size,
90 stride_a,
91 stride_b,
92 stride_c,
93 };
94
95 let mut c_owned = match Arc::try_unwrap(c_slice) {
96 Ok(s) => s,
97 Err(_arc) => {
98 let _ = reply.send(Err(GpuError::Unrecoverable(
99 "GEMM strided-batched target buffer C has more than one live reference; \
100 caller must hold the unique GpuRef to write to it"
101 .into(),
102 )));
103 return;
104 }
105 };
106
107 c.record_write(ctx.stream);
108
109 let cublas = ctx.cublas.clone();
110 envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
111 let res = unsafe { cublas.gemm_strided_batched(cfg, &*a_slice, &*b_slice, &mut c_owned) };
112 match res {
113 Ok(()) => Ok((cublas, a_slice, b_slice, c_owned)),
114 Err(e) => Err(GpuError::LibraryError {
115 lib: LIB,
116 msg: format!("gemm_strided_batched enqueue: {e}"),
117 }),
118 }
119 });
120}
121
122impl GemmStridedBatchedDispatch for GemmStridedBatchedRequest<f32> {
123 fn dtype_name(&self) -> &'static str {
124 <f32 as atomr_accel::AccelDtype>::NAME
125 }
126 fn op_name(&self) -> &'static str {
127 "gemm_strided_batched"
128 }
129 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
130 dispatch_strided_batched::<f32>(*self, ctx);
131 }
132}
133
134impl GemmStridedBatchedDispatch for GemmStridedBatchedRequest<f64> {
135 fn dtype_name(&self) -> &'static str {
136 <f64 as atomr_accel::AccelDtype>::NAME
137 }
138 fn op_name(&self) -> &'static str {
139 "gemm_strided_batched"
140 }
141 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
142 dispatch_strided_batched::<f64>(*self, ctx);
143 }
144}
145
146#[cfg(feature = "f16")]
147impl GemmStridedBatchedDispatch for GemmStridedBatchedRequest<half::f16> {
148 fn dtype_name(&self) -> &'static str {
149 <half::f16 as atomr_accel::AccelDtype>::NAME
150 }
151 fn op_name(&self) -> &'static str {
152 "gemm_strided_batched"
153 }
154 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
155 dispatch_strided_batched::<half::f16>(*self, ctx);
156 }
157}
158
159#[cfg(feature = "f16")]
160impl GemmStridedBatchedDispatch for GemmStridedBatchedRequest<half::bf16> {
161 fn dtype_name(&self) -> &'static str {
162 <half::bf16 as atomr_accel::AccelDtype>::NAME
163 }
164 fn op_name(&self) -> &'static str {
165 "gemm_strided_batched"
166 }
167 fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
168 dispatch_strided_batched::<half::bf16>(*self, ctx);
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use tokio::sync::oneshot;
176
177 #[test]
178 fn strided_batched_request_round_trip() {
179 let (tx, _rx) = oneshot::channel();
180 let req = GemmStridedBatchedRequest::<f32> {
181 a: super::super::gemm::tests_helpers::gpu_ref_stub::<f32>(),
182 b: super::super::gemm::tests_helpers::gpu_ref_stub::<f32>(),
183 c: super::super::gemm::tests_helpers::gpu_ref_stub::<f32>(),
184 m: 1,
185 n: 1,
186 k: 1,
187 alpha: 1.0,
188 beta: 0.0,
189 trans_a: cublasOperation_t::CUBLAS_OP_N,
190 trans_b: cublasOperation_t::CUBLAS_OP_N,
191 lda: 1,
192 ldb: 1,
193 ldc: 1,
194 stride_a: 1,
195 stride_b: 1,
196 stride_c: 1,
197 batch_size: 4,
198 reply: tx,
199 };
200 let boxed: Box<dyn GemmStridedBatchedDispatch> = Box::new(req);
201 assert_eq!(boxed.op_name(), "gemm_strided_batched");
202 assert_eq!(boxed.dtype_name(), "f32");
203 Box::leak(boxed);
204 }
205}