diff --git a/src/functions/src/supabase_functions/_async/functions_client.py b/src/functions/src/supabase_functions/_async/functions_client.py index 38ed3c61..ea2c502c 100644 --- a/src/functions/src/supabase_functions/_async/functions_client.py +++ b/src/functions/src/supabase_functions/_async/functions_client.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Literal, Optional, Union from warnings import warn -from httpx import AsyncClient, HTTPError, Response +from httpx import AsyncClient, HTTPError, Response, QueryParams from ..errors import FunctionsHttpError, FunctionsRelayError from ..utils import ( @@ -73,11 +73,16 @@ async def _request( url: str, headers: Optional[Dict[str, str]] = None, json: Optional[Dict[Any, Any]] = None, + params: Optional[QueryParams] = None, ) -> Response: response = ( - await self._client.request(method, url, data=json, headers=headers) + await self._client.request( + method, url, data=json, headers=headers, params=params + ) if isinstance(json, str) - else await self._client.request(method, url, json=json, headers=headers) + else await self._client.request( + method, url, json=json, headers=headers, params=params + ) ) try: response.raise_for_status() @@ -121,8 +126,11 @@ async def invoke( if not is_valid_str_arg(function_name): raise ValueError("function_name must a valid string value.") headers = self.headers + params = QueryParams() body = None response_type = "text/plain" + url = f"{self.url}/{function_name}" + if invoke_options is not None: headers.update(invoke_options.get("headers", {})) response_type = invoke_options.get("responseType", "text/plain") @@ -135,6 +143,8 @@ async def invoke( if region.value != "any": headers["x-region"] = region.value + # Add region as query parameter + params = params.set("forceFunctionRegion", region.value) body = invoke_options.get("body") if isinstance(body, str): @@ -143,7 +153,7 @@ async def invoke( headers["Content-Type"] = "application/json" response = await self._request( - "POST", f"{self.url}/{function_name}", headers=headers, json=body + "POST", url, headers=headers, json=body, params=params ) is_relay_error = response.headers.get("x-relay-header") diff --git a/src/functions/src/supabase_functions/_sync/functions_client.py b/src/functions/src/supabase_functions/_sync/functions_client.py index 95d30e42..8063a7d3 100644 --- a/src/functions/src/supabase_functions/_sync/functions_client.py +++ b/src/functions/src/supabase_functions/_sync/functions_client.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Literal, Optional, Union from warnings import warn -from httpx import Client, HTTPError, Response +from httpx import Client, HTTPError, Response, QueryParams from ..errors import FunctionsHttpError, FunctionsRelayError from ..utils import ( @@ -73,11 +73,14 @@ def _request( url: str, headers: Optional[Dict[str, str]] = None, json: Optional[Dict[Any, Any]] = None, + params: Optional[QueryParams] = None, ) -> Response: response = ( - self._client.request(method, url, data=json, headers=headers) + self._client.request(method, url, data=json, headers=headers, params=params) if isinstance(json, str) - else self._client.request(method, url, json=json, headers=headers) + else self._client.request( + method, url, json=json, headers=headers, params=params + ) ) try: response.raise_for_status() @@ -121,8 +124,11 @@ def invoke( if not is_valid_str_arg(function_name): raise ValueError("function_name must a valid string value.") headers = self.headers + params = QueryParams() body = None response_type = "text/plain" + url = f"{self.url}/{function_name}" + if invoke_options is not None: headers.update(invoke_options.get("headers", {})) response_type = invoke_options.get("responseType", "text/plain") @@ -135,6 +141,8 @@ def invoke( if region.value != "any": headers["x-region"] = region.value + # Add region as query parameter + params = params.set("forceFunctionRegion", region.value) body = invoke_options.get("body") if isinstance(body, str): @@ -142,9 +150,7 @@ def invoke( elif isinstance(body, dict): headers["Content-Type"] = "application/json" - response = self._request( - "POST", f"{self.url}/{function_name}", headers=headers, json=body - ) + response = self._request("POST", url, headers=headers, json=body, params=params) is_relay_error = response.headers.get("x-relay-header") if is_relay_error and is_relay_error == "true": diff --git a/src/functions/tests/_async/test_function_client.py b/src/functions/tests/_async/test_function_client.py index f9b7e61d..d1f47b97 100644 --- a/src/functions/tests/_async/test_function_client.py +++ b/src/functions/tests/_async/test_function_client.py @@ -100,8 +100,11 @@ async def test_invoke_with_region(client: AsyncFunctionsClient): await client.invoke("test-function", {"region": FunctionRegion("us-east-1")}) - _, kwargs = mock_request.call_args + args, kwargs = mock_request.call_args + # Check that x-region header is present assert kwargs["headers"]["x-region"] == "us-east-1" + # Check that the URL contains the forceFunctionRegion query parameter + assert kwargs["params"]["forceFunctionRegion"] == "us-east-1" async def test_invoke_with_region_string(client: AsyncFunctionsClient): @@ -118,8 +121,11 @@ async def test_invoke_with_region_string(client: AsyncFunctionsClient): with pytest.warns(UserWarning, match=r"Use FunctionRegion\(us-east-1\)"): await client.invoke("test-function", {"region": "us-east-1"}) - _, kwargs = mock_request.call_args + args, kwargs = mock_request.call_args + # Check that x-region header is present assert kwargs["headers"]["x-region"] == "us-east-1" + # Check that the URL contains the forceFunctionRegion query parameter + assert kwargs["params"]["forceFunctionRegion"] == "us-east-1" async def test_invoke_with_http_error(client: AsyncFunctionsClient): diff --git a/src/functions/tests/_sync/test_function_client.py b/src/functions/tests/_sync/test_function_client.py index 7f2b8194..6489e91a 100644 --- a/src/functions/tests/_sync/test_function_client.py +++ b/src/functions/tests/_sync/test_function_client.py @@ -94,8 +94,11 @@ def test_invoke_with_region(client: SyncFunctionsClient): client.invoke("test-function", {"region": FunctionRegion("us-east-1")}) - _, kwargs = mock_request.call_args + args, kwargs = mock_request.call_args + # Check that x-region header is present assert kwargs["headers"]["x-region"] == "us-east-1" + # Check that the URL contains the forceFunctionRegion query parameter + assert kwargs["params"]["forceFunctionRegion"] == "us-east-1" def test_invoke_with_region_string(client: SyncFunctionsClient): @@ -110,8 +113,11 @@ def test_invoke_with_region_string(client: SyncFunctionsClient): with pytest.warns(UserWarning, match=r"Use FunctionRegion\(us-east-1\)"): client.invoke("test-function", {"region": "us-east-1"}) - _, kwargs = mock_request.call_args + args, kwargs = mock_request.call_args + # Check that x-region header is present assert kwargs["headers"]["x-region"] == "us-east-1" + # Check that the URL contains the forceFunctionRegion query parameter + assert kwargs["params"]["forceFunctionRegion"] == "us-east-1" def test_invoke_with_http_error(client: SyncFunctionsClient):