use bonsaidb_core::networking::{Payload, CURRENT_PROTOCOL_VERSION};
use futures::{SinkExt, StreamExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tungstenite::tungstenite::Message;
use crate::server::connected_client::OwnedClient;
use crate::server::shutdown::{ShutdownState, ShutdownStateWatcher};
use crate::{Backend, CustomServer, Error, Transport};
impl<B: Backend> CustomServer<B> {
pub async fn listen_for_websockets_on<T: tokio::net::ToSocketAddrs + Send + Sync>(
&self,
addr: T,
with_tls: bool,
) -> Result<(), Error> {
if with_tls {
self.listen_for_secure_tcp_on(addr, ()).await
} else {
self.listen_for_tcp_on(addr, ()).await
}
}
pub(crate) async fn handle_raw_websocket_connection<
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
>(
&self,
connection: S,
peer_address: std::net::SocketAddr,
) -> Result<(), Error> {
let stream = tokio_tungstenite::accept_hdr_async(connection, VersionChecker).await?;
self.handle_websocket(stream, peer_address).await;
Ok(())
}
#[cfg(feature = "hyper")]
pub fn upgrade_websocket(
&self,
peer_address: std::net::SocketAddr,
mut request: hyper::Request<hyper::Body>,
) -> hyper::Response<hyper::Body> {
use hyper::header::{
HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, UPGRADE,
};
use hyper::StatusCode;
use tokio_tungstenite::tungstenite::protocol::Role;
use tokio_tungstenite::WebSocketStream;
let mut response = hyper::Response::new(hyper::Body::empty());
if !request.headers().contains_key(UPGRADE) {
*response.status_mut() = StatusCode::BAD_REQUEST;
return response;
}
let Some(sec_websocket_key) = request.headers_mut().remove(SEC_WEBSOCKET_KEY) else {
*response.status_mut() = StatusCode::BAD_REQUEST;
return response;
};
let task_self = self.clone();
tokio::spawn(async move {
match hyper::upgrade::on(&mut request).await {
Ok(upgraded) => {
let ws = WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await;
task_self.handle_websocket(ws, peer_address).await;
}
Err(err) => {
log::error!("Error upgrading websocket: {:?}", err);
}
}
});
*response.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
response
.headers_mut()
.insert(UPGRADE, HeaderValue::from_static("websocket"));
response
.headers_mut()
.insert(CONNECTION, HeaderValue::from_static("upgrade"));
response.headers_mut().insert(
SEC_WEBSOCKET_ACCEPT,
compute_websocket_accept_header(sec_websocket_key.as_bytes()),
);
response
}
pub async fn handle_websocket<
S: futures::Stream<Item = Result<tokio_tungstenite::tungstenite::Message, E>>
+ futures::Sink<tokio_tungstenite::tungstenite::Message>
+ Send
+ 'static,
E: std::fmt::Debug + Send,
>(
&self,
connection: S,
peer_address: std::net::SocketAddr,
) {
let mut shutdown = self
.data
.shutdown
.watcher()
.await
.expect("watcher shut down");
let (mut sender, mut receiver) = connection.split();
let (response_sender, response_receiver) = flume::unbounded();
let (message_sender, message_receiver) = flume::unbounded();
let (api_response_sender, api_response_receiver) = flume::unbounded();
let Some(client) = self
.initialize_client(Transport::WebSocket, peer_address, api_response_sender)
.await
else {
return;
};
let task_sender = response_sender.clone();
tokio::spawn(async move {
while let Ok((session_id, name, value)) = api_response_receiver.recv_async().await {
if task_sender
.send(Payload {
id: None,
session_id,
name,
value: Ok(value),
})
.is_err()
{
break;
}
}
});
tokio::spawn(async move {
while let Ok(response) = message_receiver.recv_async().await {
if sender.send(response).await.is_err() {
break;
}
}
Result::<(), Error>::Ok(())
});
let task_sender = message_sender.clone();
tokio::spawn(async move {
while let Ok(response) = response_receiver.recv_async().await {
if task_sender
.send(Message::Binary(bincode::serialize(&response)?))
.is_err()
{
break;
}
}
Result::<(), Error>::Ok(())
});
let (request_sender, request_receiver) =
flume::bounded::<Payload>(self.data.client_simultaneous_request_limit);
self.spawn_client_request_handler(client, request_receiver, response_sender, &shutdown);
loop {
tokio::select! {
payload = receiver.next() => {
if let Some(payload) = payload {
match payload {
Ok(Message::Binary(binary)) => match bincode::deserialize::<Payload>(&binary) {
Ok(payload) => drop(request_sender.send_async(payload).await),
Err(err) => {
log::error!("[server] error decoding message: {:?}", err);
break;
}
},
Ok(Message::Close(_)) => break,
Ok(Message::Ping(payload)) => {
drop(message_sender.send(Message::Pong(payload)));
}
other => {
log::error!("[server] unexpected message: {:?}", other);
break;
}
}
} else {
return;
}
},
shutdown = shutdown.wait_for_shutdown() => {
if matches!(shutdown, ShutdownState::Shutdown) {
return;
}
}
}
}
}
fn spawn_client_request_handler(
&self,
client: OwnedClient<B>,
request_receiver: flume::Receiver<Payload>,
response_sender: flume::Sender<Payload>,
shutdown: &ShutdownStateWatcher,
) {
tokio::spawn({
let task_self = self.clone();
let shutdown = shutdown.clone();
async move {
task_self
.handle_client_requests(
client.clone(),
request_receiver,
response_sender,
shutdown,
)
.await;
}
});
}
}
#[cfg(feature = "hyper")]
fn compute_websocket_accept_header(key: &[u8]) -> hyper::header::HeaderValue {
use base64::engine::general_purpose::STANDARD as BASE64;
use base64::Engine;
use sha1::{Digest, Sha1};
let mut digest = Sha1::default();
digest.update(key);
digest.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
let encoded = BASE64.encode(digest.finalize());
hyper::header::HeaderValue::from_str(&encoded).expect("base64 is a valid value")
}
struct VersionChecker;
impl tokio_tungstenite::tungstenite::handshake::server::Callback for VersionChecker {
fn on_request(
self,
request: &tokio_tungstenite::tungstenite::handshake::server::Request,
mut response: tokio_tungstenite::tungstenite::handshake::server::Response,
) -> Result<
tokio_tungstenite::tungstenite::handshake::server::Response,
tokio_tungstenite::tungstenite::handshake::server::ErrorResponse,
> {
if let Some(protocols) = request.headers().get("Sec-WebSocket-Protocol") {
if let Ok(protocols) = protocols.to_str() {
for protocol in protocols.split(',').map(str::trim) {
if protocol == CURRENT_PROTOCOL_VERSION {
response.headers_mut().insert(
"Sec-WebSocket-Protocol",
CURRENT_PROTOCOL_VERSION.try_into().unwrap(),
);
return Ok(response);
}
}
}
}
let mut err = tokio_tungstenite::tungstenite::handshake::server::ErrorResponse::new(None);
*err.status_mut() = 406_u16.try_into().unwrap();
Err(err)
}
}