Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed mypy sp check guidelines #4887

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- Moved concentration inside x-averaged when calculating LLI due to LAM variables ([#4858](https://github.com/pybamm-team/PyBaMM/pull/4858))
- Fixed a bug that caused the variable `"Loss of lithium due to {domain} lithium plating"`to have the domain `"current collector"` (should not have any domain at all) if the `"x-average side reactions"` option was set to `"true"`. ([#4844](https://github.com/pybamm-team/PyBaMM/pull/4844))
- Fixed interpolation bug in `pybamm.QuickPlot` with spatial variables. ([#4841](https://github.com/pybamm-team/PyBaMM/pull/4841))
- Fixed mypy sp check guidelines ([#4887](https://github.com/pybamm-team/PyBaMM/pull/4887))

## Optimizations

Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/SPM_compare_particle_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Compare different discretisations in the particle
#
import argparse
from typing import Any
import numpy as np
import pybamm
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -48,7 +49,7 @@
disc.process_model(model)

# solve model
solutions = [None] * len(models)
solutions: Any = [None] * len(models)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be something like list[Any]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yess, I'll change that

t_eval = np.linspace(0, 3600, 100)
for i, model in enumerate(models):
solutions[i] = model.default_solver.solve(model, t_eval)
Expand Down
21 changes: 11 additions & 10 deletions examples/scripts/SPMe_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@
time += dt

# plot
time_in_seconds = solution["Time [s]"].entries
step_time_in_seconds = step_solution["Time [s]"].entries
voltage = solution["Voltage [V]"].entries
step_voltage = step_solution["Voltage [V]"].entries
plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)")
plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)")
plt.xlabel(r"$t$")
plt.ylabel("Voltage [V]")
plt.legend()
plt.show()
if step_solution is not None:
time_in_seconds = solution["Time [s]"].entries
step_time_in_seconds = step_solution["Time [s]"].entries
voltage = solution["Voltage [V]"].entries
step_voltage = step_solution["Voltage [V]"].entries
plt.plot(time_in_seconds, voltage, "b-", label="SPMe (continuous solve)")
plt.plot(step_time_in_seconds, step_voltage, "ro", label="SPMe (stepped solve)")
plt.xlabel(r"$t$")
plt.ylabel("Voltage [V]")
plt.legend()
plt.show()
2 changes: 1 addition & 1 deletion examples/scripts/heat_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def T_exact(x, t):
# Plot ------------------------------------------------------------------------
x_nodes = mesh["rod"].nodes # numerical gridpoints
xx = np.linspace(0, 2, 101) # fine mesh to plot exact solution
plot_times = np.linspace(0, 1, 5)
plot_times: np.ndarray = np.linspace(0, 1, 5)

plt.figure(figsize=(15, 8))
cmap = plt.get_cmap("inferno")
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/minimal_example_of_lookup_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ def process_2D(name, data):
D_s_n_data = process_2D("Negative particle diffusivity [m2.s-1]", df)


def D_s_n(sto, T):
def D_s_n_func(sto, T):
name, (x, y) = D_s_n_data
return pybamm.Interpolant(x, y, [T, sto], name)


parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n
parameter_values["Negative particle diffusivity [m2.s-1]"] = D_s_n_func

k_n = parameter_values["Negative electrode exchange-current density [A.m-2]"]

Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ concurrency = ["multiprocessing"]
ignore_missing_imports = true
allow_redefinition = true
disable_error_code = ["call-overload", "operator"]
strict = false
Copy link
Member

@Saransh-cpp Saransh-cpp Mar 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how big this PR is, it would actually be better to split it into multiple PRs, each one adding a new config option in pyproject.toml.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok so should I go ahead then and make a PR on one of the config first? or edit this one accordingly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can keep this PR for one config, and add other configs in subsequent PRs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I'll create seperate PRs for different configs and then keep this one for the end, I think that would be faster for me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed many of the warn_unreachable error depends on errors from enable_error_code config, they're related to each other and their are total of 77 errors out of which 57 are from enable_error_code so I think creating a seperate PR would still be almost as big as this one, so should I still proceed with it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think anything that reduces the diff and keeps this PR scoped to a specific change (or a few of them) would be great. Thanks for investigating!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created the PR with just enable_error_code config: #4891, sorry for the delay lab tests going on

warn_unreachable = true
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
exclude = "^(build/|docs/conf\\.py)$"

[[tool.mypy.overrides]]
module = [
Expand Down
4 changes: 3 additions & 1 deletion src/pybamm/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class Experiment:

def __init__(
self,
operating_conditions: list[str | tuple[str] | BaseStep],
operating_conditions: list[
str | tuple[str, ...] | tuple[str | BaseStep] | BaseStep
],
period: str | None = None,
temperature: float | None = None,
termination: list[str] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion src/pybamm/experiment/step/base_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
self.value = pybamm.Interpolant(
t,
y,
pybamm.t - pybamm.InputParameter("start time"),
pybamm.t - pybamm.InputParameter("start time"), # type: ignore[arg-type]
name="Drive Cycle",
)
self.period = np.diff(t).min()
Expand Down
48 changes: 43 additions & 5 deletions src/pybamm/expression_tree/binary_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Binary operator classes
#
from __future__ import annotations
import numbers

import numpy as np
import sympy
Expand Down Expand Up @@ -33,8 +32,8 @@ def _preprocess_binary(
raise ValueError("right must be a 1D array")
right = pybamm.Vector(right)

# Check both left and right are pybamm Symbols
if not (isinstance(left, pybamm.Symbol) and isinstance(right, pybamm.Symbol)):
# Check right is pybamm Symbol
if not isinstance(right, pybamm.Symbol):
raise NotImplementedError(
f"BinaryOperator not implemented for symbols of type {type(left)} and {type(right)}"
)
Expand Down Expand Up @@ -113,6 +112,9 @@ def __str__(self):
right_str = f"{self.right!s}"
return f"{left_str} {self.name} {right_str}"

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return self.__class__(self.name, left, right) # pragma: no cover

def create_copy(
self,
new_children: list[pybamm.Symbol] | None = None,
Expand All @@ -127,7 +129,7 @@ def create_copy(
children = self._children_for_copying(new_children)

if not perform_simplifications:
out = self.__class__(children[0], children[1])
out = self._new_instance(children[0], children[1])
else:
# creates a new instance using the overloaded binary operator to perform
# additional simplifications, rather than just calling the constructor
Expand Down Expand Up @@ -224,6 +226,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("**", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Power(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply chain rule and power rule
Expand Down Expand Up @@ -273,6 +278,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("+", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Addition(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
return self.left.diff(variable) + self.right.diff(variable)
Expand Down Expand Up @@ -300,6 +308,9 @@ def __init__(

super().__init__("-", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Subtraction(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
return self.left.diff(variable) - self.right.diff(variable)
Expand Down Expand Up @@ -329,6 +340,9 @@ def __init__(

super().__init__("*", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Multiplication(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply product rule
Expand Down Expand Up @@ -369,6 +383,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("@", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return MatrixMultiplication(left, right) # pragma: no cover

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
# We shouldn't need this
Expand Down Expand Up @@ -418,6 +435,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("/", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Division(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply quotient rule
Expand Down Expand Up @@ -466,6 +486,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("inner product", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Inner(left, right) # pragma: no cover

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply product rule
Expand Down Expand Up @@ -543,6 +566,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__("==", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Equality(left, right)

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
# Equality should always be multiplied by something else so hopefully don't
Expand Down Expand Up @@ -601,6 +627,9 @@ def __init__(
"""See :meth:`pybamm.BinaryOperator.__init__()`."""
super().__init__(name, left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return _Heaviside(left, right) # pragma: no cover

def diff(self, variable):
"""See :meth:`pybamm.Symbol.diff()`."""
# Heaviside should always be multiplied by something else so hopefully don't
Expand Down Expand Up @@ -678,6 +707,9 @@ def __init__(
):
super().__init__("%", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Modulo(left, right)

def _diff(self, variable: pybamm.Symbol):
"""See :meth:`pybamm.Symbol._diff()`."""
# apply chain rule and power rule
Expand Down Expand Up @@ -720,6 +752,9 @@ def __init__(
):
super().__init__("minimum", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Minimum(left, right)

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
return f"minimum({self.left!s}, {self.right!s})"
Expand Down Expand Up @@ -764,6 +799,9 @@ def __init__(
):
super().__init__("maximum", left, right)

def _new_instance(self, left: pybamm.Symbol, right: pybamm.Symbol) -> pybamm.Symbol:
return Maximum(left, right)

def __str__(self):
"""See :meth:`pybamm.Symbol.__str__()`."""
return f"maximum({self.left!s}, {self.right!s})"
Expand Down Expand Up @@ -1538,7 +1576,7 @@ def source(
corresponding to a source term in the bulk.
"""
# Broadcast if left is number
if isinstance(left, numbers.Number):
if isinstance(left, (int, float)):
left = pybamm.PrimaryBroadcast(left, "current collector")

# force type cast for mypy
Expand Down
17 changes: 14 additions & 3 deletions src/pybamm/expression_tree/broadcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def _from_json(cls, snippet):
)

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)
pass # pragma: no cover


class PrimaryBroadcast(Broadcast):
Expand Down Expand Up @@ -191,6 +190,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
return self.orphans[0]

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class PrimaryBroadcastToEdges(PrimaryBroadcast):
"""A primary broadcast onto the edges of the domain."""
Expand Down Expand Up @@ -321,6 +324,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
return self.orphans[0]

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class SecondaryBroadcastToEdges(SecondaryBroadcast):
"""A secondary broadcast onto the edges of a domain."""
Expand Down Expand Up @@ -438,6 +445,10 @@ def reduce_one_dimension(self):
"""Reduce the broadcast by one dimension."""
raise NotImplementedError

def _unary_new_copy(self, child: pybamm.Symbol, perform_simplifications=True):
"""See :meth:`pybamm.UnaryOperator._unary_new_copy()`."""
return self.__class__(child, self.broadcast_domain)


class TertiaryBroadcastToEdges(TertiaryBroadcast):
"""A tertiary broadcast onto the edges of a domain."""
Expand All @@ -463,7 +474,7 @@ def __init__(
self,
child_input: Numeric | pybamm.Symbol,
broadcast_domain: DomainType = None,
auxiliary_domains: AuxiliaryDomainType = None,
auxiliary_domains: AuxiliaryDomainType | str = None,
broadcast_domains: DomainsType = None,
name: str | None = None,
):
Expand Down
6 changes: 3 additions & 3 deletions src/pybamm/expression_tree/concatenations.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def __init__(self, *children, name: Optional[str] = None):
if name is None:
# Name is the intersection of the children names (should usually make sense
# if the children have been named consistently)
name = intersect(children[0].name, children[1].name)
name = intersect(children[0].name, children[1].name) or ""
for child in children[2:]:
name = intersect(name, child.name)
if len(name) == 0:
Expand Down Expand Up @@ -514,7 +514,7 @@ def substrings(s: str):
yield s[i : j + 1]


def intersect(s1: str, s2: str):
def intersect(s1: str, s2: str) -> str:
# find all the common strings between two strings
all_intersects = set(substrings(s1)) & set(substrings(s2))
# intersect is the longest such intercept
Expand All @@ -525,7 +525,7 @@ def intersect(s1: str, s2: str):
return intersect.lstrip().rstrip()


def simplified_concatenation(*children, name: Optional[str] = None):
def simplified_concatenation(*children, name=None):
"""Perform simplifications on a concatenation."""
# remove children that are None
children = list(filter(lambda x: x is not None, children))
Expand Down
5 changes: 3 additions & 2 deletions src/pybamm/expression_tree/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
from scipy import special
import sympy
from typing import Callable
from typing import Callable, cast
from collections.abc import Sequence
from typing_extensions import TypeVar

Expand All @@ -32,7 +32,7 @@ class Function(pybamm.Symbol):
def __init__(
self,
function: Callable,
*children: pybamm.Symbol,
*children: pybamm.Symbol | float | int,
name: str | None = None,
differentiated_function: Callable | None = None,
):
Expand All @@ -42,6 +42,7 @@ def __init__(
if isinstance(child, (float, int, np.number)):
children[idx] = pybamm.Scalar(child)

children = cast(Sequence[pybamm.Symbol], children)
if name is not None:
self.name = name
else:
Expand Down
Loading
Loading