Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions include/tvm/ffi/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,20 @@
#define TVM_FFI_DLL_EXPORT_INCLUDE_METADATA 0
#endif

#if TVM_FFI_DLL_EXPORT_INCLUDE_METADATA
#include <sstream>
#endif // TVM_FFI_DLL_EXPORT_INCLUDE_METADATA

#include <tvm/ffi/any.h>
#include <tvm/ffi/base_details.h>
#include <tvm/ffi/c_api.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/function_details.h>

#include <functional>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -165,21 +171,19 @@ class FunctionObjImpl : public FunctionObj {

/*!
* \brief Derived object class for constructing ffi::FunctionObj.
* \param callable The type-erased callable object (rvalue).
*/
explicit FunctionObjImpl(TCallable&& callable) : callable_(std::move(callable)) {
this->safe_call = SafeCall;
this->cpp_call = reinterpret_cast<void*>(CppCall);
}
/*!
* \brief Derived object class for constructing ffi::FunctionObj.
* \param callable The type-erased callable object (lvalue).
* \param args The arguments to construct TCallable
*/
explicit FunctionObjImpl(const TCallable& callable) : callable_(callable) {
template <typename... Args>
explicit FunctionObjImpl(Args&&... args) : callable_(std::forward<Args>(args)...) {
this->safe_call = SafeCall;
this->cpp_call = reinterpret_cast<void*>(CppCall);
}

FunctionObjImpl(const FunctionObjImpl&) = delete;
FunctionObjImpl& operator=(const FunctionObjImpl&) = delete;

TCallable* GetCallable() { return &callable_; }

private:
// implementation of call
static void CppCall(const FunctionObj* func, const AnyView* args, int32_t num_args, Any* result) {
Expand Down Expand Up @@ -356,6 +360,29 @@ class Function : public ObjectRef {
}
}

/*!
* \brief Constructing a packed function from a callable type
* whose signature is consistent with `ffi::Function`.
* It will create the Callable object with the given arguments,
* and return the inplace constructed Function along with
* the pointer to the callable object. The lifetime of the callable
* object is managed by the returned Function.
* \param args The arguments to construct TCallable
* \return A tuple of (Function, TCallable*)
*/
template <typename TCallable, typename... Args>
static auto FromPackedInplace(Args&&... args) {
// We must ensure TCallable is a value type (decay_t) that can hold the callable object
static_assert(std::is_same_v<TCallable, std::decay_t<TCallable>>);
static_assert(std::is_invocable_v<TCallable, const AnyView*, int32_t, Any*>);
using ObjType = details::FunctionObjImpl<TCallable>;
Function func;
auto obj_ptr = make_object<ObjType>(std::forward<Args>(args)...);
auto* call_ptr = obj_ptr->GetCallable();
func.data_ = std::move(obj_ptr);
return std::make_tuple(std::move(func), call_ptr);
}

/*!
* \brief Create ffi::Function from a C style callbacks.
*
Expand Down
Loading