Skip to content

Commit

Permalink
ruff update sachen
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed Aug 29, 2023
1 parent bd49c40 commit 69a7b68
Show file tree
Hide file tree
Showing 12 changed files with 71 additions and 61 deletions.
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def substitute(matchobj):
mathjax_path = ""
else:
extensions.append("sphinx.ext.mathjax")
mathjax_path = "https://cdn.jsdelivr.net/npm/mathjax@3/es5/" "tex-chtml.js"
mathjax_path = "https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-chtml.js"

autodoc_default_options = {"members": True, "inherited-members": True, "special_members": True}
# autodoc_typehints = 'description' # Does not work as expected. Maybe try at future date again
Expand Down Expand Up @@ -202,6 +202,7 @@ def skip_properties(app, what, name, obj, skip, options):
"""This removes all properties from the documentation as they are expected to be documented in the docstring."""
if isinstance(obj, property):
return True
return None


def setup(app):
Expand Down
37 changes: 19 additions & 18 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ memory-profiler = "^0.58.0"
matplotlib = "^3.4.3"
toml = "^0.10.2"
Sphinx = "^6.1.3"
ruff = "^0.0.235"
ruff = "^0.0.286"


[[tool.poetry.source]]
Expand Down Expand Up @@ -156,8 +156,6 @@ ignore = [
"EM101",
"EM102",
"EM103",
# Multiline docstring summary
"D213",
# Varaibles before return
"RET504",
# Abstract raise into inner function
Expand All @@ -169,7 +167,11 @@ ignore = [
# df as varaible name
"PD901",
# melt over stack
"PD013"
"PD013",
# Avoid specifying long messages outside the exception class
"TRY003",
# To many arguments
"PLR0913"
]


Expand Down
2 changes: 1 addition & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def test_nested_mutable_algorithm_copy():
assert (
joblib.hash(test_instance.mutable.get_params())
== joblib.hash(nested_instance.get_params())
== joblib.hash({k: f for k, f in nested_params.items()})
== joblib.hash(dict(nested_params.items()))
)


Expand Down
4 changes: 2 additions & 2 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _create_valid_index(input_dict=None, columns_names=None):

def _create_random_bool_map(n, seed):
np.random.seed(seed)
return list(map(lambda x: x >= 0.5, np.random.rand(n)))
return [x >= 0.5 for x in np.random.rand(n)]


class TestDataset:
Expand Down Expand Up @@ -340,7 +340,7 @@ def test_getitem_error_input(self, subscript, select_lvl, what_to_expect):
"groupby_level", (["patients"], ["patients", "tests"], ["patients", "tests", "extra with space"])
)
@pytest.mark.parametrize(
"index,is_single_level",
("index", "is_single_level"),
(
(
_create_valid_index(
Expand Down
7 changes: 2 additions & 5 deletions tests/test_parameter_string_annot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ def __init__(self, hyper: int, normal: str, custom_annotated: int, normal_no_ann


def test_import_forward():
if TYPE_CHECKING:
pass

class Test(BaseTpcpObject):
hyper: HyperPara[int]
normal: Para[renamed_optimize]
Expand Down Expand Up @@ -85,11 +82,11 @@ def __init__(self, hyper: int, normal: optimize.GridSearch, custom_annotated: in
def test_test_str_based_forward():
class Test(BaseTpcpObject):
hyper: HyperPara[int]
normal: Para["Dataset"]
normal: Para[Dataset]
custom_annotated: Annotated[HyperPara[int], "custom_metadata"]
normal_no_annot: int

def __init__(self, hyper: int, normal: "Dataset", custom_annotated: int, normal_no_annot: int):
def __init__(self, hyper: int, normal: Dataset, custom_annotated: int, normal_no_annot: int):
self.hyper = hyper
self.normal = normal
self.custom_annotated = custom_annotated
Expand Down
24 changes: 12 additions & 12 deletions tpcp/_algorithm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@

from typing_extensions import Concatenate, ParamSpec

from tpcp import Algorithm
from tpcp._base import NOTHING, _get_annotated_fields_of_type
from tpcp._hash import custom_hash
from tpcp._parameters import _ParaTypes
from tpcp.exceptions import PotentialUserErrorWarning

if TYPE_CHECKING:
from tpcp import Algorithm, OptimizableAlgorithm, OptimizablePipeline
from tpcp import OptimizableAlgorithm, OptimizablePipeline
from tpcp._algorithm import AlgorithmT

OptimizableT = TypeVar("OptimizableT", OptimizablePipeline, OptimizableAlgorithm)
Expand Down Expand Up @@ -140,11 +141,9 @@ def _check_safe_run(algorithm: AlgorithmT, old_method: Callable, *args: Any, **k
before_paras = algorithm.get_params()
before_paras_hash = custom_hash(before_paras)
output: AlgorithmT
if hasattr(old_method, "__self__"):
# In this case the method is already bound and we do not need to pass the algo as first argument
output = old_method(*args, **kwargs)
else:
output = old_method(algorithm, *args, **kwargs)

# In this case the method is already bound and we do not need to pass the algo as first argument
output = old_method(*args, **kwargs) if hasattr(old_method, "__self__") else old_method(algorithm, *args, **kwargs)
after_paras = algorithm.get_params()
after_paras_hash = custom_hash(after_paras)
if not before_paras_hash == after_paras_hash:
Expand Down Expand Up @@ -220,6 +219,7 @@ def safe_wrapped(self: AlgorithmT, *args: P.args, **kwargs: P.kwargs) -> Algorit
f"` _action_methods = ({action_method.__name__},)`\n\n"
"Or append it to the tuple, if it already exists.",
PotentialUserErrorWarning,
stacklevel=2,
)
return _check_safe_run(self, action_method, *args, **kwargs)

Expand All @@ -232,11 +232,9 @@ def _get_nested_opti_paras(algorithm: Algorithm, opti_para_names: List[str]) ->
optimizable_paras = {}
other_paras = {}
for p, v in paras.items():
if p in opti_para_names:
optimizable_paras[p] = v
# For each optimizable parameter, we also add all children, as they are also allowed to change,
# if the parent is allowed to.
elif any(p.startswith(o + "__") for o in opti_para_names):
if p in opti_para_names or any(p.startswith(o + "__") for o in opti_para_names):
# For each optimizable parameter, we also add all children, as they are also allowed to change,
# if the parent is allowed to.
optimizable_paras[p] = v
else:
other_paras[p] = v
Expand All @@ -249,7 +247,7 @@ def _get_nested_opti_paras(algorithm: Algorithm, opti_para_names: List[str]) ->
return optimizable_paras, other_paras


def _check_safe_optimize( # noqa: C901
def _check_safe_optimize( # noqa: C901, PLR0912
algorithm: OptimizableT, old_method: Callable, *args: Any, **kwargs: Any
) -> OptimizableT:

Expand Down Expand Up @@ -349,6 +347,7 @@ def _check_safe_optimize( # noqa: C901
f"({optimizable_paras}). "
"This could indicate an implementation error of the `self_optimize` method.",
PotentialUserErrorWarning,
stacklevel=2,
)
if other_returns != (NOTHING, NOTHING):
return optimized_algorithm, other_returns
Expand Down Expand Up @@ -402,6 +401,7 @@ def safe_wrapped(self: OptimizableT, *args: P.args, **kwargs: P.kwargs) -> Optim
"The `make_optimize_safe` decorator is only meant for the `self_optimize` method, but you applied it "
f"to the `{self_optimize_method.__name__}` method.",
PotentialUserErrorWarning,
stacklevel=2,
)
try:
return _check_safe_optimize(self, self_optimize_method, *args, **kwargs)
Expand Down
18 changes: 10 additions & 8 deletions tpcp/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _retry_eval_with_missing_locals(
def _custom_get_type_hints(cls: Type[_BaseTpcpObject]) -> Dict[str, Any]:
"""Extract type hints while avoiding issues with forward references.
We automatically skip all douple_underscore methods.
We automatically skip all douple-underscore methods.
"""
hints = {}
for base in reversed(cls.__mro__):
Expand All @@ -194,11 +194,11 @@ def _custom_get_type_hints(cls: Type[_BaseTpcpObject]) -> Dict[str, Any]:
if name.startswith("__"):
continue
if value is None:
value = type(None)
value = type(None) # noqa: PLW2901
elif isinstance(value, str):
# NOTE: This does not check if the str is a valid expression.
# This might not be an issue, but could lead to obscure error messages.
value = _retry_eval_with_missing_locals(value, base_globals)
value = _retry_eval_with_missing_locals(value, base_globals) # noqa: PLW2901
hints[name] = value
return hints

Expand All @@ -212,7 +212,7 @@ def _extract_annotations(
origin = get_origin(v)
if origin is ClassVar:
# If the parameter is a ClassVar, we go one level deeper and check, if its argument was annotated.
v = get_args(v)[0]
v = get_args(v)[0] # noqa: PLW2901
origin = get_origin(v)
if origin is Annotated:
for annot in get_args(v)[1:]:
Expand Down Expand Up @@ -265,7 +265,8 @@ def _validate_all_parent_parameters_implemented(cls: Type[_BaseTpcpObject]):
f"Missing parameters: {missing_params}\n"
"This might not be a problem, but indicates bad design and you might run into actual issues with some "
"of the validation magic `tpcp` does in the background. "
"We would recommend to implement all parameters of your parents in a subclass."
"We would recommend to implement all parameters of your parents in a subclass.",
stacklevel=2,
)


Expand Down Expand Up @@ -495,7 +496,7 @@ def _set_comp_field(instance, field_name, params):
# We first partition our field names to know to which index they belong
comp_params: DefaultDict[str, Any] = defaultdict(dict)
for key, value in params.items():
key, delim, sub_key = key.partition("__")
key, delim, sub_key = key.partition("__") # noqa: PLW2901
if delim:
comp_params[key][sub_key] = value
else:
Expand Down Expand Up @@ -537,7 +538,7 @@ def _set_params(instance: BaseTpcpObjectObjT, **params: Any) -> BaseTpcpObjectOb

nested_params: DefaultDict[str, Any] = defaultdict(dict) # grouped by prefix
for key, value in params.items():
key, delim, sub_key = key.partition("__")
key, delim, sub_key = key.partition("__") # noqa: PLW2901
if key not in valid_params:
raise ValueError(f"`{key}` is not a valid parameter name for {type(instance).__name__}.")

Expand Down Expand Up @@ -626,6 +627,7 @@ def _annotations_are_valid(
"Annotating a nested parameter (parameter like `nested_object__nest_para` as a simple "
"Parameter has no effect and the entire line should be removed.",
PotentialUserErrorWarning,
stacklevel=2,
)
elif k not in fields:
raise ValueError(
Expand Down Expand Up @@ -728,7 +730,7 @@ def clone(algorithm: T, *, safe: bool = False) -> T:
with Path(os.devnull).open("w") as devnull, contextlib.redirect_stdout(devnull):
return copy.deepcopy(algorithm)
raise TypeError(
f"Cannot clone object '{repr(algorithm)}' (type {type(algorithm)}): "
f"Cannot clone object '{algorithm!r}' (type {type(algorithm)}): "
"it does not seem to be a compatible algorithm/pipline class or general `tpcp` object as it does not "
"inherit from `BaseTpcpObject` or `Algorithm` or `Pipeline`."
)
Expand Down
4 changes: 2 additions & 2 deletions tpcp/_utils/_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _score(
return result


def _optimize_and_score( # noqa: C901
def _optimize_and_score(
optimizer: BaseOptimize,
scorer: Scorer,
train_set: Dataset,
Expand Down Expand Up @@ -229,7 +229,7 @@ def _optimize_and_score( # noqa: C901
# instance of the trained pipeline.
result["optimizer"] = optimizer
if return_parameters:
result["parameters"] = {**hyperparameters, **pure_parameters} or None
result["parameters"] = {**hyperparameters, **pure_parameters}
return result


Expand Down
5 changes: 4 additions & 1 deletion tpcp/optimize/_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def optimize(self, dataset: DatasetT, **optimize_params: Any) -> Self: # noqa:
"`DummyOptimize` does never call this method and skips any optimization steps! "
"Use `Optimize` if you actually want to optimize your pipeline.",
PotentialUserErrorWarning,
stacklevel=2,
)
self.optimized_pipeline_ = self.pipeline.clone()
return self
Expand Down Expand Up @@ -830,6 +831,7 @@ def _store(key_name: str, array, weights=None, splits=False, rank=False):
warnings.warn(
f"One or more of the {key_name.split('_')[0]} scores are non-finite: {array_means}",
category=UserWarning,
stacklevel=2,
)
# Weighted std is not directly available in numpy
array_stds = np.sqrt(np.average((array - array_means[:, np.newaxis]) ** 2, axis=1, weights=weights))
Expand Down Expand Up @@ -914,7 +916,8 @@ def _validate_return_optimized(return_optimized, multi_metric, results) -> Tuple
"single score."
"`return_optimized` is set to True. "
"The only allowed string value for `return_optimized` in a single metric case is `-score`, "
"to invert the metric before score selection."
"to invert the metric before score selection.",
stacklevel=2,
)
return reverse, "score"
raise ValueError("`return_optimized` must be a bool or explicitly `score` or `-score` in a single metric case.")
Expand Down
Loading

0 comments on commit 69a7b68

Please sign in to comment.