Skip to content

Commit 1e6beea

Browse files
committed
Changes to the code
1 parent 26de419 commit 1e6beea

File tree

1 file changed

+174
-58
lines changed

1 file changed

+174
-58
lines changed

mnist/main.py

Lines changed: 174 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import torch.nn as nn
44
import torch.nn.functional as F
55
import torch.optim as optim
6-
from torchvision import datasets, transforms
6+
from torchvision import datasets
7+
from torchvision.transforms import v2 as transforms
78
from torch.optim.lr_scheduler import StepLR
89

910

1011
class Net(nn.Module):
1112
def __init__(self):
12-
super(Net, self).__init__()
13+
super().__init__()
1314
self.conv1 = nn.Conv2d(1, 32, 3, 1)
1415
self.conv2 = nn.Conv2d(32, 64, 3, 1)
1516
self.dropout1 = nn.Dropout(0.25)
@@ -33,19 +34,42 @@ def forward(self, x):
3334
return output
3435

3536

36-
def train(args, model, device, train_loader, optimizer, epoch):
37+
def train_amp(args, model, device, train_loader, opt, epoch, scaler):
3738
model.train()
3839
for batch_idx, (data, target) in enumerate(train_loader):
39-
data, target = data.to(device), target.to(device)
40-
optimizer.zero_grad()
40+
data, target = data.to(device, memory_format=torch.channels_last), target.to(
41+
device
42+
)
43+
opt.zero_grad()
44+
with torch.autocast(device_type=device.type):
45+
output = model(data)
46+
loss = F.nll_loss(output, target)
47+
scaler.scale(loss).backward()
48+
scaler.step(opt)
49+
scaler.update()
50+
if batch_idx % args.log_interval == 0:
51+
print(
52+
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
53+
)
54+
if args.dry_run:
55+
break
56+
57+
58+
def train(args, model, device, train_loader, opt, epoch):
59+
model.train()
60+
for batch_idx, (data, target) in enumerate(train_loader):
61+
data, target = data.to(device, memory_format=torch.channels_last), target.to(
62+
device
63+
)
64+
opt.zero_grad()
4165
output = model(data)
4266
loss = F.nll_loss(output, target)
4367
loss.backward()
44-
optimizer.step()
68+
opt.step()
4569
if batch_idx % args.log_interval == 0:
46-
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
47-
epoch, batch_idx * len(data), len(train_loader.dataset),
48-
100. * batch_idx / len(train_loader), loss.item()))
70+
print(
71+
f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100.0 * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}"
72+
)
4973
if args.dry_run:
5074
break
5175

@@ -58,87 +82,179 @@ def test(model, device, test_loader):
5882
for data, target in test_loader:
5983
data, target = data.to(device), target.to(device)
6084
output = model(data)
61-
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
62-
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
85+
test_loss += F.nll_loss(
86+
output, target, reduction="sum"
87+
).item() # sum up batch loss
88+
pred = output.argmax(
89+
dim=1, keepdim=True
90+
) # get the index of the max log-probability
6391
correct += pred.eq(target.view_as(pred)).sum().item()
6492

6593
test_loss /= len(test_loader.dataset)
6694

67-
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
68-
test_loss, correct, len(test_loader.dataset),
69-
100. * correct / len(test_loader.dataset)))
95+
print(
96+
f"\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100.0 * correct / len(test_loader.dataset):.0f}%)\n"
97+
)
7098

7199

72-
def main():
100+
def parse_args():
73101
# Training settings
74-
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
75-
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
76-
help='input batch size for training (default: 64)')
77-
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
78-
help='input batch size for testing (default: 1000)')
79-
parser.add_argument('--epochs', type=int, default=14, metavar='N',
80-
help='number of epochs to train (default: 14)')
81-
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
82-
help='learning rate (default: 1.0)')
83-
parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
84-
help='Learning rate step gamma (default: 0.7)')
85-
parser.add_argument('--no-cuda', action='store_true', default=False,
86-
help='disables CUDA training')
87-
parser.add_argument('--no-mps', action='store_true', default=False,
88-
help='disables macOS GPU training')
89-
parser.add_argument('--dry-run', action='store_true', default=False,
90-
help='quickly check a single pass')
91-
parser.add_argument('--seed', type=int, default=1, metavar='S',
92-
help='random seed (default: 1)')
93-
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
94-
help='how many batches to wait before logging training status')
95-
parser.add_argument('--save-model', action='store_true', default=False,
96-
help='For Saving the current Model')
102+
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
103+
parser.add_argument(
104+
"--batch-size",
105+
type=int,
106+
default=64,
107+
metavar="N",
108+
help="input batch size for training (default: 64)",
109+
)
110+
parser.add_argument(
111+
"--test-batch-size",
112+
type=int,
113+
default=1000,
114+
metavar="N",
115+
help="input batch size for testing (default: 1000)",
116+
)
117+
parser.add_argument(
118+
"--epochs",
119+
type=int,
120+
default=14,
121+
metavar="N",
122+
help="number of epochs to train (default: 14)",
123+
)
124+
parser.add_argument(
125+
"--lr",
126+
type=float,
127+
default=1.0,
128+
metavar="LR",
129+
help="learning rate (default: 1.0)",
130+
)
131+
parser.add_argument(
132+
"--gamma",
133+
type=float,
134+
default=0.7,
135+
metavar="M",
136+
help="Learning rate step gamma (default: 0.7)",
137+
)
138+
parser.add_argument(
139+
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
140+
)
141+
parser.add_argument(
142+
"--no-mps",
143+
action="store_true",
144+
default=False,
145+
help="disables macOS GPU training",
146+
)
147+
parser.add_argument(
148+
"--dry-run",
149+
action="store_true",
150+
default=False,
151+
help="quickly check a single pass",
152+
)
153+
parser.add_argument(
154+
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
155+
)
156+
parser.add_argument(
157+
"--log-interval",
158+
type=int,
159+
default=10,
160+
metavar="N",
161+
help="how many batches to wait before logging training status",
162+
)
163+
parser.add_argument(
164+
"--use-amp",
165+
type=bool,
166+
default=False,
167+
help="use automatic mixed precision",
168+
)
169+
parser.add_argument(
170+
"--compile-backend",
171+
type=str,
172+
default="inductor",
173+
metavar="BACKEND",
174+
help="backend to compile the model with",
175+
)
176+
parser.add_argument(
177+
"--compile-mode",
178+
type=str,
179+
default="default",
180+
metavar="MODE",
181+
help="compilation mode",
182+
)
183+
parser.add_argument(
184+
"--save-model",
185+
action="store_true",
186+
default=False,
187+
help="For Saving the current Model",
188+
)
189+
parser.add_argument(
190+
"--data-dir",
191+
type=str,
192+
default="../data",
193+
metavar="DIR",
194+
help="path to the data directory",
195+
)
97196
args = parser.parse_args()
197+
198+
return args
199+
200+
201+
def main():
202+
args = parse_args()
203+
98204
use_cuda = not args.no_cuda and torch.cuda.is_available()
99205
use_mps = not args.no_mps and torch.backends.mps.is_available()
100206

101-
torch.manual_seed(args.seed)
102-
103207
if use_cuda:
104208
device = torch.device("cuda")
105209
elif use_mps:
106210
device = torch.device("mps")
107211
else:
108212
device = torch.device("cpu")
109213

110-
train_kwargs = {'batch_size': args.batch_size}
111-
test_kwargs = {'batch_size': args.test_batch_size}
214+
train_kwargs = {"batch_size": args.batch_size}
215+
test_kwargs = {"batch_size": args.test_batch_size}
112216
if use_cuda:
113-
cuda_kwargs = {'num_workers': 1,
114-
'pin_memory': True,
115-
'shuffle': True}
217+
cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
116218
train_kwargs.update(cuda_kwargs)
117219
test_kwargs.update(cuda_kwargs)
118220

119-
transform=transforms.Compose([
120-
transforms.ToTensor(),
121-
transforms.Normalize((0.1307,), (0.3081,))
122-
])
123-
dataset1 = datasets.MNIST('../data', train=True, download=True,
124-
transform=transform)
125-
dataset2 = datasets.MNIST('../data', train=False,
126-
transform=transform)
127-
train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
221+
transform = transforms.Compose(
222+
[
223+
transforms.ToImage(),
224+
transforms.ToDtype(torch.float32, scale=True),
225+
transforms.Normalize(mean=(0.1307,), std=(0.3081,)),
226+
]
227+
)
228+
229+
data_dir = args.data_dir
230+
231+
dataset1 = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
232+
dataset2 = datasets.MNIST(data_dir, train=False, transform=transform)
233+
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
128234
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
129235

130-
model = Net().to(device)
131-
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
236+
model = Net().to(device, memory_format=torch.channels_last)
237+
model = torch.compile(model, backend=args.compile_backend, mode=args.compile_mode)
238+
optimizer = optim.Adadelta(model.parameters(), lr=torch.tensor(args.lr))
132239

133240
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
241+
242+
scaler = None
243+
if use_cuda and args.use_amp:
244+
scaler = torch.GradScaler(device=device)
245+
134246
for epoch in range(1, args.epochs + 1):
135-
train(args, model, device, train_loader, optimizer, epoch)
247+
if scaler is None:
248+
train(args, model, device, train_loader, optimizer, epoch)
249+
else:
250+
train_amp(args, model, device, train_loader, optimizer, epoch, scaler)
136251
test(model, device, test_loader)
137252
scheduler.step()
138253

139254
if args.save_model:
140255
torch.save(model.state_dict(), "mnist_cnn.pt")
141256

142257

143-
if __name__ == '__main__':
258+
if __name__ == "__main__":
144259
main()
260+

0 commit comments

Comments
 (0)