forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FunctionalStorageImpl.cpp
125 lines (114 loc) · 5.03 KB
/
FunctionalStorageImpl.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/EmptyTensor.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <c10/util/Exception.h>
#include <vector>
namespace at {
namespace functionalization {
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
if (out_idx == this->out_index) return *this;
return ViewMeta(forward_fn, reverse_fn, out_idx);
}
// Note [Functionalization: Alias Removal Part 2]
// See Note [Functionalization: Alias Removal] for more details.
// This function applies a single update from one of the views to the StorageImpl.
// We start out with <original_base> and <mutated_view>, and our goal is to end up with <mutated_base>.
// Consider this program:
//
// base = ...
// a = base.view1()
// b = a.view2()
// c = b.view3()
// c.add_(3)
//
// Then the functionalization pass will queue an update as follows:
//
// update.new_val = c # the updated value of c
// update.view_metas = [view1_meta, view2_meta, view3_meta]
//
// Syncing any of a, b or c will eventually call apply_update() on the storage, and the following will run:
//
// tmp_values = [base, a, b] # NB: c is not necessary
// t = update.new_val
// t = view3_inverse(b, t, 0) # 0 is output index, these are all single output views so it's 0
// t = view2_inverse(a, t, 0)
// t = view1_inverse(base, t, 0) # t now represents the updated storage.
// storage.base_ = t
static const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) {
at::Tensor t = update.new_val;
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
if (update.view_metas.empty()) return t;
std::vector<at::Tensor> tmp_values({base});
tmp_values.reserve(update.view_metas.size());
for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index);
// NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
// All of these ops require additional information to recover the sizes of the original tensor.
// If need to, we could probably apply this optimization and only bother computing tmp_values
// for those necessary view ops.
tmp_values.push_back(std::move(next_view));
}
for(int i = update.view_metas.size()-1; i >= 0; --i) {
int64_t out_idx = update.view_metas[i].out_index;
// Each view inverse is implemented in ViewInverses.cpp.
t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx);
}
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
return t;
}
static c10::SymInt get_nbytes(const Tensor& value) {
// The functionalization story when wrapping tensors that don't have storage
// is a bit wonky, but fortunately for some models (e.g., dlrm) we never
// actually perform mutations on these tensors, so you never really get
// called out on it. For now, functionalization still creates "storages"
// for these tensors (which is wrong), but we don't give them any space.
// A more proper fix would be to have a SparseFunctionalTensorWrapper that
// models sparse correctly.
if (value.is_sparse()) {
return 0;
}
if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
// Today, the two implementations of SymInt are in Python (proxy tensor),
// and lazy tensor (LTC/XLA).
// LTC hasn't implemented SymInt support yet though
// Once it does, we should remove this check.
if (value.key_set().has(c10::DispatchKey::Python)) {
return value.storage().sym_nbytes();
}
}
// XLA storage objects also do not properly track nbytes.
return at::detail::computeStorageNbytes(value.sizes(), value.strides(), value.dtype().itemsize(), value.storage_offset());
}
FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
: c10::StorageImpl(
c10::StorageImpl::use_byte_size_t(),
get_nbytes(base),
DataPtr{nullptr, base.device()},
GetAllocator(kMeta),
/*resizeable=*/true
),
base_(base)
{
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
}
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) {
TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage");
updates_.push_back({updated_val, metas});
generation_++;
}
bool FunctionalStorageImpl::apply_updates() {
// N.B:none of the tensors used in this function should be FunctionalTensorWrappers at this point.
// The only reason we currently need the TLS exclude guard here is because of functorch's DynamicLayer stack.
// It adds the Functionalize key into TLS before redispatching to the functionalization kernels,
// which means that we need to explicitly exclude it here before doing any other work underneath the pass.
at::AutoDispatchSkipFunctionalize guard;
bool any_updates = !updates_.empty();
for (auto& update_data: updates_) {
base_ = apply_update(update_data, base_);
}
updates_.clear();
return any_updates;
}
} // namespace functionalization
} // namespace at