Skip to content

Commit 17df15a

Browse files
committed
PR1: Fixing allocation logic
1. refactor buffer allocation buffer to use allocation domain, intead of logical domain. 2. fixing projection from allocation to logical special path when projection is not possible: We now compute correct extent instead of returning the allocation buffer as-is, this allows that layout op to return a tensor with the correct logical size, while still allocating a large enough buffer to accommodate the padding requirement.
1 parent bf85c0b commit 17df15a

File tree

2 files changed

+40
-9
lines changed

2 files changed

+40
-9
lines changed

csrc/runtime/allocations.cpp

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -273,13 +273,29 @@ KernelArgumentHolder allocateOutputs(
273273
for (auto out_idx : arange(output_infos.size())) {
274274
auto out_info = output_infos.at(out_idx);
275275
if (output_alias_to_input_map.at(out_idx) == -1) {
276-
auto alloc_tensor = at::native::empty_strided_cuda(
277-
out_info.shape_info.logical_sizes,
278-
out_info.shape_info.logical_strides,
279-
out_info.type,
280-
c10::nullopt,
281-
device,
282-
c10::nullopt);
276+
at::Tensor alloc_tensor;
277+
if (!out_info.shape_info.allocation_sizes.empty()) {
278+
// allocate based on allocation size & stride and restride with logical
279+
// size & stride afterwards.
280+
alloc_tensor = at::native::empty_strided_cuda(
281+
out_info.shape_info.allocation_sizes,
282+
out_info.shape_info.allocation_strides,
283+
out_info.type,
284+
c10::nullopt,
285+
device,
286+
c10::nullopt);
287+
alloc_tensor = alloc_tensor.as_strided_(
288+
out_info.shape_info.logical_sizes,
289+
out_info.shape_info.logical_strides);
290+
} else {
291+
alloc_tensor = at::native::empty_strided_cuda(
292+
out_info.shape_info.logical_sizes,
293+
out_info.shape_info.logical_strides,
294+
out_info.type,
295+
c10::nullopt,
296+
device,
297+
c10::nullopt);
298+
}
283299
if (shouldFillAllocationWithNan()) {
284300
fillTensorWithNan(alloc_tensor);
285301
}
@@ -741,13 +757,22 @@ at::Tensor transformFromAllocationToLogical(
741757
.run(logical, alloc);
742758
NVF_ERROR(frontier.size() == logical.size());
743759

744-
// give up on producing right shape/stride when allocation domain has
760+
// give up on producing right stride when allocation domain has
745761
// transformation that cannot be represented via permutation. This is
746762
// currently used by PreprocessGroupedMatmulInputSf, where output is padded.
747763
std::set<IterDomain*> frontier_set(frontier.begin(), frontier.end());
748764
std::set<IterDomain*> logical_set(logical.begin(), logical.end());
749765
if (frontier_set != logical_set) {
750-
return tensor;
766+
std::vector<int64_t> logical_sizes(logical.size(), 0);
767+
std::vector<int64_t> logical_strides(logical.size(), 0);
768+
int64_t cur_stride = 1;
769+
for (const auto&& [i, id] : enumerate(logical) | std::views::reverse) {
770+
int64_t cur_size = ee.evaluate(id->extent()).as<int64_t>();
771+
logical_sizes[i] = cur_size;
772+
logical_strides[i] = cur_stride;
773+
cur_stride *= cur_size;
774+
}
775+
return tensor.as_strided(logical_sizes, logical_strides);
751776
}
752777

753778
// Now that all affine transformations are handled, and frontiers should

tests/cpp/test_layout_op.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,14 @@ bool validateGroupedLayout(
2727
NVF_ERROR(BlockScalingFactorLayout::Block128x4 == layout);
2828
int num_group = expert_offsets.size(0) - 1;
2929

30+
// validate output logical shape
31+
EXPECT_EQ(out.sizes(), ref.sizes());
32+
3033
// take length of reference for un-padded k size.
3134
int k = ref.size(1);
35+
int padded_k = (k + 4 - 1) / 4 * 4;
36+
int padded_m = sf_offsets[num_group].item().to<int>();
37+
out.as_strided_({padded_m, padded_k}, {padded_k, 1});
3238

3339
// We validate each group individually
3440
for (int i = 0; i < num_group; ++i) {

0 commit comments

Comments
 (0)