Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Changes from EMNLP evaluation and add test for SimBotController #10

Merged
merged 10 commits into from
Dec 2, 2023
2 changes: 2 additions & 0 deletions docker/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ services:
LOG_LEVEL: debug
MODEL_NAME: "heriot-watt/emma-base"
MODEL_CHECKPOINT_PATH: "/app/model/${INTENT_EXTRACTOR_MODEL}"
ENABLE_PREDICTION_PATCHING: False
ports:
- "5501:6000"
healthcheck:
Expand All @@ -69,6 +70,7 @@ services:
environment:
LOG_LEVEL: debug
MODEL_CHECKPOINT_PATH: "/app/model/${INSTRUCTION_PREDICTOR_MODEL}"
ENABLE_PREDICTION_PATCHING: False
ports:
- "5502:6000"
healthcheck:
Expand Down
1 change: 0 additions & 1 deletion src/emma_experience_hub/api/clients/emma_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def _make_request(
emma_policy_request = EmmaPolicyRequest(
environment_history=environment_state_history,
dialogue_history=dialogue_history,
inventory=inventory_entity,
force_stop_token=force_stop_token,
)
logger.debug(f"Sending {emma_policy_request.num_images} images.")
Expand Down
6 changes: 0 additions & 6 deletions src/emma_experience_hub/api/clients/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@
import torch
from loguru import logger
from numpy.typing import ArrayLike
from opentelemetry import trace
from PIL import Image

from emma_common.datamodels import TorchDataMixin
from emma_experience_hub.api.clients.client import Client
from emma_experience_hub.datamodels import EmmaExtractedFeatures


tracer = trace.get_tracer(__name__)


class FeatureExtractorClient(Client):
"""API Client for sending requests to the feature extractor server."""

Expand Down Expand Up @@ -48,7 +44,6 @@ def change_device(self, device: torch.device) -> None:
if response.status_code == httpx.codes.OK:
logger.info(f"Feature extractor model moved to device `{device}`")

@tracer.start_as_current_span("Extract features from single image")
def process_single_image(self, image: Union[Image.Image, ArrayLike]) -> EmmaExtractedFeatures:
"""Submit a request to the feature extraction server for a single image."""
image_bytes = self._convert_single_image_to_bytes(image)
Expand All @@ -70,7 +65,6 @@ def process_single_image(self, image: Union[Image.Image, ArrayLike]) -> EmmaExtr

return feature_response

@tracer.start_as_current_span("Extract features from image batch")
def process_many_images(
self, images: Union[list[Image.Image], list[ArrayLike]]
) -> list[EmmaExtractedFeatures]:
Expand Down
1 change: 1 addition & 0 deletions src/emma_experience_hub/api/clients/simbot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from emma_experience_hub.api.clients.simbot.placeholder_vision import SimBotPlaceholderVisionClient
from emma_experience_hub.api.clients.simbot.qa_intent import SimBotQAIntentClient
from emma_experience_hub.api.clients.simbot.session_db import SimBotSessionDbClient
from emma_experience_hub.api.clients.simbot.session_local_db import SimBotSQLLiteClient
33 changes: 13 additions & 20 deletions src/emma_experience_hub/api/clients/simbot/action_prediction.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from typing import Optional

from opentelemetry import trace

from emma_common.datamodels import DialogueUtterance, EnvironmentStateTurn
from emma_experience_hub.api.clients.emma_policy import EmmaPolicyClient


tracer = trace.get_tracer(__name__)


class SimbotActionPredictionClient(EmmaPolicyClient):
"""Action prediction client which interfaces with the Policy model."""

Expand All @@ -20,14 +15,13 @@ def generate(
inventory_entity: Optional[str] = None,
) -> str:
"""Generate a response from the features and provided language."""
with tracer.start_as_current_span("Generate action"):
return self._make_request(
f"{self._endpoint}/generate",
environment_state_history,
dialogue_history,
force_stop_token=force_stop_token,
inventory_entity=inventory_entity,
)
return self._make_request(
f"{self._endpoint}/generate",
environment_state_history,
dialogue_history,
force_stop_token=force_stop_token,
inventory_entity=inventory_entity,
)

def find_object_in_scene(
self,
Expand All @@ -36,10 +30,9 @@ def find_object_in_scene(
inventory_entity: Optional[str] = None,
) -> list[str]:
"""Generate a response from the features and provided language."""
with tracer.start_as_current_span("Find object in scene"):
return self._make_request(
f"{self._endpoint}/generate_find",
environment_state_history,
dialogue_history,
inventory_entity=inventory_entity,
)
return self._make_request(
f"{self._endpoint}/generate_find",
environment_state_history,
dialogue_history,
inventory_entity=inventory_entity,
)
26 changes: 1 addition & 25 deletions src/emma_experience_hub/api/clients/simbot/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from typing import Generic, Optional, TypeVar, Union

import torch
from cloudpathlib import S3Client, S3Path
from opentelemetry import trace

from emma_experience_hub.api.clients.client import Client
from emma_experience_hub.api.clients.pydantic import PydanticClientMixin, PydanticT
Expand All @@ -25,8 +23,6 @@

T = TypeVar("T")

tracer = trace.get_tracer(__name__)


class SimBotCacheClient(Client, Generic[T]):
"""Cache client for SimBot data."""
Expand All @@ -35,16 +31,12 @@ class SimBotCacheClient(Client, Generic[T]):

def __init__(
self,
bucket_name: str,
local_cache_dir: Path,
object_prefix: Optional[str] = None,
) -> None:
self.bucket = bucket_name
self.prefix = object_prefix if object_prefix is not None else ""
self._local_cache_dir = local_cache_dir

self._s3 = S3Client(local_cache_dir=self._local_cache_dir)

def healthcheck(self) -> bool:
"""Healthcheck for the client."""
return self._local_cache_dir.exists()
Expand All @@ -61,12 +53,6 @@ def load(self, session_id: str, prediction_request_id: str) -> T:
"""Load the data, converting from bytes to the object."""
raise NotImplementedError()

def upload_to_s3(self, session_id: str, prediction_request_id: str) -> None:
"""Load the cached data to S3."""
cached_data = self._load_bytes(session_id, prediction_request_id)
s3_path = self._create_s3_path(session_id, prediction_request_id)
s3_path.write_bytes(cached_data)

def _save_bytes(self, data: bytes, session_id: str, prediction_request_id: str) -> None:
"""Save the data."""
destination_path = self._create_local_path(session_id, prediction_request_id)
Expand All @@ -84,14 +70,6 @@ def _create_local_path(self, session_id: str, prediction_request_id: str) -> Pat
Path(f"{session_id}/{str(prediction_request_id)}.{self.suffix}")
)

def _create_s3_path(self, session_id: str, prediction_request_id: str) -> S3Path:
"""Build the name of the object on S3."""
object_name = "/".join(
[self.prefix, session_id, f"{str(prediction_request_id)}.{self.suffix}"]
).lstrip("/")

return self._s3.CloudPath(f"s3://{self.bucket}/{object_name}")


class SimBotPydanticCacheClient(PydanticClientMixin[PydanticT], SimBotCacheClient[PydanticT]):
"""Cache Pydantic models for SimBot.
Expand Down Expand Up @@ -119,7 +97,7 @@ def load(self, session_id: str, prediction_request_id: str) -> PydanticT:


class SimBotAuxiliaryMetadataClient(SimBotPydanticCacheClient[SimBotAuxiliaryMetadataPayload]):
"""Cache auxiliary metadata on S3."""
"""Cache auxiliary metadata."""

model = SimBotAuxiliaryMetadataPayload
suffix = "json"
Expand All @@ -130,7 +108,6 @@ class SimBotExtractedFeaturesClient(SimBotCacheClient[list[EmmaExtractedFeatures

suffix = "pt"

@tracer.start_as_current_span("Save extracted features")
def save(
self,
data: list[EmmaExtractedFeatures],
Expand All @@ -151,7 +128,6 @@ def save(
# Write data
self._save_bytes(data_buffer.getvalue(), session_id, prediction_request_id)

@tracer.start_as_current_span("Load extracted features from file")
def load(self, session_id: str, prediction_request_id: str) -> list[EmmaExtractedFeatures]:
"""Load the extracted features from a single file."""
# Load the raw data using torch.
Expand Down
40 changes: 14 additions & 26 deletions src/emma_experience_hub/api/clients/simbot/features.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from loguru import logger
from opentelemetry import trace

from emma_experience_hub.api.clients.client import Client
from emma_experience_hub.api.clients.feature_extractor import FeatureExtractorClient
Expand All @@ -13,9 +12,6 @@
from emma_experience_hub.datamodels.simbot.payloads import SimBotAuxiliaryMetadataPayload


tracer = trace.get_tracer(__name__)


class SimBotFeaturesClient(Client):
"""Extract features and cache them."""

Expand All @@ -41,12 +37,10 @@ def healthcheck(self) -> bool:
]
)

@tracer.start_as_current_span("Check if features exist")
def check_exist(self, turn: SimBotSessionTurn) -> bool:
"""Check whether features already exist for the given turn."""
return self.features_cache_client.check_exist(turn.session_id, turn.prediction_request_id)

@tracer.start_as_current_span("Get features")
def get_features(self, turn: SimBotSessionTurn) -> list[EmmaExtractedFeatures]:
"""Get the features for the given turn."""
logger.debug("Getting features for turn...")
Expand All @@ -65,39 +59,33 @@ def get_features(self, turn: SimBotSessionTurn) -> list[EmmaExtractedFeatures]:

return features

@tracer.start_as_current_span("Get auxiliary metadata")
def get_auxiliary_metadata(self, turn: SimBotSessionTurn) -> SimBotAuxiliaryMetadataPayload:
"""Cache the auxiliary metadata for the given turn."""
# Check whether the auxiliary metadata exists within the cache
with tracer.start_as_current_span("Check auxiliary metadata cache"):
auxiliary_metadata_exists = self.auxiliary_metadata_cache_client.check_exist(
turn.session_id, turn.prediction_request_id
)
auxiliary_metadata_exists = self.auxiliary_metadata_cache_client.check_exist(
turn.session_id, turn.prediction_request_id
)

# Load the auxiliary metadata from the cache or the EFS URI
if auxiliary_metadata_exists:
with tracer.start_as_current_span("Load auxiliary metadata from cache"):
auxiliary_metadata = self.auxiliary_metadata_cache_client.load(
turn.session_id, turn.prediction_request_id
)
auxiliary_metadata = self.auxiliary_metadata_cache_client.load(
turn.session_id, turn.prediction_request_id
)
else:
with tracer.start_as_current_span("Load auxiliary metadata from EFS"):
auxiliary_metadata = SimBotAuxiliaryMetadataPayload.from_efs_uri(
uri=turn.auxiliary_metadata_uri
)
auxiliary_metadata = SimBotAuxiliaryMetadataPayload.from_efs_uri(
uri=turn.auxiliary_metadata_uri
)

# If it has not been cached, upload it to the cache
if not auxiliary_metadata_exists:
with tracer.start_as_current_span("Save auxiliary metadata to cache"):
self.auxiliary_metadata_cache_client.save(
auxiliary_metadata,
turn.session_id,
turn.prediction_request_id,
)
self.auxiliary_metadata_cache_client.save(
auxiliary_metadata,
turn.session_id,
turn.prediction_request_id,
)

return auxiliary_metadata

@tracer.start_as_current_span("Get mask for embiggenator")
def get_mask_for_embiggenator(self, turn: SimBotSessionTurn) -> list[list[int]]:
"""Try to replace the object mask with the placeholder model output if needed."""
image = next(iter(self.get_auxiliary_metadata(turn).images))
Expand Down
12 changes: 3 additions & 9 deletions src/emma_experience_hub/api/clients/simbot/hacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
import httpx
from loguru import logger
from methodtools import lru_cache
from opentelemetry import trace
from pydantic import BaseModel

from emma_experience_hub.api.clients import Client
from emma_experience_hub.datamodels.simbot.actions import SimBotAction


tracer = trace.get_tracer(__name__)

LRU_CACHE_MAX_SIZE = 64


Expand Down Expand Up @@ -41,16 +38,14 @@ def healthcheck(self) -> bool:

def get_low_level_prediction_from_raw_text(self, utterance: str) -> Optional[str]:
"""Generate a response from the provided language."""
with tracer.start_as_current_span("Match text to template"):
response = self._get_low_level_prediction_from_raw_text(utterance)
response = self._get_low_level_prediction_from_raw_text(utterance)

logger.debug(f"Cache info: {self._get_low_level_prediction_from_raw_text.cache_info()}")
return response

def get_room_prediction_from_raw_text(self, utterance: str) -> Optional[SimBotHacksRoom]:
"""Generate a room prediction from the provided language."""
with tracer.start_as_current_span("Get room from raw text"):
response = self._get_room_prediction_from_raw_text(utterance)
response = self._get_room_prediction_from_raw_text(utterance)

logger.debug(f"Cache info: {self._get_room_prediction_from_raw_text.cache_info()}")
return response
Expand All @@ -62,8 +57,7 @@ def get_anticipator_prediction_from_action(
entity_labels: Optional[list[str]] = None,
) -> Optional[SimBotHacksAnticipator]:
"""Generate possible plan of instructions from the given action."""
with tracer.start_as_current_span("Get anticipator plan"):
response = self._get_anticipated_instructions(action, inventory_entity, entity_labels)
response = self._get_anticipated_instructions(action, inventory_entity, entity_labels)

logger.debug(f"Cache info: {self._get_room_prediction_from_raw_text.cache_info()}")
return response
Expand Down
18 changes: 6 additions & 12 deletions src/emma_experience_hub/api/clients/simbot/nlu_intent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from typing import Optional

from opentelemetry import trace

from emma_common.datamodels import DialogueUtterance, EnvironmentStateTurn
from emma_experience_hub.api.clients.emma_policy import EmmaPolicyClient


tracer = trace.get_tracer(__name__)


class SimBotNLUIntentClient(EmmaPolicyClient):
"""API Client for SimBot NLU."""

Expand All @@ -19,10 +14,9 @@ def generate(
inventory_entity: Optional[str] = None,
) -> str:
"""Generate a response from the features and provided language."""
with tracer.start_as_current_span("Generate NLU intent"):
return self._make_request(
f"{self._endpoint}/generate",
environment_state_history,
dialogue_history,
inventory_entity,
)
return self._make_request(
f"{self._endpoint}/generate",
environment_state_history,
dialogue_history,
inventory_entity,
)
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import httpx
from loguru import logger
from opentelemetry import trace
from PIL import Image

from emma_experience_hub.api.clients.feature_extractor import FeatureExtractorClient


tracer = trace.get_tracer(__name__)


class SimBotPlaceholderVisionClient(FeatureExtractorClient):
"""Run the placeholder vision client."""

Expand Down
Loading