Skip to content

Commit 64a04d2

Browse files
ezyangpytorchmergebot
authored andcommitted
Make sparse empty constructors specialize instead of fail on symbolic inputs (#129983)
Exercised in pytorch/pytorch#128327 Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch/pytorch#129983 Approved by: https://github.com/anijain2305
1 parent 7350441 commit 64a04d2

File tree

4 files changed

+43
-2
lines changed

4 files changed

+43
-2
lines changed

aten/src/ATen/native/native_functions.yaml

+4-2
Original file line numberDiff line numberDiff line change
@@ -2370,8 +2370,10 @@
23702370
MPS: empty_mps
23712371
Meta: empty_meta_symint
23722372
MkldnnCPU: empty_mkldnn
2373-
SparseCPU, SparseCUDA, SparseMeta: empty_sparse
2374-
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: empty_sparse_compressed
2373+
SparseCPU, SparseCUDA: empty_sparse
2374+
SparseMeta: empty_sparse_symint
2375+
SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
2376+
SparseCsrMeta: empty_sparse_compressed_symint
23752377
QuantizedCPU, QuantizedCUDA, QuantizedMeta: empty_unknown_quantized
23762378
tags: core
23772379

aten/src/ATen/native/sparse/SparseCsrTensor.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,17 @@ SPARSE_COMPRESSED_TENSOR(csc, kSparseCsc)
644644
SPARSE_COMPRESSED_TENSOR(bsr, kSparseBsr)
645645
SPARSE_COMPRESSED_TENSOR(bsc, kSparseBsc)
646646

647+
Tensor empty_sparse_compressed_symint(
648+
SymIntArrayRef size,
649+
std::optional<ScalarType> dtype,
650+
std::optional<Layout> layout,
651+
std::optional<Device> device,
652+
std::optional<bool> pin_memory,
653+
std::optional<MemoryFormat> optional_memory_format) {
654+
// TODO: Don't specialize
655+
return empty_sparse_compressed(C10_AS_INTARRAYREF_SLOW_ALLOC(size), dtype, layout, device, pin_memory, optional_memory_format);
656+
}
657+
647658
// Warning: ideally, torch.empty(..., layout=<sparse compressed
648659
// format>) ought to be unsupported because it does not return a valid
649660
// sparse compressed tensor without initialization of compressed

aten/src/ATen/native/sparse/SparseTensor.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,17 @@ SparseTensor new_with_dims_and_tensor_sparse_symint(
226226
/** Public creation API that dispatch to methods above **/
227227

228228
/** Empty init **/
229+
Tensor empty_sparse_symint(
230+
SymIntArrayRef size,
231+
std::optional<ScalarType> dtype,
232+
std::optional<Layout> layout,
233+
std::optional<Device> device,
234+
std::optional<bool> pin_memory,
235+
std::optional<MemoryFormat> optional_memory_format) {
236+
// TODO: Don't specialize
237+
return empty_sparse(C10_AS_INTARRAYREF_SLOW_ALLOC(size), dtype, layout, device, pin_memory, optional_memory_format);
238+
}
239+
229240
Tensor empty_sparse(
230241
IntArrayRef size,
231242
std::optional<ScalarType> dtype,

c10/core/SymIntArrayRef.h

+17
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
#include <c10/core/SymInt.h>
44
#include <c10/util/ArrayRef.h>
5+
#include <c10/util/DimVector.h>
56
#include <c10/util/Exception.h>
67
#include <c10/util/Optional.h>
8+
#include <c10/util/irange.h>
79
#include <cstdint>
810

911
namespace c10 {
@@ -45,7 +47,22 @@ inline at::IntArrayRef asIntArrayRefSlow(
4547
return asIntArrayRefUnchecked(ar);
4648
}
4749

50+
// Even slower than asIntArrayRefSlow, as it forces an allocation for a
51+
// destination int, BUT it is able to force specialization (it never errors)
52+
inline c10::DimVector asIntArrayRefSlowAlloc(
53+
c10::SymIntArrayRef ar,
54+
const char* file,
55+
int64_t line) {
56+
c10::DimVector res(ar.size(), 0);
57+
for (const auto i : c10::irange(ar.size())) {
58+
res[i] = ar[i].guard_int(file, line);
59+
}
60+
return res;
61+
}
62+
4863
#define C10_AS_INTARRAYREF_SLOW(a) c10::asIntArrayRefSlow(a, __FILE__, __LINE__)
64+
#define C10_AS_INTARRAYREF_SLOW_ALLOC(a) \
65+
c10::asIntArrayRefSlowAlloc(a, __FILE__, __LINE__)
4966

5067
// Prefer using a more semantic constructor, like
5168
// fromIntArrayRefKnownNonNegative

0 commit comments

Comments
 (0)