diff --git a/crates/wasi-experimental-http-wasmtime/src/lib.rs b/crates/wasi-experimental-http-wasmtime/src/lib.rs index 8c7a415..7cbd0b1 100644 --- a/crates/wasi-experimental-http-wasmtime/src/lib.rs +++ b/crates/wasi-experimental-http-wasmtime/src/lib.rs @@ -13,6 +13,7 @@ use url::Url; use wasmtime::*; const MEMORY: &str = "memory"; +const ALLOW_ALL_HOSTS: &str = "insecure:allow-all"; pub type WasiHttpHandle = u32; @@ -481,7 +482,7 @@ impl HttpState { }; let ctx = caller.as_context_mut(); - let http_ctx = get_cx(&mut ctx.data()); + let http_ctx = get_cx(ctx.data()); match HostCalls::req( st.clone(), @@ -619,13 +620,18 @@ fn is_allowed(url: &str, allowed_hosts: Option<&[String]>) -> Result { - let allowed: Result, _> = domains.iter().map(|d| Url::parse(d)).collect(); - let allowed = allowed.map_err(|_| HttpError::InvalidUrl)?; - - Ok(allowed - .iter() - .map(|u| u.host_str().unwrap()) - .any(|x| x == url_host.as_str())) + // check domains has any "insecure:allow-all" wildcard + if domains.iter().any(|domain| domain == ALLOW_ALL_HOSTS) { + Ok(true) + } else { + let allowed: Result, _> = domains.iter().map(|d| Url::parse(d)).collect(); + let allowed = allowed.map_err(|_| HttpError::InvalidUrl)?; + + Ok(allowed + .iter() + .map(|u| u.host_str().unwrap()) + .any(|x| x == url_host.as_str())) + } } None => Ok(false), } @@ -711,3 +717,47 @@ fn test_allowed_domains() { is_allowed("https://test.brigade.sh", Some(allowed_domains.as_ref())).unwrap() ); } + +#[test] +#[allow(clippy::bool_assert_comparison)] +fn test_allowed_domains_with_wildcard() { + let allowed_domains = vec![ + "https://example.com".to_string(), + ALLOW_ALL_HOSTS.to_string(), + "http://192.168.0.1".to_string(), + ]; + + assert_eq!( + true, + is_allowed( + "https://api.brigade.sh/healthz", + Some(allowed_domains.as_ref()) + ) + .unwrap() + ); + assert_eq!( + true, + is_allowed( + "https://example.com/some/path/with/more/paths", + Some(allowed_domains.as_ref()) + ) + .unwrap() + ); + assert_eq!( + true, + is_allowed("http://192.168.0.1/login", Some(allowed_domains.as_ref())).unwrap() + ); + assert_eq!( + true, + is_allowed("https://test.brigade.sh", Some(allowed_domains.as_ref())).unwrap() + ); +} + +#[test] +#[should_panic] +#[allow(clippy::bool_assert_comparison)] +fn test_url_parsing() { + let allowed_domains = vec![ALLOW_ALL_HOSTS.to_string()]; + + is_allowed("not even a url", Some(allowed_domains.as_ref())).unwrap(); +} diff --git a/tests/as/index.ts b/tests/as/index.ts index 079fdd3..d2951f6 100644 --- a/tests/as/index.ts +++ b/tests/as/index.ts @@ -16,7 +16,7 @@ export function post(): void { } export function get(): void { - let res = new RequestBuilder("https://api.brigade.sh/healthz") + let res = new RequestBuilder("https://some-random-api.ml/facts/dog") .method(Method.GET) .send(); @@ -37,7 +37,7 @@ export function concurrent(): void { } function makeReq(): Response { - return new RequestBuilder("https://api.brigade.sh/healthz") + return new RequestBuilder("https://some-random-api.ml/facts/dog") .method(Method.GET) .send(); } diff --git a/tests/integration.rs b/tests/integration.rs index 69c11e0..dfdc27c 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -7,6 +7,8 @@ mod tests { use wasmtime_wasi::sync::WasiCtxBuilder; use wasmtime_wasi::*; + const ALLOW_ALL_HOSTS: &str = "insecure:allow-all"; + // We run the same test in a Tokio and non-Tokio environment // in order to make sure both scenarios are working. @@ -38,7 +40,7 @@ mod tests { fn test_with_allowed_domains() { setup_tests( Some(vec![ - "https://api.brigade.sh".to_string(), + "https://some-random-api.ml".to_string(), "https://postman-echo.com".to_string(), ]), None, @@ -49,13 +51,23 @@ mod tests { async fn test_async_with_allowed_domains() { setup_tests( Some(vec![ - "https://api.brigade.sh".to_string(), + "https://some-random-api.ml".to_string(), "https://postman-echo.com".to_string(), ]), None, ); } + #[test] + fn test_with_wildcard_domain() { + setup_tests(Some(vec![ALLOW_ALL_HOSTS.to_string()]), None); + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_async_with_wildcard_domain() { + setup_tests(Some(vec![ALLOW_ALL_HOSTS.to_string()]), None); + } + #[test] #[should_panic] fn test_concurrent_requests_rust() { @@ -80,7 +92,7 @@ mod tests { let func = "concurrent"; let (instance, mut store) = create_instance( module, - Some(vec!["https://api.brigade.sh".to_string()]), + Some(vec!["https://some-random-api.ml".to_string()]), Some(2), ) .unwrap(); diff --git a/tests/rust/src/lib.rs b/tests/rust/src/lib.rs index 3a24418..f9d8dea 100644 --- a/tests/rust/src/lib.rs +++ b/tests/rust/src/lib.rs @@ -2,13 +2,13 @@ use bytes::Bytes; #[no_mangle] pub extern "C" fn get() { - let url = "https://api.brigade.sh/healthz".to_string(); + let url = "https://some-random-api.ml/facts/dog".to_string(); let req = http::request::Builder::new().uri(&url).body(None).unwrap(); let mut res = wasi_experimental_http::request(req).expect("cannot make get request"); let str = std::str::from_utf8(&res.body_read_all().unwrap()) .unwrap() .to_string(); - assert_eq!(str, r#""#); + assert_eq!(str.is_empty(), false); assert_eq!(res.status_code, 200); assert!(!res .header_get("content-type".to_string()) @@ -47,7 +47,7 @@ pub extern "C" fn post() { #[allow(unused_variables)] #[no_mangle] pub extern "C" fn concurrent() { - let url = "https://api.brigade.sh/healthz".to_string(); + let url = "https://some-random-api.ml/facts/dog".to_string(); // the responses are unused to avoid dropping them. let req1 = make_req(url.clone()); let req2 = make_req(url.clone());