Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
d828c9f
Use meta device tensor to infer contiguity for expr-eval segments
zasdfgbnm Jan 7, 2026
074e947
Merge branch 'resetContiguityFromTensor' into meta-eval
zasdfgbnm Jan 13, 2026
4ad3785
Merge branch 'resetContiguityFromTensor' into meta-eval
zasdfgbnm Jan 14, 2026
c0b50c3
Refactor output shape inference functions and update segment output h…
zasdfgbnm Jan 14, 2026
7ffdaa3
save
zasdfgbnm Jan 14, 2026
7a5b0dc
save
zasdfgbnm Jan 14, 2026
d92e5ee
save
zasdfgbnm Jan 14, 2026
074209b
save
zasdfgbnm Jan 14, 2026
68d52fb
save
zasdfgbnm Jan 14, 2026
fb0572a
save
zasdfgbnm Jan 14, 2026
08e5b6b
save
zasdfgbnm Jan 14, 2026
784ce68
enable
zasdfgbnm Jan 14, 2026
9c46183
Merge branch 'meta-eval' of github.com:NVIDIA/Fuser into meta-eval
zasdfgbnm Jan 14, 2026
93a4012
save
zasdfgbnm Jan 14, 2026
e5d4d67
fix
zasdfgbnm Jan 14, 2026
38defa9
save
zasdfgbnm Jan 14, 2026
f50e52f
save
zasdfgbnm Jan 14, 2026
5fd7496
fix
zasdfgbnm Jan 14, 2026
40782db
fix
zasdfgbnm Jan 14, 2026
53d70fe
save
zasdfgbnm Jan 14, 2026
87b00e8
save
zasdfgbnm Jan 14, 2026
4afe5b1
fix
zasdfgbnm Jan 14, 2026
447de33
fix
zasdfgbnm Jan 14, 2026
dd41424
Merge branch 'meta-eval' of github.com:NVIDIA/Fuser into meta-eval
zasdfgbnm Jan 14, 2026
831c777
save
zasdfgbnm Jan 14, 2026
d246fc6
save
zasdfgbnm Jan 14, 2026
01a011c
save
zasdfgbnm Jan 15, 2026
03fa1f5
save
zasdfgbnm Jan 15, 2026
0a2878b
Apply suggestion from @wujingyue
zasdfgbnm Jan 15, 2026
72c66dd
Merge branch 'resetContiguityFromTensor' into meta-eval
zasdfgbnm Jan 15, 2026
b151b50
save
zasdfgbnm Jan 15, 2026
4249f8b
save
zasdfgbnm Jan 15, 2026
40b5411
fix
zasdfgbnm Jan 16, 2026
c81f895
Merge branch 'resetContiguityFromTensor' into meta-eval
zasdfgbnm Jan 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions csrc/ir/composite_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,20 @@ std::vector<PolymorphicValue> MatmulOp::evaluate(
matmul_out = matmul_out.unsqueeze(rfactor_did_idx);
}

const auto& [sizes, strides] = inferShapeAndContiguousStrides(out(), ee);
auto meta_out = at::detail::empty_strided_meta(sizes, strides, a.dtype());
// Without InferContiguity, we mistakenly assume the output is contiguous.
if (!isOptionEnabled(EnableOption::InferContiguity)) {
const auto& [sizes, strides] = inferShapeAndContiguousStrides(out(), ee);
auto meta_out = at::detail::empty_strided_meta(sizes, strides, a.dtype());

if (meta_out.is_contiguous()) {
return {matmul_out};
}
if (meta_out.is_contiguous()) {
return {matmul_out};
}

auto strided_matmul_out = at::empty_strided(sizes, strides, a.options());
strided_matmul_out = strided_matmul_out.copy_(matmul_out);
return {strided_matmul_out};
auto strided_matmul_out = at::empty_strided(sizes, strides, a.options());
strided_matmul_out = strided_matmul_out.copy_(matmul_out);
return {strided_matmul_out};
}
return {matmul_out};
}

LinearOp::LinearOp(
Expand Down
1 change: 1 addition & 0 deletions csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ const std::unordered_map<std::string, EnableOption>& getEnableOptions() {
{"p2p_protocol", EnableOption::P2pProtocol},
{"multicast_protocol", EnableOption::MulticastProtocol},
{"parallel_serde", EnableOption::ParallelSerde},
{"infer_contiguity", EnableOption::InferContiguity},
};
return available_options;
}
Expand Down
1 change: 1 addition & 0 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ enum class EnableOption {
MulticastProtocol, //! Prescribe multicast protocol:
//! memcpy|multimem|batch_memcpy
ParallelSerde, //! Enable deserializing FusionExecutorCache in parallel
InferContiguity, //! Enable contiguity inference
EndOfOption //! Placeholder for counting the number of elements
};

Expand Down
4 changes: 2 additions & 2 deletions csrc/runtime/allocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@

namespace nvfuser {

KernelArgumentHolder inferOutputShapeAndContiguousStrides(
KernelArgumentHolder inferContiguousOutputMetaTensor(
Fusion* fusion,
const KernelArgumentHolder& args,
PrecomputedValues* evaluator_precomputed_values) {
FUSER_PERF_SCOPE(
"fusion_executor::allocations::inferOutputShapeAndContiguousStrides");
"fusion_executor::allocations::inferContiguousOutputMetaTensor");
ExpressionEvaluator expr_eval;

std::unique_ptr<PrecomputedValues> evaluator_precomputed_values_up = nullptr;
Expand Down
2 changes: 1 addition & 1 deletion csrc/runtime/allocations.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct GlobalBufferInfo {
//! pushing scalar int 0 as a place-holder.
//! 2. This API does not allocate output in memory, but only returns the
//! inferred output sizes. Used in runtime/fusion_executor_cache.cpp.
KernelArgumentHolder inferOutputShapeAndContiguousStrides(
KernelArgumentHolder inferContiguousOutputMetaTensor(
Fusion* fusion,
const KernelArgumentHolder& args,
PrecomputedValues* evaluator_precomputed_values = nullptr);
Expand Down
1 change: 1 addition & 0 deletions csrc/runtime/fusion_cache_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include <fusion_segmenter.h>
#include <ir/all_nodes.h>
#include <ir/utils.h>
#include <polymorphic_value.h>
#include <runtime/executor_kernel_arg.h>

Expand Down
76 changes: 68 additions & 8 deletions csrc/runtime/fusion_kernel_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,65 @@ KernelArgumentHolder FusionKernelRuntime::runWithInputs(
return fusion_outputs;
}

KernelArgumentHolder FusionKernelRuntime::inferOutputMetaTensor(
HeuristicParamsList* heuristics,
SegmentedGroup* group_to_run,
const KernelArgumentHolder& group_runtime_inputs,
PrecomputedValues* evaluator_precomputed_values) const {
FUSER_PERF_SCOPE("FusionKernelRuntime::inferOutputMetaTensor");
NVF_ERROR(heuristics != nullptr);
Fusion* fusion_to_run = group_to_run->getFusion();
KernelArgumentHolder group_runtime_outputs;
const auto& heuristic_params = heuristics->at(group_to_run->groupId());
const bool is_expr_eval =
heuristic_params->scheduler_type == SchedulerType::ExprEval;
if (is_expr_eval && isOptionEnabled(EnableOption::InferContiguity)) {
// For expr evaluated fusion, the striding rules follow that of ATen.
ExpressionEvaluator eval_fusion;
for (auto i : arange(group_to_run->inputs().size())) {
const auto& tensor_pv = group_runtime_inputs[i];
if (tensor_pv.is<at::Tensor>()) {
const auto& t = tensor_pv.as<at::Tensor>();
if (t.defined()) {
const auto meta_t = at::empty_strided(
t.sizes(),
t.strides(),
at::TensorOptions().device(at::kMeta).dtype(t.dtype()));
eval_fusion.bind(fusion_to_run->inputs()[i], meta_t);
} else {
eval_fusion.bind(fusion_to_run->inputs()[i], t);
}
} else {
eval_fusion.bind(fusion_to_run->inputs()[i], tensor_pv);
}
}
for (auto v : fusion_to_run->outputs()) {
auto result = eval_fusion.evaluate(v);
group_runtime_outputs.push(result);
}
} else {
return inferContiguousOutputMetaTensor(
fusion_to_run, group_runtime_inputs, evaluator_precomputed_values);
}
return group_runtime_outputs;
}

void FusionKernelRuntime::updateContiguityOfSegmentOutputs(
SegmentedGroup* group_to_run,
const KernelArgumentHolder& group_runtime_outputs) const {
FUSER_PERF_SCOPE("FusionKernelRuntime::updateContiguityOfSegmentOutputs");
if (!isOptionEnabled(EnableOption::InferContiguity)) {
return;
}
for (auto [i, output] : enumerate(group_to_run->outputs())) {
auto tv = dynamic_cast<TensorView*>(output);
if (tv) {
const at::Tensor& tensor = group_runtime_outputs[i].as<at::Tensor>();
ir_utils::resetContiguityFromTensor(tv, tensor);
}
}
}

std::vector<KernelArgumentHolder> FusionKernelRuntime::prepareInputs(
const KernelArgumentHolder& args) const {
std::vector<KernelArgumentHolder> all_runtime_inputs;
Expand All @@ -362,16 +421,14 @@ std::vector<KernelArgumentHolder> FusionKernelRuntime::prepareInputs(
group_runtime_inputs.setCacheId(group_cache_id.value());
}

// TODO: inferOutputShapeAndContiguousStrides doesn't seem to strictly
// require a Fusion for each segment. Consider using the complete fusion
// instead.
auto fusion_to_run = segmented_fusion_->makeFusion(group_to_run).second;
auto group_runtime_outputs = inferOutputShapeAndContiguousStrides(
fusion_to_run.get(), group_runtime_inputs);
auto group_runtime_outputs = inferOutputMetaTensor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm losing track of the code. group_runtime_inputs contain meta tensors or real tensors at this moment? The setDeviceIndex call seems to say they are real tensors.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC in prepareInputs, group_runtime_inputs contains real tensor (but still, inferOutputShapeAndContiguousStrides returns meta tensor), but in getMaybeHeuristicsFor, group_runtime_inputs contains meta tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Should setDeviceIndex at line 419 be removed? Is it safe or necessary? (I don't think your PR changes the situation; just OOC).

heuristics_.get(), group_to_run, group_runtime_inputs);

// map output args to tensor map
args_manager.updateWithSegmentOutputs(
group_to_run->outputs(), group_runtime_outputs, run_order_id);

updateContiguityOfSegmentOutputs(group_to_run, group_runtime_outputs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this to hide some bugs in mark_aliases_prepare or allocation_order_inference? The TensorViews in the complete fusion and therefore in segments ought to be correct after preseg.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you define "hide a bug"? We need the correct continuity eventually, which is only possible after we know the scheduler of segmentation. So, why isn't this just writing the correct information, instead of hiding a bug?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which is only possible after we know the scheduler of segmentation

But scheduling happens after prepareInputs:

compileKernel(group_runtime_inputs, group_to_run);

I'm probably missing some important details that are so obvious to you. Let me try to remove this line and see where things break...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$ _bn && pytest tests/python/direct/test_python_frontend.py -k test_issue4888 -vs 

passes with the following patch

diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp
index e025d29d..132cba82 100644
--- a/csrc/runtime/fusion_kernel_runtime.cpp
+++ b/csrc/runtime/fusion_kernel_runtime.cpp
@@ -427,8 +427,6 @@ std::vector<KernelArgumentHolder> FusionKernelRuntime::prepareInputs(
     // map output args to tensor map
     args_manager.updateWithSegmentOutputs(
         group_to_run->outputs(), group_runtime_outputs, run_order_id);
-
-    updateContiguityOfSegmentOutputs(group_to_run, group_runtime_outputs);
   }
 
   return all_runtime_inputs;

But let me try other tests as well...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed the other call to updateContiguityOfSegmentOutputs. After removing that, I see SegmentationTest.RevertPrivatizedUpcast fails. Let me try to understand the error...

$ bin/test_nvfuser --gtest_filter=SegmentationTest.RevertPrivatizedUpcast 
Running main() from /opt/pytorch/nvfuser/third_party/googletest/googletest/src/gtest_main.cc
Note: Google Test filter = SegmentationTest.RevertPrivatizedUpcast
[==========] Running 1 test from 1 test suite.
[----------] Global test environment set-up.
[----------] 1 test from SegmentationTest
[ RUN      ] SegmentationTest.RevertPrivatizedUpcast
/opt/pytorch/nvfuser/tests/cpp/test_segmentation.cpp:855: Failure
Expected equality of these values:
  num_upcast_ops
    Which is: 1
  2

To reproduce: NVFUSER_TEST_RANDOM_SEED=1768609993 NVFUSER_TEST_ATEN_RANDOM_SEED=0 test_nvfuser --gtest_filter='SegmentationTest.RevertPrivatizedUpcast'
[  FAILED  ] SegmentationTest.RevertPrivatizedUpcast (218 ms)
[----------] 1 test from SegmentationTest (218 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (218 ms total)
[  PASSED  ] 0 tests.
[  FAILED  ] 1 test, listed below:
[  FAILED  ] SegmentationTest.RevertPrivatizedUpcast

 1 FAILED TEST

}

return all_runtime_inputs;
Expand Down Expand Up @@ -599,13 +656,16 @@ std::optional<std::unique_ptr<HeuristicParamsList>> FusionKernelRuntime::
}

// Generate metadata for the fusion's outputs
auto group_runtime_outputs = inferOutputShapeAndContiguousStrides(
fusion_to_run,
auto group_runtime_outputs = inferOutputMetaTensor(
heuristics.get(),
group_to_run,
group_runtime_inputs,
evaluator_precomputed_values.get());

args_manager.updateWithSegmentOutputs(
group_to_run->outputs(), group_runtime_outputs, run_order_id);

updateContiguityOfSegmentOutputs(group_to_run, group_runtime_outputs);
}
return heuristics;
}
Expand Down
21 changes: 21 additions & 0 deletions csrc/runtime/fusion_kernel_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,27 @@ class FusionKernelRuntime {
//! Access the list of schedulers maintained in this runtime instance
const std::vector<std::unique_ptr<HeuristicParams>>& schedulers() const;

//! Infer the output shape and stride of the fusion as tensors on Meta device
//! If the group is scheduled to be evaluated using ExprEval, the output
//! tensors are inferred using the ExprEval on meta device. Otherwise, the
//! output tensors are inferred assuming they are contiguous.
KernelArgumentHolder inferOutputMetaTensor(
HeuristicParamsList* heuristics,
SegmentedGroup* group_to_run,
const KernelArgumentHolder& group_runtime_inputs,
PrecomputedValues* evaluator_precomputed_values = nullptr) const;

//! When a FusionIR is constructed, all intermediate tensors are assumed to
//! be contiguous. Unfortunately, this assumption is not always true, and
//! could not be determined at compile time. Depending on the runtime inputs,
//! we may segment the fusions differently, and some fusion segments would be
//! executed using ATen, which may not generate contiguous tensors. So we have
//! to update the contiguity of the segment outputs on the fly depending on
//! the runtime inputs.
void updateContiguityOfSegmentOutputs(
SegmentedGroup* group_to_run,
const KernelArgumentHolder& group_runtime_outputs) const;

// Create KernelArgumentHolders for all of the segments. Sorted in
// the run order.
std::vector<KernelArgumentHolder> prepareInputs(
Expand Down
3 changes: 3 additions & 0 deletions tests/cpp/test_alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ TEST_F(AliasTest, Issue1452) {
}

TEST_F(AliasTest, AliasOutputBeforeNonAliasOutput) {
EnableOptionsGuard opt_guard;
EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity);

auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/test_indexing_advanced.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ class AdvancedIndexingTest : public NVFuserFixtureParamTest<bool> {
} else {
EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel);
}
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

class AdvancedIndexingIdModelTest : public NVFuserTest {
protected:
void SetUp() override {
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class LayoutOpTest : public NVFuserTest {
void SetUp() override {
NVFuserTest::SetUp();
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity);
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_loop_domain_scheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class LoopDomainSchedulingTest : public NVFuserTest {
protected:
void SetUp() override {
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

Expand Down
8 changes: 7 additions & 1 deletion tests/cpp/test_low_precision_recipe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,13 @@ TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) {
class BlockQuantizationSchedulingTest
: public BlackwellBase,
public ::testing::WithParamInterface<
std::tuple<DataType, std::pair<int, int>, bool, bool>> {};
std::tuple<DataType, std::pair<int, int>, bool, bool>> {
protected:
void SetUp() override {
BlackwellBase::SetUp();
EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity);
}
};

TEST_P(BlockQuantizationSchedulingTest, AutoScheduleSingleOp) {
const auto data_type = std::get<0>(GetParam());
Expand Down
33 changes: 0 additions & 33 deletions tests/cpp/test_matmul_aten_evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,37 +371,4 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(Sizes({n, 1})),
testing::Values(Sizes({n}))));

using MatmulNodeTest = NVFuserTest;

TEST_F(MatmulNodeTest, OutputStrides) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

TensorView* x = makeSymbolicTensor(2, DataType::Half);
TensorView* y = makeSymbolicTensor(2, DataType::Half);
TensorView* z = matmul(x, y);

fusion->addInput(x);
fusion->addInput(y);
fusion->addOutput(z);

z->setAllocationDomain({z->axis(1), z->axis(0), z->axis(2)}, true);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
at::Tensor x_tensor = at::randn({2, 3}, options);
at::Tensor y_tensor = at::randn({3, 5}, options);

FusionExecutorCache executor_cache(std::move(fusion));
auto outs = executor_cache.runFusionWithInputs({x_tensor, y_tensor});
at::Tensor z_tensor = outs[0].as<at::Tensor>();
testValidate(
executor_cache.fusion(),
{z_tensor},
{x_tensor, y_tensor},
__LINE__,
__FILE__);

EXPECT_THAT(z_tensor.strides(), ElementsAre(1, 2));
}

} // namespace nvfuser
1 change: 1 addition & 0 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2802,6 +2802,7 @@ class MatmulFusionTest
EnableOptionsGuard::getCurOptions().set(
EnableOption::FuseMultipleMatmuls);
}
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}

bool fusion_enabled = GetParam().first;
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class PointwiseTest : public NVFuserTest {
protected:
void SetUp() override {
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_rng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ at::Tensor generate_normal(int64_t size, at::ScalarType dtype) {
class RNGTest : public NVFuserTest {
void SetUp() override {
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

Expand Down
5 changes: 4 additions & 1 deletion tests/cpp/test_segmentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,10 @@ TEST_F(SegmentationTest, RevertPrivatizedUpcast) {

++num_upcast_ops;
}
EXPECT_EQ(num_upcast_ops, 1);
// There is an unswitched IfThenElse in the generated kernel, and in each
// of its branches, there is an upcast op with tv1 as its producer. So we
// should have two upcast ops.
EXPECT_EQ(num_upcast_ops, 2);
}
}

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ void NVFuserTest::SetUp() {
GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs";
}
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}

NVFuserTest::~NVFuserTest() {
Expand Down
12 changes: 11 additions & 1 deletion tests/python/direct/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def exec_nvfuser(
new_fusion_expected=True,
expected_fd_str=None,
device=None,
enable_options=None,
disable_options=None,
validate_results=False,
):
# Copy inputs because aliased outputs can modify inputs when running
Expand Down Expand Up @@ -64,12 +66,20 @@ def exec_nvfuser(
if validate_results:
out = fd.validate(inputs)
else:
if enable_options is None:
enable_options = []
if disable_options is None:
disable_options = []
out = fd.execute(
inputs,
device=device,
_enable_options=enable_options,
_disable_options=disable_options,
)

assert check_captured_python_definition(out, fd, inputs_captured, device)
assert check_captured_python_definition(
out, fd, inputs_captured, device, enable_options, disable_options
)
assert expected_fd_str is None or expected_fd_str in repr(fd)
return out, fd

Expand Down
Loading