Skip to content

Commit

Permalink
address comments and harden tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mahf708 committed Jul 1, 2024
1 parent e43b59f commit 5025176
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 51 deletions.
47 changes: 27 additions & 20 deletions components/eamxx/src/diagnostics/atm_tend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ namespace scream {
AtmTendDiag::AtmTendDiag(const ekat::Comm &comm,
const ekat::ParameterList &params)
: AtmosphereDiagnostic(comm, params) {
EKAT_REQUIRE_MSG(params.isParameter("Tend Name"),
"Error! AtmTendDiag requires 'Tend Name' in its "
EKAT_REQUIRE_MSG(params.isParameter("Field Name"),
"Error! AtmTendDiag requires 'Field Name' in its "
"input parameters.\n");

m_name = m_params.get<std::string>("Tend Name");
m_name = m_params.get<std::string>("Field Name");
}

std::string AtmTendDiag::name() const { return m_name + "_atm_tend"; }
Expand All @@ -29,48 +29,55 @@ void AtmTendDiag::set_grids(
void AtmTendDiag::initialize_impl(const RunType /*run_type*/) {
const auto &f = get_field_in(m_name);
const auto &fid = f.get_header().get_identifier();
const auto &gn = fid.get_grid_name();

// Sanity checks
using namespace ShortFieldTagsNames;
const auto &layout = fid.get_layout();
EKAT_REQUIRE_MSG(f.data_type() == DataType::RealType,
"Error! FieldAtHeight only supports Real data type field.\n"
"Error! AtmTendDiag only supports Real data type field.\n"
" - field name: " +
fid.name() +
"\n"
" - field data type: " +
e2str(f.data_type()) + "\n");

using namespace ekat::units;
// The units are the same except per second
auto diag_units = fid.get_units() / s;
// TODO: set the units string correctly by appending "/s"

// All good, create the diag output
FieldIdentifier d_fid(name(), layout.clone(), fid.get_units(),
fid.get_grid_name());
FieldIdentifier d_fid(name(), layout.clone(), diag_units, gn);
m_diagnostic_output = Field(d_fid);
m_diagnostic_output.allocate_view();

// Let's also create the previous field
FieldIdentifier prev_fid(name() + "_prev", layout.clone(), fid.get_units(),
fid.get_grid_name());
m_field_prev = Field(prev_fid);
m_field_prev.allocate_view();
FieldIdentifier prev_fid(name() + "_prev", layout.clone(), diag_units, gn);
m_f_prev = Field(prev_fid);
m_f_prev.allocate_view();
}
void AtmTendDiag::compute_diagnostic_impl() {
Real var_fill_value = constants::DefaultFillValue<Real>().value;
std::int64_t dt;
auto tts = m_diagnostic_output.get_header().get_tracking().get_time_stamp();

const auto &f = get_field_in(m_name);
const auto &f = get_field_in(m_name);
const auto &curr_ts = f.get_header().get_tracking().get_time_stamp();
const auto &prev_ts = m_f_prev.get_header().get_tracking().get_time_stamp();

if(m_ts.is_valid()) {
dt = tts - m_ts;
auto ddt = static_cast<Real>(dt);
m_ts = tts;
m_field_prev.update(f, 1 / ddt, -1 / ddt);
m_diagnostic_output.deep_copy(m_field_prev);
if(prev_ts.is_valid()) {
// This diag was called before, so we have a valid value for m_f_prev,
// and can compute the tendency
dt = curr_ts - prev_ts;
m_f_prev.update(f, 1.0 / dt, -1.0 / dt);
m_diagnostic_output.deep_copy(m_f_prev);
} else {
// This is the first time we evaluate this diag. We cannot compute a tend
// yet, so fill with an invalid value
m_diagnostic_output.deep_copy(var_fill_value);
m_ts = tts;
}
m_field_prev.deep_copy(f);
m_f_prev.deep_copy(f);
m_f_prev.get_header().get_tracking().update_time_stamp(curr_ts);
}

} // namespace scream
5 changes: 1 addition & 4 deletions components/eamxx/src/diagnostics/atm_tend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,7 @@ class AtmTendDiag : public AtmosphereDiagnostic {
std::string m_name;

// Store the previous field
Field m_field_prev;

// Store a time stamp
util::TimeStamp m_ts;
Field m_f_prev;

}; // class AtmTendDiag

Expand Down
51 changes: 24 additions & 27 deletions components/eamxx/src/diagnostics/tests/atm_tend_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ std::shared_ptr<GridsManager> create_gm(const ekat::Comm &comm, const int ncols,
return gm;
}

TEST_CASE("extraaci") {
TEST_CASE("atm_tend") {
using namespace ShortFieldTagsNames;
using namespace ekat::units;

Expand All @@ -36,11 +36,9 @@ TEST_CASE("extraaci") {
// A time stamp
util::TimeStamp t0({2024, 1, 1}, {0, 0, 0});

const auto nondim = Units::nondimensional();

// Create a grids manager - single column for these tests
constexpr int nlevs = 5;
const int ngcols = 1 * comm.size();
constexpr int nlevs = 25;
const int ngcols = 25 * comm.size();

auto gm = create_gm(comm, ngcols, nlevs);
auto grid = gm->get_grid("Physics");
Expand All @@ -51,11 +49,11 @@ TEST_CASE("extraaci") {

Field qc(qc_fid);
qc.allocate_view();
qc.get_header().get_tracking().update_time_stamp(t0);

// Construct random number generator stuff
using RPDF = std::uniform_real_distribution<Real>;
RPDF pdf(0, 0.05);
RPDF pdf(0.0, 200.0);

auto engine = scream::setup_random_test();

// Construct the Diagnostics
Expand All @@ -65,47 +63,46 @@ TEST_CASE("extraaci") {

ekat::ParameterList params;
REQUIRE_THROWS(
diag_factory.create("AtmTendDiag", comm, params)); // No 'Tend Name'
diag_factory.create("AtmTendDiag", comm, params)); // No 'Field Name'

// TODO: The diag currently doesn't throw when given a phony name, need
// hardening! params.set<std::string>("Tend Name", "NoWay"); REQUIRE_THROWS
// (diag_factory.create("AtmTendDiag",comm,params)); // Bad 'Tend Name'
Real var_fill_value = constants::DefaultFillValue<Real>().value;

// Randomize
// Set time for qc and randomize its values
qc.get_header().get_tracking().update_time_stamp(t0);
randomize(qc, engine, pdf);

// Create and set up the diagnostic
params.set("grid_name", grid->name());
params.set<std::string>("Tend Name", "qc");
params.set<std::string>("Field Name", "qc");
auto diag = diag_factory.create("AtmTendDiag", comm, params);
diag->set_grids(gm);
diag->set_required_field(qc);
diag->initialize(t0, RunType::Initial);

auto qc_v = qc.get_view<Real **, Host>();
qc_v(0, 0) = 5.0;
qc.sync_to_dev();

// Run diag
diag->compute_diagnostic();
auto diag_f = diag->get_diagnostic();

Real var_fill_value = constants::DefaultFillValue<Real>().value;

// Check result: diag should be filled with var_fill_value
auto some_field = qc.clone();
some_field.deep_copy(var_fill_value);
REQUIRE(views_are_equal(diag_f, some_field));

util::TimeStamp t1({2024, 1, 2}, {0, 0, 0}); // a day later?
some_field.deep_copy(qc);

util::TimeStamp t1({2024, 1, 2}, {0, 0, 0}); // a day later
const Real a_day = 24.0 * 60.0 * 60.0; // seconds
qc.get_header().get_tracking().update_time_stamp(t1);
qc_v(0, 0) = 29.0;
qc.sync_to_dev();
// diag->initialize(t1, RunType::Initial);
// diag->update(t1, RunType::Initial);
randomize(qc, engine, pdf);

// Run diag again
diag->compute_diagnostic();
diag_f = diag->get_diagnostic();
auto diag_v = diag_f.get_view<Real **, Host>();
REQUIRE(diag_v(0, 0) == 1.0 / 3600.0);
some_field.update(qc, 1.0 / a_day, -1.0 / a_day);
REQUIRE(views_are_equal(diag_f, some_field));

// This should fail (return false):
some_field.update(qc, 1.0 / a_day, -1.0 / a_day);
REQUIRE_FALSE(views_are_equal(diag_f, some_field));
}

} // namespace scream

0 comments on commit 5025176

Please sign in to comment.