Skip to content

Accelerate the limited TorchMD_GN with NNPOps #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Feb 22, 2022
Merged
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
key: ${{ runner.os }}-${{ env.CACHE_NUMBER }}-${{ hashFiles('environment.yml') }}
env:
# Increase this value to reset cache if environment.yml has not changed
CACHE_NUMBER: 0
CACHE_NUMBER: 1

- name: Create a conda environment
uses: conda-incubator/setup-miniconda@v2
Expand All @@ -51,4 +51,4 @@ jobs:
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics

- name: Run tests
run: pytest
run: pytest -v
520 changes: 520 additions & 0 deletions benchmarks/graph_network.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
name: torchmd-net
channels:
- raimis
- mmh
- conda-forge
dependencies:
- ase
- h5py
- matplotlib
# An official NNPOps packages still not available
- mmh::nnpops==0.2
- pip
- python
- pytorch==1.10.0
Expand Down
72 changes: 72 additions & 0 deletions tests/test_cfconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
from pytest import mark
import torch as pt
from torchmdnet.models.torchmd_gn import CFConv as RefCFConv
from torchmdnet.models.utils import Distance, GaussianSmearing, ShiftedSoftplus

from NNPOps.CFConv import CFConv
from NNPOps.CFConvNeighbors import CFConvNeighbors

@mark.parametrize('device', ['cpu', 'cuda'])
@mark.parametrize(['num_atoms', 'num_filters', 'num_rbfs'], [(3, 5, 7), (3, 7, 5), (5, 3, 7), (5, 7, 3), (7, 3, 5), (7, 5, 3)])
@mark.parametrize('cutoff_upper', [5.0, 10.0])
def test_cfconv(device, num_atoms, num_filters, num_rbfs, cutoff_upper):

if not pt.cuda.is_available() and device == 'cuda':
pytest.skip('No GPU')

device = pt.device(device)

# Generate random inputs
pos = (10 * pt.rand(num_atoms, 3, dtype=pt.float32, device=device) - 5).detach()
pos.requires_grad = True
input = pt.rand(num_atoms, num_filters, dtype=pt.float32, device=device)

# Construct a non-optimized CFConv object
dist = Distance(0.0, cutoff_upper).to(device)
rbf = GaussianSmearing(0.0, cutoff_upper, num_rbfs, trainable=False).to(device)
net = pt.nn.Sequential(
pt.nn.Linear(num_rbfs, num_filters),
ShiftedSoftplus(),
pt.nn.Linear(num_filters, num_filters))

# Randomize linear layers
net.requires_grad_(False)
pt.nn.init.normal_(net[0].weight)
pt.nn.init.normal_(net[0].bias)
pt.nn.init.normal_(net[2].weight)
pt.nn.init.normal_(net[2].bias)

ref_conv = RefCFConv(num_filters, num_filters, num_filters, net, 0.0, cutoff_upper).to(device)

# Disable the additional linear layers
ref_conv.requires_grad_(False)
ref_conv.lin1.weight.zero_()
ref_conv.lin1.weight.fill_diagonal_(1)
ref_conv.lin2.weight.zero_()
ref_conv.lin2.weight.fill_diagonal_(1)

# Compute with the non-optimized CFConv
edge_index, edge_weight, _ = dist(pos, batch=None)
edge_attr = rbf(edge_weight)
ref_output = ref_conv(input, edge_index, edge_weight, edge_attr)
ref_total = pt.sum(ref_output)
ref_total.backward()
ref_grad = pos.grad.clone()

pos.grad.zero_()

# Construct an optimize CFConv object
gaussianWidth = rbf.offset[1] - rbf.offset[0]
neighbors = CFConvNeighbors(cutoff_upper)
conv = CFConv(gaussianWidth, 'ssp', net[0].weight.T, net[0].bias, net[2].weight.T, net[2].bias)

# Compute with the optimized CFConv
neighbors.build(pos)
output = conv(neighbors, pos, input)
total = pt.sum(output)
total.backward()
grad = pos.grad.clone()

assert pt.allclose(ref_output, output, atol=5e-7)
assert pt.allclose(ref_grad, grad, atol=5e-7)
54 changes: 54 additions & 0 deletions tests/test_optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
from pytest import mark
import torch as pt
from torchmdnet.models.model import create_model
from torchmdnet.optimize import optimize

@mark.parametrize('device', ['cpu', 'cuda'])
@mark.parametrize('num_atoms', [10, 100])
def test_gn(device, num_atoms):

if not pt.cuda.is_available() and device == 'cuda':
pytest.skip('No GPU')

device = pt.device(device)

# Generate random inputs
elements = pt.randint(1, 100, (num_atoms,)).to(device)
positions = (10 * pt.rand((num_atoms, 3)) - 5).to(device)

# Crate a non-optimized model
# SchNet: TorchMD_GN(rbf_type='gauss', trainable_rbf=False, activation='ssp', neighbor_embedding=False)
args = {
'embedding_dimension': 128,
'num_layers': 6,
'num_rbf': 50,
'rbf_type': 'gauss',
'trainable_rbf': False,
'activation': 'ssp',
'neighbor_embedding': False,
'cutoff_lower': 0.0,
'cutoff_upper': 5.0,
'max_z': 100,
'max_num_neighbors': num_atoms,
'model': 'graph-network',
'aggr': 'add',
'derivative': True,
'atom_filter': -1,
'prior_model': None,
'output_model': 'Scalar',
'reduce_op': 'add'
}
ref_model = create_model(args).to(device)

# Execute the non-optimized model
ref_energy, ref_gradient = ref_model(elements, positions)

# Optimize the model
model = optimize(ref_model).to(device)

# Execute the optimize model
energy, gradient = model(elements, positions)

assert pt.allclose(ref_energy, energy, atol=5e-7)
assert pt.allclose(ref_gradient, gradient, atol=1e-5)
67 changes: 67 additions & 0 deletions torchmdnet/optimize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch as pt
from NNPOps.CFConv import CFConv
from NNPOps.CFConvNeighbors import CFConvNeighbors

from .models.model import TorchMD_Net
from .models.torchmd_gn import TorchMD_GN


class TorchMD_GN_optimized(pt.nn.Module):

def __init__(self, model):

if model.rbf_type != 'gauss':
raise ValueError('Only rbf_type="gauss" is supproted')
if model.trainable_rbf:
raise ValueError('trainalbe_rbf=True is not supported')
if model.activation != 'ssp':
raise ValueError('Only activation="ssp" is supported')
if model.neighbor_embedding:
raise ValueError('neighbor_embedding=True is not supported')
if model.cutoff_lower != 0.0:
raise ValueError('Only lower_cutoff=0.0 is supported')
if model.aggr != 'add':
raise ValueError('Only aggr="add" is supported')

super().__init__()
self.model = model

self.neighbors = CFConvNeighbors(self.model.cutoff_upper)

offset = self.model.distance_expansion.offset
width = offset[1] - offset[0]
self.convs = [CFConv(gaussianWidth=width, activation='ssp',
weights1=inter.mlp[0].weight.T, biases1=inter.mlp[0].bias,
weights2=inter.mlp[2].weight.T, biases2=inter.mlp[2].bias)
for inter in self.model.interactions]

def forward(self, z, pos, batch):

assert pt.all(batch == 0)

x = self.model.embedding(z)

self.neighbors.build(pos)
for inter, conv in zip(self.model.interactions, self.convs):
y = inter.conv.lin1(x)
y = conv(self.neighbors, pos, y)
y = inter.conv.lin2(y)
y = inter.act(y)
x = x + inter.lin(y)

return x, None, z, pos, batch

def __repr__(self):
return 'Optimized: ' + repr(self.model)


def optimize(model):

assert isinstance(model, TorchMD_Net)

if isinstance(model.representation_model, TorchMD_GN):
model.representation_model = TorchMD_GN_optimized(model.representation_model)
else:
raise ValueError('Unsupported model! Only TorchMD_GN is suppored.')

return model