-
Notifications
You must be signed in to change notification settings - Fork 35
Tupek/gretl refactor #1432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Tupek/gretl refactor #1432
Changes from 66 commits
4ccb49e
77b56ab
f7e758b
c19b7ba
4ec0b48
b72a057
ae86ce3
0746ec4
1ce7e8d
2490a51
471dfa5
a928494
449b438
bb271cf
dbe1feb
f5898c5
be0ea51
96fcff9
37c76a7
61fd848
44107f6
c90d23f
0469141
ba58d95
02c56fd
37c7340
ddee2ce
80d7b4d
fce33b1
094c1d0
0474505
5a7c418
1cb451f
fa0e7cf
9e2b7a9
feb37e1
e6ce787
6206f9e
065a335
0b72bd6
c6078f6
457a261
1cc225b
dda8d50
19a6bed
bd16a6b
6458244
ea31d01
40b10a6
b648971
0d0fb7a
1b7366d
450e7b9
3118646
13440ea
b620140
3e1423f
c7a4c66
9a67661
6cbf528
d34c80b
d8caccd
fd371bf
5e3204d
eadb07f
40fab88
e8f0ebb
29cf89a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,42 @@ | ||
|
|
||
| set(serac_gretl_sources | ||
| data_store.cpp | ||
| state_base.cpp | ||
| vector_state.cpp | ||
| ) | ||
|
|
||
| set(serac_gretl_headers | ||
| checkpoint.hpp | ||
| data_store.hpp | ||
| test_utils.hpp | ||
| state_base.hpp | ||
| state.hpp | ||
| create_state.hpp | ||
| upstream_state.hpp | ||
| double_state.hpp | ||
| vector_state.hpp | ||
| print_utils.hpp | ||
| ) | ||
|
|
||
| blt_add_library( | ||
| NAME serac_gretl | ||
| SOURCES ${serac_gretl_sources} | ||
| HEADERS ${serac_gretl_headers} | ||
| ) | ||
|
|
||
| target_include_directories(serac_gretl PUBLIC | ||
| $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../..> | ||
| $<INSTALL_INTERFACE:include> | ||
| ) | ||
|
|
||
| install(FILES ${serac_gretl_headers} DESTINATION include/serac/gretl ) | ||
|
|
||
| install(TARGETS serac_gretl | ||
| EXPORT serac-targets | ||
| DESTINATION lib | ||
| ) | ||
|
|
||
| if(SERAC_ENABLE_TESTS) | ||
| add_subdirectory(tests) | ||
| endif() | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,237 @@ | ||||||
| // Copyright (c) Lawrence Livermore National Security, LLC and | ||||||
| // other Serac Project Developers. See the top-level LICENSE file for | ||||||
| // details. | ||||||
| // | ||||||
| // SPDX-License-Identifier: (BSD-3-Clause) | ||||||
|
|
||||||
| /** | ||||||
| * @file checkpoint.hpp | ||||||
| */ | ||||||
|
|
||||||
| #pragma once | ||||||
|
|
||||||
| #include <set> | ||||||
| #include <map> | ||||||
| #include <ostream> | ||||||
| #include <iostream> | ||||||
| #include <cassert> | ||||||
| #include <limits> | ||||||
|
|
||||||
| /// @brief gretl_assert that prints line and file info before throwing in release and halting in debug | ||||||
| #define gretl_assert(x) \ | ||||||
| if (!(x)) \ | ||||||
| throw std::runtime_error{"Error on line " + std::to_string(__LINE__) + " in file " + std::string(__FILE__)}; \ | ||||||
| assert(x); | ||||||
|
|
||||||
| /// @brief gretl_assert_msg that prints message, line and file info before throwing in release and halting in debug | ||||||
| #define gretl_assert_msg(x, msg_name_) \ | ||||||
| if (!(x)) \ | ||||||
| throw std::runtime_error{"Error on line " + std::to_string(__LINE__) + " in file " + std::string(__FILE__) + \ | ||||||
| std::string(", ") + std::string(msg_name_)}; \ | ||||||
| assert(x); | ||||||
|
|
||||||
| namespace gretl { | ||||||
|
|
||||||
| struct Unit {}; | ||||||
|
||||||
|
|
||||||
| /// @brief checkpoint struct which tracks level and step per "Minimal Repetition Dynamic Checkpointing Algorithm for | ||||||
| /// Unsteady Adjoint Calculation", Wang, et al. , 2009. | ||||||
| struct Checkpoint { | ||||||
| size_t level; ///< level | ||||||
| size_t step; ///< step | ||||||
| static constexpr size_t infinity() | ||||||
| { | ||||||
| return std::numeric_limits<size_t>::max(); | ||||||
| } ///< The largest possible step and level value | ||||||
| }; | ||||||
|
|
||||||
| /// @brief comparison operator between two checkpoints to determine which is most disposable per the dynamic | ||||||
| /// checkpointing algorithm | ||||||
| inline bool operator<(const Checkpoint& a, const Checkpoint& b) | ||||||
| { | ||||||
| if (a.level == Checkpoint::infinity() && b.level == Checkpoint::infinity()) { | ||||||
| return a.step > b.step; | ||||||
| } | ||||||
| if (a.level == Checkpoint::infinity()) return false; | ||||||
| if (b.level == Checkpoint::infinity()) return true; | ||||||
| return a.step > b.step; | ||||||
| } | ||||||
|
|
||||||
| /// @brief output stream for a single checkpoint | ||||||
| inline std::ostream& operator<<(std::ostream& stream, const Checkpoint& p); | ||||||
|
|
||||||
| /// @brief CheckpointManager class which encapsulates the logic of when and which steps should be dynamically saved a | ||||||
| /// fetched | ||||||
| struct CheckpointManager { | ||||||
| static constexpr size_t invalidCheckpointIndex = | ||||||
| std::numeric_limits<size_t>::max(); ///< magic number of invalid checkpoint | ||||||
|
|
||||||
| /// @brief utilty for checking if an index is valid. There is a magic number, invalidCheckpointIndex, which | ||||||
| /// represents an invalid checkpoint | ||||||
| static bool valid_checkpoint_index(size_t i) { return i != invalidCheckpointIndex; } | ||||||
|
|
||||||
| /// @brief returns const_iterator to currently most dispensable checkpoint step | ||||||
| std::set<gretl::Checkpoint>::const_iterator most_dispensable() const | ||||||
| { | ||||||
| size_t maxHigherTimeLevel = 0; | ||||||
| for (auto rIter = cps.begin(); rIter != cps.end(); ++rIter) { | ||||||
| if (rIter->level < maxHigherTimeLevel) { | ||||||
| return rIter; | ||||||
| } | ||||||
| maxHigherTimeLevel = std::max(rIter->level, maxHigherTimeLevel); | ||||||
| } | ||||||
| return cps.end(); | ||||||
| } | ||||||
|
|
||||||
| /// @brief this does multiple things | ||||||
| /// 1. it adds checkpoints into the database, and updates internal data structures | ||||||
| /// 2. it determines if a checkpoint needs to be removed | ||||||
| /// 3. if a checkpoint needs to be removed, it returns the index for that checkpoint | ||||||
| /// 4. otherwise, it returns zero | ||||||
| size_t add_checkpoint_and_get_index_to_remove(size_t step, bool persistent = false) | ||||||
| { | ||||||
| size_t levelupAmount = 1; //= relativeCost >= 2.0 ? 3 : 1; | ||||||
|
|
||||||
| Checkpoint nextStep{.level = levelupAmount - 1, .step = step}; | ||||||
|
|
||||||
| size_t nextEraseStep = invalidCheckpointIndex; | ||||||
|
|
||||||
| // don't include persistent data in data quota. MRT, this might change | ||||||
| if (persistent) { | ||||||
| maxNumStates++; | ||||||
| nextStep.level = Checkpoint::infinity(); | ||||||
| gretl_assert(cps.size() < maxNumStates); | ||||||
| } | ||||||
|
|
||||||
| if (cps.size() < maxNumStates) { | ||||||
| cps.insert(nextStep); | ||||||
| } else { | ||||||
| auto iterToMostDispensable = most_dispensable(); | ||||||
| if (iterToMostDispensable != cps.end()) { | ||||||
| nextEraseStep = iterToMostDispensable->step; | ||||||
| cps.erase(iterToMostDispensable); | ||||||
| cps.insert(nextStep); | ||||||
| } else { | ||||||
| nextEraseStep = cps.begin()->step; | ||||||
| nextStep.level = cps.begin()->level + levelupAmount; | ||||||
|
|
||||||
| cps.erase(cps.begin()); | ||||||
| cps.insert(nextStep); | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| return nextEraseStep; | ||||||
| } | ||||||
|
|
||||||
| /// @brief return largest currently checkpointed step | ||||||
| size_t last_checkpoint_step() const { return cps.begin()->step; } | ||||||
|
|
||||||
| /// @brief erase | ||||||
| bool erase_step(size_t stepIndex) | ||||||
| { | ||||||
| for (std::set<Checkpoint>::iterator it = cps.begin(); it != cps.end(); ++it) { | ||||||
| if (it->step == stepIndex) { | ||||||
| if (it->level != Checkpoint::infinity()) { | ||||||
| cps.erase(it); | ||||||
| return true; | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
| return false; | ||||||
| } | ||||||
|
|
||||||
| /// @brief check if this step is currently checkpointed. This could potentially use performance optimization down the | ||||||
| /// way. | ||||||
| bool contains_step(size_t stepIndex) const | ||||||
| { | ||||||
| for (auto& c : cps) { | ||||||
| if (c.step == stepIndex) { | ||||||
| return true; | ||||||
| } | ||||||
| } | ||||||
| return false; | ||||||
| } | ||||||
|
|
||||||
| /// @brief erase all non persistent checkpoints | ||||||
| void reset() | ||||||
| { | ||||||
| for (auto cp_it = cps.begin(); cp_it != cps.end(); ++cp_it) { | ||||||
| if (cp_it->level == Checkpoint::infinity()) { | ||||||
| cps.erase(cps.begin(), cp_it); | ||||||
| break; | ||||||
| } | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| size_t maxNumStates = | ||||||
| 20; ///< The maximum number of non-persistent, not-in-scope states stored by the CheckpointManager | ||||||
| std::set<Checkpoint> cps; ///< Vector of checkpoints | ||||||
| }; | ||||||
|
|
||||||
| /// @brief interface to run forward with a linear graph, checkpoint, then automatically backpropagate the sensitivities | ||||||
| /// given the reverse_callback jvp. | ||||||
|
||||||
| /// given the reverse_callback jvp. | |
| /// given the reverse_callback vjp. |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep!