Skip to content

Commit b8661de

Browse files
authored
Disable TB Testing (aws#275)
1 parent c66b509 commit b8661de

File tree

2 files changed

+85
-3
lines changed

2 files changed

+85
-3
lines changed

tests/utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Standard Library
2+
import os
3+
import shutil
4+
from pathlib import Path
5+
6+
# First Party
7+
from smdebug.core.config_constants import (
8+
CONFIG_FILE_PATH_ENV_STR,
9+
DEFAULT_SAGEMAKER_OUTDIR,
10+
DEFAULT_SAGEMAKER_TENSORBOARD_PATH,
11+
TENSORBOARD_CONFIG_FILE_PATH_ENV_STR,
12+
)
13+
from smdebug.core.utils import remove_file_if_exists
14+
15+
16+
class SagemakerSimulator(object):
17+
"""
18+
Creates an environment variable pointing to a JSON config file, and creates the config file.
19+
Used for integration testing with zero-code-change.
20+
21+
If `disable=True`, then we still create the `out_dir` directory, but ignore the config file.
22+
"""
23+
24+
def __init__(
25+
self,
26+
json_config_path="/tmp/zcc_config.json",
27+
tensorboard_dir="/tmp/tensorboard",
28+
training_job_name="sm_job",
29+
json_file_contents="{}",
30+
enable_tb=True,
31+
cleanup=True,
32+
):
33+
self.out_dir = DEFAULT_SAGEMAKER_OUTDIR
34+
self.json_config_path = json_config_path
35+
self.tb_json_config_path = DEFAULT_SAGEMAKER_TENSORBOARD_PATH
36+
self.tensorboard_dir = tensorboard_dir
37+
self.training_job_name = training_job_name
38+
self.json_file_contents = json_file_contents
39+
self.enable_tb = enable_tb
40+
self.cleanup = cleanup
41+
42+
def __enter__(self):
43+
if self.cleanup is True:
44+
shutil.rmtree(self.out_dir, ignore_errors=True)
45+
shutil.rmtree(self.json_config_path, ignore_errors=True)
46+
tb_parent_dir = str(Path(self.tb_json_config_path).parent)
47+
shutil.rmtree(tb_parent_dir, ignore_errors=True)
48+
49+
os.environ[CONFIG_FILE_PATH_ENV_STR] = self.json_config_path
50+
os.environ["TRAINING_JOB_NAME"] = self.training_job_name
51+
with open(self.json_config_path, "w+") as my_file:
52+
# We'll just use the defaults, but the file is expected to exist
53+
my_file.write(self.json_file_contents)
54+
55+
if self.enable_tb is True:
56+
os.environ[TENSORBOARD_CONFIG_FILE_PATH_ENV_STR] = self.tb_json_config_path
57+
os.makedirs(tb_parent_dir, exist_ok=True)
58+
with open(self.tb_json_config_path, "w+") as my_file:
59+
my_file.write(
60+
f"""
61+
{{
62+
"LocalPath": "{self.tensorboard_dir}"
63+
}}
64+
"""
65+
)
66+
67+
return self
68+
69+
def __exit__(self, *args):
70+
# Throws errors when the writers try to close.
71+
# shutil.rmtree(self.out_dir, ignore_errors=True)
72+
if self.cleanup is True:
73+
remove_file_if_exists(self.json_config_path)
74+
remove_file_if_exists(self.tb_json_config_path)
75+
if CONFIG_FILE_PATH_ENV_STR in os.environ:
76+
del os.environ[CONFIG_FILE_PATH_ENV_STR]
77+
if "TRAINING_JOB_NAME" in os.environ:
78+
del os.environ["TRAINING_JOB_NAME"]
79+
if TENSORBOARD_CONFIG_FILE_PATH_ENV_STR in os.environ:
80+
del os.environ[TENSORBOARD_CONFIG_FILE_PATH_ENV_STR]

tests/zero_code_change/test_tensorflow2_integration.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
# Third Party
2121
import pytest
2222
import tensorflow.compat.v2 as tf
23+
from tests.utils import SagemakerSimulator
2324

2425
# First Party
2526
import smdebug.tensorflow as smd
2627
from smdebug.core.collection import CollectionKeys
27-
from smdebug.core.utils import SagemakerSimulator
2828

2929

3030
def get_keras_model_v2():
@@ -52,7 +52,8 @@ def helper_test_keras_v2(script_mode: bool = False, eager_mode: bool = True):
5252
tf.keras.backend.clear_session()
5353
if not eager_mode:
5454
tf.compat.v1.disable_eager_execution()
55-
with SagemakerSimulator() as sim:
55+
enable_tb = False if tf.__version__ == "2.0.2" else True
56+
with SagemakerSimulator(enable_tb=enable_tb) as sim:
5657
model = get_keras_model_v2()
5758
(x_train, y_train), (x_test, y_test) = get_keras_data()
5859
x_train, x_test = x_train / 255, x_test / 255
@@ -102,7 +103,8 @@ def helper_test_keras_v2_json_config(
102103
tf.keras.backend.clear_session()
103104
if not eager_mode:
104105
tf.compat.v1.disable_eager_execution()
105-
with SagemakerSimulator(json_file_contents=json_file_contents) as sim:
106+
enable_tb = False if tf.__version__ == "2.0.2" else True
107+
with SagemakerSimulator(json_file_contents=json_file_contents, enable_tb=enable_tb) as sim:
106108
model = get_keras_model_v2()
107109
(x_train, y_train), (x_test, y_test) = get_keras_data()
108110
x_train, x_test = x_train / 255, x_test / 255

0 commit comments

Comments
 (0)