Skip to content

Commit 9a1be53

Browse files
committed
Avoid division by zero in adagrad
The `update` subfunction in `adagrad()` will produce division-by-zero errors. These occur when running `check/pytest` in Python 3.12 with NumPy 2.x: ``` src/openfermion/resource_estimates/pbc/thc/factorizations/thc_jax_test.py::test_kpoint_thc_helper src/openfermion/resource_estimates/pbc/thc/generate_costing_table_thc_test.py::test_generate_costing_table_thc /usr/local/google/home/mhucka/project-files/google/github/openfermion/src/openfermion/resource_estimates/thc/utils/adagrad.py:43: RuntimeWarning: divide by zero encountered in divide g_sq_inv_sqrt = np.where(g_sq > 0, 1.0 / np.sqrt(g_sq), 0.0) ``` A common method to avoid this situation is to add a small positive number to `g_sq` before taking the square root. This ensures that the value is never exactly zero, avoiding the division-by-zero problem. This approach is a common and robust method for avoiding zero denominators in gradient-based optimization. It also acts as a smoothing factor, preventing the update from becoming unstable when `g_sq` gets very small.
1 parent 7f04e57 commit 9a1be53

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

  • src/openfermion/resource_estimates/thc/utils

src/openfermion/resource_estimates/thc/utils/adagrad.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@ def init(x0):
3737
m = np.zeros_like(x0)
3838
return x0, g_sq, m
3939

40-
def update(i, g, state):
40+
def update(i, g, state, eps=1e-9):
4141
x, g_sq, m = state
4242
g_sq += np.square(g)
43-
g_sq_inv_sqrt = np.where(g_sq > 0, 1.0 / np.sqrt(g_sq), 0.0)
43+
# Add a small number to avoid division by zero
44+
g_sq_safe = g_sq + eps
45+
g_sq_inv_sqrt = 1.0 / np.sqrt(g_sq_safe)
4446
m = (1.0 - momentum) * (g * g_sq_inv_sqrt) + momentum * m
4547
x = x - step_size(i) * m
4648
return x, g_sq, m

0 commit comments

Comments
 (0)