Compute MLE of biases and weights at BMM node construction#53
Draft
Cesar199999 wants to merge 1 commit intomainfrom
Draft
Compute MLE of biases and weights at BMM node construction#53Cesar199999 wants to merge 1 commit intomainfrom
Cesar199999 wants to merge 1 commit intomainfrom
Conversation
Contributor
|
That's indeed a shame to have the generics propagate. Maybe @jakehemmerle can jump on a live session with you to see if there's any alternative? |
Contributor
Author
Sounds good! I'll book a session |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The goal of this PR is to provide a solution for #29. Namely, the MLE's of the weight matrix and bias vector are recomputed at each commit and prove step. The proposed solution is to store the MLE's as an attribute of the BMMNode struct, and to compute them only once at the construction of the BMMNode. This saves a lot of recomputation and removes a duplicated chunk of code. Besides, even though the memory increase can be high (10~8x for 256 bits fields), it won't be a botleneck in the near future. For instance, on a GPT2 model (~700MB), the cached MLEs should't exceed 8GB
@Antonio95 proposed to lazily compute the MLE's at commit time and store them as optional attributes of the BMMNode struct. However, this requires to change the signature of the
commitfor every node type, to allow a mutable reference toself.A drawback of this solution is that now BMM nodes must include the field
Fas a type parameter, note that this would be the case for every solution that stores the MLE's as an attribute. Unfortunately, the genericFbubbles up all the way to theModelstruct, which adds an undesired layer of complexity.@HungryCatsStudio/maintainers, Do you think there is a better solution to this problem? If not, Do you think this solution is acceptable?