Skip to content

Commit

Permalink
Add parse_args
Browse files Browse the repository at this point in the history
This combines the argument parsing logic from the request handler
into a new function, in order to reduce the redundancy of the
argument parsing logic.

`parse_args` and `send_response` are declared as `pub` so that it can
be used by crates implementing custom request handlers.
  • Loading branch information
msk committed Apr 4, 2024
1 parent adb96ff commit 7264b52
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 30 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- `request::parse_args` to parse arguments for a request handler.

### Changed

- `SendError::MessageTooLarge` no longer contains the underlying error,
Expand Down
58 changes: 28 additions & 30 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ pub async fn handle<H: Handler>(
recv: &mut RecvStream,
) -> Result<(), HandlerError> {
let mut buf = Vec::new();
let codec = bincode::DefaultOptions::new();
loop {
let (code, body) = match message::recv_request_raw(recv, &mut buf).await {
Ok(res) => res,
Expand All @@ -232,19 +231,15 @@ pub async fn handle<H: Handler>(
send_response(send, &mut buf, handler.reload_config().await).await?;
}
RequestCode::ReloadTi => {
let version = codec
.deserialize::<&str>(body)
.map_err(frame::RecvError::DeserializationFailure)?;
let version = parse_args::<&str>(body)?;

Check warning on line 234 in src/request.rs

View check run for this annotation

Codecov / codecov/patch

src/request.rs#L234

Added line #L234 was not covered by tests
let result = handler.reload_ti(version).await;
send_response(send, &mut buf, result).await?;
}
RequestCode::ResourceUsage => {
send_response(send, &mut buf, handler.resource_usage().await).await?;
}
RequestCode::TorExitNodeList => {
let nodes = codec
.deserialize::<Vec<&str>>(body)
.map_err(frame::RecvError::DeserializationFailure)?;
let nodes = parse_args::<Vec<&str>>(body)?;

Check warning on line 242 in src/request.rs

View check run for this annotation

Codecov / codecov/patch

src/request.rs#L242

Added line #L242 was not covered by tests
let result = handler.tor_exit_node_list(&nodes).await;
send_response(send, &mut buf, result).await?;
}
Expand All @@ -257,10 +252,7 @@ pub async fn handle<H: Handler>(
send_response(send, &mut buf, result).await?;
}
RequestCode::TrustedDomainList => {
let domains = codec
.deserialize::<Result<Vec<&str>, String>>(body)
.map_err(frame::RecvError::DeserializationFailure)?;

let domains = parse_args::<Result<Vec<&str>, String>>(body)?;

Check warning on line 255 in src/request.rs

View check run for this annotation

Codecov / codecov/patch

src/request.rs#L255

Added line #L255 was not covered by tests
let result = if let Ok(domains) = domains {
handler.trusted_domain_list(&domains).await
} else {
Expand All @@ -269,50 +261,38 @@ pub async fn handle<H: Handler>(
send_response(send, &mut buf, result).await?;
}
RequestCode::InternalNetworkList => {
let network_list = codec
.deserialize::<HostNetworkGroup>(body)
.map_err(frame::RecvError::DeserializationFailure)?;
let network_list = parse_args::<HostNetworkGroup>(body)?;

Check warning on line 264 in src/request.rs

View check run for this annotation

Codecov / codecov/patch

src/request.rs#L264

Added line #L264 was not covered by tests
let result = handler.internal_network_list(network_list).await;
send_response(send, &mut buf, result).await?;
}
RequestCode::AllowList => {
let allow_list = codec
.deserialize::<HostNetworkGroup>(body)
.map_err(frame::RecvError::DeserializationFailure)?;
let allow_list = parse_args::<HostNetworkGroup>(body)?;

Check warning on line 269 in src/request.rs

View check run for this annotation

Codecov / codecov/patch

src/request.rs#L269

Added line #L269 was not covered by tests
let result = handler.allow_list(allow_list).await;
send_response(send, &mut buf, result).await?;
}
RequestCode::BlockList => {
let block_list = codec
.deserialize::<HostNetworkGroup>(body)
.map_err(frame::RecvError::DeserializationFailure)?;
let block_list = parse_args::<HostNetworkGroup>(body)?;

Check warning on line 274 in src/request.rs

View check run for this annotation

Codecov / codecov/patch

src/request.rs#L274

Added line #L274 was not covered by tests
let result = handler.block_list(block_list).await;
send_response(send, &mut buf, result).await?;
}
RequestCode::EchoRequest => {
send_response(send, &mut buf, Ok::<(), String>(())).await?;
}
RequestCode::TrustedUserAgentList => {
let user_agent_list = codec
.deserialize::<Vec<&str>>(body)
.map_err(frame::RecvError::DeserializationFailure)?;
let user_agent_list = parse_args::<Vec<&str>>(body)?;

Check warning on line 282 in src/request.rs

View check run for this annotation

Codecov / codecov/patch

src/request.rs#L282

Added line #L282 was not covered by tests
let result = handler.trusted_user_agent_list(&user_agent_list).await;
send_response(send, &mut buf, result).await?;
}
RequestCode::ReloadFilterRule => {
let rules = codec
.deserialize::<Vec<TrafficFilterRule>>(body)
.map_err(frame::RecvError::DeserializationFailure)?;
let rules = parse_args::<Vec<TrafficFilterRule>>(body)?;

Check warning on line 287 in src/request.rs

View check run for this annotation

Codecov / codecov/patch

src/request.rs#L287

Added line #L287 was not covered by tests
let result = handler.update_traffic_filter_rules(&rules).await;
send_response(send, &mut buf, result).await?;
}
RequestCode::GetConfig => {
send_response(send, &mut buf, handler.get_config().await).await?;
}
RequestCode::SetConfig => {
let conf = codec
.deserialize::<Config>(body)
.map_err(frame::RecvError::DeserializationFailure)?;
let conf = parse_args::<Config>(body)?;

Check warning on line 295 in src/request.rs

View check run for this annotation

Codecov / codecov/patch

src/request.rs#L295

Added line #L295 was not covered by tests
let result = handler.set_config(conf).await;
send_response(send, &mut buf, result).await?;
}
Expand All @@ -335,7 +315,25 @@ pub async fn handle<H: Handler>(
Ok(())
}

async fn send_response<T: Serialize>(
/// Parses the arguments of a request.
///
/// # Errors
///
/// Returns `frame::RecvError::DeserializationFailure`: if the arguments could
/// not be deserialized.
pub fn parse_args<'de, T: Deserialize<'de>>(args: &'de [u8]) -> Result<T, frame::RecvError> {
bincode::DefaultOptions::new()
.deserialize::<T>(args)
.map_err(frame::RecvError::DeserializationFailure)
}

Check warning on line 328 in src/request.rs

View check run for this annotation

Codecov / codecov/patch

src/request.rs#L324-L328

Added lines #L324 - L328 were not covered by tests

/// Sends a response to a request.
///
/// # Errors
///
/// * `SendError::MessageTooLarge` if `e` is too large to be serialized
/// * `SendError::WriteError` if the message could not be written
pub async fn send_response<T: Serialize>(

Check warning on line 336 in src/request.rs

View check run for this annotation

Codecov / codecov/patch

src/request.rs#L336

Added line #L336 was not covered by tests
send: &mut SendStream,
buf: &mut Vec<u8>,
body: T,
Expand Down

0 comments on commit 7264b52

Please sign in to comment.