Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit f0c69f5

Browse files
leezuszha
authored andcommitted
Add missing default axis value to symbol.squeeze op (#15707)
* Add missing default arg * Add test * add test
1 parent f2ac85a commit f0c69f5

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

python/mxnet/symbol/symbol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2539,7 +2539,7 @@ def softmin(self, *args, **kwargs):
25392539
"""
25402540
return op.softmin(self, *args, **kwargs)
25412541

2542-
def squeeze(self, axis, inplace=False, **kwargs): # pylint: disable=unused-argument
2542+
def squeeze(self, axis=None, inplace=False, **kwargs): # pylint: disable=unused-argument
25432543
"""Convenience fluent method for :py:func:`squeeze`.
25442544
25452545
The arguments are the same as for :py:func:`squeeze`, with

tests/python/unittest/test_gluon.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_parameter_dict():
115115
params1.get('w1', shape=(10, 10), stype='row_sparse')
116116
params1.load('test_parameter_dict.params', ctx)
117117
trainer1 = mx.gluon.Trainer(params1, 'sgd')
118-
118+
119119
# compare the values before and after save/load
120120
cur_w0 = params1.get('w0').data(ctx)
121121
cur_w1 = params1.get('w1').row_sparse_data(all_row_ids)
@@ -134,7 +134,7 @@ def test_parameter_dict():
134134
cur_w1 = params2.get('w1').data(ctx)
135135
mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
136136
mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())
137-
137+
138138
# test the dtype casting functionality
139139
params0 = gluon.ParameterDict('')
140140
params0.get('w0', shape=(10, 10), dtype='float32')
@@ -386,7 +386,7 @@ def hybrid_forward(self, F, x):
386386
if 'conv' in param_name and 'weight' in param_name:
387387
break
388388
assert np.dtype(net_fp64.params[param_name].dtype) == np.dtype(np.float64)
389-
389+
390390
# 3.b Verify same functionnality with the imports API
391391
net_fp_64 = mx.gluon.SymbolBlock.imports(sym_file, 'data', params_file, ctx=ctx)
392392

@@ -2788,7 +2788,7 @@ def test_gluon_param_load():
27882788
net.cast('float16')
27892789
net.load_parameters('test_gluon_param_load.params', cast_dtype=True)
27902790
mx.nd.waitall()
2791-
2791+
27922792
@with_seed()
27932793
def test_gluon_param_load_dtype_source():
27942794
net = mx.gluon.nn.Dense(10, in_units=10)
@@ -2800,6 +2800,22 @@ def test_gluon_param_load_dtype_source():
28002800
assert net.weight.dtype == np.float16
28012801
mx.nd.waitall()
28022802

2803+
@with_seed()
2804+
def test_squeeze_consistency():
2805+
class Foo(gluon.HybridBlock):
2806+
def __init__(self, inplace, **kwargs):
2807+
super(Foo, self).__init__(**kwargs)
2808+
self.inplace = inplace
2809+
2810+
def forward(self, x):
2811+
return x.squeeze(inplace=self.inplace)
2812+
2813+
for inplace in (True, False):
2814+
block = Foo(inplace)
2815+
block.hybridize()
2816+
shape = (np.random.randint(1, 10), np.random.randint(1, 10), 1)
2817+
block(mx.nd.ones(shape))
2818+
28032819
if __name__ == '__main__':
28042820
import nose
28052821
nose.runmodule()

tests/python/unittest/test_symbol.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def check_fluent_regular(func, kwargs, shape=(5, 17, 1), equal_nan=False):
242242
check_fluent_regular('reshape', {'shape': (17, 1, 5)})
243243
check_fluent_regular('broadcast_to', {'shape': (5, 17, 47)})
244244
check_fluent_regular('squeeze', {'axis': (1, 3)}, shape=(2, 1, 3, 1, 4))
245+
check_fluent_regular('squeeze', {}, shape=(2, 1, 3, 1, 4))
245246

246247
def check_symbol_consistency(sym1, sym2, ctx, skip_grad=False, equal_nan=False):
247248
assert sym1.list_arguments() == sym2.list_arguments()

0 commit comments

Comments
 (0)