Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Bert Embedding Model #5447

Closed
wants to merge 627 commits into from
Closed

[Model] Bert Embedding Model #5447

wants to merge 627 commits into from

Conversation

laishzh
Copy link
Contributor

@laishzh laishzh commented Jun 12, 2024

Implement Bert Embedding Model

This PR implements the Bert Embedding Model which discussed in #5179.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this function down to the bottom of the file

Copy link
Contributor Author

@laishzh laishzh Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, leave a TODO here. I will move this Class to the bottom later.

super().__init__()
self.size = config.hidden_size

self.word_embeddings = nn.Embedding(config.vocab_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you look into using VocabParallelEmbedding from our parallel layers?

Using the nn.Embedding will not work with tensor parallelism

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I noticed this feature before. Already updated.

config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need dropout for inference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All dropouts have been removed.

bias=True,
quant_config=quant_config)

self.attn = Attention(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to use bidirectional attention here rather than causal attention

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to implement a new type of attention in vllm? Not sure if there is a way to use the current attention implementation with different parameters.

Copy link
Contributor Author

@laishzh laishzh Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-neuralmagic Hi, I'd like to hear your thoughts on this point. The framework of BERT model is almost completed. But the BertSelfAttention output differs from transformers. After diving into the Attention implementation, there are massive changes of Attention are needed to use bidirectional attention.
I also found some related discussions(#3117 (comment)). I think this PR depends on #4942. What's your opinions?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@laishzh sorry missed this note!

Yes. Pull in the BertSelfAttention from that PR. #4942 should land this week

@robertgshaw2-neuralmagic
Copy link
Collaborator

Thanks! You're on the right track here

Comment on lines 92 to 105
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a for-else loop? Please break this up as I find these pretty confusing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I also felt confused at the first time. I rewote this part just now. It's supposed to be easier to understand. Have any further suggestions?

config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow naming standards for layernorm

Suggested change
self.LayerNorm = nn.LayerNorm(config.hidden_size,
self.layernorm = nn.LayerNorm(config.hidden_size,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followed the suggestion, and added renaming of LayerNorm parameter when loading weights.

Comment on lines 358 to 360
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto on layernorm and dropout

Comment on lines 387 to 389
self.LayerNorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto on layernorm and dropout

@laishzh
Copy link
Contributor Author

laishzh commented Jun 13, 2024

Thanks! You're on the right track here

Wow! Really appreciate your guidance.

Comment on lines +166 to +169
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
config.hidden_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these be VocabParallelEmbedding as well?

@laishzh laishzh marked this pull request as ready for review August 19, 2024 07:42
@maxdebayser maxdebayser mentioned this pull request Aug 28, 2024
@laishzh
Copy link
Contributor Author

laishzh commented Sep 9, 2024

@robertgshaw2-neuralmagic @mgoin Hi, just update this work.
Because the BertEmbeddingModel is encoder-only architecture. I refactored the EmbeddingModelRunner as child class of EncoderDecoderModelRunner to reuse the encoder-decoder framework.

# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
seq_lens = model_input.seq_lens\
if not self.model_config.is_encoder_model \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-neuralmagic @mgoin Here still have a question. How to determine precisely which model is encoder model or decoder model. As an example, there is lake of fields is_decoder_model or is_encoder_model in config.json of Mistral Model which is implemented before. Here is the link: config.json of Mistral.

Signed-off-by: Max de Bayser <[email protected]>
@maxdebayser
Copy link
Contributor

Hi @laishzh, I'm working on another PR that is based on yours. In case it helps, I've solved the latest merge conflicts of your branch with main here: https://github.com/maxdebayser/vllm/tree/bert

laishzh and others added 2 commits September 26, 2024 23:23
# Conflicts:
#	vllm/inputs/data.py
@simon-mo simon-mo mentioned this pull request Oct 1, 2024
39 tasks
@maxdebayser
Copy link
Contributor

@laishzh @robertgshaw2-neuralmagic , I've solved that lastest merge conflicts here: https://github.com/maxdebayser/vllm/tree/bert

@laishzh
Copy link
Contributor Author

laishzh commented Oct 7, 2024

@robertgshaw2-neuralmagic Please draw your attention. I just revert the change to EmbeddingModelBlockManager. And there is still leaving a problem of how to distinguish the model arch(Encoder-Only, or Decoder-Only model), which is to determine the seq_len when pooling(#5447 (comment)). Please let me know if any changes are needed.

@robertgshaw2-neuralmagic
Copy link
Collaborator

Hey @laishzh - there were a few issues with this PR. Specifically, it uses the Encoder-Decoder pathway and the BertModel is not implemented properly (loading is not canonical, tensor parallelism does not work, and the pooling logic is not correct)

I need get this landed ASAP, so I finished off the PR here. #9056

I will added you as a co-author.

@laishzh
Copy link
Contributor Author

laishzh commented Oct 11, 2024

Hey @laishzh - there were a few issues with this PR. Specifically, it uses the Encoder-Decoder pathway and the BertModel is not implemented properly (loading is not canonical, tensor parallelism does not work, and the pooling logic is not correct)

I need get this landed ASAP, so I finished off the PR here. #9056

I will added you as a co-author.

I'm OK. Very glad to see this feature could be supported in vllm. I will delve into those points that you mentioned. Thanks for your work!

@DarkLight1337
Copy link
Member

Closing as superseded by #9056

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants