Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
473191d
Add utility ir_utils::resetContiguityFromTensor
zasdfgbnm Jan 6, 2026
ffd6864
Merge branch 'main' into resetContiguityFromTensor
zasdfgbnm Jan 7, 2026
69a65a8
save
zasdfgbnm Jan 7, 2026
5a784d4
save
zasdfgbnm Jan 7, 2026
d828c9f
Use meta device tensor to infer contiguity for expr-eval segments
zasdfgbnm Jan 7, 2026
c95a584
Merge branch 'main' into resetContiguityFromTensor
zasdfgbnm Jan 13, 2026
074e947
Merge branch 'resetContiguityFromTensor' into meta-eval
zasdfgbnm Jan 13, 2026
6bcdb96
Merge branch 'main' into resetContiguityFromTensor
zasdfgbnm Jan 14, 2026
4ad3785
Merge branch 'resetContiguityFromTensor' into meta-eval
zasdfgbnm Jan 14, 2026
c0b50c3
Refactor output shape inference functions and update segment output h…
zasdfgbnm Jan 14, 2026
7ffdaa3
save
zasdfgbnm Jan 14, 2026
7a5b0dc
save
zasdfgbnm Jan 14, 2026
d92e5ee
save
zasdfgbnm Jan 14, 2026
074209b
save
zasdfgbnm Jan 14, 2026
68d52fb
save
zasdfgbnm Jan 14, 2026
fb0572a
save
zasdfgbnm Jan 14, 2026
08e5b6b
save
zasdfgbnm Jan 14, 2026
784ce68
enable
zasdfgbnm Jan 14, 2026
9c46183
Merge branch 'meta-eval' of github.com:NVIDIA/Fuser into meta-eval
zasdfgbnm Jan 14, 2026
93a4012
save
zasdfgbnm Jan 14, 2026
e5d4d67
fix
zasdfgbnm Jan 14, 2026
38defa9
save
zasdfgbnm Jan 14, 2026
f50e52f
save
zasdfgbnm Jan 14, 2026
5fd7496
fix
zasdfgbnm Jan 14, 2026
40782db
fix
zasdfgbnm Jan 14, 2026
53d70fe
save
zasdfgbnm Jan 14, 2026
87b00e8
save
zasdfgbnm Jan 14, 2026
4afe5b1
fix
zasdfgbnm Jan 14, 2026
447de33
fix
zasdfgbnm Jan 14, 2026
dd41424
Merge branch 'meta-eval' of github.com:NVIDIA/Fuser into meta-eval
zasdfgbnm Jan 14, 2026
831c777
save
zasdfgbnm Jan 14, 2026
d246fc6
save
zasdfgbnm Jan 14, 2026
01a011c
save
zasdfgbnm Jan 15, 2026
03fa1f5
save
zasdfgbnm Jan 15, 2026
0a2878b
Apply suggestion from @wujingyue
zasdfgbnm Jan 15, 2026
6e3b29e
Merge branch 'main' into resetContiguityFromTensor
zasdfgbnm Jan 15, 2026
72c66dd
Merge branch 'resetContiguityFromTensor' into meta-eval
zasdfgbnm Jan 15, 2026
b151b50
save
zasdfgbnm Jan 15, 2026
4249f8b
save
zasdfgbnm Jan 15, 2026
40b5411
fix
zasdfgbnm Jan 16, 2026
b7778bd
Merge branch 'main' into resetContiguityFromTensor
zasdfgbnm Jan 16, 2026
c81f895
Merge branch 'resetContiguityFromTensor' into meta-eval
zasdfgbnm Jan 16, 2026
8466489
Don't update contiguity (#5842)
wujingyue Jan 21, 2026
1500e69
Merge branch 'main' of github.com:NVIDIA/Fuser into meta-eval
zasdfgbnm Jan 21, 2026
7c62ef3
save
zasdfgbnm Jan 21, 2026
fb95173
fix
zasdfgbnm Jan 21, 2026
f87a0ca
Remove unused include
zasdfgbnm Jan 23, 2026
9665cf0
fix
zasdfgbnm Jan 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/options.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ const std::unordered_map<std::string, EnableOption>& getEnableOptions() {
{"p2p_protocol", EnableOption::P2pProtocol},
{"multicast_protocol", EnableOption::MulticastProtocol},
{"parallel_serde", EnableOption::ParallelSerde},
{"infer_contiguity", EnableOption::InferContiguity},
};
return available_options;
}
Expand Down
1 change: 1 addition & 0 deletions csrc/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ enum class EnableOption {
MulticastProtocol, //! Prescribe multicast protocol:
//! memcpy|multimem|batch_memcpy
ParallelSerde, //! Enable deserializing FusionExecutorCache in parallel
InferContiguity, //! Enable contiguity inference
EndOfOption //! Placeholder for counting the number of elements
};

Expand Down
4 changes: 2 additions & 2 deletions csrc/runtime/allocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@

namespace nvfuser {

KernelArgumentHolder inferOutputShapeAndContiguousStrides(
KernelArgumentHolder inferContiguousOutputMetaTensor(
Fusion* fusion,
const KernelArgumentHolder& args,
PrecomputedValues* evaluator_precomputed_values) {
FUSER_PERF_SCOPE(
"fusion_executor::allocations::inferOutputShapeAndContiguousStrides");
"fusion_executor::allocations::inferContiguousOutputMetaTensor");
ExpressionEvaluator expr_eval;

std::unique_ptr<PrecomputedValues> evaluator_precomputed_values_up = nullptr;
Expand Down
2 changes: 1 addition & 1 deletion csrc/runtime/allocations.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct GlobalBufferInfo {
//! pushing scalar int 0 as a place-holder.
//! 2. This API does not allocate output in memory, but only returns the
//! inferred output sizes. Used in runtime/fusion_executor_cache.cpp.
KernelArgumentHolder inferOutputShapeAndContiguousStrides(
KernelArgumentHolder inferContiguousOutputMetaTensor(
Fusion* fusion,
const KernelArgumentHolder& args,
PrecomputedValues* evaluator_precomputed_values = nullptr);
Expand Down
56 changes: 48 additions & 8 deletions csrc/runtime/fusion_kernel_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,49 @@ KernelArgumentHolder FusionKernelRuntime::runWithInputs(
return fusion_outputs;
}

KernelArgumentHolder FusionKernelRuntime::inferOutputMetaTensor(
HeuristicParamsList* heuristics,
SegmentedGroup* group_to_run,
const KernelArgumentHolder& group_runtime_inputs,
PrecomputedValues* evaluator_precomputed_values) const {
FUSER_PERF_SCOPE("FusionKernelRuntime::inferOutputMetaTensor");
NVF_ERROR(heuristics != nullptr);
Fusion* fusion_to_run = group_to_run->getFusion();
const auto& heuristic_params = heuristics->at(group_to_run->groupId());
const bool is_expr_eval =
heuristic_params->scheduler_type == SchedulerType::ExprEval;
if (!(is_expr_eval && isOptionEnabled(EnableOption::InferContiguity))) {
return inferContiguousOutputMetaTensor(
fusion_to_run, group_runtime_inputs, evaluator_precomputed_values);
}

// For expr evaluated fusion, the striding rules follow that of ATen.
ExpressionEvaluator eval_fusion;
for (const auto& [in, tensor_pv] :
zip(fusion_to_run->inputs(), group_runtime_inputs)) {
if (tensor_pv.is<at::Tensor>()) {
const auto& t = tensor_pv.as<at::Tensor>();
if (t.defined()) {
const auto meta_t = at::empty_strided(
t.sizes(),
t.strides(),
at::TensorOptions().device(at::kMeta).dtype(t.dtype()));
eval_fusion.bind(in, meta_t);
} else {
eval_fusion.bind(in, t);
}
} else {
eval_fusion.bind(in, tensor_pv);
}
}
KernelArgumentHolder group_runtime_outputs;
for (Val* v : fusion_to_run->outputs()) {
auto result = eval_fusion.evaluate(v);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: no error handling for evaluate() failure - if evaluation fails or returns an invalid result, it will silently continue

Consider validating the result or wrapping in try-catch, especially since the PR description mentions some ATen ops on meta device can hang due to GIL issues

group_runtime_outputs.push(result);
}
return group_runtime_outputs;
}

std::vector<KernelArgumentHolder> FusionKernelRuntime::prepareInputs(
const KernelArgumentHolder& args) const {
std::vector<KernelArgumentHolder> all_runtime_inputs;
Expand All @@ -362,12 +405,8 @@ std::vector<KernelArgumentHolder> FusionKernelRuntime::prepareInputs(
group_runtime_inputs.setCacheId(group_cache_id.value());
}

// TODO: inferOutputShapeAndContiguousStrides doesn't seem to strictly
// require a Fusion for each segment. Consider using the complete fusion
// instead.
auto fusion_to_run = segmented_fusion_->makeFusion(group_to_run).second;
auto group_runtime_outputs = inferOutputShapeAndContiguousStrides(
fusion_to_run.get(), group_runtime_inputs);
auto group_runtime_outputs = inferOutputMetaTensor(
heuristics_.get(), group_to_run, group_runtime_inputs);

// map output args to tensor map
args_manager.updateWithSegmentOutputs(
Expand Down Expand Up @@ -599,8 +638,9 @@ std::optional<std::unique_ptr<HeuristicParamsList>> FusionKernelRuntime::
}

// Generate metadata for the fusion's outputs
auto group_runtime_outputs = inferOutputShapeAndContiguousStrides(
fusion_to_run,
auto group_runtime_outputs = inferOutputMetaTensor(
heuristics.get(),
group_to_run,
group_runtime_inputs,
evaluator_precomputed_values.get());

Expand Down
10 changes: 10 additions & 0 deletions csrc/runtime/fusion_kernel_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,16 @@ class FusionKernelRuntime {
//! Access the list of schedulers maintained in this runtime instance
const std::vector<std::unique_ptr<HeuristicParams>>& schedulers() const;

//! Infer the output shape and stride of the fusion as tensors on Meta device
//! If the group is scheduled to be evaluated using ExprEval, the output
//! tensors are inferred using the ExprEval on meta device. Otherwise, the
//! output tensors are inferred assuming they are contiguous.
KernelArgumentHolder inferOutputMetaTensor(
HeuristicParamsList* heuristics,
SegmentedGroup* group_to_run,
const KernelArgumentHolder& group_runtime_inputs,
PrecomputedValues* evaluator_precomputed_values = nullptr) const;

// Create KernelArgumentHolders for all of the segments. Sorted in
// the run order.
std::vector<KernelArgumentHolder> prepareInputs(
Expand Down
3 changes: 3 additions & 0 deletions tests/cpp/test_alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,9 @@ TEST_F(AliasTest, Issue1452) {
}

TEST_F(AliasTest, AliasOutputBeforeNonAliasOutput) {
EnableOptionsGuard opt_guard;
EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a no-op? AliasTest doesn't seem to enable InferContiguity.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AliasTest is NVFuserTest, which enables InferContiguity.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it -- I didn't realize NVFuserTest enables the option by default. Then, why do some tests enable it again, e.g., https://github.com/NVIDIA/Fuser/pull/5772/files#diff-3675636f2228bd2f8c3f308c28fa88f1d659d8eb3d869570dcfdf013f77908aaR29?


auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/test_indexing_advanced.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ class AdvancedIndexingTest : public NVFuserFixtureParamTest<bool> {
} else {
EnableOptionsGuard::getCurOptions().unset(EnableOption::IdModel);
}
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

class AdvancedIndexingIdModelTest : public NVFuserTest {
protected:
void SetUp() override {
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_layout_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class LayoutOpTest : public NVFuserTest {
void SetUp() override {
NVFuserTest::SetUp();
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity);
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_loop_domain_scheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class LoopDomainSchedulingTest : public NVFuserTest {
protected:
void SetUp() override {
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

Expand Down
8 changes: 7 additions & 1 deletion tests/cpp/test_low_precision_recipe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,13 @@ TEST_F(BlockQuantizationValidationTest, MergesMustBeContiguous) {
class BlockQuantizationSchedulingTest
: public BlackwellBase,
public ::testing::WithParamInterface<
std::tuple<DataType, std::pair<int, int>, bool, bool>> {};
std::tuple<DataType, std::pair<int, int>, bool, bool>> {
protected:
void SetUp() override {
BlackwellBase::SetUp();
EnableOptionsGuard::getCurOptions().unset(EnableOption::InferContiguity);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing. I think we should migrate our entire codebase to prefer ctor, instead of just this specific test. I don't think it is a good idea to mix ctor and SetUp. Because NVFuserTest setup InferContiguity and everything else in SetUp, we need to be consistent here, because otherwise whatever we set in ctor will be overriden by NVFuserTest::SetUp

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should migrate our entire codebase to prefer ctor, instead of just this specific test

I actually did that for NVFuserTest. In fact, most setup code for NVFuserTest is in its constructor already except this line (https://github.com/NVIDIA/Fuser/pull/5772/files#diff-16f891fd5f846480392227c6bbf81ead352f59fdc9964e5d6e4dc6089bb622c5R61) which was added later without my notice. The only thing that should be kept in NVFuserTest::SetUp at this moment is GTEST_SKIP.

}
};

TEST_P(BlockQuantizationSchedulingTest, AutoScheduleSingleOp) {
const auto data_type = std::get<0>(GetParam());
Expand Down
33 changes: 0 additions & 33 deletions tests/cpp/test_matmul_aten_evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,37 +371,4 @@ INSTANTIATE_TEST_SUITE_P(
testing::Values(Sizes({n, 1})),
testing::Values(Sizes({n}))));

using MatmulNodeTest = NVFuserTest;

TEST_F(MatmulNodeTest, OutputStrides) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

TensorView* x = makeSymbolicTensor(2, DataType::Half);
TensorView* y = makeSymbolicTensor(2, DataType::Half);
TensorView* z = matmul(x, y);

fusion->addInput(x);
fusion->addInput(y);
fusion->addOutput(z);

z->setAllocationDomain({z->axis(1), z->axis(0), z->axis(2)}, true);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
at::Tensor x_tensor = at::randn({2, 3}, options);
at::Tensor y_tensor = at::randn({3, 5}, options);

FusionExecutorCache executor_cache(std::move(fusion));
auto outs = executor_cache.runFusionWithInputs({x_tensor, y_tensor});
at::Tensor z_tensor = outs[0].as<at::Tensor>();
testValidate(
executor_cache.fusion(),
{z_tensor},
{x_tensor, y_tensor},
__LINE__,
__FILE__);

EXPECT_THAT(z_tensor.strides(), ElementsAre(1, 2));
}

} // namespace nvfuser
1 change: 1 addition & 0 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2802,6 +2802,7 @@ class MatmulFusionTest
EnableOptionsGuard::getCurOptions().set(
EnableOption::FuseMultipleMatmuls);
}
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}

bool fusion_enabled = GetParam().first;
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class PointwiseTest : public NVFuserTest {
protected:
void SetUp() override {
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/test_rng.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ at::Tensor generate_normal(int64_t size, at::ScalarType dtype) {
class RNGTest : public NVFuserTest {
void SetUp() override {
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}
};

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ void NVFuserTest::SetUp() {
GTEST_SKIP() << "skipping tests on pre-PASCAL GPUs";
}
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel);
EnableOptionsGuard::getCurOptions().set(EnableOption::InferContiguity);
}

NVFuserTest::~NVFuserTest() {
Expand Down
12 changes: 11 additions & 1 deletion tests/python/direct/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def exec_nvfuser(
new_fusion_expected=True,
expected_fd_str=None,
device=None,
enable_options=None,
disable_options=None,
validate_results=False,
):
# Copy inputs because aliased outputs can modify inputs when running
Expand Down Expand Up @@ -64,12 +66,20 @@ def exec_nvfuser(
if validate_results:
out = fd.validate(inputs)
else:
if enable_options is None:
enable_options = []
if disable_options is None:
disable_options = []
out = fd.execute(
inputs,
device=device,
_enable_options=enable_options,
_disable_options=disable_options,
)

assert check_captured_python_definition(out, fd, inputs_captured, device)
assert check_captured_python_definition(
out, fd, inputs_captured, device, enable_options, disable_options
)
assert expected_fd_str is None or expected_fd_str in repr(fd)
return out, fd

Expand Down
98 changes: 98 additions & 0 deletions tests/python/direct/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2763,3 +2763,101 @@ def fusion_func(fd: FusionDefinition):

out, _ = nvfuser_direct_test.exec_nvfuser(fusion_func, inputs)
nvfuser_direct_test.assertEqual(out[0], inputs[0])


def test_issue4888(nvfuser_direct_test):
# https://github.com/NVIDIA/Fuser/issues/4888
def nvfuser_fusion_id2(fd: FusionDefinition) -> None:
T0 = fd.define_tensor(
shape=[4096, 4097],
contiguity=[True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[1, 0],
)
T1 = fd.define_tensor(
shape=[4096, 4097],
contiguity=[True, True],
dtype=DataType.Bool,
is_cpu=False,
stride_order=[1, 0],
)
T2 = fd.define_tensor(
shape=[4096, 4097],
contiguity=[True, True],
dtype=DataType.Bool,
is_cpu=False,
stride_order=[1, 0],
)
T3 = fd.define_tensor(
shape=[1, 32, 4096, 4096],
contiguity=[None, True, True, True],
dtype=DataType.BFloat16,
is_cpu=False,
stride_order=[3, 2, 1, 0],
)
T4 = fd.ops.cast(T0, dtype=DataType.Float)
T5 = fd.ops.bitwise_or(T1, T2)
T6 = fd.ops.set(T5)
fd.add_output(T6, T1)
T7 = fd.ops.cast(T6, dtype=DataType.Float)
T8 = fd.ops.mul(T4, T7)
T9 = fd.ops.cast(T8, dtype=DataType.BFloat16)
T10 = fd.ops.set(T9)
fd.add_output(T10, T0)
T15 = fd.ops.broadcast_in_dim(T10, shape=[1, 4096, 4097], broadcast_dims=[1, 2])
T21 = fd.ops.broadcast_in_dim(
T15, shape=[1, 1, 4096, 4097], broadcast_dims=[0, 2, 3]
)
T27 = fd.ops.broadcast_in_dim(
T21, shape=[1, 1, 4096, 4097], broadcast_dims=[0, 1, 2, 3]
)
T43 = fd.ops.slice(
T27,
start_indices=[0, 0, 0, 0],
end_indices=[1, 1, 4096, 4096],
strides=[1, 1, 1, 1],
manual_normalization=0,
)
T49 = fd.ops.broadcast_in_dim(
T43, shape=[1, 32, 4096, 4096], broadcast_dims=[0, 1, 2, 3]
)
T50 = fd.ops.cast(T49, dtype=DataType.Float)
T51 = fd.ops.cast(T3, dtype=DataType.Float)
S52 = fd.define_scalar(0.0883883, dtype=DataType.Double)
T53 = fd.ops.mul(T51, S52)
T54 = fd.ops.add(T53, T50)
T55 = fd.ops.max(T54, dims=[3], keepdim=False, dtype=DataType.Null)
T61 = fd.ops.broadcast_in_dim(
T55, shape=[1, 32, 4096, 1], broadcast_dims=[0, 1, 2]
)
T67 = fd.ops.broadcast_in_dim(
T61, shape=[1, 32, 4096, 4096], broadcast_dims=[0, 1, 2, 3]
)
T68 = fd.ops.sub(T54, T67)
T69 = fd.ops.exp(T68)
T70 = fd.ops.sum(T69, dims=[3], keepdim=False, dtype=DataType.Null)
T76 = fd.ops.broadcast_in_dim(
T70, shape=[1, 32, 4096, 1], broadcast_dims=[0, 1, 2]
)
T82 = fd.ops.broadcast_in_dim(
T76, shape=[1, 32, 4096, 4096], broadcast_dims=[0, 1, 2, 3]
)
T83 = fd.ops.reciprocal(T82)
T84 = fd.ops.mul(T69, T83)
T85 = fd.ops.cast(T84, dtype=DataType.BFloat16)
fd.add_output(T49)
fd.add_output(T84)
fd.add_output(T85)

inputs = [
torch.testing.make_tensor((4096, 4097), dtype=torch.bfloat16, device="cuda:0"),
torch.testing.make_tensor((4096, 4097), dtype=torch.bool, device="cuda:0"),
torch.testing.make_tensor((4096, 4097), dtype=torch.bool, device="cuda:0"),
torch.testing.make_tensor(
(1, 32, 4096, 4096), dtype=torch.bfloat16, device="cuda:0"
),
]
nvfuser_direct_test.exec_nvfuser(
nvfuser_fusion_id2, inputs, enable_options=["infer_contiguity"]
)
Loading