Skip to main content

atomr_accel_cuda/kernel/blas/
gemm_strided_batched.rs

1//! Typed `GemmStridedBatchedRequest<T>` + `GemmStridedBatchedDispatch`
2//! impls.
3
4use std::sync::Arc;
5
6use cudarc::cublas::sys::cublasOperation_t;
7use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
8use tokio::sync::oneshot;
9
10use crate::dtype::GemmSupported;
11use crate::error::GpuError;
12use crate::gpu_ref::GpuRef;
13use crate::kernel::dispatch::{BlasDispatchCtx, GemmStridedBatchedDispatch};
14use crate::kernel::envelope;
15
16const LIB: &str = "cublas";
17
18/// Typed strided-batched gemm request. Per-batch strides describe the
19/// element offset between consecutive batch entries inside a single
20/// allocation.
21pub struct GemmStridedBatchedRequest<T: GemmSupported> {
22    pub a: GpuRef<T>,
23    pub b: GpuRef<T>,
24    pub c: GpuRef<T>,
25    pub m: i32,
26    pub n: i32,
27    pub k: i32,
28    pub alpha: T,
29    pub beta: T,
30    pub trans_a: cublasOperation_t,
31    pub trans_b: cublasOperation_t,
32    pub lda: i32,
33    pub ldb: i32,
34    pub ldc: i32,
35    pub stride_a: i64,
36    pub stride_b: i64,
37    pub stride_c: i64,
38    pub batch_size: i32,
39    pub reply: oneshot::Sender<Result<(), GpuError>>,
40}
41
42fn dispatch_strided_batched<T>(req: GemmStridedBatchedRequest<T>, ctx: &BlasDispatchCtx<'_>)
43where
44    T: GemmSupported + Copy,
45    cudarc::cublas::CudaBlas: Gemm<T>,
46{
47    let GemmStridedBatchedRequest {
48        a,
49        b,
50        c,
51        m,
52        n,
53        k,
54        alpha,
55        beta,
56        trans_a,
57        trans_b,
58        lda,
59        ldb,
60        ldc,
61        stride_a,
62        stride_b,
63        stride_c,
64        batch_size,
65        reply,
66    } = req;
67
68    let (a_slice, b_slice, c_slice) = match envelope::access_all_3(&a, &b, &c) {
69        Ok(t) => t,
70        Err(e) => {
71            let _ = reply.send(Err(e));
72            return;
73        }
74    };
75
76    let cfg = StridedBatchedConfig::<T> {
77        gemm: GemmConfig::<T> {
78            transa: trans_a,
79            transb: trans_b,
80            m,
81            n,
82            k,
83            alpha,
84            lda,
85            ldb,
86            beta,
87            ldc,
88        },
89        batch_size,
90        stride_a,
91        stride_b,
92        stride_c,
93    };
94
95    let mut c_owned = match Arc::try_unwrap(c_slice) {
96        Ok(s) => s,
97        Err(_arc) => {
98            let _ = reply.send(Err(GpuError::Unrecoverable(
99                "GEMM strided-batched target buffer C has more than one live reference; \
100                 caller must hold the unique GpuRef to write to it"
101                    .into(),
102            )));
103            return;
104        }
105    };
106
107    c.record_write(ctx.stream);
108
109    let cublas = ctx.cublas.clone();
110    envelope::run_kernel(LIB, ctx.stream, ctx.completion, (), reply, move || {
111        let res = unsafe { cublas.gemm_strided_batched(cfg, &*a_slice, &*b_slice, &mut c_owned) };
112        match res {
113            Ok(()) => Ok((cublas, a_slice, b_slice, c_owned)),
114            Err(e) => Err(GpuError::LibraryError {
115                lib: LIB,
116                msg: format!("gemm_strided_batched enqueue: {e}"),
117            }),
118        }
119    });
120}
121
122impl GemmStridedBatchedDispatch for GemmStridedBatchedRequest<f32> {
123    fn dtype_name(&self) -> &'static str {
124        <f32 as atomr_accel::AccelDtype>::NAME
125    }
126    fn op_name(&self) -> &'static str {
127        "gemm_strided_batched"
128    }
129    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
130        dispatch_strided_batched::<f32>(*self, ctx);
131    }
132}
133
134impl GemmStridedBatchedDispatch for GemmStridedBatchedRequest<f64> {
135    fn dtype_name(&self) -> &'static str {
136        <f64 as atomr_accel::AccelDtype>::NAME
137    }
138    fn op_name(&self) -> &'static str {
139        "gemm_strided_batched"
140    }
141    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
142        dispatch_strided_batched::<f64>(*self, ctx);
143    }
144}
145
146#[cfg(feature = "f16")]
147impl GemmStridedBatchedDispatch for GemmStridedBatchedRequest<half::f16> {
148    fn dtype_name(&self) -> &'static str {
149        <half::f16 as atomr_accel::AccelDtype>::NAME
150    }
151    fn op_name(&self) -> &'static str {
152        "gemm_strided_batched"
153    }
154    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
155        dispatch_strided_batched::<half::f16>(*self, ctx);
156    }
157}
158
159#[cfg(feature = "f16")]
160impl GemmStridedBatchedDispatch for GemmStridedBatchedRequest<half::bf16> {
161    fn dtype_name(&self) -> &'static str {
162        <half::bf16 as atomr_accel::AccelDtype>::NAME
163    }
164    fn op_name(&self) -> &'static str {
165        "gemm_strided_batched"
166    }
167    fn dispatch(self: Box<Self>, ctx: &BlasDispatchCtx<'_>) {
168        dispatch_strided_batched::<half::bf16>(*self, ctx);
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use tokio::sync::oneshot;
176
177    #[test]
178    fn strided_batched_request_round_trip() {
179        let (tx, _rx) = oneshot::channel();
180        let req = GemmStridedBatchedRequest::<f32> {
181            a: super::super::gemm::tests_helpers::gpu_ref_stub::<f32>(),
182            b: super::super::gemm::tests_helpers::gpu_ref_stub::<f32>(),
183            c: super::super::gemm::tests_helpers::gpu_ref_stub::<f32>(),
184            m: 1,
185            n: 1,
186            k: 1,
187            alpha: 1.0,
188            beta: 0.0,
189            trans_a: cublasOperation_t::CUBLAS_OP_N,
190            trans_b: cublasOperation_t::CUBLAS_OP_N,
191            lda: 1,
192            ldb: 1,
193            ldc: 1,
194            stride_a: 1,
195            stride_b: 1,
196            stride_c: 1,
197            batch_size: 4,
198            reply: tx,
199        };
200        let boxed: Box<dyn GemmStridedBatchedDispatch> = Box::new(req);
201        assert_eq!(boxed.op_name(), "gemm_strided_batched");
202        assert_eq!(boxed.dtype_name(), "f32");
203        Box::leak(boxed);
204    }
205}