Skip to main content

atomr_accel_cuda/
dtype.rs

1//! `CudaDtype` — CUDA-side dtype mappings and capability markers.
2//!
3//! The backend-agnostic [`AccelDtype`] trait (in `atomr-accel`) names
4//! the dtype and gives identity values. `CudaDtype` adds the cudarc-
5//! enum mappings every kernel actor needs:
6//!
7//! - [`cuda_data_type`](CudaDtype::cuda_data_type) — `cudaDataType_t`
8//!   (consumed by cuBLAS, cuBLASLt, cuSPARSE, cuSOLVER, cuTENSOR).
9//! - [`cublas_compute_type`](CudaDtype::cublas_compute_type) — the
10//!   natural `cublasComputeType_t` for matmul accumulation.
11//! - [`cudnn_data_type`](CudaDtype::cudnn_data_type) — `cudnnDataType_t`
12//!   (cuDNN tensor descriptor element type), gated on `cudnn`.
13//! - [`nccl_data_type`](CudaDtype::nccl_data_type) — `ncclDataType_t`
14//!   (collective-op element type), gated on `nccl`.
15//! - [`cuda_type_name`](CudaDtype::cuda_type_name) — CUDA C++ type
16//!   name (`"float"`, `"__half"`, `"__nv_bfloat16"`, …) for NVRTC
17//!   kernel source generation.
18//!
19//! Capability markers ([`GemmSupported`], [`CudnnSupported`], …) are
20//! the compile-time gate keeping operations from being dispatched
21//! against unsupported dtypes — `BlasMsg::gemm::<i64>(...)` does not
22//! compile because `i64: GemmSupported` has no impl.
23
24use cudarc::cublas::sys as cublas_sys;
25#[cfg(feature = "cudnn")]
26use cudarc::cudnn::sys as cudnn_sys;
27use cudarc::driver::{DeviceRepr, ValidAsZeroBits};
28#[cfg(feature = "nccl")]
29use cudarc::nccl::sys as nccl_sys;
30
31/// Re-export `atomr_accel::DType` so existing `crate::dtype::DType`
32/// imports inside `atomr-accel-cuda` (added by Phase 0.4) keep working
33/// without changing every call site.
34pub use atomr_accel::DType;
35
36/// Re-export so `crate::dtype::AccelDtype` resolves for actor modules
37/// that prefer the unified import path.
38pub use atomr_accel::AccelDtype;
39
40/// Alias used by `BlasLtDispatch::dtype_kind` and other Phase 1 dispatchers.
41pub use atomr_accel::DType as DTypeKind;
42
43/// Local fp8 / fp4 wrappers (`#[repr(transparent)]` over `u8`) that
44/// satisfy cudarc's orphan-rule constraint for `unsafe impl DeviceRepr`.
45/// Convertible from/to the backend-agnostic `atomr_accel::dtype::*`
46/// equivalents.
47/// 64-bit interleaved complex (`{re, im}` of f32). Layout matches
48/// `cufft_sys::float2` and `numpy.complex64`. Phase 1.5++.
49///
50/// `#[repr(transparent)]` over `[f32; 2]` so `unsafe impl DeviceRepr`
51/// is sound (cudarc's orphan rule blocks blanket impls on tuple-struct
52/// shapes from foreign crates) and so transmutes from `Vec<C32>` to /
53/// from `Vec<num_complex::Complex<f32>>` (also `#[repr(C)]` with two
54/// `f32` fields) are layout-safe.
55///
56/// Maps to `cudaDataType_t::CUDA_C_32F` and `cuComplex` in CUDA C++.
57#[repr(transparent)]
58#[derive(Copy, Clone, Debug, Default, PartialEq)]
59pub struct C32(pub [f32; 2]);
60
61impl C32 {
62    /// Construct from real / imaginary components.
63    #[inline]
64    pub const fn new(re: f32, im: f32) -> Self {
65        Self([re, im])
66    }
67    #[inline]
68    pub fn re(self) -> f32 {
69        self.0[0]
70    }
71    #[inline]
72    pub fn im(self) -> f32 {
73        self.0[1]
74    }
75}
76
77/// 128-bit interleaved complex (`{re, im}` of f64). Layout matches
78/// `cufft_sys::double2` and `numpy.complex128`. Phase 1.5++.
79///
80/// Maps to `cudaDataType_t::CUDA_C_64F` and `cuDoubleComplex`.
81#[repr(transparent)]
82#[derive(Copy, Clone, Debug, Default, PartialEq)]
83pub struct C64(pub [f64; 2]);
84
85impl C64 {
86    /// Construct from real / imaginary components.
87    #[inline]
88    pub const fn new(re: f64, im: f64) -> Self {
89        Self([re, im])
90    }
91    #[inline]
92    pub fn re(self) -> f64 {
93        self.0[0]
94    }
95    #[inline]
96    pub fn im(self) -> f64 {
97        self.0[1]
98    }
99}
100
101// Layout bridges to the cudarc cuFFT FFI structs. `cufft_sys::float2`
102// is `#[repr(C)] #[repr(align(8))] { x: f32, y: f32 }` — slightly
103// stricter alignment than `[f32; 2]` (align 4). The `From` impl is a
104// field-by-field copy, not a transmute, so the alignment mismatch is
105// harmless on the value side. (Transmuting `*const float2` → `*const
106// C32` is _not_ sound because the alignment shrinks; callers needing
107// a pointer-level bridge should keep `cufft_sys::float2` typed.)
108//
109// Gated on the `cufft` cargo feature because the cudarc::cufft module
110// is itself feature-gated.
111#[cfg(feature = "cufft")]
112impl From<cudarc::cufft::sys::float2> for C32 {
113    #[inline]
114    fn from(v: cudarc::cufft::sys::float2) -> Self {
115        C32([v.x, v.y])
116    }
117}
118
119#[cfg(feature = "cufft")]
120impl From<C32> for cudarc::cufft::sys::float2 {
121    #[inline]
122    fn from(v: C32) -> Self {
123        cudarc::cufft::sys::float2 {
124            x: v.0[0],
125            y: v.0[1],
126        }
127    }
128}
129
130#[cfg(feature = "cufft")]
131impl From<cudarc::cufft::sys::double2> for C64 {
132    #[inline]
133    fn from(v: cudarc::cufft::sys::double2) -> Self {
134        C64([v.x, v.y])
135    }
136}
137
138#[cfg(feature = "cufft")]
139impl From<C64> for cudarc::cufft::sys::double2 {
140    #[inline]
141    fn from(v: C64) -> Self {
142        cudarc::cufft::sys::double2 {
143            x: v.0[0],
144            y: v.0[1],
145        }
146    }
147}
148
149// SAFETY: `C32` / `C64` are `#[repr(transparent)]` over `[f32; 2]` /
150// `[f64; 2]`, both of which are POD; cudarc allows arbitrary `Copy +
151// 'static` over device-mappable bit patterns to be `DeviceRepr`. All
152// bit patterns of `f32` / `f64` (including NaN, inf, signaling NaN)
153// are valid floats, so `ValidAsZeroBits` is sound — the all-zeros
154// pattern represents `+0.0 + 0.0i`.
155unsafe impl DeviceRepr for C32 {}
156unsafe impl ValidAsZeroBits for C32 {}
157unsafe impl DeviceRepr for C64 {}
158unsafe impl ValidAsZeroBits for C64 {}
159
160// `AccelDtype` requires a `DType` discriminant. The base atomr-accel
161// `DType` enum has no Complex variant; reuse the matching scalar lane
162// (`F32` for C32, `F64` for C64) — the same convention `FftKind`'s
163// `scalar_dtype()` already uses for cuFFT plan keys. Callers that
164// need to distinguish complex from real branch on `T` directly, not
165// on `KIND`.
166impl atomr_accel::AccelDtype for C32 {
167    type Scalar = f32;
168    const KIND: DType = DType::F32;
169    const SIZE: usize = 8;
170    const NAME: &'static str = "complex64";
171    #[inline]
172    fn zero() -> Self {
173        C32([0.0, 0.0])
174    }
175    #[inline]
176    fn one() -> Self {
177        C32([1.0, 0.0])
178    }
179    #[inline]
180    fn nan() -> Option<Self> {
181        Some(C32([f32::NAN, f32::NAN]))
182    }
183}
184
185impl atomr_accel::AccelDtype for C64 {
186    type Scalar = f64;
187    const KIND: DType = DType::F64;
188    const SIZE: usize = 16;
189    const NAME: &'static str = "complex128";
190    #[inline]
191    fn zero() -> Self {
192        C64([0.0, 0.0])
193    }
194    #[inline]
195    fn one() -> Self {
196        C64([1.0, 0.0])
197    }
198    #[inline]
199    fn nan() -> Option<Self> {
200        Some(C64([f64::NAN, f64::NAN]))
201    }
202}
203
204#[cfg(feature = "f8")]
205#[repr(transparent)]
206#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
207pub struct F8E4m3(pub u8);
208
209#[cfg(feature = "f8")]
210impl F8E4m3 {
211    #[inline]
212    pub fn from_f32(x: f32) -> Self {
213        Self(atomr_accel::dtype::F8E4m3::from_f32(x).0)
214    }
215    #[inline]
216    pub fn to_f32(self) -> f32 {
217        atomr_accel::dtype::F8E4m3(self.0).to_f32()
218    }
219}
220
221#[cfg(feature = "f8")]
222#[repr(transparent)]
223#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
224pub struct F8E5m2(pub u8);
225
226#[cfg(feature = "f8")]
227impl F8E5m2 {
228    #[inline]
229    pub fn from_f32(x: f32) -> Self {
230        Self(atomr_accel::dtype::F8E5m2::from_f32(x).0)
231    }
232    #[inline]
233    pub fn to_f32(self) -> f32 {
234        atomr_accel::dtype::F8E5m2(self.0).to_f32()
235    }
236}
237
238#[cfg(feature = "f8")]
239impl atomr_accel::AccelDtype for F8E4m3 {
240    type Scalar = f32;
241    const KIND: DType = DType::F8E4m3;
242    const SIZE: usize = 1;
243    const NAME: &'static str = "f8_e4m3";
244    #[inline]
245    fn zero() -> Self {
246        F8E4m3(0x00)
247    }
248    #[inline]
249    fn one() -> Self {
250        F8E4m3(0x38)
251    }
252    #[inline]
253    fn nan() -> Option<Self> {
254        Some(F8E4m3(0x7f))
255    }
256}
257
258#[cfg(feature = "f8")]
259impl atomr_accel::AccelDtype for F8E5m2 {
260    type Scalar = f32;
261    const KIND: DType = DType::F8E5m2;
262    const SIZE: usize = 1;
263    const NAME: &'static str = "f8_e5m2";
264    #[inline]
265    fn zero() -> Self {
266        F8E5m2(0x00)
267    }
268    #[inline]
269    fn one() -> Self {
270        F8E5m2(0x3c)
271    }
272    #[inline]
273    fn nan() -> Option<Self> {
274        Some(F8E5m2(0x7e))
275    }
276}
277
278/// CUDA-specific layer over [`AccelDtype`].
279///
280/// The cudarc bounds (`DeviceRepr`, `ValidAsZeroBits`) are part of the
281/// supertrait list so dispatch payloads can call `stream.alloc_zeros::<T>`
282/// behind a single `T: CudaDtype` bound.
283pub trait CudaDtype: AccelDtype + DeviceRepr + ValidAsZeroBits {
284    fn cuda_data_type() -> cublas_sys::cudaDataType_t;
285    fn cublas_compute_type() -> cublas_sys::cublasComputeType_t;
286    /// CUDA C++ type name for NVRTC source generation.
287    fn cuda_type_name() -> &'static str;
288    #[cfg(feature = "cudnn")]
289    fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t;
290    #[cfg(feature = "nccl")]
291    fn nccl_data_type() -> nccl_sys::ncclDataType_t;
292}
293
294/// Capability marker — type may be a cuBLAS GEMM operand.
295pub trait GemmSupported: CudaDtype {}
296
297/// Capability marker — type may be a cuDNN tensor element.
298pub trait CudnnSupported: CudaDtype {}
299
300/// Capability marker — type may be a cuFFT element.
301pub trait FftSupported: CudaDtype {}
302
303/// Capability marker — type may be a cuRAND distribution element
304/// (`Self` is one of the float dtypes accepted by `curandGenerate*`).
305pub trait RngFloatSupported: CudaDtype {}
306
307/// Capability marker — type may be an NCCL collective-op element.
308pub trait NcclReduceSupported: CudaDtype {}
309
310/// Capability marker — type may be a cuSOLVER dense factorization
311/// element (real or complex float).
312pub trait SolverSupported: CudaDtype {}
313
314/// Capability marker — type may be a cuSPARSE SpMV/SpMM/SpGEMM
315/// element.
316pub trait SparseSupported: CudaDtype {}
317
318/// Capability marker — type may be a cuTENSOR contraction operand.
319pub trait TensorSupported: CudaDtype {}
320
321/// Capability marker — cuRAND integer-fill operand. `curandGenerate` produces u32,
322/// `curandGenerateLongLong` produces u64. Used by `Discrete` and raw-bit paths.
323pub trait RngIntSupported: CudaDtype {}
324
325/// Phase 4 cuSPARSE index-type marker. Only `i32` and `i64` are
326/// representable cuSPARSE row/col index dtypes.
327pub trait SparseIndex: AccelDtype {
328    #[cfg(feature = "cusparse")]
329    fn cusparse_index_type() -> cudarc::cusparse::sys::cusparseIndexType_t;
330}
331
332#[cfg(feature = "cusparse")]
333impl SparseIndex for i32 {
334    fn cusparse_index_type() -> cudarc::cusparse::sys::cusparseIndexType_t {
335        cudarc::cusparse::sys::cusparseIndexType_t::CUSPARSE_INDEX_32I
336    }
337}
338#[cfg(feature = "cusparse")]
339impl SparseIndex for i64 {
340    fn cusparse_index_type() -> cudarc::cusparse::sys::cusparseIndexType_t {
341        cudarc::cusparse::sys::cusparseIndexType_t::CUSPARSE_INDEX_64I
342    }
343}
344#[cfg(not(feature = "cusparse"))]
345impl SparseIndex for i32 {}
346#[cfg(not(feature = "cusparse"))]
347impl SparseIndex for i64 {}
348
349// Phase 1 cuBLAS sub-marker traits (per-op dtype subsets).
350
351pub trait AxpyDotNrm2Supported: CudaDtype {}
352pub trait GemvSupported: CudaDtype {}
353pub trait GerSupported: CudaDtype {}
354pub trait GeamSupported: CudaDtype {}
355pub trait SyrkSupported: CudaDtype {}
356pub trait TrsmSupported: CudaDtype {}
357
358macro_rules! impl_cuda_dtype {
359    (
360        $rust:ty,
361        $cuda:ident,
362        $compute:ident,
363        $name:literal,
364        cudnn: $cudnn:ident,
365        nccl: $($nccl:ident)?
366    ) => {
367        impl CudaDtype for $rust {
368            #[inline]
369            fn cuda_data_type() -> cublas_sys::cudaDataType_t {
370                cublas_sys::cudaDataType_t::$cuda
371            }
372            #[inline]
373            fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
374                cublas_sys::cublasComputeType_t::$compute
375            }
376            #[inline]
377            fn cuda_type_name() -> &'static str { $name }
378            #[cfg(feature = "cudnn")]
379            #[inline]
380            fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
381                cudnn_sys::cudnnDataType_t::$cudnn
382            }
383            #[cfg(feature = "nccl")]
384            #[inline]
385            fn nccl_data_type() -> nccl_sys::ncclDataType_t {
386                impl_cuda_dtype!(@nccl_arm $($nccl)?)
387            }
388        }
389    };
390    (@nccl_arm $variant:ident) => { nccl_sys::ncclDataType_t::$variant };
391    (@nccl_arm) => { panic!("dtype not supported by NCCL") };
392}
393
394impl_cuda_dtype!(f32, CUDA_R_32F, CUBLAS_COMPUTE_32F, "float",
395    cudnn: CUDNN_DATA_FLOAT, nccl: ncclFloat32);
396impl_cuda_dtype!(f64, CUDA_R_64F, CUBLAS_COMPUTE_64F, "double",
397    cudnn: CUDNN_DATA_DOUBLE, nccl: ncclFloat64);
398impl_cuda_dtype!(i8, CUDA_R_8I, CUBLAS_COMPUTE_32I, "char",
399    cudnn: CUDNN_DATA_INT8, nccl: ncclInt8);
400impl_cuda_dtype!(u8, CUDA_R_8U, CUBLAS_COMPUTE_32I, "unsigned char",
401    cudnn: CUDNN_DATA_UINT8, nccl: ncclUint8);
402impl_cuda_dtype!(i32, CUDA_R_32I, CUBLAS_COMPUTE_32I, "int",
403    cudnn: CUDNN_DATA_INT32, nccl: ncclInt32);
404impl_cuda_dtype!(u32, CUDA_R_32U, CUBLAS_COMPUTE_32I, "unsigned int",
405    cudnn: CUDNN_DATA_INT32, nccl: ncclUint32);
406impl_cuda_dtype!(i64, CUDA_R_64I, CUBLAS_COMPUTE_32I, "long long",
407    cudnn: CUDNN_DATA_INT64, nccl: ncclInt64);
408impl_cuda_dtype!(u64, CUDA_R_64U, CUBLAS_COMPUTE_32I, "unsigned long long",
409    cudnn: CUDNN_DATA_INT64, nccl: ncclUint64);
410
411#[cfg(feature = "f16")]
412impl CudaDtype for half::f16 {
413    #[inline]
414    fn cuda_data_type() -> cublas_sys::cudaDataType_t {
415        cublas_sys::cudaDataType_t::CUDA_R_16F
416    }
417    #[inline]
418    fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
419        cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
420    }
421    #[inline]
422    fn cuda_type_name() -> &'static str {
423        "__half"
424    }
425    #[cfg(feature = "cudnn")]
426    #[inline]
427    fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
428        cudnn_sys::cudnnDataType_t::CUDNN_DATA_HALF
429    }
430    #[cfg(feature = "nccl")]
431    #[inline]
432    fn nccl_data_type() -> nccl_sys::ncclDataType_t {
433        nccl_sys::ncclDataType_t::ncclFloat16
434    }
435}
436
437#[cfg(feature = "f16")]
438impl CudaDtype for half::bf16 {
439    #[inline]
440    fn cuda_data_type() -> cublas_sys::cudaDataType_t {
441        cublas_sys::cudaDataType_t::CUDA_R_16BF
442    }
443    #[inline]
444    fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
445        cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
446    }
447    #[inline]
448    fn cuda_type_name() -> &'static str {
449        "__nv_bfloat16"
450    }
451    #[cfg(feature = "cudnn")]
452    #[inline]
453    fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
454        cudnn_sys::cudnnDataType_t::CUDNN_DATA_BFLOAT16
455    }
456    #[cfg(feature = "nccl")]
457    #[inline]
458    fn nccl_data_type() -> nccl_sys::ncclDataType_t {
459        nccl_sys::ncclDataType_t::ncclBfloat16
460    }
461}
462
463impl GemmSupported for f32 {}
464impl GemmSupported for f64 {}
465impl GemmSupported for i8 {}
466impl GemmSupported for i32 {}
467#[cfg(feature = "f16")]
468impl GemmSupported for half::f16 {}
469#[cfg(feature = "f16")]
470impl GemmSupported for half::bf16 {}
471
472#[cfg(feature = "f8")]
473mod fp8_impls {
474    use super::*;
475    use cudarc::driver::{DeviceRepr, ValidAsZeroBits};
476
477    unsafe impl DeviceRepr for F8E4m3 {}
478    unsafe impl ValidAsZeroBits for F8E4m3 {}
479    unsafe impl DeviceRepr for F8E5m2 {}
480    unsafe impl ValidAsZeroBits for F8E5m2 {}
481
482    impl CudaDtype for F8E4m3 {
483        #[inline]
484        fn cuda_data_type() -> cublas_sys::cudaDataType_t {
485            cublas_sys::cudaDataType_t::CUDA_R_8F_E4M3
486        }
487        #[inline]
488        fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
489            cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
490        }
491        #[inline]
492        fn cuda_type_name() -> &'static str {
493            "__nv_fp8_e4m3"
494        }
495        #[cfg(feature = "cudnn")]
496        #[inline]
497        fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
498            cudnn_sys::cudnnDataType_t::CUDNN_DATA_FP8_E4M3
499        }
500        #[cfg(feature = "nccl")]
501        #[inline]
502        fn nccl_data_type() -> nccl_sys::ncclDataType_t {
503            nccl_sys::ncclDataType_t::ncclFloat8e4m3
504        }
505    }
506
507    impl CudaDtype for F8E5m2 {
508        #[inline]
509        fn cuda_data_type() -> cublas_sys::cudaDataType_t {
510            cublas_sys::cudaDataType_t::CUDA_R_8F_E5M2
511        }
512        #[inline]
513        fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
514            cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
515        }
516        #[inline]
517        fn cuda_type_name() -> &'static str {
518            "__nv_fp8_e5m2"
519        }
520        #[cfg(feature = "cudnn")]
521        #[inline]
522        fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
523            cudnn_sys::cudnnDataType_t::CUDNN_DATA_FP8_E5M2
524        }
525        #[cfg(feature = "nccl")]
526        #[inline]
527        fn nccl_data_type() -> nccl_sys::ncclDataType_t {
528            nccl_sys::ncclDataType_t::ncclFloat8e5m2
529        }
530    }
531
532    impl GemmSupported for F8E4m3 {}
533    impl GemmSupported for F8E5m2 {}
534    impl CudnnSupported for F8E4m3 {}
535    impl CudnnSupported for F8E5m2 {}
536    impl NcclReduceSupported for F8E4m3 {}
537    impl NcclReduceSupported for F8E5m2 {}
538}
539
540impl CudaDtype for C32 {
541    #[inline]
542    fn cuda_data_type() -> cublas_sys::cudaDataType_t {
543        cublas_sys::cudaDataType_t::CUDA_C_32F
544    }
545    #[inline]
546    fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
547        cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
548    }
549    #[inline]
550    fn cuda_type_name() -> &'static str {
551        "cuComplex"
552    }
553    #[cfg(feature = "cudnn")]
554    #[inline]
555    fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
556        // cuDNN has no native complex tensor element. Phase 1.5++ does
557        // not surface a CudnnSupported impl for complex dtypes; calling
558        // this method is a programmer error.
559        panic!("C32 is not a cuDNN tensor element type");
560    }
561    #[cfg(feature = "nccl")]
562    #[inline]
563    fn nccl_data_type() -> nccl_sys::ncclDataType_t {
564        // NCCL has no native complex reduce element. Same gating as
565        // above — Phase 1.5++ does not impl `NcclReduceSupported` for
566        // complex dtypes.
567        panic!("C32 is not an NCCL reduce element type");
568    }
569}
570
571impl CudaDtype for C64 {
572    #[inline]
573    fn cuda_data_type() -> cublas_sys::cudaDataType_t {
574        cublas_sys::cudaDataType_t::CUDA_C_64F
575    }
576    #[inline]
577    fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
578        cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_64F
579    }
580    #[inline]
581    fn cuda_type_name() -> &'static str {
582        "cuDoubleComplex"
583    }
584    #[cfg(feature = "cudnn")]
585    #[inline]
586    fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
587        panic!("C64 is not a cuDNN tensor element type");
588    }
589    #[cfg(feature = "nccl")]
590    #[inline]
591    fn nccl_data_type() -> nccl_sys::ncclDataType_t {
592        panic!("C64 is not an NCCL reduce element type");
593    }
594}
595
596impl CudnnSupported for f32 {}
597impl CudnnSupported for f64 {}
598impl CudnnSupported for i8 {}
599impl CudnnSupported for u8 {}
600impl CudnnSupported for i32 {}
601impl CudnnSupported for i64 {}
602#[cfg(feature = "f16")]
603impl CudnnSupported for half::f16 {}
604#[cfg(feature = "f16")]
605impl CudnnSupported for half::bf16 {}
606
607impl FftSupported for f32 {}
608impl FftSupported for f64 {}
609impl FftSupported for C32 {}
610impl FftSupported for C64 {}
611#[cfg(feature = "f16")]
612impl FftSupported for half::f16 {}
613
614impl RngFloatSupported for f32 {}
615impl RngFloatSupported for f64 {}
616
617impl RngIntSupported for u32 {}
618impl RngIntSupported for u64 {}
619
620impl AxpyDotNrm2Supported for f32 {}
621impl AxpyDotNrm2Supported for f64 {}
622#[cfg(feature = "f16")]
623impl AxpyDotNrm2Supported for half::f16 {}
624#[cfg(feature = "f16")]
625impl AxpyDotNrm2Supported for half::bf16 {}
626
627impl GemvSupported for f32 {}
628impl GemvSupported for f64 {}
629
630impl GerSupported for f32 {}
631impl GerSupported for f64 {}
632
633impl GeamSupported for f32 {}
634impl GeamSupported for f64 {}
635
636impl SyrkSupported for f32 {}
637impl SyrkSupported for f64 {}
638
639impl TrsmSupported for f32 {}
640impl TrsmSupported for f64 {}
641
642impl NcclReduceSupported for f32 {}
643impl NcclReduceSupported for f64 {}
644impl NcclReduceSupported for i8 {}
645impl NcclReduceSupported for u8 {}
646impl NcclReduceSupported for i32 {}
647impl NcclReduceSupported for u32 {}
648impl NcclReduceSupported for i64 {}
649impl NcclReduceSupported for u64 {}
650#[cfg(feature = "f16")]
651impl NcclReduceSupported for half::f16 {}
652#[cfg(feature = "f16")]
653impl NcclReduceSupported for half::bf16 {}
654
655impl SolverSupported for f32 {}
656impl SolverSupported for f64 {}
657
658impl SparseSupported for f32 {}
659impl SparseSupported for f64 {}
660#[cfg(feature = "f16")]
661impl SparseSupported for half::f16 {}
662#[cfg(feature = "f16")]
663impl SparseSupported for half::bf16 {}
664
665impl TensorSupported for f32 {}
666impl TensorSupported for f64 {}
667#[cfg(feature = "f16")]
668impl TensorSupported for half::f16 {}
669#[cfg(feature = "f16")]
670impl TensorSupported for half::bf16 {}
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675    use atomr_accel::DType;
676
677    #[test]
678    fn cuda_data_type_round_trip() {
679        assert_eq!(<f32 as AccelDtype>::KIND, DType::F32);
680        assert_eq!(
681            <f32 as CudaDtype>::cuda_data_type(),
682            cublas_sys::cudaDataType_t::CUDA_R_32F
683        );
684        assert_eq!(
685            <f64 as CudaDtype>::cuda_data_type(),
686            cublas_sys::cudaDataType_t::CUDA_R_64F
687        );
688        assert_eq!(<f32 as CudaDtype>::cuda_type_name(), "float");
689        assert_eq!(<f64 as CudaDtype>::cuda_type_name(), "double");
690    }
691
692    #[test]
693    fn integer_compute_types() {
694        assert_eq!(
695            <i32 as CudaDtype>::cublas_compute_type(),
696            cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32I
697        );
698    }
699
700    #[cfg(feature = "f16")]
701    #[test]
702    fn f16_mappings() {
703        assert_eq!(
704            <half::f16 as CudaDtype>::cuda_data_type(),
705            cublas_sys::cudaDataType_t::CUDA_R_16F
706        );
707        assert_eq!(
708            <half::bf16 as CudaDtype>::cuda_data_type(),
709            cublas_sys::cudaDataType_t::CUDA_R_16BF
710        );
711        assert_eq!(<half::f16 as CudaDtype>::cuda_type_name(), "__half");
712        assert_eq!(<half::bf16 as CudaDtype>::cuda_type_name(), "__nv_bfloat16");
713    }
714
715    fn _assert_capability_bounds<G, C, F, N>()
716    where
717        G: GemmSupported,
718        C: CudnnSupported,
719        F: FftSupported,
720        N: NcclReduceSupported,
721    {
722    }
723
724    #[test]
725    fn capability_compile_time_check() {
726        _assert_capability_bounds::<f32, f32, f32, f32>();
727        _assert_capability_bounds::<f64, f64, f64, f64>();
728    }
729
730    #[test]
731    fn complex_dtype_size_and_layout() {
732        assert_eq!(<C32 as AccelDtype>::SIZE, 8);
733        assert_eq!(<C64 as AccelDtype>::SIZE, 16);
734        assert_eq!(<C32 as AccelDtype>::KIND, DType::F32);
735        assert_eq!(<C64 as AccelDtype>::KIND, DType::F64);
736        assert_eq!(<C32 as AccelDtype>::NAME, "complex64");
737        assert_eq!(<C64 as AccelDtype>::NAME, "complex128");
738
739        // Layout matches `[T; 2]` exactly (transparent).
740        assert_eq!(std::mem::size_of::<C32>(), 8);
741        assert_eq!(std::mem::size_of::<C64>(), 16);
742        assert_eq!(std::mem::align_of::<C32>(), std::mem::align_of::<f32>());
743        assert_eq!(std::mem::align_of::<C64>(), std::mem::align_of::<f64>());
744    }
745
746    #[test]
747    fn complex_cuda_data_type_mapping() {
748        assert_eq!(
749            <C32 as CudaDtype>::cuda_data_type(),
750            cublas_sys::cudaDataType_t::CUDA_C_32F
751        );
752        assert_eq!(
753            <C64 as CudaDtype>::cuda_data_type(),
754            cublas_sys::cudaDataType_t::CUDA_C_64F
755        );
756        assert_eq!(<C32 as CudaDtype>::cuda_type_name(), "cuComplex");
757        assert_eq!(<C64 as CudaDtype>::cuda_type_name(), "cuDoubleComplex");
758        assert_eq!(
759            <C32 as CudaDtype>::cublas_compute_type(),
760            cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
761        );
762        assert_eq!(
763            <C64 as CudaDtype>::cublas_compute_type(),
764            cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_64F
765        );
766    }
767
768    #[test]
769    fn complex_fft_supported_compile_time_check() {
770        fn _check<T: FftSupported>() {}
771        _check::<C32>();
772        _check::<C64>();
773    }
774
775    #[test]
776    fn complex_zero_one_nan_identities() {
777        assert_eq!(<C32 as AccelDtype>::zero(), C32([0.0, 0.0]));
778        assert_eq!(<C32 as AccelDtype>::one(), C32([1.0, 0.0]));
779        assert!(<C32 as AccelDtype>::nan()
780            .map(|n| n.0[0].is_nan() && n.0[1].is_nan())
781            .unwrap_or(false));
782        assert_eq!(<C64 as AccelDtype>::zero(), C64([0.0, 0.0]));
783        assert_eq!(<C64 as AccelDtype>::one(), C64([1.0, 0.0]));
784    }
785
786    #[cfg(feature = "cufft")]
787    #[test]
788    fn complex_round_trips_cufft_sys() {
789        use cudarc::cufft::sys as s;
790        let f = s::float2 { x: 1.5, y: -2.5 };
791        let c: C32 = f.into();
792        assert_eq!(c, C32([1.5, -2.5]));
793        let f2: s::float2 = c.into();
794        assert_eq!(f2.x, 1.5);
795        assert_eq!(f2.y, -2.5);
796
797        let d = s::double2 { x: 7.0, y: 8.0 };
798        let c64: C64 = d.into();
799        assert_eq!(c64, C64([7.0, 8.0]));
800        let d2: s::double2 = c64.into();
801        assert_eq!(d2.x, 7.0);
802        assert_eq!(d2.y, 8.0);
803    }
804}