diff --git a/README.md b/README.md index a04d4866..4c2cf8a3 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ ## 📖 Overview -We provide efficient and streamlined implementations of the TOFU, MUSE unlearning benchmarks while supporting 5 unlearning methods, 3+ datasets, 6+ evaluation metrics, and 7+ LLMs. Each of these can be easily extended to incorporate more variants. +We provide efficient and streamlined implementations of the TOFU, MUSE unlearning benchmarks while supporting 6 unlearning methods, 3+ datasets, 6+ evaluation metrics, and 7+ LLMs. Each of these can be easily extended to incorporate more variants. We invite the LLM unlearning community to collaborate by adding new benchmarks, unlearning methods, datasets and evaluation metrics here to expand OpenUnlearning's features, gain feedback from wider usage and drive progress in the field. @@ -35,7 +35,7 @@ We provide several variants for each of the components in the unlearning pipelin | **Component** | **Available Options** | |------------------------|----------------------| | **Benchmarks** | [TOFU](https://arxiv.org/abs/2401.06121), [MUSE](https://muse-bench.github.io/) | -| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO | +| **Unlearning Methods** | GradAscent, GradDiff, NPO, SimNPO, DPO, RMU | | **Evaluation Metrics** | Verbatim Probability, Verbatim ROUGE, QA-ROUGE, MIA Attacks, TruthRatio, Model Utility | | **Datasets** | MUSE-News (BBC), MUSE-Books (Harry Potter), TOFU (different splits) | | **Model Families** | TOFU: LLaMA-3.2, LLaMA-3.1, LLaMA-2; MUSE: LLaMA-2, ICLM; Additional: Phi-3.5, Phi-1.5, Gemma | diff --git a/configs/experiment/unlearn/muse/default.yaml b/configs/experiment/unlearn/muse/default.yaml index 454a84e3..b4bdbe0f 100644 --- a/configs/experiment/unlearn/muse/default.yaml +++ b/configs/experiment/unlearn/muse/default.yaml @@ -34,6 +34,7 @@ eval: muse: data_split: ${data_split} retain_logs_path: ${retain_logs_path} + overwrite: true trainer: args: diff --git a/configs/experiment/unlearn/muse/scalability.yaml b/configs/experiment/unlearn/muse/scalability.yaml index 11d90f50..b19e0cb5 100644 --- a/configs/experiment/unlearn/muse/scalability.yaml +++ b/configs/experiment/unlearn/muse/scalability.yaml @@ -34,6 +34,7 @@ eval: muse: data_split: ${data_split} retain_logs_path: ${retain_logs_path} + overwrite: true trainer: args: diff --git a/configs/experiment/unlearn/muse/sustainabilty.yaml b/configs/experiment/unlearn/muse/sustainabilty.yaml index e5d79687..9a0a03e3 100644 --- a/configs/experiment/unlearn/muse/sustainabilty.yaml +++ b/configs/experiment/unlearn/muse/sustainabilty.yaml @@ -34,6 +34,7 @@ eval: muse: data_split: ${data_split} retain_logs_path: ${retain_logs_path} + overwrite: true trainer: args: diff --git a/configs/experiment/unlearn/tofu/default.yaml b/configs/experiment/unlearn/tofu/default.yaml index 5f7c4757..f2e0ab1a 100644 --- a/configs/experiment/unlearn/tofu/default.yaml +++ b/configs/experiment/unlearn/tofu/default.yaml @@ -20,6 +20,7 @@ eval: tofu: forget_split: ${forget_split} retain_logs_path: ${retain_logs_path} + overwrite: true data: anchor: forget diff --git a/configs/experiment/unlearn/tofu/idk.yaml b/configs/experiment/unlearn/tofu/idk.yaml index 61a365d0..5fcb85df 100644 --- a/configs/experiment/unlearn/tofu/idk.yaml +++ b/configs/experiment/unlearn/tofu/idk.yaml @@ -20,6 +20,7 @@ eval: tofu: forget_split: ${forget_split} retain_logs_path: ${retain_logs_path} + overwrite: true data: anchor: forget diff --git a/configs/trainer/RMU.yaml b/configs/trainer/RMU.yaml new file mode 100644 index 00000000..7e1f9028 --- /dev/null +++ b/configs/trainer/RMU.yaml @@ -0,0 +1,14 @@ +defaults: + - GradDiff + +handler: RMU +method_args: + # The params here are more dependent on model and dataset. Tune them carefully to work + gamma: 1.0 + steering_coeff: 2 + retain_loss_type: EMBED_DIFF + alpha: 1 + module_regex: model\.layers\.7 + trainable_params_regex: + - .* # update all parameters (as done in https://github.com/tmlr-group/G-effect/blob/ef368eea3b2c6dba1e090b9ebb021ac9f047e0ae/dataloader.py#L271) + # - model\.layers\.(5|6|7)\.mlp\.down_proj\.weight # If you want to update only these weights (as done in https://github.com/centerforaisafety/wmdp/blob/bc5e1ba0367ea826caeeeaa50656336a1e87acfb/rmu/unlearn.py#L26) \ No newline at end of file diff --git a/docs/results.md b/docs/results.md index 3f7cc7c9..3af7cb66 100644 --- a/docs/results.md +++ b/docs/results.md @@ -23,7 +23,7 @@ For all the experiments below, we used the following setup | **Hyperparameters** | Learning Rate (lr) = 1e-5
α = 1, γ = 1, β = 0.1 (where applicable)
Number of Epochs = 10
Optimizer: [paged_adamw_32bit](https://huggingface.co/docs/bitsandbytes/main/en/reference/optim/adamw#bitsandbytes.optim.PagedAdamW) | __Note:__ -1. Results may vary even with the same effective hyperparameters when trained with modifications to the distributed training setup, including when training on a single GPU. For example: methods such as SimNPO, can be significantly improved with careful tuning. **Please use these numbers only for reproducibility purposes**. +1. Results may vary even with the same effective hyperparameters when trained with modifications to the distributed training setup, including when training on a single GPU. For example: methods such as SimNPO & RMU can be significantly improved with careful tuning. **Please use these numbers only for reproducibility purposes**. 2. NPO in MUSE: for NPO, the MUSE implementation is inconsistent with the [original paper](https://github.com/licong-lin/negative-preference-optimization) as discussed [here]( https://github.com/jaechan-repo/muse_bench/issues/2). This inconsistency is carried over into implementations like [SimNPO](https://github.com/OPTML-Group/Unlearn-Simple/issues/5). Here, we use the original NPO implementation with the same loss function expression across datasets. @@ -140,6 +140,18 @@ __Note:__ 0.6 3.17e-04 + + RMU + 0.4 + 0.62 + 0.64 + 9.59e-10 + 0.02 + 0.81 + 6.92e-21 + 0.03 + 0.81 + @@ -257,6 +269,18 @@ __Note:__ 0.54 1.07e-05 + + RMU + 0.16 + 0.55 + 0.70 + 4.87e-10 + 0.58 + 0.77 + 3.15e-15 + 0.59 + 0.76 + @@ -354,6 +378,17 @@ __Note:__ -54.26 0.54 + + RMU + 0.48 + 0.05 + 56.36 + 0.51 + 0.29 + 0.79 + -60.52 + 0.48 + \ No newline at end of file diff --git a/scripts/tofu_unlearn.sh b/scripts/tofu_unlearn.sh index 1794c9b6..ae33189f 100644 --- a/scripts/tofu_unlearn.sh +++ b/scripts/tofu_unlearn.sh @@ -14,6 +14,7 @@ trainers_experiments=( "GradDiff unlearn/tofu/default.yaml" "NPO unlearn/tofu/default.yaml" "DPO unlearn/tofu/idk.yaml" + "RMU unlearn/tofu/default.yaml" ) forget_retain_splits=( "forget01 retain99" diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py index 1c769bf6..7e195fa9 100644 --- a/src/trainer/__init__.py +++ b/src/trainer/__init__.py @@ -9,6 +9,7 @@ from trainer.unlearn.npo import NPO from trainer.unlearn.dpo import DPO from trainer.unlearn.simnpo import SimNPO +from trainer.unlearn.rmu import RMU TRAINER_REGISTRY: Dict[str, Any] = {} @@ -79,3 +80,4 @@ def load_trainer( _register_trainer(NPO) _register_trainer(DPO) _register_trainer(SimNPO) +_register_trainer(RMU) diff --git a/src/trainer/unlearn/grad_diff.py b/src/trainer/unlearn/grad_diff.py index e11c7a71..bfecc19a 100644 --- a/src/trainer/unlearn/grad_diff.py +++ b/src/trainer/unlearn/grad_diff.py @@ -14,7 +14,7 @@ def __init__(self, gamma=1.0, alpha=1.0, retain_loss_type="NLL", *args, **kwargs self.ref_model = self._prepare_ref_model(self.model) def _prepare_ref_model(self, model): - ref_model = copy.deepcopy(model).to("cuda") + ref_model = copy.deepcopy(model).to(self.accelerator.device) ref_model.eval() if self.is_deepspeed_enabled: ref_model = self._prepare_deepspeed(ref_model) diff --git a/src/trainer/unlearn/rmu.py b/src/trainer/unlearn/rmu.py new file mode 100644 index 00000000..391bd6ad --- /dev/null +++ b/src/trainer/unlearn/rmu.py @@ -0,0 +1,142 @@ +"""Borrowed implementation from https://github.com/centerforaisafety/wmdp/blob/main/rmu/unlearn.py""" + +import re +import torch +import deepspeed +from trainer.unlearn.grad_diff import GradDiff + + +class RMU(GradDiff): + def __init__(self, + module_regex="model\.layers\.7", + trainable_params_regex=["model\.layers\.(5|6|7)\.mlp\.down_proj\.weight"], + steering_coeff=20, + *args, **kwargs): + """ + RMU Trainer that fine-tunes only specific layers and parameters using regex-based filtering. + + Args: + module_path (str): Regex pattern to match module names. + trainable_param_paths (list of str): List of regex patterns for trainable parameters. + """ + super().__init__(*args, **kwargs) + + # Create reference model if not already set + if self.ref_model is None: + self.ref_model = self._prepare_ref_model(self.model) + + # Unfreeze only the selected parameters + self.trainable_params_regex = trainable_params_regex # Regex for selecting params + + # Get actual module references + self.module_regex = module_regex # Regex for selecting modules + self.model_module = self._get_matching_module(self.model, self.module_regex) + self.ref_module = self._get_matching_module(self.ref_model, self.module_regex) + self.steering_coeff = steering_coeff + self.control_vec = None + + + def create_optimizer(self): + self._freeze_all_params(self.model, False) + # This makes the optimizer to select only trainable params + self._set_trainable_params(self.model, self.trainable_params_regex, True) + super().create_optimizer() + self._freeze_all_params(self.model, True) + + + def _get_matching_module(self, model, module_regex): + """Returns a single module matching the given regex from a DeepSpeed/DDP-wrapped model.""" + # Handle DeepSpeed and DDP-wrapped models by accessing the underlying module + if isinstance(model, deepspeed.DeepSpeedEngine): + model = model.module # Extract the actual PyTorch model inside + + matched_modules = {name: module for name, module in model.named_modules() if re.fullmatch(module_regex, name)} + + if len(matched_modules) > 1: + raise ValueError(f"More than one module matched with {module_regex}: {list(matched_modules.keys())}") + elif not matched_modules: + raise ValueError(f"No module matched with {module_regex}") + + return next(iter(matched_modules.values())) # Return the single matched module + + def _freeze_all_params(self, model, requires_grad=True): + """Freeze all parameters in the model initially.""" + for param in model.parameters(): + param.requires_grad = requires_grad + + def _set_trainable_params(self, model, trainable_params_regex, requires_grad=True): + """Unfreeze specific parameters that match the regex patterns.""" + for name, param in model.named_parameters(): + if any(re.fullmatch(pattern, name) for pattern in trainable_params_regex): + param.requires_grad = requires_grad + # print(f"{name}:requires_grad\t{requires_grad}") + + def forward_with_cache(self, model, inputs, module, no_grad=True): + """Performs a forward pass while caching the output of a specified module.""" + cache = [] + def hook(module, input, output): + if isinstance(output, tuple): + cache.append(output[0]) + else: + cache.append(output) + return None + + hook_handle = module.register_forward_hook(hook) + with torch.set_grad_enabled(not(no_grad)): + outputs = model(**inputs) + hook_handle.remove() + return cache[0], outputs + + def get_control_vector(self, dim): + if self.control_vec is None: + random_vector = torch.rand(1,1, dim) + self.control_vec = random_vector / torch.norm(random_vector) * self.steering_coeff + return self.control_vec + + + def compute_activation_loss(self, activation1, activation2, mask): + squared_diff = torch.nn.functional.mse_loss(activation1, activation2, reduction="none") # Shape (b, s, d) + expanded_mask = mask.unsqueeze(-1).expand_as(squared_diff) # Shape: [b, s, d] + squared_diff_sum = (squared_diff * expanded_mask).mean(dim=2).sum(dim=(1)) # Shape: [b, 1] + num_tokens = mask.sum(dim=-1, keepdim=True) # Sum over seq_len, Shape: [b, 1] + return (squared_diff_sum / num_tokens).mean() + + def compute_retain_loss(self, model, retain_inputs): + retain_loss = 0.0 + + if self.retain_loss_type == "EMBED_DIFF": + model_retain_activations, _ = self.forward_with_cache(model, retain_inputs, module=self.model_module, no_grad=False) + ref_retain_activations, _ = self.forward_with_cache(self.ref_model, retain_inputs, module=self.ref_module, no_grad=True) + mask = (retain_inputs['labels'] != -100) # Shape: [b, s] + retain_loss = self.compute_activation_loss(model_retain_activations, ref_retain_activations.to(model_retain_activations.device), mask) + else: + retain_loss = super().compute_retain_loss(model, retain_inputs) + return retain_loss + + def compute_loss(self, model, inputs, return_outputs=False): + forget_inputs = inputs["forget"] + forget_inputs = { + "input_ids": forget_inputs["input_ids"], + "attention_mask": forget_inputs["attention_mask"], + "labels": forget_inputs["labels"], + } + + model_forget_activations, forget_outputs = self.forward_with_cache(model, forget_inputs, self.model_module, no_grad=False) + # If multiple datasets or concepts need unlearning, pass the control vector during processing; otherwise, default to a random vector during training. + control_vec = forget_inputs.get("control_vec", self.get_control_vector(model_forget_activations.shape[-1])) + control_vec = control_vec.to(dtype=model_forget_activations.dtype, device=model_forget_activations.device) + control_vec = control_vec.expand_as(model_forget_activations) + mask = (forget_inputs['labels'] != -100) # Shape: [b, s] + forget_loss = self.compute_activation_loss(model_forget_activations, control_vec, mask) + + retain_inputs = inputs["retain"] + retain_inputs = { + "input_ids": retain_inputs["input_ids"], + "attention_mask": retain_inputs["attention_mask"], + "labels": retain_inputs["labels"], + } + retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs) + + loss = self.gamma * forget_loss + self.alpha * retain_loss + + return (loss, forget_outputs) if return_outputs else loss