@@ -605,7 +605,7 @@ def __pow__(self, exponent):
605605
606606 def __eq__ (self , other ):
607607 """Approximate numerical equality (not true equality)."""
608- return self .isclose (other , rel_tol = EQ_TOLERANCE , abs_tol = EQ_TOLERANCE )
608+ return self .isclose (other )
609609
610610 def __ne__ (self , other ):
611611 return not (self == other )
@@ -618,7 +618,7 @@ def __next__(self):
618618 term , coefficient = next (self ._iter )
619619 return self .__class__ (term = term , coefficient = coefficient )
620620
621- def isclose (self , other , rel_tol = EQ_TOLERANCE , abs_tol = EQ_TOLERANCE ):
621+ def isclose (self , other , tol = None , rtol = EQ_TOLERANCE , atol = EQ_TOLERANCE ):
622622 """Check if other (SymbolicOperator) is close to self.
623623
624624 Comparison is done for each term individually. Return True
@@ -627,36 +627,55 @@ def isclose(self, other, rel_tol=EQ_TOLERANCE, abs_tol=EQ_TOLERANCE):
627627
628628 Args:
629629 other(SymbolicOperator): SymbolicOperator to compare against.
630- rel_tol(float): Relative tolerance.
631- abs_tol(float): Absolute tolerance.
630+ tol(float): This parameter is deprecated since version 1.8.0.
631+ Use `rtol` and/or `atol` instead. If `tol` is provided, it
632+ is used as the value of `atol`.
633+ rtol(float): Relative tolerance used in comparing each term in
634+ self and other.
635+ atol(float): Absolute tolerance used in comparing each term in
636+ self and other.
632637 """
633638 if not isinstance (self , type (other )):
634639 return NotImplemented
635640
641+ if tol is not None :
642+ if rtol != EQ_TOLERANCE or atol != EQ_TOLERANCE :
643+ raise ValueError (
644+ 'Parameters rtol and atol are mutually exclusive with the'
645+ ' deprecated parameter tol; use either tol or the other two,'
646+ ' not in combination.'
647+ )
648+ warnings .warn (
649+ 'Parameter tol is deprecated. Use rtol and/or atol instead.' ,
650+ DeprecationWarning ,
651+ stacklevel = 2 , # Identify the location of the warning.
652+ )
653+ atol = tol
654+
636655 # terms which are in both:
637656 for term in set (self .terms ).intersection (set (other .terms )):
638657 a = self .terms [term ]
639658 b = other .terms [term ]
640659 if isinstance (a , sympy .Expr ) or isinstance (b , sympy .Expr ):
641- if self ._issmall (a - b , abs_tol ) is False :
660+ if self ._issmall (a - b , atol ) is False :
642661 return False
643- elif not abs (a - b ) <= abs_tol + rel_tol * max (abs (a ), abs (b )):
662+ elif not abs (a - b ) <= atol + rtol * max (abs (a ), abs (b )):
644663 return False
645- # terms only in one (compare to 0.0 so only abs_tol )
664+ # terms only in one (compare to 0.0 so only atol )
646665 for term in set (self .terms ).symmetric_difference (set (other .terms )):
647666 if term in self .terms :
648667 coeff = self .terms [term ]
649668 if isinstance (coeff , sympy .Expr ):
650- if self ._issmall (coeff , abs_tol ) is False :
669+ if self ._issmall (coeff , atol ) is False :
651670 return False
652- elif not abs (coeff ) <= abs_tol :
671+ elif not abs (coeff ) <= atol :
653672 return False
654673 else :
655674 coeff = other .terms [term ]
656675 if isinstance (coeff , sympy .Expr ):
657- if self ._issmall (coeff , abs_tol ) is False :
676+ if self ._issmall (coeff , atol ) is False :
658677 return False
659- elif not abs (coeff ) <= abs_tol :
678+ elif not abs (coeff ) <= atol :
660679 return False
661680 return True
662681
0 commit comments