Skip to content

Commit 2e8b5c5

Browse files
authored
Merge pull request #809 from soumyadipghosh/dist-cpp
Distributed example on C++ API (LibTorch)
2 parents 36441a8 + 5baaba8 commit 2e8b5c5

File tree

3 files changed

+234
-0
lines changed

3 files changed

+234
-0
lines changed

cpp/distributed/CMakeLists.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
2+
project(dist-mnist)
3+
4+
find_package(Torch REQUIRED)
5+
6+
find_package(MPI REQUIRED)
7+
8+
include_directories(SYSTEM ${MPI_C_INCLUDE_PATH} ${MPI_CXX_INCLUDE_PATH})
9+
10+
add_executable(dist-mnist dist-mnist.cpp)
11+
target_link_libraries(dist-mnist ${TORCH_LIBRARIES})
12+
target_link_libraries(dist-mnist ${MPI_LIBRARIES})
13+
target_link_libraries(dist-mnist ${CMAKE_PREFIX_PATH}/lib/libc10d.a)
14+
15+
if(MPI_COMPILE_FLAGS)
16+
set_target_properties(dist-mnist PROPERTIES
17+
COMPILE_FLAGS "${MPI_COMPILE_FLAGS}")
18+
endif()
19+
20+
if(MPI_LINK_FLAGS)
21+
set_target_properties(dist-mnist PROPERTIES
22+
LINK_FLAGS "${MPI_LINK_FLAGS}")
23+
endif()

cpp/distributed/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Distributed Training on MNIST using PyTorch C++ Frontend (Libtorch)
2+
3+
This folder contains an example of data-parallel training of a convolutional neural network on the MNIST dataset. For parallelization, Message Passing Interface (MPI) is used.
4+
5+
The entire code is contained in dist-mnist.cpp
6+
7+
You can find instructions on how to install MPI [here] (https://www.open-mpi.org/faq/?category=building). This code was tested on Open MPI but it should run on other MPI distributions as well such as MPICH, MVAPICH, etc.
8+
9+
To build the code, run the following commands from the terminal:
10+
11+
```shell
12+
$ cd distributed
13+
$ mkdir build
14+
$ cd build
15+
$ cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
16+
$ make
17+
```
18+
19+
where /path/to/libtorch should be the path to the unzipped LibTorch distribution. Note that the LibTorch from the [PyTorch homepage] ((https://pytorch.org/get-started/locally/) does not include MPI headers and cannot be used for this example. You have to compile LibTorch manually - a set of guidelines is provided [here] (https://gist.github.com/lasagnaphil/3e0099816837318e8e8bcab7edcfd5d9), however this may vary for different systems.
20+
21+
To run the code,
22+
23+
```shell
24+
mpirun -np {NUM-PROCS} ./dist-mnist
25+
```
26+

cpp/distributed/dist-mnist.cpp

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
#include <c10d/ProcessGroupMPI.hpp>
2+
#include <torch/torch.h>
3+
#include <iostream>
4+
5+
// Define a Convolutional Module
6+
struct Model : torch::nn::Module {
7+
Model()
8+
: conv1(torch::nn::Conv2dOptions(1, 10, 5)),
9+
conv2(torch::nn::Conv2dOptions(10, 20, 5)),
10+
fc1(320, 50),
11+
fc2(50, 10) {
12+
register_module("conv1", conv1);
13+
register_module("conv2", conv2);
14+
register_module("conv2_drop", conv2_drop);
15+
register_module("fc1", fc1);
16+
register_module("fc2", fc2);
17+
}
18+
19+
torch::Tensor forward(torch::Tensor x) {
20+
x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));
21+
x = torch::relu(
22+
torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));
23+
x = x.view({-1, 320});
24+
x = torch::relu(fc1->forward(x));
25+
x = torch::dropout(x, 0.5, is_training());
26+
x = fc2->forward(x);
27+
return torch::log_softmax(x, 1);
28+
}
29+
30+
torch::nn::Conv2d conv1;
31+
torch::nn::Conv2d conv2;
32+
torch::nn::Dropout2d conv2_drop;
33+
torch::nn::Linear fc1;
34+
torch::nn::Linear fc2;
35+
};
36+
37+
void waitWork(
38+
std::shared_ptr<c10d::ProcessGroupMPI> pg,
39+
std::vector<std::shared_ptr<c10d::ProcessGroup::Work>> works) {
40+
for (auto& work : works) {
41+
try {
42+
work->wait();
43+
} catch (const std::exception& ex) {
44+
std::cerr << "Exception received: " << ex.what() << std::endl;
45+
pg->abort();
46+
}
47+
}
48+
}
49+
50+
int main(int argc, char* argv[]) {
51+
// Creating MPI Process Group
52+
auto pg = c10d::ProcessGroupMPI::createProcessGroupMPI();
53+
54+
// Retrieving MPI environment variables
55+
auto numranks = pg->getSize();
56+
auto rank = pg->getRank();
57+
58+
// TRAINING
59+
// Read train dataset
60+
const char* kDataRoot = "../data";
61+
auto train_dataset =
62+
torch::data::datasets::MNIST(kDataRoot)
63+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
64+
.map(torch::data::transforms::Stack<>());
65+
66+
// Distributed Random Sampler
67+
auto data_sampler = torch::data::samplers::DistributedRandomSampler(
68+
train_dataset.size().value(), numranks, rank, false);
69+
70+
auto num_train_samples_per_proc = train_dataset.size().value() / numranks;
71+
72+
// Generate dataloader
73+
auto total_batch_size = 64;
74+
auto batch_size_per_proc =
75+
total_batch_size / numranks; // effective batch size in each processor
76+
auto data_loader = torch::data::make_data_loader(
77+
std::move(train_dataset), data_sampler, batch_size_per_proc);
78+
79+
// setting manual seed
80+
torch::manual_seed(0);
81+
82+
auto model = std::make_shared<Model>();
83+
84+
auto learning_rate = 1e-2;
85+
86+
torch::optim::SGD optimizer(model->parameters(), learning_rate);
87+
88+
// Number of epochs
89+
size_t num_epochs = 10;
90+
91+
for (size_t epoch = 1; epoch <= num_epochs; ++epoch) {
92+
size_t num_correct = 0;
93+
94+
for (auto& batch : *data_loader) {
95+
auto ip = batch.data;
96+
auto op = batch.target.squeeze();
97+
98+
// convert to required formats
99+
ip = ip.to(torch::kF32);
100+
op = op.to(torch::kLong);
101+
102+
// Reset gradients
103+
model->zero_grad();
104+
105+
// Execute forward pass
106+
auto prediction = model->forward(ip);
107+
108+
auto loss = torch::nll_loss(torch::log_softmax(prediction, 1), op);
109+
110+
// Backpropagation
111+
loss.backward();
112+
113+
// Averaging the gradients of the parameters in all the processors
114+
// Note: This may lag behind DistributedDataParallel (DDP) in performance
115+
// since this synchronizes parameters after backward pass while DDP
116+
// overlaps synchronizing parameters and computing gradients in backward
117+
// pass
118+
std::vector<std::shared_ptr<::c10d::ProcessGroup::Work>> works;
119+
for (auto& param : model->named_parameters()) {
120+
std::vector<torch::Tensor> tmp = {param.value().grad()};
121+
auto work = pg->allreduce(tmp);
122+
works.push_back(std::move(work));
123+
}
124+
125+
waitWork(pg, works);
126+
127+
for (auto& param : model->named_parameters()) {
128+
param.value().grad().data() = param.value().grad().data() / numranks;
129+
}
130+
131+
// Update parameters
132+
optimizer.step();
133+
134+
auto guess = prediction.argmax(1);
135+
num_correct += torch::sum(guess.eq_(op)).item<int64_t>();
136+
} // end batch loader
137+
138+
auto accuracy = 100.0 * num_correct / num_train_samples_per_proc;
139+
140+
std::cout << "Accuracy in rank " << rank << " in epoch " << epoch << " - "
141+
<< accuracy << std::endl;
142+
143+
} // end epoch
144+
145+
// TESTING ONLY IN RANK 0
146+
if (rank == 0) {
147+
auto test_dataset =
148+
torch::data::datasets::MNIST(
149+
kDataRoot, torch::data::datasets::MNIST::Mode::kTest)
150+
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
151+
.map(torch::data::transforms::Stack<>());
152+
153+
auto num_test_samples = test_dataset.size().value();
154+
auto test_loader = torch::data::make_data_loader(
155+
std::move(test_dataset), num_test_samples);
156+
157+
model->eval(); // enable eval mode to prevent backprop
158+
159+
size_t num_correct = 0;
160+
161+
for (auto& batch : *test_loader) {
162+
auto ip = batch.data;
163+
auto op = batch.target.squeeze();
164+
165+
// convert to required format
166+
ip = ip.to(torch::kF32);
167+
op = op.to(torch::kLong);
168+
169+
auto prediction = model->forward(ip);
170+
171+
auto loss = torch::nll_loss(torch::log_softmax(prediction, 1), op);
172+
173+
std::cout << "Test loss - " << loss.item<float>() << std::endl;
174+
175+
auto guess = prediction.argmax(1);
176+
177+
num_correct += torch::sum(guess.eq_(op)).item<int64_t>();
178+
179+
} // end test loader
180+
181+
std::cout << "Num correct - " << num_correct << std::endl;
182+
std::cout << "Test Accuracy - " << 100.0 * num_correct / num_test_samples
183+
<< std::endl;
184+
} // end rank 0
185+
}

0 commit comments

Comments
 (0)