diff --git a/README.md b/README.md index 05eb17eeb..983426841 100644 --- a/README.md +++ b/README.md @@ -197,9 +197,10 @@ Ohkami::new(( ``` `.howls()` (`tls` feature only) is used to run Ohkami with TLS (HTTPS) support -with [`rustls`](https://github.com/rustls) ecosystem (described in `tls` feature section). +with `tokio` and [`rustls`](https://github.com/rustls) ecosystem +(currently `rt_tokio` only / described in `tls` feature section). -`howl(s)` supports graceful shutdown by `Ctrl-C` or `SIGTERM` signal on native runtimes. +`howl(s)` supports graceful shutdown by `Ctrl-C` ( `SIGINT` ) on native runtimes.
@@ -309,9 +310,6 @@ async fn main() { ### `"ws"` : WebSocket -Ohkami only handles `ws://`.\ -Use some reverse proxy to do with `wss://`. - ```rust,no_run use ohkami::{Ohkami, Route}; use ohkami::ws::{WebSocketContext, WebSocket, Message}; @@ -443,7 +441,7 @@ $ openssl req -x509 -newkey rsa:4096 -nodes -keyout server.key -out server.crt - [dependencies] ohkami = { version = "0.24", features = ["rt_tokio", "tls"] } tokio = { version = "1", features = ["full"] } -rustls = { version = "0.22", features = ["ring"] } +rustls = { version = "0.23", features = ["ring"] } rustls-pemfile = "2.2" ``` diff --git a/examples/websocket/.gitignore b/examples/websocket/.gitignore new file mode 100644 index 000000000..fca5b1335 --- /dev/null +++ b/examples/websocket/.gitignore @@ -0,0 +1 @@ +/*.pem diff --git a/examples/websocket/Cargo.toml b/examples/websocket/Cargo.toml index 3e2c463a3..a60e51167 100644 --- a/examples/websocket/Cargo.toml +++ b/examples/websocket/Cargo.toml @@ -4,8 +4,11 @@ version = "0.1.0" edition = "2024" [dependencies] -ohkami = { workspace = true } -tokio = { workspace = true } +ohkami = { workspace = true } +tokio = { workspace = true } +rustls = { optional = true, version = "0.23", features = ["ring"] } +rustls-pemfile = { optional = true, version = "2.2" } [features] +tls = ["ohkami/tls", "dep:rustls", "dep:rustls-pemfile"] DEBUG = ["ohkami/DEBUG"] \ No newline at end of file diff --git a/examples/websocket/README.md b/examples/websocket/README.md new file mode 100644 index 000000000..cd4a2a6c1 --- /dev/null +++ b/examples/websocket/README.md @@ -0,0 +1,16 @@ +# WebSocket Example + +## Feature flags description + +- `DEBUG`: Enables Ohkami's debug logging. +- `tls`: Enables TLS support ( https://, wss:// ). + +## Prerequisites + +If you want to run this example with TLS support, you need to have +[`mkcert`](https://github.com/FiloSottile/mkcert) and run: + +```sh +# assuming you have mkcert installed and `mkcert -install` has already executed: +mkcert -key-file key.pem -cert-file cert.pem localhost 127.0.0.1 +``` diff --git a/examples/websocket/src/main.rs b/examples/websocket/src/main.rs index 7f40d0761..fe51e5956 100644 --- a/examples/websocket/src/main.rs +++ b/examples/websocket/src/main.rs @@ -143,11 +143,54 @@ async fn main() { } } - Ohkami::new((Logger, - "/".Mount("./template").omit_extensions(&[".html"]), + #[cfg(feature="tls")] + let tls_config = { + use rustls::ServerConfig; + use rustls::pki_types::{CertificateDer, PrivateKeyDer}; + use std::fs::File; + use std::io::BufReader; + + // Initialize rustls crypto provider + rustls::crypto::ring::default_provider().install_default() + .expect("Failed to install rustls crypto provider"); + + // Load certificates and private key + let cert_file = File::open("./cert.pem").expect("Failed to open certificate file"); + let key_file = File::open("./key.pem").expect("Failed to open private key file"); + + let cert_chain = rustls_pemfile::certs(&mut BufReader::new(cert_file)) + .map(|cd| cd.map(CertificateDer::from)) + .collect::, _>>() + .expect("Failed to read certificate chain"); + + let key = rustls_pemfile::read_one(&mut BufReader::new(key_file)) + .expect("Failed to read private key") + .map(|p| match p { + rustls_pemfile::Item::Pkcs1Key(k) => PrivateKeyDer::Pkcs1(k), + rustls_pemfile::Item::Pkcs8Key(k) => PrivateKeyDer::Pkcs8(k), + _ => panic!("Unexpected private key type"), + }) + .expect("Failed to read private key"); + + // Build TLS configuration + ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(cert_chain, key) + .expect("Failed to build TLS configuration") + }; + + let o = Ohkami::new(( + Logger, + "/".Mount("./template").omit_extensions(&["html"]), "/echo1".GET(echo_text), "/echo2/:name".GET(echo_text_2), "/echo3/:name".GET(echo_text_3), "/echo4/:name".GET(echo4), - )).howl("localhost:3030").await + )); + + #[cfg(not(feature="tls"))] + o.howl("localhost:3030").await; + + #[cfg(feature="tls")] + o.howls("localhost:3030", tls_config).await; } diff --git a/examples/websocket/template/index.html b/examples/websocket/template/index.html index 64683d94f..4248dc8c2 100644 --- a/examples/websocket/template/index.html +++ b/examples/websocket/template/index.html @@ -26,7 +26,13 @@ -

Echo Text

+

Echo Text Sample

+ +
@@ -49,53 +55,16 @@

Echo Text

- - \ No newline at end of file diff --git a/examples/websocket/template/main.js b/examples/websocket/template/main.js new file mode 100644 index 000000000..9c46aa5d1 --- /dev/null +++ b/examples/websocket/template/main.js @@ -0,0 +1,40 @@ +export function connect_input_and_button_to(ws_url, input_id, button_id) { + let ws = null; + + const input = document.getElementById(input_id); + input.spellcheck = false; + input.disabled = true; + + const button = document.getElementById(button_id); + button.textContent = "connect"; + + button.addEventListener( + "click", (e) => { + if (button.textContent == "connect") { + ws = new WebSocket(ws_url); + ws.addEventListener("open", (e) => { + console.log(e); + ws.send("test"); + }); + ws.addEventListener("message", (e) => { + console.log("ws got message: ", e.data); + }); + ws.addEventListener("close", (e) => { + console.log("close:", e); + + input.value = ""; + input.disabled = true; + + button.textContent = "connect"; + }); + + input.disabled = false; + + button.textContent = "send"; + } else { + console.log("sending:", input.value); + ws.send(input.value); + } + } + ); +} diff --git a/ohkami/Cargo.toml b/ohkami/Cargo.toml index 42f760ba5..8bbe7be8b 100644 --- a/ohkami/Cargo.toml +++ b/ohkami/Cargo.toml @@ -46,7 +46,7 @@ mime_guess = { version = "2.0", optional = true } ctrlc = { version = "3.4", optional = true } num_cpus = { version = "1.17", optional = true } futures-util = { version = "0.3", optional = true, default-features = false } -mews = { version = "0.2", optional = true } +mews = { version = "0.4", optional = true } rustls = { version = "0.23.23", optional = true } tokio-rustls = { version = "0.26.2", optional = true } @@ -65,7 +65,7 @@ rt_smol = ["__rt_native__", rt_nio = ["__rt_native__", "dep:nio", "dep:tokio","tokio/io-util", - "mews?/rt_nio", + #"mews?/rt_nio", ] rt_glommio = ["__rt_native__", "dep:glommio", @@ -85,7 +85,7 @@ rt_lambda = ["__rt__", nightly = [] openapi = ["dep:ohkami_openapi", "ohkami_macros/openapi"] sse = ["ohkami_lib/stream"] -ws = ["ohkami_lib/stream", "dep:mews"] +ws = ["ohkami_lib/stream", "dep:mews", "dep:futures-util", "futures-util/io", "futures-util/unstable","futures-util/bilock"] tls = ["rt_tokio", "dep:rustls", "dep:tokio-rustls"] # currently depending on tokio-rustls and works only on tokio ##### internal ##### diff --git a/ohkami/src/lib.rs b/ohkami/src/lib.rs index 921430808..3d0847c23 100644 --- a/ohkami/src/lib.rs +++ b/ohkami/src/lib.rs @@ -202,7 +202,7 @@ pub use ohkami::{Ohkami, Route}; pub mod fang; pub use fang::{handler, Fang, FangProc, FangAction}; -#[cfg(all(feature="__rt_native__", feature="rt_tokio", feature="tls"))] +#[cfg(feature="tls")] mod tls; pub mod header; diff --git a/ohkami/src/ohkami/mod.rs b/ohkami/src/ohkami/mod.rs index 38bac67ea..2e5161ee1 100644 --- a/ohkami/src/ohkami/mod.rs +++ b/ohkami/src/ohkami/mod.rs @@ -612,9 +612,9 @@ impl Ohkami { }; let session = Session::new( - router.clone(), connection, - addr.ip() + addr.ip(), + router.clone(), ); let wg = wg.add(); @@ -714,14 +714,25 @@ impl Ohkami { crate::INFO!("start serving on {}", listener.local_addr().unwrap()); while let Some(accept) = ctrl_c.until_interrupt(listener.accept()).await { - let Ok((tcp_stream, addr)) = accept else { continue }; + crate::DEBUG!("accept: {accept:?}"); - let Ok(tls_stream) = tls_acceptor.accept(tcp_stream).await else { continue }; + let Ok((connection, addr)) = accept else { continue }; + + let connection = match ctrl_c.until_interrupt(tls_acceptor.accept(connection)).await { + None => break, + Some(Ok(tls_stream)) => TlsStream(tls_stream), + Some(Err(e)) => { + crate::ERROR!("TLS accept error: {e}"); + continue; + } + }; + + crate::DEBUG!("accepted TLS connection: {connection:?}"); let session = Session::new( + connection, + addr.ip(), router.clone(), - TlsStream(tls_stream), - addr.ip() ); let wg = wg.add(); diff --git a/ohkami/src/response/mod.rs b/ohkami/src/response/mod.rs index 70f087959..455f77a83 100644 --- a/ohkami/src/response/mod.rs +++ b/ohkami/src/response/mod.rs @@ -295,7 +295,7 @@ pub(super) enum Upgrade { None, #[cfg(feature="ws")] - WebSocket(mews::WebSocket), + WebSocket(mews::WebSocket), } #[cfg(feature="__rt_native__")] impl Upgrade { diff --git a/ohkami/src/session/connection.rs b/ohkami/src/session/connection.rs new file mode 100644 index 000000000..d901bd2ab --- /dev/null +++ b/ohkami/src/session/connection.rs @@ -0,0 +1,165 @@ +#[derive(Debug)] +pub enum Connection { + Tcp(crate::__rt__::TcpStream), + #[cfg(feature="tls")] + Tls(crate::tls::TlsStream), +} + +impl From for Connection { + fn from(stream: crate::__rt__::TcpStream) -> Self { + Self::Tcp(stream) + } +} +#[cfg(feature="tls")] +impl From for Connection { + fn from(stream: crate::tls::TlsStream) -> Self { + Self::Tls(stream) + } +} + +#[cfg(feature="rt_tokio")] +const _: () = { + impl tokio::io::AsyncRead for Connection { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_> + ) -> std::task::Poll> { + match std::pin::Pin::into_inner(self) { + Self::Tcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf), + #[cfg(feature="tls")] + Self::Tls(stream) => std::pin::Pin::new(stream).poll_read(cx, buf), + } + } + } + + impl tokio::io::AsyncWrite for Connection { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8] + ) -> std::task::Poll> { + match std::pin::Pin::into_inner(self) { + Self::Tcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf), + #[cfg(feature="tls")] + Self::Tls(stream) => std::pin::Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_> + ) -> std::task::Poll> { + match std::pin::Pin::into_inner(self) { + Self::Tcp(stream) => std::pin::Pin::new(stream).poll_flush(cx), + #[cfg(feature="tls")] + Self::Tls(stream) => std::pin::Pin::new(stream).poll_flush(cx), + } + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_> + ) -> std::task::Poll> { + match std::pin::Pin::into_inner(self) { + Self::Tcp(stream) => std::pin::Pin::new(stream).poll_shutdown(cx), + #[cfg(feature="tls")] + Self::Tls(stream) => std::pin::Pin::new(stream).poll_shutdown(cx), + } + } + } +}; + +/* + * Currently `tls` feature is only supported on `rt_tokio`. + */ + +#[cfg(feature="rt_smol")] +const _: () = { + impl smol::io::AsyncRead for Connection { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8] + ) -> std::task::Poll> { + match std::pin::Pin::into_inner(self) { + Self::Tcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf), + } + } + } + + impl smol::io::AsyncWrite for Connection { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8] + ) -> std::task::Poll> { + match std::pin::Pin::into_inner(self) { + Self::Tcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_> + ) -> std::task::Poll> { + match std::pin::Pin::into_inner(self) { + Self::Tcp(stream) => std::pin::Pin::new(stream).poll_flush(cx), + } + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_> + ) -> std::task::Poll> { + match std::pin::Pin::into_inner(self) { + Self::Tcp(stream) => std::pin::Pin::new(stream).poll_close(cx), + } + } + } +}; + +#[cfg(feature="rt_glommio")] +const _: () = { + impl futures_util::AsyncRead for Connection { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8] + ) -> std::task::Poll> { + match std::pin::Pin::into_inner(self) { + Self::Tcp(stream) => std::pin::Pin::new(stream).poll_read(cx, buf), + } + } + } + + impl futures_util::AsyncWrite for Connection { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8] + ) -> std::task::Poll> { + match std::pin::Pin::into_inner(self) { + Self::Tcp(stream) => std::pin::Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_> + ) -> std::task::Poll> { + match std::pin::Pin::into_inner(self) { + Self::Tcp(stream) => std::pin::Pin::new(stream).poll_flush(cx), + } + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_> + ) -> std::task::Poll> { + match std::pin::Pin::into_inner(self) { + Self::Tcp(stream) => std::pin::Pin::new(stream).poll_close(cx), + } + } + } +}; diff --git a/ohkami/src/session/mod.rs b/ohkami/src/session/mod.rs index 646e40c5a..db83a21d8 100644 --- a/ohkami/src/session/mod.rs +++ b/ohkami/src/session/mod.rs @@ -1,41 +1,32 @@ #![cfg(feature="__rt_native__")] +mod connection; + use std::{any::Any, pin::Pin, sync::Arc, time::Duration}; use std::panic::{AssertUnwindSafe, catch_unwind}; -use crate::__rt__::{AsyncRead, AsyncWrite}; use crate::response::Upgrade; use crate::util::timeout_in; use crate::router::r#final::Router; use crate::{Request, Response}; -pub(crate) struct Session { +pub use self::connection::Connection; + +pub(crate) struct Session { + connection: Connection, router: Arc, - connection: C, ip: std::net::IpAddr, } -pub(crate) trait Connection: AsyncRead + AsyncWrite + Unpin { - #[cfg(feature="ws")] - fn into_websocket_stream(self) -> Result; -} - -impl Connection for crate::__rt__::TcpStream { - #[cfg(feature="ws")] - fn into_websocket_stream(self) -> Result { - Ok(self) - } -} - -impl Session { +impl Session { pub(crate) fn new( + connection: impl Into, + ip: std::net::IpAddr, router: Arc, - connection: C, - ip: std::net::IpAddr ) -> Self { Self { + connection: connection.into(), + ip, router, - connection, - ip } } @@ -128,31 +119,25 @@ impl Session { Upgrade::None => { crate::DEBUG!("about to shutdown connection"); }, + #[cfg(feature="ws")] Upgrade::WebSocket(ws) => { - match self.connection.into_websocket_stream() { - Ok(tcp_stream) => { - crate::DEBUG!("WebSocket session started"); - - let aborted = ws.manage_with_timeout( - Duration::from_secs(crate::CONFIG.websocket_timeout()), - tcp_stream - ).await; - if aborted { - crate::WARNING!("\ - WebSocket session aborted by timeout. In Ohkami, \ - WebSocket timeout is set to 3600 seconds (1 hour) \ - by default and is configurable by `OHKAMI_WEBSOCKET_TIMEOUT` \ - environment variable.\ - "); - } + crate::DEBUG!("WebSocket session started"); - crate::DEBUG!("WebSocket session finished"); - }, - Err(msg) => { - crate::WARNING!("{msg}"); - } + let aborted = ws.manage_with_timeout( + Duration::from_secs(crate::CONFIG.websocket_timeout()), + self.connection + ).await; + if aborted { + crate::WARNING!("\ + WebSocket session aborted by timeout. In Ohkami, \ + WebSocket timeout is set to 3600 seconds (1 hour) \ + by default and is configurable by `OHKAMI_WEBSOCKET_TIMEOUT` \ + environment variable.\ + "); } + + crate::DEBUG!("WebSocket session finished"); }, } } diff --git a/ohkami/src/tls/mod.rs b/ohkami/src/tls/mod.rs index a6a6c6a9f..2c0062074 100644 --- a/ohkami/src/tls/mod.rs +++ b/ohkami/src/tls/mod.rs @@ -1,11 +1,12 @@ +#![cfg(feature="tls")] + use tokio::io::{AsyncRead, AsyncWrite}; pub struct TlsStream(pub tokio_rustls::server::TlsStream); -impl crate::session::Connection for TlsStream { - #[cfg(feature="ws")] - fn into_websocket_stream(self) -> Result { - Err("WebSocket connections are not supported over TLS yet") +impl std::fmt::Debug for TlsStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) } } diff --git a/ohkami/src/ws/mod.rs b/ohkami/src/ws/mod.rs index 4ee7a46fb..2c11facc0 100644 --- a/ohkami/src/ws/mod.rs +++ b/ohkami/src/ws/mod.rs @@ -12,7 +12,7 @@ pub use self::worker::*; /// # Context for WebSocket handshake /// -/// `.upgrade` performs handshake and creates a WebSocket session. +/// `.upgrade(~)` performs handshake and creates a WebSocket session. /// /// ### note /// @@ -63,4 +63,9 @@ impl<'req> WebSocketContext<'req> { pub fn new(sec_websocket_key: &'req str) -> Self { Self { sec_websocket_key } } + + /* + `.upgrade(~)` and something are implemented in + `native` or `worker` submodule + */ } diff --git a/ohkami/src/ws/native.rs b/ohkami/src/ws/native.rs index 049d8276b..8da2af4bb 100644 --- a/ohkami/src/ws/native.rs +++ b/ohkami/src/ws/native.rs @@ -6,11 +6,13 @@ pub use mews::{ Config, Connection, ReadHalf, WriteHalf, - WebSocket as Session, connection, split, }; +/// used in `crate::response::content::Content::WebSocket` +pub(crate) type Session = mews::WebSocket; + impl<'ctx> super::WebSocketContext<'ctx> { /// create a `WebSocket` with the handler and default `Config`. /// use [`upgrade_with`](crate::ws::WebSocketContext::upgrade_with) to provide a custom config. @@ -18,12 +20,12 @@ impl<'ctx> super::WebSocketContext<'ctx> { /// ## handler /// /// any 'static `FnOnce(Connection) -> {impl Future + Send} + Send + Sync` - pub fn upgrade( + pub fn upgrade( self, handler: H - ) -> WebSocket + ) -> WebSocket where - H: FnOnce(Connection) -> F + Send + Sync + 'static, + H: FnOnce(Connection) -> F + Send + Sync + 'static, F: std::future::Future + Send + 'static, { self.upgrade_with(Config::default(), handler) @@ -34,12 +36,12 @@ impl<'ctx> super::WebSocketContext<'ctx> { /// ## handler /// /// any 'static `FnOnce(Connection) -> {impl Future + Send} + Send + Sync` - pub fn upgrade_with(self, + pub fn upgrade_with(self, config: Config, handler: H - ) -> WebSocket + ) -> WebSocket where - H: FnOnce(Connection) -> F + Send + Sync + 'static, + H: FnOnce(Connection) -> F + Send + Sync + 'static, F: std::future::Future + Send + 'static, { let (sign, session) = mews::WebSocketContext::new(self.sec_websocket_key) @@ -110,9 +112,9 @@ impl<'ctx> super::WebSocketContext<'ctx> { /// }) /// } /// ``` -pub struct WebSocket { +pub struct WebSocket { sign: String, - session: Session, + session: Session, } impl crate::IntoResponse for WebSocket { fn into_response(self) -> crate::Response { diff --git a/ohkami/src/ws/worker.rs b/ohkami/src/ws/worker.rs index ca77f1dd3..3c575b9a3 100644 --- a/ohkami/src/ws/worker.rs +++ b/ohkami/src/ws/worker.rs @@ -2,7 +2,8 @@ pub use mews::{Message, CloseFrame, CloseCode}; -pub(crate) use worker::WebSocket as Session; +/// used in `crate::response::content::Content::WebSocket` +pub(crate) type Session = worker::WebSocket; use worker::{WebSocketPair, EventStream, wasm_bindgen_futures}; use std::rc::Rc;