From ca1a91b9af52a9e3fae68d3935b82138b2d6e575 Mon Sep 17 00:00:00 2001 From: Jiong Zhang Date: Tue, 28 May 2024 12:07:17 -0700 Subject: [PATCH] fix sparse gradient clipping for torch>=2.0 (#288) --- pecos/utils/torch_util.py | 62 +++++++++++++++++++++++++++++++ pecos/xmc/xlinear/model.py | 8 ++-- pecos/xmc/xtransformer/matcher.py | 34 +++++++++++------ setup.py | 3 +- 4 files changed, 91 insertions(+), 16 deletions(-) diff --git a/pecos/utils/torch_util.py b/pecos/utils/torch_util.py index 0398228..a6e826b 100644 --- a/pecos/utils/torch_util.py +++ b/pecos/utils/torch_util.py @@ -12,6 +12,7 @@ import numpy as np import torch +from typing import Union, Iterable LOGGER = logging.getLogger(__name__) @@ -72,3 +73,64 @@ def apply_mask(hidden_states, masks): hidden_dim = hidden_states.shape[-1] hidden_states.view(-1, hidden_dim)[~masks.view(-1).type(torch.ByteTensor), :] = 0 return hidden_states + + +def clip_grad_norm_( + parameters: Union[torch.Tensor, Iterable[torch.Tensor]], + max_norm: float, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, +) -> torch.Tensor: + r""" + Implementation of torch.nn.utils.clip_grad_norm_ in torch==1.13 + This is to support sparse gradient with gradient clipping. + REF: https://pytorch.org/docs/1.13/_modules/torch/nn/utils/clip_grad.html#clip_grad_norm_ + + Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.0) + device = grads[0].device + if norm_type == "inf": + norms = [g.detach().abs().max().to(device) for g in grads] + total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + else: + total_norm = torch.norm( + torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type + ) + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for g in grads: + g.detach().mul_(clip_coef_clamped.to(g.device)) + return total_norm diff --git a/pecos/xmc/xlinear/model.py b/pecos/xmc/xlinear/model.py index 143ecc9..703cbf4 100644 --- a/pecos/xmc/xlinear/model.py +++ b/pecos/xmc/xlinear/model.py @@ -537,9 +537,11 @@ def predict( Ye = self.predict( X[i : i + max_pred_chunk, :], pred_params=pred_params, - selected_outputs_csr=selected_outputs_csr[i : i + max_pred_chunk, :] - if selected_outputs_csr is not None - else None, + selected_outputs_csr=( + selected_outputs_csr[i : i + max_pred_chunk, :] + if selected_outputs_csr is not None + else None + ), **new_kwargs, ) Ys.append(Ye) diff --git a/pecos/xmc/xtransformer/matcher.py b/pecos/xmc/xtransformer/matcher.py index 00d5ae7..9ed2a82 100644 --- a/pecos/xmc/xtransformer/matcher.py +++ b/pecos/xmc/xtransformer/matcher.py @@ -784,18 +784,20 @@ def _predict( if not only_embeddings: text_model_W_seq, text_model_b_seq = self.text_model( output_indices=inputs["label_indices"], - num_device=len(self.text_encoder.device_ids) - if hasattr(self.text_encoder, "device_ids") - else 1, + num_device=( + len(self.text_encoder.device_ids) + if hasattr(self.text_encoder, "device_ids") + else 1 + ), ) outputs = self.text_encoder( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], token_type_ids=inputs["token_type_ids"], - label_embedding=None - if only_embeddings - else (text_model_W_seq, text_model_b_seq), + label_embedding=( + None if only_embeddings else (text_model_W_seq, text_model_b_seq) + ), ) if not only_embeddings: @@ -1088,9 +1090,11 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None): } text_model_W_seq, text_model_b_seq = self.text_model( output_indices=inputs["label_indices"], - num_device=len(self.text_encoder.device_ids) - if hasattr(self.text_encoder, "device_ids") - else 1, + num_device=( + len(self.text_encoder.device_ids) + if hasattr(self.text_encoder, "device_ids") + else 1 + ), ) outputs = self.text_encoder( input_ids=inputs["input_ids"], @@ -1119,9 +1123,15 @@ def fine_tune_encoder(self, prob, val_prob=None, val_csr_codes=None): scheduler.step() # update learning rate schedule optimizer.zero_grad() # clear gradient accumulation - torch.nn.utils.clip_grad_norm_( - self.text_model.parameters(), train_params.max_grad_norm - ) + if self.text_model.is_sparse: + torch_util.clip_grad_norm_( + self.text_model.parameters(), train_params.max_grad_norm + ) + else: + torch.nn.utils.clip_grad_norm_( + self.text_model.parameters(), train_params.max_grad_norm + ) + emb_optimizer.step() # perform gradient update emb_scheduler.step() # update learning rate schedule emb_optimizer.zero_grad() # clear gradient accumulation diff --git a/setup.py b/setup.py index 7d83871..45cbd58 100644 --- a/setup.py +++ b/setup.py @@ -115,7 +115,8 @@ def get_blas_lib_dir(cls): install_requires = numpy_requires + [ 'scipy>=1.4.1', 'scikit-learn>=0.24.1', - 'torch>=1.8.0,<2.0.0', + 'torch==1.13; python_version<"3.8"', + 'torch>=2.0; python_version>="3.8"', 'sentencepiece>=0.1.86,!=0.1.92', # 0.1.92 results in error for transformers 'transformers>=4.1.1; python_version<"3.9"', 'transformers>=4.4.2; python_version>="3.9"'