Skip to content

Commit a5f5632

Browse files
committed
feat(core): allow routing traffic to actors via path
1 parent 323a773 commit a5f5632

File tree

9 files changed

+941
-80
lines changed

9 files changed

+941
-80
lines changed

engine/packages/guard/src/routing/mod.rs

Lines changed: 122 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ pub(crate) const SEC_WEBSOCKET_PROTOCOL: HeaderName =
1717
HeaderName::from_static("sec-websocket-protocol");
1818
pub(crate) const WS_PROTOCOL_TARGET: &str = "rivet_target.";
1919

20+
#[derive(Debug, Clone)]
21+
pub struct ActorPathInfo {
22+
pub actor_id: String,
23+
pub token: Option<String>,
24+
pub remaining_path: String,
25+
}
26+
2027
/// Creates the main routing function that handles all incoming requests
2128
#[tracing::instrument(skip_all)]
2229
pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) -> RoutingFn {
@@ -35,17 +42,35 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
3542

3643
tracing::debug!("Routing request for hostname: {host}, path: {path}");
3744

38-
// Parse query parameters
39-
let query_params = parse_query_params(path);
40-
4145
// Check if this is a WebSocket upgrade request
4246
let is_websocket = headers
4347
.get("upgrade")
4448
.and_then(|v| v.to_str().ok())
4549
.map(|v| v.eq_ignore_ascii_case("websocket"))
4650
.unwrap_or(false);
4751

48-
// Extract target from WebSocket protocol, HTTP header, or query param
52+
// First, check if this is an actor path-based route
53+
if let Some(actor_path_info) = parse_actor_path(path) {
54+
tracing::debug!(?actor_path_info, "routing using path-based actor routing");
55+
56+
// Route to pegboard gateway with the extracted information
57+
if let Some(routing_output) = pegboard_gateway::route_request_path_based(
58+
&ctx,
59+
&shared_state,
60+
&actor_path_info.actor_id,
61+
actor_path_info.token.as_deref(),
62+
&actor_path_info.remaining_path,
63+
headers,
64+
is_websocket,
65+
)
66+
.await?
67+
{
68+
return Ok(routing_output);
69+
}
70+
}
71+
72+
// Fallback to header-based routing
73+
// Extract target from WebSocket protocol or HTTP header
4974
let target = if is_websocket {
5075
// For WebSocket, parse the sec-websocket-protocol header
5176
headers
@@ -58,21 +83,15 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
5883
.map(|p| p.trim())
5984
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TARGET))
6085
})
61-
// Fallback to query parameter if protocol not provided
62-
.or_else(|| query_params.get("x_rivet_target").map(|s| s.as_str()))
6386
} else {
64-
// For HTTP, use the x-rivet-target header, fallback to query param
65-
headers
66-
.get(X_RIVET_TARGET)
67-
.and_then(|x| x.to_str().ok())
68-
.or_else(|| query_params.get("x_rivet_target").map(|s| s.as_str()))
87+
// For HTTP, use the x-rivet-target header
88+
headers.get(X_RIVET_TARGET).and_then(|x| x.to_str().ok())
6989
};
7090

7191
// Read target
7292
if let Some(target) = target {
7393
if let Some(routing_output) =
74-
runner::route_request(&ctx, target, host, path, headers, &query_params)
75-
.await?
94+
runner::route_request(&ctx, target, host, path, headers).await?
7695
{
7796
return Ok(routing_output);
7897
}
@@ -85,7 +104,6 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
85104
path,
86105
headers,
87106
is_websocket,
88-
&query_params,
89107
)
90108
.await?
91109
{
@@ -120,18 +138,98 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
120138
)
121139
}
122140

123-
/// Parse query parameters from a path string
124-
fn parse_query_params(path: &str) -> std::collections::HashMap<String, String> {
125-
let mut params = std::collections::HashMap::new();
141+
/// Parse actor routing information from path
142+
/// Matches patterns:
143+
/// - /gateway/actors/{actor_id}/tokens/{token}/route/{...path}
144+
/// - /gateway/actors/{actor_id}/route/{...path}
145+
pub fn parse_actor_path(path: &str) -> Option<ActorPathInfo> {
146+
// Find query string position (everything from ? onwards, but before fragment)
147+
let query_pos = path.find('?');
148+
let fragment_pos = path.find('#');
149+
150+
// Extract query string (excluding fragment)
151+
let query_string = match (query_pos, fragment_pos) {
152+
(Some(q), Some(f)) if q < f => &path[q..f],
153+
(Some(q), None) => &path[q..],
154+
_ => "",
155+
};
156+
157+
// Extract base path (before query and fragment)
158+
let base_path = match query_pos {
159+
Some(pos) => &path[..pos],
160+
None => match fragment_pos {
161+
Some(pos) => &path[..pos],
162+
None => path,
163+
},
164+
};
165+
166+
// Check for double slashes (invalid path)
167+
if base_path.contains("//") {
168+
return None;
169+
}
170+
171+
// Split the path into segments
172+
let segments: Vec<&str> = base_path.split('/').filter(|s| !s.is_empty()).collect();
173+
174+
// Check minimum required segments: gateway, actors, {actor_id}, route
175+
if segments.len() < 4 {
176+
return None;
177+
}
178+
179+
// Verify the fixed segments
180+
if segments[0] != "gateway" || segments[1] != "actors" {
181+
return None;
182+
}
183+
184+
// Check for empty actor_id
185+
if segments[2].is_empty() {
186+
return None;
187+
}
126188

127-
if let Some(query_start) = path.find('?') {
128-
// Strip fragment if present
129-
let query = &path[query_start + 1..].split('#').next().unwrap_or("");
130-
// Use url::form_urlencoded to properly decode query parameters
131-
for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
132-
params.insert(key.into_owned(), value.into_owned());
189+
let actor_id = segments[2].to_string();
190+
191+
// Check for token or direct route
192+
let (token, remaining_path_start_idx) =
193+
if segments.len() >= 6 && segments[3] == "tokens" && segments[5] == "route" {
194+
// Pattern with token: /gateway/actors/{actor_id}/tokens/{token}/route/{...path}
195+
// Check for empty token
196+
if segments[4].is_empty() {
197+
return None;
198+
}
199+
(Some(segments[4].to_string()), 6)
200+
} else if segments.len() >= 4 && segments[3] == "route" {
201+
// Pattern without token: /gateway/actors/{actor_id}/route/{...path}
202+
(None, 4)
203+
} else {
204+
return None;
205+
};
206+
207+
// Calculate the position in the original path where remaining path starts
208+
let mut prefix_len = 0;
209+
for (i, segment) in segments.iter().enumerate() {
210+
if i >= remaining_path_start_idx {
211+
break;
133212
}
213+
prefix_len += 1 + segment.len(); // +1 for the slash
134214
}
135215

136-
params
216+
// Extract the remaining path preserving trailing slashes
217+
let remaining_base = if prefix_len < base_path.len() {
218+
&base_path[prefix_len..]
219+
} else {
220+
"/"
221+
};
222+
223+
// Ensure remaining path starts with /
224+
let remaining_path = if remaining_base.is_empty() || !remaining_base.starts_with('/') {
225+
format!("/{}{}", remaining_base, query_string)
226+
} else {
227+
format!("{}{}", remaining_base, query_string)
228+
};
229+
230+
Some(ActorPathInfo {
231+
actor_id,
232+
token,
233+
remaining_path,
234+
})
137235
}

engine/packages/guard/src/routing/pegboard_gateway.rs

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,30 @@ use crate::{errors, shared_state::SharedState};
1010

1111
const ACTOR_READY_TIMEOUT: Duration = Duration::from_secs(10);
1212
pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor");
13+
pub const X_RIVET_AMESPACE: HeaderName = HeaderName::from_static("x-rivet-namespace");
1314
const WS_PROTOCOL_ACTOR: &str = "rivet_actor.";
15+
const WS_PROTOCOL_TOKEN: &str = "rivet_token.";
1416

15-
/// Route requests to actor services based on hostname and path
17+
/// Route requests to actor services using path-based routing
18+
#[tracing::instrument(skip_all)]
19+
pub async fn route_request_path_based(
20+
ctx: &StandaloneCtx,
21+
shared_state: &SharedState,
22+
actor_id_str: &str,
23+
_token: Option<&str>,
24+
path: &str,
25+
_headers: &hyper::HeaderMap,
26+
_is_websocket: bool,
27+
) -> Result<Option<RoutingOutput>> {
28+
// NOTE: Token validation implemented in EE
29+
30+
// Parse actor ID
31+
let actor_id = Id::parse(actor_id_str).context("invalid actor id in path")?;
32+
33+
route_request_inner(ctx, shared_state, actor_id, path).await
34+
}
35+
36+
/// Route requests to actor services based on headers
1637
#[tracing::instrument(skip_all)]
1738
pub async fn route_request(
1839
ctx: &StandaloneCtx,
@@ -22,14 +43,13 @@ pub async fn route_request(
2243
path: &str,
2344
headers: &hyper::HeaderMap,
2445
is_websocket: bool,
25-
query_params: &std::collections::HashMap<String, String>,
2646
) -> Result<Option<RoutingOutput>> {
2747
// Check target
2848
if target != "actor" {
2949
return Ok(None);
3050
}
3151

32-
// Extract actor ID from WebSocket protocol, HTTP header, or query param
52+
// Extract actor ID from WebSocket protocol or HTTP header
3353
let actor_id_str = if is_websocket {
3454
// For WebSocket, parse the sec-websocket-protocol header
3555
headers
@@ -42,26 +62,22 @@ pub async fn route_request(
4262
.map(|p| p.trim())
4363
.find_map(|p| p.strip_prefix(WS_PROTOCOL_ACTOR))
4464
})
45-
// Fallback to query parameter if protocol not provided
46-
.or_else(|| query_params.get("x_rivet_actor").map(|s| s.as_str()))
4765
.ok_or_else(|| {
4866
crate::errors::MissingHeader {
49-
header: "`rivet_actor.*` protocol in sec-websocket-protocol or x_rivet_actor query parameter".to_string(),
67+
header: "`rivet_actor.*` protocol in sec-websocket-protocol".to_string(),
5068
}
5169
.build()
5270
})?
5371
} else {
54-
// For HTTP, use the x-rivet-actor header, fallback to query param
72+
// For HTTP, use the x-rivet-actor header
5573
headers
5674
.get(X_RIVET_ACTOR)
5775
.map(|x| x.to_str())
5876
.transpose()
5977
.context("invalid x-rivet-actor header")?
60-
// Fallback to query parameter if header not provided
61-
.or_else(|| query_params.get("x_rivet_actor").map(|s| s.as_str()))
6278
.ok_or_else(|| {
6379
crate::errors::MissingHeader {
64-
header: format!("{} header or x_rivet_actor query parameter", X_RIVET_ACTOR),
80+
header: X_RIVET_ACTOR.to_string(),
6581
}
6682
.build()
6783
})?
@@ -70,6 +86,15 @@ pub async fn route_request(
7086
// Find actor to route to
7187
let actor_id = Id::parse(actor_id_str).context("invalid x-rivet-actor header")?;
7288

89+
route_request_inner(ctx, shared_state, actor_id, path).await
90+
}
91+
92+
async fn route_request_inner(
93+
ctx: &StandaloneCtx,
94+
shared_state: &SharedState,
95+
actor_id: Id,
96+
path: &str,
97+
) -> Result<Option<RoutingOutput>> {
7398
// Route to peer dc where the actor lives
7499
if actor_id.label() != ctx.config().dc_label() {
75100
tracing::debug!(peer_dc_label=?actor_id.label(), "re-routing actor to peer dc");
@@ -153,11 +178,12 @@ pub async fn route_request(
153178

154179
tracing::debug!(?actor_id, ?runner_id, "actor ready");
155180

156-
// Return pegboard-gateway instance
181+
// Return pegboard-gateway instance with path
157182
let gateway = pegboard_gateway::PegboardGateway::new(
158183
shared_state.pegboard_gateway.clone(),
159184
runner_id,
160185
actor_id,
186+
path.to_string(),
161187
);
162188
Ok(Some(RoutingOutput::CustomServe(std::sync::Arc::new(
163189
gateway,

engine/packages/guard/src/routing/runner.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ pub async fn route_request(
1414
host: &str,
1515
path: &str,
1616
headers: &hyper::HeaderMap,
17-
query_params: &std::collections::HashMap<String, String>,
1817
) -> Result<Option<RoutingOutput>> {
1918
if target != "runner" {
2019
return Ok(None);
@@ -58,7 +57,7 @@ pub async fn route_request(
5857

5958
// Check auth (if enabled)
6059
if let Some(auth) = &ctx.config().auth {
61-
// Extract token from protocol, header, or query param
60+
// Extract token from protocol or header
6261
let token = if is_websocket {
6362
headers
6463
.get(SEC_WEBSOCKET_PROTOCOL)
@@ -69,26 +68,19 @@ pub async fn route_request(
6968
.map(|p| p.trim())
7069
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TOKEN))
7170
})
72-
// Fallback to query parameter if protocol not provided
73-
.or_else(|| query_params.get("x_rivet_token").map(|s| s.as_str()))
7471
.ok_or_else(|| {
7572
crate::errors::MissingHeader {
76-
header: "`rivet_token.*` protocol in sec-websocket-protocol or x_rivet_token query parameter".to_string(),
73+
header: "`rivet_token.*` protocol in sec-websocket-protocol".to_string(),
7774
}
7875
.build()
7976
})?
8077
} else {
8178
headers
8279
.get(X_RIVET_TOKEN)
8380
.and_then(|x| x.to_str().ok())
84-
// Fallback to query parameter if header not provided
85-
.or_else(|| query_params.get("x_rivet_token").map(|s| s.as_str()))
8681
.ok_or_else(|| {
8782
crate::errors::MissingHeader {
88-
header: format!(
89-
"{} header or x_rivet_token query parameter",
90-
X_RIVET_TOKEN
91-
),
83+
header: X_RIVET_TOKEN.to_string(),
9284
}
9385
.build()
9486
})?

0 commit comments

Comments
 (0)