1
1
#include < torch/torch.h>
2
-
2
+ # include < argparse/argparse.hpp >
3
3
#include < cmath>
4
4
#include < cstdio>
5
5
#include < iostream>
@@ -10,9 +10,6 @@ const int64_t kNoiseSize = 100;
10
10
// The batch size for training.
11
11
const int64_t kBatchSize = 64 ;
12
12
13
- // The number of epochs to train.
14
- const int64_t kNumberOfEpochs = 30 ;
15
-
16
13
// Where to find the MNIST dataset.
17
14
const char * kDataFolder = " ./data" ;
18
15
@@ -75,7 +72,43 @@ struct DCGANGeneratorImpl : nn::Module {
75
72
76
73
TORCH_MODULE (DCGANGenerator);
77
74
75
+ nn::Sequential create_discriminator () {
76
+ return nn::Sequential (
77
+ // Layer 1
78
+ nn::Conv2d (nn::Conv2dOptions (1 , 64 , 4 ).stride (2 ).padding (1 ).bias (false )),
79
+ nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
80
+ // Layer 2
81
+ nn::Conv2d (nn::Conv2dOptions (64 , 128 , 4 ).stride (2 ).padding (1 ).bias (false )),
82
+ nn::BatchNorm2d (128 ),
83
+ nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
84
+ // Layer 3
85
+ nn::Conv2d (
86
+ nn::Conv2dOptions (128 , 256 , 4 ).stride (2 ).padding (1 ).bias (false )),
87
+ nn::BatchNorm2d (256 ),
88
+ nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
89
+ // Layer 4
90
+ nn::Conv2d (nn::Conv2dOptions (256 , 1 , 3 ).stride (1 ).padding (0 ).bias (false )),
91
+ nn::Sigmoid ());
92
+ }
93
+
78
94
int main (int argc, const char * argv[]) {
95
+ argparse::ArgumentParser parser (" cpp/dcgan example" );
96
+ parser.add_argument (" --epochs" )
97
+ .help (" Number of epochs to train" )
98
+ .default_value (std::int64_t {30 })
99
+ .scan <' i' , int64_t >();
100
+ try {
101
+ parser.parse_args (argc, argv);
102
+ } catch (const std::exception& err) {
103
+ std::cout << err.what () << std::endl;
104
+ std::cout << parser;
105
+ std::exit (1 );
106
+ }
107
+ // The number of epochs to train, default value is 30.
108
+ const int64_t kNumberOfEpochs = parser.get <int64_t >(" --epochs" );
109
+ std::cout << " Traning with number of epochs: " << kNumberOfEpochs
110
+ << std::endl;
111
+
79
112
torch::manual_seed (1 );
80
113
81
114
// Create the device we pass around based on whether CUDA is available.
@@ -88,33 +121,15 @@ int main(int argc, const char* argv[]) {
88
121
DCGANGenerator generator (kNoiseSize );
89
122
generator->to (device);
90
123
91
- nn::Sequential discriminator (
92
- // Layer 1
93
- nn::Conv2d (
94
- nn::Conv2dOptions (1 , 64 , 4 ).stride (2 ).padding (1 ).bias (false )),
95
- nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
96
- // Layer 2
97
- nn::Conv2d (
98
- nn::Conv2dOptions (64 , 128 , 4 ).stride (2 ).padding (1 ).bias (false )),
99
- nn::BatchNorm2d (128 ),
100
- nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
101
- // Layer 3
102
- nn::Conv2d (
103
- nn::Conv2dOptions (128 , 256 , 4 ).stride (2 ).padding (1 ).bias (false )),
104
- nn::BatchNorm2d (256 ),
105
- nn::LeakyReLU (nn::LeakyReLUOptions ().negative_slope (0.2 )),
106
- // Layer 4
107
- nn::Conv2d (
108
- nn::Conv2dOptions (256 , 1 , 3 ).stride (1 ).padding (0 ).bias (false )),
109
- nn::Sigmoid ());
124
+ nn::Sequential discriminator = create_discriminator ();
110
125
discriminator->to (device);
111
126
112
127
// Assume the MNIST dataset is available under `kDataFolder`;
113
128
auto dataset = torch::data::datasets::MNIST (kDataFolder )
114
129
.map (torch::data::transforms::Normalize<>(0.5 , 0.5 ))
115
130
.map (torch::data::transforms::Stack<>());
116
- const int64_t batches_per_epoch =
117
- std::ceil (dataset.size ().value () / static_cast <double >(kBatchSize ));
131
+ const int64_t batches_per_epoch = static_cast < int64_t >(
132
+ std::ceil (dataset.size ().value () / static_cast <double >(kBatchSize ))) ;
118
133
119
134
auto data_loader = torch::data::make_data_loader (
120
135
std::move (dataset),
@@ -136,7 +151,7 @@ int main(int argc, const char* argv[]) {
136
151
int64_t checkpoint_counter = 1 ;
137
152
for (int64_t epoch = 1 ; epoch <= kNumberOfEpochs ; ++epoch) {
138
153
int64_t batch_index = 0 ;
139
- for (torch::data::Example<>& batch : *data_loader) {
154
+ for (const torch::data::Example<>& batch : *data_loader) {
140
155
// Train discriminator with real images.
141
156
discriminator->zero_grad ();
142
157
torch::Tensor real_images = batch.data .to (device);
0 commit comments