Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions common/src/fetcher/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ mod data;
use backon::{ExponentialBuilder, Retryable};
pub use data::*;

use crate::http::calculate_retry_after_from_response_header;
use reqwest::{Client, ClientBuilder, IntoUrl, Method, Response};
use crate::http::{calculate_retry_after_from_response_header, get_client_error};
use reqwest::{Client, ClientBuilder, IntoUrl, Method, Response, StatusCode};
use std::fmt::Debug;
use std::future::Future;
use std::marker::PhantomData;
Expand All @@ -31,6 +31,8 @@ pub enum Error {
Request(#[from] reqwest::Error),
#[error("Rate limited (HTTP 429), retry after {0:?}")]
RateLimited(Duration),
#[error("Client error: {0}")]
ClientError(StatusCode),
}

/// Options for the [`Fetcher`]
Expand Down Expand Up @@ -144,6 +146,7 @@ impl Fetcher {

(|| async { self.fetch_once(url.clone(), &processor).await })
.retry(retry)
.when(|e| !matches!(e, Error::ClientError(_)))
.adjust(|e, dur| {
if let Error::RateLimited(retry_after) = e {
if let Some(dur_value) = dur
Expand Down Expand Up @@ -175,6 +178,10 @@ impl Fetcher {
log::info!("Rate limited (429), retry after: {:?}", retry_after);
return Err(Error::RateLimited(retry_after));
}
if let Some(status_code) = get_client_error(&response) {
log::info!("Client error: {}", status_code);
return Err(Error::ClientError(status_code));
}

Ok(processor.process(response).await?)
}
Expand Down
9 changes: 9 additions & 0 deletions common/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,12 @@ pub fn calculate_retry_after_from_response_header(
}
None
}

pub fn get_client_error(response: &Response) -> Option<StatusCode> {
let status = response.status();
if status.is_client_error() {
Some(status)
} else {
None
}
}
26 changes: 25 additions & 1 deletion common/tests/fetcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use tokio::net::TcpListener;
use walker_common::fetcher::{Fetcher, FetcherOptions};
use walker_common::fetcher::{Error, Fetcher, FetcherOptions};

/// Test helper to start a mock HTTP server
async fn start_mock_server<F>(handler: F) -> String
Expand Down Expand Up @@ -60,6 +60,30 @@ async fn test_successful_fetch() {
assert_eq!(result, "Hello, World!");
}

#[tokio::test]
async fn test_404_should_not_retry() {
let attempt_count = Arc::new(AtomicUsize::new(0));
let attempt_count_clone = attempt_count.clone();

let server = start_mock_server(move |_req| {
attempt_count_clone.fetch_add(1, Ordering::SeqCst);
let builder = hyper::Response::builder().status(StatusCode::NOT_FOUND);
builder.body("Not found".to_string()).unwrap()
})
.await;

let fetcher = Fetcher::new(FetcherOptions::new().retries(2))
.await
.unwrap();

let result: Result<String, Error> = fetcher.fetch(&server).await;
match result {
Err(Error::ClientError(code)) => assert_eq!(code, StatusCode::NOT_FOUND),
other => panic!("expected ClientError(404), got {other:?}"),
}
assert_eq!(attempt_count.load(Ordering::SeqCst), 1);
}

#[rstest]
#[case::with_retry_after_header(Some("1"), 1)]
#[case::without_retry_after_header(None, 10)]
Expand Down