Skip to content

Commit

Permalink
Bumping version of core Scallop library
Browse files Browse the repository at this point in the history
  • Loading branch information
Liby99 committed Feb 23, 2024
1 parent 8d5f7c7 commit bcf2d21
Show file tree
Hide file tree
Showing 75 changed files with 2,064 additions and 518 deletions.
2 changes: 1 addition & 1 deletion core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "scallop-core"
version = "0.2.1"
version = "0.2.2"
authors = ["Ziyang Li <[email protected]>"]
edition = "2018"

Expand Down
262 changes: 239 additions & 23 deletions core/src/common/foreign_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ pub enum BindingTypes {
},
}

impl Default for BindingTypes {
fn default() -> Self {
Self::unit()
}
}

impl BindingTypes {
pub fn unit() -> Self {
Self::TupleType(vec![])
Expand All @@ -101,10 +107,6 @@ impl BindingTypes {
}
}

pub fn empty_tuple() -> Self {
Self::tuple(vec![])
}

pub fn tuple(elems: Vec<BindingType>) -> Self {
Self::TupleType(elems)
}
Expand Down Expand Up @@ -149,21 +151,223 @@ pub enum ParamType {
Optional(ValueType),
}

impl ParamType {
pub fn is_mandatory(&self) -> bool {
match self {
Self::Mandatory(_) => true,
_ => false,
}
}

pub fn value_type(&self) -> &ValueType {
match self {
Self::Mandatory(vt) => vt,
Self::Optional(vt) => vt,
}
}
}

/// The type of an aggregator
///
/// ``` ignore
/// OUTPUT_TYPE := AGGREGATE<PARAM: FAMILY, ...>[ARG_TYPE](INPUT_TYPE)
/// ```
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Default)]
pub struct AggregateType {
pub generics: HashMap<String, GenericTypeFamily>,
pub param_types: Vec<ParamType>,
pub named_param_types: HashMap<String, ParamType>,
pub arg_type: BindingTypes,
pub input_type: BindingTypes,
pub output_type: BindingTypes,
pub allow_exclamation_mark: bool,
}

impl AggregateType {
pub fn infer_output_arity(&self, arg_arity: usize, input_arity: usize) -> Result<usize, String> {
let mut grounded_generic_arity = HashMap::new();
self.ground_input_aggregate_binding("argument", &self.arg_type, arg_arity, &mut grounded_generic_arity)?;
self.ground_input_aggregate_binding("input", &self.input_type, input_arity, &mut grounded_generic_arity)?;
self.solve_output_binding_arity(&self.output_type, &grounded_generic_arity)
}

fn ground_input_aggregate_binding(
&self,
kind: &str,
binding_types: &BindingTypes,
num_variables: usize,
grounded_generic_arity: &mut HashMap<String, usize>,
) -> Result<(), String> {
match binding_types {
BindingTypes::IfNotUnit { .. } => Err(format!("cannot have if-not-unit binding type in aggregate input")),
BindingTypes::TupleType(elems) => {
if elems.len() == 0 {
// If elems.len() is 0, it means that there should be no variable for this part of aggregation.
// We throw error if there is at least 1 variable.
// Otherwise, the type checking is done as there is no variable that needs to be unified for type
if num_variables != 0 {
Err(format!("expected 0 {kind} variables, found {num_variables}"))
} else {
Ok(())
}
} else if elems.len() == 1 {
// If elems.len() is 1, we could have that exact element to be a generic type variable or a concrete value type
match &elems[0] {
BindingType::Generic(g) => {
if let Some(grounded_type_arity) = grounded_generic_arity.get(g) {
if *grounded_type_arity != num_variables {
Err(format!("the generic type `{g}` is grounded to have {grounded_type_arity} variables, but found {num_variables}"))
} else {
Ok(())
}
} else if let Some(generic_type_family) = self.generics.get(g) {
let arity = self.solve_generic_type(kind, g, generic_type_family, num_variables)?;
grounded_generic_arity.insert(g.to_string(), arity);
Ok(())
} else {
Err(format!("unknown generic type parameter `{g}`"))
}
}
BindingType::ValueType(_) => {
if num_variables == 1 {
Ok(())
} else {
// Arity mismatch
Err(format!("expected one {kind} variable; found {num_variables}"))
}
}
}
} else {
if elems.iter().any(|e| e.is_generic()) {
Err(format!(
"cannot have generic in the {kind} of aggregate of more than 1 elements"
))
} else if elems.len() != num_variables {
Err(format!(
"expected {} {kind} variables, found {num_variables}",
elems.len()
))
} else {
Ok(())
}
}
}
}
}

fn solve_generic_type(
&self,
kind: &str,
generic_type_name: &str,
generic_type_family: &GenericTypeFamily,
num_variables: usize,
) -> Result<usize, String> {
match generic_type_family {
GenericTypeFamily::NonEmptyTuple => {
if num_variables == 0 {
Err(format!(
"arity mismatch. Expected non-empty {kind} variables, but found 0"
))
} else {
Ok(num_variables)
}
}
GenericTypeFamily::NonEmptyTupleWithElements(elem_type_families) => {
if elem_type_families.iter().any(|tf| !tf.is_type_family()) {
Err(format!(
"generic type family `{generic_type_name}` contains unsupported nested tuple"
))
} else if num_variables != elem_type_families.len() {
Err(format!(
"arity mismatch. Expected {} {kind} variables, but found 0",
elem_type_families.len()
))
} else {
Ok(num_variables)
}
}
GenericTypeFamily::UnitOr(child_generic_type_family) => {
if num_variables == 0 {
Ok(0)
} else {
self.solve_generic_type(kind, generic_type_name, &*child_generic_type_family, num_variables)
}
}
GenericTypeFamily::TypeFamily(_) => {
if num_variables != 1 {
Err(format!("arity mismatch. Expected 1 {kind} variables, but found 0"))
} else {
Ok(1)
}
}
}
}

fn solve_output_binding_arity(
&self,
binding_types: &BindingTypes,
grounded_generic_arity: &HashMap<String, usize>,
) -> Result<usize, String> {
match binding_types {
BindingTypes::IfNotUnit {
generic_type,
then_type,
else_type,
} => {
if let Some(arity) = grounded_generic_arity.get(generic_type) {
if *arity > 0 {
self.solve_output_binding_arity(then_type, grounded_generic_arity)
} else {
self.solve_output_binding_arity(else_type, grounded_generic_arity)
}
} else {
Err(format!(
"error grounding output type: unknown generic type `{generic_type}`"
))
}
}
BindingTypes::TupleType(elems) => Ok(
elems
.iter()
.map(|elem| match elem {
BindingType::Generic(g) => {
if let Some(arity) = grounded_generic_arity.get(g) {
Ok(*arity)
} else {
Err(format!("error grounding output type: unknown generic type `{g}`"))
}
}
BindingType::ValueType(_) => Ok(1),
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.fold(0, |acc, a| acc + a),
),
}
}
}

#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
pub struct AggregateInfo {
pub pos_params: Vec<Value>,
pub named_params: BTreeMap<String, Value>,
pub has_exclamation_mark: bool,
pub arg_var_types: Vec<ValueType>,
pub input_var_types: Vec<ValueType>,
}

impl AggregateInfo {
pub fn with_arg_var_types(mut self, arg_var_types: Vec<ValueType>) -> Self {
self.arg_var_types = arg_var_types;
self
}

pub fn with_input_var_types(mut self, input_var_types: Vec<ValueType>) -> Self {
self.input_var_types = input_var_types;
self
}
}

pub trait Aggregate: Into<DynamicAggregate> {
/// The concrete aggregator that this aggregate is instantiated into
type Aggregator<P: Provenance>: Aggregator<P>;
Expand All @@ -175,23 +379,22 @@ pub trait Aggregate: Into<DynamicAggregate> {
fn aggregate_type(&self) -> AggregateType;

/// Instantiate the aggregate into an aggregator with the given parameters
fn instantiate<P: Provenance>(
&self,
params: Vec<Value>,
has_exclamation_mark: bool,
arg_types: Vec<ValueType>,
input_types: Vec<ValueType>,
) -> Self::Aggregator<P>;
fn instantiate<P: Provenance>(&self, aggregate_info: AggregateInfo) -> Self::Aggregator<P>;
}

/// A dynamic aggregate kind
#[derive(Clone)]
pub enum DynamicAggregate {
Avg(AvgAggregate),
Count(CountAggregate),
Disjunct(DisjunctAggregate),
Enumerate(EnumerateAggregate),
Exists(ExistsAggregate),
MinMax(MinMaxAggregate),
Normalize(NormalizeAggregate),
Rank(RankAggregate),
Sampler(DynamicSampleAggregate),
Sort(SortAggregate),
StringJoin(StringJoinAggregate),
SumProd(SumProdAggregate),
WeightedSumAvg(WeightedSumAvgAggregate),
Expand All @@ -202,12 +405,17 @@ macro_rules! match_aggregate {
match $a {
DynamicAggregate::Avg($v) => $e,
DynamicAggregate::Count($v) => $e,
DynamicAggregate::Disjunct($v) => $e,
DynamicAggregate::Enumerate($v) => $e,
DynamicAggregate::MinMax($v) => $e,
DynamicAggregate::SumProd($v) => $e,
DynamicAggregate::Normalize($v) => $e,
DynamicAggregate::Rank($v) => $e,
DynamicAggregate::Sampler($v) => $e,
DynamicAggregate::Sort($v) => $e,
DynamicAggregate::StringJoin($v) => $e,
DynamicAggregate::SumProd($v) => $e,
DynamicAggregate::WeightedSumAvg($v) => $e,
DynamicAggregate::Exists($v) => $e,
DynamicAggregate::Sampler($v) => $e,
}
};
}
Expand Down Expand Up @@ -247,6 +455,7 @@ impl AggregateRegistry {

// Register
registry.register(CountAggregate);
registry.register(DisjunctAggregate);
registry.register(MinMaxAggregate::min());
registry.register(MinMaxAggregate::max());
registry.register(MinMaxAggregate::argmin());
Expand All @@ -257,7 +466,13 @@ impl AggregateRegistry {
registry.register(WeightedSumAvgAggregate::weighted_sum());
registry.register(WeightedSumAvgAggregate::weighted_avg());
registry.register(ExistsAggregate::new());
registry.register(NormalizeAggregate::normalize());
registry.register(NormalizeAggregate::softmax());
registry.register(StringJoinAggregate::new());
registry.register(EnumerateAggregate::new());
registry.register(RankAggregate::new());
registry.register(SortAggregate::sort());
registry.register(SortAggregate::argsort());
registry.register(DynamicSampleAggregate::new(TopKSamplerAggregate::top()));
registry.register(DynamicSampleAggregate::new(TopKSamplerAggregate::unique()));
registry.register(DynamicSampleAggregate::new(CategoricalAggregate::new()));
Expand All @@ -276,23 +491,24 @@ impl AggregateRegistry {
self.registry.iter()
}

pub fn instantiate_aggregator<P: Provenance>(
&self,
name: &str,
params: Vec<Value>,
has_exclamation_mark: bool,
arg_types: Vec<ValueType>,
input_types: Vec<ValueType>,
) -> Option<DynamicAggregator<P>> {
pub fn instantiate_aggregator<P: Provenance>(&self, name: &str, info: AggregateInfo) -> Option<DynamicAggregator<P>> {
if let Some(aggregate) = self.registry.get(name) {
match_aggregate!(aggregate, a, {
let instantiated = a.instantiate::<P>(params, has_exclamation_mark, arg_types, input_types);
let instantiated = a.instantiate::<P>(info);
Some(DynamicAggregator(Box::new(instantiated)))
})
} else {
None
}
}

pub fn create_type_registry(&self) -> HashMap<String, AggregateType> {
self
.registry
.iter()
.map(|(name, agg)| (name.clone(), agg.aggregate_type()))
.collect()
}
}

pub trait Aggregator<P: Provenance>: dyn_clone::DynClone + 'static {
Expand Down
Loading

0 comments on commit bcf2d21

Please sign in to comment.