Skip to content

Commit

Permalink
factor out validate_tup_type
Browse files Browse the repository at this point in the history
  • Loading branch information
blueyed committed Oct 2, 2020
1 parent 624028e commit 8815a61
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 38 deletions.
11 changes: 2 additions & 9 deletions src/_pytest/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Union

import _pytest._code
from .types import validate_tup_type
from _pytest.compat import overload
from _pytest.compat import STRING_TYPES
from _pytest.compat import TYPE_CHECKING
Expand Down Expand Up @@ -671,15 +672,7 @@ def raises( # noqa: F811
"""
__tracebackhide__ = True

if isinstance(expected_exception, type):
excepted_exceptions = (expected_exception,) # type: Tuple[Type[_E], ...]
else:
excepted_exceptions = expected_exception
for exc in excepted_exceptions:
if not isinstance(exc, type) or not issubclass(exc, BaseException):
msg = "expected exception must be a BaseException type, not {}"
not_a = exc.__name__ if isinstance(exc, type) else type(exc).__name__
raise TypeError(msg.format(not_a))
validate_tup_type(expected_exception, BaseException)

message = "DID NOT RAISE {}".format(expected_exception)

Expand Down
16 changes: 2 additions & 14 deletions src/_pytest/recwarn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from _pytest.compat import overload
from _pytest.compat import TYPE_CHECKING
from .types import validate_tup_type
from _pytest.fixtures import yield_fixture
from _pytest.outcomes import fail

Expand Down Expand Up @@ -214,20 +215,7 @@ def __init__(
) -> None:
super().__init__()

msg = "exceptions must be derived from Warning, not %s"
if expected_warning is None:
expected_warning_tup = None
elif isinstance(expected_warning, tuple):
for exc in expected_warning:
if not issubclass(exc, Warning):
raise TypeError(msg % type(exc))
expected_warning_tup = expected_warning
elif issubclass(expected_warning, Warning):
expected_warning_tup = (expected_warning,)
else:
raise TypeError(msg % type(expected_warning))

self.expected_warning = expected_warning_tup
self.expected_warning = validate_tup_type(expected_warning, Warning)
self.match_expr = match_expr

def __exit__(
Expand Down
39 changes: 39 additions & 0 deletions src/_pytest/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from .compat import TYPE_CHECKING

# from more_itertools import collapse

if TYPE_CHECKING:
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union

_T = TypeVar("_T", bound="type")


def collapse_tuples(obj) -> "Tuple":
def walk(node):
if isinstance(node, tuple):
for child in node:
yield from walk(child)
else:
yield node

return tuple(x for x in walk(obj))


def validate_tup_type(
type_or_types: "Union[_T, Tuple[_T, ...]]", base_type: "_T"
) -> "Optional[Tuple[_T, ...]]":
if type_or_types is None:
return None
types = collapse_tuples(type_or_types)
for exc in types:
if not isinstance(exc, type) or not issubclass(exc, base_type):
raise TypeError(
"exceptions must be derived from {}, not {}".format(
base_type.__name__,
exc.__name__ if isinstance(exc, type) else type(exc).__name__,
)
)
return types
23 changes: 9 additions & 14 deletions testing/python/raises.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ def test_division(example_input, expectation):
def test_noclass_iterable(self) -> None:
with pytest.raises(
TypeError,
match="^exceptions must be derived from BaseException, not <class 'str'>$",
match="^exceptions must be derived from BaseException, not str$",
):
pytest.raises("wrong", lambda: None) # type: ignore[call-overload]

def test_noclass_noniterable(self) -> None:
with pytest.raises(
TypeError,
match="^exceptions must be derived from BaseException, not <class 'int'>$",
match="^exceptions must be derived from BaseException, not int$",
):
pytest.raises(41, lambda: None) # type: ignore[call-overload]

Expand Down Expand Up @@ -295,20 +295,15 @@ def test_raises_context_manager_with_kwargs(self):
assert "Unexpected keyword arguments" in str(excinfo.value)

def test_expected_exception_is_not_a_baseexception(self) -> None:
with pytest.raises(TypeError) as excinfo:
with pytest.raises("hello"): # type: ignore[call-overload]
pass # pragma: no cover
assert "must be a BaseException type, not str" in str(excinfo.value)
msg = "^exceptions must be derived from BaseException, not {}$"
with pytest.raises(TypeError, match=msg.format("str")):
pytest.raises("hello") # type: ignore[call-overload]

class NotAnException:
pass

with pytest.raises(TypeError) as excinfo:
with pytest.raises(NotAnException): # type: ignore[type-var]
pass # pragma: no cover
assert "must be a BaseException type, not NotAnException" in str(excinfo.value)
with pytest.raises(TypeError, match=msg.format("NotAnException")):
pytest.raises(NotAnException) # type: ignore[type-var]

with pytest.raises(TypeError) as excinfo:
with pytest.raises(("hello", NotAnException)): # type: ignore[arg-type]
pass # pragma: no cover
assert "must be a BaseException type, not str" in str(excinfo.value)
with pytest.raises(TypeError, match=msg.format("str")):
pytest.raises(("hello", NotAnException)) # type: ignore[arg-type]
8 changes: 7 additions & 1 deletion testing/test_recwarn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,20 @@ def test_warn_stacklevel(self) -> None:
warnings.warn("test", DeprecationWarning, 2)

def test_typechecking(self) -> None:
msg = "exceptions must be derived from Warning, not <class '{}'>"
msg = "exceptions must be derived from Warning, not {}"
with pytest.raises(TypeError, match=msg.format("int")):
WarningsChecker(5)
with pytest.raises(TypeError, match=msg.format("str")):
WarningsChecker(("hi", RuntimeWarning))
with pytest.raises(TypeError, match=msg.format("list")):
WarningsChecker([DeprecationWarning, RuntimeWarning])

def test_nested_tuples(self) -> None:
wc1 = WarningsChecker((DeprecationWarning, RuntimeWarning))
wc2 = WarningsChecker(((DeprecationWarning,), (RuntimeWarning,)))
assert wc1.expected_warning == (DeprecationWarning, RuntimeWarning)
assert wc1.expected_warning == wc2.expected_warning

def test_invalid_enter_exit(self) -> None:
# wrap this test in WarningsRecorder to ensure warning state gets reset
with WarningsRecorder():
Expand Down

0 comments on commit 8815a61

Please sign in to comment.