Skip to content

POC: neighbor search #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
dependencies:
- ase
- h5py
- gxx_linux-64
- matplotlib
# An official NNPOps packages still not available
- mmh::nnpops==0.2
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,7 @@
name="torchmd-net",
version=version,
packages=find_packages(),
package_data={"torchmdnet": ["neighbors/neighbors*"]},
include_package_data=True,
install_requires=requirements,
)
30 changes: 30 additions & 0 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from turtle import pos
import numpy as np
import pytest
import torch as pt
from torchmdnet.neighbors import get_neighbor_pairs

@pytest.mark.parametrize('num_atoms', [1, 2, 3, 4, 5, 10, 100, 1000, 10000])
@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_neighbors(num_atoms, device):

if not pt.cuda.is_available() and device == 'cuda':
pytest.skip('No GPU')

positions = pt.randn((num_atoms, 3), device=device)
device = positions.device

ref_neighbors = np.tril_indices(num_atoms, -1)
ref_positions = positions.cpu().numpy()
ref_distances = np.linalg.norm(ref_positions[ref_neighbors[0]] - ref_positions[ref_neighbors[1]], axis=1)

neighbors, distances = get_neighbor_pairs(positions)

assert neighbors.device == device
assert distances.device == device

assert neighbors.dtype == pt.int32
assert distances.dtype == pt.float32

assert np.all(ref_neighbors == neighbors.cpu().numpy())
assert np.allclose(ref_distances, distances.cpu().numpy())
9 changes: 9 additions & 0 deletions torchmdnet/neighbors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os
import torch as pt
from torch.utils import cpp_extension

sources = ['neighbors.cpp', 'neighbors_cpu.cpp'] + (['neighbors_cuda.cu'] if pt.cuda.is_available() else [])
sources = [os.path.join(os.path.dirname(__file__), name) for name in sources]

cpp_extension.load(name='neighbors', sources=sources, is_python_module=False)
get_neighbor_pairs = pt.ops.neighbors.get_neighbor_pairs
166 changes: 166 additions & 0 deletions torchmdnet/neighbors/demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Demonstation of the neigbor search operation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compile and import"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch as pt\n",
"from torchmdnet.neighbors import get_neighbor_pairs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"pos = pt.tensor([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Forward"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[1, 2, 2],\n",
" [0, 0, 1]], dtype=torch.int32),\n",
" tensor([ 5.1962, 10.3923, 5.1962]))"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_neighbor_pairs(pos.to('cpu'))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[1, 2, 2],\n",
" [0, 0, 1]], device='cuda:0', dtype=torch.int32),\n",
" tensor([ 5.1962, 10.3923, 5.1962], device='cuda:0'))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_neighbor_pairs(pos.to('cuda'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Forward and backward"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-1.1547, -1.1547, -1.1547],\n",
" [ 0.0000, 0.0000, 0.0000],\n",
" [ 1.1547, 1.1547, 1.1547]])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pos_ = pos.to('cpu').detach()\n",
"pos_.requires_grad = True\n",
"res = get_neighbor_pairs(pos_)\n",
"res[1].sum().backward()\n",
"pos_.grad"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"pos_ = pos.to('cuda').detach()\n",
"pos_.requires_grad = True\n",
"res = get_neighbor_pairs(pos_)\n",
"res[1].sum().backward()\n",
"pos_.grad"
]
}
],
"metadata": {
"interpreter": {
"hash": "475250b3cf807ed3cfcdd5b2e8760b9c29416cad54cd9aa950ed2b92e2b64699"
},
"kernelspec": {
"display_name": "Python 3.9.10 ('torchmd-net')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
5 changes: 5 additions & 0 deletions torchmdnet/neighbors/neighbors.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include <torch/extension.h>

TORCH_LIBRARY(neighbors, m) {
m.def("get_neighbor_pairs(Tensor positions, Scalar cutoff, Scalar max_num_neighbors) -> (Tensor neighbors, Tensor distances)");
}
40 changes: 40 additions & 0 deletions torchmdnet/neighbors/neighbors_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include <torch/extension.h>
#include <tuple>

using std::tuple;
using torch::div;
using torch::index_select;
using torch::arange;
using torch::frobenius_norm;
using torch::kInt32;
using torch::Scalar;
using torch::stack;
using torch::Tensor;

static tuple<Tensor, Tensor> forward(const Tensor& positions,
const Scalar& cutoff,
const Scalar& max_num_neighbors) {

TORCH_CHECK(positions.dim() == 2, "Expected \"positions\" to have two dimensions");
TORCH_CHECK(positions.size(0) > 0, "Expected the 1nd dimension size of \"positions\" to be more than 0");
TORCH_CHECK(positions.size(1) == 3, "Expected the 2nd dimension size of \"positions\" to be 3");
TORCH_CHECK(positions.is_contiguous(), "Expected \"positions\" to be contiguous");

const int num_atoms = positions.size(0);
const int num_pairs = num_atoms * (num_atoms - 1) / 2;

const Tensor indices = arange(0, num_pairs, positions.options().dtype(kInt32));
Tensor rows = (((8 * indices + 1).sqrt() + 1) / 2).floor().to(kInt32);
rows -= (rows * (rows - 1) > 2 * indices).to(kInt32);
const Tensor columns = indices - div(rows * (rows - 1), 2, "floor");

const Tensor neighbors = stack({rows, columns});
const Tensor vectors = index_select(positions, 0, rows) - index_select(positions, 0, columns);
const Tensor distances = frobenius_norm(vectors, 1);

return {neighbors, distances};
}

TORCH_LIBRARY_IMPL(neighbors, CPU, m) {
m.impl("get_neighbor_pairs", &forward);
}
Loading