forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
LegacyVmapTransforms.h
182 lines (158 loc) · 7.55 KB
/
LegacyVmapTransforms.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#pragma once
#include <ATen/LegacyBatchedTensorImpl.h>
#include <ATen/core/IListRef.h>
namespace at {
// This file contains abstractions used for transforming *logical* vmap
// arguments into *physical* arguments. (Keep reading for definitions of these
// terms).
// NOTE: [Logical vs physical args]
// Consider the following vmap.
// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4))
// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4],
// with batch dims 0 and 2:
// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)])
//
// We say the *logical* view of the tensor has size [3] -- tensors inside
// `func` appear to have size [3].
// However, the *physical* underlying tensor (the one passed to vmap) has size
// [2, 3, 4].
//
// This notion of logical vs physical also extends to non-tensor arguments.
// Consider the previous tensor; let's assume the user called
// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical
// dimension they are reducing over is dim 0 but the physical dim is dim 1
// (the first non-batch dimension)
// Forward declared; see NOTE: [What is a VmapPhysicalView?]
struct VmapPhysicalView;
// Most PyTorch operators take 4 or fewer inputs.
constexpr int64_t kVmapTransformStaticInputSize = 4;
using VmapPhysicalViewVec =
SmallVector<VmapPhysicalView, kVmapTransformStaticInputSize>;
// Pytorch generally advertises good performance for <= 5 dims.
// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap
// dimensions to get 8. Adjust this number as necessary
constexpr int64_t kVmapStaticDimVecSize = 8;
using VmapDimVector = SmallVector<int64_t, kVmapStaticDimVecSize>;
// NOTE: [What is an VmapTransform?]
// An *VmapTransform* converts logical views of tensors to physical views.
//
// Batching rules use VmapTransforms to convert logical arguments to
// physical arguments, then call one or more at:: operator that handles the
// physical arguments, and then converts the physical result back to a logical
// argument.
// VmapTransform for operators that take tensors with multiple batch dims.
// Given one or more logical views on Tensors, `logicalToPhysical`
// permutes all of the batch dims to the front of the tensor, aligns
// and expands the batch dims to match each other (according to their `level`),
// and returns a VmapPhysicalView on the tensor(s).
struct TORCH_API MultiBatchVmapTransform {
static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor);
static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors);
};
// VmapTransform for operators that broadcast all inputs.
// Given some logical views on Tensors, `logicalToPhysical`:
// - permutes all of the batch dims to the front of the tensors
// - aligns all the batch dims to the collective levels of all of the tensors.
// If a tensor does not have a batch dim for a vmap level, then it receives
// a size-one dimension for said level.
// - aligns the non-batch dims to have the same dimensionality, adding extra
// size-1 dimensions in between the batch dimensions and the non-batch
// dimensions so that the batch dimensions are lined up from the right.
//
// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch
// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap
// tensors of size (B, 1, 2) and (B, 3, 2).
//
// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns
// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't
// actually *need* to return a tensor of size (1, 2) for the second tensor
// because the broadcasting operation takes care of that for us, but we do
// it anyways to keep things simple.
struct TORCH_API BroadcastingVmapTransform {
static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors);
};
// Forward declared, if you're reading this file head to toe, don't worry about
// it yet.
struct VmapPhysicalToLogicalMap;
// NOTE: [What is a VmapPhysicalView?]
// VmapPhysicalView represents a physical view on a Tensor.
//
// One can use it to further convert logical dimension indices, logical shapes,
// and more to their physical variants, or convert a new (physical) tensor into
// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented).
//
// VmapPhysicalView stores a physical tensor with all of its batch dimensions at
// the front and some levels that correspond to said batch dimensions.
//
// The levels bitset specifies which vmap levels correspond to the batch
// dimensions at the front of the tensor. In particular, the number of set bits
// corresponds to the number of batch dimensions on `tensor` and the rightmost
// bit of `levels` specifies the maximum number of nested vmaps we are in at
// this point in time.
// For example, given:
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3})
//
// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less
// than or equal to 3.
// bitset: 010100
// ^
// |
// levels: 012345
struct TORCH_API VmapPhysicalView {
VmapPhysicalView(Tensor&& tensor, std::bitset<kVmapNumLevels> levels)
: levels_(levels), tensor_(tensor) {
TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
}
Tensor& tensor() {
return tensor_;
}
const Tensor& tensor() const {
return tensor_;
}
// Maps logical dim indices to physical dim indices. Also does dim wrapping.
//
// For example, given:
// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3})
//
// Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}.
// This is because the size of levels tell us that the first two dimensions
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually
// a physical dim of `n + 2`.
VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const;
int64_t getPhysicalDim(int64_t logical_dim) const;
// Returns a VmapPhysicalToLogicalMap object. This can be used for
// mapping a physical tensor to a new logical tensor (BatchedTensor)
VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const;
// Maps a logical shape to a physical shape by pre-pending the batch
// sizes to the logical shape.
VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const;
int64_t numBatchDims() const;
private:
int64_t numLogicalDims() const;
std::bitset<kVmapNumLevels> levels_;
Tensor tensor_;
};
// Convenience struct used for mapping a physical tensor (a non-BatchedTensor)
// to a logical one (BatchedTensor). It holds some levels that are used to do
// the mapping and assumes that the batch dimensions in the physical tensor all
// occur at the front of the tensor.
struct TORCH_API VmapPhysicalToLogicalMap {
VmapPhysicalToLogicalMap(std::bitset<kVmapNumLevels> levels)
: levels_(levels) {}
// Maps a physical tensor to a new logical tensor (BatchedTensor).
// Assumes that all of the "batch dimensions" are at the front
// of the physical tensor. For example, given:
// - x = rank-4 Tensor with size 2, 3, 5, 7
// - levels = (2, 4)
// Returns:
// - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)])
Tensor apply(const Tensor& physical_tensor) const;
// Given a vector of physical tensors,
// 1. maps each tensor to a new logical tensor. Assumes that all of the
// "batch dimensions" are at the front of the physical tensors.
// 2. stores the new logical tensors back into the passed-in vector. This is
// to avoid additional dynamic allocations.
void applyInplace(std::vector<Tensor>& physical_tensors) const;
std::bitset<kVmapNumLevels> levels_;
};
} // namespace at