Skip to content

Commit

Permalink
some code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
reiase committed Mar 25, 2024
1 parent 15ea7e5 commit 100486d
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 42 deletions.
102 changes: 99 additions & 3 deletions core/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,31 @@ use std::collections::HashSet;
use std::fmt::Debug;

use crate::storage::{
frozen_global_storage, Entry, GetOrElse, MultipleVersion, Tree, THREAD_STORAGE,
frozen_global_storage, Entry, GetOrElse, MultipleVersion, Params, THREAD_STORAGE,
};
use crate::value::{Value, EMPTY};
use crate::xxh::XXHashable;

/// ParameterScope
///
/// `ParameterScope` is a data structure that stores the current set of named parameters
/// and their values. `ParameterScope` is used to manage the scope of named parameters,
/// allowing parameters to be defined and used within a specific scope,
/// and then restored to the previous scope when the scope is exited.
///
/// The parameter scope can be used to implement a variety of features, such
/// as named parameters, default parameter values, and parameter inheritance.
#[derive(Debug, Clone)]
pub enum ParamScope {
/// No parameters are defined in the current scope.
Nothing,
Just(Tree),
/// The current scope contains a set of named parameters stored in `Params`.
Just(Params),
}

impl Default for ParamScope {
fn default() -> Self {
ParamScope::Just(Tree::new())
ParamScope::Just(Params::new())
}
}

Expand Down Expand Up @@ -505,3 +516,88 @@ mod tests {
}
}
}

// FILEPATH: /home/reiase/workspace/hyperparameter/core/src/api.rs
// BEGIN: test_code

#[cfg(test)]
mod test_param_scope {
use super::*;
use std::convert::TryInto;

#[test]
fn test_param_scope_default() {
let ps = ParamScope::default();
match ps {
ParamScope::Just(_) => assert!(true),
_ => assert!(false, "Default ParamScope should be ParamScope::Just"),
}
}

#[test]
fn test_param_scope_from_vec() {
let vec = vec!["param1=value1", "param2=value2"];
let ps: ParamScope = (&vec).into();
match ps {
ParamScope::Just(params) => {
assert_eq!(params.get(&"param1".xxh()).unwrap().value(), &Value::from("value1"));
assert_eq!(params.get(&"param2".xxh()).unwrap().value(), &Value::from("value2"));
}
_ => assert!(false, "ParamScope should be ParamScope::Just"),
}
}

#[test]
fn test_param_scope_get_with_hash() {
let mut ps = ParamScope::default();
ps.add("param=value");
let value = ps.get_with_hash("param".xxh());
assert_eq!(value, Value::from("value"));
}

#[test]
fn test_param_scope_get() {
let mut ps = ParamScope::default();
ps.add("param=value");
let value: String = ps.get("param").try_into().unwrap();
assert_eq!(value, "value");
}

#[test]
fn test_param_scope_add() {
let mut ps = ParamScope::default();
ps.add("param=value");
match ps {
ParamScope::Just(params) => {
assert_eq!(params.get(&"param".xxh()).unwrap().value(), &Value::from("value"));
}
_ => assert!(false, "ParamScope should be ParamScope::Just"),
}
}

#[test]
fn test_param_scope_keys() {
let mut ps = ParamScope::default();
ps.add("param=value");
let keys = ps.keys();
assert_eq!(keys, vec!["param"]);
}

#[test]
fn test_param_scope_enter_exit() {
let mut ps = ParamScope::default();
ps.add("param=value");
ps.enter();
match ps {
ParamScope::Nothing => assert!(true),
_ => assert!(false, "ParamScope should be ParamScope::Nothing after enter"),
}
ps.exit();
match ps {
ParamScope::Just(_) => assert!(true),
_ => assert!(false, "ParamScope should be ParamScope::Just after exit"),
}
}
}

// END: test_code
54 changes: 28 additions & 26 deletions core/src/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,23 +40,25 @@ impl Entry {
}
}

pub type Tree = BTreeMap<u64, Entry>;
pub type Params = BTreeMap<u64, Entry>;

pub trait MultipleVersion<K> {
fn update<V: Into<Value>>(&mut self, key: K, val: V);
fn revision<V: Into<Value>>(&mut self, key: K, val: V);
fn rollback(&mut self, key: K);
}

impl MultipleVersion<u64> for Tree {
impl MultipleVersion<u64> for Params {
fn update<V: Into<Value>>(&mut self, key: u64, val: V) {
self.entry(key).and_modify(|e| {
if let Some(e) = self.get_mut(&key) {
e.val.update(val);
});
}
}

fn revision<V: Into<Value>>(&mut self, key: u64, val: V) {
self.entry(key).and_modify(|e| e.val.revision(val));
if let Some(e) = self.get_mut(&key) {
e.val.revision(val);
}
}

fn rollback(&mut self, key: u64) {
Expand All @@ -75,8 +77,8 @@ thread_local! {
fn create_thread_storage() -> RefCell<Storage> {
let ts = RefCell::new(Storage::default());
ts.borrow_mut()
.tree
.clone_from(&GLOBAL_STORAGE.lock().unwrap().tree);
.params
.clone_from(&GLOBAL_STORAGE.lock().unwrap().params);
ts
}

Expand All @@ -89,14 +91,14 @@ pub fn frozen_global_storage() {
GLOBAL_STORAGE
.lock()
.unwrap()
.tree
.clone_from(&ts.borrow().tree);
.params
.clone_from(&ts.borrow().params);
});
}

#[derive(Debug)]
pub struct Storage {
pub tree: Tree,
pub params: Params,
pub history: Vec<HashSet<u64>>,
}

Expand All @@ -105,7 +107,7 @@ unsafe impl Send for Storage {}
impl Default for Storage {
fn default() -> Self {
Storage {
tree: Tree::new(),
params: Params::new(),
history: vec![HashSet::new()],
}
}
Expand All @@ -116,30 +118,30 @@ impl Storage {
self.history.push(HashSet::new());
}

pub fn exit(&mut self) -> Tree {
let mut changes = Tree::new();
pub fn exit(&mut self) -> Params {
let mut changes = Params::new();
for key in self.history.pop().unwrap() {
changes.insert(key, self.tree.get(&key).unwrap().shallow());
self.tree.rollback(key);
changes.insert(key, self.params.get(&key).unwrap().shallow());
self.params.rollback(key);
}
changes
}

pub fn get_entry(&self, key: u64) -> Option<&Entry> {
self.tree.get(&key)
self.params.get(&key)
}

pub fn put_entry(&mut self, key: u64, entry: Entry) -> Option<Entry> {
self.tree.insert(key, entry)
self.params.insert(key, entry)
}

pub fn del_entry(&mut self, key: u64) {
self.tree.remove(&key);
self.params.remove(&key);
}

pub fn get<T: XXHashable>(&self, key: T) -> &Value {
let hkey = key.xxh();
if let Some(e) = self.tree.get(&hkey) {
if let Some(e) = self.params.get(&hkey) {
e.value()
} else {
&EMPTY
Expand All @@ -150,15 +152,15 @@ impl Storage {
let hkey = key.xxh();
let key: String = key.into();
if self.history.last().unwrap().contains(&hkey) {
self.tree.update(hkey, val);
self.params.update(hkey, val);
} else {
if let std::collections::btree_map::Entry::Vacant(e) = self.tree.entry(hkey) {
if let std::collections::btree_map::Entry::Vacant(e) = self.params.entry(hkey) {
e.insert(Entry {
key,
val: VersionedValue::from(val.into()),
});
} else {
self.tree.revision(hkey, val);
self.params.revision(hkey, val);
}
self.history.last_mut().unwrap().insert(hkey);
}
Expand All @@ -167,15 +169,15 @@ impl Storage {
pub fn del<T: XXHashable>(&mut self, key: T) {
let hkey = key.xxh();
if self.history.last().unwrap().contains(&hkey) {
self.tree.update(hkey, None::<i32>);
self.params.update(hkey, None::<i32>);
} else {
self.tree.revision(hkey, None::<i32>);
self.params.revision(hkey, None::<i32>);
self.history.last_mut().unwrap().insert(hkey);
}
}

pub fn keys(&self) -> Vec<String> {
self.tree
self.params
.values()
.filter(|x| !matches!(x.value(), Value::Empty))
.map(|x| x.key.clone())
Expand All @@ -202,7 +204,7 @@ where
T: Into<Value> + TryFrom<Value> + for<'a> TryFrom<&'a Value>,
{
fn get_or_else(&self, key: u64, dval: T) -> T {
if let Some(val) = self.tree.get(&key) {
if let Some(val) = self.params.get(&key) {
match val.value().try_into() {
Ok(v) => v,
Err(_) => dval,
Expand Down
12 changes: 3 additions & 9 deletions core/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,7 @@ pub const EMPTY: Value = Value::Empty;

impl<T: Into<Value>> From<Option<T>> for Value {
fn from(value: Option<T>) -> Self {
match value {
Some(x) => {
let y: Value = x.into();
y
}
None => Value::Empty,
}
value.map_or(Value::Empty, |x| x.into())
}
}

Expand All @@ -63,7 +57,7 @@ impl From<i64> for Value {

impl From<f32> for Value {
fn from(value: f32) -> Self {
Value::Float(value.into())
Value::Float(value as f64)
}
}

Expand All @@ -81,7 +75,7 @@ impl From<String> for Value {

impl From<&String> for Value {
fn from(value: &String) -> Self {
Value::Text(value.clone())
Value::Text(value.to_string())
}
}

Expand Down
6 changes: 2 additions & 4 deletions core/src/xxh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ impl XXHashable for &str {

impl XXHashable for CStr {
fn xxh(&self) -> u64 {
let bs = self.to_bytes();
xxhash(bs)
xxhash(self.to_bytes())
}
}

impl XXHashable for CString {
fn xxh(&self) -> u64 {
let bs = self.to_bytes();
xxhash(bs)
xxhash(self.to_bytes())
}
}

Expand Down

0 comments on commit 100486d

Please sign in to comment.