Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 8c44af4

Browse files
authored
[Gluon] Don't serialize shared parameters twice (#16582)
Add deduplicate argument (default of False) to save_parameters.
1 parent 0712f00 commit 8c44af4

File tree

2 files changed

+69
-6
lines changed

2 files changed

+69
-6
lines changed

python/mxnet/gluon/block.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import copy
2525
import warnings
2626
import re
27-
from collections import OrderedDict
27+
from collections import OrderedDict, defaultdict
2828

2929
from ..base import mx_real_t, MXNetError
3030
from .. import symbol, ndarray, initializer
@@ -413,7 +413,7 @@ def _collect_params_with_prefix(self, prefix=''):
413413
ret.update(child._collect_params_with_prefix(prefix + name))
414414
return ret
415415

416-
def save_parameters(self, filename):
416+
def save_parameters(self, filename, deduplicate=False):
417417
"""Save parameters to file.
418418
419419
Saved parameters can only be loaded with `load_parameters`. Note that this
@@ -424,14 +424,28 @@ def save_parameters(self, filename):
424424
----------
425425
filename : str
426426
Path to file.
427+
deduplicate : bool, default False
428+
If True, save shared parameters only once. Otherwise, if a Block
429+
contains multiple sub-blocks that share parameters, each of the
430+
shared parameters will be separately saved for every sub-block.
427431
428432
References
429433
----------
430434
`Saving and Loading Gluon Models \
431435
<https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html>`_
432436
"""
433437
params = self._collect_params_with_prefix()
434-
arg_dict = {key : val._reduce() for key, val in params.items()}
438+
439+
if deduplicate:
440+
# Shared parameters are stored only a single time as of MXNet 1.6.
441+
# Shared parameters are registered under multiple prefixes returned by
442+
# _collect_params_with_prefix. We select a single one and only store
443+
# it. In load_parameters it is sufficient for a shared parameter to
444+
# only set it for a single prefix.
445+
reverse_params = {v: k for k, v in params.items()}
446+
params = {v: k for k, v in reverse_params.items()}
447+
448+
arg_dict = {key: val._reduce() for key, val in params.items()}
435449
save_fn = _mx_npx.save if is_np_array() else ndarray.save
436450
save_fn(filename, arg_dict)
437451

@@ -510,15 +524,24 @@ def load_parameters(self, filename, ctx=None, allow_missing=False,
510524

511525
if not any('.' in i for i in loaded.keys()):
512526
# legacy loading
513-
del loaded
527+
loaded = None # This should be changed to `del loaded` when dropping Python 2
514528
self.collect_params().load(
515529
filename, ctx, allow_missing, ignore_extra, self.prefix,
516530
cast_dtype=cast_dtype, dtype_source=dtype_source)
517531
return
518532

519533
if not allow_missing:
520-
for name in params.keys():
521-
assert name in loaded, \
534+
# Shared parameters are stored only a single time as of MXNet 1.6.
535+
# We thus retrieve all prefixes (through _collect_params_with_prefix)
536+
# that a shared parameter is used with. Check that there are no
537+
# missing parameters that were not yet already loaded from the
538+
# shared version.
539+
params_inv = defaultdict(list)
540+
for k, v in params.items():
541+
params_inv[v].append(k)
542+
543+
for name, param in params.items():
544+
assert any(p in loaded for p in params_inv[param]), \
522545
"Parameter '%s' is missing in file '%s', which contains parameters: %s. " \
523546
"Set allow_missing=True to ignore missing parameters."%(
524547
name, filename, _brief_print_list(loaded.keys()))

tests/python/unittest/test_gluon.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,6 +1511,46 @@ def forward(self, x):
15111511
net2 = Network()
15121512
net2.load_parameters('tmp.params')
15131513

1514+
@with_seed()
1515+
def test_save_load_deduplicate_with_shared_params():
1516+
class B(mx.gluon.Block):
1517+
def __init__(self, params=None):
1518+
super(B, self).__init__(params=params)
1519+
1520+
with self.name_scope():
1521+
self.weight = self.params.get('weight', shape=(10, 10))
1522+
1523+
class C(mx.gluon.Block):
1524+
def __init__(self, b1, b2):
1525+
super(C, self).__init__()
1526+
self.b1 = b1
1527+
self.b2 = b2
1528+
1529+
b1 = B()
1530+
b2 = B(b1.collect_params())
1531+
c = C(b1, b2)
1532+
c.initialize()
1533+
c.save_parameters('tmp.params', deduplicate=True)
1534+
1535+
params = mx.nd.load('tmp.params')
1536+
assert len(params) == 1 # Only a single copy of the shared parameter is saved
1537+
1538+
b1 = B()
1539+
b2 = B(b1.collect_params())
1540+
c = C(b1, b2)
1541+
c.load_parameters('tmp.params')
1542+
1543+
# Test default behavior
1544+
c.save_parameters('tmp2.params', deduplicate=False)
1545+
1546+
params = mx.nd.load('tmp2.params')
1547+
assert len(params) == 2 # Only a single copy of the shared parameter is saved
1548+
1549+
b1 = B()
1550+
b2 = B(b1.collect_params())
1551+
c = C(b1, b2)
1552+
c.load_parameters('tmp2.params')
1553+
15141554
@with_seed()
15151555
def test_symbol_block_save_load():
15161556
class Net(gluon.HybridBlock):

0 commit comments

Comments
 (0)