Skip to content

Commit

Permalink
fix ci, formatting, flake8, and types
Browse files Browse the repository at this point in the history
  • Loading branch information
amitkparekh committed Dec 1, 2023
1 parent 798c694 commit 469e83f
Show file tree
Hide file tree
Showing 13 changed files with 30 additions and 27 deletions.
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ extend-ignore =
E501,
# Stop finding commented out code because it's mistaking shape annotations for code
E800,
# Don't complain about asserts
S101,
# Stop complaining about subprocess, we need it for this project
S404,S602,S603,S607,
# Stop complaining about using functions from random
Expand Down
22 changes: 11 additions & 11 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_language_version:
repos:
# -------------------------- Version control checks -------------------------- #
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-merge-conflict
name: Check for merge conflicts
Expand All @@ -21,15 +21,15 @@ repos:
name: Check for destroyed symlinks

- repo: https://github.com/sirosen/check-jsonschema
rev: 0.22.0
rev: 0.27.2
hooks:
- id: check-github-workflows
name: Validate GitHub workflows
types: [yaml]

# ----------------------------- Check file issues ---------------------------- #
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-toml
name: Check TOML
Expand All @@ -53,7 +53,7 @@ repos:

# ------------------------------ Python checking ----------------------------- #
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: debug-statements
name: Check for debugger statements
Expand All @@ -77,19 +77,19 @@ repos:

# ----------------------------- Automatic linters ---------------------------- #
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
rev: v3.15.0
hooks:
- id: pyupgrade
name: Update syntax for newer Python
types: [python]
args: ["--py39-plus"]
- repo: https://github.com/sirosen/texthooks
rev: 0.5.0
rev: 0.6.3
hooks:
- id: fix-smartquotes
name: Fix Smart Quotes
- repo: https://github.com/asottile/yesqa
rev: v1.4.0
rev: v1.5.0
hooks:
- id: yesqa
name: Remove unnecessary `# noqa` comments
Expand All @@ -98,7 +98,7 @@ repos:

# ------------------------------ Python imports ------------------------------ #
- repo: https://github.com/hakancelik96/unimport
rev: 0.16.0
rev: 1.1.0
hooks:
- id: unimport
name: Remove any unused imports
Expand Down Expand Up @@ -128,20 +128,20 @@ repos:

# -------------------------------- Formatting -------------------------------- #
- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.0.0-alpha.6
rev: v3.1.0
hooks:
- id: prettier
name: Prettier
exclude: ^.*/?CHANGELOG\.md$
- repo: https://github.com/myint/docformatter
rev: v1.6.1
rev: v1.7.5
hooks:
- id: docformatter
name: Format docstrings
types: [python]
args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99]
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.11.0
hooks:
- id: black-jupyter
types: [python]
Expand Down
2 changes: 1 addition & 1 deletion src/emma_experience_hub/api/clients/simbot/session_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,4 @@ def _get_all_session_turns(self, session_id: str) -> list[dict[str, Any]]:

all_response_items.extend(response["Items"])

return all_response_items
return all_response_items # type: ignore[unreachable]
4 changes: 2 additions & 2 deletions src/emma_experience_hub/commands/build_emma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ def torch_cuda_version_callback(torch_version: Optional[str]) -> str:

try:
int(torch_version.split("+cu")[-1])
except ValueError:
except ValueError as err:
raise typer.BadParameter(
"The cuda version must be castable to an integer. For example, `113` and not `11.3`"
)
) from err

return torch_version

Expand Down
2 changes: 1 addition & 1 deletion src/emma_experience_hub/commands/teach/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def filter_edh_instances(
edh_instances = [
instance
for instance in edh_instances
if any([action.obj_interaction_action for action in instance.driver_actions_future])
if any(action.obj_interaction_action for action in instance.driver_actions_future)
]
console.log(
f"{num_instances_before_filter - len(edh_instances)} EDH instances [cyan]do not have an interaction action[/] in their future. {len(edh_instances)} EDH instances remaining..."
Expand Down
4 changes: 2 additions & 2 deletions src/emma_experience_hub/datamodels/simbot/feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def calculate_rule_score(cls, score: int, values: dict[str, Any]) -> int: # noq
def is_query_suitable(self, query: dict[str, Any]) -> bool:
"""Evaluate the rule given the query and ensure it is suitable."""
try:
return self.rule.matches(query) and all([name in query for name in self.slot_names])
return self.rule.matches(query) and all(name in query for name in self.slot_names)
except Exception:
return False

Expand Down Expand Up @@ -323,7 +323,7 @@ def from_all_information(
interaction_action.payload, SimBotObjectInteractionPayload
)
if interaction_action_has_bbox:
object_area = get_area_from_compressed_mask(interaction_action.payload.object.mask)
object_area = get_area_from_compressed_mask(interaction_action.payload.object.mask) # type: ignore[union-attr]

return cls(
# Require a lightweight dialog action when the model does not decode a <stop token
Expand Down
2 changes: 1 addition & 1 deletion src/emma_experience_hub/datamodels/simbot/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def try_to_update_agent_inventory(self) -> None:
viewpoint=self.previous_valid_turn.environment.get_closest_viewpoint_name(),
action=self.previous_valid_turn.actions.interaction,
inventory_entity=self.previous_valid_turn.state.inventory.entity,
action_history=[turn.actions.interaction for turn in past_turns], # type: ignore[union-attr]
action_history=[turn.actions.interaction for turn in past_turns], # type: ignore[misc]
inventory_history=[turn.state.inventory.entity for turn in past_turns],
)
self.current_state.memory.update_interaction_turn_index(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def __call__(

try:
decoded_action = self._separate_decoded_trajectory(decoded_trajectory)[0]
except IndexError:
except IndexError as err:
# If there is a problem when decoding the action
raise IndexError("Could not decode any actions from the trajectory")
raise IndexError("Could not decode any actions from the trajectory") from err

# Just use the first action, because if that is wrong, any future ones after it are likely
# hallucinated
Expand Down
4 changes: 2 additions & 2 deletions src/emma_experience_hub/parsers/simbot/qa_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,14 @@ def _process_intent(self, utterance: str, intent: str, entities: list[Any]) -> s

return intent

def _is_regex_match_for_incomplete_utterance(self, utterance) -> bool:
def _is_regex_match_for_incomplete_utterance(self, utterance: str) -> bool:
if re.search(self._incomplete_utterance_regex_pattern, utterance):
logger.debug(f"found incomplete regex match for the utterance: {utterance}")
return True
logger.debug(f"No incomplete regex match for the utterance: {utterance}")
return False

def _regex_patterns_map(self) -> bool:
def _regex_patterns_map(self) -> str:
verbs = [
"find",
"search",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def run(self, session: SimBotSession) -> SimBotAgentIntents: # noqa: WPS212

# check for object-qa and return
if session.current_turn.intent.user.is_user_qa_about_object:
return self.handle_object_qa_intent(session)
return self.handle_object_qa_intent(session) # type: ignore[return-value]

# If we have received an invalid utterance, the agent does not act
should_skip_action_selection = self._should_skip_action_selection(
Expand Down
1 change: 1 addition & 0 deletions src/emma_experience_hub/pipelines/simbot/find_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def _should_scan_found_object(
class_label=found_object_label,
extracted_features=extracted_features,
)
assert scene_object_tokens.object_index is not None

object_idx = scene_object_tokens.object_index - 1
area = extracted_features[frame_idx - 1].bbox_areas[object_idx].item()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _utterance_only_contains_wake_word(
self, speech_recognition_payload: SimBotSpeechRecognitionPayload
) -> bool:
"""Detect whether the utterance only contains the wake word or not."""
return all([token.is_wake_word for token in speech_recognition_payload.tokens])
return all(token.is_wake_word for token in speech_recognition_payload.tokens)

@tracer.start_as_current_span("Check for empty utterance")
def _utterance_is_empty(
Expand Down
6 changes: 3 additions & 3 deletions tests/pipelines/simbot/test_language_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_response_slots_in_all_rules(rule_parser: SimBotFeedbackFromSessionState
]
for rule_text, slots in rules_and_slots:
rule_words = rule_text.split()
assert all([slot_name in rule_words for slot_name in slots])
assert all(slot_name in rule_words for slot_name in slots)


def test_all_response_slots_are_validated_by_rules(
Expand All @@ -71,10 +71,10 @@ def test_all_response_slots_are_validated_by_rules(
]
for rule_text, slots in rules_and_slots:
# Each slot name must be used in the rule so that it exists in some way
assert all([slot_name in rule_text for slot_name in slots])
assert all(slot_name in rule_text for slot_name in slots)

# Any slot name in the rule must not be checking for it to be equal to None
assert not any([f"{slot_name} == null" in rule_text for slot_name in slots])
assert not any(f"{slot_name} == null" in rule_text for slot_name in slots)


def test_all_rule_symbols_in_state(rule_parser: SimBotFeedbackFromSessionStateParser) -> None:
Expand Down

0 comments on commit 469e83f

Please sign in to comment.