diff --git a/README.md b/README.md
index dce38e5a..a04d4866 100644
--- a/README.md
+++ b/README.md
@@ -44,7 +44,7 @@ We provide several variants for each of the components in the unlearning pipelin
## 📌 Table of Contents
- 📖 [Overview](#-overview)
-- 🗃️ [Available Components](#-available-components)
+- 🗃️ [Available Components](#%EF%B8%8F-available-components)
- ⚡ [Quickstart](#-quickstart)
- 🛠️ [Environment Setup](#-environment-setup)
- 💾 [Data Setup](#-data-setup)
@@ -56,7 +56,7 @@ We provide several variants for each of the components in the unlearning pipelin
- ➕ [How to Add New Components](#-how-to-add-new-components)
- 📚 [Further Documentation](#-further-documentation)
- 🔗 [Support & Contributors](#-support--contributors)
-- 📝 [Citing this work](#-citating-this-work)
+- 📝 [Citing this work](#-citing-this-work)
- 🤝 [Acknowledgements](#-acknowledgements)
- 📄 [License](#-license)
@@ -198,7 +198,7 @@ If you use OpenUnlearning in your research, please cite:
---
-### 🤝 Acknowledgments
+### 🤝 Acknowledgements
- This repo is inspired from [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory).
- The [TOFU](https://github.com/locuslab/tofu) and [MUSE](https://github.com/jaechan-repo/muse_bench) benchmarks served as the foundation for our re-implementation.
diff --git a/configs/experiment/unlearn/muse/default.yaml b/configs/experiment/unlearn/muse/default.yaml
index 454a84e3..b4bdbe0f 100644
--- a/configs/experiment/unlearn/muse/default.yaml
+++ b/configs/experiment/unlearn/muse/default.yaml
@@ -34,6 +34,7 @@ eval:
muse:
data_split: ${data_split}
retain_logs_path: ${retain_logs_path}
+ overwrite: true
trainer:
args:
diff --git a/configs/experiment/unlearn/muse/scalability.yaml b/configs/experiment/unlearn/muse/scalability.yaml
index 11d90f50..b19e0cb5 100644
--- a/configs/experiment/unlearn/muse/scalability.yaml
+++ b/configs/experiment/unlearn/muse/scalability.yaml
@@ -34,6 +34,7 @@ eval:
muse:
data_split: ${data_split}
retain_logs_path: ${retain_logs_path}
+ overwrite: true
trainer:
args:
diff --git a/configs/experiment/unlearn/muse/sustainabilty.yaml b/configs/experiment/unlearn/muse/sustainabilty.yaml
index e5d79687..9a0a03e3 100644
--- a/configs/experiment/unlearn/muse/sustainabilty.yaml
+++ b/configs/experiment/unlearn/muse/sustainabilty.yaml
@@ -34,6 +34,7 @@ eval:
muse:
data_split: ${data_split}
retain_logs_path: ${retain_logs_path}
+ overwrite: true
trainer:
args:
diff --git a/configs/experiment/unlearn/tofu/default.yaml b/configs/experiment/unlearn/tofu/default.yaml
index 5f7c4757..f2e0ab1a 100644
--- a/configs/experiment/unlearn/tofu/default.yaml
+++ b/configs/experiment/unlearn/tofu/default.yaml
@@ -20,6 +20,7 @@ eval:
tofu:
forget_split: ${forget_split}
retain_logs_path: ${retain_logs_path}
+ overwrite: true
data:
anchor: forget
diff --git a/configs/experiment/unlearn/tofu/idk.yaml b/configs/experiment/unlearn/tofu/idk.yaml
index 61a365d0..5fcb85df 100644
--- a/configs/experiment/unlearn/tofu/idk.yaml
+++ b/configs/experiment/unlearn/tofu/idk.yaml
@@ -20,6 +20,7 @@ eval:
tofu:
forget_split: ${forget_split}
retain_logs_path: ${retain_logs_path}
+ overwrite: true
data:
anchor: forget
diff --git a/configs/trainer/RMU.yaml b/configs/trainer/RMU.yaml
new file mode 100644
index 00000000..dcc49e20
--- /dev/null
+++ b/configs/trainer/RMU.yaml
@@ -0,0 +1,13 @@
+defaults:
+ - GradDiff
+
+handler: RMU
+method_args:
+ # The params here are more dependent on model and dataset. Tune them carefully to work
+ gamma: 1.0
+ alpha: 1000
+ steering_coeff: 300
+ retain_loss_type: null
+ module_regex: model\.layers\.7
+ trainable_params_regex:
+ - model\.layers\.(5|6|7)\.mlp\.down_proj\.weight
\ No newline at end of file
diff --git a/docs/results.md b/docs/results.md
index 3f7cc7c9..32a55545 100644
--- a/docs/results.md
+++ b/docs/results.md
@@ -23,7 +23,7 @@ For all the experiments below, we used the following setup
| **Hyperparameters** | Learning Rate (lr) = 1e-5
α = 1, γ = 1, β = 0.1 (where applicable)
Number of Epochs = 10
Optimizer: [paged_adamw_32bit](https://huggingface.co/docs/bitsandbytes/main/en/reference/optim/adamw#bitsandbytes.optim.PagedAdamW) |
__Note:__
-1. Results may vary even with the same effective hyperparameters when trained with modifications to the distributed training setup, including when training on a single GPU. For example: methods such as SimNPO, can be significantly improved with careful tuning. **Please use these numbers only for reproducibility purposes**.
+1. Results may vary even with the same effective hyperparameters when trained with modifications to the distributed training setup, including when training on a single GPU. For example: methods such as SimNPO & RMU can be significantly improved with careful tuning. **Please use these numbers only for reproducibility purposes**.
2. NPO in MUSE: for NPO, the MUSE implementation is inconsistent with the [original paper](https://github.com/licong-lin/negative-preference-optimization) as discussed [here]( https://github.com/jaechan-repo/muse_bench/issues/2). This inconsistency is carried over into implementations like [SimNPO](https://github.com/OPTML-Group/Unlearn-Simple/issues/5). Here, we use the original NPO implementation with the same loss function expression across datasets.
@@ -140,6 +140,17 @@ __Note:__
0.6 |
3.17e-04 |
+
+ | RMU |
+ 6.76e-03 |
+ 7.18e-04 |
+ 0.84 |
+ 1.21e-10 |
+ 0 |
+ 0.81 |
+ 1.18e-17 |
+ 0 |
+ 0.8 |
@@ -257,6 +268,18 @@ __Note:__
0.54 |
1.07e-05 |
+
+ | RMU |
+ 6.76e-03 |
+ 0.60 |
+ 0.47 |
+ 2.89e-11 |
+ 0.6 |
+ 0.47 |
+ 0.32 |
+ 0.59 |
+ 0.64 |
+
@@ -354,6 +377,17 @@ __Note:__
-54.26 |
0.54 |
+
+ | RMU |
+ 0.67 |
+ 0.57 |
+ -99.81 |
+ 0.56 |
+ 0.47 |
+ 1.0 |
+ -57.35 |
+ 0.67 |
+
\ No newline at end of file
diff --git a/scripts/tofu_unlearn.sh b/scripts/tofu_unlearn.sh
index a556bd1d..1794c9b6 100644
--- a/scripts/tofu_unlearn.sh
+++ b/scripts/tofu_unlearn.sh
@@ -13,7 +13,7 @@ trainers_experiments=(
"GradAscent unlearn/tofu/default.yaml"
"GradDiff unlearn/tofu/default.yaml"
"NPO unlearn/tofu/default.yaml"
- "DPO unlearn/tofu/default.yaml"
+ "DPO unlearn/tofu/idk.yaml"
)
forget_retain_splits=(
"forget01 retain99"
diff --git a/setup_data.py b/setup_data.py
index 48de0ad1..358779c3 100644
--- a/setup_data.py
+++ b/setup_data.py
@@ -1,8 +1,17 @@
from huggingface_hub import snapshot_download
+# Setup retain model metrics
snapshot_download(
repo_id="open-unlearning/eval",
allow_patterns="*.json",
repo_type="dataset",
local_dir="saves/eval",
)
+
+# Setup data
+snapshot_download(
+ repo_id="open-unlearning/idk",
+ allow_patterns="*.jsonl",
+ repo_type="dataset",
+ local_dir="data",
+)
\ No newline at end of file
diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py
index 1c769bf6..7e195fa9 100644
--- a/src/trainer/__init__.py
+++ b/src/trainer/__init__.py
@@ -9,6 +9,7 @@
from trainer.unlearn.npo import NPO
from trainer.unlearn.dpo import DPO
from trainer.unlearn.simnpo import SimNPO
+from trainer.unlearn.rmu import RMU
TRAINER_REGISTRY: Dict[str, Any] = {}
@@ -79,3 +80,4 @@ def load_trainer(
_register_trainer(NPO)
_register_trainer(DPO)
_register_trainer(SimNPO)
+_register_trainer(RMU)
diff --git a/src/trainer/unlearn/grad_diff.py b/src/trainer/unlearn/grad_diff.py
index e11c7a71..bfecc19a 100644
--- a/src/trainer/unlearn/grad_diff.py
+++ b/src/trainer/unlearn/grad_diff.py
@@ -14,7 +14,7 @@ def __init__(self, gamma=1.0, alpha=1.0, retain_loss_type="NLL", *args, **kwargs
self.ref_model = self._prepare_ref_model(self.model)
def _prepare_ref_model(self, model):
- ref_model = copy.deepcopy(model).to("cuda")
+ ref_model = copy.deepcopy(model).to(self.accelerator.device)
ref_model.eval()
if self.is_deepspeed_enabled:
ref_model = self._prepare_deepspeed(ref_model)
diff --git a/src/trainer/unlearn/rmu.py b/src/trainer/unlearn/rmu.py
new file mode 100644
index 00000000..effce154
--- /dev/null
+++ b/src/trainer/unlearn/rmu.py
@@ -0,0 +1,138 @@
+import re
+import torch
+import deepspeed
+from torch import nn
+from trainer.unlearn.grad_diff import GradDiff
+
+
+class RMU(GradDiff):
+ def __init__(self,
+ module_regex="model\.layers\.7",
+ trainable_params_regex=["model\.layers\.(5|6|7)\.mlp\.down_proj\.weight"],
+ steering_coeff=20,
+ *args, **kwargs):
+ """
+ RMU Trainer that fine-tunes only specific layers and parameters using regex-based filtering.
+
+ Args:
+ module_path (str): Regex pattern to match module names.
+ trainable_param_paths (list of str): List of regex patterns for trainable parameters.
+ """
+ super().__init__(*args, **kwargs)
+
+ # Create reference model if not already set
+ if self.ref_model is None:
+ self.ref_model = self._prepare_ref_model(self.model)
+
+ # Unfreeze only the selected parameters
+ self.trainable_params_regex = trainable_params_regex # Regex for selecting params
+
+ # Get actual module references
+ self.module_regex = module_regex # Regex for selecting modules
+ self.model_module = self._get_matching_module(self.model, self.module_regex)
+ self.ref_module = self._get_matching_module(self.ref_model, self.module_regex)
+ self.steering_coeff = steering_coeff
+ self.control_vec = None
+
+
+ def create_optimizer(self):
+ self._set_all_params(self.model, False)
+ # This makes the optimizer to select only trainable params
+ self._set_trainable_params(self.model, self.trainable_params_regex, True)
+ super().create_optimizer()
+ self._set_all_params(self.model, True)
+
+
+ def _get_matching_module(self, model, module_regex):
+ """Returns a single module matching the given regex from a DeepSpeed/DDP-wrapped model."""
+ # Handle DeepSpeed and DDP-wrapped models by accessing the underlying module
+ if isinstance(model, deepspeed.DeepSpeedEngine):
+ model = model.module # Extract the actual PyTorch model inside
+
+ matched_modules = {name: module for name, module in model.named_modules() if re.fullmatch(module_regex, name)}
+
+ if len(matched_modules) > 1:
+ raise ValueError(f"More than one module matched with {module_regex}: {list(matched_modules.keys())}")
+ elif not matched_modules:
+ raise ValueError(f"No module matched with {module_regex}")
+
+ return next(iter(matched_modules.values())) # Return the single matched module
+
+ def _set_all_params(self, model, requires_grad=True):
+ """Freeze all parameters in the model initially."""
+ for param in model.parameters():
+ param.requires_grad = requires_grad
+
+ def _set_trainable_params(self, model, trainable_params_regex, requires_grad=True):
+ """Unfreeze specific parameters that match the regex patterns."""
+ for name, param in model.named_parameters():
+ if any(re.fullmatch(pattern, name) for pattern in trainable_params_regex):
+ param.requires_grad = requires_grad
+ print(f"{name}:requires_grad\t{requires_grad}")
+
+ def forward_with_cache(self, model, inputs, module, no_grad=True):
+ cache = []
+ def hook(module, input, output):
+ if isinstance(output, tuple):
+ cache.append(output[0])
+ else:
+ cache.append(output)
+ return None
+
+ hook_handle = module.register_forward_hook(hook)
+ if no_grad:
+ with torch.no_grad():
+ _ = model(**inputs)
+ else:
+ _ = model(**inputs)
+ hook_handle.remove()
+ return cache[0]
+
+ def get_control_vector(self, dim):
+ if self.control_vec is None:
+ random_vector = torch.rand(1,1, dim)
+ self.control_vec = random_vector / torch.norm(random_vector) * self.steering_coeff
+ return self.control_vec
+
+
+ def compute_activation_loss(self, activation1, avtivation2, mask):
+ squared_diff = torch.nn.functional.mse_loss(activation1, avtivation2, reduction="none") # Shape (b, s, d)
+ expanded_mask = mask.unsqueeze(-1).expand_as(squared_diff) # Shape: [b, s, d]
+ # squared_diff_sum = (squared_diff * expanded_mask).sum(dim=(1, 2)) # Sum over seq_len and feature dim
+ squared_diff_sum = (squared_diff * expanded_mask).mean(dim=2).sum(dim=(1)) # Sum over seq_len and feature dim
+ num_tokens = mask.sum(dim=-1, keepdim=True) # Sum over seq_len, Shape: [b, 1]
+ return (squared_diff_sum / num_tokens).mean()
+
+ def compute_retain_loss(self, model, retain_inputs):
+ model_retain_activations = self.forward_with_cache(model, retain_inputs, module=self.model_module, no_grad=False)
+ ref_retain_activations = self.forward_with_cache(self.ref_model, retain_inputs, module=self.ref_module, no_grad=True).to(model_retain_activations.device)
+ mask = (retain_inputs['labels'] != -100) # Shape: [b, s]
+ retain_loss = self.compute_activation_loss(model_retain_activations, ref_retain_activations, mask)
+ return retain_loss
+
+ def compute_loss(self, model, inputs, return_outputs=False):
+ forget_inputs = inputs["forget"]
+ forget_inputs = {
+ "input_ids": forget_inputs["input_ids"],
+ "attention_mask": forget_inputs["attention_mask"],
+ "labels": forget_inputs["labels"],
+ }
+
+ model_forget_activations = self.forward_with_cache(model, forget_inputs, self.model_module, no_grad=False)
+ control_vec = forget_inputs.get("control_vec", self.get_control_vector(model_forget_activations.shape[-1]))
+ control_vec = control_vec.to(dtype=model_forget_activations.dtype, device=model_forget_activations.device)
+ control_vec = control_vec.expand_as(model_forget_activations)
+ mask = (forget_inputs['labels'] != -100) # Shape: [b, s]
+ forget_loss = self.compute_activation_loss(model_forget_activations, control_vec, mask)
+
+ retain_inputs = inputs["retain"]
+ retain_inputs = {
+ "input_ids": retain_inputs["input_ids"],
+ "attention_mask": retain_inputs["attention_mask"],
+ "labels": retain_inputs["labels"],
+ }
+ retain_loss = self.compute_retain_loss(model=model, retain_inputs=retain_inputs)
+
+ loss = self.gamma * forget_loss + self.alpha * retain_loss
+
+ return (loss, model_forget_activations) if return_outputs else loss
\ No newline at end of file