forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
decomposition_registry_util.cpp
108 lines (99 loc) · 3.18 KB
/
decomposition_registry_util.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
/**
* @generated
* This is an auto-generated file. Please do not modify it by hand.
* To re-generate, please run:
* cd ~/pytorch && python torchgen/decompositions/gen_jit_decompositions.py
*/
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/decomposition_registry_util.h>
#include <torch/csrc/jit/runtime/operator.h>
namespace torch {
namespace jit {
const std::string decomp_funcs =
R"(def var_decomposition(input: Tensor,
dim: Optional[List[int]]=None,
correction: Union[float, int, NoneType, bool]=None,
keepdim: bool=False) -> Tensor:
_0 = uninitialized(float)
if torch.__is__(dim, None):
dim0 = annotate(List[int], [])
else:
dim0 = unchecked_cast(List[int], dim)
if torch.eq(torch.len(dim0), 0):
n = torch.numel(input)
else:
n0 = 1
for _1 in range(torch.len(dim0)):
dim_i = dim0[_1]
n1 = torch.mul(n0, (torch.size(input))[dim_i])
n0 = n1
n = n0
mean = torch.mean(input, dim0, True)
sub = torch.sub(input, mean)
sq = torch.mul(sub, sub)
sum = torch.sum(sq, dim0, keepdim)
if torch.__is__(correction, None):
denom = float(torch.sub(n, 1))
else:
correction0 = unchecked_cast(Union[float, int, bool], correction)
_2 = isinstance(correction0, int)
if _2:
correction1 = unchecked_cast(int, correction0)
denom0 = float(torch.sub(n, correction1))
else:
correction2 = unchecked_cast(Union[float, bool], correction0)
_3 = isinstance(correction2, float)
if _3:
correction3 = unchecked_cast(float, correction2)
denom2 = torch.sub(float(n), correction3)
denom1 = denom2
else:
ops.prim.RaiseException("correction must be int or float", "builtins.RuntimeError")
denom1 = _0
denom0 = denom1
denom = denom0
_4 = torch.div(sum, ops.prim.max(0, denom))
return _4
def var(input: Tensor,
unbiased: bool=True) -> Tensor:
if unbiased:
_0 = 1
else:
_0 = 0
_1 = uninitialized(float)
n = torch.numel(input)
mean = torch.mean(input, annotate(List[int], []), True)
sub = torch.sub(input, mean)
sq = torch.mul(sub, sub)
sum = torch.sum(sq, annotate(List[int], []))
_2 = isinstance(_0, int)
if _2:
denom = float(torch.sub(n, _0))
else:
correction = unchecked_cast(Union[float, bool], _0)
_3 = isinstance(correction, float)
if _3:
correction0 = unchecked_cast(float, correction)
denom0 = torch.sub(float(n), correction0)
else:
ops.prim.RaiseException("correction must be int or float", "builtins.RuntimeError")
denom0 = _1
denom = denom0
_4 = torch.div(sum, ops.prim.max(0, denom))
return _4
)";
const std::string& GetSerializedDecompositions() {
return decomp_funcs;
}
const OperatorMap<std::string>& GetDecompositionMapping() {
// clang-format off
static const OperatorMap<std::string> decomposition_mapping {
{"aten::var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "var_decomposition"},
{"aten::var(Tensor self, bool unbiased=True) -> Tensor", "var"},
};
// clang-format on
return decomposition_mapping;
}
} // namespace jit
} // namespace torch