forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
builtin_function.h
88 lines (69 loc) · 2 KB
/
builtin_function.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#pragma once
#include <ATen/core/function.h>
#include <ATen/core/ivalue.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
#include <functional>
#include <utility>
namespace torch {
namespace jit {
struct BuiltinOpFunction : public Function {
BuiltinOpFunction(
c10::QualifiedName qualname,
c10::FunctionSchema schema,
std::function<void(Stack&)> callable,
std::string doc_string = "")
: name_(std::move(qualname)),
callable_(std::move(callable)),
schema_(std::move(schema)),
doc_string_(std::move(doc_string)) {
TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1);
}
c10::string_view doc_string() const override {
return doc_string_;
}
void run(Stack& stack) override {
callable_(stack);
}
c10::intrusive_ptr<c10::ivalue::Future> runAsync(
Stack& stack,
TaskLauncher /* not used */) override {
run(stack);
auto res = c10::make_intrusive<c10::ivalue::Future>(stack.front().type());
res->markCompleted(std::move(stack.front()));
return res;
}
const c10::QualifiedName& qualname() const override {
return name_;
}
// if this isn't yet defined, run its method_creator function
void ensure_defined() override {
// nop
}
const c10::FunctionSchema& getSchema() const override {
return schema_;
}
size_t num_inputs() const override {
return schema_.arguments().size();
}
Function& setSchema(c10::FunctionSchema schema) override {
schema_ = std::move(schema);
return *this;
}
bool call(Stack& stack, c10::optional<size_t>, c10::function_ref<void(const Code&)>) override {
run(stack);
return false;
}
bool call(Stack& stack, c10::function_ref<void(const mobile::Code&)>) override {
run(stack);
return false;
}
~BuiltinOpFunction() override = default;
private:
c10::QualifiedName name_;
std::function<void(Stack&)> callable_;
c10::FunctionSchema schema_;
std::string doc_string_;
};
} // namespace jit
} // namespace torch