Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Various Configurable Initializations #161

Merged
merged 52 commits into from
Jul 9, 2024

Conversation

flxst
Copy link
Member

@flxst flxst commented Jun 25, 2024

What does this PR do?

This PR implements the following weight initializations (see https://arxiv.org/abs/2312.16903):

  • plain (= all weights normally distributed)
  • scaled (= same as plain, but narrower distribution for projection weights W0 & W2)
  • scaled embed (= same as scaled, but wider distribution for embedding)

A weight initialisation component is introduced which modifies the model weights in place (see #168 for more details)

General Changes

Breaking Changes

  • All training configs require an additional component for initialization of the raw model (i.e. the model with random weights), as shown here.

Checklist before submitting final PR

  • My PR is minimal and addresses one issue / enhancement in isolation
  • I have merged the latest version of the target branch into this feature branch
  • I have reviewed my own code w.r.t. correct implementation, missing type hints, proper documentation, etc.
  • I have run a sample config for model training
  • I have fixed all failing tests (python tests/tests.py)

@flxst flxst self-assigned this Jun 25, 2024
@flxst flxst marked this pull request as draft June 25, 2024 11:50
@flxst flxst marked this pull request as ready for review June 27, 2024 13:22
@le1nux le1nux self-requested a review June 27, 2024 14:29
@le1nux le1nux added the enhancement New feature or request label Jun 27, 2024
Copy link
Member

@le1nux le1nux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! I think the implementation is correct (checked with the paper). I only had 2-3 questions where things were not clear to me. Regarding the integration into modalities, I think we need more iterations fixing some issues about the dependencies and coupling.

I'm not the biggest fan of having the initialisation in the parent class of the models (i.e., NNModel).
The reason is that the weight initialisation is pretty much dependent on the concrete model implementation. For instance, the gpt2 model must have the c_proj parameters initialised in certain ways, whereas CoCa for instance has different named parameters. Currently, we do string matching in the parent class which introduces low level, model-specific dependencies in the parent class. Similarly, this can be the case for modules, e.g., a custom linear layer that we introduce in a particular model. Also in this case, we would need custom initialisation in the parent class. For each of these special cases we would have to modify the parent class, strongly coupling dependencies.

We could resolve this inverse dependency by introducing an generic WeightInitializer class, that initializes the weights of a model in a generic, configurable way. The WeightInitializer class would be passed to the constructor of the concrete model. The model would then call something like weight_initalizer.init_weights(self, weight_init_config). Basically, the strategy pattern that modifies the calling object (here, the model) in place.

We should have a generic WeightInitializer covering the general cases. For specific models, we can also introduce new WeightInitializers that are specific to a certain model. These WeightInitializers should be instantiable as part of the hierarchical instantiation, as we do for the other components.

The config for the model and WeightInitializer would look like this:

weight_initializer:
   <weight init config...>

model:
  component_key: model
  variant_key: gpt2
  config:
    n_layer: 2
    n_head_q: 8
    n_head_kv: 4
    ffn_hidden: 128
    weight_initializer:
              instance_key: weight_initializer
              pass_type: BY_REFERENCE

Additionally, I left a bunch of mostly minor comments regarding some questions and ideas.

src/modalities/models/model.py Outdated Show resolved Hide resolved
src/modalities/models/model.py Outdated Show resolved Hide resolved
src/modalities/models/model.py Outdated Show resolved Hide resolved
src/modalities/models/model.py Outdated Show resolved Hide resolved
src/modalities/models/model.py Outdated Show resolved Hide resolved
src/modalities/models/model.py Outdated Show resolved Hide resolved
src/modalities/models/model.py Outdated Show resolved Hide resolved
src/modalities/models/model.py Outdated Show resolved Hide resolved
src/modalities/models/model.py Outdated Show resolved Hide resolved
src/modalities/models/model.py Outdated Show resolved Hide resolved
@le1nux le1nux mentioned this pull request Jul 1, 2024
5 tasks
@le1nux le1nux requested review from le1nux and mali-git July 3, 2024 14:16
Copy link
Member

@le1nux le1nux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM :)

@le1nux le1nux merged commit 0b8dfc0 into dev_experiments Jul 9, 2024
@le1nux le1nux deleted the feat/initialization branch July 9, 2024 08:40
@le1nux le1nux mentioned this pull request Jul 9, 2024
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants