Skip to content

Commit

Permalink
fixing typing pylance errors
Browse files Browse the repository at this point in the history
  • Loading branch information
rpmcginty committed Feb 20, 2024
1 parent 9c25071 commit 60970e8
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 27 deletions.
5 changes: 4 additions & 1 deletion src/aibs_informatics_aws_utils/dynamodb/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ class ExpressionComponentsBase:
@cached_property
def expression_attribute_values__serialized(self) -> Dict[str, Dict[str, Any]]:
serializer = TypeSerializer()
return {k: serializer.serialize(v) for k, v in self.expression_attribute_values.items()}
return {
k: cast(Dict[str, Any], serializer.serialize(v))
for k, v in self.expression_attribute_values.items()
}


@dataclass
Expand Down
24 changes: 13 additions & 11 deletions src/aibs_informatics_aws_utils/dynamodb/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ def table_put_item(


def table_get_item(
table_name: str, key: Mapping[str, Any], attrs: str = None
table_name: str, key: Mapping[str, Any], attrs: Optional[str] = None
) -> Optional[Dict[str, Any]]:
table = table_as_resource(table_name)
props: GetItemInputRequestTypeDef = {"Key": key, "ReturnConsumedCapacity": "NONE"}
props: GetItemInputRequestTypeDef = {"Key": key, "ReturnConsumedCapacity": "NONE"} # type: ignore # we modify use of this type (no table name is needed here)

if attrs is not None:
props["ProjectionExpression"] = attrs

response = table.get_item(**props)
response = table.get_item(**props) # type: ignore # pylance complains about extra fields

logger.info("Response from table.get_item: %s", response)

Expand All @@ -79,8 +79,8 @@ def table_get_item(
def table_get_items(
table_name: str,
keys: List[Mapping[str, Any]],
attrs: str = None,
region: str = None,
attrs: Optional[str] = None,
region: Optional[str] = None,
) -> List[Dict[str, Any]]:
db = get_dynamodb_client(region=region)
serializer = TypeSerializer()
Expand Down Expand Up @@ -183,7 +183,7 @@ def table_query(
key_condition_expression: ConditionBase,
index_name: Optional[str] = None,
filter_expression: Optional[ConditionBase] = None,
region: str = None,
region: Optional[str] = None,
consistent_read: bool = False,
) -> List[Dict[str, Any]]:
"""Query a table
Expand Down Expand Up @@ -252,7 +252,7 @@ def table_query(
items: List[Dict[str, Any]] = []
paginator = db.get_paginator("query")
logger.info(f"Performing DB 'query' on {table.name} with following parameters: {db_request}")
for i, response in enumerate(paginator.paginate(**db_request)):
for i, response in enumerate(paginator.paginate(**db_request)): # type: ignore # pylance complains about extra fields
new_items = response.get("Items", [])
items.extend(new_items)
logger.debug(f"Iter #{i+1}: item count from table. Query: {len(new_items)}")
Expand All @@ -266,7 +266,7 @@ def table_scan(
table_name: str,
index_name: Optional[str] = None,
filter_expression: Optional[ConditionBase] = None,
region: str = None,
region: Optional[str] = None,
consistent_read: bool = False,
) -> List[Dict[str, Any]]:
"""Scan a table
Expand Down Expand Up @@ -319,7 +319,7 @@ def table_scan(
items: List[Dict[str, Any]] = []
paginator = db.get_paginator("scan")
logger.info(f"Performing DB 'scan' on {table.name} with following parameters: {db_request}")
for i, response in enumerate(paginator.paginate(**db_request)):
for i, response in enumerate(paginator.paginate(**db_request)): # type: ignore # pylance complains about extra fields
new_items = response.get("Items", [])
items.extend(new_items)
logger.debug(f"Iter #{i+1}: item count from table. Scan: {len(new_items)}")
Expand All @@ -339,7 +339,9 @@ def table_get_key_schema(table_name: str) -> Dict[str, str]:
return {k["KeyType"]: k["AttributeName"] for k in table.key_schema}


def execute_partiql_statement(statement: str, region: str = None) -> List[Dict[str, Any]]:
def execute_partiql_statement(
statement: str, region: Optional[str] = None
) -> List[Dict[str, Any]]:
db = get_dynamodb_client(region=region)

response = db.execute_statement(Statement=statement)
Expand All @@ -351,7 +353,7 @@ def execute_partiql_statement(statement: str, region: str = None) -> List[Dict[s
return results


def table_as_resource(table: str, region: str = None):
def table_as_resource(table: str, region: Optional[str] = None):
"""Helper method to get the table as a resource for given env_label
if provided.
"""
Expand Down
24 changes: 13 additions & 11 deletions src/aibs_informatics_aws_utils/dynamodb/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,20 +202,20 @@ def build_optimized_condition_expression_set(
new_condition = Key(k).eq(v)
if (
k in candidate_conditions
and candidate_conditions[k]._values[1:] != new_condition._values[1:]
and candidate_conditions[k]._values[1:] != new_condition._values[1:] # type: ignore[union-attr]
):
raise DBQueryException(f"Multiple values provided for attribute {k}!")
candidate_conditions[k] = Key(k).eq(v)
elif len(_._values) and isinstance(_._values[0], (Key, Attr)):
attr_name = cast(str, _._values[0].name)
elif len(_._values) and isinstance(_._values[0], (Key, Attr)): # type: ignore[union-attr]
attr_name = cast(str, _._values[0].name) # type: ignore[union-attr]
if attr_name not in index_all_key_names or not isinstance(
_, SupportedKeyComparisonTypes
):
non_candidate_conditions.append(_)
continue
if (
attr_name in candidate_conditions
and candidate_conditions[attr_name]._values[1:] != _._values[1:]
and candidate_conditions[attr_name]._values[1:] != _._values[1:] # type: ignore[union-attr]
):
raise DBQueryException(f"Multiple values provided for attribute {attr_name}!")
candidate_conditions[attr_name] = _
Expand All @@ -228,12 +228,12 @@ def build_optimized_condition_expression_set(
):
target_index = index
partition_key = candidate_conditions.pop(index.key_name)
partition_key._values = (Key(index.key_name), *partition_key._values[1:])
partition_key._values = (Key(index.key_name), *partition_key._values[1:]) # type: ignore[union-attr]
if index.sort_key_name is not None and index.sort_key_name in candidate_conditions:
sort_key_condition_expression = candidate_conditions.pop(index.sort_key_name)
sort_key_condition_expression._values = (
sort_key_condition_expression._values = ( # type: ignore[union-attr]
Key(index.sort_key_name),
*sort_key_condition_expression._values[1:],
*sort_key_condition_expression._values[1:], # type: ignore[union-attr]
)
break

Expand Down Expand Up @@ -315,7 +315,9 @@ def build_key(
return key
index = cls.index_or_default(index)
return (
index.get_primary_key(*key) if isinstance(key, tuple) else index.get_primary_key(key)
index.get_primary_key(key[0], key[1])
if isinstance(key, tuple)
else index.get_primary_key(key)
)

# --------------------------------------------------------------------------
Expand Down Expand Up @@ -386,8 +388,8 @@ def batch_get(
items = table_get_items(table_name=self.table_name, keys=item_keys)
if len(items) != len(item_keys) and not ignore_missing:
missing_keys = set(
[(_[index.key_name], _.get(index.sort_key_name)) for _ in item_keys]
).difference((_[index.key_name], _.get(index.sort_key_name)) for _ in items)
[(_[index.key_name], _.get(index.sort_key_name or "")) for _ in item_keys]
).difference((_[index.key_name], _.get(index.sort_key_name or "")) for _ in items)

raise DBReadException(f"Could not find items for {missing_keys}")
entries = [self.build_entry(_, partial=partial) for _ in items]
Expand Down Expand Up @@ -704,7 +706,7 @@ def delete(
e_msg = f"{self.table_name} - Delete failed for the following primary key: {key}"
try:
deleted_attributes = table_delete_item(
table_name=self.table_name, key=key, return_values="ALL_OLD"
table_name=self.table_name, key=key, return_values="ALL_OLD" # type: ignore[arg-type] # expected type more general than specified here
)

if not deleted_attributes:
Expand Down
6 changes: 3 additions & 3 deletions src/aibs_informatics_aws_utils/efs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def list_efs_file_systems(

file_systems: List[FileSystemDescriptionTypeDef] = []
paginator_kwargs = remove_null_values(dict(FileSystemId=file_system_id))
for results in paginator.paginate(**paginator_kwargs):
for results in paginator.paginate(**paginator_kwargs): # type: ignore
for fs in results["FileSystems"]:
if name and fs.get("Name") != name:
continue
Expand Down Expand Up @@ -159,12 +159,12 @@ def list_efs_access_points(

for fs_id in file_system_ids:
response = efs.describe_access_points(
**remove_null_values(dict(AccessPointId=access_point_id, FileSystemId=fs_id))
**remove_null_values(dict(AccessPointId=access_point_id, FileSystemId=fs_id)) # type: ignore
)
access_points.extend(response["AccessPoints"])
while response.get("NextToken"):
response = efs.describe_access_points(
**remove_null_values(
**remove_null_values( # type: ignore
dict(
AccessPointId=access_point_id,
FileSystemId=fs_id,
Expand Down
2 changes: 1 addition & 1 deletion src/aibs_informatics_aws_utils/efs/mount_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def is_mounted_path(self, path: StrPath) -> bool:
def as_env_vars(self, name: Optional[str] = None) -> Dict[str, str]:
"""Converts the mount point configuration to environment variables."""
if self.access_point and self.access_point.get("AccessPointId"):
mount_point_id = self.access_point.get("AccessPointId")
mount_point_id = self.access_point["AccessPointId"] # type: ignore # pylance complains even though we checked
else:
mount_point_id = self.file_system["FileSystemId"]

Expand Down

0 comments on commit 60970e8

Please sign in to comment.