Skip to content

Commit

Permalink
added example on how to use checkpoint API
Browse files Browse the repository at this point in the history
  • Loading branch information
cnpetra committed Sep 13, 2024
1 parent 5bc2af6 commit a664e35
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 38 deletions.
61 changes: 59 additions & 2 deletions src/Drivers/Dense/NlpDenseConsEx1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
#include <cstdio>
#include <cassert>

#ifdef HIOP_USE_AXOM
#include <axom/sidre/core/DataStore.hpp>
#include <axom/sidre/core/Group.hpp>
#include <axom/sidre/core/View.hpp>
#include <axom/sidre/spio/IOManager.hpp>
using namespace axom;
#endif

using namespace hiop;

Ex1Meshing1D::Ex1Meshing1D(double a, double b, size_type glob_n, double r, MPI_Comm comm_)
Expand Down Expand Up @@ -178,10 +186,59 @@ void DiscretizedFunction::setFunctionValue(index_type i_global, const double& va
this->data_[i_local]=value;
}



/* DenseConsEx1 class implementation */

bool DenseConsEx1::iterate_callback(int iter,
double obj_value,
double logbar_obj_value,
int n,
const double* x,
const double* z_L,
const double* z_U,
int m_ineq,
const double* s,
int m,
const double* g,
const double* lambda,
double inf_pr,
double inf_du,
double onenorm_pr,
double mu,
double alpha_du,
double alpha_pr,
int ls_trials)
{
#ifdef HIOP_USE_AXOM
//save state to sidre::Group every 5 iterations if a solver/algorithm object was provided
if(iter > 0 && (iter % 5 == 0) &&nullptr!=solver_) {
//
//Example of how to save HiOp state to axom::sidre::Group
//

//We first manufacture a Group. User code supposedly already has one.
sidre::DataStore ds;
sidre::Group* group = ds.getRoot()->createGroup("hiop state ex1");

//the actual saving of state to group
try {
solver_->save_state_to_sidre_group(*group);
} catch(::std::runtime_error& e) {
//user chooses action when an error occured in saving the state...
//we choose to stop HiOp
return false;
}

//User code can further inspect the Group or add addtl info to DataStore, with the end goal
//of saving it to file before HiOp starts next iteration. Here we just save it.
sidre::IOManager writer(comm);
int n_files;
MPI_Comm_size(comm, &n_files);
writer.write(ds.getRoot(), n_files, "hiop_state_ex1", sidre::Group::getDefaultIOProtocol());
}
#endif
return true;
}

/*set c to
* c(t) = 1-10*t, for 0<=t<=1/10,
* 0, for 1/10<=t<=1.
Expand Down
46 changes: 44 additions & 2 deletions src/Drivers/Dense/NlpDenseConsEx1.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@
#define MPI_Comm int
#endif

#ifdef HIOP_USE_AXOM
namespace axom {
namespace sidre {
// forward declarations
class DataStore;
class Group;
}
}
#endif


#include <iostream>

/* Example 1: a simple infinite-dimensional QP in the optimiz. function variable x:[0,1]->R
Expand Down Expand Up @@ -78,7 +89,7 @@ class Ex1Meshing1D
MPI_Comm comm;
int my_rank, comm_size;
index_type* col_partition;

friend class DiscretizedFunction;

private:
Expand Down Expand Up @@ -112,7 +123,9 @@ class DenseConsEx1 : public hiop::hiopInterfaceDenseConstraints
{
public:
DenseConsEx1(int n_mesh_elem=100, double mesh_ratio=1.0)
: n_vars(n_mesh_elem), comm(MPI_COMM_WORLD)
: n_vars(n_mesh_elem),
comm(MPI_COMM_WORLD),
solver_(nullptr)
{
//create the members
_mesh = new Ex1Meshing1D(0.0,1.0, n_vars, mesh_ratio, comm);
Expand Down Expand Up @@ -218,6 +231,31 @@ class DenseConsEx1 : public hiop::hiopInterfaceDenseConstraints
}
return true;
}

inline void set_solver(hiop::hiopAlgFilterIPM* alg_obj)
{
solver_ = alg_obj;
}

bool iterate_callback(int iter,
double obj_value,
double logbar_obj_value,
int n,
const double* x,
const double* z_L,
const double* z_U,
int m_ineq,
const double* s,
int m,
const double* g,
const double* lambda,
double inf_pr,
double inf_du,
double onenorm_pr,
double mu,
double alpha_du,
double alpha_pr,
int ls_trials);
private:
int n_vars;
MPI_Comm comm;
Expand All @@ -228,6 +266,10 @@ class DenseConsEx1 : public hiop::hiopInterfaceDenseConstraints
DiscretizedFunction* c;
DiscretizedFunction* x; //proxy for taking hiop's variable in and working with it as a function

/// Pointer to the solver, to be used to checkpoint
hiop::hiopAlgFilterIPM* solver_;

private:
//populates the linear term c
void set_c();
};
Expand Down
108 changes: 95 additions & 13 deletions src/Drivers/Dense/NlpDenseConsEx1Driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,21 @@
#include <cstdlib>
#include <string>

#ifdef HIOP_USE_AXOM
#include <axom/sidre/core/DataStore.hpp>
#include <axom/sidre/core/Group.hpp>
#include <axom/sidre/core/View.hpp>
#include <axom/sidre/spio/IOManager.hpp>
using namespace axom;
#endif


using namespace hiop;

static bool self_check(size_type n, double obj_value);
static bool do_load_checkpoint_test(const size_type& mesh_size,
const double& ratio,
const double& obj_val_expected);

static bool parse_arguments(int argc, char **argv, size_type& n, double& distortion_ratio, bool& self_check)
{
Expand Down Expand Up @@ -67,24 +79,27 @@ int main(int argc, char **argv)
err = MPI_Init(&argc, &argv); assert(MPI_SUCCESS==err);
err = MPI_Comm_rank(MPI_COMM_WORLD,&rank); assert(MPI_SUCCESS==err);
err = MPI_Comm_size(MPI_COMM_WORLD,&numRanks); assert(MPI_SUCCESS==err);
if(0==rank) printf("Support for MPI is enabled\n");
if(0==rank) {
printf("Support for MPI is enabled\n");
}
#endif
bool selfCheck; size_type mesh_size; double ratio;
if(!parse_arguments(argc, argv, mesh_size, ratio, selfCheck)) { usage(argv[0]); return 1;}

bool selfCheck;
size_type mesh_size;
double ratio;
double objective = 0.;
if(!parse_arguments(argc, argv, mesh_size, ratio, selfCheck)) {
usage(argv[0]);
return 1;
}

DenseConsEx1 problem(mesh_size, ratio);
//if(rank==0) printf("interface created\n");
hiop::hiopNlpDenseConstraints nlp(problem);
//if(rank==0) printf("nlp formulation created\n");

//nlp.options->SetIntegerValue("verbosity_level", 4);
//nlp.options->SetNumericValue("tolerance", 1e-4);
//nlp.options->SetStringValue("duals_init", "zero");
//nlp.options->SetIntegerValue("max_iter", 2);

hiop::hiopAlgFilterIPM solver(&nlp);
problem.set_solver(&solver);

hiop::hiopSolveStatus status = solver.run();
double objective = solver.getObjective();
objective = solver.getObjective();

//this is used for testing when the driver is called with -selfcheck
if(selfCheck) {
Expand All @@ -97,7 +112,19 @@ int main(int argc, char **argv)
}
}

if(0==rank) printf("Objective: %18.12e\n", objective);
if(0==rank) {
printf("Objective: %18.12e\n", objective);
}

#ifdef HIOP_USE_AXOM
// example/test for HiOp's load checkpoint API.
if(!do_load_checkpoint_test(mesh_size, ratio, objective)) {
if(rank==0) {
printf("Load checkpoint and restart test failed.");
}
return -1;
}
#endif
#ifdef HIOP_USE_MPI
MPI_Finalize();
#endif
Expand Down Expand Up @@ -134,3 +161,58 @@ static bool self_check(size_type n, double objval)

return true;
}

/**
* An illustration on how to use load_state_from_sidre_group API method of HiOp's algorithm class.
*
*
*/
static bool do_load_checkpoint_test(const size_type& mesh_size,
const double& ratio,
const double& obj_val_expected)
{
#ifdef HIOP_USE_AXOM
//Pretend this is new job and recreate the HiOp objects.
DenseConsEx1 problem(mesh_size, ratio);
hiop::hiopNlpDenseConstraints nlp(problem);

hiop::hiopAlgFilterIPM solver(&nlp);

//
// example of how to use load_state_sidre_group to warm-start
//

//Supposedly, the user code should have the group in hand before asking HiOp to load from it.
//We will manufacture it by loading a sidre checkpoint file. Here the checkpoint file
// "hiop_state_ex1.root" was created from the interface class' iterate_callback method
// (saved every 5 iterations)
sidre::DataStore ds;

try {
sidre::IOManager reader(MPI_COMM_WORLD);
reader.read(ds.getRoot(), "hiop_state_ex1.root", false);
} catch(std::exception& e) {
printf("Failed to read checkpoint file. Error: [%s]", e.what());
return false;
}


//the actual API call
try {
const sidre::Group* group = ds.getRoot()->getGroup("hiop state ex11");
solver.load_state_from_sidre_group(*group);
} catch(std::runtime_error& e) {
printf("Failed to load from sidre::group. Error: [%s]", e.what());
return false;
}

hiop::hiopSolveStatus status = solver.run();
double obj_val = solver.getObjective();
if(obj_val != obj_val_expected) {
return false;
}

#endif

return true;
}
47 changes: 26 additions & 21 deletions src/Optimization/hiopAlgFilterIPM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,8 @@ void hiopAlgFilterIPMBase::displayTerminationMsg()
///////////////////////////////////////////////////////////////////////////////////////////////////
hiopAlgFilterIPMQuasiNewton::hiopAlgFilterIPMQuasiNewton(hiopNlpDenseConstraints* nlp_in,
const bool within_FR)
: hiopAlgFilterIPMBase(nlp_in, within_FR)
: hiopAlgFilterIPMBase(nlp_in, within_FR),
load_state_api_called_(false)
{
nlpdc = nlp_in;
reload_options();
Expand Down Expand Up @@ -984,36 +985,38 @@ hiopSolveStatus hiopAlgFilterIPMQuasiNewton::run()
nlp->runStats.tmOptimizTotal.start();

iter_num_ = 0;
iter_num_total_ = 0;

//
// starting point:
// - user provided (with slack adjustments and lsq eq. duals initialization)
// or
// - loaded checkpoint
// - load checkpoint API (method load_state_from_sidre_group) called before calling this method
// - checkpoint from file (option "checkpoint_load_on_start")
//
if(nlp->options->GetString("checkpoint_load_on_start") != "yes") {
if(nlp->options->GetString("checkpoint_load_on_start") != "yes" && !load_state_api_called_) {
//this also evaluates the nlp
startingProcedure(*it_curr, _f_nlp, *_c, *_d, *_grad_f, *_Jac_c, *_Jac_d);
_mu=mu0;
iter_num_total_ = 0;
} else {
//
//checkpoint load
//
//load from file: will populate it_curr, _Hess_lagr, and algorithmic parameters
auto chkpnt_ok = load_state_from_file(nlp->options->GetString("checkpoint_file"));
if(chkpnt_ok) {
//additionally: need to evaluate the nlp
if(!this->evalNlp_noHess(*it_curr, _f_nlp, *_c, *_d, *_grad_f, *_Jac_c, *_Jac_d)) {
nlp->log->printf(hovError, "Failure in evaluating user NLP functions at loaded checkpoint.");
return Error_In_User_Function;
if(!load_state_api_called_) {
//
//checkpoint load from file
//
//load from file: will populate it_curr, _Hess_lagr, and algorithmic parameters
auto chkpnt_ok = load_state_from_file(nlp->options->GetString("checkpoint_file"));
if(!chkpnt_ok) {
nlp->log->printf(hovWarning, "Using default starting procedure (no checkpoint load!).\n");
iter_num_total_ = 0;
//fall back on the default starting procedure (it also evaluates the nlp)
startingProcedure(*it_curr, _f_nlp, *_c, *_d, *_grad_f, *_Jac_c, *_Jac_d);
_mu=mu0;
iter_num_total_ = 0;
}
} else {
nlp->log->printf(hovWarning, "Using default starting procedure (no checkpoint load!).\n");
iter_num_total_ = 0;
//fall back on the default starting procedure (it also evaluates the nlp)
startingProcedure(*it_curr, _f_nlp, *_c, *_d, *_grad_f, *_Jac_c, *_Jac_d);
_mu=mu0;
}
//additionally: need to evaluate the nlp
if(!this->evalNlp_noHess(*it_curr, _f_nlp, *_c, *_d, *_grad_f, *_Jac_c, *_Jac_d)) {
nlp->log->printf(hovError, "Failure in evaluating user NLP functions at loaded checkpoint.");
return Error_In_User_Function;
}
solver_status_ = NlpSolve_SolveNotCalled;
}
Expand Down Expand Up @@ -1641,6 +1644,8 @@ void hiopAlgFilterIPMQuasiNewton::save_state_to_sidre_group(::axom::sidre::Group

void hiopAlgFilterIPMQuasiNewton::load_state_from_sidre_group(const sidre::Group& group)
{
load_state_api_called_ = true;

//metadata

//algorithmic parameters
Expand Down
5 changes: 5 additions & 0 deletions src/Optimization/hiopAlgFilterIPM.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,11 @@ class hiopAlgFilterIPMQuasiNewton : public hiopAlgFilterIPMBase

private:
hiopNlpDenseConstraints* nlpdc;
#ifdef HIOP_USE_AXOM
///@brief Indicates whether load checkpoint API was called previous to run method.
bool load_state_api_called_;
#endif // HIOP_USE_AXOM

private:
hiopAlgFilterIPMQuasiNewton() : hiopAlgFilterIPMBase(NULL) {};
hiopAlgFilterIPMQuasiNewton(const hiopAlgFilterIPMQuasiNewton& ) : hiopAlgFilterIPMBase(NULL){};
Expand Down

0 comments on commit a664e35

Please sign in to comment.