-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bert] Feature: Custom Model Outputs #31
[Bert] Feature: Custom Model Outputs #31
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ping @laggui for an additional review before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM! One minor comment regarding implementation.
But I cloned your fork with the branch changes and tried to run the example with the same arguments as the README and it failed:
Model variant: roberta-base
thread 'main' panicked at src/loader.rs:304:56:
Config file present: InvalidFormat("missing field `with_pooling_layer` at line 21 column 1")
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
Seems the downloaded config does not have the newly added field and it fails to parse. We should fix that in the implementation.
Have you tried the example with your changes @bkonkle?
@@ -18,13 +18,13 @@ pub struct BertEmbeddingsConfig { | |||
|
|||
#[derive(Module, Debug)] | |||
pub struct BertEmbeddings<B: Backend> { | |||
pub pad_token_idx: usize, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any particular reason why this is now public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is specifically so that I can use it here, in the batcher for my text classification pipeline (and later for the token classification pipeline): https://github.com/bkonkle/burn-transformers/blob/0.1.0/src/pipelines/text_classification/batcher.rs#L72
Thanks for the review! I tested it, but then made the last-minute change to add I might be able to work up some Github Actions code to run those examples automatically as part of PR checks. I'll open a separate PR for that if I do. Update: Defaulting it to false with |
…the original model
I went with Examples are working again for me: Model variant: roberta-base
Input: Shape { dims: [3, 63] } // (Batch Size, Seq_len)
Roberta Sentence embedding Shape { dims: [3, 768] } // (Batch Size, Embedding_dim)
Model variant: bert-base-uncased
Input: Shape { dims: [3, 64] } // (Batch Size, Seq_len)
Roberta Sentence embedding Shape { dims: [3, 768] } // (Batch Size, Embedding_dim)
Model variant: roberta-large
Input: Shape { dims: [3, 63] } // (Batch Size, Seq_len)
Roberta Sentence embedding Shape { dims: [3, 1024] } // (Batch Size, Embedding_dim) I'm using this in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution! 🙂
I'll approve with the latest changes.
Closes #22
pad_token_idx
public for things like the batcher to use..clone()
in a few places where it isn't needed..envrc
to gitignore for direnv users..vscode
to gitignore for VS Code users.