diff --git a/include/tvm/ffi/function.h b/include/tvm/ffi/function.h index d1cc6933..30437315 100644 --- a/include/tvm/ffi/function.h +++ b/include/tvm/ffi/function.h @@ -33,6 +33,10 @@ #define TVM_FFI_DLL_EXPORT_INCLUDE_METADATA 0 #endif +#if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA +#include +#endif // TVM_FFI_DLL_EXPORT_INCLUDE_METADATA + #include #include #include @@ -40,7 +44,9 @@ #include #include +#include #include +#include #include #include #include @@ -165,21 +171,19 @@ class FunctionObjImpl : public FunctionObj { /*! * \brief Derived object class for constructing ffi::FunctionObj. - * \param callable The type-erased callable object (rvalue). - */ - explicit FunctionObjImpl(TCallable&& callable) : callable_(std::move(callable)) { - this->safe_call = SafeCall; - this->cpp_call = reinterpret_cast(CppCall); - } - /*! - * \brief Derived object class for constructing ffi::FunctionObj. - * \param callable The type-erased callable object (lvalue). + * \param args The arguments to construct TCallable */ - explicit FunctionObjImpl(const TCallable& callable) : callable_(callable) { + template + explicit FunctionObjImpl(Args&&... args) : callable_(std::forward(args)...) { this->safe_call = SafeCall; this->cpp_call = reinterpret_cast(CppCall); } + FunctionObjImpl(const FunctionObjImpl&) = delete; + FunctionObjImpl& operator=(const FunctionObjImpl&) = delete; + + TCallable* GetCallable() { return &callable_; } + private: // implementation of call static void CppCall(const FunctionObj* func, const AnyView* args, int32_t num_args, Any* result) { @@ -356,6 +360,29 @@ class Function : public ObjectRef { } } + /*! + * \brief Constructing a packed function from a callable type + * whose signature is consistent with `ffi::Function`. + * It will create the Callable object with the given arguments, + * and return the inplace constructed Function along with + * the pointer to the callable object. The lifetime of the callable + * object is managed by the returned Function. + * \param args The arguments to construct TCallable + * \return A tuple of (Function, TCallable*) + */ + template + static auto FromPackedInplace(Args&&... args) { + // We must ensure TCallable is a value type (decay_t) that can hold the callable object + static_assert(std::is_same_v>); + static_assert(std::is_invocable_v); + using ObjType = details::FunctionObjImpl; + Function func; + auto obj_ptr = make_object(std::forward(args)...); + auto* call_ptr = obj_ptr->GetCallable(); + func.data_ = std::move(obj_ptr); + return std::make_tuple(std::move(func), call_ptr); + } + /*! * \brief Create ffi::Function from a C style callbacks. * diff --git a/include/tvm/ffi/reflection/overload.h b/include/tvm/ffi/reflection/overload.h new file mode 100644 index 00000000..6556338a --- /dev/null +++ b/include/tvm/ffi/reflection/overload.h @@ -0,0 +1,501 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/reflection/overload.h + * \brief Registry of reflection metadata, supporting function overloading. + */ +#ifndef TVM_FFI_EXTRA_OVERLOAD_H +#define TVM_FFI_EXTRA_OVERLOAD_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace ffi { + +namespace details { + +struct OverloadBase { + public: + // Try Call function pointer type, return true if matched and called + using FnPtr = bool (*)(OverloadBase*, const AnyView*, int32_t, Any*); + + explicit OverloadBase(int32_t num_args, std::optional name) + : num_args_(num_args), + name_(name ? std::move(*name) : ""), + name_ptr_(name ? &this->name_ : nullptr) {} + + virtual void Register(std::unique_ptr overload) = 0; + virtual FnPtr GetTryCallPtr() = 0; + virtual void GetMismatchMessage(std::ostringstream& os, const AnyView* args, + int32_t num_args) = 0; + + virtual ~OverloadBase() = default; + OverloadBase(const OverloadBase&) = delete; + OverloadBase& operator=(const OverloadBase&) = delete; + + public: + static constexpr int32_t kAllMatched = -1; + + // a fast cache for last matched arg index + // on 64-bit platform, this is packed in the same 8 byte with num_args_ + int32_t last_mismatch_index_{kAllMatched}; + + // some constant helper args + const int32_t num_args_; + const std::string name_; + const std::string* const name_ptr_; +}; + +template +struct CaptureTupleAux; + +template +struct CaptureTupleAux> { + using type = std::tuple>...>; +}; + +template +struct TypedOverload : OverloadBase { + public: + static_assert(std::is_same_v>, "Callable must be value type"); + + using FuncInfo = details::FunctionInfo; + using PackedArgs = typename FuncInfo::ArgType; + using Ret = typename FuncInfo::RetType; + using CaptureTuple = typename CaptureTupleAux::type; + using OverloadBase::name_; + using OverloadBase::name_ptr_; + using typename OverloadBase::FnPtr; + + static constexpr auto kNumArgs = FuncInfo::num_args; + static constexpr auto kSeq = std::make_index_sequence{}; + + explicit TypedOverload(const Callable& f, std::optional name = std::nullopt) + : OverloadBase(kNumArgs, std::move(name)), f_(f) {} + explicit TypedOverload(Callable&& f, std::optional name = std::nullopt) + : OverloadBase(kNumArgs, std::move(name)), f_(std::move(f)) {} + + bool TryCall(const AnyView* args, int32_t num_args, Any* rv) { + if (num_args != kNumArgs) return false; + CaptureTuple captures{}; + if (!TrySetAux(kSeq, captures, args)) return false; + // now all captures are set + if constexpr (std::is_same_v) { + CallAux(kSeq, captures); + return true; + } else { + *rv = CallAux(kSeq, captures); + return true; + } + } + + void Register(std::unique_ptr overload) override { + TVM_FFI_ICHECK(false) << "This should never be called."; + } + + FnPtr GetTryCallPtr() final { + // lambda without a capture can be converted to function pointer + return [](OverloadBase* base, const AnyView* args, int32_t num_args, Any* rv) -> bool { + return static_cast*>(base)->TryCall(args, num_args, rv); + }; + } + + void GetMismatchMessage(std::ostringstream& os, const AnyView* args, int32_t num_args) final { + FGetFuncSignature f_sig = FuncInfo::Sig; + if (num_args != kNumArgs) { + os << "Mismatched number of arguments when calling: `" << name_ << " " + << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " << kNumArgs << " arguments"; + } else { + GetMismatchMessageAux<0>(os, args, num_args); + } + } + + private: + template + void GetMismatchMessageAux(std::ostringstream& os, const AnyView* args, int32_t num_args) { + if constexpr (I < kNumArgs) { + if (this->last_mismatch_index_ == static_cast(I)) { + TVMFFIAny any_data = args[I].CopyToTVMFFIAny(); + FGetFuncSignature f_sig = FuncInfo::Sig; + using Type = std::decay_t>; + os << "Mismatched type on argument #" << I << " when calling: `" << name_ << " " + << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected `" << Type2Str::v() + << "` but got `" << TypeTraits::GetMismatchTypeInfo(&any_data) << '`'; + } else { + GetMismatchMessageAux(os, args, num_args); + } + } + // end of recursion + } + + template + Ret CallAux(std::index_sequence, CaptureTuple& tuple) { + /// NOTE: this works for T, const T, const T&, T&& argument types + return f_(static_cast>(std::move(*std::get(tuple)))...); + } + + template + bool TrySetAux(std::index_sequence, CaptureTuple& tuple, const AnyView* args) { + return (TrySetOne(tuple, args) && ...); + } + + template + bool TrySetOne(CaptureTuple& tuple, const AnyView* args) { + using Type = std::decay_t>; + auto& capture = std::get(tuple); + if constexpr (std::is_same_v) { + capture = args[I]; + return true; + } else if constexpr (std::is_same_v) { + capture = Any(args[I]); + return true; + } else { + capture = args[I].template try_cast(); + if (capture.has_value()) return true; + // slow path: record the last mismatch index + this->last_mismatch_index_ = static_cast(I); + return false; + } + } + + protected: + Callable f_; +}; + +template +inline auto CreateNewOverload(Callable&& f, std::string name) { + using Type = TypedOverload>; + return std::make_unique(std::forward(f), std::move(name)); +} + +template +struct OverloadedFunction : TypedOverload { + public: + using TypedBase = TypedOverload; + using OverloadBase::name_; + using OverloadBase::name_ptr_; + using TypedBase::GetTryCallPtr; + using TypedBase::kNumArgs; + using TypedBase::kSeq; + using TypedBase::TypedBase; // constructors + using typename OverloadBase::FnPtr; + using typename TypedBase::Ret; + + void Register(std::unique_ptr overload) final { + const auto fptr = overload->GetTryCallPtr(); + overloads_.emplace_back(std::move(overload), fptr); + } + + void operator()(const AnyView* args, int32_t num_args, Any* rv) { + // fast path: only add a little overhead when no overloads + if (overloads_.size() == 0) { + return unpack_call(kSeq, name_ptr_, f_, args, num_args, rv); + } + + // this can be inlined by compiler, don't worry + if (this->TryCall(args, num_args, rv)) return; + + // virtual calls cannot be inlined, so we fast check the num_args first + // we also de-virtualize the fptr to reduce one more indirection + for (const auto& [overload, fptr] : overloads_) { + if (overload->num_args_ != num_args) continue; + if (fptr(overload.get(), args, num_args, rv)) return; + } + + this->HandleOverloadFailure(args, num_args); + } + + private: + void HandleOverloadFailure(const AnyView* args, int32_t num_args) { + std::ostringstream oss; + int32_t i = 0; + oss << "Overload #" << i++ << ": "; + this->GetMismatchMessage(oss, args, num_args); + for (const auto& [overload, _] : overloads_) { + oss << "\nOverload #" << i++ << ": "; + overload->GetMismatchMessage(oss, args, num_args); + } + TVM_FFI_THROW(TypeError) << "No matching overload found when calling: `" << name_ << "` with " + << num_args << " arguments:\n" + << std::move(oss).str(); + } + using TypedBase::f_; + std::vector, FnPtr>> overloads_; +}; + +} // namespace details + +/*! \brief Reflection namespace */ +namespace reflection { + +/*! + * \brief Helper to register Object's reflection metadata. + * \tparam Class The class type. + * + * \code + * namespace refl = tvm::ffi::reflection; + * refl::ObjectDef().def_ro("my_field", &MyClass::my_field); + * \endcode + */ +template +class OverloadObjectDef : private ObjectDef { + public: + using Super = ObjectDef; + /*! + * \brief Constructor + * \tparam ExtraArgs The extra arguments. + * \param extra_args The extra arguments. + */ + template + explicit OverloadObjectDef(ExtraArgs&&... extra_args) + : Super(std::forward(extra_args)...) {} + + /*! + * \brief Define a readonly field. + * + * \tparam Class The class type. + * \tparam T The field type. + * \tparam Extra The extra arguments. + * + * \param name The name of the field. + * \param field_ptr The pointer to the field. + * \param extra The extra arguments that can be docstring or default value. + * + * \return The reflection definition. + */ + template + TVM_FFI_INLINE OverloadObjectDef& def_ro(const char* name, T BaseClass::* field_ptr, + Extra&&... extra) { + /// NOTE: we don't allow properties to be overloaded + Super::def_ro(name, field_ptr, std::forward(extra)...); + return *this; + } + + /*! + * \brief Define a read-write field. + * + * \tparam Class The class type. + * \tparam T The field type. + * \tparam Extra The extra arguments. + * + * \param name The name of the field. + * \param field_ptr The pointer to the field. + * \param extra The extra arguments that can be docstring or default value. + * + * \return The reflection definition. + */ + template + TVM_FFI_INLINE OverloadObjectDef& def_rw(const char* name, T BaseClass::* field_ptr, + Extra&&... extra) { + /// NOTE: we don't allow properties to be overloaded + Super::def_rw(name, field_ptr, std::forward(extra)...); + return *this; + } + + /*! + * \brief Define a method. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the method. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring. + * + * \return The reflection definition. + */ + template + TVM_FFI_INLINE OverloadObjectDef& def(const char* name, Func&& func, Extra&&... extra) { + RegisterMethod(name, false, std::forward(func), std::forward(extra)...); + return *this; + } + + /*! + * \brief Define a static method. + * + * \tparam Func The function type. + * \tparam Extra The extra arguments. + * + * \param name The name of the method. + * \param func The function to be registered. + * \param extra The extra arguments that can be docstring. + * + * \return The reflection definition. + */ + template + TVM_FFI_INLINE OverloadObjectDef& def_static(const char* name, Func&& func, Extra&&... extra) { + RegisterMethod(name, true, std::forward(func), std::forward(extra)...); + return *this; + } + + /*! + * \brief Register a constructor for this object type. + * + * This method registers a static `__init__` method that constructs an instance + * of the object with the specified argument types. The constructor can be invoked + * from Python or other FFI bindings. + * + * \tparam Args The argument types for the constructor. + * \tparam Extra Additional arguments (e.g., docstring). + * + * \param init_func An instance of `init` specifying constructor signature. + * \param extra Optional additional metadata such as docstring. + * + * \return Reference to this `ObjectDef` for method chaining. + * + * Example: + * \code + * refl::ObjectDef() + * .def(refl::init(), "Constructor docstring"); + * \endcode + */ + template + TVM_FFI_INLINE OverloadObjectDef& def([[maybe_unused]] init init_func, + Extra&&... extra) { + RegisterMethod(kInitMethodName, true, &init::template execute, + std::forward(extra)...); + return *this; + } + + private: + using ReflectionDefBase::ApplyExtraInfoTrait; + using ReflectionDefBase::WrapFunction; + using Super::kInitMethodName; + using Super::type_index_; + using Super::type_key_; + + template + static auto GetOverloadMethod(std::string name, Func&& func) { + using WrapFn = decltype(WrapFunction(std::forward(func))); + using OverloadFn = details::OverloadedFunction>; + return ffi::Function::FromPackedInplace(WrapFunction(std::forward(func)), + std::move(name)); + } + + template + static auto NewOverload(std::string name, Func&& func) { + return details::CreateNewOverload(WrapFunction(std::forward(func)), std::move(name)); + } + + template + void RegisterExtraInfo(ExtraArgs&&... extra_args) { + TVMFFITypeMetadata info; + info.total_size = sizeof(Class); + info.structural_eq_hash_kind = Class::_type_s_eq_hash_kind; + info.creator = nullptr; + info.doc = TVMFFIByteArray{nullptr, 0}; + if constexpr (std::is_default_constructible_v) { + info.creator = ReflectionDefBase::ObjectCreatorDefault; + } else if constexpr (std::is_constructible_v) { + info.creator = ReflectionDefBase::ObjectCreatorUnsafeInit; + } + // apply extra info traits + ((ApplyExtraInfoTrait(&info, std::forward(extra_args)), ...)); + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMetadata(type_index_, &info)); + } + + template + void RegisterField(const char* name, T BaseClass::* field_ptr, bool writable, + ExtraArgs&&... extra_args) { + static_assert(std::is_base_of_v, "BaseClass must be a base class of Class"); + FieldInfoBuilder info; + info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; + info.field_static_type_index = TypeToFieldStaticTypeIndex::value; + // store byte offset and setter, getter + // so the same setter can be reused for all the same type + info.offset = GetFieldByteOffsetToObject(field_ptr); + info.size = sizeof(T); + info.alignment = alignof(T); + info.flags = 0; + if (writable) { + info.flags |= kTVMFFIFieldFlagBitMaskWritable; + } + info.getter = ReflectionDefBase::FieldGetter; + info.setter = ReflectionDefBase::FieldSetter; + // initialize default value to nullptr + info.default_value = AnyView(nullptr).CopyToTVMFFIAny(); + info.doc = TVMFFIByteArray{nullptr, 0}; + info.metadata_.emplace_back("type_schema", details::TypeSchema::v()); + // apply field info traits + ((ApplyFieldInfoTrait(&info, std::forward(extra_args)), ...)); + // call register + std::string metadata_str = Metadata::ToJSON(info.metadata_); + info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterField(type_index_, &info)); + } + + // register a method + template + void RegisterMethod(const char* name, bool is_static, Func&& func, Extra&&... extra) { + using FuncInfo = details::FunctionInfo>; + MethodInfoBuilder info; + info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; + info.doc = TVMFFIByteArray{nullptr, 0}; + info.flags = 0; + if (is_static) { + info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod; + } + + auto method_name = std::string(type_key_) + "." + name; + + // if an overload method exists, register to existing overload function + if (const auto overload_it = registered_fields_.find(name); + overload_it != registered_fields_.end()) { + details::OverloadBase* overload_ptr = overload_it->second; + return overload_ptr->Register(NewOverload(std::move(method_name), std::forward(func))); + } + + // first time registering overload method + auto [method, overload_ptr] = + GetOverloadMethod(std::move(method_name), std::forward(func)); + registered_fields_.try_emplace(name, overload_ptr); + + info.method = AnyView(method).CopyToTVMFFIAny(); + info.metadata_.emplace_back("type_schema", FuncInfo::TypeSchema()); + // apply method info traits + ((ApplyMethodInfoTrait(&info, std::forward(extra)), ...)); + std::string metadata_str = Metadata::ToJSON(info.metadata_); + info.metadata = TVMFFIByteArray{metadata_str.c_str(), metadata_str.size()}; + TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterMethod(type_index_, &info)); + } + + std::unordered_map registered_fields_; +}; + +} // namespace reflection +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXTRA_OVERLOAD_H diff --git a/include/tvm/ffi/reflection/registry.h b/include/tvm/ffi/reflection/registry.h index 3224a9fd..1dc22aeb 100644 --- a/include/tvm/ffi/reflection/registry.h +++ b/include/tvm/ffi/reflection/registry.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -36,6 +37,7 @@ #include #include #include +#include #include #include @@ -94,6 +96,8 @@ class Metadata : public InfoTrait { friend class GlobalDef; template friend class ObjectDef; + template + friend class OverloadObjectDef; /*! * \brief Move metadata into a vector of key-value pairs. * \param out The output vector. @@ -270,52 +274,49 @@ class ReflectionDefBase { } } + template + TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) { + return ffi::Function::FromTyped(WrapFunction(std::forward(func)), std::move(name)); + } + + template + TVM_FFI_INLINE static Func&& WrapFunction(Func&& func) { + return std::forward(func); + } template - TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...)) { + TVM_FFI_INLINE static auto WrapFunction(R (Class::*func)(Args...)) { static_assert(std::is_base_of_v || std::is_base_of_v, "Class must be derived from ObjectRef or Object"); if constexpr (std::is_base_of_v) { - auto fwrap = [func](Class target, Args... params) -> R { + return [func](Class target, Args... params) -> R { // call method pointer return (target.*func)(std::forward(params)...); }; - return ffi::Function::FromTyped(fwrap, std::move(name)); } - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class* target, Args... params) -> R { + return [func](const Class* target, Args... params) -> R { // call method pointer return (const_cast(target)->*func)(std::forward(params)...); }; - return ffi::Function::FromTyped(fwrap, std::move(name)); } } - template - TVM_FFI_INLINE static Function GetMethod(std::string name, R (Class::*func)(Args...) const) { + TVM_FFI_INLINE static auto WrapFunction(R (Class::*func)(Args...) const) { static_assert(std::is_base_of_v || std::is_base_of_v, "Class must be derived from ObjectRef or Object"); if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class& target, Args... params) -> R { + return [func](const Class& target, Args... params) -> R { // call method pointer return (target.*func)(std::forward(params)...); }; - return ffi::Function::FromTyped(fwrap, std::move(name)); } - if constexpr (std::is_base_of_v) { - auto fwrap = [func](const Class* target, Args... params) -> R { + return [func](const Class* target, Args... params) -> R { // call method pointer return (target->*func)(std::forward(params)...); }; - return ffi::Function::FromTyped(fwrap, std::move(name)); } } - - template - TVM_FFI_INLINE static Function GetMethod(std::string name, Func&& func) { - return ffi::Function::FromTyped(std::forward(func), std::move(name)); - } }; /// \endcond @@ -438,6 +439,8 @@ struct init { // Allow ObjectDef to access the execute function template friend class ObjectDef; + template + friend class OverloadObjectDef; /*! * \brief Constructor @@ -585,6 +588,9 @@ class ObjectDef : public ReflectionDefBase { } private: + template + friend class OverloadObjectDef; + template void RegisterExtraInfo(ExtraArgs&&... extra_args) { TVMFFITypeMetadata info; @@ -643,6 +649,7 @@ class ObjectDef : public ReflectionDefBase { if (is_static) { info.flags |= kTVMFFIFieldFlagBitMaskIsStaticMethod; } + // obtain the method function Function method = GetMethod(std::string(type_key_) + "." + name, std::forward(func)); info.method = AnyView(method).CopyToTVMFFIAny(); diff --git a/tests/cpp/test_overload.cc b/tests/cpp/test_overload.cc new file mode 100644 index 00000000..7dfb9c70 --- /dev/null +++ b/tests/cpp/test_overload.cc @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +using namespace tvm::ffi; + +struct TestOverloadObj : public Object { + explicit TestOverloadObj(int32_t x) : type(Type::INT) {} + explicit TestOverloadObj(float y) : type(Type::FLOAT) {} + + static int AddOneInt(int x) { return x + 1; } + static float AddOneFloat(float x) { return x + 1.0f; } + + template + auto Holds(T) const { + if constexpr (std::is_same_v) { + return type == Type::INT; + } else if constexpr (std::is_same_v) { + return type == Type::FLOAT; + } else { + static_assert(sizeof(T) == 0, "Unsupported type"); + } + } + + enum class Type { INT, FLOAT } type; + TVM_FFI_DECLARE_OBJECT_INFO("test.TestOverloadObj", TestOverloadObj, Object); +}; + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::OverloadObjectDef() + .def(refl::init()) + .def(refl::init()) + .def("hold_same_type", &TestOverloadObj::Holds) + .def("hold_same_type", &TestOverloadObj::Holds) + .def_static("add_one_static", &TestOverloadObj::AddOneInt) + .def_static("add_one_static", &TestOverloadObj::AddOneFloat); +} + +TEST(Reflection, CallOverloadedInitMethod) { + Function init_method = reflection::GetMethod("test.TestOverloadObj", "__ffi_init__"); + Any obj_a = init_method(10); // choose the int constructor + EXPECT_TRUE(obj_a.as() != nullptr); + EXPECT_EQ(obj_a.as()->type, TestOverloadObj::Type::INT); + Any obj_b = init_method(3.14f); // choose the float constructor + EXPECT_TRUE(obj_b.as() != nullptr); + EXPECT_EQ(obj_b.as()->type, TestOverloadObj::Type::FLOAT); +} + +TEST(Reflection, CallOverloadedMethod) { + Function init_method = reflection::GetMethod("test.TestOverloadObj", "__ffi_init__"); + Function hold_same_type = reflection::GetMethod("test.TestOverloadObj", "hold_same_type"); + Any obj_a = init_method(10); // choose the int constructor + Any res_a = hold_same_type(obj_a, 20); + EXPECT_EQ(res_a.as(), true); + Any res_b = hold_same_type(obj_a, 3.14f); + EXPECT_EQ(res_b.as(), false); +} + +TEST(Reflection, CallOverloadedStaticMethod) { + Function add_one = reflection::GetMethod("test.TestOverloadObj", "add_one_static"); + Any res_a = add_one(20); + EXPECT_EQ(res_a.as(), 21); + Any res_b = add_one(1.0f); + static_assert(1.0f + 1.0f == 2.0f); + EXPECT_EQ(res_b.as(), 2.0f); +} + +} // namespace