Skip to content

Commit fee8b7b

Browse files
committed
chore(guard): route target via websocket protocols
1 parent 5cb4a7b commit fee8b7b

File tree

7 files changed

+106
-37
lines changed

7 files changed

+106
-37
lines changed

packages/core/guard/core/src/proxy_service.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,7 +1169,6 @@ impl ProxyService {
11691169
}
11701170

11711171
// Handle WebSocket upgrade properly with hyper_tungstenite
1172-
// First, upgrade the client connection
11731172
tracing::debug!("Upgrading client connection to WebSocket");
11741173
let (client_response, client_websocket) = match hyper_tungstenite::upgrade(req, None) {
11751174
Result::Ok(x) => {
@@ -1928,7 +1927,15 @@ impl ProxyService {
19281927
// structure but convert it to our expected return type without modifying its content
19291928
tracing::debug!("Returning WebSocket upgrade response to client");
19301929
// Extract the parts from the response but preserve all headers and status
1931-
let (parts, _) = client_response.into_parts();
1930+
let (mut parts, _) = client_response.into_parts();
1931+
1932+
// Add Sec-WebSocket-Protocol header to the response
1933+
// Many WebSocket clients (e.g. node-ws & Cloudflare) require a protocol in the response
1934+
parts.headers.insert(
1935+
"sec-websocket-protocol",
1936+
hyper::header::HeaderValue::from_static("rivet"),
1937+
);
1938+
19321939
// Create a new response with an empty body - WebSocket upgrades don't need a body
19331940
Ok(Response::from_parts(
19341941
parts,

packages/core/guard/server/src/routing/mod.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ pub mod pegboard_gateway;
1212
mod runner;
1313

1414
pub(crate) const X_RIVET_TARGET: HeaderName = HeaderName::from_static("x-rivet-target");
15+
pub(crate) const SEC_WEBSOCKET_PROTOCOL: HeaderName =
16+
HeaderName::from_static("sec-websocket-protocol");
17+
pub(crate) const WS_PROTOCOL_TARGET: &str = "rivet_target.";
1518

1619
/// Creates the main routing function that handles all incoming requests
1720
#[tracing::instrument(skip_all)]
@@ -31,9 +34,34 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
3134

3235
tracing::debug!("Routing request for hostname: {host}, path: {path}");
3336

37+
// Check if this is a WebSocket upgrade request
38+
let is_websocket = headers
39+
.get("upgrade")
40+
.and_then(|v| v.to_str().ok())
41+
.map(|v| v.eq_ignore_ascii_case("websocket"))
42+
.unwrap_or(false);
43+
44+
// Extract target from WebSocket protocol or HTTP header
45+
let target = if is_websocket {
46+
// For WebSocket, parse the sec-websocket-protocol header
47+
headers
48+
.get(SEC_WEBSOCKET_PROTOCOL)
49+
.and_then(|protocols| protocols.to_str().ok())
50+
.and_then(|protocols| {
51+
// Parse protocols to find target.{value}
52+
protocols
53+
.split(',')
54+
.map(|p| p.trim())
55+
.find(|p| p.starts_with(WS_PROTOCOL_TARGET))
56+
.map(|p| &p[WS_PROTOCOL_TARGET.len()..])
57+
})
58+
} else {
59+
// For HTTP, use the x-rivet-target header
60+
headers.get(X_RIVET_TARGET).and_then(|x| x.to_str().ok())
61+
};
62+
3463
// Read target
35-
if let Some(target) = headers.get(X_RIVET_TARGET).and_then(|x| x.to_str().ok())
36-
{
64+
if let Some(target) = target {
3765
if let Some(routing_output) =
3866
runner::route_request(&ctx, target, host, path).await?
3967
{
@@ -47,6 +75,7 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
4775
host,
4876
path,
4977
headers,
78+
is_websocket,
5079
)
5180
.await?
5281
{

packages/core/guard/server/src/routing/pegboard_gateway.rs

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ use crate::{errors, shared_state::SharedState};
1010

1111
const ACTOR_READY_TIMEOUT: Duration = Duration::from_secs(10);
1212
pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor");
13+
const SEC_WEBSOCKET_PROTOCOL: HeaderName = HeaderName::from_static("sec-websocket-protocol");
14+
const WS_PROTOCOL_ACTOR: &str = "rivet_actor.";
1315

1416
/// Route requests to actor services based on hostname and path
1517
#[tracing::instrument(skip_all)]
@@ -20,20 +22,48 @@ pub async fn route_request(
2022
_host: &str,
2123
path: &str,
2224
headers: &hyper::HeaderMap,
25+
is_websocket: bool,
2326
) -> Result<Option<RoutingOutput>> {
2427
// Check target
2528
if target != "actor" {
2629
return Ok(None);
2730
}
2831

32+
// Extract actor ID from WebSocket protocol or HTTP header
33+
let actor_id_str = if is_websocket {
34+
// For WebSocket, parse the sec-websocket-protocol header
35+
headers
36+
.get(SEC_WEBSOCKET_PROTOCOL)
37+
.and_then(|protocols| protocols.to_str().ok())
38+
.and_then(|protocols| {
39+
// Parse protocols to find actor.{id}
40+
protocols
41+
.split(',')
42+
.map(|p| p.trim())
43+
.find(|p| p.starts_with(WS_PROTOCOL_ACTOR))
44+
.map(|p| &p[WS_PROTOCOL_ACTOR.len()..])
45+
})
46+
.ok_or_else(|| {
47+
crate::errors::MissingHeader {
48+
header: "actor protocol in sec-websocket-protocol".to_string(),
49+
}
50+
.build()
51+
})?
52+
} else {
53+
// For HTTP, use the x-rivet-actor header
54+
headers
55+
.get(X_RIVET_ACTOR)
56+
.and_then(|x| x.to_str().ok())
57+
.ok_or_else(|| {
58+
crate::errors::MissingHeader {
59+
header: X_RIVET_ACTOR.to_string(),
60+
}
61+
.build()
62+
})?
63+
};
64+
2965
// Find actor to route to
30-
let actor_id_str = headers.get(X_RIVET_ACTOR).ok_or_else(|| {
31-
crate::errors::MissingHeader {
32-
header: X_RIVET_ACTOR.to_string(),
33-
}
34-
.build()
35-
})?;
36-
let actor_id = Id::parse(actor_id_str.to_str()?)?;
66+
let actor_id = Id::parse(actor_id_str)?;
3767

3868
// Route to peer dc where the actor lives
3969
if actor_id.label() != ctx.config().dc_label() {

packages/core/pegboard-gateway/src/lib.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use bytes::Bytes;
44
use futures_util::StreamExt;
55
use gas::prelude::*;
66
use http_body_util::{BodyExt, Full};
7-
use hyper::{Request, Response, StatusCode};
7+
use hyper::{Request, Response, StatusCode, header::HeaderName};
88
use rivet_guard_core::{
99
WebSocketHandle,
1010
custom_serve::CustomServeTrait,
@@ -22,6 +22,8 @@ use crate::shared_state::{SharedState, TunnelMessageData};
2222
pub mod shared_state;
2323

2424
const UPS_REQ_TIMEOUT: Duration = Duration::from_secs(2);
25+
const SEC_WEBSOCKET_PROTOCOL: HeaderName = HeaderName::from_static("sec-websocket-protocol");
26+
const WS_PROTOCOL_ACTOR: &str = "rivet_actor.";
2527

2628
pub struct PegboardGateway {
2729
ctx: StandaloneCtx,
@@ -94,7 +96,7 @@ impl PegboardGateway {
9496
req: Request<Full<Bytes>>,
9597
_request_context: &mut RequestContext,
9698
) -> Result<Response<ResponseBody>> {
97-
// Extract actor ID for the message
99+
// Extract actor ID for the message (HTTP requests use x-rivet-actor header)
98100
let actor_id = req
99101
.headers()
100102
.get("x-rivet-actor")
@@ -200,11 +202,19 @@ impl PegboardGateway {
200202
path: &str,
201203
_request_context: &mut RequestContext,
202204
) -> Result<()> {
203-
// Extract actor ID for the message
205+
// Extract actor ID from WebSocket protocol
204206
let actor_id = headers
205-
.get("x-rivet-actor")
206-
.context("missing x-rivet-actor")?
207-
.to_str()?
207+
.get(SEC_WEBSOCKET_PROTOCOL)
208+
.and_then(|protocols| protocols.to_str().ok())
209+
.and_then(|protocols| {
210+
// Parse protocols to find actor.{id}
211+
protocols
212+
.split(',')
213+
.map(|p| p.trim())
214+
.find(|p| p.starts_with(WS_PROTOCOL_ACTOR))
215+
.map(|p| &p[WS_PROTOCOL_ACTOR.len()..])
216+
})
217+
.context("missing actor protocol in sec-websocket-protocol")?
208218
.to_string();
209219

210220
// Extract headers

packages/infra/engine/tests/common/actors.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -427,20 +427,20 @@ pub async fn ping_actor_websocket_via_guard(guard_port: u16, actor_id: &str) ->
427427
"testing websocket connection to actor via guard"
428428
);
429429

430-
// Build WebSocket URL and request
430+
// Build WebSocket URL and request with protocols for routing
431431
let ws_url = format!("ws://127.0.0.1:{}/ws", guard_port);
432432
let mut request = ws_url
433433
.clone()
434434
.into_client_request()
435435
.expect("Failed to create WebSocket request");
436436

437-
// Add headers for routing through guard to actor
438-
request
439-
.headers_mut()
440-
.insert("X-Rivet-Target", "actor".parse().unwrap());
441-
request
442-
.headers_mut()
443-
.insert("X-Rivet-Actor", actor_id.parse().unwrap());
437+
// Add protocols for routing through guard to actor
438+
request.headers_mut().insert(
439+
"Sec-WebSocket-Protocol",
440+
format!("rivet, rivet_target.actor, rivet_actor.{}", actor_id)
441+
.parse()
442+
.unwrap(),
443+
);
444444

445445
// Connect to WebSocket
446446
let (ws_stream, response) = connect_async(request)

scripts/tests/actor_e2e.ts

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,8 @@ function testWebSocket(actorId: string): Promise<void> {
6060

6161
console.log(`Connecting WebSocket to: ${wsUrl}`);
6262

63-
const ws = new WebSocket(wsUrl, {
64-
headers: {
65-
"X-Rivet-Target": "actor",
66-
"X-Rivet-Actor": actorId,
67-
},
68-
});
63+
const protocols = ["rivet", "rivet_target.actor", `rivet_actor.${actorId}`];
64+
const ws = new WebSocket(wsUrl, protocols);
6965

7066
let pingReceived = false;
7167
let echoReceived = false;

sdks/typescript/runner/src/mod.ts

Lines changed: 3 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)