Skip to content

Commit 2465565

Browse files
committed
add merge headers method
1 parent 5092f01 commit 2465565

File tree

1 file changed

+44
-12
lines changed

1 file changed

+44
-12
lines changed

plugins/websocket/src/lib.rs

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,33 @@
99
html_favicon_url = "https://github.com/tauri-apps/tauri/raw/dev/app-icon.png"
1010
)]
1111

12-
use futures_util::{stream::SplitSink, SinkExt, StreamExt};
13-
use http::header::{HeaderName, HeaderValue};
14-
use serde::{ser::Serializer, Deserialize, Serialize};
12+
use futures_util::{SinkExt, StreamExt, stream::SplitSink};
13+
use http::{
14+
HeaderMap, Request,
15+
header::{HeaderName, HeaderValue},
16+
};
17+
use serde::{Deserialize, Serialize, ser::Serializer};
1518
use tauri::{
19+
AppHandle, Manager, Runtime, State, Url, Window,
1620
ipc::Channel,
1721
plugin::{Builder as PluginBuilder, TauriPlugin},
18-
Manager, Runtime, State, Window,
1922
};
2023
use tokio::{net::TcpStream, sync::Mutex};
2124
#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
2225
use tokio_tungstenite::connect_async_tls_with_config;
2326
#[cfg(not(any(feature = "rustls-tls", feature = "native-tls")))]
2427
use tokio_tungstenite::connect_async_with_config;
2528
use tokio_tungstenite::{
29+
Connector, MaybeTlsStream, WebSocketStream,
2630
tungstenite::{
31+
Message,
2732
client::IntoClientRequest,
2833
protocol::{CloseFrame as ProtocolCloseFrame, WebSocketConfig},
29-
Message,
3034
},
31-
Connector, MaybeTlsStream, WebSocketStream,
3235
};
3336

34-
use std::collections::HashMap;
3537
use std::str::FromStr;
38+
use std::{collections::HashMap, marker::PhantomData};
3639

3740
type Id = u32;
3841
type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
@@ -157,6 +160,10 @@ async fn connect<R: Runtime>(
157160
}
158161
}
159162

163+
if let Some(state) = window.app_handle().try_state::<RequestCallback<R>>() {
164+
(state.inner().0)(&mut request, window.app_handle());
165+
}
166+
160167
#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
161168
let tls_connector = match window.try_state::<TlsConnector>() {
162169
Some(tls_connector) => tls_connector.0.lock().await.clone(),
@@ -242,31 +249,56 @@ async fn send(
242249
}
243250

244251
pub fn init<R: Runtime>() -> TauriPlugin<R> {
245-
Builder::default().build()
252+
Builder::new().build()
246253
}
247254

248-
#[derive(Default)]
249-
pub struct Builder {
255+
/// Struct to provide concrete type for the manager
256+
struct RequestCallback<R: Runtime>(
257+
Box<dyn Fn(&mut Request<()>, &AppHandle<R>) + Send + Sync + 'static>,
258+
);
259+
260+
pub struct Builder<R: Runtime> {
250261
tls_connector: Option<Connector>,
262+
merge_headers: Option<RequestCallback<R>>,
251263
}
252264

253-
impl Builder {
265+
impl<R> Builder<R>
266+
where
267+
R: Runtime,
268+
{
254269
pub fn new() -> Self {
255270
Self {
256271
tls_connector: None,
272+
merge_headers: None,
257273
}
258274
}
259275

276+
/// add a callback which is able to modify the initial headers of the http upgrade request.
277+
/// This is useful for scenarios where the frontend may not know all the required headers that must be sent.
278+
/// e.g. in the scenario of http-only cookies
279+
pub fn merge_header_callback(
280+
mut self,
281+
cb: Box<dyn Fn(&mut Request<()>, &AppHandle<R>) + Send + Sync + 'static>,
282+
) -> Self {
283+
self.merge_headers.replace(RequestCallback(cb));
284+
self
285+
}
286+
260287
pub fn tls_connector(mut self, connector: Connector) -> Self {
261288
self.tls_connector.replace(connector);
262289
self
263290
}
264291

265-
pub fn build<R: Runtime>(self) -> TauriPlugin<R> {
292+
pub fn build(self) -> TauriPlugin<R> {
266293
PluginBuilder::new("websocket")
267294
.invoke_handler(tauri::generate_handler![connect, send])
268295
.setup(|app, _api| {
269296
app.manage(ConnectionManager::default());
297+
298+
if let Some(cb) = self.merge_headers {
299+
app.manage(cb);
300+
}
301+
270302
#[cfg(any(feature = "rustls-tls", feature = "native-tls"))]
271303
app.manage(TlsConnector(Mutex::new(self.tls_connector)));
272304
Ok(())

0 commit comments

Comments
 (0)