Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions csrc/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ void PrecomputedValues::bindParallelExtents(
auto raw_val = launch_constraint.getRawVal(it.first);
if (raw_val > 0) {
for (auto extent : it.second) {
bindValue(extent->evaluatorIndex(), raw_val);
bindValue(extent->evaluatorIndex(), raw_val, extent);
}
}
}
Expand All @@ -198,7 +198,10 @@ void PrecomputedValues::bindConcreteParallelTypeValue(
auto index_list_it = thread_dim_value_indices_.find(pt);
if (index_list_it != thread_dim_value_indices_.end()) {
for (auto index : *(index_list_it->second)) {
bindValue(index, value);
const Val* ir_node = (index >= 0 && index < (int)symbols_.size())
? symbols_[index]
: nullptr;
bindValue(index, value, ir_node);
}
}
}
Expand Down Expand Up @@ -228,7 +231,7 @@ void PrecomputedValues::bindValues(
bindTensorMetaData(tv, tensor);
}
} else {
bindValue(input->evaluatorIndex(), args[i]);
bindValue(input->evaluatorIndex(), args[i], input);
}
}
}
Expand Down Expand Up @@ -360,15 +363,19 @@ void PrecomputedValues::initializeNamedScalars() {
void PrecomputedValues::validate() {
FUSER_PERF_SCOPE("PrecomputedValuess::Validate");
using namespace PolymorphicValue_functions;
for (const auto& it : binding_log_) {
NVF_ERROR(
isSame(values_[it.first], it.second),
"Precomputed values failed to validate.",
"\nSomething unexpected changed between the compilation and "
"execution.\n",
values_[it.first],
" != ",
it.second);
for (const auto& [index, expected_value, ir_node] : binding_log_) {
if (!isSame(values_[index], expected_value)) {
std::stringstream error_msg;
error_msg << "Precomputed values failed to validate.\n"
<< "Something unexpected changed between the compilation and "
"execution.\n";
if (ir_node != nullptr) {
error_msg << "IR node: " << ir_node->toString() << "\n";
}
error_msg << "Computed value: " << values_[index] << "\n"
<< "Expected value: " << expected_value;
NVF_ERROR(false, error_msg.str());
}
}
has_valid_values_ = true;
}
Expand All @@ -391,12 +398,15 @@ void PrecomputedValues::bindTensorMetaData(
if (id->isBroadcast()) {
// DIDs are ignored for broadcast. See MultideviceShardingTest.Broadcast
// and .ExpandedBroadcast.
bindValue(id->extent()->evaluatorIndex(), 1L);
bindValue(id->extent()->evaluatorIndex(), 1L, id->extent());
if (id->hasExpandedExtent()) {
bindValue(id->expandedExtent()->evaluatorIndex(), dim_size);
bindValue(
id->expandedExtent()->evaluatorIndex(),
dim_size,
id->expandedExtent());
}
} else {
bindValue(id->extent()->evaluatorIndex(), dim_size);
bindValue(id->extent()->evaluatorIndex(), dim_size, id->extent());
}
}

Expand Down Expand Up @@ -424,7 +434,7 @@ void PrecomputedValues::bindTensorMetaData(
tv->toString(),
" with input tensor ",
tensor);
bindValue(metadata_val->evaluatorIndex(), metadata);
bindValue(metadata_val->evaluatorIndex(), metadata, metadata_val);
}

NaiveValueMachine::NaiveValueMachine(PrecomputedValues& precomputed_values)
Expand Down
15 changes: 10 additions & 5 deletions csrc/evaluator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,17 +211,22 @@ class PrecomputedValues {

//! Bind concrete value to the given index
//! if the index is valid.
void bindValue_(int index, const PolymorphicValue& value) {
//! \param ir_node Used to track the original IR node for the index, only
//! used for improving error messages
void bindValue_(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment about the optional ir_node parameter?

int index,
const PolymorphicValue& value,
const Val* ir_node = nullptr) {
if (index < 0 || is_constant_[index]) {
return;
}
defined_[index] = true;
values_[index] = value;
binding_log_.emplace_back(index, value);
binding_log_.emplace_back(index, value, ir_node);
}
template <typename T>
void bindValue(int index, const T& value) {
bindValue_(index, PolymorphicValue(value));
void bindValue(int index, const T& value, const Val* ir_node = nullptr) {
bindValue_(index, PolymorphicValue(value), ir_node);
}

//! Invalidate all computed values in the workspace.
Expand Down Expand Up @@ -292,7 +297,7 @@ class PrecomputedValues {
//! An internal log to keep track of all the bindings
//! used in each evaluation cycle. To be used for
//! consistency check.
std::vector<std::pair<int, PolymorphicValue>> binding_log_;
std::vector<std::tuple<int, PolymorphicValue, const Val*>> binding_log_;

//! Integer runtime for realizing the values computations.
std::unique_ptr<NaiveValueMachine> value_machine_;
Expand Down