Skip to content

Adding a cell list neighbor list module #169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 79 commits into from
May 29, 2023

Conversation

RaulPPelaez
Copy link
Collaborator

@RaulPPelaez RaulPPelaez commented May 3, 2023

This PR includes a new module called DistanceCellList (looking for a better name), which is an alternative to the Distance module that provides three strategies to find neighbors:

  1. The O(N^2) getNeighborPairs functionality from NNPops, referred to as brute, which required making it batch-aware. The batching functionality is actually a really minimal modification, so this could be upstreamed to NNPops and that used instead.
  2. A cell list distance module, using a classic spatial hashing approach (see here)
  3. An improved O(N^2) approach referred to as shared (from NVIDIA here)

The current Distance module has some drawbacks:

  1. Really slow for a single batch
  2. Does not understand periodic boundary conditions
  3. Incompatible with CUDA graphs
  4. Always returns redundant pairs (give you i,j and j,i). This is required for the current Message Passing modules.

The new module solves all these by being:

  1. Two orders of magnitude faster for a single batch (while being at least as fast as Distance in all instances tested)
  2. CUDA graph compatible (and jit.script compatible)
  3. Drop in replacement to Distance when using default parameters*.
  4. Compatible with periodic boundary conditions.
  5. Able to optionally skip redundant neighbors, saving time and memory.

*There is a caveat, Distance requires max_num_neighs (maximum allows neighbors per particle), DistanceCellList requires max_num_pairs (maximum number of total neighbor pairs).

#From DistanceCellList:
"""
        max_num_pairs : int
            Maximum number of pairs to store.
            If negative, it is interpreted as (minus) the maximum number of neighbors per atom.
"""

This is the current declaration of the new module:

class DistanceCellList(torch.nn.Module):
def __init__(
self,
cutoff_lower=0.0,
cutoff_upper=5.0,
max_num_pairs=-32,
return_vecs=False,
loop=False,
strategy="brute",
include_transpose=True,
resize_to_fit=True,
check_errors=True,
box=None,
):
super(DistanceCellList, self).__init__()
""" Compute the neighbor list for a given cutoff.
This operation can be placed inside a CUDA graph in some cases.
In particular, resize_to_fit and check_errors must be False.
Parameters
----------
cutoff_lower : float
Lower cutoff for the neighbor list.
cutoff_upper : float
Upper cutoff for the neighbor list.
max_num_pairs : int
Maximum number of pairs to store.
If negative, it is interpreted as (minus) the maximum number of neighbors per atom.
strategy : str
Strategy to use for computing the neighbor list. Can be one of
["shared", "brute", "cell"].
Shared: An O(N^2) algorithm that leverages CUDA shared memory, best for large number of particles.
Brute: A brute force O(N^2) algorithm, best for small number of particles.
Cell: A cell list algorithm, best for large number of particles, low cutoffs and low batch size.
box : Optional[torch.Tensor]
Size of the box, shape (3,3) or None.
If strategy is "cell", the box must be diagonal.
loop : bool
Whether to include self-interactions.
include_transpose : bool
Whether to include the transpose of the neighbor list.
resize_to_fit : bool
Whether to resize the neighbor list to the actual number of pairs found. When False, the list is padded with (-1,-1) pairs up to max_num_pairs
If this is True the operation is not CUDA graph compatible.
check_errors : bool
Whether to check for too many pairs. If this is True the operation is not CUDA graph compatible.
return_vecs : bool
Whether to return the distance vectors.
"""

Changes to the installation process

This operation is written in C++ and CUDA. TorchMD-Net does not currently have an in place build system being only python.
I build this as a torch cpp_extension with jit compilation
. Meaning that the CUDA/C++ code is compiled transparently the first time DistanceCellList is instantiated in a way compatible with the current "pip install -e ." workflow.
If an user does not use the new module, nothing is compiled and no additional dependencies/overhead are required.

OTOH, a user constructing DistanceCellList will:

  1. Require NVCC installed with a CUDA version ABI compatible with torch (cudatoolkit-dev from conda-forge works)
  2. Experience a several minutes compilation time the first time DistanceCellList is used after "pip install -e ." (torch will cache the compilation).

Tasks:

  • Adapt the neighbor functions in NNPops to torchmd-net "pip install -e ." installation (thanks POC: neighbor search #61).
  • Implement a batch-aware cell list neighbor construction
    • Optimize the cell list.
    • Make the operation CUDA graph compatible
  • Add tests
  • Add benchmark
  • Feature parity with current Distance module.
    • Add the "loop" parameter (include self-interactions)
    • Add the "cutoff_lower" parameter
  • Add periodic boundary conditions
    • Support rectangular boxes
    • Support triclinic boxes
      • Brute force
      • Cell list
  • Add backwards pass
    • CPU (Autograd takes care of it)
    • GPU (common backwards pass for every strategy)
  • torch.jit.script compatibility

Challenges:

  1. The brute approach cannot handle more than 32K particles total. AFAIK, there is no way to make it work without destroying what makes it be so fast. Anyhow, this strategy is not really suited for such high workloads. There is a guard that simply forces the shared strategy to be used if the user selected brute but more than 32K particles are requested.
  2. The performance of the cell strategy degrades quickly with the number of batches. This is because I construct a single cell list such that particles with the same cell are contiguous in memory. When traversing the cell list, one thus finds particles from all batches, forcing a lot of unnecessary checks. I mitigate this by sorting by batch inside each cell, allowing to skip some pairs, but this could be done in a more smart way, I am sure (maybe a binary search looking for the first atom in the current particle`s batch?).
    The alternative, constructing a cell list per batch, requires much more memory and cannot be done without GPU-CPU memory copies.
  3. Automatically choosing a strategy based on some heuristic. I tried this in a million ways, but jit.script is not taking it. The heuristic cannot be applied until the forward method (when positions and batch are known), changing the function dynamically like that is just not something that TorchScript supports as far as I can tell.

triclinic boxes, cell list only rectangular)
Some optimizations overall
Update tests
Update benchmark
Move error checking python-side
update tests
update benchmarks
Allow batch to be None (defaults to all zeros)
Use pytorch CUDACachingAllocator with thrust::sort and for temporary memory
@RaulPPelaez
Copy link
Collaborator Author

I did a quick skim through the code. A thorough review is going to take a long time, given how much code there is! But I had a couple of high level thoughts about it.

Thanks for taking a look Peter! I am aware its a lot of code -.-

First, I wonder if the CUDA parts would go better in NNPOps? Building a neighbor list is a really common operation. That would make it available for more than just TorchMD-Net.

I think this would indeed be a good addition. Feel free to upstream this to NNPOps and we will then switch to that in torchmd-net. I can help too. Currently NNPOps does not support some of the things that are implemented here. In particular the possibility to include self interactions and "transpose" pairs.
These can be added after the NNPOps operation in a simple way, although I am not sure how to do so while remaining CUDA-graph compatible. I can only think of ways to go about it that cost performance and memory.

To port this to NNPOps we need to decide if we want to put the following functionality there:

  1. Add a lower cutoff
  2. Add the include_transpose and loop options
  3. How to handle the selection of the different strategies (three separated functions vs one with a parameter?)
  4. Batches.

All of these can be implemented in NNPOps as additions to the interface to remain retrocompatible.

As a side note, I studied the voxel implementation you shared. One thing that I found improves traversal performance a lot is to use a single loop to go over cells:

for (int i = 0; i < 27; i++) {
const int neighbor_cell = getNeighborCellIndex(cell_i, i, cell_dim);
addNeighborsForCell(i_atom, neighbor_cell, cell_list, box_size, list);
}

Instead of three nested ones:
https://github.com/openmm/openmm/blob/d6cca3903aa0be02c105aed4febcfa2747f48fc1/platforms/reference/src/SimTKReference/ReferenceNeighborList.cpp#L154

Its the classic "recompute whatever to avoid branching" CUDA tradeoff, also promotes unrolling. This is cool and all, but completely destroys the triclinic traversal strategy you have in the ref. Not sure how to go about it yet...

Second, this code seems way more complicated than it needs to be. For example, everything related to hashing. That's a textbook implementation of a generic voxel structure that can support arbitrarily sparse data points scattered over an arbitrarily large volume of space. But for the sorts of applications we care about, it's way more complexity than we need.

For molecular models where atoms are evenly distributed over a fixed volume, it's really easy to sort them into voxels.

1. Record the index of the voxel each atom is in.

2. Sort them.

3. For every voxel, record the index of the first atom in the sorted list that's in that voxel.

This is exactly what my code does, but I use a 64 bit hash composed of a Morton hash and the batch index to sort instead of just the cell linear index

auto ci = getCell(pi, box_size, cutoff);
// Calculate the hash
const uint32_t hash = hashMorton(ci);
// Create a hash combining the Morton hash and the batch index, so that atoms in the same cell
// are contiguous
const uint64_t hash_final = (static_cast<uint64_t>(hash) << 32) | i_batch;

Step 1 can be implemented with PyTorch in just a small amount of Python code.

Why separate this from the rest of the CUDA implementation? Assigning the hash is a small kernel launch in C++, and I am much more familiar with bit manipulation there.

Step 2 is a single call to torch.sort().

I resorted to use the Radix sort implementation that comes with CUDA via cub because:

  1. Tensor (and thus torch.sort) does not support uint64, which I want to include both cell-hash and batch as hash.
  2. I could not make torch.sort play well with CUDA graphs. I also had this problem with thrust::sort (both torch and thrust call cub::DeviceRadixSort like I do down the line). Both of them synchronize at some point, be it to allocate some temp memory, copy, or decide on the number of radix sweeps (this is a guess).

The Morton hash only uses 30 bits, so I could ignore the two last bits and use torch.long, solving 1. I chose not to do that because I suspect I can leverage these bits to improve batch handling in the future. Right now I construct a single cell list with all batches in it, but maybe a cell-list-per-batch is best? Which amounts to just switching the order here:

const uint64_t hash_final = (static_cast<uint64_t>(hash) << 32) | i_batch;

One hurtful thing to achieve this is that you do not actually know the number of batches CPU side, so you need to cudamemcpy to implement this. Not sure how to solve that...

For 2 though, calling the low-level-ish cub::DeviceRadixSort directly takes ~20 lines of boilerplate code, I do not see the value of loosing CUDA graphs for that. I could agree if all the rest could be written in pytorch, which would make the thing work with all torch backends automatically. But since the list traversal requires a CUDA kernel anyway...
Also, cub exposes these begin_bit/end_bit parameters https://nvlabs.github.io/cub/structcub_1_1_device_radix_sort.html that thrust and torch hardcode to 8*sizeof(KeyT) but can be used to do funny things when sorting, such as sparing radix sweeps. Granted, not that important when sorting is basically instantaneous, but still useful because it saves kernel launches.

Step 3 is a single linear pass through the sorted list, though it's also easy to parallelize for efficiency. And you're done. The atoms in voxel i are the ones from voxelStart[i] to voxelStart[i+1].

These things change, but last time I checked the access pattern resulting from having voxelStart and voxelEnd (while taking twice the memory) was much faster. Anyhow this is what this kernel is doing:

template <typename scalar_t>
__global__ void fillCellOffsetsD(const Accessor<scalar_t, 2> sorted_positions,
const Accessor<int32_t, 1> sorted_indices,
Accessor<int32_t, 1> cell_start, Accessor<int32_t, 1> cell_end,
scalar3<scalar_t> box_size, scalar_t cutoff) {
// Since positions are sorted by cell, for a given atom, if the previous atom is in a different
// cell, then the current atom is the first atom in its cell We use this fact to fill the
// cell_start and cell_end arrays
const int32_t i_atom = blockIdx.x * blockDim.x + threadIdx.x;
if (i_atom >= sorted_positions.size(0))
return;
const auto pi = fetchPosition(sorted_positions, i_atom);
const int3 cell_dim = getCellDimensions(box_size, cutoff);
const int icell = getCellIndex(getCell(pi, box_size, cutoff), cell_dim);
int im1_cell;
if (i_atom > 0) {
int im1 = i_atom - 1;
const auto pim1 = fetchPosition(sorted_positions, im1);
im1_cell = getCellIndex(getCell(pim1, box_size, cutoff), cell_dim);
} else {
im1_cell = 0;
}
if (icell != im1_cell || i_atom == 0) {
int n_cells = cell_start.size(0);
cell_start[icell] = i_atom;
if (i_atom > 0) {
cell_end[im1_cell] = i_atom;
}
}
if (i_atom == sorted_positions.size(0) - 1) {
cell_end[icell] = i_atom + 1;
}
}

Again, I decided to use a kernel instead of trying to torchify it because I anticipate one can be smart about it to improve batch handling.

Bringing in Thrust as a dependency also seems unnecessary. It looks like mostly all you're using it for are the min() and max() functions?

I initially used thrust::sort and some containers/allocators from thrust, seems like I forgot to take away some headers. Thrust has many convenience functions like min/max that replace the uncudable std alternatives, like min/max.
I do not see thrust as a dependency since it comes with CUDA, if you can include cuda.h you can include thrust. I can just roll my own thrust::min max, but why negate us from the utilities in thrust? its a good library and its always there AFAIK.
Is there a situation in which you have the CUDA headers but not thrust?

@peastman
Copy link
Collaborator

I don't understand what the hashes are adding. Ultimately you want to know the list of atoms in a particular voxel. Sorting by hash rather than voxel index just adds a lot of code complexity and runtime overhead for no obvious benefit.

I understand the benefit of hashes in the generic case, when you have arbitrarily sparse data scattered over an arbitrarily large volume of space. You might have billions of voxels, almost all of which are empty. But that's not the case in molecular simulations. The data is evenly distributed over a small area of space. The number of voxels is in the thousands at most, and few are empty.

Why separate this from the rest of the CUDA implementation?

You can replace a few hundred lines of CUDA code with probably 10-20 lines of Python, and it will be just as fast. This isn't the bottleneck operation.

@RaulPPelaez
Copy link
Collaborator Author

I don't understand what the hashes are adding. Ultimately you want to know the list of atoms in a particular voxel. Sorting by hash rather than voxel index just adds a lot of code complexity and runtime overhead for no obvious benefit.

I believe I did not conveyed correctly a detail in my strategy, I do not only compute the cell index/hash of each position and sort the indexes, I also reorder the positions and batch arrays according to this hash and use these sorted copies when traversing the cell list.

This has a profound impact in performance because it increases data locality, not only in terms of cache but also increasing coherence.

This is where the Z-order hash comes into play, since one goes over the neighboring 27 cells, a Z order increases the change of neighboring cells being contiguous in memory. A simple linear cell index makes it so one has to jump n.x*n.y elements in the voxelStart array when traversing the cells of a different height.

This improved things by like 20% when I first implemented this (in a GTX980), I checked now in a 4090 and a Titan V and the effect of this is negligible. I guess cache sizes are crazy now :p. Or maybe the overhead of the atomic addition of neighbor pairs hides any gains from this.
So ok, lets leave the simpler cell index as hash.
I also tried to not include the batch in the hash, which prevents from breaking early during traversal based on batch. Surprisingly, skipping the batch check actually increases performance a bit. Maybe actually loading chunks of other batches actually helps cache overall?
I took it out, which allows me to leave the hash as an int32 and use torch::sort. Annoyingly torch::sort returns the indexes in int64, requiring an extra cast. However this is negligible.

Now, recording the voxel index of each particle requires some arithmetics and checks, my pytorch-fu is not enough to see how to transform this into an efficient sequence of torch calls:

/*
* @brief Get the cell index of a point
* @param p The point position
* @param box_size The size of the box in each dimension
* @param cutoff The cutoff
* @return The cell index
*/
template <typename scalar_t>
__device__ int3 getCell(scalar3<scalar_t> p, scalar3<scalar_t> box_size, scalar_t cutoff) {
p = rect::apply_pbc<scalar_t>(p, box_size);
// Take to the [0, box_size] range and divide by cutoff (which is the cell size)
int cx = floorf((p.x + scalar_t(0.5) * box_size.x) / cutoff);
int cy = floorf((p.y + scalar_t(0.5) * box_size.y) / cutoff);
int cz = floorf((p.z + scalar_t(0.5) * box_size.z) / cutoff);
int3 cell_dim = getCellDimensions(box_size, cutoff);
// Wrap around. If the position of a particle is exactly box_size, it will be in the last cell,
// which results in an illegal access down the line.
if (cx == cell_dim.x)
cx = 0;
if (cy == cell_dim.y)
cy = 0;
if (cz == cell_dim.z)
cz = 0;
return make_int3(cx, cy, cz);
}
/*
* @brief Get the index of a cell in a 1D array of cells.
* @param cell The cell coordinates, assumed to be in the range [0, cell_dim].
* @param cell_dim The number of cells in each dimension
*/
__device__ int getCellIndex(int3 cell, int3 cell_dim) {
return cell.x + cell_dim.x * (cell.y + cell_dim.y * cell.z);
}

I am also scared to torchify just this, since the above device functions are also used when traversing and translating construction to torch would require implementing this logic twice.

But that's not the case in molecular simulations. The data is evenly distributed over a small area of space. The number of voxels is in the thousands at most, and few are empty.

I do not have any hope for the cell list to be a viable option for inference of small molecules. I am thinking about dense large systems, like a water box, or some protein with implicit water. Something like 32^3 cells and above. To me anything below that sounds like a N^2 with extra steps. Maybe I should aim for other stuff?

@raimis
Copy link
Collaborator

raimis commented May 25, 2023

We need this in production as soon as possible. After discussing with @RaulPPelaez, we decided:

  • The cell list algorithm seems to be good enough for our purpose. Further improvements (if any) will be in a separate PR.
  • The cell list algorithm does not support the triclinic cells. We don't need that at the moment. Let's just open an issue to remind.

@peastman
Copy link
Collaborator

I am thinking about dense large systems, like a water box, or some protein with implicit water. Something like 32^3 cells and above.

In a water box, there are usually no empty voxels. So anything you do is going to be linear in the number of voxels (which is linear in the number of atoms).

The cell list algorithm seems to be good enough for our purpose. Further improvements (if any) will be in a separate PR.

Ok. Let me do one review pass through the code first.

The cell list algorithm does not support the triclinic cells. We don't need that at the moment. Let's just open an issue to remind.

Agreed.

@@ -77,6 +79,164 @@ def message(self, x_j, W):
return x_j * W


class OptimizedDistance(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to just replace the existing Distance class, instead of adding another class with a different name? Are there cases when someone would prefer the old class instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I did not did this is because I am weary of a current use of Distance relying on something unexpected (maybe related to the ordering?). I am waiting for regular users of each model to tell me "I trained with the new Distance and everything is ok". Maybe a bit superstitious, but if it is all the same I would rather do another PR to replace current uses.

Comment on lines +73 to +86
int3 periodic_cell = cell;
if (cell.x < 0)
periodic_cell.x += cell_dim.x;
if (cell.x >= cell_dim.x)
periodic_cell.x -= cell_dim.x;
if (cell.y < 0)
periodic_cell.y += cell_dim.y;
if (cell.y >= cell_dim.y)
periodic_cell.y -= cell_dim.y;
if (cell.z < 0)
periodic_cell.z += cell_dim.z;
if (cell.z >= cell_dim.z)
periodic_cell.z -= cell_dim.z;
return periodic_cell;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes it will never be off by more than one grid width. Is that guaranteed to be true? A safer (and much simpler) implementation is

return make_int3(cell.x%cell_dim.x, cell.y%cell_dim.y, cell.z%cell_dim.z);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Careful, in C++ -1%10 is -1, not 9

Copy link
Collaborator Author

@RaulPPelaez RaulPPelaez May 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which is why the previous suggestion for getCell using modulus also does not work when particles are left to the main box, ugghhhh.
What am I missing here?

@RaulPPelaez
Copy link
Collaborator Author

I am going to merge this now so we can move on. @peastman please feel free to PR if you see how to transform some of the kernels to torch ops, I gave it a try but got nowhere.
Will open a new PR as I switch current uses of Distance with OptimizedDistance.

@RaulPPelaez RaulPPelaez merged commit e20876f into torchmd:main May 29, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants