Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add a pytest that mocks the clients and runs the entire api #11

Merged
merged 3 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions tests/controller/test_simbot_controller.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
from typing import Any

from pytest_cases import parametrize_with_cases
from pytest_cases import parametrize, parametrize_with_cases

from emma_experience_hub.api.clients.simbot import SimbotActionPredictionClient
from emma_experience_hub.api.controllers import SimBotController
from emma_experience_hub.common.settings import SimBotSettings
from emma_experience_hub.datamodels.simbot import SimBotRequest
from tests.fixtures.clients import (
mock_policy_response_goto_room,
mock_policy_response_search,
mock_policy_response_toggle_computer,
)
from tests.fixtures.simbot_api_requests import SimBotRequestCases


@parametrize_with_cases("request_body", cases=SimBotRequestCases)
@parametrize_with_cases("request_body", cases=SimBotRequestCases.case_without_previous_actions)
@parametrize(
"mock_policy_responses",
[
mock_policy_response_toggle_computer,
mock_policy_response_goto_room,
mock_policy_response_search,
],
)
def test_simbot_api(
request_body: dict[str, Any],
simbot_settings: SimBotSettings,
mock_feature_extraction_response: Any,
mock_policy_responses: Any,
) -> None:
"""Test the SimBot API."""
simbot_request = SimBotRequest.parse_obj(request_body)
controller = SimBotController.from_simbot_settings(simbot_settings)
controller.handle_request_from_simbot_arena(simbot_request)
response = controller.handle_request_from_simbot_arena(simbot_request)
assert response.actions[0].raw_output is not None
assert response.actions[0].raw_output == SimbotActionPredictionClient.generate() # type: ignore[call-arg]
69 changes: 68 additions & 1 deletion tests/fixtures/clients.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from collections.abc import Generator
from typing import Any

import httpx
from pydantic import AnyHttpUrl
from pytest_cases import fixture
from pytest import MonkeyPatch, fixture
from pytest_httpx import HTTPXMock

from emma_experience_hub.api.clients import ProfanityFilterClient
from emma_experience_hub.api.clients.simbot import (
SimbotActionPredictionClient,
SimBotFeaturesClient,
SimBotNLUIntentClient,
)
from emma_experience_hub.datamodels import EmmaExtractedFeatures
from tests.fixtures.simbot_arena_constants import create_placeholder_features_frames


@fixture(scope="session")
Expand All @@ -19,3 +27,62 @@ def custom_response(request: httpx.Request) -> httpx.Response: # noqa: WPS430
yield ProfanityFilterClient(
endpoint=AnyHttpUrl(url="http://localhost", scheme="http"), timeout=None
)


@fixture
def mock_feature_extraction_response(monkeypatch: MonkeyPatch) -> None:
"""Mock get_features from the SimBotFeaturesClient."""

def mock_features(*args: Any, **kwargs: Any) -> list[EmmaExtractedFeatures]: # noqa: WPS430
features = create_placeholder_features_frames()
return features

monkeypatch.setattr(SimBotFeaturesClient, "get_features", mock_features)


@fixture
def mock_policy_response_goto_room(monkeypatch: MonkeyPatch) -> None:
"""Mock the responses of EMMA policy."""

def get_nlu(*args: Any, **kwargs: Any) -> str: # noqa: WPS430
return "<act><one_match>"

def get_action(*args: Any, **kwargs: Any) -> str: # noqa: WPS430
return "goto breakroom<stop>."

monkeypatch.setattr(SimBotNLUIntentClient, "generate", get_nlu)
monkeypatch.setattr(SimbotActionPredictionClient, "generate", get_action)


@fixture
def mock_policy_response_toggle_computer(monkeypatch: MonkeyPatch) -> None:
"""Mock the responses of EMMA policy when the input instruction is about turning on the
computer."""

def get_nlu(*args: Any, **kwargs: Any) -> str: # noqa: WPS430
return "<act><one_match>"

def get_action(*args: Any, **kwargs: Any) -> str: # noqa: WPS430
return "toggle computer <frame_token_1> <vis_token_1><stop>."

monkeypatch.setattr(SimBotNLUIntentClient, "generate", get_nlu)
monkeypatch.setattr(SimbotActionPredictionClient, "generate", get_action)


@fixture
def mock_policy_response_search(monkeypatch: MonkeyPatch) -> None:
"""Mock the responses of EMMA policy when the input instruction is about searching an
object."""

def get_nlu(*args: Any, **kwargs: Any) -> str: # noqa: WPS430
return "<search>"

def get_object(*args: Any, **kwargs: Any) -> list[str]: # noqa: WPS430
return ["<frame_token_1> <vis_token_1>"]

def get_target(*args: Any, **kwargs: Any) -> str: # noqa: WPS430
return "goto object <frame_token_1> <vis_token_1> <stop>."

monkeypatch.setattr(SimBotNLUIntentClient, "generate", get_nlu)
monkeypatch.setattr(SimbotActionPredictionClient, "find_object_in_scene", get_object)
monkeypatch.setattr(SimbotActionPredictionClient, "generate", get_target)