Skip to content

Commit efccca4

Browse files
make aoti_torch_empty_strided support creating incontiguous tensor (#15321)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #15228 by @Gasoonjia ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/gasoonjia/59/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/59/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/gasoonjia/59/orig Differential Revision: [D84938258](https://our.internmc.facebook.com/intern/diff/D84938258/) @diff-train-skip-merge Co-authored-by: gasoonjia <[email protected]> Co-authored-by: Gasoonjia <[email protected]>
1 parent f05f103 commit efccca4

File tree

8 files changed

+385
-44
lines changed

8 files changed

+385
-44
lines changed

backends/cuda/CMakeLists.txt

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,39 @@ find_package(CUDAToolkit REQUIRED)
3434
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
3535
find_package_torch()
3636

37+
# CUDA tensor maker for backends that support incontiguous tensors
38+
set(_tensor_maker_sources runtime/tensor/tensor_maker.cpp)
39+
add_library(cuda_tensor_maker STATIC ${_tensor_maker_sources})
40+
target_include_directories(
41+
cuda_tensor_maker
42+
PUBLIC $<BUILD_INTERFACE:${EXECUTORCH_ROOT}> $<INSTALL_INTERFACE:include>
43+
$<BUILD_INTERFACE:${EXECUTORCH_ROOT}/..>
44+
)
45+
target_compile_options(
46+
cuda_tensor_maker
47+
PUBLIC $<$<CXX_COMPILER_ID:MSVC>:/EHsc /GR>
48+
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-fexceptions -frtti -fPIC>
49+
)
50+
# Ensure symbols are exported properly
51+
if(APPLE)
52+
target_link_options(cuda_tensor_maker PUBLIC -Wl,-export_dynamic)
53+
else()
54+
target_link_options(
55+
cuda_tensor_maker PUBLIC
56+
$<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wl,--export-dynamic>
57+
)
58+
endif()
59+
60+
# Link against ExecuTorch core libraries
61+
target_link_libraries(cuda_tensor_maker PUBLIC executorch ${CMAKE_DL_LIBS})
62+
executorch_target_link_options_shared_lib(cuda_tensor_maker)
63+
64+
install(
65+
TARGETS cuda_tensor_maker
66+
EXPORT ExecuTorchTargets
67+
DESTINATION lib
68+
)
69+
3770
# CUDA-specific AOTI functionality
3871
set(_aoti_cuda_sources
3972
runtime/cuda_backend.cpp
@@ -62,9 +95,10 @@ target_link_options(
6295
aoti_cuda PUBLIC $<$<NOT:$<CXX_COMPILER_ID:MSVC>>:-Wl,--export-dynamic>
6396
)
6497

65-
# Link against CUDA::cudart, common AOTI library, and PyTorch CUDA libraries
98+
# Link against CUDA::cudart, common AOTI library, cuda_tensor_maker, and PyTorch
99+
# CUDA libraries
66100
target_link_libraries(
67-
aoti_cuda PUBLIC aoti_common CUDA::cudart ${CMAKE_DL_LIBS}
101+
aoti_cuda PUBLIC aoti_common cuda_tensor_maker CUDA::cudart ${CMAKE_DL_LIBS}
68102
)
69103
# If you need other CUDA libraries, link them similarly:
70104
# target_link_libraries(aoti_cuda PUBLIC CUDA::cublas CUDA::cufft ...)

backends/cuda/runtime/TARGETS

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,25 @@ runtime.cxx_library(
2727
],
2828
)
2929

30+
runtime.cxx_library(
31+
name = "tensor_maker",
32+
srcs = [
33+
"tensor/tensor_maker.cpp",
34+
],
35+
headers = [
36+
"tensor/tensor_maker.h",
37+
],
38+
# @lint-ignore BUCKLINT: Avoid `link_whole=True` (https://fburl.com/avoid-link-whole)
39+
link_whole = True,
40+
supports_python_dlopen = True,
41+
visibility = ["@EXECUTORCH_CLIENTS"],
42+
deps = [
43+
"//executorch/runtime/core:core",
44+
"//executorch/runtime/core/exec_aten:lib",
45+
"//executorch/runtime/core/exec_aten/util:tensor_util",
46+
],
47+
)
48+
3049
runtime.cxx_library(
3150
name = "runtime_shims",
3251
srcs = [
@@ -52,8 +71,8 @@ runtime.cxx_library(
5271
compiler_flags = ["-Wno-global-constructors"],
5372
visibility = ["@EXECUTORCH_CLIENTS"],
5473
deps = [
74+
":tensor_maker",
5575
"//executorch/backends/aoti:common_shims",
56-
"//executorch/extension/tensor:tensor",
5776
"//executorch/runtime/core:core",
5877
"//executorch/runtime/core/exec_aten:lib",
5978
"//executorch/runtime/platform:platform",

backends/cuda/runtime/shims/memory.cpp

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <executorch/backends/cuda/runtime/platform/platform.h>
1212
#include <executorch/backends/cuda/runtime/shims/memory.h>
1313
#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>
14+
#include <executorch/backends/cuda/runtime/tensor/tensor_maker.h>
1415
#include <executorch/backends/cuda/runtime/utils.h>
1516
#include <executorch/runtime/platform/log.h>
1617
#include <cstdint>
@@ -163,9 +164,11 @@ AOTITorchError aoti_torch_create_tensor_from_blob_v2(
163164

164165
// Create ExecutorTorch tensor that wraps the existing memory
165166
// Note: We're NOT copying the data, just wrapping it
166-
auto tensor = executorch::extension::from_blob(
167-
data, // existing memory (don't copy!)
167+
// Using CUDA-specific tensor maker that supports incontiguous tensors
168+
auto tensor = make_tensor(
168169
sizes, // tensor dimensions
170+
data, // existing memory (don't copy!)
171+
{}, // dim_order (empty, will be auto-generated)
169172
strides, // tensor strides (allows different strides)
170173
dtype_to_scalar_type(dtype) // map int32_t dtype to ScalarType
171174
);
@@ -210,10 +213,6 @@ AOTITorchError aoti_torch_empty_strided(
210213

211214
// This requires us to reserve CUDA memory and put it into a ETensor
212215
void* ptr;
213-
int64_t numel = 1;
214-
for (int64_t i = 0; i < ndim; i++) {
215-
numel *= sizes_ptr[i];
216-
}
217216

218217
ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype));
219218

@@ -223,7 +222,28 @@ AOTITorchError aoti_torch_empty_strided(
223222
InvalidArgument,
224223
"Invalid element size for dtype: %d",
225224
dtype);
226-
int64_t nbytes = numel * element_size;
225+
226+
// Calculate storage size based on strides, matching PyTorch's behavior
227+
// This is critical when sizes and strides don't match the expected contiguous
228+
// layout Reference: PyTorch's computeStorageNbytes in EmptyTensor.cpp
229+
int64_t storage_size = 1; // storage offset (0) + 1
230+
for (int64_t i = 0; i < ndim; i++) {
231+
if (sizes_ptr[i] == 0) {
232+
storage_size = 0;
233+
break;
234+
}
235+
// For each dimension, add stride[i] * (size[i] - 1)
236+
// This gives us the maximum offset in that dimension
237+
int64_t stride_i = (strides_ptr != nullptr) ? strides_ptr[i] : 1;
238+
if (strides_ptr == nullptr) {
239+
// Calculate contiguous stride if not provided
240+
for (int64_t j = i + 1; j < ndim; j++) {
241+
stride_i *= sizes_ptr[j];
242+
}
243+
}
244+
storage_size += stride_i * (sizes_ptr[i] - 1);
245+
}
246+
int64_t nbytes = storage_size * element_size;
227247

228248
if (device_type == static_cast<int32_t>(SupportedDevices::CUDA)) {
229249
ET_CUDA_CHECK_OR_RETURN_ERROR(
@@ -250,16 +270,20 @@ AOTITorchError aoti_torch_empty_strided(
250270
auto strides = convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
251271

252272
// ETensor creation with dynamic shape support for edge cases
253-
auto tensor = executorch::extension::from_blob(
254-
ptr, sizes, strides, dtype_to_scalar_type(dtype));
273+
// Using CUDA-specific tensor maker that supports incontiguous tensors
274+
auto tensor = make_tensor(
275+
sizes,
276+
ptr,
277+
{}, // dim_order (empty, will be auto-generated)
278+
strides,
279+
dtype_to_scalar_type(dtype));
255280

256281
// Store the tensor so it doesn't get destroyed
257282
tensors.insert(tensor);
258283
*ret_new_tensor = tensor.get();
259284

260285
// This tensor owns the memory it allocated, set reference count to 1
261286
memory_to_n_tensor[ptr] = 1;
262-
263287
return Error::Ok;
264288
}
265289

@@ -630,9 +654,11 @@ AOTITorchError aoti_torch__reinterpret_tensor(
630654

631655
// Create new tensor view that reinterprets the same memory with different
632656
// shape/strides This creates a view, not a copy - the data pointer is shared
633-
std::shared_ptr<Tensor> tensor = executorch::extension::from_blob(
634-
data_ptr, // Reuse the same memory from source tensor
657+
// Using CUDA-specific tensor maker that supports incontiguous tensors
658+
std::shared_ptr<Tensor> tensor = make_tensor(
635659
sizes, // New sizes with explicit SizesType
660+
data_ptr, // Reuse the same memory from source tensor
661+
{}, // dim_order (empty, will be auto-generated)
636662
strides, // New strides with explicit StridesType
637663
dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting
638664
);

backends/cuda/runtime/shims/tensor_attribute.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
#pragma once
1010

11-
#include <executorch/extension/tensor/tensor.h>
1211
#include <executorch/runtime/core/error.h>
12+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1313
#include <cstdint>
1414

1515
namespace executorch::backends::cuda {

backends/cuda/runtime/shims/tests/test_aoti_torch__reinterpret_tensor.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <executorch/backends/cuda/runtime/shims/memory.h>
1313
#include <executorch/backends/cuda/runtime/shims/tensor_attribute.h>
1414
#include <executorch/backends/cuda/runtime/utils.h>
15+
#include <executorch/extension/tensor/tensor_ptr_maker.h>
1516
#include <executorch/runtime/core/error.h>
1617
#include <executorch/runtime/platform/platform.h>
1718
#include <gtest/gtest.h>

backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp

Lines changed: 107 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -278,30 +278,6 @@ TEST_F(AOTITorchEmptyStridedTest, LargeTensor) {
278278
EXPECT_EQ(tensor->size(2), 50);
279279
}
280280

281-
// Test error handling with memory allocation failures
282-
TEST_F(AOTITorchEmptyStridedTest, MemoryAllocationStress) {
283-
// Try to create a very large tensor that might cause allocation failure
284-
// (This test may pass or fail depending on available memory)
285-
std::vector<int64_t> huge_sizes = {10000, 10000, 100}; // ~38GB for float32
286-
Tensor* tensor;
287-
288-
AOTITorchError error = aoti_torch_empty_strided(
289-
huge_sizes.size(),
290-
huge_sizes.data(),
291-
nullptr,
292-
6, // float32
293-
1, // CUDA device
294-
0, // device index
295-
&tensor);
296-
297-
// Either succeed or fail with memory allocation error
298-
if (error == Error::Ok) {
299-
EXPECT_NE(tensor, nullptr);
300-
} else {
301-
EXPECT_EQ(error, Error::MemoryAllocationFailed);
302-
}
303-
}
304-
305281
// Test aoti_torch_empty_strided with bfloat16 dtype
306282
TEST_F(AOTITorchEmptyStridedTest, BFloat16Tensor) {
307283
// Test creating bfloat16 tensor on CUDA
@@ -509,11 +485,11 @@ TEST_F(AOTITorchEmptyStridedTest, ZeroElementTensor) {
509485
EXPECT_EQ(sizes_ptr[2], 3);
510486
}
511487

512-
// Test different data types (only float32 is currently supported)
488+
// Test different data types (currently we support bf16, fp32 and int32)
513489
TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) {
514490
std::vector<int64_t> sizes = {2, 3};
515491

516-
// Test float32 (dtype 6) - currently the only supported type
492+
// Test float32 (dtype 6) - one of the supported types
517493
Tensor* tensor_float32;
518494
AOTITorchError error = aoti_torch_empty_strided(
519495
sizes.size(),
@@ -527,7 +503,7 @@ TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) {
527503
EXPECT_EQ(error, Error::Ok);
528504
EXPECT_NE(tensor_float32, nullptr);
529505

530-
// Test unsupported data types should return error
506+
// Test int32 (dtype 3) - one of the supported types
531507
Tensor* tensor_int32;
532508
error = aoti_torch_empty_strided(
533509
sizes.size(),
@@ -538,7 +514,8 @@ TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) {
538514
0, // device index
539515
&tensor_int32);
540516

541-
EXPECT_EQ(error, Error::InvalidArgument); // Should fail for unsupported dtype
517+
EXPECT_EQ(error, Error::Ok);
518+
EXPECT_NE(tensor_int32, nullptr);
542519

543520
// Test another unsupported data type
544521
Tensor* tensor_float64;
@@ -586,3 +563,105 @@ TEST_F(AOTITorchEmptyStridedTest, MultiDimensionalTensors) {
586563
EXPECT_EQ(tensor_5d->size(3), 4);
587564
EXPECT_EQ(tensor_5d->size(4), 5);
588565
}
566+
567+
// Test incontiguous tensor creation - transpose-like layout
568+
TEST_F(AOTITorchEmptyStridedTest, IncontiguousTransposeLayout) {
569+
// Create a tensor with transpose-like strides (column-major)
570+
// For a 3x4 tensor in column-major order, strides should be [1, 3]
571+
// This means each row step is 1, and each column step is 3
572+
std::vector<int64_t> sizes = {3, 4};
573+
std::vector<int64_t> strides = {1, 3}; // Column-major (incontiguous)
574+
575+
Tensor* tensor;
576+
AOTITorchError error = aoti_torch_empty_strided(
577+
sizes.size(),
578+
sizes.data(),
579+
strides.data(),
580+
static_cast<int32_t>(SupportedDTypes::FLOAT32),
581+
static_cast<int32_t>(SupportedDevices::CUDA),
582+
0, // device index
583+
&tensor);
584+
585+
EXPECT_EQ(error, Error::Ok);
586+
EXPECT_NE(tensor, nullptr);
587+
588+
// Verify tensor properties
589+
EXPECT_EQ(tensor->dim(), 2);
590+
EXPECT_EQ(tensor->size(0), 3);
591+
EXPECT_EQ(tensor->size(1), 4);
592+
593+
// Verify the strides are what we specified
594+
int64_t* strides_ptr;
595+
EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok);
596+
EXPECT_EQ(strides_ptr[0], 1); // Column-major stride for dimension 0
597+
EXPECT_EQ(strides_ptr[1], 3); // Column-major stride for dimension 1
598+
599+
// Verify that memory was allocated correctly for incontiguous layout
600+
// Storage size should be: stride[0] * (size[0] - 1) + stride[1] * (size[1] -
601+
// 1) + 1 = 1 * (3 - 1) + 3 * (4 - 1) + 1 = 1 * 2 + 3 * 3 + 1 = 2 + 9 + 1 = 12
602+
// elements Total bytes = 12 * 4 = 48 bytes (for float32)
603+
EXPECT_EQ(tensor->numel(), 12); // numel is still 3*4=12 for logical shape
604+
605+
// The tensor should be accessible and writable
606+
void* data_ptr = tensor->mutable_data_ptr();
607+
EXPECT_NE(data_ptr, nullptr);
608+
609+
// Verify we can use CUDA to write to the memory
610+
std::vector<float> test_data(12, 1.0f);
611+
cudaError_t cuda_err = cudaMemcpy(
612+
data_ptr, test_data.data(), 12 * sizeof(float), cudaMemcpyHostToDevice);
613+
EXPECT_EQ(cuda_err, cudaSuccess);
614+
}
615+
616+
// Test incontiguous tensor creation - expanded/broadcasted stride pattern
617+
TEST_F(AOTITorchEmptyStridedTest, IncontiguousExpandedStrides) {
618+
// Create a tensor with expanded strides (simulating broadcasting)
619+
// A 2x3x4 tensor where the first dimension has stride 0 (expanded)
620+
// This creates a tensor where the first dimension is "broadcasted"
621+
std::vector<int64_t> sizes = {2, 3, 4};
622+
std::vector<int64_t> strides = {0, 4, 1}; // First dimension has stride 0
623+
624+
Tensor* tensor;
625+
AOTITorchError error = aoti_torch_empty_strided(
626+
sizes.size(),
627+
sizes.data(),
628+
strides.data(),
629+
static_cast<int32_t>(SupportedDTypes::FLOAT32),
630+
static_cast<int32_t>(SupportedDevices::CUDA),
631+
0, // device index
632+
&tensor);
633+
634+
EXPECT_EQ(error, Error::Ok);
635+
EXPECT_NE(tensor, nullptr);
636+
637+
// Verify tensor properties
638+
EXPECT_EQ(tensor->dim(), 3);
639+
EXPECT_EQ(tensor->size(0), 2);
640+
EXPECT_EQ(tensor->size(1), 3);
641+
EXPECT_EQ(tensor->size(2), 4);
642+
643+
// Verify the strides are what we specified
644+
int64_t* strides_ptr;
645+
EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok);
646+
EXPECT_EQ(strides_ptr[0], 0); // Expanded dimension stride
647+
EXPECT_EQ(strides_ptr[1], 4);
648+
EXPECT_EQ(strides_ptr[2], 1);
649+
650+
// Verify that memory was allocated correctly for this incontiguous layout
651+
// Storage size should be: stride[0] * (size[0] - 1) + stride[1] * (size[1] -
652+
// 1) + stride[2] * (size[2] - 1) + 1 = 0 * (2 - 1) + 4 * (3 - 1) + 1 * (4 -
653+
// 1) + 1 = 0 + 8 + 3 + 1 = 12 elements Note: numel() returns logical number
654+
// of elements (2*3*4=24), not storage size
655+
EXPECT_EQ(tensor->numel(), 24); // Logical numel is 2*3*4=24
656+
657+
// The tensor should be accessible and writable
658+
void* data_ptr = tensor->mutable_data_ptr();
659+
EXPECT_NE(data_ptr, nullptr);
660+
661+
// Verify we can use CUDA to write to the allocated memory
662+
// We only need to allocate 12 elements (storage size), not 24
663+
std::vector<float> test_data(12, 2.0f);
664+
cudaError_t cuda_err = cudaMemcpy(
665+
data_ptr, test_data.data(), 12 * sizeof(float), cudaMemcpyHostToDevice);
666+
EXPECT_EQ(cuda_err, cudaSuccess);
667+
}

0 commit comments

Comments
 (0)