atomr_accel_cuda/sys/
cublaslt.rs1use std::ffi::c_void;
15use std::ptr;
16
17#[cfg(test)]
18use cudarc::cublaslt::sys::cublasLtMatmulDescOpaque_t;
19use cudarc::cublaslt::sys::{
20 cublasLtMatmulDescAttributes_t, cublasLtMatmulDesc_t, cublasLtMatmulPreferenceAttributes_t,
21 cublasLtMatmulPreferenceOpaque_t, cublasLtMatmulPreference_t, cublasStatus_t,
22};
23
24pub fn check(status: cublasStatus_t, op: &str) -> Result<(), String> {
30 match status {
31 cublasStatus_t::CUBLAS_STATUS_SUCCESS => Ok(()),
32 other => Err(format!("{op}: {other:?}")),
33 }
34}
35
36pub struct Preference {
39 pub raw: cublasLtMatmulPreference_t,
40}
41
42unsafe impl Send for Preference {}
46unsafe impl Sync for Preference {}
47
48impl Preference {
49 pub fn new() -> Result<Self, String> {
50 let mut raw: cublasLtMatmulPreference_t =
51 ptr::null_mut::<cublasLtMatmulPreferenceOpaque_t>();
52 let status = unsafe { cudarc::cublaslt::sys::cublasLtMatmulPreferenceCreate(&mut raw) };
53 check(status, "cublasLtMatmulPreferenceCreate")?;
54 Ok(Self { raw })
55 }
56
57 pub fn set_u64(
61 &self,
62 attr: cublasLtMatmulPreferenceAttributes_t,
63 value: u64,
64 ) -> Result<(), String> {
65 let status = unsafe {
66 cudarc::cublaslt::sys::cublasLtMatmulPreferenceSetAttribute(
67 self.raw,
68 attr,
69 &value as *const u64 as *const c_void,
70 std::mem::size_of::<u64>(),
71 )
72 };
73 check(status, "cublasLtMatmulPreferenceSetAttribute")
74 }
75}
76
77impl Drop for Preference {
78 fn drop(&mut self) {
79 if !self.raw.is_null() {
80 unsafe {
81 let _ = cudarc::cublaslt::sys::cublasLtMatmulPreferenceDestroy(self.raw);
82 }
83 self.raw = ptr::null_mut::<cublasLtMatmulPreferenceOpaque_t>();
84 }
85 }
86}
87
88pub unsafe fn set_desc_pointer_attr(
96 desc: cublasLtMatmulDesc_t,
97 attr: cublasLtMatmulDescAttributes_t,
98 ptr: *const c_void,
99) -> Result<(), String> {
100 let status = unsafe {
101 cudarc::cublaslt::sys::cublasLtMatmulDescSetAttribute(
102 desc,
103 attr,
104 &ptr as *const *const c_void as *const c_void,
105 std::mem::size_of::<*const c_void>(),
106 )
107 };
108 check(status, "cublasLtMatmulDescSetAttribute(pointer)")
109}
110
111pub unsafe fn set_desc_i32_attr(
119 desc: cublasLtMatmulDesc_t,
120 attr: cublasLtMatmulDescAttributes_t,
121 value: i32,
122) -> Result<(), String> {
123 let status = unsafe {
124 cudarc::cublaslt::sys::cublasLtMatmulDescSetAttribute(
125 desc,
126 attr,
127 &value as *const i32 as *const c_void,
128 std::mem::size_of::<i32>(),
129 )
130 };
131 check(status, "cublasLtMatmulDescSetAttribute(i32)")
132}
133
134#[cfg(test)]
137pub fn mock_desc_handle() -> cublasLtMatmulDesc_t {
138 let leaked: Box<cublasLtMatmulDescOpaque_t> = Box::new(unsafe { std::mem::zeroed() });
139 Box::into_raw(leaked)
140}
141
142#[cfg(test)]
144pub unsafe fn drop_mock_desc(desc: cublasLtMatmulDesc_t) {
145 if !desc.is_null() {
146 let _ = unsafe { Box::from_raw(desc) };
147 }
148}