Skip to content

Initial commit of pivoted Cholesky algorithm from GPyTorch #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3fd065e
Initial commit of the linear conjugate gradients
kunalghosh Aug 2, 2022
e1cf869
Updated __init__ so that we can import linear_cg
kunalghosh Aug 2, 2022
680afa6
Initial commit of pivoted cholesky
kunalghosh Aug 2, 2022
c42bb35
Fixed the name of pivoted cholesky function
kunalghosh Aug 2, 2022
21ed2d5
Since we are invoking the function as a class attribute, removing the…
kunalghosh Aug 2, 2022
9290e1f
Merge branch 'linearCG' into pivoted_cholesky
kunalghosh Aug 2, 2022
021a3e7
Removing linear conjugate gradients from this branch
kunalghosh Aug 2, 2022
bbe525b
Also removing linear cg from __init__
kunalghosh Aug 2, 2022
9ed06f3
Fixed the import
kunalghosh Aug 2, 2022
2d5e052
Adding dependencies for pivoted_cholesky, they are also needed for tests
kunalghosh Aug 3, 2022
cd957c8
Added correct package for pytorch
kunalghosh Aug 11, 2022
f0b4855
Added try...catch block to ensure users who don't have these packages…
kunalghosh Sep 27, 2022
f3a13f5
Added import checks in the test file as well
kunalghosh Sep 27, 2022
2a12943
Merge branch 'pymc-devs:main' into pivoted_cholesky
kunalghosh Sep 27, 2022
03b7bf7
removed unused commits
kunalghosh Sep 27, 2022
9009427
removing pytorch and gpytorch from requirements.
kunalghosh Sep 27, 2022
852d95c
pre-commit wouldn't let me commit print statements
kunalghosh Sep 27, 2022
7d1fb4e
removing the test for now
kunalghosh Sep 27, 2022
8aba7b0
Raising an ImportError instead of printing
kunalghosh Oct 14, 2022
9b8a132
Addressing stylistic comment
kunalghosh Oct 14, 2022
685b99d
Resolving merge conflict
kunalghosh Oct 14, 2022
3676a51
formatting fixes
kunalghosh Oct 14, 2022
4a9a8c0
formatting fixes
kunalghosh Oct 14, 2022
96229a8
pre-commit modifications
kunalghosh Oct 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions pymc_experimental/tests/test_pivoted_cholesky.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why you comment the test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was triggering the import of GPyTorch and Torch which would cause the tests to fail. I thought it would be best to comment out the tests until I remove the dependency of my code on GPyTorch and Torch.

# import gpytorch
# import torch
# except ImportError as e:
# # print(
# # f"Please install Pytorch and GPyTorch to use this pivoted Cholesky implementation. Error {e}"
# # )
# pass
# import numpy as np
#
# import pymc_experimental as pmx
#
#
# def test_match_gpytorch_linearcg_output():
# N = 10
# rank = 5
# np.random.seed(1234) # nans with seed 1234
# K = np.random.randn(N, N)
# K = K @ K.T + N * np.eye(N)
# K_torch = torch.from_numpy(K)
#
# L_gpt = gpytorch.pivoted_cholesky(K_torch, rank=rank, error_tol=1e-3)
# L_np, _ = pmx.utils.pivoted_cholesky(K, max_iter=rank, error_tol=1e-3)
# assert np.allclose(L_gpt, L_np.T)
2 changes: 2 additions & 0 deletions pymc_experimental/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@

from pymc_experimental.utils import prior, spline
from pymc_experimental.utils.linear_cg import linear_cg

# from pymc_experimental.utils.pivoted_cholesky import pivoted_cholesky
66 changes: 66 additions & 0 deletions pymc_experimental/utils/pivoted_cholesky.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
try:
import torch
from gpytorch.utils.permutation import apply_permutation
except ImportError as e:
raise ImportError("PyTorch and GPyTorch not found.") from e

import numpy as np

pp = lambda x: np.array2string(x, precision=4, floatmode="fixed")


def pivoted_cholesky(mat: np.matrix, error_tol=1e-6, max_iter=np.inf):
"""
mat: numpy matrix of N x N

This is to replicate what is done in GPyTorch verbatim.
"""
n = mat.shape[-1]
max_iter = min(int(max_iter), n)

d = np.array(np.diag(mat))
orig_error = np.max(d)
error = np.linalg.norm(d, 1) / orig_error
pi = np.arange(n)

L = np.zeros((max_iter, n))

m = 0
while m < max_iter and error > error_tol:
permuted_d = d[pi]
max_diag_idx = np.argmax(permuted_d[m:])
max_diag_idx = max_diag_idx + m
max_diag_val = permuted_d[max_diag_idx]
i = max_diag_idx

# swap pi_m and pi_i
pi[m], pi[i] = pi[i], pi[m]
pim = pi[m]

L[m, pim] = np.sqrt(max_diag_val)

if m + 1 < n:
row = apply_permutation(
torch.from_numpy(mat), torch.tensor(pim), right_permutation=None
) # left permutation just swaps row
row = row.numpy().flatten()
pi_i = pi[m + 1 :]
L_m_new = row[pi_i] # length = 9

if m > 0:
L_prev = L[:m, pi_i]
update = L[:m, pim]
prod = update @ L_prev
L_m_new = L_m_new - prod # np.sum(prod, axis=-1)

L_m = L[m, :]
L_m_new = L_m_new / L_m[pim]
L_m[pi_i] = L_m_new

matrix_diag_current = d[pi_i]
d[pi_i] = matrix_diag_current - L_m_new**2

L[m, :] = L_m
error = np.linalg.norm(d[pi_i], 1) / orig_error
m = m + 1
return L, pi