Skip to content

Commit fbc313e

Browse files
committed
[Feature] Support using lazy object in list, tuple and set
Add pop method in ConfigList
1 parent 6c5eebb commit fbc313e

File tree

4 files changed

+187
-73
lines changed

4 files changed

+187
-73
lines changed

mmengine/config/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from .config import Config, ConfigDict, DictAction, read_base
2+
from .config import (Config, ConfigDict, ConfigList, ConfigSet, ConfigTuple,
3+
DictAction, read_base)
34

4-
__all__ = ['Config', 'ConfigDict', 'DictAction', 'read_base']
5+
__all__ = [
6+
'Config', 'ConfigDict', 'DictAction', 'read_base', 'ConfigList',
7+
'ConfigSet', 'ConfigTuple'
8+
]

mmengine/config/config.py

Lines changed: 132 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,87 @@
4545
import re # type: ignore
4646

4747

48-
def _lazy2string(cfg_dict, dict_type=None):
49-
if isinstance(cfg_dict, dict):
50-
dict_type = dict_type or type(cfg_dict)
51-
return dict_type({k: _lazy2string(v) for k, v in dict.items(cfg_dict)})
52-
elif isinstance(cfg_dict, (tuple, list)):
53-
return type(cfg_dict)(_lazy2string(v) for v in cfg_dict)
54-
elif isinstance(cfg_dict, (LazyAttr, LazyObject)):
55-
return f'{cfg_dict.module}.{str(cfg_dict)}'
56-
else:
57-
return cfg_dict
48+
class LazyContainerMeta(type):
49+
50+
def __init__(self, *args, **kwargs):
51+
super().__init__(*args, **kwargs)
52+
self.lazy = False
53+
54+
55+
class LazyContainerMixin(metaclass=LazyContainerMeta):
56+
57+
def to_builtin(self, keep_lazy=False):
58+
59+
def _to_builtin(cfg):
60+
if isinstance(cfg, dict):
61+
return dict({k: _to_builtin(v) for k, v in dict.items(cfg)})
62+
elif isinstance(cfg, tuple):
63+
return tuple(_to_builtin(v) for v in tuple.__iter__(cfg))
64+
elif isinstance(cfg, list):
65+
return list(_to_builtin(v) for v in list.__iter__(cfg))
66+
elif isinstance(cfg, set):
67+
return {_to_builtin(v) for v in set.__iter__(cfg)}
68+
elif isinstance(cfg, (LazyAttr, LazyObject)):
69+
if not keep_lazy:
70+
return f'{cfg.module}.{str(cfg)}'
71+
else:
72+
return cfg
73+
else:
74+
return cfg
75+
76+
return _to_builtin(self)
77+
78+
def build_lazy(self, value: Any) -> Any:
79+
"""If class attribute ``lazy`` is False, the LazyObject will be built
80+
and returned.
81+
82+
Args:
83+
value (Any): The value to be built.
84+
85+
Returns:
86+
Any: The built value.
87+
"""
88+
if (isinstance(value, (LazyAttr, LazyObject))
89+
and not self.__class__.lazy):
90+
value = value.build()
91+
return value
92+
93+
def __deepcopy__(self, memo):
94+
return self.__class__(
95+
copy.deepcopy(item, memo) for item in super().__iter__())
96+
97+
def __copy__(self):
98+
return self.__class__(item for item in super().__iter__())
99+
100+
def __iter__(self):
101+
# Override `__iter__` to overwrite to support star unpacking
102+
# `*cfg_list`
103+
yield from map(self.build_lazy, super().__iter__())
104+
105+
def __getitem__(self, idx):
106+
try:
107+
value = self.build_lazy(super().__getitem__(idx))
108+
except Exception as e:
109+
raise e
110+
else:
111+
return value
112+
113+
def __eq__(self, other):
114+
return all(a == b for a, b in zip(self, other))
115+
116+
def __reduce_ex__(self, proto):
117+
# Override __reduce_ex__ to avoid dump the built lazy object.
118+
if digit_version(platform.python_version()) < digit_version('3.8'):
119+
return (self.__class__, (tuple(i for i in super().__iter__()), ),
120+
None, None, None)
121+
else:
122+
return (self.__class__, (tuple(i for i in super().__iter__()), ),
123+
None, None, None, None)
58124

125+
copy = __copy__
59126

60-
class ConfigDict(Dict):
127+
128+
class ConfigDict(LazyContainerMixin, Dict):
61129
"""A dictionary for config which has the same interface as python's built-
62130
in dictionary and can be used as a normal dictionary.
63131
@@ -72,7 +140,6 @@ class ConfigDict(Dict):
72140
object during configuration parsing, and it should be set to False outside
73141
the Config to ensure that users do not experience the ``LazyObject``.
74142
"""
75-
lazy = False
76143

77144
def __init__(__self, *args, **kwargs):
78145
object.__setattr__(__self, '__parent', kwargs.pop('__parent', None))
@@ -118,8 +185,14 @@ def __getattr__(self, name):
118185
@classmethod
119186
def _hook(cls, item):
120187
# avoid to convert user defined dict to ConfigDict.
121-
if type(item) in (dict, OrderedDict):
188+
if isinstance(item, ConfigDict):
189+
return item
190+
elif type(item) in (dict, OrderedDict):
122191
return cls(item)
192+
elif isinstance(item, LazyContainerMixin):
193+
return type(item)(
194+
cls._hook(elem)
195+
for elem in super(LazyContainerMixin, item).__iter__())
123196
elif isinstance(item, (list, tuple)):
124197
return type(item)(cls._hook(elem) for elem in item)
125198
return item
@@ -150,11 +223,6 @@ def __copy__(self):
150223

151224
copy = __copy__
152225

153-
def __iter__(self):
154-
# Implement `__iter__` to overwrite the unpacking operator `**cfg_dict`
155-
# to get the built lazy object
156-
return iter(self.keys())
157-
158226
def get(self, key: str, default: Optional[Any] = None) -> Any:
159227
"""Get the value of the key. If class attribute ``lazy`` is True, the
160228
LazyObject will be built and returned.
@@ -201,20 +269,6 @@ def update(self, *args, **kwargs) -> None:
201269
else:
202270
self[k].update(v)
203271

204-
def build_lazy(self, value: Any) -> Any:
205-
"""If class attribute ``lazy`` is False, the LazyObject will be built
206-
and returned.
207-
208-
Args:
209-
value (Any): The value to be built.
210-
211-
Returns:
212-
Any: The built value.
213-
"""
214-
if isinstance(value, (LazyAttr, LazyObject)) and not self.lazy:
215-
value = value.build()
216-
return value
217-
218272
def values(self):
219273
"""Yield the values of the dictionary.
220274
@@ -288,28 +342,29 @@ def __eq__(self, other):
288342
return False
289343

290344
def _to_lazy_dict(self):
291-
"""Convert the ConfigDict to a normal dictionary recursively, and keep
292-
the ``LazyObject`` or ``LazyAttr`` object not built."""
345+
# NOTE: Keep this function for backward compatibility.
293346

294-
def _to_dict(data):
295-
if isinstance(data, ConfigDict):
296-
return {
297-
key: _to_dict(value)
298-
for key, value in Dict.items(data)
299-
}
300-
elif isinstance(data, dict):
301-
return {key: _to_dict(value) for key, value in data.items()}
302-
elif isinstance(data, (list, tuple)):
303-
return type(data)(_to_dict(item) for item in data)
304-
else:
305-
return data
306-
307-
return _to_dict(self)
347+
return self.to_builtin(keep_lazy=True)
308348

309349
def to_dict(self):
310-
"""Convert the ConfigDict to a normal dictionary recursively, and
311-
convert the ``LazyObject`` or ``LazyAttr`` to string."""
312-
return _lazy2string(self, dict_type=dict)
350+
# NOTE: Keep this function for backward compatibility.
351+
return self.to_builtin()
352+
353+
354+
class ConfigList(LazyContainerMixin, list): # type: ignore
355+
356+
def pop(self, idx):
357+
return self.build_lazy(super().pop(idx))
358+
359+
360+
class ConfigTuple(LazyContainerMixin, tuple): # type: ignore
361+
...
362+
363+
364+
class ConfigSet(LazyContainerMixin, set): # type: ignore
365+
366+
def pop(self, idx):
367+
return self.build_lazy(super().pop(idx))
313368

314369

315370
def add_args(parser: ArgumentParser,
@@ -479,27 +534,30 @@ def fromfile(filename: Union[str, Path],
479534
env_variables=env_variables,
480535
)
481536
else:
482-
# Enable lazy import when parsing the config.
483-
# Using try-except to make sure ``ConfigDict.lazy`` will be reset
484-
# to False. See more details about lazy in the docstring of
485-
# ConfigDict
486-
ConfigDict.lazy = True
487-
try:
537+
with Config._lazy_context():
488538
cfg_dict, imported_names = Config._parse_lazy_import(filename)
489-
except Exception as e:
490-
raise e
491-
finally:
492-
# disable lazy import to get the real type. See more details
493-
# about lazy in the docstring of ConfigDict
494-
ConfigDict.lazy = False
495-
496539
cfg = Config(
497540
cfg_dict,
498541
filename=filename,
499542
format_python_code=format_python_code)
500543
object.__setattr__(cfg, '_imported_names', imported_names)
501544
return cfg
502545

546+
@staticmethod
547+
@contextmanager
548+
def _lazy_context():
549+
ConfigDict.lazy = True
550+
ConfigSet.lazy = True
551+
ConfigList.lazy = True
552+
ConfigTuple.lazy = True
553+
554+
yield
555+
556+
ConfigDict.lazy = False
557+
ConfigSet.lazy = False
558+
ConfigList.lazy = False
559+
ConfigTuple.lazy = False
560+
503561
@staticmethod
504562
def fromstring(cfg_str: str, file_format: str) -> 'Config':
505563
"""Build a Config instance from config text.
@@ -1110,12 +1168,12 @@ def _parse_lazy_import(filename: str) -> Tuple[ConfigDict, set]:
11101168
continue
11111169
ret[key] = value
11121170
# convert dict to ConfigDict
1113-
cfg_dict = Config._dict_to_config_dict_lazy(ret)
1171+
cfg_dict = Config._to_lazy_container(ret)
11141172

11151173
return cfg_dict, imported_names
11161174

11171175
@staticmethod
1118-
def _dict_to_config_dict_lazy(cfg: dict):
1176+
def _to_lazy_container(cfg: dict):
11191177
"""Recursively converts ``dict`` to :obj:`ConfigDict`. The only
11201178
difference between ``_dict_to_config_dict_lazy`` and
11211179
``_dict_to_config_dict_lazy`` is that the former one does not consider
@@ -1131,11 +1189,15 @@ def _dict_to_config_dict_lazy(cfg: dict):
11311189
if isinstance(cfg, dict):
11321190
cfg_dict = ConfigDict()
11331191
for key, value in cfg.items():
1134-
cfg_dict[key] = Config._dict_to_config_dict_lazy(value)
1192+
cfg_dict[key] = Config._to_lazy_container(value)
11351193
return cfg_dict
1136-
if isinstance(cfg, (tuple, list)):
1137-
return type(cfg)(
1138-
Config._dict_to_config_dict_lazy(_cfg) for _cfg in cfg)
1194+
if isinstance(cfg, list):
1195+
return ConfigList(Config._to_lazy_container(_cfg) for _cfg in cfg)
1196+
if isinstance(cfg, tuple):
1197+
return ConfigTuple(Config._to_lazy_container(_cfg) for _cfg in cfg)
1198+
if isinstance(cfg, set):
1199+
return ConfigSet(Config._to_lazy_container(_cfg) for _cfg in cfg)
1200+
11391201
return cfg
11401202

11411203
@staticmethod

tests/data/config/lazy_module_config/toy_model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2+
from torch.distributed.fsdp.wrap import (size_based_auto_wrap_policy,
3+
transformer_auto_wrap_policy)
4+
5+
from mmengine._strategy import ColossalAIStrategy
26
from mmengine.config import read_base
37
from mmengine.dataset import DefaultSampler
48
from mmengine.hooks import EMAHook
@@ -46,4 +50,7 @@
4650
priority=49)
4751
]
4852

53+
# illegal model wrapper config, just for unit test.
54+
strategy = dict(type=ColossalAIStrategy, model_wrapper=dict(
55+
auto_wrap_policy=(size_based_auto_wrap_policy, transformer_auto_wrap_policy)))
4956
runner_type = FlexibleRunner

tests/test_config/test_config.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616

1717
import mmengine
18-
from mmengine import Config, ConfigDict, DictAction
18+
from mmengine import Config, ConfigDict, ConfigList, DictAction
1919
from mmengine.config.lazy import LazyObject
2020
from mmengine.fileio import dump, load
2121
from mmengine.registry import MODELS, DefaultScope, Registry
@@ -1209,3 +1209,44 @@ def _recursive_check_lazy(self, cfg, expr):
12091209
[self._recursive_check_lazy(value, expr) for value in cfg]
12101210
else:
12111211
self.assertTrue(expr(cfg))
1212+
1213+
1214+
class TestConfigList(TestCase):
1215+
1216+
def test_getitem(self):
1217+
cfg_list = ConfigList([
1218+
1, 2,
1219+
ConfigDict(type=LazyObject('mmengine')),
1220+
LazyObject('mmengine')
1221+
])
1222+
self.assertIs(cfg_list[2]['type'], mmengine)
1223+
self.assertIs(cfg_list[3], mmengine)
1224+
1225+
def test_star(self):
1226+
1227+
def check_star(a, b, c):
1228+
self.assertIs(c, mmengine)
1229+
1230+
cfg_list = ConfigList([1, 2, LazyObject('mmengine')])
1231+
check_star(*cfg_list)
1232+
1233+
def check_for_loop(self):
1234+
cfg_list = ConfigList([LazyObject('mmengine')])
1235+
for i in cfg_list:
1236+
self.assertIs(i, mmengine)
1237+
1238+
def test_copy(self):
1239+
cfg_list = ConfigList([
1240+
1, 2,
1241+
ConfigDict(type=LazyObject('mmengine')),
1242+
LazyObject('mmengine')
1243+
])
1244+
cfg_copy = cfg_list.copy()
1245+
self.assertIsInstance(cfg_copy, ConfigList)
1246+
self.assertEqual(cfg_list, cfg_copy)
1247+
self.assertIs(cfg_list[2], cfg_copy[2])
1248+
1249+
cfg_copy = copy.deepcopy(cfg_list)
1250+
self.assertIsInstance(cfg_copy, ConfigList)
1251+
self.assertEqual(cfg_list, cfg_copy)
1252+
self.assertIsNot(cfg_list[2], cfg_copy[2])

0 commit comments

Comments
 (0)