Skip to content

Commit d8eb68a

Browse files
Nic-Maholgerrothmonai-bot
authored
5821 6303 Optimize MonaiAlgo FL based on BundleWorkflow (#6158)
part of #5821 Fixes #6303 ### Description This PR simplified the MONAI FL `MonaiAlgo` module to leverage `BundleWorkflow`. The main point is to decouple the bundle read / write related logic with FL module and use predefined required-properties. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Nic Ma <[email protected]> Signed-off-by: monai-bot <[email protected]> Co-authored-by: Holger Roth <[email protected]> Co-authored-by: monai-bot <[email protected]>
1 parent c30a5b9 commit d8eb68a

15 files changed

+361
-424
lines changed

monai/bundle/properties.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@
159159
BundleProperty.REQUIRED: True,
160160
BundlePropertyConfig.ID: "device",
161161
},
162+
"evaluator": {
163+
BundleProperty.DESC: "inference / evaluation workflow engine.",
164+
BundleProperty.REQUIRED: True,
165+
BundlePropertyConfig.ID: "evaluator",
166+
},
162167
"network_def": {
163168
BundleProperty.DESC: "network module for the inference.",
164169
BundleProperty.REQUIRED: True,

monai/bundle/workflows.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,18 @@ class BundleWorkflow(ABC):
4444
4545
"""
4646

47+
supported_train_type: tuple = ("train", "training")
48+
supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation")
49+
4750
def __init__(self, workflow: str | None = None):
4851
if workflow is None:
4952
self.properties = None
5053
self.workflow = None
5154
return
52-
if workflow.lower() in ("train", "training"):
55+
if workflow.lower() in self.supported_train_type:
5356
self.properties = TrainProperties
5457
self.workflow = "train"
55-
elif workflow.lower() in ("infer", "inference", "eval", "evaluation"):
58+
elif workflow.lower() in self.supported_infer_type:
5659
self.properties = InferProperties
5760
self.workflow = "infer"
5861
else:
@@ -215,6 +218,7 @@ def __init__(
215218
else:
216219
settings_ = ConfigParser.load_config_files(tracking)
217220
self.patch_bundle_tracking(parser=self.parser, settings=settings_)
221+
self._is_initialized: bool = False
218222

219223
def initialize(self) -> Any:
220224
"""
@@ -223,6 +227,7 @@ def initialize(self) -> Any:
223227
"""
224228
# reset the "reference_resolver" buffer at initialization stage
225229
self.parser.parse(reset=True)
230+
self._is_initialized = True
226231
return self._run_expr(id=self.init_id)
227232

228233
def run(self) -> Any:
@@ -284,7 +289,7 @@ def _get_property(self, name: str, property: dict) -> Any:
284289
property: other information for the target property, defined in `TrainProperties` or `InferProperties`.
285290
286291
"""
287-
if not self.parser.ref_resolver.is_resolved():
292+
if not self._is_initialized:
288293
raise RuntimeError("Please execute 'initialize' before getting any parsed content.")
289294
prop_id = self._get_prop_id(name, property)
290295
return self.parser.get_parsed_content(id=prop_id) if prop_id is not None else None
@@ -303,6 +308,7 @@ def _set_property(self, name: str, property: dict, value: Any) -> None:
303308
if prop_id is not None:
304309
self.parser[prop_id] = value
305310
# must parse the config again after changing the content
311+
self._is_initialized = False
306312
self.parser.ref_resolver.reset()
307313

308314
def _check_optional_id(self, name: str, property: dict) -> bool:

monai/fl/client/monai_algo.py

Lines changed: 151 additions & 184 deletions
Large diffs are not rendered by default.

monai/fl/utils/constants.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,6 @@ class FlStatistics(StrEnum):
5151
FEATURE_NAMES = "feature_names"
5252

5353

54-
class RequiredBundleKeys(StrEnum):
55-
BUNDLE_ROOT = "bundle_root"
56-
57-
58-
class BundleKeys(StrEnum):
59-
TRAINER = "train#trainer"
60-
EVALUATOR = "validate#evaluator"
61-
TRAIN_TRAINER_MAX_EPOCHS = "train#trainer#max_epochs"
62-
VALIDATE_HANDLERS = "validate#handlers"
63-
DATASET_DIR = "dataset_dir"
64-
TRAIN_DATA = "train#dataset#data"
65-
VALID_DATA = "validate#dataset#data"
66-
67-
6854
class FiltersType(StrEnum):
6955
PRE_FILTERS = "pre_filters"
7056
POST_WEIGHT_FILTERS = "post_weight_filters"

monai/fl/utils/filters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __call__(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeO
3838

3939
class SummaryFilter(Filter):
4040
"""
41-
Summary filter to content of ExchangeObject.
41+
Summary filter to show content of ExchangeObject.
4242
"""
4343

4444
def __call__(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeObject:

tests/nonconfig_workflow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ def _get_property(self, name, property):
9999
return self._bundle_root
100100
if name == "device":
101101
return self._device
102+
if name == "evaluator":
103+
return self._evaluator
102104
if name == "network_def":
103105
return self._network_def
104106
if name == "inferer":
@@ -115,6 +117,8 @@ def _set_property(self, name, property, value):
115117
self._bundle_root = value
116118
elif name == "device":
117119
self._device = value
120+
elif name == "evaluator":
121+
self._evaluator = value
118122
elif name == "network_def":
119123
self._network_def = value
120124
elif name == "inferer":

tests/test_fl_monai_algo.py

Lines changed: 75 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313

1414
import os
1515
import shutil
16-
import tempfile
1716
import unittest
17+
from copy import deepcopy
18+
from os.path import join as pathjoin
1819

1920
from parameterized import parameterized
2021

21-
from monai.bundle import ConfigParser
22+
from monai.bundle import ConfigParser, ConfigWorkflow
2223
from monai.bundle.utils import DEFAULT_HANDLERS_ID
2324
from monai.fl.client.monai_algo import MonaiAlgo
2425
from monai.fl.utils.constants import ExtraItems
@@ -28,11 +29,14 @@
2829

2930
_root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__)))
3031
_data_dir = os.path.join(_root_dir, "testing_data")
32+
_logging_file = pathjoin(_data_dir, "logging.conf")
3133

3234
TEST_TRAIN_1 = [
3335
{
3436
"bundle_root": _data_dir,
35-
"config_train_filename": os.path.join(_data_dir, "config_fl_train.json"),
37+
"train_workflow": ConfigWorkflow(
38+
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
39+
),
3640
"config_evaluate_filename": None,
3741
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
3842
}
@@ -48,68 +52,92 @@
4852
TEST_TRAIN_3 = [
4953
{
5054
"bundle_root": _data_dir,
51-
"config_train_filename": [
52-
os.path.join(_data_dir, "config_fl_train.json"),
53-
os.path.join(_data_dir, "config_fl_train.json"),
54-
],
55+
"train_workflow": ConfigWorkflow(
56+
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
57+
),
5558
"config_evaluate_filename": None,
56-
"config_filters_filename": [
57-
os.path.join(_data_dir, "config_fl_filters.json"),
58-
os.path.join(_data_dir, "config_fl_filters.json"),
59-
],
59+
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
60+
}
61+
]
62+
63+
TEST_TRAIN_4 = [
64+
{
65+
"bundle_root": _data_dir,
66+
"train_workflow": ConfigWorkflow(
67+
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
68+
),
69+
"config_evaluate_filename": None,
70+
"tracking": {
71+
"handlers_id": DEFAULT_HANDLERS_ID,
72+
"configs": {
73+
"execute_config": f"{_data_dir}/config_executed.json",
74+
"trainer": {
75+
"_target_": "MLFlowHandler",
76+
"tracking_uri": path_to_uri(_data_dir) + "/mlflow_override",
77+
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)",
78+
"close_on_complete": True,
79+
},
80+
},
81+
},
82+
"config_filters_filename": None,
6083
}
6184
]
6285

6386
TEST_EVALUATE_1 = [
6487
{
6588
"bundle_root": _data_dir,
6689
"config_train_filename": None,
67-
"config_evaluate_filename": os.path.join(_data_dir, "config_fl_evaluate.json"),
90+
"eval_workflow": ConfigWorkflow(
91+
config_file=[
92+
os.path.join(_data_dir, "config_fl_train.json"),
93+
os.path.join(_data_dir, "config_fl_evaluate.json"),
94+
],
95+
workflow="train",
96+
logging_file=_logging_file,
97+
),
6898
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
6999
}
70100
]
71101
TEST_EVALUATE_2 = [
72102
{
73103
"bundle_root": _data_dir,
74104
"config_train_filename": None,
75-
"config_evaluate_filename": os.path.join(_data_dir, "config_fl_evaluate.json"),
105+
"config_evaluate_filename": [
106+
os.path.join(_data_dir, "config_fl_train.json"),
107+
os.path.join(_data_dir, "config_fl_evaluate.json"),
108+
],
109+
"eval_workflow_name": "training",
76110
"config_filters_filename": None,
77111
}
78112
]
79113
TEST_EVALUATE_3 = [
80114
{
81115
"bundle_root": _data_dir,
82116
"config_train_filename": None,
83-
"config_evaluate_filename": [
84-
os.path.join(_data_dir, "config_fl_evaluate.json"),
85-
os.path.join(_data_dir, "config_fl_evaluate.json"),
86-
],
87-
"config_filters_filename": [
88-
os.path.join(_data_dir, "config_fl_filters.json"),
89-
os.path.join(_data_dir, "config_fl_filters.json"),
90-
],
117+
"eval_workflow": ConfigWorkflow(
118+
config_file=[
119+
os.path.join(_data_dir, "config_fl_train.json"),
120+
os.path.join(_data_dir, "config_fl_evaluate.json"),
121+
],
122+
workflow="train",
123+
logging_file=_logging_file,
124+
),
125+
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
91126
}
92127
]
93128

94129
TEST_GET_WEIGHTS_1 = [
95130
{
96131
"bundle_root": _data_dir,
97-
"config_train_filename": os.path.join(_data_dir, "config_fl_train.json"),
132+
"train_workflow": ConfigWorkflow(
133+
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
134+
),
98135
"config_evaluate_filename": None,
99136
"send_weight_diff": False,
100137
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
101138
}
102139
]
103140
TEST_GET_WEIGHTS_2 = [
104-
{
105-
"bundle_root": _data_dir,
106-
"config_train_filename": None,
107-
"config_evaluate_filename": None,
108-
"send_weight_diff": False,
109-
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
110-
}
111-
]
112-
TEST_GET_WEIGHTS_3 = [
113141
{
114142
"bundle_root": _data_dir,
115143
"config_train_filename": os.path.join(_data_dir, "config_fl_train.json"),
@@ -118,59 +146,31 @@
118146
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
119147
}
120148
]
121-
TEST_GET_WEIGHTS_4 = [
149+
TEST_GET_WEIGHTS_3 = [
122150
{
123151
"bundle_root": _data_dir,
124-
"config_train_filename": [
125-
os.path.join(_data_dir, "config_fl_train.json"),
126-
os.path.join(_data_dir, "config_fl_train.json"),
127-
],
152+
"train_workflow": ConfigWorkflow(
153+
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
154+
),
128155
"config_evaluate_filename": None,
129156
"send_weight_diff": True,
130-
"config_filters_filename": [
131-
os.path.join(_data_dir, "config_fl_filters.json"),
132-
os.path.join(_data_dir, "config_fl_filters.json"),
133-
],
157+
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
134158
}
135159
]
136160

137161

138162
@SkipIfNoModule("ignite")
139163
@SkipIfNoModule("mlflow")
140164
class TestFLMonaiAlgo(unittest.TestCase):
141-
@parameterized.expand([TEST_TRAIN_1, TEST_TRAIN_2, TEST_TRAIN_3])
165+
@parameterized.expand([TEST_TRAIN_1, TEST_TRAIN_2, TEST_TRAIN_3, TEST_TRAIN_4])
142166
def test_train(self, input_params):
143-
# get testing data dir and update train config; using the first to define data dir
144-
if isinstance(input_params["config_train_filename"], list):
145-
config_train_filename = [
146-
os.path.join(input_params["bundle_root"], x) for x in input_params["config_train_filename"]
147-
]
148-
else:
149-
config_train_filename = os.path.join(input_params["bundle_root"], input_params["config_train_filename"])
150-
151-
data_dir = tempfile.mkdtemp()
152-
# test experiment management
153-
input_params["tracking"] = {
154-
"handlers_id": DEFAULT_HANDLERS_ID,
155-
"configs": {
156-
"execute_config": f"{data_dir}/config_executed.json",
157-
"trainer": {
158-
"_target_": "MLFlowHandler",
159-
"tracking_uri": path_to_uri(data_dir) + "/mlflow_override",
160-
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)",
161-
"close_on_complete": True,
162-
},
163-
},
164-
}
165-
166167
# initialize algo
167168
algo = MonaiAlgo(**input_params)
168169
algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"})
169170
algo.abort()
170171

171172
# initialize model
172-
parser = ConfigParser()
173-
parser.read_config(config_train_filename)
173+
parser = ConfigParser(config=deepcopy(algo.train_workflow.parser.get()))
174174
parser.parse()
175175
network = parser.get_parsed_content("network")
176176

@@ -179,27 +179,22 @@ def test_train(self, input_params):
179179
# test train
180180
algo.train(data=data, extra={})
181181
algo.finalize()
182-
self.assertTrue(os.path.exists(f"{data_dir}/mlflow_override"))
183-
self.assertTrue(os.path.exists(f"{data_dir}/config_executed.json"))
184-
shutil.rmtree(data_dir)
182+
183+
# test experiment management
184+
if "execute_config" in algo.train_workflow.parser:
185+
self.assertTrue(os.path.exists(f"{_data_dir}/mlflow_override"))
186+
shutil.rmtree(f"{_data_dir}/mlflow_override")
187+
self.assertTrue(os.path.exists(f"{_data_dir}/config_executed.json"))
188+
os.remove(f"{_data_dir}/config_executed.json")
185189

186190
@parameterized.expand([TEST_EVALUATE_1, TEST_EVALUATE_2, TEST_EVALUATE_3])
187191
def test_evaluate(self, input_params):
188-
# get testing data dir and update train config; using the first to define data dir
189-
if isinstance(input_params["config_evaluate_filename"], list):
190-
config_eval_filename = [
191-
os.path.join(input_params["bundle_root"], x) for x in input_params["config_evaluate_filename"]
192-
]
193-
else:
194-
config_eval_filename = os.path.join(input_params["bundle_root"], input_params["config_evaluate_filename"])
195-
196192
# initialize algo
197193
algo = MonaiAlgo(**input_params)
198194
algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"})
199195

200196
# initialize model
201-
parser = ConfigParser()
202-
parser.read_config(config_eval_filename)
197+
parser = ConfigParser(config=deepcopy(algo.eval_workflow.parser.get()))
203198
parser.parse()
204199
network = parser.get_parsed_content("network")
205200

@@ -208,7 +203,7 @@ def test_evaluate(self, input_params):
208203
# test evaluate
209204
algo.evaluate(data=data, extra={})
210205

211-
@parameterized.expand([TEST_GET_WEIGHTS_1, TEST_GET_WEIGHTS_2, TEST_GET_WEIGHTS_3, TEST_GET_WEIGHTS_4])
206+
@parameterized.expand([TEST_GET_WEIGHTS_1, TEST_GET_WEIGHTS_2, TEST_GET_WEIGHTS_3])
212207
def test_get_weights(self, input_params):
213208
# initialize algo
214209
algo = MonaiAlgo(**input_params)

0 commit comments

Comments
 (0)