Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NNCG to optimizers submodule #1661

Open
wants to merge 25 commits into
base: master
Choose a base branch
from
Open
Changes from 3 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d0e22bb
Add NNCG to optimizers submodule
pratikrathore8 Feb 22, 2024
3afed9f
Update nys_newton_cg.py
pratikrathore8 Feb 23, 2024
081d5f6
Moved NNCG to pytorch folder
pratikrathore8 Feb 27, 2024
03a77a1
Minor formatting changes in NNCG
pratikrathore8 Mar 1, 2024
88d2f7e
Update nys_newton_cg.py
pratikrathore8 Mar 4, 2024
fff6a91
Fix Codacy issues
pratikrathore8 Mar 4, 2024
19490ea
Fix more Codacy issues
pratikrathore8 Mar 4, 2024
ec59a99
Added NNCG to config.py and optimizers.py
pratikrathore8 Mar 11, 2024
8995aad
Clean up NNCG integration in optimizers.py
pratikrathore8 Mar 13, 2024
1b13a08
Fixed import order in optimizers.py
pratikrathore8 Mar 15, 2024
2d63ba3
Made demo with NNCG and Burgers equation
pratikrathore8 Apr 8, 2024
c356c90
Merge branch 'lululxvi:master' into master
pratikrathore8 Oct 28, 2024
8d0210a
refactor nncg integration
pratikrathore8 Oct 29, 2024
5c9beba
clean up commented code
pratikrathore8 Oct 29, 2024
ce1dcc8
format with black
pratikrathore8 Oct 29, 2024
d97ca16
remove unnecessary error checks
pratikrathore8 Oct 29, 2024
81c4452
fix some codacy issues in nncg
pratikrathore8 Oct 29, 2024
60d3ff3
further improvements to nncg integration
pratikrathore8 Oct 30, 2024
6e739e1
add train_step_nncg
pratikrathore8 Oct 31, 2024
f4e322e
improve documentation in nncg config
pratikrathore8 Nov 4, 2024
6338f59
added doc for nncg demo
pratikrathore8 Nov 5, 2024
838eaa8
added demo file to pinn_forward.rst
pratikrathore8 Nov 5, 2024
e0bb44d
Merge branch 'master' into master
pratikrathore8 Nov 7, 2024
61f08f9
Merge branch 'lululxvi:master' into master
pratikrathore8 Nov 13, 2024
7d58317
change ordering in pinn_forward.rst
pratikrathore8 Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 312 additions & 0 deletions deepxde/optimizers/pytorch/nys_newton_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
import torch
from torch.optim import Optimizer
from torch.func import vmap
from functools import reduce
pratikrathore8 marked this conversation as resolved.
Show resolved Hide resolved


def _armijo(f, x, gx, dx, t, alpha=0.1, beta=0.5):
"""Line search to find a step size that satisfies the Armijo condition."""
f0 = f(x, 0, dx)
f1 = f(x, t, dx)
while f1 > f0 + alpha * t * gx.dot(dx):
t *= beta
f1 = f(x, t, dx)
return t


def _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, x):
"""Applies the inverse of the Nystrom approximation of the Hessian to a vector."""
z = U.T @ x
z = (lambd_r + mu) * (U @ (S_mu_inv * z)) + (x - U @ z)
return z


def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters):
"""Solves a positive-definite linear system using NyströmPCG.

`Frangella et al. Randomized Nyström Preconditioning.
SIAM Journal on Matrix Analysis and Applications, 2023.
<https://epubs.siam.org/doi/10.1137/21M1466244>`"""
pratikrathore8 marked this conversation as resolved.
Show resolved Hide resolved
lambd_r = S[r - 1]
S_mu_inv = (S + mu) ** (-1)

resid = b - (hess(x) + mu * x)
with torch.no_grad():
z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid)
p = z.clone()

i = 0

while torch.norm(resid) > tol and i < max_iters:
v = hess(p) + mu * p
with torch.no_grad():
alpha = torch.dot(resid, z) / torch.dot(p, v)
x += alpha * p

rTz = torch.dot(resid, z)
resid -= alpha * v
z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid)
beta = torch.dot(resid, z) / rTz

p = z + beta * p

i += 1

if torch.norm(resid) > tol:
print(
f"Warning: PCG did not converge to tolerance. Tolerance was {tol} but norm of residual is {torch.norm(resid)}"
)

return x


class NysNewtonCG(Optimizer):
"""Implementation of NysNewtonCG, a damped Newton-CG method that uses Nyström preconditioning.

`Rathore et al. Challenges in Training PINNs: A Loss Landscape Perspective.
Preprint, 2024. <https://arxiv.org/abs/2402.01868>`

.. warning::
This optimizer doesn't support per-parameter options and parameter
groups (there can be only one).

NOTE: This optimizer is currently a beta version.

Our implementation is inspired by the PyTorch implementation of `L-BFGS
<https://pytorch.org/docs/stable/_modules/torch/optim/lbfgs.html#LBFGS>`.

The parameters rank and mu will probably need to be tuned for your specific problem.
If the optimizer is running very slowly, you can try one of the following:
- Increase the rank (this should increase the accuracy of the Nyström approximation in PCG)
- Reduce cg_tol (this will allow PCG to terminate with a less accurate solution)
- Reduce cg_max_iters (this will allow PCG to terminate after fewer iterations)

Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1.0)
rank (int, optional): rank of the Nyström approximation (default: 10)
mu (float, optional): damping parameter (default: 1e-4)
chunk_size (int, optional): number of Hessian-vector products to be computed in parallel (default: 1)
cg_tol (float, optional): tolerance for PCG (default: 1e-16)
cg_max_iters (int, optional): maximum number of PCG iterations (default: 1000)
line_search_fn (str, optional): either 'armijo' or None (default: None)
verbose (bool, optional): verbosity (default: False)

pratikrathore8 marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
params,
lr=1.0,
rank=10,
mu=1e-4,
chunk_size=1,
cg_tol=1e-16,
cg_max_iters=1000,
line_search_fn=None,
verbose=False,
):
defaults = dict(
lr=lr,
rank=rank,
chunk_size=chunk_size,
mu=mu,
cg_tol=cg_tol,
cg_max_iters=cg_max_iters,
line_search_fn=line_search_fn,
)
self.rank = rank
self.mu = mu
self.chunk_size = chunk_size
self.cg_tol = cg_tol
self.cg_max_iters = cg_max_iters
self.line_search_fn = line_search_fn
self.verbose = verbose
self.U = None
self.S = None
self.n_iters = 0
super(NysNewtonCG, self).__init__(params, defaults)

if len(self.param_groups) > 1:
raise ValueError(
"NysNewtonCG doesn't currently support per-parameter options (parameter groups)"
)

if self.line_search_fn is not None and self.line_search_fn != "armijo":
raise ValueError("NysNewtonCG only supports Armijo line search")

self._params = self.param_groups[0]["params"]
self._params_list = list(self._params)
self._numel_cache = None

def step(self, closure=None):
"""Perform a single optimization step.

Args:
closure (callable, optional): A closure that reevaluates the model and returns (i) the loss and (ii) gradient w.r.t. the parameters.
The closure can compute the gradient w.r.t. the parameters by calling torch.autograd.grad on the loss with create_graph=True.
pratikrathore8 marked this conversation as resolved.
Show resolved Hide resolved
"""
if self.n_iters == 0:
# Store the previous direction for warm starting PCG
self.old_dir = torch.zeros(self._numel(), device=self._params[0].device)

# NOTE: The closure must return both the loss and the gradient
loss = None
if closure is not None:
with torch.enable_grad():
loss, grad_tuple = closure()

g = torch.cat([grad.view(-1) for grad in grad_tuple if grad is not None])

# One step update
for group_idx, group in enumerate(self.param_groups):

def hvp_temp(x):
return self._hvp(g, self._params_list, x)

# Calculate the Newton direction
d = _nystrom_pcg(
hvp_temp,
g,
self.old_dir,
self.mu,
self.U,
self.S,
self.rank,
self.cg_tol,
self.cg_max_iters,
)

# Store the previous direction for warm starting PCG
self.old_dir = d

# Check if d is a descent direction
if torch.dot(d, g) <= 0:
print("Warning: d is not a descent direction")

if self.line_search_fn == "armijo":
x_init = self._clone_param()

def obj_func(x, t, dx):
self._add_grad(t, dx)
loss = float(closure()[0])
self._set_param(x)
return loss

# Use -d for convention
t = _armijo(obj_func, x_init, g, -d, group["lr"])
else:
t = group["lr"]

self.state[group_idx]["t"] = t

# update parameters
ls = 0
for p in group["params"]:
np = torch.numel(p)
dp = d[ls : ls + np].view(p.shape)
ls += np
p.data.add_(-dp, alpha=t)

self.n_iters += 1

return loss, g

def update_preconditioner(self, grad_tuple):
"""Update the Nystrom approximation of the Hessian.

Args:
grad_tuple (tuple): tuple of Tensors containing the gradients of the loss w.r.t. the parameters.
This tuple can be obtained by calling torch.autograd.grad on the loss with create_graph=True.
"""

# Flatten and concatenate the gradients
gradsH = torch.cat(
[gradient.view(-1) for gradient in grad_tuple if gradient is not None]
)

# Generate test matrix (NOTE: This is transposed test matrix)
p = gradsH.shape[0]
Phi = torch.randn((self.rank, p), device=gradsH.device) / (p**0.5)
Phi = torch.linalg.qr(Phi.t(), mode="reduced")[0].t()

Y = self._hvp_vmap(gradsH, self._params_list)(Phi)

# Calculate shift
shift = torch.finfo(Y.dtype).eps
Y_shifted = Y + shift * Phi

# Calculate Phi^T * H * Phi (w/ shift) for Cholesky
choleskytarget = torch.mm(Y_shifted, Phi.t())

# Perform Cholesky, if fails, do eigendecomposition
# The new shift is the abs of smallest eigenvalue (negative) plus the original shift
try:
C = torch.linalg.cholesky(choleskytarget)
except:
# eigendecomposition, eigenvalues and eigenvector matrix
eigs, eigvectors = torch.linalg.eigh(choleskytarget)
shift = shift + torch.abs(torch.min(eigs))
# add shift to eigenvalues
eigs = eigs + shift
# put back the matrix for Cholesky by eigenvector * eigenvalues after shift * eigenvector^T
C = torch.linalg.cholesky(
torch.mm(eigvectors, torch.mm(torch.diag(eigs), eigvectors.T))
)

try:
B = torch.linalg.solve_triangular(C, Y_shifted, upper=False, left=True)
# temporary fix for issue @ https://github.com/pytorch/pytorch/issues/97211
except:
B = torch.linalg.solve_triangular(
C.to("cpu"), Y_shifted.to("cpu"), upper=False, left=True
).to(C.device)

# B = V * S * U^T b/c we have been using transposed sketch
_, S, UT = torch.linalg.svd(B, full_matrices=False)
self.U = UT.t()
self.S = torch.max(torch.square(S) - shift, torch.tensor(0.0))

self.rho = self.S[-1]

if self.verbose:
print(f"Approximate eigenvalues = {self.S}")

def _hvp_vmap(self, grad_params, params):
return vmap(
lambda v: self._hvp(grad_params, params, v),
in_dims=0,
chunk_size=self.chunk_size,
)

def _hvp(self, grad_params, params, v):
Hv = torch.autograd.grad(grad_params, params, grad_outputs=v, retain_graph=True)
Hv = tuple(Hvi.detach() for Hvi in Hv)
return torch.cat([Hvi.reshape(-1) for Hvi in Hv])

def _numel(self):
if self._numel_cache is None:
self._numel_cache = reduce(
lambda total, p: total + p.numel(), self._params, 0
)
return self._numel_cache

def _add_grad(self, step_size, update):
offset = 0
for p in self._params:
numel = p.numel()
# Avoid in-place operation by creating a new tensor
p.data = p.data.add(
update[offset : offset + numel].view_as(p), alpha=step_size
)
offset += numel
assert offset == self._numel()

def _clone_param(self):
return [p.clone(memory_format=torch.contiguous_format) for p in self._params]

def _set_param(self, params_data):
for p, pdata in zip(self._params, params_data):
# Replace the .data attribute of the tensor
p.data = pdata.data