Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed mypy sp check guidelines #4887

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".ipynb_checkpoints"]

# Suppress warnings generated by Sphinx and/or by Sphinx extensions
suppress_warnings = []
suppress_warnings = [] # type: list[str]

# -- Options for HTML output -------------------------------------------------

Expand Down Expand Up @@ -174,7 +174,7 @@
html_title = f"{project} v{version} Manual"
html_last_updated_fmt = "%Y-%m-%d"
html_css_files = ["pybamm.css"]
html_context = {"default_mode": "light"}
html_context = {"default_mode": "light"} # type: dict[str, str | bool | None | ParameterSets]
html_use_modindex = True
html_copy_source = False
html_domain_indices = False
Expand All @@ -195,7 +195,7 @@
)

# Set canonical URL from the Read the Docs Domain
html_baseurl = os.getenv("READTHEDOCS_CANONICAL_URL", "")
html_baseurl = os.getenv("READTHEDOCS_CANONICAL_URL", "") # type: str

# Tell Jinja2 templates the build is running on Read the Docs
if os.getenv("READTHEDOCS") == "True":
Expand Down Expand Up @@ -231,7 +231,7 @@
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
} # type: dict[str, str]

# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
Expand Down Expand Up @@ -321,7 +321,7 @@
# made to a notebook, if any.
# On local builds, the version is not set, so we use "latest".

notebooks_version = version
notebooks_version = version # type: str | None
append_to_url = f"blob/v{notebooks_version}"

if (os.environ.get("READTHEDOCS_VERSION") == "latest") or (
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/SPM_compare_particle_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
disc.process_model(model)

# solve model
solutions = [None] * len(models)
solutions = [None] * len(models) # type: Any
t_eval = np.linspace(0, 3600, 100)
for i, model in enumerate(models):
solutions[i] = model.default_solver.solve(model, t_eval)
Expand Down
21 changes: 11 additions & 10 deletions examples/scripts/SPMe_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@
time += dt

# plot
time_in_seconds = solution["Time [s]"].entries
step_time_in_seconds = step_solution["Time [s]"].entries
voltage = solution["Voltage [V]"].entries
step_voltage = step_solution["Voltage [V]"].entries
plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)")
plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)")
plt.xlabel(r"$t$")
plt.ylabel("Voltage [V]")
plt.legend()
plt.show()
if step_solution is not None:
time_in_seconds = solution["Time [s]"].entries
step_time_in_seconds = step_solution["Time [s]"].entries
voltage = solution["Voltage [V]"].entries
step_voltage = step_solution["Voltage [V]"].entries
plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)")
plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)")
plt.xlabel(r"$t$")
plt.ylabel("Voltage [V]")
plt.legend()
plt.show()
2 changes: 1 addition & 1 deletion examples/scripts/heat_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def T_exact(x, t):
# Plot ------------------------------------------------------------------------
x_nodes = mesh["rod"].nodes # numerical gridpoints
xx = np.linspace(0, 2, 101) # fine mesh to plot exact solution
plot_times = np.linspace(0, 1, 5)
plot_times = np.linspace(0, 1, 5) # type: np.ndarray

plt.figure(figsize=(15, 8))
cmap = plt.get_cmap("inferno")
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/minimal_example_of_lookup_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def process_2D(name, data):
D_s_n_data = process_2D("Negative particle diffusivity [m2.s-1]", df)


def D_s_n(sto, T):
def D_s_n_func(sto, T):
name, (x, y) = D_s_n_data
return pybamm.Interpolant(x, y, [T, sto], name)


parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n
parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n_func

k_n = parameter_values["Negative electrode exchange-current density [A.m-2]"]

Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ concurrency = ["multiprocessing"]
ignore_missing_imports = true
allow_redefinition = true
disable_error_code = ["call-overload", "operator"]
strict = false
Copy link
Member

@Saransh-cpp Saransh-cpp Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how big this PR is, it would actually be better to split it into multiple PRs, each one adding a new config option in pyproject.toml.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok so should I go ahead then and make a PR on one of the config first? or edit this one accordingly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can keep this PR for one config, and add other configs in subsequent PRs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I'll create seperate PRs for different configs and then keep this one for the end, I think that would be faster for me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed many of the warn_unreachable error depends on errors from enable_error_code config, they're related to each other and their are total of 77 errors out of which 57 are from enable_error_code so I think creating a seperate PR would still be almost as big as this one, so should I still proceed with it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think anything that reduces the diff and keeps this PR scoped to a specific change (or a few of them) would be great. Thanks for investigating!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created the PR with just enable_error_code config: #4891, sorry for the delay lab tests going on

warn_unreachable = true
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
exclude = 'build/'

[[tool.mypy.overrides]]
module = [
Expand Down
4 changes: 3 additions & 1 deletion src/pybamm/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class Experiment:

def __init__(
self,
operating_conditions: list[str | tuple[str] | BaseStep],
operating_conditions: list[
str | tuple[str, ...] | tuple[str | BaseStep] | BaseStep
],
period: str | None = None,
temperature: float | None = None,
termination: list[str] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/experiment/step/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
self.value = pybamm.Interpolant(
t,
y,
pybamm.t - pybamm.InputParameter("start time"),
pybamm.t - pybamm.InputParameter("start time"), # type: ignore[arg-type]
name="Drive Cycle",
)
self.period = np.diff(t).min()
Expand Down
9 changes: 4 additions & 5 deletions src/pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Binary operator classes
#
from __future__ import annotations
import numbers

import numpy as np
import sympy
Expand Down Expand Up @@ -33,8 +32,8 @@ def _preprocess_binary(
raise ValueError("right must be a 1D array")
right = pybamm.Vector(right)

# Check both left and right are pybamm Symbols
if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)):
# Check right is pybamm Symbol
if not isinstance(right, pybamm.Symbol):
raise NotImplementedError(
f"BinaryOperator not implemented for symbols of type {type(left)} and {type(right)}"
)
Expand Down Expand Up @@ -127,7 +126,7 @@ def create_copy(
children = self._children_for_copying(new_children)

if not perform_simplifications:
out = self.__class__(children[0], children[1])
out = self.__class__(*children)
else:
# creates a new instance using the overloaded binary operator to perform
# additional simplifications, rather than just calling the constructor
Expand Down Expand Up @@ -1538,7 +1537,7 @@ def source(
corresponding to a source term in the bulk.
"""
# Broadcast if left is number
if isinstance(left, numbers.Number):
if isinstance(left, (int, float)):
left = pybamm.PrimaryBroadcast(left, "current collector")

# force type cast for mypy
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def __init__(
self,
child_input: Numeric | pybamm.Symbol,
broadcast_domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
auxiliary_domains: AuxiliaryDomainType | str = None,
broadcast_domains: DomainsType = None,
name: str | None = None,
):
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def __init__(self, *children, name: Optional[str] = None):
if name is None:
# Name is the intersection of the children names (should usually make sense
# if the children have been named consistently)
name = intersect(children[0].name, children[1].name)
name = intersect(children[0].name, children[1].name) or ""
for child in children[2:]:
name = intersect(name, child.name)
if len(name) == 0:
Expand Down
7 changes: 2 additions & 5 deletions src/pybamm/expression_tree/operations/serialise.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self):
class _SymbolEncoder(json.JSONEncoder):
"""Converts PyBaMM symbols into a JSON-serialisable format"""

def default(self, node: dict):
def default(self, node: dict | pybamm.Symbol):
node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)}
if isinstance(node, pybamm.Symbol):
node_dict.update(node.to_json()) # this doesn't include children
Expand All @@ -46,7 +46,7 @@ def default(self, node: dict):
class _MeshEncoder(json.JSONEncoder):
"""Converts PyBaMM meshes into a JSON-serialisable format"""

def default(self, node: pybamm.Mesh):
def default(self, node: pybamm.Mesh | pybamm.SubMesh):
node_dict = {"py/object": str(type(node))[8:-2], "py/id": id(node)}
if isinstance(node, pybamm.Mesh):
node_dict.update(node.to_json())
Expand All @@ -64,9 +64,6 @@ def default(self, node: pybamm.Mesh):
node_dict.update(node.to_json())
return node_dict

node_dict["json"] = json.JSONEncoder.default(self, node) # pragma: no cover
return node_dict # pragma: no cover

class _Empty:
"""A dummy class to aid deserialisation"""

Expand Down
8 changes: 6 additions & 2 deletions src/pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def create_object_of_size(size: int, typ="vector"):
return np.nan * np.ones((size, size))


def evaluate_for_shape_using_domain(domains: dict[str, list[str] | str], typ="vector"):
def evaluate_for_shape_using_domain(
domains: dict[str, list[str] | str] | list[str], typ="vector"
):
"""
Return a vector of the appropriate shape, based on the domains.
Domain 'sizes' can clash, but are unlikely to, and won't cause failures if they do.
Expand Down Expand Up @@ -964,7 +966,9 @@ def to_casadi(
"""
return pybamm.CasadiConverter(casadi_symbols).convert(self, t, y, y_dot, inputs)

def _children_for_copying(self, children: list[Symbol] | None = None) -> Symbol:
def _children_for_copying(
self, children: list[Symbol] | None = None
) -> list[Symbol]:
"""
Gets existing children for a symbol being copied if they aren't provided.
"""
Expand Down
4 changes: 2 additions & 2 deletions src/pybamm/expression_tree/unary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import sympy
import pybamm
from pybamm.util import import_optional_dependency
from pybamm.type_definitions import DomainsType
from pybamm.type_definitions import DomainsType, Numeric


class UnaryOperator(pybamm.Symbol):
Expand All @@ -31,7 +31,7 @@ class UnaryOperator(pybamm.Symbol):
def __init__(
self,
name: str,
child: pybamm.Symbol,
child: pybamm.Symbol | Numeric,
domains: DomainsType = None,
):
if isinstance(child, (float, int, np.number)):
Expand Down
10 changes: 5 additions & 5 deletions src/pybamm/expression_tree/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ def __init__(
domains: DomainsType = None,
bounds: tuple[pybamm.Symbol] | None = None,
print_name: str | None = None,
scale: float | pybamm.Symbol | None = 1,
reference: float | pybamm.Symbol | None = 0,
scale: float | int | pybamm.Symbol | None = 1,
reference: float | int | pybamm.Symbol | None = 0,
):
if isinstance(scale, numbers.Number):
if isinstance(scale, (float, int)):
scale = pybamm.Scalar(scale)
if isinstance(reference, numbers.Number):
if isinstance(reference, (float, int)):
reference = pybamm.Scalar(reference)
self._scale = scale
self._reference = reference
Expand All @@ -88,7 +88,7 @@ def bounds(self):
return self._bounds

@bounds.setter
def bounds(self, values: tuple[Numeric, Numeric]):
def bounds(self, values: tuple[Numeric, Numeric] | None):
if values is None:
values = (-np.inf, np.inf)
else:
Expand Down
2 changes: 2 additions & 0 deletions src/pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def __init__(self, name="Unnamed model"):
self.use_jacobian = True
self.convert_to_format = "casadi"

self.calculate_sensitivities = []

# Model is not initially discretised
self.is_discretised = False
self.y_slices = None
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
# Set colors, linestyles, figsize, axis limits
# call LoopList to make sure list index never runs out
if colors is None:
self.colors = LoopList(colors or ["r", "b", "k", "g", "m", "c"])
self.colors = LoopList(["r", "b", "k", "g", "m", "c"])
else:
self.colors = LoopList(colors)
self.linestyles = LoopList(linestyles or ["-", ":", "--", "-."])
Expand Down
10 changes: 5 additions & 5 deletions src/pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def supports_parallel_solve(self):
def requires_explicit_sensitivities(self):
return True

@root_method.setter
def root_method(self, method):
@root_method.setter # type: ignore[attr-defined, no-redef]
def root_method(self, method) -> None:
if method == "casadi":
method = pybamm.CasadiAlgebraicSolver(self.root_tol)
elif isinstance(method, str):
Expand Down Expand Up @@ -1122,7 +1122,7 @@ def _set_sens_initial_conditions_from(
"""

ninputs = len(model.calculate_sensitivities)
initial_conditions = tuple([] for _ in range(ninputs))
initial_conditions = tuple([] for _ in range(ninputs)) # type: tuple
solution = solution.last_state
for var in model.initial_conditions:
final_state = solution[var.name]
Expand All @@ -1143,10 +1143,10 @@ def _set_sens_initial_conditions_from(
slices = [y_slices[symbol][0] for symbol in model.initial_conditions.keys()]

# sort equations according to slices
concatenated_initial_conditions = [
concatenated_initial_conditions = tuple(
casadi.vertcat(*[eq for _, eq in sorted(zip(slices, init))])
for init in initial_conditions
]
)
return concatenated_initial_conditions

def process_t_interp(self, t_interp):
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/solvers/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def __init__(
def has_sensitivities(self) -> bool:
if isinstance(self._all_sensitivities, bool):
return self._all_sensitivities
elif isinstance(self._all_sensitivities, dict):
else:
return len(self._all_sensitivities) > 0

def extract_explicit_sensitivities(self):
Expand Down
Loading
Loading