Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

repo-review round 2 #485

Merged
merged 6 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 44 additions & 42 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import inspect
import itertools
import re
import warnings
from collections import ChainMap, namedtuple
from collections.abc import Hashable, Iterable, Mapping, MutableMapping, Sequence
from datetime import datetime
from typing import (
Any,
Callable,
Literal,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -48,6 +48,7 @@
_get_version,
_is_datetime_like,
always_iterable,
emit_user_level_warning,
invert_mappings,
parse_cell_methods_attr,
parse_cf_standard_name_table,
Expand Down Expand Up @@ -107,7 +108,7 @@ def apply_mapper(
"""

if not isinstance(key, Hashable):
if default is None:
if default is None: # type: ignore[unreachable]
raise ValueError(
"`default` must be provided when `key` is not not a valid DataArray name (of hashable type)."
)
Expand Down Expand Up @@ -224,7 +225,7 @@ def _get_custom_criteria(
try:
from regex import match as regex_match
except ImportError:
from re import match as regex_match # type: ignore
from re import match as regex_match # type: ignore[no-redef]

if isinstance(obj, DataArray):
obj = obj._to_temp_dataset()
Expand Down Expand Up @@ -363,8 +364,6 @@ def _get_measure(obj: DataArray | Dataset, key: str) -> list[str]:
if key in measures:
results.update([measures[key]])

if isinstance(results, str):
return [results]
return list(results)


Expand Down Expand Up @@ -471,7 +470,7 @@ def _get_all(obj: DataArray | Dataset, key: Hashable) -> list[Hashable]:
"""
all_mappers: tuple[Mapper] = (
_get_custom_criteria,
functools.partial(_get_custom_criteria, criteria=cf_role_criteria), # type: ignore
functools.partial(_get_custom_criteria, criteria=cf_role_criteria), # type: ignore[assignment]
functools.partial(_get_custom_criteria, criteria=grid_mapping_var_criteria),
_get_axis_coord,
_get_measure,
Expand Down Expand Up @@ -653,10 +652,10 @@ def _getattr(
):
raise AttributeError(
f"{obj.__class__.__name__+'.cf'!r} object has no attribute {attr!r}"
)
) from None
raise AttributeError(
f"{attr!r} is not a valid attribute on the underlying xarray object."
)
) from None

if isinstance(attribute, Mapping):
if not attribute:
Expand All @@ -680,7 +679,7 @@ def _getattr(
newmap.update(dict.fromkeys(inverted[key], value))
newmap.update({key: attribute[key] for key in unused_keys})

skip: dict[str, list[Hashable] | None] = {
skip: dict[str, list[Literal["coords", "measures"]] | None] = {
"data_vars": ["coords"],
"coords": None,
}
Expand All @@ -689,7 +688,7 @@ def _getattr(
newmap[key] = _getitem(accessor, key, skip=skip[attr])
return newmap

elif isinstance(attribute, Callable): # type: ignore
elif isinstance(attribute, Callable): # type: ignore[arg-type]
func: Callable = attribute

else:
Expand Down Expand Up @@ -721,7 +720,7 @@ def wrapper(*args, **kwargs):
def _getitem(
accessor: CFAccessor,
key: Hashable,
skip: list[Hashable] | None = None,
skip: list[Literal["coords", "measures"]] | None = None,
) -> DataArray:
...

Expand All @@ -730,15 +729,15 @@ def _getitem(
def _getitem(
accessor: CFAccessor,
key: Iterable[Hashable],
skip: list[Hashable] | None = None,
skip: list[Literal["coords", "measures"]] | None = None,
) -> Dataset:
...


def _getitem(
accessor,
key,
skip=None,
accessor: CFAccessor,
key: Hashable | Iterable[Hashable],
skip: list[Literal["coords", "measures"]] | None = None,
):
"""
Index into obj using key. Attaches CF associated variables.
Expand Down Expand Up @@ -789,7 +788,7 @@ def check_results(names, key):
measures = accessor._get_all_cell_measures()
except ValueError:
measures = []
warnings.warn("Ignoring bad cell_measures attribute.", UserWarning)
emit_user_level_warning("Ignoring bad cell_measures attribute.", UserWarning)

if isinstance(obj, Dataset):
grid_mapping_names = list(accessor.grid_mapping_names)
Expand Down Expand Up @@ -852,6 +851,7 @@ def check_results(names, key):
)
coords.extend(itertools.chain(*extravars.values()))

ds: Dataset
if isinstance(obj, DataArray):
ds = obj._to_temp_dataset()
else:
Expand All @@ -860,7 +860,7 @@ def check_results(names, key):
if scalar_key:
if len(allnames) == 1:
(name,) = allnames
da: DataArray = ds.reset_coords()[name] # type: ignore
da: DataArray = ds.reset_coords()[name]
if name in coords:
coords.remove(name)
for k1 in coords:
Expand All @@ -877,26 +877,27 @@ def check_results(names, key):

ds = ds.reset_coords()[varnames + coords]
if isinstance(obj, DataArray):
if scalar_key and len(ds.variables) == 1:
# single dimension coordinates
assert coords
assert not varnames
if scalar_key:
if len(ds.variables) == 1: # type: ignore[unreachable]
# single dimension coordinates
assert coords
assert not varnames

return ds[coords[0]]
return ds[coords[0]]

elif scalar_key and len(ds.variables) > 1:
raise NotImplementedError(
"Not sure what to return when given scalar key for DataArray and it has multiple values. "
"Please open an issue."
)
else:
raise NotImplementedError(
"Not sure what to return when given scalar key for DataArray and it has multiple values. "
"Please open an issue."
)

return ds.set_coords(coords)

except KeyError:
raise KeyError(
f"{kind}.cf does not understand the key {k!r}. "
f"Use 'repr({kind}.cf)' (or '{kind}.cf' in a Jupyter environment) to see a list of key names that can be interpreted."
)
) from None


def _possible_x_y_plot(obj, key, skip=None):
Expand Down Expand Up @@ -1135,7 +1136,7 @@ def _assert_valid_other_comparison(self, other):
)
return flag_dict

def __eq__(self, other) -> DataArray: # type: ignore
def __eq__(self, other) -> DataArray: # type: ignore[override]
"""
Compare flag values against ``other``.

Expand All @@ -1155,7 +1156,7 @@ def __eq__(self, other) -> DataArray: # type: ignore
"""
return self._extract_flags([other])[other].rename(self._obj.name)

def __ne__(self, other) -> DataArray: # type: ignore
def __ne__(self, other) -> DataArray: # type: ignore[override]
"""
Compare flag values against ``other``.

Expand Down Expand Up @@ -1328,7 +1329,7 @@ def curvefit(
coords_iter = coords
coords = [
apply_mapper(
[_single(_get_coords)], self._obj, v, error=False, default=[v] # type: ignore
[_single(_get_coords)], self._obj, v, error=False, default=[v] # type: ignore[arg-type]
)[0]
for v in coords_iter
]
Expand All @@ -1339,7 +1340,7 @@ def curvefit(
reduce_dims_iter = list(reduce_dims)
reduce_dims = [
apply_mapper(
[_single(_get_dims)], self._obj, v, error=False, default=[v] # type: ignore
[_single(_get_dims)], self._obj, v, error=False, default=[v] # type: ignore[arg-type]
)[0]
for v in reduce_dims_iter
]
Expand Down Expand Up @@ -1435,7 +1436,7 @@ def _rewrite_values(

# allow multiple return values here.
# these are valid for .sel, .isel, .coarsen
all_mappers = ChainMap( # type: ignore
all_mappers = ChainMap( # type: ignore[misc]
key_mappers,
dict.fromkeys(var_kws, (_get_all,)),
)
Expand Down Expand Up @@ -1531,7 +1532,7 @@ def describe(self):
Print a string repr to screen.
"""

warnings.warn(
emit_user_level_warning(
"'obj.cf.describe()' will be removed in a future version. "
"Use instead 'repr(obj.cf)' or 'obj.cf' in a Jupyter environment.",
DeprecationWarning,
Expand Down Expand Up @@ -1695,10 +1696,9 @@ def cell_measures(self) -> dict[str, list[Hashable]]:
bad_vars = list(
as_dataset.filter_by_attrs(cell_measures=attr).data_vars.keys()
)
warnings.warn(
emit_user_level_warning(
f"Ignoring bad cell_measures attribute: {attr} on {bad_vars}.",
UserWarning,
stacklevel=2,
)
measures = {
key: self._drop_missing_variables(_get_all(self._obj, key)) for key in keys
Expand Down Expand Up @@ -1816,9 +1816,9 @@ def get_associated_variable_names(
except ValueError as e:
if error:
msg = e.args[0] + " Ignore this error by passing 'error=False'"
raise ValueError(msg)
raise ValueError(msg) from None
else:
warnings.warn(
emit_user_level_warning(
f"Ignoring bad cell_measures attribute: {attrs_or_encoding['cell_measures']}",
UserWarning,
)
Expand Down Expand Up @@ -1850,7 +1850,7 @@ def get_associated_variable_names(
missing = set(allvars) - set(self._maybe_to_dataset()._variables)
if missing:
if OPTIONS["warn_on_missing_variables"]:
warnings.warn(
emit_user_level_warning(
f"Variables {missing!r} not found in object but are referred to in the CF attributes.",
UserWarning,
)
Expand Down Expand Up @@ -1963,7 +1963,7 @@ def get_renamer_and_conflicts(keydict):

# Rename and warn
if conflicts:
warnings.warn(
emit_user_level_warning(
"Conflicting variables skipped:\n"
+ "\n".join(
[
Expand Down Expand Up @@ -2684,10 +2684,12 @@ def decode_vertical_coords(self, *, outnames=None, prefix=None):
try:
zname = outnames[dim]
except KeyError:
raise KeyError("Your `outnames` need to include a key of `dim`.")
raise KeyError(
"Your `outnames` need to include a key of `dim`."
) from None

else:
warnings.warn(
emit_user_level_warning(
"`prefix` is being deprecated; use `outnames` instead.",
DeprecationWarning,
)
Expand Down
4 changes: 2 additions & 2 deletions cf_xarray/criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
try:
import regex as re
except ImportError:
import re # type: ignore
import re # type: ignore[no-redef]

from collections.abc import Mapping, MutableMapping
from typing import Any
Expand Down Expand Up @@ -128,7 +128,7 @@
coordinate_criteria["time"] = coordinate_criteria["T"]

# "long_name" and "standard_name" criteria are the same. For convenience.
for coord, attrs in coordinate_criteria.items():
for coord in coordinate_criteria:
coordinate_criteria[coord]["long_name"] = coordinate_criteria[coord][
"standard_name"
]
Expand Down
24 changes: 9 additions & 15 deletions cf_xarray/formatting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import warnings
from collections.abc import Hashable, Iterable
from functools import partial
Expand All @@ -10,7 +12,7 @@
try:
from rich.table import Table
except ImportError:
Table = None # type: ignore
Table = None # type: ignore[assignment, misc]


def _format_missing_row(row: str, rich: bool) -> str:
Expand Down Expand Up @@ -41,7 +43,7 @@ def _format_cf_name(name: str, rich: bool) -> str:
def make_text_section(
accessor,
subtitle: str,
attr: str,
attr: str | dict,
dims=None,
valid_keys=None,
valid_values=None,
Expand Down Expand Up @@ -140,10 +142,10 @@ def _maybe_panel(textgen, title: str, rich: bool):
width=100,
)
if isinstance(textgen, Table):
return Panel(textgen, padding=(0, 20), **kwargs) # type: ignore
return Panel(textgen, padding=(0, 20), **kwargs) # type: ignore[arg-type]
else:
text = "".join(textgen)
return Panel(f"[color(241)]{text.rstrip()}[/color(241)]", **kwargs) # type: ignore
return Panel(f"[color(241)]{text.rstrip()}[/color(241)]", **kwargs) # type: ignore[arg-type]
else:
text = "".join(textgen)
return title + ":\n" + text
Expand Down Expand Up @@ -220,22 +222,14 @@ def _format_flags(accessor, rich):
table.add_column("Value", justify="right")
table.add_column("Bits", justify="center")

for val, bit, (key, (mask, value)) in zip(
value_text, bit_text, flag_dict.items()
):
table.add_row(
_format_cf_name(key, rich),
val,
bit,
)
for val, bit, key in zip(value_text, bit_text, flag_dict):
table.add_row(_format_cf_name(key, rich), val, bit)

return table

else:
rows = []
for val, bit, (key, (mask, value)) in zip(
value_text, bit_text, flag_dict.items()
):
for val, bit, key in zip(value_text, bit_text, flag_dict):
rows.append(f"{TAB}{_format_cf_name(key, rich)}: {TAB} {val} {bit}")
return _print_rows("Flag Meanings", rows, rich)

Expand Down
4 changes: 2 additions & 2 deletions cf_xarray/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class set_options: # numpydoc ignore=PR01,PR02

def __init__(self, **kwargs):
self.old = {}
for k, v in kwargs.items():
for k in kwargs:
if k not in OPTIONS:
raise ValueError(
f"argument name {k!r} is not in the set of valid options {set(OPTIONS)!r}"
Expand All @@ -58,7 +58,7 @@ def __init__(self, **kwargs):

def _apply_update(self, options_dict):
options_dict = copy.deepcopy(options_dict)
for k, v in options_dict.items():
for k in options_dict:
if k == "custom_criteria":
options_dict["custom_criteria"] = always_iterable(
options_dict["custom_criteria"], allowed=(tuple, list)
Expand Down
Loading