Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LeastSquares for functions with more than two arguments #1016

Merged
merged 8 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- run: uv pip install --system -v . pytest ${{ matrix.installs }}
# python -m pip install .[test] is not used here to test minimum (faster),
# the cov workflow runs all tests.
- run: python -m pytest
# pip install .[test] is not used here to test minimum (faster)
# cov workflow runs all tests
- run: uv pip install --system . pytest pytest-xdist ${{ matrix.installs }}
- run: python -m pytest -n 3
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ bench/*.svg
.project
.pydevproject
.settings
.coverage
.coverage*
.ipynb_checkpoints
.eggs
.pytest_cache
.mypy_cache
.ruff_cache
.nox

Untitled*.ipynb
Untitled*.py
Expand Down
49 changes: 28 additions & 21 deletions noxfile.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""
Noxfile for iMinuit.
Noxfile for iminuit.

Use `-R` to instantly reuse an existing environment and
to avoid rebuilding the binary.
Pass extra arguments to pytest after --
"""

import nox
import sys

sys.path.append(".")
import python_releases

nox.needs_version = ">=2024.3.2"
nox.options.default_venv_backend = "uv|virtualenv"
Expand All @@ -15,46 +18,50 @@
"COVERAGE_CORE": "sysmon", # faster coverage on Python 3.12
}

PYPROJECT = nox.project.load_toml("pyproject.toml")
MINIMUM_PYTHON = PYPROJECT["project"]["requires-python"].strip(">=")
LATEST_PYTHON = str(python_releases.latest())

nox.options.sessions = ["test", "mintest", "maxtest"]


@nox.session(reuse_venv=True)
@nox.session()
def test(session: nox.Session) -> None:
"""Run the unit and regular tests."""
"""Run all tests."""
session.install("-e.[test]")
session.run("pytest", *session.posargs, env=ENV)
session.run("pytest", "-n=auto", *session.posargs, env=ENV)


@nox.session(python="3.12", reuse_venv=True)
def maxtest(session: nox.Session) -> None:
"""Run the unit and regular tests."""
session.install("-e.", "scipy", "matplotlib", "pytest", "--pre")
session.run("pytest", *session.posargs, env=ENV)
@nox.session(python=MINIMUM_PYTHON, venv_backend="uv")
def mintest(session: nox.Session) -> None:
"""Run tests on the minimum python version."""
session.install("-e.", "--resolution=lowest-direct")
session.install("pytest", "pytest-xdist")
session.run("pytest", "-n=auto", *session.posargs)


@nox.session(python="3.9", venv_backend="uv")
def mintest(session: nox.Session) -> None:
@nox.session(python=LATEST_PYTHON)
def maxtest(session: nox.Session) -> None:
"""Run the unit and regular tests."""
session.install("-e.", "--resolution=lowest-direct")
session.install("pytest")
session.run("pytest", *session.posargs)
session.install("-e.", "scipy", "matplotlib", "pytest", "pytest-xdist", "--pre")
session.run("pytest", "-n=auto", *session.posargs, env=ENV)


@nox.session(python="pypy3.9", venv_backend="uv")
@nox.session(python="pypy3.9")
def pypy(session: nox.Session) -> None:
"""Run the unit and regular tests."""
session.install("-e.")
session.install("pytest")
session.run("pytest", *session.posargs)
session.install("pytest", "pytest-xdist")
session.run("pytest", "-n=auto", *session.posargs)


# Python-3.12 provides coverage info faster
@nox.session(python="3.12", reuse_venv=True)
@nox.session(python="3.12", venv_backend="uv")
def cov(session: nox.Session) -> None:
"""Run covage and place in 'htmlcov' directory."""
session.install("-e.[test,doc]")
session.run("coverage", "run", "-m", "pytest", env=ENV)
session.run("coverage", "html", "-d", "htmlcov")
session.run("coverage", "html", "-d", "build/htmlcov")
session.run("coverage", "report", "-m")


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ test = [
"numba; platform_python_implementation=='CPython'",
"numba-stats; platform_python_implementation=='CPython'",
"pytest",
"pytest-xdist",
"scipy",
"tabulate",
"boost_histogram",
Expand Down
63 changes: 63 additions & 0 deletions python_releases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Get the latest Python release which is online."""

import urllib.request
import re
from html.parser import HTMLParser
import gzip
from packaging.version import Version


class PythonVersionParser(HTMLParser):
"""Specialized HTMLParser to get Python version number."""

def __init__(self):
super().__init__()
self.versions = set()
self.found_version = False

def handle_starttag(self, tag, attrs):
"""Look for the right tag and store result in an attribute."""
if tag == "a":
for attr in attrs:
if attr[0] == "href" and "/downloads/release/python-" in attr[1]:
self.found_version = True
return

def handle_data(self, data):
"""Extract Python version from entry."""
if self.found_version:
self.found_version = False
match = re.search(r"Python (\d+\.\d+)", data)
if match:
self.versions.add(Version(match.group(1)))


def versions():
"""Get all Python release versions."""
req = urllib.request.Request("https://www.python.org/downloads/")
req.add_header("Accept-Encoding", "gzip")

with urllib.request.urlopen(req) as response:
raw = response.read()
if response.info().get("Content-Encoding") == "gzip":
raw = gzip.decompress(raw)
html = raw.decode("utf-8")

parser = PythonVersionParser()
parser.feed(html)

return parser.versions


def latest():
"""Return version of latest Python release."""
return max(versions())


def main():
"""Print all discovered release versions."""
print(" ".join(str(x) for x in sorted(versions())))


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion src/iminuit/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2195,7 +2195,7 @@ def __init__(
y = _norm(y)
assert x.ndim >= 1 # guaranteed by _norm

self._ndim = x.ndim
self._ndim = x.shape[0] if x.ndim > 1 else 1
self._model = model
self._model_grad = grad
self.loss = loss
Expand Down
16 changes: 9 additions & 7 deletions src/iminuit/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
These are used by mypy and similar tools.
"""

from typing import Protocol, Optional, List, Union, runtime_checkable, NamedTuple
from typing import (
Protocol,
Optional,
List,
Union,
runtime_checkable,
NamedTuple,
Annotated,
)
from numpy.typing import NDArray
import numpy as np
import dataclasses
import sys

if sys.version_info < (3, 9):
from typing_extensions import Annotated # noqa pragma: no cover
else:
from typing import Annotated # noqa pragma: no cover


# Key for ValueView, ErrorView, etc.
Expand Down
8 changes: 3 additions & 5 deletions src/iminuit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,15 @@
Collection,
Sequence,
TypeVar,
Annotated,
get_args,
get_origin,
)
import abc
from time import monotonic
import warnings
import sys

if sys.version_info < (3, 9):
from typing_extensions import Annotated, get_args, get_origin # pragma: no cover
else:
from typing import Annotated, get_args, get_origin # pragma: no cover

T = TypeVar("T")

__all__ = (
Expand Down
61 changes: 48 additions & 13 deletions tests/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,34 +1066,69 @@ def model(x, a, b):


def test_LeastSquares_2D():
x = np.array([1.0, 2.0, 3.0])
y = np.array([4.0, 5.0, 6.0])
z = 1.5 * x + 0.2 * y
ze = 1.5

def model(xy, a, b):
x, y = xy
return a * x + b * y

c = LeastSquares((x, y), z, ze, model, grad=numerical_model_gradient(model))
x = np.array([1.0, 2.0, 3.0])
y = np.array([4.0, 5.0, 6.0])
f = model((x, y), 1.5, 0.2)
fe = 1.5

c = LeastSquares((x, y), f, fe, model, grad=numerical_model_gradient(model))
assert c.ndata == 3

ref = numerical_cost_gradient(c)
assert_allclose(c.grad(1, 2), ref(1, 2))

assert_equal(c.x, (x, y))
assert_equal(c.y, z)
assert_equal(c.yerror, ze)
assert_equal(c.y, f)
assert_equal(c.yerror, fe)
assert_allclose(c(1.5, 0.2), 0.0)
assert_allclose(c(2.5, 0.2), np.sum(((z - 2.5 * x - 0.2 * y) / ze) ** 2))
assert_allclose(c(1.5, 1.2), np.sum(((z - 1.5 * x - 1.2 * y) / ze) ** 2))
assert_allclose(c(2.5, 0.2), np.sum(((f - 2.5 * x - 0.2 * y) / fe) ** 2))
assert_allclose(c(1.5, 1.2), np.sum(((f - 1.5 * x - 1.2 * y) / fe) ** 2))

c.y = 2 * z
assert_equal(c.y, 2 * z)
c.y = 2 * f
assert_equal(c.y, 2 * f)
c.x = (y, x)
assert_equal(c.x, (y, x))


def test_LeastSquares_3D():
def model(xyz, a, b):
x, y, z = xyz
return a * x + b * y + a * b * z

x = np.array([1.0, 2.0, 3.0, 4.0])
y = np.array([4.0, 5.0, 6.0, 7.0])
z = np.array([7.0, 8.0, 9.0, 10.0])

f = model((x, y, z), 1.5, 0.2)
fe = 1.5

c = LeastSquares((x, y, z), f, fe, model, grad=numerical_model_gradient(model))
assert c.ndata == 4

ref = numerical_cost_gradient(c)
assert_allclose(c.grad(1, 2), ref(1, 2))

assert_equal(c.x, (x, y, z))
assert_equal(c.y, f)
assert_equal(c.yerror, fe)
assert_allclose(c(1.5, 0.2), 0.0)
assert_allclose(
c(2.5, 0.2), np.sum(((f - 2.5 * x - 0.2 * y - 2.5 * 0.2 * z) / fe) ** 2)
)
assert_allclose(
c(1.5, 1.2), np.sum(((f - 1.5 * x - 1.2 * y - 1.5 * 1.2 * z) / fe) ** 2)
)

c.y = 2 * z
assert_equal(c.y, 2 * z)
c.x = (y, x, z)
assert_equal(c.x, (y, x, z))


def test_LeastSquares_bad_input():
with pytest.raises(ValueError, match="shape mismatch"):
LeastSquares([1, 2], [], [1], lambda x, a: 0)
Expand Down Expand Up @@ -1208,7 +1243,7 @@ def line(x, par):
def test_LeastSquares_visualize_2D():
pytest.importorskip("matplotlib")

c = LeastSquares([[1, 2]], [[2, 3]], 0.1, line)
c = LeastSquares([[1, 2], [2, 3]], [1, 2], 0.1, line)

with pytest.raises(ValueError, match="not implemented for multi-dimensional"):
c.visualize((1, 2))
Expand Down
1 change: 1 addition & 0 deletions tests/test_without_ipywidgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


def test_interactive():
pytest.importorskip("matplotlib")
import iminuit

cost = LeastSquares([1.1, 2.2], [3.3, 4.4], 1, lambda x, a: a * x)
Expand Down
Loading