2020import re
2121from collections .abc import Callable
2222from pprint import pformat
23- from typing import Any
23+ from typing import Any , Generic , TypeVar
2424
2525from ..compat import (
2626 text_repr ,
3535 Mismatch ,
3636)
3737
38+ T = TypeVar ("T" )
39+ U = TypeVar ("U" )
3840
39- def _format (thing ):
41+
42+ def _format (thing : object ) -> str :
4043 """Blocks of text with newlines are formatted as triple-quote
4144 strings. Everything else is pretty-printed.
4245 """
4346 if isinstance (thing , (str , bytes )):
44- return text_repr (thing )
45- return pformat (thing )
47+ result : str = text_repr (thing )
48+ return result
49+ pformat_result : str = pformat (thing )
50+ return pformat_result
4651
4752
48- class _BinaryComparison :
53+ class _BinaryComparison ( Matcher [ T ]) :
4954 """Matcher that compares an object to another object."""
5055
5156 mismatch_string : str
5257 # comparator is defined by subclasses - using Any to allow different signatures
5358 comparator : Callable [..., Any ]
5459
55- def __init__ (self , expected ) :
60+ def __init__ (self , expected : T ) -> None :
5661 self .expected = expected
5762
58- def __str__ (self ):
63+ def __str__ (self ) -> str :
5964 return f"{ self .__class__ .__name__ } ({ self .expected !r} )"
6065
61- def match (self , other ) :
66+ def match (self , other : T ) -> Mismatch | None :
6267 if self .comparator (other , self .expected ):
6368 return None
6469 return _BinaryMismatch (other , self .mismatch_string , self .expected )
6570
6671
67- class _BinaryMismatch (Mismatch ):
72+ class _BinaryMismatch (Mismatch , Generic [ T ] ):
6873 """Two things did not match."""
6974
70- def __init__ (self , actual , mismatch_string , reference , reference_on_right = True ):
75+ def __init__ (
76+ self ,
77+ actual : T ,
78+ mismatch_string : str ,
79+ reference : T ,
80+ reference_on_right : bool = True ,
81+ ) -> None :
7182 self ._actual = actual
7283 self ._mismatch_string = mismatch_string
7384 self ._reference = reference
7485 self ._reference_on_right = reference_on_right
7586
76- def describe (self ):
87+ def describe (self ) -> str :
7788 # Special handling for set comparisons
7889 if (
7990 self ._mismatch_string == "!="
8091 and isinstance (self ._reference , set )
8192 and isinstance (self ._actual , set )
8293 ):
83- return self ._describe_set_difference ()
94+ result : str = self ._describe_set_difference ()
95+ return result
8496
8597 actual = repr (self ._actual )
8698 reference = repr (self ._reference )
@@ -97,8 +109,12 @@ def describe(self):
97109 left , right = reference , actual
98110 return f"{ left } { self ._mismatch_string } { right } "
99111
100- def _describe_set_difference (self ):
112+ def _describe_set_difference (self ) -> str :
101113 """Describe the difference between two sets in a readable format."""
114+ # Type narrowing: we know these are sets from the isinstance check in describe()
115+ assert isinstance (self ._reference , set )
116+ assert isinstance (self ._actual , set )
117+
102118 reference_only = sorted (
103119 self ._reference - self ._actual , key = lambda x : (type (x ).__name__ , x )
104120 )
@@ -126,7 +142,7 @@ class Equals(_BinaryComparison):
126142 mismatch_string = "!="
127143
128144
129- class _FlippedEquals :
145+ class _FlippedEquals ( Matcher [ T ]) :
130146 """Matches if the items are equal.
131147
132148 Exactly like ``Equals`` except that the short mismatch message is "
@@ -136,10 +152,13 @@ class _FlippedEquals:
136152 the assertion.
137153 """
138154
139- def __init__ (self , expected ) :
155+ def __init__ (self , expected : T ) -> None :
140156 self ._expected = expected
141157
142- def match (self , other ):
158+ def __str__ (self ) -> str :
159+ return f"_FlippedEquals({ self ._expected !r} )"
160+
161+ def match (self , other : T ) -> Mismatch | None :
143162 mismatch = Equals (self ._expected ).match (other )
144163 if not mismatch :
145164 return None
@@ -335,25 +354,25 @@ def match(self, matchee):
335354 return None
336355
337356
338- class IsInstance :
357+ class IsInstance ( Matcher [ T ]) :
339358 """Matcher that wraps isinstance."""
340359
341- def __init__ (self , * types ) :
360+ def __init__ (self , * types : type [ T ]) -> None :
342361 self .types = tuple (types )
343362
344- def __str__ (self ):
363+ def __str__ (self ) -> str :
345364 return "{}({})" .format (
346365 self .__class__ .__name__ , ", " .join (type .__name__ for type in self .types )
347366 )
348367
349- def match (self , other ) :
368+ def match (self , other : T ) -> Mismatch | None :
350369 if isinstance (other , self .types ):
351370 return None
352371 return NotAnInstance (other , self .types )
353372
354373
355- class NotAnInstance (Mismatch ):
356- def __init__ (self , matchee , types ) :
374+ class NotAnInstance (Mismatch , Generic [ T ] ):
375+ def __init__ (self , matchee : T , types : tuple [ type [ T ], ...]) -> None :
357376 """Create a NotAnInstance Mismatch.
358377
359378 :param matchee: the thing which is not an instance of any of types.
@@ -362,7 +381,7 @@ def __init__(self, matchee, types):
362381 self .matchee = matchee
363382 self .types = types
364383
365- def describe (self ):
384+ def describe (self ) -> str :
366385 if len (self .types ) == 1 :
367386 typestr = self .types [0 ].__name__
368387 else :
@@ -372,8 +391,8 @@ def describe(self):
372391 return f"'{ self .matchee } ' is not an instance of { typestr } "
373392
374393
375- class DoesNotContain (Mismatch ):
376- def __init__ (self , matchee , needle ) :
394+ class DoesNotContain (Mismatch , Generic [ T , U ] ):
395+ def __init__ (self , matchee : T , needle : U ) -> None :
377396 """Create a DoesNotContain Mismatch.
378397
379398 :param matchee: the object that did not contain needle.
@@ -382,41 +401,41 @@ def __init__(self, matchee, needle):
382401 self .matchee = matchee
383402 self .needle = needle
384403
385- def describe (self ):
404+ def describe (self ) -> str :
386405 return f"{ self .needle !r} not in { self .matchee !r} "
387406
388407
389- class Contains (Matcher ):
408+ class Contains (Matcher [ T ], Generic [ T , U ] ):
390409 """Checks whether something is contained in another thing."""
391410
392- def __init__ (self , needle ) :
411+ def __init__ (self , needle : U ) -> None :
393412 """Create a Contains Matcher.
394413
395414 :param needle: the thing that needs to be contained by matchees.
396415 """
397416 self .needle = needle
398417
399- def __str__ (self ):
418+ def __str__ (self ) -> str :
400419 return f"Contains({ self .needle !r} )"
401420
402- def match (self , matchee ) :
421+ def match (self , matchee : T ) -> Mismatch | None :
403422 try :
404- if self .needle not in matchee :
423+ if self .needle not in matchee : # type: ignore[operator]
405424 return DoesNotContain (matchee , self .needle )
406425 except TypeError :
407426 # e.g. 1 in 2 will raise TypeError
408427 return DoesNotContain (matchee , self .needle )
409428 return None
410429
411430
412- class MatchesRegex :
431+ class MatchesRegex ( Matcher [ str ]) :
413432 """Matches if the matchee is matched by a regular expression."""
414433
415- def __init__ (self , pattern , flags = 0 ) :
434+ def __init__ (self , pattern : str | bytes , flags : int = 0 ) -> None :
416435 self .pattern = pattern
417436 self .flags = flags
418437
419- def __str__ (self ):
438+ def __str__ (self ) -> str :
420439 args = [f"{ self .pattern !r} " ]
421440 flag_arg = []
422441 # dir() sorts the attributes for us, so we don't need to do it again.
@@ -428,15 +447,16 @@ def __str__(self):
428447 args .append ("|" .join (flag_arg ))
429448 return "{}({})" .format (self .__class__ .__name__ , ", " .join (args ))
430449
431- def match (self , value ) :
432- if not re .match (self .pattern , value , self .flags ):
450+ def match (self , value : str ) -> Mismatch | None :
451+ if not re .match (self .pattern , value , self .flags ): # type: ignore[arg-type]
433452 pattern = self .pattern
434453 if not isinstance (pattern , str ):
435454 pattern = pattern .decode ("latin1" )
436455 pattern = pattern .encode ("unicode_escape" ).decode ("ascii" )
437456 return Mismatch (
438457 "{!r} does not match /{}/" .format (value , pattern .replace ("\\ \\ " , "\\ " ))
439458 )
459+ return None
440460
441461
442462def has_len (x , y ):
0 commit comments