13
13
14
14
import os
15
15
import shutil
16
- import tempfile
17
16
import unittest
17
+ from copy import deepcopy
18
+ from os .path import join as pathjoin
18
19
19
20
from parameterized import parameterized
20
21
21
- from monai .bundle import ConfigParser
22
+ from monai .bundle import ConfigParser , ConfigWorkflow
22
23
from monai .bundle .utils import DEFAULT_HANDLERS_ID
23
24
from monai .fl .client .monai_algo import MonaiAlgo
24
25
from monai .fl .utils .constants import ExtraItems
28
29
29
30
_root_dir = os .path .abspath (os .path .join (os .path .dirname (__file__ )))
30
31
_data_dir = os .path .join (_root_dir , "testing_data" )
32
+ _logging_file = pathjoin (_data_dir , "logging.conf" )
31
33
32
34
TEST_TRAIN_1 = [
33
35
{
34
36
"bundle_root" : _data_dir ,
35
- "config_train_filename" : os .path .join (_data_dir , "config_fl_train.json" ),
37
+ "train_workflow" : ConfigWorkflow (
38
+ config_file = os .path .join (_data_dir , "config_fl_train.json" ), workflow = "train" , logging_file = _logging_file
39
+ ),
36
40
"config_evaluate_filename" : None ,
37
41
"config_filters_filename" : os .path .join (_data_dir , "config_fl_filters.json" ),
38
42
}
48
52
TEST_TRAIN_3 = [
49
53
{
50
54
"bundle_root" : _data_dir ,
51
- "config_train_filename" : [
52
- os .path .join (_data_dir , "config_fl_train.json" ),
53
- os .path .join (_data_dir , "config_fl_train.json" ),
54
- ],
55
+ "train_workflow" : ConfigWorkflow (
56
+ config_file = os .path .join (_data_dir , "config_fl_train.json" ), workflow = "train" , logging_file = _logging_file
57
+ ),
55
58
"config_evaluate_filename" : None ,
56
- "config_filters_filename" : [
57
- os .path .join (_data_dir , "config_fl_filters.json" ),
58
- os .path .join (_data_dir , "config_fl_filters.json" ),
59
- ],
59
+ "config_filters_filename" : os .path .join (_data_dir , "config_fl_filters.json" ),
60
+ }
61
+ ]
62
+
63
+ TEST_TRAIN_4 = [
64
+ {
65
+ "bundle_root" : _data_dir ,
66
+ "train_workflow" : ConfigWorkflow (
67
+ config_file = os .path .join (_data_dir , "config_fl_train.json" ), workflow = "train" , logging_file = _logging_file
68
+ ),
69
+ "config_evaluate_filename" : None ,
70
+ "tracking" : {
71
+ "handlers_id" : DEFAULT_HANDLERS_ID ,
72
+ "configs" : {
73
+ "execute_config" : f"{ _data_dir } /config_executed.json" ,
74
+ "trainer" : {
75
+ "_target_" : "MLFlowHandler" ,
76
+ "tracking_uri" : path_to_uri (_data_dir ) + "/mlflow_override" ,
77
+ "output_transform" : "$monai.handlers.from_engine(['loss'], first=True)" ,
78
+ "close_on_complete" : True ,
79
+ },
80
+ },
81
+ },
82
+ "config_filters_filename" : None ,
60
83
}
61
84
]
62
85
63
86
TEST_EVALUATE_1 = [
64
87
{
65
88
"bundle_root" : _data_dir ,
66
89
"config_train_filename" : None ,
67
- "config_evaluate_filename" : os .path .join (_data_dir , "config_fl_evaluate.json" ),
90
+ "eval_workflow" : ConfigWorkflow (
91
+ config_file = [
92
+ os .path .join (_data_dir , "config_fl_train.json" ),
93
+ os .path .join (_data_dir , "config_fl_evaluate.json" ),
94
+ ],
95
+ workflow = "train" ,
96
+ logging_file = _logging_file ,
97
+ ),
68
98
"config_filters_filename" : os .path .join (_data_dir , "config_fl_filters.json" ),
69
99
}
70
100
]
71
101
TEST_EVALUATE_2 = [
72
102
{
73
103
"bundle_root" : _data_dir ,
74
104
"config_train_filename" : None ,
75
- "config_evaluate_filename" : os .path .join (_data_dir , "config_fl_evaluate.json" ),
105
+ "config_evaluate_filename" : [
106
+ os .path .join (_data_dir , "config_fl_train.json" ),
107
+ os .path .join (_data_dir , "config_fl_evaluate.json" ),
108
+ ],
109
+ "eval_workflow_name" : "training" ,
76
110
"config_filters_filename" : None ,
77
111
}
78
112
]
79
113
TEST_EVALUATE_3 = [
80
114
{
81
115
"bundle_root" : _data_dir ,
82
116
"config_train_filename" : None ,
83
- "config_evaluate_filename" : [
84
- os .path .join (_data_dir , "config_fl_evaluate.json" ),
85
- os .path .join (_data_dir , "config_fl_evaluate.json" ),
86
- ],
87
- "config_filters_filename" : [
88
- os .path .join (_data_dir , "config_fl_filters.json" ),
89
- os .path .join (_data_dir , "config_fl_filters.json" ),
90
- ],
117
+ "eval_workflow" : ConfigWorkflow (
118
+ config_file = [
119
+ os .path .join (_data_dir , "config_fl_train.json" ),
120
+ os .path .join (_data_dir , "config_fl_evaluate.json" ),
121
+ ],
122
+ workflow = "train" ,
123
+ logging_file = _logging_file ,
124
+ ),
125
+ "config_filters_filename" : os .path .join (_data_dir , "config_fl_filters.json" ),
91
126
}
92
127
]
93
128
94
129
TEST_GET_WEIGHTS_1 = [
95
130
{
96
131
"bundle_root" : _data_dir ,
97
- "config_train_filename" : os .path .join (_data_dir , "config_fl_train.json" ),
132
+ "train_workflow" : ConfigWorkflow (
133
+ config_file = os .path .join (_data_dir , "config_fl_train.json" ), workflow = "train" , logging_file = _logging_file
134
+ ),
98
135
"config_evaluate_filename" : None ,
99
136
"send_weight_diff" : False ,
100
137
"config_filters_filename" : os .path .join (_data_dir , "config_fl_filters.json" ),
101
138
}
102
139
]
103
140
TEST_GET_WEIGHTS_2 = [
104
- {
105
- "bundle_root" : _data_dir ,
106
- "config_train_filename" : None ,
107
- "config_evaluate_filename" : None ,
108
- "send_weight_diff" : False ,
109
- "config_filters_filename" : os .path .join (_data_dir , "config_fl_filters.json" ),
110
- }
111
- ]
112
- TEST_GET_WEIGHTS_3 = [
113
141
{
114
142
"bundle_root" : _data_dir ,
115
143
"config_train_filename" : os .path .join (_data_dir , "config_fl_train.json" ),
118
146
"config_filters_filename" : os .path .join (_data_dir , "config_fl_filters.json" ),
119
147
}
120
148
]
121
- TEST_GET_WEIGHTS_4 = [
149
+ TEST_GET_WEIGHTS_3 = [
122
150
{
123
151
"bundle_root" : _data_dir ,
124
- "config_train_filename" : [
125
- os .path .join (_data_dir , "config_fl_train.json" ),
126
- os .path .join (_data_dir , "config_fl_train.json" ),
127
- ],
152
+ "train_workflow" : ConfigWorkflow (
153
+ config_file = os .path .join (_data_dir , "config_fl_train.json" ), workflow = "train" , logging_file = _logging_file
154
+ ),
128
155
"config_evaluate_filename" : None ,
129
156
"send_weight_diff" : True ,
130
- "config_filters_filename" : [
131
- os .path .join (_data_dir , "config_fl_filters.json" ),
132
- os .path .join (_data_dir , "config_fl_filters.json" ),
133
- ],
157
+ "config_filters_filename" : os .path .join (_data_dir , "config_fl_filters.json" ),
134
158
}
135
159
]
136
160
137
161
138
162
@SkipIfNoModule ("ignite" )
139
163
@SkipIfNoModule ("mlflow" )
140
164
class TestFLMonaiAlgo (unittest .TestCase ):
141
- @parameterized .expand ([TEST_TRAIN_1 , TEST_TRAIN_2 , TEST_TRAIN_3 ])
165
+ @parameterized .expand ([TEST_TRAIN_1 , TEST_TRAIN_2 , TEST_TRAIN_3 , TEST_TRAIN_4 ])
142
166
def test_train (self , input_params ):
143
- # get testing data dir and update train config; using the first to define data dir
144
- if isinstance (input_params ["config_train_filename" ], list ):
145
- config_train_filename = [
146
- os .path .join (input_params ["bundle_root" ], x ) for x in input_params ["config_train_filename" ]
147
- ]
148
- else :
149
- config_train_filename = os .path .join (input_params ["bundle_root" ], input_params ["config_train_filename" ])
150
-
151
- data_dir = tempfile .mkdtemp ()
152
- # test experiment management
153
- input_params ["tracking" ] = {
154
- "handlers_id" : DEFAULT_HANDLERS_ID ,
155
- "configs" : {
156
- "execute_config" : f"{ data_dir } /config_executed.json" ,
157
- "trainer" : {
158
- "_target_" : "MLFlowHandler" ,
159
- "tracking_uri" : path_to_uri (data_dir ) + "/mlflow_override" ,
160
- "output_transform" : "$monai.handlers.from_engine(['loss'], first=True)" ,
161
- "close_on_complete" : True ,
162
- },
163
- },
164
- }
165
-
166
167
# initialize algo
167
168
algo = MonaiAlgo (** input_params )
168
169
algo .initialize (extra = {ExtraItems .CLIENT_NAME : "test_fl" })
169
170
algo .abort ()
170
171
171
172
# initialize model
172
- parser = ConfigParser ()
173
- parser .read_config (config_train_filename )
173
+ parser = ConfigParser (config = deepcopy (algo .train_workflow .parser .get ()))
174
174
parser .parse ()
175
175
network = parser .get_parsed_content ("network" )
176
176
@@ -179,27 +179,22 @@ def test_train(self, input_params):
179
179
# test train
180
180
algo .train (data = data , extra = {})
181
181
algo .finalize ()
182
- self .assertTrue (os .path .exists (f"{ data_dir } /mlflow_override" ))
183
- self .assertTrue (os .path .exists (f"{ data_dir } /config_executed.json" ))
184
- shutil .rmtree (data_dir )
182
+
183
+ # test experiment management
184
+ if "execute_config" in algo .train_workflow .parser :
185
+ self .assertTrue (os .path .exists (f"{ _data_dir } /mlflow_override" ))
186
+ shutil .rmtree (f"{ _data_dir } /mlflow_override" )
187
+ self .assertTrue (os .path .exists (f"{ _data_dir } /config_executed.json" ))
188
+ os .remove (f"{ _data_dir } /config_executed.json" )
185
189
186
190
@parameterized .expand ([TEST_EVALUATE_1 , TEST_EVALUATE_2 , TEST_EVALUATE_3 ])
187
191
def test_evaluate (self , input_params ):
188
- # get testing data dir and update train config; using the first to define data dir
189
- if isinstance (input_params ["config_evaluate_filename" ], list ):
190
- config_eval_filename = [
191
- os .path .join (input_params ["bundle_root" ], x ) for x in input_params ["config_evaluate_filename" ]
192
- ]
193
- else :
194
- config_eval_filename = os .path .join (input_params ["bundle_root" ], input_params ["config_evaluate_filename" ])
195
-
196
192
# initialize algo
197
193
algo = MonaiAlgo (** input_params )
198
194
algo .initialize (extra = {ExtraItems .CLIENT_NAME : "test_fl" })
199
195
200
196
# initialize model
201
- parser = ConfigParser ()
202
- parser .read_config (config_eval_filename )
197
+ parser = ConfigParser (config = deepcopy (algo .eval_workflow .parser .get ()))
203
198
parser .parse ()
204
199
network = parser .get_parsed_content ("network" )
205
200
@@ -208,7 +203,7 @@ def test_evaluate(self, input_params):
208
203
# test evaluate
209
204
algo .evaluate (data = data , extra = {})
210
205
211
- @parameterized .expand ([TEST_GET_WEIGHTS_1 , TEST_GET_WEIGHTS_2 , TEST_GET_WEIGHTS_3 , TEST_GET_WEIGHTS_4 ])
206
+ @parameterized .expand ([TEST_GET_WEIGHTS_1 , TEST_GET_WEIGHTS_2 , TEST_GET_WEIGHTS_3 ])
212
207
def test_get_weights (self , input_params ):
213
208
# initialize algo
214
209
algo = MonaiAlgo (** input_params )
0 commit comments