24
24
import copy
25
25
import warnings
26
26
import re
27
- from collections import OrderedDict
27
+ from collections import OrderedDict , defaultdict
28
28
29
29
from ..base import mx_real_t , MXNetError
30
30
from .. import symbol , ndarray , initializer
@@ -413,7 +413,7 @@ def _collect_params_with_prefix(self, prefix=''):
413
413
ret .update (child ._collect_params_with_prefix (prefix + name ))
414
414
return ret
415
415
416
- def save_parameters (self , filename ):
416
+ def save_parameters (self , filename , deduplicate = False ):
417
417
"""Save parameters to file.
418
418
419
419
Saved parameters can only be loaded with `load_parameters`. Note that this
@@ -424,14 +424,28 @@ def save_parameters(self, filename):
424
424
----------
425
425
filename : str
426
426
Path to file.
427
+ deduplicate : bool, default False
428
+ If True, save shared parameters only once. Otherwise, if a Block
429
+ contains multiple sub-blocks that share parameters, each of the
430
+ shared parameters will be separately saved for every sub-block.
427
431
428
432
References
429
433
----------
430
434
`Saving and Loading Gluon Models \
431
435
<https://mxnet.apache.org/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html>`_
432
436
"""
433
437
params = self ._collect_params_with_prefix ()
434
- arg_dict = {key : val ._reduce () for key , val in params .items ()}
438
+
439
+ if deduplicate :
440
+ # Shared parameters are stored only a single time as of MXNet 1.6.
441
+ # Shared parameters are registered under multiple prefixes returned by
442
+ # _collect_params_with_prefix. We select a single one and only store
443
+ # it. In load_parameters it is sufficient for a shared parameter to
444
+ # only set it for a single prefix.
445
+ reverse_params = {v : k for k , v in params .items ()}
446
+ params = {v : k for k , v in reverse_params .items ()}
447
+
448
+ arg_dict = {key : val ._reduce () for key , val in params .items ()}
435
449
save_fn = _mx_npx .save if is_np_array () else ndarray .save
436
450
save_fn (filename , arg_dict )
437
451
@@ -510,15 +524,24 @@ def load_parameters(self, filename, ctx=None, allow_missing=False,
510
524
511
525
if not any ('.' in i for i in loaded .keys ()):
512
526
# legacy loading
513
- del loaded
527
+ loaded = None # This should be changed to ` del loaded` when dropping Python 2
514
528
self .collect_params ().load (
515
529
filename , ctx , allow_missing , ignore_extra , self .prefix ,
516
530
cast_dtype = cast_dtype , dtype_source = dtype_source )
517
531
return
518
532
519
533
if not allow_missing :
520
- for name in params .keys ():
521
- assert name in loaded , \
534
+ # Shared parameters are stored only a single time as of MXNet 1.6.
535
+ # We thus retrieve all prefixes (through _collect_params_with_prefix)
536
+ # that a shared parameter is used with. Check that there are no
537
+ # missing parameters that were not yet already loaded from the
538
+ # shared version.
539
+ params_inv = defaultdict (list )
540
+ for k , v in params .items ():
541
+ params_inv [v ].append (k )
542
+
543
+ for name , param in params .items ():
544
+ assert any (p in loaded for p in params_inv [param ]), \
522
545
"Parameter '%s' is missing in file '%s', which contains parameters: %s. " \
523
546
"Set allow_missing=True to ignore missing parameters." % (
524
547
name , filename , _brief_print_list (loaded .keys ()))
0 commit comments