21
21
from itertools import permutations , combinations_with_replacement
22
22
import os
23
23
import pickle as pkl
24
+ import functools
24
25
from nose .tools import assert_raises , raises
25
26
from common import with_seed , assertRaises , TemporaryDirectory
26
27
from mxnet .test_utils import almost_equal
@@ -1887,14 +1888,16 @@ def check_save_load(save_is_np_shape, load_is_np_shape, shapes, save_throw_excep
1887
1888
check_save_load (True , True , [(2 , 0 , 1 ), (0 ,), (), (), (0 , 4 ), (), (3 , 0 , 0 , 0 ), (2 , 1 ), (0 , 5 , 0 )], False , False )
1888
1889
1889
1890
1890
- @with_seed ()
1891
- def test_update_ops_mutation ():
1892
- def assert_mutate (x , y , op ):
1891
+ def _test_update_ops_mutation_impl ():
1892
+ assert_allclose = functools .partial (
1893
+ np .testing .assert_allclose , rtol = 1e-10 )
1894
+
1895
+ def assert_mutate (x , y ):
1893
1896
np .testing .assert_raises (
1894
- AssertionError , np . testing . assert_allclose , x , y )
1897
+ AssertionError , assert_allclose , x , y )
1895
1898
1896
- def assert_unchanged (x , y , op ):
1897
- np . testing . assert_allclose (x , y )
1899
+ def assert_unchanged (x , y ):
1900
+ assert_allclose (x , y )
1898
1901
1899
1902
def test_op (op , num_inputs , mutated_inputs , ** kwargs ):
1900
1903
for dim in range (1 , 7 ):
@@ -1919,9 +1922,9 @@ def test_op(op, num_inputs, mutated_inputs, **kwargs):
1919
1922
for idx , (pre_array , post_array ) in \
1920
1923
enumerate (zip (pre_arrays , post_arrays )):
1921
1924
if idx in mutated_inputs :
1922
- assert_mutate (pre_array , post_array , op )
1925
+ assert_mutate (pre_array , post_array )
1923
1926
else :
1924
- assert_unchanged (pre_array , post_array , op )
1927
+ assert_unchanged (pre_array , post_array )
1925
1928
1926
1929
test_op (mx .nd .signsgd_update , 2 , [0 ], **
1927
1930
{'rescale_grad' : 0.1 , 'lr' : 0.01 , 'wd' : 1e-3 ,
@@ -1952,6 +1955,20 @@ def test_op(op, num_inputs, mutated_inputs, **kwargs):
1952
1955
{'rescale_grad' : 0.1 , 'lr' : 0.01 , 'wd' : 1e-3 })
1953
1956
1954
1957
1958
+ @with_seed ()
1959
+ def test_update_ops_mutation ():
1960
+ _test_update_ops_mutation_impl ()
1961
+
1962
+
1963
+ # Problem :
1964
+ # https://github.com/apache/incubator-mxnet/pull/15768#issuecomment-532046408
1965
+ @with_seed (412298777 )
1966
+ def test_update_ops_mutation_failed_seed ():
1967
+ # The difference was -5.9604645e-08 which was
1968
+ # lower than then `rtol` of 1e-07
1969
+ _test_update_ops_mutation_impl ()
1970
+
1971
+
1955
1972
def test_large_int_rounding ():
1956
1973
large_integer = 50000001
1957
1974
0 commit comments