From 79cece14553e3893abab8c437188c21aa28b2054 Mon Sep 17 00:00:00 2001 From: Ziyang Li Date: Wed, 4 Oct 2023 18:16:11 -0400 Subject: [PATCH] Adding foreign aggregators and wmc_with_disjunctions --- .github/workflows/scallop-core.yml | 4 + .github/workflows/scallopy.yml | 12 +- changelog.md | 16 +- core/Cargo.toml | 3 +- core/src/common/aggregate_op.rs | 42 +- core/src/common/foreign_aggregate.rs | 319 +++++++++ core/src/common/foreign_aggregates/avg.rs | 122 ++++ .../common/foreign_aggregates/categorical.rs | 77 +++ core/src/common/foreign_aggregates/count.rs | 96 +++ core/src/common/foreign_aggregates/exists.rs | 99 +++ core/src/common/foreign_aggregates/min_max.rs | 279 ++++++++ core/src/common/foreign_aggregates/mod.rs | 25 + core/src/common/foreign_aggregates/sampler.rs | 139 ++++ .../common/foreign_aggregates/string_join.rs | 107 +++ .../src/common/foreign_aggregates/sum_prod.rs | 159 +++++ core/src/common/foreign_aggregates/top_k.rs | 158 +++++ core/src/common/foreign_aggregates/uniform.rs | 74 ++ .../foreign_aggregates/weighted_sum_avg.rs | 201 ++++++ .../foreign_functions/string_replace.rs | 4 +- .../common/foreign_tensor/external_tensor.rs | 8 +- core/src/common/foreign_tensor/registry.rs | 9 +- core/src/common/generic_tuple.rs | 62 ++ core/src/common/mod.rs | 2 + core/src/common/tuple.rs | 4 + core/src/common/tuples.rs | 86 +-- core/src/common/value.rs | 2 +- core/src/common/value_type.rs | 68 ++ core/src/compiler/back/ast.rs | 39 +- core/src/compiler/back/attr.rs | 7 +- core/src/compiler/back/b2r.rs | 53 +- core/src/compiler/back/compile.rs | 1 + core/src/compiler/back/pretty.rs | 16 +- core/src/compiler/front/analysis.rs | 9 +- .../compiler/front/analyzers/aggregation.rs | 50 +- .../front/analyzers/boundness/context.rs | 26 +- .../compiler/front/analyzers/constant_decl.rs | 17 +- .../front/analyzers/invalid_constant.rs | 4 +- core/src/compiler/front/analyzers/mod.rs | 1 - .../front/analyzers/type_inference/error.rs | 641 ++++++------------ .../type_inference/foreign_aggregate.rs | 30 + .../front/analyzers/type_inference/local.rs | 204 ++---- .../front/analyzers/type_inference/mod.rs | 5 + .../type_inference/type_inference.rs | 240 ++++--- .../analyzers/type_inference/type_set.rs | 6 +- .../analyzers/type_inference/unification.rs | 555 ++++++++++++--- core/src/compiler/front/ast/attr.rs | 20 +- core/src/compiler/front/ast/formula.rs | 81 +-- core/src/compiler/front/ast/mod.rs | 2 +- core/src/compiler/front/ast/relation_decl.rs | 6 +- core/src/compiler/front/ast/rule.rs | 16 +- core/src/compiler/front/ast/type_decl.rs | 14 +- core/src/compiler/front/ast/types.rs | 4 +- core/src/compiler/front/ast/utils.rs | 35 +- core/src/compiler/front/compile.rs | 15 +- core/src/compiler/front/error.rs | 63 ++ core/src/compiler/front/f2b/f2b.rs | 104 +-- core/src/compiler/front/f2b/flatten_expr.rs | 16 +- core/src/compiler/front/grammar.lalrpop | 50 +- core/src/compiler/front/pretty.rs | 47 +- .../front/transformations/adt_to_relation.rs | 22 +- .../front/transformations/conjunctive_head.rs | 24 +- .../transformations/const_var_to_const.rs | 44 +- .../transformations/desugar_arg_type_anno.rs | 6 +- .../front/transformations/desugar_case_is.rs | 28 +- .../transformations/desugar_forall_exists.rs | 14 +- .../front/transformations/desugar_range.rs | 18 +- .../transformations/forall_to_not_exists.rs | 37 +- .../transformations/implies_to_disjunction.rs | 10 +- .../front/transformations/tagged_rule.rs | 4 +- core/src/compiler/ram/ast.rs | 29 +- core/src/compiler/ram/pretty.rs | 17 +- core/src/compiler/ram/ram2rs.rs | 75 +- core/src/integrate/context.rs | 36 +- core/src/integrate/interpret.rs | 4 + .../runtime/database/extensional/database.rs | 4 +- .../runtime/database/intentional/database.rs | 1 + .../runtime/dynamic/aggregator/aggregator.rs | 113 --- core/src/runtime/dynamic/aggregator/argmax.rs | 12 - core/src/runtime/dynamic/aggregator/argmin.rs | 12 - .../dynamic/aggregator/categorical_k.rs | 18 - core/src/runtime/dynamic/aggregator/count.rs | 18 - core/src/runtime/dynamic/aggregator/exists.rs | 12 - core/src/runtime/dynamic/aggregator/max.rs | 12 - core/src/runtime/dynamic/aggregator/min.rs | 12 - core/src/runtime/dynamic/aggregator/mod.rs | 25 - core/src/runtime/dynamic/aggregator/prod.rs | 13 - core/src/runtime/dynamic/aggregator/sum.rs | 13 - core/src/runtime/dynamic/aggregator/top_k.rs | 12 - .../dataflow/aggregation/implicit_group.rs | 24 +- .../dataflow/aggregation/join_group.rs | 20 +- .../dynamic/dataflow/aggregation/mod.rs | 2 + .../dataflow/aggregation/single_group.rs | 11 +- .../dynamic/dataflow/batching/batch.rs | 9 +- .../dynamic/dataflow/batching/batches.rs | 11 +- .../dynamic/dataflow/dynamic_dataflow.rs | 36 +- .../dynamic/dataflow/dynamic_exclusion.rs | 7 +- .../dynamic/dataflow/dynamic_relation.rs | 6 +- core/src/runtime/dynamic/dataflow/filter.rs | 12 +- core/src/runtime/dynamic/dataflow/find.rs | 28 +- .../dataflow/foreign_predicate/constraint.rs | 8 +- .../dataflow/foreign_predicate/join.rs | 8 +- .../runtime/dynamic/dataflow/overwrite_one.rs | 10 +- core/src/runtime/dynamic/dataflow/project.rs | 18 +- core/src/runtime/dynamic/incremental.rs | 4 +- core/src/runtime/dynamic/iteration.rs | 19 +- core/src/runtime/dynamic/mod.rs | 2 - core/src/runtime/dynamic/relation.rs | 24 +- core/src/runtime/env/environment.rs | 20 +- core/src/runtime/env/random.rs | 6 + core/src/runtime/monitor/debug_runtime.rs | 5 + core/src/runtime/monitor/debug_tags.rs | 5 + core/src/runtime/monitor/dump_proofs.rs | 171 +++++ core/src/runtime/monitor/dynamic_monitors.rs | 35 +- core/src/runtime/monitor/iteration_checker.rs | 5 + core/src/runtime/monitor/logging.rs | 5 + core/src/runtime/monitor/mod.rs | 4 + core/src/runtime/monitor/monitor.rs | 23 +- core/src/runtime/monitor/registry.rs | 43 ++ .../provenance/common/as_boolean_formula.rs | 47 ++ .../runtime/provenance/common/disjunction.rs | 6 +- .../runtime/provenance/common/dnf_formula.rs | 15 + .../provenance/common/input_tags/boolean.rs | 2 +- .../provenance/common/input_tags/float.rs | 2 +- .../common/input_tags/input_diff_prob.rs | 2 +- .../common/input_tags/input_exclusion.rs | 2 +- .../input_tags/input_exclusive_diff_prob.rs | 2 +- .../common/input_tags/input_exclusive_prob.rs | 2 +- .../provenance/common/input_tags/natural.rs | 2 +- .../provenance/common/input_tags/unit.rs | 2 +- .../differentiable/diff_add_mult_prob.rs | 72 -- .../differentiable/diff_max_mult_prob.rs | 72 -- .../differentiable/diff_min_max_prob.rs | 270 ++------ .../differentiable/diff_nand_min_prob.rs | 72 -- .../differentiable/diff_nand_mult_prob.rs | 72 -- .../diff_top_bottom_k_clauses.rs | 123 +--- .../differentiable/diff_top_k_proofs.rs | 123 +--- .../runtime/provenance/discrete/boolean.rs | 62 +- .../runtime/provenance/discrete/natural.rs | 24 - .../src/runtime/provenance/discrete/proofs.rs | 10 - core/src/runtime/provenance/discrete/unit.rs | 49 +- .../provenance/probabilistic/add_mult_prob.rs | 87 --- .../provenance/probabilistic/min_max_prob.rs | 202 +----- .../probabilistic/top_bottom_k_clauses.rs | 123 +--- .../provenance/probabilistic/top_k_proofs.rs | 124 +--- core/src/runtime/provenance/provenance.rs | 157 ----- .../runtime/statics/aggregator/aggregator.rs | 8 +- core/src/runtime/statics/aggregator/argmax.rs | 35 +- core/src/runtime/statics/aggregator/argmin.rs | 35 +- .../src/runtime/statics/aggregator/argprod.rs | 1 + core/src/runtime/statics/aggregator/argsum.rs | 1 + core/src/runtime/statics/aggregator/count.rs | 39 +- core/src/runtime/statics/aggregator/exists.rs | 40 +- core/src/runtime/statics/aggregator/max.rs | 27 +- core/src/runtime/statics/aggregator/min.rs | 27 +- core/src/runtime/statics/aggregator/mod.rs | 4 + core/src/runtime/statics/aggregator/prod.rs | 40 +- core/src/runtime/statics/aggregator/sum.rs | 40 +- core/src/runtime/statics/aggregator/top_k.rs | 16 +- .../dataflow/aggregation/implicit_group.rs | 9 +- .../dataflow/aggregation/join_group.rs | 9 +- .../dataflow/aggregation/single_group.rs | 15 +- core/src/runtime/statics/iteration.rs | 12 +- core/src/runtime/statics/tuple.rs | 303 ++++++++- .../runtime/statics/utils/flatten_tuple.rs | 87 +++ core/src/runtime/statics/utils/mod.rs | 2 + core/src/testing/test_compile.rs | 1 + core/src/utils/float.rs | 10 + core/tests/compiler/errors.rs | 6 +- core/tests/compiler/parse.rs | 2 +- core/tests/integrate/adt.rs | 2 +- core/tests/integrate/aggregate.rs | 93 +++ core/tests/integrate/basic.rs | 59 +- core/tests/integrate/mod.rs | 2 + core/tests/integrate/sampling.rs | 19 + core/tests/runtime/dataflow/dyn_aggregate.rs | 6 +- core/tests/runtime/dataflow/dyn_difference.rs | 3 +- .../runtime/dataflow/dyn_group_aggregate.rs | 14 +- .../runtime/dataflow/dyn_group_by_key.rs | 26 +- core/tests/runtime/interpret/iteration.rs | 42 +- core/tests/runtime/provenance/prob.rs | 54 +- core/tests/runtime/provenance/top_bottom_k.rs | 4 +- core/tests/runtime/statics/iteration.rs | 83 ++- doc/src/language/adt_and_entity.md | 2 +- doc/src/language/aggregation.md | 10 +- etc/codegen/Cargo.toml | 2 +- etc/codegen/examples/digit_sum_2_codegen.rs | 2 +- etc/codegen/tests/codegen_basic.rs | 22 +- etc/scallop-cli/setup.cfg | 2 +- etc/scallop-wasm/Cargo.toml | 2 +- etc/scallop-wasm/src/lib.rs | 4 +- etc/scallopy-ext/setup.cfg | 2 +- etc/scallopy-ext/src/scallopy_ext/registry.py | 20 +- etc/scallopy-ext/src/scallopy_ext/utils.py | 2 + etc/scallopy/Cargo.toml | 2 +- etc/scallopy/scallopy/context.py | 26 +- etc/scallopy/scallopy/forward.py | 9 +- etc/scallopy/scallopy/scallopy.pyi | 4 + etc/scallopy/src/collection.rs | 4 +- etc/scallopy/src/context.rs | 32 +- etc/scallopy/src/foreign_attribute.rs | 16 +- etc/scallopy/src/foreign_function.rs | 8 +- etc/scallopy/src/foreign_predicate.rs | 6 +- etc/scallopy/src/io.rs | 15 +- etc/scallopy/src/tag.rs | 14 +- .../src/tensor/torch/external_tensor.rs | 28 +- etc/scallopy/src/tensor/torch/registry.rs | 90 ++- etc/scallopy/src/tuple.rs | 11 +- etc/scallopy/tests/tensors.py | 11 +- etc/sclc/Cargo.toml | 2 +- etc/sclc/src/exec.rs | 9 +- etc/sclc/src/options.rs | 5 +- etc/sclc/src/pylib.rs | 6 +- etc/scli/Cargo.toml | 2 +- etc/scli/src/main.rs | 7 +- etc/sclrepl/Cargo.toml | 2 +- etc/vscode-scl/CHANGELOG.md | 5 +- etc/vscode-scl/examples/syntax_test.scl | 2 +- etc/vscode-scl/examples/syntax_test_2.scl | 22 +- etc/vscode-scl/package.json | 2 +- .../syntaxes/scallop.tmLanguage.json | 146 ++-- examples/datalog/equality_saturation.scl | 2 +- examples/legacy/good_scl/animal.scl | 2 +- examples/legacy/good_scl/obj_color.scl | 2 +- examples/legacy/good_scl/obj_color_2.scl | 2 +- examples/legacy/good_scl/student_grade_1.scl | 2 +- examples/legacy/invalid_scl/unbound_3.scl | 2 +- lib/astnode-derive/src/lib.rs | 148 ++-- lib/parse_relative_duration/src/lib.rs | 2 +- lib/parse_relative_duration/src/parse.rs | 546 +++++++-------- lib/parse_relative_duration/tests/basics.rs | 48 +- lib/sdd/Cargo.toml | 2 +- 231 files changed, 6116 insertions(+), 3880 deletions(-) create mode 100644 core/src/common/foreign_aggregate.rs create mode 100644 core/src/common/foreign_aggregates/avg.rs create mode 100644 core/src/common/foreign_aggregates/categorical.rs create mode 100644 core/src/common/foreign_aggregates/count.rs create mode 100644 core/src/common/foreign_aggregates/exists.rs create mode 100644 core/src/common/foreign_aggregates/min_max.rs create mode 100644 core/src/common/foreign_aggregates/mod.rs create mode 100644 core/src/common/foreign_aggregates/sampler.rs create mode 100644 core/src/common/foreign_aggregates/string_join.rs create mode 100644 core/src/common/foreign_aggregates/sum_prod.rs create mode 100644 core/src/common/foreign_aggregates/top_k.rs create mode 100644 core/src/common/foreign_aggregates/uniform.rs create mode 100644 core/src/common/foreign_aggregates/weighted_sum_avg.rs create mode 100644 core/src/compiler/front/analyzers/type_inference/foreign_aggregate.rs delete mode 100644 core/src/runtime/dynamic/aggregator/aggregator.rs delete mode 100644 core/src/runtime/dynamic/aggregator/argmax.rs delete mode 100644 core/src/runtime/dynamic/aggregator/argmin.rs delete mode 100644 core/src/runtime/dynamic/aggregator/categorical_k.rs delete mode 100644 core/src/runtime/dynamic/aggregator/count.rs delete mode 100644 core/src/runtime/dynamic/aggregator/exists.rs delete mode 100644 core/src/runtime/dynamic/aggregator/max.rs delete mode 100644 core/src/runtime/dynamic/aggregator/min.rs delete mode 100644 core/src/runtime/dynamic/aggregator/mod.rs delete mode 100644 core/src/runtime/dynamic/aggregator/prod.rs delete mode 100644 core/src/runtime/dynamic/aggregator/sum.rs delete mode 100644 core/src/runtime/dynamic/aggregator/top_k.rs create mode 100644 core/src/runtime/monitor/dump_proofs.rs create mode 100644 core/src/runtime/monitor/registry.rs create mode 100644 core/src/runtime/statics/aggregator/argprod.rs create mode 100644 core/src/runtime/statics/aggregator/argsum.rs create mode 100644 core/src/runtime/statics/utils/flatten_tuple.rs create mode 100644 core/tests/integrate/aggregate.rs create mode 100644 core/tests/integrate/sampling.rs create mode 100644 etc/scallopy-ext/src/scallopy_ext/utils.py diff --git a/.github/workflows/scallop-core.yml b/.github/workflows/scallop-core.yml index 95591ba..2fa2aea 100644 --- a/.github/workflows/scallop-core.yml +++ b/.github/workflows/scallop-core.yml @@ -3,8 +3,12 @@ name: Scallop Core on: push: branches: [ master ] + paths: + - "**.rs" pull_request: branches: [ master ] + paths: + - "**.rs" env: CARGO_TERM_COLOR: always diff --git a/.github/workflows/scallopy.yml b/.github/workflows/scallopy.yml index a7d2ae5..6be6a05 100644 --- a/.github/workflows/scallopy.yml +++ b/.github/workflows/scallopy.yml @@ -1,6 +1,16 @@ name: Scallopy -on: [push] +on: + push: + branches: [ master ] + paths: + - "**.py" + - "**.rs" + pull_request: + branches: [ master ] + paths: + - "**.py" + - "**.rs" env: SCALLOPDIR: ${{ github.workspace }} diff --git a/changelog.md b/changelog.md index 8c7a682..9b30f23 100644 --- a/changelog.md +++ b/changelog.md @@ -1,4 +1,18 @@ -# (Latest) v0.2.0, Jun 11, 2023 +# v0.2.2, (WIP) + +- Adding `wmc_with_disjunctions` option for provenances that deal with boolean formulas for more accurate probability estimation + +# v0.2.1, Sep 12, 2023 + +- Democratizing foreign functions and foreign predicates so that they can be implemented in Python +- Adding foreign attributes which are higher-order functions +- Adding `scallop-ext` the extension library and multiple Scallop plugins, including `scallop-gpt`, `scallop-clip`, and so on. +- Fixed multiple bugs related foreign predicate computation +- Adding `count!` aggregator for non-probabilistic operation +- Fixed sum and product aggregator so that they can accept additional argument +- Multiple bugs fixed; performance improved + +# v0.2.0, Jun 11, 2023 - Fixing CSV loading and its performance; adding new modes to specify `keys` - Adding `Symbol` type to the language diff --git a/core/Cargo.toml b/core/Cargo.toml index 63f81c2..80b419d 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallop-core" -version = "0.2.0" +version = "0.2.1" authors = ["Ziyang Li "] edition = "2018" @@ -27,6 +27,7 @@ dateparser = "0.1.6" dyn-clone = "1.0.10" lazy_static = "1.4" serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" parse_relative_duration = { path = "../lib/parse_relative_duration" } rand = { version = "0.8", features = ["std_rng", "small_rng", "alloc"] } astnode-derive = { path = "../lib/astnode-derive" } diff --git a/core/src/common/aggregate_op.rs b/core/src/common/aggregate_op.rs index 2213d33..ef297a5 100644 --- a/core/src/common/aggregate_op.rs +++ b/core/src/common/aggregate_op.rs @@ -6,8 +6,8 @@ use super::value_type::*; #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] pub enum AggregateOp { Count { discrete: bool }, - Sum(ValueType), - Prod(ValueType), + Sum { has_arg: bool, ty: ValueType }, + Prod { has_arg: bool, ty: ValueType }, Min, Argmin, Max, @@ -20,9 +20,27 @@ pub enum AggregateOp { impl std::fmt::Display for AggregateOp { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Count { discrete } => if *discrete { f.write_str("discrete_count") } else { f.write_str("count") }, - Self::Sum(t) => f.write_fmt(format_args!("sum<{}>", t)), - Self::Prod(t) => f.write_fmt(format_args!("prod<{}>", t)), + Self::Count { discrete } => { + if *discrete { + f.write_str("discrete_count") + } else { + f.write_str("count") + } + } + Self::Sum { has_arg, ty } => { + if *has_arg { + f.write_fmt(format_args!("sum_wa<{}>", ty)) + } else { + f.write_fmt(format_args!("sum<{}>", ty)) + } + } + Self::Prod { has_arg, ty } => { + if *has_arg { + f.write_fmt(format_args!("prod_wa<{}>", ty)) + } else { + f.write_fmt(format_args!("prod<{}>", ty)) + } + } Self::Min => f.write_str("min"), Self::Max => f.write_str("max"), Self::Argmin => f.write_str("argmin"), @@ -66,4 +84,18 @@ impl AggregateOp { pub fn categorical_k(k: usize) -> Self { Self::CategoricalK(k) } + + pub fn is_min_max(&self) -> bool { + match self { + Self::Min | Self::Max | Self::Argmin | Self::Argmax => true, + _ => false, + } + } + + pub fn is_sum_prod(&self) -> bool { + match self { + Self::Sum { .. } | Self::Prod { .. } => true, + _ => false, + } + } } diff --git a/core/src/common/foreign_aggregate.rs b/core/src/common/foreign_aggregate.rs new file mode 100644 index 0000000..760ba6d --- /dev/null +++ b/core/src/common/foreign_aggregate.rs @@ -0,0 +1,319 @@ +use std::collections::*; + +use crate::runtime::dynamic::*; +use crate::runtime::env::RuntimeEnvironment; +use crate::runtime::provenance::*; + +use super::type_family::*; +use super::value::*; +use super::value_type::*; + +use super::foreign_aggregates::*; + +#[derive(Clone, Debug)] +pub enum GenericTypeFamily { + /// A unit tuple or a given type family + UnitOr(Box), + + /// An arbitrary tuple, can have variable lengths > 0 + NonEmptyTuple, + + /// A tuple, with elements length > 0 + NonEmptyTupleWithElements(Vec), + + /// A base value type family + TypeFamily(TypeFamily), +} + +impl GenericTypeFamily { + pub fn unit_or(g: Self) -> Self { + Self::UnitOr(Box::new(g)) + } + + pub fn type_family(t: TypeFamily) -> Self { + Self::TypeFamily(t) + } + + pub fn non_empty_tuple() -> Self { + Self::NonEmptyTuple + } + + pub fn possibly_empty_tuple() -> Self { + Self::unit_or(Self::non_empty_tuple()) + } + + pub fn non_empty_tuple_with_elements(elems: Vec) -> Self { + assert!(elems.len() > 0, "elements must be non-empty"); + Self::NonEmptyTupleWithElements(elems) + } + + pub fn possibly_empty_tuple_with_elements(elems: Vec) -> Self { + Self::unit_or(Self::non_empty_tuple_with_elements(elems)) + } + + pub fn is_type_family(&self) -> bool { + match self { + Self::TypeFamily(_) => true, + _ => false, + } + } + + pub fn as_type_family(&self) -> Option<&TypeFamily> { + match self { + Self::TypeFamily(tf) => Some(tf), + _ => None, + } + } +} + +#[derive(Clone, Debug)] +pub enum BindingTypes { + /// A tuple type; arity-0 means unit + TupleType(Vec), + + /// Depending on whether the generic type is evaluated to be unit, choose + /// between the `then_type` or the `else_type` + IfNotUnit { + generic_type: String, + then_type: Box, + else_type: Box, + }, +} + +impl BindingTypes { + pub fn unit() -> Self { + Self::TupleType(vec![]) + } + + pub fn generic(s: &str) -> Self { + Self::TupleType(vec![BindingType::generic(s)]) + } + + pub fn value_type(v: ValueType) -> Self { + Self::TupleType(vec![BindingType::value_type(v)]) + } + + pub fn if_not_unit(t: &str, t1: Self, t2: Self) -> Self { + Self::IfNotUnit { + generic_type: t.to_string(), + then_type: Box::new(t1), + else_type: Box::new(t2), + } + } + + pub fn empty_tuple() -> Self { + Self::tuple(vec![]) + } + + pub fn tuple(elems: Vec) -> Self { + Self::TupleType(elems) + } +} + +#[derive(Clone, Debug)] +pub enum BindingType { + /// A generic type + Generic(String), + + /// A single value type + ValueType(ValueType), +} + +impl BindingType { + pub fn generic(s: &str) -> Self { + Self::Generic(s.to_string()) + } + + pub fn value_type(value_type: ValueType) -> Self { + Self::ValueType(value_type) + } + + pub fn is_generic(&self) -> bool { + match self { + Self::Generic(_) => false, + _ => true, + } + } + + pub fn as_value_type(&self) -> Option<&ValueType> { + match self { + Self::ValueType(vt) => Some(vt), + _ => None, + } + } +} + +#[derive(Clone, Debug)] +pub enum ParamType { + Mandatory(ValueType), + Optional(ValueType), +} + +/// The type of an aggregator +/// +/// ``` ignore +/// OUTPUT_TYPE := AGGREGATE[ARG_TYPE](INPUT_TYPE) +/// ``` +#[derive(Clone, Debug)] +pub struct AggregateType { + pub generics: HashMap, + pub param_types: Vec, + pub arg_type: BindingTypes, + pub input_type: BindingTypes, + pub output_type: BindingTypes, + pub allow_exclamation_mark: bool, +} + +pub trait Aggregate: Into { + /// The concrete aggregator that this aggregate is instantiated into + type Aggregator: Aggregator

; + + /// The name of the aggregate + fn name(&self) -> String; + + /// The type of the aggregate + fn aggregate_type(&self) -> AggregateType; + + /// Instantiate the aggregate into an aggregator with the given parameters + fn instantiate( + &self, + params: Vec, + has_exclamation_mark: bool, + arg_types: Vec, + input_types: Vec, + ) -> Self::Aggregator

; +} + +/// A dynamic aggregate kind +#[derive(Clone)] +pub enum DynamicAggregate { + Avg(AvgAggregate), + Count(CountAggregate), + Exists(ExistsAggregate), + MinMax(MinMaxAggregate), + Sampler(DynamicSampleAggregate), + StringJoin(StringJoinAggregate), + SumProd(SumProdAggregate), + WeightedSumAvg(WeightedSumAvgAggregate), +} + +macro_rules! match_aggregate { + ($a: expr, $v:ident, $e:expr) => { + match $a { + DynamicAggregate::Avg($v) => $e, + DynamicAggregate::Count($v) => $e, + DynamicAggregate::MinMax($v) => $e, + DynamicAggregate::SumProd($v) => $e, + DynamicAggregate::StringJoin($v) => $e, + DynamicAggregate::WeightedSumAvg($v) => $e, + DynamicAggregate::Exists($v) => $e, + DynamicAggregate::Sampler($v) => $e, + } + }; +} + +impl DynamicAggregate { + pub fn name(&self) -> String { + match_aggregate!(self, a, a.name()) + } + + pub fn aggregate_type(&self) -> AggregateType { + match_aggregate!(self, a, a.aggregate_type()) + } +} + +impl std::fmt::Debug for DynamicAggregate { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match_aggregate!(self, a, f.write_str(&a.name())) + } +} + +/// The registry of aggregates +#[derive(Clone, Debug)] +pub struct AggregateRegistry { + registry: HashMap, +} + +impl AggregateRegistry { + /// Create an empty registry + pub fn new() -> Self { + Self { + registry: HashMap::new(), + } + } + + pub fn std() -> Self { + let mut registry = Self::new(); + + // Register + registry.register(CountAggregate); + registry.register(MinMaxAggregate::min()); + registry.register(MinMaxAggregate::max()); + registry.register(MinMaxAggregate::argmin()); + registry.register(MinMaxAggregate::argmax()); + registry.register(SumProdAggregate::sum()); + registry.register(SumProdAggregate::prod()); + registry.register(AvgAggregate::new()); + registry.register(WeightedSumAvgAggregate::weighted_sum()); + registry.register(WeightedSumAvgAggregate::weighted_avg()); + registry.register(ExistsAggregate::new()); + registry.register(StringJoinAggregate::new()); + registry.register(DynamicSampleAggregate::new(TopKSamplerAggregate::top())); + registry.register(DynamicSampleAggregate::new(TopKSamplerAggregate::unique())); + registry.register(DynamicSampleAggregate::new(CategoricalAggregate::new())); + registry.register(DynamicSampleAggregate::new(UniformAggregate::new())); + + // Return + registry + } + + /// Register an aggregate into the registry + pub fn register(&mut self, agg: A) { + self.registry.entry(agg.name()).or_insert(agg.into()); + } + + pub fn iter(&self) -> impl Iterator { + self.registry.iter() + } + + pub fn instantiate_aggregator( + &self, + name: &str, + params: Vec, + has_exclamation_mark: bool, + arg_types: Vec, + input_types: Vec, + ) -> Option> { + if let Some(aggregate) = self.registry.get(name) { + match_aggregate!(aggregate, a, { + let instantiated = a.instantiate::

(params, has_exclamation_mark, arg_types, input_types); + Some(DynamicAggregator(Box::new(instantiated))) + }) + } else { + None + } + } +} + +pub trait Aggregator: dyn_clone::DynClone + 'static { + fn aggregate(&self, p: &P, env: &RuntimeEnvironment, elems: DynamicElements

) -> DynamicElements

; +} + +pub struct DynamicAggregator(Box>); + +impl Clone for DynamicAggregator

{ + fn clone(&self) -> Self { + Self(dyn_clone::clone_box(&*self.0)) + } +} + +impl DynamicAggregator { + pub fn aggregate( + &self, + prov: &Prov, + env: &RuntimeEnvironment, + elems: DynamicElements, + ) -> DynamicElements { + self.0.aggregate(prov, env, elems) + } +} diff --git a/core/src/common/foreign_aggregates/avg.rs b/core/src/common/foreign_aggregates/avg.rs new file mode 100644 index 0000000..d487502 --- /dev/null +++ b/core/src/common/foreign_aggregates/avg.rs @@ -0,0 +1,122 @@ +use itertools::Itertools; + +use crate::common::tuple::*; +use crate::common::type_family::*; +use crate::common::value::*; +use crate::common::value_type::*; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; +use crate::runtime::provenance::*; + +use super::*; + +#[derive(Clone)] +pub struct AvgAggregate; + +impl AvgAggregate { + pub fn new() -> Self { + Self + } +} + +impl Into for AvgAggregate { + fn into(self) -> DynamicAggregate { + DynamicAggregate::Avg(self) + } +} + +impl Aggregate for AvgAggregate { + type Aggregator = AvgAggregator; + + fn name(&self) -> String { + "avg".to_string() + } + + fn aggregate_type(&self) -> AggregateType { + AggregateType { + generics: vec![ + ("A".to_string(), GenericTypeFamily::possibly_empty_tuple()), + ("T".to_string(), GenericTypeFamily::type_family(TypeFamily::Number)), + ] + .into_iter() + .collect(), + param_types: vec![], + arg_type: BindingTypes::generic("A"), + input_type: BindingTypes::generic("T"), + output_type: BindingTypes::generic("T"), + allow_exclamation_mark: true, + } + } + + fn instantiate( + &self, + _params: Vec, + has_exclamation_mark: bool, + arg_types: Vec, + input_types: Vec, + ) -> Self::Aggregator

{ + AvgAggregator { + non_multi_world: has_exclamation_mark, + num_args: arg_types.len(), + value_type: input_types[0].clone(), + } + } +} + +#[derive(Clone)] +pub struct AvgAggregator { + non_multi_world: bool, + num_args: usize, + value_type: ValueType, +} + +impl AvgAggregator { + pub fn avg(non_multi_world: bool) -> Self + where + ValueType: FromType, + { + Self { + non_multi_world, + num_args: 0, + value_type: ValueType::from_type(), + } + } +} + +impl AvgAggregator { + pub fn perform_avg<'a, I: Iterator>(&self, i: I) -> Tuple { + if self.num_args > 0 { + let iterator = i.map(|t| &t[self.num_args]); + self.value_type.avg(iterator) + } else { + self.value_type.avg(i) + } + } +} + +impl Aggregator

for AvgAggregator { + default fn aggregate(&self, prov: &P, _env: &RuntimeEnvironment, batch: DynamicElements

) -> DynamicElements

{ + if self.non_multi_world { + let res = self.perform_avg(batch.iter_tuples()); + vec![DynamicElement::new(res, prov.one())] + } else { + let mut result = vec![]; + for chosen_set in (0..batch.len()).powerset() { + let res = self.perform_avg(chosen_set.iter().map(|i| &batch[*i].tuple)); + let maybe_tag = batch.iter().enumerate().fold(Some(prov.one()), |maybe_acc, (i, elem)| { + maybe_acc.and_then(|acc| { + if chosen_set.contains(&i) { + Some(prov.mult(&acc, &elem.tag)) + } else { + prov.negate(&elem.tag).map(|neg_tag| prov.mult(&acc, &neg_tag)) + } + }) + }); + if let Some(tag) = maybe_tag { + result.push(DynamicElement::new(res, tag)); + } + } + result + } + } +} diff --git a/core/src/common/foreign_aggregates/categorical.rs b/core/src/common/foreign_aggregates/categorical.rs new file mode 100644 index 0000000..35bbfc5 --- /dev/null +++ b/core/src/common/foreign_aggregates/categorical.rs @@ -0,0 +1,77 @@ +use std::collections::*; + +use rand::distributions::WeightedIndex; + +use crate::common::value::*; +use crate::common::value_type::*; +use crate::runtime::env::*; + +use super::*; + +#[derive(Clone)] +pub struct CategoricalAggregate; + +impl CategoricalAggregate { + pub fn new() -> Self { + Self + } +} + +impl Into for CategoricalAggregate { + fn into(self) -> DynamicAggregate { + DynamicAggregate::Sampler(DynamicSampleAggregate::new(self)) + } +} + +impl SampleAggregate for CategoricalAggregate { + fn name(&self) -> String { + "uniform".to_string() + } + + fn param_types(&self) -> Vec { + vec![ParamType::Mandatory(ValueType::USize)] + } + + fn instantiate( + &self, + params: Vec, + _has_exclamation_mark: bool, + _arg_types: Vec, + _input_types: Vec, + ) -> DynamicSampler { + CategoricalSampler { + k: params[0].as_usize(), + } + .into() + } +} + +#[derive(Clone)] +pub struct CategoricalSampler { + k: usize, +} + +impl Into for CategoricalSampler { + fn into(self) -> DynamicSampler { + DynamicSampler::new(self) + } +} + +impl Sampler for CategoricalSampler { + fn sampler_type(&self) -> SamplerType { + SamplerType::WeightOnly + } + + fn sample_weight_only(&self, rt: &RuntimeEnvironment, weights: Vec) -> Vec { + if weights.len() <= self.k { + (0..weights.len()).collect() + } else { + let dist = WeightedIndex::new(&weights).unwrap(); + (0..self.k) + .map(|_| rt.random.sample_from(&dist)) + .collect::>() + .into_iter() + .collect() + } + } +} diff --git a/core/src/common/foreign_aggregates/count.rs b/core/src/common/foreign_aggregates/count.rs new file mode 100644 index 0000000..438e87d --- /dev/null +++ b/core/src/common/foreign_aggregates/count.rs @@ -0,0 +1,96 @@ +use itertools::Itertools; + +use crate::common::value::*; +use crate::common::value_type::*; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; +use crate::runtime::provenance::*; + +use super::*; + +#[derive(Clone)] +pub struct CountAggregate; + +impl Into for CountAggregate { + fn into(self) -> DynamicAggregate { + DynamicAggregate::Count(self) + } +} + +impl Aggregate for CountAggregate { + type Aggregator = CountAggregator; + + fn name(&self) -> String { + "count".to_string() + } + + /// `{T: non-empty-tuple} ==> usize := count(T)` + fn aggregate_type(&self) -> AggregateType { + AggregateType { + generics: std::iter::once(("T".to_string(), GenericTypeFamily::non_empty_tuple())).collect(), + param_types: vec![], + arg_type: BindingTypes::unit(), + input_type: BindingTypes::generic("T"), + output_type: BindingTypes::value_type(ValueType::USize), + allow_exclamation_mark: true, + } + } + + fn instantiate( + &self, + _params: Vec, + has_exlamation_mark: bool, + _arg_types: Vec, + _input_types: Vec, + ) -> Self::Aggregator

{ + CountAggregator { + non_multi_world: has_exlamation_mark, + } + } +} + +#[derive(Clone)] +pub struct CountAggregator { + pub non_multi_world: bool, +} + +impl CountAggregator { + pub fn new(non_multi_world: bool) -> Self { + Self { non_multi_world } + } +} + +impl Aggregator for CountAggregator { + default fn aggregate( + &self, + prov: &Prov, + _env: &RuntimeEnvironment, + batch: DynamicElements, + ) -> DynamicElements { + if self.non_multi_world { + vec![DynamicElement::new(batch.len(), prov.one())] + } else { + let mut result = vec![]; + if batch.is_empty() { + result.push(DynamicElement::new(0usize, prov.one())); + } else { + for chosen_set in (0..batch.len()).powerset() { + let count = chosen_set.len(); + let maybe_tag = batch.iter().enumerate().fold(Some(prov.one()), |maybe_acc, (i, elem)| { + maybe_acc.and_then(|acc| { + if chosen_set.contains(&i) { + Some(prov.mult(&acc, &elem.tag)) + } else { + prov.negate(&elem.tag).map(|neg_tag| prov.mult(&acc, &neg_tag)) + } + }) + }); + if let Some(tag) = maybe_tag { + result.push(DynamicElement::new(count, tag)); + } + } + } + result + } + } +} diff --git a/core/src/common/foreign_aggregates/exists.rs b/core/src/common/foreign_aggregates/exists.rs new file mode 100644 index 0000000..eea5894 --- /dev/null +++ b/core/src/common/foreign_aggregates/exists.rs @@ -0,0 +1,99 @@ +use crate::common::value_type::*; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; +use crate::runtime::provenance::*; + +use super::*; + +#[derive(Clone)] +pub struct ExistsAggregate; + +impl ExistsAggregate { + pub fn new() -> Self { + Self + } +} + +impl Into for ExistsAggregate { + fn into(self) -> DynamicAggregate { + DynamicAggregate::Exists(self) + } +} + +impl Aggregate for ExistsAggregate { + type Aggregator = ExistsAggregator; + + fn name(&self) -> String { + "exists".to_string() + } + + fn aggregate_type(&self) -> AggregateType { + AggregateType { + generics: vec![("T".to_string(), GenericTypeFamily::possibly_empty_tuple())] + .into_iter() + .collect(), + param_types: vec![], + arg_type: BindingTypes::empty_tuple(), + input_type: BindingTypes::generic("T"), + output_type: BindingTypes::value_type(ValueType::Bool), + allow_exclamation_mark: true, + } + } + + fn instantiate( + &self, + _params: Vec, + has_exclamation_mark: bool, + _arg_types: Vec, + _input_types: Vec, + ) -> Self::Aggregator

{ + ExistsAggregator { + non_multi_world: has_exclamation_mark, + } + } +} + +#[derive(Clone)] +pub struct ExistsAggregator { + pub non_multi_world: bool, +} + +impl ExistsAggregator { + pub fn new(non_multi_world: bool) -> Self { + Self { non_multi_world } + } +} + +impl Aggregator for ExistsAggregator { + default fn aggregate( + &self, + prov: &Prov, + _env: &RuntimeEnvironment, + batch: DynamicElements, + ) -> DynamicElements { + if self.non_multi_world { + vec![DynamicElement::new(!batch.is_empty(), prov.one())] + } else { + let mut maybe_exists_tag = None; + let mut maybe_not_exists_tag = Some(prov.one()); + for elem in batch { + maybe_exists_tag = match maybe_exists_tag { + Some(exists_tag) => Some(prov.add(&exists_tag, &elem.tag)), + None => Some(elem.tag.clone()), + }; + maybe_not_exists_tag = match maybe_not_exists_tag { + Some(net) => prov + .negate(&elem.tag) + .map(|neg_elem_tag| prov.mult(&net, &neg_elem_tag)), + None => None, + }; + } + + maybe_exists_tag + .into_iter() + .map(|t| DynamicElement::new(true, t)) + .chain(maybe_not_exists_tag.into_iter().map(|t| DynamicElement::new(false, t))) + .collect() + } + } +} diff --git a/core/src/common/foreign_aggregates/min_max.rs b/core/src/common/foreign_aggregates/min_max.rs new file mode 100644 index 0000000..d7cfbe1 --- /dev/null +++ b/core/src/common/foreign_aggregates/min_max.rs @@ -0,0 +1,279 @@ +use crate::common::tuple::*; +use crate::common::tuples::*; +use crate::common::value_type::*; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; +use crate::runtime::provenance::*; + +use super::*; + +#[derive(Clone)] +pub struct MinMaxAggregate { + is_min: bool, + arg_only: bool, +} + +impl MinMaxAggregate { + pub fn min() -> Self { + Self { + is_min: true, + arg_only: false, + } + } + + pub fn max() -> Self { + Self { + is_min: false, + arg_only: false, + } + } + + pub fn argmin() -> Self { + Self { + is_min: true, + arg_only: true, + } + } + + pub fn argmax() -> Self { + Self { + is_min: false, + arg_only: true, + } + } +} + +impl Into for MinMaxAggregate { + fn into(self) -> DynamicAggregate { + DynamicAggregate::MinMax(self) + } +} + +impl Aggregate for MinMaxAggregate { + type Aggregator = MinMaxAggregator; + + fn name(&self) -> String { + if self.is_min { + if self.arg_only { + "argmin".to_string() + } else { + "min".to_string() + } + } else { + if self.arg_only { + "argmax".to_string() + } else { + "max".to_string() + } + } + } + + /// `{A: Tuple?, T: Tuple} ==> (A, T) := min[A](T)` + /// `{A: Tuple?, T: Tuple} ==> (A, T) := max[A](T)` + /// `{A: Tuple, T: Tuple} ==> A := argmin[A](T)` + /// `{A: Tuple, T: Tuple} ==> A := argmax[A](T)` + fn aggregate_type(&self) -> AggregateType { + if self.arg_only { + AggregateType { + generics: vec![ + ("A".to_string(), GenericTypeFamily::non_empty_tuple()), + ("T".to_string(), GenericTypeFamily::non_empty_tuple()), + ] + .into_iter() + .collect(), + param_types: vec![], + arg_type: BindingTypes::generic("A"), + input_type: BindingTypes::generic("T"), + output_type: BindingTypes::generic("A"), + allow_exclamation_mark: true, + } + } else { + AggregateType { + generics: vec![ + ("A".to_string(), GenericTypeFamily::possibly_empty_tuple()), + ("T".to_string(), GenericTypeFamily::non_empty_tuple()), + ] + .into_iter() + .collect(), + param_types: vec![], + arg_type: BindingTypes::generic("A"), + input_type: BindingTypes::generic("T"), + output_type: BindingTypes::if_not_unit( + "A", + BindingTypes::tuple(vec![BindingType::generic("A"), BindingType::generic("T")]), + BindingTypes::generic("T"), + ), + allow_exclamation_mark: true, + } + } + } + + fn instantiate( + &self, + _params: Vec, + has_exclamation_mark: bool, + arg_types: Vec, + _input_types: Vec, + ) -> Self::Aggregator

{ + MinMaxAggregator { + is_min: self.is_min, + arg_only: self.arg_only, + non_multi_world: has_exclamation_mark, + num_args: arg_types.len(), + } + } +} + +#[derive(Clone)] +pub struct MinMaxAggregator { + pub is_min: bool, + pub arg_only: bool, + pub non_multi_world: bool, + pub num_args: usize, +} + +impl MinMaxAggregator { + pub fn min() -> Self { + Self { + is_min: true, + arg_only: false, + non_multi_world: false, + num_args: 0, + } + } + + pub fn max() -> Self { + Self { + is_min: false, + arg_only: false, + non_multi_world: false, + num_args: 0, + } + } + + pub fn argmin(num_args: usize) -> Self { + Self { + is_min: true, + arg_only: true, + non_multi_world: false, + num_args, + } + } + + pub fn argmax(num_args: usize) -> Self { + Self { + is_min: false, + arg_only: true, + non_multi_world: false, + num_args, + } + } +} + +impl MinMaxAggregator { + pub fn post_process_arg>>(&self, elems: I) -> DynamicElements

{ + elems + .map(|e| { + // Depending on whether we only need the argument, return only the arg part of the results + if self.arg_only { + let tuple = if self.num_args == 1 { + e.tuple[0].clone() + } else { + Tuple::tuple(e.tuple[..self.num_args].iter().cloned()) + }; + DynamicElement::new(tuple, e.tag) + } else { + DynamicElement::new(e.tuple, e.tag) + } + }) + .collect() + } + + pub fn discrete_min_max<'a, P: Provenance, I: Iterator>( + &self, + p: &P, + tuple_iter: I, + ) -> DynamicElements

{ + let min_max = if self.is_min { + if self.num_args > 0 { + tuple_iter.arg_minimum(self.num_args) + } else { + tuple_iter.minimum() + } + } else { + if self.num_args > 0 { + tuple_iter.arg_maximum(self.num_args) + } else { + tuple_iter.maximum() + } + }; + self.post_process_arg(min_max.into_iter().map(|t| DynamicElement::new(t.clone(), p.one()))) + } + + pub fn multi_world_min_max(&self, prov: &P, mut batch: DynamicElements

) -> DynamicElements

{ + // First, we sort all the tuples + batch.sort_by_key(|e| e.tuple[self.num_args..].iter().cloned().collect::>()); + let tagged_tuples = batch; + + // Then we compute the strata + // + // For example, for the array [0, 0, 0, 1, 1, 2 , 3, 3], + // We have 4 strata ---1--- --2-- -3- --4-- + // they are represented by their start index, [0, 3, 5, 6] + let strata = { + let (mut maybe_curr_elem, mut strata) = (None, vec![0]); + for (i, tagged_tuple) in tagged_tuples.iter().enumerate() { + if let Some(curr_elem) = maybe_curr_elem { + if &tagged_tuple.tuple[self.num_args..] > curr_elem { + strata.push(i); + } + } else { + maybe_curr_elem = Some(&tagged_tuple.tuple[self.num_args..]); + } + } + strata + }; + + // We now compute the tags + // + // Let's take `minimum` as an example. Suppose we are in a stratum. + // For an element inside of this stratum to be the minimum of the whole batch, + // It has to be the case that all the element before this stratum are FALSE (a.k.a. have a negated tag) + // the elements in and after this stratum do not matter, + let mut result = vec![]; + for (i, stratum_start) in strata.iter().copied().enumerate() { + let stratum_end = if i + 1 < strata.len() { + strata[i + 1] + } else { + tagged_tuples.len() + }; + let false_range = if self.is_min { + 0..stratum_start + } else { + stratum_end..tagged_tuples.len() + }; + let maybe_false_tag = false_range.fold(Some(prov.one()), |maybe_acc, j| { + maybe_acc.and_then(|acc| prov.negate(&tagged_tuples[j].tag).map(|neg| prov.mult(&acc, &neg))) + }); + if let Some(false_tag) = maybe_false_tag { + for j in stratum_start..stratum_end { + let and_true_tag: P::Tag = prov.mult(&false_tag, &tagged_tuples[j].tag); + result.push(DynamicElement::

::new(tagged_tuples[j].tuple.clone(), and_true_tag)); + } + } + } + + // Depending on whether we only need the argument, return only the arg part of the results + self.post_process_arg(result.into_iter()) + } +} + +impl Aggregator

for MinMaxAggregator { + default fn aggregate(&self, p: &P, _env: &RuntimeEnvironment, batch: DynamicElements

) -> DynamicElements

{ + if self.non_multi_world { + self.discrete_min_max(p, batch.iter_tuples()) + } else { + self.multi_world_min_max(p, batch) + } + } +} diff --git a/core/src/common/foreign_aggregates/mod.rs b/core/src/common/foreign_aggregates/mod.rs new file mode 100644 index 0000000..2bbce7c --- /dev/null +++ b/core/src/common/foreign_aggregates/mod.rs @@ -0,0 +1,25 @@ +mod avg; +mod categorical; +mod count; +mod exists; +mod min_max; +mod sampler; +mod string_join; +mod sum_prod; +mod top_k; +mod uniform; +mod weighted_sum_avg; + +pub use avg::*; +pub use categorical::*; +pub use count::*; +pub use exists::*; +pub use min_max::*; +pub use sampler::*; +pub use string_join::*; +pub use sum_prod::*; +pub use top_k::*; +pub use uniform::*; +pub use weighted_sum_avg::*; + +use super::foreign_aggregate::*; diff --git a/core/src/common/foreign_aggregates/sampler.rs b/core/src/common/foreign_aggregates/sampler.rs new file mode 100644 index 0000000..493d56e --- /dev/null +++ b/core/src/common/foreign_aggregates/sampler.rs @@ -0,0 +1,139 @@ +use crate::common::tuple::*; +use crate::common::value::*; +use crate::common::value_type::*; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; +use crate::runtime::provenance::*; + +use super::*; + +pub trait SampleAggregate: dyn_clone::DynClone + 'static { + fn name(&self) -> String; + + fn param_types(&self) -> Vec; + + fn instantiate( + &self, + params: Vec, + has_exclamation_mark: bool, + arg_types: Vec, + input_types: Vec, + ) -> DynamicSampler; +} + +pub struct DynamicSampleAggregate(Box); + +unsafe impl Send for DynamicSampleAggregate {} +unsafe impl Sync for DynamicSampleAggregate {} + +impl DynamicSampleAggregate { + pub fn new(t: T) -> Self { + Self(Box::new(t)) + } +} + +impl Clone for DynamicSampleAggregate { + fn clone(&self) -> Self { + Self(dyn_clone::clone_box(&*self.0)) + } +} + +impl Into for DynamicSampleAggregate { + fn into(self) -> DynamicAggregate { + DynamicAggregate::Sampler(self) + } +} + +impl Aggregate for DynamicSampleAggregate { + type Aggregator = DynamicSampler; + + fn name(&self) -> String { + SampleAggregate::name(&*self.0) + } + + fn aggregate_type(&self) -> AggregateType { + AggregateType { + generics: vec![("T".to_string(), GenericTypeFamily::non_empty_tuple())] + .into_iter() + .collect(), + param_types: SampleAggregate::param_types(&*self.0), + arg_type: BindingTypes::unit(), + input_type: BindingTypes::generic("T"), + output_type: BindingTypes::generic("T"), + allow_exclamation_mark: false, + } + } + + fn instantiate( + &self, + params: Vec, + has_exclamation_mark: bool, + arg_types: Vec, + input_types: Vec, + ) -> Self::Aggregator

{ + SampleAggregate::instantiate(&*self.0, params, has_exclamation_mark, arg_types, input_types) + } +} + +pub enum SamplerType { + LengthOnly, + WeightOnly, + TupleOnly, + WeightAndTuple, +} + +#[allow(unused)] +pub trait Sampler: dyn_clone::DynClone + 'static { + fn sampler_type(&self) -> SamplerType; + + fn sample_length_only(&self, env: &RuntimeEnvironment, len: usize) -> Vec { + vec![] + } + + fn sample_weight_only(&self, env: &RuntimeEnvironment, elems: Vec) -> Vec { + vec![] + } + + fn sample_tuple_only(&self, env: &RuntimeEnvironment, elems: Vec<&Tuple>) -> Vec { + vec![] + } + + fn sample_weight_and_tuple(&self, env: &RuntimeEnvironment, elems: Vec<(f64, &Tuple)>) -> Vec { + vec![] + } +} + +pub struct DynamicSampler(Box); + +impl DynamicSampler { + pub fn new(t: T) -> Self { + Self(Box::new(t)) + } +} + +impl Clone for DynamicSampler { + fn clone(&self) -> Self { + Self(dyn_clone::clone_box(&*self.0)) + } +} + +impl Aggregator

for DynamicSampler { + fn aggregate(&self, p: &P, env: &RuntimeEnvironment, elems: DynamicElements

) -> DynamicElements

{ + let indices = match self.0.sampler_type() { + SamplerType::LengthOnly => self.0.sample_length_only(env, elems.len()), + SamplerType::TupleOnly => { + let to_sample = elems.iter().map(|e| &e.tuple).collect(); + self.0.sample_tuple_only(env, to_sample) + } + SamplerType::WeightOnly => { + let to_sample = elems.iter().map(|e| p.weight(&e.tag)).collect(); + self.0.sample_weight_only(env, to_sample) + } + SamplerType::WeightAndTuple => { + let to_sample = elems.iter().map(|e| (p.weight(&e.tag), &e.tuple)).collect(); + self.0.sample_weight_and_tuple(env, to_sample) + } + }; + indices.into_iter().map(|i| elems[i].clone()).collect() + } +} diff --git a/core/src/common/foreign_aggregates/string_join.rs b/core/src/common/foreign_aggregates/string_join.rs new file mode 100644 index 0000000..6302d6d --- /dev/null +++ b/core/src/common/foreign_aggregates/string_join.rs @@ -0,0 +1,107 @@ +use itertools::Itertools; + +use crate::common::tuple::*; +use crate::common::value::*; +use crate::common::value_type::*; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; +use crate::runtime::provenance::*; + +use super::*; + +#[derive(Clone)] +pub struct StringJoinAggregate; + +impl StringJoinAggregate { + pub fn new() -> Self { + Self + } +} + +impl Into for StringJoinAggregate { + fn into(self) -> DynamicAggregate { + DynamicAggregate::StringJoin(self) + } +} + +impl Aggregate for StringJoinAggregate { + type Aggregator = StringJoinAggregator; + + fn name(&self) -> String { + "string_join".to_string() + } + + fn aggregate_type(&self) -> AggregateType { + AggregateType { + generics: vec![ + ("A".to_string(), GenericTypeFamily::possibly_empty_tuple()), + ] + .into_iter() + .collect(), + param_types: vec![ParamType::Optional(ValueType::String)], + arg_type: BindingTypes::generic("A"), + input_type: BindingTypes::value_type(ValueType::String), + output_type: BindingTypes::value_type(ValueType::String), + allow_exclamation_mark: true, + } + } + + fn instantiate( + &self, + params: Vec, + has_exclamation_mark: bool, + arg_types: Vec, + _input_types: Vec, + ) -> Self::Aggregator

{ + StringJoinAggregator { + non_multi_world: has_exclamation_mark, + num_args: arg_types.len(), + separator: params.get(0).map(|v| v.as_str().to_string()).unwrap_or("".to_string()), + } + } +} + +#[derive(Clone)] +pub struct StringJoinAggregator { + non_multi_world: bool, + num_args: usize, + separator: String, +} + +impl StringJoinAggregator { + pub fn perform_string_join<'a, I: Iterator>(&self, i: I) -> Tuple { + let strings: Vec<_> = if self.num_args > 0 { + i.map(|t| t[self.num_args].as_string()).collect() + } else { + i.map(|t| t.as_string()).collect() + }; + strings.join(&self.separator).into() + } +} + +impl Aggregator

for StringJoinAggregator { + default fn aggregate(&self, prov: &P, _env: &RuntimeEnvironment, batch: DynamicElements

) -> DynamicElements

{ + if self.non_multi_world { + let res = self.perform_string_join(batch.iter_tuples()); + vec![DynamicElement::new(res, prov.one())] + } else { + let mut result = vec![]; + for chosen_set in (0..batch.len()).powerset() { + let res = self.perform_string_join(chosen_set.iter().map(|i| &batch[*i].tuple)); + let maybe_tag = batch.iter().enumerate().fold(Some(prov.one()), |maybe_acc, (i, elem)| { + maybe_acc.and_then(|acc| { + if chosen_set.contains(&i) { + Some(prov.mult(&acc, &elem.tag)) + } else { + prov.negate(&elem.tag).map(|neg_tag| prov.mult(&acc, &neg_tag)) + } + }) + }); + if let Some(tag) = maybe_tag { + result.push(DynamicElement::new(res, tag)); + } + } + result + } + } +} diff --git a/core/src/common/foreign_aggregates/sum_prod.rs b/core/src/common/foreign_aggregates/sum_prod.rs new file mode 100644 index 0000000..d4496c8 --- /dev/null +++ b/core/src/common/foreign_aggregates/sum_prod.rs @@ -0,0 +1,159 @@ +use itertools::Itertools; + +use crate::common::tuple::*; +use crate::common::type_family::*; +use crate::common::value::*; +use crate::common::value_type::*; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; +use crate::runtime::provenance::*; + +use super::*; + +#[derive(Clone)] +pub struct SumProdAggregate { + is_sum: bool, +} + +impl SumProdAggregate { + pub fn sum() -> Self { + Self { is_sum: true } + } + + pub fn prod() -> Self { + Self { is_sum: false } + } +} + +impl Into for SumProdAggregate { + fn into(self) -> DynamicAggregate { + DynamicAggregate::SumProd(self) + } +} + +impl Aggregate for SumProdAggregate { + type Aggregator = SumProdAggregator; + + fn name(&self) -> String { + if self.is_sum { + "sum".to_string() + } else { + "prod".to_string() + } + } + + fn aggregate_type(&self) -> AggregateType { + AggregateType { + generics: vec![ + ("A".to_string(), GenericTypeFamily::possibly_empty_tuple()), + ("T".to_string(), GenericTypeFamily::type_family(TypeFamily::Number)), + ] + .into_iter() + .collect(), + param_types: vec![], + arg_type: BindingTypes::generic("A"), + input_type: BindingTypes::generic("T"), + output_type: BindingTypes::generic("T"), + allow_exclamation_mark: true, + } + } + + fn instantiate( + &self, + _params: Vec, + has_exclamation_mark: bool, + arg_types: Vec, + input_types: Vec, + ) -> Self::Aggregator

{ + assert!( + input_types.len() == 1, + "sum/prod aggregate should take in argument of only size 1" + ); + SumProdAggregator { + is_sum: self.is_sum, + non_multi_world: has_exclamation_mark, + num_args: arg_types.len(), + value_type: input_types[0].clone(), + } + } +} + +#[derive(Clone)] +pub struct SumProdAggregator { + is_sum: bool, + non_multi_world: bool, + num_args: usize, + value_type: ValueType, +} + +impl SumProdAggregator { + pub fn sum(non_multi_world: bool) -> Self + where + ValueType: FromType, + { + Self { + is_sum: true, + non_multi_world, + num_args: 0, + value_type: ValueType::from_type(), + } + } + + pub fn prod(non_multi_world: bool) -> Self + where + ValueType: FromType, + { + Self { + is_sum: false, + non_multi_world, + num_args: 0, + value_type: ValueType::from_type(), + } + } +} + +impl SumProdAggregator { + pub fn perform_sum_prod<'a, I: Iterator>(&self, i: I) -> Tuple { + if self.num_args > 0 { + let iterator = i.map(|t| &t[self.num_args]); + if self.is_sum { + self.value_type.sum(iterator) + } else { + self.value_type.prod(iterator) + } + } else { + if self.is_sum { + self.value_type.sum(i) + } else { + self.value_type.prod(i) + } + } + } +} + +impl Aggregator

for SumProdAggregator { + default fn aggregate(&self, prov: &P, _env: &RuntimeEnvironment, batch: DynamicElements

) -> DynamicElements

{ + if self.non_multi_world { + let res = self.perform_sum_prod(batch.iter_tuples()); + vec![DynamicElement::new(res, prov.one())] + } else { + let mut result = vec![]; + for chosen_set in (0..batch.len()).powerset() { + let res = self.perform_sum_prod(chosen_set.iter().map(|i| &batch[*i].tuple)); + let maybe_tag = batch.iter().enumerate().fold(Some(prov.one()), |maybe_acc, (i, elem)| { + maybe_acc.and_then(|acc| { + if chosen_set.contains(&i) { + Some(prov.mult(&acc, &elem.tag)) + } else { + prov.negate(&elem.tag).map(|neg_tag| prov.mult(&acc, &neg_tag)) + } + }) + }); + if let Some(tag) = maybe_tag { + result.push(DynamicElement::new(res, tag)); + } + } + result + } + } +} diff --git a/core/src/common/foreign_aggregates/top_k.rs b/core/src/common/foreign_aggregates/top_k.rs new file mode 100644 index 0000000..ebbfe56 --- /dev/null +++ b/core/src/common/foreign_aggregates/top_k.rs @@ -0,0 +1,158 @@ +use std::collections::*; + +use crate::common::value::*; +use crate::common::value_type::*; +use crate::runtime::env::*; + +use super::*; + +#[derive(Clone)] +pub struct TopKSamplerAggregate { + is_unique: bool, +} + +impl TopKSamplerAggregate { + pub fn top() -> Self { + Self { is_unique: false } + } + + pub fn unique() -> Self { + Self { is_unique: true } + } +} + +impl Into for TopKSamplerAggregate { + fn into(self) -> DynamicAggregate { + DynamicAggregate::Sampler(DynamicSampleAggregate::new(self)) + } +} + +impl SampleAggregate for TopKSamplerAggregate { + fn name(&self) -> String { + if self.is_unique { + "unique".to_string() + } else { + "top".to_string() + } + } + + fn param_types(&self) -> Vec { + if self.is_unique { + vec![] + } else { + vec![ParamType::Optional(ValueType::USize)] + } + } + + fn instantiate(&self, params: Vec, _: bool, _: Vec, _: Vec) -> DynamicSampler { + if self.is_unique { + TopKSampler { k: 1 }.into() + } else { + TopKSampler { + k: params.get(0).map(|v| v.as_usize()).unwrap_or(1), + } + .into() + } + } +} + +#[derive(Clone)] +pub struct TopKSampler { + k: usize, +} + +impl TopKSampler { + pub fn new(k: usize) -> Self { + Self { k } + } +} + +impl Into for TopKSampler { + fn into(self) -> DynamicSampler { + DynamicSampler::new(self) + } +} + +impl Sampler for TopKSampler { + fn sampler_type(&self) -> SamplerType { + SamplerType::WeightOnly + } + + fn sample_weight_only(&self, _: &RuntimeEnvironment, elems: Vec) -> Vec { + aggregate_top_k_helper(elems.len(), self.k, |i| elems[i]) + } +} + +pub fn aggregate_top_k_helper(num_elements: usize, k: usize, weight_fn: F) -> Vec +where + F: Fn(usize) -> f64, +{ + #[derive(Clone, Debug)] + struct Element { + id: usize, + weight: f64, + } + + impl std::cmp::PartialEq for Element { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } + } + + impl std::cmp::Eq for Element {} + + impl std::cmp::PartialOrd for Element { + fn partial_cmp(&self, other: &Self) -> Option { + other.weight.partial_cmp(&self.weight) + } + } + + impl std::cmp::Ord for Element { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + if let Some(ord) = other.weight.partial_cmp(&self.weight) { + ord + } else { + other.id.cmp(&self.id) + } + } + } + + // Create a min-heap + let mut heap = BinaryHeap::new(); + + // First insert k elements into the heap + let size = k.min(num_elements); + for id in 0..size { + let elem = Element { + id, + weight: weight_fn(id), + }; + heap.push(elem); + } + + // Then iterate through all other elements + if heap.len() > 0 { + for id in size..num_elements { + let elem = Element { + id, + weight: weight_fn(id), + }; + let min_elem_in_heap = heap.peek().unwrap(); + if &elem < min_elem_in_heap { + heap.pop(); + heap.push(elem); + } + } + } + + // Return the list of ids in the heap + heap.into_iter().map(|elem| elem.id).collect() +} + +pub fn unweighted_aggregate_top_k_helper(elements: Vec, k: usize) -> Vec { + if elements.len() <= k { + elements + } else { + elements.into_iter().take(k).collect() + } +} diff --git a/core/src/common/foreign_aggregates/uniform.rs b/core/src/common/foreign_aggregates/uniform.rs new file mode 100644 index 0000000..13772eb --- /dev/null +++ b/core/src/common/foreign_aggregates/uniform.rs @@ -0,0 +1,74 @@ +use std::collections::*; + +use crate::common::value::*; +use crate::common::value_type::*; +use crate::runtime::env::*; + +use super::*; + +#[derive(Clone)] +pub struct UniformAggregate; + +impl UniformAggregate { + pub fn new() -> Self { + Self + } +} + +impl Into for UniformAggregate { + fn into(self) -> DynamicAggregate { + DynamicAggregate::Sampler(DynamicSampleAggregate::new(self)) + } +} + +impl SampleAggregate for UniformAggregate { + fn name(&self) -> String { + "uniform".to_string() + } + + fn param_types(&self) -> Vec { + vec![ParamType::Mandatory(ValueType::USize)] + } + + fn instantiate( + &self, + params: Vec, + _has_exclamation_mark: bool, + _arg_types: Vec, + _input_types: Vec, + ) -> DynamicSampler { + UniformSampler { + k: params[0].as_usize(), + } + .into() + } +} + +#[derive(Clone)] +pub struct UniformSampler { + k: usize, +} + +impl Into for UniformSampler { + fn into(self) -> DynamicSampler { + DynamicSampler::new(self) + } +} + +impl Sampler for UniformSampler { + fn sampler_type(&self) -> SamplerType { + SamplerType::LengthOnly + } + + fn sample_length_only(&self, rt: &RuntimeEnvironment, len: usize) -> Vec { + if len <= self.k { + (0..len).collect() + } else { + (0..self.k) + .map(|_| rt.random.random_usize(len)) + .collect::>() + .into_iter() + .collect() + } + } +} diff --git a/core/src/common/foreign_aggregates/weighted_sum_avg.rs b/core/src/common/foreign_aggregates/weighted_sum_avg.rs new file mode 100644 index 0000000..63c91c0 --- /dev/null +++ b/core/src/common/foreign_aggregates/weighted_sum_avg.rs @@ -0,0 +1,201 @@ +use itertools::Itertools; +use std::convert::*; + +use crate::common::tuple::*; +use crate::common::type_family::*; +use crate::common::value::*; +use crate::common::value_type::*; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; +use crate::runtime::provenance::*; +use crate::utils::*; + +use super::*; + +#[derive(Clone)] +pub struct WeightedSumAvgAggregate { + is_sum: bool, +} + +impl WeightedSumAvgAggregate { + pub fn weighted_sum() -> Self { + Self { is_sum: true } + } + + pub fn weighted_avg() -> Self { + Self { is_sum: false } + } +} + +impl Into for WeightedSumAvgAggregate { + fn into(self) -> DynamicAggregate { + DynamicAggregate::WeightedSumAvg(self) + } +} + +impl Aggregate for WeightedSumAvgAggregate { + type Aggregator = WeightedSumAvgAggregator; + + fn name(&self) -> String { + if self.is_sum { + "weighted_sum".to_string() + } else { + "weighted_avg".to_string() + } + } + + fn aggregate_type(&self) -> AggregateType { + AggregateType { + generics: vec![ + ("A".to_string(), GenericTypeFamily::possibly_empty_tuple()), + ("T".to_string(), GenericTypeFamily::type_family(TypeFamily::Float)), + ] + .into_iter() + .collect(), + param_types: vec![], + arg_type: BindingTypes::generic("A"), + input_type: BindingTypes::generic("T"), + output_type: BindingTypes::generic("T"), + allow_exclamation_mark: true, + } + } + + fn instantiate( + &self, + _params: Vec, + has_exclamation_mark: bool, + arg_types: Vec, + input_types: Vec, + ) -> Self::Aggregator

{ + assert!( + input_types.len() == 1, + "sum/prod aggregate should take in argument of only size 1" + ); + WeightedSumAvgAggregator { + is_sum: self.is_sum, + non_multi_world: has_exclamation_mark, + num_args: arg_types.len(), + value_type: input_types[0].clone(), + } + } +} + +#[derive(Clone)] +pub struct WeightedSumAvgAggregator { + is_sum: bool, + non_multi_world: bool, + num_args: usize, + value_type: ValueType, +} + +impl WeightedSumAvgAggregator { + pub fn weighted_sum(non_multi_world: bool) -> Self + where + ValueType: FromType, + { + Self { + is_sum: true, + non_multi_world, + num_args: 0, + value_type: ValueType::from_type(), + } + } + + pub fn weighted_avg(non_multi_world: bool) -> Self + where + ValueType: FromType, + { + Self { + is_sum: false, + non_multi_world, + num_args: 0, + value_type: ValueType::from_type(), + } + } +} + +impl WeightedSumAvgAggregator { + pub fn perform_avg<'a, I: Iterator, F: Float + Into>(&self, i: I) -> Tuple + where + Tuple: AsTuple, + { + let (sum, sum_of_weight): (F, F) = if self.num_args > 0 { + i.fold((F::zero(), F::zero()), |(sum, sum_of_weight), (weight, tuple)| { + let mult = F::from_f64(weight) * AsTuple::::as_tuple(&tuple[self.num_args]); + (sum + mult, sum_of_weight + F::from_f64(weight)) + }) + } else { + i.fold((F::zero(), F::zero()), |(sum, sum_of_weight), (weight, tuple)| { + let mult = F::from_f64(weight) * AsTuple::::as_tuple(tuple); + (sum + mult, sum_of_weight + F::from_f64(weight)) + }) + }; + (sum / sum_of_weight).into() + } + + pub fn perform_sum<'a, I: Iterator, F: Float + Into>(&self, i: I) -> Tuple + where + Tuple: AsTuple, + { + let sum: F = if self.num_args > 0 { + i.fold(F::zero(), |sum, (weight, tuple)| { + sum + F::from_f64(weight) * AsTuple::::as_tuple(&tuple[self.num_args]) + }) + } else { + i.fold(F::zero(), |sum, (weight, tuple)| { + sum + F::from_f64(weight) * AsTuple::::as_tuple(tuple) + }) + }; + sum.into() + } + + pub fn perform_sum_avg_typed<'a, I: Iterator, F: Float + Into>(&self, i: I) -> Tuple + where + Tuple: AsTuple, + { + if self.is_sum { + self.perform_sum(i) + } else { + self.perform_avg(i) + } + } + + pub fn perform_sum_avg<'a, I: Iterator>(&self, i: I) -> Tuple { + match self.value_type { + ValueType::F32 => self.perform_sum_avg_typed::<_, f32>(i), + ValueType::F64 => self.perform_sum_avg_typed::<_, f64>(i), + _ => unreachable!("type checking should have confirmed that the value type is f32 or f64"), + } + } +} + +impl Aggregator

for WeightedSumAvgAggregator { + default fn aggregate(&self, prov: &P, _env: &RuntimeEnvironment, batch: DynamicElements

) -> DynamicElements

{ + if self.non_multi_world { + let res = self.perform_sum_avg(batch.iter().map(|e| (prov.weight(&e.tag), &e.tuple))); + vec![DynamicElement::new(res, prov.one())] + } else { + let mut result = vec![]; + for chosen_set in (0..batch.len()).powerset() { + let res = self.perform_sum_avg( + chosen_set + .iter() + .map(|i| (prov.weight(&batch[*i].tag), &batch[*i].tuple)), + ); + let maybe_tag = batch.iter().enumerate().fold(Some(prov.one()), |maybe_acc, (i, elem)| { + maybe_acc.and_then(|acc| { + if chosen_set.contains(&i) { + Some(prov.mult(&acc, &elem.tag)) + } else { + prov.negate(&elem.tag).map(|neg_tag| prov.mult(&acc, &neg_tag)) + } + }) + }); + if let Some(tag) = maybe_tag { + result.push(DynamicElement::new(res, tag)); + } + } + result + } + } +} diff --git a/core/src/common/foreign_functions/string_replace.rs b/core/src/common/foreign_functions/string_replace.rs index 5b61809..0d396b6 100644 --- a/core/src/common/foreign_functions/string_replace.rs +++ b/core/src/common/foreign_functions/string_replace.rs @@ -29,9 +29,7 @@ impl ForeignFunction for StringReplace { fn execute(&self, args: Vec) -> Option { assert_eq!(args.len(), 3); match (&args[0], &args[1], &args[2]) { - (Value::String(s), Value::String(pat), Value::String(replace)) => { - Some(Value::String(s.replace(pat, replace))) - }, + (Value::String(s), Value::String(pat), Value::String(replace)) => Some(Value::String(s.replace(pat, replace))), _ => panic!("Invalid argument, expected string"), } } diff --git a/core/src/common/foreign_tensor/external_tensor.rs b/core/src/common/foreign_tensor/external_tensor.rs index 12f2bc6..3c4bed6 100644 --- a/core/src/common/foreign_tensor/external_tensor.rs +++ b/core/src/common/foreign_tensor/external_tensor.rs @@ -1,5 +1,5 @@ -use std::any::Any; use serde::*; +use std::any::Any; use super::*; @@ -17,9 +17,7 @@ pub struct DynamicExternalTensor { impl DynamicExternalTensor { pub fn new(t: T) -> Self { - Self { - tensor: Box::new(t), - } + Self { tensor: Box::new(t) } } pub fn internal(&self) -> &dyn ExternalTensor { @@ -34,7 +32,7 @@ impl DynamicExternalTensor { impl Clone for DynamicExternalTensor { fn clone(&self) -> Self { Self { - tensor: dyn_clone::clone_box(&*self.tensor) + tensor: dyn_clone::clone_box(&*self.tensor), } } } diff --git a/core/src/common/foreign_tensor/registry.rs b/core/src/common/foreign_tensor/registry.rs index 5357687..81d4713 100644 --- a/core/src/common/foreign_tensor/registry.rs +++ b/core/src/common/foreign_tensor/registry.rs @@ -18,9 +18,7 @@ pub struct DynamicTensorRegistry { impl DynamicTensorRegistry { pub fn new() -> Self { - Self { - maybe_registry: None, - } + Self { maybe_registry: None } } pub fn set(&mut self, t: T) { @@ -36,7 +34,10 @@ impl DynamicTensorRegistry { } pub fn get(&self, int_tensor: &InternalTensorSymbol) -> Option { - self.maybe_registry.as_ref().and_then(|registry| registry.get(int_tensor).cloned()) + self + .maybe_registry + .as_ref() + .and_then(|registry| registry.get(int_tensor).cloned()) } pub fn eval(&self, value: &TensorValue) -> Option { diff --git a/core/src/common/generic_tuple.rs b/core/src/common/generic_tuple.rs index aec6563..2977a3b 100644 --- a/core/src/common/generic_tuple.rs +++ b/core/src/common/generic_tuple.rs @@ -37,6 +37,13 @@ impl GenericTuple { Self::Tuple(_) => None, } } + + pub fn drain(self, i: usize) -> Option> { + match self { + Self::Value(_) => None, + Self::Tuple(vs) => vs.into_vec().into_iter().nth(i), + } + } } impl std::ops::Index for GenericTuple { @@ -50,6 +57,61 @@ impl std::ops::Index for GenericTuple { } } +impl std::ops::Index> for GenericTuple { + type Output = [Self]; + + fn index(&self, range: std::ops::Range) -> &Self::Output { + match self { + Self::Tuple(t) => &t[range], + _ => panic!("Cannot access tuple value with `{:?}`", range), + } + } +} + +impl std::ops::Index> for GenericTuple { + type Output = [Self]; + + fn index(&self, range: std::ops::RangeTo) -> &Self::Output { + match self { + Self::Tuple(t) => &t[range], + _ => panic!("Cannot access tuple value with `{:?}`", range), + } + } +} + +impl std::ops::Index> for GenericTuple { + type Output = [Self]; + + fn index(&self, range: std::ops::RangeFrom) -> &Self::Output { + match self { + Self::Tuple(t) => &t[range], + _ => panic!("Cannot access tuple value with `{:?}`", range), + } + } +} + +impl std::ops::Index> for GenericTuple { + type Output = [Self]; + + fn index(&self, range: std::ops::RangeInclusive) -> &Self::Output { + match self { + Self::Tuple(t) => &t[range], + _ => panic!("Cannot access tuple value with `{:?}`", range), + } + } +} + +impl std::ops::Index for GenericTuple { + type Output = [Self]; + + fn index(&self, range: std::ops::RangeFull) -> &Self::Output { + match self { + Self::Tuple(t) => &t[range], + _ => panic!("Cannot access tuple value with `{:?}`", range), + } + } +} + impl std::ops::Index<&TupleAccessor> for GenericTuple { type Output = GenericTuple; diff --git a/core/src/common/mod.rs b/core/src/common/mod.rs index 83a4a04..85df7c8 100644 --- a/core/src/common/mod.rs +++ b/core/src/common/mod.rs @@ -8,6 +8,8 @@ pub mod duration; pub mod element; pub mod entity; pub mod expr; +pub mod foreign_aggregate; +pub mod foreign_aggregates; pub mod foreign_function; pub mod foreign_functions; pub mod foreign_predicate; diff --git a/core/src/common/tuple.rs b/core/src/common/tuple.rs index 22354a2..5726ce1 100644 --- a/core/src/common/tuple.rs +++ b/core/src/common/tuple.rs @@ -11,6 +11,10 @@ impl Tuple { TupleType::type_of(self) } + pub fn tuple>(i: I) -> Self { + Self::Tuple(i.collect()) + } + pub fn arity(&self) -> usize { match self { Self::Tuple(ts) => ts.len(), diff --git a/core/src/common/tuples.rs b/core/src/common/tuples.rs index 4e8da88..3a8a1e3 100644 --- a/core/src/common/tuples.rs +++ b/core/src/common/tuples.rs @@ -1,94 +1,102 @@ use super::tuple::*; -pub trait Tuples { - fn minimum(self) -> Vec; +pub trait Tuples<'a> { + fn sort(self, num_args: usize) -> Vec<&'a Tuple>; - fn arg_minimum(self) -> Vec; + fn minimum(self) -> Vec<&'a Tuple>; - fn maximum(self) -> Vec; + fn arg_minimum(self, num_args: usize) -> Vec<&'a Tuple>; - fn arg_maximum(self) -> Vec; + fn maximum(self) -> Vec<&'a Tuple>; + + fn arg_maximum(self, num_args: usize) -> Vec<&'a Tuple>; } -impl<'a, I> Tuples for I +impl<'a, I> Tuples<'a> for I where I: Iterator, { - fn minimum(self) -> Vec { + fn sort(self, num_args: usize) -> Vec<&'a Tuple> { + let mut collected = self.collect::>(); + collected.sort_by_key(|e| &e[num_args..]); + collected + } + + fn minimum(self) -> Vec<&'a Tuple> { let mut result = vec![]; let mut min_value = None; for v in self { - if let Some(m) = &min_value { + if let Some(m) = min_value { if v == m { - result.push(v.clone()); + result.push(v); } else if v < m { - min_value = Some(v.clone()); + min_value = Some(v); result.clear(); - result.push(v.clone()); + result.push(v); } } else { - min_value = Some(v.clone()); - result.push(v.clone()); + min_value = Some(v); + result.push(v); } } return result; } - fn arg_minimum(self) -> Vec { + fn arg_minimum(self, num_args: usize) -> Vec<&'a Tuple> { let mut result = vec![]; - let mut min_value = None; + let mut min_value: Option<&[Tuple]> = None; for v in self { if let Some(m) = &min_value { - if &v[1] == m { - result.push(v.clone()); - } else if &v[1] < m { - min_value = Some(v[1].clone()); + if &v[num_args..] == &m[..] { + result.push(v); + } else if &v[num_args..] < m { + min_value = Some(&v[num_args..]); result.clear(); - result.push(v.clone()); + result.push(v); } } else { - min_value = Some(v[1].clone()); - result.push(v.clone()); + min_value = Some(&v[num_args..]); + result.push(v); } } return result; } - fn maximum(self) -> Vec { + fn maximum(self) -> Vec<&'a Tuple> { let mut result = vec![]; let mut min_value = None; for v in self { - if let Some(m) = &min_value { + if let Some(m) = min_value { if v == m { - result.push(v.clone()); + result.push(v); } else if v > m { - min_value = Some(v.clone()); + min_value = Some(v); result.clear(); - result.push(v.clone()); + result.push(v); } } else { - min_value = Some(v.clone()); - result.push(v.clone()); + min_value = Some(v); + result.push(v); } } return result; } - fn arg_maximum(self) -> Vec { + fn arg_maximum(self, num_args: usize) -> Vec<&'a Tuple> { let mut result = vec![]; - let mut min_value = None; + let mut max_value: Option<&[Tuple]> = None; for v in self { - if let Some(m) = &min_value { - if &v[1] == m { - result.push(v.clone()); - } else if &v[1] > m { - min_value = Some(v[1].clone()); + if let Some(m) = &max_value { + if &v[num_args..] == &m[..] { + result.push(v); + } else if &v[num_args..] > &m[..] { + max_value = Some(&v[num_args..]); result.clear(); - result.push(v.clone()); + result.push(v); } } else { - min_value = Some(v[1].clone()); - result.push(v.clone()); + max_value = Some(&v[num_args..]); + result.push(v); } } return result; diff --git a/core/src/common/value.rs b/core/src/common/value.rs index e4315f8..72eaaeb 100644 --- a/core/src/common/value.rs +++ b/core/src/common/value.rs @@ -3,8 +3,8 @@ use std::convert::*; use chrono::{DateTime, Utc}; use chronoutil::RelativeDuration; -use super::foreign_tensor::*; use super::duration::*; +use super::foreign_tensor::*; use super::value_type::*; #[derive(Debug, Clone, PartialEq, PartialOrd)] diff --git a/core/src/common/value_type.rs b/core/src/common/value_type.rs index af08aa1..9d391f2 100644 --- a/core/src/common/value_type.rs +++ b/core/src/common/value_type.rs @@ -285,6 +285,74 @@ impl ValueType { } } + pub fn avg<'a, I: Iterator>(&self, i: I) -> Tuple { + match self { + Self::I8 => { + let (sum, num) = i.fold((0, 0), |(a, n), v| (a + v.as_i8(), n + 1)); + (sum / num).into() + } + Self::I16 => { + let (sum, num) = i.fold((0, 0), |(a, n), v| (a + v.as_i16(), n + 1)); + (sum / num).into() + } + Self::I32 => { + let (sum, num) = i.fold((0, 0), |(a, n), v| (a + v.as_i32(), n + 1)); + (sum / num).into() + } + Self::I64 => { + let (sum, num) = i.fold((0, 0), |(a, n), v| (a + v.as_i64(), n + 1)); + (sum / num).into() + } + Self::I128 => { + let (sum, num) = i.fold((0, 0), |(a, n), v| (a + v.as_i128(), n + 1)); + (sum / num).into() + } + Self::ISize => { + let (sum, num) = i.fold((0, 0), |(a, n), v| (a + v.as_isize(), n + 1)); + (sum / num).into() + } + + // Unsigned + Self::U8 => { + let (sum, num) = i.fold((0, 0), |(a, n), v| (a + v.as_u8(), n + 1)); + (sum / num).into() + } + Self::U16 => { + let (sum, num) = i.fold((0, 0), |(a, n), v| (a + v.as_u16(), n + 1)); + (sum / num).into() + } + Self::U32 => { + let (sum, num) = i.fold((0, 0), |(a, n), v| (a + v.as_u32(), n + 1)); + (sum / num).into() + } + Self::U64 => { + let (sum, num) = i.fold((0, 0), |(a, n), v| (a + v.as_u64(), n + 1)); + (sum / num).into() + } + Self::U128 => { + let (sum, num) = i.fold((0, 0), |(a, n), v| (a + v.as_u128(), n + 1)); + (sum / num).into() + } + Self::USize => { + let (sum, num) = i.fold((0, 0), |(a, n), v| (a + v.as_usize(), n + 1)); + (sum / num).into() + } + + // Floating point + Self::F32 => { + let (sum, num) = i.fold((0.0, 0), |(a, n), v| (a + v.as_f32(), n + 1)); + (sum / num as f32).into() + } + Self::F64 => { + let (sum, num) = i.fold((0.0, 0), |(a, n), v| (a + v.as_f64(), n + 1)); + (sum / num as f64).into() + } + + // Others + _ => panic!("Cannot perform sum on type `{}`", self), + } + } + pub fn prod<'a, I: Iterator>(&self, i: I) -> Tuple { match self { Self::I8 => i.fold(1, |a, v| a * v.as_i8()).into(), diff --git a/core/src/compiler/back/ast.rs b/core/src/compiler/back/ast.rs index ac1a43e..da497f5 100644 --- a/core/src/compiler/back/ast.rs +++ b/core/src/compiler/back/ast.rs @@ -1,12 +1,12 @@ use std::collections::*; -use crate::common::adt_variant_registry::ADTVariantRegistry; -use crate::common::aggregate_op::AggregateOp; +use crate::common::adt_variant_registry::*; +use crate::common::foreign_aggregate::*; use crate::common::foreign_function::*; use crate::common::foreign_predicate::*; -use crate::common::input_tag::DynamicInputTag; -use crate::common::output_option::OutputOption; -use crate::common::value_type::ValueType; +use crate::common::input_tag::*; +use crate::common::output_option::*; +use crate::common::value_type::*; use crate::compiler::front; @@ -25,6 +25,7 @@ pub struct Program { pub rules: Vec, pub function_registry: ForeignFunctionRegistry, pub predicate_registry: ForeignPredicateRegistry, + pub aggregate_registry: AggregateRegistry, pub adt_variant_registry: ADTVariantRegistry, } @@ -709,19 +710,36 @@ pub struct UnaryConstraint { #[derive(Clone, Debug, PartialEq)] pub struct Reduce { - pub op: AggregateOp, + // Aggregator + pub aggregator: String, + pub params: Vec, + pub has_exclamation_mark: bool, + + // Concretized types of reduce arguments + pub left_var_types: Vec, + pub arg_var_types: Vec, + pub input_var_types: Vec, + + // Variables pub left_vars: Vec, pub group_by_vars: Vec, pub other_group_by_vars: Vec, pub arg_vars: Vec, pub to_aggregate_vars: Vec, + + // Bodies of reduce pub body_formula: Atom, pub group_by_formula: Option, } impl Reduce { pub fn new( - op: AggregateOp, + aggregator: String, + params: Vec, + has_exclamation_mark: bool, + left_var_types: Vec, + arg_var_types: Vec, + input_var_types: Vec, left_vars: Vec, group_by_vars: Vec, other_group_by_vars: Vec, @@ -731,7 +749,12 @@ impl Reduce { group_by_formula: Option, ) -> Self { Self { - op, + aggregator, + params, + has_exclamation_mark, + left_var_types, + arg_var_types, + input_var_types, left_vars, group_by_vars, other_group_by_vars, diff --git a/core/src/compiler/back/attr.rs b/core/src/compiler/back/attr.rs index abfa306..518d2cd 100644 --- a/core/src/compiler/back/attr.rs +++ b/core/src/compiler/back/attr.rs @@ -1,4 +1,3 @@ -use crate::common::aggregate_op::AggregateOp; use crate::common::input_file::InputFile; #[derive(Clone, Debug, PartialEq)] @@ -97,7 +96,7 @@ pub enum Attribute { impl Attribute { pub fn aggregate_body( - aggregator: AggregateOp, + aggregator: String, num_group_by_vars: usize, num_arg_vars: usize, num_key_vars: usize, @@ -124,14 +123,14 @@ impl Attribute { #[derive(Clone, Debug, PartialEq)] pub struct AggregateBodyAttribute { - pub aggregator: AggregateOp, + pub aggregator: String, pub num_group_by_vars: usize, pub num_arg_vars: usize, pub num_key_vars: usize, } impl AggregateBodyAttribute { - pub fn new(aggregator: AggregateOp, num_group_by_vars: usize, num_arg_vars: usize, num_key_vars: usize) -> Self { + pub fn new(aggregator: String, num_group_by_vars: usize, num_arg_vars: usize, num_key_vars: usize) -> Self { Self { aggregator, num_group_by_vars, diff --git a/core/src/compiler/back/b2r.rs b/core/src/compiler/back/b2r.rs index 166c31f..f1b263c 100644 --- a/core/src/compiler/back/b2r.rs +++ b/core/src/compiler/back/b2r.rs @@ -903,39 +903,21 @@ impl Program { r: &Reduce, prop: DataflowProp, ) -> ram::Dataflow { - // Handle different versions of reduce... + // TODO: Handle output of aggregators let lt = VariableTuple::from_vars(r.left_vars.iter().cloned(), true); + + // Handle different versions of reduce... let (var_tuple, has_group_by) = if !r.group_by_vars.is_empty() { let gbt = VariableTuple::from_vars(r.group_by_vars.iter().cloned(), true); let vt = if r.group_by_formula.is_some() { let ogbt = VariableTuple::from_vars(r.other_group_by_vars.iter().cloned(), true); - if !r.arg_vars.is_empty() { - let avt = VariableTuple::from_vars(r.arg_vars.iter().cloned(), true); - VariableTuple::from((gbt, ogbt, (avt, lt))) - } else { - // With group by, no arg - VariableTuple::from((gbt, ogbt, lt)) - } + VariableTuple::from((gbt, ogbt, lt)) } else { - if !r.arg_vars.is_empty() { - let avt = VariableTuple::from_vars(r.arg_vars.iter().cloned(), true); - VariableTuple::from((gbt, (avt, lt))) - } else { - // With group by, no arg - VariableTuple::from((gbt, lt)) - } + VariableTuple::from((gbt, lt)) }; (vt, true) } else { - if !r.arg_vars.is_empty() { - // No group by, with arg - let avt = VariableTuple::from_vars(r.arg_vars.iter().cloned(), true); - let var_tuple = VariableTuple::from((avt, lt)); - (var_tuple, false) - } else { - // No group_by, no arg - (lt, false) - } + (lt, false) }; // Get the type of group by... @@ -950,7 +932,15 @@ impl Program { }; // Construct the reduce and the dataflow - let agg = ram::Dataflow::reduce(r.op.clone(), r.body_formula.predicate.clone(), group_by); + let agg = ram::Dataflow::reduce( + r.aggregator.clone(), + r.params.clone(), + r.has_exclamation_mark, + r.arg_var_types.clone(), + r.input_var_types.clone(), + r.body_formula.predicate.clone(), + group_by, + ); let dataflow = ram::Dataflow::project(agg, var_tuple.projection(goal)); // Check if we need to store into temporary variable @@ -1094,7 +1084,6 @@ impl Program { // For an aggregate sub-relation let head_args = head.variable_args().into_iter().cloned().collect::>(); let num_group_by = agg_attr.num_group_by_vars; - let num_args = agg_attr.num_arg_vars; // Compute the items for aggregation let mut elems = vec![]; @@ -1102,16 +1091,8 @@ impl Program { let tuple_vars = VariableTuple::from_vars(head_args[..num_group_by].iter().cloned(), true); elems.push(tuple_vars); } - let start = num_group_by + num_args; - let to_agg_tuple = VariableTuple::from_vars(head_args[start..].iter().cloned(), true); - if num_args > 0 { - let start = num_group_by; - let end = num_group_by + num_args; - let tuple_args = VariableTuple::from_vars(head_args[start..end].iter().cloned(), true); - elems.push((tuple_args, to_agg_tuple).into()); - } else { - elems.push(to_agg_tuple); - } + let to_agg_tuple = VariableTuple::from_vars(head_args[num_group_by..].iter().cloned(), true); + elems.push(to_agg_tuple); // Combine them into a var tuple let var_tuple = if elems.len() == 1 { diff --git a/core/src/compiler/back/compile.rs b/core/src/compiler/back/compile.rs index 3a0e36b..a3d8225 100644 --- a/core/src/compiler/back/compile.rs +++ b/core/src/compiler/back/compile.rs @@ -106,6 +106,7 @@ impl Program { strata: ram_strata, function_registry: self.function_registry.clone(), predicate_registry: self.predicate_registry.clone(), + aggregate_registry: self.aggregate_registry.clone(), adt_variant_registry: self.adt_variant_registry.clone(), relation_to_stratum, }) diff --git a/core/src/compiler/back/pretty.rs b/core/src/compiler/back/pretty.rs index 95340f5..80f0890 100644 --- a/core/src/compiler/back/pretty.rs +++ b/core/src/compiler/back/pretty.rs @@ -301,7 +301,7 @@ impl Display for Reduce { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { if self.left_vars.len() > 1 { f.write_fmt(format_args!( - "({}) = ", + "({}) := ", self .left_vars .iter() @@ -310,7 +310,15 @@ impl Display for Reduce { .join(", ") ))?; } else if self.left_vars.len() == 1 { - f.write_fmt(format_args!("{} = ", self.left_vars[0].name))?; + f.write_fmt(format_args!("{} := ", self.left_vars[0].name))?; + } + f.write_str(&self.aggregator)?; + if self.params.len() > 0 { + let params = self.params.iter().map(|p| p.to_string()).collect::>().join(", "); + f.write_str(¶ms)?; + } + if self.has_exclamation_mark { + f.write_str("!")?; } let group_by_vars = if self.group_by_vars.is_empty() { String::new() @@ -352,8 +360,8 @@ impl Display for Reduce { .collect::>() .join(", "); f.write_fmt(format_args!( - "{}{}{}({}: {}{})", - self.op, group_by_vars, arg_vars, to_agg_vars, self.body_formula, group_by_atom + "{}{}({}: {}{})", + group_by_vars, arg_vars, to_agg_vars, self.body_formula, group_by_atom )) } } diff --git a/core/src/compiler/front/analysis.rs b/core/src/compiler/front/analysis.rs index 6832433..81dd2f4 100644 --- a/core/src/compiler/front/analysis.rs +++ b/core/src/compiler/front/analysis.rs @@ -1,3 +1,4 @@ +use crate::common::foreign_aggregate::AggregateRegistry; use crate::common::foreign_function::ForeignFunctionRegistry; use crate::common::foreign_predicate::ForeignPredicateRegistry; @@ -24,7 +25,11 @@ pub struct Analysis { impl Analysis { /// Create a new front IR analysis object - pub fn new(function_registry: &ForeignFunctionRegistry, predicate_registry: &ForeignPredicateRegistry) -> Self { + pub fn new( + function_registry: &ForeignFunctionRegistry, + predicate_registry: &ForeignPredicateRegistry, + aggregate_registry: &AggregateRegistry, + ) -> Self { Self { invalid_constant: InvalidConstantAnalyzer::new(), invalid_wildcard: InvalidWildcardAnalyzer::new(), @@ -36,7 +41,7 @@ impl Analysis { constant_decl_analysis: ConstantDeclAnalysis::new(), adt_analysis: AlgebraicDataTypeAnalysis::new(), head_relation_analysis: HeadRelationAnalysis::new(predicate_registry), - type_inference: TypeInference::new(function_registry, predicate_registry), + type_inference: TypeInference::new(function_registry, predicate_registry, aggregate_registry), boundness_analysis: BoundnessAnalysis::new(predicate_registry), demand_attr_analysis: DemandAttributeAnalysis::new(), } diff --git a/core/src/compiler/front/analyzers/aggregation.rs b/core/src/compiler/front/analyzers/aggregation.rs index ee3d8e5..68f18f4 100644 --- a/core/src/compiler/front/analyzers/aggregation.rs +++ b/core/src/compiler/front/analyzers/aggregation.rs @@ -17,9 +17,8 @@ impl AggregationAnalysis { impl NodeVisitor for AggregationAnalysis { fn visit(&mut self, reduce: &Reduce) { // Check max/min arg - match reduce.operator().internal() { - _ReduceOp::Max | _ReduceOp::Min => {} - _ReduceOp::Forall => { + match reduce.operator().name().name().as_str() { + "forall" => { // Check the body of forall expression match reduce.body() { Formula::Implies(_) => {} @@ -28,40 +27,14 @@ impl NodeVisitor for AggregationAnalysis { }), } } - _ReduceOp::Unknown(a) => self.errors.push(AggregationAnalysisError::UnknownAggregator { - agg: a.clone(), - loc: reduce.location().clone(), - }), - _ => { - if !reduce.args().is_empty() { - self - .errors - .push(AggregationAnalysisError::NonMinMaxAggregationHasArgument { - op: reduce.operator().clone(), - }) - } - } - } - - // Check the binding variables - if reduce.bindings().is_empty() { - match reduce.operator().internal() { - _ReduceOp::Exists | _ReduceOp::Forall | _ReduceOp::Unknown(_) => {} - _ => self.errors.push(AggregationAnalysisError::EmptyBinding { - agg: reduce.operator().to_string(), - loc: reduce.location().clone(), - }), - } + _ => {} } } } #[derive(Debug, Clone)] pub enum AggregationAnalysisError { - NonMinMaxAggregationHasArgument { op: ReduceOp }, - UnknownAggregator { agg: String, loc: Loc }, ForallBodyNotImplies { loc: Loc }, - EmptyBinding { agg: String, loc: Loc }, } impl FrontCompileErrorTrait for AggregationAnalysisError { @@ -71,29 +44,12 @@ impl FrontCompileErrorTrait for AggregationAnalysisError { fn report(&self, src: &Sources) -> String { match self { - Self::NonMinMaxAggregationHasArgument { op } => { - format!( - "{} aggregation cannot have arguments\n{}", - op.to_string(), - op.location().report(src) - ) - } - Self::UnknownAggregator { agg, loc } => { - format!("unknown aggregator `{}`\n{}", agg, loc.report(src)) - } Self::ForallBodyNotImplies { loc } => { format!( "the body of forall aggregation must be an `implies` formula\n{}", loc.report(src) ) } - Self::EmptyBinding { agg, loc } => { - format!( - "the binding variables of `{}` aggregation cannot be empty\n{}", - agg, - loc.report(src), - ) - } } } } diff --git a/core/src/compiler/front/analyzers/boundness/context.rs b/core/src/compiler/front/analyzers/boundness/context.rs index 3afb8e4..490d294 100644 --- a/core/src/compiler/front/analyzers/boundness/context.rs +++ b/core/src/compiler/front/analyzers/boundness/context.rs @@ -55,7 +55,11 @@ pub struct DisjunctionContext { impl DisjunctionContext { pub fn from_formula(formula: &Formula) -> Self { let conjuncts: Vec = match formula { - Formula::Disjunction(d) => d.iter_args().map(|a| Self::from_formula(a).conjuncts).flatten().collect(), + Formula::Disjunction(d) => d + .iter_args() + .map(|a| Self::from_formula(a).conjuncts) + .flatten() + .collect(), Formula::Conjunction(c) => { let ctxs = c.iter_args().map(|a| Self::from_formula(a).conjuncts); let cp = ctxs.multi_cartesian_product(); @@ -193,6 +197,7 @@ impl ConjunctionContext { #[derive(Debug, Clone)] pub struct AggregationContext { + pub result_var_or_wildcards: Vec<(Loc, Option)>, pub result_vars: Vec, pub binding_vars: Vec, pub arg_vars: Vec, @@ -205,23 +210,23 @@ pub struct AggregationContext { } impl AggregationContext { - pub fn left_variable_names(&self) -> BTreeSet { + pub fn left_variable_names(&self) -> Vec { self.result_vars.iter().map(|v| v.name().to_string()).collect() } - pub fn binding_variable_names(&self) -> BTreeSet { + pub fn binding_variable_names(&self) -> Vec { self.binding_vars.iter().cloned().collect() } - pub fn argument_variable_names(&self) -> BTreeSet { + pub fn argument_variable_names(&self) -> Vec { self.arg_vars.iter().map(|n| n.name().to_string()).collect() } - pub fn group_by_head_variable_names(&self) -> BTreeSet { + pub fn group_by_head_variable_names(&self) -> Vec { if let Some((_, vars, _)) = &self.group_by { vars.iter().map(|n| n.name().to_string()).collect() } else { - BTreeSet::new() + Vec::new() } } @@ -247,6 +252,15 @@ impl AggregationContext { // Construct self Self { + result_var_or_wildcards: reduce + .iter_left() + .map(|vow| { + ( + vow.location().clone(), + vow.as_variable().map(|v| v.name().name().clone()), + ) + }) + .collect(), result_vars: reduce.iter_left_variables().cloned().collect(), binding_vars: reduce.iter_binding_names().map(|n| n.to_string()).collect(), arg_vars: reduce.args().clone(), diff --git a/core/src/compiler/front/analyzers/constant_decl.rs b/core/src/compiler/front/analyzers/constant_decl.rs index d37dc70..2031294 100644 --- a/core/src/compiler/front/analyzers/constant_decl.rs +++ b/core/src/compiler/front/analyzers/constant_decl.rs @@ -154,7 +154,8 @@ impl NodeVisitor for ConstantDeclAnalysis { let entity = ca.value(); // Then we make sure that the entity is indeed a constant - if let Some(var_loc) = entity.get_first_non_constant_location(&|v| self.variables.contains_key(v.variable_name())) { + if let Some(var_loc) = entity.get_first_non_constant_location(&|v| self.variables.contains_key(v.variable_name())) + { self.errors.push(ConstantDeclError::EntityContainsNonConstant { const_decl_loc: ca.location().clone(), var_loc: var_loc.clone(), @@ -169,21 +170,15 @@ impl NodeVisitor for ConstantDeclAnalysis { // Process the entity into a set of entity facts and one final constant value let (entity_facts, constant) = - entity.to_facts_with_constant_variables(|v| - self - .variables - .get(v.name().name()) - .map(|(_, _, c)| c.clone()) - ); + entity.to_facts_with_constant_variables(|v| self.variables.get(v.name().name()).map(|(_, _, c)| c.clone())); // Extend the entity facts with the storage self.entity_facts.extend(entity_facts); // Store the variable - self.variables.insert( - ca.name().name().to_string(), - (ca.location().clone(), ty, constant), - ); + self + .variables + .insert(ca.name().name().to_string(), (ca.location().clone(), ty, constant)); } } } diff --git a/core/src/compiler/front/analyzers/invalid_constant.rs b/core/src/compiler/front/analyzers/invalid_constant.rs index 51712b6..a0e28fa 100644 --- a/core/src/compiler/front/analyzers/invalid_constant.rs +++ b/core/src/compiler/front/analyzers/invalid_constant.rs @@ -14,7 +14,7 @@ impl InvalidConstantAnalyzer { impl NodeVisitor for InvalidConstantAnalyzer { fn visit(&mut self, datetime: &DateTimeLiteral) { match crate::utils::parse_date_time_string(datetime.datetime()) { - Some(_) => {}, + Some(_) => {} None => { self.errors.push(InvalidConstantError::InvalidConstant { loc: datetime.location().clone(), @@ -28,7 +28,7 @@ impl NodeVisitor for InvalidConstantAnalyzer { impl NodeVisitor for InvalidConstantAnalyzer { fn visit(&mut self, duration: &DurationLiteral) { match crate::utils::parse_duration_string(duration.duration()) { - Some(_) => {}, + Some(_) => {} None => { self.errors.push(InvalidConstantError::InvalidConstant { loc: duration.location().clone(), diff --git a/core/src/compiler/front/analyzers/mod.rs b/core/src/compiler/front/analyzers/mod.rs index 72faf60..d19caf4 100644 --- a/core/src/compiler/front/analyzers/mod.rs +++ b/core/src/compiler/front/analyzers/mod.rs @@ -37,5 +37,4 @@ pub mod errors { pub use super::invalid_constant::InvalidConstantError; pub use super::invalid_wildcard::InvalidWildcardError; pub use super::output_files::OutputFilesError; - pub use super::type_inference::TypeInferenceError; } diff --git a/core/src/compiler/front/analyzers/type_inference/error.rs b/core/src/compiler/front/analyzers/type_inference/error.rs index ec46122..cb5607f 100644 --- a/core/src/compiler/front/analyzers/type_inference/error.rs +++ b/core/src/compiler/front/analyzers/type_inference/error.rs @@ -3,442 +3,229 @@ use crate::compiler::front::*; use super::*; -#[derive(Clone, Debug)] -pub enum TypeInferenceError { - UnknownRelation { - relation: String, - }, - DuplicateTypeDecl { - type_name: String, - source_decl_loc: NodeLocation, - duplicate_decl_loc: NodeLocation, - }, - DuplicateRelationTypeDecl { - predicate: String, - source_decl_loc: NodeLocation, - duplicate_decl_loc: NodeLocation, - }, - UnknownADTVariant { - predicate: String, - loc: NodeLocation, - }, - InvalidSubtype { - source_type: String, - source_type_loc: NodeLocation, - }, - UnknownCustomType { - type_name: String, - loc: NodeLocation, - }, - UnknownQueryRelationType { - predicate: String, - loc: NodeLocation, - }, - UnknownFunctionType { - function_name: String, - loc: NodeLocation, - }, - UnknownVariable { - variable: String, - loc: NodeLocation, - }, - ArityMismatch { - predicate: String, - expected: usize, - actual: usize, - source_loc: NodeLocation, - mismatch_loc: NodeLocation, - }, - FunctionArityMismatch { - function: String, - actual: usize, - loc: NodeLocation, - }, - ADTVariantArityMismatch { - variant: String, - expected: usize, - actual: usize, - loc: NodeLocation, - }, - EntityTupleArityMismatch { - predicate: String, - expected: usize, - actual: usize, - source_loc: NodeLocation, - }, - InvalidArgIndex { - predicate: String, - index: usize, - source_loc: NodeLocation, - access_loc: NodeLocation, - }, - InvalidForeignPredicateArgIndex { - predicate: String, - index: usize, - access_loc: NodeLocation, - }, - ConstantSetArityMismatch { - predicate: String, - decl_loc: NodeLocation, - mismatch_tuple_loc: NodeLocation, - }, - ConstantTypeMismatch { - expected: ValueType, - found: TypeSet, - }, - BadEnumValueKind { - found: &'static str, - loc: NodeLocation, - }, - NegativeEnumValue { - found: i64, - loc: NodeLocation, - }, - CannotUnifyTypes { - t1: TypeSet, - t2: TypeSet, - loc: Option, - }, - CannotUnifyForeignPredicateArgument { +impl FrontCompileErrorMessage { + pub fn unknown_relation(relation: String) -> Self { + Self::error().msg(format!("unknown relation `{relation}`")) + } + + pub fn duplicate_type_decl(type_name: String, source_decl_loc: Loc, duplicate_decl_loc: Loc) -> Self { + Self::error() + .msg(format!( + "duplicated type declaration found for `{type_name}`. It is originally defined here:" + )) + .src(source_decl_loc) + .msg("while we find a duplicated declaration here:") + .src(duplicate_decl_loc) + } + + pub fn duplicate_relation_type_decl(predicate: String, source_decl_loc: Loc, duplicate_decl_loc: Loc) -> Self { + Self::error() + .msg(format!( + "duplicated relation type declaration found for `{predicate}`. It is originally defined here:" + )) + .src(source_decl_loc) + .msg("while we find a duplicated declaration here:") + .src(duplicate_decl_loc) + } + + pub fn unknown_adt_variant(var: String, loc: Loc) -> Self { + Self::error() + .msg(format!("unknown algebraic data type variant `{var}`:")) + .src(loc) + } + + pub fn invalid_subtype(source_type: String, loc: Loc) -> Self { + Self::error() + .msg(format!("cannot create subtype from `{source_type}`")) + .src(loc) + } + + pub fn unknown_custom_type(type_name: String, loc: Loc) -> Self { + Self::error().msg(format!("unknown custom type `{type_name}`")).src(loc) + } + + pub fn unknown_query_relation_type(predicate: String, loc: Loc) -> Self { + Self::error() + .msg(format!("unknown relation `{predicate}` used in query")) + .src(loc) + } + + pub fn unknown_function_type(func_name: String, loc: Loc) -> Self { + Self::error().msg(format!("unknown function `{func_name}`")).src(loc) + } + + pub fn unknown_variable(var: String, loc: Loc) -> Self { + Self::error() + .msg(format!("unknown variable `{var}` in the rule")) + .src(loc) + } + + pub fn unknown_aggregate(agg: String, loc: Loc) -> Self { + Self::error().msg(format!("unknown aggregate `{agg}`")).src(loc) + } + + pub fn arity_mismatch(pred: String, expected: usize, actual: usize, mismatch_loc: Loc) -> Self { + Self::error() + .msg(format!( + "arity mismatch for relation `{pred}`. Expected {expected}, found {actual}:" + )) + .src(mismatch_loc) + } + + pub fn function_arity_mismatch(func: String, actual: usize, loc: Loc) -> Self { + Self::error() + .msg(format!( + "bad number of arguments for function `{func}`, found {actual}:" + )) + .src(loc) + } + + pub fn adt_variant_arity_mismatch(var: String, expected: usize, actual: usize, loc: Loc) -> Self { + Self::error() + .msg(format!( + "arity mismatch for algebraic data type variant `{var}`. Expected {expected}, found {actual}:" + )) + .src(loc) + } + + pub fn entity_tuple_arity_mismatch(pred: String, expected: usize, actual: usize, loc: Loc) -> Self { + Self::error() + .msg(format!( + "incorrect number of arguments in entity tuple for `{pred}`. Expected {expected}, found {actual}:" + )) + .src(loc) + } + + pub fn invalid_predicate_arg_index(pred: String, index: usize, source_loc: Loc, access_loc: Loc) -> Self { + Self::error() + .msg(format!( + "Unexpected {index}-th argument for relation `{pred}`. The relation type is inferred here:" + )) + .src(source_loc) + .msg("erroneous access happens here:") + .src(access_loc) + } + + pub fn invalid_foreign_predicate_arg_index(pred: String, index: usize, access_loc: Loc) -> Self { + Self::error() + .msg(format!( + "Unexpected {index}-th argument for foreign predicate `{pred}`:" + )) + .src(access_loc) + } + + pub fn constant_set_arity_mismatch(predicate: String, mismatch_loc: Loc) -> Self { + Self::error() + .msg(format!("arity mismatch in set for relation `{predicate}`:")) + .src(mismatch_loc) + } + + pub fn constant_type_mismatch(expected: ValueType, found: TypeSet) -> Self { + Self::error() + .msg(format!( + "type mismatch for constant. Expected `{expected}`, found `{found}`" + )) + .src(found.location().clone()) + } + + pub fn bad_enum_value_kind(found: &'static str, loc: Loc) -> Self { + Self::error() + .msg(format!("bad enum value. Expected unsigned integers, found `{found}`")) + .src(loc) + } + + pub fn negative_enum_value(found: i64, loc: Loc) -> Self { + Self::error() + .msg(format!( + "enum value `{found}` found to be negative. Expected unsigned integers" + )) + .src(loc) + } + + pub fn cannot_unify_foreign_predicate_arg( pred: String, i: usize, expected_ty: TypeSet, actual_ty: TypeSet, - loc: NodeLocation, - }, - NoMatchingTripletRule { - op1_ty: TypeSet, - op2_ty: TypeSet, - e_ty: TypeSet, - location: NodeLocation, - }, - CannotUnifyVariables { - v1: String, - t1: TypeSet, - v2: String, - t2: TypeSet, - loc: NodeLocation, - }, - CannotTypeCast { - t1: TypeSet, - t2: ValueType, - loc: NodeLocation, - }, - ConstraintNotBoolean { - ty: TypeSet, - loc: NodeLocation, - }, - InvalidReduceOutput { - op: String, - expected: usize, - found: usize, - loc: NodeLocation, - }, - InvalidReduceBindingVar { - op: String, - expected: usize, - found: usize, - loc: NodeLocation, - }, - InvalidUniqueNumParams { - num_output_vars: usize, - num_binding_vars: usize, - loc: NodeLocation, - }, - CannotRedefineForeignPredicate { - pred: String, - loc: NodeLocation, - }, - CannotQueryForeignPredicate { - pred: String, - loc: NodeLocation, - }, - Internal { - error_string: String, - }, -} + loc: Loc, + ) -> Self { + Self::error() + .msg(format!("cannot unify the type of {i}-th argument of foreign predicate `{pred}`, expected type `{expected_ty}`, found `{actual_ty}`:")) + .src(loc) + } -impl TypeInferenceError { - pub fn annotate_location(&mut self, new_location: &NodeLocation) { - match self { - Self::CannotUnifyTypes { loc, .. } => { - *loc = Some(new_location.clone()); - } - _ => {} - } + pub fn cannot_unify_variables(v1: String, t1: TypeSet, v2: String, t2: TypeSet, loc: Loc) -> Self { + Self::error() + .msg(format!( + "cannot unify variable types: `{v1}` has `{t1}` type, `{v2}` has `{t2}` type, but they should be unified:" + )) + .src(loc) } -} -impl FrontCompileErrorTrait for TypeInferenceError { - fn error_type(&self) -> FrontCompileErrorType { - FrontCompileErrorType::Error + pub fn no_matching_triplet_rule(op1_ty: TypeSet, op2_ty: TypeSet, e_ty: TypeSet, loc: Loc) -> Self { + Self::error() + .msg(format!("no matching rule found; two operands have type `{op1_ty}` and `{op2_ty}`, while the expression has type `{e_ty}`:")) + .src(loc) } - fn report(&self, src: &Sources) -> String { - match self { - Self::UnknownRelation { relation } => { - format!("unknown relation `{relation}`") - } - Self::DuplicateTypeDecl { - type_name, - source_decl_loc, - duplicate_decl_loc, - } => { - format!( - "duplicated type declaration found for `{}`. It is originally defined here:\n{}\nwhile we find a duplicated declaration here:\n{}", - type_name, source_decl_loc.report(src), duplicate_decl_loc.report(src), - ) - } - Self::DuplicateRelationTypeDecl { - predicate, - source_decl_loc, - duplicate_decl_loc, - } => { - format!( - "duplicated relation type declaration found for `{}`. It is originally defined here:\n{}\nwhile we find a duplicated declaration here:\n{}", - predicate, source_decl_loc.report(src), duplicate_decl_loc.report(src) - ) - } - Self::UnknownADTVariant { predicate, loc } => { - format!( - "unknown algebraic data type variant `{predicate}`:\n{}", - loc.report(src) - ) - } - Self::InvalidSubtype { - source_type, - source_type_loc, - } => { - format!( - "cannot create subtype from `{}`\n{}", - source_type, - source_type_loc.report(src) - ) - } - Self::UnknownCustomType { type_name, loc } => { - format!("unknown custom type `{}`\n{}", type_name, loc.report(src)) - } - Self::UnknownQueryRelationType { predicate, loc } => { - format!("unknown relation `{}` used in query\n{}", predicate, loc.report(src)) - } - Self::UnknownFunctionType { function_name, loc } => { - format!("unknown function `{}`\n{}", function_name, loc.report(src)) - } - Self::UnknownVariable { variable, loc } => { - format!("unknown variable `{}` in the rule\n{}", variable, loc.report(src)) - } - Self::ArityMismatch { - predicate, - expected, - actual, - mismatch_loc, - .. - } => { - format!( - "arity mismatch for relation `{predicate}`. Expected {expected}, found {actual}:\n{}", - mismatch_loc.report(src) - ) - } - Self::FunctionArityMismatch { function, actual, loc } => { - format!( - "wrong number of arguments for function `{}`, found {}:\n{}", - function, - actual, - loc.report(src) - ) - } - Self::ADTVariantArityMismatch { - variant, - expected, - actual, - loc, - } => { - format!( - "arity mismatch for algebraic data type variant `{variant}`. Expected {expected}, found {actual}:\n{}", - loc.report(src) - ) - } - Self::EntityTupleArityMismatch { - predicate, - expected, - actual, - source_loc, - } => { - format!( - "incorrect number of arguments in entity tuple for `{predicate}`. Expected {expected}, found {actual}:\n{}", - source_loc.report(src) - ) - } - Self::InvalidArgIndex { - predicate, - index, - source_loc, - access_loc, - } => { - format!( - "Unexpected {}-th argument for relation `{}`. The relation type is inferred here:\n{}\nerroneous access happens here:\n{}", - index, predicate, source_loc.report(src), access_loc.report(src) - ) - } - Self::InvalidForeignPredicateArgIndex { - predicate, - index, - access_loc, - } => { - format!( - "Unexpected {}-th argument for foreign predicate `{}`:\n{}", - index, - predicate, - access_loc.report(src) - ) - } - Self::ConstantSetArityMismatch { - predicate, - mismatch_tuple_loc, - .. - } => { - format!( - "arity mismatch in constant set declaration for relation `{}`:\n{}", - predicate, - mismatch_tuple_loc.report(src) - ) - } - Self::ConstantTypeMismatch { expected, found } => { - format!( - "type mismatch for constant. Expected `{}`, found `{}`\n{}", - expected, - found, - found.location().report(src) - ) - } - Self::BadEnumValueKind { found, loc } => { - format!( - "bad enum value. Expected unsigned integers, found `{}`\n{}", - found, - loc.report(src), - ) - } - Self::NegativeEnumValue { found, loc } => { - format!( - "enum value `{}` found to be negative. Expected unsigned integers\n{}", - found, - loc.report(src), - ) - } - Self::CannotUnifyTypes { t1, t2, loc } => match loc { - Some(l) => { - format!("cannot unify types `{}` and `{}` in\n{}\nwhere the first is inferred here\n{}\nand the second is inferred here\n{}", t1, t2, l.report(src), t1.location().report(src), t2.location().report(src)) - } - None => { - format!( - "cannot unify types `{}` and `{}`, where the first is declared here\n{}\nand the second is declared here\n{}", - t1, t2, t1.location().report(src), t2.location().report(src) - ) - } - }, - Self::CannotUnifyForeignPredicateArgument { - pred, - i, - expected_ty, - actual_ty, - loc, - } => { - format!( - "cannot unify the type of {}-th argument of foreign predicate `{}`, expected type `{}`, found `{}`:\n{}", - i, - pred, - expected_ty, - actual_ty, - loc.report(src), - ) - } - Self::CannotUnifyVariables { v1, t1, v2, t2, loc } => { - format!( - "cannot unify variable types: `{}` has `{}` type, `{}` has `{}` type, but they should be unified\n{}", - v1, - t1, - v2, - t2, - loc.report(src) - ) - } - Self::NoMatchingTripletRule { - op1_ty, - op2_ty, - e_ty, - location, - } => { - format!( - "no matching rule found; two operands have type `{}` and `{}`, while the expression has type `{}`:\n{}", - op1_ty, - op2_ty, - e_ty, - location.report(src) - ) - } - Self::CannotTypeCast { t1, t2, loc } => { - format!("cannot cast type from `{}` to `{}`\n{}", t1, t2, loc.report(src)) - } - Self::ConstraintNotBoolean { ty, loc } => { - format!( - "constraint must have `bool` type, found `{}` type\n{}", - ty, - loc.report(src) - ) - } - Self::InvalidReduceOutput { - op, - expected, - found, - loc, - } => { - format!( - "invalid amount of output for `{}`. Expected {}, found {}\n{}", - op, - expected, - found, - loc.report(src) - ) - } - Self::InvalidReduceBindingVar { - op, - expected, - found, - loc, - } => { - format!( - "invalid amount of binding variables for `{}`. Expected {}, found {}\n{}", - op, - expected, - found, - loc.report(src) - ) - } - Self::InvalidUniqueNumParams { - num_output_vars, - num_binding_vars, - loc, - } => { - format!( - "expected same amount of output variables and binding variables for aggregation `unique`, but found {} output variables and {} binding variables\n{}", - num_output_vars, num_binding_vars, loc.report(src) - ) - } - Self::CannotRedefineForeignPredicate { pred, loc } => { - format!( - "the predicate `{}` is being defined here, but it is also a foreign predicate which cannot be populated\n{}", - pred, - loc.report(src), - ) - } - Self::CannotQueryForeignPredicate { pred, loc } => { - format!( - "the foreign predicate `{}` cannot be queried\n{}", - pred, - loc.report(src), - ) - } - Self::Internal { error_string } => error_string.clone(), + pub fn cannot_type_cast(t1: TypeSet, t2: ValueType, loc: Loc) -> Self { + Self::error() + .msg(format!("cannot cast type from `{t1}` to `{t2}`")) + .src(loc) + } + + pub fn constraint_not_boolean(ty: TypeSet, loc: Loc) -> Self { + Self::error() + .msg(format!("constraint must have `bool` type, but found `{ty}` type")) + .src(loc) + } + + pub fn cannot_redefine_foreign_predicate(pred: String, loc: Loc) -> Self { + Self::error() + .msg(format!( + "the predicate `{pred}` is being defined here, but it is also a foreign predicate which cannot be populated" + )) + .src(loc) + } + + pub fn cannot_query_foreign_predicate(pred: String, loc: Loc) -> Self { + Self::error() + .msg(format!("the foreign predicate `{pred}` cannot be queried:")) + .src(loc) + } +} + +pub struct CannotUnifyTypes { + pub t1: TypeSet, + pub t2: TypeSet, + pub loc: Option, +} + +impl Into for CannotUnifyTypes { + fn into(self) -> FrontCompileErrorMessage { + if let Some(l) = self.loc { + FrontCompileErrorMessage::error() + .msg(format!("cannot unify types `{}` and `{}` in", &self.t1, &self.t2)) + .src(l) + .msg("where the first is inferred here") + .src(self.t1.location().clone()) + .msg("and the second is inferred here") + .src(self.t2.location().clone()) + } else { + FrontCompileErrorMessage::error() + .msg(format!( + "cannot unify types `{}` and `{}`, where the first is declared here", + &self.t1, &self.t2 + )) + .src(self.t1.location().clone()) + .msg("and the second is declared here") + .src(self.t2.location().clone()) } } } + +impl CannotUnifyTypes { + pub fn annotate_location(&mut self, new_location: &NodeLocation) { + self.loc = Some(new_location.clone()); + } +} diff --git a/core/src/compiler/front/analyzers/type_inference/foreign_aggregate.rs b/core/src/compiler/front/analyzers/type_inference/foreign_aggregate.rs new file mode 100644 index 0000000..7c7af62 --- /dev/null +++ b/core/src/compiler/front/analyzers/type_inference/foreign_aggregate.rs @@ -0,0 +1,30 @@ +use std::collections::*; + +use crate::common::foreign_aggregate::*; + +#[derive(Debug, Clone)] +pub struct AggregateTypeRegistry { + pub aggregate_types: HashMap, +} + +impl AggregateTypeRegistry { + pub fn empty() -> Self { + Self { + aggregate_types: HashMap::new(), + } + } + + pub fn from_aggregate_registry(far: &AggregateRegistry) -> Self { + let mut registry = Self::empty(); + for (_, fa) in far.iter() { + let name = fa.name(); + let agg_type = fa.aggregate_type(); + registry.aggregate_types.insert(name, agg_type); + } + registry + } + + pub fn get(&self, agg_name: &str) -> Option<&AggregateType> { + self.aggregate_types.get(agg_name) + } +} diff --git a/core/src/compiler/front/analyzers/type_inference/local.rs b/core/src/compiler/front/analyzers/type_inference/local.rs index 9d2ffef..d1e300c 100644 --- a/core/src/compiler/front/analyzers/type_inference/local.rs +++ b/core/src/compiler/front/analyzers/type_inference/local.rs @@ -13,7 +13,7 @@ pub struct LocalTypeInferenceContext { pub vars_of_same_type: Vec<(String, String)>, pub var_types: HashMap, pub constraints: Vec, - pub errors: Vec, + pub errors: Vec, } impl LocalTypeInferenceContext { @@ -45,7 +45,7 @@ impl LocalTypeInferenceContext { &self, predicate_registry: &PredicateTypeRegistry, inferred_relation_types: &mut HashMap, Loc)>, - ) -> Result<(), TypeInferenceError> { + ) -> Result<(), Error> { for (pred, arities) in &self.atom_arities { // Skip foreign predicates if predicate_registry.contains_predicate(pred) { @@ -60,16 +60,10 @@ impl LocalTypeInferenceContext { } // Make sure the arity matches - let (tys, source_loc) = &inferred_relation_types[pred]; + let (tys, _) = &inferred_relation_types[pred]; for (arity, atom_loc) in arities { if arity != &tys.len() { - return Err(TypeInferenceError::ArityMismatch { - predicate: pred.clone(), - expected: tys.len(), - actual: *arity, - source_loc: source_loc.clone(), - mismatch_loc: atom_loc.clone(), - }); + return Err(Error::arity_mismatch(pred.clone(), tys.len(), *arity, atom_loc.clone())); } } } @@ -120,8 +114,9 @@ impl LocalTypeInferenceContext { inferred_relation_types: &HashMap, Loc)>, function_type_registry: &FunctionTypeRegistry, predicate_type_registry: &PredicateTypeRegistry, + aggregate_type_registry: &AggregateTypeRegistry, inferred_expr_types: &mut HashMap, - ) -> Result<(), TypeInferenceError> { + ) -> Result<(), Error> { for unif in &self.unifications { unif.unify( custom_types, @@ -129,6 +124,7 @@ impl LocalTypeInferenceContext { inferred_relation_types, function_type_registry, predicate_type_registry, + aggregate_type_registry, inferred_expr_types, )?; } @@ -139,7 +135,7 @@ impl LocalTypeInferenceContext { &self, inferred_var_expr: &mut HashMap>>, inferred_expr_types: &mut HashMap, - ) -> Result<(), TypeInferenceError> { + ) -> Result<(), Error> { let mut var_tys = inferred_var_expr .entry(self.rule_loc.clone()) .or_default() @@ -150,14 +146,11 @@ impl LocalTypeInferenceContext { .filter_map(|e| inferred_expr_types.get(e)) .collect::>(); if tys.is_empty() { - Err(TypeInferenceError::UnknownVariable { - variable: var.clone(), - loc: self.rule_loc.clone(), - }) + Err(Error::unknown_variable(var.clone(), self.rule_loc.clone())) } else { match TypeSet::unify_type_sets(tys) { Ok(ty) => Ok((var.clone(), ty)), - Err(err) => Err(err), + Err(err) => Err(err.into()), } } }) @@ -169,17 +162,13 @@ impl LocalTypeInferenceContext { let v2_ty = &var_tys[v2]; match v1_ty.unify(v2_ty) { Err(err) => { - let new_err = match err { - TypeInferenceError::CannotUnifyTypes { t1, t2, .. } => TypeInferenceError::CannotUnifyVariables { - v1: v1.clone(), - t1, - v2: v2.clone(), - t2, - loc: self.rule_loc.clone(), - }, - err => err, - }; - return Err(new_err); + return Err(Error::cannot_unify_variables( + v1.clone(), + err.t1, + v2.clone(), + err.t2, + self.rule_loc.clone(), + )); } Ok(new_ty) => { var_tys.insert(v1.clone(), new_ty.clone()); @@ -191,7 +180,7 @@ impl LocalTypeInferenceContext { // Check variable type constraints for (var, (ty, _)) in &self.var_types { let curr_ty = &var_tys[var]; - let new_ty = ty.unify(curr_ty)?; + let new_ty = ty.unify(curr_ty).map_err(|e| e.into())?; var_tys.insert(var.clone(), new_ty); } @@ -211,7 +200,7 @@ impl LocalTypeInferenceContext { inferred_relation_expr: &HashMap<(String, usize), BTreeSet>, inferred_expr_types: &HashMap, inferred_relation_types: &mut HashMap, Loc)>, - ) -> Result<(), TypeInferenceError> { + ) -> Result<(), Error> { // Propagate inferred relation types for ((predicate, i), exprs) in inferred_relation_expr { let tys = exprs @@ -219,7 +208,7 @@ impl LocalTypeInferenceContext { .filter_map(|e| inferred_expr_types.get(e)) .collect::>(); if !tys.is_empty() { - let ty = TypeSet::unify_type_sets(tys)?; + let ty = TypeSet::unify_type_sets(tys).map_err(|e| e.into())?; if let Some((arg_types, _)) = inferred_relation_types.get_mut(predicate) { arg_types[*i] = ty; } @@ -233,7 +222,7 @@ impl LocalTypeInferenceContext { &self, custom_types: &HashMap, inferred_expr_types: &HashMap, - ) -> Result<(), TypeInferenceError> { + ) -> Result<(), Error> { // Check if type cast can happen for unif in &self.unifications { match unif { @@ -241,11 +230,7 @@ impl LocalTypeInferenceContext { let target_base_ty = find_value_type(custom_types, ty).unwrap(); let op1_ty = &inferred_expr_types[op1]; if !op1_ty.can_type_cast(&target_base_ty) { - return Err(TypeInferenceError::CannotTypeCast { - t1: op1_ty.clone(), - t2: target_base_ty, - loc: e.clone(), - }); + return Err(Error::cannot_type_cast(op1_ty.clone(), target_base_ty, e.clone())); } } _ => {} @@ -254,15 +239,12 @@ impl LocalTypeInferenceContext { Ok(()) } - pub fn check_constraint(&self, inferred_expr_types: &HashMap) -> Result<(), TypeInferenceError> { + pub fn check_constraint(&self, inferred_expr_types: &HashMap) -> Result<(), Error> { // Check if constraints are all boolean for constraint_expr in &self.constraints { let ty = &inferred_expr_types[constraint_expr]; if !ty.is_boolean() { - return Err(TypeInferenceError::ConstraintNotBoolean { - ty: ty.clone(), - loc: constraint_expr.clone(), - }); + return Err(Error::constraint_not_boolean(ty.clone(), constraint_expr.clone())); } } Ok(()) @@ -304,124 +286,26 @@ impl NodeVisitor for LocalTypeInferenceContext { impl NodeVisitor for LocalTypeInferenceContext { fn visit(&mut self, r: &Reduce) { - // First check the output validity - let vars = r.left(); - if let Some(arity) = r.operator().output_arity() { - if vars.len() != arity { - self.errors.push(TypeInferenceError::InvalidReduceOutput { - op: r.operator().to_string().to_string(), - expected: arity, - found: r.left().len(), - loc: r.location().clone(), - }); - return; - } - } - - // Then check the number of bindings - let maybe_num_bindings = r.operator().num_bindings(); - let bindings = r.bindings(); - if let Some(num_bindings) = maybe_num_bindings { - if bindings.len() != num_bindings { - self.errors.push(TypeInferenceError::InvalidReduceBindingVar { - op: r.operator().to_string().to_string(), - expected: num_bindings, - found: bindings.len(), - loc: r.location().clone(), - }); - return; - } - } - - // Then propagate the variables - match r.operator().internal() { - _ReduceOp::Count(_) => { - if let Some(n) = vars[0].name() { - let loc = vars[0].location(); - let ty = TypeSet::BaseType(ValueType::USize, loc.clone()); - self.var_types.insert(n.to_string(), (ty, loc.clone())); - } - } - _ReduceOp::Sum => { - if let Some(n) = vars[0].name() { - let loc = vars[0].location(); - let ty = TypeSet::Numeric(loc.clone()); - self.var_types.insert(n.to_string(), (ty, loc.clone())); - - // Result var and binding var should have the same type - self - .vars_of_same_type - .push((n.to_string(), bindings[0].name().to_string())); - } - } - _ReduceOp::Prod => { - if let Some(n) = vars[0].name() { - let loc = vars[0].location(); - let ty = TypeSet::Numeric(loc.clone()); - self.var_types.insert(n.to_string(), (ty, loc.clone())); - - // Result var and binding var should have the same type - self - .vars_of_same_type - .push((n.to_string(), bindings[0].name().to_string())); - } - } - _ReduceOp::Min => { - if let Some(n) = vars[0].name() { - let loc = vars[0].location(); - let ty = TypeSet::Numeric(loc.clone()); - self.var_types.insert(n.to_string(), (ty, loc.clone())); - - // Result var and binding var should have the same type - self - .vars_of_same_type - .push((n.to_string(), bindings[0].name().to_string())); - } - } - _ReduceOp::Max => { - if let Some(n) = vars[0].name() { - let loc = vars[0].location(); - let ty = TypeSet::Numeric(loc.clone()); - self.var_types.insert(n.to_string(), (ty, loc.clone())); - - // Result var and binding var should have the same type - self - .vars_of_same_type - .push((n.to_string(), bindings[0].name().to_string())); - } - } - _ReduceOp::Exists => { - if let Some(n) = vars[0].name() { - let loc = vars[0].location(); - let ty = TypeSet::BaseType(ValueType::Bool, loc.clone()); - self.var_types.insert(n.to_string(), (ty, loc.clone())); - } - } - _ReduceOp::Forall => { - if let Some(n) = vars[0].name() { - let loc = vars[0].location(); - let ty = TypeSet::BaseType(ValueType::Bool, loc.clone()); - self.var_types.insert(n.to_string(), (ty, loc.clone())); - } - } - _ReduceOp::Unique | _ReduceOp::TopK(_) | _ReduceOp::CategoricalK(_) => { - if vars.len() == bindings.len() { - for (var, binding) in vars.iter().zip(bindings.iter()) { - if let Some(n) = var.name() { - self.vars_of_same_type.push((n.to_string(), binding.name().to_string())); - } - } - } else { - self.errors.push(TypeInferenceError::InvalidUniqueNumParams { - num_output_vars: vars.len(), - num_binding_vars: bindings.len(), - loc: r.location().clone(), - }); - return; - } - } - _ReduceOp::Unknown(_) => {} - } + // First get the aggregate type + let agg_op = r.operator(); + let agg_name = agg_op.name().name(); + let has_exclamation_mark = agg_op.has_exclaimation_mark().clone(); + + // Add the aggregation unification + self.unifications.push(Unification::Aggregate( + r.iter_left().map(|vow| vow.location()).cloned().collect(), + agg_name.clone(), + agg_op.location().clone(), + agg_op + .parameters() + .iter() + .map(|param| param.location()) + .cloned() + .collect(), + r.iter_args().map(|a| a.location()).cloned().collect(), + r.iter_bindings().map(|a| a.location()).cloned().collect(), + has_exclamation_mark, + )); } } diff --git a/core/src/compiler/front/analyzers/type_inference/mod.rs b/core/src/compiler/front/analyzers/type_inference/mod.rs index 7f2c4d8..fc307ed 100644 --- a/core/src/compiler/front/analyzers/type_inference/mod.rs +++ b/core/src/compiler/front/analyzers/type_inference/mod.rs @@ -1,6 +1,7 @@ //! # Type inference analysis mod error; +mod foreign_aggregate; mod foreign_function; mod foreign_predicate; mod local; @@ -10,8 +11,10 @@ mod type_set; mod unification; use super::super::utils::*; +use super::super::FrontCompileErrorMessage; pub use error::*; +pub use foreign_aggregate::*; pub use foreign_function::*; pub use foreign_predicate::*; pub use local::*; @@ -19,3 +22,5 @@ pub use operator_rules::*; pub use type_inference::*; pub use type_set::*; pub use unification::*; + +type Error = FrontCompileErrorMessage; diff --git a/core/src/compiler/front/analyzers/type_inference/type_inference.rs b/core/src/compiler/front/analyzers/type_inference/type_inference.rs index a264576..e50fb89 100644 --- a/core/src/compiler/front/analyzers/type_inference/type_inference.rs +++ b/core/src/compiler/front/analyzers/type_inference/type_inference.rs @@ -1,6 +1,7 @@ use std::collections::*; use crate::common::adt_variant_registry::ADTVariantRegistry; +use crate::common::foreign_aggregate::AggregateRegistry; use crate::common::foreign_function::*; use crate::common::foreign_predicate::*; use crate::common::tuple_type::*; @@ -17,6 +18,9 @@ pub struct TypeInference { /// A mapping from referred constant variables' location to its declared type, if the type is specified in the constant declaration pub constant_types: HashMap, + /// Foreign aggregate type registry + pub foreign_aggregate_type_registry: AggregateTypeRegistry, + /// Foreign function types pub foreign_function_type_registry: FunctionTypeRegistry, @@ -48,16 +52,21 @@ pub struct TypeInference { pub expr_types: HashMap, /// A list of errors obtained from the type inference process - pub errors: Vec, + pub errors: Vec, } impl TypeInference { - pub fn new(function_registry: &ForeignFunctionRegistry, predicate_registry: &ForeignPredicateRegistry) -> Self { + pub fn new( + function_registry: &ForeignFunctionRegistry, + predicate_registry: &ForeignPredicateRegistry, + aggregate_registry: &AggregateRegistry, + ) -> Self { Self { custom_types: HashMap::new(), constant_types: HashMap::new(), foreign_function_type_registry: FunctionTypeRegistry::from_foreign_function_registry(function_registry), foreign_predicate_type_registry: PredicateTypeRegistry::from_foreign_predicate_registry(predicate_registry), + foreign_aggregate_type_registry: AggregateTypeRegistry::from_aggregate_registry(aggregate_registry), relation_type_decl_loc: HashMap::new(), relation_field_names: HashMap::new(), adt_relations: HashMap::new(), @@ -70,6 +79,14 @@ impl TypeInference { } } + pub fn loc_value_type(&self, loc: &Loc) -> Option { + self.expr_types.get(loc).map(TypeSet::to_default_value_type) + } + + pub fn loc_value_types<'a, I: Iterator>(&self, i: I) -> Option> { + i.map(|loc| self.loc_value_type(loc)).collect() + } + pub fn expr_value_type(&self, t: &T) -> Option where T: AstNode, @@ -138,18 +155,18 @@ impl TypeInference { .collect() } - pub fn find_value_type(&self, ty: &Type) -> Result { + pub fn find_value_type(&self, ty: &Type) -> Result { find_value_type(&self.custom_types, ty) } pub fn check_and_add_custom_type(&mut self, name: &str, ty: &Type, loc: &Loc) { if self.custom_types.contains_key(name) { let (_, source_loc) = &self.custom_types[name]; - self.errors.push(TypeInferenceError::DuplicateTypeDecl { - type_name: name.to_string(), - source_decl_loc: source_loc.clone(), - duplicate_decl_loc: loc.clone(), - }); + self.errors.push(FrontCompileErrorMessage::duplicate_type_decl( + name.to_string(), + source_loc.clone(), + loc.clone(), + )); } else { match self.find_value_type(ty) { Ok(base_ty) => { @@ -166,11 +183,11 @@ impl TypeInference { // Check if the relation has been declared if self.relation_type_decl_loc.contains_key(predicate) { let source_loc = &self.relation_type_decl_loc[predicate]; - self.errors.push(TypeInferenceError::DuplicateRelationTypeDecl { - predicate: predicate.to_string(), - source_decl_loc: source_loc.clone(), - duplicate_decl_loc: loc.clone(), - }); + self.errors.push(Error::duplicate_relation_type_decl( + predicate.to_string(), + source_loc.clone(), + loc.clone(), + )); return; } @@ -207,7 +224,7 @@ impl TypeInference { match maybe_new_type_sets { Ok(new_type_sets) => new_type_sets, Err(err) => { - self.errors.push(err); + self.errors.push(err.into()); return; } } @@ -226,7 +243,7 @@ impl TypeInference { } } - pub fn resolve_constant_type(&self, c: &Constant) -> Result { + pub fn resolve_constant_type(&self, c: &Constant) -> Result { if let Some(ty) = self.constant_types.get(c.location()) { let val_ty = find_value_type(&self.custom_types, ty)?; Ok(TypeSet::BaseType(val_ty, ty.location().clone())) @@ -238,10 +255,9 @@ impl TypeInference { pub fn check_query_predicates(&mut self) { for (pred, loc) in &self.query_relations { if !self.inferred_relation_types.contains_key(pred) { - self.errors.push(TypeInferenceError::UnknownQueryRelationType { - predicate: pred.clone(), - loc: loc.clone(), - }); + self + .errors + .push(Error::unknown_query_relation_type(pred.clone(), loc.clone())); } } } @@ -252,7 +268,7 @@ impl TypeInference { } } - fn infer_types_helper(&mut self) -> Result<(), TypeInferenceError> { + fn infer_types_helper(&mut self) -> Result<(), Error> { // Mapping from variable to set of expressions let mut inferred_var_expr = HashMap::>>::new(); @@ -286,6 +302,7 @@ impl TypeInference { &self.inferred_relation_types, &self.foreign_function_type_registry, &self.foreign_predicate_type_registry, + &self.foreign_aggregate_type_registry, &mut inferred_expr_types, )?; ctx.propagate_variable_types(&mut inferred_var_expr, &mut inferred_expr_types)?; @@ -333,7 +350,11 @@ impl TypeInference { impl NodeVisitor for TypeInference { fn visit(&mut self, subtype_decl: &SubtypeDecl) { - self.check_and_add_custom_type(subtype_decl.name().name(), subtype_decl.subtype_of(), subtype_decl.location()); + self.check_and_add_custom_type( + subtype_decl.name().name(), + subtype_decl.subtype_of(), + subtype_decl.location(), + ); } } @@ -387,10 +408,10 @@ impl NodeVisitor for TypeInference { // Check if the relation is a foreign predicate let predicate = relation_type.predicate_name(); if self.foreign_predicate_type_registry.contains_predicate(predicate) { - self.errors.push(TypeInferenceError::CannotRedefineForeignPredicate { - pred: predicate.to_string(), - loc: relation_type.location().clone(), - }); + self.errors.push(Error::cannot_redefine_foreign_predicate( + predicate.to_string(), + relation_type.location().clone(), + )); return; } @@ -415,16 +436,14 @@ impl NodeVisitor for TypeInference { Some(c) => match c { Constant::Integer(i) => { if i.int() < &0 { - self.errors.push(TypeInferenceError::NegativeEnumValue { - found: i.int().clone(), - loc: c.location().clone(), - }) + self + .errors + .push(Error::negative_enum_value(i.int().clone(), c.location().clone())) } } - _ => self.errors.push(TypeInferenceError::BadEnumValueKind { - found: c.kind(), - loc: c.location().clone(), - }), + _ => self + .errors + .push(Error::bad_enum_value_kind(c.kind(), c.location().clone())), }, _ => {} } @@ -433,23 +452,31 @@ impl NodeVisitor for TypeInference { } impl NodeVisitor for TypeInference { - fn visit(&mut self, _: &FunctionTypeDecl) { + fn visit(&mut self, f: &FunctionTypeDecl) { // TODO - println!("[Warning] Cannot handle function type declaration yet; the declarations should be processed by external attributes") + self.errors.push( + Error::warning() + .msg("the compiler cannot handle function type declaration yet; should be processed by external attributes") + .src(f.location().clone()), + ); } } impl NodeVisitor for TypeInference { fn visit(&mut self, const_assign: &ConstAssignment) { if let Some(raw_type) = const_assign.ty() { - let result = find_value_type(&self.custom_types, raw_type).and_then(|ty| { - let ts = TypeSet::from_constant(const_assign.value().as_constant().expect("[Internal Error] During type inference, all entities should be normalized to constant. This is probably an internal error.")); - ts.unify(&TypeSet::BaseType(ty, raw_type.location().clone())) - }); - match result { - Ok(_) => {} - Err(mut err) => { - err.annotate_location(const_assign.location()); + match find_value_type(&self.custom_types, raw_type) { + Ok(ty) => { + let ts = TypeSet::from_constant(const_assign.value().as_constant().expect("[Internal Error] During type inference, all entities should be normalized to constant. This is probably an internal error.")); + match ts.unify(&TypeSet::BaseType(ty, raw_type.location().clone())) { + Ok(_) => {} + Err(mut err) => { + err.annotate_location(const_assign.location()); + self.errors.push(err.into()); + } + } + } + Err(err) => { self.errors.push(err); } } @@ -463,10 +490,10 @@ impl NodeVisitor for TypeInference { // Check if the relation is a foreign predicate if self.foreign_predicate_type_registry.contains_predicate(pred) { - self.errors.push(TypeInferenceError::CannotRedefineForeignPredicate { - pred: pred.to_string(), - loc: constant_set_decl.location().clone(), - }); + self.errors.push(Error::cannot_redefine_foreign_predicate( + pred.to_string(), + constant_set_decl.location().clone(), + )); return; } @@ -478,25 +505,27 @@ impl NodeVisitor for TypeInference { // First get the arity of the constant set. let arity = { // Compute the arity from the set - let maybe_arity = constant_set_decl.set().iter_tuples().fold(Ok(None), |acc, tuple| match acc { - Ok(maybe_arity) => { - let current_arity = tuple.arity(); - if let Some(previous_arity) = &maybe_arity { - if previous_arity != ¤t_arity { - return Err(TypeInferenceError::ConstantSetArityMismatch { - predicate: pred.clone(), - decl_loc: constant_set_decl.location().clone(), - mismatch_tuple_loc: tuple.location().clone(), - }); + let maybe_arity = constant_set_decl + .set() + .iter_tuples() + .fold(Ok(None), |acc, tuple| match acc { + Ok(maybe_arity) => { + let current_arity = tuple.arity(); + if let Some(previous_arity) = &maybe_arity { + if previous_arity != ¤t_arity { + return Err(Error::constant_set_arity_mismatch( + pred.clone(), + tuple.location().clone(), + )); + } else { + Ok(Some(current_arity)) + } } else { Ok(Some(current_arity)) } - } else { - Ok(Some(current_arity)) } - } - Err(err) => Err(err), - }); + Err(err) => Err(err), + }); // If there is arity mismatch inside the set, add the error and stop match maybe_arity { @@ -521,13 +550,12 @@ impl NodeVisitor for TypeInference { for tuple in constant_set_decl.set().iter_tuples() { // Check if the arity of the tuple matches the defined ones if tuple.arity() != type_sets.len() { - self.errors.push(TypeInferenceError::ArityMismatch { - predicate: pred.clone(), - expected: type_sets.len(), - actual: tuple.arity(), - source_loc: loc.clone(), - mismatch_loc: tuple.location().clone(), - }); + self.errors.push(Error::arity_mismatch( + pred.clone(), + type_sets.len(), + tuple.arity(), + tuple.location().clone(), + )); continue; } @@ -547,7 +575,7 @@ impl NodeVisitor for TypeInference { *ts = new_ts; } Err(err) => { - self.errors.push(err); + self.errors.push(err.into()); return; } } @@ -569,10 +597,10 @@ impl NodeVisitor for TypeInference { // Check if the relation is a foreign predicate if self.foreign_predicate_type_registry.contains_predicate(&pred) { - self.errors.push(TypeInferenceError::CannotRedefineForeignPredicate { - pred: pred.to_string(), - loc: fact_decl.location().clone(), - }); + self.errors.push(Error::cannot_redefine_foreign_predicate( + pred.to_string(), + fact_decl.location().clone(), + )); return; } @@ -580,10 +608,11 @@ impl NodeVisitor for TypeInference { if pred.contains("adt#") { // Make sure that the predicate is an existing ADT relation if !self.adt_relations.contains_key(pred) { - self.errors.push(TypeInferenceError::UnknownADTVariant { - predicate: pred[4..].to_string(), - loc: fact_decl.atom().predicate().location().clone(), - }) + self.errors.push(Error::unknown_adt_variant( + pred[4..].to_string(), + fact_decl.atom().predicate().location().clone(), + )); + return; } } @@ -606,25 +635,24 @@ impl NodeVisitor for TypeInference { // Check the type if self.inferred_relation_types.contains_key(pred) { - let (original_type_sets, original_type_def_loc) = &self.inferred_relation_types[pred]; + let (original_type_sets, _) = &self.inferred_relation_types[pred]; // First check if the arity matches if curr_type_sets.len() != original_type_sets.len() { if let Some((variant_name, _)) = self.adt_relations.get(pred) { - self.errors.push(TypeInferenceError::ADTVariantArityMismatch { - variant: variant_name.clone(), - expected: original_type_sets.len() - 1, - actual: curr_type_sets.len() - 1, - loc: fact_decl.atom().location().clone(), - }); + self.errors.push(Error::adt_variant_arity_mismatch( + variant_name.clone(), + original_type_sets.len() - 1, + curr_type_sets.len() - 1, + fact_decl.atom().location().clone(), + )); } else { - self.errors.push(TypeInferenceError::ArityMismatch { - predicate: pred.clone(), - expected: original_type_sets.len(), - actual: curr_type_sets.len(), - source_loc: original_type_def_loc.clone(), - mismatch_loc: fact_decl.atom().location().clone(), - }); + self.errors.push(Error::arity_mismatch( + pred.clone(), + original_type_sets.len(), + curr_type_sets.len(), + fact_decl.atom().location().clone(), + )); } return; } @@ -639,7 +667,7 @@ impl NodeVisitor for TypeInference { Ok(new_type_sets) => { self.inferred_relation_types.get_mut(pred).unwrap().0 = new_type_sets; } - Err(err) => self.errors.push(err), + Err(err) => self.errors.push(err.into()), } } else { self @@ -654,10 +682,10 @@ impl NodeVisitor for TypeInference { for pred in rule.head().iter_predicates() { // Check if a head predicate is a foreign predicate if self.foreign_predicate_type_registry.contains_predicate(&pred) { - self.errors.push(TypeInferenceError::CannotRedefineForeignPredicate { - pred: pred.to_string(), - loc: rule.location().clone(), - }); + self.errors.push(Error::cannot_redefine_foreign_predicate( + pred.to_string(), + rule.location().clone(), + )); return; } } @@ -687,10 +715,10 @@ impl NodeVisitor for TypeInference { // Check if the relation is a foreign predicate let pred = query.formatted_predicate(); if self.foreign_predicate_type_registry.contains_predicate(&pred) { - self.errors.push(TypeInferenceError::CannotQueryForeignPredicate { - pred: pred.to_string(), - loc: query.location().clone(), - }); + self.errors.push(Error::cannot_query_foreign_predicate( + pred.to_string(), + query.location().clone(), + )); return; } @@ -723,10 +751,7 @@ impl NodeVisitor for TypeInference { } } -pub fn find_value_type( - custom_types: &HashMap, - ty: &Type, -) -> Result { +pub fn find_value_type(custom_types: &HashMap, ty: &Type) -> Result { match ty.to_value_type() { Ok(base_ty) => Ok(base_ty), Err(other_name) => { @@ -734,10 +759,7 @@ pub fn find_value_type( let base_ty = custom_types[&other_name].0.clone(); Ok(base_ty) } else { - Err(TypeInferenceError::UnknownCustomType { - type_name: other_name, - loc: ty.location().clone(), - }) + Err(Error::unknown_custom_type(other_name, ty.location().clone())) } } } diff --git a/core/src/compiler/front/analyzers/type_inference/type_set.rs b/core/src/compiler/front/analyzers/type_inference/type_set.rs index 94619a2..cc91800 100644 --- a/core/src/compiler/front/analyzers/type_inference/type_set.rs +++ b/core/src/compiler/front/analyzers/type_inference/type_set.rs @@ -267,7 +267,7 @@ impl TypeSet { } } - pub fn unify_type_sets(tss: Vec<&TypeSet>) -> Result { + pub fn unify_type_sets(tss: Vec<&TypeSet>) -> Result { let mut ty = tss[0].clone(); for curr_ty in tss { match ty.unify(curr_ty) { @@ -278,12 +278,12 @@ impl TypeSet { Ok(ty) } - pub fn unify(&self, other: &Self) -> Result { + pub fn unify(&self, other: &Self) -> Result { use std::cmp::Ordering::*; match self.partial_cmp(other) { Some(Equal) | Some(Less) => Ok(self.clone()), Some(Greater) => Ok(other.clone()), - None => Err(TypeInferenceError::CannotUnifyTypes { + None => Err(CannotUnifyTypes { t1: self.clone(), t2: other.clone(), loc: None, diff --git a/core/src/compiler/front/analyzers/type_inference/unification.rs b/core/src/compiler/front/analyzers/type_inference/unification.rs index 31ab6f6..7f4661c 100644 --- a/core/src/compiler/front/analyzers/type_inference/unification.rs +++ b/core/src/compiler/front/analyzers/type_inference/unification.rs @@ -2,9 +2,12 @@ use std::collections::*; use super::*; +use crate::common::foreign_aggregate::*; use crate::common::value_type::*; use crate::compiler::front::*; +type Error = FrontCompileErrorMessage; + /// The structure storing unification relationships #[derive(Clone, Debug)] pub enum Unification { @@ -56,6 +59,10 @@ pub enum Unification { /// f, ops*, $f(ops*) Call(String, Vec, Loc), + /// var* := AGGREGATE![arg*](in_var*: ...) + /// var*, AGGREGATE name, AGGREGATE (loc), param*, arg*, in_var*, has_exclamation_mark + Aggregate(Vec, String, Loc, Vec, Vec, Vec, bool), + /// C, ops*, new C(ops*) New(String, Vec, Loc), } @@ -69,8 +76,9 @@ impl Unification { inferred_relation_types: &HashMap, Loc)>, function_type_registry: &FunctionTypeRegistry, predicate_type_registry: &PredicateTypeRegistry, + aggregate_type_registry: &AggregateTypeRegistry, inferred_expr_types: &mut HashMap, - ) -> Result<(), TypeInferenceError> { + ) -> Result<(), Error> { match self { Self::IthArgOfRelation(e, p, i) => { if let Some(tys) = predicate_type_registry.get(p) { @@ -81,35 +89,35 @@ impl Unification { // Unify the type match unify_ty(e, ty.clone(), inferred_expr_types) { Ok(_) => Ok(()), - Err(_) => Err(TypeInferenceError::CannotUnifyForeignPredicateArgument { - pred: p.clone(), - i: *i, - expected_ty: ty, - actual_ty: inferred_expr_types.get(e).unwrap().clone(), - loc: e.clone(), - }), + Err(_) => Err(Error::cannot_unify_foreign_predicate_arg( + p.clone(), + *i, + ty, + inferred_expr_types.get(e).unwrap().clone(), + e.clone(), + )), } } else { - Err(TypeInferenceError::InvalidForeignPredicateArgIndex { - predicate: p.clone(), - index: i.clone(), - access_loc: e.clone(), - }) + Err(Error::invalid_foreign_predicate_arg_index( + p.clone(), + i.clone(), + e.clone(), + )) } } else { // It is user defined predicate let (tys, loc) = inferred_relation_types.get(p).unwrap(); if i < &tys.len() { let ty = tys[*i].clone(); - unify_ty(e, ty, inferred_expr_types)?; + unify_ty(e, ty, inferred_expr_types).map_err(|e| e.into())?; Ok(()) } else { - Err(TypeInferenceError::InvalidArgIndex { - predicate: p.clone(), - index: i.clone(), - source_loc: loc.clone(), - access_loc: e.clone(), - }) + Err(Error::invalid_predicate_arg_index( + p.clone(), + i.clone(), + loc.clone(), + e.clone(), + )) } } } @@ -128,7 +136,7 @@ impl Unification { Ok(t) => t, Err(mut err) => { err.annotate_location(e); - return Err(err); + return Err(err.into()); } }; @@ -137,7 +145,7 @@ impl Unification { Ok(()) } else { // If the constant is not typed, simply check the inferred expression types - unify_ty(e, ty.clone(), inferred_expr_types)?; + unify_ty(e, ty.clone(), inferred_expr_types).map_err(|e| e.into())?; Ok(()) } } @@ -175,13 +183,13 @@ impl Unification { } Err(mut err) => { err.annotate_location(e); - Err(err) + Err(err.into()) } } } Self::EqNeq(op1, op2, e) => { // The type of e is boolean - unify_boolean(e, inferred_expr_types)?; + unify_boolean(e, inferred_expr_types).map_err(|e| e.into())?; // The two operators are of the same type let op_ty = TypeSet::Any(op1.clone()); @@ -195,15 +203,15 @@ impl Unification { } Err(mut err) => { err.annotate_location(e); - Err(err) + Err(err.into()) } } } Self::AndOrXor(op1, op2, e) => { // All e, op1, and op2 are boolean - unify_boolean(e, inferred_expr_types)?; - unify_boolean(op1, inferred_expr_types)?; - unify_boolean(op2, inferred_expr_types)?; + unify_boolean(e, inferred_expr_types).map_err(|e| e.into())?; + unify_boolean(op1, inferred_expr_types).map_err(|e| e.into())?; + unify_boolean(op2, inferred_expr_types).map_err(|e| e.into())?; Ok(()) } @@ -227,20 +235,20 @@ impl Unification { } Err(mut err) => { err.annotate_location(e); - Err(err) + Err(err.into()) } } } Self::Not(op1, e) => { // e and op1 should both be boolean - unify_boolean(e, inferred_expr_types)?; - unify_boolean(op1, inferred_expr_types)?; + unify_boolean(e, inferred_expr_types).map_err(|e| e.into())?; + unify_boolean(op1, inferred_expr_types).map_err(|e| e.into())?; Ok(()) } Self::IfThenElse(e, cond, then_br, else_br) => { // cond should be boolean - unify_boolean(cond, inferred_expr_types)?; + unify_boolean(cond, inferred_expr_types).map_err(|e| e.into())?; // Make sure that the expression, the then branch, and the else branch all have the same type let e_ty = get_or_insert_ty(e, TypeSet::Any(e.clone()), inferred_expr_types); @@ -255,7 +263,7 @@ impl Unification { } Err(mut err) => { err.annotate_location(e); - Err(err) + Err(err.into()) } } } @@ -265,10 +273,10 @@ impl Unification { Ok(base_ty) => TypeSet::BaseType(base_ty, ty.location().clone()), Err(err) => return Err(err), }; - unify_ty(e, ts, inferred_expr_types)?; + unify_ty(e, ts, inferred_expr_types).map_err(|e| e.into())?; // op1 can be any type (for now) - unify_any(op1, inferred_expr_types)?; + unify_any(op1, inferred_expr_types).map_err(|e| e.into())?; Ok(()) } @@ -294,7 +302,7 @@ impl Unification { } FunctionArgumentType::TypeSet(ts) => { // Unify arg type for non-generic ones - unify_ty(arg, ts, inferred_expr_types)?; + unify_ty(arg, ts, inferred_expr_types).map_err(|e| e.into())?; } } } @@ -311,7 +319,7 @@ impl Unification { FunctionReturnType::BaseType(t) => { // Unify the return type with the base type let ts = TypeSet::base(t.clone()); - unify_ty(e, ts, inferred_expr_types)?; + unify_ty(e, ts, inferred_expr_types).map_err(|e| e.into())?; } } @@ -325,15 +333,17 @@ impl Unification { if let Some(instances) = generic_type_param_instances.get(&i) { if instances.len() >= 1 { // Keep an aggregated unified ts starting from the first instance - let mut agg_unified_ts = unify_ty(&instances[0], generic_type_family.clone(), inferred_expr_types)?; + let mut agg_unified_ts = + unify_ty(&instances[0], generic_type_family.clone(), inferred_expr_types).map_err(|e| e.into())?; // Iterate from the next instance for j in 1..instances.len() { // Make sure the current type conform to the generic type parameter - let curr_unified_ts = unify_ty(&instances[j], generic_type_family.clone(), inferred_expr_types)?; + let curr_unified_ts = unify_ty(&instances[j], generic_type_family.clone(), inferred_expr_types) + .map_err(|e| e.into())?; // Unify with the aggregated type set - agg_unified_ts = agg_unified_ts.unify(&curr_unified_ts)?; + agg_unified_ts = agg_unified_ts.unify(&curr_unified_ts).map_err(|e| e.into())?; } // At the end, update all instances to have the `agg_unified_ts` type @@ -347,24 +357,110 @@ impl Unification { // No more error Ok(()) } else { - Err(TypeInferenceError::FunctionArityMismatch { - function: function.clone(), - actual: args.len(), - loc: e.clone(), - }) + Err(Error::function_arity_mismatch(function.clone(), args.len(), e.clone())) } } else { - Err(TypeInferenceError::UnknownFunctionType { - function_name: function.clone(), - loc: e.clone(), - }) + Err(Error::unknown_function_type(function.clone(), e.clone())) + } + } + Self::Aggregate(out_vars, agg, agg_loc, param_consts, arg_vars, in_vars, has_exclamation) => { + if let Some(agg_type) = aggregate_type_registry.get(agg) { + // 1. check the parameters length match + let mut has_optional = false; + let mut curr_param_const_id = 0; + for (i, param_type) in agg_type.param_types.iter().enumerate() { + match param_type { + ParamType::Mandatory(vt) => { + if has_optional { + return Err(Error::error() + .msg(format!("error in aggregate `{agg}`: mandatory parameter must occur before optional parameter"))); + } else if let Some(curr_param) = param_consts.get(curr_param_const_id) { + unify_ty(curr_param, TypeSet::base(vt.clone()), inferred_expr_types).map_err(|e| e.into())?; + curr_param_const_id += 1; + } else { + return Err(Error::error() + .msg(format!("mandatory {i}-th {vt} parameter not found for aggregate `{agg}`:")) + .src(agg_loc.clone())) + } + } + ParamType::Optional(vt) => { + has_optional = true; + if let Some(curr_param) = param_consts.get(curr_param_const_id) { + match unify_ty(curr_param, TypeSet::base(vt.clone()), inferred_expr_types) { + Ok(_) => { + curr_param_const_id += 1; + } + Err(_) => {} + } + } + } + } + } + + // 2. check if there is any extra parameter not scanned + if curr_param_const_id + 1 < param_consts.len() { + return Err(Error::error() + .msg(format!("expected at most {} parameters, found {} parameters", agg_type.param_types.len(), param_consts.len())) + .src(agg_loc.clone())) + } + + // 3. check the exclamation mark + if *has_exclamation && !agg_type.allow_exclamation_mark { + return Err( + Error::error() + .msg(format!("aggregator `{agg}` does not support exclamation mark")) + .src(agg_loc.clone()), + ); + } + + // 4. trying to ground generics for the arg_vars or in_vars + let mut grounded_generic_types = HashMap::new(); + ground_input_aggregate_binding_type( + "argument", + agg, + agg_loc, + &agg_type.arg_type, + arg_vars, + &agg_type.generics, + &mut grounded_generic_types, + inferred_expr_types, + )?; + ground_input_aggregate_binding_type( + "input", + agg, + agg_loc, + &agg_type.input_type, + in_vars, + &agg_type.generics, + &mut grounded_generic_types, + inferred_expr_types, + )?; + + // 5. trying to unify the out_vars + ground_output_aggregate_binding_type( + agg, + agg_loc, + &agg_type.output_type, + out_vars, + &grounded_generic_types, + inferred_expr_types, + )?; + + // If passed all the tests, all good! + Ok(()) + } else { + Err( + Error::error() + .msg(format!("unknown aggregate `{agg}`")) + .src(agg_loc.clone()), + ) } } Self::New(functor, args, e) => { let adt_variant_relation_name = format!("adt#{functor}"); // cond should be boolean - unify_entity(e, inferred_expr_types)?; + unify_entity(e, inferred_expr_types).map_err(|e| e.into())?; // Get the functor/relation if let Some((types, _)) = inferred_relation_types.get(&adt_variant_relation_name) { @@ -377,24 +473,21 @@ impl Unification { } Err(mut err) => { err.annotate_location(arg); - return Err(err); + return Err(err.into()); } } } Ok(()) } else { - Err(TypeInferenceError::ADTVariantArityMismatch { - variant: functor.clone(), - expected: types.len() - 1, - actual: args.len(), - loc: e.clone(), - }) + Err(Error::adt_variant_arity_mismatch( + functor.clone(), + types.len() - 1, + args.len(), + e.clone(), + )) } } else { - Err(TypeInferenceError::UnknownADTVariant { - predicate: functor.clone(), - loc: e.clone(), - }) + Err(Error::unknown_adt_variant(functor.clone(), e.clone())) } } } @@ -431,11 +524,11 @@ fn unify_polymorphic_binary_expression( e: &Loc, inferred_expr_types: &mut HashMap, rules: &[(ValueType, ValueType, ValueType)], -) -> Result<(), TypeInferenceError> { +) -> Result<(), Error> { // First get the already inferred types of op1, op2, and e - let op1_ty = unify_any(op1, inferred_expr_types)?; - let op2_ty = unify_any(op2, inferred_expr_types)?; - let e_ty = unify_any(e, inferred_expr_types)?; + let op1_ty = unify_any(op1, inferred_expr_types).map_err(|e| e.into())?; + let op2_ty = unify_any(op2, inferred_expr_types).map_err(|e| e.into())?; + let e_ty = unify_any(e, inferred_expr_types).map_err(|e| e.into())?; // Then iterate through all the rules to see if any could be applied let mut applied_rules = AppliedRules::new(); @@ -449,18 +542,13 @@ fn unify_polymorphic_binary_expression( match applied_rules { AppliedRules::None => { // If no rule can be applied, then the type inference is failed - Err(TypeInferenceError::NoMatchingTripletRule { - op1_ty, - op2_ty, - e_ty, - location: e.clone(), - }) + Err(Error::no_matching_triplet_rule(op1_ty, op2_ty, e_ty, e.clone())) } AppliedRules::One((t1, t2, te)) => { // If there is exactly one rule that can be applied, then unify them with the exact types - unify_ty(op1, TypeSet::BaseType(t1, e.clone()), inferred_expr_types)?; - unify_ty(op2, TypeSet::BaseType(t2, e.clone()), inferred_expr_types)?; - unify_ty(e, TypeSet::BaseType(te, e.clone()), inferred_expr_types)?; + unify_ty(op1, TypeSet::BaseType(t1, e.clone()), inferred_expr_types).map_err(|e| e.into())?; + unify_ty(op2, TypeSet::BaseType(t2, e.clone()), inferred_expr_types).map_err(|e| e.into())?; + unify_ty(e, TypeSet::BaseType(te, e.clone()), inferred_expr_types).map_err(|e| e.into())?; Ok(()) } AppliedRules::Multiple => { @@ -477,13 +565,13 @@ fn unify_comparison_expression( e: &Loc, inferred_expr_types: &mut HashMap, rules: &[(ValueType, ValueType)], -) -> Result<(), TypeInferenceError> { +) -> Result<(), Error> { // The result should be a boolean - let e_ty = unify_boolean(e, inferred_expr_types)?; + let e_ty = unify_boolean(e, inferred_expr_types).map_err(|e| e.into())?; // First get the already inferred types of op1, op2, and e - let op1_ty = unify_any(op1, inferred_expr_types)?; - let op2_ty = unify_any(op2, inferred_expr_types)?; + let op1_ty = unify_any(op1, inferred_expr_types).map_err(|e| e.into())?; + let op2_ty = unify_any(op2, inferred_expr_types).map_err(|e| e.into())?; // Then iterate through all the rules to see if any could be applied let mut applied_rules = AppliedRules::new(); @@ -497,17 +585,12 @@ fn unify_comparison_expression( match applied_rules { AppliedRules::None => { // If no rule can be applied, then the type inference is failed - Err(TypeInferenceError::NoMatchingTripletRule { - op1_ty, - op2_ty, - e_ty, - location: e.clone(), - }) + Err(Error::no_matching_triplet_rule(op1_ty, op2_ty, e_ty, e.clone())) } AppliedRules::One((t1, t2)) => { // If there is exactly one rule that can be applied, then unify them with the exact types - unify_ty(op1, TypeSet::BaseType(t1, e.clone()), inferred_expr_types)?; - unify_ty(op2, TypeSet::BaseType(t2, e.clone()), inferred_expr_types)?; + unify_ty(op1, TypeSet::BaseType(t1, e.clone()), inferred_expr_types).map_err(|e| e.into())?; + unify_ty(op2, TypeSet::BaseType(t2, e.clone()), inferred_expr_types).map_err(|e| e.into())?; Ok(()) } AppliedRules::Multiple => { @@ -522,7 +605,7 @@ fn unify_ty( e: &Loc, ty: TypeSet, inferred_expr_types: &mut HashMap, -) -> Result { +) -> Result { let old_e_ty = inferred_expr_types.entry(e.clone()).or_insert(ty.clone()); match ty.unify(old_e_ty) { Ok(new_e_ty) => { @@ -536,17 +619,309 @@ fn unify_ty( } } -fn unify_any(e: &Loc, inferred_expr_types: &mut HashMap) -> Result { +fn unify_any(e: &Loc, inferred_expr_types: &mut HashMap) -> Result { let e_ty = TypeSet::Any(e.clone()); unify_ty(e, e_ty, inferred_expr_types) } -fn unify_boolean(e: &Loc, inferred_expr_types: &mut HashMap) -> Result { +fn unify_boolean(e: &Loc, inferred_expr_types: &mut HashMap) -> Result { let e_ty = TypeSet::BaseType(ValueType::Bool, e.clone()); unify_ty(e, e_ty, inferred_expr_types) } -fn unify_entity(e: &Loc, inferred_expr_types: &mut HashMap) -> Result { +fn unify_entity(e: &Loc, inferred_expr_types: &mut HashMap) -> Result { let e_ty = TypeSet::BaseType(ValueType::Entity, e.clone()); unify_ty(e, e_ty, inferred_expr_types) } + +/// Given a binding type of an aggregate and the concrete variables for the aggregate, check the variable types and +/// potentially ground the generic types if they present +fn ground_input_aggregate_binding_type( + kind: &str, + aggregate: &str, + aggregate_loc: &Loc, + binding_types: &BindingTypes, + variables: &Vec, + generic_type_families: &HashMap, + grounded_generic_types: &mut HashMap>, + inferred_expr_types: &mut HashMap, +) -> Result<(), Error> { + // First match on binding types + match binding_types { + BindingTypes::IfNotUnit { .. } => { + // Input binding types cannot have if-not-unit expression + Err(Error::error().msg(format!( + "error in aggregate `{aggregate}`: cannot have if-not-unit binding type in aggregate {kind}" + ))) + } + BindingTypes::TupleType(elems) => { + if elems.len() == 0 { + // If elems.len() is 0, it means that there should be no variable for this part of aggregation. + // We throw error if there is at least 1 variable. + // Otherwise, the type checking is done as there is no variable that needs to be unified for type + if variables.len() != 0 { + Err( + Error::error() + .msg(format!( + "unexpected {kind} variables in aggregate `{aggregate}`. Expected 0, found {}", + variables.len() + )) + .src(aggregate_loc.clone()), + ) + } else { + Ok(()) + } + } else if elems.len() == 1 { + // If elems.len() is 1, we could have that exact element to be a generic type variable or a concrete value type + match &elems[0] { + BindingType::Generic(g) => { + if let Some(grounded_type_sets) = grounded_generic_types.get(g) { + if grounded_type_sets.len() != variables.len() { + Err( + Error::error() + .msg( + format!( + "the generic type `{g}` in aggregate `{aggregate}` is grounded to have {} variables; however, it is unified with a set of {} variables:", + grounded_type_sets.len(), + variables.len() + ) + ) + .src(aggregate_loc.clone()) + ) + } else { + for (grounded_type_set, variable_loc) in grounded_type_sets.iter().zip(variables.iter()) { + unify_ty(variable_loc, grounded_type_set.clone(), inferred_expr_types).map_err(|e| e.into())?; + } + Ok(()) + } + } else if let Some(generic_type_family) = generic_type_families.get(g) { + let grounded_type_sets = solve_generic_type( + kind, + aggregate, + aggregate_loc, + g, + generic_type_family, + variables, + inferred_expr_types, + )?; + grounded_generic_types.insert(g.to_string(), grounded_type_sets); + Ok(()) + } else { + Err(Error::error().msg(format!( + "error processing aggregate `{aggregate}`: unknown generic type parameter `{g}`" + ))) + } + } + BindingType::ValueType(v) => { + if variables.len() == 1 { + unify_ty(&variables[0], TypeSet::base(v.clone()), inferred_expr_types).map_err(|e| e.into())?; + Ok(()) + } else { + // Arity mismatch + if variables.len() == 0 { + Err( + Error::error() + .msg(format!( + "expected exactly one {v} {kind} variable in aggregate `{aggregate}`, found {}", + variables.len() + )) + .src(aggregate_loc.clone()), + ) + } else { + Err( + Error::error() + .msg(format!( + "expected exactly one {v} {kind} variable in aggregate `{aggregate}`, found {}", + variables.len() + )) + .src(variables[1].clone()), + ) + } + } + } + } + } else { + if elems.iter().any(|e| e.is_generic()) { + Err(Error::error().msg(format!( + "error in aggregate `{aggregate}`: cannot have generic in the {kind} of aggregate of more than 1 elements" + ))) + } else if elems.len() != variables.len() { + Err( + Error::error() + .msg(format!( + "expected {} {kind} variables in aggregate `{aggregate}`, found {}", + elems.len(), + variables.len() + )) + .src(aggregate_loc.clone()), + ) + } else { + for (elem_binding_type, variable_loc) in elems.iter().zip(variables.iter()) { + let elem_value_type = elem_binding_type.as_value_type().unwrap(); // unwrap is ok since we have checked that no element is generic + unify_ty( + variable_loc, + TypeSet::base(elem_value_type.clone()), + inferred_expr_types, + ) + .map_err(|e| e.into())?; + } + Ok(()) + } + } + } + } +} + +fn solve_generic_type( + kind: &str, + aggregate: &str, + aggregate_loc: &Loc, + generic_type_name: &str, + generic_type_family: &GenericTypeFamily, + variables: &Vec, + inferred_expr_types: &mut HashMap, +) -> Result, Error> { + match generic_type_family { + GenericTypeFamily::NonEmptyTuple => { + if variables.len() == 0 { + Err( + Error::error() + .msg(format!( + "arity mismatch on aggregate `{aggregate}`. Expected non-empty {kind} variables, but found 0" + )) + .src(aggregate_loc.clone()), + ) + } else { + variables + .iter() + .map(|var_loc| unify_ty(var_loc, TypeSet::any(), inferred_expr_types).map_err(|e| e.into())) + .collect::, _>>() + } + } + GenericTypeFamily::NonEmptyTupleWithElements(elem_type_families) => { + if elem_type_families.iter().any(|tf| !tf.is_type_family()) { + Err(Error::error().msg(format!("error in aggregate `{aggregate}`: generic type family `{generic_type_name}` contains unsupported nested tuple"))) + } else if variables.len() != elem_type_families.len() { + Err( + Error::error() + .msg(format!( + "arity mismatch on aggregate `{aggregate}`. Expected {} {kind} variables, but found 0", + elem_type_families.len() + )) + .src(aggregate_loc.clone()), + ) + } else { + variables + .iter() + .zip(elem_type_families.iter()) + .map(|(var_loc, elem_type_family)| { + let type_family = elem_type_family.as_type_family().unwrap(); // unwrap is okay since we have checked that every elem is a base type family + unify_ty(var_loc, TypeSet::from(type_family.clone()), inferred_expr_types).map_err(|e| e.into()) + }) + .collect::, _>>() + } + } + GenericTypeFamily::UnitOr(child_generic_type_family) => { + if variables.len() == 0 { + Ok(vec![]) + } else { + solve_generic_type( + kind, + aggregate, + aggregate_loc, + generic_type_name, + child_generic_type_family, + variables, + inferred_expr_types, + ) + } + } + GenericTypeFamily::TypeFamily(tf) => { + if variables.len() != 1 { + Err( + Error::error() + .msg(format!( + "arity mismatch on aggregate `{aggregate}`. Expected 1 {kind} variables, but found 0" + )) + .src(aggregate_loc.clone()), + ) + } else { + let ts = unify_ty(&variables[0], TypeSet::from(tf.clone()), inferred_expr_types).map_err(|e| e.into())?; + Ok(vec![ts]) + } + } + } +} + +fn ground_output_aggregate_binding_type( + aggregate: &str, + aggregate_loc: &Loc, + binding_types: &BindingTypes, + variables: &Vec, + grounded_generic_types: &HashMap>, + inferred_expr_types: &mut HashMap, +) -> Result<(), Error> { + let expected_variable_types = solve_binding_types(aggregate, binding_types, grounded_generic_types)?; + if expected_variable_types.len() != variables.len() { + Err( + Error::error() + .msg(format!( + "in aggregate `{aggregate}`, {} output argument(s) is expected, found {}", + expected_variable_types.len(), + variables.len() + )) + .src(aggregate_loc.clone()), + ) + } else { + for (expected_var_type, variable_loc) in expected_variable_types.into_iter().zip(variables.iter()) { + unify_ty(variable_loc, expected_var_type, inferred_expr_types).map_err(|e| e.into())?; + } + Ok(()) + } +} + +fn solve_binding_types( + aggregate: &str, + binding_types: &BindingTypes, + grounded_generic_types: &HashMap>, +) -> Result, Error> { + match binding_types { + BindingTypes::IfNotUnit { + generic_type, + then_type, + else_type, + } => { + if let Some(type_sets) = grounded_generic_types.get(generic_type) { + if type_sets.len() > 0 { + solve_binding_types(aggregate, then_type, grounded_generic_types) + } else { + solve_binding_types(aggregate, else_type, grounded_generic_types) + } + } else { + Err(Error::error().msg(format!( + "error grounding output type of aggregate `{aggregate}`: unknown generic type `{generic_type}`" + ))) + } + } + BindingTypes::TupleType(elems) => Ok( + elems + .iter() + .map(|elem| match elem { + BindingType::Generic(g) => { + if let Some(type_sets) = grounded_generic_types.get(g) { + Ok(type_sets.clone()) + } else { + Err(Error::error().msg(format!( + "error grounding output type of aggregate `{aggregate}`: unknown generic type `{g}`" + ))) + } + } + BindingType::ValueType(v) => Ok(vec![TypeSet::base(v.clone())]), + }) + .collect::, _>>()? + .into_iter() + .flat_map(|es| es) + .collect(), + ), + } +} diff --git a/core/src/compiler/front/ast/attr.rs b/core/src/compiler/front/ast/attr.rs index 05c5ca5..5d2650a 100644 --- a/core/src/compiler/front/ast/attr.rs +++ b/core/src/compiler/front/ast/attr.rs @@ -67,11 +67,17 @@ impl AttributeValue { } pub fn as_boolean(&self) -> Option { - self.as_constant().and_then(|c| c.as_boolean()).map(|b| b.value().clone()) + self + .as_constant() + .and_then(|c| c.as_boolean()) + .map(|b| b.value().clone()) } pub fn as_string(&self) -> Option { - self.as_constant().and_then(|c| c.as_string()).map(|b| b.string().clone()) + self + .as_constant() + .and_then(|c| c.as_string()) + .map(|b| b.string().clone()) } } @@ -114,7 +120,10 @@ impl Attribute { } pub fn num_pos_args(&self) -> usize { - self.iter_args().filter(|a| AttributeArg::is_pos(a)).fold(0, |acc, _| acc + 1) + self + .iter_args() + .filter(|a| AttributeArg::is_pos(a)) + .fold(0, |acc, _| acc + 1) } pub fn iter_pos_args(&self) -> impl Iterator { @@ -157,7 +166,10 @@ impl Attribute { } pub fn num_kw_args(&self) -> usize { - self.iter_args().filter(|a| AttributeArg::is_kw(a)).fold(0, |acc, _| acc + 1) + self + .iter_args() + .filter(|a| AttributeArg::is_kw(a)) + .fold(0, |acc, _| acc + 1) } pub fn iter_kw_args(&self) -> impl Iterator { diff --git a/core/src/compiler/front/ast/formula.rs b/core/src/compiler/front/ast/formula.rs index a31fccb..7f21cf9 100644 --- a/core/src/compiler/front/ast/formula.rs +++ b/core/src/compiler/front/ast/formula.rs @@ -18,9 +18,7 @@ pub enum Formula { impl Formula { pub fn negate(&self) -> Self { match self { - Self::Atom(a) => { - Self::NegAtom(NegAtom::new(a.clone())) - }, + Self::Atom(a) => Self::NegAtom(NegAtom::new(a.clone())), Self::NegAtom(n) => Self::Atom(n.atom().clone()), Self::Case(_) => { // TODO @@ -181,79 +179,10 @@ impl Reduce { #[derive(Clone, Debug, PartialEq, Serialize, AstNode)] #[doc(hidden)] -pub enum _ReduceOp { - Count(bool), - Sum, - Prod, - Min, - Max, - Exists, - Forall, - Unique, - TopK(usize), - CategoricalK(usize), - Unknown(String), -} - -impl ToString for _ReduceOp { - fn to_string(&self) -> String { - match self { - Self::Count(discrete) => if *discrete { - "count!".to_string() - } else { - "count".to_string() - }, - Self::Sum => "sum".to_string(), - Self::Prod => "prod".to_string(), - Self::Min => "min".to_string(), - Self::Max => "max".to_string(), - Self::Exists => "exists".to_string(), - Self::Forall => "forall".to_string(), - Self::Unique => "unique".to_string(), - Self::TopK(k) => format!("top<{}>", k), - Self::CategoricalK(k) => format!("categorical<{}>", k), - Self::Unknown(_) => "unknown".to_string(), - } - } -} - -impl ReduceOp { - pub fn output_arity(&self) -> Option { - match self.internal() { - _ReduceOp::Count(_) => Some(1), - _ReduceOp::Sum => Some(1), - _ReduceOp::Prod => Some(1), - _ReduceOp::Min => Some(1), - _ReduceOp::Max => Some(1), - _ReduceOp::Exists => Some(1), - _ReduceOp::Forall => Some(1), - _ReduceOp::Unique => None, - _ReduceOp::TopK(_) => None, - _ReduceOp::CategoricalK(_) => None, - _ReduceOp::Unknown(_) => None, - } - } - - pub fn num_bindings(&self) -> Option { - match self.internal() { - _ReduceOp::Count(_) => None, - _ReduceOp::Sum => Some(1), - _ReduceOp::Prod => Some(1), - _ReduceOp::Min => Some(1), - _ReduceOp::Max => Some(1), - _ReduceOp::Exists => None, - _ReduceOp::Forall => None, - _ReduceOp::Unique => None, - _ReduceOp::TopK(_) => None, - _ => None, - } - } -} - -impl ToString for ReduceOp { - fn to_string(&self) -> String { - self.internal().to_string() - } +pub struct _ReduceOp { + pub name: Identifier, + pub parameters: Vec, + pub has_exclaimation_mark: bool, } /// An syntax sugar for forall/exists operation, e.g. `forall(p: person(p) => father(p, _))`. diff --git a/core/src/compiler/front/ast/mod.rs b/core/src/compiler/front/ast/mod.rs index 14f3de9..c9cbec2 100644 --- a/core/src/compiler/front/ast/mod.rs +++ b/core/src/compiler/front/ast/mod.rs @@ -28,5 +28,5 @@ pub use type_decl::*; pub use types::*; pub use utils::*; -use serde::*; use astnode_derive::*; +use serde::*; diff --git a/core/src/compiler/front/ast/relation_decl.rs b/core/src/compiler/front/ast/relation_decl.rs index a54287e..b8dbcfe 100644 --- a/core/src/compiler/front/ast/relation_decl.rs +++ b/core/src/compiler/front/ast/relation_decl.rs @@ -95,7 +95,11 @@ pub struct _RuleDecl { impl RuleDecl { pub fn rule_tag_predicate(&self) -> String { if let Some(head_atom) = self.rule().head().as_atom() { - format!("rt#{}#{}", head_atom.predicate(), self.location_id().expect("location id has not been tagged yet")) + format!( + "rt#{}#{}", + head_atom.predicate(), + self.location_id().expect("location id has not been tagged yet") + ) } else { unimplemented!("Rule head is not an atom") } diff --git a/core/src/compiler/front/ast/rule.rs b/core/src/compiler/front/ast/rule.rs index 0517f92..1e38d56 100644 --- a/core/src/compiler/front/ast/rule.rs +++ b/core/src/compiler/front/ast/rule.rs @@ -9,17 +9,11 @@ pub struct _Rule { impl Into> for Rule { fn into(self) -> Vec { - vec![ - Item::RelationDecl( - RelationDecl::Rule( - RuleDecl::new( - Attributes::new(), - Tag::none(), - self, - ), - ), - ), - ] + vec![Item::RelationDecl(RelationDecl::Rule(RuleDecl::new( + Attributes::new(), + Tag::none(), + self, + )))] } } diff --git a/core/src/compiler/front/ast/type_decl.rs b/core/src/compiler/front/ast/type_decl.rs index 699e194..57eaf55 100644 --- a/core/src/compiler/front/ast/type_decl.rs +++ b/core/src/compiler/front/ast/type_decl.rs @@ -75,16 +75,10 @@ pub struct _RelationType { impl Into> for RelationType { fn into(self) -> Vec { - vec![ - Item::TypeDecl( - TypeDecl::Relation( - RelationTypeDecl::new( - Attributes::new(), - vec![self] - ), - ), - ), - ] + vec![Item::TypeDecl(TypeDecl::Relation(RelationTypeDecl::new( + Attributes::new(), + vec![self], + )))] } } diff --git a/core/src/compiler/front/ast/types.rs b/core/src/compiler/front/ast/types.rs index b80fa0c..9d1e7f6 100644 --- a/core/src/compiler/front/ast/types.rs +++ b/core/src/compiler/front/ast/types.rs @@ -1,4 +1,4 @@ -use serde::ser::{Serialize, Serializer, SerializeStruct}; +use serde::ser::{Serialize, SerializeStruct, Serializer}; use crate::common::value_type::*; @@ -36,7 +36,7 @@ pub enum _Type { impl Serialize for _Type { fn serialize(&self, serializer: S) -> Result where - S: Serializer, + S: Serializer, { // 3 is the number of fields in the struct. let mut state = serializer.serialize_struct("_Type", 2)?; diff --git a/core/src/compiler/front/ast/utils.rs b/core/src/compiler/front/ast/utils.rs index 7154cb6..ffe519b 100644 --- a/core/src/compiler/front/ast/utils.rs +++ b/core/src/compiler/front/ast/utils.rs @@ -1,7 +1,7 @@ use colored::*; -use super::*; use super::super::*; +use super::*; #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize)] pub struct CharLocation { @@ -323,7 +323,10 @@ derive_ast_walker!(String); derive_ast_walker!(crate::common::input_tag::DynamicInputTag); derive_ast_walker!(crate::common::binary_op::BinaryOp); -impl AstWalker for Vec where T: AstWalker { +impl AstWalker for Vec +where + T: AstWalker, +{ fn walk(&self, v: &mut V) { for child in self { child.walk(v) @@ -337,7 +340,10 @@ impl AstWalker for Vec where T: AstWalker { } } -impl AstWalker for Option where T: AstWalker { +impl AstWalker for Option +where + T: AstWalker, +{ fn walk(&self, v: &mut V) { if let Some(n) = self { n.walk(v) @@ -351,7 +357,10 @@ impl AstWalker for Option where T: AstWalker { } } -impl AstWalker for Box where T: AstWalker { +impl AstWalker for Box +where + T: AstWalker, +{ fn walk(&self, v: &mut V) { (&**self).walk(v) } @@ -361,7 +370,10 @@ impl AstWalker for Box where T: AstWalker { } } -impl AstWalker for (A,) where A: AstWalker { +impl AstWalker for (A,) +where + A: AstWalker, +{ fn walk(&self, v: &mut V) { self.0.walk(v); } @@ -371,7 +383,11 @@ impl AstWalker for (A,) where A: AstWalker { } } -impl AstWalker for (A, B) where A: AstWalker, B: AstWalker { +impl AstWalker for (A, B) +where + A: AstWalker, + B: AstWalker, +{ fn walk(&self, v: &mut V) { self.0.walk(v); self.1.walk(v); @@ -383,7 +399,12 @@ impl AstWalker for (A, B) where A: AstWalker, B: AstWalker { } } -impl AstWalker for (A, B, C) where A: AstWalker, B: AstWalker, C: AstWalker { +impl AstWalker for (A, B, C) +where + A: AstWalker, + B: AstWalker, + C: AstWalker, +{ fn walk(&self, v: &mut V) { self.0.walk(v); self.1.walk(v); diff --git a/core/src/compiler/front/compile.rs b/core/src/compiler/front/compile.rs index a1b7063..212f7ff 100644 --- a/core/src/compiler/front/compile.rs +++ b/core/src/compiler/front/compile.rs @@ -7,6 +7,7 @@ use super::analyzers::*; use super::attribute::*; use super::*; +use crate::common::foreign_aggregate::*; use crate::common::foreign_function::*; use crate::common::foreign_predicate::*; use crate::common::tuple_type::*; @@ -33,6 +34,9 @@ pub struct FrontContext { /// Foreign predicate registry holding all foreign predicates pub foreign_predicate_registry: ForeignPredicateRegistry, + /// Foreign aggregate registry holding all foreign aggregates + pub foreign_aggregate_registry: AggregateRegistry, + /// Attribute processor registry holding all attribute processors pub attribute_processor_registry: AttributeProcessorRegistry, @@ -47,13 +51,15 @@ impl FrontContext { pub fn new() -> Self { let function_registry = ForeignFunctionRegistry::std(); let predicate_registry = ForeignPredicateRegistry::std(); + let aggregate_registry = AggregateRegistry::std(); let attribute_registry = AttributeProcessorRegistry::new(); - let analysis = Analysis::new(&function_registry, &predicate_registry); + let analysis = Analysis::new(&function_registry, &predicate_registry, &aggregate_registry); Self { sources: Sources::new(), items: Vec::new(), foreign_function_registry: function_registry, foreign_predicate_registry: predicate_registry, + foreign_aggregate_registry: aggregate_registry, attribute_processor_registry: attribute_registry, imported_files: HashSet::new(), node_id_annotator: NodeIdAnnotator::new(), @@ -357,7 +363,12 @@ impl FrontContext { } pub fn relations(&self) -> Vec { - self.type_inference().relations().into_iter().filter(|r| !self.is_hidden_relation(r)).collect() + self + .type_inference() + .relations() + .into_iter() + .filter(|r| !self.is_hidden_relation(r)) + .collect() } pub fn is_hidden_relation(&self, r: &str) -> bool { diff --git a/core/src/compiler/front/error.rs b/core/src/compiler/front/error.rs index 2167890..557cbdf 100644 --- a/core/src/compiler/front/error.rs +++ b/core/src/compiler/front/error.rs @@ -5,6 +5,7 @@ use crate::common::value_type::ValueParseError; use super::*; +#[derive(Clone, Debug)] pub enum FrontCompileErrorType { Warning, Error, @@ -126,6 +127,68 @@ impl FrontCompileError { } } +#[derive(Debug, Clone)] +pub struct FrontCompileErrorMessage { + pub error_type: FrontCompileErrorType, + pub parts: Vec, +} + +#[derive(Debug, Clone)] +pub enum FrontCompileErrorPart { + Message(String), + Source(NodeLocation), +} + +impl FrontCompileErrorMessage { + pub fn error() -> Self { + Self { + error_type: FrontCompileErrorType::Error, + parts: vec![], + } + } + + pub fn warning() -> Self { + Self { + error_type: FrontCompileErrorType::Warning, + parts: vec![], + } + } + + pub fn msg(mut self, s: S) -> Self { + self.parts.push(FrontCompileErrorPart::Message(s.to_string())); + self + } + + pub fn src(mut self, loc: NodeLocation) -> Self { + self.parts.push(FrontCompileErrorPart::Source(loc)); + self + } +} + +impl FrontCompileErrorTrait for FrontCompileErrorMessage { + fn error_type(&self) -> FrontCompileErrorType { + self.error_type.clone() + } + + fn report(&self, src: &Sources) -> String { + let mut whole_message = String::new(); + for (i, part) in self.parts.iter().enumerate() { + if i > 0 { + whole_message += "\n"; + } + match part { + FrontCompileErrorPart::Message(msg) => { + whole_message += msg; + } + FrontCompileErrorPart::Source(loc) => { + whole_message += &loc.report(src); + } + } + } + whole_message + } +} + pub trait FrontCompileErrorTrait: DynClone + std::fmt::Debug { /// Get the error type of this error (warning/error) fn error_type(&self) -> FrontCompileErrorType; diff --git a/core/src/compiler/front/f2b/f2b.rs b/core/src/compiler/front/f2b/f2b.rs index c999cb0..8d711f9 100644 --- a/core/src/compiler/front/f2b/f2b.rs +++ b/core/src/compiler/front/f2b/f2b.rs @@ -1,6 +1,5 @@ use std::collections::*; -use crate::common::aggregate_op::*; use crate::common::output_option::OutputOption; use crate::common::value_type::ValueType; use crate::compiler::back; @@ -10,7 +9,7 @@ use super::super::ast as front; use super::super::compile::*; use super::FlattenExprContext; -use front::{AstNode, NodeLocation, AstWalker}; +use front::{AstNode, AstWalker, NodeLocation}; impl FrontContext { pub fn to_back_program(&self) -> back::Program { @@ -47,6 +46,7 @@ impl FrontContext { rules, function_registry: self.foreign_function_registry.clone(), predicate_registry: self.foreign_predicate_registry.clone(), + aggregate_registry: self.foreign_aggregate_registry.clone(), adt_variant_registry: self.type_inference().create_adt_variant_registry(), } } @@ -469,56 +469,62 @@ impl FrontContext { .chain(to_agg_var_names.iter()) .cloned() .collect::>(); - let body_tys = self.type_inference().variable_types(src_rule_loc, body_args.iter()); - let body_terms = self.back_terms_with_types(body_args.clone(), body_tys.clone()); + let body_arg_tys = self.type_inference().variable_types(src_rule_loc, body_args.iter()); + let body_terms = self.back_terms_with_types(body_args.clone(), body_arg_tys.clone()); - // Get the reduce literal + // Get the variables for reduce literal let group_by_vars = self.back_vars(src_rule_loc, group_by_vars.into_iter().collect()); let other_group_by_vars = self.back_vars(src_rule_loc, other_group_by_vars); - let to_agg_vars = self.back_vars(src_rule_loc, to_agg_var_names.into_iter().collect()); - let left_vars = self.back_vars(src_rule_loc, agg_ctx.left_variable_names().into_iter().collect()); - let arg_vars = self.back_vars(src_rule_loc, arg_var_names.into_iter().collect()); + let to_agg_vars = self.back_vars(src_rule_loc, to_agg_var_names.clone()); + let left_vars = agg_ctx + .result_var_or_wildcards + .iter() + .map(|(loc, maybe_name)| { + if let Some(name) = maybe_name { + back::Variable::new(name.clone(), self.type_inference().variable_type(src_rule_loc, name)) + } else { + let name = format!("agg#wc#{}", loc.id.expect("[internal error] should have id")); + let ty = self + .type_inference() + .loc_value_type(loc) + .expect("[internal error] should have inferred type"); + back::Variable::new(name, ty) + } + }) + .collect(); + let arg_vars = self.back_vars(src_rule_loc, arg_var_names.iter().cloned().collect()); + + // Get the types of the variables for reduce literal + let left_var_types = self + .type_inference() + .loc_value_types(agg_ctx.result_var_or_wildcards.iter().map(|(l, _)| l)) + .expect("[internal error] detected result of reduce without inferred type"); + let arg_var_types = self.type_inference().variable_types(src_rule_loc, arg_var_names.iter()); + let to_agg_var_types = self + .type_inference() + .variable_types(src_rule_loc, to_agg_var_names.iter()); // Generate the internal aggregate operator - let op = match &agg_ctx.aggregate_op { - front::_ReduceOp::Count(discrete) => { - AggregateOp::Count { discrete: *discrete } - }, - front::_ReduceOp::Sum => { - assert_eq!( - left_vars.len(), - 1, - "[Internal Error] There should be only one var for summation" - ); - AggregateOp::Sum(left_vars[0].ty.clone()) - } - front::_ReduceOp::Prod => { - assert_eq!( - left_vars.len(), - 1, - "[Internal Error] There should be only one var for production" - ); - AggregateOp::Prod(left_vars[0].ty.clone()) - } - front::_ReduceOp::Min => AggregateOp::min(!arg_vars.is_empty()), - front::_ReduceOp::Max => AggregateOp::max(!arg_vars.is_empty()), - front::_ReduceOp::Exists => AggregateOp::Exists, - front::_ReduceOp::Unique => AggregateOp::top_k(1), - front::_ReduceOp::TopK(k) => AggregateOp::top_k(k.clone()), - front::_ReduceOp::CategoricalK(k) => AggregateOp::categorical_k(k.clone()), - front::_ReduceOp::Forall => { - panic!("[Internal Error] There should be no forall aggregator op. This is a bug"); - } - front::_ReduceOp::Unknown(_) => { - panic!("[Internal Error] There should be no unknown aggregator op. This is a bug"); - } - }; + let op = agg_ctx.aggregate_op.name.name().to_string(); + let params = agg_ctx + .aggregate_op + .parameters + .iter() + .map(|p| { + let value_type = self + .type_inference() + .expr_value_type(p) + .expect("[internal error] aggregate param not grounded with value type"); + p.to_value(&value_type) + }) + .collect::>(); + let has_exclamation_mark = agg_ctx.aggregate_op.has_exclaimation_mark; // Get the body to-aggregate relation let body_attr = back::AggregateBodyAttribute::new(op.clone(), group_by_vars.len(), arg_vars.len(), to_agg_vars.len()); let body_attrs = back::Attributes::singleton(body_attr); - let body_relation = back::Relation::new_with_attrs(body_attrs, body_predicate.clone(), body_tys.clone()); + let body_relation = back::Relation::new_with_attrs(body_attrs, body_predicate.clone(), body_arg_tys.clone()); temp_relations.push(body_relation); // Get the rules for body @@ -538,12 +544,21 @@ impl FrontContext { // Get the literal let body_atom = back::Atom::new(body_predicate.clone(), body_terms); let reduce_literal = back::Reduce::new( + // Aggregator op, + params, + has_exclamation_mark, + // types + left_var_types, + arg_var_types, + to_agg_var_types, + // Variables left_vars, group_by_vars, other_group_by_vars, arg_vars, to_agg_vars, + // Bodies body_atom, group_by_atom, ); @@ -560,11 +575,6 @@ impl FrontContext { .collect() } - // fn back_terms(&self, src_rule_loc: &NodeLocation, var_names: Vec) -> Vec { - // let var_tys = self.type_inference().variable_types(src_rule_loc, var_names.iter()); - // self.back_terms_with_types(var_names, var_tys) - // } - fn back_vars_with_types(&self, var_names: Vec, var_tys: Vec) -> Vec { var_names .into_iter() diff --git a/core/src/compiler/front/f2b/flatten_expr.rs b/core/src/compiler/front/f2b/flatten_expr.rs index 2421536..ea0b935 100644 --- a/core/src/compiler/front/f2b/flatten_expr.rs +++ b/core/src/compiler/front/f2b/flatten_expr.rs @@ -88,7 +88,10 @@ impl<'a> FlattenExprContext<'a> { } else if let Some(leaf) = self.leaf.get(loc) { leaf.clone() } else { - panic!("[Internal Error] Cannot find loc {:?} from the context, should not happen", loc) + panic!( + "[Internal Error] Cannot find loc {:?} from the context, should not happen", + loc + ) } } @@ -277,11 +280,7 @@ impl<'a> FlattenExprContext<'a> { let mut literals = vec![]; // First get the atom - let back_atom_args = neg_atom - .atom() - .iter_args() - .map(|a| self.get_expr_term(a)) - .collect(); + let back_atom_args = neg_atom.atom().iter_args().map(|a| self.get_expr_term(a)).collect(); let back_atom = back::NegAtom { atom: back::Atom { predicate: neg_atom.atom().formatted_predicate().clone(), @@ -335,7 +334,10 @@ impl<'a> FlattenExprContext<'a> { curr_literals } else { - panic!("[Internal Error] Cannot use `{}` for unary constraint", u.op().internal()); + panic!( + "[Internal Error] Cannot use `{}` for unary constraint", + u.op().internal() + ); } } diff --git a/core/src/compiler/front/grammar.lalrpop b/core/src/compiler/front/grammar.lalrpop index 17da069..8826815 100644 --- a/core/src/compiler/front/grammar.lalrpop +++ b/core/src/compiler/front/grammar.lalrpop @@ -495,40 +495,18 @@ VariableOrWildcard: VariableOrWildcard = { ReduceOp = Spanned<_ReduceOp>; _ReduceOp: _ReduceOp = { - "exists" => _ReduceOp::Exists, - "forall" => _ReduceOp::Forall, - => { - match n.name().as_str() { - "count" => _ReduceOp::Count(false), - "sum" => _ReduceOp::Sum, - "prod" => _ReduceOp::Prod, - "min" => _ReduceOp::Min, - "max" => _ReduceOp::Max, - "unique" => _ReduceOp::Unique, - x => _ReduceOp::Unknown(x.to_string()), - } + "exists" => { + _ReduceOp::new(Identifier::new_with_span("exists".to_string(), l, r), vec![], has_exclamation_mark.is_some()) }, - "!" => { - match n.name().as_str() { - "count" => _ReduceOp::Count(true), - x => _ReduceOp::Unknown(x.to_string()), - } + "forall" => { + _ReduceOp::new(Identifier::new_with_span("forall".to_string(), l, r), vec![], has_exclamation_mark.is_some()) + }, + => { + _ReduceOp::new(n, vec![], has_exclamation_mark.is_some()) + }, + "<" > ">" => { + _ReduceOp::new(n, cs, has_exclamation_mark.is_some()) }, - "<" ">" => { - match n.name().as_str() { - "top" => if k > 0 { - _ReduceOp::TopK(k as usize) - } else { - _ReduceOp::Unknown(format!("top<{}>", k)) - }, - "categorical" => if k > 0 { - _ReduceOp::CategoricalK(k as usize) - } else { - _ReduceOp::Unknown(format!("categorical<{}>", k)) - }, - x => _ReduceOp::Unknown(x.to_string()), - } - } } ReduceArgsFull: Vec = { @@ -567,8 +545,12 @@ _Reduce: _Reduce = { ForallExistsReduceOp = Spanned<_ForallExistsReduceOp>; _ForallExistsReduceOp: _ReduceOp = { - "exists" => _ReduceOp::Exists, - "forall" => _ReduceOp::Forall, + "exists" => { + _ReduceOp::new(Identifier::new_with_span("exists".to_string(), l, r), vec![], has_exclamation_mark.is_some()) + }, + "forall" => { + _ReduceOp::new(Identifier::new_with_span("forall".to_string(), l, r), vec![], has_exclamation_mark.is_some()) + }, } ForallExistsReduce = Spanned<_ForallExistsReduce>; diff --git a/core/src/compiler/front/pretty.rs b/core/src/compiler/front/pretty.rs index c5a9b4b..029d39c 100644 --- a/core/src/compiler/front/pretty.rs +++ b/core/src/compiler/front/pretty.rs @@ -116,9 +116,7 @@ impl Display for Query { impl Display for AttributeValue { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match &self { - AttributeValue::Constant(c) => { - c.fmt(f) - }, + AttributeValue::Constant(c) => c.fmt(f), AttributeValue::List(l) => { f.write_str("[")?; for (i, v) in l.iter().enumerate() { @@ -128,7 +126,7 @@ impl Display for AttributeValue { v.fmt(f)?; } f.write_str("]") - }, + } AttributeValue::Tuple(l) => { f.write_str("(")?; for (i, v) in l.iter().enumerate() { @@ -138,7 +136,7 @@ impl Display for AttributeValue { v.fmt(f)?; } f.write_str(")") - }, + } } } } @@ -155,7 +153,7 @@ impl Display for Attribute { match arg { AttributeArg::Pos(p) => { f.write_fmt(format_args!("{}", p))?; - }, + } AttributeArg::Kw(kw) => { f.write_fmt(format_args!("{} = {}", kw.name(), kw.value()))?; } @@ -278,11 +276,7 @@ impl Display for AlgebraicDataTypeDecl { impl Display for AlgebraicDataTypeVariant { fn fmt(&self, f: &mut Formatter<'_>) -> Result { let name = self.constructor_name(); - let args = self - .iter_args() - .map(|t| format!("{t}")) - .collect::>() - .join(", "); + let args = self.iter_args().map(|t| format!("{t}")).collect::>().join(", "); f.write_fmt(format_args!("{name}({args})")) } } @@ -546,7 +540,11 @@ impl Display for Disjunction { fn fmt(&self, f: &mut Formatter<'_>) -> Result { f.write_fmt(format_args!( "({})", - self.iter_args().map(|a| format!("{}", a)).collect::>().join(" or ") + self + .iter_args() + .map(|a| format!("{}", a)) + .collect::>() + .join(" or ") )) } } @@ -555,7 +553,11 @@ impl Display for Conjunction { fn fmt(&self, f: &mut Formatter<'_>) -> Result { f.write_fmt(format_args!( "({})", - self.iter_args().map(|a| format!("{}", a)).collect::>().join(" and ") + self + .iter_args() + .map(|a| format!("{}", a)) + .collect::>() + .join(" and ") )) } } @@ -574,6 +576,7 @@ impl Display for Constraint { impl Display for Reduce { fn fmt(&self, f: &mut Formatter<'_>) -> Result { + println!("Printing reduce"); if self.left().len() > 1 { f.write_fmt(format_args!( "({})", @@ -587,7 +590,7 @@ impl Display for Reduce { } else { Display::fmt(self.left().iter().next().unwrap(), f)?; } - f.write_str(" = ")?; + f.write_str(" := ")?; self.operator().fmt(f)?; if !self.args().is_empty() { f.write_fmt(format_args!( @@ -634,7 +637,21 @@ impl Display for ForallExistsReduce { impl Display for ReduceOp { fn fmt(&self, f: &mut Formatter<'_>) -> Result { - f.write_str(&self.to_string()) + f.write_str(self.name().name())?; + if self.has_parameters() { + f.write_str("<")?; + for (i, param) in self.iter_parameters().enumerate() { + if i > 0 { + f.write_str(", ")?; + } + param.fmt(f)?; + } + f.write_str(">")?; + } + if *self.has_exclaimation_mark() { + f.write_str("!")?; + } + Ok(()) } } diff --git a/core/src/compiler/front/transformations/adt_to_relation.rs b/core/src/compiler/front/transformations/adt_to_relation.rs index e42a6b2..c859bb2 100644 --- a/core/src/compiler/front/transformations/adt_to_relation.rs +++ b/core/src/compiler/front/transformations/adt_to_relation.rs @@ -31,7 +31,10 @@ impl<'a> TransformAlgebraicDataType<'a> { .adt_variants .iter() .map(|(variant_name, variant_info)| { - let rel_name = variant_info.name.clone_without_location_id().map(|n| format!("adt#{n}")); + let rel_name = variant_info + .name + .clone_without_location_id() + .map(|n| format!("adt#{n}")); // Generate the args including the first ID type let first_arg = Type::named(variant_info.belongs_to_type.name().to_string()); @@ -57,24 +60,17 @@ impl<'a> TransformAlgebraicDataType<'a> { .collect(); let adt_attr = Attribute::new( Identifier::new("adt".to_string()), - vec![ - AttributeValue::string(variant_name.clone()).into(), - is_entity.into(), - ], + vec![AttributeValue::string(variant_name.clone()).into(), is_entity.into()], ); // Generate another attribute `@hidden` let hidden_attr = Attribute::new(Identifier::new("hidden".to_string()), vec![]); // Generate a type declaration item - Item::TypeDecl( - TypeDecl::Relation( - RelationTypeDecl::new( - vec![adt_attr, hidden_attr], - vec![RelationType::new(rel_name, arg_types)], - ), - ), - ) + Item::TypeDecl(TypeDecl::Relation(RelationTypeDecl::new( + vec![adt_attr, hidden_attr], + vec![RelationType::new(rel_name, arg_types)], + ))) }) .collect(); diff --git a/core/src/compiler/front/transformations/conjunctive_head.rs b/core/src/compiler/front/transformations/conjunctive_head.rs index 2ab2471..897bfcd 100644 --- a/core/src/compiler/front/transformations/conjunctive_head.rs +++ b/core/src/compiler/front/transformations/conjunctive_head.rs @@ -33,21 +33,17 @@ impl NodeVisitor for TransformConjunctiveHead { match rule.head() { RuleHead::Conjunction(c) => { for atom in c.iter_atoms() { - self.to_add_items.push( - Item::RelationDecl( - RelationDecl::Rule( - RuleDecl::new( - Attributes::new(), - Tag::none(), - Rule::new_with_loc( - RuleHead::atom(atom.clone()), - rule.body().clone(), - rule.location().clone_without_id(), - ), - ) + self + .to_add_items + .push(Item::RelationDecl(RelationDecl::Rule(RuleDecl::new( + Attributes::new(), + Tag::none(), + Rule::new_with_loc( + RuleHead::atom(atom.clone()), + rule.body().clone(), + rule.location().clone_without_id(), ), - ), - ); + )))); } } _ => {} diff --git a/core/src/compiler/front/transformations/const_var_to_const.rs b/core/src/compiler/front/transformations/const_var_to_const.rs index dd9474e..2182ccf 100644 --- a/core/src/compiler/front/transformations/const_var_to_const.rs +++ b/core/src/compiler/front/transformations/const_var_to_const.rs @@ -18,31 +18,27 @@ impl<'a> TransformConstVarToConst<'a> { .entity_facts .iter() .map(|entity_fact| { - Item::RelationDecl( - RelationDecl::Fact( - FactDecl::new( - Attributes::new(), - Tag::new(DynamicInputTag::None), - Atom::new_with_loc( - { - entity_fact - .functor - .clone_without_location_id() - .map(|n| format!("adt#{n}")) - }, - vec![], - { - std::iter::once(&entity_fact.id) - .chain(entity_fact.args.iter()) - .cloned() - .map(Expr::Constant) - .collect() - }, - entity_fact.loc.clone(), - ), - ), + Item::RelationDecl(RelationDecl::Fact(FactDecl::new( + Attributes::new(), + Tag::new(DynamicInputTag::None), + Atom::new_with_loc( + { + entity_fact + .functor + .clone_without_location_id() + .map(|n| format!("adt#{n}")) + }, + vec![], + { + std::iter::once(&entity_fact.id) + .chain(entity_fact.args.iter()) + .cloned() + .map(Expr::Constant) + .collect() + }, + entity_fact.loc.clone(), ), - ) + ))) }) .collect() } diff --git a/core/src/compiler/front/transformations/desugar_arg_type_anno.rs b/core/src/compiler/front/transformations/desugar_arg_type_anno.rs index 9ba80c4..2dbd830 100644 --- a/core/src/compiler/front/transformations/desugar_arg_type_anno.rs +++ b/core/src/compiler/front/transformations/desugar_arg_type_anno.rs @@ -21,7 +21,11 @@ impl DesugarArgTypeAdornment { pub fn retain_relation(&mut self, rel_type: &RelationType, existing_attrs: &Vec) -> bool { if rel_type.has_adornment() { let demand_attr = Self::generate_demand_attribute(rel_type); - let attrs: Vec<_> = existing_attrs.iter().cloned().chain(std::iter::once(demand_attr)).collect(); + let attrs: Vec<_> = existing_attrs + .iter() + .cloned() + .chain(std::iter::once(demand_attr)) + .collect(); let item = Item::TypeDecl(TypeDecl::Relation(RelationTypeDecl::new(attrs, vec![rel_type.clone()]))); self.new_items.push(item); false diff --git a/core/src/compiler/front/transformations/desugar_case_is.rs b/core/src/compiler/front/transformations/desugar_case_is.rs index af3f153..132b634 100644 --- a/core/src/compiler/front/transformations/desugar_case_is.rs +++ b/core/src/compiler/front/transformations/desugar_case_is.rs @@ -13,22 +13,20 @@ impl DesugarCaseIs { match &case.entity() { Entity::Expr(e) => { // If the entity is directly an expression, the formula is a constraint - Formula::Constraint( - Constraint::new_with_loc( - Expr::binary( - BinaryExpr::new( - BinaryOp::new_eq(), - Expr::Variable(case.variable().clone()), - e.clone(), - ), - ), - case.location().clone() - ), - ) + Formula::Constraint(Constraint::new_with_loc( + Expr::binary(BinaryExpr::new( + BinaryOp::new_eq(), + Expr::Variable(case.variable().clone()), + e.clone(), + )), + case.location().clone(), + )) } Entity::Object(o) => { // If the entity is an object, the formula is a conjunction of atoms - let parent_id = case.location_id().expect("Case location id is not populated prior to desugar case is transformation"); + let parent_id = case + .location_id() + .expect("Case location id is not populated prior to desugar case is transformation"); let variable = case.variable().clone(); let mut variable_counter = IdAllocator::new(); let mut formulas = vec![]; @@ -85,9 +83,7 @@ impl DesugarCaseIs { impl NodeVisitor for DesugarCaseIs { fn visit_mut(&mut self, formula: &mut Formula) { match formula { - Formula::Case(c) => { - *formula = self.transform_case_is_to_formula(c) - }, + Formula::Case(c) => *formula = self.transform_case_is_to_formula(c), _ => {} } } diff --git a/core/src/compiler/front/transformations/desugar_forall_exists.rs b/core/src/compiler/front/transformations/desugar_forall_exists.rs index 57aa43d..6c7af44 100644 --- a/core/src/compiler/front/transformations/desugar_forall_exists.rs +++ b/core/src/compiler/front/transformations/desugar_forall_exists.rs @@ -47,15 +47,11 @@ impl NodeVisitor for DesugarForallExists { let reduce_formula = Formula::Reduce(reduce); // Create the constraint formula - let constraint = Constraint::new( - Expr::binary( - BinaryExpr::new( - BinaryOp::new_eq(), - Expr::variable(boolean_var.clone()), - Expr::constant(Constant::boolean(BoolLiteral::new(goal))), - ) - ) - ); + let constraint = Constraint::new(Expr::binary(BinaryExpr::new( + BinaryOp::new_eq(), + Expr::variable(boolean_var.clone()), + Expr::constant(Constant::boolean(BoolLiteral::new(goal))), + ))); let constraint_formula = Formula::Constraint(constraint); // Create the conjunction of the two diff --git a/core/src/compiler/front/transformations/desugar_range.rs b/core/src/compiler/front/transformations/desugar_range.rs index 75dd9ef..7c80222 100644 --- a/core/src/compiler/front/transformations/desugar_range.rs +++ b/core/src/compiler/front/transformations/desugar_range.rs @@ -32,20 +32,14 @@ impl NodeVisitor for DesugarRange { match r.inclusive() { true => vec![ r.lower().clone(), - Expr::binary( - BinaryExpr::new( - BinaryOp::new_add(), - r.upper().clone(), - Expr::Constant(Constant::integer(IntLiteral::new(1))), - ) - ), - Expr::Variable(r.left().clone()), - ], - false => vec![ - r.lower().clone(), - r.upper().clone(), + Expr::binary(BinaryExpr::new( + BinaryOp::new_add(), + r.upper().clone(), + Expr::Constant(Constant::integer(IntLiteral::new(1))), + )), Expr::Variable(r.left().clone()), ], + false => vec![r.lower().clone(), r.upper().clone(), Expr::Variable(r.left().clone())], }, r.location().clone(), ); diff --git a/core/src/compiler/front/transformations/forall_to_not_exists.rs b/core/src/compiler/front/transformations/forall_to_not_exists.rs index 90be3b9..e06536c 100644 --- a/core/src/compiler/front/transformations/forall_to_not_exists.rs +++ b/core/src/compiler/front/transformations/forall_to_not_exists.rs @@ -23,12 +23,9 @@ impl TransformForall { fn transform_forall(&mut self, r: &Reduce) -> Option { // First check if this reduce is a forall aggregation - if r.operator().is_forall() && r.left().len() == 1 { - // Get the left variable - let left_var = r.left()[0].clone(); - - // If the left variable is a wildcard, discard this transformation - if let VariableOrWildcard::Variable(left_var) = left_var { + if r.operator().name().name() == "forall" && r.num_left() == 1 { + // Unwrap ok since num_left is 1 + if let VariableOrWildcard::Variable(left_var) = &r.left().get(0).unwrap() { // Do the transformation match r.body() { Formula::Implies(i) => { @@ -37,7 +34,8 @@ impl TransformForall { let temp_var = Variable::new(Identifier::new(temp_var_name)); let not_temp_var = Expr::unary(UnaryExpr::new(UnaryOp::not(), Expr::Variable(temp_var.clone()))); let left_var_expr = Expr::Variable(left_var.clone()); - let left_var_eq_not_temp_var = Expr::binary(BinaryExpr::new(BinaryOp::new_eq(), left_var_expr, not_temp_var)); + let left_var_eq_not_temp_var = + Expr::binary(BinaryExpr::new(BinaryOp::new_eq(), left_var_expr, not_temp_var)); let constraint = Constraint::new(left_var_eq_not_temp_var); // Create exists aggregation literal @@ -46,24 +44,25 @@ impl TransformForall { i.location().clone_without_id(), )); let reduce = Reduce::new_with_loc( - vec![ - VariableOrWildcard::Variable(temp_var) - ], // Left - ReduceOp::exists_with_loc(r.operator().location().clone_without_id()), // Reduce op - r.args().clone(), // args - r.bindings().clone(), // bindings + vec![VariableOrWildcard::Variable(temp_var)], // Left + ReduceOp::new_with_loc( + Identifier::new_with_loc("exists".to_string(), r.operator().name().location().clone_without_id()), + vec![], + r.operator().has_exclaimation_mark().clone(), + r.operator().location().clone_without_id(), + ), // Reduce op + r.args().clone(), // args + r.bindings().clone(), // bindings left_and_not_right, r.group_by().clone(), i.location().clone_without_id(), ); // Conjunction of both - let result = Formula::Conjunction( - Conjunction::new_with_loc( - vec![Formula::Constraint(constraint), Formula::Reduce(reduce)], - r.location().clone_without_id(), - ), - ); + let result = Formula::Conjunction(Conjunction::new_with_loc( + vec![Formula::Constraint(constraint), Formula::Reduce(reduce)], + r.location().clone_without_id(), + )); Some(result) } _ => None, diff --git a/core/src/compiler/front/transformations/implies_to_disjunction.rs b/core/src/compiler/front/transformations/implies_to_disjunction.rs index ab96ab9..abdbffa 100644 --- a/core/src/compiler/front/transformations/implies_to_disjunction.rs +++ b/core/src/compiler/front/transformations/implies_to_disjunction.rs @@ -8,12 +8,10 @@ impl NodeVisitor for TransformImplies { fn visit_mut(&mut self, formula: &mut Formula) { match formula { Formula::Implies(i) => { - let rewrite = Formula::Disjunction( - Disjunction::new_with_loc( - vec![i.left().negate(), i.right().clone()], - i.location().clone(), - ) - ); + let rewrite = Formula::Disjunction(Disjunction::new_with_loc( + vec![i.left().negate(), i.right().clone()], + i.location().clone(), + )); *formula = rewrite; } _ => {} diff --git a/core/src/compiler/front/transformations/tagged_rule.rs b/core/src/compiler/front/transformations/tagged_rule.rs index 7f5f5bb..ca76b12 100644 --- a/core/src/compiler/front/transformations/tagged_rule.rs +++ b/core/src/compiler/front/transformations/tagged_rule.rs @@ -55,9 +55,7 @@ impl NodeVisitor for TransformTaggedRule { let pred = Self::transform(rule_decl); // Store this probability for later - self - .to_add_tags - .push((pred.clone(), rule_decl.tag().tag().clone())); + self.to_add_tags.push((pred.clone(), rule_decl.tag().tag().clone())); } else if Self::has_prob_attr(rule_decl) { // If the rule is annotated with `@probabilistic` Self::transform(rule_decl); diff --git a/core/src/compiler/ram/ast.rs b/core/src/compiler/ram/ast.rs index 91b6f9e..c056f9d 100644 --- a/core/src/compiler/ram/ast.rs +++ b/core/src/compiler/ram/ast.rs @@ -1,8 +1,8 @@ use std::collections::*; -use crate::common::adt_variant_registry::ADTVariantRegistry; -use crate::common::aggregate_op::AggregateOp; +use crate::common::adt_variant_registry::*; use crate::common::expr::*; +use crate::common::foreign_aggregate::*; use crate::common::foreign_function::*; use crate::common::foreign_predicate::*; use crate::common::input_file::InputFile; @@ -11,12 +11,14 @@ use crate::common::output_option::OutputOption; use crate::common::tuple::{AsTuple, Tuple}; use crate::common::tuple_type::TupleType; use crate::common::value::Value; +use crate::common::value_type::ValueType; #[derive(Debug, Clone)] pub struct Program { pub strata: Vec, pub function_registry: ForeignFunctionRegistry, pub predicate_registry: ForeignPredicateRegistry, + pub aggregate_registry: AggregateRegistry, pub adt_variant_registry: ADTVariantRegistry, pub relation_to_stratum: HashMap, } @@ -27,6 +29,7 @@ impl Program { strata: Vec::new(), function_registry: ForeignFunctionRegistry::new(), predicate_registry: ForeignPredicateRegistry::new(), + aggregate_registry: AggregateRegistry::new(), adt_variant_registry: ADTVariantRegistry::new(), relation_to_stratum: HashMap::new(), } @@ -297,9 +300,21 @@ impl Dataflow { Self::ForeignPredicateJoin(Box::new(self), predicate, args) } - pub fn reduce(op: AggregateOp, predicate: S, group_by: ReduceGroupByType) -> Self { + pub fn reduce( + aggregator: String, + params: Vec, + has_exclamation_mark: bool, + arg_var_types: Vec, + input_var_types: Vec, + predicate: S, + group_by: ReduceGroupByType, + ) -> Self { Self::Reduce(Reduce { - op, + aggregator, + params, + has_exclamation_mark, + arg_var_types, + input_var_types, predicate: predicate.to_string(), group_by, }) @@ -355,7 +370,11 @@ impl ReduceGroupByType { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct Reduce { - pub op: AggregateOp, + pub aggregator: String, + pub params: Vec, + pub has_exclamation_mark: bool, + pub arg_var_types: Vec, + pub input_var_types: Vec, pub predicate: String, pub group_by: ReduceGroupByType, } diff --git a/core/src/compiler/ram/pretty.rs b/core/src/compiler/ram/pretty.rs index dc3217b..3bfe0c6 100644 --- a/core/src/compiler/ram/pretty.rs +++ b/core/src/compiler/ram/pretty.rs @@ -138,9 +138,22 @@ impl Dataflow { ReduceGroupByType::Implicit => format!(" implicit group"), _ => format!(""), }; + let params = if !r.params.is_empty() { + format!( + "<{}>", + r.params.iter().map(|p| p.to_string()).collect::>().join(", ") + ) + } else { + format!("") + }; + let exclamation_mark = if r.has_exclamation_mark { + format!("!") + } else { + format!("") + }; f.write_fmt(format_args!( - "Aggregation {}({}{})", - r.op, r.predicate, group_by_predicate + "Aggregation {}{}{}({}{})", + r.aggregator, params, exclamation_mark, r.predicate, group_by_predicate )) } diff --git a/core/src/compiler/ram/ram2rs.rs b/core/src/compiler/ram/ram2rs.rs index dd03a56..814dd0c 100644 --- a/core/src/compiler/ram/ram2rs.rs +++ b/core/src/compiler/ram/ram2rs.rs @@ -3,7 +3,6 @@ use quote::{format_ident, quote}; use std::collections::*; use super::*; -use crate::common::aggregate_op::*; use crate::common::binary_op::*; use crate::common::expr::*; use crate::common::input_tag::*; @@ -66,7 +65,7 @@ impl ast::Program { quote! { &#dep_strat_name, } }) .collect::>(); - quote! { let #curr_strat_name = #curr_strat_run_name(ctx, &mut edb, #(#args)*); } + quote! { let #curr_strat_name = #curr_strat_run_name(&rt, ctx, &mut edb, #(#args)*); } }) .collect::>(); @@ -95,8 +94,8 @@ impl ast::Program { run_with_edb(ctx, ExtensionalDatabase::new()) } pub fn run_with_edb(ctx: &mut C, mut edb: ExtensionalDatabase) -> OutputRelations { - let runtime_env = RuntimeEnvironment::default(); - edb.internalize(&runtime_env, ctx); + let rt = RuntimeEnvironment::default(); + edb.internalize(&rt, ctx); #(#exec_strata)* #output_relations } @@ -235,6 +234,21 @@ impl ast::Stratum { .map(|update| update.to_rs_insert(id, rel_to_strat_map)) .collect::>(); + // 2.5. complete statements + let complete_stmts = self + .relations + .iter() + .map(|(_, r)| { + let field_name = relation_name_to_rs_field_name(&r.predicate); + quote! { + let #field_name = iter.complete(&#field_name); + } + // let relation_name = r.predicate.clone(); + // quote! { println!("{}: {:?}", #relation_name, #field_name); } + }) + .collect::>(); + let complete_stmts = quote! { #(#complete_stmts)* }; + // 3. Ensemble final result let ensemble_result_fields = self .relations @@ -242,7 +256,7 @@ impl ast::Stratum { .filter_map(|(_, r)| { if r.output.is_not_hidden() || inters_dep.contains(&r.predicate) { let rs_name = relation_name_to_rs_field_name(&r.predicate); - Some(quote! { #rs_name: iter.complete(&#rs_name), }) + Some(quote! { #rs_name, }) } else { None } @@ -252,13 +266,14 @@ impl ast::Stratum { // Final function quote! { - fn #fn_name(ctx: &mut C, edb: &mut ExtensionalDatabase, #(#args)*) -> #ret_ty { - let mut iter = StaticIteration::::new(ctx); + fn #fn_name(rt: &RuntimeEnvironment, ctx: &mut C, edb: &mut ExtensionalDatabase, #(#args)*) -> #ret_ty { + let mut iter = StaticIteration::::new(rt, ctx); #(#create_relation_stmts)* while iter.changed() || iter.is_first_iteration() { #(#updates)* iter.step(); } + #complete_stmts #ensemble_result } } @@ -366,21 +381,37 @@ impl ast::Dataflow { let to_agg_col = get_col(&r.predicate); // Get the aggregator - let agg = match &r.op { - AggregateOp::Count { discrete } => if *discrete { - unimplemented! {} - } else { - quote! { CountAggregator::new() } - }, - AggregateOp::Sum(_) => quote! { SumAggregator::new() }, - AggregateOp::Prod(_) => quote! { ProdAggregator::new() }, - AggregateOp::Max => quote! { MaxAggregator::new() }, - AggregateOp::Min => quote! { MinAggregator::new() }, - AggregateOp::Argmax => quote! { ArgmaxAggregator::new() }, - AggregateOp::Argmin => quote! { ArgminAggregator::new() }, - AggregateOp::Exists => quote! { ExistsAggregator::new() }, - AggregateOp::TopK(k) => quote! { TopKAggregator::new(#k) }, - AggregateOp::CategoricalK(_) => unimplemented! {}, + let non_multi_world = r.has_exclamation_mark; + let agg = match r.aggregator.as_str() { + "count" => quote! { CountAggregator::new(#non_multi_world) }, + "sum" => quote! { SumAggregator::new() }, + "prod" => quote! { ProdAggregator::new() }, + "max" => { + if r.arg_var_types.is_empty() { + quote! { MaxAggregator::new() } + } else { + unimplemented!("Implicit max with argument") + } + } + "min" => { + if r.arg_var_types.is_empty() { + quote! { MinAggregator::new() } + } else { + unimplemented!("Implicit max with argument") + } + } + "argmax" => { + quote! { ArgmaxAggregator::new() } + } + "argmin" => { + quote! { ArgminAggregator::new() } + } + "exists" => quote! { ExistsAggregator::new(#non_multi_world) }, + "top" => { + let k = r.params[0].as_usize(); + quote! { TopKAggregator::new(#k) } + } + _ => unimplemented!(), }; // Get the dataflow diff --git a/core/src/integrate/context.rs b/core/src/integrate/context.rs index 0a91a30..073b97c 100644 --- a/core/src/integrate/context.rs +++ b/core/src/integrate/context.rs @@ -27,6 +27,9 @@ pub struct IntegrateContext { /// This is for incremental compilation front_ctx: compiler::front::FrontContext, + /// The monitors + monitors: DynamicMonitors, + /// Flag denoting whether the Front-IR has changed or not; initialized to not changed. /// Once the front is compiled and stayed unchanged, no further analysis will be performed /// on the front compilation context @@ -50,6 +53,7 @@ impl IntegrateContext { ..Default::default() }), }, + monitors: DynamicMonitors::new(), front_has_changed: false, } } @@ -58,6 +62,7 @@ impl IntegrateContext { Self { options: compiler::CompileOptions::default(), front_ctx: compiler::front::FrontContext::new(), + monitors: DynamicMonitors::new(), internal: InternalIntegrateContext { prov_ctx, runtime_env: RuntimeEnvironment::default(), @@ -76,6 +81,7 @@ impl IntegrateContext { Self { options: options.compiler_options, front_ctx: compiler::front::FrontContext::new(), + monitors: DynamicMonitors::new(), internal: InternalIntegrateContext { prov_ctx, runtime_env: RuntimeEnvironment::default(), @@ -96,6 +102,7 @@ impl IntegrateContext { IntegrateContext { options: self.options.clone(), front_ctx: self.front_ctx.clone(), + monitors: DynamicMonitors::new(), internal: InternalIntegrateContext { prov_ctx: new_prov, runtime_env: self.internal.runtime_env.clone(), @@ -135,6 +142,12 @@ impl IntegrateContext { Ok(()) } + pub fn add_monitors(&mut self, monitors: &[&str]) { + let reg = MonitorRegistry::::std(); + let monitors = reg.load_monitors(&monitors); + self.monitors.extend(monitors) + } + /// Add a program string pub fn add_program(&mut self, program: &str) -> Result<(), IntegrateError> { let source = compiler::front::StringSource::new(program.to_string()); @@ -479,7 +492,11 @@ impl IntegrateContext { self.compile()?; // Finally execute the ram - self.internal.run() + if self.monitors.is_empty() { + self.internal.run() + } else { + self.internal.run_with_monitor(&self.monitors) + } } /// Get the relation type @@ -524,7 +541,11 @@ impl IntegrateContext { /// Get the relation output collection of a given relation pub fn computed_relation(&mut self, relation: &str) -> Option>> { - self.internal.computed_relation(relation) + if self.monitors.is_empty() { + self.internal.computed_relation(relation) + } else { + self.internal.computed_relation_with_monitor(relation, &self.monitors) + } } /// Get the relation output collection of a given relation @@ -671,6 +692,17 @@ impl InternalIntegrateContext { self.exec_ctx.relation_ref(relation) } + pub fn computed_relation_ref_with_monitor>( + &mut self, + relation: &str, + m: &M, + ) -> Option<&dynamic::DynamicOutputCollection> { + self + .exec_ctx + .recover_with_monitor(relation, &self.runtime_env, &self.prov_ctx, m); + self.exec_ctx.relation_ref(relation) + } + /// Get the RC'ed output collection of a given relation pub fn computed_relation(&mut self, relation: &str) -> Option>> { self.exec_ctx.recover(relation, &self.runtime_env, &self.prov_ctx); diff --git a/core/src/integrate/interpret.rs b/core/src/integrate/interpret.rs index 33c4258..d5ee1ae 100644 --- a/core/src/integrate/interpret.rs +++ b/core/src/integrate/interpret.rs @@ -126,11 +126,15 @@ impl InterpretContext {} OutputOption::Default => { // Recover + m.observe_recovering_relation(predicate); relation.recover_with_monitor(&self.runtime_env, &self.provenance, m, true); + m.observe_finish_recovering_relation(predicate); } OutputOption::File(f) => { // Recover and export the file + m.observe_recovering_relation(predicate); relation.recover_with_monitor(&self.runtime_env, &self.provenance, m, true); + m.observe_finish_recovering_relation(predicate); database::io::store_file(f, relation).map_err(IntegrateError::io)?; } } diff --git a/core/src/runtime/database/extensional/database.rs b/core/src/runtime/database/extensional/database.rs index 824a148..5dac44f 100644 --- a/core/src/runtime/database/extensional/database.rs +++ b/core/src/runtime/database/extensional/database.rs @@ -261,7 +261,9 @@ impl ExtensionalDatabase { } pub fn internalize_with_monitor>(&mut self, env: &RuntimeEnvironment, ctx: &Prov, m: &M) { - for (_, relation) in &mut self.extensional_relations { + for (name, relation) in &mut self.extensional_relations { + m.observe_loading_relation(name); + m.observe_loading_relation_from_edb(name); relation.internalize_with_monitor(env, ctx, m); } self.internalized = true diff --git a/core/src/runtime/database/intentional/database.rs b/core/src/runtime/database/intentional/database.rs index c66e0ad..e9cab6d 100644 --- a/core/src/runtime/database/intentional/database.rs +++ b/core/src/runtime/database/intentional/database.rs @@ -118,6 +118,7 @@ impl IntentionalDatabase { // !SPECIAL MONITORING! m.observe_recovering_relation(relation); r.recover_with_monitor(env, ctx, m, drain); + m.observe_finish_recovering_relation(relation); } } diff --git a/core/src/runtime/dynamic/aggregator/aggregator.rs b/core/src/runtime/dynamic/aggregator/aggregator.rs deleted file mode 100644 index 27881da..0000000 --- a/core/src/runtime/dynamic/aggregator/aggregator.rs +++ /dev/null @@ -1,113 +0,0 @@ -use crate::common::aggregate_op::AggregateOp; -use crate::common::value_type::*; -use crate::runtime::env::*; -use crate::runtime::provenance::*; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub enum DynamicAggregator { - Count(DynamicCount), - Sum(DynamicSum), - Prod(DynamicProd), - Min(DynamicMin), - Max(DynamicMax), - Argmin(DynamicArgmin), - Argmax(DynamicArgmax), - Exists(DynamicExists), - TopK(DynamicTopK), - CategoricalK(DynamicCategoricalK), -} - -impl From for DynamicAggregator { - fn from(o: AggregateOp) -> Self { - match o { - AggregateOp::Count { discrete } => Self::count(discrete), - AggregateOp::Sum(t) => Self::sum(t), - AggregateOp::Prod(t) => Self::prod(t), - AggregateOp::Min => Self::min(), - AggregateOp::Max => Self::max(), - AggregateOp::Argmin => Self::argmin(), - AggregateOp::Argmax => Self::argmax(), - AggregateOp::Exists => Self::exists(), - AggregateOp::TopK(k) => Self::top_k(k), - AggregateOp::CategoricalK(k) => Self::categorical_k(k), - } - } -} - -impl DynamicAggregator { - pub fn count(discrete: bool) -> Self { - Self::Count(DynamicCount { discrete }) - } - - pub fn sum(ty: ValueType) -> Self { - Self::Sum(DynamicSum(ty)) - } - - pub fn sum_with_ty() -> Self - where - ValueType: FromType, - { - Self::Sum(DynamicSum(>::from_type())) - } - - pub fn prod(ty: ValueType) -> Self { - Self::Prod(DynamicProd(ty)) - } - - pub fn prod_with_ty() -> Self - where - ValueType: FromType, - { - Self::Prod(DynamicProd(>::from_type())) - } - - pub fn min() -> Self { - Self::Min(DynamicMin) - } - - pub fn max() -> Self { - Self::Max(DynamicMax) - } - - pub fn argmin() -> Self { - Self::Argmin(DynamicArgmin) - } - - pub fn argmax() -> Self { - Self::Argmax(DynamicArgmax) - } - - pub fn exists() -> Self { - Self::Exists(DynamicExists) - } - - pub fn top_k(k: usize) -> Self { - Self::TopK(DynamicTopK(k)) - } - - pub fn categorical_k(k: usize) -> Self { - Self::CategoricalK(DynamicCategoricalK(k)) - } - - pub fn aggregate( - &self, - batch: DynamicElements, - ctx: &Prov, - rt: &RuntimeEnvironment, - ) -> DynamicElements { - match self { - Self::Count(c) => c.aggregate(batch, ctx), - Self::Sum(s) => s.aggregate(batch, ctx), - Self::Prod(p) => p.aggregate(batch, ctx), - Self::Min(m) => m.aggregate(batch, ctx), - Self::Max(m) => m.aggregate(batch, ctx), - Self::Argmin(m) => m.aggregate(batch, ctx), - Self::Argmax(m) => m.aggregate(batch, ctx), - Self::Exists(e) => e.aggregate(batch, ctx), - Self::TopK(t) => t.aggregate(batch, ctx), - Self::CategoricalK(c) => c.aggregate(batch, ctx, rt), - } - } -} diff --git a/core/src/runtime/dynamic/aggregator/argmax.rs b/core/src/runtime/dynamic/aggregator/argmax.rs deleted file mode 100644 index e252f91..0000000 --- a/core/src/runtime/dynamic/aggregator/argmax.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::runtime::provenance::*; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct DynamicArgmax; - -impl DynamicArgmax { - pub fn aggregate(&self, batch: DynamicElements, ctx: &Prov) -> DynamicElements { - ctx.dynamic_argmax(batch) - } -} diff --git a/core/src/runtime/dynamic/aggregator/argmin.rs b/core/src/runtime/dynamic/aggregator/argmin.rs deleted file mode 100644 index c35195a..0000000 --- a/core/src/runtime/dynamic/aggregator/argmin.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::runtime::provenance::*; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct DynamicArgmin; - -impl DynamicArgmin { - pub fn aggregate(&self, batch: DynamicElements, ctx: &Prov) -> DynamicElements { - ctx.dynamic_argmin(batch) - } -} diff --git a/core/src/runtime/dynamic/aggregator/categorical_k.rs b/core/src/runtime/dynamic/aggregator/categorical_k.rs deleted file mode 100644 index c1cb1fd..0000000 --- a/core/src/runtime/dynamic/aggregator/categorical_k.rs +++ /dev/null @@ -1,18 +0,0 @@ -use crate::runtime::env::*; -use crate::runtime::provenance::*; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct DynamicCategoricalK(pub usize); - -impl DynamicCategoricalK { - pub fn aggregate( - &self, - batch: DynamicElements, - ctx: &Prov, - rt: &RuntimeEnvironment, - ) -> DynamicElements { - ctx.dynamic_categorical_k(self.0, batch, rt) - } -} diff --git a/core/src/runtime/dynamic/aggregator/count.rs b/core/src/runtime/dynamic/aggregator/count.rs deleted file mode 100644 index 1dd1ab4..0000000 --- a/core/src/runtime/dynamic/aggregator/count.rs +++ /dev/null @@ -1,18 +0,0 @@ -use crate::runtime::provenance::*; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct DynamicCount { - pub discrete: bool, -} - -impl DynamicCount { - pub fn aggregate(&self, batch: DynamicElements, ctx: &Prov) -> DynamicElements { - if self.discrete { - ctx.dynamic_discrete_count(batch) - } else { - ctx.dynamic_count(batch) - } - } -} diff --git a/core/src/runtime/dynamic/aggregator/exists.rs b/core/src/runtime/dynamic/aggregator/exists.rs deleted file mode 100644 index 948d933..0000000 --- a/core/src/runtime/dynamic/aggregator/exists.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::runtime::provenance::*; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct DynamicExists; - -impl DynamicExists { - pub fn aggregate(&self, batch: DynamicElements, ctx: &Prov) -> DynamicElements { - ctx.dynamic_exists(batch) - } -} diff --git a/core/src/runtime/dynamic/aggregator/max.rs b/core/src/runtime/dynamic/aggregator/max.rs deleted file mode 100644 index 5b0024c..0000000 --- a/core/src/runtime/dynamic/aggregator/max.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::runtime::provenance::*; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct DynamicMax; - -impl DynamicMax { - pub fn aggregate(&self, batch: DynamicElements, ctx: &Prov) -> DynamicElements { - ctx.dynamic_max(batch) - } -} diff --git a/core/src/runtime/dynamic/aggregator/min.rs b/core/src/runtime/dynamic/aggregator/min.rs deleted file mode 100644 index 8ca2230..0000000 --- a/core/src/runtime/dynamic/aggregator/min.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::runtime::provenance::*; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct DynamicMin; - -impl DynamicMin { - pub fn aggregate(&self, batch: DynamicElements, ctx: &Prov) -> DynamicElements { - ctx.dynamic_min(batch) - } -} diff --git a/core/src/runtime/dynamic/aggregator/mod.rs b/core/src/runtime/dynamic/aggregator/mod.rs deleted file mode 100644 index 0cb4a8b..0000000 --- a/core/src/runtime/dynamic/aggregator/mod.rs +++ /dev/null @@ -1,25 +0,0 @@ -mod aggregator; -mod argmax; -mod argmin; -mod categorical_k; -mod count; -mod exists; -mod max; -mod min; -mod prod; -mod sum; -mod top_k; - -pub use aggregator::*; -pub use argmax::*; -pub use argmin::*; -pub use categorical_k::*; -pub use count::*; -pub use exists::*; -pub use max::*; -pub use min::*; -pub use prod::*; -pub use sum::*; -pub use top_k::*; - -use super::*; diff --git a/core/src/runtime/dynamic/aggregator/prod.rs b/core/src/runtime/dynamic/aggregator/prod.rs deleted file mode 100644 index ddd3733..0000000 --- a/core/src/runtime/dynamic/aggregator/prod.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::common::value_type::*; -use crate::runtime::provenance::*; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct DynamicProd(pub ValueType); - -impl DynamicProd { - pub fn aggregate(&self, batch: DynamicElements, ctx: &Prov) -> DynamicElements { - ctx.dynamic_prod(&self.0, batch) - } -} diff --git a/core/src/runtime/dynamic/aggregator/sum.rs b/core/src/runtime/dynamic/aggregator/sum.rs deleted file mode 100644 index 5187ccc..0000000 --- a/core/src/runtime/dynamic/aggregator/sum.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::common::value_type::*; -use crate::runtime::provenance::*; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct DynamicSum(pub ValueType); - -impl DynamicSum { - pub fn aggregate(&self, batch: DynamicElements, ctx: &Prov) -> DynamicElements { - ctx.dynamic_sum(&self.0, batch) - } -} diff --git a/core/src/runtime/dynamic/aggregator/top_k.rs b/core/src/runtime/dynamic/aggregator/top_k.rs deleted file mode 100644 index 3da07da..0000000 --- a/core/src/runtime/dynamic/aggregator/top_k.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::runtime::provenance::*; - -use super::*; - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct DynamicTopK(pub usize); - -impl DynamicTopK { - pub fn aggregate(&self, batch: DynamicElements, ctx: &Prov) -> DynamicElements { - ctx.dynamic_top_k(self.0, batch) - } -} diff --git a/core/src/runtime/dynamic/dataflow/aggregation/implicit_group.rs b/core/src/runtime/dynamic/dataflow/aggregation/implicit_group.rs index bb7f811..e107f94 100644 --- a/core/src/runtime/dynamic/dataflow/aggregation/implicit_group.rs +++ b/core/src/runtime/dynamic/dataflow/aggregation/implicit_group.rs @@ -3,7 +3,7 @@ use crate::common::tuple::*; use super::*; pub struct DynamicAggregationImplicitGroupDataflow<'a, Prov: Provenance> { - pub agg: DynamicAggregator, + pub agg: DynamicAggregator, pub d: DynamicDataflow<'a, Prov>, pub ctx: &'a Prov, pub runtime: &'a RuntimeEnvironment, @@ -21,7 +21,12 @@ impl<'a, Prov: Provenance> Clone for DynamicAggregationImplicitGroupDataflow<'a, } impl<'a, Prov: Provenance> DynamicAggregationImplicitGroupDataflow<'a, Prov> { - pub fn new(agg: DynamicAggregator, d: DynamicDataflow<'a, Prov>, ctx: &'a Prov, runtime: &'a RuntimeEnvironment) -> Self { + pub fn new( + agg: DynamicAggregator, + d: DynamicDataflow<'a, Prov>, + ctx: &'a Prov, + runtime: &'a RuntimeEnvironment, + ) -> Self { Self { agg, d, ctx, runtime } } } @@ -36,14 +41,13 @@ impl<'a, Prov: Provenance> Dataflow<'a, Prov> for DynamicAggregationImplicitGrou }; // Temporary function to aggregate the group and populate the result - let consolidate_group = - |result: &mut DynamicElements, agg_key: Tuple, agg_group| { - let agg_results = self.agg.aggregate(agg_group, self.ctx, self.runtime); - let joined_results = agg_results - .into_iter() - .map(|agg_result| DynamicElement::new((agg_key.clone(), agg_result.tuple.clone()), agg_result.tag)); - result.extend(joined_results); - }; + let consolidate_group = |result: &mut DynamicElements, agg_key: Tuple, agg_group| { + let agg_results = self.agg.aggregate(self.ctx, self.runtime, agg_group); + let joined_results = agg_results + .into_iter() + .map(|agg_result| DynamicElement::new((agg_key.clone(), agg_result.tuple.clone()), agg_result.tag)); + result.extend(joined_results); + }; // Get the first element from the batch; otherwise, return empty let first_elem = if let Some(e) = batch.next_elem() { diff --git a/core/src/runtime/dynamic/dataflow/aggregation/join_group.rs b/core/src/runtime/dynamic/dataflow/aggregation/join_group.rs index 2a4e861..7545847 100644 --- a/core/src/runtime/dynamic/dataflow/aggregation/join_group.rs +++ b/core/src/runtime/dynamic/dataflow/aggregation/join_group.rs @@ -3,7 +3,7 @@ use itertools::*; use super::*; pub struct DynamicAggregationJoinGroupDataflow<'a, Prov: Provenance> { - pub agg: DynamicAggregator, + pub agg: DynamicAggregator, pub d1: DynamicDataflow<'a, Prov>, pub d2: DynamicDataflow<'a, Prov>, pub ctx: &'a Prov, @@ -23,8 +23,20 @@ impl<'a, Prov: Provenance> Clone for DynamicAggregationJoinGroupDataflow<'a, Pro } impl<'a, Prov: Provenance> DynamicAggregationJoinGroupDataflow<'a, Prov> { - pub fn new(agg: DynamicAggregator, d1: DynamicDataflow<'a, Prov>, d2: DynamicDataflow<'a, Prov>, ctx: &'a Prov, runtime: &'a RuntimeEnvironment) -> Self { - Self { agg, d1, d2, ctx, runtime } + pub fn new( + agg: DynamicAggregator, + d1: DynamicDataflow<'a, Prov>, + d2: DynamicDataflow<'a, Prov>, + ctx: &'a Prov, + runtime: &'a RuntimeEnvironment, + ) -> Self { + Self { + agg, + d1, + d2, + ctx, + runtime, + } } } @@ -105,7 +117,7 @@ impl<'a, Prov: Provenance> Dataflow<'a, Prov> for DynamicAggregationJoinGroupDat .iter() .map(|e| DynamicElement::new(e.tuple[1].clone(), e.tag.clone())) .collect::>(); - let agg_results = self.agg.aggregate(to_agg_tups, self.ctx, self.runtime); + let agg_results = self.agg.aggregate(self.ctx, self.runtime, to_agg_tups); iproduct!(group_by_vals, agg_results) .map(|((tag, t1), agg_result)| { DynamicElement::new( diff --git a/core/src/runtime/dynamic/dataflow/aggregation/mod.rs b/core/src/runtime/dynamic/dataflow/aggregation/mod.rs index f4113e6..3150ba8 100644 --- a/core/src/runtime/dynamic/dataflow/aggregation/mod.rs +++ b/core/src/runtime/dynamic/dataflow/aggregation/mod.rs @@ -7,3 +7,5 @@ pub use join_group::*; pub use single_group::*; use super::*; + +use crate::common::foreign_aggregate::*; diff --git a/core/src/runtime/dynamic/dataflow/aggregation/single_group.rs b/core/src/runtime/dynamic/dataflow/aggregation/single_group.rs index a8a8ff6..1875be4 100644 --- a/core/src/runtime/dynamic/dataflow/aggregation/single_group.rs +++ b/core/src/runtime/dynamic/dataflow/aggregation/single_group.rs @@ -1,7 +1,7 @@ use super::*; pub struct DynamicAggregationSingleGroupDataflow<'a, Prov: Provenance> { - pub agg: DynamicAggregator, + pub agg: DynamicAggregator, pub d: DynamicDataflow<'a, Prov>, pub ctx: &'a Prov, pub runtime: &'a RuntimeEnvironment, @@ -19,7 +19,12 @@ impl<'a, Prov: Provenance> Clone for DynamicAggregationSingleGroupDataflow<'a, P } impl<'a, Prov: Provenance> DynamicAggregationSingleGroupDataflow<'a, Prov> { - pub fn new(agg: DynamicAggregator, d: DynamicDataflow<'a, Prov>, ctx: &'a Prov, runtime: &'a RuntimeEnvironment) -> Self { + pub fn new( + agg: DynamicAggregator, + d: DynamicDataflow<'a, Prov>, + ctx: &'a Prov, + runtime: &'a RuntimeEnvironment, + ) -> Self { Self { agg, d, ctx, runtime } } } @@ -28,7 +33,7 @@ impl<'a, Prov: Provenance> Dataflow<'a, Prov> for DynamicAggregationSingleGroupD fn iter_recent(&self) -> DynamicBatches<'a, Prov> { if let Some(b) = self.d.iter_recent().next_batch() { let batch = b.collect::>(); - DynamicBatches::single(ElementsBatch::new(self.agg.aggregate(batch, self.ctx, self.runtime))) + DynamicBatches::single(ElementsBatch::new(self.agg.aggregate(self.ctx, self.runtime, batch))) } else { DynamicBatches::empty() } diff --git a/core/src/runtime/dynamic/dataflow/batching/batch.rs b/core/src/runtime/dynamic/dataflow/batching/batch.rs index c54e0bb..063bd31 100644 --- a/core/src/runtime/dynamic/dataflow/batching/batch.rs +++ b/core/src/runtime/dynamic/dataflow/batching/batch.rs @@ -56,7 +56,10 @@ impl<'a, Prov: Provenance> DynamicBatch<'a, Prov> { } pub fn filter) -> bool + Clone + 'a>(self, f: F) -> Self { - Self(Box::new(FilterBatch { child: self, filter_fn: f })) + Self(Box::new(FilterBatch { + child: self, + filter_fn: f, + })) } } @@ -75,9 +78,7 @@ pub struct RefElementsBatch<'a, Prov: Provenance> { impl<'a, Prov: Provenance> RefElementsBatch<'a, Prov> { pub fn new(elements: &'a DynamicElements) -> Self { - Self { - iter: elements.iter(), - } + Self { iter: elements.iter() } } } diff --git a/core/src/runtime/dynamic/dataflow/batching/batches.rs b/core/src/runtime/dynamic/dataflow/batching/batches.rs index 66f45b3..6c9ab34 100644 --- a/core/src/runtime/dynamic/dataflow/batching/batches.rs +++ b/core/src/runtime/dynamic/dataflow/batching/batches.rs @@ -95,10 +95,7 @@ pub struct DynamicBatchesChain<'a, Prov: Provenance> { impl<'a, Prov: Provenance> DynamicBatchesChain<'a, Prov> { pub fn new(bs: Vec>) -> Self { - Self { - bs, - id: 0, - } + Self { bs, id: 0 } } } @@ -125,7 +122,11 @@ pub struct DynamicBatchesBinary<'a, Prov: Provenance> { } impl<'a, Prov: Provenance> DynamicBatchesBinary<'a, Prov> { - pub fn new>(mut b1: DynamicBatches<'a, Prov>, b2: DynamicBatches<'a, Prov>, op: Op) -> Self { + pub fn new>( + mut b1: DynamicBatches<'a, Prov>, + b2: DynamicBatches<'a, Prov>, + op: Op, + ) -> Self { let b1_curr = b1.next_batch(); let b2_source = b2.clone(); Self { diff --git a/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs b/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs index c6ae984..53ca874 100644 --- a/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs +++ b/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs @@ -80,11 +80,19 @@ impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { } pub fn project(self, expression: Expr, runtime: &'a RuntimeEnvironment) -> Self { - Self::new(DynamicProjectDataflow { source: self, expression, runtime }) + Self::new(DynamicProjectDataflow { + source: self, + expression, + runtime, + }) } pub fn filter(self, filter: Expr, runtime: &'a RuntimeEnvironment) -> Self { - Self::new(DynamicFilterDataflow { source: self, filter, runtime }) + Self::new(DynamicFilterDataflow { + source: self, + filter, + runtime, + }) } pub fn find(self, key: Tuple) -> Self { @@ -115,7 +123,13 @@ impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { Self::new(DynamicAntijoinDataflow { d1: self, d2, ctx }) } - pub fn foreign_predicate_ground(pred: String, bounded: Vec, first_iter: bool, ctx: &'a Prov, runtime: &'a RuntimeEnvironment) -> Self { + pub fn foreign_predicate_ground( + pred: String, + bounded: Vec, + first_iter: bool, + ctx: &'a Prov, + runtime: &'a RuntimeEnvironment, + ) -> Self { Self::new(ForeignPredicateGroundDataflow { foreign_predicate: pred, bounded_constants: bounded, @@ -125,7 +139,13 @@ impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { }) } - pub fn foreign_predicate_constraint(self, pred: String, args: Vec, ctx: &'a Prov, runtime: &'a RuntimeEnvironment) -> Self { + pub fn foreign_predicate_constraint( + self, + pred: String, + args: Vec, + ctx: &'a Prov, + runtime: &'a RuntimeEnvironment, + ) -> Self { Self::new(ForeignPredicateConstraintDataflow { dataflow: self, foreign_predicate: pred, @@ -135,7 +155,13 @@ impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { }) } - pub fn foreign_predicate_join(self, pred: String, args: Vec, ctx: &'a Prov, runtime: &'a RuntimeEnvironment) -> Self { + pub fn foreign_predicate_join( + self, + pred: String, + args: Vec, + ctx: &'a Prov, + runtime: &'a RuntimeEnvironment, + ) -> Self { Self::new(ForeignPredicateJoinDataflow { left: self, foreign_predicate: pred, diff --git a/core/src/runtime/dynamic/dataflow/dynamic_exclusion.rs b/core/src/runtime/dynamic/dataflow/dynamic_exclusion.rs index e8c940a..f2a2ccb 100644 --- a/core/src/runtime/dynamic/dataflow/dynamic_exclusion.rs +++ b/core/src/runtime/dynamic/dataflow/dynamic_exclusion.rs @@ -29,7 +29,12 @@ pub struct DynamicExclusionDataflow<'a, Prov: Provenance> { } impl<'a, Prov: Provenance> DynamicExclusionDataflow<'a, Prov> { - pub fn new(left: DynamicDataflow<'a, Prov>, right: DynamicDataflow<'a, Prov>, ctx: &'a Prov, runtime: &'a RuntimeEnvironment) -> Self { + pub fn new( + left: DynamicDataflow<'a, Prov>, + right: DynamicDataflow<'a, Prov>, + ctx: &'a Prov, + runtime: &'a RuntimeEnvironment, + ) -> Self { Self { left, right, diff --git a/core/src/runtime/dynamic/dataflow/dynamic_relation.rs b/core/src/runtime/dynamic/dataflow/dynamic_relation.rs index e0806f6..71615c9 100644 --- a/core/src/runtime/dynamic/dataflow/dynamic_relation.rs +++ b/core/src/runtime/dynamic/dataflow/dynamic_relation.rs @@ -152,11 +152,7 @@ impl<'a, Prov: Provenance> Batch<'a, Prov> for DynamicRelationRecentBatch<'a, Pr } } -fn search_ahead_variable_helper( - collection: &DynamicCollection, - elem_id: &mut usize, - mut cmp: F, -) -> bool +fn search_ahead_variable_helper(collection: &DynamicCollection, elem_id: &mut usize, mut cmp: F) -> bool where Prov: Provenance, F: FnMut(&Tuple) -> bool, diff --git a/core/src/runtime/dynamic/dataflow/filter.rs b/core/src/runtime/dynamic/dataflow/filter.rs index ab3708d..2656000 100644 --- a/core/src/runtime/dynamic/dataflow/filter.rs +++ b/core/src/runtime/dynamic/dataflow/filter.rs @@ -10,11 +10,19 @@ pub struct DynamicFilterDataflow<'a, Prov: Provenance> { impl<'a, Prov: Provenance> Dataflow<'a, Prov> for DynamicFilterDataflow<'a, Prov> { fn iter_stable(&self) -> DynamicBatches<'a, Prov> { - DynamicBatches::new(DynamicFilterBatches { runtime: self.runtime, source: self.source.iter_stable(), filter: self.filter.clone() }) + DynamicBatches::new(DynamicFilterBatches { + runtime: self.runtime, + source: self.source.iter_stable(), + filter: self.filter.clone(), + }) } fn iter_recent(&self) -> DynamicBatches<'a, Prov> { - DynamicBatches::new(DynamicFilterBatches { runtime: self.runtime, source: self.source.iter_recent(), filter: self.filter.clone() }) + DynamicBatches::new(DynamicFilterBatches { + runtime: self.runtime, + source: self.source.iter_recent(), + filter: self.filter.clone(), + }) } } diff --git a/core/src/runtime/dynamic/dataflow/find.rs b/core/src/runtime/dynamic/dataflow/find.rs index 2a1229d..9c277e2 100644 --- a/core/src/runtime/dynamic/dataflow/find.rs +++ b/core/src/runtime/dynamic/dataflow/find.rs @@ -9,11 +9,17 @@ pub struct DynamicFindDataflow<'a, Prov: Provenance> { impl<'a, Prov: Provenance> Dataflow<'a, Prov> for DynamicFindDataflow<'a, Prov> { fn iter_stable(&self) -> DynamicBatches<'a, Prov> { - DynamicBatches::new(DynamicFindBatches { source: self.source.iter_stable(), key: self.key.clone() }) + DynamicBatches::new(DynamicFindBatches { + source: self.source.iter_stable(), + key: self.key.clone(), + }) } fn iter_recent(&self) -> DynamicBatches<'a, Prov> { - DynamicBatches::new(DynamicFindBatches { source: self.source.iter_recent(), key: self.key.clone() }) + DynamicBatches::new(DynamicFindBatches { + source: self.source.iter_recent(), + key: self.key.clone(), + }) } } @@ -49,17 +55,15 @@ impl<'a, Prov: Provenance> Batch<'a, Prov> for DynamicFindBatch<'a, Prov> { let key = self.key.clone(); loop { match &self.curr_elem { - Some(elem) => { - match elem.tuple[0].cmp(&key) { - Ordering::Less => self.curr_elem = self.source.search_elem_0_until(&key), - Ordering::Equal => { - let result = elem.clone(); - self.curr_elem = self.source.next_elem(); - return Some(result); - } - Ordering::Greater => return None, + Some(elem) => match elem.tuple[0].cmp(&key) { + Ordering::Less => self.curr_elem = self.source.search_elem_0_until(&key), + Ordering::Equal => { + let result = elem.clone(); + self.curr_elem = self.source.next_elem(); + return Some(result); } - } + Ordering::Greater => return None, + }, None => return None, } } diff --git a/core/src/runtime/dynamic/dataflow/foreign_predicate/constraint.rs b/core/src/runtime/dynamic/dataflow/foreign_predicate/constraint.rs index 472fecd..490ab5f 100644 --- a/core/src/runtime/dynamic/dataflow/foreign_predicate/constraint.rs +++ b/core/src/runtime/dynamic/dataflow/foreign_predicate/constraint.rs @@ -19,12 +19,13 @@ pub struct ForeignPredicateConstraintDataflow<'a, Prov: Provenance> { pub ctx: &'a Prov, /// Runtime environment - pub runtime: &'a RuntimeEnvironment + pub runtime: &'a RuntimeEnvironment, } impl<'a, Prov: Provenance> Dataflow<'a, Prov> for ForeignPredicateConstraintDataflow<'a, Prov> { fn iter_stable(&self) -> DynamicBatches<'a, Prov> { - let fp = self.runtime + let fp = self + .runtime .predicate_registry .get(&self.foreign_predicate) .expect("Foreign predicate not found"); @@ -38,7 +39,8 @@ impl<'a, Prov: Provenance> Dataflow<'a, Prov> for ForeignPredicateConstraintData } fn iter_recent(&self) -> DynamicBatches<'a, Prov> { - let fp = self.runtime + let fp = self + .runtime .predicate_registry .get(&self.foreign_predicate) .expect("Foreign predicate not found"); diff --git a/core/src/runtime/dynamic/dataflow/foreign_predicate/join.rs b/core/src/runtime/dynamic/dataflow/foreign_predicate/join.rs index d06a201..aa16a91 100644 --- a/core/src/runtime/dynamic/dataflow/foreign_predicate/join.rs +++ b/core/src/runtime/dynamic/dataflow/foreign_predicate/join.rs @@ -20,7 +20,7 @@ pub struct ForeignPredicateJoinDataflow<'a, Prov: Provenance> { pub ctx: &'a Prov, /// Runtime environment - pub runtime: &'a RuntimeEnvironment + pub runtime: &'a RuntimeEnvironment, } impl<'a, Prov: Provenance> Clone for ForeignPredicateJoinDataflow<'a, Prov> { @@ -39,7 +39,8 @@ impl<'a, Prov: Provenance> Dataflow<'a, Prov> for ForeignPredicateJoinDataflow<' fn iter_stable(&self) -> DynamicBatches<'a, Prov> { DynamicBatches::new(ForeignPredicateJoinBatches { batches: self.left.iter_stable(), - foreign_predicate: self.runtime + foreign_predicate: self + .runtime .predicate_registry .get(&self.foreign_predicate) .expect("Foreign predicate not found") @@ -53,7 +54,8 @@ impl<'a, Prov: Provenance> Dataflow<'a, Prov> for ForeignPredicateJoinDataflow<' fn iter_recent(&self) -> DynamicBatches<'a, Prov> { DynamicBatches::new(ForeignPredicateJoinBatches { batches: self.left.iter_recent(), - foreign_predicate: self.runtime + foreign_predicate: self + .runtime .predicate_registry .get(&self.foreign_predicate) .expect("Foreign predicate not found") diff --git a/core/src/runtime/dynamic/dataflow/overwrite_one.rs b/core/src/runtime/dynamic/dataflow/overwrite_one.rs index ec71071..db80c7d 100644 --- a/core/src/runtime/dynamic/dataflow/overwrite_one.rs +++ b/core/src/runtime/dynamic/dataflow/overwrite_one.rs @@ -16,11 +16,17 @@ impl<'a, Prov: Provenance> Clone for DynamicOverwriteOneDataflow<'a, Prov> { impl<'a, Prov: Provenance> Dataflow<'a, Prov> for DynamicOverwriteOneDataflow<'a, Prov> { fn iter_stable(&self) -> DynamicBatches<'a, Prov> { - DynamicBatches::new(DynamicOverwriteOneBatches { source: self.source.iter_stable(), ctx: self.ctx }) + DynamicBatches::new(DynamicOverwriteOneBatches { + source: self.source.iter_stable(), + ctx: self.ctx, + }) } fn iter_recent(&self) -> DynamicBatches<'a, Prov> { - DynamicBatches::new(DynamicOverwriteOneBatches { source: self.source.iter_recent(), ctx: self.ctx }) + DynamicBatches::new(DynamicOverwriteOneBatches { + source: self.source.iter_recent(), + ctx: self.ctx, + }) } } diff --git a/core/src/runtime/dynamic/dataflow/project.rs b/core/src/runtime/dynamic/dataflow/project.rs index cadbd3f..72d7e8a 100644 --- a/core/src/runtime/dynamic/dataflow/project.rs +++ b/core/src/runtime/dynamic/dataflow/project.rs @@ -10,11 +10,19 @@ pub struct DynamicProjectDataflow<'a, Prov: Provenance> { impl<'a, Prov: Provenance> Dataflow<'a, Prov> for DynamicProjectDataflow<'a, Prov> { fn iter_stable(&self) -> DynamicBatches<'a, Prov> { - DynamicBatches::new(DynamicProjectBatches::new(self.runtime, self.source.iter_stable(), self.expression.clone())) + DynamicBatches::new(DynamicProjectBatches::new( + self.runtime, + self.source.iter_stable(), + self.expression.clone(), + )) } fn iter_recent(&self) -> DynamicBatches<'a, Prov> { - DynamicBatches::new(DynamicProjectBatches::new(self.runtime, self.source.iter_recent(), self.expression.clone())) + DynamicBatches::new(DynamicProjectBatches::new( + self.runtime, + self.source.iter_recent(), + self.expression.clone(), + )) } } @@ -27,7 +35,11 @@ pub struct DynamicProjectBatches<'a, Prov: Provenance> { impl<'a, Prov: Provenance> DynamicProjectBatches<'a, Prov> { pub fn new(runtime: &'a RuntimeEnvironment, source: DynamicBatches<'a, Prov>, expression: Expr) -> Self { - Self { runtime, source, expression } + Self { + runtime, + source, + expression, + } } } diff --git a/core/src/runtime/dynamic/incremental.rs b/core/src/runtime/dynamic/incremental.rs index 298189d..0c03386 100644 --- a/core/src/runtime/dynamic/incremental.rs +++ b/core/src/runtime/dynamic/incremental.rs @@ -320,7 +320,9 @@ impl DynamicExecutionContext { where M: Monitor, { - self.incremental_execute_with_monitor_helper(None, runtime, ctx, m) + self.incremental_execute_with_monitor_helper(None, runtime, ctx, m)?; + m.observe_finish_execution(); + Ok(()) } pub fn incremental_execute_with_monitor( diff --git a/core/src/runtime/dynamic/iteration.rs b/core/src/runtime/dynamic/iteration.rs index 05eb5e3..6ec4290 100644 --- a/core/src/runtime/dynamic/iteration.rs +++ b/core/src/runtime/dynamic/iteration.rs @@ -307,9 +307,11 @@ impl<'a, Prov: Provenance> DynamicIteration<'a, Prov> { .foreign_predicate_join(p.clone(), a.clone(), ctx, env) } ram::Dataflow::OverwriteOne(d) => self.build_dynamic_dataflow(env, ctx, d).overwrite_one(ctx), - ram::Dataflow::Exclusion(d1, d2) => self - .build_dynamic_dataflow(env, ctx, d1) - .dynamic_exclusion(self.build_dynamic_dataflow(env, ctx, d2), ctx, env), + ram::Dataflow::Exclusion(d1, d2) => { + self + .build_dynamic_dataflow(env, ctx, d1) + .dynamic_exclusion(self.build_dynamic_dataflow(env, ctx, d2), ctx, env) + } ram::Dataflow::Filter(d, e) => { let internal_filter = env .internalize_expr(e) @@ -359,7 +361,16 @@ impl<'a, Prov: Provenance> DynamicIteration<'a, Prov> { r1.antijoin(r2, ctx) } ram::Dataflow::Reduce(a) => { - let op = a.op.clone().into(); + let op = env + .aggregate_registry + .instantiate_aggregator::( + &a.aggregator, + a.params.clone(), + a.has_exclamation_mark, + a.arg_var_types.clone(), + a.input_var_types.clone(), + ) + .expect(&format!("cannot instantiate aggregator `{}`", a.aggregator)); match &a.group_by { ram::ReduceGroupByType::None => { let c = self.build_dynamic_collection(&a.predicate); diff --git a/core/src/runtime/dynamic/mod.rs b/core/src/runtime/dynamic/mod.rs index 0123942..19fe448 100644 --- a/core/src/runtime/dynamic/mod.rs +++ b/core/src/runtime/dynamic/mod.rs @@ -1,4 +1,3 @@ -mod aggregator; mod collection; pub mod dataflow; mod element; @@ -7,7 +6,6 @@ mod iteration; mod output_collection; mod relation; -pub use aggregator::*; pub use collection::*; pub use element::*; pub use incremental::*; diff --git a/core/src/runtime/dynamic/relation.rs b/core/src/runtime/dynamic/relation.rs index a6ea95c..1644096 100644 --- a/core/src/runtime/dynamic/relation.rs +++ b/core/src/runtime/dynamic/relation.rs @@ -102,7 +102,10 @@ impl DynamicRelation { .map(|(info, tuple)| DynamicElement::new(tuple.into(), ctx.tagging_optional_fn(info))) .collect::>(); - self.to_add.borrow_mut().push(DynamicCollection::from_vec(elements, ctx)); + self + .to_add + .borrow_mut() + .push(DynamicCollection::from_vec(elements, ctx)); } pub fn insert_tagged_with_monitor(&self, ctx: &Prov, data: Vec<(Option>, Tup)>, m: &M) @@ -120,7 +123,10 @@ impl DynamicRelation { }) .collect::>(); - self.to_add.borrow_mut().push(DynamicCollection::from_vec(elements, ctx)); + self + .to_add + .borrow_mut() + .push(DynamicCollection::from_vec(elements, ctx)); } pub fn num_stable(&self) -> usize { @@ -221,7 +227,12 @@ impl DynamicRelation { !self.recent.borrow().is_empty() } - pub fn insert_dataflow_recent<'a>(&'a self, ctx: &'a Prov, d: &DynamicDataflow<'a, Prov>, runtime: &'a RuntimeEnvironment) { + pub fn insert_dataflow_recent<'a>( + &'a self, + ctx: &'a Prov, + d: &DynamicDataflow<'a, Prov>, + runtime: &'a RuntimeEnvironment, + ) { for batch in d.iter_recent() { let data = if runtime.early_discard { batch.filter(move |e| !ctx.discard(&e.tag)).collect::>() @@ -232,7 +243,12 @@ impl DynamicRelation { } } - pub fn insert_dataflow_stable<'a>(&'a self, ctx: &'a Prov, d: &DynamicDataflow<'a, Prov>, runtime: &'a RuntimeEnvironment) { + pub fn insert_dataflow_stable<'a>( + &'a self, + ctx: &'a Prov, + d: &DynamicDataflow<'a, Prov>, + runtime: &'a RuntimeEnvironment, + ) { for batch in d.iter_stable() { let data = if runtime.early_discard { batch.filter(move |e| !ctx.discard(&e.tag)).collect::>() diff --git a/core/src/runtime/env/environment.rs b/core/src/runtime/env/environment.rs index 3a913a1..f95c320 100644 --- a/core/src/runtime/env/environment.rs +++ b/core/src/runtime/env/environment.rs @@ -3,6 +3,7 @@ use std::collections::*; use crate::common::constants::*; use crate::common::entity; use crate::common::expr::*; +use crate::common::foreign_aggregate::*; use crate::common::foreign_function::*; use crate::common::foreign_predicate::*; use crate::common::foreign_tensor; @@ -34,6 +35,9 @@ pub struct RuntimeEnvironment { /// Foreign predicate registry pub predicate_registry: ForeignPredicateRegistry, + /// Foreign aggregate registry + pub aggregate_registry: AggregateRegistry, + /// Mutual exclusion ID allocator pub exclusion_id_allocator: IdAllocator2, @@ -62,6 +66,7 @@ impl RuntimeEnvironment { iter_limit: None, function_registry: ForeignFunctionRegistry::std(), predicate_registry: ForeignPredicateRegistry::std(), + aggregate_registry: AggregateRegistry::std(), exclusion_id_allocator: IdAllocator2::new(), symbol_registry: SymbolRegistry2::new(), dynamic_entity_store: DynamicEntityStorage2::new(), @@ -77,6 +82,7 @@ impl RuntimeEnvironment { iter_limit: None, function_registry: ForeignFunctionRegistry::std(), predicate_registry: ForeignPredicateRegistry::std(), + aggregate_registry: AggregateRegistry::std(), exclusion_id_allocator: IdAllocator2::new(), symbol_registry: SymbolRegistry2::new(), dynamic_entity_store: DynamicEntityStorage2::new(), @@ -84,7 +90,7 @@ impl RuntimeEnvironment { } } - pub fn new(ffr: ForeignFunctionRegistry, fpr: ForeignPredicateRegistry) -> Self { + pub fn new(ffr: ForeignFunctionRegistry, fpr: ForeignPredicateRegistry, far: AggregateRegistry) -> Self { Self { random_seed: DEFAULT_RANDOM_SEED, random: Random::new(DEFAULT_RANDOM_SEED), @@ -92,6 +98,7 @@ impl RuntimeEnvironment { iter_limit: None, function_registry: ffr, predicate_registry: fpr, + aggregate_registry: far, exclusion_id_allocator: IdAllocator2::new(), symbol_registry: SymbolRegistry2::new(), dynamic_entity_store: DynamicEntityStorage2::new(), @@ -107,6 +114,7 @@ impl RuntimeEnvironment { iter_limit: None, function_registry: ffr, predicate_registry: ForeignPredicateRegistry::std(), + aggregate_registry: AggregateRegistry::std(), exclusion_id_allocator: IdAllocator2::new(), symbol_registry: SymbolRegistry2::new(), dynamic_entity_store: DynamicEntityStorage2::new(), @@ -122,6 +130,7 @@ impl RuntimeEnvironment { iter_limit: options.iter_limit, function_registry: ForeignFunctionRegistry::std(), predicate_registry: ForeignPredicateRegistry::std(), + aggregate_registry: AggregateRegistry::std(), exclusion_id_allocator: IdAllocator2::new(), symbol_registry: SymbolRegistry2::new(), dynamic_entity_store: DynamicEntityStorage2::new(), @@ -148,6 +157,7 @@ impl RuntimeEnvironment { pub fn load_from_ram_program(&mut self, ram_program: &ram::Program) { self.function_registry = ram_program.function_registry.clone(); self.predicate_registry = ram_program.predicate_registry.clone(); + self.aggregate_registry = ram_program.aggregate_registry.clone(); self .dynamic_entity_store .update_variant_registry(ram_program.adt_variant_registry.clone()); @@ -212,9 +222,11 @@ impl RuntimeEnvironment { pub fn externalize_tuple(&self, tup: &Tuple) -> Option { match tup { - Tuple::Tuple(ts) => { - Some(Tuple::Tuple(ts.iter().map(|t| self.externalize_tuple(t)).collect::>>()?)) - }, + Tuple::Tuple(ts) => Some(Tuple::Tuple( + ts.iter() + .map(|t| self.externalize_tuple(t)) + .collect::>>()?, + )), Tuple::Value(v) => self.externalize_value(v).map(Tuple::Value), } } diff --git a/core/src/runtime/env/random.rs b/core/src/runtime/env/random.rs index a6027e8..efb1c73 100644 --- a/core/src/runtime/env/random.rs +++ b/core/src/runtime/env/random.rs @@ -1,5 +1,6 @@ use std::sync::*; +use rand::prelude::*; use rand::rngs::SmallRng; use rand::SeedableRng; @@ -20,4 +21,9 @@ impl Random { pub fn sample_from>(&self, dist: &D) -> T { dist.sample(&mut *self.rng.lock().unwrap()) } + + /// Sample a number between 0 and `n` (exclusive) + pub fn random_usize(&self, n: usize) -> usize { + self.rng.lock().unwrap().gen_range(0..n) + } } diff --git a/core/src/runtime/monitor/debug_runtime.rs b/core/src/runtime/monitor/debug_runtime.rs index b3c976b..b759eaf 100644 --- a/core/src/runtime/monitor/debug_runtime.rs +++ b/core/src/runtime/monitor/debug_runtime.rs @@ -2,9 +2,14 @@ use crate::runtime::provenance::Provenance; use super::*; +#[derive(Clone)] pub struct DebugRuntimeMonitor; impl Monitor for DebugRuntimeMonitor { + fn name(&self) -> &'static str { + "debug-runtime" + } + fn observe_executing_stratum(&self, stratum_id: usize) { println!("[Executing Stratum #{}]", stratum_id) } diff --git a/core/src/runtime/monitor/debug_tags.rs b/core/src/runtime/monitor/debug_tags.rs index 86d6d7a..570a6b4 100644 --- a/core/src/runtime/monitor/debug_tags.rs +++ b/core/src/runtime/monitor/debug_tags.rs @@ -3,9 +3,14 @@ use crate::runtime::provenance::Provenance; use super::*; +#[derive(Clone)] pub struct DebugTagsMonitor; impl Monitor for DebugTagsMonitor { + fn name(&self) -> &'static str { + "debug-tags" + } + fn observe_loading_relation(&self, relation: &str) { println!("[Tagging Relation] {}", relation) } diff --git a/core/src/runtime/monitor/dump_proofs.rs b/core/src/runtime/monitor/dump_proofs.rs new file mode 100644 index 0000000..ae18885 --- /dev/null +++ b/core/src/runtime/monitor/dump_proofs.rs @@ -0,0 +1,171 @@ +use std::collections::*; +use std::env; +use std::fs; + +use crate::common::foreign_tensor::*; +use crate::common::tuple::*; +use crate::runtime::provenance::*; +use crate::utils::*; + +use super::*; + +#[derive(Clone)] +pub struct DumpProofsInternal { + current_tagging_relation: Option, + tagged_tuples: HashMap, String)>>, + current_recovering_relation: Option, + recovered_tuples: HashMap>)>>, +} + +pub struct DumpProofsMonitor { + internal: ::RcCell, +} + +impl DumpProofsMonitor { + pub fn new() -> Self { + Self { + internal: ::new_rc_cell(DumpProofsInternal { + current_tagging_relation: None, + tagged_tuples: HashMap::new(), + current_recovering_relation: None, + recovered_tuples: HashMap::new(), + }), + } + } + + pub fn set_current_tagging_relation(&self, relation: &str) { + ArcFamily::get_rc_cell_mut(&self.internal, |c| { + c.current_tagging_relation = Some(relation.to_string()); + }) + } + + pub fn record_tag(&self, tup: &Tuple, prob: Option, tag: &DNFFormula) { + ArcFamily::get_rc_cell_mut(&self.internal, |c| { + if let Some(r) = &c.current_tagging_relation { + if let Some(id) = tag.get_singleton_id() { + c.tagged_tuples + .entry(r.to_string()) + .or_default() + .push((id, prob, tup.to_string())); + } + } + }) + } + + pub fn set_current_recovering_relation(&self, relation: &str) { + ArcFamily::get_rc_cell_mut(&self.internal, |c| { + c.current_recovering_relation = Some(relation.to_string()); + }) + } + + pub fn record_recover(&self, tup: &Tuple, prob: f64, tag: &DNFFormula) { + ArcFamily::get_rc_cell_mut(&self.internal, |c| { + if let Some(r) = &c.current_recovering_relation { + let proofs = tag + .clauses + .iter() + .map(|clause| { + clause + .literals + .iter() + .map(|literal| match literal { + Literal::Pos(id) => (true, *id), + Literal::Neg(id) => (false, *id), + }) + .collect::>() + }) + .collect::>(); + c.recovered_tuples + .entry(r.to_string()) + .or_default() + .push((tup.to_string(), prob, proofs)); + } + }) + } + + pub fn dump_relation(&self, relation: &str) { + ArcFamily::get_rc_cell_mut(&self.internal, |c| { + let dir = env::var("SCALLOP_DUMP_PROOFS_DIR").unwrap_or(".tmp/dumped-tuples".to_string()); + for (relation, tuples) in &c.tagged_tuples { + let js = serde_json::to_string(&tuples).expect("Cannot serialize tuples"); + fs::write(&format!("{dir}/{relation}.json"), js).expect("Unable to write file"); + } + if let Some(tuples) = c.recovered_tuples.get(relation) { + let js = serde_json::to_string(&tuples).expect("Cannot serialize tuples"); + fs::write(&format!("{dir}/{relation}.json"), js).expect("Unable to write file"); + } + }) + } +} + +impl Clone for DumpProofsMonitor { + fn clone(&self) -> Self { + Self { + internal: ArcFamily::get_rc_cell(&self.internal, |x| ArcFamily::new_rc_cell(x.clone())), + } + } +} + +impl Monitor for DumpProofsMonitor { + default fn name(&self) -> &'static str { + "dump-proofs" + } + default fn observe_finish_execution(&self) {} + default fn observe_loading_relation(&self, _: &str) {} + default fn observe_tagging(&self, _: &Tuple, _: &Option, _: &Prov::Tag) {} + default fn observe_recovering_relation(&self, _: &str) {} + default fn observe_recover(&self, _: &Tuple, _: &Prov::Tag, _: &Prov::OutputTag) {} + default fn observe_finish_recovering_relation(&self, _: &str) {} +} + +impl Monitor> for DumpProofsMonitor { + fn name(&self) -> &'static str { + "dump-proofs" + } + + fn observe_loading_relation(&self, relation: &str) { + self.set_current_tagging_relation(relation) + } + + fn observe_tagging(&self, tup: &Tuple, input_tag: &Option, tag: &DNFFormula) { + self.record_tag(tup, input_tag.as_ref().map(|v| v.prob), tag) + } + + fn observe_recovering_relation(&self, relation: &str) { + self.set_current_recovering_relation(relation) + } + + fn observe_recover(&self, tup: &Tuple, tag: &DNFFormula, output_tag: &f64) { + self.record_recover(tup, output_tag.clone(), tag) + } + + fn observe_finish_recovering_relation(&self, relation: &str) { + self.dump_relation(relation) + } +} + +impl Monitor> for DumpProofsMonitor { + fn name(&self) -> &'static str { + "dump-proofs" + } + + fn observe_loading_relation(&self, relation: &str) { + self.set_current_tagging_relation(relation) + } + + fn observe_tagging(&self, tup: &Tuple, input_tag: &Option>, tag: &DNFFormula) { + self.record_tag(tup, input_tag.as_ref().map(|v| v.prob), tag) + } + + fn observe_recovering_relation(&self, relation: &str) { + self.set_current_recovering_relation(relation) + } + + fn observe_recover(&self, tup: &Tuple, tag: &DNFFormula, output_tag: &OutputDiffProb) { + self.record_recover(tup, output_tag.0.clone(), tag) + } + + fn observe_finish_recovering_relation(&self, relation: &str) { + self.dump_relation(relation) + } +} diff --git a/core/src/runtime/monitor/dynamic_monitors.rs b/core/src/runtime/monitor/dynamic_monitors.rs index f05ac49..e4273f8 100644 --- a/core/src/runtime/monitor/dynamic_monitors.rs +++ b/core/src/runtime/monitor/dynamic_monitors.rs @@ -4,7 +4,15 @@ use crate::runtime::provenance::*; use super::*; pub struct DynamicMonitors { - monitors: Vec>>, + pub monitors: Vec>>, +} + +impl Clone for DynamicMonitors { + fn clone(&self) -> Self { + Self { + monitors: self.monitors.iter().map(|m| dyn_clone::clone_box(&**m)).collect(), + } + } } impl DynamicMonitors { @@ -12,6 +20,10 @@ impl DynamicMonitors { Self { monitors: vec![] } } + pub fn monitor_names(&self) -> Vec<&str> { + self.monitors.iter().map(|m| m.name()).collect() + } + pub fn is_empty(&self) -> bool { self.monitors.is_empty() } @@ -25,6 +37,12 @@ impl DynamicMonitors { self.add(m); self } + + pub fn extend(&mut self, other: Self) { + for m in other.monitors { + self.monitors.push(m); + } + } } macro_rules! dynamic_monitors_observe_event { @@ -38,6 +56,9 @@ macro_rules! dynamic_monitors_observe_event { } impl Monitor for DynamicMonitors { + fn name(&self) -> &'static str { + "multiple" + } dynamic_monitors_observe_event!(observe_executing_stratum, (stratum_id: usize)); dynamic_monitors_observe_event!(observe_stratum_iteration, (iteration_count: usize)); dynamic_monitors_observe_event!(observe_hitting_iteration_limit, ()); @@ -45,13 +66,9 @@ impl Monitor for DynamicMonitors { dynamic_monitors_observe_event!(observe_loading_relation, (relation: &str)); dynamic_monitors_observe_event!(observe_loading_relation_from_edb, (relation: &str)); dynamic_monitors_observe_event!(observe_loading_relation_from_idb, (relation: &str)); - dynamic_monitors_observe_event!( - observe_tagging, - (tup: &Tuple, input_tag: &Option, tag: &Prov::Tag) - ); + dynamic_monitors_observe_event!(observe_tagging, (tup: &Tuple, input_tag: &Option, tag: &Prov::Tag)); + dynamic_monitors_observe_event!(observe_finish_execution, ()); dynamic_monitors_observe_event!(observe_recovering_relation, (relation: &str)); - dynamic_monitors_observe_event!( - observe_recover, - (tup: &Tuple, tag: &Prov::Tag, output_tag: &Prov::OutputTag) - ); + dynamic_monitors_observe_event!(observe_recover, (tup: &Tuple, tag: &Prov::Tag, output_tag: &Prov::OutputTag)); + dynamic_monitors_observe_event!(observe_finish_recovering_relation, (relation: &str)); } diff --git a/core/src/runtime/monitor/iteration_checker.rs b/core/src/runtime/monitor/iteration_checker.rs index b9fed43..dde0276 100644 --- a/core/src/runtime/monitor/iteration_checker.rs +++ b/core/src/runtime/monitor/iteration_checker.rs @@ -6,6 +6,7 @@ use super::*; /// /// A monitor with iteration limit; will panic if the execution uses an iteration /// over the limit +#[derive(Clone)] pub struct IterationCheckingMonitor { iter_limit: usize, } @@ -17,6 +18,10 @@ impl IterationCheckingMonitor { } impl Monitor for IterationCheckingMonitor { + fn name(&self) -> &'static str { + "iteration-checking" + } + fn observe_stratum_iteration(&self, iteration_count: usize) { if iteration_count > self.iter_limit { panic!( diff --git a/core/src/runtime/monitor/logging.rs b/core/src/runtime/monitor/logging.rs index 8b2f8d1..3de7153 100644 --- a/core/src/runtime/monitor/logging.rs +++ b/core/src/runtime/monitor/logging.rs @@ -4,6 +4,7 @@ use crate::runtime::provenance::Provenance; use super::*; +#[derive(Clone)] pub struct LoggingMonitor; impl LoggingMonitor { @@ -21,6 +22,10 @@ impl LoggingMonitor { } impl Monitor for LoggingMonitor { + fn name(&self) -> &'static str { + "logging" + } + fn observe_executing_stratum(&self, stratum_id: usize) { self.info(&format!("executing stratum #{}", stratum_id)) } diff --git a/core/src/runtime/monitor/mod.rs b/core/src/runtime/monitor/mod.rs index c55ee34..9eb06ce 100644 --- a/core/src/runtime/monitor/mod.rs +++ b/core/src/runtime/monitor/mod.rs @@ -1,13 +1,17 @@ mod debug_runtime; mod debug_tags; +mod dump_proofs; mod dynamic_monitors; mod iteration_checker; mod logging; mod monitor; +mod registry; pub use debug_runtime::*; pub use debug_tags::*; +pub use dump_proofs::*; pub use dynamic_monitors::*; pub use iteration_checker::*; pub use logging::*; pub use monitor::*; +pub use registry::*; diff --git a/core/src/runtime/monitor/monitor.rs b/core/src/runtime/monitor/monitor.rs index dafee0d..fb7d4da 100644 --- a/core/src/runtime/monitor/monitor.rs +++ b/core/src/runtime/monitor/monitor.rs @@ -1,7 +1,9 @@ use crate::common::tuple::Tuple; use crate::runtime::provenance::*; -pub trait Monitor { +pub trait Monitor: dyn_clone::DynClone + 'static { + fn name(&self) -> &'static str; + /// Observe stratum iteration #[allow(unused_variables)] fn observe_executing_stratum(&self, stratum_id: usize) {} @@ -34,6 +36,10 @@ pub trait Monitor { #[allow(unused_variables)] fn observe_tagging(&self, tup: &Tuple, input_tag: &Option, tag: &Prov::Tag) {} + /// Observe that the execution is finished + #[allow(unused_variables)] + fn observe_finish_execution(&self) {} + /// Observe recovering output tags of a relation #[allow(unused_variables)] fn observe_recovering_relation(&self, relation: &str) {} @@ -41,9 +47,17 @@ pub trait Monitor { /// Observe a call on recover function #[allow(unused_variables)] fn observe_recover(&self, tup: &Tuple, tag: &Prov::Tag, output_tag: &Prov::OutputTag) {} + + /// Observe recovering output tags of a relation + #[allow(unused_variables)] + fn observe_finish_recovering_relation(&self, relation: &str) {} } -impl Monitor for () {} +impl Monitor for () { + fn name(&self) -> &'static str { + "empty" + } +} macro_rules! monitor_observe_event { ($func:ident, ($($arg:ident),*), $elem:ident) => { @@ -66,9 +80,10 @@ macro_rules! impl_monitor { ( $($elem:ident),* ) => { impl<$($elem),*, Prov> Monitor for ($($elem,)*) where - $($elem: Monitor,)* + $($elem: Monitor + Clone,)* Prov: Provenance, { + fn name(&self) -> &'static str { "multiple" } monitor_observe_event!(observe_executing_stratum, ($($elem),*), (stratum_id: usize)); monitor_observe_event!(observe_stratum_iteration, ($($elem),*), (iteration_count: usize)); monitor_observe_event!(observe_hitting_iteration_limit, ($($elem),*), ()); @@ -77,8 +92,10 @@ macro_rules! impl_monitor { monitor_observe_event!(observe_loading_relation_from_edb, ($($elem),*), (relation: &str)); monitor_observe_event!(observe_loading_relation_from_idb, ($($elem),*), (relation: &str)); monitor_observe_event!(observe_tagging, ($($elem),*), (tup: &Tuple, input_tag: &Option, tag: &Prov::Tag)); + monitor_observe_event!(observe_finish_execution, ($($elem),*), ()); monitor_observe_event!(observe_recovering_relation, ($($elem),*), (relation: &str)); monitor_observe_event!(observe_recover, ($($elem),*), (tup: &Tuple, tag: &Prov::Tag, output_tag: &Prov::OutputTag)); + monitor_observe_event!(observe_finish_recovering_relation, ($($elem),*), (relation: &str)); } } } diff --git a/core/src/runtime/monitor/registry.rs b/core/src/runtime/monitor/registry.rs new file mode 100644 index 0000000..2db2416 --- /dev/null +++ b/core/src/runtime/monitor/registry.rs @@ -0,0 +1,43 @@ +use std::collections::*; + +use crate::runtime::provenance::*; + +use super::*; + +pub struct MonitorRegistry { + monitors: HashMap>>, +} + +impl MonitorRegistry { + pub fn new() -> Self { + Self { + monitors: HashMap::new(), + } + } + + pub fn std() -> Self { + let mut m = Self::new(); + m.register(DebugRuntimeMonitor); + m.register(DebugTagsMonitor); + m.register(LoggingMonitor); + m.register(DumpProofsMonitor::new()); + m + } + + pub fn register>(&mut self, m: M) { + self.monitors.entry(m.name().to_string()).or_insert(Box::new(m)); + } + + pub fn get(&self, name: &str) -> Option<&Box>> { + self.monitors.get(name) + } + + pub fn load_monitors(&self, names: &[&str]) -> DynamicMonitors { + DynamicMonitors { + monitors: names + .iter() + .filter_map(|name| self.get(name).map(|m| dyn_clone::clone_box(&**m))) + .collect(), + } + } +} diff --git a/core/src/runtime/provenance/common/as_boolean_formula.rs b/core/src/runtime/provenance/common/as_boolean_formula.rs index 702f89f..6d6cf2e 100644 --- a/core/src/runtime/provenance/common/as_boolean_formula.rs +++ b/core/src/runtime/provenance/common/as_boolean_formula.rs @@ -1,5 +1,7 @@ use sdd::Semiring; +use super::Disjunctions; + pub trait AsBooleanFormula { /// Implement as boolean formula fn as_boolean_formula(&self) -> sdd::BooleanFormula; @@ -16,4 +18,49 @@ pub trait AsBooleanFormula { let sdd = sdd_builder.build(&formula); sdd.eval_t(v, s) } + + /// Can be used to build an SDD + fn wmc_with_disjunctions(&self, s: &S, v: &V, disj: &Disjunctions) -> S::Element + where + S: Semiring, + V: Fn(&usize) -> S::Element, + { + let formula = self.as_boolean_formula(); + + // Adding disjunction as part of formula + let formula_with_disj = sdd::bf_disjunction( + std::iter::once(formula) + .chain( + disj + .disjunctions + .iter() + .map(|disj| { + sdd::bf_conjunction( + disj + .facts + .iter() + .map(|to_be_neg_fact_id| { + sdd::bf_disjunction( + disj + .facts + .iter() + .map(|fact_id| { + if fact_id == to_be_neg_fact_id { + sdd::bf_neg(fact_id.clone()) + } else { + sdd::bf_pos(fact_id.clone()) + } + }) + ) + }) + ) + }) + ) + ); + + let sdd_config = sdd::bottom_up::SDDBuilderConfig::with_formula(&formula_with_disj); + let sdd_builder = sdd::bottom_up::SDDBuilder::with_config(sdd_config); + let sdd = sdd_builder.build(&formula_with_disj); + sdd.eval_t(v, s) + } } diff --git a/core/src/runtime/provenance/common/disjunction.rs b/core/src/runtime/provenance/common/disjunction.rs index 08140bb..713cafa 100644 --- a/core/src/runtime/provenance/common/disjunction.rs +++ b/core/src/runtime/provenance/common/disjunction.rs @@ -3,7 +3,7 @@ pub use std::iter::FromIterator; #[derive(Clone, Debug)] pub struct Disjunction { - facts: BTreeSet, + pub facts: BTreeSet, } impl Disjunction { @@ -70,8 +70,8 @@ impl FromIterator for Disjunction { #[derive(Clone, Debug, Default)] pub struct Disjunctions { - id_map: HashMap, - disjunctions: Vec, + pub id_map: HashMap, + pub disjunctions: Vec, } impl Disjunctions { diff --git a/core/src/runtime/provenance/common/dnf_formula.rs b/core/src/runtime/provenance/common/dnf_formula.rs index 14e13d8..0161843 100644 --- a/core/src/runtime/provenance/common/dnf_formula.rs +++ b/core/src/runtime/provenance/common/dnf_formula.rs @@ -13,6 +13,21 @@ impl DNFFormula { Self { clauses } } + pub fn get_singleton_id(&self) -> Option { + if self.clauses.len() == 1 { + if self.clauses[0].literals.len() == 1 { + match &self.clauses[0].literals[0] { + Literal::Pos(id) => Some(*id), + _ => None, + } + } else { + None + } + } else { + None + } + } + pub fn is_empty(&self) -> bool { self.clauses.is_empty() } diff --git a/core/src/runtime/provenance/common/input_tags/boolean.rs b/core/src/runtime/provenance/common/input_tags/boolean.rs index 6ed82b0..c3c360f 100644 --- a/core/src/runtime/provenance/common/input_tags/boolean.rs +++ b/core/src/runtime/provenance/common/input_tags/boolean.rs @@ -1,5 +1,5 @@ -use crate::common::input_tag::*; use crate::common::foreign_tensor::*; +use crate::common::input_tag::*; use super::*; diff --git a/core/src/runtime/provenance/common/input_tags/float.rs b/core/src/runtime/provenance/common/input_tags/float.rs index 48406ad..2625f80 100644 --- a/core/src/runtime/provenance/common/input_tags/float.rs +++ b/core/src/runtime/provenance/common/input_tags/float.rs @@ -1,5 +1,5 @@ -use crate::common::input_tag::*; use crate::common::foreign_tensor::*; +use crate::common::input_tag::*; use super::*; diff --git a/core/src/runtime/provenance/common/input_tags/input_diff_prob.rs b/core/src/runtime/provenance/common/input_tags/input_diff_prob.rs index 3ca1558..3424acd 100644 --- a/core/src/runtime/provenance/common/input_tags/input_diff_prob.rs +++ b/core/src/runtime/provenance/common/input_tags/input_diff_prob.rs @@ -1,5 +1,5 @@ -use crate::common::input_tag::*; use crate::common::foreign_tensor::*; +use crate::common::input_tag::*; use super::*; diff --git a/core/src/runtime/provenance/common/input_tags/input_exclusion.rs b/core/src/runtime/provenance/common/input_tags/input_exclusion.rs index 70e8f5a..16637e2 100644 --- a/core/src/runtime/provenance/common/input_tags/input_exclusion.rs +++ b/core/src/runtime/provenance/common/input_tags/input_exclusion.rs @@ -1,5 +1,5 @@ -use crate::common::input_tag::*; use crate::common::foreign_tensor::*; +use crate::common::input_tag::*; use super::*; diff --git a/core/src/runtime/provenance/common/input_tags/input_exclusive_diff_prob.rs b/core/src/runtime/provenance/common/input_tags/input_exclusive_diff_prob.rs index e38692a..2052e40 100644 --- a/core/src/runtime/provenance/common/input_tags/input_exclusive_diff_prob.rs +++ b/core/src/runtime/provenance/common/input_tags/input_exclusive_diff_prob.rs @@ -1,5 +1,5 @@ -use crate::common::input_tag::*; use crate::common::foreign_tensor::*; +use crate::common::input_tag::*; use super::*; diff --git a/core/src/runtime/provenance/common/input_tags/input_exclusive_prob.rs b/core/src/runtime/provenance/common/input_tags/input_exclusive_prob.rs index 4f321ed..57efe72 100644 --- a/core/src/runtime/provenance/common/input_tags/input_exclusive_prob.rs +++ b/core/src/runtime/provenance/common/input_tags/input_exclusive_prob.rs @@ -1,5 +1,5 @@ -use crate::common::input_tag::*; use crate::common::foreign_tensor::*; +use crate::common::input_tag::*; use super::*; diff --git a/core/src/runtime/provenance/common/input_tags/natural.rs b/core/src/runtime/provenance/common/input_tags/natural.rs index 949f573..18e32e2 100644 --- a/core/src/runtime/provenance/common/input_tags/natural.rs +++ b/core/src/runtime/provenance/common/input_tags/natural.rs @@ -1,5 +1,5 @@ -use crate::common::input_tag::*; use crate::common::foreign_tensor::*; +use crate::common::input_tag::*; use super::*; diff --git a/core/src/runtime/provenance/common/input_tags/unit.rs b/core/src/runtime/provenance/common/input_tags/unit.rs index e2bd466..95b692c 100644 --- a/core/src/runtime/provenance/common/input_tags/unit.rs +++ b/core/src/runtime/provenance/common/input_tags/unit.rs @@ -1,5 +1,5 @@ -use crate::common::input_tag::*; use crate::common::foreign_tensor::*; +use crate::common::input_tag::*; use super::*; diff --git a/core/src/runtime/provenance/differentiable/diff_add_mult_prob.rs b/core/src/runtime/provenance/differentiable/diff_add_mult_prob.rs index 193b409..76f9169 100644 --- a/core/src/runtime/provenance/differentiable/diff_add_mult_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_add_mult_prob.rs @@ -1,10 +1,6 @@ -use itertools::Itertools; - use super::*; use crate::common::element::*; use crate::common::foreign_tensor::*; -use crate::runtime::dynamic::*; -use crate::runtime::statics::*; use crate::utils::PointerFamily; pub struct DiffAddMultProbProvenance { @@ -120,72 +116,4 @@ impl Provenance for DiffAddMultProbProvenance f64 { t.real } - - fn dynamic_count(&self, batch: DynamicElements) -> DynamicElements { - let mut result = vec![]; - if batch.is_empty() { - result.push(DynamicElement::new(0usize, self.one())); - } else { - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.tag_of_chosen_set(&batch, &chosen_set); - result.push(DynamicElement::new(count, tag)); - } - } - result - } - - fn dynamic_exists(&self, batch: DynamicElements) -> DynamicElements { - let mut max_prob = 0.0; - let mut max_info = None; - for elem in batch { - let prob = elem.tag.real; - if prob > max_prob { - max_prob = prob; - max_info = Some(elem.tag.clone()); - } - } - if let Some(tag) = max_info { - let f = DynamicElement::new(false, self.negate(&tag).unwrap()); - let t = DynamicElement::new(true, tag); - vec![f, t] - } else { - let e = DynamicElement::new(false, self.one()); - vec![e] - } - } - - fn static_count(&self, batch: StaticElements) -> StaticElements { - let mut result = vec![]; - if batch.is_empty() { - result.push(StaticElement::new(0usize, self.one())); - } else { - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.tag_of_chosen_set(&batch, &chosen_set); - result.push(StaticElement::new(count, tag)); - } - } - result - } - - fn static_exists(&self, batch: StaticElements) -> StaticElements { - let mut max_prob = 0.0; - let mut max_info = None; - for elem in batch { - let prob = elem.tag.real; - if prob > max_prob { - max_prob = prob; - max_info = Some(elem.tag.clone()); - } - } - if let Some(tag) = max_info { - let f = StaticElement::new(false, self.negate(&tag).unwrap()); - let t = StaticElement::new(true, tag); - vec![f, t] - } else { - let e = StaticElement::new(false, self.one()); - vec![e] - } - } } diff --git a/core/src/runtime/provenance/differentiable/diff_max_mult_prob.rs b/core/src/runtime/provenance/differentiable/diff_max_mult_prob.rs index 9a60585..5f2bc9b 100644 --- a/core/src/runtime/provenance/differentiable/diff_max_mult_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_max_mult_prob.rs @@ -1,10 +1,6 @@ -use itertools::Itertools; - use super::*; use crate::common::element::*; use crate::common::foreign_tensor::*; -use crate::runtime::dynamic::*; -use crate::runtime::statics::*; use crate::utils::PointerFamily; pub struct DiffMaxMultProbProvenance { @@ -118,72 +114,4 @@ impl Provenance for DiffMaxMultProbProvenance f64 { t.real } - - fn dynamic_count(&self, batch: DynamicElements) -> DynamicElements { - let mut result = vec![]; - if batch.is_empty() { - result.push(DynamicElement::new(0usize, self.one())); - } else { - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.tag_of_chosen_set(&batch, &chosen_set); - result.push(DynamicElement::new(count, tag)); - } - } - result - } - - fn dynamic_exists(&self, batch: DynamicElements) -> DynamicElements { - let mut max_prob = 0.0; - let mut max_info = None; - for elem in batch { - let prob = elem.tag.real; - if prob > max_prob { - max_prob = prob; - max_info = Some(elem.tag.clone()); - } - } - if let Some(tag) = max_info { - let f = DynamicElement::new(false, self.negate(&tag).unwrap()); - let t = DynamicElement::new(true, tag); - vec![f, t] - } else { - let e = DynamicElement::new(false, self.one()); - vec![e] - } - } - - fn static_count(&self, batch: StaticElements) -> StaticElements { - let mut result = vec![]; - if batch.is_empty() { - result.push(StaticElement::new(0usize, self.one())); - } else { - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.tag_of_chosen_set(&batch, &chosen_set); - result.push(StaticElement::new(count, tag)); - } - } - result - } - - fn static_exists(&self, batch: StaticElements) -> StaticElements { - let mut max_prob = 0.0; - let mut max_info = None; - for elem in batch { - let prob = elem.tag.real; - if prob > max_prob { - max_prob = prob; - max_info = Some(elem.tag.clone()); - } - } - if let Some(tag) = max_info { - let f = StaticElement::new(false, self.negate(&tag).unwrap()); - let t = StaticElement::new(true, tag); - vec![f, t] - } else { - let e = StaticElement::new(false, self.one()); - vec![e] - } - } } diff --git a/core/src/runtime/provenance/differentiable/diff_min_max_prob.rs b/core/src/runtime/provenance/differentiable/diff_min_max_prob.rs index 2c138bc..8629c8e 100644 --- a/core/src/runtime/provenance/differentiable/diff_min_max_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_min_max_prob.rs @@ -1,13 +1,13 @@ -use itertools::Itertools; - -use super::*; use crate::common::element::*; +use crate::common::foreign_aggregate::*; +use crate::common::foreign_aggregates::*; use crate::common::foreign_tensor::*; -use crate::common::value_type::*; use crate::runtime::dynamic::*; -use crate::runtime::statics::*; +use crate::runtime::env::*; use crate::utils::*; +use super::*; + #[derive(Clone)] pub enum Derivative { Pos(usize), @@ -46,6 +46,26 @@ impl std::fmt::Debug for Prob { } } +impl std::cmp::PartialEq for Prob { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl std::cmp::Eq for Prob {} + +impl std::cmp::PartialOrd for Prob { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) + } +} + +impl std::cmp::Ord for Prob { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.0.total_cmp(&other.0) + } +} + impl Tag for Prob {} pub struct DiffMinMaxProbProvenance { @@ -62,48 +82,6 @@ impl Clone for DiffMinMaxProbProvenance { } } -impl DiffMinMaxProbProvenance { - pub fn collect_chosen_elements<'a, E>(&self, all: &'a Vec, chosen_ids: &Vec) -> Vec<&'a E> - where - E: Element, - { - all - .iter() - .enumerate() - .filter(|(i, _)| chosen_ids.contains(i)) - .map(|(_, e)| e) - .collect::>() - } - - pub fn min_tag_of_chosen_set>(&self, all: &Vec, chosen_ids: &Vec) -> Prob { - all - .iter() - .enumerate() - .map(|(id, elem)| { - if chosen_ids.contains(&id) { - elem.tag().clone() - } else { - self.negate(elem.tag()).unwrap() - } - }) - .fold(self.one(), |a, b| self.mult(&a, &b)) - } - - fn max_min_prob_of_k_count>(&self, sorted_set: &Vec, k: usize) -> Prob { - sorted_set - .iter() - .enumerate() - .map(|(id, elem)| { - if id < k { - elem.tag().clone() - } else { - self.negate(elem.tag()).unwrap() - } - }) - .fold(self.one(), |a, b| self.mult(&a, &b)) - } -} - impl Default for DiffMinMaxProbProvenance { fn default() -> Self { Self { @@ -207,171 +185,49 @@ impl Provenance for DiffMinMaxProbProvenance f64 { t.0 } +} - fn dynamic_count(&self, mut batch: DynamicElements) -> DynamicElements { - if batch.is_empty() { - vec![DynamicElement::new(0usize, self.one())] - } else { - batch.sort_by(|a, b| b.tag.0.total_cmp(&a.tag.0)); - let mut elems = vec![]; - for k in 0..=batch.len() { - let prob = self.max_min_prob_of_k_count(&batch, k); - elems.push(DynamicElement::new(k, prob)); - } - elems - } - } - - fn dynamic_sum(&self, ty: &ValueType, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let chosen_elements = self.collect_chosen_elements(&batch, &chosen_set); - let sum = ty.sum(chosen_elements.iter_tuples()); - let prob = self.min_tag_of_chosen_set(&batch, &chosen_set); - elems.push(DynamicElement::new(sum, prob)); - } - elems - } - - fn dynamic_prod(&self, ty: &ValueType, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let chosen_elements = self.collect_chosen_elements(&batch, &chosen_set); - let sum = ty.prod(chosen_elements.iter_tuples()); - let prob = self.min_tag_of_chosen_set(&batch, &chosen_set); - elems.push(DynamicElement::new(sum, prob)); - } - elems - } - - fn dynamic_min(&self, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let min_elem = batch[i].tuple.clone(); - let mut agg_tag = self.one(); - for j in 0..i { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - agg_tag = self.mult(&agg_tag, &batch[i].tag); - elems.push(DynamicElement::new(min_elem, agg_tag)); - } - elems - } - - fn dynamic_max(&self, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let max_elem = batch[i].tuple.clone(); - let mut agg_tag = batch[i].tag.clone(); - for j in i + 1..batch.len() { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - elems.push(DynamicElement::new(max_elem, agg_tag)); - } - elems - } - - fn dynamic_exists(&self, batch: DynamicElements) -> DynamicElements { - let mut max_prob = 0.0; - let mut max_deriv = None; - for elem in batch { - let prob = elem.tag.0; - if prob > max_prob { - max_prob = prob; - max_deriv = Some(elem.tag.1); - } - } - if let Some(deriv) = max_deriv { - let f = DynamicElement::new(false, Self::Tag::new(1.0 - max_prob, deriv.negate())); - let t = DynamicElement::new(true, Self::Tag::new(max_prob, deriv)); - vec![f, t] - } else { - vec![DynamicElement::new(false, self.one())] - } - } - - fn static_count(&self, mut batch: StaticElements) -> StaticElements { - if batch.is_empty() { - vec![StaticElement::new(0usize, self.one())] - } else { - batch.sort_by(|a, b| b.tag.0.total_cmp(&a.tag.0)); - let mut elems = vec![]; - for k in 0..=batch.len() { - let prob = self.max_min_prob_of_k_count(&batch, k); - elems.push(StaticElement::new(k, prob)); - } - elems - } - } - - fn static_sum(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let chosen_elements = self.collect_chosen_elements(&batch, &chosen_set); - let sum = Tup::sum(chosen_elements.iter_tuples().cloned()); - let prob = self.min_tag_of_chosen_set(&batch, &chosen_set); - elems.push(StaticElement::new(sum, prob)); - } - elems - } - - fn static_prod( +impl Aggregator> for CountAggregator { + fn aggregate( &self, - batch: StaticElements, - ) -> StaticElements { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let chosen_elements = self.collect_chosen_elements(&batch, &chosen_set); - let prod = Tup::prod(chosen_elements.iter_tuples().cloned()); - let prob = self.min_tag_of_chosen_set(&batch, &chosen_set); - elems.push(StaticElement::new(prod, prob)); - } - elems - } - - fn static_min(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let min_elem = batch[i].tuple.get().clone(); - let mut agg_tag = self.one(); - for j in 0..i { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - agg_tag = self.mult(&agg_tag, &batch[i].tag); - elems.push(StaticElement::new(min_elem, agg_tag)); - } - elems - } - - fn static_max(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let max_elem = batch[i].tuple.get().clone(); - let mut agg_tag = batch[i].tag.clone(); - for j in i + 1..batch.len() { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); + p: &DiffMinMaxProbProvenance, + _env: &RuntimeEnvironment, + mut batch: DynamicElements>, + ) -> DynamicElements> { + if self.non_multi_world { + vec![DynamicElement::new(batch.len(), p.one())] + } else { + if batch.is_empty() { + vec![DynamicElement::new(0usize, p.one())] + } else { + batch.sort_by(|a, b| b.tag.0.total_cmp(&a.tag.0)); + let mut elems = vec![]; + for k in 0..=batch.len() { + let prob = max_min_prob_of_k_count(&batch, k); + elems.push(DynamicElement::new(k, prob)); + } + elems } - elems.push(StaticElement::new(max_elem, agg_tag)); } - elems } +} - fn static_exists(&self, batch: StaticElements) -> StaticElements { - let mut max_prob = 0.0; - let mut max_deriv = None; - for elem in batch { - let prob = elem.tag.0; - if prob > max_prob { - max_prob = prob; - max_deriv = Some(elem.tag.1); +fn max_min_prob_of_k_count(sorted_set: &Vec, k: usize) -> Prob +where + T: FromTensor, + P: PointerFamily, + E: Element>, +{ + let prob = sorted_set + .iter() + .enumerate() + .map(|(id, elem)| { + if id < k { + elem.tag().clone() + } else { + Prob(1.0 - elem.tag().0, elem.tag().1.negate()) } - } - if let Some(deriv) = max_deriv { - let f = StaticElement::new(false, Self::Tag::new(1.0 - max_prob, deriv.negate())); - let t = StaticElement::new(true, Self::Tag::new(max_prob, deriv)); - vec![f, t] - } else { - vec![StaticElement::new(false, self.one())] - } - } + }) + .fold(Prob(f64::INFINITY, Derivative::Zero), |a, b| a.min(b)); + prob.into() } diff --git a/core/src/runtime/provenance/differentiable/diff_nand_min_prob.rs b/core/src/runtime/provenance/differentiable/diff_nand_min_prob.rs index 5886809..b457d80 100644 --- a/core/src/runtime/provenance/differentiable/diff_nand_min_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_nand_min_prob.rs @@ -1,10 +1,6 @@ -use itertools::Itertools; - use super::*; use crate::common::element::*; use crate::common::foreign_tensor::*; -use crate::runtime::dynamic::*; -use crate::runtime::statics::*; use crate::utils::PointerFamily; pub struct DiffNandMinProbProvenance { @@ -118,72 +114,4 @@ impl Provenance for DiffNandMinProbProvenance f64 { t.real } - - fn dynamic_count(&self, batch: DynamicElements) -> DynamicElements { - let mut result = vec![]; - if batch.is_empty() { - result.push(DynamicElement::new(0usize, self.one())); - } else { - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.tag_of_chosen_set(&batch, &chosen_set); - result.push(DynamicElement::new(count, tag)); - } - } - result - } - - fn dynamic_exists(&self, batch: DynamicElements) -> DynamicElements { - let mut max_prob = 0.0; - let mut max_info = None; - for elem in batch { - let prob = elem.tag.real; - if prob > max_prob { - max_prob = prob; - max_info = Some(elem.tag.clone()); - } - } - if let Some(tag) = max_info { - let f = DynamicElement::new(false, self.negate(&tag).unwrap()); - let t = DynamicElement::new(true, tag); - vec![f, t] - } else { - let e = DynamicElement::new(false, self.one()); - vec![e] - } - } - - fn static_count(&self, batch: StaticElements) -> StaticElements { - let mut result = vec![]; - if batch.is_empty() { - result.push(StaticElement::new(0usize, self.one())); - } else { - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.tag_of_chosen_set(&batch, &chosen_set); - result.push(StaticElement::new(count, tag)); - } - } - result - } - - fn static_exists(&self, batch: StaticElements) -> StaticElements { - let mut max_prob = 0.0; - let mut max_info = None; - for elem in batch { - let prob = elem.tag.real; - if prob > max_prob { - max_prob = prob; - max_info = Some(elem.tag.clone()); - } - } - if let Some(tag) = max_info { - let f = StaticElement::new(false, self.negate(&tag).unwrap()); - let t = StaticElement::new(true, tag); - vec![f, t] - } else { - let e = StaticElement::new(false, self.one()); - vec![e] - } - } } diff --git a/core/src/runtime/provenance/differentiable/diff_nand_mult_prob.rs b/core/src/runtime/provenance/differentiable/diff_nand_mult_prob.rs index ba12278..a40e024 100644 --- a/core/src/runtime/provenance/differentiable/diff_nand_mult_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_nand_mult_prob.rs @@ -1,10 +1,6 @@ -use itertools::Itertools; - use super::*; use crate::common::element::*; use crate::common::foreign_tensor::*; -use crate::runtime::dynamic::*; -use crate::runtime::statics::*; use crate::utils::PointerFamily; pub struct DiffNandMultProbProvenance { @@ -118,72 +114,4 @@ impl Provenance for DiffNandMultProbProvenance< fn weight(&self, t: &Self::Tag) -> f64 { t.real } - - fn dynamic_count(&self, batch: DynamicElements) -> DynamicElements { - let mut result = vec![]; - if batch.is_empty() { - result.push(DynamicElement::new(0usize, self.one())); - } else { - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.tag_of_chosen_set(&batch, &chosen_set); - result.push(DynamicElement::new(count, tag)); - } - } - result - } - - fn dynamic_exists(&self, batch: DynamicElements) -> DynamicElements { - let mut max_prob = 0.0; - let mut max_info = None; - for elem in batch { - let prob = elem.tag.real; - if prob > max_prob { - max_prob = prob; - max_info = Some(elem.tag.clone()); - } - } - if let Some(tag) = max_info { - let f = DynamicElement::new(false, self.negate(&tag).unwrap()); - let t = DynamicElement::new(true, tag); - vec![f, t] - } else { - let e = DynamicElement::new(false, self.one()); - vec![e] - } - } - - fn static_count(&self, batch: StaticElements) -> StaticElements { - let mut result = vec![]; - if batch.is_empty() { - result.push(StaticElement::new(0usize, self.one())); - } else { - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.tag_of_chosen_set(&batch, &chosen_set); - result.push(StaticElement::new(count, tag)); - } - } - result - } - - fn static_exists(&self, batch: StaticElements) -> StaticElements { - let mut max_prob = 0.0; - let mut max_info = None; - for elem in batch { - let prob = elem.tag.real; - if prob > max_prob { - max_prob = prob; - max_info = Some(elem.tag.clone()); - } - } - if let Some(tag) = max_info { - let f = StaticElement::new(false, self.negate(&tag).unwrap()); - let t = StaticElement::new(true, tag); - vec![f, t] - } else { - let e = StaticElement::new(false, self.one()); - vec![e] - } - } } diff --git a/core/src/runtime/provenance/differentiable/diff_top_bottom_k_clauses.rs b/core/src/runtime/provenance/differentiable/diff_top_bottom_k_clauses.rs index e92fcf4..473e148 100644 --- a/core/src/runtime/provenance/differentiable/diff_top_bottom_k_clauses.rs +++ b/core/src/runtime/provenance/differentiable/diff_top_bottom_k_clauses.rs @@ -1,10 +1,6 @@ use std::collections::*; -use itertools::Itertools; - use crate::common::foreign_tensor::*; -use crate::runtime::dynamic::*; -use crate::runtime::statics::*; use crate::utils::*; use super::*; @@ -13,6 +9,7 @@ pub struct DiffTopBottomKClausesProvenance, pub disjunctions: P::Cell, + pub wmc_with_disjunctions: bool, } impl Clone for DiffTopBottomKClausesProvenance { @@ -21,16 +18,18 @@ impl Clone for DiffTopBottomKClausesProvenance< k: self.k, storage: self.storage.clone_internal(), disjunctions: P::clone_cell(&self.disjunctions), + wmc_with_disjunctions: self.wmc_with_disjunctions, } } } impl DiffTopBottomKClausesProvenance { - pub fn new(k: usize) -> Self { + pub fn new(k: usize, wmc_with_disjunctions: bool) -> Self { Self { k, storage: DiffProbStorage::new(), disjunctions: P::new_cell(Disjunctions::new()), + wmc_with_disjunctions, } } @@ -97,7 +96,13 @@ impl Provenance for DiffTopBottomKClausesProven s.constant(real.clone()) } }; - let wmc_result = t.wmc(&s, &v); + let wmc_result = if self.wmc_with_disjunctions { + P::get_cell(&self.disjunctions, |disj| { + t.wmc_with_disjunctions(&s, &v, disj) + }) + } else { + t.wmc(&s, &v) + }; let prob = wmc_result.real; let deriv = wmc_result .deriv @@ -139,110 +144,4 @@ impl Provenance for DiffTopBottomKClausesProven let v = |i: &usize| self.storage.get_prob(i); t.wmc(&RealSemiring::new(), &v) } - - fn dynamic_count(&self, batch: DynamicElements) -> DynamicElements { - if batch.is_empty() { - vec![DynamicElement::new(0usize, self.one())] - } else { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.top_bottom_k_tag_of_chosen_set(batch.iter().map(|e| &e.tag), &chosen_set, self.k); - elems.push(DynamicElement::new(count, tag)); - } - elems - } - } - - fn dynamic_min(&self, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let min_elem = batch[i].tuple.clone(); - let mut agg_tag = self.one(); - for j in 0..i { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - agg_tag = self.mult(&agg_tag, &batch[i].tag); - elems.push(DynamicElement::new(min_elem, agg_tag)); - } - elems - } - - fn dynamic_max(&self, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let max_elem = batch[i].tuple.clone(); - let mut agg_tag = batch[i].tag.clone(); - for j in i + 1..batch.len() { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - elems.push(DynamicElement::new(max_elem, agg_tag)); - } - elems - } - - fn dynamic_exists(&self, batch: DynamicElements) -> DynamicElements { - let mut exists_tag = self.zero(); - let mut not_exists_tag = self.one(); - for elem in batch { - exists_tag = self.add(&exists_tag, &elem.tag); - not_exists_tag = self.mult(¬_exists_tag, &self.negate(&elem.tag).unwrap()); - } - let t = DynamicElement::new(true, exists_tag); - let f = DynamicElement::new(false, not_exists_tag); - vec![t, f] - } - - fn static_count(&self, batch: StaticElements) -> StaticElements { - if batch.is_empty() { - vec![StaticElement::new(0, self.one())] - } else { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.top_bottom_k_tag_of_chosen_set(batch.iter().map(|e| &e.tag), &chosen_set, self.k); - elems.push(StaticElement::new(count, tag)); - } - elems - } - } - - fn static_min(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let min_elem = batch[i].tuple.get().clone(); - let mut agg_tag = self.one(); - for j in 0..i { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - agg_tag = self.mult(&agg_tag, &batch[i].tag); - elems.push(StaticElement::new(min_elem, agg_tag)); - } - elems - } - - fn static_max(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let max_elem = batch[i].tuple.get().clone(); - let mut agg_tag = batch[i].tag.clone(); - for j in i + 1..batch.len() { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - elems.push(StaticElement::new(max_elem, agg_tag)); - } - elems - } - - fn static_exists(&self, batch: StaticElements) -> StaticElements { - let mut exists_tag = self.zero(); - let mut not_exists_tag = self.one(); - for elem in batch { - exists_tag = self.add(&exists_tag, &elem.tag); - not_exists_tag = self.mult(¬_exists_tag, &self.negate(&elem.tag).unwrap()); - } - let t = StaticElement::new(true, exists_tag); - let f = StaticElement::new(false, not_exists_tag); - vec![t, f] - } } diff --git a/core/src/runtime/provenance/differentiable/diff_top_k_proofs.rs b/core/src/runtime/provenance/differentiable/diff_top_k_proofs.rs index b0aafe3..d9984cc 100644 --- a/core/src/runtime/provenance/differentiable/diff_top_k_proofs.rs +++ b/core/src/runtime/provenance/differentiable/diff_top_k_proofs.rs @@ -1,8 +1,4 @@ -use itertools::Itertools; - use crate::common::foreign_tensor::*; -use crate::runtime::dynamic::*; -use crate::runtime::statics::*; use crate::utils::*; use super::*; @@ -11,6 +7,7 @@ pub struct DiffTopKProofsProvenance { pub k: usize, pub storage: DiffProbStorage, pub disjunctions: P::Cell, + pub wmc_with_disjunctions: bool, } impl Clone for DiffTopKProofsProvenance { @@ -19,16 +16,18 @@ impl Clone for DiffTopKProofsProvenance { k: self.k, storage: self.storage.clone_internal(), disjunctions: P::clone_cell(&self.disjunctions), + wmc_with_disjunctions: self.wmc_with_disjunctions, } } } impl DiffTopKProofsProvenance { - pub fn new(k: usize) -> Self { + pub fn new(k: usize, wmc_with_disjunctions: bool) -> Self { Self { k, storage: DiffProbStorage::new(), disjunctions: P::new_cell(Disjunctions::new()), + wmc_with_disjunctions, } } @@ -95,7 +94,13 @@ impl Provenance for DiffTopKProofsProvenance Provenance for DiffTopKProofsProvenance) -> DynamicElements { - if batch.is_empty() { - vec![DynamicElement::new(0usize, self.one())] - } else { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.top_k_tag_of_chosen_set(batch.iter().map(|e| &e.tag), &chosen_set, self.k); - elems.push(DynamicElement::new(count, tag)); - } - elems - } - } - - fn dynamic_min(&self, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let min_elem = batch[i].tuple.clone(); - let mut agg_tag = self.one(); - for j in 0..i { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - agg_tag = self.mult(&agg_tag, &batch[i].tag); - elems.push(DynamicElement::new(min_elem, agg_tag)); - } - elems - } - - fn dynamic_max(&self, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let max_elem = batch[i].tuple.clone(); - let mut agg_tag = batch[i].tag.clone(); - for j in i + 1..batch.len() { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - elems.push(DynamicElement::new(max_elem, agg_tag)); - } - elems - } - - fn dynamic_exists(&self, batch: DynamicElements) -> DynamicElements { - let mut exists_tag = self.zero(); - let mut not_exists_tag = self.one(); - for elem in batch { - exists_tag = self.add(&exists_tag, &elem.tag); - not_exists_tag = self.mult(¬_exists_tag, &self.negate(&elem.tag).unwrap()); - } - let t = DynamicElement::new(true, exists_tag); - let f = DynamicElement::new(false, not_exists_tag); - vec![t, f] - } - - fn static_count(&self, batch: StaticElements) -> StaticElements { - if batch.is_empty() { - vec![StaticElement::new(0, self.one())] - } else { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.top_k_tag_of_chosen_set(batch.iter().map(|e| &e.tag), &chosen_set, self.k); - elems.push(StaticElement::new(count, tag)); - } - elems - } - } - - fn static_min(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let min_elem = batch[i].tuple.get().clone(); - let mut agg_tag = self.one(); - for j in 0..i { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - agg_tag = self.mult(&agg_tag, &batch[i].tag); - elems.push(StaticElement::new(min_elem, agg_tag)); - } - elems - } - - fn static_max(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let max_elem = batch[i].tuple.get().clone(); - let mut agg_tag = batch[i].tag.clone(); - for j in i + 1..batch.len() { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - elems.push(StaticElement::new(max_elem, agg_tag)); - } - elems - } - - fn static_exists(&self, batch: StaticElements) -> StaticElements { - let mut exists_tag = self.zero(); - let mut not_exists_tag = self.one(); - for elem in batch { - exists_tag = self.add(&exists_tag, &elem.tag); - not_exists_tag = self.mult(¬_exists_tag, &self.negate(&elem.tag).unwrap()); - } - let t = StaticElement::new(true, exists_tag); - let f = StaticElement::new(false, not_exists_tag); - vec![t, f] - } } diff --git a/core/src/runtime/provenance/discrete/boolean.rs b/core/src/runtime/provenance/discrete/boolean.rs index e9ed8e4..da085a4 100644 --- a/core/src/runtime/provenance/discrete/boolean.rs +++ b/core/src/runtime/provenance/discrete/boolean.rs @@ -1,6 +1,9 @@ -use super::*; +use crate::common::foreign_aggregate::*; +use crate::common::foreign_aggregates::*; use crate::runtime::dynamic::*; -use crate::runtime::statics::*; +use crate::runtime::env::*; + +use super::*; pub type Boolean = bool; @@ -53,26 +56,53 @@ impl Provenance for BooleanProvenance { fn saturated(&self, t_old: &Self::Tag, t_new: &Self::Tag) -> bool { t_old == t_new } +} - fn dynamic_count(&self, batch: DynamicElements) -> DynamicElements { - let count = batch - .into_iter() - .fold(0usize, |acc, e| if e.tag { acc + 1 } else { acc }); - vec![DynamicElement::new(count, self.one())] +impl Aggregator for CountAggregator { + fn aggregate( + &self, + _p: &BooleanProvenance, + _env: &RuntimeEnvironment, + elems: DynamicElements, + ) -> DynamicElements { + let cnt = elems.iter().fold(0usize, |c, e| if e.tag { c + 1 } else { c }); + vec![DynamicElement::new(cnt, true)] } +} - fn static_count(&self, batch: StaticElements) -> StaticElements { - let count = batch - .into_iter() - .fold(0usize, |acc, e| if e.tag { acc + 1 } else { acc }); - vec![StaticElement::new(count, self.one())] +impl Aggregator for ExistsAggregator { + fn aggregate( + &self, + _p: &BooleanProvenance, + _env: &RuntimeEnvironment, + elems: DynamicElements, + ) -> DynamicElements { + let exist = elems.iter().any(|e| e.tag); + vec![DynamicElement::new(exist, true)] } +} - fn dynamic_top_k(&self, k: usize, batch: DynamicElements) -> DynamicElements { - unweighted_aggregate_top_k_helper(batch, k) +impl Aggregator for MinMaxAggregator { + fn aggregate( + &self, + p: &BooleanProvenance, + _env: &RuntimeEnvironment, + batch: DynamicElements, + ) -> DynamicElements { + let elems = batch.iter().filter_map(|e| if e.tag { Some(&e.tuple) } else { None }); + self.discrete_min_max(p, elems) } +} - fn static_top_k(&self, k: usize, batch: StaticElements) -> StaticElements { - unweighted_aggregate_top_k_helper(batch, k) +impl Aggregator for SumProdAggregator { + fn aggregate( + &self, + _p: &BooleanProvenance, + _env: &RuntimeEnvironment, + batch: DynamicElements, + ) -> DynamicElements { + let elems = batch.iter().filter_map(|e| if e.tag { Some(&e.tuple) } else { None }); + let res = self.perform_sum_prod(elems); + vec![DynamicElement::new(res, true)] } } diff --git a/core/src/runtime/provenance/discrete/natural.rs b/core/src/runtime/provenance/discrete/natural.rs index 8062262..497ac7a 100644 --- a/core/src/runtime/provenance/discrete/natural.rs +++ b/core/src/runtime/provenance/discrete/natural.rs @@ -1,6 +1,4 @@ use super::*; -use crate::runtime::dynamic::*; -use crate::runtime::statics::*; pub type Natural = usize; @@ -49,26 +47,4 @@ impl Provenance for NaturalProvenance { fn saturated(&self, _: &Self::Tag, _: &Self::Tag) -> bool { true } - - fn dynamic_count(&self, batch: DynamicElements) -> DynamicElements { - let count = batch - .into_iter() - .fold(0usize, |acc, e| if e.tag > 0 { acc + 1 } else { acc }); - vec![DynamicElement::new(count, self.one())] - } - - fn static_count(&self, batch: StaticElements) -> StaticElements { - let count = batch - .into_iter() - .fold(0usize, |acc, e| if e.tag > 0 { acc + 1 } else { acc }); - vec![StaticElement::new(count, self.one())] - } - - fn dynamic_top_k(&self, k: usize, batch: DynamicElements) -> DynamicElements { - unweighted_aggregate_top_k_helper(batch, k) - } - - fn static_top_k(&self, k: usize, batch: StaticElements) -> StaticElements { - unweighted_aggregate_top_k_helper(batch, k) - } } diff --git a/core/src/runtime/provenance/discrete/proofs.rs b/core/src/runtime/provenance/discrete/proofs.rs index 62aab2a..674ba95 100644 --- a/core/src/runtime/provenance/discrete/proofs.rs +++ b/core/src/runtime/provenance/discrete/proofs.rs @@ -2,8 +2,6 @@ use std::collections::*; use itertools::iproduct; -use crate::runtime::dynamic::*; -use crate::runtime::statics::*; use crate::utils::*; use super::*; @@ -215,12 +213,4 @@ impl Provenance for ProofsProvenance

{ fn saturated(&self, t_old: &Self::Tag, t_new: &Self::Tag) -> bool { t_old == t_new } - - fn dynamic_top_k(&self, k: usize, batch: DynamicElements) -> DynamicElements { - unweighted_aggregate_top_k_helper(batch, k) - } - - fn static_top_k(&self, k: usize, batch: StaticElements) -> StaticElements { - unweighted_aggregate_top_k_helper(batch, k) - } } diff --git a/core/src/runtime/provenance/discrete/unit.rs b/core/src/runtime/provenance/discrete/unit.rs index e172d36..597034f 100644 --- a/core/src/runtime/provenance/discrete/unit.rs +++ b/core/src/runtime/provenance/discrete/unit.rs @@ -1,5 +1,7 @@ +use crate::common::foreign_aggregate::*; +use crate::common::foreign_aggregates::*; use crate::runtime::dynamic::*; -use crate::runtime::statics::*; +use crate::runtime::env::*; use super::*; @@ -69,12 +71,49 @@ impl Provenance for UnitProvenance { fn saturated(&self, _: &Self::Tag, _: &Self::Tag) -> bool { true } +} + +impl Aggregator for CountAggregator { + fn aggregate( + &self, + _p: &UnitProvenance, + _env: &RuntimeEnvironment, + elems: DynamicElements, + ) -> DynamicElements { + vec![DynamicElement::new(elems.len(), Unit)] + } +} - fn dynamic_top_k(&self, k: usize, batch: DynamicElements) -> DynamicElements { - unweighted_aggregate_top_k_helper(batch, k) +impl Aggregator for ExistsAggregator { + fn aggregate( + &self, + _p: &UnitProvenance, + _env: &RuntimeEnvironment, + elems: DynamicElements, + ) -> DynamicElements { + vec![DynamicElement::new(!elems.is_empty(), Unit)] } +} + +impl Aggregator for MinMaxAggregator { + fn aggregate( + &self, + p: &UnitProvenance, + _env: &RuntimeEnvironment, + batch: DynamicElements, + ) -> DynamicElements { + self.discrete_min_max(p, batch.iter_tuples()) + } +} - fn static_top_k(&self, k: usize, batch: StaticElements) -> StaticElements { - unweighted_aggregate_top_k_helper(batch, k) +impl Aggregator for SumProdAggregator { + fn aggregate( + &self, + _p: &UnitProvenance, + _env: &RuntimeEnvironment, + batch: DynamicElements, + ) -> DynamicElements { + let res = self.perform_sum_prod(batch.iter_tuples()); + vec![DynamicElement::new(res, Unit)] } } diff --git a/core/src/runtime/provenance/probabilistic/add_mult_prob.rs b/core/src/runtime/provenance/probabilistic/add_mult_prob.rs index f8be209..64c6ac0 100644 --- a/core/src/runtime/provenance/probabilistic/add_mult_prob.rs +++ b/core/src/runtime/provenance/probabilistic/add_mult_prob.rs @@ -1,9 +1,4 @@ -use itertools::Itertools; - use super::*; -use crate::common::element::*; -use crate::runtime::dynamic::*; -use crate::runtime::statics::*; #[derive(Clone, Debug)] pub struct AddMultProbProvenance { @@ -11,20 +6,6 @@ pub struct AddMultProbProvenance { } impl AddMultProbProvenance { - fn tag_of_chosen_set>(&self, all: &Vec, chosen_ids: &Vec) -> f64 { - all - .iter() - .enumerate() - .map(|(id, elem)| { - if chosen_ids.contains(&id) { - elem.tag().clone() - } else { - self.negate(elem.tag()).unwrap() - } - }) - .fold(self.one(), |a, b| self.mult(&a, &b)) - } - /// The soft comparison between two probabilities /// /// This function is commonly used for testing purpose @@ -91,72 +72,4 @@ impl Provenance for AddMultProbProvenance { fn weight(&self, t: &Self::Tag) -> f64 { *t as f64 } - - fn dynamic_count(&self, batch: DynamicElements) -> DynamicElements { - if batch.is_empty() { - vec![DynamicElement::new(0usize, self.one())] - } else { - let mut result = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.tag_of_chosen_set(&batch, &chosen_set); - result.push(DynamicElement::new(count, tag)); - } - result - } - } - - fn dynamic_exists(&self, batch: DynamicElements) -> DynamicElements { - let mut max_prob = 0.0; - let mut max_info = None; - for elem in batch { - let prob = elem.tag; - if prob > max_prob { - max_prob = prob; - max_info = Some(elem.tag.clone()); - } - } - if let Some(tag) = max_info { - let f = DynamicElement::new(false, self.negate(&tag).unwrap()); - let t = DynamicElement::new(true, tag); - vec![f, t] - } else { - let e = DynamicElement::new(false, self.one()); - vec![e] - } - } - - fn static_count(&self, batch: StaticElements) -> StaticElements { - let mut result = vec![]; - if batch.is_empty() { - result.push(StaticElement::new(0usize, self.one())); - } else { - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.tag_of_chosen_set(&batch, &chosen_set); - result.push(StaticElement::new(count, tag)); - } - } - result - } - - fn static_exists(&self, batch: StaticElements) -> StaticElements { - let mut max_prob = 0.0; - let mut max_info = None; - for elem in batch { - let prob = elem.tag; - if prob > max_prob { - max_prob = prob; - max_info = Some(elem.tag.clone()); - } - } - if let Some(tag) = max_info { - let f = StaticElement::new(false, self.negate(&tag).unwrap()); - let t = StaticElement::new(true, tag); - vec![f, t] - } else { - let e = StaticElement::new(false, self.one()); - vec![e] - } - } } diff --git a/core/src/runtime/provenance/probabilistic/min_max_prob.rs b/core/src/runtime/provenance/probabilistic/min_max_prob.rs index 5f0a7a5..aa89995 100644 --- a/core/src/runtime/provenance/probabilistic/min_max_prob.rs +++ b/core/src/runtime/provenance/probabilistic/min_max_prob.rs @@ -1,9 +1,9 @@ -use itertools::Itertools; - use crate::common::element::*; -use crate::common::value_type::*; +use crate::common::foreign_aggregate::*; +use crate::common::foreign_aggregates::*; + use crate::runtime::dynamic::*; -use crate::runtime::statics::*; +use crate::runtime::env::*; use super::*; @@ -80,185 +80,31 @@ impl Provenance for MinMaxProbProvenance { fn weight(&self, t: &Self::Tag) -> f64 { *t } +} - fn dynamic_count(&self, mut batch: DynamicElements) -> DynamicElements { - if batch.is_empty() { - vec![DynamicElement::new(0usize, self.one())] - } else { - batch.sort_by(|a, b| b.tag.total_cmp(&a.tag)); - let mut elems = vec![]; - for k in 0..=batch.len() { - let prob = max_min_prob_of_k_count(&batch, k); - elems.push(DynamicElement::new(k, prob)); - } - elems - } - } - - fn dynamic_sum(&self, ty: &ValueType, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let chosen_elements = collect_chosen_elements(&batch, &chosen_set); - let sum = ty.sum(chosen_elements.iter_tuples()); - let prob = min_prob_of_chosen_set(&batch, &chosen_set); - elems.push(DynamicElement::new(sum, prob)); - } - elems - } - - fn dynamic_prod(&self, ty: &ValueType, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let chosen_elements = collect_chosen_elements(&batch, &chosen_set); - let sum = ty.prod(chosen_elements.iter_tuples()); - let prob = min_prob_of_chosen_set(&batch, &chosen_set); - elems.push(DynamicElement::new(sum, prob)); - } - elems - } - - fn dynamic_min(&self, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let min_elem = batch[i].tuple.clone(); - let mut agg_tag = self.one(); - for j in 0..i { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - agg_tag = self.mult(&agg_tag, &batch[i].tag); - elems.push(DynamicElement::new(min_elem, agg_tag)); - } - elems - } - - fn dynamic_max(&self, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let max_elem = batch[i].tuple.clone(); - let mut agg_tag = batch[i].tag.clone(); - for j in i + 1..batch.len() { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - elems.push(DynamicElement::new(max_elem, agg_tag)); - } - elems - } - - fn dynamic_exists(&self, batch: DynamicElements) -> DynamicElements { - let mut exists_tag = self.zero(); - let mut not_exists_tag = self.one(); - for elem in batch { - exists_tag = self.add(&exists_tag, &elem.tag); - not_exists_tag = self.mult(¬_exists_tag, &self.negate(&elem.tag).unwrap()); - } - vec![ - DynamicElement::new(true, exists_tag), - DynamicElement::new(false, not_exists_tag), - ] - } - - fn static_count(&self, mut batch: StaticElements) -> StaticElements { - if batch.is_empty() { - vec![StaticElement::new(0usize, self.one())] - } else { - batch.sort_by(|a, b| b.tag.total_cmp(&a.tag)); - let mut elems = vec![]; - for k in 0..=batch.len() { - let prob = max_min_prob_of_k_count(&batch, k); - elems.push(StaticElement::new(k, prob)); - } - elems - } - } - - fn static_sum(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let chosen_elements = collect_chosen_elements(&batch, &chosen_set); - let sum = Tup::sum(chosen_elements.iter_tuples().cloned()); - let prob = min_prob_of_chosen_set(&batch, &chosen_set); - elems.push(StaticElement::new(sum, prob)); - } - elems - } - - fn static_prod( +impl Aggregator for CountAggregator { + fn aggregate( &self, - batch: StaticElements, - ) -> StaticElements { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let chosen_elements = collect_chosen_elements(&batch, &chosen_set); - let prod = Tup::prod(chosen_elements.iter_tuples().cloned()); - let prob = min_prob_of_chosen_set(&batch, &chosen_set); - elems.push(StaticElement::new(prod, prob)); - } - elems - } - - fn static_min(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let min_elem = batch[i].tuple.get().clone(); - let mut agg_tag = self.one(); - for j in 0..i { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - agg_tag = self.mult(&agg_tag, &batch[i].tag); - elems.push(StaticElement::new(min_elem, agg_tag)); - } - elems - } - - fn static_max(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let max_elem = batch[i].tuple.get().clone(); - let mut agg_tag = batch[i].tag.clone(); - for j in i + 1..batch.len() { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - elems.push(StaticElement::new(max_elem, agg_tag)); - } - elems - } - - fn static_exists(&self, batch: StaticElements) -> StaticElements { - let mut max_prob = 0.0; - let mut max_id = None; - for elem in batch { - let prob = elem.tag; - if prob > max_prob { - max_prob = prob; - max_id = Some(elem.tag.clone()); - } - } - if let Some(tag) = max_id { - let f = StaticElement::new(false, self.negate(&tag).unwrap()); - let t = StaticElement::new(true, tag); - vec![t, f] + p: &MinMaxProbProvenance, + _env: &RuntimeEnvironment, + mut batch: DynamicElements, + ) -> DynamicElements { + if self.non_multi_world { + vec![DynamicElement::new(batch.len(), p.one())] } else { - vec![StaticElement::new(false, self.one())] - } - } -} - -fn min_prob_of_chosen_set(all: &Vec, chosen_ids: &Vec) -> f64 -where - E: Element, -{ - let prob = all - .iter() - .enumerate() - .map(|(id, elem)| { - if chosen_ids.contains(&id) { - *elem.tag() + if batch.is_empty() { + vec![DynamicElement::new(0usize, p.one())] } else { - 1.0 - *elem.tag() + batch.sort_by(|a, b| b.tag.total_cmp(&a.tag)); + let mut elems = vec![]; + for k in 0..=batch.len() { + let prob = max_min_prob_of_k_count(&batch, k); + elems.push(DynamicElement::new(k, prob)); + } + elems } - }) - .fold(f64::INFINITY, |a, b| a.min(b)); - prob.into() + } + } } fn max_min_prob_of_k_count(sorted_set: &Vec, k: usize) -> f64 diff --git a/core/src/runtime/provenance/probabilistic/top_bottom_k_clauses.rs b/core/src/runtime/provenance/probabilistic/top_bottom_k_clauses.rs index 32033e5..eace951 100644 --- a/core/src/runtime/provenance/probabilistic/top_bottom_k_clauses.rs +++ b/core/src/runtime/provenance/probabilistic/top_bottom_k_clauses.rs @@ -1,10 +1,6 @@ use std::collections::*; -use itertools::Itertools; - use super::*; -use crate::runtime::dynamic::*; -use crate::runtime::statics::*; use crate::utils::{PointerFamily, RcFamily}; #[derive(Debug)] @@ -12,6 +8,7 @@ pub struct TopBottomKClausesProvenance { pub k: usize, pub probs: P::Cell>, pub disjunctions: P::Cell, + pub wmc_with_disjunctions: bool, } impl Clone for TopBottomKClausesProvenance

{ @@ -20,16 +17,18 @@ impl Clone for TopBottomKClausesProvenance

{ k: self.k, probs: P::clone_cell(&self.probs), disjunctions: P::clone_cell(&self.disjunctions), + wmc_with_disjunctions: self.wmc_with_disjunctions, } } } impl TopBottomKClausesProvenance

{ - pub fn new(k: usize) -> Self { + pub fn new(k: usize, wmc_with_disjunctions: bool) -> Self { Self { k, probs: P::new_cell(Vec::new()), disjunctions: P::new_cell(Disjunctions::new()), + wmc_with_disjunctions, } } @@ -76,7 +75,13 @@ impl Provenance for TopBottomKClausesProvenance

{ fn recover_fn(&self, t: &Self::Tag) -> Self::OutputTag { let s = RealSemiring; let v = |i: &usize| -> f64 { self.fact_probability(i) }; - t.wmc(&s, &v) + if self.wmc_with_disjunctions { + P::get_cell(&self.disjunctions, |disj| { + t.wmc_with_disjunctions(&s, &v, disj) + }) + } else { + t.wmc(&s, &v) + } } fn discard(&self, t: &Self::Tag) -> bool { @@ -112,110 +117,4 @@ impl Provenance for TopBottomKClausesProvenance

{ let v = |i: &usize| -> f64 { self.fact_probability(i) }; t.wmc(&s, &v) } - - fn dynamic_count(&self, batch: DynamicElements) -> DynamicElements { - if batch.is_empty() { - vec![DynamicElement::new(0usize, self.one())] - } else { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.top_bottom_k_tag_of_chosen_set(batch.iter().map(|e| &e.tag), &chosen_set, self.k); - elems.push(DynamicElement::new(count, tag)); - } - elems - } - } - - fn dynamic_min(&self, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let min_elem = batch[i].tuple.clone(); - let mut agg_tag = self.one(); - for j in 0..i { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - agg_tag = self.mult(&agg_tag, &batch[i].tag); - elems.push(DynamicElement::new(min_elem, agg_tag)); - } - elems - } - - fn dynamic_max(&self, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let max_elem = batch[i].tuple.clone(); - let mut agg_tag = batch[i].tag.clone(); - for j in i + 1..batch.len() { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - elems.push(DynamicElement::new(max_elem, agg_tag)); - } - elems - } - - fn dynamic_exists(&self, batch: DynamicElements) -> DynamicElements { - let mut exists_tag = self.zero(); - let mut not_exists_tag = self.one(); - for elem in batch { - exists_tag = self.add(&exists_tag, &elem.tag); - not_exists_tag = self.mult(¬_exists_tag, &self.negate(&elem.tag).unwrap()); - } - let t = DynamicElement::new(true, exists_tag); - let f = DynamicElement::new(false, not_exists_tag); - vec![t, f] - } - - fn static_count(&self, batch: StaticElements) -> StaticElements { - if batch.is_empty() { - vec![StaticElement::new(0, self.one())] - } else { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.top_bottom_k_tag_of_chosen_set(batch.iter().map(|e| &e.tag), &chosen_set, self.k); - elems.push(StaticElement::new(count, tag)); - } - elems - } - } - - fn static_min(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let min_elem = batch[i].tuple.get().clone(); - let mut agg_tag = self.one(); - for j in 0..i { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - agg_tag = self.mult(&agg_tag, &batch[i].tag); - elems.push(StaticElement::new(min_elem, agg_tag)); - } - elems - } - - fn static_max(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let max_elem = batch[i].tuple.get().clone(); - let mut agg_tag = batch[i].tag.clone(); - for j in i + 1..batch.len() { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - elems.push(StaticElement::new(max_elem, agg_tag)); - } - elems - } - - fn static_exists(&self, batch: StaticElements) -> StaticElements { - let mut exists_tag = self.zero(); - let mut not_exists_tag = self.one(); - for elem in batch { - exists_tag = self.add(&exists_tag, &elem.tag); - not_exists_tag = self.mult(¬_exists_tag, &self.negate(&elem.tag).unwrap()); - } - let t = StaticElement::new(true, exists_tag); - let f = StaticElement::new(false, not_exists_tag); - vec![t, f] - } } diff --git a/core/src/runtime/provenance/probabilistic/top_k_proofs.rs b/core/src/runtime/provenance/probabilistic/top_k_proofs.rs index c3ffdc1..54ee755 100644 --- a/core/src/runtime/provenance/probabilistic/top_k_proofs.rs +++ b/core/src/runtime/provenance/probabilistic/top_k_proofs.rs @@ -1,14 +1,11 @@ -use itertools::Itertools; - use super::*; -use crate::runtime::dynamic::*; -use crate::runtime::statics::*; use crate::utils::*; pub struct TopKProofsProvenance { pub k: usize, pub probs: P::Cell>, pub disjunctions: P::Cell, + pub wmc_with_disjunctions: bool, } impl Default for TopKProofsProvenance

{ @@ -17,6 +14,7 @@ impl Default for TopKProofsProvenance

{ k: 3, probs: P::new_cell(Vec::new()), disjunctions: P::new_cell(Disjunctions::new()), + wmc_with_disjunctions: false, } } } @@ -27,16 +25,18 @@ impl Clone for TopKProofsProvenance

{ k: self.k, probs: P::clone_cell(&self.probs), disjunctions: P::clone_cell(&self.disjunctions), + wmc_with_disjunctions: self.wmc_with_disjunctions, } } } impl TopKProofsProvenance

{ - pub fn new(k: usize) -> Self { + pub fn new(k: usize, wmc_with_disjunctions: bool) -> Self { Self { k, probs: P::new_cell(Vec::new()), disjunctions: P::new_cell(Disjunctions::new()), + wmc_with_disjunctions, } } @@ -87,7 +87,13 @@ impl Provenance for TopKProofsProvenance

{ fn recover_fn(&self, t: &Self::Tag) -> Self::OutputTag { let s = RealSemiring; let v = |i: &usize| -> f64 { self.fact_probability(i) }; - t.wmc(&s, &v) + if self.wmc_with_disjunctions { + P::get_cell(&self.disjunctions, |disj| { + t.wmc_with_disjunctions(&s, &v, disj) + }) + } else { + t.wmc(&s, &v) + } } fn discard(&self, t: &Self::Tag) -> bool { @@ -123,110 +129,4 @@ impl Provenance for TopKProofsProvenance

{ let v = |i: &usize| -> f64 { self.fact_probability(i) }; t.wmc(&s, &v) } - - fn dynamic_count(&self, batch: DynamicElements) -> DynamicElements { - if batch.is_empty() { - vec![DynamicElement::new(0usize, self.one())] - } else { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.top_k_tag_of_chosen_set(batch.iter().map(|e| &e.tag), &chosen_set, self.k); - elems.push(DynamicElement::new(count, tag)); - } - elems - } - } - - fn dynamic_min(&self, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let min_elem = batch[i].tuple.clone(); - let mut agg_tag = self.one(); - for j in 0..i { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - agg_tag = self.mult(&agg_tag, &batch[i].tag); - elems.push(DynamicElement::new(min_elem, agg_tag)); - } - elems - } - - fn dynamic_max(&self, batch: DynamicElements) -> DynamicElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let max_elem = batch[i].tuple.clone(); - let mut agg_tag = batch[i].tag.clone(); - for j in i + 1..batch.len() { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - elems.push(DynamicElement::new(max_elem, agg_tag)); - } - elems - } - - fn dynamic_exists(&self, batch: DynamicElements) -> DynamicElements { - let mut exists_tag = self.zero(); - let mut not_exists_tag = self.one(); - for elem in batch { - exists_tag = self.add(&exists_tag, &elem.tag); - not_exists_tag = self.mult(¬_exists_tag, &self.negate(&elem.tag).unwrap()); - } - let t = DynamicElement::new(true, exists_tag); - let f = DynamicElement::new(false, not_exists_tag); - vec![t, f] - } - - fn static_count(&self, batch: StaticElements) -> StaticElements { - if batch.is_empty() { - vec![StaticElement::new(0, self.one())] - } else { - let mut elems = vec![]; - for chosen_set in (0..batch.len()).powerset() { - let count = chosen_set.len(); - let tag = self.top_k_tag_of_chosen_set(batch.iter().map(|e| &e.tag), &chosen_set, self.k); - elems.push(StaticElement::new(count, tag.into())); - } - elems - } - } - - fn static_min(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let min_elem = batch[i].tuple.get().clone(); - let mut agg_tag = self.one(); - for j in 0..i { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - agg_tag = self.mult(&agg_tag, &batch[i].tag); - elems.push(StaticElement::new(min_elem, agg_tag)); - } - elems - } - - fn static_max(&self, batch: StaticElements) -> StaticElements { - let mut elems = vec![]; - for i in 0..batch.len() { - let max_elem = batch[i].tuple.get().clone(); - let mut agg_tag = batch[i].tag.clone(); - for j in i + 1..batch.len() { - agg_tag = self.mult(&agg_tag, &self.negate(&batch[j].tag).unwrap()); - } - elems.push(StaticElement::new(max_elem, agg_tag)); - } - elems - } - - fn static_exists(&self, batch: StaticElements) -> StaticElements { - let mut exists_tag = self.zero(); - let mut not_exists_tag = self.one(); - for elem in batch { - exists_tag = self.add(&exists_tag, &elem.tag); - not_exists_tag = self.mult(¬_exists_tag, &self.negate(&elem.tag).unwrap()); - } - let t = StaticElement::new(true, exists_tag); - let f = StaticElement::new(false, not_exists_tag); - vec![t, f] - } } diff --git a/core/src/runtime/provenance/provenance.rs b/core/src/runtime/provenance/provenance.rs index 7899e2b..7faace4 100644 --- a/core/src/runtime/provenance/provenance.rs +++ b/core/src/runtime/provenance/provenance.rs @@ -1,16 +1,7 @@ -use std::collections::HashSet; use std::fmt::{Debug, Display}; -use rand::distributions::WeightedIndex; - use super::*; -use crate::common::tuples::*; -use crate::common::value_type::*; -use crate::runtime::dynamic::*; -use crate::runtime::env::*; -use crate::runtime::statics::*; - /// A provenance pub trait Provenance: Clone + 'static { /// The input tag space of the provenance @@ -81,154 +72,6 @@ pub trait Provenance: Clone + 'static { fn weight(&self, tag: &Self::Tag) -> f64 { 1.0 } - - fn dynamic_count(&self, batch: DynamicElements) -> DynamicElements { - vec![DynamicElement::new(batch.len(), self.one())] - } - - fn dynamic_discrete_count(&self, batch: DynamicElements) -> DynamicElements { - vec![DynamicElement::new(batch.len(), self.one())] - } - - fn dynamic_sum(&self, ty: &ValueType, batch: DynamicElements) -> DynamicElements { - let s = ty.sum(batch.iter_tuples()); - vec![DynamicElement::new(s, self.one())] - } - - fn dynamic_prod(&self, ty: &ValueType, batch: DynamicElements) -> DynamicElements { - let p = ty.prod(batch.iter_tuples()); - vec![DynamicElement::new(p, self.one())] - } - - fn dynamic_min(&self, batch: DynamicElements) -> DynamicElements { - batch.first().into_iter().cloned().collect() - } - - fn dynamic_max(&self, batch: DynamicElements) -> DynamicElements { - batch.last().into_iter().cloned().collect() - } - - fn dynamic_argmin(&self, batch: DynamicElements) -> DynamicElements { - batch - .iter_tuples() - .arg_minimum() - .into_iter() - .map(|t| DynamicElement::new(t, self.one())) - .collect() - } - - fn dynamic_argmax(&self, batch: DynamicElements) -> DynamicElements { - batch - .iter_tuples() - .arg_maximum() - .into_iter() - .map(|t| DynamicElement::new(t, self.one())) - .collect() - } - - fn dynamic_exists(&self, batch: DynamicElements) -> DynamicElements { - vec![DynamicElement::new(!batch.is_empty(), self.one())] - } - - fn dynamic_top_k(&self, k: usize, batch: DynamicElements) -> DynamicElements { - let ids = aggregate_top_k_helper(batch.len(), k, |id| self.weight(&batch[id].tag)); - ids.into_iter().map(|id| batch[id].clone()).collect() - } - - fn dynamic_categorical_k( - &self, - k: usize, - batch: DynamicElements, - rt: &RuntimeEnvironment, - ) -> DynamicElements { - if batch.len() <= k { - batch - } else { - let weights = batch.iter().map(|e| self.weight(&e.tag)).collect::>(); - let dist = WeightedIndex::new(&weights).unwrap(); - let sampled_ids = (0..k).map(|_| rt.random.sample_from(&dist)).collect::>(); - batch - .into_iter() - .enumerate() - .filter_map(|(i, e)| if sampled_ids.contains(&i) { Some(e) } else { None }) - .collect() - } - } - - fn static_count(&self, batch: StaticElements) -> StaticElements { - vec![StaticElement::new(batch.len(), self.one())] - } - - fn static_sum(&self, batch: StaticElements) -> StaticElements { - vec![StaticElement::new( - ::sum(batch.iter_tuples().cloned()), - self.one(), - )] - } - - fn static_prod(&self, batch: StaticElements) -> StaticElements { - vec![StaticElement::new( - ::prod(batch.iter_tuples().cloned()), - self.one(), - )] - } - - fn static_max(&self, batch: StaticElements) -> StaticElements { - batch.last().into_iter().cloned().collect() - } - - fn static_min(&self, batch: StaticElements) -> StaticElements { - batch.first().into_iter().cloned().collect() - } - - fn static_argmax( - &self, - batch: StaticElements<(T1, T2), Self>, - ) -> StaticElements<(T1, T2), Self> { - static_argmax(batch.into_iter().map(|e| e.tuple())) - .into_iter() - .map(|t| StaticElement::new(t, self.one())) - .collect() - } - - fn static_argmin( - &self, - batch: StaticElements<(T1, T2), Self>, - ) -> StaticElements<(T1, T2), Self> { - static_argmin(batch.into_iter().map(|e| e.tuple())) - .into_iter() - .map(|t| StaticElement::new(t, self.one())) - .collect() - } - - fn static_exists(&self, batch: StaticElements) -> StaticElements { - vec![StaticElement::new(!batch.is_empty(), self.one())] - } - - fn static_top_k(&self, k: usize, batch: StaticElements) -> StaticElements { - let ids = aggregate_top_k_helper(batch.len(), k, |id| self.weight(&batch[id].tag)); - ids.into_iter().map(|id| batch[id].clone()).collect() - } - - fn static_categorical_k( - &self, - k: usize, - batch: StaticElements, - rt: &RuntimeEnvironment, - ) -> StaticElements { - if batch.len() <= k { - batch - } else { - let weights = batch.iter().map(|e| self.weight(&e.tag)).collect::>(); - let dist = WeightedIndex::new(&weights).unwrap(); - let sampled_ids = (0..k).map(|_| rt.random.sample_from(&dist)).collect::>(); - batch - .into_iter() - .enumerate() - .filter_map(|(i, e)| if sampled_ids.contains(&i) { Some(e) } else { None }) - .collect() - } - } } pub type OutputTagOf = ::OutputTag; diff --git a/core/src/runtime/statics/aggregator/aggregator.rs b/core/src/runtime/statics/aggregator/aggregator.rs index c5d089d..8a7633f 100644 --- a/core/src/runtime/statics/aggregator/aggregator.rs +++ b/core/src/runtime/statics/aggregator/aggregator.rs @@ -1,8 +1,14 @@ +use crate::runtime::env::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; pub trait Aggregator: Clone { type Output: StaticTupleTrait; - fn aggregate(&self, tuples: StaticElements, ctx: &Prov) -> StaticElements; + fn aggregate( + &self, + tuples: StaticElements, + rt: &RuntimeEnvironment, + ctx: &Prov, + ) -> StaticElements; } diff --git a/core/src/runtime/statics/aggregator/argmax.rs b/core/src/runtime/statics/aggregator/argmax.rs index 5388865..0f90dde 100644 --- a/core/src/runtime/statics/aggregator/argmax.rs +++ b/core/src/runtime/statics/aggregator/argmax.rs @@ -1,8 +1,13 @@ use std::marker::PhantomData; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; +use crate::common::foreign_aggregate::Aggregator as DynamicAggregator; +use crate::common::foreign_aggregates::MinMaxAggregator as DynamicMinMaxAggregator; + pub struct ArgmaxAggregator { phantom: PhantomData<(T1, T2, Prov)>, } @@ -13,16 +18,36 @@ impl ArgmaxAggrega } } -impl Aggregator<(T1, T2), Prov> for ArgmaxAggregator +impl Aggregator for ArgmaxAggregator where - T1: StaticTupleTrait, + T1: StaticTupleTrait + TupleLength, T2: StaticTupleTrait, + T: StaticTupleTrait, + (T1, T2): FlattenTuple, Prov: Provenance, { - type Output = (T1, T2); + type Output = T1; - fn aggregate(&self, tuples: StaticElements<(T1, T2), Prov>, ctx: &Prov) -> StaticElements { - ctx.static_argmax(tuples) + fn aggregate( + &self, + tuples: StaticElements, + rt: &RuntimeEnvironment, + ctx: &Prov, + ) -> StaticElements { + let agg = DynamicMinMaxAggregator::argmax(::len()); + let dyn_elems = tuples + .into_iter() + .map(|e| { + let tag = e.tag.clone(); + DynamicElement::new(T::into_dyn_tuple(e.tuple()), tag) + }) + .collect(); + let results = agg.aggregate(ctx, rt, dyn_elems); + let stat_elems = results + .into_iter() + .map(|e| StaticElement::new(T1::from_dyn_tuple(e.tuple), e.tag)) + .collect(); + stat_elems } } diff --git a/core/src/runtime/statics/aggregator/argmin.rs b/core/src/runtime/statics/aggregator/argmin.rs index c7c8466..d472e94 100644 --- a/core/src/runtime/statics/aggregator/argmin.rs +++ b/core/src/runtime/statics/aggregator/argmin.rs @@ -1,8 +1,13 @@ use std::marker::PhantomData; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; +use crate::common::foreign_aggregate::Aggregator as DynamicAggregator; +use crate::common::foreign_aggregates::MinMaxAggregator as DynamicMinMaxAggregator; + pub struct ArgminAggregator { phantom: PhantomData<(T1, T2, Prov)>, } @@ -13,16 +18,36 @@ impl ArgminAggrega } } -impl Aggregator<(T1, T2), Prov> for ArgminAggregator +impl Aggregator for ArgminAggregator where - T1: StaticTupleTrait, + T1: StaticTupleTrait + TupleLength, T2: StaticTupleTrait, + T: StaticTupleTrait, + (T1, T2): FlattenTuple, Prov: Provenance, { - type Output = (T1, T2); + type Output = T1; - fn aggregate(&self, tuples: StaticElements<(T1, T2), Prov>, ctx: &Prov) -> StaticElements { - ctx.static_argmin(tuples) + fn aggregate( + &self, + tuples: StaticElements, + rt: &RuntimeEnvironment, + ctx: &Prov, + ) -> StaticElements { + let agg = DynamicMinMaxAggregator::argmin(::len()); + let dyn_elems = tuples + .into_iter() + .map(|e| { + let tag = e.tag.clone(); + DynamicElement::new(T::into_dyn_tuple(e.tuple()), tag) + }) + .collect(); + let results = agg.aggregate(ctx, rt, dyn_elems); + let stat_elems = results + .into_iter() + .map(|e| StaticElement::new(T1::from_dyn_tuple(e.tuple), e.tag)) + .collect(); + stat_elems } } diff --git a/core/src/runtime/statics/aggregator/argprod.rs b/core/src/runtime/statics/aggregator/argprod.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/core/src/runtime/statics/aggregator/argprod.rs @@ -0,0 +1 @@ + diff --git a/core/src/runtime/statics/aggregator/argsum.rs b/core/src/runtime/statics/aggregator/argsum.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/core/src/runtime/statics/aggregator/argsum.rs @@ -0,0 +1 @@ + diff --git a/core/src/runtime/statics/aggregator/count.rs b/core/src/runtime/statics/aggregator/count.rs index 9969923..72e2ab7 100644 --- a/core/src/runtime/statics/aggregator/count.rs +++ b/core/src/runtime/statics/aggregator/count.rs @@ -1,15 +1,23 @@ use std::marker::PhantomData; +use crate::common::foreign_aggregate::Aggregator as DynamicAggregator; +use crate::common::foreign_aggregates::CountAggregator as DynamicCountAggregator; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; pub struct CountAggregator { + non_multi_world: bool, phantom: PhantomData<(Tup, Prov)>, } impl CountAggregator { - pub fn new() -> Self { - Self { phantom: PhantomData } + pub fn new(non_multi_world: bool) -> Self { + Self { + non_multi_world, + phantom: PhantomData, + } } } @@ -20,8 +28,26 @@ where { type Output = usize; - fn aggregate(&self, tuples: StaticElements, ctx: &Prov) -> StaticElements { - ctx.static_count(tuples) + fn aggregate( + &self, + tuples: StaticElements, + rt: &RuntimeEnvironment, + ctx: &Prov, + ) -> StaticElements { + let agg = DynamicCountAggregator::new(self.non_multi_world); + let dyn_elems = tuples + .into_iter() + .map(|e| { + let tag = e.tag.clone(); + DynamicElement::new(Tup::into_dyn_tuple(e.tuple()), tag) + }) + .collect(); + let results = agg.aggregate(ctx, rt, dyn_elems); + let stat_elems = results + .into_iter() + .map(|e| StaticElement::new(e.tuple.as_usize(), e.tag)) + .collect(); + stat_elems } } @@ -31,6 +57,9 @@ where Prov: Provenance, { fn clone(&self) -> Self { - Self { phantom: PhantomData } + Self { + non_multi_world: self.non_multi_world, + phantom: PhantomData, + } } } diff --git a/core/src/runtime/statics/aggregator/exists.rs b/core/src/runtime/statics/aggregator/exists.rs index c7bc49c..1cd74c1 100644 --- a/core/src/runtime/statics/aggregator/exists.rs +++ b/core/src/runtime/statics/aggregator/exists.rs @@ -1,15 +1,24 @@ use std::marker::PhantomData; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; +use crate::common::foreign_aggregate::Aggregator as DynamicAggregator; +use crate::common::foreign_aggregates::ExistsAggregator as DynamicExistsAggregator; + pub struct ExistsAggregator { + non_multi_world: bool, phantom: PhantomData<(Tup, Prov)>, } impl ExistsAggregator { - pub fn new() -> Self { - Self { phantom: PhantomData } + pub fn new(non_multi_world: bool) -> Self { + Self { + non_multi_world, + phantom: PhantomData, + } } } @@ -20,8 +29,26 @@ where { type Output = bool; - fn aggregate(&self, tuples: StaticElements, ctx: &Prov) -> StaticElements { - ctx.static_exists(tuples) + fn aggregate( + &self, + tuples: StaticElements, + rt: &RuntimeEnvironment, + ctx: &Prov, + ) -> StaticElements { + let agg = DynamicExistsAggregator::new(self.non_multi_world); + let dyn_elems = tuples + .into_iter() + .map(|e| { + let tag = e.tag.clone(); + DynamicElement::new(Tup::into_dyn_tuple(e.tuple()), tag) + }) + .collect(); + let results = agg.aggregate(ctx, rt, dyn_elems); + let stat_elems = results + .into_iter() + .map(|e| StaticElement::new(e.tuple.as_bool(), e.tag)) + .collect(); + stat_elems } } @@ -31,6 +58,9 @@ where Prov: Provenance, { fn clone(&self) -> Self { - Self { phantom: PhantomData } + Self { + non_multi_world: self.non_multi_world, + phantom: PhantomData, + } } } diff --git a/core/src/runtime/statics/aggregator/max.rs b/core/src/runtime/statics/aggregator/max.rs index 16d0253..6bb0d09 100644 --- a/core/src/runtime/statics/aggregator/max.rs +++ b/core/src/runtime/statics/aggregator/max.rs @@ -1,8 +1,13 @@ use std::marker::PhantomData; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; +use crate::common::foreign_aggregate::Aggregator as DynamicAggregator; +use crate::common::foreign_aggregates::MinMaxAggregator as DynamicMinMaxAggregator; + pub struct MaxAggregator { phantom: PhantomData<(Tup, Prov)>, } @@ -20,8 +25,26 @@ where { type Output = Tup; - fn aggregate(&self, tuples: StaticElements, ctx: &Prov) -> StaticElements { - ctx.static_max(tuples) + fn aggregate( + &self, + tuples: StaticElements, + rt: &RuntimeEnvironment, + ctx: &Prov, + ) -> StaticElements { + let agg = DynamicMinMaxAggregator::max(); + let dyn_elems = tuples + .into_iter() + .map(|e| { + let tag = e.tag.clone(); + DynamicElement::new(Tup::into_dyn_tuple(e.tuple()), tag) + }) + .collect(); + let results = agg.aggregate(ctx, rt, dyn_elems); + let stat_elems = results + .into_iter() + .map(|e| StaticElement::new(Tup::from_dyn_tuple(e.tuple), e.tag)) + .collect(); + stat_elems } } diff --git a/core/src/runtime/statics/aggregator/min.rs b/core/src/runtime/statics/aggregator/min.rs index 5969bd9..7f97e70 100644 --- a/core/src/runtime/statics/aggregator/min.rs +++ b/core/src/runtime/statics/aggregator/min.rs @@ -1,8 +1,13 @@ use std::marker::PhantomData; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; +use crate::common::foreign_aggregate::Aggregator as DynamicAggregator; +use crate::common::foreign_aggregates::MinMaxAggregator as DynamicMinMaxAggregator; + pub struct MinAggregator { phantom: PhantomData<(Tup, Prov)>, } @@ -20,8 +25,26 @@ where { type Output = Tup; - fn aggregate(&self, tuples: StaticElements, ctx: &Prov) -> StaticElements { - ctx.static_min(tuples) + fn aggregate( + &self, + tuples: StaticElements, + rt: &RuntimeEnvironment, + ctx: &Prov, + ) -> StaticElements { + let agg = DynamicMinMaxAggregator::min(); + let dyn_elems = tuples + .into_iter() + .map(|e| { + let tag = e.tag.clone(); + DynamicElement::new(Tup::into_dyn_tuple(e.tuple()), tag) + }) + .collect(); + let results = agg.aggregate(ctx, rt, dyn_elems); + let stat_elems = results + .into_iter() + .map(|e| StaticElement::new(Tup::from_dyn_tuple(e.tuple), e.tag)) + .collect(); + stat_elems } } diff --git a/core/src/runtime/statics/aggregator/mod.rs b/core/src/runtime/statics/aggregator/mod.rs index bb4f006..cacac37 100644 --- a/core/src/runtime/statics/aggregator/mod.rs +++ b/core/src/runtime/statics/aggregator/mod.rs @@ -1,6 +1,8 @@ mod aggregator; mod argmax; mod argmin; +mod argprod; +mod argsum; mod count; mod exists; mod max; @@ -12,6 +14,8 @@ mod top_k; pub use aggregator::*; pub use argmax::*; pub use argmin::*; +pub use argprod::*; +pub use argsum::*; pub use count::*; pub use exists::*; pub use max::*; diff --git a/core/src/runtime/statics/aggregator/prod.rs b/core/src/runtime/statics/aggregator/prod.rs index 81b0862..6bcbac5 100644 --- a/core/src/runtime/statics/aggregator/prod.rs +++ b/core/src/runtime/statics/aggregator/prod.rs @@ -1,13 +1,25 @@ use std::marker::PhantomData; +use crate::common::value_type::*; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; -pub struct ProdAggregator { +use crate::common::foreign_aggregate::Aggregator as DynamicAggregator; +use crate::common::foreign_aggregates::SumProdAggregator as DynamicSumProdAggregator; + +pub struct ProdAggregator +where + ValueType: FromType, +{ phantom: PhantomData<(Tup, Prov)>, } -impl ProdAggregator { +impl ProdAggregator +where + ValueType: FromType, +{ pub fn new() -> Self { Self { phantom: PhantomData } } @@ -17,11 +29,30 @@ impl Aggregator for ProdAggregator where Tup: StaticTupleTrait + ProdType, Prov: Provenance, + ValueType: FromType, { type Output = Tup; - fn aggregate(&self, tuples: StaticElements, ctx: &Prov) -> StaticElements { - ctx.static_prod(tuples) + fn aggregate( + &self, + tuples: StaticElements, + rt: &RuntimeEnvironment, + ctx: &Prov, + ) -> StaticElements { + let agg = DynamicSumProdAggregator::prod::(false); + let dyn_elems = tuples + .into_iter() + .map(|e| { + let tag = e.tag.clone(); + DynamicElement::new(Tup::into_dyn_tuple(e.tuple()), tag) + }) + .collect(); + let results = agg.aggregate(ctx, rt, dyn_elems); + let stat_elems = results + .into_iter() + .map(|e| StaticElement::new(Tup::from_dyn_tuple(e.tuple), e.tag)) + .collect(); + stat_elems } } @@ -29,6 +60,7 @@ impl Clone for ProdAggregator where Tup: StaticTupleTrait + ProdType, Prov: Provenance, + ValueType: FromType, { fn clone(&self) -> Self { Self { phantom: PhantomData } diff --git a/core/src/runtime/statics/aggregator/sum.rs b/core/src/runtime/statics/aggregator/sum.rs index d47977e..878c9af 100644 --- a/core/src/runtime/statics/aggregator/sum.rs +++ b/core/src/runtime/statics/aggregator/sum.rs @@ -1,13 +1,25 @@ use std::marker::PhantomData; +use crate::common::value_type::*; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; -pub struct SumAggregator { +use crate::common::foreign_aggregate::Aggregator as DynamicAggregator; +use crate::common::foreign_aggregates::SumProdAggregator as DynamicSumProdAggregator; + +pub struct SumAggregator +where + ValueType: FromType, +{ phantom: PhantomData<(Tup, Prov)>, } -impl SumAggregator { +impl SumAggregator +where + ValueType: FromType, +{ pub fn new() -> Self { Self { phantom: PhantomData } } @@ -17,11 +29,30 @@ impl Aggregator for SumAggregator where Tup: StaticTupleTrait + SumType, Prov: Provenance, + ValueType: FromType, { type Output = Tup; - fn aggregate(&self, tuples: StaticElements, ctx: &Prov) -> StaticElements { - ctx.static_sum(tuples) + fn aggregate( + &self, + tuples: StaticElements, + rt: &RuntimeEnvironment, + ctx: &Prov, + ) -> StaticElements { + let agg = DynamicSumProdAggregator::sum::(false); + let dyn_elems = tuples + .into_iter() + .map(|e| { + let tag = e.tag.clone(); + DynamicElement::new(Tup::into_dyn_tuple(e.tuple()), tag) + }) + .collect(); + let results = agg.aggregate(ctx, rt, dyn_elems); + let stat_elems = results + .into_iter() + .map(|e| StaticElement::new(Tup::from_dyn_tuple(e.tuple), e.tag)) + .collect(); + stat_elems } } @@ -29,6 +60,7 @@ impl Clone for SumAggregator where Tup: StaticTupleTrait + SumType, Prov: Provenance, + ValueType: FromType, { fn clone(&self) -> Self { Self { phantom: PhantomData } diff --git a/core/src/runtime/statics/aggregator/top_k.rs b/core/src/runtime/statics/aggregator/top_k.rs index 9a5e5cf..7537699 100644 --- a/core/src/runtime/statics/aggregator/top_k.rs +++ b/core/src/runtime/statics/aggregator/top_k.rs @@ -1,8 +1,11 @@ use std::marker::PhantomData; +use crate::runtime::env::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; +use crate::common::foreign_aggregates::*; + pub struct TopKAggregator { k: usize, phantom: PhantomData<(Tup, Prov)>, @@ -24,8 +27,17 @@ where { type Output = Tup; - fn aggregate(&self, tuples: StaticElements, ctx: &Prov) -> StaticElements { - ctx.static_top_k(self.k, tuples) + fn aggregate( + &self, + tuples: StaticElements, + rt: &RuntimeEnvironment, + ctx: &Prov, + ) -> StaticElements { + let agg = TopKSampler::new(self.k); + let weights = tuples.iter().map(|e| ctx.weight(&e.tag)).collect(); + let indices = agg.sample_weight_only(rt, weights); + let stat_elems = indices.into_iter().map(|i| tuples[i].clone()).collect(); + stat_elems } } diff --git a/core/src/runtime/statics/dataflow/aggregation/implicit_group.rs b/core/src/runtime/statics/dataflow/aggregation/implicit_group.rs index b7d202b..9817b3a 100644 --- a/core/src/runtime/statics/dataflow/aggregation/implicit_group.rs +++ b/core/src/runtime/statics/dataflow/aggregation/implicit_group.rs @@ -1,5 +1,6 @@ use std::marker::PhantomData; +use crate::runtime::env::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; @@ -15,6 +16,7 @@ where { agg: A, d: D, + rt: &'a RuntimeEnvironment, ctx: &'a Prov, phantom: PhantomData<(K, T1)>, } @@ -27,10 +29,11 @@ where A: Aggregator, Prov: Provenance, { - pub fn new(agg: A, d: D, ctx: &'a Prov) -> Self { + pub fn new(agg: A, d: D, rt: &'a RuntimeEnvironment, ctx: &'a Prov) -> Self { Self { agg, d, + rt, ctx, phantom: PhantomData, } @@ -63,12 +66,13 @@ where // Cache the context let agg = self.agg; + let rt = self.rt; let ctx = self.ctx; // Temporary function to aggregate the group and populate the result let consolidate_group = |result: &mut StaticElements<(K, A::Output), Prov>, agg_key: K, agg_group: StaticElements| { - let agg_results = agg.aggregate(agg_group, ctx); + let agg_results = agg.aggregate(agg_group, rt, ctx); let joined_results = agg_results .into_iter() .map(|agg_result| StaticElement::new((agg_key.clone(), agg_result.tuple.get().clone()), agg_result.tag)); @@ -123,6 +127,7 @@ where Self { agg: self.agg.clone(), d: self.d.clone(), + rt: self.rt, ctx: self.ctx, phantom: PhantomData, } diff --git a/core/src/runtime/statics/dataflow/aggregation/join_group.rs b/core/src/runtime/statics/dataflow/aggregation/join_group.rs index ecc105a..7452f03 100644 --- a/core/src/runtime/statics/dataflow/aggregation/join_group.rs +++ b/core/src/runtime/statics/dataflow/aggregation/join_group.rs @@ -2,6 +2,7 @@ use std::marker::PhantomData; use itertools::iproduct; +use crate::runtime::env::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; @@ -20,6 +21,7 @@ where agg: A, d1: D1, d2: D2, + rt: &'a RuntimeEnvironment, ctx: &'a Prov, phantom: PhantomData<(K, T1, T2)>, } @@ -34,11 +36,12 @@ where A: Aggregator, Prov: Provenance, { - pub fn new(agg: A, d1: D1, d2: D2, ctx: &'a Prov) -> Self { + pub fn new(agg: A, d1: D1, d2: D2, rt: &'a RuntimeEnvironment, ctx: &'a Prov) -> Self { Self { agg, d1, d2, + rt, ctx, phantom: PhantomData, } @@ -78,6 +81,7 @@ where }; let agg = self.agg; + let rt = self.rt; let ctx = self.ctx; let mut groups = vec![]; @@ -143,7 +147,7 @@ where .iter() .map(|e| StaticElement::new(e.tuple.1.clone(), e.tag.clone())) .collect::>(); - let agg_results = agg.aggregate(to_agg_tups, ctx); + let agg_results = agg.aggregate(to_agg_tups, rt, ctx); iproduct!(group_by_vals, agg_results) .map(|((tag, t1), agg_result)| { StaticElement::new( @@ -175,6 +179,7 @@ where agg: self.agg.clone(), d1: self.d1.clone(), d2: self.d2.clone(), + rt: self.rt, ctx: self.ctx, phantom: PhantomData, } diff --git a/core/src/runtime/statics/dataflow/aggregation/single_group.rs b/core/src/runtime/statics/dataflow/aggregation/single_group.rs index 66a3328..89577f9 100644 --- a/core/src/runtime/statics/dataflow/aggregation/single_group.rs +++ b/core/src/runtime/statics/dataflow/aggregation/single_group.rs @@ -1,5 +1,6 @@ use std::marker::PhantomData; +use crate::runtime::env::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; @@ -14,6 +15,7 @@ where { agg: A, d: D, + rt: &'a RuntimeEnvironment, ctx: &'a Prov, phantom: PhantomData, } @@ -25,10 +27,11 @@ where A: Aggregator, Prov: Provenance, { - pub fn new(agg: A, d: D, ctx: &'a Prov) -> Self { + pub fn new(agg: A, d: D, rt: &'a RuntimeEnvironment, ctx: &'a Prov) -> Self { Self { agg, d, + rt, ctx, phantom: PhantomData, } @@ -53,18 +56,13 @@ where fn iter_recent(self) -> Self::Recent { // Sanitize input relation let batch = if let Some(b) = self.d.iter_recent().next() { - let result = b.collect::>(); - if result.is_empty() { - return Self::Recent::empty(); - } else { - result - } + b.collect::>() } else { return Self::Recent::empty(); }; // Aggregate the result using aggregator - let result = self.agg.aggregate(batch, self.ctx); + let result = self.agg.aggregate(batch, self.rt, self.ctx); Self::Recent::singleton(result.into_iter()) } } @@ -80,6 +78,7 @@ where Self { agg: self.agg.clone(), d: self.d.clone(), + rt: self.rt, ctx: self.ctx, phantom: PhantomData, } diff --git a/core/src/runtime/statics/iteration.rs b/core/src/runtime/statics/iteration.rs index 82ac972..3a3bf8f 100644 --- a/core/src/runtime/statics/iteration.rs +++ b/core/src/runtime/statics/iteration.rs @@ -1,19 +1,23 @@ use super::*; + +use crate::runtime::env::*; use crate::runtime::provenance::*; pub struct StaticIteration<'a, Prov: Provenance> { pub iter_num: usize, pub early_discard: bool, pub relations: Vec>>, + pub runtime_environment: &'a RuntimeEnvironment, pub provenance_context: &'a mut Prov, } impl<'a, Prov: Provenance> StaticIteration<'a, Prov> { /// Create a new Iteration - pub fn new(provenance_context: &'a mut Prov) -> Self { + pub fn new(runtime_environment: &'a RuntimeEnvironment, provenance_context: &'a mut Prov) -> Self { Self { iter_num: 0, early_discard: true, + runtime_environment, provenance_context, relations: Vec::new(), } @@ -122,7 +126,7 @@ impl<'a, Prov: Provenance> StaticIteration<'a, Prov> { T1: StaticTupleTrait, D: dataflow::Dataflow, { - dataflow::AggregationSingleGroup::new(agg, d, &self.provenance_context) + dataflow::AggregationSingleGroup::new(agg, d, self.runtime_environment, &self.provenance_context) } pub fn aggregate_implicit_group( @@ -136,7 +140,7 @@ impl<'a, Prov: Provenance> StaticIteration<'a, Prov> { T1: StaticTupleTrait, D: dataflow::Dataflow<(K, T1), Prov>, { - dataflow::AggregationImplicitGroup::new(agg, d, &self.provenance_context) + dataflow::AggregationImplicitGroup::new(agg, d, self.runtime_environment, &self.provenance_context) } pub fn aggregate_join_group( @@ -153,7 +157,7 @@ impl<'a, Prov: Provenance> StaticIteration<'a, Prov> { D1: dataflow::Dataflow<(K, T1), Prov>, D2: dataflow::Dataflow<(K, T2), Prov>, { - dataflow::AggregationJoinGroup::new(agg, v1, v2, &self.provenance_context) + dataflow::AggregationJoinGroup::new(agg, v1, v2, self.runtime_environment, &self.provenance_context) } pub fn complete(&self, r: &StaticRelation) -> StaticCollection diff --git a/core/src/runtime/statics/tuple.rs b/core/src/runtime/statics/tuple.rs index f59bc0a..f0022be 100644 --- a/core/src/runtime/statics/tuple.rs +++ b/core/src/runtime/statics/tuple.rs @@ -1,6 +1,305 @@ -pub trait StaticTupleTrait: 'static + Sized + Clone + std::fmt::Debug + std::cmp::PartialOrd {} +use crate::common::tuple::*; +use crate::common::value::*; -impl StaticTupleTrait for T where T: 'static + Sized + Clone + std::fmt::Debug + std::cmp::PartialOrd {} +pub trait StaticTupleTrait: 'static + Sized + Clone + std::fmt::Debug + std::cmp::PartialOrd { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self; + + fn into_dyn_tuple(self) -> Tuple; +} + +impl StaticTupleTrait for () { + fn from_dyn_tuple(_: Tuple) -> Self { + () + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Tuple(Box::new([])) + } +} + +impl StaticTupleTrait for i8 { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_i8() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::I8(self)) + } +} + +impl StaticTupleTrait for i16 { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_i16() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::I16(self)) + } +} + +impl StaticTupleTrait for i32 { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_i32() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::I32(self)) + } +} + +impl StaticTupleTrait for i64 { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_i64() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::I64(self)) + } +} + +impl StaticTupleTrait for i128 { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_i128() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::I128(self)) + } +} + +impl StaticTupleTrait for isize { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_isize() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::ISize(self)) + } +} + +impl StaticTupleTrait for u8 { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_u8() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::U8(self)) + } +} + +impl StaticTupleTrait for u16 { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_u16() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::U16(self)) + } +} + +impl StaticTupleTrait for u32 { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_u32() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::U32(self)) + } +} + +impl StaticTupleTrait for u64 { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_u64() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::U64(self)) + } +} + +impl StaticTupleTrait for u128 { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_u128() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::U128(self)) + } +} + +impl StaticTupleTrait for usize { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_usize() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::USize(self)) + } +} + +impl StaticTupleTrait for f32 { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_f32() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::F32(self)) + } +} + +impl StaticTupleTrait for f64 { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_f64() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::F64(self)) + } +} + +impl StaticTupleTrait for bool { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_bool() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::Bool(self)) + } +} + +impl StaticTupleTrait for char { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_char() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::Char(self)) + } +} + +impl StaticTupleTrait for &'static str { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_str() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::Str(self)) + } +} + +impl StaticTupleTrait for String { + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + dyn_tuple.as_string() + } + fn into_dyn_tuple(self) -> Tuple { + Tuple::Value(Value::String(self)) + } +} + +impl StaticTupleTrait for (T1,) +where + T1: StaticTupleTrait, +{ + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + match dyn_tuple { + Tuple::Tuple(elems) => (T1::from_dyn_tuple(elems[0].clone()),), + _ => panic!("expected dyn tuple"), + } + } + + fn into_dyn_tuple(self) -> Tuple { + Tuple::Tuple(Box::new([self.0.into_dyn_tuple()])) + } +} + +impl StaticTupleTrait for (T1, T2) +where + T1: StaticTupleTrait, + T2: StaticTupleTrait, +{ + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + match dyn_tuple { + Tuple::Tuple(elems) => ( + T1::from_dyn_tuple(elems[0].clone()), + T2::from_dyn_tuple(elems[1].clone()), + ), + _ => panic!("expected dyn tuple"), + } + } + + fn into_dyn_tuple(self) -> Tuple { + Tuple::Tuple(Box::new([self.0.into_dyn_tuple(), self.1.into_dyn_tuple()])) + } +} + +impl StaticTupleTrait for (T1, T2, T3) +where + T1: StaticTupleTrait, + T2: StaticTupleTrait, + T3: StaticTupleTrait, +{ + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + match dyn_tuple { + Tuple::Tuple(elems) => ( + T1::from_dyn_tuple(elems[0].clone()), + T2::from_dyn_tuple(elems[1].clone()), + T3::from_dyn_tuple(elems[2].clone()), + ), + _ => panic!("expected dyn tuple"), + } + } + + fn into_dyn_tuple(self) -> Tuple { + Tuple::Tuple(Box::new([ + self.0.into_dyn_tuple(), + self.1.into_dyn_tuple(), + self.2.into_dyn_tuple(), + ])) + } +} + +impl StaticTupleTrait for (T1, T2, T3, T4) +where + T1: StaticTupleTrait, + T2: StaticTupleTrait, + T3: StaticTupleTrait, + T4: StaticTupleTrait, +{ + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + match dyn_tuple { + Tuple::Tuple(elems) => ( + T1::from_dyn_tuple(elems[0].clone()), + T2::from_dyn_tuple(elems[1].clone()), + T3::from_dyn_tuple(elems[2].clone()), + T4::from_dyn_tuple(elems[3].clone()), + ), + _ => panic!("expected dyn tuple"), + } + } + + fn into_dyn_tuple(self) -> Tuple { + Tuple::Tuple(Box::new([ + self.0.into_dyn_tuple(), + self.1.into_dyn_tuple(), + self.2.into_dyn_tuple(), + self.3.into_dyn_tuple(), + ])) + } +} + +impl StaticTupleTrait for (T1, T2, T3, T4, T5) +where + T1: StaticTupleTrait, + T2: StaticTupleTrait, + T3: StaticTupleTrait, + T4: StaticTupleTrait, + T5: StaticTupleTrait, +{ + fn from_dyn_tuple(dyn_tuple: Tuple) -> Self { + match dyn_tuple { + Tuple::Tuple(elems) => ( + T1::from_dyn_tuple(elems[0].clone()), + T2::from_dyn_tuple(elems[1].clone()), + T3::from_dyn_tuple(elems[2].clone()), + T4::from_dyn_tuple(elems[3].clone()), + T5::from_dyn_tuple(elems[4].clone()), + ), + _ => panic!("expected dyn tuple"), + } + } + + fn into_dyn_tuple(self) -> Tuple { + Tuple::Tuple(Box::new([ + self.0.into_dyn_tuple(), + self.1.into_dyn_tuple(), + self.2.into_dyn_tuple(), + self.3.into_dyn_tuple(), + self.4.into_dyn_tuple(), + ])) + } +} #[derive(Clone, PartialEq, PartialOrd)] pub struct StaticTupleWrapper(T); diff --git a/core/src/runtime/statics/utils/flatten_tuple.rs b/core/src/runtime/statics/utils/flatten_tuple.rs new file mode 100644 index 0000000..a5034ce --- /dev/null +++ b/core/src/runtime/statics/utils/flatten_tuple.rs @@ -0,0 +1,87 @@ +pub trait BaseType {} +impl BaseType for i8 {} +impl BaseType for i16 {} +impl BaseType for i32 {} +impl BaseType for i64 {} +impl BaseType for i128 {} +impl BaseType for isize {} +impl BaseType for u8 {} +impl BaseType for u16 {} +impl BaseType for u32 {} +impl BaseType for u64 {} +impl BaseType for u128 {} +impl BaseType for usize {} +impl BaseType for f32 {} +impl BaseType for f64 {} +impl BaseType for bool {} +impl BaseType for char {} +impl BaseType for String {} +impl BaseType for &'static str {} + +pub trait TupleLength { + fn len() -> usize; +} + +impl TupleLength for T { + fn len() -> usize { + 1 + } +} +impl TupleLength for () { + fn len() -> usize { + 0 + } +} +impl TupleLength for (T,) { + fn len() -> usize { + 1 + } +} +impl TupleLength for (T1, T2) { + fn len() -> usize { + 2 + } +} +impl TupleLength for (T1, T2, T3) { + fn len() -> usize { + 3 + } +} +impl TupleLength for (T1, T2, T3, T4) { + fn len() -> usize { + 4 + } +} +impl TupleLength for (T1, T2, T3, T4, T5) { + fn len() -> usize { + 5 + } +} +impl TupleLength for (T1, T2, T3, T4, T5, T6) { + fn len() -> usize { + 6 + } +} +impl TupleLength for (T1, T2, T3, T4, T5, T6, T7) { + fn len() -> usize { + 7 + } +} + +pub trait FlattenTuple { + type Output; + + fn flatten(self) -> Self::Output; + + fn unflatten(other: Self::Output) -> Self; +} + +impl FlattenTuple for (T1, T2) { + type Output = (T1, T2); + fn flatten(self) -> Self::Output { + self + } + fn unflatten(other: Self::Output) -> Self { + other + } +} diff --git a/core/src/runtime/statics/utils/mod.rs b/core/src/runtime/statics/utils/mod.rs index 19f4929..3a862ba 100644 --- a/core/src/runtime/statics/utils/mod.rs +++ b/core/src/runtime/statics/utils/mod.rs @@ -1,9 +1,11 @@ mod argmax; mod argmin; +mod flatten_tuple; mod prod_type; mod sum_type; pub use argmax::*; pub use argmin::*; +pub use flatten_tuple::*; pub use prod_type::*; pub use sum_type::*; diff --git a/core/src/testing/test_compile.rs b/core/src/testing/test_compile.rs index 259a6cb..14f7e11 100644 --- a/core/src/testing/test_compile.rs +++ b/core/src/testing/test_compile.rs @@ -38,6 +38,7 @@ where match e { compiler::CompileError::Front(e) => { let e = format!("{}", e); + println!("{e}"); if f(e) { return; } diff --git a/core/src/utils/float.rs b/core/src/utils/float.rs index 07ced4f..72a4cc2 100644 --- a/core/src/utils/float.rs +++ b/core/src/utils/float.rs @@ -16,6 +16,8 @@ pub trait Float: fn zero() -> Self; fn one() -> Self; + + fn from_f64(f: f64) -> Self; } impl Float for f32 { @@ -26,6 +28,10 @@ impl Float for f32 { fn one() -> Self { 1.0 } + + fn from_f64(f: f64) -> Self { + f as f32 + } } impl Float for f64 { @@ -36,4 +42,8 @@ impl Float for f64 { fn one() -> Self { 1.0 } + + fn from_f64(f: f64) -> Self { + f + } } diff --git a/core/tests/compiler/errors.rs b/core/tests/compiler/errors.rs index 012031a..bde7204 100644 --- a/core/tests/compiler/errors.rs +++ b/core/tests/compiler/errors.rs @@ -133,9 +133,9 @@ fn bad_enum_type_decl() { fn bad_no_binding_agg_1() { expect_front_compile_failure( r#" - rel r() = x := count(edge(1, 3)) + rel r() = x := sum(edge(1, 3)) "#, - |e| e.contains("binding variables of `count` aggregation cannot be empty"), + |e| e.contains("arity mismatch on aggregate `sum`. Expected 1 input variables, but found 0"), ) } @@ -145,6 +145,6 @@ fn issue_96() { r#" type semantic_parse(bound q: String, e: Expr) "#, - |e| e.contains("unknown custom type `Expr`") + |e| e.contains("unknown custom type `Expr`"), ) } diff --git a/core/tests/compiler/parse.rs b/core/tests/compiler/parse.rs index 2cff902..fb40757 100644 --- a/core/tests/compiler/parse.rs +++ b/core/tests/compiler/parse.rs @@ -1,5 +1,5 @@ -use scallop_core::compiler::front::parser::*; use scallop_core::compiler::front::ast::*; +use scallop_core::compiler::front::parser::*; #[test] fn parse_type_decl() { diff --git a/core/tests/integrate/adt.rs b/core/tests/integrate/adt.rs index 6a59a0a..30cfe04 100644 --- a/core/tests/integrate/adt.rs +++ b/core/tests/integrate/adt.rs @@ -160,7 +160,7 @@ const EQSAT_1_PROGRAM: &'static str = r#" rel equiv_programs(sp) = input_program(p) and equivalent(p, sp) // Find the best program (minimum weight) among all programs equivalent to p - rel best_program(p) = w := min[p](w: equiv_programs(p) and weight(p, w)) + rel best_program(p) = p := argmin[p](w: equiv_programs(p) and weight(p, w)) rel best_program_str(s) = best_program(best_prog) and to_string(best_prog, s) query best_program_str "#; diff --git a/core/tests/integrate/aggregate.rs b/core/tests/integrate/aggregate.rs new file mode 100644 index 0000000..2e11c2e --- /dev/null +++ b/core/tests/integrate/aggregate.rs @@ -0,0 +1,93 @@ +use scallop_core::runtime::provenance::*; +use scallop_core::testing::*; + +#[test] +fn test_avg_1() { + expect_interpret_result( + r#" + rel scores = {(0, 55.0), (1, 45.0), (2, 50.0)} + rel avg_score(a) = a := avg[i](s: scores(i, s)) + "#, + ("avg_score", vec![(50.0f32,)]), + ); +} + +#[test] +fn test_weighted_avg_1() { + expect_interpret_result( + r#" + rel scores = {(0, 55.0), (1, 45.0), (2, 50.0)} + rel avg_score(a) = a := weighted_avg[i](s: scores(i, s)) + "#, + ("avg_score", vec![(50.0f32,)]), + ); +} + +#[test] +fn test_weighted_avg_2() { + let prov = min_max_prob::MinMaxProbProvenance::new(); + expect_interpret_result_with_tag( + r#" + rel scores = {0.5::(0, 55.0), 1.0::(1, 45.0), 0.5::(2, 50.0)} + rel avg_score(a) = a := weighted_avg![i](s: scores(i, s)) + "#, + prov, + ("avg_score", vec![(1.0, (48.75f32,))]), + min_max_prob::MinMaxProbProvenance::cmp, + ); +} + +#[test] +fn test_string_join_1() { + expect_interpret_result( + r#" + rel my_strings = {"hello", "world"} + rel result(j) = j := string_join(s: my_strings(s)) + "#, + ("result", vec![("helloworld".to_string(),)]), + ); +} + +#[test] +fn test_string_join_2() { + expect_interpret_result( + r#" + rel my_strings = {"hello", "world"} + rel result(j) = j := string_join<" ">(s: my_strings(s)) + "#, + ("result", vec![("hello world".to_string(),)]), + ); +} + +#[test] +fn test_string_join_3() { + expect_interpret_result( + r#" + rel my_strings = {(2, "hello"), (1, "world")} + rel result(j) = j := string_join<" ">[i](s: my_strings(i, s)) + "#, + ("result", vec![("world hello".to_string(),)]), + ); +} + +#[test] +fn test_aggregate_compile_fail_1() { + expect_front_compile_failure( + r#" + rel num = {1, 2, 3, 4} + rel my_top(x) = x := top<1, 3, 5, 7>(y: num(y)) + "#, + |s| s.contains("expected at most 1 parameter"), + ) +} + +#[test] +fn test_aggregate_compile_fail_2() { + expect_front_compile_failure( + r#" + rel num = {1, 2, 3, 4} + rel my_top(x) = x := argmin(y: num(y)) + "#, + |s| s.contains("Expected non-empty argument variables"), + ) +} diff --git a/core/tests/integrate/basic.rs b/core/tests/integrate/basic.rs index 03f0bb4..f464ee6 100644 --- a/core/tests/integrate/basic.rs +++ b/core/tests/integrate/basic.rs @@ -288,6 +288,17 @@ fn fib_dt_test_1() { ); } +#[test] +fn aggregate_argmax_1() { + expect_interpret_result( + r#" + rel exam_grades = {("tom", 50.0), ("mary", 60.0)} + rel best_student(n) = n := argmax[n](s: exam_grades(n, s)) + "#, + ("best_student", vec![("mary".to_string(),)]), + ) +} + #[test] fn obj_color_test_1() { expect_interpret_result( @@ -299,8 +310,8 @@ fn obj_color_test_1() { rel object_color(4, "green") rel object_color(5, "red") - rel color_count(c, n) :- n = count(o: object_color(o, c)) - rel max_color(c) :- _ = max[c](n: color_count(c, n)) + rel color_count(c, n) :- n := count(o: object_color(o, c)) + rel max_color(c) :- c := argmax[c](n: color_count(c, n)) "#, ("max_color", vec![("green".to_string(),)]), ); @@ -318,7 +329,7 @@ fn obj_color_test_2() { (4, "green"), (5, "blue"), } - rel max_color(c) = _ = max[c](n: n = count(o: object_color(o, c))) + rel max_color(c) = c := argmax[c](n: n := count(o: object_color(o, c))) "#, ("max_color", vec![("blue".to_string(),), ("green".to_string(),)]), ); @@ -476,7 +487,7 @@ fn class_student_grade_1() { (1, "frank", 30), } - rel class_top_student(c, s) = _ = max[s](g: class_student_grade(c, s, g)) + rel class_top_student(c, s) = s := argmax[s](g: class_student_grade(c, s, g)) "#, ( "class_top_student", @@ -499,13 +510,49 @@ fn class_student_grade_2() { } rel avg_score((s as f32) / (n as f32)) = - s = sum(x: class_student_grade(_, _, x)), - n = count(a, b, c: class_student_grade(a, b, c)) + s := sum[class, name](x: class_student_grade(class, name, x)), + n := count(class, name: class_student_grade(class, name, _)) "#, ("avg_score", vec![(63.333f32,)]), ) } +#[test] +fn aggr_sum_test_1() { + expect_interpret_result( + r#" + rel sales = {(0, 100), (1, 100), (2, 200)} + rel total_sale(t) = t := sum[p](s: sales(p, s)) + "#, + ("total_sale", vec![(400i32,)]), + ) +} + +#[test] +fn aggr_sum_test_2() { + expect_interpret_result( + r#" + rel sales = {("market", "tom", 100), ("ads", "tom", 100), ("ads", "jenny", 200)} + rel total_sale(t) = t := sum[d, p](s: sales(d, p, s)) + "#, + ("total_sale", vec![(400i32,)]), + ) +} + +#[test] +fn aggr_sum_test_3() { + expect_interpret_result( + r#" + rel sales = {("market", "tom", 100), ("ads", "tom", 100), ("ads", "jenny", 100)} + rel dp_sale(d, t) = t := sum[p](s: sales(d, p, s)) + "#, + ( + "dp_sale", + vec![("market".to_string(), 100i32), ("ads".to_string(), 200i32)], + ), + ) +} + #[test] fn unused_relation_1() { expect_interpret_result( diff --git a/core/tests/integrate/mod.rs b/core/tests/integrate/mod.rs index 94e91cb..8eae8f8 100644 --- a/core/tests/integrate/mod.rs +++ b/core/tests/integrate/mod.rs @@ -1,4 +1,5 @@ mod adt; +mod aggregate; mod attr; mod basic; mod bug; @@ -10,4 +11,5 @@ mod incr; mod io; mod iter; mod prob; +mod sampling; mod time; diff --git a/core/tests/integrate/sampling.rs b/core/tests/integrate/sampling.rs new file mode 100644 index 0000000..5e0acb1 --- /dev/null +++ b/core/tests/integrate/sampling.rs @@ -0,0 +1,19 @@ +use scallop_core::integrate::*; + +#[test] +fn test_uniform_sample() { + let result = interpret_string( + r#" + rel numbers = {0, 1, 2, 3} + rel sampled_number(x) = x := uniform<1>(x: numbers(x)) + "# + .to_string(), + ) + .expect("Failed executing"); + let sampled_number = result + .get_output_collection("sampled_number") + .expect("Cannot get `sampled_number` relation"); + assert_eq!(sampled_number.len(), 1, "There should be only one number being sampled"); + let number = sampled_number.ith_tuple(0).expect("There should be one number")[0].as_i32(); + assert!(number >= 0 && number < 4, "number should be anything between 0 and 3"); +} diff --git a/core/tests/runtime/dataflow/dyn_aggregate.rs b/core/tests/runtime/dataflow/dyn_aggregate.rs index c6a7df3..5ccda0c 100644 --- a/core/tests/runtime/dataflow/dyn_aggregate.rs +++ b/core/tests/runtime/dataflow/dyn_aggregate.rs @@ -1,4 +1,4 @@ -use scallop_core::common::aggregate_op::AggregateOp; +use scallop_core::common::value_type::ValueType; use scallop_core::runtime::dynamic::dataflow::*; use scallop_core::runtime::dynamic::*; use scallop_core::runtime::env::*; @@ -36,7 +36,9 @@ fn test_dynamic_aggregate_count_1() { agg.insert_dataflow_recent( &ctx, &DynamicDataflow::new(DynamicAggregationSingleGroupDataflow::new( - AggregateOp::count().into(), + rt.aggregate_registry + .instantiate_aggregator("count", vec![], false, vec![], vec![ValueType::I8, ValueType::I8]) + .unwrap(), DynamicDataflow::dynamic_collection(&completed_target, first_time), &ctx, &rt, diff --git a/core/tests/runtime/dataflow/dyn_difference.rs b/core/tests/runtime/dataflow/dyn_difference.rs index db79142..b7b5678 100644 --- a/core/tests/runtime/dataflow/dyn_difference.rs +++ b/core/tests/runtime/dataflow/dyn_difference.rs @@ -38,7 +38,8 @@ where while source_1.changed(&ctx) || target.changed(&ctx) { target.insert_dataflow_recent( &ctx, - &DynamicDataflow::dynamic_relation(&source_1).difference(DynamicDataflow::dynamic_recent_collection(&source_2_coll), &ctx), + &DynamicDataflow::dynamic_relation(&source_1) + .difference(DynamicDataflow::dynamic_recent_collection(&source_2_coll), &ctx), &mut rt, ) } diff --git a/core/tests/runtime/dataflow/dyn_group_aggregate.rs b/core/tests/runtime/dataflow/dyn_group_aggregate.rs index cb6d7e6..99dc465 100644 --- a/core/tests/runtime/dataflow/dyn_group_aggregate.rs +++ b/core/tests/runtime/dataflow/dyn_group_aggregate.rs @@ -1,5 +1,5 @@ -use scallop_core::common::aggregate_op::AggregateOp; use scallop_core::common::expr::*; +use scallop_core::common::value_type::*; use scallop_core::runtime::dynamic::dataflow::*; use scallop_core::runtime::dynamic::*; use scallop_core::runtime::env::*; @@ -51,7 +51,9 @@ fn test_dynamic_group_and_count_1() { color_count.insert_dataflow_recent( &ctx, &DynamicDataflow::new(DynamicAggregationImplicitGroupDataflow::new( - AggregateOp::count().into(), + rt.aggregate_registry + .instantiate_aggregator("count", vec![], false, vec![], vec![ValueType::USize]) + .unwrap(), DynamicDataflow::dynamic_collection(&completed_rev_color, first_time), &ctx, &rt, @@ -112,7 +114,9 @@ fn test_dynamic_group_count_max_1() { color_count.insert_dataflow_recent( &ctx, &DynamicDataflow::new(DynamicAggregationImplicitGroupDataflow::new( - AggregateOp::count().into(), + rt.aggregate_registry + .instantiate_aggregator("count", vec![], false, vec![], vec![ValueType::USize]) + .unwrap(), DynamicDataflow::dynamic_collection(&completed_rev_color, iter_1_first_time), &ctx, &rt, @@ -132,7 +136,9 @@ fn test_dynamic_group_count_max_1() { max_count_color.insert_dataflow_recent( &ctx, &DynamicDataflow::new(DynamicAggregationSingleGroupDataflow::new( - AggregateOp::Argmax.into(), + rt.aggregate_registry + .instantiate_aggregator("max", vec![], false, vec![ValueType::Str], vec![ValueType::USize]) + .unwrap(), DynamicDataflow::dynamic_collection(&completed_color_count, iter_2_first_time), &ctx, &rt, diff --git a/core/tests/runtime/dataflow/dyn_group_by_key.rs b/core/tests/runtime/dataflow/dyn_group_by_key.rs index 346d414..ae27c1e 100644 --- a/core/tests/runtime/dataflow/dyn_group_by_key.rs +++ b/core/tests/runtime/dataflow/dyn_group_by_key.rs @@ -1,5 +1,5 @@ -use scallop_core::common::aggregate_op::AggregateOp; use scallop_core::common::expr::*; +use scallop_core::common::value_type::*; use scallop_core::compiler::ram::*; use scallop_core::runtime::dynamic::*; use scallop_core::runtime::env::*; @@ -10,8 +10,8 @@ fn test_group_by_key_1() -> DynamicCollection where Prov: Provenance + Default, { - let mut ctx = Prov::default(); - let mut rt = RuntimeEnvironment::new_std(); + let ctx = Prov::default(); + let rt = RuntimeEnvironment::new_std(); let result_1 = { let mut strata_1 = DynamicIteration::::new(); @@ -21,10 +21,10 @@ where strata_1.create_dynamic_relation("_colors_key"); strata_1 .get_dynamic_relation_unsafe("color") - .insert_untagged(&mut ctx, vec![(0usize, "blue"), (1, "green"), (2, "blue")]); + .insert_untagged(&ctx, vec![(0usize, "blue"), (1, "green"), (2, "blue")]); strata_1 .get_dynamic_relation_unsafe("colors") - .insert_untagged(&mut ctx, vec![("blue",), ("green",), ("red",)]); + .insert_untagged(&ctx, vec![("blue",), ("green",), ("red",)]); strata_1.add_update_dataflow( "_color_rev", Dataflow::relation("color".to_string()).project((Expr::access(1), Expr::access(0))), @@ -35,7 +35,7 @@ where ); strata_1.add_output_relation("_color_rev"); strata_1.add_output_relation("_colors_key"); - strata_1.run(&ctx, &mut rt) + strata_1.run(&ctx, &rt) }; let mut result_2 = { @@ -45,11 +45,19 @@ where strata_2.add_input_dynamic_collection("_colors_key", &result_1["_colors_key"]); strata_2.add_update_dataflow( "color_count", - Dataflow::reduce(AggregateOp::count(), "_color_rev", ReduceGroupByType::join("_colors_key")) - .project((Expr::access(0), Expr::access(2))), + Dataflow::reduce( + "count".to_string(), + vec![], + false, + vec![], + vec![ValueType::USize], + "_color_rev", + ReduceGroupByType::join("_colors_key"), + ) + .project((Expr::access(0), Expr::access(2))), ); strata_2.add_output_relation("color_count"); - strata_2.run(&ctx, &mut rt) + strata_2.run(&ctx, &rt) }; result_2.remove("color_count").unwrap() diff --git a/core/tests/runtime/interpret/iteration.rs b/core/tests/runtime/interpret/iteration.rs index de16eb7..089961b 100644 --- a/core/tests/runtime/interpret/iteration.rs +++ b/core/tests/runtime/interpret/iteration.rs @@ -1,5 +1,5 @@ -use scallop_core::common::aggregate_op::AggregateOp; use scallop_core::common::expr::*; +use scallop_core::common::value_type::ValueType; use scallop_core::compiler::ram::*; use scallop_core::runtime::dynamic::*; use scallop_core::runtime::env::*; @@ -10,9 +10,9 @@ fn test_iteration_1() -> DynamicCollection where Prov: Provenance + Default, { - let mut ctx = Prov::default(); + let ctx = Prov::default(); + let rt = RuntimeEnvironment::default(); let mut iter = DynamicIteration::::new(); - let mut rt = RuntimeEnvironment::default(); // First create relations iter.create_dynamic_relation("edge"); @@ -22,7 +22,7 @@ where // Insert EDB facts iter .get_dynamic_relation_unsafe("edge") - .insert_untagged(&mut ctx, vec![(0, 1), (1, 2), (1, 3)]); + .insert_untagged(&ctx, vec![(0, 1), (1, 2), (1, 3)]); // Insert updates iter.add_update_dataflow("path", Dataflow::relation("edge")); @@ -42,7 +42,7 @@ where iter.add_output_relation("edge"); // Run the iteration - let mut result = iter.run(&ctx, &mut rt); + let mut result = iter.run(&ctx, &rt); // Test the result expect_collection(&result["path"], vec![(0, 1), (1, 2), (0, 2), (1, 3), (0, 3)]); @@ -71,15 +71,15 @@ fn test_iteration_2() -> DynamicCollection where Prov: Provenance + Default, { - let mut ctx = Prov::default(); - let mut rt = RuntimeEnvironment::default(); + let ctx = Prov::default(); + let rt = RuntimeEnvironment::default(); let result_1 = { let mut strata_1 = DynamicIteration::::new(); strata_1.create_dynamic_relation("color"); strata_1.create_dynamic_relation("_color_rev"); strata_1.get_dynamic_relation_unsafe("color").insert_untagged( - &mut ctx, + &ctx, vec![ (0, "blue"), (1, "green"), @@ -94,7 +94,7 @@ where Dataflow::relation("color").project((Expr::access(1), Expr::access(0))), ); strata_1.add_output_relation("_color_rev"); - strata_1.run(&ctx, &mut rt) + strata_1.run(&ctx, &rt) }; let result_2 = { @@ -103,10 +103,18 @@ where strata_2.add_input_dynamic_collection("_color_rev", &result_1["_color_rev"]); strata_2.add_update_dataflow( "color_count", - Dataflow::reduce(AggregateOp::count(), "_color_rev", ReduceGroupByType::Implicit), + Dataflow::reduce( + "count".to_string(), + vec![], + false, + vec![], + vec![ValueType::I32], + "_color_rev", + ReduceGroupByType::Implicit, + ), ); strata_2.add_output_relation("color_count"); - strata_2.run(&ctx, &mut rt) + strata_2.run(&ctx, &rt) }; let mut result_3 = { @@ -115,10 +123,18 @@ where strata_3.add_input_dynamic_collection("color_count", &result_2["color_count"]); strata_3.add_update_dataflow( "max_color_count", - Dataflow::reduce(AggregateOp::Argmax, "color_count", ReduceGroupByType::None), + Dataflow::reduce( + "max".to_string(), + vec![], + false, + vec![ValueType::Str], + vec![ValueType::USize], + "color_count", + ReduceGroupByType::None, + ), ); strata_3.add_output_relation("max_color_count"); - strata_3.run(&ctx, &mut rt) + strata_3.run(&ctx, &rt) }; expect_collection(&result_3["max_color_count"], vec![("blue", 3usize)]); diff --git a/core/tests/runtime/provenance/prob.rs b/core/tests/runtime/provenance/prob.rs index e361d6f..cf1b0f5 100644 --- a/core/tests/runtime/provenance/prob.rs +++ b/core/tests/runtime/provenance/prob.rs @@ -1,5 +1,5 @@ -use scallop_core::common::aggregate_op::AggregateOp; use scallop_core::common::expr::*; +use scallop_core::common::value_type::ValueType; use scallop_core::compiler::ram::*; use scallop_core::runtime::dynamic::*; use scallop_core::runtime::env::*; @@ -7,15 +7,15 @@ use scallop_core::runtime::provenance::*; #[test] fn test_simple_probability_count() { - let mut ctx = min_max_prob::MinMaxProbProvenance::default(); - let mut rt = RuntimeEnvironment::default(); + let ctx = min_max_prob::MinMaxProbProvenance::default(); + let rt = RuntimeEnvironment::default(); let result_1 = { let mut strata_1 = DynamicIteration::::new(); strata_1.create_dynamic_relation("color"); strata_1.create_dynamic_relation("_color_rev"); strata_1.get_dynamic_relation_unsafe("color").insert_tagged( - &mut ctx, + &ctx, vec![ (Some(0.5), (0usize, "blue")), (Some(0.8), (1, "green")), @@ -30,7 +30,7 @@ fn test_simple_probability_count() { Dataflow::relation("color").project((Expr::access(1), Expr::access(0))), ); strata_1.add_output_relation("_color_rev"); - strata_1.run(&ctx, &mut rt) + strata_1.run(&ctx, &rt) }; let result_2 = { @@ -39,10 +39,18 @@ fn test_simple_probability_count() { strata_2.add_input_dynamic_collection("_color_rev", &result_1["_color_rev"]); strata_2.add_update_dataflow( "color_count", - Dataflow::reduce(AggregateOp::count(), "_color_rev", ReduceGroupByType::Implicit), + Dataflow::reduce( + "count".to_string(), + vec![], + false, + vec![], + vec![ValueType::USize], + "_color_rev", + ReduceGroupByType::Implicit, + ), ); strata_2.add_output_relation("color_count"); - strata_2.run(&ctx, &mut rt) + strata_2.run(&ctx, &rt) }; println!("{:?}", result_2) @@ -50,15 +58,15 @@ fn test_simple_probability_count() { #[test] fn test_min_max_prob_count_max() { - let mut ctx = min_max_prob::MinMaxProbProvenance::default(); - let mut rt = RuntimeEnvironment::default(); + let ctx = min_max_prob::MinMaxProbProvenance::default(); + let rt = RuntimeEnvironment::default(); let result_1 = { let mut strata_1 = DynamicIteration::::new(); strata_1.create_dynamic_relation("color"); strata_1.create_dynamic_relation("_color_rev"); strata_1.get_dynamic_relation_unsafe("color").insert_tagged( - &mut ctx, + &ctx, vec![ (Some(0.6), (0usize, "blue")), (Some(0.4), (0, "green")), @@ -73,7 +81,7 @@ fn test_min_max_prob_count_max() { Dataflow::relation("color").project((Expr::access(1), Expr::access(0))), ); strata_1.add_output_relation("_color_rev"); - strata_1.run(&ctx, &mut rt) + strata_1.run(&ctx, &rt) }; let result_2 = { @@ -82,10 +90,18 @@ fn test_min_max_prob_count_max() { strata_2.add_input_dynamic_collection("_color_rev", &result_1["_color_rev"]); strata_2.add_update_dataflow( "color_count", - Dataflow::reduce(AggregateOp::count(), "_color_rev", ReduceGroupByType::Implicit), + Dataflow::reduce( + "count".to_string(), + vec![], + false, + vec![], + vec![ValueType::USize], + "_color_rev", + ReduceGroupByType::Implicit, + ), ); strata_2.add_output_relation("color_count"); - strata_2.run(&ctx, &mut rt) + strata_2.run(&ctx, &rt) }; println!("{:?}", result_2); @@ -96,10 +112,18 @@ fn test_min_max_prob_count_max() { strata_3.add_input_dynamic_collection("color_count", &result_2["color_count"]); strata_3.add_update_dataflow( "max_color", - Dataflow::reduce(AggregateOp::Argmax, "color_count", ReduceGroupByType::None), + Dataflow::reduce( + "max".to_string(), + vec![], + false, + vec![ValueType::Str], + vec![ValueType::USize], + "color_count", + ReduceGroupByType::None, + ), ); strata_3.add_output_relation("max_color"); - strata_3.run(&ctx, &mut rt) + strata_3.run(&ctx, &rt) }; println!("{:?}", result_2["color_count"]); diff --git a/core/tests/runtime/provenance/top_bottom_k.rs b/core/tests/runtime/provenance/top_bottom_k.rs index 5b265f1..c35ce2d 100644 --- a/core/tests/runtime/provenance/top_bottom_k.rs +++ b/core/tests/runtime/provenance/top_bottom_k.rs @@ -7,7 +7,7 @@ mod diff { #[test] fn test_diff_top_bottom_k_clauses_1() { - let ctx = DiffTopBottomKClausesProvenance::<(), RcFamily>::new(1); + let ctx = DiffTopBottomKClausesProvenance::<(), RcFamily>::new(1, false); // Create a few tags let a = ctx.tagging_fn((0.9, (), None).into()); @@ -29,7 +29,7 @@ mod diff { #[test] fn test_diff_top_bottom_k_clauses_2() { - let ctx = DiffTopBottomKClausesProvenance::<(), RcFamily>::new(1); + let ctx = DiffTopBottomKClausesProvenance::<(), RcFamily>::new(1, false); // Create a few tags let a = ctx.tagging_fn((0.1, (), None).into()); diff --git a/core/tests/runtime/statics/iteration.rs b/core/tests/runtime/statics/iteration.rs index 9155e4a..375bdd5 100644 --- a/core/tests/runtime/statics/iteration.rs +++ b/core/tests/runtime/statics/iteration.rs @@ -1,11 +1,13 @@ +use scallop_core::runtime::env::*; use scallop_core::runtime::provenance::*; use scallop_core::runtime::statics::*; use scallop_core::testing::*; #[test] fn test_static_iter_edge_path() { + let env = RuntimeEnvironment::default(); let mut prov = unit::UnitProvenance::default(); - let mut iter = StaticIteration::::new(&mut prov); + let mut iter = StaticIteration::::new(&env, &mut prov); // Add relations let edge = iter.create_relation::<(usize, usize)>(); @@ -30,6 +32,7 @@ fn test_static_iter_edge_path() { #[test] fn test_static_iter_odd_even_3() { + let env = RuntimeEnvironment::default(); let mut prov = unit::UnitProvenance::default(); struct Stratum0Result { @@ -38,8 +41,8 @@ fn test_static_iter_odd_even_3() { _numbers_perm_0_0: StaticCollection<(i32, i32), C>, } - fn stratum_0(prov: &mut C) -> Stratum0Result { - let mut iter = StaticIteration::::new(prov); + fn stratum_0(env: &RuntimeEnvironment, prov: &mut C) -> Stratum0Result { + let mut iter = StaticIteration::::new(env, prov); let numbers = iter.create_relation::<(i32,)>(); let _numbers_perm_0_ = iter.create_relation::<(i32, ())>(); let _numbers_perm_0_0 = iter.create_relation::<(i32, i32)>(); @@ -70,8 +73,12 @@ fn test_static_iter_odd_even_3() { odd: StaticCollection<(i32,), C>, } - fn stratum_1(prov: &mut C, stratum_0_result: &Stratum0Result) -> Stratum1Result { - let mut iter = StaticIteration::::new(prov); + fn stratum_1( + env: &RuntimeEnvironment, + prov: &mut C, + stratum_0_result: &Stratum0Result, + ) -> Stratum1Result { + let mut iter = StaticIteration::::new(env, prov); let _temp_0 = iter.create_relation::<(i32, i32)>(); let odd = iter.create_relation::<(i32,)>(); let _odd_perm_0 = iter.create_relation::<(i32, ())>(); @@ -113,11 +120,12 @@ fn test_static_iter_odd_even_3() { } fn stratum_2( + env: &RuntimeEnvironment, prov: &mut C, stratum_0_result: &Stratum0Result, stratum_1_result: &Stratum1Result, ) -> Stratum2Result { - let mut iter = StaticIteration::::new(prov); + let mut iter = StaticIteration::::new(env, prov); let even = iter.create_relation::<(i32,)>(); while iter.changed() || iter.is_first_iteration() { iter.insert_dataflow( @@ -141,9 +149,9 @@ fn test_static_iter_odd_even_3() { } // Execute - let stratum_0_result = stratum_0(&mut prov); - let stratum_1_result = stratum_1(&mut prov, &stratum_0_result); - let stratum_2_result = stratum_2(&mut prov, &stratum_0_result, &stratum_1_result); + let stratum_0_result = stratum_0(&env, &mut prov); + let stratum_1_result = stratum_1(&env, &mut prov, &stratum_0_result); + let stratum_2_result = stratum_2(&env, &mut prov, &stratum_0_result, &stratum_1_result); // Check result expect_static_collection(&stratum_1_result.odd, vec![(1,), (3,), (5,), (7,), (9,)]); @@ -157,8 +165,8 @@ fn test_static_out_degree_join() { edge: StaticCollection<(usize, usize), C>, } - fn stratum_0(prov: &mut C) -> Stratum0Result { - let mut iter = StaticIteration::::new(prov); + fn stratum_0(env: &RuntimeEnvironment, prov: &mut C) -> Stratum0Result { + let mut iter = StaticIteration::::new(env, prov); // Add relations let node = iter.create_relation::<(usize,)>(); @@ -186,8 +194,12 @@ fn test_static_out_degree_join() { out_degree: StaticCollection<(usize, usize), C>, } - fn stratum_1(prov: &mut C, stratum_0_result: &Stratum0Result) -> Stratum1Result { - let mut iter = StaticIteration::::new(prov); + fn stratum_1( + env: &RuntimeEnvironment, + prov: &mut C, + stratum_0_result: &Stratum0Result, + ) -> Stratum1Result { + let mut iter = StaticIteration::::new(env, prov); // Add relations let out_degree = iter.create_relation::<(usize, usize)>(); @@ -198,7 +210,7 @@ fn test_static_out_degree_join() { &out_degree, dataflow::project( iter.aggregate_join_group( - CountAggregator::new(), + CountAggregator::new(false), dataflow::collection(&stratum_0_result.node_temp, iter.is_first_iteration()), dataflow::collection(&stratum_0_result.edge, iter.is_first_iteration()), ), @@ -213,11 +225,12 @@ fn test_static_out_degree_join() { } } + let env = RuntimeEnvironment::default(); let mut prov = unit::UnitProvenance::default(); // Execute - let stratum_0_result = stratum_0(&mut prov); - let stratum_1_result = stratum_1(&mut prov, &stratum_0_result); + let stratum_0_result = stratum_0(&env, &mut prov); + let stratum_1_result = stratum_1(&env, &mut prov, &stratum_0_result); // Check result expect_static_collection(&stratum_1_result.out_degree, vec![(0, 1), (1, 1), (2, 0)]); @@ -229,8 +242,8 @@ fn test_static_out_degree_implicit_group() { edge: StaticCollection<(usize, usize), C>, } - fn stratum_0(prov: &mut C) -> Stratum0Result { - let mut iter = StaticIteration::::new(prov); + fn stratum_0(env: &RuntimeEnvironment, prov: &mut C) -> Stratum0Result { + let mut iter = StaticIteration::::new(env, prov); // Add relations let edge = iter.create_relation::<(usize, usize)>(); @@ -253,8 +266,12 @@ fn test_static_out_degree_implicit_group() { out_degree: StaticCollection<(usize, usize), C>, } - fn stratum_1(prov: &mut C, stratum_0_result: &Stratum0Result) -> Stratum1Result { - let mut iter = StaticIteration::::new(prov); + fn stratum_1( + env: &RuntimeEnvironment, + prov: &mut C, + stratum_0_result: &Stratum0Result, + ) -> Stratum1Result { + let mut iter = StaticIteration::::new(env, prov); // Add relations let out_degree = iter.create_relation::<(usize, usize)>(); @@ -264,7 +281,7 @@ fn test_static_out_degree_implicit_group() { iter.insert_dataflow( &out_degree, iter.aggregate_implicit_group( - CountAggregator::new(), + CountAggregator::new(false), dataflow::collection(&stratum_0_result.edge, iter.is_first_iteration()), ), ); @@ -276,11 +293,12 @@ fn test_static_out_degree_implicit_group() { } } + let env = RuntimeEnvironment::default(); let mut prov = unit::UnitProvenance::default(); // Execute - let stratum_0_result = stratum_0(&mut prov); - let stratum_1_result = stratum_1(&mut prov, &stratum_0_result); + let stratum_0_result = stratum_0(&env, &mut prov); + let stratum_1_result = stratum_1(&env, &mut prov, &stratum_0_result); // Check result expect_static_collection(&stratum_1_result.out_degree, vec![(0, 1), (1, 1)]); @@ -292,8 +310,8 @@ fn test_static_num_edges() { edge: StaticCollection<(usize, usize), C>, } - fn stratum_0(prov: &mut C) -> Stratum0Result { - let mut iter = StaticIteration::::new(prov); + fn stratum_0(env: &RuntimeEnvironment, prov: &mut C) -> Stratum0Result { + let mut iter = StaticIteration::::new(env, prov); // Add relations let edge = iter.create_relation::<(usize, usize)>(); @@ -316,8 +334,12 @@ fn test_static_num_edges() { num_edges: StaticCollection, } - fn stratum_1(prov: &mut C, stratum_0_result: &Stratum0Result) -> Stratum1Result { - let mut iter = StaticIteration::::new(prov); + fn stratum_1( + env: &RuntimeEnvironment, + prov: &mut C, + stratum_0_result: &Stratum0Result, + ) -> Stratum1Result { + let mut iter = StaticIteration::::new(env, prov); // Add relations let num_edges = iter.create_relation::(); @@ -327,7 +349,7 @@ fn test_static_num_edges() { iter.insert_dataflow( &num_edges, iter.aggregate( - CountAggregator::new(), + CountAggregator::new(false), dataflow::collection(&stratum_0_result.edge, iter.is_first_iteration()), ), ); @@ -339,11 +361,12 @@ fn test_static_num_edges() { } } + let env = RuntimeEnvironment::default(); let mut prov = unit::UnitProvenance::default(); // Execute - let stratum_0_result = stratum_0(&mut prov); - let stratum_1_result = stratum_1(&mut prov, &stratum_0_result); + let stratum_0_result = stratum_0(&env, &mut prov); + let stratum_1_result = stratum_1(&env, &mut prov, &stratum_0_result); // Check result expect_static_collection(&stratum_1_result.num_edges, vec![3]); diff --git a/doc/src/language/adt_and_entity.md b/doc/src/language/adt_and_entity.md index 8c268c4..756c7dd 100644 --- a/doc/src/language/adt_and_entity.md +++ b/doc/src/language/adt_and_entity.md @@ -445,7 +445,7 @@ Lastly, we use the aggregation to find the equivalent programs with the minimum Note that we have used an `argmax` aggregation denoted by `min[p]` here: ``` scl -rel best_program(p) = _ := min[p](w: input_expr(e) and equivalent(e, p) and weight(p, w)) +rel best_program(p) = p := argmin[p](w: input_expr(e) and equivalent(e, p) and weight(p, w)) ``` If we query for the best program and turn it into string, we will get our expected output, a single variable `"a"`! diff --git a/doc/src/language/aggregation.md b/doc/src/language/aggregation.md index 53cbe2e..f564082 100644 --- a/doc/src/language/aggregation.md +++ b/doc/src/language/aggregation.md @@ -154,7 +154,7 @@ It is also possible to get argmax/argmin. Suppose we want to get the person (along with their grade) who scored the best, we write: ``` scl -rel best_student(n, s) = s := max[n](s: exam_grades(n, s)) +rel best_student(n, s) = (n, s) := max[n](s: exam_grades(n, s)) ``` Here, we are still finding the maximum score `s`, but along with `max` we have specified the "arg" (`[n]`) which associates with the maximum score. @@ -163,8 +163,14 @@ The arg variable is grounded by the aggregation body, and can be directly used i If we do not care about the grade and just want to know who has the best grade, we can use wildcard `_` to ignore the result variable, like +``` scl +rel best_student(n) = (n, _) := max[n](s: exam_grades(n, s)) ``` -rel best_student(n) = _ := max[n](s: exam_grades(n, s)) + +Alternatively, we can also use `argmax`: + +``` scl +rel best_student(n) = n := argmax[n](s: exam_grades(n, s)) ``` ## Exists and Forall diff --git a/etc/codegen/Cargo.toml b/etc/codegen/Cargo.toml index 5851a0a..a4699bf 100644 --- a/etc/codegen/Cargo.toml +++ b/etc/codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallop-codegen" -version = "0.2.0" +version = "0.2.1" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/codegen/examples/digit_sum_2_codegen.rs b/etc/codegen/examples/digit_sum_2_codegen.rs index d2c4e92..9d46c33 100644 --- a/etc/codegen/examples/digit_sum_2_codegen.rs +++ b/etc/codegen/examples/digit_sum_2_codegen.rs @@ -12,7 +12,7 @@ mod sum_2 { fn main() { // First set the top-k-proofs provenance context - let mut ctx = top_k_proofs::TopKProofsProvenance::new(3); + let mut ctx = top_k_proofs::TopKProofsProvenance::new(3, false); // Then create an edb and populate facts inside of it let mut edb = sum_2::create_edb::(); diff --git a/etc/codegen/tests/codegen_basic.rs b/etc/codegen/tests/codegen_basic.rs index cb172c4..4853f4c 100644 --- a/etc/codegen/tests/codegen_basic.rs +++ b/etc/codegen/tests/codegen_basic.rs @@ -300,6 +300,24 @@ fn codegen_out_degree_2() { expect_static_output_collection(&result.out_degree, vec![(0, 1), (1, 1), (2, 0)]); } +#[test] +fn codegen_exists_1() { + mod exists_1 { + use scallop_codegen::scallop; + scallop! { + rel edge = {(0, 1), (1, 2)} + rel path(x, y) = edge(x, y) or (path(x, z) and edge(z, y)) + rel result1(b) = b := exists(path(0, 2)) + rel result2(b) = b := exists(path(0, 3)) + } + } + + let mut ctx = unit::UnitProvenance::default(); + let result = exists_1::run(&mut ctx); + expect_static_output_collection(&result.result1, vec![(true,)]); + expect_static_output_collection(&result.result2, vec![(false,)]); +} + #[test] fn codegen_sum_1() { mod sum_1 { @@ -506,7 +524,7 @@ fn codegen_srl_1() { noun(n, _, "Person"), synonym(vid, vid0) - rel how_many_play_soccer(c) = c = count(n: plays_soccer(n)) + rel how_many_play_soccer(c) = c := count(n: plays_soccer(n)) } } let mut ctx = unit::UnitProvenance::default(); @@ -528,7 +546,7 @@ fn codegen_class_student_grade_1() { (1, "frank", 30), } - rel class_top_student(c, s) = _ = max[s](g: class_student_grade(c, s, g)) + rel class_top_student(c, s) = s := argmax[s](g: class_student_grade(c, s, g)) } } let mut ctx = unit::UnitProvenance::default(); diff --git a/etc/scallop-cli/setup.cfg b/etc/scallop-cli/setup.cfg index 56ebd2b..88b4401 100644 --- a/etc/scallop-cli/setup.cfg +++ b/etc/scallop-cli/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = scallop -version = 0.2.0 +version = 0.2.1 author = Ziyang Li author_email = liby99@seas.upenn.edu description = Scallop CLI diff --git a/etc/scallop-wasm/Cargo.toml b/etc/scallop-wasm/Cargo.toml index 59cae89..97e915f 100644 --- a/etc/scallop-wasm/Cargo.toml +++ b/etc/scallop-wasm/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallop-wasm" -version = "0.2.0" +version = "0.2.1" authors = ["Ziyang Li"] edition = "2018" diff --git a/etc/scallop-wasm/src/lib.rs b/etc/scallop-wasm/src/lib.rs index 6e5dbcf..be3cc86 100644 --- a/etc/scallop-wasm/src/lib.rs +++ b/etc/scallop-wasm/src/lib.rs @@ -27,7 +27,7 @@ pub fn interpret_with_minmaxprob(source: String) -> String { #[wasm_bindgen] pub fn interpret_with_topkproofs(source: String, top_k: usize) -> String { - let ctx = top_k_proofs::TopKProofsProvenance::::new(top_k); + let ctx = top_k_proofs::TopKProofsProvenance::::new(top_k, false); match interpret_string_with_ctx(source, ctx) { Ok(result) => result .into_iter() @@ -42,7 +42,7 @@ pub fn interpret_with_topkproofs(source: String, top_k: usize) -> String { #[wasm_bindgen] pub fn interpret_with_topbottomkclauses(source: String, k: usize) -> String { - let ctx = top_bottom_k_clauses::TopBottomKClausesProvenance::::new(k); + let ctx = top_bottom_k_clauses::TopBottomKClausesProvenance::::new(k, false); match interpret_string_with_ctx(source, ctx) { Ok(result) => result .into_iter() diff --git a/etc/scallopy-ext/setup.cfg b/etc/scallopy-ext/setup.cfg index 71bc5aa..12ada4d 100644 --- a/etc/scallopy-ext/setup.cfg +++ b/etc/scallopy-ext/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = scallopy_ext -version = 0.2.0 +version = 0.2.1 author = Ziyang Li author_email = liby99@seas.upenn.edu description = Scallopy Extension diff --git a/etc/scallopy-ext/src/scallopy_ext/registry.py b/etc/scallopy-ext/src/scallopy_ext/registry.py index 48e2956..87bb160 100644 --- a/etc/scallopy-ext/src/scallopy_ext/registry.py +++ b/etc/scallopy-ext/src/scallopy_ext/registry.py @@ -4,25 +4,27 @@ import scallopy +from . import utils + class PluginRegistry: def __init__(self): - self.setup_argparse_functions = entry_points(group="scallop.plugin.setup_arg_parser") - self.configure_functions = entry_points(group="scallop.plugin.configure") - self.loading_functions = entry_points(group="scallop.plugin.load_into_context") + self.setup_argparse_functions = utils.dedup(entry_points(group="scallop.plugin.setup_arg_parser")) + self.configure_functions = utils.dedup(entry_points(group="scallop.plugin.configure")) + self.loading_functions = utils.dedup(entry_points(group="scallop.plugin.load_into_context")) self.unknown_args = {} def loaded_plugins(self) -> List[str]: all_plugins = set() for setup_module in self.setup_argparse_functions: - all_plugins.add(setup_module.name) - for setup_module in self.setup_argparse_functions: - all_plugins.add(setup_module.name) - for setup_module in self.setup_argparse_functions: - all_plugins.add(setup_module.name) + all_plugins.add(f"{setup_module.name}::setup_arg_parser") + for config_module in self.configure_functions: + all_plugins.add(f"{config_module.name}::configure") + for loading_module in self.loading_functions: + all_plugins.add(f"{loading_module.name}::load_into_context") return list(all_plugins) def dump_loaded_plugins(self): - print("[scallopy-ext] Loaded plugins:", self.loaded_plugins()) + print("[scallopy-ext] Loaded plugins:", ", ".join(self.loaded_plugins())) def setup_argument_parser(self, parser: ArgumentParser): for setup_module in self.setup_argparse_functions: diff --git a/etc/scallopy-ext/src/scallopy_ext/utils.py b/etc/scallopy-ext/src/scallopy_ext/utils.py new file mode 100644 index 0000000..54f4672 --- /dev/null +++ b/etc/scallopy-ext/src/scallopy_ext/utils.py @@ -0,0 +1,2 @@ +def dedup(elems): + return list(set(elems)) diff --git a/etc/scallopy/Cargo.toml b/etc/scallopy/Cargo.toml index dfab86e..04f6f6a 100644 --- a/etc/scallopy/Cargo.toml +++ b/etc/scallopy/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallopy" -version = "0.2.0" +version = "0.2.1" edition = "2018" [lib] diff --git a/etc/scallopy/scallopy/context.py b/etc/scallopy/scallopy/context.py index 0b65cfe..8d4de89 100644 --- a/etc/scallopy/scallopy/context.py +++ b/etc/scallopy/scallopy/context.py @@ -60,10 +60,12 @@ def __init__( provenance: str = "unit", custom_provenance: Optional[ScallopProvenance] = None, k: int = 3, + wmc_with_disjunctions: bool = False, train_k: Optional[int] = None, test_k: Optional[int] = None, fork_from: Optional[ScallopContext] = None, no_stdlib: bool = False, + monitors: List[str] = [], ): super(ScallopContext, self).__init__() @@ -98,15 +100,21 @@ def __init__( self._mutual_exclusion_counter = Counter() self._sample_facts = {} self._k = k + self._wmc_with_disjunctions = wmc_with_disjunctions self._train_k = train_k self._test_k = test_k self._history_actions: List[HistoryAction] = [] + self._monitors = monitors self._internal = InternalScallopContext(provenance=provenance, custom_provenance=custom_provenance, k=k) # Load stdlib self._internal.enable_tensor_registry() # Always enable tensor registry for now if not no_stdlib: self.load_stdlib() + + # Load monitors + if len(self._monitors) > 0: + self._internal.add_monitors(self._monitors) else: # Fork from an existing context self.provenance = deepcopy(fork_from.provenance) @@ -118,9 +126,11 @@ def __init__( self._mutual_exclusion_counter = deepcopy(fork_from._mutual_exclusion_counter) self._sample_facts = deepcopy(fork_from._sample_facts) self._k = deepcopy(fork_from._k) + self._wmc_with_disjunctions = deepcopy(fork_from._wmc_with_disjunctions) self._train_k = deepcopy(fork_from._train_k) self._test_k = deepcopy(fork_from._test_k) self._history_actions = deepcopy(fork_from._history_actions) + self._monitors = deepcopy(fork_from._monitors) self._internal = fork_from._internal.clone() def __getstate__(self): @@ -142,14 +152,19 @@ def __setstate__(self, state): self._history_actions: List[HistoryAction] = [] # Internal scallop context - self._internal = InternalScallopContext(provenance=self.provenance, custom_provenance=self._custom_provenance, k=self._k) + self._internal = InternalScallopContext(provenance=self.provenance, custom_provenance=self._custom_provenance, k=self._k, wmc_with_disjunctions=self._wmc_with_disjunctions) # Restore from history actions for history_action in state["_history_actions"]: function = getattr(self, history_action.func_name) function(*history_action.pos_args, **history_action.kw_args) - def clone(self, provenance=None, k=None) -> ScallopContext: + def clone( + self, + provenance: Optional[str] = None, + k: Optional[int] = None, + wmc_with_disjunctions: Optional[bool] = None, + ) -> ScallopContext: """ Clone the current context. This is useful for incremental execution: @@ -175,7 +190,8 @@ def clone(self, provenance=None, k=None) -> ScallopContext: # Clone internal context; this process may fail if the provenance is not compatible new_k = k if k is not None else self._k - new_ctx._internal = new_ctx._internal.clone_with_new_provenance(provenance, new_k) + new_wmc_with_disjunctions = wmc_with_disjunctions if wmc_with_disjunctions is not None else self._wmc_with_disjunctions + new_ctx._internal = new_ctx._internal.clone_with_new_provenance(provenance, new_k, new_wmc_with_disjunctions) # Update parameters related to provenance new_ctx.provenance = provenance @@ -552,7 +568,7 @@ def get_front_ir(self): """ return self._internal.get_front_ir() - def relation(self, relation: str, debug: bool = False) -> ScallopCollection: + def relation(self, relation: str) -> ScallopCollection: """ Inspect the (computed) relation in the context. Will return a `ScallopCollection` which is iterable. @@ -564,7 +580,7 @@ def relation(self, relation: str, debug: bool = False) -> ScallopCollection: :param relation: the name of the relation """ - int_col = self._internal.relation(relation) if not debug else self._internal.relation_with_debug_tag(relation) + int_col = self._internal.relation(relation) return ScallopCollection(self.provenance, int_col) def has_relation(self, relation: str) -> bool: diff --git a/etc/scallopy/scallopy/forward.py b/etc/scallopy/scallopy/forward.py index 381a7aa..bcd9aa8 100644 --- a/etc/scallopy/scallopy/forward.py +++ b/etc/scallopy/scallopy/forward.py @@ -38,11 +38,18 @@ def __init__( jit_name: str = "", jit_recompile: bool = False, dispatch: str = "parallel", + monitors: List[str] = [], ): super(ScallopForwardFunction, self).__init__() # Setup the context - self.ctx = ScallopContext(provenance=provenance, custom_provenance=custom_provenance, k=k, train_k=train_k, test_k=test_k) + self.ctx = ScallopContext( + provenance=provenance, + custom_provenance=custom_provenance, + k=k, + train_k=train_k, + test_k=test_k, + monitors=monitors) # Import the file if specified if file is not None: diff --git a/etc/scallopy/scallopy/scallopy.pyi b/etc/scallopy/scallopy/scallopy.pyi index aed56f0..b20edcf 100644 --- a/etc/scallopy/scallopy/scallopy.pyi +++ b/etc/scallopy/scallopy/scallopy.pyi @@ -16,6 +16,7 @@ class InternalScallopContext: self, provenance: str = "unit", k: int = 3, + wmc_with_disjunctions: bool = False, custom_provenance: Optional[ScallopProvenance] = None, ) -> None: ... @@ -28,6 +29,7 @@ class InternalScallopContext: provenance: str, custom_provenance: Any, k: int, + wmc_with_disjunctions: bool, ) -> InternalScallopContext: ... def set_non_incremental(self): ... @@ -44,6 +46,8 @@ class InternalScallopContext: def remove_iter_limit(self): ... + def add_monitors(self, monitors: List[str]): ... + def run(self, iter_limit: Optional[int]) -> None: ... def run_with_debug_tag(self, iter_limit: Optional[int]) -> None: ... diff --git a/etc/scallopy/src/collection.rs b/etc/scallopy/src/collection.rs index d0b8718..d5115c0 100644 --- a/etc/scallopy/src/collection.rs +++ b/etc/scallopy/src/collection.rs @@ -213,13 +213,13 @@ impl CollectionIterator { slf.current_index += 1; if slf.collection.has_empty_tag() { if let Some(tuple) = to_python_tuple(slf.collection.ith_tuple(i), &slf.env) { - return IterNextOutput::Yield(tuple) + return IterNextOutput::Yield(tuple); } } else { let tuple = to_python_tuple(slf.collection.ith_tuple(i), &slf.env); let tag = slf.collection.ith_tag(i); let elem = Python::with_gil(|py| (tag, tuple).to_object(py)); - return IterNextOutput::Yield(elem) + return IterNextOutput::Yield(elem); } } IterNextOutput::Return("Ended") diff --git a/etc/scallopy/src/context.rs b/etc/scallopy/src/context.rs index 586e2c2..24912a9 100644 --- a/etc/scallopy/src/context.rs +++ b/etc/scallopy/src/context.rs @@ -107,8 +107,13 @@ impl Context { /// * `k` - an unsigned integer serving as the hyper-parameter for provenance such as `"topkproofs"` /// * `custom_provenance` - an optional python object serving as the provenance context #[new] - #[pyo3(signature=(provenance="unit", k=3, custom_provenance=None))] - fn new(provenance: &str, k: usize, custom_provenance: Option>) -> Result { + #[pyo3(signature=(provenance="unit", k=3, wmc_with_disjunctions=false, custom_provenance=None))] + fn new( + provenance: &str, + k: usize, + wmc_with_disjunctions: bool, + custom_provenance: Option>, + ) -> Result { // Check provenance type match provenance { "unit" => Ok(Self { @@ -129,12 +134,12 @@ impl Context { }), "topkproofs" => Ok(Self { ctx: ContextEnum::TopKProofs(IntegrateContext::new_incremental( - top_k_proofs::TopKProofsProvenance::new(k), + top_k_proofs::TopKProofsProvenance::new(k, wmc_with_disjunctions), )), }), "topbottomkclauses" => Ok(Self { ctx: ContextEnum::TopBottomKClauses(IntegrateContext::new_incremental( - top_bottom_k_clauses::TopBottomKClausesProvenance::new(k), + top_bottom_k_clauses::TopBottomKClausesProvenance::new(k, wmc_with_disjunctions), )), }), "diffminmaxprob" => Ok(Self { @@ -169,12 +174,12 @@ impl Context { }), "difftopkproofs" => Ok(Self { ctx: ContextEnum::DiffTopKProofs(IntegrateContext::new_incremental( - diff_top_k_proofs::DiffTopKProofsProvenance::new(k), + diff_top_k_proofs::DiffTopKProofsProvenance::new(k, wmc_with_disjunctions), )), }), "difftopbottomkclauses" => Ok(Self { ctx: ContextEnum::DiffTopBottomKClauses(IntegrateContext::new_incremental( - diff_top_bottom_k_clauses::DiffTopBottomKClausesProvenance::new(k), + diff_top_bottom_k_clauses::DiffTopBottomKClausesProvenance::new(k, wmc_with_disjunctions), )), }), "custom" => { @@ -196,7 +201,7 @@ impl Context { } /// Create a new scallop context with a different provenance as the current context - fn clone_with_new_provenance(&self, provenance: &str, k: usize) -> Result { + fn clone_with_new_provenance(&self, provenance: &str, k: usize, wmc_with_disjunctions: bool) -> Result { // Check provenance type match provenance { "unit" => Ok(Self { @@ -231,14 +236,14 @@ impl Context { ctx: ContextEnum::TopKProofs(match_context_except_custom!( &self.ctx, c, - c.clone_with_new_provenance(top_k_proofs::TopKProofsProvenance::new(k)) + c.clone_with_new_provenance(top_k_proofs::TopKProofsProvenance::new(k, wmc_with_disjunctions)) )?), }), "topbottomkclauses" => Ok(Self { ctx: ContextEnum::TopBottomKClauses(match_context_except_custom!( &self.ctx, c, - c.clone_with_new_provenance(top_bottom_k_clauses::TopBottomKClausesProvenance::new(k),) + c.clone_with_new_provenance(top_bottom_k_clauses::TopBottomKClausesProvenance::new(k, wmc_with_disjunctions),) )?), }), "diffminmaxprob" => Ok(Self { @@ -287,14 +292,14 @@ impl Context { ctx: ContextEnum::DiffTopKProofs(match_context_except_custom!( &self.ctx, c, - c.clone_with_new_provenance(diff_top_k_proofs::DiffTopKProofsProvenance::new(k),) + c.clone_with_new_provenance(diff_top_k_proofs::DiffTopKProofsProvenance::new(k, wmc_with_disjunctions),) )?), }), "difftopbottomkclauses" => Ok(Self { ctx: ContextEnum::DiffTopBottomKClauses(match_context_except_custom!( &self.ctx, c, - c.clone_with_new_provenance(diff_top_bottom_k_clauses::DiffTopBottomKClausesProvenance::new(k),) + c.clone_with_new_provenance(diff_top_bottom_k_clauses::DiffTopBottomKClausesProvenance::new(k, wmc_with_disjunctions),) )?), }), "custom" => Err(BindingError::CustomProvenanceUnsupported), @@ -338,6 +343,11 @@ impl Context { match_context!(&mut self.ctx, c, c.remove_iter_limit()) } + /// Add monitors to the system + fn add_monitors(&mut self, monitors: Vec<&str>) { + match_context!(&mut self.ctx, c, c.add_monitors(&monitors)) + } + /// Compile the surface program stored in the scallopy context into the ram program. /// /// This function is usually used before creating a forward function. diff --git a/etc/scallopy/src/foreign_attribute.rs b/etc/scallopy/src/foreign_attribute.rs index c7d22e4..1ac5d69 100644 --- a/etc/scallopy/src/foreign_attribute.rs +++ b/etc/scallopy/src/foreign_attribute.rs @@ -73,19 +73,17 @@ impl AttributeProcessor for PythonForeignAttribute { fn apply(&self, item: &ast::Item, attr: &ast::Attribute) -> Result { Python::with_gil(|py| { - let item_py = pythonize::pythonize(py, item) - .map_err(|e| AttributeError::Custom { - msg: format!("Error pythonizing item: {e}") - })?; - let attr_py = pythonize::pythonize(py, attr) - .map_err(|e| AttributeError::Custom { - msg: format!("Error pythonizing attribute: {e}") - })?; + let item_py = pythonize::pythonize(py, item).map_err(|e| AttributeError::Custom { + msg: format!("Error pythonizing item: {e}"), + })?; + let attr_py = pythonize::pythonize(py, attr).map_err(|e| AttributeError::Custom { + msg: format!("Error pythonizing attribute: {e}"), + })?; let args = PyTuple::new(py, vec![item_py, attr_py]); let result = self.py_attr.call_method(py, "apply", args, None).map_err(|e| { e.print(py); AttributeError::Custom { - msg: format!("Error applying attribute: {e}") + msg: format!("Error applying attribute: {e}"), } })?; Ok(self.process_action(py, result)) diff --git a/etc/scallopy/src/foreign_function.rs b/etc/scallopy/src/foreign_function.rs index 9aa58f3..945b3b7 100644 --- a/etc/scallopy/src/foreign_function.rs +++ b/etc/scallopy/src/foreign_function.rs @@ -22,17 +22,13 @@ impl PythonForeignFunction { /// Create a new PythonForeignFunction pub fn new(ff: PyObject) -> Self { let suppress_warning = Python::with_gil(|py| { - ff - .getattr(py, "suppress_warning") + ff.getattr(py, "suppress_warning") .expect("Cannot get foreign function generic type parameters") .extract(py) .expect("`suppress_warning` cannot be extracted into boolean") }); - Self { - ff, - suppress_warning, - } + Self { ff, suppress_warning } } } diff --git a/etc/scallopy/src/foreign_predicate.rs b/etc/scallopy/src/foreign_predicate.rs index 81486f8..42b5a2a 100644 --- a/etc/scallopy/src/foreign_predicate.rs +++ b/etc/scallopy/src/foreign_predicate.rs @@ -25,12 +25,14 @@ pub struct PythonForeignPredicate { impl PythonForeignPredicate { pub fn new(fp: PyObject) -> Self { Python::with_gil(|py| { - let name = fp.getattr(py, "name") + let name = fp + .getattr(py, "name") .expect("Cannot get foreign predicate name") .extract(py) .expect("Foreign predicate name cannot be extracted into String"); - let suppress_warning = fp.getattr(py, "suppress_warning") + let suppress_warning = fp + .getattr(py, "suppress_warning") .expect("Cannot get foreign predicate `suppress_warning`") .extract(py) .expect("Foreign predicate `suppress_warning` cannot be extracted into bool"); diff --git a/etc/scallopy/src/io.rs b/etc/scallopy/src/io.rs index 83189b7..3c7c253 100644 --- a/etc/scallopy/src/io.rs +++ b/etc/scallopy/src/io.rs @@ -33,13 +33,22 @@ impl Into for CSVFileOptions { args.push(AttributeArgument::named_bool("has_header", self.has_header)); args.push(AttributeArgument::named_bool("has_probability", self.has_probability)); if let Some(keys) = self.keys { - args.push(AttributeArgument::named_list("keys", keys.iter().cloned().map(AttributeValue::string).collect())); + args.push(AttributeArgument::named_list( + "keys", + keys.iter().cloned().map(AttributeValue::string).collect(), + )); } if let Some(fields) = self.fields { - args.push(AttributeArgument::named_list("fields", fields.iter().cloned().map(AttributeValue::string).collect())); + args.push(AttributeArgument::named_list( + "fields", + fields.iter().cloned().map(AttributeValue::string).collect(), + )); } // Get attribute - Attribute { name: "file".to_string(), args } + Attribute { + name: "file".to_string(), + args, + } } } diff --git a/etc/scallopy/src/tag.rs b/etc/scallopy/src/tag.rs index a28207a..b5aed3f 100644 --- a/etc/scallopy/src/tag.rs +++ b/etc/scallopy/src/tag.rs @@ -1,10 +1,10 @@ use pyo3::types::*; -use scallop_core::common::input_tag::*; use scallop_core::common::foreign_tensor::*; +use scallop_core::common::input_tag::*; -use super::tensor::*; use super::error::*; +use super::tensor::*; pub fn from_python_input_tag(ty: &str, tag: &PyAny) -> Result { match ty { @@ -16,10 +16,10 @@ pub fn from_python_input_tag(ty: &str, tag: &PyAny) -> Result { let (prob, exc_id): (f64, usize) = tag.extract()?; Ok(DynamicInputTag::ExclusiveFloat(prob, exc_id)) - }, - "diff-prob" => { - Ok(DynamicInputTag::Tensor(DynamicExternalTensor::new(Tensor::from_py_value(tag.into())))) - }, - _ => Err(BindingError::InvalidInputTag) + } + "diff-prob" => Ok(DynamicInputTag::Tensor(DynamicExternalTensor::new( + Tensor::from_py_value(tag.into()), + ))), + _ => Err(BindingError::InvalidInputTag), } } diff --git a/etc/scallopy/src/tensor/torch/external_tensor.rs b/etc/scallopy/src/tensor/torch/external_tensor.rs index ca57e25..8a7ac28 100644 --- a/etc/scallopy/src/tensor/torch/external_tensor.rs +++ b/etc/scallopy/src/tensor/torch/external_tensor.rs @@ -1,5 +1,5 @@ -use std::any::Any; use pyo3::prelude::*; +use std::any::Any; use scallop_core::common::foreign_tensor::*; @@ -12,9 +12,7 @@ pub struct TorchTensor { impl TorchTensor { pub fn new(p: Py) -> Self { - Self { - internal: p, - } + Self { internal: p } } pub fn internal(&self) -> Py { @@ -25,25 +23,35 @@ impl TorchTensor { impl ExternalTensor for TorchTensor { fn shape(&self) -> TensorShape { Python::with_gil(|py| { - let shape_tuple: Vec = self.internal.getattr(py, "shape").expect("Cannot get `.shape` from object").extract(py).expect("`.shape` is not a tuple"); + let shape_tuple: Vec = self + .internal + .getattr(py, "shape") + .expect("Cannot get `.shape` from object") + .extract(py) + .expect("`.shape` is not a tuple"); TensorShape::from(shape_tuple) }) } fn get_f64(&self) -> f64 { Python::with_gil(|py| { - self.internal.call_method0(py, "item").expect("Cannot call function `.item()`").extract(py).expect("Cannot turn `.item()` into f64") + self + .internal + .call_method0(py, "item") + .expect("Cannot call function `.item()`") + .extract(py) + .expect("Cannot turn `.item()` into f64") }) } - fn as_any(&self) -> &dyn Any { self } + fn as_any(&self) -> &dyn Any { + self + } } impl PyExternalTensor for TorchTensor { fn from_py_value(p: &PyAny) -> Self { - Self { - internal: p.into(), - } + Self { internal: p.into() } } fn to_py_value(&self) -> Py { diff --git a/etc/scallopy/src/tensor/torch/registry.rs b/etc/scallopy/src/tensor/torch/registry.rs index 8473e91..69e5922 100644 --- a/etc/scallopy/src/tensor/torch/registry.rs +++ b/etc/scallopy/src/tensor/torch/registry.rs @@ -1,5 +1,5 @@ -use std::collections::*; use pyo3::prelude::*; +use std::collections::*; use scallop_core::common::foreign_tensor::*; @@ -24,34 +24,66 @@ impl TorchTensorRegistry { } fn eval_expr_torch(&self, value: &TensorExpr) -> TorchTensor { - Python::with_gil(|py| { - match value { - TensorExpr::Symbol(s) => self.get_torch(s), - TensorExpr::Float(f) => { - let builtins = PyModule::import(py, "torch").expect("Cannot import torch"); - let result: Py = builtins.getattr("tensor").expect("Cannot get tensor").call1((*f,)).expect("Cannot create tensor").extract().expect("Cannot convert"); - TorchTensor::new(result) - } - TensorExpr::Add(a, b) => { - let (ta, tb) = (self.eval_expr_torch(&a), self.eval_expr_torch(&b)); - let result: Py = ta.internal().getattr(py, "__add__").expect("Cannot sum").call1(py, (tb.internal(),)).expect("Cannot sum").extract(py).expect("Cannot convert"); - TorchTensor::new(result) - } - TensorExpr::Sub(a, b) => { - let (ta, tb) = (self.eval_expr_torch(&a), self.eval_expr_torch(&b)); - let result: Py = ta.internal().getattr(py, "__sub__").expect("Cannot sub").call1(py, (tb.internal(),)).expect("Cannot sub").extract(py).expect("Cannot convert"); - TorchTensor::new(result) - } - TensorExpr::Mul(a, b) => { - let (ta, tb) = (self.eval_expr_torch(&a), self.eval_expr_torch(&b)); - let result: Py = ta.internal().getattr(py, "__mul__").expect("Cannot mul").call1(py, (tb.internal(),)).expect("Cannot mul").extract(py).expect("Cannot convert"); - TorchTensor::new(result) - } - TensorExpr::Dot(a, b) => { - let (ta, tb) = (self.eval_expr_torch(&a), self.eval_expr_torch(&b)); - let result: Py = ta.internal().getattr(py, "dot").expect("Cannot dot").call1(py, (tb.internal(),)).expect("Cannot dot").extract(py).expect("Cannot convert"); - TorchTensor::new(result) - } + Python::with_gil(|py| match value { + TensorExpr::Symbol(s) => self.get_torch(s), + TensorExpr::Float(f) => { + let builtins = PyModule::import(py, "torch").expect("Cannot import torch"); + let result: Py = builtins + .getattr("tensor") + .expect("Cannot get tensor") + .call1((*f,)) + .expect("Cannot create tensor") + .extract() + .expect("Cannot convert"); + TorchTensor::new(result) + } + TensorExpr::Add(a, b) => { + let (ta, tb) = (self.eval_expr_torch(&a), self.eval_expr_torch(&b)); + let result: Py = ta + .internal() + .getattr(py, "__add__") + .expect("Cannot sum") + .call1(py, (tb.internal(),)) + .expect("Cannot sum") + .extract(py) + .expect("Cannot convert"); + TorchTensor::new(result) + } + TensorExpr::Sub(a, b) => { + let (ta, tb) = (self.eval_expr_torch(&a), self.eval_expr_torch(&b)); + let result: Py = ta + .internal() + .getattr(py, "__sub__") + .expect("Cannot sub") + .call1(py, (tb.internal(),)) + .expect("Cannot sub") + .extract(py) + .expect("Cannot convert"); + TorchTensor::new(result) + } + TensorExpr::Mul(a, b) => { + let (ta, tb) = (self.eval_expr_torch(&a), self.eval_expr_torch(&b)); + let result: Py = ta + .internal() + .getattr(py, "__mul__") + .expect("Cannot mul") + .call1(py, (tb.internal(),)) + .expect("Cannot mul") + .extract(py) + .expect("Cannot convert"); + TorchTensor::new(result) + } + TensorExpr::Dot(a, b) => { + let (ta, tb) = (self.eval_expr_torch(&a), self.eval_expr_torch(&b)); + let result: Py = ta + .internal() + .getattr(py, "dot") + .expect("Cannot dot") + .call1(py, (tb.internal(),)) + .expect("Cannot dot") + .extract(py) + .expect("Cannot convert"); + TorchTensor::new(result) } }) } diff --git a/etc/scallopy/src/tuple.rs b/etc/scallopy/src/tuple.rs index be05876..8f55ec9 100644 --- a/etc/scallopy/src/tuple.rs +++ b/etc/scallopy/src/tuple.rs @@ -104,12 +104,14 @@ pub fn from_python_value(v: &PyAny, ty: &ValueType, env: &PythonRuntimeEnvironme } ValueType::DateTime => { let string = v.extract()?; - let dt = utils::parse_date_time_string(string).ok_or(PyTypeError::new_err(format!("Cannot parse into DateTime: {}", string)))?; + let dt = utils::parse_date_time_string(string) + .ok_or(PyTypeError::new_err(format!("Cannot parse into DateTime: {}", string)))?; Ok(Value::DateTime(dt)) } ValueType::Duration => { let string = v.extract()?; - let dt = utils::parse_duration_string(string).ok_or(PyTypeError::new_err(format!("Cannot parse into Duration: {}", string)))?; + let dt = utils::parse_duration_string(string) + .ok_or(PyTypeError::new_err(format!("Cannot parse into Duration: {}", string)))?; Ok(Value::Duration(dt)) } ValueType::Entity => { @@ -122,7 +124,10 @@ pub fn from_python_value(v: &PyAny, ty: &ValueType, env: &PythonRuntimeEnvironme fn tensor_from_py_object(pyobj: &PyAny, env: &PythonRuntimeEnvironment) -> PyResult { let py_tensor = Tensor::from_py_value(pyobj); - let symbol = env.tensor_registry.register(DynamicExternalTensor::new(py_tensor)).ok_or(BindingError::CannotRegisterTensor)?; + let symbol = env + .tensor_registry + .register(DynamicExternalTensor::new(py_tensor)) + .ok_or(BindingError::CannotRegisterTensor)?; Ok(Value::TensorValue(symbol.into())) } diff --git a/etc/scallopy/tests/tensors.py b/etc/scallopy/tests/tensors.py index c84cc31..d8cccd1 100644 --- a/etc/scallopy/tests/tensors.py +++ b/etc/scallopy/tests/tensors.py @@ -52,7 +52,7 @@ def test_tensor_3(self): def test_tensor_backprop_4(self): x = torch.randn(10, requires_grad=True) y = torch.randn(10) - opt = torch.optim.Adam(params=[x], lr=0.1) + opt = torch.optim.Adam(params=[x], lr=0.01) gt_initial_sim = x.dot(y) / (x.norm() * y.norm()) + 1.0 / 2.0 ctx = scallopy.Context(provenance="difftopkproofs") @@ -75,11 +75,12 @@ def test_tensor_backprop_4(self): @unittest.skipIf(not scallopy.torch_tensor_enabled(), "not supported in this scallopy version") def test_tensor_forward_backprop_1(self): + torch.manual_seed(1357) batch_size = 16 x = torch.randn((batch_size, 10), requires_grad=True) y = torch.randn((batch_size, 10), requires_grad=True) - opt = torch.optim.Adam(params=[x, y], lr=0.1) + opt = torch.optim.Adam(params=[x, y], lr=0.01) scl_module = scallopy.Module( program=""" @@ -102,7 +103,7 @@ def step() -> float: return l.item() curr_loss = step() - for i in range(4): + for j in range(4): next_loss = step() - assert next_loss < curr_loss - curr_loss = next_loss + assert next_loss <= curr_loss + curr_loss = next_loss diff --git a/etc/sclc/Cargo.toml b/etc/sclc/Cargo.toml index 19b661e..ec4a821 100644 --- a/etc/sclc/Cargo.toml +++ b/etc/sclc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sclc-core" -version = "0.2.0" +version = "0.2.1" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/sclc/src/exec.rs b/etc/sclc/src/exec.rs index 0707760..7c348b8 100644 --- a/etc/sclc/src/exec.rs +++ b/etc/sclc/src/exec.rs @@ -154,14 +154,15 @@ fn run_function(output_code: TokenStream) -> TokenStream { fn main_body(opt: &Options) -> TokenStream { if let Some(p) = &opt.provenance { let top_k = opt.top_k; + let wmc_with_disjunctions = opt.wmc_with_disjunctions; match p.as_str() { "unit" => quote! { run(unit::UnitProvenance::default()); }, "bool" => quote! { run(boolean::BooleanProvenance::default()); }, "minmaxprob" => quote! { run(min_max_prob::MinMaxProbProvenance::default()); }, "addmultprob" => quote! { run(add_mult_prob::AddMultProbProvenance::default()); }, - "topkproofs" => quote! { run(top_k_proofs::TopKProofsProvenance::::new(#top_k)); }, + "topkproofs" => quote! { run(top_k_proofs::TopKProofsProvenance::::new(#top_k, #wmc_with_disjunctions)); }, "samplekproofs" => quote! { run(sample_k_proofs::SampleKProofsContext::new(#top_k)); }, - "topbottomkclauses" => quote! { run(top_bottom_k_clauses::TopBottomKClausesContext::::new(#top_k)); }, + "topbottomkclauses" => quote! { run(top_bottom_k_clauses::TopBottomKClausesContext::::new(#top_k, #wmc_with_disjunctions)); }, p => panic!("Unknown provenance `{}`. Aborting", p), } } else { @@ -172,9 +173,9 @@ fn main_body(opt: &Options) -> TokenStream { "bool" => run(boolean::BooleanProvenance::default()), "minmaxprob" => run(min_max_prob::MinMaxProbProvenance::default()), "addmultprob" => run(add_mult_prob::AddMultProbProvenance::default()), - "topkproofs" => run(top_k_proofs::TopKProofsProvenance::::new(opt.top_k)), + "topkproofs" => run(top_k_proofs::TopKProofsProvenance::::new(opt.top_k, opt.wmc_with_disjunctions)), "samplekproofs" => run(sample_k_proofs::SampleKProofsProvenance::new(opt.top_k)), - "topbottomkclauses" => run(top_bottom_k_clauses::TopBottomKClausesProvenance::::new(opt.top_k)), + "topbottomkclauses" => run(top_bottom_k_clauses::TopBottomKClausesProvenance::::new(opt.top_k, opt.wmc_with_disjunction)), p => println!("Unknown provenance `{}`. Aborting", p), } } diff --git a/etc/sclc/src/options.rs b/etc/sclc/src/options.rs index e564641..504d634 100644 --- a/etc/sclc/src/options.rs +++ b/etc/sclc/src/options.rs @@ -27,8 +27,11 @@ pub struct Options { #[structopt(long)] pub provenance: Option, - #[structopt(long, default_value = "3")] + #[structopt(short = "k", long, default_value = "3")] pub top_k: usize, + + #[structopt(long)] + pub wmc_with_disjunctions: bool, } impl From<&Options> for compiler::CompileOptions { diff --git a/etc/sclc/src/pylib.rs b/etc/sclc/src/pylib.rs index 9e41c83..4c8d1dc 100644 --- a/etc/sclc/src/pylib.rs +++ b/etc/sclc/src/pylib.rs @@ -293,15 +293,15 @@ fn generate_context_code() -> TokenStream { #[pymethods] impl Context { #[new] - #[args(provenance = "\"unit\"", top_k = "None")] - fn new(provenance: &str, top_k: Option) -> PyResult { + #[args(provenance = "\"unit\"", top_k = "None", wmc_with_disjunctions = "False")] + fn new(provenance: &str, top_k: Option, wmc_with_disjunctions: bool) -> PyResult { let top_k = top_k.unwrap_or(3); match provenance { "unit" => Ok(Self { ctx: ContextEnum::Unit(StaticContext::new(unit::UnitProvenance::default())) }), "minmaxprob" => Ok(Self { ctx: ContextEnum::MinMaxProb(StaticContext::new(min_max_prob::MinMaxProbProvenance::default())) }), "addmultprob" => Ok(Self { ctx: ContextEnum::AddMultProb(StaticContext::new(add_mult_prob::AddMultProbProvenance::default())) }), "diffminmaxprob" => Ok(Self { ctx: ContextEnum::DiffMinMaxProb(StaticContext::new(diff_min_max_prob::DiffMinMaxProbProvenance::default())) }), - "difftopkproofs" => Ok(Self { ctx: ContextEnum::DiffTopKProofs(StaticContext::new(diff_top_k_proofs::DiffTopKProofsProvenance::new(top_k))) }), + "difftopkproofs" => Ok(Self { ctx: ContextEnum::DiffTopKProofs(StaticContext::new(diff_top_k_proofs::DiffTopKProofsProvenance::new(top_k, wmc_with_disjunctions))) }), "difftopbottomkclauses" => Ok(Self { ctx: ContextEnum::DiffTopBottomKClauses(StaticContext::new(diff_top_bottom_k_clauses::DiffTopBottomKClausesProvenance::new(top_k))) }), p => Err(PyErr::from(BindingError(format!("Unknown provenance `{}`", p.to_string())))), } diff --git a/etc/scli/Cargo.toml b/etc/scli/Cargo.toml index 5a84284..11767e2 100644 --- a/etc/scli/Cargo.toml +++ b/etc/scli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scli" -version = "0.2.0" +version = "0.2.1" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/scli/src/main.rs b/etc/scli/src/main.rs index 1867338..1580b6d 100644 --- a/etc/scli/src/main.rs +++ b/etc/scli/src/main.rs @@ -23,6 +23,9 @@ struct Options { #[structopt(short = "k", long, default_value = "3")] top_k: usize, + #[structopt(long)] + wmc_with_disjunctions: bool, + #[structopt(short = "q", long)] query: Option, @@ -160,11 +163,11 @@ fn main() -> Result<(), String> { interpret(ctx, &opt.input, integrate_opt, predicate_set, monitor_options) } "topkproofs" => { - let ctx = provenance::top_k_proofs::TopKProofsProvenance::::new(opt.top_k); + let ctx = provenance::top_k_proofs::TopKProofsProvenance::::new(opt.top_k, opt.wmc_with_disjunctions); interpret(ctx, &opt.input, integrate_opt, predicate_set, monitor_options) } "topbottomkclauses" => { - let ctx = provenance::top_bottom_k_clauses::TopBottomKClausesProvenance::::new(opt.top_k); + let ctx = provenance::top_bottom_k_clauses::TopBottomKClausesProvenance::::new(opt.top_k, opt.wmc_with_disjunctions); interpret(ctx, &opt.input, integrate_opt, predicate_set, monitor_options) } _ => Err(format!("Unknown provenance semiring `{}`", opt.provenance)), diff --git a/etc/sclrepl/Cargo.toml b/etc/sclrepl/Cargo.toml index 9c74027..857d6c1 100644 --- a/etc/sclrepl/Cargo.toml +++ b/etc/sclrepl/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sclrepl" -version = "0.2.0" +version = "0.2.1" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/vscode-scl/CHANGELOG.md b/etc/vscode-scl/CHANGELOG.md index 508fbea..3e76401 100644 --- a/etc/vscode-scl/CHANGELOG.md +++ b/etc/vscode-scl/CHANGELOG.md @@ -1,5 +1,6 @@ # Change Log -## [Unreleased] +## 0.1.1, Sep 23, 2023 -- Initial release +- Improved aggregation syntax to allow for arbitrary aggregates +- Fixed disjunctive head syntax diff --git a/etc/vscode-scl/examples/syntax_test.scl b/etc/vscode-scl/examples/syntax_test.scl index 65022b5..264d691 100644 --- a/etc/vscode-scl/examples/syntax_test.scl +++ b/etc/vscode-scl/examples/syntax_test.scl @@ -52,7 +52,7 @@ rel fib = {(0, 1), (1, 1)} rel fib(x, a + b) = fib(x - 1, a), fib(x - 2, b), x > 1 // Reduce -rel how_many_play_soccer(c) = c = count(n: plays_soccer(n)) +rel how_many_play_soccer(c) = c := count(n: plays_soccer(n)) // Logical operators rel something(a && b) = b || c && !d && x == y || x != y diff --git a/etc/vscode-scl/examples/syntax_test_2.scl b/etc/vscode-scl/examples/syntax_test_2.scl index f1d834f..1ad2823 100644 --- a/etc/vscode-scl/examples/syntax_test_2.scl +++ b/etc/vscode-scl/examples/syntax_test_2.scl @@ -119,20 +119,30 @@ rel expr_to_string(e, $format("({} {} {})", op1_str, op_str, op2_str)) = // Rules with aggregation inside rel num_students(n) = n := count(a: student(p)) -rel num_paths(n) = n := count(a, b: path(a, b)) +rel num_paths(n) = n := count!(a, b: path(a, b)) rel my_rule(x) = _ := count(a, b: path(a, b) where c: color(c)) rel my_rule_2() = forall(a, b: path(a, b) and path(a, b) implies edge(a, b)) -rel sample_rule(x) = x := top<3>(a, b: path(a, b)) +rel sample_rule(x) = (x, y) := top<3>(a, b: path(a, b)) rel sample_rule(x) = x := categorical<3>(a, b: path(a, b)) -rel sample_rule(x) = x := uniform<3>(a, b: path(a, b)) -rel sample_rule(x) = x := min[a](a, b: path(a, b)) -rel sample_rule(x) = _ := min[x](s: score(x, s)) +rel sample_rule(x) = x := uniform!<3>(a, b: path(a, b)) +rel sample_rule(x) = (a, x) := min![a](a, b: path(a, b, v, d)) +rel sample_rule(x) = (abhb, xjh, yjhhk, zkjjk, dkhbkhb) := min[a](a, b: path(a, b) where c, d, e, f, g: some_relation(3851)) +rel sample_rule(x) = x := argmin![a, g, h, u, h](a, c, d, b: path(a, b)) +rel sample_rule(x) = x := argmin![a, g, h, u, h](a, c, d, b: path(a, b) /* something like this */ ) +rel sample_rule(x) = x := argmin!(a, c, d, b: + path(a, b) and // new line + edge(a, b) and // new line again + a == b +) +rel sample_rule(x) = _ := min![x](s: score(x, s)) +rel sample_rule(x) = _ := sum[p](s: score(x, s)) +rel sample_rule(x) = _ := cross_entropy[i](y_pred, y: pred(i, y_pred) and ground_truth(i, y)) // Nested aggregation rel nested_agg(x) = x := count(x: m := max(x: relation(x))) // Disjunctive datalog -rel { assign(x, true); assign(x, false) } = var(x) +rel { assign(x, true); assign(x, false) } = var(x, y, z, a, b, c, d, e, f, g, h, i, j, a + 1) // Relations with generic arguments rel grid(x, y) = range(0, 5, x) and range(0, 5, y) diff --git a/etc/vscode-scl/package.json b/etc/vscode-scl/package.json index dcd8895..605f30c 100644 --- a/etc/vscode-scl/package.json +++ b/etc/vscode-scl/package.json @@ -2,7 +2,7 @@ "name": "scallop", "displayName": "Scallop Language Syntax Highlight", "description": "Scallop Language Support", - "version": "0.0.9", + "version": "0.1.1", "repository": { "type": "git", "url": "https://github.com/liby99/scallop-v2" diff --git a/etc/vscode-scl/syntaxes/scallop.tmLanguage.json b/etc/vscode-scl/syntaxes/scallop.tmLanguage.json index 36901c5..f59dd5a 100644 --- a/etc/vscode-scl/syntaxes/scallop.tmLanguage.json +++ b/etc/vscode-scl/syntaxes/scallop.tmLanguage.json @@ -342,11 +342,12 @@ }, "relation_decl": { "patterns": [ - { "include": "#fact_set_decl" } + { "include": "#fact_set_decl" }, + { "include": "#disj_head_decl" } ] }, "fact_set_decl": { - "begin": "(rel)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s*(=)\\s*\\{", + "begin": "(rel)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s*(=|:-)\\s*\\{", "beginCaptures": { "1": { "name": "keyword.control.scallop" @@ -376,6 +377,23 @@ } } }, + "disj_head_decl": { + "begin": "(rel)\\s+\\{", + "end": "\\}\\s*(=|:-)", + "beginCaptures": { + "1": { + "name": "keyword.control.scallop" + } + }, + "endCaptures": { + "2": { + "name": "keyword.operator.scallop" + } + }, + "patterns": [ + { "include": "#atom" } + ] + }, "query_decl": { "match": "(query)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\n", "captures": { @@ -389,154 +407,102 @@ }, "aggregation": { "patterns": [ - { "include": "#simple_aggregation" }, + { "include": "#basic_aggregation" }, { "include": "#sample_aggregation" }, - { "include": "#argminmax_aggregation" }, + { "include": "#arg_aggr_aggregation" }, { "include": "#forall_exists_aggregation" } ] }, - "simple_aggregation": { - "begin": "([a-zA-Z_][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z_][a-zA-Z0-9_]*))*\\s*(:=|=)\\s*(count|sum|prod|min|max|exists|forall|unique)\\(([a-zA-Z][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z][a-zA-Z0-9_]*))*\\s*:", - "end": "\\)", + "basic_aggregation": { + "begin": "(:=)\\s*([a-zA-Z][a-zA-Z0-9_]*)(!?)\\(", + "end": ":", "beginCaptures": { "1": { - "name": "variable.parameter.scallop" - }, - "3": { - "name": "variable.parameter.scallop" - }, - "4": { "name": "keyword.operator.scallop" }, - "5": { + "2": { "name": "keyword.other.scallop" }, - "6": { - "name": "variable.parameter.scallop" - }, - "8": { - "name": "variable.parameter.scallop" + "3": { + "name": "keyword.operator.scallop" } }, "patterns": [ { "include": "#comment" }, - { "include": "#where_clause" }, - { "include": "#formula" } + { "include": "#variable" } ] }, "sample_aggregation": { - "begin": "([a-zA-Z_][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z_][a-zA-Z0-9_]*))*\\s*(:=|=)\\s*(top|categorical|uniform)(<)(\\d+)(>)\\(([a-zA-Z][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z][a-zA-Z0-9_]*))*\\s*:", - "end": "\\)", + "begin": "(:=)\\s*([a-zA-Z][a-zA-Z0-9_]*)(!?)(<)(\\d+)(>)\\(", + "end": ":", "beginCaptures": { "1": { - "name": "variable.parameter.scallop" + "name": "keyword.operator.scallop" + }, + "2": { + "name": "keyword.other.scallop" }, "3": { - "name": "variable.parameter.scallop" + "name": "keyword.operator.scallop" }, "4": { "name": "keyword.operator.scallop" }, "5": { - "name": "keyword.other.scallop" - }, - "6": { - "name": "punctuation.brackets.angle.scallop" - }, - "7": { "name": "constant.numeric.scallop" }, - "8": { - "name": "punctuation.brackets.angle.scallop" - }, - "9": { - "name": "variable.parameter.scallop" - }, - "11": { - "name": "variable.parameter.scallop" + "6": { + "name": "keyword.operator.scallop" } }, "patterns": [ { "include": "#comment" }, - { "include": "#where_clause" }, - { "include": "#formula" } + { "include": "#variable" } ] }, - "argminmax_aggregation": { - "begin": "([a-zA-Z_][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z_][a-zA-Z0-9_]*))*\\s*(:=|=)\\s*(min|max)(\\[)([a-zA-Z][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z][a-zA-Z0-9_]*))*(\\])\\(([a-zA-Z][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z][a-zA-Z0-9_]*))*\\s*:", - "end": "\\)", + "arg_aggr_aggregation": { + "begin": "(:=)\\s*([a-zA-Z][a-zA-Z0-9_]*)(!?)(\\[)", + "end": "\\]", "beginCaptures": { "1": { - "name": "variable.parameter.scallop" - }, - "3": { - "name": "variable.parameter.scallop" - }, - "4": { "name": "keyword.operator.scallop" }, - "5": { + "2": { "name": "keyword.other.scallop" }, - "6": { - "name": "punctuation.brackets.angle.scallop" - }, - "7": { - "name": "variable.parameter.scallop" - }, - "9": { - "name": "variable.parameter.scallop" + "3": { + "name": "keyword.operator.scallop" }, - "10": { + "4": { + "name": "punctuation.brackets.angle.scallop" + } + }, + "endCaptures": { + "0": { "name": "punctuation.brackets.angle.scallop" - }, - "11": { - "name": "variable.parameter.scallop" - }, - "13": { - "name": "variable.parameter.scallop" } }, "patterns": [ { "include": "#comment" }, - { "include": "#where_clause" }, - { "include": "#formula" } + { "include": "#variable" } ] }, "forall_exists_aggregation": { - "begin": "(forall|exists)\\(([a-zA-Z][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z][a-zA-Z0-9_]*))*\\s*:", - "end": "\\)", + "begin": "(forall|exists)(!?)\\(", + "end": ":", "beginCaptures": { "1": { "name": "keyword.other.scallop" }, "2": { - "name": "variable.parameter.scallop" - }, - "4": { - "name": "variable.parameter.scallop" + "name": "keyword.operator.scallop" } }, "patterns": [ { "include": "#comment" }, - { "include": "#where_clause" }, - { "include": "#formula" } + { "include": "#variable" } ] }, - "where_clause": { - "match": "(where)\\s+([a-zA-Z][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z][a-zA-Z0-9_]*))*\\s*:", - "captures": { - "1": { - "name": "keyword.control.scallop" - }, - "2": { - "name": "variable.parameter.scallop" - }, - "4": { - "name": "variable.parameter.scallop" - } - } - }, "case_is": { "begin": "(case)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s+(is)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s*\\(", "end": "\\)", diff --git a/examples/datalog/equality_saturation.scl b/examples/datalog/equality_saturation.scl index 7372265..606f484 100644 --- a/examples/datalog/equality_saturation.scl +++ b/examples/datalog/equality_saturation.scl @@ -32,7 +32,7 @@ rel weight(p, w1 + w2 + 1) = case p is Add(p1, p2) and weight(p1, w1) and weight rel equiv_programs(sp) = input_program(p) and equivalent(p, sp) // Find the best program (minimum weight) among all programs equivalent to p -rel best_program(p) = _ := min[p](w: equiv_programs(p) and weight(p, w)) +rel best_program(p) = p := argmin[p](w: equiv_programs(p) and weight(p, w)) rel best_program_str(s) = best_program(best_prog) and to_string(best_prog, s) // ======================================== diff --git a/examples/legacy/good_scl/animal.scl b/examples/legacy/good_scl/animal.scl index ffc5dc3..7dbe167 100644 --- a/examples/legacy/good_scl/animal.scl +++ b/examples/legacy/good_scl/animal.scl @@ -20,4 +20,4 @@ rel num_things("animal", n) :- n = count(o: name(o, "animal")) rel num_things("tiger", n) :- n = count(o: name(o, "tiger")) // Comparing -rel more_animal_or_tiger(t) :- _ = max[t](n: num_things(t, n)) +rel more_animal_or_tiger(t) :- t = argmax[t](n: num_things(t, n)) diff --git a/examples/legacy/good_scl/obj_color.scl b/examples/legacy/good_scl/obj_color.scl index 423b62e..5f1a60f 100644 --- a/examples/legacy/good_scl/obj_color.scl +++ b/examples/legacy/good_scl/obj_color.scl @@ -6,4 +6,4 @@ rel object_color(4, "green") rel object_color(5, "red") rel color_count(c, n) :- n = count(o: object_color(o, c)) -rel max_color(c) :- _ = max[c](n: color_count(c, n)) +rel max_color(c) :- c := argmax[c](n: color_count(c, n)) diff --git a/examples/legacy/good_scl/obj_color_2.scl b/examples/legacy/good_scl/obj_color_2.scl index 939df92..1ecd79a 100644 --- a/examples/legacy/good_scl/obj_color_2.scl +++ b/examples/legacy/good_scl/obj_color_2.scl @@ -5,4 +5,4 @@ rel object_color(3, "green") rel object_color(4, "green") rel object_color(5, "red") -rel max_color(c) = _ = max[c](n: n = count(o: object_color(o, c))) +rel max_color(c) = c := argmax[c](n: n = count(o: object_color(o, c))) diff --git a/examples/legacy/good_scl/student_grade_1.scl b/examples/legacy/good_scl/student_grade_1.scl index 51592bf..8d2b73c 100644 --- a/examples/legacy/good_scl/student_grade_1.scl +++ b/examples/legacy/good_scl/student_grade_1.scl @@ -7,4 +7,4 @@ rel class_student_grade = { (1, "frank", 30), } -rel class_top_student(c, s) = _ = max[s](g: class_student_grade(c, s, g)) +rel class_top_student(c, s) = s = argmax[s](g: class_student_grade(c, s, g)) diff --git a/examples/legacy/invalid_scl/unbound_3.scl b/examples/legacy/invalid_scl/unbound_3.scl index e94912c..a0f463e 100644 --- a/examples/legacy/invalid_scl/unbound_3.scl +++ b/examples/legacy/invalid_scl/unbound_3.scl @@ -1 +1 @@ -rel R(a) = _ = max[a](o: Q(_, a)) +rel R(a) = (_, _) = max[a](o: Q(_, a)) diff --git a/lib/astnode-derive/src/lib.rs b/lib/astnode-derive/src/lib.rs index c51b720..0b475e9 100644 --- a/lib/astnode-derive/src/lib.rs +++ b/lib/astnode-derive/src/lib.rs @@ -119,21 +119,21 @@ fn skip_decorators(token_list: Vec) -> Vec { match &token_list[i] { TokenTree::Punct(p) if p.as_char() == '#' => { i += 2; - }, - _ => { - break token_list[i..].to_vec() - }, + } + _ => break token_list[i..].to_vec(), } } } fn get_has_pub(token_list: &Vec) -> bool { match &token_list[0] { - TokenTree::Ident(i) => if i.to_string() == "pub" { - true - } else { - false - }, + TokenTree::Ident(i) => { + if i.to_string() == "pub" { + true + } else { + false + } + } _ => false, } } @@ -145,11 +145,9 @@ fn get_is_struct(has_pub: bool, token_list: &Vec) -> bool { if i.to_string() == "struct" { match &token_list[offset + 2] { TokenTree::Group(g) => match g.delimiter() { - Delimiter::Brace => { - true - } - _ => panic!("AstNode only support decorating struct with fields") - } + Delimiter::Brace => true, + _ => panic!("AstNode only support decorating struct with fields"), + }, TokenTree::Punct(p) if p.as_char() == ';' => true, t => panic!("Unknown token tree {}", t), } @@ -177,7 +175,7 @@ fn get_type_name(has_pub: bool, is_struct: bool, token_list: &Vec) -> } else { full_name } - }, + } other => panic!("Unexpected token tree {:?}", other), } } @@ -202,22 +200,20 @@ fn get_struct_fields(has_pub: bool, token_list: &Vec) -> Vec<(bool, S curr_is_pub = Some(false); curr_field_name = Some(i.to_string()); } - }, - _ => panic!("Cannot parse") + } + _ => panic!("Cannot parse"), } } else if curr_is_pub.is_some() && curr_field_name.is_none() { match token { TokenTree::Ident(i) => { curr_field_name = Some(i.to_string()); } - _ => panic!("Cannot parse") + _ => panic!("Cannot parse"), } } else if curr_field_name.is_some() && curr_field_type.is_none() { match token { - TokenTree::Punct(p) if p.as_char() == ':' => { - curr_field_type = Some("".to_string()) - } - _ => panic!("Cannot parse") + TokenTree::Punct(p) if p.as_char() == ':' => curr_field_type = Some("".to_string()), + _ => panic!("Cannot parse"), } } else if curr_field_type.is_some() { match token { @@ -255,7 +251,7 @@ fn get_struct_fields(has_pub: bool, token_list: &Vec) -> Vec<(bool, S } } } - _ => panic!("Cannot parse") + _ => panic!("Cannot parse"), } } } @@ -270,7 +266,7 @@ fn get_struct_fields(has_pub: bool, token_list: &Vec) -> Vec<(bool, S fields } TokenTree::Punct(_) => vec![], - _ => panic!("Not a group") + _ => panic!("Not a group"), } } @@ -290,13 +286,19 @@ fn get_enum_variants(has_pub: bool, token_list: &Vec) -> Vec<(String, TokenTree::Ident(i) => { curr_variant_name = Some(i.to_string()); } - _ => panic!("Cannot parse") + _ => panic!("Cannot parse"), } } else if curr_variant_name.is_some() && curr_variant_type.is_none() { match token { TokenTree::Group(g) => { if g.delimiter() == Delimiter::Parenthesis { - curr_variant_type = Some(Some(g.stream().into_iter().map(|t| t.to_string()).collect::>().join(""))); + curr_variant_type = Some(Some( + g.stream() + .into_iter() + .map(|t| t.to_string()) + .collect::>() + .join(""), + )); } else { panic!("AstNode enum variant cannot be a struct") } @@ -306,7 +308,7 @@ fn get_enum_variants(has_pub: bool, token_list: &Vec) -> Vec<(String, curr_variant_name = None; curr_variant_type = None; } - _ => panic!("Cannot parse") + _ => panic!("Cannot parse"), } } else if curr_variant_name.is_some() && curr_variant_type.is_some() { match token { @@ -315,7 +317,7 @@ fn get_enum_variants(has_pub: bool, token_list: &Vec) -> Vec<(String, curr_variant_name = None; curr_variant_type = None; } - _ => panic!("Cannot parse") + _ => panic!("Cannot parse"), } } else { panic!("Cannot parse") @@ -334,16 +336,11 @@ fn get_enum_variants(has_pub: bool, token_list: &Vec) -> Vec<(String, panic!("AstNode enum has to have at least one variant") } } - _ => panic!("Not a group") + _ => panic!("Not a group"), } } -fn derive_struct( - pub_kw: &str, - name: String, - has_pub: bool, - token_list: &Vec, -) -> String { +fn derive_struct(pub_kw: &str, name: String, has_pub: bool, token_list: &Vec) -> String { let struct_def = format!(r#"{pub_kw} type {name} = AstNodeWrapper<_{name}>;"#); let fields = get_struct_fields(has_pub, &token_list); @@ -372,7 +369,8 @@ fn derive_struct( .collect::(); // Constructor - let new_impl = format!(r#" + let new_impl = format!( + r#" impl _{name} {{ pub fn new({fn_args}) -> Self {{ Self {{ {constructor_args} }} }} pub fn with_location(self, loc: NodeLocation) -> {name} {{ {name} {{ _loc: loc, _node: self }} }} @@ -387,7 +385,8 @@ fn derive_struct( pub fn internal(&self) -> &_{name} {{ &self._node }} pub fn internal_mut(&mut self) -> &mut _{name} {{ &mut self._node }} }} - "#); + "# + ); // Walker let (walker, walker_mut): (String, String) = fields @@ -398,12 +397,14 @@ fn derive_struct( (walk, walk_mut) }) .unzip(); - let impl_walker = format!(r#" + let impl_walker = format!( + r#" impl AstWalker for {name} {{ fn walk(&self, v: &mut V) {{ v.visit(self); v.visit(&self._loc); {walker} }} fn walk_mut(&mut self, v: &mut V) {{ v.visit_mut(self); v.visit_mut(&mut self._loc); {walker_mut} }} }} - "#); + "# + ); // Accessor let fields_accessor = fields @@ -444,12 +445,7 @@ fn derive_struct( .join("\n"); let fields_impl = format!(r#"impl {name} {{ {fields_accessor} }}"#); - vec![ - struct_def, - new_impl, - impl_walker, - fields_impl, - ].join("\n") + vec![struct_def, new_impl, impl_walker, fields_impl].join("\n") } fn derive_enum(pub_kw: &str, name: String, has_pub: bool, token_list: &Vec) -> String { @@ -462,17 +458,31 @@ fn derive_enum(pub_kw: &str, name: String, has_pub: bool, token_list: &Vec)>) -> String { - let imut_cases = variants.iter().map(|(name, _)| format!("Self::{name}(v) => v.location()")).collect::>().join(","); - let mut_cases = variants.iter().map(|(name, _)| format!("Self::{name}(v) => v.location_mut()")).collect::>().join(","); - let clone_cases = variants.iter().map(|(name, _)| format!("Self::{name}(v) => Self::{name}(v.clone_with_loc(loc))")).collect::>().join(","); + let imut_cases = variants + .iter() + .map(|(name, _)| format!("Self::{name}(v) => v.location()")) + .collect::>() + .join(","); + let mut_cases = variants + .iter() + .map(|(name, _)| format!("Self::{name}(v) => v.location_mut()")) + .collect::>() + .join(","); + let clone_cases = variants + .iter() + .map(|(name, _)| format!("Self::{name}(v) => Self::{name}(v.clone_with_loc(loc))")) + .collect::>() + .join(","); - let ast_node_impl = format!(r#" + let ast_node_impl = format!( + r#" impl AstNode for {name} {{ fn location(&self) -> &NodeLocation {{ match self {{ {imut_cases} }} }} fn location_mut(&mut self) -> &mut NodeLocation {{ match self {{ {mut_cases} }} }} fn clone_with_loc(&self, loc: NodeLocation) -> Self {{ match self {{ {clone_cases} }} }} }} - "#); + "# + ); let helpers = variants .iter() @@ -501,26 +511,28 @@ fn derive_variant_enum(name: String, variants: Vec<(String, Option)>) -> ) }) .unzip(); - let impl_walker = format!(r#" + let impl_walker = format!( + r#" impl AstWalker for {name} {{ fn walk(&self, v: &mut V) {{ v.visit(self); match self {{ {walkers} }} }} fn walk_mut(&mut self, v: &mut V) {{ v.visit_mut(self); match self {{ {walker_muts} }} }} }} - "#); + "# + ); - vec![ - ast_node_impl, - match_helper_impl, - impl_walker, - ].join("\n") + vec![ast_node_impl, match_helper_impl, impl_walker].join("\n") } fn derive_const_enum(pub_kw: &str, name: String, variants: Vec<(String, Option)>) -> String { - assert!(name.chars().nth(0) == Some('_'), "The first character of the name needs to be an underscore `_`"); + assert!( + name.chars().nth(0) == Some('_'), + "The first character of the name needs to be an underscore `_`" + ); let name = name[1..].to_string(); let type_def = format!(r#"{pub_kw} type {name} = AstNodeWrapper<_{name}>;"#); - let universal_helper = format!(r#" + let universal_helper = format!( + r#" impl _{name} {{ pub fn with_location(self, loc: NodeLocation) -> {name} {{ {name} {{ _loc: loc, _node: self }} }} pub fn with_span(self, start: usize, end: usize) -> {name} {{ {name} {{ _loc: NodeLocation::from_span(start, end), _node: self }} }} @@ -529,7 +541,8 @@ fn derive_const_enum(pub_kw: &str, name: String, variants: Vec<(String, Option &_{name} {{ &self._node }} pub fn internal_mut(&mut self) -> &mut _{name} {{ &mut self._node }} }} - "#); + "# + ); let helpers = variants .iter() @@ -556,17 +569,14 @@ fn derive_const_enum(pub_kw: &str, name: String, variants: Vec<(String, Option(&self, v: &mut V) {{ v.visit(self); v.visit(&self._loc); }} fn walk_mut(&mut self, v: &mut V) {{ v.visit_mut(self); v.visit_mut(&mut self._loc); }} }} - "#); - - vec![ - type_def, - universal_helper, - match_helper_impl, - impl_walker, - ].join("\n") + "# + ); + + vec![type_def, universal_helper, match_helper_impl, impl_walker].join("\n") } diff --git a/lib/parse_relative_duration/src/lib.rs b/lib/parse_relative_duration/src/lib.rs index f809a8c..266f8b8 100644 --- a/lib/parse_relative_duration/src/lib.rs +++ b/lib/parse_relative_duration/src/lib.rs @@ -154,9 +154,9 @@ extern crate regex; #[macro_use] extern crate lazy_static; -extern crate num; extern crate chrono; extern crate chronoutil; +extern crate num; /// This module contains the parse function and the error `enum`. /// diff --git a/lib/parse_relative_duration/src/parse.rs b/lib/parse_relative_duration/src/parse.rs index e621a60..b2d03d1 100644 --- a/lib/parse_relative_duration/src/parse.rs +++ b/lib/parse_relative_duration/src/parse.rs @@ -1,135 +1,125 @@ +use chrono::Duration; +use chronoutil::RelativeDuration; use num::pow::pow; use num::{BigInt, ToPrimitive}; use regex::Regex; use std::error::Error as ErrorTrait; use std::fmt; -use chrono::Duration; -use chronoutil::RelativeDuration; #[derive(Debug, PartialEq, Eq, Clone)] /// An enumeration of the possible errors while parsing. pub enum Error { - // When I switch exponents to use `BigInt`, this variant should be impossible. - // Right now it'll return this error with things like "1e123456781234567812345678" - // where the exponent can't be parsed into an `isize`. - /// An exponent failed to be parsed as an `isize`. - ParseInt(String), - /// An unrecognized unit was found. - UnknownUnit(String), - /// A `BigInt` was too big to be converted into a `i64` or was negative. - OutOfBounds(BigInt), - /// A value without a unit was found. - NoUnitFound(String), - /// No value at all was found. - NoValueFound(String), + // When I switch exponents to use `BigInt`, this variant should be impossible. + // Right now it'll return this error with things like "1e123456781234567812345678" + // where the exponent can't be parsed into an `isize`. + /// An exponent failed to be parsed as an `isize`. + ParseInt(String), + /// An unrecognized unit was found. + UnknownUnit(String), + /// A `BigInt` was too big to be converted into a `i64` or was negative. + OutOfBounds(BigInt), + /// A value without a unit was found. + NoUnitFound(String), + /// No value at all was found. + NoValueFound(String), } impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Error::ParseInt(ref s) => { - write!(f, "ParseIntError: Failed to parse \"{}\" as an integer", s) - } - Error::UnknownUnit(ref s) => { - write!(f, "UnknownUnitError: \"{}\" is not a known unit", s) - } - Error::OutOfBounds(ref b) => { - write!(f, "OutOfBoundsError: \"{}\" cannot be converted to i64", b) - } - Error::NoUnitFound(ref s) => { - write!(f, "NoUnitFoundError: no unit found for the value \"{}\"", s) - } - Error::NoValueFound(ref s) => write!( - f, - "NoValueFoundError: no value found in the string \"{}\"", - s - ), - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + Error::ParseInt(ref s) => { + write!(f, "ParseIntError: Failed to parse \"{}\" as an integer", s) + } + Error::UnknownUnit(ref s) => { + write!(f, "UnknownUnitError: \"{}\" is not a known unit", s) + } + Error::OutOfBounds(ref b) => { + write!(f, "OutOfBoundsError: \"{}\" cannot be converted to i64", b) + } + Error::NoUnitFound(ref s) => { + write!(f, "NoUnitFoundError: no unit found for the value \"{}\"", s) + } + Error::NoValueFound(ref s) => write!(f, "NoValueFoundError: no value found in the string \"{}\"", s), } + } } impl ErrorTrait for Error { - fn description(&self) -> &str { - match *self { - Error::ParseInt(_) => "Failed to parse a string into an integer", - Error::UnknownUnit(_) => "An unknown unit was used", - Error::OutOfBounds(_) => "An integer was too large to convert into a i64", - Error::NoUnitFound(_) => "A value without a unit was found", - Error::NoValueFound(_) => "No value was found", - } + fn description(&self) -> &str { + match *self { + Error::ParseInt(_) => "Failed to parse a string into an integer", + Error::UnknownUnit(_) => "An unknown unit was used", + Error::OutOfBounds(_) => "An integer was too large to convert into a i64", + Error::NoUnitFound(_) => "A value without a unit was found", + Error::NoValueFound(_) => "No value was found", } + } } /// A `ProtoDuration` is a duration with arbitrarily large fields. /// It can be conditionally converted into a RelativeDuration, if the fields are small enough. #[derive(Default)] struct ProtoDuration { - /// The number of nanoseconds in the `ProtoDuration`. May be negative. - nanoseconds: BigInt, - /// The number of microseconds in the `ProtoDuration`. May be negative. - microseconds: BigInt, - /// The number of milliseconds in the `ProtoDuration`. May be negative. - milliseconds: BigInt, - /// The number of seconds in the `ProtoDuration`. May be negative. - seconds: BigInt, - /// The number of minutes in the `ProtoDuration`. May be negative. - minutes: BigInt, - /// The number of hours in the `ProtoDuration`. May be negative. - hours: BigInt, - /// The number of days in the `ProtoDuration`. May be negative. - days: BigInt, - /// The number of weeks in the `ProtoDuration`. May be negative. - weeks: BigInt, - /// The number of months in the `ProtoDuration`. May be negative. - months: BigInt, - /// The number of years in the `ProtoDuration`. May be negative. - years: BigInt, + /// The number of nanoseconds in the `ProtoDuration`. May be negative. + nanoseconds: BigInt, + /// The number of microseconds in the `ProtoDuration`. May be negative. + microseconds: BigInt, + /// The number of milliseconds in the `ProtoDuration`. May be negative. + milliseconds: BigInt, + /// The number of seconds in the `ProtoDuration`. May be negative. + seconds: BigInt, + /// The number of minutes in the `ProtoDuration`. May be negative. + minutes: BigInt, + /// The number of hours in the `ProtoDuration`. May be negative. + hours: BigInt, + /// The number of days in the `ProtoDuration`. May be negative. + days: BigInt, + /// The number of weeks in the `ProtoDuration`. May be negative. + weeks: BigInt, + /// The number of months in the `ProtoDuration`. May be negative. + months: BigInt, + /// The number of years in the `ProtoDuration`. May be negative. + years: BigInt, } impl ProtoDuration { - /// Try to convert a `ProtoDuration` into a `RelativeDuration`. - /// This may fail if the `ProtoDuration` is too long or it ends up having a negative total duration. - fn into_duration(self) -> Result { - let mut nanoseconds = - self.nanoseconds + 1_000_i64 * self.microseconds + 1_000_000_i64 * self.milliseconds; - let mut seconds = self.seconds - + 60_i64 * self.minutes - + 3_600_i64 * self.hours - + 86_400_i64 * self.days - + 604_800_i64 * self.weeks; - let months = self.months + 12_i32 * self.years; + /// Try to convert a `ProtoDuration` into a `RelativeDuration`. + /// This may fail if the `ProtoDuration` is too long or it ends up having a negative total duration. + fn into_duration(self) -> Result { + let mut nanoseconds = self.nanoseconds + 1_000_i64 * self.microseconds + 1_000_000_i64 * self.milliseconds; + let mut seconds = + self.seconds + 60_i64 * self.minutes + 3_600_i64 * self.hours + 86_400_i64 * self.days + 604_800_i64 * self.weeks; + let months = self.months + 12_i32 * self.years; - seconds += &nanoseconds / 1_000_000_000_i64; - nanoseconds %= 1_000_000_000_i64; + seconds += &nanoseconds / 1_000_000_000_i64; + nanoseconds %= 1_000_000_000_i64; - let seconds = - ::to_i64(&seconds).ok_or_else(|| Error::OutOfBounds(seconds))?; - let nanoseconds = ::to_i64(&nanoseconds).ok_or_else(|| { - // This shouldn't happen since nanoseconds is less than 1 billion. - Error::OutOfBounds(nanoseconds) - })?; - let months = - ::to_i32(&months).ok_or_else(|| Error::OutOfBounds(months))?; + let seconds = ::to_i64(&seconds).ok_or_else(|| Error::OutOfBounds(seconds))?; + let nanoseconds = ::to_i64(&nanoseconds).ok_or_else(|| { + // This shouldn't happen since nanoseconds is less than 1 billion. + Error::OutOfBounds(nanoseconds) + })?; + let months = ::to_i32(&months).ok_or_else(|| Error::OutOfBounds(months))?; - Ok(RelativeDuration::months(months).with_duration(Duration::seconds(seconds) + Duration::nanoseconds(nanoseconds))) - } + Ok(RelativeDuration::months(months).with_duration(Duration::seconds(seconds) + Duration::nanoseconds(nanoseconds))) + } } lazy_static! { - static ref NUMBER_RE: Regex = Regex::new( - r"(?x) + static ref NUMBER_RE: Regex = Regex::new( + r"(?x) ^ [^\w-]* # any non-word characters, except '-' (for negatives - may add '.' for decimals) (-?\d+) # a possible negative sign and some positive number of digits [^\w-]* # more non-word characters $" - ) - .expect("Compiling a regex went wrong"); + ) + .expect("Compiling a regex went wrong"); } lazy_static! { - static ref DURATION_RE: Regex = Regex::new( - r"(?x)(?i) + static ref DURATION_RE: Regex = Regex::new( + r"(?x)(?i) (?P-?\d+) # the integer part \.?(?:(?P\d+))? # an optional decimal part # note: the previous part will eat any decimals @@ -142,228 +132,214 @@ lazy_static! { (?P[\w&&[^\d]]+) # a word with no digits )? ", - ) - .expect("Compiling a regex went wrong"); + ) + .expect("Compiling a regex went wrong"); } /// Convert some unit abbreviations to their full form. /// See the [module level documentation](index.html) for more information about which abbreviations are accepted. // TODO: return an `enum`. fn parse_unit(unit: &str) -> &str { - let unit_casefold = unit.to_lowercase(); + let unit_casefold = unit.to_lowercase(); - if unit_casefold.starts_with('n') - && ("nanoseconds".starts_with(&unit_casefold) || "nsecs".starts_with(&unit_casefold)) - { - "nanoseconds" - } else if unit_casefold.starts_with("mic") && "microseconds".starts_with(&unit_casefold) - || unit_casefold.starts_with('u') && "usecs".starts_with(&unit_casefold) - || unit_casefold.starts_with('μ') && "\u{3bc}secs".starts_with(&unit_casefold) - { - "microseconds" - } else if unit_casefold.starts_with("mil") && "milliseconds".starts_with(&unit_casefold) - || unit_casefold.starts_with("ms") && "msecs".starts_with(&unit_casefold) - { - "milliseconds" - } else if unit_casefold.starts_with('s') - && ("seconds".starts_with(&unit_casefold) || "secs".starts_with(&unit_casefold)) - { - "seconds" - } else if (unit_casefold.starts_with("min") || unit.starts_with('m')) - && ("minutes".starts_with(&unit_casefold) || "mins".starts_with(&unit_casefold)) - { - "minutes" - } else if unit_casefold.starts_with('h') - && ("hours".starts_with(&unit_casefold) || "hrs".starts_with(&unit_casefold)) - { - "hours" - } else if unit_casefold.starts_with('d') && "days".starts_with(&unit_casefold) { - "days" - } else if unit_casefold.starts_with('w') && "weeks".starts_with(&unit_casefold) { - "weeks" - } else if (unit_casefold.starts_with("mo") || unit.starts_with('M')) - && "months".starts_with(&unit_casefold) - { - "months" - } else if unit_casefold.starts_with('y') - && ("years".starts_with(&unit_casefold) || "yrs".starts_with(&unit_casefold)) - { - "years" - } else { - unit - } + if unit_casefold.starts_with('n') + && ("nanoseconds".starts_with(&unit_casefold) || "nsecs".starts_with(&unit_casefold)) + { + "nanoseconds" + } else if unit_casefold.starts_with("mic") && "microseconds".starts_with(&unit_casefold) + || unit_casefold.starts_with('u') && "usecs".starts_with(&unit_casefold) + || unit_casefold.starts_with('μ') && "\u{3bc}secs".starts_with(&unit_casefold) + { + "microseconds" + } else if unit_casefold.starts_with("mil") && "milliseconds".starts_with(&unit_casefold) + || unit_casefold.starts_with("ms") && "msecs".starts_with(&unit_casefold) + { + "milliseconds" + } else if unit_casefold.starts_with('s') + && ("seconds".starts_with(&unit_casefold) || "secs".starts_with(&unit_casefold)) + { + "seconds" + } else if (unit_casefold.starts_with("min") || unit.starts_with('m')) + && ("minutes".starts_with(&unit_casefold) || "mins".starts_with(&unit_casefold)) + { + "minutes" + } else if unit_casefold.starts_with('h') && ("hours".starts_with(&unit_casefold) || "hrs".starts_with(&unit_casefold)) + { + "hours" + } else if unit_casefold.starts_with('d') && "days".starts_with(&unit_casefold) { + "days" + } else if unit_casefold.starts_with('w') && "weeks".starts_with(&unit_casefold) { + "weeks" + } else if (unit_casefold.starts_with("mo") || unit.starts_with('M')) && "months".starts_with(&unit_casefold) { + "months" + } else if unit_casefold.starts_with('y') && ("years".starts_with(&unit_casefold) || "yrs".starts_with(&unit_casefold)) + { + "years" + } else { + unit + } } /// Parse a string into a duration object. /// /// See the [module level documentation](index.html) for more. pub fn parse(input: &str) -> Result { - if let Some(int) = NUMBER_RE.captures(input) { - // This means it's just a value - // Since the regex matched, the first group exists, so we can unwrap. - let seconds = BigInt::parse_bytes(int.get(1).unwrap().as_str().as_bytes(), 10) - .ok_or_else(|| Error::ParseInt(int.get(1).unwrap().as_str().to_owned()))?; - Ok(RelativeDuration::from(Duration::seconds( - seconds - .to_i64() - .ok_or_else(|| Error::OutOfBounds(seconds))? - ))) - } else if DURATION_RE.is_match(input) { - // This means we have at least one "unit" (or plain word) and one value. - let mut duration = ProtoDuration::default(); - for capture in DURATION_RE.captures_iter(input) { - match ( - capture.name("int"), - capture.name("dec"), - capture.name("exp"), - capture.name("unit"), - ) { - // capture.get(0) is *always* the actual match, so unwrapping causes no problems - (.., None) => { - return Err(Error::NoUnitFound( - capture.get(0).unwrap().as_str().to_owned(), - )) - } - (None, ..) => { - return Err(Error::NoValueFound( - capture.get(0).unwrap().as_str().to_owned(), - )) - } - (Some(int), None, None, Some(unit)) => { - let int = BigInt::parse_bytes(int.as_str().as_bytes(), 10) - .ok_or_else(|| Error::ParseInt(int.as_str().to_owned()))?; + if let Some(int) = NUMBER_RE.captures(input) { + // This means it's just a value + // Since the regex matched, the first group exists, so we can unwrap. + let seconds = BigInt::parse_bytes(int.get(1).unwrap().as_str().as_bytes(), 10) + .ok_or_else(|| Error::ParseInt(int.get(1).unwrap().as_str().to_owned()))?; + Ok(RelativeDuration::from(Duration::seconds( + seconds.to_i64().ok_or_else(|| Error::OutOfBounds(seconds))?, + ))) + } else if DURATION_RE.is_match(input) { + // This means we have at least one "unit" (or plain word) and one value. + let mut duration = ProtoDuration::default(); + for capture in DURATION_RE.captures_iter(input) { + match ( + capture.name("int"), + capture.name("dec"), + capture.name("exp"), + capture.name("unit"), + ) { + // capture.get(0) is *always* the actual match, so unwrapping causes no problems + (.., None) => return Err(Error::NoUnitFound(capture.get(0).unwrap().as_str().to_owned())), + (None, ..) => return Err(Error::NoValueFound(capture.get(0).unwrap().as_str().to_owned())), + (Some(int), None, None, Some(unit)) => { + let int = + BigInt::parse_bytes(int.as_str().as_bytes(), 10).ok_or_else(|| Error::ParseInt(int.as_str().to_owned()))?; - match parse_unit(unit.as_str()) { - "nanoseconds" => duration.nanoseconds += int, - "microseconds" => duration.microseconds += int, - "milliseconds" => duration.milliseconds += int, - "seconds" => duration.seconds += int, - "minutes" => duration.minutes += int, - "hours" => duration.hours += int, - "days" => duration.days += int, - "weeks" => duration.weeks += int, - "months" => duration.months += int, - "years" => duration.years += int, - s => return Err(Error::UnknownUnit(s.to_owned())), - } - } - (Some(int), Some(dec), None, Some(unit)) => { - let int = BigInt::parse_bytes(int.as_str().as_bytes(), 10) - .ok_or_else(|| Error::ParseInt(int.as_str().to_owned()))?; + match parse_unit(unit.as_str()) { + "nanoseconds" => duration.nanoseconds += int, + "microseconds" => duration.microseconds += int, + "milliseconds" => duration.milliseconds += int, + "seconds" => duration.seconds += int, + "minutes" => duration.minutes += int, + "hours" => duration.hours += int, + "days" => duration.days += int, + "weeks" => duration.weeks += int, + "months" => duration.months += int, + "years" => duration.years += int, + s => return Err(Error::UnknownUnit(s.to_owned())), + } + } + (Some(int), Some(dec), None, Some(unit)) => { + let int = + BigInt::parse_bytes(int.as_str().as_bytes(), 10).ok_or_else(|| Error::ParseInt(int.as_str().to_owned()))?; - let exp = dec.as_str().len(); + let exp = dec.as_str().len(); - let dec = BigInt::parse_bytes(dec.as_str().as_bytes(), 10) - .ok_or_else(|| Error::ParseInt(dec.as_str().to_owned()))?; + let dec = + BigInt::parse_bytes(dec.as_str().as_bytes(), 10).ok_or_else(|| Error::ParseInt(dec.as_str().to_owned()))?; - // boosted_int is value * 10^exp * unit - let mut boosted_int = int * pow(BigInt::from(10), exp) + dec; + // boosted_int is value * 10^exp * unit + let mut boosted_int = int * pow(BigInt::from(10), exp) + dec; - // boosted_int is now value * 10^exp * nanoseconds - match parse_unit(unit.as_str()) { - "nanoseconds" => boosted_int = boosted_int, - "microseconds" => boosted_int = 1_000_i64 * boosted_int, - "milliseconds" => boosted_int = 1_000_000_i64 * boosted_int, - "seconds" => boosted_int = 1_000_000_000_i64 * boosted_int, - "minutes" => boosted_int = 60_000_000_000_i64 * boosted_int, - "hours" => boosted_int = 3_600_000_000_000_i64 * boosted_int, - "days" => boosted_int = 86_400_000_000_000_i64 * boosted_int, - "weeks" => boosted_int = 604_800_000_000_000_i64 * boosted_int, - "months" => boosted_int = 2_629_746_000_000_000_i64 * boosted_int, - "years" => boosted_int = 31_556_952_000_000_000_i64 * boosted_int, - s => return Err(Error::UnknownUnit(s.to_owned())), - } + // boosted_int is now value * 10^exp * nanoseconds + match parse_unit(unit.as_str()) { + "nanoseconds" => boosted_int = boosted_int, + "microseconds" => boosted_int = 1_000_i64 * boosted_int, + "milliseconds" => boosted_int = 1_000_000_i64 * boosted_int, + "seconds" => boosted_int = 1_000_000_000_i64 * boosted_int, + "minutes" => boosted_int = 60_000_000_000_i64 * boosted_int, + "hours" => boosted_int = 3_600_000_000_000_i64 * boosted_int, + "days" => boosted_int = 86_400_000_000_000_i64 * boosted_int, + "weeks" => boosted_int = 604_800_000_000_000_i64 * boosted_int, + "months" => boosted_int = 2_629_746_000_000_000_i64 * boosted_int, + "years" => boosted_int = 31_556_952_000_000_000_i64 * boosted_int, + s => return Err(Error::UnknownUnit(s.to_owned())), + } - // boosted_int is now value * nanoseconds (rounding down) - boosted_int /= pow(BigInt::from(10), exp); - duration.nanoseconds += boosted_int; - } - (Some(int), None, Some(exp), Some(unit)) => { - let int = BigInt::parse_bytes(int.as_str().as_bytes(), 10) - .ok_or_else(|| Error::ParseInt(int.as_str().to_owned()))?; + // boosted_int is now value * nanoseconds (rounding down) + boosted_int /= pow(BigInt::from(10), exp); + duration.nanoseconds += boosted_int; + } + (Some(int), None, Some(exp), Some(unit)) => { + let int = + BigInt::parse_bytes(int.as_str().as_bytes(), 10).ok_or_else(|| Error::ParseInt(int.as_str().to_owned()))?; - let exp = exp - .as_str() - .parse::() - .or_else(|_| Err(Error::ParseInt(exp.as_str().to_owned())))?; + let exp = exp + .as_str() + .parse::() + .or_else(|_| Err(Error::ParseInt(exp.as_str().to_owned())))?; - // boosted_int is value * 10^-exp * unit - let mut boosted_int = int; + // boosted_int is value * 10^-exp * unit + let mut boosted_int = int; - // boosted_int is now value * 10^-exp * nanoseconds - match parse_unit(unit.as_str()) { - "nanoseconds" => boosted_int = boosted_int, - "microseconds" => boosted_int = 1_000_i64 * boosted_int, - "milliseconds" => boosted_int = 1_000_000_i64 * boosted_int, - "seconds" => boosted_int = 1_000_000_000_i64 * boosted_int, - "minutes" => boosted_int = 60_000_000_000_i64 * boosted_int, - "hours" => boosted_int = 3_600_000_000_000_i64 * boosted_int, - "days" => boosted_int = 86_400_000_000_000_i64 * boosted_int, - "weeks" => boosted_int = 604_800_000_000_000_i64 * boosted_int, - "months" => boosted_int = 2_629_746_000_000_000_i64 * boosted_int, - "years" => boosted_int = 31_556_952_000_000_000_i64 * boosted_int, - s => return Err(Error::UnknownUnit(s.to_owned())), - } + // boosted_int is now value * 10^-exp * nanoseconds + match parse_unit(unit.as_str()) { + "nanoseconds" => boosted_int = boosted_int, + "microseconds" => boosted_int = 1_000_i64 * boosted_int, + "milliseconds" => boosted_int = 1_000_000_i64 * boosted_int, + "seconds" => boosted_int = 1_000_000_000_i64 * boosted_int, + "minutes" => boosted_int = 60_000_000_000_i64 * boosted_int, + "hours" => boosted_int = 3_600_000_000_000_i64 * boosted_int, + "days" => boosted_int = 86_400_000_000_000_i64 * boosted_int, + "weeks" => boosted_int = 604_800_000_000_000_i64 * boosted_int, + "months" => boosted_int = 2_629_746_000_000_000_i64 * boosted_int, + "years" => boosted_int = 31_556_952_000_000_000_i64 * boosted_int, + s => return Err(Error::UnknownUnit(s.to_owned())), + } - // boosted_int is now value * nanoseconds - // x.wrapping_abs() as usize will always give the intended result - // This is because isize::MIN as usize == abs(isize::MIN) (as a usize) - if exp < 0 { - boosted_int /= pow(BigInt::from(10), exp.wrapping_abs() as usize); - } else { - boosted_int *= pow(BigInt::from(10), exp.wrapping_abs() as usize); - } - duration.nanoseconds += boosted_int; - } - (Some(int), Some(dec), Some(exp), Some(unit)) => { - let int = BigInt::parse_bytes(int.as_str().as_bytes(), 10) - .ok_or_else(|| Error::ParseInt(int.as_str().to_owned()))?; + // boosted_int is now value * nanoseconds + // x.wrapping_abs() as usize will always give the intended result + // This is because isize::MIN as usize == abs(isize::MIN) (as a usize) + if exp < 0 { + boosted_int /= pow(BigInt::from(10), exp.wrapping_abs() as usize); + } else { + boosted_int *= pow(BigInt::from(10), exp.wrapping_abs() as usize); + } + duration.nanoseconds += boosted_int; + } + (Some(int), Some(dec), Some(exp), Some(unit)) => { + let int = + BigInt::parse_bytes(int.as_str().as_bytes(), 10).ok_or_else(|| Error::ParseInt(int.as_str().to_owned()))?; - let dec_exp = dec.as_str().len(); + let dec_exp = dec.as_str().len(); - let exp = exp - .as_str() - .parse::() - .or_else(|_| Err(Error::ParseInt(exp.as_str().to_owned())))? - - (BigInt::from(dec_exp)); - let exp = exp.to_isize().ok_or_else(|| Error::OutOfBounds(exp))?; + let exp = exp + .as_str() + .parse::() + .or_else(|_| Err(Error::ParseInt(exp.as_str().to_owned())))? + - (BigInt::from(dec_exp)); + let exp = exp.to_isize().ok_or_else(|| Error::OutOfBounds(exp))?; - let dec = BigInt::parse_bytes(dec.as_str().as_bytes(), 10) - .ok_or_else(|| Error::ParseInt(dec.as_str().to_owned()))?; + let dec = + BigInt::parse_bytes(dec.as_str().as_bytes(), 10).ok_or_else(|| Error::ParseInt(dec.as_str().to_owned()))?; - // boosted_int is value * 10^-exp * unit - let mut boosted_int = int * pow(BigInt::from(10), dec_exp) + dec; + // boosted_int is value * 10^-exp * unit + let mut boosted_int = int * pow(BigInt::from(10), dec_exp) + dec; - // boosted_int is now value * 10^-exp * nanoseconds - match parse_unit(unit.as_str()) { - "nanoseconds" => boosted_int = boosted_int, - "microseconds" => boosted_int *= 1_000_i64, - "milliseconds" => boosted_int *= 1_000_000_i64, - "seconds" => boosted_int *= 1_000_000_000_i64, - "minutes" => boosted_int *= 60_000_000_000_i64, - "hours" => boosted_int *= 3_600_000_000_000_i64, - "days" => boosted_int *= 86_400_000_000_000_i64, - "weeks" => boosted_int *= 604_800_000_000_000_i64, - "months" => boosted_int *= 2_629_746_000_000_000_i64, - "years" => boosted_int *= 31_556_952_000_000_000_i64, - s => return Err(Error::UnknownUnit(s.to_owned())), - } + // boosted_int is now value * 10^-exp * nanoseconds + match parse_unit(unit.as_str()) { + "nanoseconds" => boosted_int = boosted_int, + "microseconds" => boosted_int *= 1_000_i64, + "milliseconds" => boosted_int *= 1_000_000_i64, + "seconds" => boosted_int *= 1_000_000_000_i64, + "minutes" => boosted_int *= 60_000_000_000_i64, + "hours" => boosted_int *= 3_600_000_000_000_i64, + "days" => boosted_int *= 86_400_000_000_000_i64, + "weeks" => boosted_int *= 604_800_000_000_000_i64, + "months" => boosted_int *= 2_629_746_000_000_000_i64, + "years" => boosted_int *= 31_556_952_000_000_000_i64, + s => return Err(Error::UnknownUnit(s.to_owned())), + } - // boosted_int is now value * nanoseconds (potentially rounded down) - // x.wrapping_abs() as usize will always give the intended result - // This is because isize::MIN as usize == abs(isize::MIN) (as a usize) - if exp < 0 { - boosted_int /= pow(BigInt::from(10), exp.wrapping_abs() as usize); - } else { - boosted_int *= pow(BigInt::from(10), exp.wrapping_abs() as usize); - } - duration.nanoseconds += boosted_int; - } - } + // boosted_int is now value * nanoseconds (potentially rounded down) + // x.wrapping_abs() as usize will always give the intended result + // This is because isize::MIN as usize == abs(isize::MIN) (as a usize) + if exp < 0 { + boosted_int /= pow(BigInt::from(10), exp.wrapping_abs() as usize); + } else { + boosted_int *= pow(BigInt::from(10), exp.wrapping_abs() as usize); + } + duration.nanoseconds += boosted_int; } - duration.into_duration() - } else { - // Just a unit or nothing at all - Err(Error::NoValueFound(input.to_owned())) + } } + duration.into_duration() + } else { + // Just a unit or nothing at all + Err(Error::NoValueFound(input.to_owned())) + } } diff --git a/lib/parse_relative_duration/tests/basics.rs b/lib/parse_relative_duration/tests/basics.rs index 31bdcbf..52c017e 100644 --- a/lib/parse_relative_duration/tests/basics.rs +++ b/lib/parse_relative_duration/tests/basics.rs @@ -1,32 +1,36 @@ -extern crate num; extern crate chrono; extern crate chronoutil; +extern crate num; extern crate parse_relative_duration; -use num::BigInt; use chrono::Duration; use chronoutil::RelativeDuration; +use num::BigInt; use parse_relative_duration::parse; macro_rules! test_parse { - (fn $fun:ident($string: expr, $months: expr, $seconds: expr, $nanoseconds: expr)) => { - #[test] - fn $fun() { - assert_eq!(parse($string), Ok( - RelativeDuration::months($months).with_duration(Duration::seconds($seconds) + Duration::nanoseconds($nanoseconds)) - )) - } - }; + (fn $fun:ident($string: expr, $months: expr, $seconds: expr, $nanoseconds: expr)) => { + #[test] + fn $fun() { + assert_eq!( + parse($string), + Ok( + RelativeDuration::months($months) + .with_duration(Duration::seconds($seconds) + Duration::nanoseconds($nanoseconds)) + ) + ) + } + }; } macro_rules! test_invalid { - (fn $fun:ident($string: expr, $error: expr)) => { - #[test] - fn $fun() { - assert_eq!(parse($string), Err($error)); - } - }; + (fn $fun:ident($string: expr, $error: expr)) => { + #[test] + fn $fun() { + assert_eq!(parse($string), Err($error)); + } + }; } test_parse!(fn nano1("1nsec", 0, 0, 1)); @@ -170,12 +174,12 @@ test_invalid!(fn wrong_order("year15", parse::Error::NoUnitFound("15".to_string( #[test] fn number_too_big() { - assert_eq!( - Ok(parse("123456789012345678901234567890 seconds")), - "123456789012345678901234567890" - .parse::() - .map(|int| Err(parse::Error::OutOfBounds(int))) - ); + assert_eq!( + Ok(parse("123456789012345678901234567890 seconds")), + "123456789012345678901234567890" + .parse::() + .map(|int| Err(parse::Error::OutOfBounds(int))) + ); } test_invalid!(fn not_enough_units("16 17 seconds", parse::Error::NoUnitFound("16".to_string()))); diff --git a/lib/sdd/Cargo.toml b/lib/sdd/Cargo.toml index 739afef..938b845 100644 --- a/lib/sdd/Cargo.toml +++ b/lib/sdd/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sdd" -version = "0.1.0" +version = "0.1.1" authors = ["Ziyang Li "] edition = "2018"