Skip to content

Commit 9291eb3

Browse files
committed
Change names of parameters to isclose() and deprecate tol
Per comments in PR review, this renames `rel_tol` and `abs_tol` to be `rtol` and `atol`, respectively, as well ask keeps the previous `tol` parameter but marks it as deprecated.
1 parent 0929e8d commit 9291eb3

2 files changed

Lines changed: 74 additions & 14 deletions

File tree

src/openfermion/ops/operators/symbolic_operator.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def __pow__(self, exponent):
605605

606606
def __eq__(self, other):
607607
"""Approximate numerical equality (not true equality)."""
608-
return self.isclose(other, rel_tol=EQ_TOLERANCE, abs_tol=EQ_TOLERANCE)
608+
return self.isclose(other)
609609

610610
def __ne__(self, other):
611611
return not (self == other)
@@ -618,7 +618,7 @@ def __next__(self):
618618
term, coefficient = next(self._iter)
619619
return self.__class__(term=term, coefficient=coefficient)
620620

621-
def isclose(self, other, rel_tol=EQ_TOLERANCE, abs_tol=EQ_TOLERANCE):
621+
def isclose(self, other, tol=None, rtol=EQ_TOLERANCE, atol=EQ_TOLERANCE):
622622
"""Check if other (SymbolicOperator) is close to self.
623623
624624
Comparison is done for each term individually. Return True
@@ -627,36 +627,55 @@ def isclose(self, other, rel_tol=EQ_TOLERANCE, abs_tol=EQ_TOLERANCE):
627627
628628
Args:
629629
other(SymbolicOperator): SymbolicOperator to compare against.
630-
rel_tol(float): Relative tolerance.
631-
abs_tol(float): Absolute tolerance.
630+
tol(float): This parameter is deprecated since version 1.8.0.
631+
Use `rtol` and/or `atol` instead. If `tol` is provided, it
632+
is used as the value of `atol`.
633+
rtol(float): Relative tolerance used in comparing each term in
634+
self and other.
635+
atol(float): Absolute tolerance used in comparing each term in
636+
self and other.
632637
"""
633638
if not isinstance(self, type(other)):
634639
return NotImplemented
635640

641+
if tol is not None:
642+
if rtol != EQ_TOLERANCE or atol != EQ_TOLERANCE:
643+
raise ValueError(
644+
'Parameters rtol and atol are mutually exclusive with the'
645+
' deprecated parameter tol; use either tol or the other two,'
646+
' not in combination.'
647+
)
648+
warnings.warn(
649+
'Parameter tol is deprecated. Use rtol and/or atol instead.',
650+
DeprecationWarning,
651+
stacklevel=2, # Identify the location of the warning.
652+
)
653+
atol = tol
654+
636655
# terms which are in both:
637656
for term in set(self.terms).intersection(set(other.terms)):
638657
a = self.terms[term]
639658
b = other.terms[term]
640659
if isinstance(a, sympy.Expr) or isinstance(b, sympy.Expr):
641-
if self._issmall(a - b, abs_tol) is False:
660+
if self._issmall(a - b, atol) is False:
642661
return False
643-
elif not abs(a - b) <= abs_tol + rel_tol * max(abs(a), abs(b)):
662+
elif not abs(a - b) <= atol + rtol * max(abs(a), abs(b)):
644663
return False
645-
# terms only in one (compare to 0.0 so only abs_tol)
664+
# terms only in one (compare to 0.0 so only atol)
646665
for term in set(self.terms).symmetric_difference(set(other.terms)):
647666
if term in self.terms:
648667
coeff = self.terms[term]
649668
if isinstance(coeff, sympy.Expr):
650-
if self._issmall(coeff, abs_tol) is False:
669+
if self._issmall(coeff, atol) is False:
651670
return False
652-
elif not abs(coeff) <= abs_tol:
671+
elif not abs(coeff) <= atol:
653672
return False
654673
else:
655674
coeff = other.terms[term]
656675
if isinstance(coeff, sympy.Expr):
657-
if self._issmall(coeff, abs_tol) is False:
676+
if self._issmall(coeff, atol) is False:
658677
return False
659-
elif not abs(coeff) <= abs_tol:
678+
elif not abs(coeff) <= atol:
660679
return False
661680
return True
662681

src/openfermion/ops/operators/symbolic_operator_test.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,48 @@ def test_pow_high_term(self):
860860
term = DummyOperator1(ops, coeff)
861861
high = term**10
862862
expected = DummyOperator1(ops * 10, coeff**10)
863-
self.assertTrue(high.isclose(expected, rel_tol=1e-12, abs_tol=1e-12))
863+
self.assertTrue(high.isclose(expected, rtol=1e-12, atol=1e-12))
864+
865+
def test_isclose_parameter_deprecation(self):
866+
op1 = DummyOperator1('0^ 1', 1.0)
867+
op2 = DummyOperator1('0^ 1', 1.001)
868+
869+
with self.assertWarns(DeprecationWarning):
870+
op1.isclose(op2, tol=0.01)
871+
872+
with warnings.catch_warnings():
873+
warnings.simplefilter("ignore", category=DeprecationWarning)
874+
self.assertTrue(op1.isclose(op2, tol=0.001))
875+
self.assertFalse(op1.isclose(op2, tol=0.0001))
876+
877+
def test_isclose_parameter_combos(self):
878+
op1 = DummyOperator1('0^ 1', 1.0)
879+
op2 = DummyOperator1('0^ 1', 1.001)
880+
881+
with self.assertRaises(ValueError):
882+
op1.isclose(op2, tol=0.01, rtol=1e-5)
883+
884+
with self.assertRaises(ValueError):
885+
op1.isclose(op2, tol=0.01, atol=1e-5)
886+
887+
def test_isclose_atol_rtol(self):
888+
op1 = DummyOperator1('0^ 1', 1.0)
889+
op2 = DummyOperator1('0^ 1', 1.001)
890+
891+
op_a = DummyOperator1('0^ 1', 1.0)
892+
op_b = DummyOperator1('0^ 1', 1.001)
893+
self.assertTrue(op_a.isclose(op_b, atol=0.001))
894+
self.assertFalse(op_a.isclose(op_b, atol=0.0001))
895+
896+
op_c = DummyOperator1('0^ 1', 1000)
897+
op_d = DummyOperator1('0^ 1', 1001)
898+
self.assertTrue(op_c.isclose(op_d, rtol=0.001))
899+
self.assertFalse(op_c.isclose(op_d, rtol=0.0001))
900+
901+
op_e = DummyOperator1('0^ 1', 1.0)
902+
op_f = DummyOperator1('0^ 1', 1.001)
903+
self.assertTrue(op_e.isclose(op_f, rtol=1e-4, atol=1e-3))
904+
self.assertFalse(op_e.isclose(op_f, rtol=1e-4, atol=1e-5))
864905

865906
def test_isclose(self):
866907
op1 = DummyOperator1()
@@ -869,8 +910,8 @@ def test_isclose(self):
869910
op1 += DummyOperator1('2^ 3', 1)
870911
op2 += DummyOperator1('0^ 1', 1000000)
871912
op2 += DummyOperator1('2^ 3', 1.001)
872-
self.assertFalse(op1.isclose(op2, abs_tol=1e-4))
873-
self.assertTrue(op1.isclose(op2, abs_tol=1e-2))
913+
self.assertFalse(op1.isclose(op2, atol=1e-4))
914+
self.assertTrue(op1.isclose(op2, atol=1e-2))
874915

875916
# Case from https://github.com/quantumlib/OpenFermion/issues/764
876917
x = FermionOperator("0^ 0")

0 commit comments

Comments
 (0)