|
| 1 | +#include <c10d/ProcessGroupMPI.hpp> |
| 2 | +#include <torch/torch.h> |
| 3 | +#include <iostream> |
| 4 | + |
| 5 | +// Define a Convolutional Module |
| 6 | +struct Model : torch::nn::Module { |
| 7 | + Model() |
| 8 | + : conv1(torch::nn::Conv2dOptions(1, 10, 5)), |
| 9 | + conv2(torch::nn::Conv2dOptions(10, 20, 5)), |
| 10 | + fc1(320, 50), |
| 11 | + fc2(50, 10) { |
| 12 | + register_module("conv1", conv1); |
| 13 | + register_module("conv2", conv2); |
| 14 | + register_module("conv2_drop", conv2_drop); |
| 15 | + register_module("fc1", fc1); |
| 16 | + register_module("fc2", fc2); |
| 17 | + } |
| 18 | + |
| 19 | + torch::Tensor forward(torch::Tensor x) { |
| 20 | + x = torch::relu(torch::max_pool2d(conv1->forward(x), 2)); |
| 21 | + x = torch::relu( |
| 22 | + torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2)); |
| 23 | + x = x.view({-1, 320}); |
| 24 | + x = torch::relu(fc1->forward(x)); |
| 25 | + x = torch::dropout(x, 0.5, is_training()); |
| 26 | + x = fc2->forward(x); |
| 27 | + return torch::log_softmax(x, 1); |
| 28 | + } |
| 29 | + |
| 30 | + torch::nn::Conv2d conv1; |
| 31 | + torch::nn::Conv2d conv2; |
| 32 | + torch::nn::Dropout2d conv2_drop; |
| 33 | + torch::nn::Linear fc1; |
| 34 | + torch::nn::Linear fc2; |
| 35 | +}; |
| 36 | + |
| 37 | +void waitWork( |
| 38 | + std::shared_ptr<c10d::ProcessGroupMPI> pg, |
| 39 | + std::vector<std::shared_ptr<c10d::ProcessGroup::Work>> works) { |
| 40 | + for (auto& work : works) { |
| 41 | + try { |
| 42 | + work->wait(); |
| 43 | + } catch (const std::exception& ex) { |
| 44 | + std::cerr << "Exception received: " << ex.what() << std::endl; |
| 45 | + pg->abort(); |
| 46 | + } |
| 47 | + } |
| 48 | +} |
| 49 | + |
| 50 | +int main(int argc, char* argv[]) { |
| 51 | + // Creating MPI Process Group |
| 52 | + auto pg = c10d::ProcessGroupMPI::createProcessGroupMPI(); |
| 53 | + |
| 54 | + // Retrieving MPI environment variables |
| 55 | + auto numranks = pg->getSize(); |
| 56 | + auto rank = pg->getRank(); |
| 57 | + |
| 58 | + // TRAINING |
| 59 | + // Read train dataset |
| 60 | + const char* kDataRoot = "../data"; |
| 61 | + auto train_dataset = |
| 62 | + torch::data::datasets::MNIST(kDataRoot) |
| 63 | + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) |
| 64 | + .map(torch::data::transforms::Stack<>()); |
| 65 | + |
| 66 | + // Distributed Random Sampler |
| 67 | + auto data_sampler = torch::data::samplers::DistributedRandomSampler( |
| 68 | + train_dataset.size().value(), numranks, rank, false); |
| 69 | + |
| 70 | + auto num_train_samples_per_proc = train_dataset.size().value() / numranks; |
| 71 | + |
| 72 | + // Generate dataloader |
| 73 | + auto total_batch_size = 64; |
| 74 | + auto batch_size_per_proc = |
| 75 | + total_batch_size / numranks; // effective batch size in each processor |
| 76 | + auto data_loader = torch::data::make_data_loader( |
| 77 | + std::move(train_dataset), data_sampler, batch_size_per_proc); |
| 78 | + |
| 79 | + // setting manual seed |
| 80 | + torch::manual_seed(0); |
| 81 | + |
| 82 | + auto model = std::make_shared<Model>(); |
| 83 | + |
| 84 | + auto learning_rate = 1e-2; |
| 85 | + |
| 86 | + torch::optim::SGD optimizer(model->parameters(), learning_rate); |
| 87 | + |
| 88 | + // Number of epochs |
| 89 | + size_t num_epochs = 10; |
| 90 | + |
| 91 | + for (size_t epoch = 1; epoch <= num_epochs; ++epoch) { |
| 92 | + size_t num_correct = 0; |
| 93 | + |
| 94 | + for (auto& batch : *data_loader) { |
| 95 | + auto ip = batch.data; |
| 96 | + auto op = batch.target.squeeze(); |
| 97 | + |
| 98 | + // convert to required formats |
| 99 | + ip = ip.to(torch::kF32); |
| 100 | + op = op.to(torch::kLong); |
| 101 | + |
| 102 | + // Reset gradients |
| 103 | + model->zero_grad(); |
| 104 | + |
| 105 | + // Execute forward pass |
| 106 | + auto prediction = model->forward(ip); |
| 107 | + |
| 108 | + auto loss = torch::nll_loss(torch::log_softmax(prediction, 1), op); |
| 109 | + |
| 110 | + // Backpropagation |
| 111 | + loss.backward(); |
| 112 | + |
| 113 | + // Averaging the gradients of the parameters in all the processors |
| 114 | + // Note: This may lag behind DistributedDataParallel (DDP) in performance |
| 115 | + // since this synchronizes parameters after backward pass while DDP |
| 116 | + // overlaps synchronizing parameters and computing gradients in backward |
| 117 | + // pass |
| 118 | + std::vector<std::shared_ptr<::c10d::ProcessGroup::Work>> works; |
| 119 | + for (auto& param : model->named_parameters()) { |
| 120 | + std::vector<torch::Tensor> tmp = {param.value().grad()}; |
| 121 | + auto work = pg->allreduce(tmp); |
| 122 | + works.push_back(std::move(work)); |
| 123 | + } |
| 124 | + |
| 125 | + waitWork(pg, works); |
| 126 | + |
| 127 | + for (auto& param : model->named_parameters()) { |
| 128 | + param.value().grad().data() = param.value().grad().data() / numranks; |
| 129 | + } |
| 130 | + |
| 131 | + // Update parameters |
| 132 | + optimizer.step(); |
| 133 | + |
| 134 | + auto guess = prediction.argmax(1); |
| 135 | + num_correct += torch::sum(guess.eq_(op)).item<int64_t>(); |
| 136 | + } // end batch loader |
| 137 | + |
| 138 | + auto accuracy = 100.0 * num_correct / num_train_samples_per_proc; |
| 139 | + |
| 140 | + std::cout << "Accuracy in rank " << rank << " in epoch " << epoch << " - " |
| 141 | + << accuracy << std::endl; |
| 142 | + |
| 143 | + } // end epoch |
| 144 | + |
| 145 | + // TESTING ONLY IN RANK 0 |
| 146 | + if (rank == 0) { |
| 147 | + auto test_dataset = |
| 148 | + torch::data::datasets::MNIST( |
| 149 | + kDataRoot, torch::data::datasets::MNIST::Mode::kTest) |
| 150 | + .map(torch::data::transforms::Normalize<>(0.1307, 0.3081)) |
| 151 | + .map(torch::data::transforms::Stack<>()); |
| 152 | + |
| 153 | + auto num_test_samples = test_dataset.size().value(); |
| 154 | + auto test_loader = torch::data::make_data_loader( |
| 155 | + std::move(test_dataset), num_test_samples); |
| 156 | + |
| 157 | + model->eval(); // enable eval mode to prevent backprop |
| 158 | + |
| 159 | + size_t num_correct = 0; |
| 160 | + |
| 161 | + for (auto& batch : *test_loader) { |
| 162 | + auto ip = batch.data; |
| 163 | + auto op = batch.target.squeeze(); |
| 164 | + |
| 165 | + // convert to required format |
| 166 | + ip = ip.to(torch::kF32); |
| 167 | + op = op.to(torch::kLong); |
| 168 | + |
| 169 | + auto prediction = model->forward(ip); |
| 170 | + |
| 171 | + auto loss = torch::nll_loss(torch::log_softmax(prediction, 1), op); |
| 172 | + |
| 173 | + std::cout << "Test loss - " << loss.item<float>() << std::endl; |
| 174 | + |
| 175 | + auto guess = prediction.argmax(1); |
| 176 | + |
| 177 | + num_correct += torch::sum(guess.eq_(op)).item<int64_t>(); |
| 178 | + |
| 179 | + } // end test loader |
| 180 | + |
| 181 | + std::cout << "Num correct - " << num_correct << std::endl; |
| 182 | + std::cout << "Test Accuracy - " << 100.0 * num_correct / num_test_samples |
| 183 | + << std::endl; |
| 184 | + } // end rank 0 |
| 185 | +} |
0 commit comments