|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import numpy as np |
| 4 | +from torchmdnet.models.utils import Distance, OptimizedDistance |
| 5 | + |
| 6 | + |
| 7 | +def benchmark_neighbors( |
| 8 | + device, strategy, n_batches, total_num_particles, mean_num_neighbors, density |
| 9 | +): |
| 10 | + """Benchmark the neighbor list generation. |
| 11 | +
|
| 12 | + Parameters |
| 13 | + ---------- |
| 14 | + device : str |
| 15 | + Device to use for the benchmark. |
| 16 | + strategy : str |
| 17 | + Strategy to use for the neighbor list generation (cell, brute). |
| 18 | + n_batches : int |
| 19 | + Number of batches to generate. |
| 20 | + total_num_particles : int |
| 21 | + Total number of particles. |
| 22 | + mean_num_neighbors : int |
| 23 | + Mean number of neighbors per particle. |
| 24 | + density : float |
| 25 | + Density of the system. |
| 26 | + Returns |
| 27 | + ------- |
| 28 | + float |
| 29 | + Average time per batch in seconds. |
| 30 | + """ |
| 31 | + torch.random.manual_seed(12344) |
| 32 | + np.random.seed(43211) |
| 33 | + num_particles = total_num_particles // n_batches |
| 34 | + expected_num_neighbors = mean_num_neighbors |
| 35 | + cutoff = np.cbrt(3 * expected_num_neighbors / (4 * np.pi * density)) |
| 36 | + n_atoms_per_batch = torch.randint( |
| 37 | + int(num_particles / 2), int(num_particles * 2), size=(n_batches,), device="cpu" |
| 38 | + ) |
| 39 | + # Fix so that the total number of particles is correct. Special care if the difference is negative |
| 40 | + difference = total_num_particles - n_atoms_per_batch.sum() |
| 41 | + if difference > 0: |
| 42 | + while difference > 0: |
| 43 | + i = np.random.randint(0, n_batches) |
| 44 | + n_atoms_per_batch[i] += 1 |
| 45 | + difference -= 1 |
| 46 | + else: |
| 47 | + while difference < 0: |
| 48 | + i = np.random.randint(0, n_batches) |
| 49 | + if n_atoms_per_batch[i] > num_particles: |
| 50 | + n_atoms_per_batch[i] -= 1 |
| 51 | + difference += 1 |
| 52 | + lbox = np.cbrt(num_particles / density) |
| 53 | + batch = torch.repeat_interleave( |
| 54 | + torch.arange(n_batches, dtype=torch.int64), n_atoms_per_batch |
| 55 | + ).to(device) |
| 56 | + cumsum = np.cumsum(np.concatenate([[0], n_atoms_per_batch])) |
| 57 | + pos = torch.rand(cumsum[-1], 3, device="cpu").to(device) * lbox |
| 58 | + if strategy != "distance": |
| 59 | + max_num_pairs = (expected_num_neighbors * n_atoms_per_batch.sum()).item() * 2 |
| 60 | + box = torch.eye(3, device=device) * lbox |
| 61 | + nl = OptimizedDistance( |
| 62 | + cutoff_upper=cutoff, |
| 63 | + max_num_pairs=max_num_pairs, |
| 64 | + strategy=strategy, |
| 65 | + box=box, |
| 66 | + loop=False, |
| 67 | + include_transpose=True, |
| 68 | + check_errors=False, |
| 69 | + resize_to_fit=False, |
| 70 | + ) |
| 71 | + else: |
| 72 | + max_num_neighbors = int(expected_num_neighbors * 5) |
| 73 | + nl = Distance( |
| 74 | + loop=False, |
| 75 | + cutoff_lower=0.0, |
| 76 | + cutoff_upper=cutoff, |
| 77 | + max_num_neighbors=max_num_neighbors, |
| 78 | + ) |
| 79 | + # Warmup |
| 80 | + s = torch.cuda.Stream() |
| 81 | + s.wait_stream(torch.cuda.current_stream()) |
| 82 | + with torch.cuda.stream(s): |
| 83 | + for i in range(10): |
| 84 | + neighbors, distances, distance_vecs = nl(pos, batch) |
| 85 | + torch.cuda.synchronize() |
| 86 | + nruns = 50 |
| 87 | + torch.cuda.synchronize() |
| 88 | + |
| 89 | + start = torch.cuda.Event(enable_timing=True) |
| 90 | + end = torch.cuda.Event(enable_timing=True) |
| 91 | + graph = torch.cuda.CUDAGraph() |
| 92 | + # record in a cuda graph |
| 93 | + if strategy != "distance": |
| 94 | + with torch.cuda.graph(graph): |
| 95 | + neighbors, distances, distance_vecs = nl(pos, batch) |
| 96 | + start.record() |
| 97 | + for i in range(nruns): |
| 98 | + graph.replay() |
| 99 | + end.record() |
| 100 | + else: |
| 101 | + start.record() |
| 102 | + for i in range(nruns): |
| 103 | + neighbors, distances, distance_vecs = nl(pos, batch) |
| 104 | + end.record() |
| 105 | + torch.cuda.synchronize() |
| 106 | + # Final time |
| 107 | + return start.elapsed_time(end) / nruns |
| 108 | + |
| 109 | + |
| 110 | +if __name__ == "__main__": |
| 111 | + n_particles = 32767 |
| 112 | + mean_num_neighbors = min(n_particles, 64) |
| 113 | + density = 0.8 |
| 114 | + print( |
| 115 | + "Benchmarking neighbor list generation for {} particles with {} neighbors on average".format( |
| 116 | + n_particles, mean_num_neighbors |
| 117 | + ) |
| 118 | + ) |
| 119 | + results = {} |
| 120 | + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] |
| 121 | + for strategy in ["shared", "brute", "cell", "distance"]: |
| 122 | + for n_batches in batch_sizes: |
| 123 | + time = benchmark_neighbors( |
| 124 | + device="cuda", |
| 125 | + strategy=strategy, |
| 126 | + n_batches=n_batches, |
| 127 | + total_num_particles=n_particles, |
| 128 | + mean_num_neighbors=mean_num_neighbors, |
| 129 | + density=density, |
| 130 | + ) |
| 131 | + # Store results in a dictionary |
| 132 | + results[strategy, n_batches] = time |
| 133 | + print("Summary") |
| 134 | + print("-------") |
| 135 | + print( |
| 136 | + "{:<10} {:<21} {:<21} {:<18} {:<10}".format( |
| 137 | + "Batch size", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)" |
| 138 | + ) |
| 139 | + ) |
| 140 | + print( |
| 141 | + "{:<10} {:<21} {:<21} {:<18} {:<10}".format( |
| 142 | + "----------", "---------", "---------", "---------", "---------" |
| 143 | + ) |
| 144 | + ) |
| 145 | + # Print a column per strategy, show speedup over Distance in parenthesis |
| 146 | + for n_batches in batch_sizes: |
| 147 | + base = results["distance", n_batches] |
| 148 | + print( |
| 149 | + "{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format( |
| 150 | + n_batches, |
| 151 | + results["shared", n_batches], |
| 152 | + base / results["shared", n_batches], |
| 153 | + results["brute", n_batches], |
| 154 | + base / results["brute", n_batches], |
| 155 | + results["cell", n_batches], |
| 156 | + base / results["cell", n_batches], |
| 157 | + results["distance", n_batches], |
| 158 | + ) |
| 159 | + ) |
| 160 | + n_particles_list = np.power(2, np.arange(8, 18)) |
| 161 | + |
| 162 | + for n_batches in [1, 2, 32, 64]: |
| 163 | + print( |
| 164 | + "Benchmarking neighbor list generation for {} batches with {} neighbors on average".format( |
| 165 | + n_batches, mean_num_neighbors |
| 166 | + ) |
| 167 | + ) |
| 168 | + results = {} |
| 169 | + for strategy in ["shared", "brute", "cell", "distance"]: |
| 170 | + for n_particles in n_particles_list: |
| 171 | + mean_num_neighbors = min(n_particles, 64) |
| 172 | + time = benchmark_neighbors( |
| 173 | + device="cuda", |
| 174 | + strategy=strategy, |
| 175 | + n_batches=n_batches, |
| 176 | + total_num_particles=n_particles, |
| 177 | + mean_num_neighbors=mean_num_neighbors, |
| 178 | + density=density, |
| 179 | + ) |
| 180 | + # Store results in a dictionary |
| 181 | + results[strategy, n_particles] = time |
| 182 | + print("Summary") |
| 183 | + print("-------") |
| 184 | + print( |
| 185 | + "{:<10} {:<21} {:<21} {:<18} {:<10}".format( |
| 186 | + "N Particles", "Shared(ms)", "Brute(ms)", "Cell(ms)", "Distance(ms)" |
| 187 | + ) |
| 188 | + ) |
| 189 | + print( |
| 190 | + "{:<10} {:<21} {:<21} {:<18} {:<10}".format( |
| 191 | + "----------", "---------", "---------", "---------", "---------" |
| 192 | + ) |
| 193 | + ) |
| 194 | + # Print a column per strategy, show speedup over Distance in parenthesis |
| 195 | + for n_particles in n_particles_list: |
| 196 | + base = results["distance", n_particles] |
| 197 | + brute_speedup = base / results["brute", n_particles] |
| 198 | + if n_particles > 32000: |
| 199 | + results["brute", n_particles] = 0 |
| 200 | + brute_speedup = 0 |
| 201 | + print( |
| 202 | + "{:<10} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<4.2f} x{:<14.2f} {:<10.2f}".format( |
| 203 | + n_particles, |
| 204 | + results["shared", n_particles], |
| 205 | + base / results["shared", n_particles], |
| 206 | + results["brute", n_particles], |
| 207 | + brute_speedup, |
| 208 | + results["cell", n_particles], |
| 209 | + base / results["cell", n_particles], |
| 210 | + results["distance", n_particles], |
| 211 | + ) |
| 212 | + ) |
0 commit comments