From 37483fa24ff7541d4d9c1b94ed5de6bd04f724bc Mon Sep 17 00:00:00 2001 From: Ruunyox Date: Sun, 18 Jul 2021 22:45:51 +0200 Subject: [PATCH 1/2] added alpha scaling in expnorm and visualization function --- torchmdnet/models/utils.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 37c306cad..4eeffedbe 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -5,6 +5,31 @@ from torch_geometric.nn import radius_graph, MessagePassing +def visualize_basis(basis_type, num_rbf=50, cutoff_lower=0, cutoff_upper=5): + """ + Function for quickly visualizing a specific basis. This is useful for inspecting + the distance coverage of basis functions for non-default lower and upper cutoffs. + + Args: + basis_type (str): Specifies the type of basis functions used. Can be one of + ['gauss',expnorm'] + num_rbf (int, optional): The number of basis functions. + (default: :obj:`50`) + cutoff_lower (float, optional): The lower cutoff of the basis. + (default: :obj:`0`) + cutoff_upper (float, optional): The upper cutoff of the basis. + (default: :obj:`5`) + """ + distances = torch.linspace(cutoff_lower-1, cutoff_upper+1, 1000) + basis_kwargs = {'num_rbf':num_rbf, 'cutoff_lower':cutoff_lower, 'cutoff_upper':cutoff_upper} + basis_expansion = rbf_class_mapping[basis_type](**basis_kwargs) + expanded_distances = basis_expansion(distances) + + for i in range(expanded_distances.shape[-1]): + plt.plot(distances.numpy(), expanded_distances[:,i].detach().numpy()) + plt.show() + + class NeighborEmbedding(MessagePassing): def __init__(self, hidden_channels, num_rbf, cutoff_lower, cutoff_upper, max_z=100): @@ -83,6 +108,7 @@ def __init__(self, cutoff_lower=0.0, cutoff_upper=5.0, num_rbf=50, trainable=Tru self.trainable = trainable self.cutoff_fn = CosineCutoff(0, cutoff_upper) + self.alpha = 5.0/(cutoff_upper - cutoff_lower) means, betas = self._initial_params() if trainable: @@ -107,7 +133,7 @@ def reset_parameters(self): def forward(self, dist): dist = dist.unsqueeze(-1) - return self.cutoff_fn(dist) * torch.exp(-self.betas * (torch.exp(-dist + self.cutoff_lower) - self.means) ** 2) + return self.cutoff_fn(dist) * torch.exp(-self.betas * (torch.exp(self.alpha*(-dist + self.cutoff_lower)) - self.means) ** 2) class ShiftedSoftplus(nn.Module): From c73ed5b82acce2116638867113b53095518e93a3 Mon Sep 17 00:00:00 2001 From: Ruunyox Date: Mon, 19 Jul 2021 13:51:00 +0200 Subject: [PATCH 2/2] visualize-basis plt import --- torchmdnet/models/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchmdnet/models/utils.py b/torchmdnet/models/utils.py index 4eeffedbe..a5cd9e7bb 100644 --- a/torchmdnet/models/utils.py +++ b/torchmdnet/models/utils.py @@ -20,6 +20,8 @@ def visualize_basis(basis_type, num_rbf=50, cutoff_lower=0, cutoff_upper=5): cutoff_upper (float, optional): The upper cutoff of the basis. (default: :obj:`5`) """ + import matplotlib.pyplot as plt + distances = torch.linspace(cutoff_lower-1, cutoff_upper+1, 1000) basis_kwargs = {'num_rbf':num_rbf, 'cutoff_lower':cutoff_lower, 'cutoff_upper':cutoff_upper} basis_expansion = rbf_class_mapping[basis_type](**basis_kwargs)