Skip to content

Commit eb0fdef

Browse files
authored
Collections only handle tensor_names now, Cleanup (aws#60)
* make core use only names * update branches * try to fix boto errors * accept yes prompt * trigger ci * make set * trigger CI * fix test * trigger ci * fix branch * trigger ci
1 parent 8320cee commit eb0fdef

File tree

5 files changed

+47
-61
lines changed

5 files changed

+47
-61
lines changed

config/buildspec.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ phases:
1212
- cd $CODEBUILD_SRC_DIR && chmod +x config/protoc_downloader.sh && ./config/protoc_downloader.sh
1313
- pip install pytest
1414
- pip install wheel
15+
- pip uninstall -y boto3 && pip uninstall -y aiobotocore && pip uninstall -y botocore
1516

1617
pre_build:
1718
commands:

config/configure_branch_for_test.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
tornasole_core:default
2-
tornasole_tf:default
3-
tornasole_mxnet:default
4-
tornasole_rules:default
1+
tornasole_core:redn_fix
2+
tornasole_tf:redn_fix
3+
tornasole_mxnet:redn_fix
4+
tornasole_rules:fix_mode_available
55
tornasole_pytorch:default

tests/core/test_collections.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@ def test_manager():
2828
cm = CollectionManager()
2929
cm.create_collection('default')
3030
cm.get('default').include('loss')
31+
cm.get('default').add_tensor_name('assaas')
3132
cm.add(Collection('trial1'))
3233
cm.add('trial2')
3334
cm.get('trial2').include('total_loss')
3435
assert len(cm.collections) == 3
3536
assert cm.get('default') == cm.collections['default']
3637
assert 'loss' in cm.get('default').include_regex
38+
assert len(cm.get('default').get_tensor_names()) > 0
3739
assert 'total_loss' in cm.collections['trial2'].include_regex

tornasole_core/collection.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,18 @@ def __init__(self, name, include_regex=None,
4242
self.reduction_config = reduction_config
4343
self.save_config = save_config
4444

45-
self.tensor_names = []
46-
self.reduction_tensor_names = []
47-
48-
# these two are internal fields only, used by TF
49-
self.tensors = []
50-
self.reduction_tensors = []
45+
self.tensor_names = set()
46+
self.reduction_tensor_names = set()
5147

5248
def get_include_regex(self):
5349
return self.include_regex
5450

5551
def get_tensor_names(self):
5652
return self.tensor_names
5753

54+
def get_reduction_tensor_names(self):
55+
return self.reduction_tensor_names
56+
5857
def include(self, t):
5958
if isinstance(t, list):
6059
for i in t:
@@ -80,20 +79,17 @@ def set_save_config(self, save_cfg):
8079
raise TypeError('Can only take an instance of SaveConfig')
8180
self.save_config = save_cfg
8281

83-
def add_tensor(self, t):
84-
if t.name not in self.tensor_names:
85-
self.tensor_names.append(t.name)
86-
self.tensors.append(t)
82+
def add_tensor_name(self, tname):
83+
if tname not in self.tensor_names:
84+
self.tensor_names.add(tname)
8785

88-
def remove_tensor(self, t):
89-
if t.name in self.tensor_names:
90-
self.tensor_names.remove(t.name)
91-
if t in self.tensors:
92-
self.tensors.remove(t)
86+
def remove_tensor_name(self, tname):
87+
if tname in self.tensor_names:
88+
self.tensor_names.remove(tname)
9389

94-
def add_reduction_tensor(self, s):
95-
self.reduction_tensor_names.append(s.name)
96-
self.reduction_tensors.append(s)
90+
def add_reduction_tensor_name(self, sname):
91+
if sname not in self.reduction_tensor_names:
92+
self.reduction_tensor_names.add(sname)
9793

9894
def export(self):
9995
# v0 export
@@ -135,8 +131,8 @@ def load(s):
135131
list_separator = ','
136132
name = parts[1]
137133
include = [x for x in parts[2].split(list_separator) if x]
138-
tensor_names = [x for x in parts[3].split(list_separator) if x]
139-
reduction_tensor_names = [x for x in parts[4].split(list_separator) if x]
134+
tensor_names = set([x for x in parts[3].split(list_separator) if x])
135+
reduction_tensor_names = set([x for x in parts[4].split(list_separator) if x])
140136
reduction_config = ReductionConfig.load(parts[5])
141137
save_config = SaveConfig.load(parts[6])
142138
c = Collection(name, include_regex=include,

tornasole_core/save_manager.py

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, collection_manager, include_collections_names,
2424
self.save_states_cache = {}
2525
# todo clear cache for old steps
2626
self.tensor_to_collection = {}
27-
self.when_nan_tensors = {}
27+
self.prepared = False
2828

2929
def prepare(self):
3030
for c_name, c in self.collection_manager.get_collections().items():
@@ -47,14 +47,24 @@ def prepare(self):
4747

4848
if c.reduction_config is None and self.default_reduction_config is not None:
4949
c.reduction_config = self.default_reduction_config
50+
self.prepared = True
5051

5152
def _should_collection_be_saved(self, coll_name):
5253
return coll_name in self.include_collections_names
5354

55+
def _raise_error(self):
56+
raise ValueError('Save Manager is not ready. '
57+
'Please call prepare() method '
58+
'before calling this method.')
59+
5460
def get_all_collections_to_save(self):
61+
if not self.prepared:
62+
self._raise_error()
5563
return self.save_collections
5664

5765
def collections_to_save(self, mode, step):
66+
if not self.prepared:
67+
self._raise_error()
5868
if (mode, step) not in self.save_states_cache:
5969
collection_save_state = {}
6070
for coll in self.save_collections:
@@ -66,21 +76,31 @@ def collections_to_save(self, mode, step):
6676
return self.save_states_cache[(mode, step)]
6777

6878
def get_save_config(self, collection, mode):
79+
if not self.prepared:
80+
self._raise_error()
6981
return self.configs_for_collections[collection.name].get_save_config(mode)
7082

7183
def get_reduction_config(self, collection):
84+
if not self.prepared:
85+
self._raise_error()
7286
return collection.get_reduction_config()
7387

7488
def from_collections(self, tensor_name):
7589
# for tf this will be prepopulated because of prepare_tensors
7690
if not tensor_name in self.tensor_to_collection:
7791
# for mxnet it is computed and then cached
7892
matched_colls = []
79-
for coll in self.save_collections:
80-
if tensor_name in coll.tensor_names:
93+
for coll in self.get_all_collections_to_save():
94+
if tensor_name in coll.tensor_names or \
95+
tensor_name in coll.reduction_tensor_names:
96+
# if being matched as reduction,
97+
# it must be in reduction_tensor_name, not with regex
8198
matched_colls.append(coll)
8299
elif match_inc(tensor_name, coll.get_include_regex()):
83-
coll.tensor_names.append(tensor_name)
100+
if self.get_reduction_config(coll):
101+
coll.add_reduction_tensor_name(tensor_name)
102+
else:
103+
coll.add_tensor_name(tensor_name)
84104
matched_colls.append(coll)
85105
self.tensor_to_collection[tensor_name] = matched_colls
86106
return self.tensor_to_collection[tensor_name]
@@ -100,36 +120,3 @@ def should_save_tensor(self, tensorname, mode, step):
100120
final_ss['step'] = final_ss['step'] or ss['step']
101121
final_ss['when_nan'] = final_ss['when_nan'] or ss['when_nan']
102122
return final_ss
103-
104-
# below are used only by TF
105-
def prepare_tensors(self):
106-
for c_name, c in self.collection_manager.get_collections().items():
107-
if c_name == 'when_nan':
108-
continue
109-
if c not in self.save_collections:
110-
continue
111-
for t in c.tensors + c.reduction_tensors:
112-
self._add_tensor_to_collection(t, c)
113-
114-
def _add_tensor_to_collection(self, t, c):
115-
if t.name not in self.tensor_to_collection:
116-
self.tensor_to_collection[t.name] = [c]
117-
else:
118-
self.tensor_to_collection[t.name].append(c)
119-
120-
def add_when_nan_tensor(self, collection, tensor):
121-
self.configs_for_collections[collection.name].add_when_nan_tensor(tensor)
122-
if tensor.name not in self.when_nan_tensors:
123-
self.when_nan_tensors[tensor.name] = []
124-
self.when_nan_tensors[tensor.name].append(collection)
125-
self._add_tensor_to_collection(tensor, collection)
126-
127-
if 'when_nan' not in self.collection_manager.collections:
128-
self.collection_manager.create_collection('when_nan')
129-
self.collection_manager.get('when_nan').add_tensor(tensor)
130-
131-
def is_when_nan_tensor(self, tensor_name):
132-
return tensor_name in self.when_nan_tensors
133-
134-
def when_nan_collections(self, tensor_name):
135-
return self.when_nan_tensors[tensor_name]

0 commit comments

Comments
 (0)