Skip to content

Commit

Permalink
Merge pull request #38075 from mantidproject/ewm7196-nans-flag
Browse files Browse the repository at this point in the history
Add flag to `CompareWorkspaces` so users can specify `NaN == NaN` behavior
  • Loading branch information
KedoKudo authored Oct 1, 2024
2 parents f94d911 + b2acf19 commit db4f29d
Show file tree
Hide file tree
Showing 38 changed files with 598 additions and 194 deletions.
8 changes: 6 additions & 2 deletions Framework/API/inc/MantidAPI/Column.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,13 @@ class MANTID_API_DLL Column {
return vec;
}

virtual bool equals(const Column &, double) const { throw std::runtime_error("equals not implemented"); };
virtual bool equals(const Column &, double, bool const = false) const {
throw std::runtime_error("equals not implemented");
};

virtual bool equalsRelErr(const Column &, double) const { throw std::runtime_error("equals not implemented"); };
virtual bool equalsRelErr(const Column &, double, bool const = false) const {
throw std::runtime_error("equals not implemented");
};

protected:
/// Sets the new column size.
Expand Down
4 changes: 2 additions & 2 deletions Framework/Algorithms/inc/MantidAlgorithms/CompareWorkspaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ class MANTID_ALGORITHMS_DLL CompareWorkspaces final : public API::Algorithm {
"testing process.";
}

static bool withinAbsoluteTolerance(double x1, double x2, double atol);
static bool withinRelativeTolerance(double x1, double x2, double rtol);
static bool withinAbsoluteTolerance(double x1, double x2, double atol, bool const nanEqual = false);
static bool withinRelativeTolerance(double x1, double x2, double rtol, bool const nanEqual = false);

private:
/// Initialise algorithm
Expand Down
94 changes: 48 additions & 46 deletions Framework/Algorithms/src/CompareWorkspaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "MantidGeometry/Crystal/IPeak.h"
#include "MantidGeometry/Instrument/ComponentInfo.h"
#include "MantidGeometry/Instrument/DetectorInfo.h"
#include "MantidKernel/FloatingPointComparison.h"
#include "MantidKernel/Unit.h"

namespace Mantid::Algorithms {
Expand Down Expand Up @@ -71,23 +72,18 @@ int compareEventLists(Kernel::Logger &logger, const EventList &el1, const EventL
const auto &e1 = events1[i];
const auto &e2 = events2[i];

bool diffpulse = false;
bool difftof = false;
bool diffweight = false;
if (std::abs(e1.pulseTime().totalNanoseconds() - e2.pulseTime().totalNanoseconds()) > tolPulse) {
diffpulse = true;
++numdiffpulse;
}
if (std::abs(e1.tof() - e2.tof()) > tolTof) {
difftof = true;
++numdifftof;
}
bool diffpulse =
!withinAbsoluteDifference(e1.pulseTime().totalNanoseconds(), e2.pulseTime().totalNanoseconds(), tolPulse);
bool difftof = !withinAbsoluteDifference(e1.tof(), e2.tof(), tolTof);
bool diffweight = !withinAbsoluteDifference(e1.weight(), e2.weight(), tolWeight);
if (diffpulse && difftof)
++numdiffboth;
if (std::abs(e1.weight() - e2.weight()) > tolWeight) {
diffweight = true;
++numdiffweight;
}
numdiffboth++;
if (diffpulse)
numdiffpulse++;
if (difftof)
numdifftof++;
if (diffweight)
numdiffweight++;

bool same = (!diffpulse) && (!difftof) && (!diffweight);
if (!same) {
Expand Down Expand Up @@ -148,6 +144,8 @@ void CompareWorkspaces::init() {
"Very often such logs are huge so making it true should be "
"the last option.");

declareProperty("NaNsEqual", false, "Whether NaN values should compare as equal with other NaN values.");

declareProperty("NumberMismatchedSpectraToPrint", 1, "Number of mismatched spectra from lowest to be listed. ");

declareProperty("DetailedPrintIndex", EMPTY_INT(), "Mismatched spectra that will be printed out in details. ");
Expand All @@ -172,13 +170,14 @@ void CompareWorkspaces::exec() {
m_parallelComparison = false;

double const tolerance = getProperty("Tolerance");
bool const nanEqual = getProperty("NaNsEqual");
if (getProperty("ToleranceRelErr")) {
this->m_compare = [tolerance](double const x1, double const x2) -> bool {
return CompareWorkspaces::withinRelativeTolerance(x1, x2, tolerance);
this->m_compare = [tolerance, nanEqual](double const x1, double const x2) -> bool {
return CompareWorkspaces::withinRelativeTolerance(x1, x2, tolerance, nanEqual);
};
} else {
this->m_compare = [tolerance](double const x1, double const x2) -> bool {
return CompareWorkspaces::withinAbsoluteTolerance(x1, x2, tolerance);
this->m_compare = [tolerance, nanEqual](double const x1, double const x2) -> bool {
return CompareWorkspaces::withinAbsoluteTolerance(x1, x2, tolerance, nanEqual);
};
}

Expand Down Expand Up @@ -1049,10 +1048,11 @@ void CompareWorkspaces::doPeaksComparison(PeaksWorkspace_sptr tws1, PeaksWorkspa
}

const bool isRelErr = getProperty("ToleranceRelErr");
const bool checkAllData = getProperty("CheckAllData");
for (int i = 0; i < tws1->getNumberPeaks(); i++) {
const Peak &peak1 = tws1->getPeak(i);
const Peak &peak2 = tws2->getPeak(i);
for (size_t j = 0; j < tws1->columnCount(); j++) {
for (std::size_t j = 0; j < tws1->columnCount(); j++) {
std::shared_ptr<const API::Column> col = tws1->getColumn(j);
std::string name = col->name();
double s1 = 0.0;
Expand Down Expand Up @@ -1127,7 +1127,8 @@ void CompareWorkspaces::doPeaksComparison(PeaksWorkspace_sptr tws1, PeaksWorkspa
<< "value1 = " << s1 << "\n"
<< "value2 = " << s2 << "\n";
recordMismatch("Data mismatch");
return;
if (!checkAllData)
return;
}
}
}
Expand Down Expand Up @@ -1163,8 +1164,10 @@ void CompareWorkspaces::doLeanElasticPeaksComparison(const LeanElasticPeaksWorks

const double tolerance = getProperty("Tolerance");
const bool isRelErr = getProperty("ToleranceRelErr");
const bool checkAllData = getProperty("CheckAllData");
const bool nanEqual = getProperty("NaNsEqual");
for (int peakIndex = 0; peakIndex < ipws1->getNumberPeaks(); peakIndex++) {
for (size_t j = 0; j < ipws1->columnCount(); j++) {
for (std::size_t j = 0; j < ipws1->columnCount(); j++) {
std::shared_ptr<const API::Column> col = ipws1->getColumn(j);
const std::string name = col->name();
double s1 = 0.0;
Expand Down Expand Up @@ -1229,10 +1232,10 @@ void CompareWorkspaces::doLeanElasticPeaksComparison(const LeanElasticPeaksWorks
// bool mismatch = !m_compare(s1, s2)
// can replace this if/else, and isRelErr and tolerance can be deleted
if (isRelErr && name != "QLab" && name != "QSample") {
if (!withinRelativeTolerance(s1, s2, tolerance)) {
if (!withinRelativeTolerance(s1, s2, tolerance, nanEqual)) {
mismatch = true;
}
} else if (!withinAbsoluteTolerance(s1, s2, tolerance)) {
} else if (!withinAbsoluteTolerance(s1, s2, tolerance, nanEqual)) {
mismatch = true;
}
if (mismatch) {
Expand All @@ -1242,7 +1245,8 @@ void CompareWorkspaces::doLeanElasticPeaksComparison(const LeanElasticPeaksWorks
<< "value1 = " << s1 << "\n"
<< "value2 = " << s2 << "\n";
recordMismatch("Data mismatch");
return;
if (!checkAllData)
return;
}
}
}
Expand Down Expand Up @@ -1283,19 +1287,23 @@ void CompareWorkspaces::doTableComparison(const API::ITableWorkspace_const_sptr

const bool checkAllData = getProperty("CheckAllData");
const bool isRelErr = getProperty("ToleranceRelErr");
const bool nanEqual = getProperty("NaNsEqual");
const double tolerance = getProperty("Tolerance");
bool mismatch;
for (size_t i = 0; i < numCols; ++i) {
for (std::size_t i = 0; i < numCols; ++i) {
const auto c1 = tws1->getColumn(i);
const auto c2 = tws2->getColumn(i);

if (isRelErr) {
mismatch = !c1->equalsRelErr(*c2, tolerance);
mismatch = !c1->equalsRelErr(*c2, tolerance, nanEqual);
} else {
mismatch = !c1->equals(*c2, tolerance);
mismatch = !c1->equals(*c2, tolerance, nanEqual);
}
if (mismatch) {
g_log.debug() << "Table data mismatch at column " << i << "\n";
for (std::size_t j = 0; j < c1->size(); j++) {
g_log.debug() << "\t" << j << " | " << c1->cell<double>(j) << ", " << c2->cell<double>(j) << "\n";
}
recordMismatch("Table data mismatch");
if (!checkAllData) {
return;
Expand Down Expand Up @@ -1356,12 +1364,15 @@ this error is within the limits requested.
@param x1 -- first value to check difference
@param x2 -- second value to check difference
@param atol -- the tolerance of the comparison. Must be nonnegative
@param nanEqual -- whether two NaNs compare as equal
@returns true if absolute difference is within the tolerance; false otherwise
*/
bool CompareWorkspaces::withinAbsoluteTolerance(double const x1, double const x2, double const atol) {
// NOTE !(|x1-x2| > atol) is not the same as |x1-x2| <= atol
return !(std::abs(x1 - x2) > atol);
bool CompareWorkspaces::withinAbsoluteTolerance(double const x1, double const x2, double const atol,
bool const nanEqual) {
if (nanEqual && std::isnan(x1) && std::isnan(x2))
return true;
return Kernel::withinAbsoluteDifference(x1, x2, atol);
}

//------------------------------------------------------------------------------------------------
Expand All @@ -1371,24 +1382,15 @@ this error is within the limits requested.
@param x1 -- first value to check difference
@param x2 -- second value to check difference
@param rtol -- the tolerance of the comparison. Must be nonnegative
@param nanEqual -- whether two NaNs compare as equal
@returns true if relative difference is within the tolerance; false otherwise
@returns true if error or false if the relative value is within the limits requested
*/
bool CompareWorkspaces::withinRelativeTolerance(double const x1, double const x2, double const rtol) {
// calculate difference
double const num = std::abs(x1 - x2);
// return early if the values are equal
if (num == 0.0)
bool CompareWorkspaces::withinRelativeTolerance(double const x1, double const x2, double const rtol,
bool const nanEqual) {
if (nanEqual && std::isnan(x1) && std::isnan(x2))
return true;
// create the average magnitude for comparison
double const den = 0.5 * (std::abs(x1) + std::abs(x2));
// return early, possibly avoids a multiplication
// NOTE if den<1, then divsion will only make num larger
// NOTE if den<1 but num<=rtol, we cannot conclude anything
if (den <= 1.0 && num > rtol)
return false;
// NOTE !(num > rtol*den) is not the same as (num <= rtol*den)
return !(num > (rtol * den));
return Kernel::withinRelativeDifference(x1, x2, rtol);
}
} // namespace Mantid::Algorithms
2 changes: 1 addition & 1 deletion Framework/Algorithms/src/DetectorEfficiencyCor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void DetectorEfficiencyCor::exec() {
int64_t numHists = m_inputWS->getNumberHistograms();
auto numHists_d = static_cast<double>(numHists);
const auto progStep = static_cast<int64_t>(ceil(numHists_d / 100.0));
auto &spectrumInfo = m_inputWS->spectrumInfo();
auto const &spectrumInfo = m_inputWS->spectrumInfo();

PARALLEL_FOR_IF(Kernel::threadSafe(*m_inputWS, *m_outputWS))
for (int64_t i = 0; i < numHists; ++i) {
Expand Down
115 changes: 115 additions & 0 deletions Framework/Algorithms/test/CompareWorkspacesTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,44 @@ class CompareWorkspacesTest : public CxxTest::TestSuite {
checker.resetProperties();
}

void test_NaNsEqual_true() {
if (!checker.isInitialized())
checker.initialize();

double const anan = std::numeric_limits<double>::quiet_NaN();

// a real and NaN are never equal
WorkspaceSingleValue_sptr ws1 = WorkspaceCreationHelper::createWorkspaceSingleValue(1.1);
WorkspaceSingleValue_sptr ws2 = WorkspaceCreationHelper::createWorkspaceSingleValue(anan);
// is not equal if NaNsEqual set true
TS_ASSERT_THROWS_NOTHING(checker.setProperty("NaNsEqual", true));
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace1", ws1));
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace2", ws2));
TS_ASSERT(checker.execute());
TS_ASSERT_EQUALS(checker.getPropertyValue("Result"), PROPERTY_VALUE_FALSE);
// is not equal if NaNsEqual set false
TS_ASSERT_THROWS_NOTHING(checker.setProperty("NaNsEqual", false));
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace1", ws1));
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace2", ws2));
TS_ASSERT(checker.execute());
TS_ASSERT_EQUALS(checker.getPropertyValue("Result"), PROPERTY_VALUE_FALSE);

// NaNs only compare equal if flag set
WorkspaceSingleValue_sptr ws3 = WorkspaceCreationHelper::createWorkspaceSingleValue(anan);
// is NOT equal if NaNsEqual set FALSE
TS_ASSERT_THROWS_NOTHING(checker.setProperty("NaNsEqual", false));
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace1", ws2));
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace2", ws3));
TS_ASSERT(checker.execute());
TS_ASSERT_EQUALS(checker.getPropertyValue("Result"), PROPERTY_VALUE_FALSE);
// ARE equal if NaNsEqual set TRUE
TS_ASSERT_THROWS_NOTHING(checker.setProperty("NaNsEqual", true));
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace1", ws2));
TS_ASSERT_THROWS_NOTHING(checker.setProperty("Workspace2", ws3));
TS_ASSERT(checker.execute());
TS_ASSERT_EQUALS(checker.getPropertyValue("Result"), PROPERTY_VALUE_TRUE);
}

void testPeaks_matches() {
if (!checker.isInitialized())
checker.initialize();
Expand Down Expand Up @@ -1193,6 +1231,83 @@ class CompareWorkspacesTest : public CxxTest::TestSuite {
TS_ASSERT_EQUALS(alg.getPropertyValue("Result"), PROPERTY_VALUE_TRUE);
}

void test_equal_tableworkspaces_match() {
std::string const col_type("double"), col_name("aColumn");
std::vector<double> col_values{1.0, 2.0, 3.0};
// create the table workspaces
Mantid::API::ITableWorkspace_sptr table1 = WorkspaceFactory::Instance().createTable();
table1->addColumn(col_type, col_name);
for (double val : col_values) {
TableRow newrow = table1->appendRow();
newrow << val;
}
Mantid::API::ITableWorkspace_sptr table2 = WorkspaceFactory::Instance().createTable();
table2->addColumn(col_type, col_name);
for (double val : col_values) {
TableRow newrow = table2->appendRow();
newrow << val;
}

Mantid::Algorithms::CompareWorkspaces alg;
alg.initialize();
TS_ASSERT_THROWS_NOTHING(alg.setProperty("Workspace1", table1));
TS_ASSERT_THROWS_NOTHING(alg.setProperty("Workspace2", table2));
TS_ASSERT(alg.execute());
TS_ASSERT_EQUALS(alg.getPropertyValue("Result"), PROPERTY_VALUE_TRUE);
}

void test_tableworkspace_NaNs_passes_with_flag() {
std::string const col_type("double"), col_name("aColumn");
std::vector<double> col_values{1.0, 2.0, std::numeric_limits<double>::quiet_NaN()};
// create the table workspaces
Mantid::API::ITableWorkspace_sptr table1 = WorkspaceFactory::Instance().createTable();
Mantid::API::ITableWorkspace_sptr table2 = WorkspaceFactory::Instance().createTable();
table1->addColumn(col_type, col_name);
table2->addColumn(col_type, col_name);
for (double val : col_values) {
TableRow newrow1 = table1->appendRow();
newrow1 << val;
TableRow newrow2 = table2->appendRow();
newrow2 << val;
}
Mantid::Algorithms::CompareWorkspaces alg;
alg.initialize();
TS_ASSERT_THROWS_NOTHING(alg.setProperty("Workspace1", table1));
TS_ASSERT_THROWS_NOTHING(alg.setProperty("Workspace2", table2));
TS_ASSERT_THROWS_NOTHING(alg.setProperty("NaNsEqual", true));
TS_ASSERT(alg.execute());
TS_ASSERT_EQUALS(alg.getPropertyValue("Result"), PROPERTY_VALUE_TRUE);
}

void test_tableworkspace_NaNs_fails() {
std::string const col_type("double"), col_name("aColumn");
std::vector<double> col_values1{1.0, 2.0, 3.0};
std::vector<double> col_values2{1.0, 2.0, std::numeric_limits<double>::quiet_NaN()};
// create the table workspaces
Mantid::API::ITableWorkspace_sptr table1 = WorkspaceFactory::Instance().createTable();
table1->addColumn(col_type, col_name);
for (double val : col_values1) {
TableRow newrow = table1->appendRow();
newrow << val;
}
Mantid::API::ITableWorkspace_sptr table2 = WorkspaceFactory::Instance().createTable();
table2->addColumn(col_type, col_name);
for (double val : col_values2) {
TableRow newrow = table2->appendRow();
newrow << val;
}

Mantid::Algorithms::CompareWorkspaces alg;
alg.initialize();
TS_ASSERT_THROWS_NOTHING(alg.setProperty("Workspace1", table1));
TS_ASSERT_THROWS_NOTHING(alg.setProperty("Workspace2", table2));
TS_ASSERT(alg.execute());
TS_ASSERT_EQUALS(alg.getPropertyValue("Result"), PROPERTY_VALUE_FALSE);

ITableWorkspace_sptr table = AnalysisDataService::Instance().retrieveWS<TableWorkspace>("compare_msgs");
TS_ASSERT_EQUALS(table->cell<std::string>(0, 0), "Table data mismatch");
}

void test_tableworkspace_different_column_names_fails() {
auto table1 = setupTableWorkspace();
table1->getColumn(5)->setName("SomethingElse");
Expand Down
Loading

0 comments on commit db4f29d

Please sign in to comment.