-
Notifications
You must be signed in to change notification settings - Fork 9.7k
Distributed example on C++ API (LibTorch) #809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 10 commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
4c4e64a
Initial commit on a distributed training example in cpp
soumyadipghosh e31e42e
removed comments and ran through clang-format
soumyadipghosh 357694d
Added Readme
soumyadipghosh dc1afe6
remove hard-coded path to dataset
soumyadipghosh b028f4a
added comment regarding performance comparison with DDP
soumyadipghosh 31fb2fc
fixing a bug in test set path
soumyadipghosh a0bfb9f
adding a separate file for GPU
soumyadipghosh 5c3c2bd
initial attempt at using ProcessGroupMPI; code doesn't compile at thi…
soumyadipghosh 44618f0
remove the GPU file to keep PR simple!
soumyadipghosh 47a8fd1
fixing some syntax; code still failing
soumyadipghosh 29464df
Link c10d static library and fix some compilation errors
4682c47
Divide numranks after allreduce is finished
016de29
Merge pull request #1 from lasagnaphil/dist-compile-fix
soumyadipghosh 5baaba8
adding manual libtorch compilation guide
soumyadipghosh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
cmake_minimum_required(VERSION 3.1 FATAL_ERROR) | ||
project(dist-mnist) | ||
|
||
find_package(Torch REQUIRED) | ||
|
||
find_package(MPI REQUIRED) | ||
|
||
include_directories(SYSTEM ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH}) | ||
|
||
add_executable(dist-mnist dist-mnist.cpp) | ||
target_link_libraries(dist-mnist ${TORCH_LIBRARIES}) | ||
target_link_libraries(dist-mnist ${MPI_LIBRARIES}) | ||
|
||
if(MPI_COMPILE_FLAGS) | ||
set_target_properties(dist-mnist PROPERTIES | ||
COMPILE_FLAGS "${MPI_COMPILE_FLAGS}") | ||
endif() | ||
|
||
if(MPI_LINK_FLAGS) | ||
set_target_properties(dist-mnist PROPERTIES | ||
LINK_FLAGS "${MPI_LINK_FLAGS}") | ||
endif() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Distributed Training on MNIST using PyTorch C++ Frontend (Libtorch) | ||
|
||
This folder contains an example of data-parallel training of a convolutional neural network on the MNIST dataset. For parallelization, Message Passing Interface (MPI) is used. | ||
|
||
The entire code is contained in dist-mnist.cpp | ||
|
||
You can find instructions on how to install MPI [here] (https://www.open-mpi.org/faq/?category=building). This code was tested on Open MPI but it should run on other MPI distributions as well such as MPICH, MVAPICH, etc. | ||
|
||
To build the code, run the following commands from the terminal: | ||
|
||
```shell | ||
$ cd distributed | ||
$ mkdir build | ||
$ cd build | ||
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch .. | ||
$ make | ||
``` | ||
|
||
where /path/to/libtorch should be the path to the unzipped LibTorch distribution, which you can get from the [PyTorch homepage] ((https://pytorch.org/get-started/locally/). | ||
|
||
To run the code, | ||
|
||
```shell | ||
mpirun -np {NUM-PROCS} ./dist-mnist | ||
``` | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
#include <c10d/ProcessGroupMPI.hpp> | ||
#include <torch/torch.h> | ||
#include <iostream> | ||
|
||
// Define a Convolutional Module | ||
struct Model : torch::nn::Module { | ||
Model() | ||
: conv1(torch::nn::Conv2dOptions(1, 10, 5)), | ||
conv2(torch::nn::Conv2dOptions(10, 20, 5)), | ||
fc1(320, 50), | ||
fc2(50, 10) { | ||
register_module("conv1", conv1); | ||
register_module("conv2", conv2); | ||
register_module("conv2_drop", conv2_drop); | ||
register_module("fc1", fc1); | ||
register_module("fc2", fc2); | ||
} | ||
|
||
torch::Tensor forward(torch::Tensor x) { | ||
x = torch::relu(torch::max_pool2d(conv1->forward(x), 2)); | ||
x = torch::relu( | ||
torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2)); | ||
x = x.view({-1, 320}); | ||
x = torch::relu(fc1->forward(x)); | ||
x = torch::dropout(x, 0.5, is_training()); | ||
x = fc2->forward(x); | ||
return torch::log_softmax(x, 1); | ||
} | ||
|
||
torch::nn::Conv2d conv1; | ||
torch::nn::Conv2d conv2; | ||
torch::nn::Dropout2d conv2_drop; | ||
torch::nn::Linear fc1; | ||
torch::nn::Linear fc2; | ||
}; | ||
|
||
void waitWork( | ||
c10::intrusive_ptr<::c10d::ProcessGroupMPI> pg, | ||
std::vector<c10::intrusive_ptr<c10d::ProcessGroup::Work>> works) { | ||
for (auto& work : works) { | ||
try { | ||
work->wait(); | ||
} catch (const std::exception& ex) { | ||
std::cerr << "Exception received: " << ex.what() << std::endl; | ||
pg->abort(); | ||
} | ||
} | ||
} | ||
|
||
int main(int argc, char* argv[]) { | ||
// Creating MPI Process Group | ||
auto pg = c10d::ProcessGroupMPI::createProcessGroupMPI(); | ||
|
||
// Retrieving MPI environment variables | ||
auto numranks = pg->getSize(); | ||
auto rank = pg->getRank(); | ||
|
||
// TRAINING | ||
// Read train dataset | ||
const char* kDataRoot = "../data"; | ||
auto train_dataset = | ||
torch::data::datasets::MNIST(kDataRoot) | ||
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) | ||
.map(torch::data::transforms::Stack<>()); | ||
|
||
// Distributed Random Sampler | ||
auto data_sampler = torch::data::samplers::DistributedRandomSampler( | ||
train_dataset.size().value(), numranks, rank, false); | ||
|
||
auto num_train_samples_per_proc = train_dataset.size().value() / numranks; | ||
|
||
// Generate dataloader | ||
auto total_batch_size = 64; | ||
auto batch_size_per_proc = | ||
total_batch_size / numranks; // effective batch size in each processor | ||
auto data_loader = torch::data::make_data_loader( | ||
std::move(train_dataset), data_sampler, batch_size_per_proc); | ||
|
||
// setting manual seed | ||
torch::manual_seed(0); | ||
|
||
auto model = std::make_shared<Model>(); | ||
|
||
auto learning_rate = 1e-2; | ||
|
||
torch::optim::SGD optimizer(model->parameters(), learning_rate); | ||
|
||
// Number of epochs | ||
size_t num_epochs = 10; | ||
|
||
for (size_t epoch = 1; epoch <= num_epochs; ++epoch) { | ||
size_t num_correct = 0; | ||
|
||
for (auto& batch : *data_loader) { | ||
auto ip = batch.data; | ||
auto op = batch.target.squeeze(); | ||
|
||
// convert to required formats | ||
ip = ip.to(torch::kF32); | ||
op = op.to(torch::kLong); | ||
|
||
// Reset gradients | ||
model->zero_grad(); | ||
|
||
// Execute forward pass | ||
auto prediction = model->forward(ip); | ||
|
||
auto loss = torch::nll_loss(torch::log_softmax(prediction, 1), op); | ||
|
||
// Backpropagation | ||
loss.backward(); | ||
|
||
// Averaging the gradients of the parameters in all the processors | ||
soumyadipghosh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// Note: This may lag behind DistributedDataParallel (DDP) in performance | ||
// since this synchronizes parameters after backward pass while DDP | ||
// overlaps synchronizing parameters and computing gradients in backward | ||
// pass | ||
std::vector<c10::intrusive_ptr<::c10d::ProcessGroup::Work>> works; | ||
for (auto& param : model->named_parameters()) { | ||
c10::intrusive_ptr<::c10d::ProcessGroup::Work> work = | ||
pg->allreduce(param.value().grad()); | ||
works.push_back(std::move(work)); | ||
param.value().grad().data() = param.value().grad().data() / numranks; | ||
} | ||
|
||
waitWork(pg, works); | ||
|
||
// Update parameters | ||
optimizer.step(); | ||
|
||
auto guess = prediction.argmax(1); | ||
num_correct += torch::sum(guess.eq_(op)).item<int64_t>(); | ||
} // end batch loader | ||
|
||
auto accuracy = 100.0 * num_correct / num_train_samples_per_proc; | ||
|
||
std::cout << "Accuracy in rank " << rank << " in epoch " << epoch << " - " | ||
<< accuracy << std::endl; | ||
|
||
} // end epoch | ||
|
||
// TESTING ONLY IN RANK 0 | ||
if (rank == 0) { | ||
auto test_dataset = | ||
torch::data::datasets::MNIST( | ||
kDataRoot, torch::data::datasets::MNIST::Mode::kTest) | ||
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) | ||
.map(torch::data::transforms::Stack<>()); | ||
|
||
auto num_test_samples = test_dataset.size().value(); | ||
auto test_loader = torch::data::make_data_loader( | ||
std::move(test_dataset), num_test_samples); | ||
|
||
model->eval(); // enable eval mode to prevent backprop | ||
|
||
size_t num_correct = 0; | ||
|
||
for (auto& batch : *test_loader) { | ||
auto ip = batch.data; | ||
auto op = batch.target.squeeze(); | ||
|
||
// convert to required format | ||
ip = ip.to(torch::kF32); | ||
op = op.to(torch::kLong); | ||
|
||
auto prediction = model->forward(ip); | ||
|
||
auto loss = torch::nll_loss(torch::log_softmax(prediction, 1), op); | ||
|
||
std::cout << "Test loss - " << loss.item<float>() << std::endl; | ||
|
||
auto guess = prediction.argmax(1); | ||
|
||
num_correct += torch::sum(guess.eq_(op)).item<int64_t>(); | ||
|
||
} // end test loader | ||
|
||
std::cout << "Num correct - " << num_correct << std::endl; | ||
std::cout << "Test Accuracy - " << 100.0 * num_correct / num_test_samples | ||
<< std::endl; | ||
} // end rank 0 | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.