@@ -25,6 +25,8 @@ namespace cuda {
25
25
26
26
using executorch::aten::SizesType;
27
27
using executorch::aten::StridesType;
28
+ using executorch::backends::aoti::aoti_torch_get_device_index;
29
+ using executorch::backends::aoti::aoti_torch_get_dtype;
28
30
using executorch::backends::aoti::dtype_to_element_size;
29
31
using executorch::backends::aoti::dtype_to_scalar_type;
30
32
using executorch::backends::aoti::validate_storage_offset;
@@ -310,6 +312,121 @@ AOTITorchError aoti_torch_delete_tensor_object(Tensor* tensor) {
310
312
return Error::Internal;
311
313
}
312
314
315
+ AOTITorchError aoti_torch__reinterpret_tensor (
316
+ Tensor* self,
317
+ int64_t ndim,
318
+ const int64_t * sizes_ptr,
319
+ const int64_t * strides_ptr,
320
+ int64_t storage_offset,
321
+ Tensor** ret_new_tensor) {
322
+ // Validate input parameters first
323
+ if (self == nullptr ) {
324
+ ET_LOG (Error, " aoti_torch__reinterpret_tensor failed: self tensor is null" );
325
+ return Error::InvalidArgument;
326
+ }
327
+
328
+ if (sizes_ptr == nullptr && ndim > 0 ) {
329
+ ET_LOG (Error, " aoti_torch__reinterpret_tensor failed: sizes_ptr is null" );
330
+ return Error::InvalidArgument;
331
+ }
332
+
333
+ if (ret_new_tensor == nullptr ) {
334
+ ET_LOG (
335
+ Error, " aoti_torch__reinterpret_tensor failed: ret_new_tensor is null" );
336
+ return Error::InvalidArgument;
337
+ }
338
+
339
+ // Check if storage_offset is not 0 - return error if not
340
+ AOTITorchError storage_offset_error = validate_storage_offset (storage_offset);
341
+ if (storage_offset_error != Error::Ok) {
342
+ return storage_offset_error;
343
+ }
344
+
345
+ // Get the device info from the source tensor to perform device_index
346
+ // validation
347
+ int32_t device_type = 0 ;
348
+ int32_t device_index = 0 ;
349
+ AOTITorchError device_error = aoti_torch_get_device_type (self, &device_type);
350
+ if (device_error != Error::Ok) {
351
+ return device_error;
352
+ }
353
+
354
+ device_error = aoti_torch_get_device_index (self, &device_index);
355
+ if (device_error != Error::Ok) {
356
+ return device_error;
357
+ }
358
+
359
+ // Ensure device_index is always 0
360
+ if (device_index != 0 ) {
361
+ ET_LOG (Error, " device_index must be 0, got: %d" , device_index);
362
+ return Error::InvalidArgument;
363
+ }
364
+
365
+ // Get the dtype from the source tensor
366
+ int32_t dtype = 0 ;
367
+ AOTITorchError dtype_error = aoti_torch_get_dtype (self, &dtype);
368
+ if (dtype_error != Error::Ok) {
369
+ return dtype_error;
370
+ }
371
+
372
+ // Validate dtype using SupportedDTypes
373
+ dtype_error = validate_dtype (dtype);
374
+ if (dtype_error != Error::Ok) {
375
+ return dtype_error;
376
+ }
377
+
378
+ // Get the original data pointer from the source tensor
379
+ void * data_ptr = self->mutable_data_ptr ();
380
+ if (data_ptr == nullptr ) {
381
+ ET_LOG (Error, " Source tensor has null data pointer" );
382
+ return Error::InvalidArgument;
383
+ }
384
+
385
+ // Check if the given memory is in the map, if not return error
386
+ auto memory_it = memory_to_n_tensor.find (data_ptr);
387
+ if (memory_it == memory_to_n_tensor.end ()) {
388
+ ET_LOG (
389
+ Error,
390
+ " Memory address %p is not being tracked by reference counting system" ,
391
+ data_ptr);
392
+ return Error::InvalidArgument;
393
+ }
394
+
395
+ // Convert sizes using utility function from utils.h
396
+ std::vector<SizesType> sizes = convert_sizes_to_vector (ndim, sizes_ptr);
397
+
398
+ // Convert strides using utility function from utils.h
399
+ std::vector<StridesType> strides =
400
+ convert_strides_to_vector (ndim, sizes_ptr, strides_ptr);
401
+
402
+ // Create new tensor view that reinterprets the same memory with different
403
+ // shape/strides This creates a view, not a copy - the data pointer is shared
404
+ std::shared_ptr<Tensor> tensor = executorch::extension::from_blob (
405
+ data_ptr, // Reuse the same memory from source tensor
406
+ sizes, // New sizes with explicit SizesType
407
+ strides, // New strides with explicit StridesType
408
+ dtype_to_scalar_type (dtype) // Convert dtype with explicit type casting
409
+ );
410
+
411
+ if (!tensor) {
412
+ ET_LOG (Error, " Failed to create reinterpreted tensor view" );
413
+ return Error::InvalidArgument;
414
+ }
415
+
416
+ // Store the tensor so it doesn't get destroyed
417
+ tensors.insert (tensor);
418
+
419
+ *ret_new_tensor = tensor.get ();
420
+
421
+ // Increment the reference count for this memory address only if it is owned
422
+ // by tensor
423
+ memory_to_n_tensor[data_ptr] = memory_to_n_tensor[data_ptr] == NOT_OWN
424
+ ? NOT_OWN
425
+ : memory_to_n_tensor[data_ptr] + 1 ;
426
+
427
+ return Error::Ok;
428
+ }
429
+
313
430
} // extern "C"
314
431
315
432
} // namespace cuda
0 commit comments