Skip to content
/ HINT Public

HyperInterval: Hypernetwork approach to training weight interval regions in continual learning

License

Notifications You must be signed in to change notification settings

gmum/HINT

Repository files navigation

HINT: Hypernetwork approach to training weight interval regions in continual learning

Abstract

Recently, a new Continual Learning (CL) paradigm was presented to control catastrophic forgetting, called Interval Continual Learning (InterContiNet), which relies on enforcing interval constraints on the neural network parameter space. Unfortunately, InterContiNet training is challenging due to the high dimensionality of the weight space, making intervals difficult to manage. To address this issue, we introduce HINT, a technique that employs interval arithmetic within the embedding space and utilizes a hypernetwork to map these intervals to the target network parameter space. We train interval embeddings for consecutive tasks and train a hypernetwork to transform these embeddings into weights of the target network. An embedding for a given task is trained along with the hypernetwork, preserving the response of the target network for the previous task embeddings. Interval arithmetic works with a more manageable, lower-dimensional embedding space rather than directly preparing intervals in a high-dimensional weight space. Our model allows faster and more efficient training. Furthermore, HINT maintains the guarantee of not forgetting. At the end of training, we can choose one universal embedding to produce a single network dedicated to all tasks. In such a framework, hypernetwork is used only for training and can be seen as a meta-trainer. HINT obtains significantly better results than InterContiNet and gives SOTA results on several benchmarks.

Teaser

Train interval embeddings for consecutive tasks and train a hypernetwork to transform these embeddings into weights of the target network.

Scheme of HINT training method

Environment

Use environment.yml file to create a conda environment with necessary libraries: conda env create -f environment.yml. The hypnettorch package is essential to easily create hypernetworks in PyTorch. Our implementaion is based on the hypermask repository.

Datasets

For the experiments and ablation study, we use 6 publicly available datasets:

The datasets may be downloaded when the algorithm runs. For each dataset, the CL task division setup follows the corresponding papers and is specified in our work, supplementary materials.

Usage

Folder AblationResults contains results of our ablation study, whereas DatasetHandlers and Utils contain handlers and functions for datasets used in the experiments, to apply specific data augmentation policies and task division. Moreover, folder IntervalNets contains interval implementation of the network architectures used in experiments and VanillaNets contains the basic convolutional network architectures, which are used when applying the interval relaxation technique to the training.

To train HINT in the Task-Incremetal Learning (TIL) scenraio, use the command python train_non_forced_method_type_scenario.py in Training folder, where method_type can be classification or regression. To conduct a grid search in this setup, one should set the variable create_grid_search to True in the train_non_forced_method_type_scenario.py file and modify the lists with hyperparameters for the selected dataset in the prepare_non_forced_scenario_params.py file.

To train in the Domain-Incremental Learning (DIL) scenario with nesting protocols, use the command python train_nested_scenario.py. This scenario works for classification for now. To conduct a grid search in this setup, one should set the variable create_grid_search to True in the train_nested_scenario.py file in Training folder and modify the lists with hyperparameters for the selected dataset in the prepare_nested_scenario_params.py file.

To train in the CIL scenario with entropy, set the variable dataset to a name of any of the datasets supported, e.g., dataset = "PermutedMNIST" in the entropy.py file and use the command python entropy.py.

Citation

@inproceedings{krukowski2024HINT,
  title={HINT: Hypernetwork approach to training weight interval regions in continual learning}, 
  author={Patryk Krukowski and Anna Bielawska and Kamil Książek and Paweł Wawrzyński and Paweł Batorski and Przemysław Spurek},
  year={2024}
}

License

Copyright 2024 IDEAS NCBR https://ideas-ncbr.pl/en/ and Group of Machine Learning Research (GMUM), Faculty of Mathematics and Computer Science of Jagiellonian University https://gmum.net/.

This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. You should have received a copy of the GNU General Public License along with this program. If not, see https://www.gnu.org/licenses/.

About

HyperInterval: Hypernetwork approach to training weight interval regions in continual learning

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages