Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data Map Trainer Callback #31647

Open
nbertagnolli opened this issue Jun 26, 2024 · 1 comment
Open

Data Map Trainer Callback #31647

nbertagnolli opened this issue Jun 26, 2024 · 1 comment
Labels
Feature request Request for a new feature trainer

Comments

@nbertagnolli
Copy link
Contributor

Feature request

It would be nice to have a callback for the trainer class which could create Data Maps. See the paper for more details https://arxiv.org/pdf/2009.10795. A Data Map measures how a model's prediction of specific training data change over the course of model training.

The Callback should support:

  • Executing at each step or epoch
  • Should integrate directly with the Trainer class.
  • Should save the prediction of each training example as a matrix of the form [n_examples, n_labels] so that it can easily be stacked into [n_epochs, n_examples, n_labels]. Right now I'm saving things as a List[List[float]] but this might be sub optimal. It needs some way of getting the logged information later.

Running this colab notebook I made will generate data map outputs for classification tasks using the Trainer in line with what I was thinking. Here is what I have so far that works will for multilabel and multiclass classification using transformers.

class DataMapCallback(TrainerCallback):
    """Trainer Callback to save DataMap data.

    Original Paper: https://arxiv.org/pdf/2009.10795.pdf.

    This callback saves the predictions of the model on each training example
    at the end of every epoch to callback_dir/{epoch}.json.
    """

    def __init__(
        self,
        log_on: str = "epoch",
        callback_dir: str = ".",
        n_log_steps: Optional[int] = None,
        prediction_fn: Optional[Callable[[PreTrainedModel, DataLoader, TrainingArguments], List[List[float]]]] = None,
    ):
        self.callback_dir = callback_dir
        self.log_on = log_on
        self.log_count = 0
        self.n_log_steps = n_log_steps
        self.prediction_fn = self._predict if prediction_fn is None else prediction_fn

        # Handle discrepencies in how we initialize the logging mode.
        if n_log_steps is not None and self.log_on != "step":
            raise ValueError(
                "n_log_steps is only valid when on='step'.  If you want to to run datamaps based on steps please specify on='step'."
            )
        elif n_log_steps is None and self.log_on == "step":
            warnings.warn(
                "You have not specified n_log_steps.  This will result in a large number of datamaps being saved setting step size to 1."
            )
            self.n_log_steps = 1

        # Create the directory if it doesn't exist.
        if not os.path.exists(self.callback_dir):
            os.makedirs(self.callback_dir, exist_ok=True)

    def _predict(self, model, train_data_loader, args):
        if train_data_loader.batch_size is None:
          batch_size = args.per_device_train_batch_size
        else:
          batch_size = train_data_loader.batch_size
        batches = BatchSampler(
            SequentialSampler(train_data_loader.dataset),
            batch_size,
            False,
        )

        with torch.no_grad():
            predictions = []
            for batch in batches:
                # Adjust the indices to include the last element because python is [)
                start_idx, end_idx = batch[0], batch[-1] + 1

                # Extract the sample from the training dataset
                sample = train_data_loader.dataset[start_idx:end_idx]

                # Make sure to apply any data collators
                sample = train_data_loader.collate_fn(sample)

                # Move all samples to the appropriate device. We only do this
                # For args that are part of the model and the dataset.
                args = set(inspect.getfullargspec(model.forward).args).intersection(
                    set(sample.keys())
                )
                sample = {k: torch.tensor(sample[k]).to(model.device) for k in args}

                # Perform inference using the model
                current_preds = model(**sample)

                # Convert the predictions to a list and append them to the result
                predictions += current_preds.logits.tolist()

        return predictions

    def _save_predictions(self, model, train_data_loader, args):
        predictions = self.prediction_fn(model, train_data_loader, args)

        # Save Predictions.
        with open(os.path.join(self.callback_dir, f"{self.log_count}.json"), "w") as f:
            json.dump(predictions, f)
        self.log_count += 1

    def on_epoch_end(self, args, state, control, logs=None, **kwargs):
        """Predict on all training examples at the end of an epoch."""

        if self.log_on == "epoch":
            self._save_predictions(kwargs["model"], kwargs["train_dataloader"], args)

    def on_save(self, args, state, control, logs=None, **kwargs):
        """Predict on all training examples when a checkpoint is saved"""
        if self.log_on == "save":
            self._save_predictions(kwargs["model"], kwargs["train_dataloader"], args)

    def on_evaluate(self, args, state, control, logs=None, **kwargs):
        """Predict on all training examples when we run an evaluation loop."""
        if self.log_on == "evaluate":
            self._save_predictions(kwargs["model"], kwargs["train_dataloader"], args)

    def on_step_end(self, args, state, control, logs=None, **kwargs):
        """Predict at the end of a step."""
        if self.log_on == "step":
            self._save_predictions(kwargs["model"], kwargs["train_dataloader"], args)

Motivation

Optimizing data is as important as correctly configuring your model. Data Maps are an incredibly powerful tool which help us understand the data we are using to train specific tasks. Using Tensorboard to monitor the loss during training can identify many bugs. This is a technique which can be equally valuable. In my day to day this technique has seriously increased the performance of production models I've trained at multiple different companies. I think it's really useful for gaining insights about your data and also pushing the limits of your performance. I want to see everyone get the same benefits I've seen. I wrote a blog on doing this with sklearn if you want to see a simple example.

Your contribution

I'd love to contribute this. I have already created a working prototype with this colab notebook. It will generate data map outputs for classification tasks using the Trainer. I'm working on an example for non classification tasks as well. If you'd be willing to guide me on this addition, and you think it's valuable, I'd do as much of this as possible : ).

@nbertagnolli nbertagnolli added the Feature request Request for a new feature label Jun 26, 2024
@amyeroberts
Copy link
Collaborator

cc @muellerzr @SunMarc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature trainer
Projects
None yet
Development

No branches or pull requests

2 participants