Skip to content

Commit ebe2c53

Browse files
committed
Update based on comments
1 parent f71aa8c commit ebe2c53

File tree

3 files changed

+13
-50
lines changed

3 files changed

+13
-50
lines changed

benchmark/citation/train_eval.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def random_planetoid_splits(data, num_classes):
3939
def run_train(dataset, model, runs, epochs, lr, weight_decay, early_stopping,
4040
profiling, use_compile, permute_masks=None, logger=None):
4141
val_losses, accs, durations = [], [], []
42+
if use_compile:
43+
model = torch_geometric.compile(model)
44+
4245
for run in range(runs):
4346
data = dataset[0]
4447
if permute_masks is not None:
@@ -97,14 +100,6 @@ def run_train(dataset, model, runs, epochs, lr, weight_decay, early_stopping,
97100
with torch_profile():
98101
train(model, optimizer, data)
99102

100-
if use_compile:
101-
print("Using torch.compile")
102-
compiled_model = torch_geometric.compile(model)
103-
train(compiled_model, optimizer, data)
104-
train(compiled_model, optimizer, data)
105-
with timeit():
106-
train(compiled_model, optimizer, data)
107-
108103

109104
@torch.no_grad()
110105
def run_inference(dataset, model, epochs, profiling, bf16, use_compile,
@@ -115,6 +110,8 @@ def run_inference(dataset, model, epochs, profiling, bf16, use_compile,
115110
data = data.to(device)
116111

117112
model.to(device).reset_parameters()
113+
if use_compile:
114+
model = torch_geometric.compile(model)
118115

119116
if torch.cuda.is_available():
120117
amp = torch.cuda.amp.autocast(enabled=False)
@@ -135,14 +132,6 @@ def run_inference(dataset, model, epochs, profiling, bf16, use_compile,
135132
with torch_profile():
136133
inference(model, data)
137134

138-
if use_compile:
139-
print("Using torch.compile")
140-
compiled_model = torch_geometric.compile(model)
141-
inference(compiled_model, data)
142-
inference(compiled_model, data)
143-
with timeit():
144-
inference(compiled_model, data)
145-
146135

147136
def run(dataset, model, runs, epochs, lr, weight_decay, early_stopping,
148137
inference, profiling, bf16, use_compile, permute_masks=None,

benchmark/kernel/main_performance.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ def run_train():
7979

8080
model = Model(dataset, num_layers, hidden).to(device)
8181
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
82-
82+
if args.compile:
83+
model = torch_geometric.compile(model)
8384
loss_list = []
8485
acc_list = []
8586
for epoch in range(1, args.epochs + 1):
@@ -106,14 +107,6 @@ def run_train():
106107
rename_profile_file(model_name, dataset_name, str(num_layers),
107108
str(hidden), 'train')
108109

109-
if args.compile:
110-
print("Using torch.compile")
111-
compiled_model = torch_geometric.compile(model)
112-
eval_acc(compiled_model, val_loader)
113-
eval_acc(compiled_model, val_loader)
114-
with timeit():
115-
eval_acc(compiled_model, val_loader)
116-
117110

118111
@torch.no_grad()
119112
def run_inference():
@@ -126,7 +119,8 @@ def run_inference():
126119
print(f'{dataset_name} - {model_name}- {num_layers} - {hidden}')
127120

128121
model = Model(dataset, num_layers, hidden).to(device)
129-
122+
if args.compile:
123+
model = torch_geometric.compile(model)
130124
with amp:
131125
for epoch in range(1, args.epochs + 1):
132126
if epoch == args.epochs:
@@ -142,14 +136,6 @@ def run_inference():
142136
str(num_layers), str(hidden),
143137
'inference')
144138

145-
if args.compile:
146-
print("Using torch.compile")
147-
compiled_model = torch_geometric.compile(model)
148-
inference_run(compiled_model, test_loader, args.bf16)
149-
inference_run(compiled_model, test_loader, args.bf16)
150-
with timeit():
151-
inference_run(compiled_model, test_loader, args.bf16)
152-
153139

154140
if not args.inference:
155141
run_train()

benchmark/points/train_eval.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ def run_train(train_dataset, test_dataset, model, epochs, batch_size,
1515
use_compile, lr, lr_decay_factor, lr_decay_step_size,
1616
weight_decay):
1717
model = model.to(device)
18+
if use_compile:
19+
model = torch_geometric.compile(model)
1820
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
1921

2022
train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
@@ -41,19 +43,13 @@ def run_train(train_dataset, test_dataset, model, epochs, batch_size,
4143
for param_group in optimizer.param_groups:
4244
param_group['lr'] = lr_decay_factor * param_group['lr']
4345

44-
if use_compile:
45-
print("Using torch.compile")
46-
compiled_model = torch_geometric.compile(model)
47-
test(compiled_model, test_loader, device)
48-
test(compiled_model, test_loader, device)
49-
with timeit():
50-
test(compiled_model, test_loader, device)
51-
5246

5347
@torch.no_grad()
5448
def run_inference(test_dataset, model, epochs, batch_size, profiling, bf16,
5549
use_compile):
5650
model = model.to(device)
51+
if use_compile:
52+
model = torch_geometric.compile(model)
5753
test_loader = DataLoader(test_dataset, batch_size, shuffle=False)
5854

5955
if torch.cuda.is_available():
@@ -74,14 +70,6 @@ def run_inference(test_dataset, model, epochs, batch_size, profiling, bf16,
7470
with torch_profile():
7571
inference(model, test_loader, device, bf16)
7672

77-
if use_compile:
78-
print("Using torch.compile")
79-
compiled_model = torch_geometric.compile(model)
80-
inference(compiled_model, test_loader, device, bf16)
81-
inference(compiled_model, test_loader, device, bf16)
82-
with timeit():
83-
inference(compiled_model, test_loader, device, bf16)
84-
8573

8674
def run(train_dataset, test_dataset, model, epochs, batch_size, lr,
8775
lr_decay_factor, lr_decay_step_size, weight_decay, inference,

0 commit comments

Comments
 (0)