diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 5e983777b04..7b71892f756 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -186,7 +186,7 @@ void PrecomputedValues::bindParallelExtents( auto raw_val = launch_constraint.getRawVal(it.first); if (raw_val > 0) { for (auto extent : it.second) { - bindValue(extent->evaluatorIndex(), raw_val); + bindValue(extent->evaluatorIndex(), raw_val, extent); } } } @@ -198,7 +198,10 @@ void PrecomputedValues::bindConcreteParallelTypeValue( auto index_list_it = thread_dim_value_indices_.find(pt); if (index_list_it != thread_dim_value_indices_.end()) { for (auto index : *(index_list_it->second)) { - bindValue(index, value); + const Val* ir_node = (index >= 0 && index < (int)symbols_.size()) + ? symbols_[index] + : nullptr; + bindValue(index, value, ir_node); } } } @@ -228,7 +231,7 @@ void PrecomputedValues::bindValues( bindTensorMetaData(tv, tensor); } } else { - bindValue(input->evaluatorIndex(), args[i]); + bindValue(input->evaluatorIndex(), args[i], input); } } } @@ -360,15 +363,19 @@ void PrecomputedValues::initializeNamedScalars() { void PrecomputedValues::validate() { FUSER_PERF_SCOPE("PrecomputedValuess::Validate"); using namespace PolymorphicValue_functions; - for (const auto& it : binding_log_) { - NVF_ERROR( - isSame(values_[it.first], it.second), - "Precomputed values failed to validate.", - "\nSomething unexpected changed between the compilation and " - "execution.\n", - values_[it.first], - " != ", - it.second); + for (const auto& [index, expected_value, ir_node] : binding_log_) { + if (!isSame(values_[index], expected_value)) { + std::stringstream error_msg; + error_msg << "Precomputed values failed to validate.\n" + << "Something unexpected changed between the compilation and " + "execution.\n"; + if (ir_node != nullptr) { + error_msg << "IR node: " << ir_node->toString() << "\n"; + } + error_msg << "Computed value: " << values_[index] << "\n" + << "Expected value: " << expected_value; + NVF_ERROR(false, error_msg.str()); + } } has_valid_values_ = true; } @@ -391,12 +398,15 @@ void PrecomputedValues::bindTensorMetaData( if (id->isBroadcast()) { // DIDs are ignored for broadcast. See MultideviceShardingTest.Broadcast // and .ExpandedBroadcast. - bindValue(id->extent()->evaluatorIndex(), 1L); + bindValue(id->extent()->evaluatorIndex(), 1L, id->extent()); if (id->hasExpandedExtent()) { - bindValue(id->expandedExtent()->evaluatorIndex(), dim_size); + bindValue( + id->expandedExtent()->evaluatorIndex(), + dim_size, + id->expandedExtent()); } } else { - bindValue(id->extent()->evaluatorIndex(), dim_size); + bindValue(id->extent()->evaluatorIndex(), dim_size, id->extent()); } } @@ -424,7 +434,7 @@ void PrecomputedValues::bindTensorMetaData( tv->toString(), " with input tensor ", tensor); - bindValue(metadata_val->evaluatorIndex(), metadata); + bindValue(metadata_val->evaluatorIndex(), metadata, metadata_val); } NaiveValueMachine::NaiveValueMachine(PrecomputedValues& precomputed_values) diff --git a/csrc/evaluator_common.h b/csrc/evaluator_common.h index aabf029ed4d..d674e3bf031 100644 --- a/csrc/evaluator_common.h +++ b/csrc/evaluator_common.h @@ -211,17 +211,22 @@ class PrecomputedValues { //! Bind concrete value to the given index //! if the index is valid. - void bindValue_(int index, const PolymorphicValue& value) { + //! \param ir_node Used to track the original IR node for the index, only + //! used for improving error messages + void bindValue_( + int index, + const PolymorphicValue& value, + const Val* ir_node = nullptr) { if (index < 0 || is_constant_[index]) { return; } defined_[index] = true; values_[index] = value; - binding_log_.emplace_back(index, value); + binding_log_.emplace_back(index, value, ir_node); } template - void bindValue(int index, const T& value) { - bindValue_(index, PolymorphicValue(value)); + void bindValue(int index, const T& value, const Val* ir_node = nullptr) { + bindValue_(index, PolymorphicValue(value), ir_node); } //! Invalidate all computed values in the workspace. @@ -292,7 +297,7 @@ class PrecomputedValues { //! An internal log to keep track of all the bindings //! used in each evaluation cycle. To be used for //! consistency check. - std::vector> binding_log_; + std::vector> binding_log_; //! Integer runtime for realizing the values computations. std::unique_ptr value_machine_;