Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(hydroflow_lang): Added state_by operator. #1469

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hydroflow_lang/src/graph/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ declare_ops![
source_stream::SOURCE_STREAM,
source_stream_serde::SOURCE_STREAM_SERDE,
state::STATE,
state_by::STATE_BY,
tee::TEE,
unique::UNIQUE,
unzip::UNZIP,
Expand Down
136 changes: 8 additions & 128 deletions hydroflow_lang/src/graph/ops/state.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use quote::{quote_spanned, ToTokens};

use syn::parse_quote_spanned;
use super::{
OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
Persistence, WriteContextArgs, RANGE_1,
OperatorCategory, OperatorConstraints,
WriteContextArgs, RANGE_1,
};
use crate::diagnostic::{Diagnostic, Level};

// TODO(mingwei): Improve example when things are more stable.
/// A lattice-based state operator, used for accumulating lattice state
Expand Down Expand Up @@ -36,132 +34,14 @@ pub const STATE: OperatorConstraints = OperatorConstraints {
ports_inn: None,
ports_out: None,
input_delaytype_fn: |_| None,
write_fn: |&WriteContextArgs {
root,
context,
hydroflow,
op_span,
ident,
inputs,
outputs,
is_pull,
singleton_output_ident,
op_name,
op_inst:
OperatorInstance {
generics:
OpInstGenerics {
type_args,
persistence_args,
..
},
..
},
..
},
write_fn: |wc @ &WriteContextArgs { op_span, .. },
diagnostics| {
let lattice_type = type_args
.first()
.map(ToTokens::to_token_stream)
.unwrap_or(quote_spanned!(op_span=> _));

let persistence = match persistence_args[..] {
[] => Persistence::Tick,
[Persistence::Mutable] => {
diagnostics.push(Diagnostic::spanned(
op_span,
Level::Error,
format!("{} does not support `'mut`.", op_name),
));
Persistence::Tick
}
[a] => a,
_ => unreachable!(),
};

let state_ident = singleton_output_ident;
let mut write_prologue = quote_spanned! {op_span=>
let #state_ident = #hydroflow.add_state(::std::cell::RefCell::new(
<#lattice_type as ::std::default::Default>::default()
));
let wc = WriteContextArgs {
arguments: &parse_quote_spanned!(op_span => ::std::convert::identity),
..wc.clone()
};
if Persistence::Tick == persistence {
write_prologue.extend(quote_spanned! {op_span=>
#hydroflow.set_state_tick_hook(#state_ident, |rcell| { rcell.take(); }); // Resets state to `Default::default()`.
});
}

// TODO(mingwei): deduplicate codegen
let write_iterator = if is_pull {
let input = &inputs[0];
quote_spanned! {op_span=>
let #ident = {
fn check_input<'a, Item, Iter, Lat>(
iter: Iter,
state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
context: &'a #root::scheduled::context::Context,
) -> impl 'a + ::std::iter::Iterator<Item = Item>
where
Item: ::std::clone::Clone,
Iter: 'a + ::std::iter::Iterator<Item = Item>,
Lat: 'static + #root::lattices::Merge<Item>,
{
iter.filter(move |item| {
let state = context.state_ref(state_handle);
let mut state = state.borrow_mut();
#root::lattices::Merge::merge(&mut *state, ::std::clone::Clone::clone(item))
})
}
check_input::<_, _, #lattice_type>(#input, #state_ident, #context)
};
}
} else if let Some(output) = outputs.first() {
quote_spanned! {op_span=>
let #ident = {
fn check_output<'a, Item, Push, Lat>(
push: Push,
state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
context: &'a #root::scheduled::context::Context,
) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
where
Item: 'a + ::std::clone::Clone,
Push: #root::pusherator::Pusherator<Item = Item>,
Lat: 'static + #root::lattices::Merge<Item>,
{
#root::pusherator::filter::Filter::new(move |item| {
let state = context.state_ref(state_handle);
let mut state = state.borrow_mut();
#root::lattices::Merge::merge(&mut *state, ::std::clone::Clone::clone(item))
}, push)
}
check_output::<_, _, #lattice_type>(#output, #state_ident, #context)
};
}
} else {
quote_spanned! {op_span=>
let #ident = {
fn check_output<'a, Item, Lat>(
state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
context: &'a #root::scheduled::context::Context,
) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
where
Item: 'a,
Lat: 'static + #root::lattices::Merge<Item>,
{
#root::pusherator::for_each::ForEach::new(move |item| {
let state = context.state_ref(state_handle);
let mut state = state.borrow_mut();
#root::lattices::Merge::merge(&mut *state, item);
})
}
check_output::<_, #lattice_type>(#state_ident, #context)
};
}
};
Ok(OperatorWriteOutput {
write_prologue,
write_iterator,
..Default::default()
})
(super::state_by::STATE_BY.write_fn)(&wc, diagnostics)
},
};
176 changes: 176 additions & 0 deletions hydroflow_lang/src/graph/ops/state_by.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
use quote::{quote_spanned, ToTokens};

use super::{
OpInstGenerics, OperatorCategory, OperatorConstraints, OperatorInstance, OperatorWriteOutput,
Persistence, WriteContextArgs, RANGE_1,
};
use crate::diagnostic::{Diagnostic, Level};

/// List state operator, but with a closure to map the input to the state lattice.
///
/// The emitted outputs (both the referencable singleton and the optional pass-through stream) are
/// of the same type as the inputs to the state_by operator and are not required to be a lattice
/// type. This is useful receiving pass-through context information on the output side.
///
/// ```hydroflow
/// use std::collections::HashSet;
///
/// use lattices::set_union::{CartesianProductBimorphism, SetUnionHashSet, SetUnionSingletonSet};
///
/// my_state = source_iter(0..3)
/// -> state_by::<SetUnionHashSet<usize>>(SetUnionSingletonSet::new_from);
/// ```
pub const STATE_BY: OperatorConstraints = OperatorConstraints {
name: "state_by",
categories: &[OperatorCategory::Persistence],
hard_range_inn: RANGE_1,
soft_range_inn: RANGE_1,
hard_range_out: &(0..=1),
soft_range_out: &(0..=1),
num_args: 1,
persistence_args: &(0..=1),
type_args: &(0..=1),
is_external_input: false,
has_singleton_output: true,
ports_inn: None,
ports_out: None,
input_delaytype_fn: |_| None,
write_fn: |&WriteContextArgs {
root,
context,
hydroflow,
op_span,
ident,
inputs,
outputs,
is_pull,
singleton_output_ident,
op_name,
op_inst:
OperatorInstance {
generics:
OpInstGenerics {
type_args,
persistence_args,
..
},
..
},
arguments,
..
},
diagnostics| {
let lattice_type = type_args
.first()
.map(ToTokens::to_token_stream)
.unwrap_or(quote_spanned!(op_span=> _));

let persistence = match persistence_args[..] {
[] => Persistence::Tick,
[Persistence::Mutable] => {
diagnostics.push(Diagnostic::spanned(
op_span,
Level::Error,
format!("{} does not support `'mut`.", op_name),
));
Persistence::Tick
}
[a] => a,
_ => unreachable!(),
};

let state_ident = singleton_output_ident;
let mut write_prologue = quote_spanned! {op_span=>
let #state_ident = #hydroflow.add_state(::std::cell::RefCell::new(
<#lattice_type as ::std::default::Default>::default()
));
};
if Persistence::Tick == persistence {
write_prologue.extend(quote_spanned! {op_span=>
#hydroflow.set_state_tick_hook(#state_ident, |rcell| { rcell.take(); }); // Resets state to `Default::default()`.
});
}

let func = &arguments[0];

// TODO(mingwei): deduplicate codegen
let write_iterator = if is_pull {
let input = &inputs[0];
quote_spanned! {op_span=>
let #ident = {
fn check_input<'a, Item, MappingFn, MappedItem, Iter, Lat>(
iter: Iter,
mapfn: MappingFn,
state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
context: &'a #root::scheduled::context::Context,
) -> impl 'a + ::std::iter::Iterator<Item = Item>
where
Item: ::std::clone::Clone,
MappingFn: 'a + Fn(Item) -> MappedItem,
Iter: 'a + ::std::iter::Iterator<Item = Item>,
Lat: 'static + #root::lattices::Merge<MappedItem>,
{
iter.filter(move |item| {
let state = context.state_ref(state_handle);
let mut state = state.borrow_mut();
#root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
})
}
check_input::<_, _, _, _, #lattice_type>(#input, #func, #state_ident, #context)
};
}
} else if let Some(output) = outputs.first() {
quote_spanned! {op_span=>
let #ident = {
fn check_output<'a, Item, MappingFn, MappedItem, Push, Lat>(
push: Push,
mapfn: MappingFn,
state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
context: &'a #root::scheduled::context::Context,
) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
where
Item: 'a + ::std::clone::Clone,
MappingFn: 'a + Fn(Item) -> MappedItem,
Push: 'a + #root::pusherator::Pusherator<Item = Item>,
Lat: 'static + #root::lattices::Merge<MappedItem>,
{
#root::pusherator::filter::Filter::new(move |item| {
let state = context.state_ref(state_handle);
let mut state = state.borrow_mut();
#root::lattices::Merge::merge(&mut *state, (mapfn)(::std::clone::Clone::clone(item)))
}, push)
}
check_output::<_, _, _, _, #lattice_type>(#output, #func, #state_ident, #context)
};
}
} else {
quote_spanned! {op_span=>
let #ident = {
fn check_output<'a, Item, MappingFn, MappedItem, Lat>(
state_handle: #root::scheduled::state::StateHandle<::std::cell::RefCell<Lat>>,
mapfn: MappingFn,
context: &'a #root::scheduled::context::Context,
) -> impl 'a + #root::pusherator::Pusherator<Item = Item>
where
Item: 'a,
MappedItem: 'a,
MappingFn: 'a + Fn(Item) -> MappedItem,
Lat: 'static + #root::lattices::Merge<MappedItem>,
{
#root::pusherator::for_each::ForEach::new(move |item| {
let state = context.state_ref(state_handle);
let mut state = state.borrow_mut();
#root::lattices::Merge::merge(&mut *state, (mapfn)(item));
})
}
check_output::<_, _, _, #lattice_type>(#state_ident, #func, #context)
};
}
};
Ok(OperatorWriteOutput {
write_prologue,
write_iterator,
..Default::default()
})
},
};
Loading