1
use std::fmt::Debug;
2
use std::marker::PhantomData;
3

            
4
use async_trait::async_trait;
5
use bonsaidb_core::api::{self, Api, ApiError, Infallible};
6
use bonsaidb_core::arc_bytes::serde::Bytes;
7
use bonsaidb_core::permissions::PermissionDenied;
8
use bonsaidb_core::schema::{InsertError, InvalidNameError};
9

            
10
use crate::{Backend, ConnectedClient, CustomServer, Error, NoBackend};
11

            
12
/// A trait that can dispatch requests for a [`Api`].
13
#[async_trait]
14
pub trait Handler<Api: api::Api, B: Backend = NoBackend>: Send + Sync {
15
    /// Returns a dispatcher to handle custom api requests. The parameters are
16
    /// provided so that they can be cloned if needed during the processing of
17
    /// requests.
18
    async fn handle(session: HandlerSession<'_, B>, request: Api) -> HandlerResult<Api>;
19
}
20

            
21
/// A session for a [`Handler`], providing ways to access the server and
22
/// connected client.
23
pub struct HandlerSession<'a, B: Backend = NoBackend> {
24
    /// The [`Handler`]'s server reference. This server instance is not limited
25
    /// to the permissions of the connected user.
26
    pub server: &'a CustomServer<B>,
27
    /// The connected client's server reference. This server instance will
28
    /// reject any database operations that the connected client is not
29
    /// explicitly authorized to perform based on its authentication state.
30
    pub as_client: CustomServer<B>,
31
    /// The connected client making the API request.
32
    pub client: &'a ConnectedClient<B>,
33
}
34

            
35
#[async_trait]
36
pub(crate) trait AnyHandler<B: Backend>: Send + Sync + Debug {
37
    async fn handle(&self, session: HandlerSession<'_, B>, request: &[u8]) -> Result<Bytes, Error>;
38
}
39

            
40
pub(crate) struct AnyWrapper<D: Handler<A, B>, B: Backend, A: Api>(
41
    pub(crate) PhantomData<(D, B, A)>,
42
);
43

            
44
impl<D, B, A> Debug for AnyWrapper<D, B, A>
45
where
46
    D: Handler<A, B>,
47
    B: Backend,
48
    A: Api,
49
{
50
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51
        f.debug_tuple("AnyWrapper").finish()
52
    }
53
}
54

            
55
#[async_trait]
56
impl<T, B, A> AnyHandler<B> for AnyWrapper<T, B, A>
57
where
58
    B: Backend,
59
    T: Handler<A, B>,
60
    A: Api,
61
{
62
72049
    async fn handle(&self, client: HandlerSession<'_, B>, request: &[u8]) -> Result<Bytes, Error> {
63
72049
        let request = pot::from_slice(request)?;
64
132364
        let response = match T::handle(client, request).await {
65
71482
            Ok(response) => Ok(response),
66
            Err(HandlerError::Api(err)) => Err(err),
67
563
            Err(HandlerError::Server(err)) => return Err(err),
68
        };
69
71482
        Ok(Bytes::from(pot::to_vec(&response)?))
70
216143
    }
71
}
72

            
73
/// An error that can occur inside of a [`Backend`] function.
74
#[derive(thiserror::Error, Debug)]
75
pub enum HandlerError<E: ApiError = Infallible> {
76
    /// An api-related error.
77
    #[error("api error: {0}")]
78
    Api(E),
79
    /// A server-related error.
80
    #[error("server error: {0}")]
81
    Server(#[from] Error),
82
}
83

            
84
impl<E: ApiError> From<PermissionDenied> for HandlerError<E> {
85
    fn from(permission_denied: PermissionDenied) -> Self {
86
        Self::Server(Error::from(permission_denied))
87
    }
88
}
89

            
90
impl<E: ApiError> From<bonsaidb_core::Error> for HandlerError<E> {
91
563
    fn from(err: bonsaidb_core::Error) -> Self {
92
563
        Self::Server(Error::from(err))
93
563
    }
94
}
95

            
96
impl<E: ApiError> From<bonsaidb_local::Error> for HandlerError<E> {
97
    fn from(err: bonsaidb_local::Error) -> Self {
98
        Self::Server(Error::from(err))
99
    }
100
}
101

            
102
impl<E: ApiError> From<InvalidNameError> for HandlerError<E> {
103
    fn from(err: InvalidNameError) -> Self {
104
        Self::Server(Error::from(err))
105
    }
106
}
107

            
108
#[cfg(feature = "websockets")]
109
impl<E: ApiError> From<bincode::Error> for HandlerError<E> {
110
    fn from(other: bincode::Error) -> Self {
111
        Self::Server(Error::from(bonsaidb_local::Error::from(other)))
112
    }
113
}
114

            
115
impl<E: ApiError> From<pot::Error> for HandlerError<E> {
116
    fn from(other: pot::Error) -> Self {
117
        Self::Server(Error::from(other))
118
    }
119
}
120

            
121
impl<E: ApiError> From<std::io::Error> for HandlerError<E> {
122
    fn from(err: std::io::Error) -> Self {
123
        Self::Server(Error::from(err))
124
    }
125
}
126

            
127
impl<T, E> From<InsertError<T>> for HandlerError<E>
128
where
129
    E: ApiError,
130
{
131
    fn from(error: InsertError<T>) -> Self {
132
        Self::Server(Error::from(error.error))
133
    }
134
}
135

            
136
/// The return type from a [`Handler`]'s [`handle()`](Handler::handle)
137
/// function.
138
pub type HandlerResult<Api> =
139
    Result<<Api as api::Api>::Response, HandlerError<<Api as api::Api>::Error>>;