Skip to content

Added torch.distributed.launch module for easier multi-proc/node distributed job launching #5348

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 13, 2018

Conversation

teng-li
Copy link
Contributor

@teng-li teng-li commented Feb 22, 2018

A helper module to launch multi-process distributed jobs either on a single-node or multiple-node

$python -m torch.distributed.launch --help
usage: launch.py [-h] [--num_node NUM_NODE] [--rank_node RANK_NODE]
                 [--nproc_per_node NPROC_PER_NODE] [--master_addr MASTER_ADDR]
                 [--master_port MASTER_PORT] [--dist_backend DIST_BACKEND]
                 training_script ...

PyTorch distributed training launch helper utilty that will spawn up multiple
distributed processes

positional arguments:
  training_script       The full path to the single GPU training
                        program/script to be launched in parallel, followed by
                        all the arguments for the training script
  training_script_args

optional arguments:
  -h, --help            show this help message and exit
  --num_node NUM_NODE   The number of nodes to use for distributed training
  --rank_node RANK_NODE
                        The rank of the node for multi-node distributed
                        training
  --nproc_per_node NPROC_PER_NODE
                        The number of processes to launch on each node
  --master_addr MASTER_ADDR
                        Master node (rank 0)'s address, should be either the
                        IP address or the hostname of node 0, for single node
                        multi-proc training, the --master_addr can simply be
                        127.0.0.1
  --master_port MASTER_PORT
                        Master node (rank 0)'s free port that needs to be used
                        for communciation during distributed training

can be used with

pytorch/examples#306

For example: single node multi-process training:

python -m torch.distributed.launch ./main.py -j 0 -a resnet18 --print-freq 1 --batch-size 32 --dist-url 'env://' /datasets01/imagenet_full_size/061417/ --epochs 1 --dist-backend 'nccl''

Multi-node multi-process training would be similar using

python -m torch.distributed.launch --num_node=2 --rank_node=0 --nproc_per_node=2 --master_addr=devfair033 ./main.py -j 0 -a resnet18 --print-freq 1 --batch-size 32 --dist-url 'env://' /datasets01/imagenet_full_size/061417/ --epochs 1 --dist-backend 'nccl'

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you verified that this can also be used like this:

python -m torch.utils.distributed.pytorch_dist_exec ...

In your case you have a development copy of PyTorch and you know where all the files are located, but this is not true for users that e.g. download binaries. They will have to use the command above. It would also be nice to include an example command in the docs, and make the filename less verbose (e.g. shorten it to torch.distributed.start)

Helper function parsing the command line options
@retval ArgumentParser
"""
parser = ArgumentParser(description="PyTorch Exec is a helper utiliy that "

This comment was marked as off-topic.

This comment was marked as off-topic.

"training currently only supports the NCCL distributed backend. "
"This utilty helper will require that training script is able to "
"parse --device=X as an argument since it will be injected by this "
"utility. "

This comment was marked as off-topic.

This comment was marked as off-topic.

"127.0.0.1")
parser.add_argument("--master_port", default=29500, type=int,
help="Master node (rank 0)'s free port that needs to be used for "
"communciation in distributed training")

This comment was marked as off-topic.

This comment was marked as off-topic.

@soumith
Copy link
Member

soumith commented Feb 22, 2018

  • what adam said.
    Additionally change name to:
python -m torch.distributed.launch

@teng-li teng-li changed the title Added pytorch_dist_exec utiliy for easier distributed job launching Added torch.distributed.launch module for easier multi-proc/node distributed job launching Feb 22, 2018
@teng-li teng-li force-pushed the pytorch_exec branch 4 times, most recently from 4a1e644 to 1464ca4 Compare February 23, 2018 00:30
python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE
--num_node=2 --rank_node=1 --master_addr="192.168.1.1"
--master_port=1234 YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and
all other arguments of your training script)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

module:

torch.distributed.init_process_group(backend='YOUR BACKEND',
"init_method='env://')

This comment was marked as off-topic.

This comment was marked as off-topic.

"init_method='env://')

(4) In your training program, you are supposed to convert your model to
DistributedDataParallel module using the following function. Please ensure

This comment was marked as off-topic.

This comment was marked as off-topic.

device_ids=[arg.device])

(5) For multi-node training, current we only support nodes with identical number
of GPUs. In other words, the number of GPUs on each node needs to be the same.

This comment was marked as off-topic.

This comment was marked as off-topic.

(5) For multi-node training, current we only support nodes with identical number
of GPUs. In other words, the number of GPUs on each node needs to be the same.

"""

This comment was marked as off-topic.

parser.add_argument("--rank_node", type=int, default=0,
help="The rank of the node for multi-node distributed "
"training")
parser.add_argument("--nproc_per_node", type=int, default=-1,

This comment was marked as off-topic.

This comment was marked as off-topic.



args = parse_args()
num_gpus = torch.cuda.device_count()

This comment was marked as off-topic.

This comment was marked as off-topic.

parser.add_argument("--num_node", type=int, default=1,
help="The number of nodes to use for distributed "
"training")
parser.add_argument("--rank_node", type=int, default=0,

This comment was marked as off-topic.

This comment was marked as off-topic.

help="The rank of the node for multi-node distributed "
"training")
parser.add_argument("--nproc_per_node", type=int, default=1,
help="The number of processes to launch on each node, "

This comment was marked as off-topic.

This comment was marked as off-topic.

@teng-li
Copy link
Contributor Author

teng-li commented Mar 6, 2018

@ngimel I deleted it because we would like this tool to work on cpu training as well.

@apaszke . Mind taking another look?

@ngimel
Copy link
Collaborator

ngimel commented Mar 6, 2018

@teng-li but the the next line "will default to the number of GPUs on your system if not specified" is not correct, and single-node example in line 25 won't do what users expect it to do. I don't mind it not being set (though to me it would still feel more natural if launcher helper took care of using all available GPUs), but it should be reflected in the documentation and examples.

@ezyang
Copy link
Contributor

ezyang commented Mar 6, 2018

@pytorchbot retest this please

3 similar comments
@ezyang
Copy link
Contributor

ezyang commented Mar 6, 2018

@pytorchbot retest this please

@ezyang
Copy link
Contributor

ezyang commented Mar 6, 2018

@pytorchbot retest this please

@ezyang
Copy link
Contributor

ezyang commented Mar 6, 2018

@pytorchbot retest this please

# spawn the processes
cmd = ["python",
args.training_script,
"--device={}".format(local_rank)] + args.training_script_args

This comment was marked as off-topic.

@apaszke apaszke merged commit 37059ba into pytorch:master Mar 13, 2018
1. This utilty and multi-process distributed (single-node or
multi-node) GPU training currently only achieves the best performance using
the NCCL distributed backend. Thus NCCL backend is the recommended backend to
use for GPU training.

This comment was marked as off-topic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants