diff --git a/include/tvm/ffi/any.h b/include/tvm/ffi/any.h index 4ff5db02..9adbd849 100644 --- a/include/tvm/ffi/any.h +++ b/include/tvm/ffi/any.h @@ -516,6 +516,11 @@ struct Type2Str { static std::string v() { return "Any"; } }; +template <> +struct Type2Str { + static std::string v() { return "Any"; } +}; + template <> struct Type2Str { static std::string v() { return "AnyView"; } @@ -526,6 +531,11 @@ struct Type2Str { static std::string v() { return "AnyView"; } }; +template <> +struct Type2Str { + static std::string v() { return "AnyView"; } +}; + template <> struct Type2Str { static std::string v() { return "void"; } diff --git a/include/tvm/ffi/type_traits.h b/include/tvm/ffi/type_traits.h index 1fd13047..dc7f982b 100644 --- a/include/tvm/ffi/type_traits.h +++ b/include/tvm/ffi/type_traits.h @@ -36,6 +36,8 @@ namespace tvm { namespace ffi { +class Any; + /*! * \brief TypeTraits that specifies the conversion behavior from/to FFI Any. * diff --git a/tests/cpp/test_reflection.cc b/tests/cpp/test_reflection.cc index 8fe6a188..82d73a27 100644 --- a/tests/cpp/test_reflection.cc +++ b/tests/cpp/test_reflection.cc @@ -294,4 +294,60 @@ TEST(Reflection, AccessPath) { auto root_parent = root->GetParent(); EXPECT_FALSE(root_parent.has_value()); } + +struct TestObjWithAny : public Object { + Any value; + explicit TestObjWithAny(Any value) : value(std::move(value)) {} + [[maybe_unused]] static constexpr bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.TestObjWithAny", TestObjWithAny, Object); +}; + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def(refl::init()).def_ro("value", &TestObjWithAny::value); +} + +struct TestObjWithAnyView : public Object { + Any value; + explicit TestObjWithAnyView(AnyView value) : value(value) {} + [[maybe_unused]] static constexpr bool _type_mutable = true; + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("test.TestObjWithAnyView", TestObjWithAnyView, Object); +}; + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def(refl::init()) + .def_ro("value", &TestObjWithAnyView::value); +} + +TEST(Reflection, InitWithAny) { + Function init = reflection::GetMethod("test.TestObjWithAny", "__ffi_init__"); + Any obj1 = init(42); + ASSERT_TRUE(obj1.as() != nullptr); + EXPECT_EQ(obj1.as()->value.cast(), 42); + + Any obj2 = init(3.14); + ASSERT_TRUE(obj2.as() != nullptr); + EXPECT_EQ(obj2.as()->value.cast(), 3.14); + + Any obj3 = init(String("hello")); + ASSERT_TRUE(obj3.as() != nullptr); + EXPECT_EQ(obj3.as()->value.cast(), "hello"); +} + +TEST(Reflection, InitWithAnyView) { + Function init = reflection::GetMethod("test.TestObjWithAnyView", "__ffi_init__"); + Any obj1 = init(42); + ASSERT_TRUE(obj1.as() != nullptr); + EXPECT_EQ(obj1.as()->value.cast(), 42); + + Any obj2 = init(3.14); + ASSERT_TRUE(obj2.as() != nullptr); + EXPECT_EQ(obj2.as()->value.cast(), 3.14); + + Any obj3 = init(String("hello")); + ASSERT_TRUE(obj3.as() != nullptr); + EXPECT_EQ(obj3.as()->value.cast(), "hello"); +} } // namespace