|
9 | 9 | html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png" |
10 | 10 | )] |
11 | 11 |
|
12 | | -use futures_util::{stream::SplitSink, SinkExt, StreamExt}; |
13 | | -use http::header::{HeaderName, HeaderValue}; |
14 | | -use serde::{ser::Serializer, Deserialize, Serialize}; |
| 12 | +use futures_util::{SinkExt, StreamExt, stream::SplitSink}; |
| 13 | +use http::{ |
| 14 | + HeaderMap, Request, |
| 15 | + header::{HeaderName, HeaderValue}, |
| 16 | +}; |
| 17 | +use serde::{Deserialize, Serialize, ser::Serializer}; |
15 | 18 | use tauri::{ |
| 19 | + AppHandle, Manager, Runtime, State, Url, Window, |
16 | 20 | ipc::Channel, |
17 | 21 | plugin::{Builder as PluginBuilder, TauriPlugin}, |
18 | | - Manager, Runtime, State, Window, |
19 | 22 | }; |
20 | 23 | use tokio::{net::TcpStream, sync::Mutex}; |
21 | 24 | #[cfg(any(feature = "rustls-tls", feature = "native-tls"))] |
22 | 25 | use tokio_tungstenite::connect_async_tls_with_config; |
23 | 26 | #[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))] |
24 | 27 | use tokio_tungstenite::connect_async_with_config; |
25 | 28 | use tokio_tungstenite::{ |
| 29 | + Connector, MaybeTlsStream, WebSocketStream, |
26 | 30 | tungstenite::{ |
| 31 | + Message, |
27 | 32 | client::IntoClientRequest, |
28 | 33 | protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig}, |
29 | | - Message, |
30 | 34 | }, |
31 | | - Connector, MaybeTlsStream, WebSocketStream, |
32 | 35 | }; |
33 | 36 |
|
34 | | -use std::collections::HashMap; |
35 | 37 | use std::str::FromStr; |
| 38 | +use std::{collections::HashMap, marker::PhantomData}; |
36 | 39 |
|
37 | 40 | type Id = u32; |
38 | 41 | type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>; |
@@ -157,6 +160,10 @@ async fn connect<R: Runtime>( |
157 | 160 | } |
158 | 161 | } |
159 | 162 |
|
| 163 | + if let Some(state) = window.app_handle().try_state::<RequestCallback<R>>() { |
| 164 | + (state.inner().0)(&mut request, window.app_handle()); |
| 165 | + } |
| 166 | + |
160 | 167 | #[cfg(any(feature = "rustls-tls", feature = "native-tls"))] |
161 | 168 | let tls_connector = match window.try_state::<TlsConnector>() { |
162 | 169 | Some(tls_connector) => tls_connector.0.lock().await.clone(), |
@@ -242,31 +249,56 @@ async fn send( |
242 | 249 | } |
243 | 250 |
|
244 | 251 | pub fn init<R: Runtime>() -> TauriPlugin<R> { |
245 | | - Builder::default().build() |
| 252 | + Builder::new().build() |
246 | 253 | } |
247 | 254 |
|
248 | | -#[derive(Default)] |
249 | | -pub struct Builder { |
| 255 | +/// Struct to provide concrete type for the manager |
| 256 | +struct RequestCallback<R: Runtime>( |
| 257 | + Box<dyn Fn(&mut Request<()>, &AppHandle<R>) + Send + Sync + 'static>, |
| 258 | +); |
| 259 | + |
| 260 | +pub struct Builder<R: Runtime> { |
250 | 261 | tls_connector: Option<Connector>, |
| 262 | + merge_headers: Option<RequestCallback<R>>, |
251 | 263 | } |
252 | 264 |
|
253 | | -impl Builder { |
| 265 | +impl<R> Builder<R> |
| 266 | +where |
| 267 | + R: Runtime, |
| 268 | +{ |
254 | 269 | pub fn new() -> Self { |
255 | 270 | Self { |
256 | 271 | tls_connector: None, |
| 272 | + merge_headers: None, |
257 | 273 | } |
258 | 274 | } |
259 | 275 |
|
| 276 | + /// add a callback which is able to modify the initial headers of the http upgrade request. |
| 277 | + /// This is useful for scenarios where the frontend may not know all the required headers that must be sent. |
| 278 | + /// e.g. in the scenario of http-only cookies |
| 279 | + pub fn merge_header_callback( |
| 280 | + mut self, |
| 281 | + cb: Box<dyn Fn(&mut Request<()>, &AppHandle<R>) + Send + Sync + 'static>, |
| 282 | + ) -> Self { |
| 283 | + self.merge_headers.replace(RequestCallback(cb)); |
| 284 | + self |
| 285 | + } |
| 286 | + |
260 | 287 | pub fn tls_connector(mut self, connector: Connector) -> Self { |
261 | 288 | self.tls_connector.replace(connector); |
262 | 289 | self |
263 | 290 | } |
264 | 291 |
|
265 | | - pub fn build<R: Runtime>(self) -> TauriPlugin<R> { |
| 292 | + pub fn build(self) -> TauriPlugin<R> { |
266 | 293 | PluginBuilder::new("websocket") |
267 | 294 | .invoke_handler(tauri::generate_handler![connect, send]) |
268 | 295 | .setup(|app, _api| { |
269 | 296 | app.manage(ConnectionManager::default()); |
| 297 | + |
| 298 | + if let Some(cb) = self.merge_headers { |
| 299 | + app.manage(cb); |
| 300 | + } |
| 301 | + |
270 | 302 | #[cfg(any(feature = "rustls-tls", feature = "native-tls"))] |
271 | 303 | app.manage(TlsConnector(Mutex::new(self.tls_connector))); |
272 | 304 | Ok(()) |
|
0 commit comments