diff --git a/tests/controller/test_simbot_controller.py b/tests/controller/test_simbot_controller.py index d23db6b..bd6d30d 100644 --- a/tests/controller/test_simbot_controller.py +++ b/tests/controller/test_simbot_controller.py @@ -1,18 +1,37 @@ from typing import Any -from pytest_cases import parametrize_with_cases +from pytest_cases import parametrize, parametrize_with_cases +from emma_experience_hub.api.clients.simbot import SimbotActionPredictionClient from emma_experience_hub.api.controllers import SimBotController from emma_experience_hub.common.settings import SimBotSettings from emma_experience_hub.datamodels.simbot import SimBotRequest +from tests.fixtures.clients import ( + mock_policy_response_goto_room, + mock_policy_response_search, + mock_policy_response_toggle_computer, +) from tests.fixtures.simbot_api_requests import SimBotRequestCases -@parametrize_with_cases("request_body", cases=SimBotRequestCases) +@parametrize_with_cases("request_body", cases=SimBotRequestCases.case_without_previous_actions) +@parametrize( + "mock_policy_responses", + [ + mock_policy_response_toggle_computer, + mock_policy_response_goto_room, + mock_policy_response_search, + ], +) def test_simbot_api( request_body: dict[str, Any], simbot_settings: SimBotSettings, + mock_feature_extraction_response: Any, + mock_policy_responses: Any, ) -> None: + """Test the SimBot API.""" simbot_request = SimBotRequest.parse_obj(request_body) controller = SimBotController.from_simbot_settings(simbot_settings) - controller.handle_request_from_simbot_arena(simbot_request) + response = controller.handle_request_from_simbot_arena(simbot_request) + assert response.actions[0].raw_output is not None + assert response.actions[0].raw_output == SimbotActionPredictionClient.generate() # type: ignore[call-arg] diff --git a/tests/fixtures/clients.py b/tests/fixtures/clients.py index 0befa45..1783c57 100644 --- a/tests/fixtures/clients.py +++ b/tests/fixtures/clients.py @@ -1,11 +1,19 @@ from collections.abc import Generator +from typing import Any import httpx from pydantic import AnyHttpUrl -from pytest_cases import fixture +from pytest import MonkeyPatch, fixture from pytest_httpx import HTTPXMock from emma_experience_hub.api.clients import ProfanityFilterClient +from emma_experience_hub.api.clients.simbot import ( + SimbotActionPredictionClient, + SimBotFeaturesClient, + SimBotNLUIntentClient, +) +from emma_experience_hub.datamodels import EmmaExtractedFeatures +from tests.fixtures.simbot_arena_constants import create_placeholder_features_frames @fixture(scope="session") @@ -19,3 +27,62 @@ def custom_response(request: httpx.Request) -> httpx.Response: # noqa: WPS430 yield ProfanityFilterClient( endpoint=AnyHttpUrl(url="http://localhost", scheme="http"), timeout=None ) + + +@fixture +def mock_feature_extraction_response(monkeypatch: MonkeyPatch) -> None: + """Mock get_features from the SimBotFeaturesClient.""" + + def mock_features(*args: Any, **kwargs: Any) -> list[EmmaExtractedFeatures]: # noqa: WPS430 + features = create_placeholder_features_frames() + return features + + monkeypatch.setattr(SimBotFeaturesClient, "get_features", mock_features) + + +@fixture +def mock_policy_response_goto_room(monkeypatch: MonkeyPatch) -> None: + """Mock the responses of EMMA policy.""" + + def get_nlu(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 + return "" + + def get_action(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 + return "goto breakroom." + + monkeypatch.setattr(SimBotNLUIntentClient, "generate", get_nlu) + monkeypatch.setattr(SimbotActionPredictionClient, "generate", get_action) + + +@fixture +def mock_policy_response_toggle_computer(monkeypatch: MonkeyPatch) -> None: + """Mock the responses of EMMA policy when the input instruction is about turning on the + computer.""" + + def get_nlu(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 + return "" + + def get_action(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 + return "toggle computer ." + + monkeypatch.setattr(SimBotNLUIntentClient, "generate", get_nlu) + monkeypatch.setattr(SimbotActionPredictionClient, "generate", get_action) + + +@fixture +def mock_policy_response_search(monkeypatch: MonkeyPatch) -> None: + """Mock the responses of EMMA policy when the input instruction is about searching an + object.""" + + def get_nlu(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 + return "" + + def get_object(*args: Any, **kwargs: Any) -> list[str]: # noqa: WPS430 + return [" "] + + def get_target(*args: Any, **kwargs: Any) -> str: # noqa: WPS430 + return "goto object ." + + monkeypatch.setattr(SimBotNLUIntentClient, "generate", get_nlu) + monkeypatch.setattr(SimbotActionPredictionClient, "find_object_in_scene", get_object) + monkeypatch.setattr(SimbotActionPredictionClient, "generate", get_target)