Skip to content

5821 6303 Optimize MonaiAlgo FL based on BundleWorkflow #6158

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 43 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
42a45e0
Merge pull request #19 from Project-MONAI/master
Nic-Ma Feb 1, 2021
cd16a13
Merge pull request #32 from Project-MONAI/master
Nic-Ma Feb 24, 2021
6f87afd
Merge pull request #180 from Project-MONAI/dev
Nic-Ma Jul 22, 2021
f398298
Merge pull request #214 from Project-MONAI/dev
Nic-Ma Sep 8, 2021
ec463d6
Merge pull request #397 from Project-MONAI/dev
Nic-Ma Apr 4, 2022
ca62306
Merge pull request #429 from Project-MONAI/dev
Nic-Ma Jul 8, 2022
6b63f3e
Merge branch 'Project-MONAI:main' into main
Nic-Ma Jan 11, 2023
394c31b
Merge pull request #460 from Project-MONAI/dev
Nic-Ma Mar 16, 2023
c0efc01
[DLMED] update MonaiAlgoStats
Nic-Ma Mar 16, 2023
e7774cd
Merge branch 'dev' into 5821-optimize-monaialgo
Nic-Ma Mar 17, 2023
86b72be
[DLMED] update MonaiAlgo
Nic-Ma Mar 17, 2023
cba1490
[DLMED] update distributed tests
Nic-Ma Mar 17, 2023
8114e4a
Merge branch 'dev' into 5821-optimize-monaialgo
Nic-Ma Mar 18, 2023
286112e
Merge branch 'dev' into 5821-optimize-monaialgo
Nic-Ma Mar 20, 2023
f5d2a40
Merge branch 'dev' into 5821-optimize-monaialgo
Nic-Ma Mar 31, 2023
475ded4
[DLMED] update MonaiAlgoStats
Nic-Ma Mar 31, 2023
cab8367
[DLMED] change bundle_root and configs args back
Nic-Ma Mar 31, 2023
cd28fe7
[DLMED] update test cases
Nic-Ma Mar 31, 2023
4641daa
[DLMED] fix mypy
Nic-Ma Mar 31, 2023
cbd6c4d
Merge branch 'dev' into 5821-optimize-monaialgo
Nic-Ma Apr 4, 2023
4888a1c
[DLMED] remove FIXME
Nic-Ma Apr 4, 2023
f96bcd4
[DLMED] fix failed CI tests
Nic-Ma Apr 4, 2023
0030419
[DLMED] revert the change in "get_weights"
Nic-Ma Apr 4, 2023
6f8e102
Merge branch 'dev' into 5821-optimize-monaialgo
Nic-Ma Apr 6, 2023
ee284ff
[DLMED] fix 6303
Nic-Ma Apr 6, 2023
7de4ed7
add weight diff check
holgerroth Apr 6, 2023
7a46063
Merge pull request #462 from holgerroth/5821-optimize-monaialgo_diff_…
Nic-Ma Apr 7, 2023
1cc92e2
Merge branch 'dev' into 5821-optimize-monaialgo
Nic-Ma Apr 7, 2023
fab3313
[DLMED] add checkpint disable
Nic-Ma Apr 7, 2023
f131f04
[DLMED] add disable checkpoint loader
Nic-Ma Apr 7, 2023
28f19c5
[DLMED] fix test logging
Nic-Ma Apr 7, 2023
66f090e
[DLMED] fix mypy
Nic-Ma Apr 7, 2023
dd1100c
enhance dist test
holgerroth Apr 7, 2023
fff8702
Merge pull request #463 from holgerroth/5821-optimize-monaialgo_dist_…
Nic-Ma Apr 7, 2023
d068e03
[MONAI] code formatting
monai-bot Apr 9, 2023
0945aab
Merge branch 'dev' into 5821-optimize-monaialgo
Nic-Ma Apr 11, 2023
f626f83
[DLMED] add eval to dist test
Nic-Ma Apr 11, 2023
7bed38b
Merge branch 'dev' into 5821-optimize-monaialgo
Nic-Ma Apr 14, 2023
7a649d2
[DLMED] add multiple rounds test
Nic-Ma Apr 14, 2023
ceba320
Merge branch 'dev' into 5821-optimize-monaialgo
Nic-Ma Apr 18, 2023
8313a74
[DLMED] added "set_device" back
Nic-Ma Apr 18, 2023
19a10d8
[DLMED] optimize tests
Nic-Ma Apr 18, 2023
0da6603
Merge branch 'dev' into 5821-optimize-monaialgo
Nic-Ma Apr 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions monai/bundle/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@
BundleProperty.REQUIRED: True,
BundlePropertyConfig.ID: "device",
},
"evaluator": {
BundleProperty.DESC: "inference / evaluation workflow engine.",
BundleProperty.REQUIRED: True,
BundlePropertyConfig.ID: "evaluator",
},
"network_def": {
BundleProperty.DESC: "network module for the inference.",
BundleProperty.REQUIRED: True,
Expand Down
12 changes: 9 additions & 3 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,18 @@ class BundleWorkflow(ABC):

"""

supported_train_type: tuple = ("train", "training")
supported_infer_type: tuple = ("infer", "inference", "eval", "evaluation")

def __init__(self, workflow: str | None = None):
if workflow is None:
self.properties = None
self.workflow = None
return
if workflow.lower() in ("train", "training"):
if workflow.lower() in self.supported_train_type:
self.properties = TrainProperties
self.workflow = "train"
elif workflow.lower() in ("infer", "inference", "eval", "evaluation"):
elif workflow.lower() in self.supported_infer_type:
self.properties = InferProperties
self.workflow = "infer"
else:
Expand Down Expand Up @@ -215,6 +218,7 @@ def __init__(
else:
settings_ = ConfigParser.load_config_files(tracking)
self.patch_bundle_tracking(parser=self.parser, settings=settings_)
self._is_initialized: bool = False

def initialize(self) -> Any:
"""
Expand All @@ -223,6 +227,7 @@ def initialize(self) -> Any:
"""
# reset the "reference_resolver" buffer at initialization stage
self.parser.parse(reset=True)
self._is_initialized = True
return self._run_expr(id=self.init_id)

def run(self) -> Any:
Expand Down Expand Up @@ -284,7 +289,7 @@ def _get_property(self, name: str, property: dict) -> Any:
property: other information for the target property, defined in `TrainProperties` or `InferProperties`.

"""
if not self.parser.ref_resolver.is_resolved():
if not self._is_initialized:
raise RuntimeError("Please execute 'initialize' before getting any parsed content.")
prop_id = self._get_prop_id(name, property)
return self.parser.get_parsed_content(id=prop_id) if prop_id is not None else None
Expand All @@ -303,6 +308,7 @@ def _set_property(self, name: str, property: dict, value: Any) -> None:
if prop_id is not None:
self.parser[prop_id] = value
# must parse the config again after changing the content
self._is_initialized = False
self.parser.ref_resolver.reset()

def _check_optional_id(self, name: str, property: dict) -> bool:
Expand Down
335 changes: 151 additions & 184 deletions monai/fl/client/monai_algo.py

Large diffs are not rendered by default.

14 changes: 0 additions & 14 deletions monai/fl/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,6 @@ class FlStatistics(StrEnum):
FEATURE_NAMES = "feature_names"


class RequiredBundleKeys(StrEnum):
BUNDLE_ROOT = "bundle_root"


class BundleKeys(StrEnum):
TRAINER = "train#trainer"
EVALUATOR = "validate#evaluator"
TRAIN_TRAINER_MAX_EPOCHS = "train#trainer#max_epochs"
VALIDATE_HANDLERS = "validate#handlers"
DATASET_DIR = "dataset_dir"
TRAIN_DATA = "train#dataset#data"
VALID_DATA = "validate#dataset#data"


class FiltersType(StrEnum):
PRE_FILTERS = "pre_filters"
POST_WEIGHT_FILTERS = "post_weight_filters"
Expand Down
2 changes: 1 addition & 1 deletion monai/fl/utils/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __call__(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeO

class SummaryFilter(Filter):
"""
Summary filter to content of ExchangeObject.
Summary filter to show content of ExchangeObject.
"""

def __call__(self, data: ExchangeObject, extra: dict | None = None) -> ExchangeObject:
Expand Down
4 changes: 4 additions & 0 deletions tests/nonconfig_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def _get_property(self, name, property):
return self._bundle_root
if name == "device":
return self._device
if name == "evaluator":
return self._evaluator
if name == "network_def":
return self._network_def
if name == "inferer":
Expand All @@ -115,6 +117,8 @@ def _set_property(self, name, property, value):
self._bundle_root = value
elif name == "device":
self._device = value
elif name == "evaluator":
self._evaluator = value
elif name == "network_def":
self._network_def = value
elif name == "inferer":
Expand Down
155 changes: 75 additions & 80 deletions tests/test_fl_monai_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@

import os
import shutil
import tempfile
import unittest
from copy import deepcopy
from os.path import join as pathjoin

from parameterized import parameterized

from monai.bundle import ConfigParser
from monai.bundle import ConfigParser, ConfigWorkflow
from monai.bundle.utils import DEFAULT_HANDLERS_ID
from monai.fl.client.monai_algo import MonaiAlgo
from monai.fl.utils.constants import ExtraItems
Expand All @@ -28,11 +29,14 @@

_root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__)))
_data_dir = os.path.join(_root_dir, "testing_data")
_logging_file = pathjoin(_data_dir, "logging.conf")

TEST_TRAIN_1 = [
{
"bundle_root": _data_dir,
"config_train_filename": os.path.join(_data_dir, "config_fl_train.json"),
"train_workflow": ConfigWorkflow(
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
),
"config_evaluate_filename": None,
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
Expand All @@ -48,68 +52,92 @@
TEST_TRAIN_3 = [
{
"bundle_root": _data_dir,
"config_train_filename": [
os.path.join(_data_dir, "config_fl_train.json"),
os.path.join(_data_dir, "config_fl_train.json"),
],
"train_workflow": ConfigWorkflow(
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
),
"config_evaluate_filename": None,
"config_filters_filename": [
os.path.join(_data_dir, "config_fl_filters.json"),
os.path.join(_data_dir, "config_fl_filters.json"),
],
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]

TEST_TRAIN_4 = [
{
"bundle_root": _data_dir,
"train_workflow": ConfigWorkflow(
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
),
"config_evaluate_filename": None,
"tracking": {
"handlers_id": DEFAULT_HANDLERS_ID,
"configs": {
"execute_config": f"{_data_dir}/config_executed.json",
"trainer": {
"_target_": "MLFlowHandler",
"tracking_uri": path_to_uri(_data_dir) + "/mlflow_override",
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)",
"close_on_complete": True,
},
},
},
"config_filters_filename": None,
}
]

TEST_EVALUATE_1 = [
{
"bundle_root": _data_dir,
"config_train_filename": None,
"config_evaluate_filename": os.path.join(_data_dir, "config_fl_evaluate.json"),
"eval_workflow": ConfigWorkflow(
config_file=[
os.path.join(_data_dir, "config_fl_train.json"),
os.path.join(_data_dir, "config_fl_evaluate.json"),
],
workflow="train",
logging_file=_logging_file,
),
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]
TEST_EVALUATE_2 = [
{
"bundle_root": _data_dir,
"config_train_filename": None,
"config_evaluate_filename": os.path.join(_data_dir, "config_fl_evaluate.json"),
"config_evaluate_filename": [
os.path.join(_data_dir, "config_fl_train.json"),
os.path.join(_data_dir, "config_fl_evaluate.json"),
],
"eval_workflow_name": "training",
"config_filters_filename": None,
}
]
TEST_EVALUATE_3 = [
{
"bundle_root": _data_dir,
"config_train_filename": None,
"config_evaluate_filename": [
os.path.join(_data_dir, "config_fl_evaluate.json"),
os.path.join(_data_dir, "config_fl_evaluate.json"),
],
"config_filters_filename": [
os.path.join(_data_dir, "config_fl_filters.json"),
os.path.join(_data_dir, "config_fl_filters.json"),
],
"eval_workflow": ConfigWorkflow(
config_file=[
os.path.join(_data_dir, "config_fl_train.json"),
os.path.join(_data_dir, "config_fl_evaluate.json"),
],
workflow="train",
logging_file=_logging_file,
),
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]

TEST_GET_WEIGHTS_1 = [
{
"bundle_root": _data_dir,
"config_train_filename": os.path.join(_data_dir, "config_fl_train.json"),
"train_workflow": ConfigWorkflow(
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
),
"config_evaluate_filename": None,
"send_weight_diff": False,
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]
TEST_GET_WEIGHTS_2 = [
{
"bundle_root": _data_dir,
"config_train_filename": None,
"config_evaluate_filename": None,
"send_weight_diff": False,
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]
TEST_GET_WEIGHTS_3 = [
{
"bundle_root": _data_dir,
"config_train_filename": os.path.join(_data_dir, "config_fl_train.json"),
Expand All @@ -118,59 +146,31 @@
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]
TEST_GET_WEIGHTS_4 = [
TEST_GET_WEIGHTS_3 = [
{
"bundle_root": _data_dir,
"config_train_filename": [
os.path.join(_data_dir, "config_fl_train.json"),
os.path.join(_data_dir, "config_fl_train.json"),
],
"train_workflow": ConfigWorkflow(
config_file=os.path.join(_data_dir, "config_fl_train.json"), workflow="train", logging_file=_logging_file
),
"config_evaluate_filename": None,
"send_weight_diff": True,
"config_filters_filename": [
os.path.join(_data_dir, "config_fl_filters.json"),
os.path.join(_data_dir, "config_fl_filters.json"),
],
"config_filters_filename": os.path.join(_data_dir, "config_fl_filters.json"),
}
]


@SkipIfNoModule("ignite")
@SkipIfNoModule("mlflow")
class TestFLMonaiAlgo(unittest.TestCase):
@parameterized.expand([TEST_TRAIN_1, TEST_TRAIN_2, TEST_TRAIN_3])
@parameterized.expand([TEST_TRAIN_1, TEST_TRAIN_2, TEST_TRAIN_3, TEST_TRAIN_4])
def test_train(self, input_params):
# get testing data dir and update train config; using the first to define data dir
if isinstance(input_params["config_train_filename"], list):
config_train_filename = [
os.path.join(input_params["bundle_root"], x) for x in input_params["config_train_filename"]
]
else:
config_train_filename = os.path.join(input_params["bundle_root"], input_params["config_train_filename"])

data_dir = tempfile.mkdtemp()
# test experiment management
input_params["tracking"] = {
"handlers_id": DEFAULT_HANDLERS_ID,
"configs": {
"execute_config": f"{data_dir}/config_executed.json",
"trainer": {
"_target_": "MLFlowHandler",
"tracking_uri": path_to_uri(data_dir) + "/mlflow_override",
"output_transform": "$monai.handlers.from_engine(['loss'], first=True)",
"close_on_complete": True,
},
},
}

# initialize algo
algo = MonaiAlgo(**input_params)
algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"})
algo.abort()

# initialize model
parser = ConfigParser()
parser.read_config(config_train_filename)
parser = ConfigParser(config=deepcopy(algo.train_workflow.parser.get()))
parser.parse()
network = parser.get_parsed_content("network")

Expand All @@ -179,27 +179,22 @@ def test_train(self, input_params):
# test train
algo.train(data=data, extra={})
algo.finalize()
self.assertTrue(os.path.exists(f"{data_dir}/mlflow_override"))
self.assertTrue(os.path.exists(f"{data_dir}/config_executed.json"))
shutil.rmtree(data_dir)

# test experiment management
if "execute_config" in algo.train_workflow.parser:
self.assertTrue(os.path.exists(f"{_data_dir}/mlflow_override"))
shutil.rmtree(f"{_data_dir}/mlflow_override")
self.assertTrue(os.path.exists(f"{_data_dir}/config_executed.json"))
os.remove(f"{_data_dir}/config_executed.json")

@parameterized.expand([TEST_EVALUATE_1, TEST_EVALUATE_2, TEST_EVALUATE_3])
def test_evaluate(self, input_params):
# get testing data dir and update train config; using the first to define data dir
if isinstance(input_params["config_evaluate_filename"], list):
config_eval_filename = [
os.path.join(input_params["bundle_root"], x) for x in input_params["config_evaluate_filename"]
]
else:
config_eval_filename = os.path.join(input_params["bundle_root"], input_params["config_evaluate_filename"])

# initialize algo
algo = MonaiAlgo(**input_params)
algo.initialize(extra={ExtraItems.CLIENT_NAME: "test_fl"})

# initialize model
parser = ConfigParser()
parser.read_config(config_eval_filename)
parser = ConfigParser(config=deepcopy(algo.eval_workflow.parser.get()))
parser.parse()
network = parser.get_parsed_content("network")

Expand All @@ -208,7 +203,7 @@ def test_evaluate(self, input_params):
# test evaluate
algo.evaluate(data=data, extra={})

@parameterized.expand([TEST_GET_WEIGHTS_1, TEST_GET_WEIGHTS_2, TEST_GET_WEIGHTS_3, TEST_GET_WEIGHTS_4])
@parameterized.expand([TEST_GET_WEIGHTS_1, TEST_GET_WEIGHTS_2, TEST_GET_WEIGHTS_3])
def test_get_weights(self, input_params):
# initialize algo
algo = MonaiAlgo(**input_params)
Expand Down
Loading