Skip to content

Assertion Error when using low-precision numbers #429

Description

@TheSeparatrix

Describe the bug

Hello,
I used the 1d earth mover's distance function ot.emd2_1d to measure the distance between two outputs from a PyTorch neural network. With default parameters, all the numbers from the PyTorch model are float32.
The distance function raises an Assertion Error because of, what looks to me like a precision error.

File ~/miniconda3/envs/disentanglement_env/lib/python3.9/site-packages/ot/lp/solver_1d.py:361, in emd2_1d(x_a, x_b, a, b, metric, p, dense, log)
    276 r"""Solves the Earth Movers distance problem between 1d measures and returns
    277 the loss
    278 
   (...)
    357     instead of the cost)
    358 """
    359 # If we do not return G (log==False), then we should not to cast it to dense
    360 # (useless overhead)
--> 361 G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
    362                     dense=dense and log, log=True)
    363 cost = log_emd['cost']
    364 if log:

File ~/miniconda3/envs/disentanglement_env/lib/python3.9/site-packages/ot/lp/solver_1d.py:237, in emd_1d(x_a, x_b, a, b, metric, p, dense, log)
    234     b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]
    236 # ensure that same mass
--> 237 np.testing.assert_almost_equal(
    238     nx.to_numpy(nx.sum(a, axis=0)),
    239     nx.to_numpy(nx.sum(b, axis=0)),
    240     err_msg='a and b vector must have the same sum'
    241 )
    242 b = b * nx.sum(a) / nx.sum(b)
    244 x_a_1d = nx.reshape(x_a, (-1,))

File ~/miniconda3/envs/disentanglement_env/lib/python3.9/site-packages/numpy/testing/_private/utils.py:599, in assert_almost_equal(actual, desired, decimal, err_msg, verbose)
    597     pass
    598 if abs(desired - actual) >= 1.5 * 10.0**(-decimal):
--> 599     raise AssertionError(_build_err_msg())

AssertionError: 
Arrays are not almost equal to 7 decimals a and b vector must have the same sum
 ACTUAL: 1.0000001
 DESIRED: 0.99999994

Expected behavior

Is this the correct behaviour? If this 0.00000016 discrepancy changes the output of the function then I will have to reconsider using higher precision numbers. However, if this doesn't impact the result too much, maybe changing this to a warning rather than an error would be good.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): observed for same code both on Linux and MacOS
  • How was POT installed (source, pip, conda): conda-forge

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)

Output:

Linux-3.10.0-1160.80.1.el7.x86_64-x86_64-with-glibc2.17
Python 3.9.12 (main, Jun  1 2022, 11:38:51) 
[GCC 7.5.0]
NumPy 1.21.5
SciPy 1.7.3
POT 0.8.2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions