Skip to content

Commit a537659

Browse files
authored
Improve code readability and make number of epochs a command line argument (#1222)
* Change the cpp/dcgan * Use an open source argparse implementation
1 parent b88d805 commit a537659

File tree

3 files changed

+50
-28
lines changed

3 files changed

+50
-28
lines changed

.github/workflows/main_cpp.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,14 @@ jobs:
3131
run: |
3232
sudo apt -y install libtbb-dev
3333
sudo apt install libopencv-dev
34-
34+
- name: Install argparse
35+
run: |
36+
git clone https://github.com/p-ranav/argparse
37+
cd argparse
38+
mkdir build
39+
cd build
40+
cmake -DARGPARSE_BUILD_SAMPLES=off -DARGPARSE_BUILD_TESTS=off ..
41+
sudo make install
3542
# Alternatively, you can install OpenCV from source
3643
# - name: Install OpenCV from source
3744
# run: |

cpp/dcgan/dcgan.cpp

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include <torch/torch.h>
2-
2+
#include <argparse/argparse.hpp>
33
#include <cmath>
44
#include <cstdio>
55
#include <iostream>
@@ -10,9 +10,6 @@ const int64_t kNoiseSize = 100;
1010
// The batch size for training.
1111
const int64_t kBatchSize = 64;
1212

13-
// The number of epochs to train.
14-
const int64_t kNumberOfEpochs = 30;
15-
1613
// Where to find the MNIST dataset.
1714
const char* kDataFolder = "./data";
1815

@@ -75,7 +72,43 @@ struct DCGANGeneratorImpl : nn::Module {
7572

7673
TORCH_MODULE(DCGANGenerator);
7774

75+
nn::Sequential create_discriminator() {
76+
return nn::Sequential(
77+
// Layer 1
78+
nn::Conv2d(nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
79+
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
80+
// Layer 2
81+
nn::Conv2d(nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
82+
nn::BatchNorm2d(128),
83+
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
84+
// Layer 3
85+
nn::Conv2d(
86+
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
87+
nn::BatchNorm2d(256),
88+
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
89+
// Layer 4
90+
nn::Conv2d(nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
91+
nn::Sigmoid());
92+
}
93+
7894
int main(int argc, const char* argv[]) {
95+
argparse::ArgumentParser parser("cpp/dcgan example");
96+
parser.add_argument("--epochs")
97+
.help("Number of epochs to train")
98+
.default_value(std::int64_t{30})
99+
.scan<'i', int64_t>();
100+
try {
101+
parser.parse_args(argc, argv);
102+
} catch (const std::exception& err) {
103+
std::cout << err.what() << std::endl;
104+
std::cout << parser;
105+
std::exit(1);
106+
}
107+
// The number of epochs to train, default value is 30.
108+
const int64_t kNumberOfEpochs = parser.get<int64_t>("--epochs");
109+
std::cout << "Traning with number of epochs: " << kNumberOfEpochs
110+
<< std::endl;
111+
79112
torch::manual_seed(1);
80113

81114
// Create the device we pass around based on whether CUDA is available.
@@ -88,33 +121,15 @@ int main(int argc, const char* argv[]) {
88121
DCGANGenerator generator(kNoiseSize);
89122
generator->to(device);
90123

91-
nn::Sequential discriminator(
92-
// Layer 1
93-
nn::Conv2d(
94-
nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).bias(false)),
95-
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
96-
// Layer 2
97-
nn::Conv2d(
98-
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).bias(false)),
99-
nn::BatchNorm2d(128),
100-
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
101-
// Layer 3
102-
nn::Conv2d(
103-
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).bias(false)),
104-
nn::BatchNorm2d(256),
105-
nn::LeakyReLU(nn::LeakyReLUOptions().negative_slope(0.2)),
106-
// Layer 4
107-
nn::Conv2d(
108-
nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).bias(false)),
109-
nn::Sigmoid());
124+
nn::Sequential discriminator = create_discriminator();
110125
discriminator->to(device);
111126

112127
// Assume the MNIST dataset is available under `kDataFolder`;
113128
auto dataset = torch::data::datasets::MNIST(kDataFolder)
114129
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
115130
.map(torch::data::transforms::Stack<>());
116-
const int64_t batches_per_epoch =
117-
std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));
131+
const int64_t batches_per_epoch = static_cast<int64_t>(
132+
std::ceil(dataset.size().value() / static_cast<double>(kBatchSize)));
118133

119134
auto data_loader = torch::data::make_data_loader(
120135
std::move(dataset),
@@ -136,7 +151,7 @@ int main(int argc, const char* argv[]) {
136151
int64_t checkpoint_counter = 1;
137152
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
138153
int64_t batch_index = 0;
139-
for (torch::data::Example<>& batch : *data_loader) {
154+
for (const torch::data::Example<>& batch : *data_loader) {
140155
// Train discriminator with real images.
141156
discriminator->zero_grad();
142157
torch::Tensor real_images = batch.data.to(device);

run_cpp_examples.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ function dcgan() {
102102
make
103103
if [ $? -eq 0 ]; then
104104
echo "Successfully built $EXAMPLE"
105-
./$EXAMPLE # Run the executable
105+
./$EXAMPLE --epochs 5 # Run the executable with kNumberOfEpochs = 5
106106
check_run_success $EXAMPLE
107107
else
108108
error "Failed to build $EXAMPLE"

0 commit comments

Comments
 (0)