Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept generic ExceptionGroups for raises #13134

Merged
merged 4 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ Tim Hoffmann
Tim Strazny
TJ Bruno
Tobias Diez
Tobias Petersen
Tom Dalton
Tom Viner
Tomáš Gavenčiak
Expand Down
8 changes: 8 additions & 0 deletions changelog/13115.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Allows supplying ``ExceptionGroup[Exception]`` and ``BaseExceptionGroup[BaseException]`` to ``pytest.raises`` to keep full typing on :class:`ExceptionInfo <pytest.ExceptionInfo>`:

.. code-block:: python

with pytest.raises(ExceptionGroup[Exception]) as exc_info:
some_function()

Parametrizing with other exception types remains an error - we do not check the types of child exceptions and thus do not permit code that might look like we do.
47 changes: 42 additions & 5 deletions src/_pytest/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
from numbers import Complex
import pprint
import re
import sys
from types import TracebackType
from typing import Any
from typing import cast
from typing import final
from typing import get_args
from typing import get_origin
from typing import overload
from typing import TYPE_CHECKING
from typing import TypeVar
Expand All @@ -24,6 +27,10 @@
from _pytest.outcomes import fail


if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
from exceptiongroup import ExceptionGroup

if TYPE_CHECKING:
from numpy import ndarray

Expand Down Expand Up @@ -954,15 +961,45 @@ def raises(
f"Raising exceptions is already understood as failing the test, so you don't need "
f"any special code to say 'this should never raise an exception'."
)

expected_exceptions: tuple[type[E], ...]
origin_exc: type[E] | None = get_origin(expected_exception)
if isinstance(expected_exception, type):
expected_exceptions: tuple[type[E], ...] = (expected_exception,)
expected_exceptions = (expected_exception,)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why comma is here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without a comma it's not a tuple but just a singular value.

Python 3.10.12 (main, Jan 17 2025, 14:35:34) [GCC 11.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> (1)
1
>>> (1,)
(1,)
>>> 

The original code had the same pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact it's the parenthesis that are superfluous for everything except when creating the empty tuple

>>> 1,
(1,)

elif origin_exc and issubclass(origin_exc, BaseExceptionGroup):
expected_exceptions = (cast(type[E], expected_exception),)
else:
expected_exceptions = expected_exception
for exc in expected_exceptions:
if not isinstance(exc, type) or not issubclass(exc, BaseException):

def validate_exc(exc: type[E]) -> type[E]:
__tracebackhide__ = True
origin_exc: type[E] | None = get_origin(exc)
nicoddemus marked this conversation as resolved.
Show resolved Hide resolved
if origin_exc and issubclass(origin_exc, BaseExceptionGroup):
exc_type = get_args(exc)[0]
if (
issubclass(origin_exc, ExceptionGroup) and exc_type in (Exception, Any)
) or (
issubclass(origin_exc, BaseExceptionGroup)
and exc_type in (BaseException, Any)
):
return cast(type[E], origin_exc)
else:
raise ValueError(
f"Only `ExceptionGroup[Exception]` or `BaseExceptionGroup[BaseExeption]` "
f"are accepted as generic types but got `{exc}`. "
f"As `raises` will catch all instances of the specified group regardless of the "
f"generic argument specific nested exceptions has to be checked "
f"with `ExceptionInfo.group_contains()`"
)

elif not isinstance(exc, type) or not issubclass(exc, BaseException):
msg = "expected exception must be a BaseException type, not {}" # type: ignore[unreachable]
not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__
raise TypeError(msg.format(not_a))
else:
return exc

expected_exceptions = tuple(validate_exc(exc) for exc in expected_exceptions)

message = f"DID NOT RAISE {expected_exception}"

Expand All @@ -973,14 +1010,14 @@ def raises(
msg += ", ".join(sorted(kwargs))
msg += "\nUse context-manager form instead?"
raise TypeError(msg)
return RaisesContext(expected_exception, message, match)
return RaisesContext(expected_exceptions, message, match)
else:
func = args[0]
if not callable(func):
raise TypeError(f"{func!r} object (type: {type(func)}) must be callable")
try:
func(*args[1:], **kwargs)
except expected_exception as e:
except expected_exceptions as e:
return _pytest._code.ExceptionInfo.from_exception(e)
fail(message)

Expand Down
34 changes: 34 additions & 0 deletions testing/code/test_excinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from _pytest._code.code import TracebackStyle

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
from exceptiongroup import ExceptionGroup


Expand Down Expand Up @@ -453,6 +454,39 @@ def test_division_zero():
result.stdout.re_match_lines([r".*__tracebackhide__ = True.*", *match])


def test_raises_accepts_generic_group() -> None:
with pytest.raises(ExceptionGroup[Exception]) as exc_info:
raise ExceptionGroup("", [RuntimeError()])
assert exc_info.group_contains(RuntimeError)


def test_raises_accepts_generic_base_group() -> None:
with pytest.raises(BaseExceptionGroup[BaseException]) as exc_info:
raise ExceptionGroup("", [RuntimeError()])
assert exc_info.group_contains(RuntimeError)


def test_raises_rejects_specific_generic_group() -> None:
with pytest.raises(ValueError):
pytest.raises(ExceptionGroup[RuntimeError])


def test_raises_accepts_generic_group_in_tuple() -> None:
with pytest.raises((ValueError, ExceptionGroup[Exception])) as exc_info:
raise ExceptionGroup("", [RuntimeError()])
assert exc_info.group_contains(RuntimeError)


def test_raises_exception_escapes_generic_group() -> None:
try:
with pytest.raises(ExceptionGroup[Exception]):
raise ValueError("my value error")
except ValueError as e:
assert str(e) == "my value error"
else:
pytest.fail("Expected ValueError to be raised")


class TestGroupContains:
def test_contains_exception_type(self) -> None:
exc_group = ExceptionGroup("", [RuntimeError()])
Expand Down
Loading