diff --git a/paddle/phi/api/include/compat/ATen/core/TensorBase.h b/paddle/phi/api/include/compat/ATen/core/TensorBase.h index 0b2c6fe0a69e65..1b8c92b43e5cc2 100644 --- a/paddle/phi/api/include/compat/ATen/core/TensorBase.h +++ b/paddle/phi/api/include/compat/ATen/core/TensorBase.h @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include #include #include @@ -27,6 +29,7 @@ #include "paddle/phi/api/include/api.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/common/place.h" +#include "paddle/phi/core/dense_tensor.h" namespace at { using PaddleTensor = paddle::Tensor; @@ -208,6 +211,35 @@ class PADDLE_API TensorBase { bool defined() const { return tensor_.defined(); } + int64_t storage_offset() const { + // Paddle DenseTensor stores offset in meta_.offset (in bytes) + // We need to convert to element offset + auto dense_tensor = + std::dynamic_pointer_cast(tensor_.impl()); + if (dense_tensor) { + size_t byte_offset = dense_tensor->meta().offset; + size_t element_size = SizeOf(tensor_.dtype()); + return element_size > 0 ? static_cast(byte_offset / element_size) + : 0; + } + return 0; + } + + c10::SymInt sym_storage_offset() const { + return c10::SymInt(storage_offset()); + } + + bool has_storage() const { return tensor_.defined(); } + + const Storage storage() const { + return Storage( + std::dynamic_pointer_cast(tensor_.impl())->Holder()); + } + + bool is_alias_of(const at::TensorBase& other) const { + return this->storage().allocation() == other.storage().allocation(); + } + Layout layout() const { switch (tensor_.layout()) { case common::DataLayout::STRIDED: diff --git a/paddle/phi/api/include/compat/c10/core/Allocator.h b/paddle/phi/api/include/compat/c10/core/Allocator.h new file mode 100644 index 00000000000000..3b4daf0b94b408 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/core/Allocator.h @@ -0,0 +1,110 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include +#include +#include +#include + +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/allocator.h" + +namespace c10 { + +// Deleter function pointer type (compatible with LibTorch) +using DeleterFnPtr = void (*)(void*); + +// DataPtr class compatible with LibTorch's c10::DataPtr +// Wraps a pointer with associated device and deleter +class DataPtr { + public: + DataPtr() : ptr_(nullptr), device_(phi::CPUPlace()) {} + + explicit DataPtr(void* data, phi::Place device = phi::CPUPlace()) + : ptr_(data), device_(device) {} + + DataPtr(void* data, + void* ctx, + DeleterFnPtr ctx_deleter, + phi::Place device = phi::CPUPlace()) + : ptr_(data), ctx_(ctx), deleter_(ctx_deleter), device_(device) {} + + // Construct from phi::Allocation + explicit DataPtr(const std::shared_ptr& alloc) + : ptr_(alloc ? alloc->ptr() : nullptr), + device_(alloc ? alloc->place() : phi::CPUPlace()), + allocation_(alloc) {} + + DataPtr(const DataPtr&) = default; + DataPtr& operator=(const DataPtr&) = default; + DataPtr(DataPtr&&) = default; + DataPtr& operator=(DataPtr&&) = default; + + void* get() const { return ptr_; } + + void* operator->() const { return ptr_; } + + explicit operator bool() const { return ptr_ != nullptr; } + + phi::Place device() const { return device_; } + + DeleterFnPtr get_deleter() const { return deleter_; } + + void* get_context() const { return ctx_; } + + void clear() { + ptr_ = nullptr; + ctx_ = nullptr; + deleter_ = nullptr; + allocation_.reset(); + } + + // Get the underlying allocation (if available) + std::shared_ptr allocation() const { return allocation_; } + + private: + void* ptr_ = nullptr; + void* ctx_ = nullptr; + DeleterFnPtr deleter_ = nullptr; + phi::Place device_; + std::shared_ptr allocation_; +}; + +inline bool operator==(const DataPtr& dp, std::nullptr_t) noexcept { + return !dp; +} + +inline bool operator==(std::nullptr_t, const DataPtr& dp) noexcept { + return !dp; +} + +inline bool operator!=(const DataPtr& dp, std::nullptr_t) noexcept { + return static_cast(dp); +} + +inline bool operator!=(std::nullptr_t, const DataPtr& dp) noexcept { + return static_cast(dp); +} + +} // namespace c10 + +namespace at { +using DataPtr = c10::DataPtr; +} // namespace at diff --git a/paddle/phi/api/include/compat/c10/core/Storage.h b/paddle/phi/api/include/compat/c10/core/Storage.h new file mode 100644 index 00000000000000..4c1b4ed03f7652 --- /dev/null +++ b/paddle/phi/api/include/compat/c10/core/Storage.h @@ -0,0 +1,348 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// #The file has been adapted from pytorch project +// #Licensed under BSD-style license - +// https://github.com/pytorch/pytorch/blob/main/LICENSE + +#pragma once + +#include +#include + +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/allocator.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/storage_properties.h" + +#include "c10/core/Allocator.h" // For DataPtr + +namespace c10 { + +struct Storage; + +// Check if two storages share the same underlying allocation +inline bool isSharedStorageAlias(const Storage& storage0, + const Storage& storage1); + +struct Storage { + public: + // Tag types for constructor disambiguation (LibTorch compatible) + struct use_byte_size_t {}; + struct unsafe_borrow_t { + unsafe_borrow_t() = default; + }; + + // Default constructor + Storage() = default; + + // Copy constructor + Storage(const Storage& other) + : allocation_(other.allocation_), + allocator_(other.allocator_), + resizable_(other.resizable_) {} + + // Copy assignment operator + Storage& operator=(const Storage& other) { + if (this != &other) { + allocation_ = other.allocation_; + allocator_ = other.allocator_; + resizable_ = other.resizable_; + } + return *this; + } + + // Move constructor + Storage(Storage&& other) noexcept + : allocation_(std::move(other.allocation_)), + allocator_(other.allocator_), + resizable_(other.resizable_) { + other.allocator_ = nullptr; + other.resizable_ = false; + } + + // Move assignment operator + Storage& operator=(Storage&& other) noexcept { + if (this != &other) { + allocation_ = std::move(other.allocation_); + allocator_ = other.allocator_; + resizable_ = other.resizable_; + other.allocator_ = nullptr; + other.resizable_ = false; + } + return *this; + } + + // Constructor with allocation and optional storage properties + Storage(std::shared_ptr alloc, + std::unique_ptr props = nullptr) + : allocation_(std::move(alloc)) {} + + // Constructor with size and allocator (LibTorch compatible) + explicit Storage(size_t size_bytes, phi::Allocator* allocator = nullptr) { + if (allocator) { + allocation_ = + std::shared_ptr(allocator->Allocate(size_bytes)); + allocator_ = allocator; + } else { + allocation_ = nullptr; + allocator_ = nullptr; + } + } + + // LibTorch compatible constructor with use_byte_size_t tag + Storage(use_byte_size_t /*use_byte_size*/, + size_t size_bytes, + phi::Allocator* allocator = nullptr, + bool resizable = false) + : allocator_(allocator), resizable_(resizable) { + if (allocator) { + allocation_ = + std::shared_ptr(allocator->Allocate(size_bytes)); + } else { + allocation_ = nullptr; + } + } + + // LibTorch compatible constructor with pre-allocated memory + Storage(use_byte_size_t /*use_byte_size*/, + size_t size_bytes, + std::shared_ptr data_ptr, + phi::Allocator* allocator = nullptr, + bool resizable = false) + : allocation_(std::move(data_ptr)), + allocator_(allocator), + resizable_(resizable) {} + + protected: + // Unsafe borrow constructor (for MaybeOwnedTraits) + explicit Storage(unsafe_borrow_t, const Storage& rhs) + : allocation_(rhs.allocation_), + allocator_(rhs.allocator_), + resizable_(rhs.resizable_) {} + + // Forward declare template and make specialization a friend + template + friend struct MaybeOwnedTraits; + + public: + // Check if storage is valid (has allocation) + bool valid() const { return allocation_ != nullptr; } + + // Boolean conversion operator (LibTorch compatible) + explicit operator bool() const { return allocation_ != nullptr; } + + // Get the number of bytes in the storage + size_t nbytes() const { return allocation_ ? allocation_->size() : 0; } + + // Set the number of bytes (for resizable storage) + void set_nbytes(size_t size_bytes) { + if (resizable_ && allocator_) { + allocation_ = + std::shared_ptr(allocator_->Allocate(size_bytes)); + } + } + + // Check if storage is resizable + bool resizable() const { return resizable_; } + + // Get mutable data pointer + void* mutable_data() const { + return allocation_ ? allocation_->ptr() : nullptr; + } + + // Get const data pointer + const void* data() const { + return allocation_ ? allocation_->ptr() : nullptr; + } + + // Get the underlying allocation as DataPtr (LibTorch compatible: data_ptr()) + DataPtr data_ptr() const { return DataPtr(allocation_); } + + // Get the underlying allocation as mutable DataPtr reference + DataPtr mutable_data_ptr() const { return DataPtr(allocation_); } + + // Get the underlying allocation + std::shared_ptr allocation() const { return allocation_; } + + // Get the allocator + phi::Allocator* allocator() const { return allocator_; } + + // Get the device/place type + phi::AllocationType device_type() const { + return allocation_ ? allocation_->place().GetType() + : phi::AllocationType::CPU; + } + + // Get the device/place + phi::Place device() const { + return allocation_ ? allocation_->place() : phi::Place(); + } + + // Check if this storage is unique (use_count == 1) + bool unique() const { return allocation_.use_count() == 1; } + + // Get the reference count + size_t use_count() const { return allocation_.use_count(); } + + // Check if this storage is an alias of another + bool is_alias_of(const Storage& other) const { + if (!allocation_ || !other.allocation_) { + return false; + } + // Check if they share the same allocation or overlapping memory + return allocation_ == other.allocation_ || + isSharedStorageAlias(*this, other); + } + + // Unsafe release of the underlying allocation (for advanced usage) + phi::Allocation* unsafeReleaseAllocation() { + auto* ptr = allocation_.get(); + allocation_.reset(); + return ptr; + } + + // Unsafe get of the underlying allocation pointer + phi::Allocation* unsafeGetAllocation() const noexcept { + return allocation_.get(); + } + + // Set data pointer (swap and return old) - accepts DataPtr + DataPtr set_data_ptr(DataPtr&& new_data_ptr) { + DataPtr old_data_ptr(allocation_); + allocation_ = new_data_ptr.allocation(); + return old_data_ptr; + } + + // Set data pointer (swap and return old) - accepts shared_ptr + std::shared_ptr set_data_ptr( + std::shared_ptr data_ptr) { + std::swap(allocation_, data_ptr); + return data_ptr; + } + + // Set data pointer (no swap) - accepts DataPtr + void set_data_ptr_noswap(DataPtr&& new_data_ptr) { + allocation_ = new_data_ptr.allocation(); + } + + // Set data pointer (no swap) - accepts shared_ptr + void set_data_ptr_noswap(std::shared_ptr data_ptr) { + allocation_ = std::move(data_ptr); + } + + private: + std::shared_ptr allocation_; + phi::Allocator* allocator_ = nullptr; + bool resizable_ = false; +}; + +// Implementation of isSharedStorageAlias +inline bool isSharedStorageAlias(const Storage& storage0, + const Storage& storage1) { + if (!storage0.valid() || !storage1.valid()) { + return false; + } + // Check if memory ranges overlap + const void* ptr0 = storage0.data(); + const void* ptr1 = storage1.data(); + size_t size0 = storage0.nbytes(); + size_t size1 = storage1.nbytes(); + + if (ptr0 == nullptr || ptr1 == nullptr || size0 == 0 || size1 == 0) { + return false; + } + + const char* start0 = static_cast(ptr0); + const char* end0 = start0 + size0; + const char* start1 = static_cast(ptr1); + const char* end1 = start1 + size1; + + // Check for overlap + return !(end0 <= start1 || end1 <= start0); +} + +// Template specialization for MaybeOwnedTraits +// Provides safe borrowing semantics for Storage objects +template +struct MaybeOwnedTraits; + +template <> +struct MaybeOwnedTraits { + using owned_type = c10::Storage; + using borrow_type = c10::Storage; + + // Create a borrowed reference from an owned Storage + static borrow_type createBorrow(const owned_type& from) { + return borrow_type(borrow_type::unsafe_borrow_t{}, from); + } + + // Assign a borrowed reference (LibTorch compatible signature with pointer) + static void assignBorrow(borrow_type* lhs, const borrow_type& rhs) { + *lhs = borrow_type(borrow_type::unsafe_borrow_t{}, rhs); + } + + // Destroy a borrowed reference (release without deallocating) + static void destroyBorrow(borrow_type* toDestroy) { + *toDestroy = Storage(); // Reset to empty state + } + + // Get a reference to the owned object from a borrow + static const owned_type& referenceFromBorrow(const borrow_type& borrow) { + return borrow; + } + + // Get a pointer to the owned object from a borrow + static const owned_type* pointerFromBorrow(const borrow_type& borrow) { + return &borrow; + } + + // Debug check if borrow is valid + static bool debugBorrowIsValid(const borrow_type& /*borrow*/) { return true; } +}; + +// Template specialization for ExclusivelyOwnedTraits +// Provides exclusive ownership semantics for Storage objects +template +struct ExclusivelyOwnedTraits; + +template <> +struct ExclusivelyOwnedTraits { + using repr_type = c10::Storage; + using pointer_type = c10::Storage*; + using const_pointer_type = const c10::Storage*; + + // Create a null/empty representation + static repr_type nullRepr() { return c10::Storage(); } + + // Create a Storage in place with given arguments + template + static repr_type createInPlace(Args&&... args) { + return c10::Storage(std::forward(args)...); + } + + // Move a Storage into the representation + static repr_type moveToRepr(c10::Storage&& x) { return std::move(x); } + + // Take ownership from a Storage pointer (LibTorch compatible) + static c10::Storage take(c10::Storage* x) { return std::move(*x); } + + // Get a pointer to the representation (mutable) + static pointer_type getImpl(repr_type* x) { return x; } + + // Get a const pointer to the representation + static const_pointer_type getImpl(const repr_type& x) { return &x; } +}; + +} // namespace c10 diff --git a/test/cpp/compat/CMakeLists.txt b/test/cpp/compat/CMakeLists.txt index ba533405a553c8..daefc89dcf53cd 100644 --- a/test/cpp/compat/CMakeLists.txt +++ b/test/cpp/compat/CMakeLists.txt @@ -1,6 +1,7 @@ if(NOT WIN32) if(WITH_GPU) nv_test(compat_basic_test SRCS compat_basic_test.cc) + nv_test(c10_storage_test SRCS c10_storage_test.cc) nv_test(compat_squeeze_test SRCS compat_squeeze_test.cc) cc_test(torch_library_test SRCS torch_library_test.cc) endif() diff --git a/test/cpp/compat/c10_storage_test.cc b/test/cpp/compat/c10_storage_test.cc new file mode 100644 index 00000000000000..dad80fdbdb02af --- /dev/null +++ b/test/cpp/compat/c10_storage_test.cc @@ -0,0 +1,381 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include +#include +#endif +#include "ATen/ATen.h" +#include "gtest/gtest.h" +#include "paddle/phi/common/float16.h" +#include "torch/all.h" + +TEST(StorageTest, BasicStorageAPIs) { + // Test basic Storage APIs through TensorBase + at::TensorBase tensor = at::ones({2, 3}, at::kFloat); + + const c10::Storage& storage = tensor.storage(); + + // Test valid() + ASSERT_TRUE(storage.valid()); + + // Test nbytes() + size_t expected_nbytes = 2 * 3 * sizeof(float); + ASSERT_EQ(storage.nbytes(), expected_nbytes); + + // Test data() and mutable_data() + ASSERT_NE(storage.data(), nullptr); + ASSERT_NE(storage.mutable_data(), nullptr); + ASSERT_EQ(storage.data(), storage.mutable_data()); + + // Test allocation() + auto alloc = storage.allocation(); + ASSERT_NE(alloc, nullptr); + ASSERT_EQ(alloc->size(), expected_nbytes); + + // Test unique() and use_count() + // Note: In PaddlePaddle, DenseTensor holds a reference, Storage holds one, + // and there may be additional internal references during tensor creation + ASSERT_FALSE(storage.unique()); + ASSERT_EQ(storage.use_count(), 3); +} + +TEST(StorageTest, StorageSharing) { + // Test storage sharing between tensors + at::TensorBase tensor1 = at::ones({2, 3}, at::kFloat); + at::TensorBase tensor2 = tensor1; // Shared storage + + const c10::Storage& storage1 = tensor1.storage(); + const c10::Storage& storage2 = tensor2.storage(); + + // Test that storages are the same + ASSERT_EQ(storage1.allocation(), storage2.allocation()); + + // Test use_count + // Note: In PaddlePaddle, the count includes: + // 1. DenseTensor's internal holder_ + // 2. storage1's allocation_ + // 3. storage2's allocation_ + // Total: 3 + ASSERT_EQ(storage1.use_count(), 3); + ASSERT_EQ(storage2.use_count(), 3); + + // Test unique() is false + ASSERT_FALSE(storage1.unique()); + ASSERT_FALSE(storage2.unique()); +} + +TEST(StorageTest, StorageOffsetAPI) { + // Test storage_offset() API + at::TensorBase tensor = at::ones({2, 3}, at::kFloat); + + // Test storage_offset() - should always return 0 for PaddlePaddle + ASSERT_EQ(tensor.storage_offset(), 0); + + // Test sym_storage_offset() - should always return SymInt(0) for PaddlePaddle + c10::SymInt sym_offset = tensor.sym_storage_offset(); + ASSERT_EQ(sym_offset, c10::SymInt(0)); +} + +TEST(StorageTest, IsAliasOfAPI) { + // Test is_alias_of() API + at::TensorBase tensor1 = at::ones({2, 3}, at::kFloat); + at::TensorBase tensor2 = tensor1; // Shared storage, should be alias + at::TensorBase tensor3 = at::ones({2, 3}, at::kFloat); // Different storage + + // Test that tensor1 and tensor2 are aliases (share same storage) + ASSERT_TRUE(tensor1.is_alias_of(tensor2)); + ASSERT_TRUE(tensor2.is_alias_of(tensor1)); + + // Test that tensor1 and tensor3 are not aliases (different storage) + ASSERT_FALSE(tensor1.is_alias_of(tensor3)); + ASSERT_FALSE(tensor3.is_alias_of(tensor1)); + + // Test that tensor is alias of itself + ASSERT_TRUE(tensor1.is_alias_of(tensor1)); +} + +TEST(StorageTest, BoolConversionOperator) { + // Test operator bool() for Storage + at::TensorBase tensor = at::ones({2, 3}, at::kFloat); + const c10::Storage& storage = tensor.storage(); + + // Valid storage should convert to true + ASSERT_TRUE(static_cast(storage)); + + // Default constructed storage should convert to false + c10::Storage empty_storage; + ASSERT_FALSE(static_cast(empty_storage)); + ASSERT_FALSE(empty_storage.valid()); +} + +TEST(StorageTest, ResizableAPI) { + // Test resizable() API + at::TensorBase tensor = at::ones({2, 3}, at::kFloat); + const c10::Storage& storage = tensor.storage(); + + // Default storage from tensor should not be resizable + ASSERT_FALSE(storage.resizable()); +} + +TEST(StorageTest, DeviceAndDeviceTypeAPIs) { + // Test device() and device_type() APIs + at::TensorBase cpu_tensor = at::ones({2, 3}, at::kFloat); + const c10::Storage& cpu_storage = cpu_tensor.storage(); + + // Test device_type() returns CPU + ASSERT_EQ(cpu_storage.device_type(), phi::AllocationType::CPU); + + // Test device() returns valid place + phi::Place place = cpu_storage.device(); + ASSERT_EQ(place.GetType(), phi::AllocationType::CPU); + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (at::cuda::is_available()) { + at::TensorBase cuda_tensor = at::ones( + {2, 3}, c10::TensorOptions().dtype(at::kFloat).device(at::kCUDA)); + const c10::Storage& cuda_storage = cuda_tensor.storage(); + + // Test device_type() returns CUDA/GPU + ASSERT_EQ(cuda_storage.device_type(), phi::AllocationType::GPU); + + // Test device() returns CUDA place + phi::Place cuda_place = cuda_storage.device(); + ASSERT_EQ(cuda_place.GetType(), phi::AllocationType::GPU); + } +#endif +} + +TEST(StorageTest, AllocatorAPI) { + // Test allocator() API + at::TensorBase tensor = at::ones({2, 3}, at::kFloat); + const c10::Storage& storage = tensor.storage(); + + // Allocator may be nullptr for storage obtained from tensor + // This is expected behavior in the compatibility layer + phi::Allocator* allocator = storage.allocator(); + // Note: allocator can be nullptr, this is just to verify the API works + (void)allocator; +} + +TEST(StorageTest, UnsafeAllocationAPIs) { + // Test unsafeGetAllocation() API + at::TensorBase tensor = at::ones({2, 3}, at::kFloat); + c10::Storage storage = tensor.storage(); + + // Test unsafeGetAllocation() + phi::Allocation* alloc_ptr = storage.unsafeGetAllocation(); + ASSERT_NE(alloc_ptr, nullptr); + ASSERT_EQ(alloc_ptr->size(), 2 * 3 * sizeof(float)); + + // Test that the pointer matches the data pointer + ASSERT_EQ(alloc_ptr->ptr(), storage.data()); +} + +TEST(StorageTest, SetDataPtrAPIs) { + // Test set_data_ptr() and set_data_ptr_noswap() APIs + at::TensorBase tensor1 = at::ones({2, 3}, at::kFloat); + at::TensorBase tensor2 = at::ones({4, 5}, at::kFloat); + + c10::Storage storage1 = tensor1.storage(); + c10::Storage storage2 = tensor2.storage(); + + auto alloc1 = storage1.allocation(); + auto alloc2 = storage2.allocation(); + + // Test set_data_ptr() - swaps and returns old + auto old_alloc = storage1.set_data_ptr(alloc2); + ASSERT_EQ(old_alloc, alloc1); + ASSERT_EQ(storage1.allocation(), alloc2); + + // Test set_data_ptr_noswap() + storage1.set_data_ptr_noswap(alloc1); + ASSERT_EQ(storage1.allocation(), alloc1); +} + +TEST(StorageTest, StorageCopyAndMove) { + // Test copy and move constructors/operators + at::TensorBase tensor = at::ones({2, 3}, at::kFloat); + c10::Storage original = tensor.storage(); + + // Test copy constructor + c10::Storage copied(original); + ASSERT_EQ(copied.allocation(), original.allocation()); + ASSERT_EQ(copied.nbytes(), original.nbytes()); + ASSERT_TRUE(copied.valid()); + + // Test copy assignment + c10::Storage copy_assigned; + copy_assigned = original; + ASSERT_EQ(copy_assigned.allocation(), original.allocation()); + + // Test move constructor + c10::Storage to_move = original; + auto alloc_before_move = to_move.allocation(); + c10::Storage moved(std::move(to_move)); + ASSERT_EQ(moved.allocation(), alloc_before_move); + ASSERT_TRUE(moved.valid()); + + // Test move assignment + c10::Storage to_move2 = original; + alloc_before_move = to_move2.allocation(); + c10::Storage move_assigned; + move_assigned = std::move(to_move2); + ASSERT_EQ(move_assigned.allocation(), alloc_before_move); +} + +TEST(StorageTest, DefaultConstructedStorage) { + // Test default constructed storage + c10::Storage storage; + + ASSERT_FALSE(storage.valid()); + ASSERT_FALSE(static_cast(storage)); + ASSERT_EQ(storage.nbytes(), 0); + ASSERT_EQ(storage.data(), nullptr); + ASSERT_EQ(storage.mutable_data(), nullptr); + ASSERT_EQ(storage.allocation(), nullptr); + ASSERT_EQ(storage.use_count(), 0); + ASSERT_FALSE(storage.resizable()); + ASSERT_EQ(storage.allocator(), nullptr); +} + +TEST(StorageTest, IsSharedStorageAliasFunction) { + // Test isSharedStorageAlias() function + at::TensorBase tensor1 = at::ones({2, 3}, at::kFloat); + at::TensorBase tensor2 = tensor1; // Shared storage + at::TensorBase tensor3 = at::ones({2, 3}, at::kFloat); // Different storage + + c10::Storage storage1 = tensor1.storage(); + c10::Storage storage2 = tensor2.storage(); + c10::Storage storage3 = tensor3.storage(); + + // Same allocation should return true + ASSERT_TRUE(c10::isSharedStorageAlias(storage1, storage2)); + ASSERT_TRUE(c10::isSharedStorageAlias(storage2, storage1)); + + // Different allocations should return false + ASSERT_FALSE(c10::isSharedStorageAlias(storage1, storage3)); + ASSERT_FALSE(c10::isSharedStorageAlias(storage3, storage1)); + + // Empty storage should return false + c10::Storage empty_storage; + ASSERT_FALSE(c10::isSharedStorageAlias(storage1, empty_storage)); + ASSERT_FALSE(c10::isSharedStorageAlias(empty_storage, storage1)); + ASSERT_FALSE(c10::isSharedStorageAlias(empty_storage, empty_storage)); +} + +TEST(StorageTest, StorageIsAliasOfMethod) { + // Test Storage::is_alias_of() method + at::TensorBase tensor1 = at::ones({2, 3}, at::kFloat); + at::TensorBase tensor2 = tensor1; + at::TensorBase tensor3 = at::ones({2, 3}, at::kFloat); + + c10::Storage storage1 = tensor1.storage(); + c10::Storage storage2 = tensor2.storage(); + c10::Storage storage3 = tensor3.storage(); + + // Same underlying allocation + ASSERT_TRUE(storage1.is_alias_of(storage2)); + ASSERT_TRUE(storage2.is_alias_of(storage1)); + + // Different allocations + ASSERT_FALSE(storage1.is_alias_of(storage3)); + + // Self alias + ASSERT_TRUE(storage1.is_alias_of(storage1)); + + // Empty storage + c10::Storage empty_storage; + ASSERT_FALSE(storage1.is_alias_of(empty_storage)); + ASSERT_FALSE(empty_storage.is_alias_of(storage1)); +} + +TEST(StorageTest, MaybeOwnedTraitsSpecialization) { + // Test MaybeOwnedTraits specialization + using Traits = c10::MaybeOwnedTraits; + + at::TensorBase tensor = at::ones({2, 3}, at::kFloat); + c10::Storage original = tensor.storage(); + + // Test createBorrow + Traits::borrow_type borrowed = Traits::createBorrow(original); + ASSERT_EQ(borrowed.allocation(), original.allocation()); + ASSERT_TRUE(borrowed.valid()); + + // Test referenceFromBorrow + const Traits::owned_type& ref = Traits::referenceFromBorrow(borrowed); + ASSERT_EQ(ref.allocation(), original.allocation()); + + // Test pointerFromBorrow + const Traits::owned_type* ptr = Traits::pointerFromBorrow(borrowed); + ASSERT_NE(ptr, nullptr); + ASSERT_EQ(ptr->allocation(), original.allocation()); + + // Test debugBorrowIsValid + ASSERT_TRUE(Traits::debugBorrowIsValid(borrowed)); + + // Test assignBorrow + c10::Storage another_borrow; + Traits::assignBorrow(&another_borrow, borrowed); + ASSERT_EQ(another_borrow.allocation(), original.allocation()); + + // Test destroyBorrow + Traits::destroyBorrow(&borrowed); + ASSERT_FALSE(borrowed.valid()); +} + +TEST(StorageTest, ExclusivelyOwnedTraitsSpecialization) { + // Test ExclusivelyOwnedTraits specialization + using Traits = c10::ExclusivelyOwnedTraits; + + // Test nullRepr + Traits::repr_type null_repr = Traits::nullRepr(); + ASSERT_FALSE(null_repr.valid()); + + // Test createInPlace with default constructor + Traits::repr_type created = Traits::createInPlace(); + ASSERT_FALSE(created.valid()); // Default constructed + + // Test moveToRepr + at::TensorBase tensor = at::ones({2, 3}, at::kFloat); + c10::Storage original = tensor.storage(); + auto alloc = original.allocation(); + Traits::repr_type moved = Traits::moveToRepr(std::move(original)); + ASSERT_EQ(moved.allocation(), alloc); + + // Test take + c10::Storage to_take = tensor.storage(); + alloc = to_take.allocation(); + c10::Storage taken = Traits::take(&to_take); + ASSERT_EQ(taken.allocation(), alloc); + + // Test getImpl (mutable) + Traits::pointer_type ptr = Traits::getImpl(&taken); + ASSERT_NE(ptr, nullptr); + ASSERT_EQ(ptr->allocation(), alloc); + + // Test getImpl (const) + const c10::Storage& const_taken = taken; + Traits::const_pointer_type const_ptr = Traits::getImpl(const_taken); + ASSERT_NE(const_ptr, nullptr); + ASSERT_EQ(const_ptr->allocation(), alloc); +}