Skip to content

Commit

Permalink
fix: unsoundness with refs
Browse files Browse the repository at this point in the history
  • Loading branch information
TroyKomodo committed Jul 8, 2023
1 parent 955ec92 commit 6a6de04
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 390 deletions.
6 changes: 4 additions & 2 deletions common/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use anyhow::Result;

use crate::logging;
Expand Down Expand Up @@ -30,8 +32,8 @@ pub struct LoggingConfig {
}

impl ::config::Config for logging::Mode {
fn graph() -> &'static ::config::KeyGraph {
&::config::KeyGraph::String
fn graph() -> Arc<::config::KeyGraph> {
Arc::new(::config::KeyGraph::String)
}
}

Expand Down
22 changes: 3 additions & 19 deletions config/config/example/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,12 @@ struct LoggingConfig {
manual2: Manual,
}

#[derive(Default, Debug, PartialEq, serde::Deserialize)]
#[derive(config::Config, Default, Debug, PartialEq, serde::Deserialize)]
struct Manual {
#[config(cli(skip), env(skip))]
cycle: Box<Option<LoggingConfig>>,
}

impl config::Config for Manual {
fn graph() -> &'static config::KeyGraph {
if let Some(tree) = config::KeyGraph::get::<Self>() {
return tree;
}

let mut keys = std::collections::BTreeMap::new();

keys.insert("level".to_string(), config::Key::new(String::graph()));
keys.insert("json".to_string(), config::Key::new(bool::graph()));

config::KeyGraph::store::<Self>(config::KeyGraph::Struct(keys))
}
}

impl Default for LoggingConfig {
fn default() -> Self {
Self {
Expand All @@ -60,9 +46,7 @@ fn main() {
}

fn parse() -> Result<AppConfig, ConfigError> {
let graph = std::thread::spawn(AppConfig::graph).join().unwrap();

println!("{:#?}", graph);
dbg!(AppConfig::graph());

let mut builder = ConfigBuilder::new();
builder.add_source(sources::CliSource::new()?);
Expand Down
160 changes: 66 additions & 94 deletions config/config/src/key.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use std::{
collections::{btree_map::Entry, BTreeMap},
collections::{BTreeMap, HashMap},
fmt::Display,
ptr::NonNull,
sync::Mutex,
sync::{Mutex, Arc, Weak}, any::TypeId,
};

#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -95,85 +94,50 @@ pub enum KeyType {
Struct(BTreeMap<String, KeyType>),
Map(Box<KeyType>, Box<KeyType>),
Seq(Box<KeyType>),
CyclicReference,
CyclicReference(&'static str),
}

static MEMO_GRAPHS: Mutex<BTreeMap<std::any::TypeId, KeyGraphRefHelper>> =
Mutex::new(BTreeMap::new());

struct KeyGraphRefHelper {
ptr: NonNull<KeyGraph>,
ref_key: NonNull<KeyGraph>,
is_building: bool,
thread_local! {
static MEMO_GRAPHS: Mutex<HashMap<TypeId, (Weak<KeyGraph>, bool)>> = Mutex::new(HashMap::new());
}

// Safety: The pointer is never null and is static.
// The pointer is also never dropped. And the pointer does not contain
// any data that can be dataraced as it is locked behind a mutex.
unsafe impl Send for KeyGraphRefHelper {}

impl Drop for KeyGraphRefHelper {
fn drop(&mut self) {
// Safety: We are the only ones with access to the pointer.
// We are dropping the pointer and therefore we can drop the value.
// In practice this should never actually be called because the `MEMO_GRAPHS` is a static
// and therefore will never be dropped.
unsafe {
let _ = Box::from_raw(self.ptr.as_ptr());
let _ = Box::from_raw(self.ref_key.as_ptr());
}
}
pub struct KeyGraphBuilder<C: std::any::Any> {
graph: Arc<KeyGraph>,
alive: bool,
building: bool,
phantom: std::marker::PhantomData<C>,
}

impl KeyGraphRefHelper {
fn new() -> Self {
// We create a pointer to some dummy value.
let ptr = NonNull::new(Box::leak(Box::new(KeyGraph::Unit))).unwrap();

// We then create a pointer to a ref to the dummy value.
// This is because the reference key is a cyclic reference.
// And we if we are calling the get function from the same path that is building the
// tree then we need to return the ref key. Otherwise we would run into an infinite
// loop and stack overflow.
// Safety: We created a non null pointer above and its static.
let ref_key =
NonNull::new(Box::leak(Box::new(KeyGraph::Ref(unsafe { ptr.as_ref() })))).unwrap();

Self {
ref_key,
is_building: true,
ptr,
impl<C: std::any::Any> KeyGraphBuilder<C> {
pub fn get(&self) -> Option<Arc<KeyGraph>> {
if self.building {
Some(Arc::new(KeyGraph::Ref(self.graph.clone(), std::any::type_name::<C>())))
} else if self.alive {
Some(self.graph.clone())
} else {
None
}
}

fn get(&mut self) -> &'static KeyGraph {
if self.is_building {
// The reason we return the ref key here is because if the get function is being
// called by the same path that is building the tree then we need to return the
// ref key because this is a cyclic reference.
// Safety: The value is non-null and is static. We also have a mutable ref to `KeyGraphRefHelper`
unsafe { self.ref_key.as_ref() }
} else {
// If we are not building then we can return the actual pointer, ie: memoization.
// Safety: The value is non-null and is static. We also have a mutable ref to `KeyGraphRefHelper`
unsafe { self.ptr.as_ref() }
pub fn build(self, graph: KeyGraph) -> Arc<KeyGraph> {
if let Some(arc) = self.get() {
return arc;
}
}

fn set(&mut self, mut tree: KeyGraph) {
// Safety: We have a mutable ref to `KeyGraphRefHelper` and therefore we are the only
// ones with access to the pointer.
// We swap the old value of the pointer with the new value. We can then drop the old
// value of the pointer. (which will be in `tree`)
// It should be impossible to call this function twice on the same type.
// However even if it is called twice it is still safe because we are swapping the
// pointer and therefore the old value will be dropped and no memory leaks will occur.
unsafe {
std::mem::swap(self.ptr.as_mut(), &mut tree);
let graph_ptr = self.graph.as_ref() as *const KeyGraph as *mut KeyGraph;
let _ = std::mem::replace(&mut *graph_ptr, graph);
}

// Since we are no longer building and are built we stop returning the ref key.
self.is_building = false;
MEMO_GRAPHS.with(|mg| {
let mut mg = mg.lock().unwrap();
let ty = TypeId::of::<C>();
if let Some((_, building)) = mg.get_mut(&ty) {
*building = false;
}
});

self.graph
}
}

Expand All @@ -193,11 +157,11 @@ pub enum KeyGraph {
Bool,
Unit,
Char,
Option(&'static KeyGraph),
Option(Arc<KeyGraph>),
Struct(BTreeMap<String, Key>),
Map(&'static KeyGraph, &'static KeyGraph),
Seq(&'static KeyGraph),
Ref(&'static KeyGraph),
Map(Arc<KeyGraph>, Arc<KeyGraph>),
Seq(Arc<KeyGraph>),
Ref(Arc<KeyGraph>, &'static str),
}

impl std::fmt::Debug for KeyGraph {
Expand All @@ -221,7 +185,7 @@ impl std::fmt::Debug for KeyGraph {
Self::Struct(map) => write!(f, "Struct({:?})", map),
Self::Map(key, value) => write!(f, "Map({:?}, {:?})", key, value),
Self::Seq(key) => write!(f, "Seq({:?})", key),
Self::Ref(_) => write!(f, "Ref(...)"),
Self::Ref(_, ty) => write!(f, "Ref(&{})", ty),
}
}
}
Expand Down Expand Up @@ -251,30 +215,38 @@ impl KeyGraph {
KeyType::Map(Box::new(key.key_type()), Box::new(value.key_type()))
}
Self::Seq(key) => KeyType::Seq(Box::new(key.key_type())),
Self::Ref(_) => KeyType::CyclicReference,
Self::Ref(_, ty) => KeyType::CyclicReference(ty),
}
}

pub fn get<T: crate::Config>() -> Option<&'static Self> {
let mut memo_graphs = MEMO_GRAPHS.lock().expect("Failed to lock memo tree");
let id = std::any::TypeId::of::<T>();

if let Entry::Vacant(e) = memo_graphs.entry(id) {
e.insert(KeyGraphRefHelper::new());
None
} else {
memo_graphs.get_mut(&id).map(|h| h.get())
}
}
pub fn builder<C: std::any::Any>() -> KeyGraphBuilder<C> {
MEMO_GRAPHS.with(|mg| {
let mut mg = mg.lock().unwrap();

let ty = TypeId::of::<C>();
if let Some((graph, building)) = mg.get(&ty) {
if let Some(graph) = graph.upgrade() {
return KeyGraphBuilder {
graph,
alive: true,
building: *building,
phantom: std::marker::PhantomData,
};
}
}

pub fn store<T: crate::Config>(tree: Self) -> &'static Self {
let mut memo_graphs = MEMO_GRAPHS.lock().expect("Failed to lock memo tree");
// Dummy value does not matter
let graph = Arc::new(KeyGraph::Unit);

let id = std::any::TypeId::of::<T>();
let helper = memo_graphs.get_mut(&id).unwrap();
mg.insert(ty, (Arc::downgrade(&graph), true));

helper.set(tree);
helper.get()
KeyGraphBuilder {
graph,
alive: false,
building: false,
phantom: std::marker::PhantomData,
}
})
}
}

Expand All @@ -283,14 +255,14 @@ impl KeyGraph {
/// - type
#[derive(Clone, Debug)]
pub struct Key {
tree: &'static KeyGraph,
tree: Arc<KeyGraph>,
skip_cli: bool,
skip_env: bool,
comment: Option<&'static str>,
}

impl Key {
pub fn new(tree: &'static KeyGraph) -> Self {
pub fn new(tree: Arc<KeyGraph>) -> Self {
Self {
tree,
skip_cli: false,
Expand Down Expand Up @@ -327,7 +299,7 @@ impl Key {
}

pub fn tree(&self) -> &KeyGraph {
self.tree
&self.tree
}

pub fn key_type(&self) -> KeyType {
Expand Down
6 changes: 3 additions & 3 deletions config/config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,10 @@ pub trait Config: serde::de::DeserializeOwned + 'static {
const VERSION: Option<&'static str> = None;
const AUTHOR: Option<&'static str> = None;

fn graph() -> &'static KeyGraph;
fn graph() -> Arc<KeyGraph>;

fn validate(value: Value) -> Result<Value> {
validate_from_graph(Self::graph(), value)
validate_from_graph(&Self::graph(), value)
}
}

Expand Down Expand Up @@ -290,7 +290,7 @@ pub fn validate_from_graph(tree: &KeyGraph, value: Value) -> Result<Value> {
)
}
}
KeyGraph::Ref(tree) => validate_from_graph(tree, value),
KeyGraph::Ref(tree, _) => validate_from_graph(tree, value),
}
}

Expand Down
Loading

0 comments on commit 6a6de04

Please sign in to comment.