From 60970e85b84aa002f1623bc92331fd4afa147b07 Mon Sep 17 00:00:00 2001 From: Ryan McGinty Date: Mon, 19 Feb 2024 18:08:29 -0800 Subject: [PATCH] fixing typing pylance errors --- .../dynamodb/conditions.py | 5 +++- .../dynamodb/functions.py | 24 ++++++++++--------- .../dynamodb/table.py | 24 ++++++++++--------- src/aibs_informatics_aws_utils/efs/core.py | 6 ++--- .../efs/mount_point.py | 2 +- 5 files changed, 34 insertions(+), 27 deletions(-) diff --git a/src/aibs_informatics_aws_utils/dynamodb/conditions.py b/src/aibs_informatics_aws_utils/dynamodb/conditions.py index 7bfe4f7..352e3e6 100644 --- a/src/aibs_informatics_aws_utils/dynamodb/conditions.py +++ b/src/aibs_informatics_aws_utils/dynamodb/conditions.py @@ -33,7 +33,10 @@ class ExpressionComponentsBase: @cached_property def expression_attribute_values__serialized(self) -> Dict[str, Dict[str, Any]]: serializer = TypeSerializer() - return {k: serializer.serialize(v) for k, v in self.expression_attribute_values.items()} + return { + k: cast(Dict[str, Any], serializer.serialize(v)) + for k, v in self.expression_attribute_values.items() + } @dataclass diff --git a/src/aibs_informatics_aws_utils/dynamodb/functions.py b/src/aibs_informatics_aws_utils/dynamodb/functions.py index c96c297..56275e9 100644 --- a/src/aibs_informatics_aws_utils/dynamodb/functions.py +++ b/src/aibs_informatics_aws_utils/dynamodb/functions.py @@ -61,15 +61,15 @@ def table_put_item( def table_get_item( - table_name: str, key: Mapping[str, Any], attrs: str = None + table_name: str, key: Mapping[str, Any], attrs: Optional[str] = None ) -> Optional[Dict[str, Any]]: table = table_as_resource(table_name) - props: GetItemInputRequestTypeDef = {"Key": key, "ReturnConsumedCapacity": "NONE"} + props: GetItemInputRequestTypeDef = {"Key": key, "ReturnConsumedCapacity": "NONE"} # type: ignore # we modify use of this type (no table name is needed here) if attrs is not None: props["ProjectionExpression"] = attrs - response = table.get_item(**props) + response = table.get_item(**props) # type: ignore # pylance complains about extra fields logger.info("Response from table.get_item: %s", response) @@ -79,8 +79,8 @@ def table_get_item( def table_get_items( table_name: str, keys: List[Mapping[str, Any]], - attrs: str = None, - region: str = None, + attrs: Optional[str] = None, + region: Optional[str] = None, ) -> List[Dict[str, Any]]: db = get_dynamodb_client(region=region) serializer = TypeSerializer() @@ -183,7 +183,7 @@ def table_query( key_condition_expression: ConditionBase, index_name: Optional[str] = None, filter_expression: Optional[ConditionBase] = None, - region: str = None, + region: Optional[str] = None, consistent_read: bool = False, ) -> List[Dict[str, Any]]: """Query a table @@ -252,7 +252,7 @@ def table_query( items: List[Dict[str, Any]] = [] paginator = db.get_paginator("query") logger.info(f"Performing DB 'query' on {table.name} with following parameters: {db_request}") - for i, response in enumerate(paginator.paginate(**db_request)): + for i, response in enumerate(paginator.paginate(**db_request)): # type: ignore # pylance complains about extra fields new_items = response.get("Items", []) items.extend(new_items) logger.debug(f"Iter #{i+1}: item count from table. Query: {len(new_items)}") @@ -266,7 +266,7 @@ def table_scan( table_name: str, index_name: Optional[str] = None, filter_expression: Optional[ConditionBase] = None, - region: str = None, + region: Optional[str] = None, consistent_read: bool = False, ) -> List[Dict[str, Any]]: """Scan a table @@ -319,7 +319,7 @@ def table_scan( items: List[Dict[str, Any]] = [] paginator = db.get_paginator("scan") logger.info(f"Performing DB 'scan' on {table.name} with following parameters: {db_request}") - for i, response in enumerate(paginator.paginate(**db_request)): + for i, response in enumerate(paginator.paginate(**db_request)): # type: ignore # pylance complains about extra fields new_items = response.get("Items", []) items.extend(new_items) logger.debug(f"Iter #{i+1}: item count from table. Scan: {len(new_items)}") @@ -339,7 +339,9 @@ def table_get_key_schema(table_name: str) -> Dict[str, str]: return {k["KeyType"]: k["AttributeName"] for k in table.key_schema} -def execute_partiql_statement(statement: str, region: str = None) -> List[Dict[str, Any]]: +def execute_partiql_statement( + statement: str, region: Optional[str] = None +) -> List[Dict[str, Any]]: db = get_dynamodb_client(region=region) response = db.execute_statement(Statement=statement) @@ -351,7 +353,7 @@ def execute_partiql_statement(statement: str, region: str = None) -> List[Dict[s return results -def table_as_resource(table: str, region: str = None): +def table_as_resource(table: str, region: Optional[str] = None): """Helper method to get the table as a resource for given env_label if provided. """ diff --git a/src/aibs_informatics_aws_utils/dynamodb/table.py b/src/aibs_informatics_aws_utils/dynamodb/table.py index 9f0995f..082a917 100644 --- a/src/aibs_informatics_aws_utils/dynamodb/table.py +++ b/src/aibs_informatics_aws_utils/dynamodb/table.py @@ -202,12 +202,12 @@ def build_optimized_condition_expression_set( new_condition = Key(k).eq(v) if ( k in candidate_conditions - and candidate_conditions[k]._values[1:] != new_condition._values[1:] + and candidate_conditions[k]._values[1:] != new_condition._values[1:] # type: ignore[union-attr] ): raise DBQueryException(f"Multiple values provided for attribute {k}!") candidate_conditions[k] = Key(k).eq(v) - elif len(_._values) and isinstance(_._values[0], (Key, Attr)): - attr_name = cast(str, _._values[0].name) + elif len(_._values) and isinstance(_._values[0], (Key, Attr)): # type: ignore[union-attr] + attr_name = cast(str, _._values[0].name) # type: ignore[union-attr] if attr_name not in index_all_key_names or not isinstance( _, SupportedKeyComparisonTypes ): @@ -215,7 +215,7 @@ def build_optimized_condition_expression_set( continue if ( attr_name in candidate_conditions - and candidate_conditions[attr_name]._values[1:] != _._values[1:] + and candidate_conditions[attr_name]._values[1:] != _._values[1:] # type: ignore[union-attr] ): raise DBQueryException(f"Multiple values provided for attribute {attr_name}!") candidate_conditions[attr_name] = _ @@ -228,12 +228,12 @@ def build_optimized_condition_expression_set( ): target_index = index partition_key = candidate_conditions.pop(index.key_name) - partition_key._values = (Key(index.key_name), *partition_key._values[1:]) + partition_key._values = (Key(index.key_name), *partition_key._values[1:]) # type: ignore[union-attr] if index.sort_key_name is not None and index.sort_key_name in candidate_conditions: sort_key_condition_expression = candidate_conditions.pop(index.sort_key_name) - sort_key_condition_expression._values = ( + sort_key_condition_expression._values = ( # type: ignore[union-attr] Key(index.sort_key_name), - *sort_key_condition_expression._values[1:], + *sort_key_condition_expression._values[1:], # type: ignore[union-attr] ) break @@ -315,7 +315,9 @@ def build_key( return key index = cls.index_or_default(index) return ( - index.get_primary_key(*key) if isinstance(key, tuple) else index.get_primary_key(key) + index.get_primary_key(key[0], key[1]) + if isinstance(key, tuple) + else index.get_primary_key(key) ) # -------------------------------------------------------------------------- @@ -386,8 +388,8 @@ def batch_get( items = table_get_items(table_name=self.table_name, keys=item_keys) if len(items) != len(item_keys) and not ignore_missing: missing_keys = set( - [(_[index.key_name], _.get(index.sort_key_name)) for _ in item_keys] - ).difference((_[index.key_name], _.get(index.sort_key_name)) for _ in items) + [(_[index.key_name], _.get(index.sort_key_name or "")) for _ in item_keys] + ).difference((_[index.key_name], _.get(index.sort_key_name or "")) for _ in items) raise DBReadException(f"Could not find items for {missing_keys}") entries = [self.build_entry(_, partial=partial) for _ in items] @@ -704,7 +706,7 @@ def delete( e_msg = f"{self.table_name} - Delete failed for the following primary key: {key}" try: deleted_attributes = table_delete_item( - table_name=self.table_name, key=key, return_values="ALL_OLD" + table_name=self.table_name, key=key, return_values="ALL_OLD" # type: ignore[arg-type] # expected type more general than specified here ) if not deleted_attributes: diff --git a/src/aibs_informatics_aws_utils/efs/core.py b/src/aibs_informatics_aws_utils/efs/core.py index dae1e08..43b2ce4 100644 --- a/src/aibs_informatics_aws_utils/efs/core.py +++ b/src/aibs_informatics_aws_utils/efs/core.py @@ -69,7 +69,7 @@ def list_efs_file_systems( file_systems: List[FileSystemDescriptionTypeDef] = [] paginator_kwargs = remove_null_values(dict(FileSystemId=file_system_id)) - for results in paginator.paginate(**paginator_kwargs): + for results in paginator.paginate(**paginator_kwargs): # type: ignore for fs in results["FileSystems"]: if name and fs.get("Name") != name: continue @@ -159,12 +159,12 @@ def list_efs_access_points( for fs_id in file_system_ids: response = efs.describe_access_points( - **remove_null_values(dict(AccessPointId=access_point_id, FileSystemId=fs_id)) + **remove_null_values(dict(AccessPointId=access_point_id, FileSystemId=fs_id)) # type: ignore ) access_points.extend(response["AccessPoints"]) while response.get("NextToken"): response = efs.describe_access_points( - **remove_null_values( + **remove_null_values( # type: ignore dict( AccessPointId=access_point_id, FileSystemId=fs_id, diff --git a/src/aibs_informatics_aws_utils/efs/mount_point.py b/src/aibs_informatics_aws_utils/efs/mount_point.py index 47f9f19..d4ca314 100644 --- a/src/aibs_informatics_aws_utils/efs/mount_point.py +++ b/src/aibs_informatics_aws_utils/efs/mount_point.py @@ -225,7 +225,7 @@ def is_mounted_path(self, path: StrPath) -> bool: def as_env_vars(self, name: Optional[str] = None) -> Dict[str, str]: """Converts the mount point configuration to environment variables.""" if self.access_point and self.access_point.get("AccessPointId"): - mount_point_id = self.access_point.get("AccessPointId") + mount_point_id = self.access_point["AccessPointId"] # type: ignore # pylance complains even though we checked else: mount_point_id = self.file_system["FileSystemId"]