3
3
import torch .nn as nn
4
4
import torch .nn .functional as F
5
5
import torch .optim as optim
6
- from torchvision import datasets , transforms
6
+ from torchvision import datasets
7
+ from torchvision .transforms import v2 as transforms
7
8
from torch .optim .lr_scheduler import StepLR
8
9
9
10
10
11
class Net (nn .Module ):
11
12
def __init__ (self ):
12
- super (Net , self ).__init__ ()
13
+ super ().__init__ ()
13
14
self .conv1 = nn .Conv2d (1 , 32 , 3 , 1 )
14
15
self .conv2 = nn .Conv2d (32 , 64 , 3 , 1 )
15
16
self .dropout1 = nn .Dropout (0.25 )
@@ -33,19 +34,42 @@ def forward(self, x):
33
34
return output
34
35
35
36
36
- def train (args , model , device , train_loader , optimizer , epoch ):
37
+ def train_amp (args , model , device , train_loader , opt , epoch , scaler ):
37
38
model .train ()
38
39
for batch_idx , (data , target ) in enumerate (train_loader ):
39
- data , target = data .to (device ), target .to (device )
40
- optimizer .zero_grad ()
40
+ data , target = data .to (device , memory_format = torch .channels_last ), target .to (
41
+ device
42
+ )
43
+ opt .zero_grad ()
44
+ with torch .autocast (device_type = device .type ):
45
+ output = model (data )
46
+ loss = F .nll_loss (output , target )
47
+ scaler .scale (loss ).backward ()
48
+ scaler .step (opt )
49
+ scaler .update ()
50
+ if batch_idx % args .log_interval == 0 :
51
+ print (
52
+ f"Train Epoch: { epoch } [{ batch_idx * len (data )} /{ len (train_loader .dataset )} ({ 100.0 * batch_idx / len (train_loader ):.0f} %)]\t Loss: { loss .item ():.6f} "
53
+ )
54
+ if args .dry_run :
55
+ break
56
+
57
+
58
+ def train (args , model , device , train_loader , opt , epoch ):
59
+ model .train ()
60
+ for batch_idx , (data , target ) in enumerate (train_loader ):
61
+ data , target = data .to (device , memory_format = torch .channels_last ), target .to (
62
+ device
63
+ )
64
+ opt .zero_grad ()
41
65
output = model (data )
42
66
loss = F .nll_loss (output , target )
43
67
loss .backward ()
44
- optimizer .step ()
68
+ opt .step ()
45
69
if batch_idx % args .log_interval == 0 :
46
- print ('Train Epoch: {} [{}/{} ({:.0f}%)] \t Loss: {:.6f}' . format (
47
- epoch , batch_idx * len (data ), len (train_loader .dataset ),
48
- 100. * batch_idx / len ( train_loader ), loss . item ()) )
70
+ print (
71
+ f"Train Epoch: { epoch } [ { batch_idx * len (data )} / { len (train_loader .dataset )} ( { 100.0 * batch_idx / len ( train_loader ):.0f } %)] \t Loss: { loss . item ():.6f } "
72
+ )
49
73
if args .dry_run :
50
74
break
51
75
@@ -58,87 +82,179 @@ def test(model, device, test_loader):
58
82
for data , target in test_loader :
59
83
data , target = data .to (device ), target .to (device )
60
84
output = model (data )
61
- test_loss += F .nll_loss (output , target , reduction = 'sum' ).item () # sum up batch loss
62
- pred = output .argmax (dim = 1 , keepdim = True ) # get the index of the max log-probability
85
+ test_loss += F .nll_loss (
86
+ output , target , reduction = "sum"
87
+ ).item () # sum up batch loss
88
+ pred = output .argmax (
89
+ dim = 1 , keepdim = True
90
+ ) # get the index of the max log-probability
63
91
correct += pred .eq (target .view_as (pred )).sum ().item ()
64
92
65
93
test_loss /= len (test_loader .dataset )
66
94
67
- print (' \n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%) \n ' . format (
68
- test_loss , correct , len (test_loader .dataset ),
69
- 100. * correct / len ( test_loader . dataset )) )
95
+ print (
96
+ f" \n Test set: Average loss: { test_loss :.4f } , Accuracy: { correct } / { len ( test_loader . dataset ) } ( { 100.0 * correct / len (test_loader .dataset ):.0f } %) \n "
97
+ )
70
98
71
99
72
- def main ():
100
+ def parse_args ():
73
101
# Training settings
74
- parser = argparse .ArgumentParser (description = 'PyTorch MNIST Example' )
75
- parser .add_argument ('--batch-size' , type = int , default = 64 , metavar = 'N' ,
76
- help = 'input batch size for training (default: 64)' )
77
- parser .add_argument ('--test-batch-size' , type = int , default = 1000 , metavar = 'N' ,
78
- help = 'input batch size for testing (default: 1000)' )
79
- parser .add_argument ('--epochs' , type = int , default = 14 , metavar = 'N' ,
80
- help = 'number of epochs to train (default: 14)' )
81
- parser .add_argument ('--lr' , type = float , default = 1.0 , metavar = 'LR' ,
82
- help = 'learning rate (default: 1.0)' )
83
- parser .add_argument ('--gamma' , type = float , default = 0.7 , metavar = 'M' ,
84
- help = 'Learning rate step gamma (default: 0.7)' )
85
- parser .add_argument ('--no-cuda' , action = 'store_true' , default = False ,
86
- help = 'disables CUDA training' )
87
- parser .add_argument ('--no-mps' , action = 'store_true' , default = False ,
88
- help = 'disables macOS GPU training' )
89
- parser .add_argument ('--dry-run' , action = 'store_true' , default = False ,
90
- help = 'quickly check a single pass' )
91
- parser .add_argument ('--seed' , type = int , default = 1 , metavar = 'S' ,
92
- help = 'random seed (default: 1)' )
93
- parser .add_argument ('--log-interval' , type = int , default = 10 , metavar = 'N' ,
94
- help = 'how many batches to wait before logging training status' )
95
- parser .add_argument ('--save-model' , action = 'store_true' , default = False ,
96
- help = 'For Saving the current Model' )
102
+ parser = argparse .ArgumentParser (description = "PyTorch MNIST Example" )
103
+ parser .add_argument (
104
+ "--batch-size" ,
105
+ type = int ,
106
+ default = 64 ,
107
+ metavar = "N" ,
108
+ help = "input batch size for training (default: 64)" ,
109
+ )
110
+ parser .add_argument (
111
+ "--test-batch-size" ,
112
+ type = int ,
113
+ default = 1000 ,
114
+ metavar = "N" ,
115
+ help = "input batch size for testing (default: 1000)" ,
116
+ )
117
+ parser .add_argument (
118
+ "--epochs" ,
119
+ type = int ,
120
+ default = 14 ,
121
+ metavar = "N" ,
122
+ help = "number of epochs to train (default: 14)" ,
123
+ )
124
+ parser .add_argument (
125
+ "--lr" ,
126
+ type = float ,
127
+ default = 1.0 ,
128
+ metavar = "LR" ,
129
+ help = "learning rate (default: 1.0)" ,
130
+ )
131
+ parser .add_argument (
132
+ "--gamma" ,
133
+ type = float ,
134
+ default = 0.7 ,
135
+ metavar = "M" ,
136
+ help = "Learning rate step gamma (default: 0.7)" ,
137
+ )
138
+ parser .add_argument (
139
+ "--no-cuda" , action = "store_true" , default = False , help = "disables CUDA training"
140
+ )
141
+ parser .add_argument (
142
+ "--no-mps" ,
143
+ action = "store_true" ,
144
+ default = False ,
145
+ help = "disables macOS GPU training" ,
146
+ )
147
+ parser .add_argument (
148
+ "--dry-run" ,
149
+ action = "store_true" ,
150
+ default = False ,
151
+ help = "quickly check a single pass" ,
152
+ )
153
+ parser .add_argument (
154
+ "--seed" , type = int , default = 1 , metavar = "S" , help = "random seed (default: 1)"
155
+ )
156
+ parser .add_argument (
157
+ "--log-interval" ,
158
+ type = int ,
159
+ default = 10 ,
160
+ metavar = "N" ,
161
+ help = "how many batches to wait before logging training status" ,
162
+ )
163
+ parser .add_argument (
164
+ "--use-amp" ,
165
+ type = bool ,
166
+ default = False ,
167
+ help = "use automatic mixed precision" ,
168
+ )
169
+ parser .add_argument (
170
+ "--compile-backend" ,
171
+ type = str ,
172
+ default = "inductor" ,
173
+ metavar = "BACKEND" ,
174
+ help = "backend to compile the model with" ,
175
+ )
176
+ parser .add_argument (
177
+ "--compile-mode" ,
178
+ type = str ,
179
+ default = "default" ,
180
+ metavar = "MODE" ,
181
+ help = "compilation mode" ,
182
+ )
183
+ parser .add_argument (
184
+ "--save-model" ,
185
+ action = "store_true" ,
186
+ default = False ,
187
+ help = "For Saving the current Model" ,
188
+ )
189
+ parser .add_argument (
190
+ "--data-dir" ,
191
+ type = str ,
192
+ default = "../data" ,
193
+ metavar = "DIR" ,
194
+ help = "path to the data directory" ,
195
+ )
97
196
args = parser .parse_args ()
197
+
198
+ return args
199
+
200
+
201
+ def main ():
202
+ args = parse_args ()
203
+
98
204
use_cuda = not args .no_cuda and torch .cuda .is_available ()
99
205
use_mps = not args .no_mps and torch .backends .mps .is_available ()
100
206
101
- torch .manual_seed (args .seed )
102
-
103
207
if use_cuda :
104
208
device = torch .device ("cuda" )
105
209
elif use_mps :
106
210
device = torch .device ("mps" )
107
211
else :
108
212
device = torch .device ("cpu" )
109
213
110
- train_kwargs = {' batch_size' : args .batch_size }
111
- test_kwargs = {' batch_size' : args .test_batch_size }
214
+ train_kwargs = {" batch_size" : args .batch_size }
215
+ test_kwargs = {" batch_size" : args .test_batch_size }
112
216
if use_cuda :
113
- cuda_kwargs = {'num_workers' : 1 ,
114
- 'pin_memory' : True ,
115
- 'shuffle' : True }
217
+ cuda_kwargs = {"num_workers" : 1 , "pin_memory" : True , "shuffle" : True }
116
218
train_kwargs .update (cuda_kwargs )
117
219
test_kwargs .update (cuda_kwargs )
118
220
119
- transform = transforms .Compose ([
120
- transforms .ToTensor (),
121
- transforms .Normalize ((0.1307 ,), (0.3081 ,))
122
- ])
123
- dataset1 = datasets .MNIST ('../data' , train = True , download = True ,
124
- transform = transform )
125
- dataset2 = datasets .MNIST ('../data' , train = False ,
126
- transform = transform )
127
- train_loader = torch .utils .data .DataLoader (dataset1 ,** train_kwargs )
221
+ transform = transforms .Compose (
222
+ [
223
+ transforms .ToImage (),
224
+ transforms .ToDtype (torch .float32 , scale = True ),
225
+ transforms .Normalize (mean = (0.1307 ,), std = (0.3081 ,)),
226
+ ]
227
+ )
228
+
229
+ data_dir = args .data_dir
230
+
231
+ dataset1 = datasets .MNIST (data_dir , train = True , download = True , transform = transform )
232
+ dataset2 = datasets .MNIST (data_dir , train = False , transform = transform )
233
+ train_loader = torch .utils .data .DataLoader (dataset1 , ** train_kwargs )
128
234
test_loader = torch .utils .data .DataLoader (dataset2 , ** test_kwargs )
129
235
130
- model = Net ().to (device )
131
- optimizer = optim .Adadelta (model .parameters (), lr = args .lr )
236
+ model = Net ().to (device , memory_format = torch .channels_last )
237
+ model = torch .compile (model , backend = args .compile_backend , mode = args .compile_mode )
238
+ optimizer = optim .Adadelta (model .parameters (), lr = torch .tensor (args .lr ))
132
239
133
240
scheduler = StepLR (optimizer , step_size = 1 , gamma = args .gamma )
241
+
242
+ scaler = None
243
+ if use_cuda and args .use_amp :
244
+ scaler = torch .GradScaler (device = device )
245
+
134
246
for epoch in range (1 , args .epochs + 1 ):
135
- train (args , model , device , train_loader , optimizer , epoch )
247
+ if scaler is None :
248
+ train (args , model , device , train_loader , optimizer , epoch )
249
+ else :
250
+ train_amp (args , model , device , train_loader , optimizer , epoch , scaler )
136
251
test (model , device , test_loader )
137
252
scheduler .step ()
138
253
139
254
if args .save_model :
140
255
torch .save (model .state_dict (), "mnist_cnn.pt" )
141
256
142
257
143
- if __name__ == ' __main__' :
258
+ if __name__ == " __main__" :
144
259
main ()
260
+
0 commit comments