use std::collections::HashMap;
use std::net::SocketAddr;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use async_lock::{Mutex, MutexGuard};
use bonsaidb_core::api;
use bonsaidb_core::api::ApiName;
use bonsaidb_core::arc_bytes::serde::Bytes;
use bonsaidb_core::connection::{Session, SessionId};
use bonsaidb_core::networking::MessageReceived;
use bonsaidb_core::pubsub::{Receiver, Subscriber as _};
use bonsaidb_local::Subscriber;
use bonsaidb_utils::fast_async_lock;
use derive_where::derive_where;
use flume::Sender;
use parking_lot::RwLock;
use crate::{Backend, CustomServer, Error, NoBackend};
#[derive(Debug, PartialEq, Eq)]
pub enum Transport {
Bonsai,
#[cfg(feature = "websockets")]
WebSocket,
}
#[derive(Debug)]
#[derive_where(Clone)]
pub struct ConnectedClient<B: Backend = NoBackend> {
data: Arc<Data<B>>,
}
#[derive(Debug)]
struct Data<B: Backend = NoBackend> {
id: u32,
sessions: RwLock<HashMap<Option<SessionId>, ClientSession>>,
address: SocketAddr,
transport: Transport,
response_sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
client_data: Mutex<Option<B::ClientData>>,
connected: AtomicBool,
}
#[derive(Debug)]
struct ClientSession {
session: Session,
subscribers: HashMap<u64, Subscriber>,
}
impl<B: Backend> ConnectedClient<B> {
#[must_use]
pub fn address(&self) -> &SocketAddr {
&self.data.address
}
#[must_use]
pub fn transport(&self) -> &Transport {
&self.data.transport
}
#[must_use]
pub fn connected(&self) -> bool {
self.data.connected.load(Ordering::Relaxed)
}
pub(crate) fn set_disconnected(&self) {
self.data.connected.store(false, Ordering::Relaxed);
}
pub(crate) fn logged_in_as(&self, session: Session) {
let mut sessions = self.data.sessions.write();
sessions.insert(
session.id,
ClientSession {
session,
subscribers: HashMap::default(),
},
);
}
pub(crate) fn log_out(&self, session: SessionId) -> Option<Session> {
let mut sessions = self.data.sessions.write();
sessions.remove(&Some(session)).map(|cs| cs.session)
}
pub fn send<Api: api::Api>(
&self,
session: Option<&Session>,
response: &Api::Response,
) -> Result<(), Error> {
let encoded = pot::to_vec(&Result::<&Api::Response, Api::Error>::Ok(response))?;
self.data.response_sender.send((
session.and_then(|session| session.id),
Api::name(),
Bytes::from(encoded),
))?;
Ok(())
}
pub async fn client_data(&self) -> LockedClientDataGuard<'_, B::ClientData> {
LockedClientDataGuard(fast_async_lock!(self.data.client_data))
}
#[must_use]
pub fn session(&self, session_id: Option<SessionId>) -> Option<Session> {
let sessions = self.data.sessions.read();
sessions.get(&session_id).map(|data| data.session.clone())
}
#[must_use]
pub fn all_sessions<C: FromIterator<Session>>(&self) -> C {
let sessions = self.data.sessions.read();
sessions.values().map(|s| s.session.clone()).collect()
}
pub(crate) fn register_subscriber(
&self,
subscriber: Subscriber,
session_id: Option<SessionId>,
) {
let subscriber_id = subscriber.id();
let receiver = subscriber.receiver().clone();
{
let mut sessions = self.data.sessions.write();
if let Some(client_session) = sessions.get_mut(&session_id) {
client_session
.subscribers
.insert(subscriber.id(), subscriber);
} else {
return;
}
}
let task_self = self.clone();
tokio::task::spawn(async move {
task_self
.forward_notifications_for(session_id, subscriber_id, receiver)
.await;
});
}
pub async fn set_client_data(&self, data: B::ClientData) {
let mut client_data = fast_async_lock!(self.data.client_data);
*client_data = Some(data);
}
async fn forward_notifications_for(
&self,
session_id: Option<SessionId>,
subscriber_id: u64,
receiver: Receiver,
) {
let session = self.session(session_id);
while let Ok(message) = receiver.receive_async().await {
if self
.send::<MessageReceived>(
session.as_ref(),
&MessageReceived {
subscriber_id,
topic: Bytes::from(message.topic.0.into_vec()),
payload: Bytes::from(&message.payload[..]),
},
)
.is_err()
{
break;
}
}
}
pub(crate) fn subscribe_by_id(
&self,
subscriber_id: u64,
topic: Bytes,
check_session_id: Option<SessionId>,
) -> Result<(), crate::Error> {
let mut sessions = self.data.sessions.write();
if let Some(client_session) = sessions.get_mut(&check_session_id) {
if let Some(subscriber) = client_session.subscribers.get(&subscriber_id) {
subscriber.subscribe_to_bytes(topic.0)?;
Ok(())
} else {
Err(Error::other(
"bonsaidb-server pubsub",
"invalid subscriber id",
))
}
} else {
Err(Error::other("bonsaidb-server auth", "invalid session id"))
}
}
pub(crate) fn unsubscribe_by_id(
&self,
subscriber_id: u64,
topic: &[u8],
check_session_id: Option<SessionId>,
) -> Result<(), crate::Error> {
let mut sessions = self.data.sessions.write();
if let Some(client_session) = sessions.get_mut(&check_session_id) {
if let Some(subscriber) = client_session.subscribers.get(&subscriber_id) {
subscriber.unsubscribe_from_bytes(topic)?;
Ok(())
} else {
Err(Error::other(
"bonsaidb-server pubsub",
"invalid subscriber id",
))
}
} else {
Err(Error::other("bonsaidb-server auth", "invalid session id"))
}
}
pub(crate) fn unregister_subscriber_by_id(
&self,
subscriber_id: u64,
check_session_id: Option<SessionId>,
) -> Result<(), crate::Error> {
let mut sessions = self.data.sessions.write();
if let Some(client_session) = sessions.get_mut(&check_session_id) {
if client_session.subscribers.remove(&subscriber_id).is_some() {
Ok(())
} else {
Err(Error::other(
"bonsaidb-server pubsub",
"invalid subscriber id",
))
}
} else {
Err(Error::other("bonsaidb-server auth", "invalid session id"))
}
}
}
pub struct LockedClientDataGuard<'client, ClientData>(MutexGuard<'client, Option<ClientData>>);
impl<'client, ClientData> Deref for LockedClientDataGuard<'client, ClientData> {
type Target = Option<ClientData>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'client, ClientData> DerefMut for LockedClientDataGuard<'client, ClientData> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[derive(Debug)]
pub struct OwnedClient<B: Backend> {
client: ConnectedClient<B>,
runtime: Arc<tokio::runtime::Handle>,
server: Option<CustomServer<B>>,
}
impl<B: Backend> OwnedClient<B> {
pub(crate) fn new(
id: u32,
address: SocketAddr,
transport: Transport,
response_sender: Sender<(Option<SessionId>, ApiName, Bytes)>,
server: CustomServer<B>,
default_session: Session,
) -> Self {
let mut session = HashMap::new();
session.insert(
None,
ClientSession {
session: default_session,
subscribers: HashMap::default(),
},
);
Self {
client: ConnectedClient {
data: Arc::new(Data {
id,
address,
transport,
response_sender,
sessions: RwLock::new(session),
client_data: Mutex::default(),
connected: AtomicBool::new(true),
}),
},
runtime: Arc::new(tokio::runtime::Handle::current()),
server: Some(server),
}
}
pub fn clone(&self) -> ConnectedClient<B> {
self.client.clone()
}
}
impl<B: Backend> Drop for OwnedClient<B> {
fn drop(&mut self) {
let id = self.client.data.id;
let server = self.server.take().unwrap();
self.runtime
.spawn(async move { server.disconnect_client(id).await });
}
}
impl<B: Backend> Deref for OwnedClient<B> {
type Target = ConnectedClient<B>;
fn deref(&self) -> &Self::Target {
&self.client
}
}