@@ -24,7 +24,7 @@ def __init__(self, collection_manager, include_collections_names,
24
24
self .save_states_cache = {}
25
25
# todo clear cache for old steps
26
26
self .tensor_to_collection = {}
27
- self .when_nan_tensors = {}
27
+ self .prepared = False
28
28
29
29
def prepare (self ):
30
30
for c_name , c in self .collection_manager .get_collections ().items ():
@@ -47,14 +47,24 @@ def prepare(self):
47
47
48
48
if c .reduction_config is None and self .default_reduction_config is not None :
49
49
c .reduction_config = self .default_reduction_config
50
+ self .prepared = True
50
51
51
52
def _should_collection_be_saved (self , coll_name ):
52
53
return coll_name in self .include_collections_names
53
54
55
+ def _raise_error (self ):
56
+ raise ValueError ('Save Manager is not ready. '
57
+ 'Please call prepare() method '
58
+ 'before calling this method.' )
59
+
54
60
def get_all_collections_to_save (self ):
61
+ if not self .prepared :
62
+ self ._raise_error ()
55
63
return self .save_collections
56
64
57
65
def collections_to_save (self , mode , step ):
66
+ if not self .prepared :
67
+ self ._raise_error ()
58
68
if (mode , step ) not in self .save_states_cache :
59
69
collection_save_state = {}
60
70
for coll in self .save_collections :
@@ -66,21 +76,31 @@ def collections_to_save(self, mode, step):
66
76
return self .save_states_cache [(mode , step )]
67
77
68
78
def get_save_config (self , collection , mode ):
79
+ if not self .prepared :
80
+ self ._raise_error ()
69
81
return self .configs_for_collections [collection .name ].get_save_config (mode )
70
82
71
83
def get_reduction_config (self , collection ):
84
+ if not self .prepared :
85
+ self ._raise_error ()
72
86
return collection .get_reduction_config ()
73
87
74
88
def from_collections (self , tensor_name ):
75
89
# for tf this will be prepopulated because of prepare_tensors
76
90
if not tensor_name in self .tensor_to_collection :
77
91
# for mxnet it is computed and then cached
78
92
matched_colls = []
79
- for coll in self .save_collections :
80
- if tensor_name in coll .tensor_names :
93
+ for coll in self .get_all_collections_to_save ():
94
+ if tensor_name in coll .tensor_names or \
95
+ tensor_name in coll .reduction_tensor_names :
96
+ # if being matched as reduction,
97
+ # it must be in reduction_tensor_name, not with regex
81
98
matched_colls .append (coll )
82
99
elif match_inc (tensor_name , coll .get_include_regex ()):
83
- coll .tensor_names .append (tensor_name )
100
+ if self .get_reduction_config (coll ):
101
+ coll .add_reduction_tensor_name (tensor_name )
102
+ else :
103
+ coll .add_tensor_name (tensor_name )
84
104
matched_colls .append (coll )
85
105
self .tensor_to_collection [tensor_name ] = matched_colls
86
106
return self .tensor_to_collection [tensor_name ]
@@ -100,36 +120,3 @@ def should_save_tensor(self, tensorname, mode, step):
100
120
final_ss ['step' ] = final_ss ['step' ] or ss ['step' ]
101
121
final_ss ['when_nan' ] = final_ss ['when_nan' ] or ss ['when_nan' ]
102
122
return final_ss
103
-
104
- # below are used only by TF
105
- def prepare_tensors (self ):
106
- for c_name , c in self .collection_manager .get_collections ().items ():
107
- if c_name == 'when_nan' :
108
- continue
109
- if c not in self .save_collections :
110
- continue
111
- for t in c .tensors + c .reduction_tensors :
112
- self ._add_tensor_to_collection (t , c )
113
-
114
- def _add_tensor_to_collection (self , t , c ):
115
- if t .name not in self .tensor_to_collection :
116
- self .tensor_to_collection [t .name ] = [c ]
117
- else :
118
- self .tensor_to_collection [t .name ].append (c )
119
-
120
- def add_when_nan_tensor (self , collection , tensor ):
121
- self .configs_for_collections [collection .name ].add_when_nan_tensor (tensor )
122
- if tensor .name not in self .when_nan_tensors :
123
- self .when_nan_tensors [tensor .name ] = []
124
- self .when_nan_tensors [tensor .name ].append (collection )
125
- self ._add_tensor_to_collection (tensor , collection )
126
-
127
- if 'when_nan' not in self .collection_manager .collections :
128
- self .collection_manager .create_collection ('when_nan' )
129
- self .collection_manager .get ('when_nan' ).add_tensor (tensor )
130
-
131
- def is_when_nan_tensor (self , tensor_name ):
132
- return tensor_name in self .when_nan_tensors
133
-
134
- def when_nan_collections (self , tensor_name ):
135
- return self .when_nan_tensors [tensor_name ]
0 commit comments