Skip to content

Commit e20876f

Browse files
authored
Merge pull request #169 from RaulPPelaez/neighbors_one_list_rules_all
Adding a cell list neighbor list module
2 parents a0ccd78 + 605a4b5 commit e20876f

14 files changed

+1975
-1
lines changed

benchmarks/neighbors.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import os
2+
import torch
3+
import numpy as np
4+
from torchmdnet.models.utils import Distance, OptimizedDistance
5+
6+
7+
def benchmark_neighbors(
8+
device, strategy, n_batches, total_num_particles, mean_num_neighbors, density
9+
):
10+
"""Benchmark the neighbor list generation.
11+
12+
Parameters
13+
----------
14+
device : str
15+
Device to use for the benchmark.
16+
strategy : str
17+
Strategy to use for the neighbor list generation (cell, brute).
18+
n_batches : int
19+
Number of batches to generate.
20+
total_num_particles : int
21+
Total number of particles.
22+
mean_num_neighbors : int
23+
Mean number of neighbors per particle.
24+
density : float
25+
Density of the system.
26+
Returns
27+
-------
28+
float
29+
Average time per batch in seconds.
30+
"""
31+
torch.random.manual_seed(12344)
32+
np.random.seed(43211)
33+
num_particles = total_num_particles // n_batches
34+
expected_num_neighbors = mean_num_neighbors
35+
cutoff = np.cbrt(3 * expected_num_neighbors / (4 * np.pi * density))
36+
n_atoms_per_batch = torch.randint(
37+
int(num_particles / 2), int(num_particles * 2), size=(n_batches,), device="cpu"
38+
)
39+
# Fix so that the total number of particles is correct. Special care if the difference is negative
40+
difference = total_num_particles - n_atoms_per_batch.sum()
41+
if difference > 0:
42+
while difference > 0:
43+
i = np.random.randint(0, n_batches)
44+
n_atoms_per_batch[i] += 1
45+
difference -= 1
46+
else:
47+
while difference < 0:
48+
i = np.random.randint(0, n_batches)
49+
if n_atoms_per_batch[i] > num_particles:
50+
n_atoms_per_batch[i] -= 1
51+
difference += 1
52+
lbox = np.cbrt(num_particles / density)
53+
batch = torch.repeat_interleave(
54+
torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch
55+
).to(device)
56+
cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch]))
57+
pos = torch.rand(cumsum[-1], 3, device="cpu").to(device) * lbox
58+
if strategy != "distance":
59+
max_num_pairs = (expected_num_neighbors * n_atoms_per_batch.sum()).item() * 2
60+
box = torch.eye(3, device=device) * lbox
61+
nl = OptimizedDistance(
62+
cutoff_upper=cutoff,
63+
max_num_pairs=max_num_pairs,
64+
strategy=strategy,
65+
box=box,
66+
loop=False,
67+
include_transpose=True,
68+
check_errors=False,
69+
resize_to_fit=False,
70+
)
71+
else:
72+
max_num_neighbors = int(expected_num_neighbors * 5)
73+
nl = Distance(
74+
loop=False,
75+
cutoff_lower=0.0,
76+
cutoff_upper=cutoff,
77+
max_num_neighbors=max_num_neighbors,
78+
)
79+
# Warmup
80+
s = torch.cuda.Stream()
81+
s.wait_stream(torch.cuda.current_stream())
82+
with torch.cuda.stream(s):
83+
for i in range(10):
84+
neighbors, distances, distance_vecs = nl(pos, batch)
85+
torch.cuda.synchronize()
86+
nruns = 50
87+
torch.cuda.synchronize()
88+
89+
start = torch.cuda.Event(enable_timing=True)
90+
end = torch.cuda.Event(enable_timing=True)
91+
graph = torch.cuda.CUDAGraph()
92+
# record in a cuda graph
93+
if strategy != "distance":
94+
with torch.cuda.graph(graph):
95+
neighbors, distances, distance_vecs = nl(pos, batch)
96+
start.record()
97+
for i in range(nruns):
98+
graph.replay()
99+
end.record()
100+
else:
101+
start.record()
102+
for i in range(nruns):
103+
neighbors, distances, distance_vecs = nl(pos, batch)
104+
end.record()
105+
torch.cuda.synchronize()
106+
# Final time
107+
return start.elapsed_time(end) / nruns
108+
109+
110+
if __name__ == "__main__":
111+
n_particles = 32767
112+
mean_num_neighbors = min(n_particles, 64)
113+
density = 0.8
114+
print(
115+
"Benchmarking neighbor list generation for {} particles with {} neighbors on average".format(
116+
n_particles, mean_num_neighbors
117+
)
118+
)
119+
results = {}
120+
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
121+
for strategy in ["shared", "brute", "cell", "distance"]:
122+
for n_batches in batch_sizes:
123+
time = benchmark_neighbors(
124+
device="cuda",
125+
strategy=strategy,
126+
n_batches=n_batches,
127+
total_num_particles=n_particles,
128+
mean_num_neighbors=mean_num_neighbors,
129+
density=density,
130+
)
131+
# Store results in a dictionary
132+
results[strategy, n_batches] = time
133+
print("Summary")
134+
print("-------")
135+
print(
136+
"{:<10} {:<21} {:<21} {:<18} {:<10}".format(
137+
"Batch size", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)"
138+
)
139+
)
140+
print(
141+
"{:<10} {:<21} {:<21} {:<18} {:<10}".format(
142+
"----------", "---------", "---------", "---------", "---------"
143+
)
144+
)
145+
# Print a column per strategy, show speedup over Distance in parenthesis
146+
for n_batches in batch_sizes:
147+
base = results["distance", n_batches]
148+
print(
149+
"{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format(
150+
n_batches,
151+
results["shared", n_batches],
152+
base / results["shared", n_batches],
153+
results["brute", n_batches],
154+
base / results["brute", n_batches],
155+
results["cell", n_batches],
156+
base / results["cell", n_batches],
157+
results["distance", n_batches],
158+
)
159+
)
160+
n_particles_list = np.power(2, np.arange(8, 18))
161+
162+
for n_batches in [1, 2, 32, 64]:
163+
print(
164+
"Benchmarking neighbor list generation for {} batches with {} neighbors on average".format(
165+
n_batches, mean_num_neighbors
166+
)
167+
)
168+
results = {}
169+
for strategy in ["shared", "brute", "cell", "distance"]:
170+
for n_particles in n_particles_list:
171+
mean_num_neighbors = min(n_particles, 64)
172+
time = benchmark_neighbors(
173+
device="cuda",
174+
strategy=strategy,
175+
n_batches=n_batches,
176+
total_num_particles=n_particles,
177+
mean_num_neighbors=mean_num_neighbors,
178+
density=density,
179+
)
180+
# Store results in a dictionary
181+
results[strategy, n_particles] = time
182+
print("Summary")
183+
print("-------")
184+
print(
185+
"{:<10} {:<21} {:<21} {:<18} {:<10}".format(
186+
"N Particles", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)"
187+
)
188+
)
189+
print(
190+
"{:<10} {:<21} {:<21} {:<18} {:<10}".format(
191+
"----------", "---------", "---------", "---------", "---------"
192+
)
193+
)
194+
# Print a column per strategy, show speedup over Distance in parenthesis
195+
for n_particles in n_particles_list:
196+
base = results["distance", n_particles]
197+
brute_speedup = base / results["brute", n_particles]
198+
if n_particles > 32000:
199+
results["brute", n_particles] = 0
200+
brute_speedup = 0
201+
print(
202+
"{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format(
203+
n_particles,
204+
results["shared", n_particles],
205+
base / results["shared", n_particles],
206+
results["brute", n_particles],
207+
brute_speedup,
208+
results["cell", n_particles],
209+
base / results["cell", n_particles],
210+
results["distance", n_particles],
211+
)
212+
)

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ dependencies:
1818
- flake8
1919
- pytest
2020
- psutil
21+
- ninja

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@
1515
name="torchmd-net",
1616
version=version,
1717
packages=find_packages(),
18+
package_data={"torchmdnet": ["neighbors/neighbors*", "neighbors/*.cu*"]},
19+
include_package_data=True,
1820
entry_points={"console_scripts": ["torchmd-train = torchmdnet.scripts.train:main"]},
1921
)

0 commit comments

Comments
 (0)