-
Notifications
You must be signed in to change notification settings - Fork 74
Use meta device tensor to infer contiguity for expr-eval segments #5772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: resetContiguityFromTensor
Are you sure you want to change the base?
Changes from all commits
d828c9f
074e947
4ad3785
c0b50c3
7ffdaa3
7a5b0dc
d92e5ee
074209b
68d52fb
fb0572a
08e5b6b
784ce68
9c46183
93a4012
e5d4d67
38defa9
f50e52f
5fd7496
40782db
53d70fe
87b00e8
4afe5b1
447de33
dd41424
831c777
d246fc6
01a011c
03fa1f5
0a2878b
72c66dd
b151b50
4249f8b
40b5411
c81f895
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -337,6 +337,65 @@ KernelArgumentHolder FusionKernelRuntime::runWithInputs( | |||
| return fusion_outputs; | ||||
| } | ||||
|
|
||||
| KernelArgumentHolder FusionKernelRuntime::inferOutputMetaTensor( | ||||
| HeuristicParamsList* heuristics, | ||||
| SegmentedGroup* group_to_run, | ||||
| const KernelArgumentHolder& group_runtime_inputs, | ||||
| PrecomputedValues* evaluator_precomputed_values) const { | ||||
| FUSER_PERF_SCOPE("FusionKernelRuntime::inferOutputMetaTensor"); | ||||
| NVF_ERROR(heuristics != nullptr); | ||||
| Fusion* fusion_to_run = group_to_run->getFusion(); | ||||
| KernelArgumentHolder group_runtime_outputs; | ||||
| const auto& heuristic_params = heuristics->at(group_to_run->groupId()); | ||||
| const bool is_expr_eval = | ||||
| heuristic_params->scheduler_type == SchedulerType::ExprEval; | ||||
| if (is_expr_eval && isOptionEnabled(EnableOption::InferContiguity)) { | ||||
| // For expr evaluated fusion, the striding rules follow that of ATen. | ||||
| ExpressionEvaluator eval_fusion; | ||||
| for (auto i : arange(group_to_run->inputs().size())) { | ||||
| const auto& tensor_pv = group_runtime_inputs[i]; | ||||
| if (tensor_pv.is<at::Tensor>()) { | ||||
| const auto& t = tensor_pv.as<at::Tensor>(); | ||||
| if (t.defined()) { | ||||
| const auto meta_t = at::empty_strided( | ||||
| t.sizes(), | ||||
| t.strides(), | ||||
| at::TensorOptions().device(at::kMeta).dtype(t.dtype())); | ||||
| eval_fusion.bind(fusion_to_run->inputs()[i], meta_t); | ||||
| } else { | ||||
| eval_fusion.bind(fusion_to_run->inputs()[i], t); | ||||
| } | ||||
| } else { | ||||
| eval_fusion.bind(fusion_to_run->inputs()[i], tensor_pv); | ||||
| } | ||||
| } | ||||
| for (auto v : fusion_to_run->outputs()) { | ||||
| auto result = eval_fusion.evaluate(v); | ||||
| group_runtime_outputs.push(result); | ||||
| } | ||||
| } else { | ||||
| return inferContiguousOutputMetaTensor( | ||||
| fusion_to_run, group_runtime_inputs, evaluator_precomputed_values); | ||||
| } | ||||
| return group_runtime_outputs; | ||||
| } | ||||
|
|
||||
| void FusionKernelRuntime::updateContiguityOfSegmentOutputs( | ||||
| SegmentedGroup* group_to_run, | ||||
| const KernelArgumentHolder& group_runtime_outputs) const { | ||||
| FUSER_PERF_SCOPE("FusionKernelRuntime::updateContiguityOfSegmentOutputs"); | ||||
| if (!isOptionEnabled(EnableOption::InferContiguity)) { | ||||
| return; | ||||
| } | ||||
| for (auto [i, output] : enumerate(group_to_run->outputs())) { | ||||
| auto tv = dynamic_cast<TensorView*>(output); | ||||
| if (tv) { | ||||
| const at::Tensor& tensor = group_runtime_outputs[i].as<at::Tensor>(); | ||||
| ir_utils::resetContiguityFromTensor(tv, tensor); | ||||
| } | ||||
| } | ||||
| } | ||||
|
|
||||
| std::vector<KernelArgumentHolder> FusionKernelRuntime::prepareInputs( | ||||
| const KernelArgumentHolder& args) const { | ||||
| std::vector<KernelArgumentHolder> all_runtime_inputs; | ||||
|
|
@@ -362,16 +421,14 @@ std::vector<KernelArgumentHolder> FusionKernelRuntime::prepareInputs( | |||
| group_runtime_inputs.setCacheId(group_cache_id.value()); | ||||
| } | ||||
|
|
||||
| // TODO: inferOutputShapeAndContiguousStrides doesn't seem to strictly | ||||
| // require a Fusion for each segment. Consider using the complete fusion | ||||
| // instead. | ||||
| auto fusion_to_run = segmented_fusion_->makeFusion(group_to_run).second; | ||||
| auto group_runtime_outputs = inferOutputShapeAndContiguousStrides( | ||||
| fusion_to_run.get(), group_runtime_inputs); | ||||
| auto group_runtime_outputs = inferOutputMetaTensor( | ||||
| heuristics_.get(), group_to_run, group_runtime_inputs); | ||||
|
|
||||
| // map output args to tensor map | ||||
| args_manager.updateWithSegmentOutputs( | ||||
| group_to_run->outputs(), group_runtime_outputs, run_order_id); | ||||
|
|
||||
| updateContiguityOfSegmentOutputs(group_to_run, group_runtime_outputs); | ||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this to hide some bugs in mark_aliases_prepare or allocation_order_inference? The TensorViews in the complete fusion and therefore in segments ought to be correct after preseg.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How do you define "hide a bug"? We need the correct continuity eventually, which is only possible after we know the scheduler of segmentation. So, why isn't this just writing the correct information, instead of hiding a bug?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
But scheduling happens after prepareInputs: Fuser/csrc/runtime/fusion_kernel_runtime.cpp Line 431 in 352dcbf
I'm probably missing some important details that are so obvious to you. Let me try to remove this line and see where things break...
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. $ _bn && pytest tests/python/direct/test_python_frontend.py -k test_issue4888 -vs passes with the following patch diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp
index e025d29d..132cba82 100644
--- a/csrc/runtime/fusion_kernel_runtime.cpp
+++ b/csrc/runtime/fusion_kernel_runtime.cpp
@@ -427,8 +427,6 @@ std::vector<KernelArgumentHolder> FusionKernelRuntime::prepareInputs(
// map output args to tensor map
args_manager.updateWithSegmentOutputs(
group_to_run->outputs(), group_runtime_outputs, run_order_id);
-
- updateContiguityOfSegmentOutputs(group_to_run, group_runtime_outputs);
}
return all_runtime_inputs;But let me try other tests as well...
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I missed the other call to updateContiguityOfSegmentOutputs. After removing that, I see |
||||
| } | ||||
|
|
||||
| return all_runtime_inputs; | ||||
|
|
@@ -599,13 +656,16 @@ std::optional<std::unique_ptr<HeuristicParamsList>> FusionKernelRuntime:: | |||
| } | ||||
|
|
||||
| // Generate metadata for the fusion's outputs | ||||
| auto group_runtime_outputs = inferOutputShapeAndContiguousStrides( | ||||
| fusion_to_run, | ||||
| auto group_runtime_outputs = inferOutputMetaTensor( | ||||
| heuristics.get(), | ||||
| group_to_run, | ||||
| group_runtime_inputs, | ||||
| evaluator_precomputed_values.get()); | ||||
|
|
||||
| args_manager.updateWithSegmentOutputs( | ||||
| group_to_run->outputs(), group_runtime_outputs, run_order_id); | ||||
|
|
||||
| updateContiguityOfSegmentOutputs(group_to_run, group_runtime_outputs); | ||||
| } | ||||
| return heuristics; | ||||
| } | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm losing track of the code. group_runtime_inputs contain meta tensors or real tensors at this moment? The
setDeviceIndexcall seems to say they are real tensors.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC in
prepareInputs,group_runtime_inputscontains real tensor (but still,inferOutputShapeAndContiguousStridesreturns meta tensor), but ingetMaybeHeuristicsFor,group_runtime_inputscontains meta tensor.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Should
setDeviceIndexat line 419 be removed? Is it safe or necessary? (I don't think your PR changes the situation; just OOC).