diff --git a/paddle/phi/api/include/compat/ATen/core/TensorBase.h b/paddle/phi/api/include/compat/ATen/core/TensorBase.h index f314cc19e0992d..7a94bbbf345bdd 100644 --- a/paddle/phi/api/include/compat/ATen/core/TensorBase.h +++ b/paddle/phi/api/include/compat/ATen/core/TensorBase.h @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -207,6 +208,25 @@ class PADDLE_API TensorBase { bool defined() const { return tensor_.defined(); } + Layout layout() const { + switch (tensor_.layout()) { + case common::DataLayout::STRIDED: + case common::DataLayout::NCHW: + case common::DataLayout::NHWC: + case common::DataLayout::NCDHW: + case common::DataLayout::NDHWC: + return c10::kStrided; + case common::DataLayout::SPARSE_COO: + return c10::kSparse; + case common::DataLayout::SPARSE_CSR: + return c10::kSparseCsr; + case common::DataLayout::ONEDNN: + return c10::kMkldnn; + default: + return c10::kStrided; + } + } + // Return a `TensorAccessor` for CPU `Tensor`s. You have to specify scalar // type and // dimension. diff --git a/test/cpp/compat/compat_basic_test.cc b/test/cpp/compat/compat_basic_test.cc index 8cc4adb7248c95..acc28ec3a44925 100644 --- a/test/cpp/compat/compat_basic_test.cc +++ b/test/cpp/compat/compat_basic_test.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -340,3 +341,16 @@ TEST(TestTensorOperators, SubScriptOperator) { ASSERT_EQ(tensor_2.data_ptr()[i], static_cast(i + offset)); } } + +TEST(TensorBaseTest, LayoutAPI) { + // Test layout() API for strided tensors + at::TensorBase tensor = at::ones({2, 3}, at::kFloat); + + // Default tensor should have Strided layout + ASSERT_EQ(tensor.layout(), c10::kStrided); + + // Test layout output stream operator + std::ostringstream oss; + oss << tensor.layout(); + ASSERT_EQ(oss.str(), "Strided"); +}