Skip to content

Commit a6ff8fb

Browse files
kshitij12345larroy
authored andcommitted
assert_allclose -> rtol=1e-10 (apache#16198)
1 parent 366fffc commit a6ff8fb

File tree

1 file changed

+25
-8
lines changed

1 file changed

+25
-8
lines changed

tests/python/unittest/test_ndarray.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from itertools import permutations, combinations_with_replacement
2222
import os
2323
import pickle as pkl
24+
import functools
2425
from nose.tools import assert_raises, raises
2526
from common import with_seed, assertRaises, TemporaryDirectory
2627
from mxnet.test_utils import almost_equal
@@ -1887,14 +1888,16 @@ def check_save_load(save_is_np_shape, load_is_np_shape, shapes, save_throw_excep
18871888
check_save_load(True, True, [(2, 0, 1), (0,), (), (), (0, 4), (), (3, 0, 0, 0), (2, 1), (0, 5, 0)], False, False)
18881889

18891890

1890-
@with_seed()
1891-
def test_update_ops_mutation():
1892-
def assert_mutate(x, y, op):
1891+
def _test_update_ops_mutation_impl():
1892+
assert_allclose = functools.partial(
1893+
np.testing.assert_allclose, rtol=1e-10)
1894+
1895+
def assert_mutate(x, y):
18931896
np.testing.assert_raises(
1894-
AssertionError, np.testing.assert_allclose, x, y)
1897+
AssertionError, assert_allclose, x, y)
18951898

1896-
def assert_unchanged(x, y, op):
1897-
np.testing.assert_allclose(x, y)
1899+
def assert_unchanged(x, y):
1900+
assert_allclose(x, y)
18981901

18991902
def test_op(op, num_inputs, mutated_inputs, **kwargs):
19001903
for dim in range(1, 7):
@@ -1919,9 +1922,9 @@ def test_op(op, num_inputs, mutated_inputs, **kwargs):
19191922
for idx, (pre_array, post_array) in \
19201923
enumerate(zip(pre_arrays, post_arrays)):
19211924
if idx in mutated_inputs:
1922-
assert_mutate(pre_array, post_array, op)
1925+
assert_mutate(pre_array, post_array)
19231926
else:
1924-
assert_unchanged(pre_array, post_array, op)
1927+
assert_unchanged(pre_array, post_array)
19251928

19261929
test_op(mx.nd.signsgd_update, 2, [0], **
19271930
{'rescale_grad': 0.1, 'lr': 0.01, 'wd': 1e-3,
@@ -1952,6 +1955,20 @@ def test_op(op, num_inputs, mutated_inputs, **kwargs):
19521955
{'rescale_grad': 0.1, 'lr': 0.01, 'wd': 1e-3})
19531956

19541957

1958+
@with_seed()
1959+
def test_update_ops_mutation():
1960+
_test_update_ops_mutation_impl()
1961+
1962+
1963+
# Problem :
1964+
# https://github.com/apache/incubator-mxnet/pull/15768#issuecomment-532046408
1965+
@with_seed(412298777)
1966+
def test_update_ops_mutation_failed_seed():
1967+
# The difference was -5.9604645e-08 which was
1968+
# lower than then `rtol` of 1e-07
1969+
_test_update_ops_mutation_impl()
1970+
1971+
19551972
def test_large_int_rounding():
19561973
large_integer = 50000001
19571974

0 commit comments

Comments
 (0)