diff --git a/.flake8 b/.flake8 index 322f2bbc..f322f864 100644 --- a/.flake8 +++ b/.flake8 @@ -14,6 +14,8 @@ extend-ignore = E501, # Stop finding commented out code because it's mistaking shape annotations for code E800, + # Don't complain about asserts + S101, # Stop complaining about subprocess, we need it for this project S404,S602,S603,S607, # Stop complaining about using functions from random diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 19c215cd..ad9b3434 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ default_language_version: repos: # -------------------------- Version control checks -------------------------- # - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-merge-conflict name: Check for merge conflicts @@ -21,7 +21,7 @@ repos: name: Check for destroyed symlinks - repo: https://github.com/sirosen/check-jsonschema - rev: 0.22.0 + rev: 0.27.2 hooks: - id: check-github-workflows name: Validate GitHub workflows @@ -29,7 +29,7 @@ repos: # ----------------------------- Check file issues ---------------------------- # - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: check-toml name: Check TOML @@ -53,7 +53,7 @@ repos: # ------------------------------ Python checking ----------------------------- # - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: debug-statements name: Check for debugger statements @@ -77,19 +77,19 @@ repos: # ----------------------------- Automatic linters ---------------------------- # - repo: https://github.com/asottile/pyupgrade - rev: v3.3.1 + rev: v3.15.0 hooks: - id: pyupgrade name: Update syntax for newer Python types: [python] args: ["--py39-plus"] - repo: https://github.com/sirosen/texthooks - rev: 0.5.0 + rev: 0.6.3 hooks: - id: fix-smartquotes name: Fix Smart Quotes - repo: https://github.com/asottile/yesqa - rev: v1.4.0 + rev: v1.5.0 hooks: - id: yesqa name: Remove unnecessary `# noqa` comments @@ -98,7 +98,7 @@ repos: # ------------------------------ Python imports ------------------------------ # - repo: https://github.com/hakancelik96/unimport - rev: 0.16.0 + rev: 1.1.0 hooks: - id: unimport name: Remove any unused imports @@ -128,20 +128,20 @@ repos: # -------------------------------- Formatting -------------------------------- # - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.0-alpha.6 + rev: v3.1.0 hooks: - id: prettier name: Prettier exclude: ^.*/?CHANGELOG\.md$ - repo: https://github.com/myint/docformatter - rev: v1.6.1 + rev: v1.7.5 hooks: - id: docformatter name: Format docstrings types: [python] args: [--in-place, --wrap-summaries=99, --wrap-descriptions=99] - repo: https://github.com/psf/black - rev: 23.3.0 + rev: 23.11.0 hooks: - id: black-jupyter types: [python] diff --git a/src/emma_experience_hub/api/clients/simbot/session_db.py b/src/emma_experience_hub/api/clients/simbot/session_db.py index ccdce81c..9cb924cf 100644 --- a/src/emma_experience_hub/api/clients/simbot/session_db.py +++ b/src/emma_experience_hub/api/clients/simbot/session_db.py @@ -122,4 +122,4 @@ def _get_all_session_turns(self, session_id: str) -> list[dict[str, Any]]: all_response_items.extend(response["Items"]) - return all_response_items + return all_response_items # type: ignore[unreachable] diff --git a/src/emma_experience_hub/commands/build_emma.py b/src/emma_experience_hub/commands/build_emma.py index 13f4010a..48c839ef 100644 --- a/src/emma_experience_hub/commands/build_emma.py +++ b/src/emma_experience_hub/commands/build_emma.py @@ -16,10 +16,10 @@ def torch_cuda_version_callback(torch_version: Optional[str]) -> str: try: int(torch_version.split("+cu")[-1]) - except ValueError: + except ValueError as err: raise typer.BadParameter( "The cuda version must be castable to an integer. For example, `113` and not `11.3`" - ) + ) from err return torch_version diff --git a/src/emma_experience_hub/commands/teach/dataset.py b/src/emma_experience_hub/commands/teach/dataset.py index 370105c3..30029b05 100644 --- a/src/emma_experience_hub/commands/teach/dataset.py +++ b/src/emma_experience_hub/commands/teach/dataset.py @@ -90,7 +90,7 @@ def filter_edh_instances( edh_instances = [ instance for instance in edh_instances - if any([action.obj_interaction_action for action in instance.driver_actions_future]) + if any(action.obj_interaction_action for action in instance.driver_actions_future) ] console.log( f"{num_instances_before_filter - len(edh_instances)} EDH instances [cyan]do not have an interaction action[/] in their future. {len(edh_instances)} EDH instances remaining..." diff --git a/src/emma_experience_hub/datamodels/simbot/feedback.py b/src/emma_experience_hub/datamodels/simbot/feedback.py index 34f998c5..17ddce4f 100644 --- a/src/emma_experience_hub/datamodels/simbot/feedback.py +++ b/src/emma_experience_hub/datamodels/simbot/feedback.py @@ -177,7 +177,7 @@ def calculate_rule_score(cls, score: int, values: dict[str, Any]) -> int: # noq def is_query_suitable(self, query: dict[str, Any]) -> bool: """Evaluate the rule given the query and ensure it is suitable.""" try: - return self.rule.matches(query) and all([name in query for name in self.slot_names]) + return self.rule.matches(query) and all(name in query for name in self.slot_names) except Exception: return False @@ -323,7 +323,7 @@ def from_all_information( interaction_action.payload, SimBotObjectInteractionPayload ) if interaction_action_has_bbox: - object_area = get_area_from_compressed_mask(interaction_action.payload.object.mask) + object_area = get_area_from_compressed_mask(interaction_action.payload.object.mask) # type: ignore[union-attr] return cls( # Require a lightweight dialog action when the model does not decode a None: viewpoint=self.previous_valid_turn.environment.get_closest_viewpoint_name(), action=self.previous_valid_turn.actions.interaction, inventory_entity=self.previous_valid_turn.state.inventory.entity, - action_history=[turn.actions.interaction for turn in past_turns], # type: ignore[union-attr] + action_history=[turn.actions.interaction for turn in past_turns], # type: ignore[misc] inventory_history=[turn.state.inventory.entity for turn in past_turns], ) self.current_state.memory.update_interaction_turn_index( diff --git a/src/emma_experience_hub/parsers/simbot/action_predictor_output.py b/src/emma_experience_hub/parsers/simbot/action_predictor_output.py index 676f5dc3..15b27588 100644 --- a/src/emma_experience_hub/parsers/simbot/action_predictor_output.py +++ b/src/emma_experience_hub/parsers/simbot/action_predictor_output.py @@ -64,9 +64,9 @@ def __call__( try: decoded_action = self._separate_decoded_trajectory(decoded_trajectory)[0] - except IndexError: + except IndexError as err: # If there is a problem when decoding the action - raise IndexError("Could not decode any actions from the trajectory") + raise IndexError("Could not decode any actions from the trajectory") from err # Just use the first action, because if that is wrong, any future ones after it are likely # hallucinated diff --git a/src/emma_experience_hub/parsers/simbot/qa_output.py b/src/emma_experience_hub/parsers/simbot/qa_output.py index 02a480cf..709347be 100644 --- a/src/emma_experience_hub/parsers/simbot/qa_output.py +++ b/src/emma_experience_hub/parsers/simbot/qa_output.py @@ -84,14 +84,14 @@ def _process_intent(self, utterance: str, intent: str, entities: list[Any]) -> s return intent - def _is_regex_match_for_incomplete_utterance(self, utterance) -> bool: + def _is_regex_match_for_incomplete_utterance(self, utterance: str) -> bool: if re.search(self._incomplete_utterance_regex_pattern, utterance): logger.debug(f"found incomplete regex match for the utterance: {utterance}") return True logger.debug(f"No incomplete regex match for the utterance: {utterance}") return False - def _regex_patterns_map(self) -> bool: + def _regex_patterns_map(self) -> str: verbs = [ "find", "search", diff --git a/src/emma_experience_hub/pipelines/simbot/agent_intent_selection.py b/src/emma_experience_hub/pipelines/simbot/agent_intent_selection.py index a0d07c6b..b702a2ca 100644 --- a/src/emma_experience_hub/pipelines/simbot/agent_intent_selection.py +++ b/src/emma_experience_hub/pipelines/simbot/agent_intent_selection.py @@ -95,7 +95,7 @@ def run(self, session: SimBotSession) -> SimBotAgentIntents: # noqa: WPS212 # check for object-qa and return if session.current_turn.intent.user.is_user_qa_about_object: - return self.handle_object_qa_intent(session) + return self.handle_object_qa_intent(session) # type: ignore[return-value] # If we have received an invalid utterance, the agent does not act should_skip_action_selection = self._should_skip_action_selection( diff --git a/src/emma_experience_hub/pipelines/simbot/find_object.py b/src/emma_experience_hub/pipelines/simbot/find_object.py index 291c5d69..5a184020 100644 --- a/src/emma_experience_hub/pipelines/simbot/find_object.py +++ b/src/emma_experience_hub/pipelines/simbot/find_object.py @@ -206,6 +206,7 @@ def _should_scan_found_object( class_label=found_object_label, extracted_features=extracted_features, ) + assert scene_object_tokens.object_index is not None object_idx = scene_object_tokens.object_index - 1 area = extracted_features[frame_idx - 1].bbox_areas[object_idx].item() diff --git a/src/emma_experience_hub/pipelines/simbot/user_utterance_verification.py b/src/emma_experience_hub/pipelines/simbot/user_utterance_verification.py index a16165af..a67c3436 100644 --- a/src/emma_experience_hub/pipelines/simbot/user_utterance_verification.py +++ b/src/emma_experience_hub/pipelines/simbot/user_utterance_verification.py @@ -103,7 +103,7 @@ def _utterance_only_contains_wake_word( self, speech_recognition_payload: SimBotSpeechRecognitionPayload ) -> bool: """Detect whether the utterance only contains the wake word or not.""" - return all([token.is_wake_word for token in speech_recognition_payload.tokens]) + return all(token.is_wake_word for token in speech_recognition_payload.tokens) @tracer.start_as_current_span("Check for empty utterance") def _utterance_is_empty( diff --git a/tests/pipelines/simbot/test_language_generator.py b/tests/pipelines/simbot/test_language_generator.py index 7cfde1bc..e350cf77 100644 --- a/tests/pipelines/simbot/test_language_generator.py +++ b/tests/pipelines/simbot/test_language_generator.py @@ -59,7 +59,7 @@ def test_response_slots_in_all_rules(rule_parser: SimBotFeedbackFromSessionState ] for rule_text, slots in rules_and_slots: rule_words = rule_text.split() - assert all([slot_name in rule_words for slot_name in slots]) + assert all(slot_name in rule_words for slot_name in slots) def test_all_response_slots_are_validated_by_rules( @@ -71,10 +71,10 @@ def test_all_response_slots_are_validated_by_rules( ] for rule_text, slots in rules_and_slots: # Each slot name must be used in the rule so that it exists in some way - assert all([slot_name in rule_text for slot_name in slots]) + assert all(slot_name in rule_text for slot_name in slots) # Any slot name in the rule must not be checking for it to be equal to None - assert not any([f"{slot_name} == null" in rule_text for slot_name in slots]) + assert not any(f"{slot_name} == null" in rule_text for slot_name in slots) def test_all_rule_symbols_in_state(rule_parser: SimBotFeedbackFromSessionStateParser) -> None: