Skip to content

Commit

Permalink
Merge pull request #10 from emma-heriot-watt/fair-eval
Browse files Browse the repository at this point in the history
feat: Changes from EMNLP evaluation and add test for SimBotController
  • Loading branch information
MalvinaNikandrou committed Dec 2, 2023
2 parents 469e83f + 42abeff commit 2926978
Show file tree
Hide file tree
Showing 40 changed files with 367 additions and 433 deletions.
2 changes: 2 additions & 0 deletions docker/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ services:
LOG_LEVEL: debug
MODEL_NAME: "heriot-watt/emma-base"
MODEL_CHECKPOINT_PATH: "/app/model/${INTENT_EXTRACTOR_MODEL}"
ENABLE_PREDICTION_PATCHING: False
ports:
- "5501:6000"
healthcheck:
Expand All @@ -69,6 +70,7 @@ services:
environment:
LOG_LEVEL: debug
MODEL_CHECKPOINT_PATH: "/app/model/${INSTRUCTION_PREDICTOR_MODEL}"
ENABLE_PREDICTION_PATCHING: False
ports:
- "5502:6000"
healthcheck:
Expand Down
1 change: 0 additions & 1 deletion src/emma_experience_hub/api/clients/emma_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def _make_request(
emma_policy_request = EmmaPolicyRequest(
environment_history=environment_state_history,
dialogue_history=dialogue_history,
inventory=inventory_entity,
force_stop_token=force_stop_token,
)
logger.debug(f"Sending {emma_policy_request.num_images} images.")
Expand Down
6 changes: 0 additions & 6 deletions src/emma_experience_hub/api/clients/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@
import torch
from loguru import logger
from numpy.typing import ArrayLike
from opentelemetry import trace
from PIL import Image

from emma_common.datamodels import TorchDataMixin
from emma_experience_hub.api.clients.client import Client
from emma_experience_hub.datamodels import EmmaExtractedFeatures


tracer = trace.get_tracer(__name__)


class FeatureExtractorClient(Client):
"""API Client for sending requests to the feature extractor server."""

Expand Down Expand Up @@ -48,7 +44,6 @@ def change_device(self, device: torch.device) -> None:
if response.status_code == httpx.codes.OK:
logger.info(f"Feature extractor model moved to device `{device}`")

@tracer.start_as_current_span("Extract features from single image")
def process_single_image(self, image: Union[Image.Image, ArrayLike]) -> EmmaExtractedFeatures:
"""Submit a request to the feature extraction server for a single image."""
image_bytes = self._convert_single_image_to_bytes(image)
Expand All @@ -70,7 +65,6 @@ def process_single_image(self, image: Union[Image.Image, ArrayLike]) -> EmmaExtr

return feature_response

@tracer.start_as_current_span("Extract features from image batch")
def process_many_images(
self, images: Union[list[Image.Image], list[ArrayLike]]
) -> list[EmmaExtractedFeatures]:
Expand Down
1 change: 1 addition & 0 deletions src/emma_experience_hub/api/clients/simbot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from emma_experience_hub.api.clients.simbot.placeholder_vision import SimBotPlaceholderVisionClient
from emma_experience_hub.api.clients.simbot.qa_intent import SimBotQAIntentClient
from emma_experience_hub.api.clients.simbot.session_db import SimBotSessionDbClient
from emma_experience_hub.api.clients.simbot.session_local_db import SimBotSQLLiteClient
33 changes: 13 additions & 20 deletions src/emma_experience_hub/api/clients/simbot/action_prediction.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from typing import Optional

from opentelemetry import trace

from emma_common.datamodels import DialogueUtterance, EnvironmentStateTurn
from emma_experience_hub.api.clients.emma_policy import EmmaPolicyClient


tracer = trace.get_tracer(__name__)


class SimbotActionPredictionClient(EmmaPolicyClient):
"""Action prediction client which interfaces with the Policy model."""

Expand All @@ -20,14 +15,13 @@ def generate(
inventory_entity: Optional[str] = None,
) -> str:
"""Generate a response from the features and provided language."""
with tracer.start_as_current_span("Generate action"):
return self._make_request(
f"{self._endpoint}/generate",
environment_state_history,
dialogue_history,
force_stop_token=force_stop_token,
inventory_entity=inventory_entity,
)
return self._make_request(
f"{self._endpoint}/generate",
environment_state_history,
dialogue_history,
force_stop_token=force_stop_token,
inventory_entity=inventory_entity,
)

def find_object_in_scene(
self,
Expand All @@ -36,10 +30,9 @@ def find_object_in_scene(
inventory_entity: Optional[str] = None,
) -> list[str]:
"""Generate a response from the features and provided language."""
with tracer.start_as_current_span("Find object in scene"):
return self._make_request(
f"{self._endpoint}/generate_find",
environment_state_history,
dialogue_history,
inventory_entity=inventory_entity,
)
return self._make_request(
f"{self._endpoint}/generate_find",
environment_state_history,
dialogue_history,
inventory_entity=inventory_entity,
)
26 changes: 1 addition & 25 deletions src/emma_experience_hub/api/clients/simbot/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from typing import Generic, Optional, TypeVar, Union

import torch
from cloudpathlib import S3Client, S3Path
from opentelemetry import trace

from emma_experience_hub.api.clients.client import Client
from emma_experience_hub.api.clients.pydantic import PydanticClientMixin, PydanticT
Expand All @@ -25,8 +23,6 @@

T = TypeVar("T")

tracer = trace.get_tracer(__name__)


class SimBotCacheClient(Client, Generic[T]):
"""Cache client for SimBot data."""
Expand All @@ -35,16 +31,12 @@ class SimBotCacheClient(Client, Generic[T]):

def __init__(
self,
bucket_name: str,
local_cache_dir: Path,
object_prefix: Optional[str] = None,
) -> None:
self.bucket = bucket_name
self.prefix = object_prefix if object_prefix is not None else ""
self._local_cache_dir = local_cache_dir

self._s3 = S3Client(local_cache_dir=self._local_cache_dir)

def healthcheck(self) -> bool:
"""Healthcheck for the client."""
return self._local_cache_dir.exists()
Expand All @@ -61,12 +53,6 @@ def load(self, session_id: str, prediction_request_id: str) -> T:
"""Load the data, converting from bytes to the object."""
raise NotImplementedError()

def upload_to_s3(self, session_id: str, prediction_request_id: str) -> None:
"""Load the cached data to S3."""
cached_data = self._load_bytes(session_id, prediction_request_id)
s3_path = self._create_s3_path(session_id, prediction_request_id)
s3_path.write_bytes(cached_data)

def _save_bytes(self, data: bytes, session_id: str, prediction_request_id: str) -> None:
"""Save the data."""
destination_path = self._create_local_path(session_id, prediction_request_id)
Expand All @@ -84,14 +70,6 @@ def _create_local_path(self, session_id: str, prediction_request_id: str) -> Pat
Path(f"{session_id}/{str(prediction_request_id)}.{self.suffix}")
)

def _create_s3_path(self, session_id: str, prediction_request_id: str) -> S3Path:
"""Build the name of the object on S3."""
object_name = "/".join(
[self.prefix, session_id, f"{str(prediction_request_id)}.{self.suffix}"]
).lstrip("/")

return self._s3.CloudPath(f"s3://{self.bucket}/{object_name}")


class SimBotPydanticCacheClient(PydanticClientMixin[PydanticT], SimBotCacheClient[PydanticT]):
"""Cache Pydantic models for SimBot.
Expand Down Expand Up @@ -119,7 +97,7 @@ def load(self, session_id: str, prediction_request_id: str) -> PydanticT:


class SimBotAuxiliaryMetadataClient(SimBotPydanticCacheClient[SimBotAuxiliaryMetadataPayload]):
"""Cache auxiliary metadata on S3."""
"""Cache auxiliary metadata."""

model = SimBotAuxiliaryMetadataPayload
suffix = "json"
Expand All @@ -130,7 +108,6 @@ class SimBotExtractedFeaturesClient(SimBotCacheClient[list[EmmaExtractedFeatures

suffix = "pt"

@tracer.start_as_current_span("Save extracted features")
def save(
self,
data: list[EmmaExtractedFeatures],
Expand All @@ -151,7 +128,6 @@ def save(
# Write data
self._save_bytes(data_buffer.getvalue(), session_id, prediction_request_id)

@tracer.start_as_current_span("Load extracted features from file")
def load(self, session_id: str, prediction_request_id: str) -> list[EmmaExtractedFeatures]:
"""Load the extracted features from a single file."""
# Load the raw data using torch.
Expand Down
40 changes: 14 additions & 26 deletions src/emma_experience_hub/api/clients/simbot/features.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from loguru import logger
from opentelemetry import trace

from emma_experience_hub.api.clients.client import Client
from emma_experience_hub.api.clients.feature_extractor import FeatureExtractorClient
Expand All @@ -13,9 +12,6 @@
from emma_experience_hub.datamodels.simbot.payloads import SimBotAuxiliaryMetadataPayload


tracer = trace.get_tracer(__name__)


class SimBotFeaturesClient(Client):
"""Extract features and cache them."""

Expand All @@ -41,12 +37,10 @@ def healthcheck(self) -> bool:
]
)

@tracer.start_as_current_span("Check if features exist")
def check_exist(self, turn: SimBotSessionTurn) -> bool:
"""Check whether features already exist for the given turn."""
return self.features_cache_client.check_exist(turn.session_id, turn.prediction_request_id)

@tracer.start_as_current_span("Get features")
def get_features(self, turn: SimBotSessionTurn) -> list[EmmaExtractedFeatures]:
"""Get the features for the given turn."""
logger.debug("Getting features for turn...")
Expand All @@ -65,39 +59,33 @@ def get_features(self, turn: SimBotSessionTurn) -> list[EmmaExtractedFeatures]:

return features

@tracer.start_as_current_span("Get auxiliary metadata")
def get_auxiliary_metadata(self, turn: SimBotSessionTurn) -> SimBotAuxiliaryMetadataPayload:
"""Cache the auxiliary metadata for the given turn."""
# Check whether the auxiliary metadata exists within the cache
with tracer.start_as_current_span("Check auxiliary metadata cache"):
auxiliary_metadata_exists = self.auxiliary_metadata_cache_client.check_exist(
turn.session_id, turn.prediction_request_id
)
auxiliary_metadata_exists = self.auxiliary_metadata_cache_client.check_exist(
turn.session_id, turn.prediction_request_id
)

# Load the auxiliary metadata from the cache or the EFS URI
if auxiliary_metadata_exists:
with tracer.start_as_current_span("Load auxiliary metadata from cache"):
auxiliary_metadata = self.auxiliary_metadata_cache_client.load(
turn.session_id, turn.prediction_request_id
)
auxiliary_metadata = self.auxiliary_metadata_cache_client.load(
turn.session_id, turn.prediction_request_id
)
else:
with tracer.start_as_current_span("Load auxiliary metadata from EFS"):
auxiliary_metadata = SimBotAuxiliaryMetadataPayload.from_efs_uri(
uri=turn.auxiliary_metadata_uri
)
auxiliary_metadata = SimBotAuxiliaryMetadataPayload.from_efs_uri(
uri=turn.auxiliary_metadata_uri
)

# If it has not been cached, upload it to the cache
if not auxiliary_metadata_exists:
with tracer.start_as_current_span("Save auxiliary metadata to cache"):
self.auxiliary_metadata_cache_client.save(
auxiliary_metadata,
turn.session_id,
turn.prediction_request_id,
)
self.auxiliary_metadata_cache_client.save(
auxiliary_metadata,
turn.session_id,
turn.prediction_request_id,
)

return auxiliary_metadata

@tracer.start_as_current_span("Get mask for embiggenator")
def get_mask_for_embiggenator(self, turn: SimBotSessionTurn) -> list[list[int]]:
"""Try to replace the object mask with the placeholder model output if needed."""
image = next(iter(self.get_auxiliary_metadata(turn).images))
Expand Down
12 changes: 3 additions & 9 deletions src/emma_experience_hub/api/clients/simbot/hacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
import httpx
from loguru import logger
from methodtools import lru_cache
from opentelemetry import trace
from pydantic import BaseModel

from emma_experience_hub.api.clients import Client
from emma_experience_hub.datamodels.simbot.actions import SimBotAction


tracer = trace.get_tracer(__name__)

LRU_CACHE_MAX_SIZE = 64


Expand Down Expand Up @@ -41,16 +38,14 @@ def healthcheck(self) -> bool:

def get_low_level_prediction_from_raw_text(self, utterance: str) -> Optional[str]:
"""Generate a response from the provided language."""
with tracer.start_as_current_span("Match text to template"):
response = self._get_low_level_prediction_from_raw_text(utterance)
response = self._get_low_level_prediction_from_raw_text(utterance)

logger.debug(f"Cache info: {self._get_low_level_prediction_from_raw_text.cache_info()}")
return response

def get_room_prediction_from_raw_text(self, utterance: str) -> Optional[SimBotHacksRoom]:
"""Generate a room prediction from the provided language."""
with tracer.start_as_current_span("Get room from raw text"):
response = self._get_room_prediction_from_raw_text(utterance)
response = self._get_room_prediction_from_raw_text(utterance)

logger.debug(f"Cache info: {self._get_room_prediction_from_raw_text.cache_info()}")
return response
Expand All @@ -62,8 +57,7 @@ def get_anticipator_prediction_from_action(
entity_labels: Optional[list[str]] = None,
) -> Optional[SimBotHacksAnticipator]:
"""Generate possible plan of instructions from the given action."""
with tracer.start_as_current_span("Get anticipator plan"):
response = self._get_anticipated_instructions(action, inventory_entity, entity_labels)
response = self._get_anticipated_instructions(action, inventory_entity, entity_labels)

logger.debug(f"Cache info: {self._get_room_prediction_from_raw_text.cache_info()}")
return response
Expand Down
18 changes: 6 additions & 12 deletions src/emma_experience_hub/api/clients/simbot/nlu_intent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from typing import Optional

from opentelemetry import trace

from emma_common.datamodels import DialogueUtterance, EnvironmentStateTurn
from emma_experience_hub.api.clients.emma_policy import EmmaPolicyClient


tracer = trace.get_tracer(__name__)


class SimBotNLUIntentClient(EmmaPolicyClient):
"""API Client for SimBot NLU."""

Expand All @@ -19,10 +14,9 @@ def generate(
inventory_entity: Optional[str] = None,
) -> str:
"""Generate a response from the features and provided language."""
with tracer.start_as_current_span("Generate NLU intent"):
return self._make_request(
f"{self._endpoint}/generate",
environment_state_history,
dialogue_history,
inventory_entity,
)
return self._make_request(
f"{self._endpoint}/generate",
environment_state_history,
dialogue_history,
inventory_entity,
)
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import httpx
from loguru import logger
from opentelemetry import trace
from PIL import Image

from emma_experience_hub.api.clients.feature_extractor import FeatureExtractorClient


tracer = trace.get_tracer(__name__)


class SimBotPlaceholderVisionClient(FeatureExtractorClient):
"""Run the placeholder vision client."""

Expand Down
Loading

0 comments on commit 2926978

Please sign in to comment.