diff --git a/deepxde/model.py b/deepxde/model.py index 4ebdf6859..dfc0f68ae 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -363,11 +363,27 @@ def closure(): if self.lr_scheduler is not None: self.lr_scheduler.step() + def train_step_nncg(inputs, targets, auxiliary_vars): + def closure(): + return get_loss_grad_nncg(inputs, targets, auxiliary_vars) + + self.opt.step(closure) + + def get_loss_grad_nncg(inputs, targets, auxiliary_vars): + losses = outputs_losses_train(inputs, targets, auxiliary_vars)[1] + total_loss = torch.sum(losses) + self.opt.zero_grad() + grad_tuple = torch.autograd.grad(total_loss, trainable_variables, + create_graph=True) + return total_loss, grad_tuple + # Callables self.outputs = outputs self.outputs_losses_train = outputs_losses_train self.outputs_losses_test = outputs_losses_test self.train_step = train_step + self.train_step_nncg = train_step_nncg + self.get_loss_grad_nncg = get_loss_grad_nncg def _compile_jax(self, lr, loss_fn, decay): """jax""" @@ -636,12 +652,22 @@ def train( self._test() self.callbacks.on_train_begin() if optimizers.is_external_optimizer(self.opt_name): + if self.opt_name == "NNCG" and backend_name != "pytorch": + raise ValueError( + "The optimizer 'NNCG' is only supported for the backend PyTorch." + ) if backend_name == "tensorflow.compat.v1": self._train_tensorflow_compat_v1_scipy(display_every) elif backend_name == "tensorflow": self._train_tensorflow_tfp() elif backend_name == "pytorch": - self._train_pytorch_lbfgs() + if self.opt_name == "L-BFGS": + self._train_pytorch_lbfgs() + elif self.opt_name == "NNCG": + self._train_pytorch_nncg(iterations, display_every) + else: + raise ValueError("Only 'L-BFGS' and 'NNCG' are supported as \ + external optimizers for PyTorch.") elif backend_name == "paddle": self._train_paddle_lbfgs() else: @@ -785,6 +811,52 @@ def _train_pytorch_lbfgs(self): if self.stop_training: break + def _train_pytorch_nncg(self, iterations, display_every): + # Loop over the iterations -- take inspiration from _train_pytorch_lbfgs and _train_sgd + for i in range(iterations): + # 1. Perform appropriate begin callbacks + self.callbacks.on_epoch_begin() + self.callbacks.on_batch_begin() + + # 2. Update the preconditioner (if applicable) + # 2.1. We can check if the preconditioner is updated by making an + # option in NNCG_options called update_freq. Do the usual modular arithmetic + # from there + if i % optimizers.NNCG_options["updatefreq"] == 0: + self.opt.zero_grad() + # 2.2. How do we actually do this? Get the sum of the losses as in + # train_step(), and use torch.autograd.grad to get a gradient + _, grad_tuple = self.get_loss_grad_nncg( + self.train_state.X_train, + self.train_state.y_train, + self.train_state.train_aux_vars, + ) + # 2.3. Plug the gradient into the NNCG update_preconditioner function + # to perform the update + self.opt.update_preconditioner(grad_tuple) + + # 3. Call the train step + self.train_step_nncg( + self.train_state.X_train, + self.train_state.y_train, + self.train_state.train_aux_vars, + ) + + # 4. Use self._test() if needed + self.train_state.epoch += 1 + self.train_state.step += 1 + if self.train_state.step % display_every == 0 or i + 1 == iterations: + self._test() + + # 5. Perform appropriate end callbacks + self.callbacks.on_batch_end() + self.callbacks.on_epoch_end() + + # 6. Allow for training to stop (if self.stop_training) + if self.stop_training: + break + + def _train_paddle_lbfgs(self): prev_n_iter = 0 diff --git a/deepxde/optimizers/__init__.py b/deepxde/optimizers/__init__.py index e1fcfced1..556761a86 100644 --- a/deepxde/optimizers/__init__.py +++ b/deepxde/optimizers/__init__.py @@ -1,7 +1,7 @@ import importlib import sys -from .config import LBFGS_options, set_LBFGS_options +from .config import LBFGS_options, set_LBFGS_options, NNCG_options, set_NNCG_options from ..backend import backend_name diff --git a/deepxde/optimizers/config.py b/deepxde/optimizers/config.py index 01ba8bd1f..285c0fdc1 100644 --- a/deepxde/optimizers/config.py +++ b/deepxde/optimizers/config.py @@ -1,9 +1,10 @@ -__all__ = ["set_LBFGS_options", "set_hvd_opt_options"] +__all__ = ["set_LBFGS_options", "set_NNCG_options", "set_hvd_opt_options"] from ..backend import backend_name from ..config import hvd LBFGS_options = {} +NNCG_options = {} if hvd is not None: hvd_opt_options = {} @@ -59,6 +60,55 @@ def set_LBFGS_options( LBFGS_options["maxfun"] = maxfun if maxfun is not None else int(maxiter * 1.25) LBFGS_options["maxls"] = maxls +def set_NNCG_options( + lr=1, + rank=10, + mu=1e-4, + updatefreq=20, + chunksz=1, + cgtol=1e-16, + cgmaxiter=1000, + lsfun="armijo", + verbose=False +): + """Sets the hyperparameters of NysNewtonCG (NNCG). + + Args: + lr (float): `lr` (torch). + Learning rate (before line search). + rank (int): `rank` (torch). + Rank of preconditioner matrix used in preconditioned conjugate gradient. + mu (float): `mu` (torch). + Hessian damping parameter. + updatefreq (int): How often the preconditioner matrix in preconditioned + conjugate gradient is updated. This parameter is not directly used in NNCG, + instead it is used in _train_pytorch_nncg in deepxde/model.py. + chunksz (int): `chunk_size` (torch). + Number of Hessian-vector products to compute in parallel when constructing + preconditioner. If `chunk_size` is 1, the Hessian-vector products are + computed serially. + cgtol (float): `cg_tol` (torch). + Convergence tolerance for the conjugate gradient method. The iteration stops + when `||r||_2 <= cgtol`, where `r` is the residual. Note that this condition + is based on the absolute tolerance, not the relative tolerance. + cgmaxiter (int): `cg_max_iters` (torch). + Maximum number of iterations for the conjugate gradient method. + lsfun (str): `line_search_fn` (torch). + The line search function used to find the step size. The default value is + "armijo". The other option is None. + verbose (bool): `verbose` (torch). + If `True`, prints the eigenvalues of the Nyström approximation + of the Hessian. + """ + NNCG_options["lr"] = lr + NNCG_options["rank"] = rank + NNCG_options["mu"] = mu + NNCG_options["updatefreq"] = updatefreq + NNCG_options["chunksz"] = chunksz + NNCG_options["cgtol"] = cgtol + NNCG_options["cgmaxiter"] = cgmaxiter + NNCG_options["lsfun"] = lsfun + NNCG_options["verbose"] = verbose def set_hvd_opt_options( compression=None, @@ -91,6 +141,7 @@ def set_hvd_opt_options( set_LBFGS_options() +set_NNCG_options() if hvd is not None: set_hvd_opt_options() diff --git a/deepxde/optimizers/pytorch/nncg.py b/deepxde/optimizers/pytorch/nncg.py new file mode 100644 index 000000000..a71e7da31 --- /dev/null +++ b/deepxde/optimizers/pytorch/nncg.py @@ -0,0 +1,324 @@ +from functools import reduce + +import torch +from torch.func import vmap +from torch.optim import Optimizer + + +def _armijo(f, x, gx, dx, t, alpha=0.1, beta=0.5): + """Line search to find a step size that satisfies the Armijo condition.""" + f0 = f(x, 0, dx) + f1 = f(x, t, dx) + while f1 > f0 + alpha * t * gx.dot(dx): + t *= beta + f1 = f(x, t, dx) + return t + + +def _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, x): + """Applies the inverse of the Nystrom approximation of the Hessian to a vector.""" + z = U.T @ x + z = (lambd_r + mu) * (U @ (S_mu_inv * z)) + (x - U @ z) + return z + + +def _nystrom_pcg(hess, b, x, mu, U, S, r, tol, max_iters): + """Solves a positive-definite linear system using NyströmPCG. + + `Frangella et al. Randomized Nyström Preconditioning. + SIAM Journal on Matrix Analysis and Applications, 2023. + ` + """ + lambd_r = S[r - 1] + S_mu_inv = (S + mu) ** (-1) + + resid = b - (hess(x) + mu * x) + with torch.no_grad(): + z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid) + p = z.clone() + + i = 0 + + while torch.norm(resid) > tol and i < max_iters: + v = hess(p) + mu * p + with torch.no_grad(): + alpha = torch.dot(resid, z) / torch.dot(p, v) + x += alpha * p + + rTz = torch.dot(resid, z) + resid -= alpha * v + z = _apply_nys_precond_inv(U, S_mu_inv, mu, lambd_r, resid) + beta = torch.dot(resid, z) / rTz + + p = z + beta * p + + i += 1 + + if torch.norm(resid) > tol: + print( + "Warning: PCG did not converge to tolerance. " + "Tolerance was {tol} but norm of residual is {torch.norm(resid)}" + ) + + return x + + +class NNCG(Optimizer): + """Implementation of NysNewtonCG, a damped Newton-CG method + that uses Nyström preconditioning. + + `Rathore et al. Challenges in Training PINNs: A Loss Landscape Perspective. + Preprint, 2024. ` + + .. warning:: + This optimizer doesn't support per-parameter options and parameter + groups (there can be only one). + + NOTE: This optimizer is currently a beta version. + + Our implementation is inspired by the PyTorch implementation of `L-BFGS + `. + + The parameters rank and mu will probably need to be tuned for your specific problem. + If the optimizer is running very slowly, you can try one of the following: + - Increase the rank (this should increase the + accuracy of the Nyström approximation in PCG) + - Reduce cg_tol (this will allow PCG to terminate with a less accurate solution) + - Reduce cg_max_iters (this will allow PCG to terminate after fewer iterations) + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1.0) + rank (int, optional): rank of the Nyström approximation (default: 10) + mu (float, optional): damping parameter (default: 1e-4) + chunk_size (int, optional): number of Hessian-vector products + to be computed in parallel (default: 1) + cg_tol (float, optional): tolerance for PCG (default: 1e-16) + cg_max_iters (int, optional): maximum number of PCG iterations (default: 1000) + line_search_fn (str, optional): either 'armijo' or None (default: None) + verbose (bool, optional): verbosity (default: False) + """ + + def __init__( + self, + params, + lr=1.0, + rank=10, + mu=1e-4, + chunk_size=1, + cg_tol=1e-16, + cg_max_iters=1000, + line_search_fn=None, + verbose=False, + ): + defaults = dict( + lr=lr, + rank=rank, + chunk_size=chunk_size, + mu=mu, + cg_tol=cg_tol, + cg_max_iters=cg_max_iters, + line_search_fn=line_search_fn, + ) + self.rank = rank + self.mu = mu + self.chunk_size = chunk_size + self.cg_tol = cg_tol + self.cg_max_iters = cg_max_iters + self.line_search_fn = line_search_fn + self.verbose = verbose + self.U = None + self.S = None + self.n_iters = 0 + super(NNCG, self).__init__(params, defaults) + + if len(self.param_groups) > 1: + raise ValueError( + "NNCG doesn't currently support " + "per-parameter options (parameter groups)" + ) + + if self.line_search_fn is not None and self.line_search_fn != "armijo": + raise ValueError("NNCG only supports Armijo line search") + + self._params = self.param_groups[0]["params"] + self._params_list = list(self._params) + self._numel_cache = None + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns (i) the loss and (ii) gradient w.r.t. the parameters. + The closure can compute the gradient w.r.t. the parameters by + calling torch.autograd.grad on the loss with create_graph=True. + """ + if self.n_iters == 0: + # Store the previous direction for warm starting PCG + self.old_dir = torch.zeros(self._numel(), device=self._params[0].device) + + # NOTE: The closure must return both the loss and the gradient + loss = None + if closure is not None: + with torch.enable_grad(): + loss, grad_tuple = closure() + + g = torch.cat([grad.view(-1) for grad in grad_tuple if grad is not None]) + + # One step update + for group_idx, group in enumerate(self.param_groups): + + def hvp_temp(x): + return self._hvp(g, self._params_list, x) + + # Calculate the Newton direction + d = _nystrom_pcg( + hvp_temp, + g, + self.old_dir, + self.mu, + self.U, + self.S, + self.rank, + self.cg_tol, + self.cg_max_iters, + ) + + # Store the previous direction for warm starting PCG + self.old_dir = d + + # Check if d is a descent direction + if torch.dot(d, g) <= 0: + print("Warning: d is not a descent direction") + + if self.line_search_fn == "armijo": + x_init = self._clone_param() + + def obj_func(x, t, dx): + self._add_grad(t, dx) + loss = float(closure()[0]) + self._set_param(x) + return loss + + # Use -d for convention + t = _armijo(obj_func, x_init, g, -d, group["lr"]) + else: + t = group["lr"] + + self.state[group_idx]["t"] = t + + # update parameters + ls = 0 + for p in group["params"]: + np = torch.numel(p) + dp = d[ls : ls + np].view(p.shape) + ls += np + p.data.add_(-dp, alpha=t) + + self.n_iters += 1 + + return loss, g + + def update_preconditioner(self, grad_tuple): + """Update the Nystrom approximation of the Hessian. + + Args: + grad_tuple (tuple): tuple of Tensors containing the gradients + of the loss w.r.t. the parameters. + This tuple can be obtained by calling torch.autograd.grad + on the loss with create_graph=True. + """ + + # Flatten and concatenate the gradients + gradsH = torch.cat( + [gradient.view(-1) for gradient in grad_tuple if gradient is not None] + ) + + # Generate test matrix (NOTE: This is transposed test matrix) + p = gradsH.shape[0] + Phi = torch.randn((self.rank, p), device=gradsH.device) / (p**0.5) + Phi = torch.linalg.qr(Phi.t(), mode="reduced")[0].t() + + Y = self._hvp_vmap(gradsH, self._params_list)(Phi) + + # Calculate shift + shift = torch.finfo(Y.dtype).eps + Y_shifted = Y + shift * Phi + + # Calculate Phi^T * H * Phi (w/ shift) for Cholesky + choleskytarget = torch.mm(Y_shifted, Phi.t()) + + # Perform Cholesky, if fails, do eigendecomposition + # The new shift is the abs of smallest eigenvalue (negative) + # plus the original shift + try: + C = torch.linalg.cholesky(choleskytarget) + except torch.linalg.LinAlgError: + # eigendecomposition, eigenvalues and eigenvector matrix + eigs, eigvectors = torch.linalg.eigh(choleskytarget) + shift = shift + torch.abs(torch.min(eigs)) + # add shift to eigenvalues + eigs = eigs + shift + # put back the matrix for Cholesky by eigenvector * eigenvalues + # after shift * eigenvector^T + C = torch.linalg.cholesky( + torch.mm(eigvectors, torch.mm(torch.diag(eigs), eigvectors.T)) + ) + + try: + B = torch.linalg.solve_triangular(C, Y_shifted, upper=False, left=True) + # temporary fix for issue @ https://github.com/pytorch/pytorch/issues/97211 + except RuntimeError: + B = torch.linalg.solve_triangular( + C.to("cpu"), Y_shifted.to("cpu"), upper=False, left=True + ).to(C.device) + + # B = V * S * U^T b/c we have been using transposed sketch + _, S, UT = torch.linalg.svd(B, full_matrices=False) + self.U = UT.t() + self.S = torch.max(torch.square(S) - shift, torch.tensor(0.0)) + + self.rho = self.S[-1] + + if self.verbose: + print(f"Approximate eigenvalues = {self.S}") + + def _hvp_vmap(self, grad_params, params): + return vmap( + lambda v: self._hvp(grad_params, params, v), + in_dims=0, + chunk_size=self.chunk_size, + ) + + def _hvp(self, grad_params, params, v): + Hv = torch.autograd.grad(grad_params, params, grad_outputs=v, retain_graph=True) + Hv = tuple(Hvi.detach() for Hvi in Hv) + return torch.cat([Hvi.reshape(-1) for Hvi in Hv]) + + def _numel(self): + if self._numel_cache is None: + self._numel_cache = reduce( + lambda total, p: total + p.numel(), self._params, 0 + ) + return self._numel_cache + + def _add_grad(self, step_size, update): + offset = 0 + for p in self._params: + numel = p.numel() + # Avoid in-place operation by creating a new tensor + p.data = p.data.add( + update[offset : offset + numel].view_as(p), alpha=step_size + ) + offset += numel + assert offset == self._numel() + + def _clone_param(self): + return [p.clone(memory_format=torch.contiguous_format) for p in self._params] + + def _set_param(self, params_data): + for p, pdata in zip(self._params, params_data): + # Replace the .data attribute of the tensor + p.data = pdata.data diff --git a/deepxde/optimizers/pytorch/optimizers.py b/deepxde/optimizers/pytorch/optimizers.py index 6329912dd..8bf795559 100644 --- a/deepxde/optimizers/pytorch/optimizers.py +++ b/deepxde/optimizers/pytorch/optimizers.py @@ -2,11 +2,12 @@ import torch -from ..config import LBFGS_options +from .nncg import NNCG +from ..config import LBFGS_options, NNCG_options def is_external_optimizer(optimizer): - return optimizer in ["L-BFGS", "L-BFGS-B"] + return optimizer in ["L-BFGS", "L-BFGS-B", "NNCG"] def get(params, optimizer, learning_rate=None, decay=None, weight_decay=0): @@ -29,6 +30,22 @@ def get(params, optimizer, learning_rate=None, decay=None, weight_decay=0): history_size=LBFGS_options["maxcor"], line_search_fn=("strong_wolfe" if LBFGS_options["maxls"] > 0 else None), ) + elif optimizer == "NNCG": + if weight_decay > 0: + raise ValueError("NNCG optimizer doesn't support weight_decay > 0") + if learning_rate is not None or decay is not None: + print("Warning: learning rate is ignored for {}".format(optimizer)) + optim = NNCG( + params, + lr=NNCG_options["lr"], + rank=NNCG_options["rank"], + mu=NNCG_options["mu"], + chunk_size=NNCG_options["chunksz"], + cg_tol=NNCG_options["cgtol"], + cg_max_iters=NNCG_options["cgmaxiter"], + line_search_fn=NNCG_options["lsfun"], + verbose=NNCG_options["verbose"], + ) else: if learning_rate is None: raise ValueError("No learning rate for {}.".format(optimizer)) diff --git a/examples/pinn_forward/Burgers_NNCG.py b/examples/pinn_forward/Burgers_NNCG.py new file mode 100644 index 000000000..d95eb1d22 --- /dev/null +++ b/examples/pinn_forward/Burgers_NNCG.py @@ -0,0 +1,67 @@ +"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle""" +import deepxde as dde +import numpy as np + + +def gen_testdata(): + data = np.load("../dataset/Burgers.npz") + t, x, exact = data["t"], data["x"], data["usol"].T + xx, tt = np.meshgrid(x, t) + X = np.vstack((np.ravel(xx), np.ravel(tt))).T + y = exact.flatten()[:, None] + return X, y + + +def pde(x, y): + dy_x = dde.grad.jacobian(y, x, i=0, j=0) + dy_t = dde.grad.jacobian(y, x, i=0, j=1) + dy_xx = dde.grad.hessian(y, x, i=0, j=0) + return dy_t + y * dy_x - 0.01 / np.pi * dy_xx + + +geom = dde.geometry.Interval(-1, 1) +timedomain = dde.geometry.TimeDomain(0, 0.99) +geomtime = dde.geometry.GeometryXTime(geom, timedomain) + +bc = dde.icbc.DirichletBC(geomtime, lambda x: 0, lambda _, on_boundary: on_boundary) +ic = dde.icbc.IC( + geomtime, lambda x: -np.sin(np.pi * x[:, 0:1]), lambda _, on_initial: on_initial +) + +data = dde.data.TimePDE( + geomtime, pde, [bc, ic], num_domain=2540, num_boundary=80, num_initial=160 +) +net = dde.nn.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal") +model = dde.Model(data, net) + +# Run Adam+L-BFGS +model.compile("adam", lr=1e-3) +model.train(iterations=15000) + +model.compile("L-BFGS") +losshistory, train_state = model.train() +dde.saveplot(losshistory, train_state, issave=True, isplot=True) + +# Get test data +X, y_true = gen_testdata() + +# Get the results after running Adam+L-BFGS +y_pred = model.predict(X) +f = model.predict(X, operator=pde) +print("Mean residual after Adam+L-BFGS:", np.mean(np.absolute(f))) +print("L2 relative error after Adam+L-BFGS:", dde.metrics.l2_relative_error(y_true, y_pred)) +np.savetxt("test_adam_lbfgs.dat", np.hstack((X, y_true, y_pred))) + +# Run NNCG after Adam+L-BFGS +dde.optimizers.set_NNCG_options(rank=50, mu=1e-1) +model.compile("NNCG") +losshistory_nncg, train_state_nncg = model.train(iterations=1000, display_every=100) +dde.saveplot(losshistory_nncg, train_state_nncg, issave=True, isplot=True) + +# Get the final results after running Adam+L-BFGS+NNCG +y_pred = model.predict(X) +f = model.predict(X, operator=pde) +print("Mean residual after Adam+L-BFGS+NNCG:", np.mean(np.absolute(f))) +print("L2 relative error after Adam+L-BFGS+NNCG:", + dde.metrics.l2_relative_error(y_true, y_pred)) +np.savetxt("test_adam_lbfgs_nncg.dat", np.hstack((X, y_true, y_pred)))