-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enhancement] Add option and functionality to set torch stream as the…
… current stream (#629) Use torch c++ api to set the current stream to current torch stream. Implementation: - Build a hidet-torch shared library to wrap the original torch C++ API (The original API contains torch defined structure like `CUDAStream` and cannot be easily dlopened during runtime and accessed) - dlopen the newly added hidet-torch library and access torch's current stream - Add option "use_torch_stream" to hidet's option to dynamically set the stream to current torch stream or hidet's stream during runtime - When hidet's CUDA graph mode is on, hidet will still create a new hidet stream and capture the graph on that stream instead of using the torch stream. Benefits: - Removes the overhead of query and calling torch's current stream api from the python side - Could also reduce the overhead occured in Hexcute integration because `set_to_torch_stream` is called in the launch function. We can remove the stream query/switch on python side. Performance improvement (measured on L4 lock frequency@6250MHZ compute/1500MHZ memory): 1. For Hexcute kernel (without cudagraph), I manually disabled CUDA graph on DMWL (vLLM) side, prefill and decoding stage will both use the generic model and call Hexcute kernel directly. command: `python3 benchmark_latency.py --model hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4 --input-len 1024 --output-len 128 --batch-size 8 --num-iters-warmup 5 --num-iters 10 --max-model-len 32768 --quantization awq_hidet` Comparsion before and after removing stream query and stream switch before Hexcute kernel call (CentML/DMWL#121) Before avg latency: 12.624572871897545 seconds After avg latency: 11.764245539499097 seconds 2. Profile small kermels in hidet and measure latency: - Enable CUDA graph `python bench_op_torch_api.py --params 16x16,16x16 --mode max-autotune matmul` Before: 0.27151119 second After: 0.25410826999999997 second - Disable CUDA graph `python bench_op_torch_api.py --params 16x16,16x16 --mode max-autotune-no-cudagraphs matmul` Before: 0.14555310999999999 second After: 0.11648335 second This is related to #563
- Loading branch information
1 parent
c583fb2
commit 3e6a291
Showing
18 changed files
with
245 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,4 +52,4 @@ lark | |
scipy | ||
|
||
# for torch runtime api dependency | ||
torch>=2.3.0 | ||
torch>=2.3.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
#pragma once | ||
#include <hidet/runtime/common.h> | ||
|
||
DLL void *hidet_get_current_torch_stream(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,6 @@ black==22.10.0 | |
pylint==2.13.9 | ||
|
||
# for models to test | ||
torch>=2.3.0 | ||
torchvision | ||
datasets | ||
diffusers | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,3 +42,6 @@ tomlkit | |
|
||
# for performance measurements | ||
scipy | ||
|
||
# for torch runtime api dependency | ||
torch>=2.3.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#include <c10/cuda/CUDAStream.h> | ||
#include <hidet/runtime/common.h> | ||
|
||
DLL void *hidet_get_current_torch_cuda_stream() { | ||
return at::cuda::getCurrentCUDAStream().stream(); | ||
} |
Oops, something went wrong.