Skip to main content

atomr_accel_cuda/graph/record/
sgemm.rs

1//! SGEMM op for [`super::super::GraphOp`].
2//!
3//! Wraps the lower-level [`crate::kernel::record::BlasSgemmOp`] /
4//! [`crate::kernel::record::BlasRecorder`] pair into a single
5//! `GraphOp`-implementing type.
6
7use crate::error::GpuError;
8use crate::gpu_ref::GpuRef;
9use crate::graph::{GraphOp, GraphRecordCtx};
10use crate::kernel::record::{BlasRecorder, BlasSgemmOp, RecordMode};
11
12/// SGEMM op for graph capture: `C := alpha · A·B + beta · C`,
13/// column-major, no transpose.
14///
15/// The op needs a cuBLAS handle on the captured stream, supplied
16/// by `GraphRecordCtx::blas`. If absent the op fails with
17/// [`GpuError::Unrecoverable`].
18///
19/// `record` consumes the held `GpuRef`s on first invocation
20/// (matching the pre-trait closed-enum semantics where the boxed
21/// op was destructured by-move). A second `record` call on the
22/// same op returns [`GpuError::Unrecoverable`].
23pub struct SgemmOp {
24    inner: Option<BlasSgemmOp>,
25}
26
27impl SgemmOp {
28    pub fn new(
29        a: GpuRef<f32>,
30        b: GpuRef<f32>,
31        c: GpuRef<f32>,
32        m: i32,
33        n: i32,
34        k: i32,
35        alpha: f32,
36        beta: f32,
37    ) -> Self {
38        Self {
39            inner: Some(BlasSgemmOp {
40                a,
41                b,
42                c,
43                m,
44                n,
45                k,
46                alpha,
47                beta,
48            }),
49        }
50    }
51}
52
53impl GraphOp for SgemmOp {
54    fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
55        let stream = ctx.require_stream()?;
56        let blas = ctx.blas.ok_or_else(|| {
57            GpuError::Unrecoverable("SgemmOp::record: cuBLAS handle not available in ctx".into())
58        })?;
59        let op = self
60            .inner
61            .take()
62            .ok_or_else(|| GpuError::Unrecoverable("SgemmOp::record: already consumed".into()))?;
63        let mut recorder = BlasRecorder { handle: blas };
64        recorder.enqueue_record(stream, op)
65    }
66
67    fn op_name(&self) -> &'static str {
68        "graph::sgemm"
69    }
70}