20
20
from itertools import product
21
21
import copy
22
22
23
- from mxnet .test_utils import assert_allclose
23
+ from mxnet .test_utils import assert_almost_equal
24
24
25
25
def check_unsupported_single_sym (sym ):
26
26
wrapped_sym = mx .sym .Group ([mx .sym .identity (s ) for s in sym ])
@@ -74,15 +74,8 @@ def check_single_sym(sym, arg_params_shapes=None, aux_params_shapes=None,
74
74
75
75
trt_fp32_outputs = [arr .asnumpy () for arr in trt_fp32_executor .outputs ]
76
76
for j , (orig , fp16 , fp32 ) in enumerate (zip (orig_outputs , trt_fp16_outputs , trt_fp32_outputs )):
77
- #abs_orig = abs(orig)
78
- #diff32 = abs(fp32 - orig)
79
- #diff16 = abs(fp16.astype('float32') - orig)
80
- #_atol32 = diff32 - rtol_fp32 * abs_orig
81
- #_atol16 = diff16 - rtol_fp16 * abs_orig
82
- #print("{}: diff32({:.2E}) | diff16({:.2E}) | atol32({:.2E}) | atol16({:.2E}) | orig.min({:.2E})".format(
83
- # j, diff32.max(), diff16.max(), _atol32.max(), _atol16.max(), abs_orig.min()))
84
- assert_allclose (fp32 , orig , rtol = rtol_fp32 , atol = atol_fp32 )
85
- assert_allclose (fp16 .astype ('float32' ), orig , rtol = rtol_fp16 , atol = atol_fp16 )
77
+ assert_almost_equal (fp32 , orig , rtol = rtol_fp32 , atol = atol_fp32 )
78
+ assert_almost_equal (fp16 .astype ('float32' ), orig , rtol = rtol_fp16 , atol = atol_fp16 )
86
79
87
80
def test_noop ():
88
81
data = mx .sym .Variable ('data' )
@@ -108,7 +101,7 @@ def test_fp16():
108
101
executor .copy_params_from (arg_params , {})
109
102
executor .forward (is_train = False )
110
103
outputs = executor .outputs [0 ].asnumpy ()
111
- assert_allclose (outputs , arr , rtol = 0. , atol = 0. )
104
+ assert_almost_equal (outputs , arr , rtol = 0. , atol = 0. )
112
105
113
106
def test_convolution2d ():
114
107
data = mx .sym .Variable ('data' )
@@ -318,15 +311,8 @@ def check_batch_norm(sym, arg_params_shapes=None, aux_params_shapes=None,
318
311
for j , (orig , fp16 , fp32 ) in enumerate (zip (orig_outputs ,
319
312
trt_fp16_outputs ,
320
313
trt_fp32_outputs )):
321
- #abs_orig = abs(orig)
322
- #diff32 = abs(fp32 - orig)
323
- #diff16 = abs(fp16.astype('float32') - orig)
324
- #_atol32 = diff32 - rtol_fp32 * abs_orig
325
- #_atol16 = diff16 - rtol_fp16 * abs_orig
326
- #print("{}: diff32({:.2E}) | diff16({:.2E}) | atol32({:.2E}) | atol16({:.2E}) | orig.min({:.2E})".format(
327
- # j, diff32.max(), diff16.max(), _atol32.max(), _atol16.max(), abs_orig.min()))
328
- assert_allclose (fp32 , orig , rtol = rtol_fp32 , atol = atol_fp32 )
329
- assert_allclose (fp16 .astype ('float32' ), orig , rtol = rtol_fp16 , atol = atol_fp16 )
314
+ assert_almost_equal (fp32 , orig , rtol = rtol_fp32 , atol = atol_fp32 )
315
+ assert_almost_equal (fp16 .astype ('float32' ), orig , rtol = rtol_fp16 , atol = atol_fp16 )
330
316
331
317
def test_batch_norm ():
332
318
data = mx .sym .Variable ('data' )
0 commit comments