-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft: Feat/initialization component #168
Conversation
… init. Need to check how we want to handle this case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me generally, although I am a bit concerned about the complexity of this whole implementation. Added a few comments. The ones that contain something like "std can be auto also for initialization types other than plain" are the most important ones, I think this should definitely be fixed.
src/modalities/nn/weight_init/high_level_weight_init_factory.py
Outdated
Show resolved
Hide resolved
src/modalities/nn/weight_init/high_level_weight_init_factory.py
Outdated
Show resolved
Hide resolved
src/modalities/nn/weight_init/high_level_weight_init_factory.py
Outdated
Show resolved
Hide resolved
…or embedding during initialization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is another thing that needs to be fixed.
…rameterwiseNormalInitialization to support also plain initialisation
src/modalities/nn/weight_init/high_level_weight_init_factory.py
Outdated
Show resolved
Hide resolved
…en std is of type float
What does this PR do?
This PR introduces the components for weight initialisation and is based on PR #161.
In PR #161 the differenct initialization methods
plain
,scaled
andscaled_embed
(see https://arxiv.org/abs/2312.16903) were implemented and added to the abstractNNModel
class.Due to some design concerns (e.g., some GPT2 internals were called from the parent), we decided to introduce a weight initialisation component that modifies the model weights in place.
General changes
plain
,scaled
andscaled_embed
initialisation.Breaking Changes
Checklist before submitting final PR
python tests/tests.py
)