forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathShape.cu
411 lines (357 loc) · 13.6 KB
/
Shape.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/Dispatch.h>
#include <c10/core/MemoryFormat.h>
#include <c10/util/Optional.h>
#include <THC/THC.h>
namespace at {
namespace native {
constexpr int CAT_ARRAY_BATCH_SIZE = 1024;
constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4;
namespace {
inline bool getCatGrid(ptrdiff_t nTensors, dim3& grid) {
const int numSM = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
//X dim of grid for cat array cooperates on a single tensor in the cat.
//Given half of the GPU, full utilization will always occur.
grid = dim3( 2LL * numSM, (long long) nTensors );
return true;
}
// Similar to any other IndexToOffset calculation for copying along a given
// dimension.
template <typename IndexType, int Dims>
struct CatArrIndexToOffset {
static inline __device__ IndexType compute(
const IndexType outputSize[Dims],
const IndexType outputStride[Dims],
const IndexType dimSize,
const unsigned int concatDim,
IndexType linearIndex) {
// linearIndex is not really linear index, but instead the offset in
// input tensor. If the input tensor is contiguous, then this offset
// is the linear index, but if the input tensor is channels last, then
// it is the linear index of the permuted contiguous tensor
IndexType offset = 0;
#pragma unroll
for (int i = Dims - 1; i >= 1; --i) {
IndexType curDimSize = i == concatDim ? dimSize : outputSize[i];
IndexType nextDimIndex = linearIndex / curDimSize;
IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex;
IndexType curDimOffset = curDimIndex * outputStride[i];
offset += curDimOffset;
linearIndex = nextDimIndex;
}
return offset + linearIndex * outputStride[0];
}
};
template <typename T, typename IndexType>
struct CatArrInputTensor {
T* input;
IndexType offset;
IndexType dimSize;
IndexType nElements;
};
template<typename IndexType, unsigned int MaxDims>
struct OutputTensorSizeStride {
IndexType outputSize[MaxDims];
IndexType outputStride[MaxDims];
};
/**
* Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a
* grid-stride loop based off of the blockIdx.x, threadIdx.x for each input to
* copy each element from each input tensor into the output.
*
* output: base pointer to the storage associated with the output tensor
* inputs: GPU-allocated array of input metadata for each input to concatenate
* in the kernel
* os: the size/stride vectors for the output tensor
* concatDim: dimension along which we are concatenating
* dimStride: the stride of the output tensor at the concatDim
*
* The most important assumption made is that the input tensors are contiguous.
*/
template <typename T, typename IndexType, int Dims>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(512)
#endif
__global__ void CatArrayBatchedCopy(
T* output,
CatArrInputTensor<T, IndexType>* inputs,
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim,
IndexType dimStride) {
IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
IndexType nElements = inputs[blockIdx.y].nElements;
if(tid >= nElements) return;
T* data = inputs[blockIdx.y].input;
IndexType offset = inputs[blockIdx.y].offset;
IndexType dimSize = inputs[blockIdx.y].dimSize;
IndexType dataOffset = offset * dimStride;
IndexType stride = gridDim.x * blockDim.x;
while( tid < nElements){
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.outputSize, os.outputStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[tid];
tid += stride;
}
}
void check_shape_except_dim(const Tensor &first, const Tensor &second,
int dimension)
{
int first_dims = first.dim();
int second_dims = second.dim();
TORCH_CHECK(first_dims == second_dims,
"Tensors must have same number of dimensions: got ", first_dims,
" and ", second_dims);
for (int dim = 0; dim < first_dims; dim++) {
if (dim == dimension) {
continue;
}
int64_t first_dim_size = at::native::size(first, dim);
int64_t second_dim_size = at::native::size(second, dim);
TORCH_CHECK(first_dim_size == second_dim_size,
"Sizes of tensors must match except in dimension ", dim, ". Got ",
static_cast<long long>(first_dim_size), " and ",
static_cast<long long>(second_dim_size));
}
}
template <typename scalar_t>
void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
int nDims, c10::MemoryFormat memory_format) {
// First, let's set up our kernel parameters. We start with a raw pointer to
// the storage for the output Tensor.
scalar_t *data = out.data_ptr<scalar_t>();
// Kernel Parameter
long tensorMetadataSize =
sizeof(CatArrInputTensor<scalar_t, unsigned int>) * CAT_ARRAY_BATCH_SIZE;
auto d_inputs_storage = at::empty(
{tensorMetadataSize}, out.options().dtype(at::kByte));
auto d_inputs = static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
d_inputs_storage.data_ptr());
OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;
// Next, let's initialize the size, stride arrays for the output Tensor.
if (memory_format == c10::MemoryFormat::Contiguous) {
for (int i = 0; i < nDims; ++i) {
param.outputSize[i] = at::native::size(out, i);
param.outputStride[i] = out.stride(i);
}
} else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
// permute the semantics of dims from NCHW to NHWC so that the input
// tensor is now contiguous
param.outputSize[0] = at::native::size(out, 0);
param.outputStride[0] = out.stride(0);
for (int i = 1; i < nDims - 1; ++i) {
param.outputSize[i] = at::native::size(out, i + 1);
param.outputStride[i] = out.stride(i + 1);
}
param.outputSize[nDims - 1] = at::native::size(out, 1);
param.outputStride[nDims - 1] = out.stride(1);
} else {
TORCH_CHECK(false, "unsupported memory format");
}
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
// Now we loop
int batchCounter = 0;
int64_t offset = 0;
for (int i = 0; i < inputs.size() ; i += CAT_ARRAY_BATCH_SIZE) {
// Re-allocate stackInputs every iteration to avoid read-after-write hazard
{
auto stackInputs_storage = at::empty({tensorMetadataSize},
out.options().dtype(at::kByte).device(at::kCPU).pinned_memory(true));
auto stackInputs =
static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
stackInputs_storage.data_ptr());
for (batchCounter = 0;
batchCounter < CAT_ARRAY_BATCH_SIZE &&
(i+batchCounter) < inputs.size();
++batchCounter) {
int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension);
stackInputs[batchCounter].input =
inputs[i+batchCounter].data_ptr<scalar_t>();
stackInputs[batchCounter].offset = offset;
stackInputs[batchCounter].dimSize = dimSize;
stackInputs[batchCounter].nElements = inputs[i+batchCounter].numel();
// update offset
offset += dimSize;
}
at::native::copy_(d_inputs_storage, stackInputs_storage,
/* non_blocking= */ true);
}
// Next, let's consider how we set our kernel launch parameters.
// We borrow from THCApply, which the kernel's internal indexing
// is based on.
dim3 applyBlock = dim3(32*16);
//Get grid where x dim fills half gpu and y dim is number of tensors.
//This will have cating two tensors fill the entire grid, but prevent
//many threads from needlessly load meta data if their sizes is small.
dim3 catGrid;
getCatGrid(batchCounter, catGrid);
if (memory_format != c10::MemoryFormat::Contiguous) {
switch (dimension) {
case 0:
break;
case 1:
dimension = nDims - dimension;
break;
default:
dimension--;
}
}
// Template Declarations for dim = 1, 2, 3, 4
#define HANDLE_CASE(DIMS) \
CatArrayBatchedCopy<scalar_t, unsigned int, DIMS><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, d_inputs, param, dimension, param.outputStride[dimension]);
switch (nDims) {
case 1:
HANDLE_CASE(1);
break;
case 2:
HANDLE_CASE(2);
break;
case 3:
HANDLE_CASE(3);
break;
case 4:
HANDLE_CASE(4);
break;
}
#undef HANDLE_CASE
AT_CUDA_CHECK(cudaGetLastError());
}
}
} // namespace
Tensor cat_cuda(TensorList inputs, int64_t dimension) {
Tensor out = at::empty({0}, inputs.front().options());
cat_out_cuda(out, inputs, dimension);
return out;
}
inline c10::MemoryFormat compute_output_memory_format(const TensorList &inputs) {
c10::optional<c10::MemoryFormat> format = c10::nullopt;
for (auto &t : inputs) {
auto f = t.suggest_memory_format();
if (!format.has_value()) {
format = f;
continue;
}
if (format.value() == f) {
continue;
}
bool contiguous = (format.value() == c10::MemoryFormat::Contiguous || f == c10::MemoryFormat::Contiguous || format.value() != f);
if (contiguous) {
return c10::MemoryFormat::Contiguous;
}
}
return format.value();
}
Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) {
// previously, size [0] tensors were the only possible empty tensors; thus, it
// wasn't possible to cat empty tensors unless all the other tensors were
// 1-dimensional, so we allowed these tensors to be "skipped". We maintain
// this behavior for backwards compatibility, but only for this specific size
// (i.e. other empty sizes are not skipped).
// FIXME: warn if this is the case
auto should_skip = [](const Tensor &t) {
return t.dim() == 1 && at::native::size(t, 0) == 0;
};
bool hasSkippedInput = false;
const Tensor *notSkippedTensor = NULL; // non-owning reference
int nDims = 0;
// Inputs cannot alias the output tensor
for (int i = 0; i < inputs.size(); i++) {
auto lap = at::get_overlap_status(out, inputs[i]);
TORCH_CHECK(lap != at::MemOverlapStatus::PARTIAL &&
lap != at::MemOverlapStatus::FULL,
"unsupported operation: the input tensors cannot refer to any "
"of the output memory locations. Found overlap in input "
"tensor ", i);
}
// Dtypes should be the same
const auto first_in_cat = inputs[0];
for (int64_t i = 1; i < inputs.size(); i++) {
TORCH_CHECK(first_in_cat.dtype() == inputs[i].dtype(),
"Expected object of scalar type ", first_in_cat.dtype(),
" but got scalar type ", inputs[i].dtype(),
" for sequence element ", i, ".");
}
for (int i = 0; i < inputs.size(); i++)
{
if (should_skip(inputs[i])) {
hasSkippedInput = true;
continue;
}
nDims = inputs[i].dim();
notSkippedTensor = &inputs[i];
}
// If all inputs are empty tensors, return an empty tensor
if (notSkippedTensor == NULL) {
return out;
}
TORCH_CHECK(inputs.size() > 0, "invalid number of inputs ", inputs.size());
TORCH_CHECK(dimension >= 0, "invalid dimension ", dimension);
for (const Tensor& t: inputs) {
TORCH_CHECK(t.device() == notSkippedTensor->device(),
"All input tensors must be on the same device. Received ",
t.device(), " and ", notSkippedTensor->device());
}
c10::MemoryFormat memory_format = compute_output_memory_format(inputs);
std::vector<int64_t> size(notSkippedTensor->sizes().vec());
// Compute size of the result in the cat dimension
int64_t cat_dim_size = 0;
for (int i = 0; i < inputs.size(); i++) {
const Tensor &tensor = inputs[i];
if (should_skip(tensor)) {
continue;
}
check_shape_except_dim(*notSkippedTensor, tensor, dimension);
cat_dim_size += at::native::size(tensor, dimension);
}
// Compute the size of the result
size[dimension] = cat_dim_size;
out.resize_(size, memory_format);
if (out.numel() == 0) {
return out;
}
// We parallelize the copy if all 6 conditions pass:
//
// 1. There is more than one input tensor
// 2. No empty inputs
// 3. The out tensor is 32-bit indexable
// 4. The number of dimensions is <= 4
// 5. All input tensors are contiguous (output tensor may be non-contig)
// 6. All input tensors can use 32-bit indexing
const bool all32BitIndexable = std::all_of(inputs.begin(), inputs.end(),
[] (const Tensor& t) {
return at::cuda::detail::canUse32BitIndexMath(t);
});
const bool allContiguous = std::all_of(inputs.begin(), inputs.end(),
[=](const Tensor& t) {
return !t.defined() || t.is_contiguous(memory_format);
});
if (inputs.size() > 1 &&
!hasSkippedInput &&
out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(out) &&
allContiguous &&
all32BitIndexable) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
out.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t>(out, inputs, dimension, nDims, memory_format);
});
} else {
int64_t offset = 0;
for (int j = 0; j < inputs.size(); j++)
{
if (should_skip(inputs[j])) continue;
int64_t dimSize = at::native::size(inputs[j], dimension);
Tensor nt = at::narrow(out, dimension, offset, dimSize);
copy_(nt, inputs[j]);
offset += dimSize;
}
}
return out;
}
} // namespace native
} // namespace at