-
Notifications
You must be signed in to change notification settings - Fork 85
Adding a cell list neighbor list module #169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding a cell list neighbor list module #169
Conversation
Use kInt32 for batch Update cell list impl.
Document cell list implementation Clean up a bit
behind an option) Add test to check identical outputs compared to Distance
include_traspose (whether to include redundant pairs)
triclinic boxes, cell list only rectangular) Some optimizations overall Update tests Update benchmark
Move error checking python-side update tests update benchmarks
Allow batch to be None (defaults to all zeros)
Use pytorch CUDACachingAllocator with thrust::sort and for temporary memory
Thanks for taking a look Peter! I am aware its a lot of code -.-
I think this would indeed be a good addition. Feel free to upstream this to NNPOps and we will then switch to that in torchmd-net. I can help too. Currently NNPOps does not support some of the things that are implemented here. In particular the possibility to include self interactions and "transpose" pairs. To port this to NNPOps we need to decide if we want to put the following functionality there:
All of these can be implemented in NNPOps as additions to the interface to remain retrocompatible. As a side note, I studied the voxel implementation you shared. One thing that I found improves traversal performance a lot is to use a single loop to go over cells: torchmd-net/torchmdnet/neighbors/neighbors_cuda_cell.cu Lines 482 to 485 in 9d5028e
Instead of three nested ones: https://github.com/openmm/openmm/blob/d6cca3903aa0be02c105aed4febcfa2747f48fc1/platforms/reference/src/SimTKReference/ReferenceNeighborList.cpp#L154 Its the classic "recompute whatever to avoid branching" CUDA tradeoff, also promotes unrolling. This is cool and all, but completely destroys the triclinic traversal strategy you have in the ref. Not sure how to go about it yet...
This is exactly what my code does, but I use a 64 bit hash composed of a Morton hash and the batch index to sort instead of just the cell linear index torchmd-net/torchmdnet/neighbors/neighbors_cuda_cell.cu Lines 137 to 142 in 9d5028e
Why separate this from the rest of the CUDA implementation? Assigning the hash is a small kernel launch in C++, and I am much more familiar with bit manipulation there.
I resorted to use the Radix sort implementation that comes with CUDA via cub because:
The Morton hash only uses 30 bits, so I could ignore the two last bits and use torch.long, solving 1. I chose not to do that because I suspect I can leverage these bits to improve batch handling in the future. Right now I construct a single cell list with all batches in it, but maybe a cell-list-per-batch is best? Which amounts to just switching the order here:
One hurtful thing to achieve this is that you do not actually know the number of batches CPU side, so you need to cudamemcpy to implement this. Not sure how to solve that... For 2 though, calling the low-level-ish cub::DeviceRadixSort directly takes ~20 lines of boilerplate code, I do not see the value of loosing CUDA graphs for that. I could agree if all the rest could be written in pytorch, which would make the thing work with all torch backends automatically. But since the list traversal requires a CUDA kernel anyway...
These things change, but last time I checked the access pattern resulting from having voxelStart and voxelEnd (while taking twice the memory) was much faster. Anyhow this is what this kernel is doing: torchmd-net/torchmdnet/neighbors/neighbors_cuda_cell.cu Lines 216 to 248 in 9d5028e
Again, I decided to use a kernel instead of trying to torchify it because I anticipate one can be smart about it to improve batch handling.
I initially used thrust::sort and some containers/allocators from thrust, seems like I forgot to take away some headers. Thrust has many convenience functions like min/max that replace the uncudable std alternatives, like min/max. |
I don't understand what the hashes are adding. Ultimately you want to know the list of atoms in a particular voxel. Sorting by hash rather than voxel index just adds a lot of code complexity and runtime overhead for no obvious benefit. I understand the benefit of hashes in the generic case, when you have arbitrarily sparse data scattered over an arbitrarily large volume of space. You might have billions of voxels, almost all of which are empty. But that's not the case in molecular simulations. The data is evenly distributed over a small area of space. The number of voxels is in the thousands at most, and few are empty.
You can replace a few hundred lines of CUDA code with probably 10-20 lines of Python, and it will be just as fast. This isn't the bottleneck operation. |
torch::sort for minimal performance loss. Change block size to 128
I believe I did not conveyed correctly a detail in my strategy, I do not only compute the cell index/hash of each position and sort the indexes, I also reorder the positions and batch arrays according to this hash and use these sorted copies when traversing the cell list. This has a profound impact in performance because it increases data locality, not only in terms of cache but also increasing coherence. This is where the Z-order hash comes into play, since one goes over the neighboring 27 cells, a Z order increases the change of neighboring cells being contiguous in memory. A simple linear cell index makes it so one has to jump n.x*n.y elements in the voxelStart array when traversing the cells of a different height. This improved things by like 20% when I first implemented this (in a GTX980), I checked now in a 4090 and a Titan V and the effect of this is negligible. I guess cache sizes are crazy now :p. Or maybe the overhead of the atomic addition of neighbor pairs hides any gains from this. Now, recording the voxel index of each particle requires some arithmetics and checks, my pytorch-fu is not enough to see how to transform this into an efficient sequence of torch calls: torchmd-net/torchmdnet/neighbors/neighbors_cuda_cell.cuh Lines 56 to 89 in 48e40e3
I am also scared to torchify just this, since the above device functions are also used when traversing and translating construction to torch would require implementing this logic twice.
I do not have any hope for the cell list to be a viable option for inference of small molecules. I am thinking about dense large systems, like a water box, or some protein with implicit water. Something like 32^3 cells and above. To me anything below that sounds like a N^2 with extra steps. Maybe I should aim for other stuff? |
We need this in production as soon as possible. After discussing with @RaulPPelaez, we decided:
|
In a water box, there are usually no empty voxels. So anything you do is going to be linear in the number of voxels (which is linear in the number of atoms).
Ok. Let me do one review pass through the code first.
Agreed. |
@@ -77,6 +79,164 @@ def message(self, x_j, W): | |||
return x_j * W | |||
|
|||
|
|||
class OptimizedDistance(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to just replace the existing Distance class, instead of adding another class with a different name? Are there cases when someone would prefer the old class instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I did not did this is because I am weary of a current use of Distance relying on something unexpected (maybe related to the ordering?). I am waiting for regular users of each model to tell me "I trained with the new Distance and everything is ok". Maybe a bit superstitious, but if it is all the same I would rather do another PR to replace current uses.
int3 periodic_cell = cell; | ||
if (cell.x < 0) | ||
periodic_cell.x += cell_dim.x; | ||
if (cell.x >= cell_dim.x) | ||
periodic_cell.x -= cell_dim.x; | ||
if (cell.y < 0) | ||
periodic_cell.y += cell_dim.y; | ||
if (cell.y >= cell_dim.y) | ||
periodic_cell.y -= cell_dim.y; | ||
if (cell.z < 0) | ||
periodic_cell.z += cell_dim.z; | ||
if (cell.z >= cell_dim.z) | ||
periodic_cell.z -= cell_dim.z; | ||
return periodic_cell; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes it will never be off by more than one grid width. Is that guaranteed to be true? A safer (and much simpler) implementation is
return make_int3(cell.x%cell_dim.x, cell.y%cell_dim.y, cell.z%cell_dim.z);
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Careful, in C++ -1%10 is -1, not 9
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which is why the previous suggestion for getCell using modulus also does not work when particles are left to the main box, ugghhhh.
What am I missing here?
This reverts commit 5459db6.
check PBC correctness
I am going to merge this now so we can move on. @peastman please feel free to PR if you see how to transform some of the kernels to torch ops, I gave it a try but got nowhere. |
This PR includes a new module called DistanceCellList (looking for a better name), which is an alternative to the Distance module that provides three strategies to find neighbors:
The current Distance module has some drawbacks:
The new module solves all these by being:
*There is a caveat, Distance requires max_num_neighs (maximum allows neighbors per particle), DistanceCellList requires max_num_pairs (maximum number of total neighbor pairs).
This is the current declaration of the new module:
torchmd-net/torchmdnet/models/utils.py
Lines 81 to 129 in 730b0a1
Changes to the installation process
This operation is written in C++ and CUDA. TorchMD-Net does not currently have an in place build system being only python.
I build this as a torch cpp_extension with jit compilation
. Meaning that the CUDA/C++ code is compiled transparently the first time DistanceCellList is instantiated in a way compatible with the current "pip install -e ." workflow.
If an user does not use the new module, nothing is compiled and no additional dependencies/overhead are required.
OTOH, a user constructing DistanceCellList will:
Tasks:
Distance
module.Challenges:
The alternative, constructing a cell list per batch, requires much more memory and cannot be done without GPU-CPU memory copies.