atomr_accel_cuda/graph/record/
sgemm.rs1use crate::error::GpuError;
8use crate::gpu_ref::GpuRef;
9use crate::graph::{GraphOp, GraphRecordCtx};
10use crate::kernel::record::{BlasRecorder, BlasSgemmOp, RecordMode};
11
12pub struct SgemmOp {
24 inner: Option<BlasSgemmOp>,
25}
26
27impl SgemmOp {
28 pub fn new(
29 a: GpuRef<f32>,
30 b: GpuRef<f32>,
31 c: GpuRef<f32>,
32 m: i32,
33 n: i32,
34 k: i32,
35 alpha: f32,
36 beta: f32,
37 ) -> Self {
38 Self {
39 inner: Some(BlasSgemmOp {
40 a,
41 b,
42 c,
43 m,
44 n,
45 k,
46 alpha,
47 beta,
48 }),
49 }
50 }
51}
52
53impl GraphOp for SgemmOp {
54 fn record(&mut self, ctx: &mut GraphRecordCtx<'_>) -> Result<(), GpuError> {
55 let stream = ctx.require_stream()?;
56 let blas = ctx.blas.ok_or_else(|| {
57 GpuError::Unrecoverable("SgemmOp::record: cuBLAS handle not available in ctx".into())
58 })?;
59 let op = self
60 .inner
61 .take()
62 .ok_or_else(|| GpuError::Unrecoverable("SgemmOp::record: already consumed".into()))?;
63 let mut recorder = BlasRecorder { handle: blas };
64 recorder.enqueue_record(stream, op)
65 }
66
67 fn op_name(&self) -> &'static str {
68 "graph::sgemm"
69 }
70}