Skip to content

Commit a7b2c4e

Browse files
committed
Improve more typing
1 parent 9f8872e commit a7b2c4e

16 files changed

Lines changed: 1778 additions & 804 deletions

File tree

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,6 @@ module = [
8585
"testtools.monkey",
8686
"testtools.run",
8787
"testtools.runtest",
88-
"testtools.testcase",
89-
"testtools.testresult.*",
9088
"testtools.twistedsupport.*",
9189
"tests.*",
9290
]

tests/matchers/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class MatcherTestProtocol(Protocol):
2121
class TestMatchersInterface:
2222
"""Mixin class that provides test methods for matcher interfaces."""
2323

24+
__test__ = False # Tell pytest not to collect this as a test class
25+
2426
def test_matches_match(self: MatcherTestProtocol) -> None:
2527
matcher = self.matches_matcher
2628
matches = self.matches_matches

tests/test_testresult.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ def make_exception_info(exceptionFactory, *args, **kwargs):
146146
class TestControlContract:
147147
"""Stopping test runs."""
148148

149+
__test__ = False # Tell pytest not to collect this as a test class
150+
149151
# These are provided by the class that uses this mixin
150152
makeResult: Any
151153
assertFalse: Any
@@ -573,6 +575,8 @@ def makeResult(self):
573575

574576

575577
class TestStreamResultContract:
578+
__test__ = False # Tell pytest not to collect this as a test class
579+
576580
# These are provided by the class that uses this mixin
577581
addCleanup: Any
578582

testtools/content.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ class TracebackContent(Content):
197197

198198
def __init__(
199199
self,
200-
err: tuple[type[BaseException], BaseException, types.TracebackType | None],
200+
err: tuple[type[BaseException], BaseException, types.TracebackType | None]
201+
| tuple[None, None, None],
201202
test: _TestCase | None,
202203
capture_locals: bool = False,
203204
) -> None:
@@ -211,6 +212,9 @@ def __init__(
211212
raise ValueError("err may not be None")
212213

213214
exctype, value, tb = err
215+
# Ensure we have a real exception, not the (None, None, None) variant
216+
assert exctype is not None, "exctype must not be None"
217+
assert value is not None, "value must not be None"
214218
# Skip test runner traceback levels
215219
if StackLinesContent.HIDE_INTERNAL_STACK:
216220
while tb and "__unittest" in tb.tb_frame.f_globals:

testtools/matchers/_basic.py

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import re
2121
from collections.abc import Callable
2222
from pprint import pformat
23-
from typing import Any
23+
from typing import Any, Generic, TypeVar
2424

2525
from ..compat import (
2626
text_repr,
@@ -35,52 +35,64 @@
3535
Mismatch,
3636
)
3737

38+
T = TypeVar("T")
39+
U = TypeVar("U")
3840

39-
def _format(thing):
41+
42+
def _format(thing: object) -> str:
4043
"""Blocks of text with newlines are formatted as triple-quote
4144
strings. Everything else is pretty-printed.
4245
"""
4346
if isinstance(thing, (str, bytes)):
44-
return text_repr(thing)
45-
return pformat(thing)
47+
result: str = text_repr(thing)
48+
return result
49+
pformat_result: str = pformat(thing)
50+
return pformat_result
4651

4752

48-
class _BinaryComparison:
53+
class _BinaryComparison(Matcher[T]):
4954
"""Matcher that compares an object to another object."""
5055

5156
mismatch_string: str
5257
# comparator is defined by subclasses - using Any to allow different signatures
5358
comparator: Callable[..., Any]
5459

55-
def __init__(self, expected):
60+
def __init__(self, expected: T) -> None:
5661
self.expected = expected
5762

58-
def __str__(self):
63+
def __str__(self) -> str:
5964
return f"{self.__class__.__name__}({self.expected!r})"
6065

61-
def match(self, other):
66+
def match(self, other: T) -> Mismatch | None:
6267
if self.comparator(other, self.expected):
6368
return None
6469
return _BinaryMismatch(other, self.mismatch_string, self.expected)
6570

6671

67-
class _BinaryMismatch(Mismatch):
72+
class _BinaryMismatch(Mismatch, Generic[T]):
6873
"""Two things did not match."""
6974

70-
def __init__(self, actual, mismatch_string, reference, reference_on_right=True):
75+
def __init__(
76+
self,
77+
actual: T,
78+
mismatch_string: str,
79+
reference: T,
80+
reference_on_right: bool = True,
81+
) -> None:
7182
self._actual = actual
7283
self._mismatch_string = mismatch_string
7384
self._reference = reference
7485
self._reference_on_right = reference_on_right
7586

76-
def describe(self):
87+
def describe(self) -> str:
7788
# Special handling for set comparisons
7889
if (
7990
self._mismatch_string == "!="
8091
and isinstance(self._reference, set)
8192
and isinstance(self._actual, set)
8293
):
83-
return self._describe_set_difference()
94+
result: str = self._describe_set_difference()
95+
return result
8496

8597
actual = repr(self._actual)
8698
reference = repr(self._reference)
@@ -97,8 +109,12 @@ def describe(self):
97109
left, right = reference, actual
98110
return f"{left} {self._mismatch_string} {right}"
99111

100-
def _describe_set_difference(self):
112+
def _describe_set_difference(self) -> str:
101113
"""Describe the difference between two sets in a readable format."""
114+
# Type narrowing: we know these are sets from the isinstance check in describe()
115+
assert isinstance(self._reference, set)
116+
assert isinstance(self._actual, set)
117+
102118
reference_only = sorted(
103119
self._reference - self._actual, key=lambda x: (type(x).__name__, x)
104120
)
@@ -126,7 +142,7 @@ class Equals(_BinaryComparison):
126142
mismatch_string = "!="
127143

128144

129-
class _FlippedEquals:
145+
class _FlippedEquals(Matcher[T]):
130146
"""Matches if the items are equal.
131147
132148
Exactly like ``Equals`` except that the short mismatch message is "
@@ -136,10 +152,13 @@ class _FlippedEquals:
136152
the assertion.
137153
"""
138154

139-
def __init__(self, expected):
155+
def __init__(self, expected: T) -> None:
140156
self._expected = expected
141157

142-
def match(self, other):
158+
def __str__(self) -> str:
159+
return f"_FlippedEquals({self._expected!r})"
160+
161+
def match(self, other: T) -> Mismatch | None:
143162
mismatch = Equals(self._expected).match(other)
144163
if not mismatch:
145164
return None
@@ -335,25 +354,25 @@ def match(self, matchee):
335354
return None
336355

337356

338-
class IsInstance:
357+
class IsInstance(Matcher[T]):
339358
"""Matcher that wraps isinstance."""
340359

341-
def __init__(self, *types):
360+
def __init__(self, *types: type[T]) -> None:
342361
self.types = tuple(types)
343362

344-
def __str__(self):
363+
def __str__(self) -> str:
345364
return "{}({})".format(
346365
self.__class__.__name__, ", ".join(type.__name__ for type in self.types)
347366
)
348367

349-
def match(self, other):
368+
def match(self, other: T) -> Mismatch | None:
350369
if isinstance(other, self.types):
351370
return None
352371
return NotAnInstance(other, self.types)
353372

354373

355-
class NotAnInstance(Mismatch):
356-
def __init__(self, matchee, types):
374+
class NotAnInstance(Mismatch, Generic[T]):
375+
def __init__(self, matchee: T, types: tuple[type[T], ...]) -> None:
357376
"""Create a NotAnInstance Mismatch.
358377
359378
:param matchee: the thing which is not an instance of any of types.
@@ -362,7 +381,7 @@ def __init__(self, matchee, types):
362381
self.matchee = matchee
363382
self.types = types
364383

365-
def describe(self):
384+
def describe(self) -> str:
366385
if len(self.types) == 1:
367386
typestr = self.types[0].__name__
368387
else:
@@ -372,8 +391,8 @@ def describe(self):
372391
return f"'{self.matchee}' is not an instance of {typestr}"
373392

374393

375-
class DoesNotContain(Mismatch):
376-
def __init__(self, matchee, needle):
394+
class DoesNotContain(Mismatch, Generic[T, U]):
395+
def __init__(self, matchee: T, needle: U) -> None:
377396
"""Create a DoesNotContain Mismatch.
378397
379398
:param matchee: the object that did not contain needle.
@@ -382,41 +401,41 @@ def __init__(self, matchee, needle):
382401
self.matchee = matchee
383402
self.needle = needle
384403

385-
def describe(self):
404+
def describe(self) -> str:
386405
return f"{self.needle!r} not in {self.matchee!r}"
387406

388407

389-
class Contains(Matcher):
408+
class Contains(Matcher[T], Generic[T, U]):
390409
"""Checks whether something is contained in another thing."""
391410

392-
def __init__(self, needle):
411+
def __init__(self, needle: U) -> None:
393412
"""Create a Contains Matcher.
394413
395414
:param needle: the thing that needs to be contained by matchees.
396415
"""
397416
self.needle = needle
398417

399-
def __str__(self):
418+
def __str__(self) -> str:
400419
return f"Contains({self.needle!r})"
401420

402-
def match(self, matchee):
421+
def match(self, matchee: T) -> Mismatch | None:
403422
try:
404-
if self.needle not in matchee:
423+
if self.needle not in matchee: # type: ignore[operator]
405424
return DoesNotContain(matchee, self.needle)
406425
except TypeError:
407426
# e.g. 1 in 2 will raise TypeError
408427
return DoesNotContain(matchee, self.needle)
409428
return None
410429

411430

412-
class MatchesRegex:
431+
class MatchesRegex(Matcher[str]):
413432
"""Matches if the matchee is matched by a regular expression."""
414433

415-
def __init__(self, pattern, flags=0):
434+
def __init__(self, pattern: str | bytes, flags: int = 0) -> None:
416435
self.pattern = pattern
417436
self.flags = flags
418437

419-
def __str__(self):
438+
def __str__(self) -> str:
420439
args = [f"{self.pattern!r}"]
421440
flag_arg = []
422441
# dir() sorts the attributes for us, so we don't need to do it again.
@@ -428,15 +447,16 @@ def __str__(self):
428447
args.append("|".join(flag_arg))
429448
return "{}({})".format(self.__class__.__name__, ", ".join(args))
430449

431-
def match(self, value):
432-
if not re.match(self.pattern, value, self.flags):
450+
def match(self, value: str) -> Mismatch | None:
451+
if not re.match(self.pattern, value, self.flags): # type: ignore[arg-type]
433452
pattern = self.pattern
434453
if not isinstance(pattern, str):
435454
pattern = pattern.decode("latin1")
436455
pattern = pattern.encode("unicode_escape").decode("ascii")
437456
return Mismatch(
438457
"{!r} does not match /{}/".format(value, pattern.replace("\\\\", "\\"))
439458
)
459+
return None
440460

441461

442462
def has_len(x, y):

testtools/matchers/_datastructures.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22

33
"""Matchers that operate with knowledge of Python data structures."""
44

5+
from collections.abc import Sequence
6+
from typing import Generic, TypeVar
7+
58
from ..helpers import map_values
69
from ._higherorder import (
710
Annotate,
811
MatchesAll,
912
MismatchesAll,
1013
)
11-
from ._impl import Mismatch
14+
from ._impl import Matcher, Mismatch
15+
16+
T = TypeVar("T")
1217

1318
__all__ = [
1419
"ContainsAll",
@@ -30,7 +35,7 @@ def ContainsAll(items):
3035
return MatchesAll(*map(Contains, items), first_only=False)
3136

3237

33-
class MatchesListwise:
38+
class MatchesListwise(Matcher["Sequence[T]"], Generic[T]):
3439
"""Matches if each matcher matches the corresponding value.
3540
3641
More easily explained by example than in words:
@@ -48,7 +53,9 @@ class MatchesListwise:
4853
3 != 1
4954
"""
5055

51-
def __init__(self, matchers, first_only=False):
56+
def __init__(
57+
self, matchers: "Sequence[Matcher[T]]", first_only: bool = False
58+
) -> None:
5259
"""Construct a MatchesListwise matcher.
5360
5461
:param matchers: A list of matcher that the matched values must match.
@@ -58,7 +65,7 @@ def __init__(self, matchers, first_only=False):
5865
self.matchers = matchers
5966
self.first_only = first_only
6067

61-
def match(self, values):
68+
def match(self, values: "Sequence[T]") -> Mismatch | None:
6269
from ._basic import HasLength
6370

6471
mismatches = []
@@ -75,6 +82,7 @@ def match(self, values):
7582
mismatches.append(mismatch)
7683
if mismatches:
7784
return MismatchesAll(mismatches)
85+
return None
7886

7987

8088
class MatchesStructure:

0 commit comments

Comments
 (0)