Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Roberta and few new tests for Bert #778

Closed

Conversation

erenup
Copy link
Contributor

@erenup erenup commented Dec 29, 2023

Hi @ncomly-nvidia @Shixiaowei02 @kaiyux @juney-nvidia @jdemouth-nvidia and nvidia team,

First of all, thank you very much for the great work of TensorRT-LLM!

Pull Request Topic

This Pull Request is to support Roberta model and more test cases for Bert model.

As @ncomly-nvidia's TensorRT-LLM Requests mentioned, Roberta is expected to be supported. So I developed roberta model based on my understanding of the TensorRT-LLM framework.

There are also some issues related to Bert and Roberta:

Hope this PR could be helpful.

Features

This PR is related to 2 models: Bert related improvement and Roberta Support.

Bert Improvement

For Bert model:

  • I added more test cases to consider the normal "attention_mask" returned by the Huggingface transformers tokenizer. This will make it easier for real-data usage for beginners if they are using the huggingface Bert version.
  • when I added attention_mask, I found the original code of Bert model did not consider the attention mask (when use_plugin=False). So I implemented the attention mask by learning the original huggingface default attention mask.
  • I also implemented the BertSequenceClassification model.
  • In tests/model/test_bert.py input_ids for tests are randomly generated from [0, vocab_size), which covers more possible cases.
  • Now the test_bert covers 24 tests for different usages.

Roberta model:

  • I implemented a Roberta model which is similar to Bert model. The main difference is about position_id and usage of pad_token_id.
  • For the Roberta model, RobertaModel/RobertaForQuestionAnaswering/RobertaForSequenceClassification are implemented.
  • examples for roberta model are added.
  • tests/model/test_roberta.py has 24 tests for different usages.

Environment and Testing

  • I tested all codes under a single 4080 GPU. So It may have some limitations.
  • tests are all passed in my hardware.
  • I install the most recent main branch in my environment (tensorrt-llm 0.7.1 in /code/tensorrt_llm of my docker).

Happy New Year everyone.

Thank you very much.

@juney-nvidia
Copy link
Collaborator

@erenup

Thanks for the great contributions to TensorRT-LLM! Really impressed by your efforts.

We will be excited to merge your contribution into our internal repo and then release it onto the github later, so the community can also benefit from your great work.

Currently, there are ongoing efforts to improve the TensorRT-LLM workflow, such as unifying the build logic as well as the runtime logic. Here are the examples of reimplementing Bloom/OPT with the new workflow:

And we are actively working to reimplement other models with the new workflow.

There are two mechanisms to merge your contributions:

  • Merge your PR into our internal repo based on the current workflow and reimplement it with the new workflow later.
  • Hold the merge for a while and reimplement it directly based on the new workflow.

Let's discuss which way makes the most sense.

Thanks for your great contributions to TensorRT-LLM again.

Happy new year:)

June

@erenup
Copy link
Contributor Author

erenup commented Dec 30, 2023

Hi @juney-nvidia,

Thank you very much for your timely response!

A unified converting and building workflow is nice and the new workflow is also elegant for various decoder models👍.

However, as a beginner of Tensorrt-LLM, some config terms in the new unified build.py such as use_custom_all_reduce or use_gpt_attention_plugin may be unfamiliar for me if I directly looked at this new script as my starting point to use Bert in Tensorrt-LLM. Some config terms may also not be very related to Bert/Roberta models.

So I'd like to suggest we first keep the current workflow of Bert/Roberta and then have a new unified workflow, which allows us to track these changes and improvements between the current and new workflow. Also, before the new workflow of Bert/Roberta comes, the community can benefit from Roberta/Bert/XLMRoebrta models in TensorRT-LLM.

Thank you very much!

@juney-nvidia
Copy link
Collaborator

Hi @juney-nvidia,

Thank you very much for your timely response!

A unified converting and building workflow is nice and the new workflow is also elegant for various decoder models👍.

However, as a beginner of Tensorrt-LLM, some config terms in the new unified build.py such as use_custom_all_reduce or use_gpt_attention_plugin may be unfamiliar for me if I directly looked at this new script as my starting point to use Bert in Tensorrt-LLM. Some config terms may also not be very related to Bert/Roberta models.

So I'd like to suggest we first keep the current workflow of Bert/Roberta and then have a new unified workflow, which allows us to track these changes and improvements between the current and new workflow. Also, before the new workflow of Bert/Roberta comes, the community can benefit from Roberta/Bert/XLMRoebrta models in TensorRT-LLM.

Thank you very much!

@erenup

Hi, we had a discussion with the team and there will be engineers assigned to help merge this MR into our internal repo firstly, then publish to the github repo.

We will keep you posted with the progress.

Thanks

June

@erenup
Copy link
Contributor Author

erenup commented Jan 4, 2024

Hi @juney-nvidia Thank you very much👍

@erenup
Copy link
Contributor Author

erenup commented Jan 4, 2024

Hi @juney-nvidia,

Previously, I also made a parallel/related PR in the tensorrt backend repo to support the deployment of Tensorrt-LLM-based classification models. I hope this simplified triton classification example can help the community deploy classification models based on your optimized transformers more easily and faster!

I am new to both Tensorrt-LLM and Triton, so there may exist potential misunderstandings of your framework. Please feel free to adopt any useful codes of these PRs at your convenience.

Thank you very much. 😊

@symphonylyh symphonylyh requested review from symphonylyh and removed request for symphonylyh January 9, 2024 23:32
@symphonylyh symphonylyh self-assigned this Jan 9, 2024
@symphonylyh
Copy link
Collaborator

Thanks, @erenup, for the great end-to-end contribution!
We will check your Triton backend PR as well, and get you contributed to both repos. In terms of the timeline, we plan to go with 2-step, first merge your RoBERTa PR here, and then the Triton one with about 1-2 week after the first.

@erenup
Copy link
Contributor Author

erenup commented Jan 11, 2024

Hi @symphonylyh Thank you very much!

@juney-nvidia
Copy link
Collaborator

@erenup

Hi,

As @symphonylyh mentioned before, we already started the efforts of integrating your nice MR into our internal repo.

During the integration, we found that it may be better to add the Roberta support as the variant of the existing BERT implementation to remove duplicated code, based on this idea, what will be finally merged into the github will not be exactly the same as what you have done in this MR. For sure your contributions will still be acknowledged since you initiated the efforts.
Is it okay from your perspective?

June

@erenup
Copy link
Contributor Author

erenup commented Jan 12, 2024

Hi @juney-nvidia

Thank you very much.

Yes, Roberta can be the variant of BERT. As I mentioned in the first message of this PR "Roberta model is similar to the Bert model. The main difference is about position_id and usage of pad_token_id."

However, I think it may not be straightforward for the beginner to understand this difference according to the multiple issues I have seen. so when we combine them, we'd better have a document section or readme section to tell users how to use them correctly.

Thank you very much.

@juney-nvidia
Copy link
Collaborator

Hi @juney-nvidia

Thank you very much.

Yes, Roberta can be the variant of BERT. As I mentioned in the first message of this PR "Roberta model is similar to the Bert model. The main difference is about position_id and usage of pad_token_id."

However, I think it may not be straightforward for the beginner to understand this difference according to the multiple issues I have seen. so when we combine them, we'd better have a document section or readme section to tell users how to use them correctly.

Thank you very much.

Hi @erenup

Thanks. For sure the necessary documentation will be prepared to tell the users how to enable it properly.

Thanks again for your contribution and great suggestions.

June

@symphonylyh
Copy link
Collaborator

@erenup we have merged the code internally for RoBERTa support, based on your PR and our refactor mentioned above (i.e., implement as delta against BERT, instead of a standalone model).

It will be in the v0.8 release. Thanks for your contribution!

@erenup
Copy link
Contributor Author

erenup commented Jan 26, 2024

Hi @symphonylyh

Look forward to seeing the new release.

Thank you very much.

@zhangjiawei5911
Copy link

Hi, I am based on https://github.com/erenup/TensorRT-LLM accelerate a four-classification Bert model using trt llm. My code successfully ran through, but there are two issues that have not been resolved. The first is that the acceleration effect of trt llm on Bert is not significant, and the speed is only 1.8 times that of hf. The second issue is that the logits and hf of the Bert output accelerated by trt-llm have a significant diff, so I would like to disturb you. I hope you can guide me. Thank you and best wish!!!

===========================================================
The following is the output of the trt llm accelerated Bert for res, and the output of hf for ref. The attachment contains the relevant code:
(Pdb) res[:10]
tensor([[ 0.1119, 0.7280, 0.1652, -0.8477],
[ 1.3018, 0.4858, -0.5298, -1.4414],
[ 1.3691, 0.4497, -0.5195, -1.4941],
[ 0.9985, 0.5444, -0.1862, -1.3711],
[ 0.5024, 0.6372, 0.1122, -1.1104],
[ 0.6802, 0.6045, 0.0274, -1.1963],
[ 0.3950, 0.6201, 0.1122, -0.9956],
[ 0.5947, 0.2908, -0.0032, -0.8525],
[ 0.3560, 0.7339, 0.0871, -1.0303],
[ 0.1573, 0.3076, 0.5088, -0.6797]], device='cuda:0',
dtype=torch.float16)
(Pdb) ref[:10]
tensor([[ 0.1131, 0.7305, 0.1624, -0.8496],
[-0.3655, 0.8359, 0.4651, -0.6289],
[-0.1752, -0.0694, -0.0105, 0.1149],
[-0.9072, 0.2054, 0.8828, 0.2087],
[-0.8599, 0.2810, 1.0967, 0.0333],
[ 0.2371, 0.9204, 0.0030, -1.0273],
[-0.4199, 0.4556, 0.3325, -0.2288],
[ 1.2705, 0.5459, -0.4202, -1.5293],
[-0.4500, 0.3540, 0.6196, -0.2308],
[-0.6636, 0.3557, 0.9224, -0.1521]], device='cuda:0',
dtype=torch.float16)
trt-llm.tar.gz

@erenup
Copy link
Contributor Author

erenup commented Jan 29, 2024

hi @zhangjiawei5911

  1. symphonylyh said this feature would be released in v0.8. You may use the official one soon later. since I am not sure if your environment is the same as mine.
  2. you need to use bert_attention and other plugins for the speed. all options can be find in arguments of build.py.
  3. when you deploy it, the speed may also be influenced by your GPU version and your GPU utilization. You can obverse if your GPU utilization is above 80%, if it's not, you may increase the batch_size to make use of more gpu.
  4. The logits will indeed be different between the original hf and tensorrt because of the precision. You can see more details in the test python scripts. You can also log every layer's outputs to debug if anything is wrong with your converted tensorrt. You can follow this: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.5.0/docs/source/2023-05-19-how-to-debug.md

Thank you very much for using and testing this feature.

@zhangjiawei5911
Copy link

hi, @erenup

Thank you for your reply, your work has greatly inspired me.
I have two question, how many times can trt-llm accelerate bert compared to hf on your gpu device? Futhermore, compared to FasterTransformer?
Does trt-llm support int8 for bert?

@erenup
Copy link
Contributor Author

erenup commented Jan 29, 2024

Hi @zhangjiawei5911, Thank you very much.

I did not make too many comparisons between different settings. In one 4080 gpu with 128 max-seq-lenth, 80%+ GPU utilization, fp16. the speed of 12 layers of bert with tensorrt-llm can be about to 1k requests/s. It's enough and super fast for me I think.

I did not try int8 since fp16 is already enough for me.

Hope it could be useful for you.

@symphonylyh
Copy link
Collaborator

Update: it will be in v0.8 official version release, and it's now already released earlier in the dev main branch, with annoucement and acknowledgement: #1020.

Thanks for the contribution!

Closing for now. @erenup Please check our modified implementation based on your PR, and open any issue if needed. And @zhangjiawei5911 can you please rebase your code on the latest main and see if you can come up with a reproducible github issue if needed? Thanks

@jayakommuru
Copy link

Hi @symphonylyh are sequence classification tasks with T5 models not supported yet?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants