Skip to content

Commit df7f2c3

Browse files
authored
Merge branch 'main' into export-D79286076
2 parents 499b438 + e852066 commit df7f2c3

File tree

5 files changed

+956
-3
lines changed

5 files changed

+956
-3
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,6 +2257,9 @@ class CommonReplacePasses:
22572257
ReplaceRepeatWithCatPass,
22582258
ReplaceFullLikeWithFullPass,
22592259
ReplaceAtenConvolutionWithCadenceConvolutionPass,
2260+
ReplacePT2QuantWithCadenceQuantPass,
2261+
ReplacePT2DequantWithCadenceDequantPass,
2262+
ReplacePowWithMulPass,
22602263
]
22612264

22622265

@@ -2302,13 +2305,10 @@ class CadenceReplaceOpsInGraph:
23022305
ReplaceScalarTensorWithFullPass,
23032306
ReplaceInfArgInFullWithValuePass,
23042307
ReplaceLogicalNotBooleanWhereWithWherePass,
2305-
ReplacePT2QuantWithCadenceQuantPass,
2306-
ReplacePT2DequantWithCadenceDequantPass,
23072308
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
23082309
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
23092310
ReplaceAtenAvgPoolWithCadenceAvgPoolPass,
23102311
ReplaceWhereWithFullArgsWithWhereScalar,
23112312
ReplaceAtenApproxGeluWithApproxGeluPass,
2312-
ReplacePowWithMulPass,
23132313
ReplaceMulTensorWithMulAndFullOpsPass,
23142314
]

backends/cuda/runtime/shims/memory.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ namespace cuda {
2525

2626
using executorch::aten::SizesType;
2727
using executorch::aten::StridesType;
28+
using executorch::backends::aoti::aoti_torch_get_device_index;
29+
using executorch::backends::aoti::aoti_torch_get_dtype;
2830
using executorch::backends::aoti::dtype_to_element_size;
2931
using executorch::backends::aoti::dtype_to_scalar_type;
3032
using executorch::backends::aoti::validate_storage_offset;
@@ -310,6 +312,121 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
310312
return Error::Internal;
311313
}
312314

315+
AOTITorchError aoti_torch__reinterpret_tensor(
316+
Tensor* self,
317+
int64_t ndim,
318+
const int64_t* sizes_ptr,
319+
const int64_t* strides_ptr,
320+
int64_t storage_offset,
321+
Tensor** ret_new_tensor) {
322+
// Validate input parameters first
323+
if (self == nullptr) {
324+
ET_LOG(Error, "aoti_torch__reinterpret_tensor failed: self tensor is null");
325+
return Error::InvalidArgument;
326+
}
327+
328+
if (sizes_ptr == nullptr && ndim > 0) {
329+
ET_LOG(Error, "aoti_torch__reinterpret_tensor failed: sizes_ptr is null");
330+
return Error::InvalidArgument;
331+
}
332+
333+
if (ret_new_tensor == nullptr) {
334+
ET_LOG(
335+
Error, "aoti_torch__reinterpret_tensor failed: ret_new_tensor is null");
336+
return Error::InvalidArgument;
337+
}
338+
339+
// Check if storage_offset is not 0 - return error if not
340+
AOTITorchError storage_offset_error = validate_storage_offset(storage_offset);
341+
if (storage_offset_error != Error::Ok) {
342+
return storage_offset_error;
343+
}
344+
345+
// Get the device info from the source tensor to perform device_index
346+
// validation
347+
int32_t device_type = 0;
348+
int32_t device_index = 0;
349+
AOTITorchError device_error = aoti_torch_get_device_type(self, &device_type);
350+
if (device_error != Error::Ok) {
351+
return device_error;
352+
}
353+
354+
device_error = aoti_torch_get_device_index(self, &device_index);
355+
if (device_error != Error::Ok) {
356+
return device_error;
357+
}
358+
359+
// Ensure device_index is always 0
360+
if (device_index != 0) {
361+
ET_LOG(Error, "device_index must be 0, got: %d", device_index);
362+
return Error::InvalidArgument;
363+
}
364+
365+
// Get the dtype from the source tensor
366+
int32_t dtype = 0;
367+
AOTITorchError dtype_error = aoti_torch_get_dtype(self, &dtype);
368+
if (dtype_error != Error::Ok) {
369+
return dtype_error;
370+
}
371+
372+
// Validate dtype using SupportedDTypes
373+
dtype_error = validate_dtype(dtype);
374+
if (dtype_error != Error::Ok) {
375+
return dtype_error;
376+
}
377+
378+
// Get the original data pointer from the source tensor
379+
void* data_ptr = self->mutable_data_ptr();
380+
if (data_ptr == nullptr) {
381+
ET_LOG(Error, "Source tensor has null data pointer");
382+
return Error::InvalidArgument;
383+
}
384+
385+
// Check if the given memory is in the map, if not return error
386+
auto memory_it = memory_to_n_tensor.find(data_ptr);
387+
if (memory_it == memory_to_n_tensor.end()) {
388+
ET_LOG(
389+
Error,
390+
"Memory address %p is not being tracked by reference counting system",
391+
data_ptr);
392+
return Error::InvalidArgument;
393+
}
394+
395+
// Convert sizes using utility function from utils.h
396+
std::vector<SizesType> sizes = convert_sizes_to_vector(ndim, sizes_ptr);
397+
398+
// Convert strides using utility function from utils.h
399+
std::vector<StridesType> strides =
400+
convert_strides_to_vector(ndim, sizes_ptr, strides_ptr);
401+
402+
// Create new tensor view that reinterprets the same memory with different
403+
// shape/strides This creates a view, not a copy - the data pointer is shared
404+
std::shared_ptr<Tensor> tensor = executorch::extension::from_blob(
405+
data_ptr, // Reuse the same memory from source tensor
406+
sizes, // New sizes with explicit SizesType
407+
strides, // New strides with explicit StridesType
408+
dtype_to_scalar_type(dtype) // Convert dtype with explicit type casting
409+
);
410+
411+
if (!tensor) {
412+
ET_LOG(Error, "Failed to create reinterpreted tensor view");
413+
return Error::InvalidArgument;
414+
}
415+
416+
// Store the tensor so it doesn't get destroyed
417+
tensors.insert(tensor);
418+
419+
*ret_new_tensor = tensor.get();
420+
421+
// Increment the reference count for this memory address only if it is owned
422+
// by tensor
423+
memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
424+
? NOT_OWN
425+
: memory_to_n_tensor[data_ptr] + 1;
426+
427+
return Error::Ok;
428+
}
429+
313430
} // extern "C"
314431

315432
} // namespace cuda

backends/cuda/runtime/shims/memory.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,31 @@ AOTITorchError aoti_torch_empty_strided(
9191
*/
9292
AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor);
9393

94+
/**
95+
* Creates a tensor view that reinterprets the same underlying memory with
96+
* different shape and strides without copying data.
97+
*
98+
* Note that the new tensor will not have the ownership of the underlying
99+
* memory.
100+
*
101+
* @param self Input tensor whose memory will be reinterpreted
102+
* @param ndim Number of dimensions for the new tensor view
103+
* @param sizes_ptr Array of sizes for each dimension
104+
* @param strides_ptr Array of strides for each dimension (or nullptr for
105+
* contiguous)
106+
* @param storage_offset Storage offset (must be 0)
107+
* @param ret_new_tensor Output pointer to store the new tensor view
108+
*
109+
* @return Error::Ok on success, appropriate error code on failure
110+
*/
111+
AOTITorchError aoti_torch__reinterpret_tensor(
112+
Tensor* self,
113+
int64_t ndim,
114+
const int64_t* sizes_ptr,
115+
const int64_t* strides_ptr,
116+
int64_t storage_offset,
117+
Tensor** ret_new_tensor);
118+
94119
// Function to clear all tensors from internal storage
95120
void clear_all_tensors();
96121
} // extern "C"

backends/cuda/runtime/shims/tests/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,4 @@ def define_common_targets():
3030
cuda_shim_cpp_unittest("aoti_torch_empty_strided")
3131
cuda_shim_cpp_unittest("aoti_torch_delete_tensor_object")
3232
cuda_shim_cpp_unittest("aoti_torch_create_tensor_from_blob_v2")
33+
cuda_shim_cpp_unittest("aoti_torch__reinterpret_tensor")

0 commit comments

Comments
 (0)