Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assertion Error when using low-precision numbers #429

Closed
TheSeparatrix opened this issue Jan 30, 2023 · 5 comments
Closed

Assertion Error when using low-precision numbers #429

TheSeparatrix opened this issue Jan 30, 2023 · 5 comments

Comments

@TheSeparatrix
Copy link

Describe the bug

Hello,
I used the 1d earth mover's distance function ot.emd2_1d to measure the distance between two outputs from a PyTorch neural network. With default parameters, all the numbers from the PyTorch model are float32.
The distance function raises an Assertion Error because of, what looks to me like a precision error.

File ~/miniconda3/envs/disentanglement_env/lib/python3.9/site-packages/ot/lp/solver_1d.py:361, in emd2_1d(x_a, x_b, a, b, metric, p, dense, log)
    276 r"""Solves the Earth Movers distance problem between 1d measures and returns
    277 the loss
    278 
   (...)
    357     instead of the cost)
    358 """
    359 # If we do not return G (log==False), then we should not to cast it to dense
    360 # (useless overhead)
--> 361 G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
    362                     dense=dense and log, log=True)
    363 cost = log_emd['cost']
    364 if log:

File ~/miniconda3/envs/disentanglement_env/lib/python3.9/site-packages/ot/lp/solver_1d.py:237, in emd_1d(x_a, x_b, a, b, metric, p, dense, log)
    234     b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]
    236 # ensure that same mass
--> 237 np.testing.assert_almost_equal(
    238     nx.to_numpy(nx.sum(a, axis=0)),
    239     nx.to_numpy(nx.sum(b, axis=0)),
    240     err_msg='a and b vector must have the same sum'
    241 )
    242 b = b * nx.sum(a) / nx.sum(b)
    244 x_a_1d = nx.reshape(x_a, (-1,))

File ~/miniconda3/envs/disentanglement_env/lib/python3.9/site-packages/numpy/testing/_private/utils.py:599, in assert_almost_equal(actual, desired, decimal, err_msg, verbose)
    597     pass
    598 if abs(desired - actual) >= 1.5 * 10.0**(-decimal):
--> 599     raise AssertionError(_build_err_msg())

AssertionError: 
Arrays are not almost equal to 7 decimals a and b vector must have the same sum
 ACTUAL: 1.0000001
 DESIRED: 0.99999994

Expected behavior

Is this the correct behaviour? If this 0.00000016 discrepancy changes the output of the function then I will have to reconsider using higher precision numbers. However, if this doesn't impact the result too much, maybe changing this to a warning rather than an error would be good.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): observed for same code both on Linux and MacOS
  • How was POT installed (source, pip, conda): conda-forge

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)

Output:

Linux-3.10.0-1160.80.1.el7.x86_64-x86_64-with-glibc2.17
Python 3.9.12 (main, Jun  1 2022, 11:38:51) 
[GCC 7.5.0]
NumPy 1.21.5
SciPy 1.7.3
POT 0.8.2
@rflamary
Copy link
Collaborator

This problem comes back to bite us in the a** regularly.

basically the numeircal solver in cython/c required some amount of precision that is tested in the function (or should be compensated) but it fails on some cases hard to reproduce. Could you give us a simple script where this happens please?

@alexisthual
Copy link

alexisthual commented May 16, 2023

I'm seeing a similar issue with code that is unfortunately a bit hard to reproduce...
Maybe sinkhorn2() with a low regularization would be a good replacement?

@Kadam-Tushar
Copy link

I am also facing this issue in my network in PyTorch where after sigmoid layer sum of inputs does not add to exactly 1.0

@rflamary
Copy link
Collaborator

rflamary commented Aug 4, 2023

Hello I'm proposing a fix here with slightly easier assert condition and ability to skip the test.

I'm planning on merging it shortly and then do a new release of POT. Feel free to test the fix and reopen the issue if it is not satisfactory.

@rflamary
Copy link
Collaborator

rflamary commented Aug 9, 2023

The PR is merged and there is now a new release. I'm closing this Issue, feel free to reopen it if you sill have the problem

@rflamary rflamary closed this as completed Aug 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants