Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Qwen 2 build.py multi gpu with 2 different GPU's issue #98

Open
teis-e opened this issue Mar 26, 2024 · 17 comments
Open

Qwen 2 build.py multi gpu with 2 different GPU's issue #98

teis-e opened this issue Mar 26, 2024 · 17 comments

Comments

@teis-e
Copy link

teis-e commented Mar 26, 2024

Model: Qwen1.5-72B-Chat-GPTQ-Int4

python3 gptq_convert.py --hf_model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen1.5-72B-Chat-GPTQ-Int4/snapshots/b8665876947e59ffb3fbf5b5caa9bd354e885a13 --tokenizer_dir /root/.cache/huggingface/hub/models--Qwen--Qwen1.5-72B-Chat-GPTQ-Int4/snapshots/b8665876947e59ffb3fbf5b5caa9bd354e885a13

Only 1 of 2 GPU's is loadig the model and then get CUDA out of memory error:
torch.cuda.OutOfMemoryError: CUDA out of memory.

I already tried:

if args.device == "cuda":
    #model.cuda()
    device = torch.device("cuda")
    model = nn.DataParallel(model) # added code
    model.to(device)

When i run the model with transformers normally it distributes it correctly.

Anyone knows a fix?

@teis-e
Copy link
Author

teis-e commented Mar 26, 2024

Nevermind i understand now i'm using an already quantized model, but now here:

python3 build.py --use_weight_only \ --weight_only_precision int4_gptq \ --per_group \ --hf_model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen1.5-72B-Chat-GPTQ-Int4/snapshots/b8665876947e59ffb3fbf5b5caa9bd354e885a13 \ --quant_ckpt_path /root/.cache/huggingface/hub/models--Qwen--Qwen1.5-72B-Chat-GPTQ-Int4/snapshots/b8665876947e59ffb3fbf5b5caa9bd354e885a13

Loading checkpoint shards: 100%|████████████████████████████████████| 11/11 [00:28<00:00,  2.59s/it]
loading attention weight...: 100%|██████████████| 80/80 [01:24<00:00,  1.06s/it]
loading other weight...:   0%|                         | 0/2963 [00:00<?, ?it/s][03/26/2024-20:19:17] [TRT-LLM] [I] converting: model.embed_tokens.weight
loading other weight...:  85%|███████████▉  | 2523/2963 [04:08<00:38, 11.36it/s]Killed

The code gets Killed randomly

@teis-e
Copy link
Author

teis-e commented Mar 26, 2024

Now after removing .cpu() in weights.py

I get this error after loading the weights:

[03/26/2024-21:07:58] [TRT] [W] UNSUPPORTED_STATESkipping tactic 8 due to insufficient memory on requested size of 19328401408 detected for tactic 0x000000000000001e.
[03/26/2024-21:07:58] [TRT] [E] 2: [virtualMemoryBuffer.cpp::resizePhysical::140] Error Code 2: OutOfMemory (no further information)
[03/26/2024-21:07:58] [TRT] [E] 2: [virtualMemoryBuffer.cpp::resizePhysical::140] Error Code 2: OutOfMemory (no further information)
[03/26/2024-21:07:58] [TRT] [W] Requested amount of GPU memory (19328401408 bytes) could not be allocated. There may not be enough free memory for allocation to succeed.

Altough there is enough vram available on my other gpu still (altough the model does load on both gpu's).

Then after this happens:
[03/26/2024-21:08:15] [TRT] [E] 2: [globWriter.cpp::makeResizableGpuMemory::423] Error Code 2: OutOfMemory (no further information) [03/26/2024-21:08:15] [TRT-LLM] [E] Engine building failed, please check the error log. [03/26/2024-21:08:15] [TRT-LLM] [I] Serializing engine to /app/tensorrt_llm/examples/qwen2/trt_engines/fp16/1-gpu/rank0.engine... Traceback (most recent call last): File "/app/tensorrt_llm/examples/qwen2/build.py", line 887, in <module> build(0, args) File "/app/tensorrt_llm/examples/qwen2/build.py", line 868, in build engine.save(args.output_dir) File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/runtime/engine.py", line 60, in save serialize_engine( File "/usr/local/lib/python3.10/dist-packages/tensorrt_llm/runtime/engine.py", line 18, in serialize_engine f.write(engine) TypeError: a bytes-like object is required, not 'NoneType'

Also after increasing swap memory, there is still the same error.

The issue seems to be in the beginning it does recognize both GPU's. But after the weights are loaded in only recognizes 1 of me GPU's.

@teis-e teis-e changed the title Qwen 2 gptq_convert.py multi gpu issue Qwen 2 build.py multi gpu issue Mar 26, 2024
@Tlntin
Copy link
Owner

Tlntin commented Mar 27, 2024

try to add --world_size 2 --tp_size 2 --pp_size 1

@teis-e
Copy link
Author

teis-e commented Mar 27, 2024

Build it successfully now, I got 2 files in

/app/tensorrt_llm/examples/qwen2/trt_engines/fp16/1-gpu

Does it automatically work for both gpu used together?

@Tlntin
Copy link
Owner

Tlntin commented Mar 27, 2024

yes, you can run like this:

mpirun -n 2 --allow-run-as-root  \
    python3 run.py --max_new_tokens=50 \
               --tokenizer_dir /root/.cache/huggingface/hub/models--Qwen--Qwen1.5-72B-Chat-GPTQ-Int4/snapshots/b8665876947e59ffb3fbf5b5caa9bd354e885a13  \
               --engine_dir=xxxxxx

@teis-e
Copy link
Author

teis-e commented Mar 27, 2024

When i run the engine in llama-index i get this error altough i'm running it in the same docker container:

[03/27/2024-11:09:02] [TRT] [E] 6: The engine plan file is generated on an incompatible device, expecting compute 8.0 got compute 8.9, please rebuild.

Any idea?

@Tlntin
Copy link
Owner

Tlntin commented Mar 27, 2024

What's your gpu?

@teis-e
Copy link
Author

teis-e commented Mar 27, 2024

I know, but llama index has a TensorRT-LLM library. I runned Llama13b engine before. Or you saying that Qwen2 is not supported anyway in llama-index?

llm = LocalTensorRTLLM(
    model_path="./model",
    engine_name="rank0.engine",
    tokenizer_dir="Qwen1.5-72B-Chat-GPTQ-Int4",
    completion_to_prompt=completion_to_prompt,
    verbose=True,
    max_new_tokens=640,
    temperature=0
)

These are my gpu's:

# nvidia-smi
Wed Mar 27 11:20:22 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-PCIE-40GB          Off | 00000000:01:00.0 Off |                    0 |
| N/A   40C    P0              38W / 250W |     13MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 4090        Off | 00000000:03:00.0 Off |                  Off |
|  0%   51C    P8              35W / 450W |    365MiB / 24564MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+

@teis-e
Copy link
Author

teis-e commented Mar 27, 2024

This command:

mpirun -n 2 --allow-run-as-root  \
    python3 run.py --max_new_tokens=50 \
               --tokenizer_dir /root/.cache/huggingface/hub/models--Qwen--Qwen1.5-72B-Chat-GPTQ-Int4/snapshots/b8665876947e59ffb3fbf5b5caa9bd354e885a13  \
               --engine_dir=xxxxxx

Gives:
run.py: error: unrecognized arguments: --max_new_tokens=50

@Tlntin
Copy link
Owner

Tlntin commented Mar 27, 2024

use --max_output_len to replace --max_new_tokens=50

@Tlntin
Copy link
Owner

Tlntin commented Mar 27, 2024

i found you use two different GPU A100 and 4090.
the code will build an engine only with the first GPU (may be a bug?).
i think you can build twice to use different GPU and different path with set CUDA_VISIBLE_DEVICES environment.
then move the second file named rank1.engine which built by sencond gpu to replace the first path.

@teis-e
Copy link
Author

teis-e commented Mar 27, 2024

But the model only fits on both together. It's little over 40G, is that a problem?

@Tlntin
Copy link
Owner

Tlntin commented Mar 28, 2024

I mean make the 3090 able to run half the model if you enable -world_size 2 --tp_size 2 --pp_size 1. However, generally speaking, the same GPU runs more stably, and different GPUs may encounter various errors.

@teis-e
Copy link
Author

teis-e commented Mar 28, 2024

try to add --world_size 2 --tp_size 2 --pp_size 1

That is what I already did

@Tlntin
Copy link
Owner

Tlntin commented Mar 28, 2024

I means:

first build (build with A100)

python3 build.py --use_weight_only \ --weight_only_precision int4_gptq \ --per_group \ --hf_model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen1.5-72B-Chat-GPTQ-Int4/snapshots/b8665876947e59ffb3fbf5b5caa9bd354e885a13 \ --quant_ckpt_path /root/.cache/huggingface/hub/models--Qwen--Qwen1.5-72B-Chat-GPTQ-Int4/snapshots/b8665876947e59ffb3fbf5b5caa9bd354e885a13 \
-world_size 2 --tp_size 2 --pp_size 1 \
--output_dir trt_engines/fp16/1-gpu/

second build (build with 3090)

CUDA_VISIBLE_DEVICES=1 \
python3 build.py --use_weight_only \ --weight_only_precision int4_gptq \ --per_group \ --hf_model_dir /root/.cache/huggingface/hub/models--Qwen--Qwen1.5-72B-Chat-GPTQ-Int4/snapshots/b8665876947e59ffb3fbf5b5caa9bd354e885a13 \ --quant_ckpt_path /root/.cache/huggingface/hub/models--Qwen--Qwen1.5-72B-Chat-GPTQ-Int4/snapshots/b8665876947e59ffb3fbf5b5caa9bd354e885a13 \
-world_size 2 --tp_size 2 --pp_size 1 \
--output_dir trt_engines/fp16/2-gpu/

move second path to first

mv trt_engines/fp16/2-gpu/rank1.engine trt_engines/fp16/1-gpu/

run xxx

mpirun -n 2 --allow-run-as-root  \
    python3 run.py --max_output_len=50 \
               --tokenizer_dir /root/.cache/huggingface/hub/models--Qwen--Qwen1.5-72B-Chat-GPTQ-Int4/snapshots/b8665876947e59ffb3fbf5b5caa9bd354e885a13  \
               --engine_dir=xxxxxx

@teis-e
Copy link
Author

teis-e commented Mar 28, 2024

I tried it

Running the first command succesfully makes an engine.

The seccond command (for the 4090) gives an error:
/usr/local/lib/python3.10/dist-packages/accelerate/utils/modeling.py:1341: UserWarning: Current model requires 5645993472 bytes of buffer for offloaded layers, which seems does not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using offload_buffers=True.
warnings.warn(
Loading checkpoint shards: 100%|████████████████████████████████████| 11/11 [00:21<00:00, 1.98s/it]
[03/28/2024-11:38:06] Some parameters are on the meta device device because they were offloaded to the cpu.

Traceback (most recent call last):
  File "/app/tensorrt_llm/examples/qwen2/build.py", line 887, in <module>
    build(0, args)
  File "/app/tensorrt_llm/examples/qwen2/build.py", line 864, in build
    engine = build_rank_engine(
  File "/app/tensorrt_llm/examples/qwen2/build.py", line 672, in build_rank_engine
    load_from_gptq_qwen(tensorrt_llm_qwen=tensorrt_llm_qwen,
  File "/app/tensorrt_llm/examples/qwen2/weight.py", line 734, in load_from_gptq_qwen
    trust_remote_code=True).eval().cpu()
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 967, in cpu
    return self._apply(lambda t: t.cpu())
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  [Previous line repeated 1 more time]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 833, in _apply
    param_applied = fn(param)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 967, in <lambda>
    return self._apply(lambda t: t.cpu())
NotImplementedError: Cannot copy out of meta tensor; no data!

Also this command starts loading all vram on the A100 not the 4090, don't know if that suppose to happen.

@teis-e
Copy link
Author

teis-e commented Apr 1, 2024

Anyone had succes building with 2 different GPU's?

@Tlntin Tlntin changed the title Qwen 2 build.py multi gpu issue Qwen 2 build.py multi gpu with 2 different GPU's issue Apr 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants