-
-
Notifications
You must be signed in to change notification settings - Fork 65
Initial commit of pivoted Cholesky algorithm from GPyTorch #63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
3fd065e
Initial commit of the linear conjugate gradients
kunalghosh e1cf869
Updated __init__ so that we can import linear_cg
kunalghosh 680afa6
Initial commit of pivoted cholesky
kunalghosh c42bb35
Fixed the name of pivoted cholesky function
kunalghosh 21ed2d5
Since we are invoking the function as a class attribute, removing the…
kunalghosh 9290e1f
Merge branch 'linearCG' into pivoted_cholesky
kunalghosh 021a3e7
Removing linear conjugate gradients from this branch
kunalghosh bbe525b
Also removing linear cg from __init__
kunalghosh 9ed06f3
Fixed the import
kunalghosh 2d5e052
Adding dependencies for pivoted_cholesky, they are also needed for tests
kunalghosh cd957c8
Added correct package for pytorch
kunalghosh f0b4855
Added try...catch block to ensure users who don't have these packages…
kunalghosh f3a13f5
Added import checks in the test file as well
kunalghosh 2a12943
Merge branch 'pymc-devs:main' into pivoted_cholesky
kunalghosh 03b7bf7
removed unused commits
kunalghosh 9009427
removing pytorch and gpytorch from requirements.
kunalghosh 852d95c
pre-commit wouldn't let me commit print statements
kunalghosh 7d1fb4e
removing the test for now
kunalghosh 8aba7b0
Raising an ImportError instead of printing
kunalghosh 9b8a132
Addressing stylistic comment
kunalghosh 685b99d
Resolving merge conflict
kunalghosh 3676a51
formatting fixes
kunalghosh 4a9a8c0
formatting fixes
kunalghosh 96229a8
pre-commit modifications
kunalghosh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# try: | ||
# import gpytorch | ||
# import torch | ||
# except ImportError as e: | ||
# # print( | ||
# # f"Please install Pytorch and GPyTorch to use this pivoted Cholesky implementation. Error {e}" | ||
# # ) | ||
# pass | ||
# import numpy as np | ||
# | ||
# import pymc_experimental as pmx | ||
# | ||
# | ||
# def test_match_gpytorch_linearcg_output(): | ||
# N = 10 | ||
# rank = 5 | ||
# np.random.seed(1234) # nans with seed 1234 | ||
# K = np.random.randn(N, N) | ||
# K = K @ K.T + N * np.eye(N) | ||
# K_torch = torch.from_numpy(K) | ||
# | ||
# L_gpt = gpytorch.pivoted_cholesky(K_torch, rank=rank, error_tol=1e-3) | ||
# L_np, _ = pmx.utils.pivoted_cholesky(K, max_iter=rank, error_tol=1e-3) | ||
# assert np.allclose(L_gpt, L_np.T) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
try: | ||
import torch | ||
from gpytorch.utils.permutation import apply_permutation | ||
except ImportError as e: | ||
raise ImportError("PyTorch and GPyTorch not found.") from e | ||
|
||
import numpy as np | ||
|
||
pp = lambda x: np.array2string(x, precision=4, floatmode="fixed") | ||
|
||
|
||
def pivoted_cholesky(mat: np.matrix, error_tol=1e-6, max_iter=np.inf): | ||
""" | ||
mat: numpy matrix of N x N | ||
|
||
This is to replicate what is done in GPyTorch verbatim. | ||
""" | ||
n = mat.shape[-1] | ||
max_iter = min(int(max_iter), n) | ||
|
||
d = np.array(np.diag(mat)) | ||
orig_error = np.max(d) | ||
error = np.linalg.norm(d, 1) / orig_error | ||
pi = np.arange(n) | ||
|
||
L = np.zeros((max_iter, n)) | ||
|
||
m = 0 | ||
while m < max_iter and error > error_tol: | ||
permuted_d = d[pi] | ||
max_diag_idx = np.argmax(permuted_d[m:]) | ||
max_diag_idx = max_diag_idx + m | ||
max_diag_val = permuted_d[max_diag_idx] | ||
i = max_diag_idx | ||
|
||
# swap pi_m and pi_i | ||
pi[m], pi[i] = pi[i], pi[m] | ||
pim = pi[m] | ||
|
||
L[m, pim] = np.sqrt(max_diag_val) | ||
|
||
if m + 1 < n: | ||
row = apply_permutation( | ||
torch.from_numpy(mat), torch.tensor(pim), right_permutation=None | ||
) # left permutation just swaps row | ||
row = row.numpy().flatten() | ||
pi_i = pi[m + 1 :] | ||
L_m_new = row[pi_i] # length = 9 | ||
|
||
if m > 0: | ||
L_prev = L[:m, pi_i] | ||
update = L[:m, pim] | ||
prod = update @ L_prev | ||
L_m_new = L_m_new - prod # np.sum(prod, axis=-1) | ||
|
||
L_m = L[m, :] | ||
L_m_new = L_m_new / L_m[pim] | ||
L_m[pi_i] = L_m_new | ||
|
||
matrix_diag_current = d[pi_i] | ||
d[pi_i] = matrix_diag_current - L_m_new**2 | ||
|
||
L[m, :] = L_m | ||
error = np.linalg.norm(d[pi_i], 1) / orig_error | ||
m = m + 1 | ||
return L, pi |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why you comment the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was triggering the import of
GPyTorch
andTorch
which would cause the tests to fail. I thought it would be best to comment out the tests until I remove the dependency of my code on GPyTorch and Torch.