Skip to main content

atomr_accel_cuda/kernel/
mod.rs

1//! Kernel-actor wrappers around CUDA library handles (§3.2).
2//!
3//! Each library actor follows a uniform shape:
4//!
5//! * a `Real { handle, stream, completion, state, … }` variant holding
6//!   the cudarc handle plus the per-actor caches it needs;
7//! * a `Mock` variant for GPU-free tests;
8//! * a `props(stream, allocator, completion, state)` constructor that
9//!   panics with `"ContextPoisoned: <Lib>::new failed: …"` if the
10//!   handle can't be created, so the supervisor restarts;
11//! * a `mock_props()` constructor that replies `Unrecoverable("…not
12//!   supported in mock mode")` to every variant.
13//!
14//! The shared kernel-enqueue body lives in
15//! [`envelope::run_kernel`] — every library actor calls it instead of
16//! reimplementing the validate / enqueue / spawn-completion-await /
17//! reply / drop-keep-alive sequence.
18//!
19//! F2 ships: `BlasActor`, `CudnnActor`, `FftActor`, `RngActor`.
20//! F3 adds: `SolverActor`, `BlasLtActor`, `NvrtcActor`.
21//! F4 adds: `CollectiveActor` (NCCL).
22
23pub mod dispatch;
24pub mod envelope;
25pub mod record;
26
27#[cfg(feature = "cublaslt")]
28pub use dispatch::{BlasLtDispatch, BlasLtDispatchCtx};
29#[cfg(feature = "nccl")]
30pub use dispatch::{CollectiveDispatch, CollectiveDispatchCtx};
31#[cfg(feature = "cudnn")]
32pub use dispatch::{CudnnDispatch, CudnnDispatchCtx};
33pub use dispatch::{
34    DevSliceArg, GemmDispatchCtx, NvrtcDispatchCtx, NvrtcLaunchDispatch, RngDispatch, ScalarArg,
35};
36#[cfg(feature = "cufft")]
37pub use dispatch::{FftDispatch, FftDispatchCtx};
38#[cfg(feature = "cusparse")]
39pub use dispatch::{SendSparseHandle, SparseDispatch, SparseDispatchCtx, SparseOp};
40#[cfg(feature = "cutensor")]
41pub use dispatch::{TensorDispatch, TensorDispatchCtx, WorkspacePool};
42
43pub mod blas;
44
45pub use blas::{
46    AsumRequest, AxpyRequest, BlasActor, BlasMsg, CopyRequest, DotRequest, GeamRequest,
47    GemmRequest, GemmStridedBatchedRequest, GemvRequest, GerRequest, IamaxRequest, IaminRequest,
48    Nrm2Request, RotRequest, ScalRequest, SwapRequest, SyrkRequest, TrsmRequest,
49};
50
51#[cfg(feature = "cudnn")]
52pub mod cudnn;
53#[cfg(feature = "cudnn")]
54pub use cudnn::{
55    ActivationFwdRequest, ActivationKind, ActivationRequest, AttentionMask, AttentionParams,
56    BatchNormRequest, ConvBwdDataRequest, ConvBwdFilterRequest, ConvDescParams, ConvForwardRequest,
57    ConvFwdRequest, ConvParams, CudnnActor, CudnnMsg, DropoutFwdRequest, EpilogueKind,
58    GroupNormRequest, InstanceNormRequest, LayerNormRequest, LrnFwdRequest, LrnParams,
59    MultiHeadAttnBwdRequest, MultiHeadAttnFwdRequest, NormBwdRequest, NormMode, NormPhase,
60    PoolBwdRequest, PoolFwdRequest, PoolMode, PoolParams, RnnBwdRequest, RnnDirection,
61    RnnFwdRequest, RnnMode, RnnParams, SoftmaxFwdRequest, SoftmaxMode, SoftmaxRequest,
62    TensorLayout,
63};
64
65#[cfg(feature = "cufft")]
66pub mod fft;
67#[cfg(feature = "cufft")]
68pub use fft::{
69    FftActor, FftCallbackKind, FftDirection, FftKind, FftMsg, FftPlan, FftPlanMany, FftRequest,
70    PlanKey,
71};
72
73#[cfg(feature = "curand")]
74pub mod rng;
75#[cfg(feature = "curand")]
76pub use rng::{Distribution, FillRequest, RngActor, RngGeneratorKind, RngMsg};
77
78#[cfg(feature = "cusolver")]
79pub mod solver;
80#[cfg(feature = "cusolver")]
81pub use solver::{
82    CholeskyRequest, GesvdjBatchedRequest, GetrfBatchedRequest, HegvdRequest, LuRequest,
83    LuSolveRequest, PotrfBatchedRequest, QrRequest, SolverActor, SolverDispatch, SolverMsg,
84    SvdRequest, SyevdRequest, SygvdRequest, Uplo,
85};
86#[cfg(all(feature = "cusolver", feature = "cusolver-sp"))]
87pub use solver::{SparseCholeskyRequest, SparseLuRequest, SparseQrRequest};
88
89#[cfg(feature = "cublaslt")]
90pub mod blas_lt;
91#[cfg(feature = "cublaslt")]
92pub use blas_lt::{
93    Activation, BlasLtActor, BlasLtMsg, Epilogue, HeuristicCacheRef, MatmulRequest, ScaleSet,
94    WorkspacePool as BlasLtWorkspacePool,
95};
96
97#[cfg(feature = "nvrtc")]
98pub mod nvrtc;
99#[cfg(feature = "nvrtc")]
100pub use nvrtc::{KernelArg, KernelHandle, NvrtcActor, NvrtcMsg, NvrtcOpts};
101
102#[cfg(feature = "nccl")]
103pub mod collective;
104#[cfg(feature = "nccl")]
105pub use collective::{
106    AllGatherRequest, AllReduceRequest, AllToAllRequest, AllToAllvRequest, BroadcastRequest,
107    CollectiveActor, CollectiveMsg, GroupGuard, NcclCapabilities, NcclReduceSupported, PreMulSumOp,
108    RecvRequest, ReduceOp, ReduceRequest, ReduceScatterRequest, SendRequest,
109};
110
111#[cfg(feature = "cusparse")]
112mod sparse;
113#[cfg(feature = "cusparse")]
114pub use sparse::{CsrMatrix, SparseActor, SparseMsg};
115
116#[cfg(feature = "cutensor")]
117pub mod tensor;
118#[cfg(feature = "cutensor")]
119pub use tensor::{
120    ComputeDesc, ContractRequest, ElementwiseBinaryRequest, ElementwiseTrinaryRequest, OperandSpec,
121    PermutationRequest, ReductionRequest, TensorActor, TensorMsg, TensorSpec,
122};