Skip to content

Commit 1954283

Browse files
authored
Support Saving Tensors in Graph Mode with add_for_mode (aws#353)
1 parent eed6b4c commit 1954283

File tree

2 files changed

+79
-1
lines changed

2 files changed

+79
-1
lines changed

smdebug/tensorflow/keras.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,15 @@ def _prepare_layers(self, mode):
350350
for w in weights:
351351
self._check_and_add_layer_tensor(mode, layer, "weight", w)
352352

353+
def _prepare_non_layer_tensors(self):
354+
for coll in self.collection_manager.get_collections().values():
355+
collection_values = coll.get_tensors()
356+
for tensor_ref in collection_values:
357+
if tensor_ref.name not in self.tensor_to_collections:
358+
self.tensor_to_collections[tensor_ref.name] = {coll}
359+
elif coll not in self.tensor_to_collections[tensor_ref.name]:
360+
self.tensor_to_collections[tensor_ref.name].add(coll)
361+
353362
def _prepare_tensors_available_post_step(self):
354363
# for gradients, optimizer_variables
355364
custom_collections, _ = self._get_custom_and_default_collections()
@@ -359,7 +368,8 @@ def _prepare_tensors_available_post_step(self):
359368
self.get_collection(name=CollectionKeys.OUTPUTS),
360369
self.get_collection(name=CollectionKeys.INPUTS),
361370
]:
362-
for tensor_ref in coll.get_tensors():
371+
collection_values = coll.get_tensors()
372+
for tensor_ref in collection_values:
363373
if tensor_ref.name not in self.tensor_to_collections:
364374
self.tensor_to_collections[tensor_ref.name] = {coll}
365375
elif coll not in self.tensor_to_collections[tensor_ref.name]:
@@ -729,6 +739,7 @@ def _on_any_batch_begin(self, batch, mode, logs=None):
729739
self._get_exec_function(mode)
730740
):
731741
self._prepare_layers(mode)
742+
self._prepare_non_layer_tensors()
732743
self._prepare_tensors_available_post_step()
733744
self._prepared_tensors[mode] = True
734745
# below should be after tensors are processed,
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Third Party
2+
import tensorflow as tf
3+
from tensorflow.keras.layers import Conv2D, Dense, Dropout, Flatten, Input, MaxPooling2D
4+
5+
# First Party
6+
import smdebug.tensorflow as smd
7+
8+
9+
def create_dataset():
10+
# Download and load MNIST dataset.
11+
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data("MNIST-data")
12+
x_train, x_test = x_train / 255.0, x_test / 255.0
13+
14+
# Add a channels dimension
15+
x_train = x_train[..., tf.newaxis]
16+
x_test = x_test[..., tf.newaxis]
17+
18+
train_ds = (
19+
tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000, seed=123).batch(2)
20+
)
21+
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(2)
22+
23+
return train_ds, test_ds
24+
25+
26+
def test_functional_model(out_dir, tf_eager_mode):
27+
if tf_eager_mode is False:
28+
tf.compat.v1.disable_eager_execution()
29+
else:
30+
return
31+
num_classes = 10
32+
train_ds, test_ds = create_dataset()
33+
34+
# Input image dimensions
35+
img_rows, img_cols = 28, 28
36+
37+
img_inputs = Input(shape=(28, 28, 1))
38+
x = Conv2D(32, kernel_size=(3, 3), activation="relu")(img_inputs)
39+
x1 = Conv2D(64, (3, 3), activation="relu")(x)
40+
x = MaxPooling2D(pool_size=(2, 2))(x1)
41+
x = Dropout(0.25)(x)
42+
x = Flatten()(x)
43+
x = Dense(128, activation="relu")(x)
44+
x = Dropout(0.5)(x)
45+
out = Dense(num_classes, activation="softmax")(x)
46+
47+
model = tf.keras.models.Model(inputs=img_inputs, outputs=out)
48+
49+
smd_callback = smd.KerasHook(
50+
export_tensorboard=False, out_dir=out_dir, include_collections=["custom"]
51+
)
52+
53+
smd_callback.get_collection("custom").add_for_mode([x1], mode=smd.modes.TRAIN)
54+
smd_callback.save_config = smd.SaveConfig(save_interval=1)
55+
opt = tf.keras.optimizers.Adadelta(1.0)
56+
57+
model.compile(
58+
loss=tf.keras.losses.sparse_categorical_crossentropy,
59+
optimizer=opt,
60+
experimental_run_tf_function=False,
61+
)
62+
63+
callbacks = [smd_callback]
64+
model.fit(train_ds, epochs=1, steps_per_epoch=100, callbacks=callbacks)
65+
66+
trial = smd.create_trial(out_dir)
67+
assert len(trial.tensor_names(collection="custom")) == 1

0 commit comments

Comments
 (0)