Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bert] Feature: Custom Model Outputs #31

Merged
merged 2 commits into from
May 8, 2024

Conversation

bkonkle
Copy link
Contributor

@bkonkle bkonkle commented May 7, 2024

Closes #22

  • Adds an optional Pooler layer for text classification using models like plain BERT instead of RoBERTa.
  • Outputs both the last hidden states and the optional Pooler output if enabled.
  • Adds the pooler layer to the loader, for using pretrained models.
  • Makes the pad_token_idx public for things like the batcher to use.
  • Removes .clone() in a few places where it isn't needed.
  • Adds .envrc to gitignore for direnv users.
  • Adds .vscode to gitignore for VS Code users.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ping @laggui for an additional review before merging.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM! One minor comment regarding implementation.

But I cloned your fork with the branch changes and tried to run the example with the same arguments as the README and it failed:

Model variant: roberta-base
thread 'main' panicked at src/loader.rs:304:56:
Config file present: InvalidFormat("missing field `with_pooling_layer` at line 21 column 1")
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

Seems the downloaded config does not have the newly added field and it fails to parse. We should fix that in the implementation.

Have you tried the example with your changes @bkonkle?

@@ -18,13 +18,13 @@ pub struct BertEmbeddingsConfig {

#[derive(Module, Debug)]
pub struct BertEmbeddings<B: Backend> {
pub pad_token_idx: usize,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason why this is now public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is specifically so that I can use it here, in the batcher for my text classification pipeline (and later for the token classification pipeline): https://github.com/bkonkle/burn-transformers/blob/0.1.0/src/pipelines/text_classification/batcher.rs#L72

@bkonkle
Copy link
Contributor Author

bkonkle commented May 7, 2024

Thanks for the review! I tested it, but then made the last-minute change to add with_pooling_layer to the config, instead of passing it in as an argument. 😅 I failed to test it afterwards, and since this config property isn't found in the original Bert model config I need to default it to false. I'll fix shortly.

I might be able to work up some Github Actions code to run those examples automatically as part of PR checks. I'll open a separate PR for that if I do.

Update: Defaulting it to false with #[config(default = false)] doesn't actually prevent the error, since it looks for the field when it attempts to load the config from the base model file. The fact that it's not present in the base config file is why I originally structured the flag as an argument, but when Nathan suggested moving it into the config I didn't think it would be an issue. 😅 I'm working towards a solution now.

@bkonkle
Copy link
Contributor Author

bkonkle commented May 7, 2024

I went with pub with_pooling_layer: Option<bool> to avoid the problems with loading the base model config, coupled with .unwrap_or(false) to resolve the wrapped value.

Examples are working again for me:

Model variant: roberta-base
Input: Shape { dims: [3, 63] } // (Batch Size, Seq_len)
Roberta Sentence embedding Shape { dims: [3, 768] } // (Batch Size, Embedding_dim)

Model variant: bert-base-uncased
Input: Shape { dims: [3, 64] } // (Batch Size, Seq_len)
Roberta Sentence embedding Shape { dims: [3, 768] } // (Batch Size, Embedding_dim)

Model variant: roberta-large
Input: Shape { dims: [3, 63] } // (Batch Size, Seq_len)
Roberta Sentence embedding Shape { dims: [3, 1024] } // (Batch Size, Embedding_dim)

I'm using this in the burn-transformers library like this: https://github.com/bkonkle/burn-transformers/blob/0.2.0/src/models/bert/sequence_classification/text_classification.rs#L130-L131

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! 🙂

I'll approve with the latest changes.

@laggui laggui merged commit 14ae737 into tracel-ai:main May 8, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Custom BERT Model outputs
3 participants