Skip to content

Commit d277d8c

Browse files
aivanoufacebook-github-bot
authored andcommitted
Add interpret docs to example component, remove test arg from cv tr… (#255)
Summary: …ainer Pull Request resolved: #255 Reviewed By: kiukchung Differential Revision: D31660518 Pulled By: aivanou fbshipit-source-id: dc4e3c4698545a318875b5b8fe970cf7bb5b2cd4
1 parent 6236614 commit d277d8c

File tree

5 files changed

+43
-30
lines changed

5 files changed

+43
-30
lines changed

torchx/components/interpret.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@
1616
See the
1717
:ref:`examples_apps/lightning_classy_vision/interpret:Model Interpretability App Example`
1818
and the corresponding
19-
:ref:`Interpret component definition<examples_apps/lightning_classy_vision/component:Trainer Component Examples>`
19+
:ref:`Interpret component definition<examples_apps/lightning_classy_vision/component:Interpreting the Model>`
2020
for an example of how to use Captum with TorchX.
2121
"""

torchx/examples/apps/lightning_classy_vision/component.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
#
4747
# torchx run --scheduler local_cwd \
4848
# ./torchx/examples/apps/lightning_classy_vision/component.py:trainer \
49-
# --output_path /tmp
49+
# --output_path /tmp/$USER
5050
#
5151
# Single trainer component code:
5252

@@ -258,44 +258,58 @@ def trainer_dist(
258258

259259

260260
# %%
261-
# Model Interpretability
261+
# Interpreting the Model
262262
# #######################
263-
# TODO(aivanou): add documentation
263+
# Defines a component that interprets the model
264+
#
265+
# Train a single trainer example: :ref:`examples_apps/lightning_classy_vision/component:Single Trainer Component`
266+
# And use the following cmd to try out:
267+
#
268+
# .. code:: bash
269+
#
270+
# torchx run --scheduler local_cwd \
271+
# ./torchx/examples/apps/lightning_classy_vision/component.py:interpret \
272+
# --output_path /tmp/aivanou/interpret --load_path /tmp/$USER/last.ckpt
264273

265274

266275
def interpret(
267-
image: str,
268276
load_path: str,
269-
data_path: str,
270277
output_path: str,
278+
data_path: Optional[str] = None,
279+
image: str = TORCHX_IMAGE,
271280
resource: Optional[str] = None,
272281
) -> torchx.AppDef:
273282
"""Runs the model interpretability app on the model outputted by the training
274283
component.
275284
276285
Args:
277-
image: image to run (e.g. foobar:latest)
278286
load_path: path to load pretrained model from
279-
data_path: path to the data to load
280287
output_path: output path for model checkpoints (e.g. file:///foo/bar)
288+
data_path: path to the data to load
289+
image: image to run (e.g. foobar:latest)
281290
resource: the resources to use
282291
"""
292+
args = [
293+
"-m",
294+
"torchx.examples.apps.lightning_classy_vision.interpret",
295+
"--load_path",
296+
load_path,
297+
"--output_path",
298+
output_path,
299+
]
300+
if data_path:
301+
args += [
302+
"--data_path",
303+
data_path,
304+
]
305+
283306
return torchx.AppDef(
284307
name="cv-interpret",
285308
roles=[
286309
torchx.Role(
287310
name="worker",
288311
entrypoint="python",
289-
args=[
290-
"-m",
291-
"torchx.examples.apps.lightning_classy_vision.interpret",
292-
"--load_path",
293-
load_path,
294-
"--data_path",
295-
data_path,
296-
"--output_path",
297-
output_path,
298-
],
312+
args=args,
299313
image=image,
300314
resource=named_resources[resource]
301315
if resource

torchx/examples/apps/lightning_classy_vision/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def download_data(remote_path: str, tmpdir: str) -> str:
148148
return data_path
149149

150150

151-
def create_random_data(output_path: str, num_images: int = 5) -> None:
151+
def create_random_data(output_path: str, num_images: int = 250) -> None:
152152
"""
153153
Fills the given path with randomly generated 64x64 images.
154154
This can be used for quick testing of the workflow of the model.

torchx/examples/apps/lightning_classy_vision/interpret.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@
3535
from torchx.examples.apps.lightning_classy_vision.data import (
3636
TinyImageNetDataModule,
3737
download_data,
38+
create_random_data,
3839
)
3940
from torchx.examples.apps.lightning_classy_vision.model import TinyImageNetModel
4041

41-
4242
# FIXME: captum must be imported after torch otherwise it causes python to crash
4343
if True:
4444
import numpy as np
@@ -57,8 +57,7 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
5757
parser.add_argument(
5858
"--data_path",
5959
type=str,
60-
help="path to load the training data from",
61-
required=True,
60+
help="path to load the training data from, if not provided, random dataset will be created",
6261
)
6362
parser.add_argument(
6463
"--output_path",
@@ -91,7 +90,12 @@ def main(argv: List[str]) -> None:
9190
model.load_from_checkpoint(checkpoint_path=args.load_path)
9291

9392
# Download and setup the data module
94-
data_path = download_data(args.data_path, tmpdir)
93+
if not args.data_path:
94+
data_path = os.path.join(tmpdir, "data")
95+
os.makedirs(data_path)
96+
create_random_data(data_path)
97+
else:
98+
data_path = download_data(args.data_path, tmpdir)
9599
data = TinyImageNetDataModule(
96100
data_dir=data_path,
97101
batch_size=1,

torchx/examples/apps/lightning_classy_vision/train.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,7 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
5656
parser.add_argument(
5757
"--batch_size", type=int, default=32, help="batch size to use for training"
5858
)
59-
parser.add_argument("--num_samples", type=int, default=None, help="num_samples")
60-
parser.add_argument(
61-
"--test",
62-
help="Sets to test mode, training on a much smaller set of randomly generated images",
63-
action="store_true",
64-
)
59+
parser.add_argument("--num_samples", type=int, default=10, help="num_samples")
6560
parser.add_argument(
6661
"--data_path",
6762
type=str,
@@ -122,7 +117,7 @@ def main(argv: List[str]) -> None:
122117
data = TinyImageNetDataModule(
123118
data_dir=data_path,
124119
batch_size=args.batch_size,
125-
num_samples=5 if args.test else args.num_samples,
120+
num_samples=args.num_samples,
126121
)
127122

128123
# Setup model checkpointing

0 commit comments

Comments
 (0)