Skip to content

Commit bdc8004

Browse files
committed
Avoid copying output from GPU to CPU
1 parent df626bd commit bdc8004

File tree

7 files changed

+164
-19
lines changed

7 files changed

+164
-19
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0123293118efb08ac4ffc4fefe9d330201465c93
1+
de4f3c4978b4d36cc0bb8f87c6877a4a040d7ae7

.ci/scripts/test_huggingface_optimum_model.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,35 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
170170
assert check_causal_lm_output_quality(model_id, generated_tokens) is True
171171

172172

173+
def get_tokenizer_path(model_dir: str, saved_files: tuple) -> str:
174+
"""
175+
Determine the tokenizer path based on files saved by tokenizer.save_pretrained().
176+
177+
Args:
178+
model_dir: The directory where tokenizer files were saved
179+
saved_files: Tuple of file paths returned by tokenizer.save_pretrained()
180+
181+
Returns:
182+
The path to use for loading the tokenizer (either a specific file or directory)
183+
184+
Raises:
185+
ValueError: If no supported tokenizer file format is found
186+
"""
187+
saved_filenames = {Path(f).name for f in saved_files}
188+
189+
if "tokenizer.model" in saved_filenames:
190+
return f"{model_dir}/tokenizer.model"
191+
192+
if "tokenizer.json" in saved_filenames:
193+
return model_dir
194+
195+
# No supported tokenizer format found
196+
raise ValueError(
197+
f"Unsupported tokenizer format. Expected 'tokenizer.model' (SentencePiece) "
198+
f"or 'tokenizer.json' (HuggingFace) but found: {saved_filenames}"
199+
)
200+
201+
173202
def test_llm_with_image_modality(
174203
model_id, model_dir, recipe, *, quantize=True, run_only=False
175204
):
@@ -196,7 +225,8 @@ def test_llm_with_image_modality(
196225
cli_export(command, model_dir)
197226

198227
tokenizer = AutoTokenizer.from_pretrained(model_id)
199-
tokenizer.save_pretrained(model_dir)
228+
saved_files = tokenizer.save_pretrained(model_dir)
229+
tokenizer_path = get_tokenizer_path(model_dir, saved_files)
200230

201231
# input
202232
processor = AutoProcessor.from_pretrained(model_id)
@@ -232,7 +262,7 @@ def test_llm_with_image_modality(
232262

233263
from executorch.extension.llm.runner import GenerationConfig, MultimodalRunner
234264

235-
runner = MultimodalRunner(f"{model_dir}/model.pte", f"{model_dir}/tokenizer.model")
265+
runner = MultimodalRunner(f"{model_dir}/model.pte", tokenizer_path)
236266
generated_text = runner.generate_text_hf(
237267
inputs,
238268
GenerationConfig(max_new_tokens=128, temperature=0, echo=False),

backends/aoti/aoti_delegate_handle.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/runtime/core/error.h>
1212
#include <executorch/runtime/core/evalue.h>
13+
#include <string>
1314

1415
namespace executorch {
1516
namespace backends {
@@ -85,6 +86,7 @@ struct AOTIDelegateHandle {
8586
AOTInductorModelContainerHandle container_handle;
8687
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
8788
// dependency
89+
std::string method_name;
8890

8991
// Function pointers specific to this handle's shared library
9092
AOTInductorModelContainerCreateWithDeviceFunc create_with_device;

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 96 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1414
#include <cstdio>
1515

16+
#include <array>
1617
#include <filesystem>
1718
#include <fstream>
19+
#include <mutex>
1820
#include <string>
1921
#include <vector>
2022

@@ -24,6 +26,7 @@
2426
#include <executorch/backends/cuda/runtime/platform/platform.h>
2527
#include <executorch/backends/cuda/runtime/shims/memory.h>
2628
#include <executorch/backends/cuda/runtime/utils.h>
29+
#include <executorch/runtime/backend/options.h>
2730

2831
namespace executorch::backends::cuda {
2932

@@ -35,20 +38,54 @@ using executorch::runtime::ArrayRef;
3538
using executorch::runtime::Backend;
3639
using executorch::runtime::BackendExecutionContext;
3740
using executorch::runtime::BackendInitContext;
41+
using executorch::runtime::BackendOption;
42+
using executorch::runtime::BackendOptionContext;
3843
using executorch::runtime::CompileSpec;
3944
using executorch::runtime::DelegateHandle;
4045
using executorch::runtime::Error;
4146
using executorch::runtime::EValue;
4247
using executorch::runtime::FreeableBuffer;
48+
using executorch::runtime::kMaxOptionValueLength;
4349
using executorch::runtime::MemoryAllocator;
4450
using executorch::runtime::NamedDataMap;
4551
using executorch::runtime::Result;
4652
using executorch::runtime::Span;
4753
using executorch::runtime::etensor::Tensor;
4854

55+
namespace {
56+
constexpr char kSkipCopyOutputToCpuForMethod[] =
57+
"skip_copy_output_to_cpu_for_method";
58+
}
59+
4960
class ET_EXPERIMENTAL CudaBackend final
5061
: public ::executorch::runtime::BackendInterface {
5162
private:
63+
void set_skip_copy_method(
64+
const std::array<char, kMaxOptionValueLength>& raw) {
65+
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
66+
skip_copy_method_ = std::string(raw.data());
67+
}
68+
69+
std::array<char, kMaxOptionValueLength> get_skip_copy_method_as_option()
70+
const {
71+
std::array<char, kMaxOptionValueLength> out{};
72+
std::string value;
73+
{
74+
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
75+
value = skip_copy_method_;
76+
}
77+
std::snprintf(out.data(), out.size(), "%s", value.c_str());
78+
return out;
79+
}
80+
81+
bool should_skip_copy_for_method(const std::string& method_name) const {
82+
if (method_name.empty()) {
83+
return false;
84+
}
85+
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
86+
return method_name == skip_copy_method_;
87+
}
88+
5289
Error load_function_pointers_into_handle(
5390
void* so_handle,
5491
AOTIDelegateHandle* handle) const {
@@ -91,6 +128,38 @@ class ET_EXPERIMENTAL CudaBackend final
91128
return 1;
92129
}
93130

131+
Error set_option(
132+
ET_UNUSED BackendOptionContext& context,
133+
const executorch::runtime::Span<BackendOption>& backend_options)
134+
override {
135+
for (const auto& option : backend_options) {
136+
if (std::strcmp(option.key, kSkipCopyOutputToCpuForMethod) == 0) {
137+
if (auto* val = std::get_if<std::array<char, kMaxOptionValueLength>>(
138+
&option.value)) {
139+
set_skip_copy_method(*val);
140+
} else {
141+
ET_LOG(
142+
Error,
143+
"Option %s must be a method name string.",
144+
kSkipCopyOutputToCpuForMethod);
145+
return Error::InvalidArgument;
146+
}
147+
}
148+
}
149+
return Error::Ok;
150+
}
151+
152+
Error get_option(
153+
ET_UNUSED BackendOptionContext& context,
154+
executorch::runtime::Span<BackendOption>& backend_options) override {
155+
for (auto& option : backend_options) {
156+
if (std::strcmp(option.key, kSkipCopyOutputToCpuForMethod) == 0) {
157+
option.value = get_skip_copy_method_as_option();
158+
}
159+
}
160+
return Error::Ok;
161+
}
162+
94163
// Once per loaded binary blob
95164
Result<DelegateHandle*> init(
96165
BackendInitContext& context,
@@ -159,6 +228,7 @@ class ET_EXPERIMENTAL CudaBackend final
159228
AOTIDelegateHandle* handle = new AOTIDelegateHandle();
160229
handle->so_handle = lib_handle;
161230
handle->so_path = so_path.string();
231+
handle->method_name = method_name;
162232

163233
// Load function pointers specific to this handle's shared library
164234
ET_CHECK_OK_OR_RETURN_ERROR(
@@ -224,7 +294,7 @@ class ET_EXPERIMENTAL CudaBackend final
224294

225295
// Process input tensors: ExecuTorch provides CPU tensors, create GPU
226296
// copies
227-
for (int i = 0; i < n_inputs; i++) {
297+
for (size_t i = 0; i < n_inputs; i++) {
228298
// Get tensor dimensions and properties from ExecuTorch CPU tensor
229299
auto cpu_tensor = &(args[i]->toTensor());
230300
auto sizes = cpu_tensor->sizes();
@@ -260,7 +330,7 @@ class ET_EXPERIMENTAL CudaBackend final
260330
}
261331
// Process output tensors: create GPU counterparts for ExecuTorch CPU
262332
// tensors
263-
for (int i = 0; i < n_outputs; i++) {
333+
for (size_t i = 0; i < n_outputs; i++) {
264334
// Get output tensor dimensions from ExecuTorch CPU tensor
265335
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
266336
auto sizes = cpu_output_tensor->sizes();
@@ -303,18 +373,26 @@ class ET_EXPERIMENTAL CudaBackend final
303373
"AOTInductorModelContainerRun failed with error code %d",
304374
error);
305375

306-
// Copy GPU output results back to CPU output tensors
307-
for (int i = 0; i < n_outputs; i++) {
308-
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
309-
// For DYNAMIC_BOUND tensors we try to resize
310-
ET_CHECK_OK_OR_RETURN_ERROR(
311-
resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()),
312-
"Error resizing tensor at output index %d",
313-
i);
314-
ET_CHECK_OK_OR_RETURN_ERROR(
315-
aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0),
316-
"Failed to copy GPU output %d back to CPU",
317-
i);
376+
const bool copy_outputs = !should_skip_copy_for_method(handle->method_name);
377+
378+
if (copy_outputs) {
379+
// Copy GPU output results back to CPU output tensors
380+
for (size_t i = 0; i < n_outputs; i++) {
381+
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
382+
// For DYNAMIC_BOUND tensors we try to resize
383+
ET_CHECK_OK_OR_RETURN_ERROR(
384+
resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()),
385+
"Error resizing tensor at output index %d",
386+
i);
387+
ET_CHECK_OK_OR_RETURN_ERROR(
388+
aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0),
389+
"Failed to copy GPU output %d back to CPU",
390+
i);
391+
}
392+
} else {
393+
for (size_t i = 0; i < n_outputs; i++) {
394+
args[i + n_inputs]->toTensor() = *gpu_outputs[i];
395+
}
318396
}
319397

320398
return Error::Ok;
@@ -365,6 +443,10 @@ class ET_EXPERIMENTAL CudaBackend final
365443
delete handle;
366444
clear_all_tensors();
367445
}
446+
447+
private:
448+
mutable std::mutex skip_copy_method_mutex_;
449+
std::string skip_copy_method_;
368450
};
369451

370452
} // namespace executorch::backends::cuda

extension/asr/runner/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@ set_target_properties(
3535
extension_asr_runner PROPERTIES POSITION_INDEPENDENT_CODE ON
3636
)
3737

38+
# If the project is configured to build with CUDA support, try to find a CUDA
39+
# runtime (prefer the CUDAToolkit package). If found, expose a compile-time
40+
# macro so sources can conditionally compile CUDA-aware code.
41+
if(EXECUTORCH_BUILD_CUDA)
42+
find_package(CUDAToolkit QUIET)
43+
if(CUDAToolkit_FOUND)
44+
target_compile_definitions(extension_asr_runner PUBLIC CUDA_AVAILABLE)
45+
message(STATUS "CUDAToolkit found; defining CUDA_AVAILABLE for ASR runner")
46+
else()
47+
message(
48+
STATUS
49+
"CUDA requested (EXECUTORCH_BUILD_CUDA=ON) but no CUDA runtime found"
50+
)
51+
endif()
52+
endif()
53+
3854
install(
3955
TARGETS extension_asr_runner
4056
EXPORT ExecuTorchTargets

extension/asr/runner/runner.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,22 @@ Error AsrRunner::load() {
107107

108108
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kDecoderMethodName));
109109
decoder_method_loaded_ = true;
110-
110+
#ifdef CUDA_AVAILABLE
111+
executorch::runtime::BackendOptions<1> backend_options;
112+
// For decoder still copy output from GPU to CPU for sampling.
113+
// TODO: change sampler to use a CUDA kernel to sample and then skip copying
114+
// decoder output as well
115+
ET_CHECK_OK_OR_RETURN_ERROR(backend_options.set_option(
116+
"skip_copy_output_to_cpu_for_method", kEncoderMethodName));
117+
const auto opt_err =
118+
executorch::runtime::set_option("CudaBackend", backend_options.view());
119+
if (opt_err != ::executorch::runtime::Error::Ok) {
120+
ET_LOG(
121+
Warning,
122+
"Failed to set CUDA backend options: %d",
123+
static_cast<int>(opt_err));
124+
}
125+
#endif
111126
ET_CHECK_OK_OR_RETURN_ERROR(load_tokenizer());
112127
auto eos_ids = get_eos_ids(tokenizer_.get(), module_.get());
113128
if (!eos_ids.empty()) {

requirements-examples.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ datasets == 3.6.0 # 4.0.0 deprecates trust_remote_code and load scripts. For now
44
timm == 1.0.7
55
torchsr == 1.0.4
66
torchtune >= 0.6.1
7-
transformers == 4.56.1
7+
transformers == 5.0.0rc1

0 commit comments

Comments
 (0)