Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using Bert/Roberta with "tensorrtllm" backend directly ? (no Python lib like tensorrt-llm package) #368

Open
2 of 4 tasks
pommedeterresautee opened this issue Mar 7, 2024 · 9 comments
Assignees
Labels
bug Something isn't working triaged Issue has been triaged by maintainers

Comments

@pommedeterresautee
Copy link

pommedeterresautee commented Mar 7, 2024

System Info

  • Ubuntu
  • GPU A100 / 3090 RTX
  • docker nvcr.io/nvidia/tritonserver:24.02-trtllm-python-py3
  • Python tensorrt-llm package (version 0.9.0.dev2024030500) installed in the docker image (no other installation)

Who can help?

As it s not obvious if this is a doc issue or a feature request:
@ncomly-nvidia @juney-nvidia

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I have compiled a Roberta model for classification (roberta more exactly) on tensorrt-llm. Accuracy is good, perf too.
It follows code from example folder of tensorrt-llm repo.

If I follow receipe from NVIDIA/TensorRT-LLM#778, triton serves the model, with expected performances.

However this PR rely on the use of tensorrt-llm package, which means using custom Python env quite slow to load, or custom image. If possible I would prefer to use vanilla image for maintenance reason.

I tried to use directly the tensorrtllm backend, but it crashes whatever I tried.

name: "tensorrt_llm"
backend: "tensorrtllm"
max_batch_size: 200

model_transaction_policy {
  decoupled: false
}

dynamic_batching {
    preferred_batch_size: [ 200 ]
    max_queue_delay_microseconds: 2000
}

input [
  {
    name: "input_ids"
    data_type: TYPE_INT32
    dims: [ -1 ]
    allow_ragged_batch: true
  },
  {
    name: "input_lengths"
    data_type: TYPE_INT32
    dims: [ 1 ]
    reshape: { shape: [ ] }
  }
]
output [
  {
    name: "output"
    data_type: TYPE_FP32
    dims: [ -1]
  }
]
instance_group [
  {
    count: 1
    kind : KIND_CPU
  }
]
parameters: {
  key: "FORCE_CPU_ONLY_INPUT_TENSORS"
  value: {
    string_value: "no"
  }
}
parameters {
  key: "gpt_model_path"
  value: {
    string_value: "/engines/model-ce/"
  }
}
parameters: {
  key: "gpt_model_type"
  value: {
    string_value: "v1"
  }
}

and the /engines/model-ce/config.json contains:

{
  "builder_config": {
    "int8": false,
    "name": "CamembertForSequenceClassification",
    "precision": "float16",
    "strongly_typed": false,
    "tensor_parallel": 1,
    "use_refit": false
  },
  "plugin_config": {
    "attention_qk_half_accumulation": false,
    "bert_attention_plugin": "float16",
    "context_fmha": true,
    "context_fmha_fp32_acc": true,
    "enable_xqa": false,
    "gemm_plugin": null,
    "gpt_attention_plugin": null,
    "identity_plugin": null,
    "layernorm_quantization_plugin": null,
    "lookup_plugin": null,
    "lora_plugin": null,
    "moe_plugin": null,
    "multi_block_mode": false,
    "nccl_plugin": null,
    "paged_kv_cache": false,
    "quantize_per_token_plugin": false,
    "quantize_tensor_plugin": false,
    "remove_input_padding": false,
    "rmsnorm_quantization_plugin": null,
    "smooth_quant_gemm_plugin": null,
    "tokens_per_block": 128,
    "use_context_fmha_for_generation": false,
    "use_custom_all_reduce": false,
    "use_paged_context_fmha": false,
    "weight_only_groupwise_quant_matmul_plugin": null,
    "weight_only_quant_matmul_plugin": null
  }
}

However it crashes (see below).

Is it even possible to use this backend for a bert like model?

Fastertransformer dev being stopped, and TRT vanilla example of Bert deploy being 2 years old, tensorrt-llm option seems to be the most up to date for NLP models.

Expected behavior

it prints the IP and the port and it serves the model.

actual behavior

Trying to load the server produces those logs:

root@geantvert:/deploy-triton# tritonserver --model-repository=/configuration
W0307 19:23:28.634363 899 pinned_memory_manager.cc:271] Unable to allocate pinned system memory, pinned memory pool will not be available: no CUDA-capable device is detected
I0307 19:23:28.634412 899 cuda_memory_manager.cc:117] CUDA memory pool disabled
E0307 19:23:28.636635 899 server.cc:243] CudaDriverHelper has not been initialized.
I0307 19:23:28.638468 899 model_lifecycle.cc:469] loading: preprocessing:1
I0307 19:23:28.638521 899 model_lifecycle.cc:469] loading: tensorrt_llm:1
[TensorRT-LLM][WARNING] gpu_device_ids is not specified, will be automatically set
[TensorRT-LLM][WARNING] max_beam_width is not specified, will use default value of 1
[TensorRT-LLM][WARNING] max_tokens_in_paged_kv_cache is not specified, will use default value
[TensorRT-LLM][WARNING] Cannot find parameter with name: batch_scheduler_policy
[TensorRT-LLM][WARNING] enable_chunked_context is not specified, will be set to false.
[TensorRT-LLM][WARNING] kv_cache_free_gpu_mem_fraction is not specified, will use default value of 0.9 or max_tokens_in_paged_kv_cache
[TensorRT-LLM][WARNING] enable_trt_overlap is not specified, will be set to false
[TensorRT-LLM][WARNING] normalize_log_probs is not specified, will be set to true
[TensorRT-LLM][WARNING] exclude_input_in_output is not specified, will be set to false
[TensorRT-LLM][WARNING] max_attention_window_size is not specified, will use default value (i.e. max_sequence_length)
[TensorRT-LLM][WARNING] enable_kv_cache_reuse is not specified, will be set to false
[TensorRT-LLM][WARNING] Parameter version cannot be read from json:
[TensorRT-LLM][WARNING] [json.exception.out_of_range.403] key 'version' not found
[TensorRT-LLM][INFO] No engine version found in the config file, assuming engine(s) built by old builder API.
[TensorRT-LLM][WARNING] Parameter pipeline_parallel cannot be read from json:
[TensorRT-LLM][WARNING] [json.exception.out_of_range.403] key 'pipeline_parallel' not found
E0307 19:23:28.734453 899 backend_model.cc:691] ERROR: Failed to create instance: unexpected error when creating modelInstanceState: [json.exception.out_of_range.403] key 'num_layers' not found
E0307 19:23:28.734514 899 model_lifecycle.cc:638] failed to load 'tensorrt_llm' version 1: Internal: unexpected error when creating modelInstanceState: [json.exception.out_of_range.403] key 'num_layers' not found
I0307 19:23:28.734538 899 model_lifecycle.cc:773] failed to load 'tensorrt_llm'
...

additional notes

N/A

@pommedeterresautee pommedeterresautee added the bug Something isn't working label Mar 7, 2024
@byshiue
Copy link
Collaborator

byshiue commented Mar 13, 2024

Currently, backend only supports decoder model.

@byshiue byshiue self-assigned this Mar 13, 2024
@byshiue byshiue added the triaged Issue has been triaged by maintainers label Mar 13, 2024
@pommedeterresautee
Copy link
Author

Thank you a lot @byshiue for your answer.
Is encoder support planed?

Our use case for encoder models are RAG linked.
Vectorization is heavy on compute at indexation time and reranking is quite heavy at inference time (depending of how many docs you rerank obviously).
I guess in 2024 there are plenty of companies building a RAG.

FWIW, on A10 GPUs we got a 2.2 speedup on batch 64 / seqlen 430 (on average) compared to PyTorch FP16 in rerank (cross encoder), and, for our data, a 3.1 speedup on indexation (bi encoder setup).
So TRT LLM in RAG (meaning support of encoder only models) makes lots of sense, and a direct support from the backend may help.

@vectornguyen76
Copy link

@pommedeterresautee Why don't you use TensorRT for embedding model instead of TensorRT-LLM

@robosina
Copy link

robosina commented Mar 18, 2024

Currently, backend only supports decoder model.

@byshiue Can't we just use the chained models(ensemble) in any encoder-decoder model?, I mean the encoder's output serves as the input for the decoder, and also this applies to the the cross-attention layer as well I guess? What constraints prevent us from using the encoder-decoder model here? Thanks in advanced

@byshiue
Copy link
Collaborator

byshiue commented Mar 19, 2024

@robosina It it not supported yet, instead of it cannot be supported.

@jayakommuru
Copy link

Hi @byshiue are sequence classification with T5 models not supported yet?

@yaysummeriscoming
Copy link

@robosina It it not supported yet, instead of it cannot be supported.

I'd love to see this feature - is there anywhere I can track it?

@WissamAntoun
Copy link

Thank you a lot @byshiue for your answer. Is encoder support planed?

Our use case for encoder models are RAG linked. Vectorization is heavy on compute at indexation time and reranking is quite heavy at inference time (depending of how many docs you rerank obviously). I guess in 2024 there are plenty of companies building a RAG.

FWIW, on A10 GPUs we got a 2.2 speedup on batch 64 / seqlen 430 (on average) compared to PyTorch FP16 in rerank (cross encoder), and, for our data, a 3.1 speedup on indexation (bi encoder setup). So TRT LLM in RAG (meaning support of encoder only models) makes lots of sense, and a direct support from the backend may help.

@pommedeterresautee did you notice speed ups when comparing TensorRT-LLM vs TensorRT (from transformer-deploy) or kernl?

@pommedeterresautee
Copy link
Author

On large batches yes but we are using custom code to reach peak performance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

7 participants