1use cudarc::cublas::sys as cublas_sys;
25#[cfg(feature = "cudnn")]
26use cudarc::cudnn::sys as cudnn_sys;
27use cudarc::driver::{DeviceRepr, ValidAsZeroBits};
28#[cfg(feature = "nccl")]
29use cudarc::nccl::sys as nccl_sys;
30
31pub use atomr_accel::DType;
35
36pub use atomr_accel::AccelDtype;
39
40pub use atomr_accel::DType as DTypeKind;
42
43#[repr(transparent)]
58#[derive(Copy, Clone, Debug, Default, PartialEq)]
59pub struct C32(pub [f32; 2]);
60
61impl C32 {
62 #[inline]
64 pub const fn new(re: f32, im: f32) -> Self {
65 Self([re, im])
66 }
67 #[inline]
68 pub fn re(self) -> f32 {
69 self.0[0]
70 }
71 #[inline]
72 pub fn im(self) -> f32 {
73 self.0[1]
74 }
75}
76
77#[repr(transparent)]
82#[derive(Copy, Clone, Debug, Default, PartialEq)]
83pub struct C64(pub [f64; 2]);
84
85impl C64 {
86 #[inline]
88 pub const fn new(re: f64, im: f64) -> Self {
89 Self([re, im])
90 }
91 #[inline]
92 pub fn re(self) -> f64 {
93 self.0[0]
94 }
95 #[inline]
96 pub fn im(self) -> f64 {
97 self.0[1]
98 }
99}
100
101#[cfg(feature = "cufft")]
112impl From<cudarc::cufft::sys::float2> for C32 {
113 #[inline]
114 fn from(v: cudarc::cufft::sys::float2) -> Self {
115 C32([v.x, v.y])
116 }
117}
118
119#[cfg(feature = "cufft")]
120impl From<C32> for cudarc::cufft::sys::float2 {
121 #[inline]
122 fn from(v: C32) -> Self {
123 cudarc::cufft::sys::float2 {
124 x: v.0[0],
125 y: v.0[1],
126 }
127 }
128}
129
130#[cfg(feature = "cufft")]
131impl From<cudarc::cufft::sys::double2> for C64 {
132 #[inline]
133 fn from(v: cudarc::cufft::sys::double2) -> Self {
134 C64([v.x, v.y])
135 }
136}
137
138#[cfg(feature = "cufft")]
139impl From<C64> for cudarc::cufft::sys::double2 {
140 #[inline]
141 fn from(v: C64) -> Self {
142 cudarc::cufft::sys::double2 {
143 x: v.0[0],
144 y: v.0[1],
145 }
146 }
147}
148
149unsafe impl DeviceRepr for C32 {}
156unsafe impl ValidAsZeroBits for C32 {}
157unsafe impl DeviceRepr for C64 {}
158unsafe impl ValidAsZeroBits for C64 {}
159
160impl atomr_accel::AccelDtype for C32 {
167 type Scalar = f32;
168 const KIND: DType = DType::F32;
169 const SIZE: usize = 8;
170 const NAME: &'static str = "complex64";
171 #[inline]
172 fn zero() -> Self {
173 C32([0.0, 0.0])
174 }
175 #[inline]
176 fn one() -> Self {
177 C32([1.0, 0.0])
178 }
179 #[inline]
180 fn nan() -> Option<Self> {
181 Some(C32([f32::NAN, f32::NAN]))
182 }
183}
184
185impl atomr_accel::AccelDtype for C64 {
186 type Scalar = f64;
187 const KIND: DType = DType::F64;
188 const SIZE: usize = 16;
189 const NAME: &'static str = "complex128";
190 #[inline]
191 fn zero() -> Self {
192 C64([0.0, 0.0])
193 }
194 #[inline]
195 fn one() -> Self {
196 C64([1.0, 0.0])
197 }
198 #[inline]
199 fn nan() -> Option<Self> {
200 Some(C64([f64::NAN, f64::NAN]))
201 }
202}
203
204#[cfg(feature = "f8")]
205#[repr(transparent)]
206#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
207pub struct F8E4m3(pub u8);
208
209#[cfg(feature = "f8")]
210impl F8E4m3 {
211 #[inline]
212 pub fn from_f32(x: f32) -> Self {
213 Self(atomr_accel::dtype::F8E4m3::from_f32(x).0)
214 }
215 #[inline]
216 pub fn to_f32(self) -> f32 {
217 atomr_accel::dtype::F8E4m3(self.0).to_f32()
218 }
219}
220
221#[cfg(feature = "f8")]
222#[repr(transparent)]
223#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, Hash)]
224pub struct F8E5m2(pub u8);
225
226#[cfg(feature = "f8")]
227impl F8E5m2 {
228 #[inline]
229 pub fn from_f32(x: f32) -> Self {
230 Self(atomr_accel::dtype::F8E5m2::from_f32(x).0)
231 }
232 #[inline]
233 pub fn to_f32(self) -> f32 {
234 atomr_accel::dtype::F8E5m2(self.0).to_f32()
235 }
236}
237
238#[cfg(feature = "f8")]
239impl atomr_accel::AccelDtype for F8E4m3 {
240 type Scalar = f32;
241 const KIND: DType = DType::F8E4m3;
242 const SIZE: usize = 1;
243 const NAME: &'static str = "f8_e4m3";
244 #[inline]
245 fn zero() -> Self {
246 F8E4m3(0x00)
247 }
248 #[inline]
249 fn one() -> Self {
250 F8E4m3(0x38)
251 }
252 #[inline]
253 fn nan() -> Option<Self> {
254 Some(F8E4m3(0x7f))
255 }
256}
257
258#[cfg(feature = "f8")]
259impl atomr_accel::AccelDtype for F8E5m2 {
260 type Scalar = f32;
261 const KIND: DType = DType::F8E5m2;
262 const SIZE: usize = 1;
263 const NAME: &'static str = "f8_e5m2";
264 #[inline]
265 fn zero() -> Self {
266 F8E5m2(0x00)
267 }
268 #[inline]
269 fn one() -> Self {
270 F8E5m2(0x3c)
271 }
272 #[inline]
273 fn nan() -> Option<Self> {
274 Some(F8E5m2(0x7e))
275 }
276}
277
278pub trait CudaDtype: AccelDtype + DeviceRepr + ValidAsZeroBits {
284 fn cuda_data_type() -> cublas_sys::cudaDataType_t;
285 fn cublas_compute_type() -> cublas_sys::cublasComputeType_t;
286 fn cuda_type_name() -> &'static str;
288 #[cfg(feature = "cudnn")]
289 fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t;
290 #[cfg(feature = "nccl")]
291 fn nccl_data_type() -> nccl_sys::ncclDataType_t;
292}
293
294pub trait GemmSupported: CudaDtype {}
296
297pub trait CudnnSupported: CudaDtype {}
299
300pub trait FftSupported: CudaDtype {}
302
303pub trait RngFloatSupported: CudaDtype {}
306
307pub trait NcclReduceSupported: CudaDtype {}
309
310pub trait SolverSupported: CudaDtype {}
313
314pub trait SparseSupported: CudaDtype {}
317
318pub trait TensorSupported: CudaDtype {}
320
321pub trait RngIntSupported: CudaDtype {}
324
325pub trait SparseIndex: AccelDtype {
328 #[cfg(feature = "cusparse")]
329 fn cusparse_index_type() -> cudarc::cusparse::sys::cusparseIndexType_t;
330}
331
332#[cfg(feature = "cusparse")]
333impl SparseIndex for i32 {
334 fn cusparse_index_type() -> cudarc::cusparse::sys::cusparseIndexType_t {
335 cudarc::cusparse::sys::cusparseIndexType_t::CUSPARSE_INDEX_32I
336 }
337}
338#[cfg(feature = "cusparse")]
339impl SparseIndex for i64 {
340 fn cusparse_index_type() -> cudarc::cusparse::sys::cusparseIndexType_t {
341 cudarc::cusparse::sys::cusparseIndexType_t::CUSPARSE_INDEX_64I
342 }
343}
344#[cfg(not(feature = "cusparse"))]
345impl SparseIndex for i32 {}
346#[cfg(not(feature = "cusparse"))]
347impl SparseIndex for i64 {}
348
349pub trait AxpyDotNrm2Supported: CudaDtype {}
352pub trait GemvSupported: CudaDtype {}
353pub trait GerSupported: CudaDtype {}
354pub trait GeamSupported: CudaDtype {}
355pub trait SyrkSupported: CudaDtype {}
356pub trait TrsmSupported: CudaDtype {}
357
358macro_rules! impl_cuda_dtype {
359 (
360 $rust:ty,
361 $cuda:ident,
362 $compute:ident,
363 $name:literal,
364 cudnn: $cudnn:ident,
365 nccl: $($nccl:ident)?
366 ) => {
367 impl CudaDtype for $rust {
368 #[inline]
369 fn cuda_data_type() -> cublas_sys::cudaDataType_t {
370 cublas_sys::cudaDataType_t::$cuda
371 }
372 #[inline]
373 fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
374 cublas_sys::cublasComputeType_t::$compute
375 }
376 #[inline]
377 fn cuda_type_name() -> &'static str { $name }
378 #[cfg(feature = "cudnn")]
379 #[inline]
380 fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
381 cudnn_sys::cudnnDataType_t::$cudnn
382 }
383 #[cfg(feature = "nccl")]
384 #[inline]
385 fn nccl_data_type() -> nccl_sys::ncclDataType_t {
386 impl_cuda_dtype!(@nccl_arm $($nccl)?)
387 }
388 }
389 };
390 (@nccl_arm $variant:ident) => { nccl_sys::ncclDataType_t::$variant };
391 (@nccl_arm) => { panic!("dtype not supported by NCCL") };
392}
393
394impl_cuda_dtype!(f32, CUDA_R_32F, CUBLAS_COMPUTE_32F, "float",
395 cudnn: CUDNN_DATA_FLOAT, nccl: ncclFloat32);
396impl_cuda_dtype!(f64, CUDA_R_64F, CUBLAS_COMPUTE_64F, "double",
397 cudnn: CUDNN_DATA_DOUBLE, nccl: ncclFloat64);
398impl_cuda_dtype!(i8, CUDA_R_8I, CUBLAS_COMPUTE_32I, "char",
399 cudnn: CUDNN_DATA_INT8, nccl: ncclInt8);
400impl_cuda_dtype!(u8, CUDA_R_8U, CUBLAS_COMPUTE_32I, "unsigned char",
401 cudnn: CUDNN_DATA_UINT8, nccl: ncclUint8);
402impl_cuda_dtype!(i32, CUDA_R_32I, CUBLAS_COMPUTE_32I, "int",
403 cudnn: CUDNN_DATA_INT32, nccl: ncclInt32);
404impl_cuda_dtype!(u32, CUDA_R_32U, CUBLAS_COMPUTE_32I, "unsigned int",
405 cudnn: CUDNN_DATA_INT32, nccl: ncclUint32);
406impl_cuda_dtype!(i64, CUDA_R_64I, CUBLAS_COMPUTE_32I, "long long",
407 cudnn: CUDNN_DATA_INT64, nccl: ncclInt64);
408impl_cuda_dtype!(u64, CUDA_R_64U, CUBLAS_COMPUTE_32I, "unsigned long long",
409 cudnn: CUDNN_DATA_INT64, nccl: ncclUint64);
410
411#[cfg(feature = "f16")]
412impl CudaDtype for half::f16 {
413 #[inline]
414 fn cuda_data_type() -> cublas_sys::cudaDataType_t {
415 cublas_sys::cudaDataType_t::CUDA_R_16F
416 }
417 #[inline]
418 fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
419 cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
420 }
421 #[inline]
422 fn cuda_type_name() -> &'static str {
423 "__half"
424 }
425 #[cfg(feature = "cudnn")]
426 #[inline]
427 fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
428 cudnn_sys::cudnnDataType_t::CUDNN_DATA_HALF
429 }
430 #[cfg(feature = "nccl")]
431 #[inline]
432 fn nccl_data_type() -> nccl_sys::ncclDataType_t {
433 nccl_sys::ncclDataType_t::ncclFloat16
434 }
435}
436
437#[cfg(feature = "f16")]
438impl CudaDtype for half::bf16 {
439 #[inline]
440 fn cuda_data_type() -> cublas_sys::cudaDataType_t {
441 cublas_sys::cudaDataType_t::CUDA_R_16BF
442 }
443 #[inline]
444 fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
445 cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
446 }
447 #[inline]
448 fn cuda_type_name() -> &'static str {
449 "__nv_bfloat16"
450 }
451 #[cfg(feature = "cudnn")]
452 #[inline]
453 fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
454 cudnn_sys::cudnnDataType_t::CUDNN_DATA_BFLOAT16
455 }
456 #[cfg(feature = "nccl")]
457 #[inline]
458 fn nccl_data_type() -> nccl_sys::ncclDataType_t {
459 nccl_sys::ncclDataType_t::ncclBfloat16
460 }
461}
462
463impl GemmSupported for f32 {}
464impl GemmSupported for f64 {}
465impl GemmSupported for i8 {}
466impl GemmSupported for i32 {}
467#[cfg(feature = "f16")]
468impl GemmSupported for half::f16 {}
469#[cfg(feature = "f16")]
470impl GemmSupported for half::bf16 {}
471
472#[cfg(feature = "f8")]
473mod fp8_impls {
474 use super::*;
475 use cudarc::driver::{DeviceRepr, ValidAsZeroBits};
476
477 unsafe impl DeviceRepr for F8E4m3 {}
478 unsafe impl ValidAsZeroBits for F8E4m3 {}
479 unsafe impl DeviceRepr for F8E5m2 {}
480 unsafe impl ValidAsZeroBits for F8E5m2 {}
481
482 impl CudaDtype for F8E4m3 {
483 #[inline]
484 fn cuda_data_type() -> cublas_sys::cudaDataType_t {
485 cublas_sys::cudaDataType_t::CUDA_R_8F_E4M3
486 }
487 #[inline]
488 fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
489 cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
490 }
491 #[inline]
492 fn cuda_type_name() -> &'static str {
493 "__nv_fp8_e4m3"
494 }
495 #[cfg(feature = "cudnn")]
496 #[inline]
497 fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
498 cudnn_sys::cudnnDataType_t::CUDNN_DATA_FP8_E4M3
499 }
500 #[cfg(feature = "nccl")]
501 #[inline]
502 fn nccl_data_type() -> nccl_sys::ncclDataType_t {
503 nccl_sys::ncclDataType_t::ncclFloat8e4m3
504 }
505 }
506
507 impl CudaDtype for F8E5m2 {
508 #[inline]
509 fn cuda_data_type() -> cublas_sys::cudaDataType_t {
510 cublas_sys::cudaDataType_t::CUDA_R_8F_E5M2
511 }
512 #[inline]
513 fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
514 cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
515 }
516 #[inline]
517 fn cuda_type_name() -> &'static str {
518 "__nv_fp8_e5m2"
519 }
520 #[cfg(feature = "cudnn")]
521 #[inline]
522 fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
523 cudnn_sys::cudnnDataType_t::CUDNN_DATA_FP8_E5M2
524 }
525 #[cfg(feature = "nccl")]
526 #[inline]
527 fn nccl_data_type() -> nccl_sys::ncclDataType_t {
528 nccl_sys::ncclDataType_t::ncclFloat8e5m2
529 }
530 }
531
532 impl GemmSupported for F8E4m3 {}
533 impl GemmSupported for F8E5m2 {}
534 impl CudnnSupported for F8E4m3 {}
535 impl CudnnSupported for F8E5m2 {}
536 impl NcclReduceSupported for F8E4m3 {}
537 impl NcclReduceSupported for F8E5m2 {}
538}
539
540impl CudaDtype for C32 {
541 #[inline]
542 fn cuda_data_type() -> cublas_sys::cudaDataType_t {
543 cublas_sys::cudaDataType_t::CUDA_C_32F
544 }
545 #[inline]
546 fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
547 cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
548 }
549 #[inline]
550 fn cuda_type_name() -> &'static str {
551 "cuComplex"
552 }
553 #[cfg(feature = "cudnn")]
554 #[inline]
555 fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
556 panic!("C32 is not a cuDNN tensor element type");
560 }
561 #[cfg(feature = "nccl")]
562 #[inline]
563 fn nccl_data_type() -> nccl_sys::ncclDataType_t {
564 panic!("C32 is not an NCCL reduce element type");
568 }
569}
570
571impl CudaDtype for C64 {
572 #[inline]
573 fn cuda_data_type() -> cublas_sys::cudaDataType_t {
574 cublas_sys::cudaDataType_t::CUDA_C_64F
575 }
576 #[inline]
577 fn cublas_compute_type() -> cublas_sys::cublasComputeType_t {
578 cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_64F
579 }
580 #[inline]
581 fn cuda_type_name() -> &'static str {
582 "cuDoubleComplex"
583 }
584 #[cfg(feature = "cudnn")]
585 #[inline]
586 fn cudnn_data_type() -> cudnn_sys::cudnnDataType_t {
587 panic!("C64 is not a cuDNN tensor element type");
588 }
589 #[cfg(feature = "nccl")]
590 #[inline]
591 fn nccl_data_type() -> nccl_sys::ncclDataType_t {
592 panic!("C64 is not an NCCL reduce element type");
593 }
594}
595
596impl CudnnSupported for f32 {}
597impl CudnnSupported for f64 {}
598impl CudnnSupported for i8 {}
599impl CudnnSupported for u8 {}
600impl CudnnSupported for i32 {}
601impl CudnnSupported for i64 {}
602#[cfg(feature = "f16")]
603impl CudnnSupported for half::f16 {}
604#[cfg(feature = "f16")]
605impl CudnnSupported for half::bf16 {}
606
607impl FftSupported for f32 {}
608impl FftSupported for f64 {}
609impl FftSupported for C32 {}
610impl FftSupported for C64 {}
611#[cfg(feature = "f16")]
612impl FftSupported for half::f16 {}
613
614impl RngFloatSupported for f32 {}
615impl RngFloatSupported for f64 {}
616
617impl RngIntSupported for u32 {}
618impl RngIntSupported for u64 {}
619
620impl AxpyDotNrm2Supported for f32 {}
621impl AxpyDotNrm2Supported for f64 {}
622#[cfg(feature = "f16")]
623impl AxpyDotNrm2Supported for half::f16 {}
624#[cfg(feature = "f16")]
625impl AxpyDotNrm2Supported for half::bf16 {}
626
627impl GemvSupported for f32 {}
628impl GemvSupported for f64 {}
629
630impl GerSupported for f32 {}
631impl GerSupported for f64 {}
632
633impl GeamSupported for f32 {}
634impl GeamSupported for f64 {}
635
636impl SyrkSupported for f32 {}
637impl SyrkSupported for f64 {}
638
639impl TrsmSupported for f32 {}
640impl TrsmSupported for f64 {}
641
642impl NcclReduceSupported for f32 {}
643impl NcclReduceSupported for f64 {}
644impl NcclReduceSupported for i8 {}
645impl NcclReduceSupported for u8 {}
646impl NcclReduceSupported for i32 {}
647impl NcclReduceSupported for u32 {}
648impl NcclReduceSupported for i64 {}
649impl NcclReduceSupported for u64 {}
650#[cfg(feature = "f16")]
651impl NcclReduceSupported for half::f16 {}
652#[cfg(feature = "f16")]
653impl NcclReduceSupported for half::bf16 {}
654
655impl SolverSupported for f32 {}
656impl SolverSupported for f64 {}
657
658impl SparseSupported for f32 {}
659impl SparseSupported for f64 {}
660#[cfg(feature = "f16")]
661impl SparseSupported for half::f16 {}
662#[cfg(feature = "f16")]
663impl SparseSupported for half::bf16 {}
664
665impl TensorSupported for f32 {}
666impl TensorSupported for f64 {}
667#[cfg(feature = "f16")]
668impl TensorSupported for half::f16 {}
669#[cfg(feature = "f16")]
670impl TensorSupported for half::bf16 {}
671
672#[cfg(test)]
673mod tests {
674 use super::*;
675 use atomr_accel::DType;
676
677 #[test]
678 fn cuda_data_type_round_trip() {
679 assert_eq!(<f32 as AccelDtype>::KIND, DType::F32);
680 assert_eq!(
681 <f32 as CudaDtype>::cuda_data_type(),
682 cublas_sys::cudaDataType_t::CUDA_R_32F
683 );
684 assert_eq!(
685 <f64 as CudaDtype>::cuda_data_type(),
686 cublas_sys::cudaDataType_t::CUDA_R_64F
687 );
688 assert_eq!(<f32 as CudaDtype>::cuda_type_name(), "float");
689 assert_eq!(<f64 as CudaDtype>::cuda_type_name(), "double");
690 }
691
692 #[test]
693 fn integer_compute_types() {
694 assert_eq!(
695 <i32 as CudaDtype>::cublas_compute_type(),
696 cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32I
697 );
698 }
699
700 #[cfg(feature = "f16")]
701 #[test]
702 fn f16_mappings() {
703 assert_eq!(
704 <half::f16 as CudaDtype>::cuda_data_type(),
705 cublas_sys::cudaDataType_t::CUDA_R_16F
706 );
707 assert_eq!(
708 <half::bf16 as CudaDtype>::cuda_data_type(),
709 cublas_sys::cudaDataType_t::CUDA_R_16BF
710 );
711 assert_eq!(<half::f16 as CudaDtype>::cuda_type_name(), "__half");
712 assert_eq!(<half::bf16 as CudaDtype>::cuda_type_name(), "__nv_bfloat16");
713 }
714
715 fn _assert_capability_bounds<G, C, F, N>()
716 where
717 G: GemmSupported,
718 C: CudnnSupported,
719 F: FftSupported,
720 N: NcclReduceSupported,
721 {
722 }
723
724 #[test]
725 fn capability_compile_time_check() {
726 _assert_capability_bounds::<f32, f32, f32, f32>();
727 _assert_capability_bounds::<f64, f64, f64, f64>();
728 }
729
730 #[test]
731 fn complex_dtype_size_and_layout() {
732 assert_eq!(<C32 as AccelDtype>::SIZE, 8);
733 assert_eq!(<C64 as AccelDtype>::SIZE, 16);
734 assert_eq!(<C32 as AccelDtype>::KIND, DType::F32);
735 assert_eq!(<C64 as AccelDtype>::KIND, DType::F64);
736 assert_eq!(<C32 as AccelDtype>::NAME, "complex64");
737 assert_eq!(<C64 as AccelDtype>::NAME, "complex128");
738
739 assert_eq!(std::mem::size_of::<C32>(), 8);
741 assert_eq!(std::mem::size_of::<C64>(), 16);
742 assert_eq!(std::mem::align_of::<C32>(), std::mem::align_of::<f32>());
743 assert_eq!(std::mem::align_of::<C64>(), std::mem::align_of::<f64>());
744 }
745
746 #[test]
747 fn complex_cuda_data_type_mapping() {
748 assert_eq!(
749 <C32 as CudaDtype>::cuda_data_type(),
750 cublas_sys::cudaDataType_t::CUDA_C_32F
751 );
752 assert_eq!(
753 <C64 as CudaDtype>::cuda_data_type(),
754 cublas_sys::cudaDataType_t::CUDA_C_64F
755 );
756 assert_eq!(<C32 as CudaDtype>::cuda_type_name(), "cuComplex");
757 assert_eq!(<C64 as CudaDtype>::cuda_type_name(), "cuDoubleComplex");
758 assert_eq!(
759 <C32 as CudaDtype>::cublas_compute_type(),
760 cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_32F
761 );
762 assert_eq!(
763 <C64 as CudaDtype>::cublas_compute_type(),
764 cublas_sys::cublasComputeType_t::CUBLAS_COMPUTE_64F
765 );
766 }
767
768 #[test]
769 fn complex_fft_supported_compile_time_check() {
770 fn _check<T: FftSupported>() {}
771 _check::<C32>();
772 _check::<C64>();
773 }
774
775 #[test]
776 fn complex_zero_one_nan_identities() {
777 assert_eq!(<C32 as AccelDtype>::zero(), C32([0.0, 0.0]));
778 assert_eq!(<C32 as AccelDtype>::one(), C32([1.0, 0.0]));
779 assert!(<C32 as AccelDtype>::nan()
780 .map(|n| n.0[0].is_nan() && n.0[1].is_nan())
781 .unwrap_or(false));
782 assert_eq!(<C64 as AccelDtype>::zero(), C64([0.0, 0.0]));
783 assert_eq!(<C64 as AccelDtype>::one(), C64([1.0, 0.0]));
784 }
785
786 #[cfg(feature = "cufft")]
787 #[test]
788 fn complex_round_trips_cufft_sys() {
789 use cudarc::cufft::sys as s;
790 let f = s::float2 { x: 1.5, y: -2.5 };
791 let c: C32 = f.into();
792 assert_eq!(c, C32([1.5, -2.5]));
793 let f2: s::float2 = c.into();
794 assert_eq!(f2.x, 1.5);
795 assert_eq!(f2.y, -2.5);
796
797 let d = s::double2 { x: 7.0, y: 8.0 };
798 let c64: C64 = d.into();
799 assert_eq!(c64, C64([7.0, 8.0]));
800 let d2: s::double2 = c64.into();
801 assert_eq!(d2.x, 7.0);
802 assert_eq!(d2.y, 8.0);
803 }
804}