From 63e3b0ba166cfe4fedc7c4c9a10f9b9d6df04706 Mon Sep 17 00:00:00 2001 From: Ziyang Li Date: Sun, 2 Apr 2023 22:56:07 -0400 Subject: [PATCH] Bumping version --- Cargo.toml | 1 + changelog.md | 16 +- core/Cargo.toml | 10 +- core/src/common/foreign_function.rs | 553 ++---------------- core/src/common/foreign_functions/abs.rs | 62 ++ core/src/common/foreign_functions/cos.rs | 23 + .../common/foreign_functions/datetime_day.rs | 37 ++ .../foreign_functions/datetime_month.rs | 37 ++ .../foreign_functions/datetime_month0.rs | 37 ++ .../common/foreign_functions/datetime_year.rs | 37 ++ core/src/common/foreign_functions/hash.rs | 35 ++ core/src/common/foreign_functions/max.rs | 71 +++ core/src/common/foreign_functions/min.rs | 71 +++ core/src/common/foreign_functions/mod.rs | 40 ++ core/src/common/foreign_functions/sin.rs | 23 + .../foreign_functions/string_char_at.rs | 39 ++ .../common/foreign_functions/string_concat.rs | 40 ++ .../common/foreign_functions/string_length.rs | 36 ++ .../src/common/foreign_functions/substring.rs | 53 ++ core/src/common/foreign_functions/tan.rs | 23 + core/src/common/foreign_predicate.rs | 313 ++++++++++ .../src/common/foreign_predicates/float_eq.rs | 63 ++ core/src/common/foreign_predicates/mod.rs | 28 + core/src/common/foreign_predicates/range.rs | 95 +++ .../src/common/foreign_predicates/soft_cmp.rs | 93 +++ core/src/common/foreign_predicates/soft_eq.rs | 100 ++++ core/src/common/foreign_predicates/soft_gt.rs | 98 ++++ core/src/common/foreign_predicates/soft_lt.rs | 98 ++++ .../src/common/foreign_predicates/soft_neq.rs | 100 ++++ .../common/foreign_predicates/string_chars.rs | 59 ++ core/src/common/generic_tuple.rs | 6 + core/src/common/input_tag.rs | 24 +- core/src/common/mod.rs | 2 + core/src/common/tuple.rs | 32 +- core/src/common/value.rs | 34 ++ core/src/common/value_type.rs | 94 +++ core/src/compiler/back/ast.rs | 40 +- core/src/compiler/back/b2r.rs | 100 +++- core/src/compiler/back/compile.rs | 1 + .../back/optimizations/empty_rule_to_fact.rs | 4 +- core/src/compiler/back/query_plan.rs | 516 +++++++++++++--- core/src/compiler/back/scc.rs | 15 +- core/src/compiler/front/analysis.rs | 18 +- .../analyzers/boundness/boundness_analysis.rs | 12 +- .../front/analyzers/boundness/context.rs | 40 +- .../front/analyzers/boundness/dependency.rs | 3 + .../front/analyzers/boundness/foreign.rs | 32 + .../front/analyzers/boundness/local.rs | 31 +- .../compiler/front/analyzers/boundness/mod.rs | 2 + .../compiler/front/analyzers/constant_decl.rs | 101 ++++ .../compiler/front/analyzers/head_relation.rs | 13 +- .../front/analyzers/invalid_constant.rs | 48 ++ core/src/compiler/front/analyzers/mod.rs | 3 + .../compiler/front/analyzers/output_files.rs | 2 +- .../front/analyzers/type_inference/error.rs | 93 ++- .../type_inference/foreign_function.rs | 172 ++++++ .../type_inference/foreign_predicate.rs | 77 +++ .../front/analyzers/type_inference/local.rs | 41 +- .../front/analyzers/type_inference/mod.rs | 8 +- .../type_inference/operator_rules.rs | 137 +++++ .../type_inference/type_inference.rs | 98 +++- .../analyzers/type_inference/type_set.rs | 39 +- .../analyzers/type_inference/unification.rs | 222 +++++-- core/src/compiler/front/ast/constant.rs | 25 +- core/src/compiler/front/ast/formula.rs | 12 +- core/src/compiler/front/ast/query.rs | 16 +- core/src/compiler/front/ast/type_decl.rs | 61 ++ core/src/compiler/front/ast/types.rs | 20 + core/src/compiler/front/compile.rs | 40 +- core/src/compiler/front/f2b/f2b.rs | 28 +- core/src/compiler/front/f2b/flatten_expr.rs | 47 +- core/src/compiler/front/grammar.lalrpop | 76 ++- core/src/compiler/front/pretty.rs | 36 +- .../front/transformations/tagged_rule.rs | 5 +- core/src/compiler/front/visitor.rs | 23 + core/src/compiler/front/visitor_mut.rs | 23 + core/src/compiler/ram/ast.rs | 39 +- core/src/compiler/ram/dependency.rs | 7 + .../ram/optimizations/project_cascade.rs | 7 +- core/src/compiler/ram/pretty.rs | 14 + core/src/compiler/ram/ram2rs.rs | 19 +- core/src/compiler/ram/transform.rs | 1 + core/src/integrate/context.rs | 64 +- .../runtime/database/extensional/database.rs | 4 +- .../runtime/database/extensional/relation.rs | 16 +- .../runtime/database/intentional/database.rs | 8 +- .../runtime/database/intentional/relation.rs | 22 +- .../dynamic/dataflow/batching/batch.rs | 12 +- .../dynamic/dataflow/batching/batches.rs | 4 + .../dynamic/dataflow/dynamic_dataflow.rs | 58 +- .../dataflow/foreign_predicate/constraint.rs | 107 ++++ .../dataflow/foreign_predicate/ground.rs | 65 ++ .../dataflow/foreign_predicate/join.rs | 146 +++++ .../dynamic/dataflow/foreign_predicate/mod.rs | 9 + core/src/runtime/dynamic/dataflow/mod.rs | 2 + core/src/runtime/dynamic/incremental.rs | 2 +- core/src/runtime/dynamic/io.rs | 12 +- core/src/runtime/dynamic/iteration.rs | 9 + core/src/runtime/dynamic/relation.rs | 11 +- core/src/runtime/env/environment.rs | 77 ++- core/src/runtime/env/options.rs | 2 + core/src/runtime/error/error.rs | 3 + .../provenance/common/diff_prob_storage.rs | 77 +++ .../runtime/provenance/common/dual_number.rs | 8 + .../provenance/common/dual_number_2.rs | 7 + .../provenance/common/input_diff_prob.rs | 27 +- .../common/input_exclusive_diff_prob.rs | 20 +- .../provenance/common/input_exclusive_prob.rs | 4 +- core/src/runtime/provenance/common/mod.rs | 2 + .../provenance/common/output_diff_prob.rs | 10 +- .../differentiable/diff_add_mult_prob.rs | 27 +- .../differentiable/diff_max_mult_prob.rs | 27 +- .../differentiable/diff_min_max_prob.rs | 138 ++--- .../differentiable/diff_nand_min_prob.rs | 27 +- .../differentiable/diff_nand_mult_prob.rs | 27 +- .../differentiable/diff_sample_k_proofs.rs | 47 +- .../diff_top_bottom_k_clauses.rs | 50 +- .../differentiable/diff_top_k_proofs.rs | 47 +- .../differentiable/diff_top_k_proofs_indiv.rs | 40 +- .../runtime/provenance/discrete/boolean.rs | 2 +- .../runtime/provenance/discrete/natural.rs | 5 +- .../src/runtime/provenance/discrete/proofs.rs | 41 +- core/src/runtime/provenance/discrete/unit.rs | 14 +- .../provenance/probabilistic/add_mult_prob.rs | 2 +- .../provenance/probabilistic/min_max_prob.rs | 2 +- .../provenance/probabilistic/prob_proofs.rs | 43 +- .../probabilistic/sample_k_proofs.rs | 52 +- .../probabilistic/top_bottom_k_clauses.rs | 28 +- .../provenance/probabilistic/top_k_proofs.rs | 36 +- core/src/runtime/provenance/provenance.rs | 7 +- core/src/runtime/statics/relation.rs | 8 +- core/src/testing/test_collection.rs | 78 ++- core/src/testing/test_interpret.rs | 8 +- core/src/utils/chrono.rs | 10 + core/src/utils/copy_on_write.rs | 8 +- core/src/utils/float.rs | 39 ++ core/src/utils/integer.rs | 2 + core/src/utils/mod.rs | 4 + core/src/utils/pointer_family.rs | 144 ++++- core/tests/compiler/errors.rs | 10 + core/tests/integrate/basic.rs | 138 +++-- core/tests/integrate/dt.rs | 4 +- core/tests/integrate/edb.rs | 6 + core/tests/integrate/ff.rs | 107 ++++ core/tests/integrate/fp.rs | 91 +++ core/tests/integrate/incr.rs | 9 +- core/tests/integrate/mod.rs | 3 +- core/tests/integrate/prob.rs | 2 +- core/tests/integrate/time.rs | 127 ++++ .../runtime/dataflow/dyn_foreign_predicate.rs | 77 +++ .../runtime/dataflow/dyn_group_aggregate.rs | 2 +- .../runtime/dataflow/dyn_group_by_key.rs | 2 +- core/tests/runtime/dataflow/dyn_intersect.rs | 2 +- core/tests/runtime/dataflow/dyn_join.rs | 2 +- core/tests/runtime/dataflow/dyn_product.rs | 2 +- core/tests/runtime/dataflow/dyn_project.rs | 2 +- core/tests/runtime/dataflow/mod.rs | 1 + core/tests/runtime/provenance/top_bottom_k.rs | 4 +- core/tests/tests.rs | 1 + core/tests/utils/mod.rs | 1 + core/tests/utils/value.rs | 17 + doc/.gitignore | 1 + doc/book.toml | 25 + doc/js/hljs_scallop.js | 32 + doc/readme.md | 21 + docs/readme.md => doc/src/crash_course.md | 0 doc/src/developer/binding.md | 0 doc/src/developer/index.md | 0 doc/src/developer/language_construct.md | 0 doc/src/grammar.md | 0 doc/src/installation.md | 30 + doc/src/introduction.md | 3 + doc/src/language/aggregation.md | 1 + doc/src/language/facts.md | 0 doc/src/language/foreign_functions.md | 1 + doc/src/language/foreign_predicates.md | 27 + doc/src/language/index.md | 0 doc/src/language/negation.md | 1 + doc/src/language/provenance.md | 1 + doc/src/language/query.md | 0 doc/src/language/recursion.md | 1 + doc/src/language/rules.md | 0 doc/src/misc/contributors.md | 0 doc/src/probabilistic/facts.md | 1 + doc/src/probabilistic/index.md | 0 doc/src/probabilistic/logic.md | 1 + doc/src/probabilistic/reasoning.md | 1 + doc/src/probabilistic/sampling.md | 1 + doc/src/readme.md | 0 doc/src/scallopy/branching.md | 1 + doc/src/scallopy/context.md | 1 + doc/src/scallopy/index.md | 0 doc/src/scallopy/provenance.md | 1 + doc/src/summary.md | 41 ++ docs/design/grammar.md | 83 --- docs/design/group_by.md | 53 -- docs/icons/scallop-logo-transp-128.png | Bin 6969 -> 0 bytes docs/icons/scallop-logo-ws-512.png | Bin 24610 -> 0 bytes etc/codegen/Cargo.toml | 2 +- etc/codegen/tests/codegen_basic.rs | 2 +- etc/codegen/tests/codegen_edb.rs | 4 +- etc/scallop-wasm/Cargo.toml | 2 +- etc/scallopy/Cargo.toml | 2 +- etc/scallopy/examples/foreign_predicate.py | 26 + etc/scallopy/scallopy/__init__.py | 5 + etc/scallopy/scallopy/context.py | 103 ++-- etc/scallopy/scallopy/forward.py | 49 +- etc/scallopy/scallopy/function.py | 10 +- etc/scallopy/scallopy/input_mapping.py | 398 +++++++++++++ etc/scallopy/scallopy/predicate.py | 174 ++++++ etc/scallopy/scallopy/scallopy.pyi | 6 + etc/scallopy/scallopy/utils.py | 10 + etc/scallopy/src/collection.rs | 112 ++-- etc/scallopy/src/context.rs | 37 +- etc/scallopy/src/custom_tag.rs | 2 +- etc/scallopy/src/error.rs | 2 + etc/scallopy/src/foreign_predicate.rs | 148 +++++ etc/scallopy/src/lib.rs | 2 + etc/scallopy/src/provenance.rs | 58 +- etc/scallopy/src/tag.rs | 36 ++ etc/scallopy/src/tuple.rs | 13 +- etc/scallopy/tests/basics.py | 11 - etc/scallopy/tests/configurations.py | 23 + etc/scallopy/tests/forward.py | 58 +- etc/scallopy/tests/input_mapping.py | 220 +++++++ etc/scallopy/tests/test.py | 2 + etc/sclc/Cargo.toml | 2 +- etc/sclc/src/pylib.rs | 12 +- etc/scli/Cargo.toml | 2 +- etc/scli/src/main.rs | 11 +- etc/sclrepl/Cargo.toml | 2 +- etc/sclrepl/src/main.rs | 2 +- etc/vscode-scl/language-configuration.json | 6 +- etc/vscode-scl/package.json | 2 +- .../syntaxes/scallop.tmLanguage.json | 28 + lib/ram/Cargo.toml | 8 + lib/ram/src/generic_tuple.rs | 132 +++++ lib/ram/src/language.rs | 143 +++++ lib/ram/src/lib.rs | 6 + lib/ram/src/tuple_type.rs | 170 ++++++ lib/ram/src/value_type.rs | 218 +++++++ lib/ram/tests/test_optim.rs | 37 ++ makefile | 14 +- readme.md | 2 +- 244 files changed, 8228 insertions(+), 1625 deletions(-) create mode 100644 core/src/common/foreign_functions/abs.rs create mode 100644 core/src/common/foreign_functions/cos.rs create mode 100644 core/src/common/foreign_functions/datetime_day.rs create mode 100644 core/src/common/foreign_functions/datetime_month.rs create mode 100644 core/src/common/foreign_functions/datetime_month0.rs create mode 100644 core/src/common/foreign_functions/datetime_year.rs create mode 100644 core/src/common/foreign_functions/hash.rs create mode 100644 core/src/common/foreign_functions/max.rs create mode 100644 core/src/common/foreign_functions/min.rs create mode 100644 core/src/common/foreign_functions/mod.rs create mode 100644 core/src/common/foreign_functions/sin.rs create mode 100644 core/src/common/foreign_functions/string_char_at.rs create mode 100644 core/src/common/foreign_functions/string_concat.rs create mode 100644 core/src/common/foreign_functions/string_length.rs create mode 100644 core/src/common/foreign_functions/substring.rs create mode 100644 core/src/common/foreign_functions/tan.rs create mode 100644 core/src/common/foreign_predicates/float_eq.rs create mode 100644 core/src/common/foreign_predicates/mod.rs create mode 100644 core/src/common/foreign_predicates/range.rs create mode 100644 core/src/common/foreign_predicates/soft_cmp.rs create mode 100644 core/src/common/foreign_predicates/soft_eq.rs create mode 100644 core/src/common/foreign_predicates/soft_gt.rs create mode 100644 core/src/common/foreign_predicates/soft_lt.rs create mode 100644 core/src/common/foreign_predicates/soft_neq.rs create mode 100644 core/src/common/foreign_predicates/string_chars.rs create mode 100644 core/src/compiler/front/analyzers/boundness/foreign.rs create mode 100644 core/src/compiler/front/analyzers/invalid_constant.rs create mode 100644 core/src/compiler/front/analyzers/type_inference/foreign_function.rs create mode 100644 core/src/compiler/front/analyzers/type_inference/foreign_predicate.rs create mode 100644 core/src/compiler/front/analyzers/type_inference/operator_rules.rs create mode 100644 core/src/runtime/dynamic/dataflow/foreign_predicate/constraint.rs create mode 100644 core/src/runtime/dynamic/dataflow/foreign_predicate/ground.rs create mode 100644 core/src/runtime/dynamic/dataflow/foreign_predicate/join.rs create mode 100644 core/src/runtime/dynamic/dataflow/foreign_predicate/mod.rs create mode 100644 core/src/runtime/provenance/common/diff_prob_storage.rs create mode 100644 core/src/utils/chrono.rs create mode 100644 core/src/utils/float.rs create mode 100644 core/tests/integrate/fp.rs create mode 100644 core/tests/integrate/time.rs create mode 100644 core/tests/runtime/dataflow/dyn_foreign_predicate.rs create mode 100644 core/tests/utils/mod.rs create mode 100644 core/tests/utils/value.rs create mode 100644 doc/.gitignore create mode 100644 doc/book.toml create mode 100644 doc/js/hljs_scallop.js create mode 100644 doc/readme.md rename docs/readme.md => doc/src/crash_course.md (100%) create mode 100644 doc/src/developer/binding.md create mode 100644 doc/src/developer/index.md create mode 100644 doc/src/developer/language_construct.md create mode 100644 doc/src/grammar.md create mode 100644 doc/src/installation.md create mode 100644 doc/src/introduction.md create mode 100644 doc/src/language/aggregation.md create mode 100644 doc/src/language/facts.md create mode 100644 doc/src/language/foreign_functions.md create mode 100644 doc/src/language/foreign_predicates.md create mode 100644 doc/src/language/index.md create mode 100644 doc/src/language/negation.md create mode 100644 doc/src/language/provenance.md create mode 100644 doc/src/language/query.md create mode 100644 doc/src/language/recursion.md create mode 100644 doc/src/language/rules.md create mode 100644 doc/src/misc/contributors.md create mode 100644 doc/src/probabilistic/facts.md create mode 100644 doc/src/probabilistic/index.md create mode 100644 doc/src/probabilistic/logic.md create mode 100644 doc/src/probabilistic/reasoning.md create mode 100644 doc/src/probabilistic/sampling.md create mode 100644 doc/src/readme.md create mode 100644 doc/src/scallopy/branching.md create mode 100644 doc/src/scallopy/context.md create mode 100644 doc/src/scallopy/index.md create mode 100644 doc/src/scallopy/provenance.md create mode 100644 doc/src/summary.md delete mode 100644 docs/design/grammar.md delete mode 100644 docs/design/group_by.md delete mode 100644 docs/icons/scallop-logo-transp-128.png delete mode 100644 docs/icons/scallop-logo-ws-512.png create mode 100644 etc/scallopy/examples/foreign_predicate.py create mode 100644 etc/scallopy/scallopy/input_mapping.py create mode 100644 etc/scallopy/scallopy/predicate.py create mode 100644 etc/scallopy/src/foreign_predicate.rs create mode 100644 etc/scallopy/src/tag.rs create mode 100644 etc/scallopy/tests/configurations.py create mode 100644 etc/scallopy/tests/input_mapping.py create mode 100644 lib/ram/Cargo.toml create mode 100644 lib/ram/src/generic_tuple.rs create mode 100644 lib/ram/src/language.rs create mode 100644 lib/ram/src/lib.rs create mode 100644 lib/ram/src/tuple_type.rs create mode 100644 lib/ram/src/value_type.rs create mode 100644 lib/ram/tests/test_optim.rs diff --git a/Cargo.toml b/Cargo.toml index e0fb5ae..431eb2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "etc/scallop-wasm", "lib/sdd", "lib/rsat", + "lib/ram", ] default-members = [ diff --git a/changelog.md b/changelog.md index b697a43..4f956ae 100644 --- a/changelog.md +++ b/changelog.md @@ -1,4 +1,18 @@ -# v0.1.7, Jan 12, 2022 +# latest + +- Fixed a bug so that NaN would not appear in the value computation +- Fixed a bug in `scallopy` where disjunctive facts are not processed correctly + +# v0.1.8, Mar 27, 2023 + +- Add foreign predicates including soft comparisons, +- Add `DateTime` and `Duration` support +- Add input-mapping support for multi-dimensional inputs in `scallopy` +- Fixing floating point support by rejecting `NaN` values +- Adding back iteration-limit to runtime environment +- Add Scallop book repository + +# v0.1.7, Jan 12, 2023 - Better integration with Extensional Databases (EDB) and memory optimizations - Better handling of mutual exclusive facts in probabilistic/differentiable reasoning diff --git a/core/Cargo.toml b/core/Cargo.toml index 08d53d0..0374bba 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallop-core" -version = "0.1.7" +version = "0.1.8" authors = ["Ziyang Li "] edition = "2018" @@ -8,10 +8,10 @@ edition = "2018" crate-type = ["rlib"] [build-dependencies] -lalrpop = { version = "0.19.0", features = ["lexer"] } +lalrpop = { version = "0.19.9", features = ["lexer"] } [dependencies] -lalrpop-util = "0.19.0" +lalrpop-util = "0.19.9" proc-macro2 = { version = "1.0", features = ["span-locations"] } quote = "1.0" syn = "1.0" @@ -21,6 +21,10 @@ colored = "2.0" petgraph = "0.6" csv = "1.1" sprs = "0.11" +chrono = "0.4" +dateparser = "0.1.6" +parse_duration = "2.1.1" dyn-clone = "1.0.10" +lazy_static = "1.4" rand = { version = "0.8", features = ["std_rng", "small_rng", "alloc"] } sdd = { path = "../lib/sdd" } diff --git a/core/src/common/foreign_function.rs b/core/src/common/foreign_function.rs index 5ac8b4f..a65a9f0 100644 --- a/core/src/common/foreign_function.rs +++ b/core/src/common/foreign_function.rs @@ -64,34 +64,21 @@ //! Make sure it can be cloned, and then we can implement the `ForeignFunction` trait //! for it. //! -//! ``` ignore -//! #[derive(Clone)] -//! pub struct YOUR_FF; -//! -//! impl ForeignFunction for YOUR_FF { -//! ... -//! } -//! ``` -//! //! Make sure the type information and the function implementation is correctly //! specified. //! Then we add the function to the standard library. -//! In `ForeignFunctionRegistry::std`, we add the following line -//! -//! ``` ignore -//! registry.register(ffs::YOUR_FF); -//! ``` +//! In `ForeignFunctionRegistry::std`, we add the following line `registry.register(ffs::YOUR_FF);`. //! //! After this, the function will be available in the standard library use std::collections::*; -use std::convert::*; use dyn_clone::DynClone; use super::type_family::*; use super::value::*; use super::value_type::*; +use super::foreign_functions as ffs; /// A type used for defining a foreign function. /// @@ -507,16 +494,30 @@ impl ForeignFunctionRegistry { // 1. we are starting from fresh registry; // 2. that all functions here have distinct names; // 3. all our functions are checked to be have correct types. + + // Arithmetic registry.register(ffs::Abs).unwrap(); registry.register(ffs::Sin).unwrap(); registry.register(ffs::Cos).unwrap(); registry.register(ffs::Tan).unwrap(); + + // Min/Max registry.register(ffs::Max).unwrap(); registry.register(ffs::Min).unwrap(); + + // String operations registry.register(ffs::StringConcat).unwrap(); registry.register(ffs::StringLength).unwrap(); registry.register(ffs::StringCharAt).unwrap(); registry.register(ffs::Substring).unwrap(); + + // DateTime operations + registry.register(ffs::DateTimeDay).unwrap(); + registry.register(ffs::DateTimeMonth).unwrap(); + registry.register(ffs::DateTimeMonth0).unwrap(); + registry.register(ffs::DateTimeYear).unwrap(); + + // Hashing operation registry.register(ffs::Hash).unwrap(); registry @@ -561,517 +562,47 @@ impl<'a> IntoIterator for &'a ForeignFunctionRegistry { } } -/// A library of pre-implemented foreign functions -pub mod ffs { - use super::*; - - /// Absolute value foreign function - /// - /// ``` scl - /// extern fn $abs(x: T) -> T - /// ``` - #[derive(Clone)] - pub struct Abs; - - impl ForeignFunction for Abs { - fn name(&self) -> String { - "abs".to_string() - } - - fn num_generic_types(&self) -> usize { - 1 - } - - fn generic_type_family(&self, i: usize) -> TypeFamily { - assert_eq!(i, 0); - TypeFamily::Number - } - - fn num_static_arguments(&self) -> usize { - 1 - } - - fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { - assert_eq!(i, 0); - ForeignFunctionParameterType::Generic(0) - } - - fn return_type(&self) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::Generic(0) - } - - fn execute(&self, args: Vec) -> Option { - match args[0] { - // Signed integers, take absolute - Value::I8(f) => Some(Value::I8(f.abs())), - Value::I16(f) => Some(Value::I16(f.abs())), - Value::I32(f) => Some(Value::I32(f.abs())), - Value::I64(f) => Some(Value::I64(f.abs())), - Value::I128(f) => Some(Value::I128(f.abs())), - Value::ISize(f) => Some(Value::ISize(f.abs())), - - // Unsigned integers, directly return - Value::U8(f) => Some(Value::U8(f)), - Value::U16(f) => Some(Value::U16(f)), - Value::U32(f) => Some(Value::U32(f)), - Value::U64(f) => Some(Value::U64(f)), - Value::U128(f) => Some(Value::U128(f)), - Value::USize(f) => Some(Value::USize(f)), - - // Floating points, take absolute - Value::F32(f) => Some(Value::F32(f.abs())), - Value::F64(f) => Some(Value::F64(f.abs())), - _ => panic!("should not happen; input variable to abs should be a number"), - } - } - } - - /// Floating point function - pub trait UnaryFloatFunction: Clone { - fn name(&self) -> String; - - fn execute_f32(&self, arg: f32) -> f32; - - fn execute_f64(&self, arg: f64) -> f64; - } - - impl ForeignFunction for F { - fn name(&self) -> String { - UnaryFloatFunction::name(self) - } - - fn num_generic_types(&self) -> usize { - 1 - } - - fn generic_type_family(&self, i: usize) -> TypeFamily { - assert_eq!(i, 0); - TypeFamily::Float - } - - fn num_static_arguments(&self) -> usize { - 1 - } - - fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { - assert_eq!(i, 0); - ForeignFunctionParameterType::Generic(0) - } - - fn return_type(&self) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::Generic(0) - } - - fn execute(&self, args: Vec) -> Option { - match args[0] { - Value::F32(f) => Some(Value::F32(self.execute_f32(f))), - Value::F64(f) => Some(Value::F64(self.execute_f64(f))), - _ => panic!("Expect floating point input"), - } - } - } - - /// Sin value foreign function - /// - /// ``` scl - /// extern fn $sin(x: T) -> T - /// ``` - #[derive(Clone)] - pub struct Sin; - - impl UnaryFloatFunction for Sin { - fn name(&self) -> String { - "sin".to_string() - } - - fn execute_f32(&self, arg: f32) -> f32 { - arg.sin() - } - - fn execute_f64(&self, arg: f64) -> f64 { - arg.sin() - } - } - - /// Cos value foreign function - /// - /// ``` scl - /// extern fn $cos(x: T) -> T - /// ``` - #[derive(Clone)] - pub struct Cos; - - impl UnaryFloatFunction for Cos { - fn name(&self) -> String { - "cos".to_string() - } - - fn execute_f32(&self, arg: f32) -> f32 { - arg.cos() - } - - fn execute_f64(&self, arg: f64) -> f64 { - arg.cos() - } - } - - /// Tan value foreign function - /// - /// ``` scl - /// extern fn $tan(x: T) -> T - /// ``` - #[derive(Clone)] - pub struct Tan; - - impl UnaryFloatFunction for Tan { - fn name(&self) -> String { - "tan".to_string() - } - - fn execute_f32(&self, arg: f32) -> f32 { - arg.tan() - } - - fn execute_f64(&self, arg: f64) -> f64 { - arg.tan() - } - } - - /// Substring - /// - /// ``` scl - /// extern fn $substring(s: String, begin: usize, end: usize?) -> String - /// ``` - #[derive(Clone)] - pub struct Substring; - - impl ForeignFunction for Substring { - fn name(&self) -> String { - "substring".to_string() - } - - fn num_static_arguments(&self) -> usize { - 2 - } - - fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { - match i { - 0 => ForeignFunctionParameterType::BaseType(ValueType::String), - 1 => ForeignFunctionParameterType::BaseType(ValueType::USize), - _ => panic!("No argument {}", i), - } - } - - fn num_optional_arguments(&self) -> usize { - 1 - } - - fn optional_argument_type(&self, _: usize) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::BaseType(ValueType::USize) - } - - fn return_type(&self) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::BaseType(ValueType::String) - } - - fn execute(&self, args: Vec) -> Option { - if args.len() == 2 { - match (&args[0], &args[1]) { - (Value::String(s), Value::USize(i)) => Some(Value::String(s[*i..].to_string())), - _ => panic!("Invalid arguments"), - } - } else { - match (&args[0], &args[1], &args[2]) { - (Value::String(s), Value::USize(i), Value::USize(j)) => Some(Value::String(s[*i..*j].to_string())), - _ => panic!("Invalid arguments"), - } - } - } - } - - /// String concat - /// - /// ``` scl - /// extern fn $string_concat(s: String...) -> String - /// ``` - #[derive(Clone)] - pub struct StringConcat; - - impl ForeignFunction for StringConcat { - fn name(&self) -> String { - "string_concat".to_string() - } - - fn has_variable_arguments(&self) -> bool { - true - } - - fn variable_argument_type(&self) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::BaseType(ValueType::String) - } - - fn return_type(&self) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::BaseType(ValueType::String) - } - - fn execute(&self, args: Vec) -> Option { - let mut result = "".to_string(); - for arg in args { - match arg { - Value::String(s) => { - result += &s; - } - _ => panic!("Argument is not string"), - } - } - Some(Value::String(result)) - } - } - - /// String length - /// - /// ``` scl - /// extern fn $string_length(s: String) -> usize - /// ``` - #[derive(Clone)] - pub struct StringLength; - - impl ForeignFunction for StringLength { - fn name(&self) -> String { - "string_length".to_string() - } - - fn num_static_arguments(&self) -> usize { - 1 - } +/// Floating point function +pub trait UnaryFloatFunction: Clone { + fn name(&self) -> String; - fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { - assert_eq!(i, 0); - ForeignFunctionParameterType::BaseType(ValueType::String) - } + fn execute_f32(&self, arg: f32) -> f32; - fn return_type(&self) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::BaseType(ValueType::USize) - } + fn execute_f64(&self, arg: f64) -> f64; +} - fn execute(&self, args: Vec) -> Option { - match &args[0] { - Value::String(s) => Some(Value::USize(s.len())), - Value::Str(s) => Some(Value::USize(s.len())), - _ => None, - } - } +impl ForeignFunction for F { + fn name(&self) -> String { + UnaryFloatFunction::name(self) } - /// String char at - /// - /// ``` scl - /// extern fn $string_chat_at(s: String, i: usize) -> char - /// ``` - #[derive(Clone)] - pub struct StringCharAt; - - impl ForeignFunction for StringCharAt { - fn name(&self) -> String { - "string_char_at".to_string() - } - - fn num_static_arguments(&self) -> usize { - 2 - } - - fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { - match i { - 0 => ForeignFunctionParameterType::BaseType(ValueType::String), - 1 => ForeignFunctionParameterType::BaseType(ValueType::USize), - _ => panic!("Invalid {}-th argument", i), - } - } - - fn return_type(&self) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::BaseType(ValueType::Char) - } - - fn execute(&self, args: Vec) -> Option { - match (&args[0], &args[1]) { - (Value::String(s), Value::USize(i)) => s.chars().skip(*i).next().map(Value::Char), - (Value::Str(s), Value::USize(i)) => s.chars().skip(*i).next().map(Value::Char), - _ => None, - } - } + fn num_generic_types(&self) -> usize { + 1 } - /// Hash - /// - /// ``` scl - /// extern fn $hash(x: Any...) -> u64 - /// ``` - #[derive(Clone)] - pub struct Hash; - - impl ForeignFunction for Hash { - fn name(&self) -> String { - "hash".to_string() - } - - fn has_variable_arguments(&self) -> bool { - true - } - - fn variable_argument_type(&self) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::TypeFamily(TypeFamily::Any) - } - - fn return_type(&self) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::BaseType(ValueType::U64) - } - - fn execute(&self, args: Vec) -> Option { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - let mut s = DefaultHasher::new(); - args.hash(&mut s); - Some(s.finish().into()) - } + fn generic_type_family(&self, i: usize) -> TypeFamily { + assert_eq!(i, 0); + TypeFamily::Float } - /// Max - /// - /// ``` scl - /// extern fn $max(x: T...) -> T - /// ``` - #[derive(Clone)] - pub struct Max; - - impl Max { - fn dyn_max(args: Vec) -> Option where Value: TryInto { - let mut iter = args.into_iter(); - let mut curr_max: T = iter.next()?.try_into().ok()?; - while let Some(next_elem) = iter.next() { - let next_elem = next_elem.try_into().ok()?; - if next_elem > curr_max { - curr_max = next_elem; - } - } - Some(curr_max) - } + fn num_static_arguments(&self) -> usize { + 1 } - impl ForeignFunction for Max { - fn name(&self) -> String { - "max".to_string() - } - - fn num_generic_types(&self) -> usize { - 1 - } - - fn generic_type_family(&self, i: usize) -> TypeFamily { - assert_eq!(i, 0); - TypeFamily::Number - } - - fn has_variable_arguments(&self) -> bool { - true - } - - fn variable_argument_type(&self) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::Generic(0) - } - - fn return_type(&self) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::Generic(0) - } - - fn execute(&self, args: Vec) -> Option { - let rt = self.infer_return_type(&args); - match rt { - ValueType::I8 => Self::dyn_max(args).map(Value::I8), - ValueType::I16 => Self::dyn_max(args).map(Value::I16), - ValueType::I32 => Self::dyn_max(args).map(Value::I32), - ValueType::I64 => Self::dyn_max(args).map(Value::I64), - ValueType::I128 => Self::dyn_max(args).map(Value::I128), - ValueType::ISize => Self::dyn_max(args).map(Value::ISize), - ValueType::U8 => Self::dyn_max(args).map(Value::U8), - ValueType::U16 => Self::dyn_max(args).map(Value::U16), - ValueType::U32 => Self::dyn_max(args).map(Value::U32), - ValueType::U64 => Self::dyn_max(args).map(Value::U64), - ValueType::U128 => Self::dyn_max(args).map(Value::U128), - ValueType::USize => Self::dyn_max(args).map(Value::USize), - ValueType::F32 => Self::dyn_max(args).map(Value::F32), - ValueType::F64 => Self::dyn_max(args).map(Value::F64), - _ => None, - } - } + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::Generic(0) } - /// Min - /// - /// ``` scl - /// extern fn $min(x: T...) -> T - /// ``` - #[derive(Clone)] - pub struct Min; - - impl Min { - fn dyn_min(args: Vec) -> Option where Value: TryInto { - let mut iter = args.into_iter(); - let mut curr_min: T = iter.next()?.try_into().ok()?; - while let Some(next_elem) = iter.next() { - let next_elem = next_elem.try_into().ok()?; - if next_elem < curr_min { - curr_min = next_elem; - } - } - Some(curr_min) - } + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::Generic(0) } - impl ForeignFunction for Min { - fn name(&self) -> String { - "min".to_string() - } - - fn num_generic_types(&self) -> usize { - 1 - } - - fn generic_type_family(&self, i: usize) -> TypeFamily { - assert_eq!(i, 0); - TypeFamily::Number - } - - fn has_variable_arguments(&self) -> bool { - true - } - - fn variable_argument_type(&self) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::Generic(0) - } - - fn return_type(&self) -> ForeignFunctionParameterType { - ForeignFunctionParameterType::Generic(0) - } - - fn execute(&self, args: Vec) -> Option { - let rt = self.infer_return_type(&args); - match rt { - ValueType::I8 => Self::dyn_min(args).map(Value::I8), - ValueType::I16 => Self::dyn_min(args).map(Value::I16), - ValueType::I32 => Self::dyn_min(args).map(Value::I32), - ValueType::I64 => Self::dyn_min(args).map(Value::I64), - ValueType::I128 => Self::dyn_min(args).map(Value::I128), - ValueType::ISize => Self::dyn_min(args).map(Value::ISize), - ValueType::U8 => Self::dyn_min(args).map(Value::U8), - ValueType::U16 => Self::dyn_min(args).map(Value::U16), - ValueType::U32 => Self::dyn_min(args).map(Value::U32), - ValueType::U64 => Self::dyn_min(args).map(Value::U64), - ValueType::U128 => Self::dyn_min(args).map(Value::U128), - ValueType::USize => Self::dyn_min(args).map(Value::USize), - ValueType::F32 => Self::dyn_min(args).map(Value::F32), - ValueType::F64 => Self::dyn_min(args).map(Value::F64), - _ => None, - } + fn execute(&self, args: Vec) -> Option { + match args[0] { + Value::F32(f) => Some(Value::F32(self.execute_f32(f))), + Value::F64(f) => Some(Value::F64(self.execute_f64(f))), + _ => panic!("Expect floating point input"), } } } diff --git a/core/src/common/foreign_functions/abs.rs b/core/src/common/foreign_functions/abs.rs new file mode 100644 index 0000000..1846766 --- /dev/null +++ b/core/src/common/foreign_functions/abs.rs @@ -0,0 +1,62 @@ +use super::*; + +/// Absolute value foreign function +/// +/// ``` scl +/// extern fn $abs(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Abs; + +impl ForeignFunction for Abs { + fn name(&self) -> String { + "abs".to_string() + } + + fn num_generic_types(&self) -> usize { + 1 + } + + fn generic_type_family(&self, i: usize) -> TypeFamily { + assert_eq!(i, 0); + TypeFamily::Number + } + + fn num_static_arguments(&self) -> usize { + 1 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::Generic(0) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::Generic(0) + } + + fn execute(&self, args: Vec) -> Option { + match args[0] { + // Signed integers, take absolute + Value::I8(f) => Some(Value::I8(f.abs())), + Value::I16(f) => Some(Value::I16(f.abs())), + Value::I32(f) => Some(Value::I32(f.abs())), + Value::I64(f) => Some(Value::I64(f.abs())), + Value::I128(f) => Some(Value::I128(f.abs())), + Value::ISize(f) => Some(Value::ISize(f.abs())), + + // Unsigned integers, directly return + Value::U8(f) => Some(Value::U8(f)), + Value::U16(f) => Some(Value::U16(f)), + Value::U32(f) => Some(Value::U32(f)), + Value::U64(f) => Some(Value::U64(f)), + Value::U128(f) => Some(Value::U128(f)), + Value::USize(f) => Some(Value::USize(f)), + + // Floating points, take absolute + Value::F32(f) => Some(Value::F32(f.abs())), + Value::F64(f) => Some(Value::F64(f.abs())), + _ => panic!("should not happen; input variable to abs should be a number"), + } + } +} diff --git a/core/src/common/foreign_functions/cos.rs b/core/src/common/foreign_functions/cos.rs new file mode 100644 index 0000000..bb91b77 --- /dev/null +++ b/core/src/common/foreign_functions/cos.rs @@ -0,0 +1,23 @@ +use super::*; + +/// Cos value foreign function +/// +/// ``` scl +/// extern fn $cos(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Cos; + +impl UnaryFloatFunction for Cos { + fn name(&self) -> String { + "cos".to_string() + } + + fn execute_f32(&self, arg: f32) -> f32 { + arg.cos() + } + + fn execute_f64(&self, arg: f64) -> f64 { + arg.cos() + } +} diff --git a/core/src/common/foreign_functions/datetime_day.rs b/core/src/common/foreign_functions/datetime_day.rs new file mode 100644 index 0000000..2c5de0d --- /dev/null +++ b/core/src/common/foreign_functions/datetime_day.rs @@ -0,0 +1,37 @@ +use chrono::Datelike; + +use super::*; + +/// Get the day of the month starting from 1 +/// +/// ``` scl +/// extern fn $datetime_day(d: DateTime) -> u32 +/// ``` +#[derive(Clone)] +pub struct DateTimeDay; + +impl ForeignFunction for DateTimeDay { + fn name(&self) -> String { + "datetime_day".to_string() + } + + fn num_static_arguments(&self) -> usize { + 1 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::BaseType(ValueType::DateTime) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::U32) + } + + fn execute(&self, args: Vec) -> Option { + match &args[0] { + Value::DateTime(d) => Some(Value::U32(d.day())), + _ => None, + } + } +} diff --git a/core/src/common/foreign_functions/datetime_month.rs b/core/src/common/foreign_functions/datetime_month.rs new file mode 100644 index 0000000..8c3e0d2 --- /dev/null +++ b/core/src/common/foreign_functions/datetime_month.rs @@ -0,0 +1,37 @@ +use chrono::Datelike; + +use super::*; + +/// Get the month of the year starting from 1 +/// +/// ``` scl +/// extern fn $datetime_month(d: DateTime) -> u32 +/// ``` +#[derive(Clone)] +pub struct DateTimeMonth; + +impl ForeignFunction for DateTimeMonth { + fn name(&self) -> String { + "datetime_month".to_string() + } + + fn num_static_arguments(&self) -> usize { + 1 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::BaseType(ValueType::DateTime) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::U32) + } + + fn execute(&self, args: Vec) -> Option { + match &args[0] { + Value::DateTime(d) => Some(Value::U32(d.month())), + _ => None, + } + } +} diff --git a/core/src/common/foreign_functions/datetime_month0.rs b/core/src/common/foreign_functions/datetime_month0.rs new file mode 100644 index 0000000..249b638 --- /dev/null +++ b/core/src/common/foreign_functions/datetime_month0.rs @@ -0,0 +1,37 @@ +use chrono::Datelike; + +use super::*; + +/// Get the month of the year starting from 0 +/// +/// ``` scl +/// extern fn $datetime_month(d: DateTime) -> u32 +/// ``` +#[derive(Clone)] +pub struct DateTimeMonth0; + +impl ForeignFunction for DateTimeMonth0 { + fn name(&self) -> String { + "datetime_month0".to_string() + } + + fn num_static_arguments(&self) -> usize { + 1 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::BaseType(ValueType::DateTime) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::U32) + } + + fn execute(&self, args: Vec) -> Option { + match &args[0] { + Value::DateTime(d) => Some(Value::U32(d.month0())), + _ => None, + } + } +} diff --git a/core/src/common/foreign_functions/datetime_year.rs b/core/src/common/foreign_functions/datetime_year.rs new file mode 100644 index 0000000..e43474c --- /dev/null +++ b/core/src/common/foreign_functions/datetime_year.rs @@ -0,0 +1,37 @@ +use chrono::Datelike; + +use super::*; + +/// Get the year (signed integer) in the calendar date +/// +/// ``` scl +/// extern fn $datetime_year(d: DateTime) -> i32 +/// ``` +#[derive(Clone)] +pub struct DateTimeYear; + +impl ForeignFunction for DateTimeYear { + fn name(&self) -> String { + "datetime_year".to_string() + } + + fn num_static_arguments(&self) -> usize { + 1 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::BaseType(ValueType::DateTime) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::I32) + } + + fn execute(&self, args: Vec) -> Option { + match &args[0] { + Value::DateTime(d) => Some(Value::I32(d.year())), + _ => None, + } + } +} diff --git a/core/src/common/foreign_functions/hash.rs b/core/src/common/foreign_functions/hash.rs new file mode 100644 index 0000000..d6c1455 --- /dev/null +++ b/core/src/common/foreign_functions/hash.rs @@ -0,0 +1,35 @@ +use super::*; + +/// Hash +/// +/// ``` scl +/// extern fn $hash(x: Any...) -> u64 +/// ``` +#[derive(Clone)] +pub struct Hash; + +impl ForeignFunction for Hash { + fn name(&self) -> String { + "hash".to_string() + } + + fn has_variable_arguments(&self) -> bool { + true + } + + fn variable_argument_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::TypeFamily(TypeFamily::Any) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::U64) + } + + fn execute(&self, args: Vec) -> Option { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut s = DefaultHasher::new(); + args.hash(&mut s); + Some(s.finish().into()) + } +} diff --git a/core/src/common/foreign_functions/max.rs b/core/src/common/foreign_functions/max.rs new file mode 100644 index 0000000..e39b911 --- /dev/null +++ b/core/src/common/foreign_functions/max.rs @@ -0,0 +1,71 @@ +use super::*; + +/// Max +/// +/// ``` scl +/// extern fn $max(x: T...) -> T +/// ``` +#[derive(Clone)] +pub struct Max; + +impl Max { + fn dyn_max(args: Vec) -> Option where Value: TryInto { + let mut iter = args.into_iter(); + let mut curr_max: T = iter.next()?.try_into().ok()?; + while let Some(next_elem) = iter.next() { + let next_elem = next_elem.try_into().ok()?; + if next_elem > curr_max { + curr_max = next_elem; + } + } + Some(curr_max) + } +} + +impl ForeignFunction for Max { + fn name(&self) -> String { + "max".to_string() + } + + fn num_generic_types(&self) -> usize { + 1 + } + + fn generic_type_family(&self, i: usize) -> TypeFamily { + assert_eq!(i, 0); + TypeFamily::Number + } + + fn has_variable_arguments(&self) -> bool { + true + } + + fn variable_argument_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::Generic(0) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::Generic(0) + } + + fn execute(&self, args: Vec) -> Option { + let rt = self.infer_return_type(&args); + match rt { + ValueType::I8 => Self::dyn_max(args).map(Value::I8), + ValueType::I16 => Self::dyn_max(args).map(Value::I16), + ValueType::I32 => Self::dyn_max(args).map(Value::I32), + ValueType::I64 => Self::dyn_max(args).map(Value::I64), + ValueType::I128 => Self::dyn_max(args).map(Value::I128), + ValueType::ISize => Self::dyn_max(args).map(Value::ISize), + ValueType::U8 => Self::dyn_max(args).map(Value::U8), + ValueType::U16 => Self::dyn_max(args).map(Value::U16), + ValueType::U32 => Self::dyn_max(args).map(Value::U32), + ValueType::U64 => Self::dyn_max(args).map(Value::U64), + ValueType::U128 => Self::dyn_max(args).map(Value::U128), + ValueType::USize => Self::dyn_max(args).map(Value::USize), + ValueType::F32 => Self::dyn_max(args).map(Value::F32), + ValueType::F64 => Self::dyn_max(args).map(Value::F64), + _ => None, + } + } +} diff --git a/core/src/common/foreign_functions/min.rs b/core/src/common/foreign_functions/min.rs new file mode 100644 index 0000000..b3684e9 --- /dev/null +++ b/core/src/common/foreign_functions/min.rs @@ -0,0 +1,71 @@ +use super::*; + +/// Min +/// +/// ``` scl +/// extern fn $min(x: T...) -> T +/// ``` +#[derive(Clone)] +pub struct Min; + +impl Min { + fn dyn_min(args: Vec) -> Option where Value: TryInto { + let mut iter = args.into_iter(); + let mut curr_min: T = iter.next()?.try_into().ok()?; + while let Some(next_elem) = iter.next() { + let next_elem = next_elem.try_into().ok()?; + if next_elem < curr_min { + curr_min = next_elem; + } + } + Some(curr_min) + } +} + +impl ForeignFunction for Min { + fn name(&self) -> String { + "min".to_string() + } + + fn num_generic_types(&self) -> usize { + 1 + } + + fn generic_type_family(&self, i: usize) -> TypeFamily { + assert_eq!(i, 0); + TypeFamily::Number + } + + fn has_variable_arguments(&self) -> bool { + true + } + + fn variable_argument_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::Generic(0) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::Generic(0) + } + + fn execute(&self, args: Vec) -> Option { + let rt = self.infer_return_type(&args); + match rt { + ValueType::I8 => Self::dyn_min(args).map(Value::I8), + ValueType::I16 => Self::dyn_min(args).map(Value::I16), + ValueType::I32 => Self::dyn_min(args).map(Value::I32), + ValueType::I64 => Self::dyn_min(args).map(Value::I64), + ValueType::I128 => Self::dyn_min(args).map(Value::I128), + ValueType::ISize => Self::dyn_min(args).map(Value::ISize), + ValueType::U8 => Self::dyn_min(args).map(Value::U8), + ValueType::U16 => Self::dyn_min(args).map(Value::U16), + ValueType::U32 => Self::dyn_min(args).map(Value::U32), + ValueType::U64 => Self::dyn_min(args).map(Value::U64), + ValueType::U128 => Self::dyn_min(args).map(Value::U128), + ValueType::USize => Self::dyn_min(args).map(Value::USize), + ValueType::F32 => Self::dyn_min(args).map(Value::F32), + ValueType::F64 => Self::dyn_min(args).map(Value::F64), + _ => None, + } + } +} diff --git a/core/src/common/foreign_functions/mod.rs b/core/src/common/foreign_functions/mod.rs new file mode 100644 index 0000000..f897202 --- /dev/null +++ b/core/src/common/foreign_functions/mod.rs @@ -0,0 +1,40 @@ +//! A library of foreign functions + +use super::value::*; +use super::value_type::*; +use super::type_family::*; +use super::foreign_function::*; + +use std::convert::*; + +mod abs; +mod cos; +mod datetime_day; +mod datetime_month; +mod datetime_month0; +mod datetime_year; +mod hash; +mod max; +mod min; +mod sin; +mod string_char_at; +mod string_concat; +mod string_length; +mod substring; +mod tan; + +pub use abs::*; +pub use cos::*; +pub use datetime_day::*; +pub use datetime_month::*; +pub use datetime_month0::*; +pub use datetime_year::*; +pub use hash::*; +pub use max::*; +pub use min::*; +pub use sin::*; +pub use string_char_at::*; +pub use string_concat::*; +pub use string_length::*; +pub use substring::*; +pub use tan::*; diff --git a/core/src/common/foreign_functions/sin.rs b/core/src/common/foreign_functions/sin.rs new file mode 100644 index 0000000..c68d829 --- /dev/null +++ b/core/src/common/foreign_functions/sin.rs @@ -0,0 +1,23 @@ +use super::*; + +/// Sin value foreign function +/// +/// ``` scl +/// extern fn $sin(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Sin; + +impl UnaryFloatFunction for Sin { + fn name(&self) -> String { + "sin".to_string() + } + + fn execute_f32(&self, arg: f32) -> f32 { + arg.sin() + } + + fn execute_f64(&self, arg: f64) -> f64 { + arg.sin() + } +} diff --git a/core/src/common/foreign_functions/string_char_at.rs b/core/src/common/foreign_functions/string_char_at.rs new file mode 100644 index 0000000..5b5917e --- /dev/null +++ b/core/src/common/foreign_functions/string_char_at.rs @@ -0,0 +1,39 @@ +use super::*; + +/// String char at +/// +/// ``` scl +/// extern fn $string_chat_at(s: String, i: usize) -> char +/// ``` +#[derive(Clone)] +pub struct StringCharAt; + +impl ForeignFunction for StringCharAt { + fn name(&self) -> String { + "string_char_at".to_string() + } + + fn num_static_arguments(&self) -> usize { + 2 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + match i { + 0 => ForeignFunctionParameterType::BaseType(ValueType::String), + 1 => ForeignFunctionParameterType::BaseType(ValueType::USize), + _ => panic!("Invalid {}-th argument", i), + } + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::Char) + } + + fn execute(&self, args: Vec) -> Option { + match (&args[0], &args[1]) { + (Value::String(s), Value::USize(i)) => s.chars().skip(*i).next().map(Value::Char), + (Value::Str(s), Value::USize(i)) => s.chars().skip(*i).next().map(Value::Char), + _ => None, + } + } +} diff --git a/core/src/common/foreign_functions/string_concat.rs b/core/src/common/foreign_functions/string_concat.rs new file mode 100644 index 0000000..2ffbd23 --- /dev/null +++ b/core/src/common/foreign_functions/string_concat.rs @@ -0,0 +1,40 @@ +use super::*; + +/// String concat +/// +/// ``` scl +/// extern fn $string_concat(s: String...) -> String +/// ``` +#[derive(Clone)] +pub struct StringConcat; + +impl ForeignFunction for StringConcat { + fn name(&self) -> String { + "string_concat".to_string() + } + + fn has_variable_arguments(&self) -> bool { + true + } + + fn variable_argument_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::String) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::String) + } + + fn execute(&self, args: Vec) -> Option { + let mut result = "".to_string(); + for arg in args { + match arg { + Value::String(s) => { + result += &s; + } + _ => panic!("Argument is not string"), + } + } + Some(Value::String(result)) + } +} diff --git a/core/src/common/foreign_functions/string_length.rs b/core/src/common/foreign_functions/string_length.rs new file mode 100644 index 0000000..dc86f36 --- /dev/null +++ b/core/src/common/foreign_functions/string_length.rs @@ -0,0 +1,36 @@ +use super::*; + +/// String length +/// +/// ``` scl +/// extern fn $string_length(s: String) -> usize +/// ``` +#[derive(Clone)] +pub struct StringLength; + +impl ForeignFunction for StringLength { + fn name(&self) -> String { + "string_length".to_string() + } + + fn num_static_arguments(&self) -> usize { + 1 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::BaseType(ValueType::String) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::USize) + } + + fn execute(&self, args: Vec) -> Option { + match &args[0] { + Value::String(s) => Some(Value::USize(s.len())), + Value::Str(s) => Some(Value::USize(s.len())), + _ => None, + } + } +} diff --git a/core/src/common/foreign_functions/substring.rs b/core/src/common/foreign_functions/substring.rs new file mode 100644 index 0000000..1b5062a --- /dev/null +++ b/core/src/common/foreign_functions/substring.rs @@ -0,0 +1,53 @@ +use super::*; + +/// Substring +/// +/// ``` scl +/// extern fn $substring(s: String, begin: usize, end: usize?) -> String +/// ``` +#[derive(Clone)] +pub struct Substring; + +impl ForeignFunction for Substring { + fn name(&self) -> String { + "substring".to_string() + } + + fn num_static_arguments(&self) -> usize { + 2 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + match i { + 0 => ForeignFunctionParameterType::BaseType(ValueType::String), + 1 => ForeignFunctionParameterType::BaseType(ValueType::USize), + _ => panic!("No argument {}", i), + } + } + + fn num_optional_arguments(&self) -> usize { + 1 + } + + fn optional_argument_type(&self, _: usize) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::USize) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::String) + } + + fn execute(&self, args: Vec) -> Option { + if args.len() == 2 { + match (&args[0], &args[1]) { + (Value::String(s), Value::USize(i)) => Some(Value::String(s[*i..].to_string())), + _ => panic!("Invalid arguments"), + } + } else { + match (&args[0], &args[1], &args[2]) { + (Value::String(s), Value::USize(i), Value::USize(j)) => Some(Value::String(s[*i..*j].to_string())), + _ => panic!("Invalid arguments"), + } + } + } +} diff --git a/core/src/common/foreign_functions/tan.rs b/core/src/common/foreign_functions/tan.rs new file mode 100644 index 0000000..58fa97f --- /dev/null +++ b/core/src/common/foreign_functions/tan.rs @@ -0,0 +1,23 @@ +use super::*; + +/// Tan value foreign function +/// +/// ``` scl +/// extern fn $tan(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Tan; + +impl UnaryFloatFunction for Tan { + fn name(&self) -> String { + "tan".to_string() + } + + fn execute_f32(&self, arg: f32) -> f32 { + arg.tan() + } + + fn execute_f64(&self, arg: f64) -> f64 { + arg.tan() + } +} diff --git a/core/src/common/foreign_predicate.rs b/core/src/common/foreign_predicate.rs index 8b13789..3101e12 100644 --- a/core/src/common/foreign_predicate.rs +++ b/core/src/common/foreign_predicate.rs @@ -1 +1,314 @@ +use std::collections::*; +use dyn_clone::*; + +use super::value::*; +use super::value_type::*; +use super::input_tag::*; +use super::foreign_predicates as fps; + +#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub enum Binding { + Free, + Bound, +} + +impl std::fmt::Display for Binding { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Free => f.write_str("f"), + Self::Bound => f.write_str("b"), + } + } +} + +impl Binding { + pub fn is_bound(&self) -> bool { + match self { + Self::Bound => true, + _ => false, + } + } + + pub fn is_free(&self) -> bool { + match self { + Self::Free => true, + _ => false, + } + } +} + +/// The identifier of a foreign predicate in a registry +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct ForeignPredicateIdentifier { + identifier: String, + types: Box<[ValueType]>, + binding_pattern: BindingPattern, +} + +impl std::fmt::Display for ForeignPredicateIdentifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!( + "pred {}[{}]({})", + self.identifier, + self.binding_pattern, + self.types.iter().map(|t| format!("{}", t)).collect::>().join(", ") + )) + } +} + +/// A binding pattern for a predicate, e.g. bbf +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct BindingPattern { + pattern: Box<[Binding]>, +} + +impl std::fmt::Display for BindingPattern { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for binding in &*self.pattern { + binding.fmt(f)?; + } + Ok(()) + } +} + +impl BindingPattern { + /// Create a new binding pattern given the predicate arity and number of bounded variables + pub fn new(arity: usize, num_bounded: usize) -> Self { + assert!(num_bounded <= arity); + Self { + pattern: (0..arity).map(|i| { + if i < num_bounded { Binding::Bound } + else { Binding::Free } + }).collect() + } + } + + /// Check if all argument needs to be bounded + pub fn is_bounded(&self) -> bool { + self.pattern.iter().all(|p| p.is_bound()) + } + + /// Check if all arguments are free + pub fn is_free(&self) -> bool { + self.pattern.iter().all(|p| p.is_free()) + } + + pub fn iter(&self) -> std::slice::Iter { + self.pattern.iter() + } +} + +impl std::ops::Index for BindingPattern { + type Output = Binding; + + fn index(&self, index: usize) -> &Self::Output { + &self.pattern[index] + } +} + +/// The foreign predicate for a runtime implementation +/// +/// The arguments of a foreign predicate can be marked as "need to be bounded" +/// or free. +/// We assume the bounded arguments are always placed before the free arguments. +/// During runtime, we expect the foreign predicate to take in all bounded +/// variables as input, and produce the free variables, along with a tag associated +/// with the tuple. +pub trait ForeignPredicate: DynClone { + /// The name of the predicate + fn name(&self) -> String; + + /// The arity of the predicate (i.e. number of arguments) + fn arity(&self) -> usize; + + /// The type of the `i`-th argument + /// + /// Should panic if `i` is larger than or equal to the arity + fn argument_type(&self, i: usize) -> ValueType; + + /// The number of bounded arguments + fn num_bounded(&self) -> usize; + + /// The number of free arguments + fn num_free(&self) -> usize { + self.arity() - self.num_bounded() + } + + /// Get a vector of the argument types + fn argument_types(&self) -> Vec { + (0..self.arity()).map(|i| self.argument_type(i)).collect() + } + + /// Get an identifier for this predicate + fn binding_pattern(&self) -> BindingPattern { + BindingPattern::new(self.arity(), self.num_bounded()) + } + + /// Evaluate the foreign predicate given a tuple containing bounded variables + /// + /// The `bounded` tuple (`Vec`) should have arity (length) `self.num_bounded()`. + /// The function returns a sequence of (dynamically) tagged-tuples where the arity is `self.num_free()` + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)>; +} + +/// The dynamic foreign predicate +pub struct DynamicForeignPredicate { + fp: Box, +} + +impl DynamicForeignPredicate { + pub fn new(fp: P) -> Self { + Self { fp: Box::new(fp) } + } +} + +impl Clone for DynamicForeignPredicate { + fn clone(&self) -> Self { + Self { + fp: dyn_clone::clone_box(&*self.fp), + } + } +} + +impl ForeignPredicate for DynamicForeignPredicate { + fn name(&self) -> String { + self.fp.name() + } + + fn arity(&self) -> usize { + self.fp.arity() + } + + fn argument_type(&self, i: usize) -> ValueType { + self.fp.argument_type(i) + } + + fn num_bounded(&self) -> usize { + self.fp.num_bounded() + } + + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + self.fp.evaluate(bounded) + } +} + +impl std::fmt::Debug for DynamicForeignPredicate { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f + .debug_struct("ForeignPredicate") + .field("name", &self.name()) + .field("types", &(0..self.arity()).map(|i| self.argument_type(i)).collect::>()) + .field("num_bounded", &self.num_bounded()) + .finish() + } +} + +impl std::fmt::Display for DynamicForeignPredicate { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("pred {}(", self.name()))?; + for i in 0..self.arity() { + if i > 0 { + f.write_str(", ")?; + } + self.argument_type(i).fmt(f)?; + if i < self.num_bounded() { + f.write_str(" [b]")?; + } else { + f.write_str(" [f]")?; + } + } + f.write_str(")") + } +} + +/// A foreign predicate registry +#[derive(Clone, Debug)] +pub struct ForeignPredicateRegistry { + registry: HashMap, +} + +impl ForeignPredicateRegistry { + /// Create an empty foreign predicate registry + pub fn new() -> Self { + Self { + registry: HashMap::new() + } + } + + /// Create a Standard Library foreign predicate registry + pub fn std() -> Self { + let mut reg = Self::new(); + + // Register all predicates + + // Range + for value_type in ValueType::integers() { + reg.register(fps::RangeBBF::new(value_type.clone())).unwrap(); + } + + // Soft comparison operators + for value_type in ValueType::floats() { + reg.register(fps::FloatEq::new(value_type.clone())).unwrap(); + reg.register(fps::SoftNumberEq::new(value_type.clone())).unwrap(); + reg.register(fps::SoftNumberNeq::new(value_type.clone())).unwrap(); + reg.register(fps::SoftNumberGt::new(value_type.clone())).unwrap(); + reg.register(fps::SoftNumberLt::new(value_type.clone())).unwrap(); + } + + // String operations + reg.register(fps::StringCharsBFF::new()).unwrap(); + + reg + } + + /// Register a new foreign predicate in the registry + pub fn register(&mut self, p: P) -> Result<(), ForeignPredicateError> { + let id = p.name(); + if self.contains(&id) { + Err(ForeignPredicateError::AlreadyExisted { id: format!("{}", id) }) + } else { + let p = DynamicForeignPredicate::new(p); + self.registry.insert(id, p); + Ok(()) + } + } + + /// Check if the registry contains a foreign predicate by using its identifier + pub fn contains(&self, id: &str) -> bool { + self.registry.contains_key(id) + } + + /// Get the foreign predicate + pub fn get(&self, id: &str) -> Option<&DynamicForeignPredicate> { + self.registry.get(id) + } + + pub fn iter<'a>(&'a self) -> hash_map::Iter<'a, String, DynamicForeignPredicate> { + self.into_iter() + } +} + +impl<'a> IntoIterator for &'a ForeignPredicateRegistry { + type IntoIter = hash_map::Iter<'a, String, DynamicForeignPredicate>; + + type Item = (&'a String, &'a DynamicForeignPredicate); + + fn into_iter(self) -> Self::IntoIter { + self.registry.iter() + } +} + +/// THe errors happening when handling foreign predicates +#[derive(Clone, Debug)] +pub enum ForeignPredicateError { + AlreadyExisted { id: String }, +} + +impl std::fmt::Display for ForeignPredicateError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::AlreadyExisted { id } => write!(f, "Foreign predicate `{}` already existed", id), + } + } +} diff --git a/core/src/common/foreign_predicates/float_eq.rs b/core/src/common/foreign_predicates/float_eq.rs new file mode 100644 index 0000000..27dcc9f --- /dev/null +++ b/core/src/common/foreign_predicates/float_eq.rs @@ -0,0 +1,63 @@ +//! The floating point equality predicate + +use super::*; + +#[derive(Clone, Debug)] +pub struct FloatEq { + /// The floating point type + pub ty: ValueType, + + /// The type of the operands + pub threshold: f64, +} + +impl FloatEq { + pub fn new(ty: ValueType) -> Self { + assert!(ty.is_float()); + Self { + ty, + threshold: 0.001, + } + } + + pub fn new_with_threshold(ty: ValueType, threshold: f64) -> Self { + Self { + ty, + threshold, + } + } +} + +impl ForeignPredicate for FloatEq { + fn name(&self) -> String { + format!("float_eq_{}", self.ty) + } + + fn arity(&self) -> usize { + 2 + } + + fn argument_type(&self, i: usize) -> ValueType { + assert!(i < 2); + self.ty.clone() + } + + fn num_bounded(&self) -> usize { + 2 + } + + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + assert_eq!(bounded.len(), 2); + let lhs = &bounded[0]; + let rhs = &bounded[1]; + match (&self.ty, lhs, rhs) { + (ValueType::F32, Value::F32(l), Value::F32(r)) if (l - r).abs() < (self.threshold as f32) => { + vec![(DynamicInputTag::None, vec![])] + }, + (ValueType::F64, Value::F64(l), Value::F64(r)) if (l - r).abs() < self.threshold => { + vec![(DynamicInputTag::None, vec![])] + }, + _ => vec![], + } + } +} diff --git a/core/src/common/foreign_predicates/mod.rs b/core/src/common/foreign_predicates/mod.rs new file mode 100644 index 0000000..616774e --- /dev/null +++ b/core/src/common/foreign_predicates/mod.rs @@ -0,0 +1,28 @@ +//! Library of foreign predicates + +use std::convert::*; + +use crate::utils::*; + +use super::input_tag::*; +use super::foreign_predicate::*; +use super::value::*; +use super::value_type::*; + +mod float_eq; +mod range; +mod string_chars; +mod soft_cmp; +mod soft_eq; +mod soft_gt; +mod soft_lt; +mod soft_neq; + +pub use float_eq::*; +pub use range::*; +pub use string_chars::*; +pub use soft_cmp::*; +pub use soft_eq::*; +pub use soft_gt::*; +pub use soft_lt::*; +pub use soft_neq::*; diff --git a/core/src/common/foreign_predicates/range.rs b/core/src/common/foreign_predicates/range.rs new file mode 100644 index 0000000..f890f6a --- /dev/null +++ b/core/src/common/foreign_predicates/range.rs @@ -0,0 +1,95 @@ +use super::*; + +/// Range foreign predicate +/// +/// ``` scl +/// extern pred range(begin: T, end: T, i: T)[bbf] +/// ``` +/// +/// The first two arguments `begin` and `end` are bounded to generate `i` as the free variable. +/// The generated number `i` will be sorted from `begin` (inclusive) to `end` (exclusive). +#[derive(Clone)] +pub struct RangeBBF { + /// The type of the range operator + pub ty: ValueType, +} + +impl RangeBBF { + /// Create a new range (bbf) foreign predicate + pub fn new(ty: ValueType) -> Self { + Self { ty } + } + + /// Compute the numbers between + fn range(begin: &Value, end: &Value) -> impl Iterator where Value: TryInto { + pub struct StepIterator { + curr: T, + end: T, + } + + impl Iterator for StepIterator { + type Item = T; + + fn next(&mut self) -> Option { + if self.curr >= self.end { + None + } else { + let result = Some(self.curr); + self.curr = self.curr + T::one(); + result + } + } + } + + // Cast value into integer type + let begin: T = begin.clone().try_into().unwrap_or(T::zero()); + let end: T = end.clone().try_into().unwrap_or(T::zero()); + + // Finally generate the list + StepIterator { curr: begin, end } + } + + fn dyn_range>(begin: &Value, end: &Value) -> Vec<(DynamicInputTag, Vec)> where Value: TryInto { + Self::range::(begin, end).map(|i| (DynamicInputTag::None, vec![i.into()])).collect() + } +} + +impl ForeignPredicate for RangeBBF { + fn name(&self) -> String { + format!("range_{}", self.ty) + } + + fn arity(&self) -> usize { + 3 + } + + fn argument_type(&self, i: usize) -> ValueType { + assert!(i < 3); + self.ty.clone() + } + + fn num_bounded(&self) -> usize { + 2 + } + + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + assert_eq!(bounded.len(), 2); + let begin = &bounded[0]; + let end = &bounded[1]; + match &self.ty { + ValueType::I8 => Self::dyn_range::(begin, end), + ValueType::I16 => Self::dyn_range::(begin, end), + ValueType::I32 => Self::dyn_range::(begin, end), + ValueType::I64 => Self::dyn_range::(begin, end), + ValueType::I128 => Self::dyn_range::(begin, end), + ValueType::ISize => Self::dyn_range::(begin, end), + ValueType::U8 => Self::dyn_range::(begin, end), + ValueType::U16 => Self::dyn_range::(begin, end), + ValueType::U32 => Self::dyn_range::(begin, end), + ValueType::U64 => Self::dyn_range::(begin, end), + ValueType::U128 => Self::dyn_range::(begin, end), + ValueType::USize => Self::dyn_range::(begin, end), + _ => vec![], + } + } +} diff --git a/core/src/common/foreign_predicates/soft_cmp.rs b/core/src/common/foreign_predicates/soft_cmp.rs new file mode 100644 index 0000000..3d6ab2d --- /dev/null +++ b/core/src/common/foreign_predicates/soft_cmp.rs @@ -0,0 +1,93 @@ +//! The sigmoid functions defined for soft comparison +//! +//! Reference: , Section 4.2 + +/// Sigmoid (s-shaped) function +#[derive(Clone, Debug)] +pub enum SigmoidFunction { + Logistic { beta: f64 }, + Reciprocal { beta: f64 }, + Cauchy { beta: f64 }, + OptimalMonotonic { beta: f64 }, +} + +impl Default for SigmoidFunction { + fn default() -> Self { + Self::Logistic { beta: 1.0 } + } +} + +impl SigmoidFunction { + /// Create a logistic sigmoid function with beta as scaling parameter + pub fn logistic(beta: f64) -> Self { + Self::Logistic { beta } + } + + /// Create a reciprocal sigmoid function with beta as scaling parameter + pub fn reciprocal(beta: f64) -> Self { + Self::Reciprocal { beta } + } + + /// Create a cauchy sigmoid function with beta as scaling parameter + pub fn cauchy(beta: f64) -> Self { + Self::Cauchy { beta } + } + + /// Create a optimal sigmoid function with beta as scaling parameter + pub fn optimal_monotonic(beta: f64) -> Self { + Self::OptimalMonotonic { beta } + } + + /// Evaluate the sigmoid function on input `x`, returns a value in [0, 1] + pub fn eval(&self, x: f64) -> f64 { + match self { + Self::Logistic { beta } => { + // \frac{1}{1 + e^{- \beta x}} + 1.0 / (1.0 + (-beta * x).exp()) + } + Self::Reciprocal { beta } => { + // \frac{\beta x}{1 + 2 \beta |x|} + \frac{1}{2} + (beta * x) / (1.0 + 2.0 * beta * x.abs()) + 0.5 + } + Self::Cauchy { beta } => { + // \frac{1}{\pi} \text{arctan}(\beta x) + \frac{1}{2} + std::f64::consts::FRAC_1_PI * (beta * x).atan() + 0.5 + } + Self::OptimalMonotonic { beta } => { + let xp = beta * x; + if xp < -0.25 { + -1.0 / (16.0 * xp) + } else if xp > 0.25 { + 1.0 - 1.0 / (16.0 * xp) + } else { + xp + 0.5 + } + } + } + } + + /// Evaluate the normalized derivative of the sigmoid function + /// + /// When x = 0, returns 1; + /// When x goes to +-infinity, returns 0 + pub fn eval_deriv(&self, x: f64) -> f64 { + match self { + Self::Logistic { beta } => { + // sech^2(\frac{\beta x}{2}) + let v = 0.5 * beta * x; + let epv = std::f64::consts::E.powf(v); + let epv2 = epv.powi(2); + 2.0 * epv / (epv2 + 1.0) + } + Self::Reciprocal { .. } => { + unimplemented!() + } + Self::Cauchy { .. } => { + unimplemented!() + } + Self::OptimalMonotonic { .. } => { + unimplemented!() + } + } + } +} diff --git a/core/src/common/foreign_predicates/soft_eq.rs b/core/src/common/foreign_predicates/soft_eq.rs new file mode 100644 index 0000000..e4eb1da --- /dev/null +++ b/core/src/common/foreign_predicates/soft_eq.rs @@ -0,0 +1,100 @@ +//! The soft equality predicate + +use super::*; + +/// Soft EQ foreign predicate +/// +/// ``` scl +/// extern pred soft_eq(lhs: T, rhs: T) +/// ``` +/// +/// It is going to output probability signal +#[derive(Clone, Debug)] +pub struct SoftNumberEq { + /// The type of the operands + pub ty: ValueType, + + /// The function chosen as the sigmoid + pub sigmoid: SigmoidFunction, +} + +impl SoftNumberEq { + pub fn new(ty: ValueType) -> Self { + Self { + ty, + sigmoid: SigmoidFunction::default(), + } + } + + pub fn new_with_sigmoid_fn(ty: ValueType, sigmoid: SigmoidFunction) -> Self { + Self { + ty, + sigmoid, + } + } + + fn soft_eq(&self, lhs: &Value, rhs: &Value) -> Option + where + T: std::ops::Sub + TryInto, + Value: TryInto, + { + let lhs: T = lhs.clone().try_into().ok()?; + let rhs: T = rhs.clone().try_into().ok()?; + let diff: f64 = (lhs - rhs).try_into().ok()?; + Some(self.sigmoid.eval_deriv(diff)) + } + + fn soft_eq_wrapper(&self, lhs: &Value, rhs: &Value) -> Vec<(DynamicInputTag, Vec)> + where + T: std::ops::Sub + TryInto, + Value: TryInto, + { + if let Some(prob) = self.soft_eq(lhs, rhs) { + vec![(DynamicInputTag::Float(prob), vec![])] + } else { + vec![] + } + } +} + +impl ForeignPredicate for SoftNumberEq { + fn name(&self) -> String { + format!("soft_eq_{}", self.ty) + } + + fn arity(&self) -> usize { + 2 + } + + fn argument_type(&self, i: usize) -> ValueType { + assert!(i < 2); + self.ty.clone() + } + + fn num_bounded(&self) -> usize { + 2 + } + + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + assert_eq!(bounded.len(), 2); + let lhs = &bounded[0]; + let rhs = &bounded[1]; + match &self.ty { + ValueType::I8 => self.soft_eq_wrapper::(lhs, rhs), + ValueType::I16 => self.soft_eq_wrapper::(lhs, rhs), + ValueType::I32 => self.soft_eq_wrapper::(lhs, rhs), + // ValueType::I64 => self.soft_gt_wrapper::(lhs, rhs), + // ValueType::I128 => self.soft_gt_wrapper::(lhs, rhs), + // ValueType::ISize => self.soft_gt_wrapper::(lhs, rhs), + ValueType::U8 => self.soft_eq_wrapper::(lhs, rhs), + ValueType::U16 => self.soft_eq_wrapper::(lhs, rhs), + ValueType::U32 => self.soft_eq_wrapper::(lhs, rhs), + // ValueType::U64 => self.soft_gt_wrapper::(lhs, rhs), + // ValueType::U128 => self.soft_gt_wrapper::(lhs, rhs), + // ValueType::USize => self.soft_gt_wrapper::(lhs, rhs), + ValueType::F32 => self.soft_eq_wrapper::(lhs, rhs), + ValueType::F64 => self.soft_eq_wrapper::(lhs, rhs), + _ => vec![], + } + } +} diff --git a/core/src/common/foreign_predicates/soft_gt.rs b/core/src/common/foreign_predicates/soft_gt.rs new file mode 100644 index 0000000..9a72fcf --- /dev/null +++ b/core/src/common/foreign_predicates/soft_gt.rs @@ -0,0 +1,98 @@ +//! The soft greater-than predicate + +use super::*; + +/// Soft Greater Than foreign predicate +/// +/// ``` scl +/// extern pred `>`(lhs: T, rhs: T) +/// ``` +#[derive(Clone, Debug)] +pub struct SoftNumberGt { + /// The type of the operands + pub ty: ValueType, + + /// The function chosen as the sigmoid + pub sigmoid: SigmoidFunction, +} + +impl SoftNumberGt { + pub fn new(ty: ValueType) -> Self { + Self { + ty, + sigmoid: SigmoidFunction::default(), + } + } + + pub fn new_with_sigmoid_fn(ty: ValueType, sigmoid: SigmoidFunction) -> Self { + Self { + ty, + sigmoid, + } + } + + fn soft_gt(&self, lhs: &Value, rhs: &Value) -> Option + where + T: std::ops::Sub + TryInto, + Value: TryInto, + { + let lhs: T = lhs.clone().try_into().ok()?; + let rhs: T = rhs.clone().try_into().ok()?; + let diff: f64 = (lhs - rhs).try_into().ok()?; + Some(self.sigmoid.eval(diff)) + } + + fn soft_gt_wrapper(&self, lhs: &Value, rhs: &Value) -> Vec<(DynamicInputTag, Vec)> + where + T: std::ops::Sub + TryInto, + Value: TryInto, + { + if let Some(prob) = self.soft_gt(lhs, rhs) { + vec![(DynamicInputTag::Float(prob), vec![])] + } else { + vec![] + } + } +} + +impl ForeignPredicate for SoftNumberGt { + fn name(&self) -> String { + format!("soft_gt_{}", self.ty) + } + + fn arity(&self) -> usize { + 2 + } + + fn argument_type(&self, i: usize) -> ValueType { + assert!(i < 2); + self.ty.clone() + } + + fn num_bounded(&self) -> usize { + 2 + } + + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + assert_eq!(bounded.len(), 2); + let lhs = &bounded[0]; + let rhs = &bounded[1]; + match &self.ty { + ValueType::I8 => self.soft_gt_wrapper::(lhs, rhs), + ValueType::I16 => self.soft_gt_wrapper::(lhs, rhs), + ValueType::I32 => self.soft_gt_wrapper::(lhs, rhs), + // ValueType::I64 => self.soft_gt_wrapper::(lhs, rhs), + // ValueType::I128 => self.soft_gt_wrapper::(lhs, rhs), + // ValueType::ISize => self.soft_gt_wrapper::(lhs, rhs), + ValueType::U8 => self.soft_gt_wrapper::(lhs, rhs), + ValueType::U16 => self.soft_gt_wrapper::(lhs, rhs), + ValueType::U32 => self.soft_gt_wrapper::(lhs, rhs), + // ValueType::U64 => self.soft_gt_wrapper::(lhs, rhs), + // ValueType::U128 => self.soft_gt_wrapper::(lhs, rhs), + // ValueType::USize => self.soft_gt_wrapper::(lhs, rhs), + ValueType::F32 => self.soft_gt_wrapper::(lhs, rhs), + ValueType::F64 => self.soft_gt_wrapper::(lhs, rhs), + _ => vec![], + } + } +} diff --git a/core/src/common/foreign_predicates/soft_lt.rs b/core/src/common/foreign_predicates/soft_lt.rs new file mode 100644 index 0000000..e2b9363 --- /dev/null +++ b/core/src/common/foreign_predicates/soft_lt.rs @@ -0,0 +1,98 @@ +//! The soft less-than predicate + +use super::*; + +/// Soft Less Than foreign predicate +/// +/// ``` scl +/// extern pred `<`(lhs: T, rhs: T) +/// ``` +#[derive(Clone, Debug)] +pub struct SoftNumberLt { + /// The type of the operands + pub ty: ValueType, + + /// The function chosen as the sigmoid + pub sigmoid: SigmoidFunction, +} + +impl SoftNumberLt { + pub fn new(ty: ValueType) -> Self { + Self { + ty, + sigmoid: SigmoidFunction::default() + } + } + + pub fn new_with_sigmoid_fn(ty: ValueType, sigmoid: SigmoidFunction) -> Self { + Self { + ty, + sigmoid, + } + } + + fn soft_lt(&self, lhs: &Value, rhs: &Value) -> Option + where + T: std::ops::Sub + TryInto, + Value: TryInto, + { + let lhs: T = lhs.clone().try_into().ok()?; + let rhs: T = rhs.clone().try_into().ok()?; + let diff: f64 = (rhs - lhs).try_into().ok()?; + Some(self.sigmoid.eval(diff)) + } + + fn soft_lt_wrapper(&self, lhs: &Value, rhs: &Value) -> Vec<(DynamicInputTag, Vec)> + where + T: std::ops::Sub + TryInto, + Value: TryInto, + { + if let Some(prob) = self.soft_lt(lhs, rhs) { + vec![(DynamicInputTag::Float(prob), vec![])] + } else { + vec![] + } + } +} + +impl ForeignPredicate for SoftNumberLt { + fn name(&self) -> String { + format!("soft_lt_{}", self.ty) + } + + fn arity(&self) -> usize { + 2 + } + + fn argument_type(&self, i: usize) -> ValueType { + assert!(i < 2); + self.ty.clone() + } + + fn num_bounded(&self) -> usize { + 2 + } + + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + assert_eq!(bounded.len(), 2); + let lhs = &bounded[0]; + let rhs = &bounded[1]; + match &self.ty { + ValueType::I8 => self.soft_lt_wrapper::(lhs, rhs), + ValueType::I16 => self.soft_lt_wrapper::(lhs, rhs), + ValueType::I32 => self.soft_lt_wrapper::(lhs, rhs), + // ValueType::I64 => self.soft_lt_wrapper::(lhs, rhs), + // ValueType::I128 => self.soft_lt_wrapper::(lhs, rhs), + // ValueType::ISize => self.soft_lt_wrapper::(lhs, rhs), + ValueType::U8 => self.soft_lt_wrapper::(lhs, rhs), + ValueType::U16 => self.soft_lt_wrapper::(lhs, rhs), + ValueType::U32 => self.soft_lt_wrapper::(lhs, rhs), + // ValueType::U64 => self.soft_lt_wrapper::(lhs, rhs), + // ValueType::U128 => self.soft_lt_wrapper::(lhs, rhs), + // ValueType::USize => self.soft_lt_wrapper::(lhs, rhs), + ValueType::F32 => self.soft_lt_wrapper::(lhs, rhs), + ValueType::F64 => self.soft_lt_wrapper::(lhs, rhs), + _ => vec![], + } + } +} diff --git a/core/src/common/foreign_predicates/soft_neq.rs b/core/src/common/foreign_predicates/soft_neq.rs new file mode 100644 index 0000000..9593d45 --- /dev/null +++ b/core/src/common/foreign_predicates/soft_neq.rs @@ -0,0 +1,100 @@ +//! The soft equality predicate + +use super::*; + +/// Soft EQ foreign predicate +/// +/// ``` scl +/// extern pred soft_neq(lhs: T, rhs: T) +/// ``` +/// +/// It is going to output probability signal +#[derive(Clone, Debug)] +pub struct SoftNumberNeq { + /// The type of the operands + pub ty: ValueType, + + /// The function chosen as the sigmoid + pub sigmoid: SigmoidFunction, +} + +impl SoftNumberNeq { + pub fn new(ty: ValueType) -> Self { + Self { + ty, + sigmoid: SigmoidFunction::default(), + } + } + + pub fn new_with_sigmoid_fn(ty: ValueType, sigmoid: SigmoidFunction) -> Self { + Self { + ty, + sigmoid, + } + } + + fn soft_neq(&self, lhs: &Value, rhs: &Value) -> Option + where + T: std::ops::Sub + TryInto, + Value: TryInto, + { + let lhs: T = lhs.clone().try_into().ok()?; + let rhs: T = rhs.clone().try_into().ok()?; + let diff: f64 = (lhs - rhs).try_into().ok()?; + Some(1.0 - self.sigmoid.eval_deriv(diff)) + } + + fn soft_neq_wrapper(&self, lhs: &Value, rhs: &Value) -> Vec<(DynamicInputTag, Vec)> + where + T: std::ops::Sub + TryInto, + Value: TryInto, + { + if let Some(prob) = self.soft_neq(lhs, rhs) { + vec![(DynamicInputTag::Float(prob), vec![])] + } else { + vec![] + } + } +} + +impl ForeignPredicate for SoftNumberNeq { + fn name(&self) -> String { + format!("soft_neq_{}", self.ty) + } + + fn arity(&self) -> usize { + 2 + } + + fn argument_type(&self, i: usize) -> ValueType { + assert!(i < 2); + self.ty.clone() + } + + fn num_bounded(&self) -> usize { + 2 + } + + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + assert_eq!(bounded.len(), 2); + let lhs = &bounded[0]; + let rhs = &bounded[1]; + match &self.ty { + ValueType::I8 => self.soft_neq_wrapper::(lhs, rhs), + ValueType::I16 => self.soft_neq_wrapper::(lhs, rhs), + ValueType::I32 => self.soft_neq_wrapper::(lhs, rhs), + // ValueType::I64 => self.soft_neq_wrapper::(lhs, rhs), + // ValueType::I128 => self.soft_neq_wrapper::(lhs, rhs), + // ValueType::ISize => self.soft_neq_wrapper::(lhs, rhs), + ValueType::U8 => self.soft_neq_wrapper::(lhs, rhs), + ValueType::U16 => self.soft_neq_wrapper::(lhs, rhs), + ValueType::U32 => self.soft_neq_wrapper::(lhs, rhs), + // ValueType::U64 => self.soft_neq_wrapper::(lhs, rhs), + // ValueType::U128 => self.soft_neq_wrapper::(lhs, rhs), + // ValueType::USize => self.soft_neq_wrapper::(lhs, rhs), + ValueType::F32 => self.soft_neq_wrapper::(lhs, rhs), + ValueType::F64 => self.soft_neq_wrapper::(lhs, rhs), + _ => vec![], + } + } +} diff --git a/core/src/common/foreign_predicates/string_chars.rs b/core/src/common/foreign_predicates/string_chars.rs new file mode 100644 index 0000000..5f726b4 --- /dev/null +++ b/core/src/common/foreign_predicates/string_chars.rs @@ -0,0 +1,59 @@ +use super::*; + +/// Range foreign predicate +/// +/// ``` scl +/// extern pred string_chars(s: String, id: usize, c: char)[bff] +/// ``` +#[derive(Clone)] +pub struct StringCharsBFF; + +impl Default for StringCharsBFF { + fn default() -> Self { + Self + } +} + +impl StringCharsBFF { + pub fn new() -> Self { + Self + } +} + +impl ForeignPredicate for StringCharsBFF { + fn name(&self) -> String { + "string_chars".to_string() + } + + fn arity(&self) -> usize { + 3 + } + + fn argument_type(&self, i: usize) -> ValueType { + match i { + 0 => ValueType::String, + 1 => ValueType::USize, + 2 => ValueType::Char, + _ => panic!("Invalid argument ID `{}`", i), + } + } + + fn num_bounded(&self) -> usize { + 1 + } + + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + assert_eq!(bounded.len(), 1); + let s = &bounded[0]; + match s { + Value::String(s) => { + s + .chars() + .enumerate() + .map(|(i, c)| (DynamicInputTag::None, vec![Value::from(i), Value::from(c)])) + .collect() + } + _ => panic!("Bounded argument is not string") + } + } +} diff --git a/core/src/common/generic_tuple.rs b/core/src/common/generic_tuple.rs index 2fc4abb..37ec4c8 100644 --- a/core/src/common/generic_tuple.rs +++ b/core/src/common/generic_tuple.rs @@ -47,6 +47,12 @@ impl std::ops::Index<&TupleAccessor> for GenericTuple { } } +impl std::iter::FromIterator for GenericTuple { + fn from_iter>(iter: I) -> Self { + Self::Tuple(iter.into_iter().map(|v| Self::Value(v)).collect()) + } +} + impl From<()> for GenericTuple { fn from(_: ()) -> Self { Self::Tuple(Box::new([])) diff --git a/core/src/common/input_tag.rs b/core/src/common/input_tag.rs index f76727f..dad19f9 100644 --- a/core/src/common/input_tag.rs +++ b/core/src/common/input_tag.rs @@ -7,8 +7,6 @@ pub enum DynamicInputTag { ExclusiveFloat(f64, usize), } -pub type InputTag = DynamicInputTag; - impl DynamicInputTag { pub fn is_some(&self) -> bool { match self { @@ -52,18 +50,18 @@ impl Default for DynamicInputTag { } } -pub trait FromInputTag: Sized { - fn from_input_tag(t: &DynamicInputTag) -> Option; +pub trait StaticInputTag: Sized { + fn from_dynamic_input_tag(_: &DynamicInputTag) -> Option; } -impl FromInputTag for T { - default fn from_input_tag(_: &DynamicInputTag) -> Option { +impl StaticInputTag for T { + default fn from_dynamic_input_tag(_: &DynamicInputTag) -> Option { None } } -impl FromInputTag for bool { - fn from_input_tag(t: &DynamicInputTag) -> Option { +impl StaticInputTag for bool { + fn from_dynamic_input_tag(t: &DynamicInputTag) -> Option { match t { DynamicInputTag::Bool(b) => Some(b.clone()), _ => None, @@ -71,18 +69,20 @@ impl FromInputTag for bool { } } -impl FromInputTag for f64 { - fn from_input_tag(t: &DynamicInputTag) -> Option { +impl StaticInputTag for f64 { + fn from_dynamic_input_tag(t: &DynamicInputTag) -> Option { match t { DynamicInputTag::Float(f) => Some(f.clone()), + DynamicInputTag::ExclusiveFloat(f, _) => Some(f.clone()), _ => None, } } } -impl FromInputTag for (f64, Option) { - fn from_input_tag(t: &DynamicInputTag) -> Option<(f64, Option)> { +impl StaticInputTag for (f64, Option) { + fn from_dynamic_input_tag(t: &DynamicInputTag) -> Option { match t { + DynamicInputTag::Exclusive(i) => Some((1.0, Some(i.clone()))), DynamicInputTag::Float(f) => Some((f.clone(), None)), DynamicInputTag::ExclusiveFloat(f, u) => Some((f.clone(), Some(u.clone()))), _ => None, diff --git a/core/src/common/mod.rs b/core/src/common/mod.rs index f811f9e..1c97bb4 100644 --- a/core/src/common/mod.rs +++ b/core/src/common/mod.rs @@ -6,7 +6,9 @@ pub mod constants; pub mod element; pub mod expr; pub mod foreign_function; +pub mod foreign_functions; pub mod foreign_predicate; +pub mod foreign_predicates; pub mod generic_tuple; pub mod input_file; pub mod input_tag; diff --git a/core/src/common/tuple.rs b/core/src/common/tuple.rs index 0a86607..a7049d6 100644 --- a/core/src/common/tuple.rs +++ b/core/src/common/tuple.rs @@ -7,17 +7,33 @@ use super::value::Value; pub type Tuple = GenericTuple; impl Tuple { - pub fn from_primitives(prims: Vec) -> Self { - Self::Tuple(prims.into_iter().map(Self::Value).collect()) - } - pub fn tuple_type(&self) -> TupleType { TupleType::type_of(self) } + pub fn arity(&self) -> usize { + match self { + Self::Tuple(ts) => ts.len(), + _ => 0, + } + } + + pub fn as_values(&self) -> Vec { + match self { + Self::Value(_) => panic!("Not a tuple"), + Self::Tuple(t) => t + .iter() + .map(|t| match t { + Self::Value(v) => v.clone(), + _ => panic!("Not a value"), + }) + .collect(), + } + } + pub fn as_ref_values(&self) -> Vec<&Value> { match self { - Self::Value(p) => vec![p], + Self::Value(_) => panic!("Not a tuple"), Self::Tuple(t) => t .iter() .map(|t| match t { @@ -162,6 +178,12 @@ where } } +impl From> for Tuple { + fn from(v: Vec) -> Self { + Self::Tuple(v.into_iter().map(|v| v.into()).collect()) + } +} + pub trait AsTuple { fn as_tuple(&self) -> T; } diff --git a/core/src/common/value.rs b/core/src/common/value.rs index f6b7abe..981549b 100644 --- a/core/src/common/value.rs +++ b/core/src/common/value.rs @@ -3,6 +3,7 @@ use std::convert::*; use super::value_type::*; +use chrono::{DateTime, Duration, Utc}; #[derive(Debug, Clone, PartialEq, PartialOrd)] pub enum Value { @@ -24,6 +25,8 @@ pub enum Value { Bool(bool), Str(&'static str), String(String), + DateTime(DateTime), + Duration(Duration), // RcString(Rc), } @@ -32,6 +35,20 @@ impl Value { ValueType::type_of(self) } + pub fn as_date_time(&self) -> DateTime { + match self { + Self::DateTime(d) => d.clone(), + _ => panic!("Not a DateTime") + } + } + + pub fn as_duration(&self) -> Duration { + match self { + Self::Duration(d) => d.clone(), + _ => panic!("Not a Duration") + } + } + pub fn as_usize(&self) -> usize { match self { Self::USize(u) => *u, @@ -80,6 +97,8 @@ impl std::hash::Hash for Value { Self::Bool(b) => b.hash(state), Self::Str(s) => s.hash(state), Self::String(s) => s.hash(state), + Self::DateTime(d) => d.hash(state), + Self::Duration(d) => d.hash(state), } } } @@ -105,6 +124,8 @@ impl std::fmt::Display for Value { Self::Bool(i) => f.write_fmt(format_args!("{}", i)), Self::Str(i) => f.write_fmt(format_args!("{:?}", i)), Self::String(i) => f.write_fmt(format_args!("{:?}", i)), + Self::DateTime(i) => f.write_fmt(format_args!("t\"{}\"", i)), + Self::Duration(i) => f.write_fmt(format_args!("d\"{}\"", i)), // Self::RcString(i) => f.write_fmt(format_args!("{:?}", i)), } } @@ -218,6 +239,18 @@ impl From for Value { } } +impl From> for Value { + fn from(dt: DateTime) -> Self { + Self::DateTime(dt) + } +} + +impl From for Value { + fn from(d: Duration) -> Self { + Self::Duration(d) + } +} + // impl From> for Value { // fn from(s: Rc) -> Self { // Self::RcString(s) @@ -239,6 +272,7 @@ macro_rules! impl_try_into { }; } +#[derive(Clone, Debug, Default)] pub struct ValueConversionError; impl_try_into!(i8, I8); diff --git a/core/src/common/value_type.rs b/core/src/common/value_type.rs index 837ec3a..3ce1c86 100644 --- a/core/src/common/value_type.rs +++ b/core/src/common/value_type.rs @@ -1,5 +1,7 @@ // use std::rc::Rc; +use crate::utils; + use super::tuple::*; use super::value::*; @@ -23,6 +25,8 @@ pub enum ValueType { Bool, Str, String, + DateTime, + Duration, // RcString, } @@ -48,6 +52,8 @@ impl ValueType { Bool(_) => Self::Bool, Str(_) => Self::Str, String(_) => Self::String, + DateTime(_) => Self::DateTime, + Duration(_) => Self::Duration, // RcString(_) => Self::RcString, } } @@ -125,6 +131,13 @@ impl ValueType { } } + pub fn is_unsigned_integer(&self) -> bool { + match self { + Self::U8 | Self::U16 | Self::U32 | Self::U64 | Self::U128 | Self::USize => true, + _ => false, + } + } + pub fn is_float(&self) -> bool { match self { Self::F32 | Self::F64 => true, @@ -153,6 +166,20 @@ impl ValueType { } } + pub fn is_datetime(&self) -> bool { + match self { + Self::DateTime => true, + _ => false, + } + } + + pub fn is_duration(&self) -> bool { + match self { + Self::Duration => true, + _ => false, + } + } + pub fn can_type_cast(&self, target: &Self) -> bool { if self.is_numeric() && target.is_numeric() { true @@ -201,6 +228,10 @@ impl ValueType { Self::Str => panic!("Cannot parse into a static string"), Self::String => Ok(Value::String(s.to_string())), // Self::RcString => Ok(Value::RcString(Rc::new(s.to_string()))), + + // DateTime and Duration + Self::DateTime => Ok(Value::DateTime(utils::parse_date_time_string(s).ok_or_else(|| ValueParseError::new(s, self))?)), + Self::Duration => Ok(Value::Duration(utils::parse_duration_string(s).ok_or_else(|| ValueParseError::new(s, self))?)), } } @@ -255,6 +286,56 @@ impl ValueType { _ => panic!("Cannot perform sum on type `{}`", self), } } + + /// Get all integer types + pub fn integers() -> &'static [ValueType] { + &[ + ValueType::I8, + ValueType::I16, + ValueType::I32, + ValueType::I64, + ValueType::I128, + ValueType::ISize, + ValueType::U8, + ValueType::U16, + ValueType::U32, + ValueType::U64, + ValueType::U128, + ValueType::USize, + ] + } + + /// Get all signed integer types + pub fn signed_integers() -> &'static [ValueType] { + &[ + ValueType::I8, + ValueType::I16, + ValueType::I32, + ValueType::I64, + ValueType::I128, + ValueType::ISize, + ] + } + + /// Get all unsigned integer types + pub fn unsigned_integers() -> &'static [ValueType] { + &[ + ValueType::U8, + ValueType::U16, + ValueType::U32, + ValueType::U64, + ValueType::U128, + ValueType::USize, + ] + } + + /// Get all floating point number types + pub fn floats() -> &'static [ValueType] { + &[ + ValueType::F32, + ValueType::F64, + ] + } } #[derive(Clone, Debug)] @@ -301,6 +382,8 @@ impl std::fmt::Display for ValueType { Str => f.write_str("&str"), String => f.write_str("String"), // RcString => f.write_str("Rc"), + DateTime => f.write_str("DateTime"), + Duration => f.write_str("Duration"), } } } @@ -417,6 +500,17 @@ impl FromType for ValueType { } } +impl FromType> for ValueType { + fn from_type() -> Self { + Self::DateTime + } +} +impl FromType for ValueType { + fn from_type() -> Self { + Self::Duration + } +} + // impl FromType> for ValueType { // fn from_type() -> Self { // Self::RcString diff --git a/core/src/compiler/back/ast.rs b/core/src/compiler/back/ast.rs index 5dae5c3..80ed32a 100644 --- a/core/src/compiler/back/ast.rs +++ b/core/src/compiler/back/ast.rs @@ -3,7 +3,8 @@ use std::collections::*; use super::Attributes; use crate::common::aggregate_op::AggregateOp; use crate::common::foreign_function::ForeignFunctionRegistry; -use crate::common::input_tag::InputTag; +use crate::common::foreign_predicate::ForeignPredicateRegistry; +use crate::common::input_tag::DynamicInputTag; use crate::common::output_option::OutputOption; use crate::compiler::front; @@ -19,6 +20,7 @@ pub struct Program { pub disjunctive_facts: Vec>, pub rules: Vec, pub function_registry: ForeignFunctionRegistry, + pub predicate_registry: ForeignPredicateRegistry, } impl Program { @@ -55,7 +57,7 @@ impl Program { #[derive(Clone, Debug, PartialEq)] pub struct Fact { - pub tag: InputTag, + pub tag: DynamicInputTag, pub predicate: String, pub args: Vec, } @@ -121,11 +123,13 @@ impl Head { } } +/// A conjunction of literals #[derive(Clone, Debug, PartialEq)] pub struct Conjunction { pub args: Vec, } +/// A term is the argument of a literal #[derive(Clone, Debug, PartialEq)] pub enum Term { Variable(Variable), @@ -133,10 +137,12 @@ pub enum Term { } impl Term { + /// Create a new variable term using the given name and type pub fn variable(name: String, ty: Type) -> Self { Self::Variable(Variable { name, ty }) } + /// Check if the term is a variable pub fn is_variable(&self) -> bool { match self { Self::Variable(_) => true, @@ -144,6 +150,7 @@ impl Term { } } + /// Check if the term is a constant pub fn is_constant(&self) -> bool { match self { Self::Constant(_) => true, @@ -151,6 +158,7 @@ impl Term { } } + /// Get the variable if the term is a variable pub fn as_variable(&self) -> Option<&Variable> { match self { Self::Variable(v) => Some(v), @@ -158,6 +166,7 @@ impl Term { } } + /// Get the constant if the term is a constant pub fn as_constant(&self) -> Option<&Constant> { match self { Self::Constant(c) => Some(c), @@ -204,6 +213,7 @@ impl std::fmt::Debug for Literal { } impl Literal { + /// Create a new assignment of binary expression pub fn binary_expr(left: Variable, op: BinaryExprOp, op1: Term, op2: Term) -> Self { Self::Assign(Assign { left, @@ -211,6 +221,7 @@ impl Literal { }) } + /// Create a new assignment of unary expression pub fn unary_expr(left: Variable, op: UnaryExprOp, op1: Term) -> Self { Self::Assign(Assign { left, @@ -218,6 +229,7 @@ impl Literal { }) } + /// Create a new assignment of if-then-else expression pub fn if_then_else_expr(left: Variable, cond: Term, then_br: Term, else_br: Term) -> Self { Self::Assign(Assign { left, @@ -225,6 +237,7 @@ impl Literal { }) } + /// Create a new assignment of call expression pub fn call_expr(left: Variable, function: String, args: Vec) -> Self { Self::Assign(Assign { left, @@ -232,6 +245,7 @@ impl Literal { }) } + /// Create a new assignment of if-then-else expression pub fn binary_constraint(op: BinaryConstraintOp, op1: Term, op2: Term) -> Self { Self::Constraint(Constraint::Binary(BinaryConstraint { op, op1, op2 })) } @@ -248,6 +262,7 @@ pub struct Atom { } impl Atom { + /// Create a new atom pub fn new(predicate: String, args: Vec) -> Self { Self { predicate, args } } @@ -276,6 +291,7 @@ impl Atom { return true; } + /// Get the atom's arguments which are variables pub fn variable_args(&self) -> impl Iterator { self.args.iter().filter_map(|a| match a { Term::Variable(v) => Some(v), @@ -283,14 +299,17 @@ impl Atom { }) } + /// Get a set of unique variables in the atom's arguments pub fn unique_variable_args(&self) -> impl Iterator { self.variable_args().cloned().collect::>().into_iter() } + /// Check if the atom has constant arguments pub fn has_constant_arg(&self) -> bool { self.args.iter().any(|a| a.is_constant()) } + /// Create a partition of the atom's arguments into constant and variable pub fn const_var_partition(&self) -> (Vec<(usize, &Constant)>, Vec<(usize, &Variable)>) { let (constants, variables): (Vec<_>, Vec<_>) = self.args.iter().enumerate().partition(|(_, t)| t.is_constant()); let constants = constants @@ -386,6 +405,7 @@ pub struct CallExpr { pub args: Vec, } +/// A constraint literal which is either a binary or unary constraint #[derive(Clone, Debug, PartialEq)] pub enum Constraint { Binary(BinaryConstraint), @@ -393,6 +413,22 @@ pub enum Constraint { } impl Constraint { + /// Create a new equality constraint using two terms + pub fn eq(op1: Term, op2: Term) -> Self { + Self::Binary(BinaryConstraint { op: BinaryConstraintOp::Eq, op1, op2 }) + } + + /// Create a new inequality constraint using two terms + pub fn neq(op1: Term, op2: Term) -> Self { + Self::Binary(BinaryConstraint { op: BinaryConstraintOp::Neq, op1, op2 }) + } + + /// Create a new binary constraint using an operator and two terms + pub fn binary(op: BinaryConstraintOp, op1: Term, op2: Term) -> Self { + Self::Binary(BinaryConstraint { op, op1, op2 }) + } + + /// Find the variable arguments occurred in this constraint pub fn variable_args(&self) -> Vec<&Variable> { let mut args = vec![]; match self { diff --git a/core/src/compiler/back/b2r.rs b/core/src/compiler/back/b2r.rs index 7e8792b..07d17a7 100644 --- a/core/src/compiler/back/b2r.rs +++ b/core/src/compiler/back/b2r.rs @@ -5,6 +5,7 @@ use itertools::Itertools; use super::*; use crate::common::binary_op::BinaryOp; use crate::common::expr::Expr; +use crate::common::foreign_predicate::ForeignPredicate; use crate::common::output_option::OutputOption; use crate::common::tuple::Tuple; use crate::common::tuple_access::TupleAccessor; @@ -162,7 +163,7 @@ impl Program { // All the updates for predicate in &stratum.predicates { for rule in self.rules_of_predicate(predicate.clone()) { - let ctx = QueryPlanContext::from_rule(stratum, rule); + let ctx = QueryPlanContext::from_rule(stratum, &self.predicate_registry, rule); let plan = ctx.query_plan(); updates.push(self.plan_to_ram_update(&mut b2r_context, &rule.head, &plan)); } @@ -343,6 +344,9 @@ impl Program { HighRamNode::Join(d1, d2) => self.join_plan_to_ram_dataflow(ctx, goal, &*d1, &*d2, prop), HighRamNode::Antijoin(d1, d2) => self.antijoin_plan_to_ram_dataflow(ctx, goal, &*d1, &*d2, prop), HighRamNode::Reduce(r) => self.reduce_plan_to_ram_dataflow(ctx, goal, r, prop), + HighRamNode::ForeignPredicateGround(a) => self.fp_ground_plan_to_ram_dataflow(ctx, goal, a, prop), + HighRamNode::ForeignPredicateConstraint(d, a) => self.fp_constraint_plan_to_ram_dataflow(ctx, goal, &*d, a, prop), + HighRamNode::ForeignPredicateJoin(d, a) => self.fp_join_plan_to_ram_dataflow(ctx, goal, &*d, a, prop), } } @@ -793,6 +797,100 @@ impl Program { Self::process_dataflow(ctx, goal, dataflow, prop) } + fn fp_ground_plan_to_ram_dataflow( + &self, + ctx: &mut B2RContext, + goal: &VariableTuple, + atom: &Atom, + prop: DataflowProp, + ) -> ram::Dataflow { + // Find the foreign predicate in the registry + let fp = self.predicate_registry.get(&atom.predicate).unwrap(); + + // Get information from the atom + let pred: String = atom.predicate.clone(); + let inputs: Vec = atom.args.iter().take(fp.num_bounded()).map(|arg| arg.as_constant().unwrap()).cloned().collect(); + let ground_dataflow = ram::Dataflow::ForeignPredicateGround(pred, inputs); + + // Get the projection onto the variable tuple + let var_tuple = atom.args.iter().skip(fp.num_bounded()).map(|arg| arg.as_variable().unwrap()).cloned(); + let project = VariableTuple::from_vars(var_tuple, false).projection(goal); + + // Project the dataflow + let dataflow = ram::Dataflow::project(ground_dataflow, project); + Self::process_dataflow(ctx, goal, dataflow, prop) + } + + fn fp_constraint_plan_to_ram_dataflow( + &self, + ctx: &mut B2RContext, + goal: &VariableTuple, + d: &Plan, + atom: &Atom, + prop: DataflowProp, + ) -> ram::Dataflow { + // Generate a sub-dataflow + let sub_goal = VariableTuple::from_vars(d.bounded_vars.iter().cloned(), false); + let sub_dataflow: ram::Dataflow = self.plan_to_ram_dataflow(ctx, &sub_goal, d, prop.with_need_sorted(false)); + + // Generate information for foreign predicate constraint + let pred: String = atom.predicate.clone(); + let exprs: Vec = atom.args.iter().map(|arg| { + match arg { + Term::Constant(c) => Expr::Constant(c.clone()), + Term::Variable(v) => Expr::Access(sub_goal.accessor_of(v).unwrap()), + } + }).collect(); + + // Return a foreign predicate constraint dataflow + let dataflow = sub_dataflow.foreign_predicate_constraint(pred, exprs); + + // Get the projection onto the variable tuple + let projection = sub_goal.projection(goal); + let dataflow = ram::Dataflow::project(dataflow, projection); + Self::process_dataflow(ctx, goal, dataflow, prop) + } + + fn fp_join_plan_to_ram_dataflow( + &self, + ctx: &mut B2RContext, + goal: &VariableTuple, + d: &Plan, + atom: &Atom, + prop: DataflowProp, + ) -> ram::Dataflow { + // Find the foreign predicate in the registry + let fp = self.predicate_registry.get(&atom.predicate).unwrap(); + + // Generate a sub-dataflow + let left_goal = VariableTuple::from_vars(d.bounded_vars.iter().cloned(), false); + let left_dataflow: ram::Dataflow = self.plan_to_ram_dataflow(ctx, &left_goal, d, prop.with_need_sorted(false)); + + // Generate information for foreign predicate constraint + let pred: String = atom.predicate.clone(); + let exprs: Vec = atom.args.iter().take(fp.num_bounded()).map(|arg| { + match arg { + Term::Constant(c) => Expr::Constant(c.clone()), + Term::Variable(v) => { + Expr::Access(left_goal.accessor_of(v).unwrap()) + }, + } + }).collect(); + + // Generate the joint dataflow + let join_dataflow = left_dataflow.foreign_predicate_join(pred, exprs); + + // Get the variable tuple of the joined output + let free_vars: Vec<_> = atom.args.iter().skip(fp.num_bounded()).map(|arg| arg.as_variable().unwrap()).cloned().collect(); + let right_tuple = VariableTuple::from_vars(free_vars.iter().cloned(), false); + let var_tuple = VariableTuple::from((left_goal, right_tuple)); + + // Project the dataflow + let projection = var_tuple.projection(goal); + let project_dataflow = ram::Dataflow::project(join_dataflow, projection); + Self::process_dataflow(ctx, goal, project_dataflow, prop) + } + fn process_dataflow( ctx: &mut B2RContext, goal: &VariableTuple, diff --git a/core/src/compiler/back/compile.rs b/core/src/compiler/back/compile.rs index e665189..69aac7a 100644 --- a/core/src/compiler/back/compile.rs +++ b/core/src/compiler/back/compile.rs @@ -105,6 +105,7 @@ impl Program { Ok(ram::Program { strata: ram_strata, function_registry: self.function_registry.clone(), + predicate_registry: self.predicate_registry.clone(), relation_to_stratum, }) } diff --git a/core/src/compiler/back/optimizations/empty_rule_to_fact.rs b/core/src/compiler/back/optimizations/empty_rule_to_fact.rs index e742c5d..2260e29 100644 --- a/core/src/compiler/back/optimizations/empty_rule_to_fact.rs +++ b/core/src/compiler/back/optimizations/empty_rule_to_fact.rs @@ -1,4 +1,4 @@ -use crate::common::input_tag::InputTag; +use crate::common::input_tag::*; use super::super::*; @@ -7,7 +7,7 @@ pub fn empty_rule_to_fact(rules: &mut Vec, facts: &mut Vec) { if rule.body.args.is_empty() { // Create fact let fact = Fact { - tag: InputTag::None, + tag: DynamicInputTag::None, predicate: rule.head.predicate.clone(), args: rule .head diff --git a/core/src/compiler/back/query_plan.rs b/core/src/compiler/back/query_plan.rs index 90bc6db..e948c6a 100644 --- a/core/src/compiler/back/query_plan.rs +++ b/core/src/compiler/back/query_plan.rs @@ -2,21 +2,33 @@ use std::collections::*; use itertools::Itertools; +use crate::common::foreign_predicate::*; +use crate::common::value_type::*; + use super::*; +/// The context for constructing a query plan #[derive(Clone, Debug)] pub struct QueryPlanContext<'a> { pub head_vars: HashSet, pub reduces: Vec, pub pos_atoms: Vec, - pub neg_atoms: Vec, + pub neg_atoms: Vec, pub assigns: Vec, pub constraints: Vec, + pub foreign_predicate_pos_atoms: Vec, pub stratum: &'a Stratum, + pub foreign_predicate_registry: &'a ForeignPredicateRegistry, } impl<'a> QueryPlanContext<'a> { - pub fn from_rule(stratum: &'a Stratum, rule: &Rule) -> Self { + /// Create a new query plan context from a rule + pub fn from_rule( + stratum: &'a Stratum, + foreign_predicate_registry: &'a ForeignPredicateRegistry, + rule: &Rule, + ) -> Self { + // First create an empty context let mut ctx = Self { head_vars: rule.head.variable_args().cloned().collect(), reduces: vec![], @@ -24,13 +36,23 @@ impl<'a> QueryPlanContext<'a> { neg_atoms: vec![], assigns: vec![], constraints: vec![], + foreign_predicate_pos_atoms: vec![], stratum, + foreign_predicate_registry, }; + + // Then fill it with the literals extracted from the rule for literal in rule.body_literals() { match literal { Literal::Reduce(r) => ctx.reduces.push(r.clone()), - Literal::Atom(a) => ctx.pos_atoms.push(a.clone()), - Literal::NegAtom(n) => ctx.neg_atoms.push(n.clone()), + Literal::Atom(a) => { + if foreign_predicate_registry.contains(&a.predicate) { + ctx.foreign_predicate_pos_atoms.push(a.clone()); + } else { + ctx.pos_atoms.push(a.clone()) + } + }, + Literal::NegAtom(n) => ctx.neg_atoms.push(n.atom.clone()), Literal::Assign(a) => ctx.assigns.push(a.clone()), Literal::Constraint(c) => ctx.constraints.push(c.clone()), Literal::True => {} @@ -42,8 +64,11 @@ impl<'a> QueryPlanContext<'a> { ctx } + /// Find all the bounded arguments given the set of positive atoms pub fn bounded_args_from_pos_atoms_set(&self, set: &Vec<&usize>) -> HashSet { let mut base_bounded_args = HashSet::new(); + + // Add the base cases: all the arguments in the positive atoms form the base bounded args for atom in self .pos_atoms .iter() @@ -61,6 +86,8 @@ impl<'a> QueryPlanContext<'a> { // Fix point iteration loop { let mut new_bounded_args = base_bounded_args.clone(); + + // Find the bounded args from the assigns for assign in &self.assigns { let can_bound = match &assign.right { AssignExpr::Binary(b) => { @@ -82,6 +109,25 @@ impl<'a> QueryPlanContext<'a> { } } + // Find the bounded args from the foreign predicate atoms + for atom in &self.foreign_predicate_pos_atoms { + // First find the predicate from the registry + let predicate = self.foreign_predicate_registry.get(&atom.predicate).unwrap(); + + // Then check if all the to-bound arguments are bounded + let can_bound = atom.args.iter().take(predicate.num_bounded()).all(|a| term_is_bounded(&new_bounded_args, a)); + + // If it can be bounded, add the rest of the arguments to the bounded args + if can_bound { + for arg in atom.args.iter().skip(predicate.num_bounded()) { + if let Term::Variable(v) = arg { + new_bounded_args.insert(v.clone()); + } + } + } + } + + // Check if the fix point is reached if new_bounded_args == base_bounded_args { break new_bounded_args; } else { @@ -91,13 +137,19 @@ impl<'a> QueryPlanContext<'a> { } fn pos_atom_arcs(&self, beam_size: usize) -> State { + // If there is no positive atom, return an empty state if self.pos_atoms.is_empty() { return State::new(); } + // Maintain a priority queue of searching states let mut priority_queue = BinaryHeap::new(); priority_queue.push(State::new()); + + // Maintain a set of final states let mut final_states = BinaryHeap::new(); + + // Start the (beam) search process while !priority_queue.is_empty() { let mut temp_queue = BinaryHeap::new(); @@ -127,81 +179,343 @@ impl<'a> QueryPlanContext<'a> { final_states.pop().unwrap() } - /// The main entry function that computes a query plan from a sequence of arcs - fn get_query_plan(&self, arcs: &Vec) -> Plan { - // Stage 1: Helper Functions (Closures) + fn try_apply_constraint(&self, applied_constraints: &mut HashSet, fringe: Plan) -> Plan { + // Apply as many constraints as possible + let node = fringe; + let mut to_apply_constraints = vec![]; + for (i, constraint) in self.constraints.iter().enumerate() { + if !applied_constraints.contains(&i) { + let can_apply = constraint.variable_args().iter().all(|v| node.bounded_vars.contains(v)); + if can_apply { + applied_constraints.insert(i); + to_apply_constraints.push(constraint.clone()); + } + } + } + if to_apply_constraints.is_empty() { + node + } else { + let new_bounded_vars = node.bounded_vars.clone(); - // Store the applied constraints - let mut applied_constraints = HashSet::new(); - let mut try_apply_constraint = |fringe: Plan| -> Plan { - // Apply as many constraints as possible - let node = fringe; - let mut to_apply_constraints = vec![]; - for (i, constraint) in self.constraints.iter().enumerate() { - if !applied_constraints.contains(&i) { - let can_apply = constraint.variable_args().iter().all(|v| node.bounded_vars.contains(v)); - if can_apply { - applied_constraints.insert(i); - to_apply_constraints.push(constraint.clone()); - } + Plan { + bounded_vars: new_bounded_vars, + ram_node: HighRamNode::Filter(Box::new(node), to_apply_constraints), + } + } + } + + /// Try applying as many assigns as possible + fn try_apply_assigns(&self, applied_assigns: &mut HashSet, mut fringe: Plan) -> Plan { + // Find all the assigns that are needed to bound the need_projection_vars + let mut bounded_vars = fringe.bounded_vars.clone(); + loop { + let mut new_projections = Vec::new(); + + // Check if we can apply more assigns + for (i, assign) in self.assigns.iter().enumerate() { + if !applied_assigns.contains(&i) + && !bounded_vars.contains(&assign.left) + && assign.variable_args().into_iter().all(|v| bounded_vars.contains(v)) + { + applied_assigns.insert(i); + new_projections.push(assign.clone()); } } - if to_apply_constraints.is_empty() { - node + + // Create projected left node + if new_projections.is_empty() { + break fringe; } else { - let new_bounded_vars = node.bounded_vars.clone(); + bounded_vars.extend(new_projections.iter().map(|i| i.left.clone())); + fringe = Plan { + bounded_vars: bounded_vars.clone(), + ram_node: HighRamNode::Project(Box::new(fringe), new_projections), + }; + } + } + } - Plan { - bounded_vars: new_bounded_vars, - ram_node: HighRamNode::Filter(Box::new(node), to_apply_constraints), - } + /// Get the essential information for analyzing a foreign predicate atom + fn foreign_predicate_atom_info<'b, 'c>(&'b self, atom: &'c Atom) -> (&'b DynamicForeignPredicate, Vec<(usize, &'c Term)>, Vec<(usize, &'c Term)>) { + let pred = self.foreign_predicate_registry.get(&atom.predicate).unwrap(); + let (to_bound_arguments, free_arguments): (Vec<_>, Vec<_>) = atom + .args + .iter() + .enumerate() + .partition(|(i, _)| *i < pred.num_bounded()); + (pred, to_bound_arguments, free_arguments) + } + + fn foreign_predicate_constant_constraints(&self, atom: &Atom, arguments: &Vec<(usize, &Term)>) -> Vec { + arguments.iter().filter_map(|(i, a)| match a { + Term::Constant(c) => { + let op1 = Term::variable(format!("c#{}#{}", &atom.predicate, i), ValueType::type_of(c)); + let op2 = Term::Constant(c.clone()); + Some(Constraint::eq(op1, op2)) } - }; + Term::Variable(_) => { + None + } + }).collect() + } - // Store the applied assigns - let mut applied_assigns = HashSet::new(); - let mut try_apply_assigns = |mut fringe: Plan| -> Plan { - // Find all the assigns that are needed to bound the need_projection_vars - let mut bounded_vars = fringe.bounded_vars.clone(); - loop { - let mut new_projections = Vec::new(); - - // Check if we can apply more assigns - for (i, assign) in self.assigns.iter().enumerate() { - if !applied_assigns.contains(&i) - && !bounded_vars.contains(&assign.left) - && assign.variable_args().into_iter().all(|v| bounded_vars.contains(v)) - { - applied_assigns.insert(i); - new_projections.push(assign.clone()); + fn foreign_predicate_equality_constraints(&self, var_eq: &Vec<(Variable, Variable)>) -> Vec { + var_eq.iter().map(|(v1, v2)| Constraint::eq(Term::Variable(v1.clone()), Term::Variable(v2.clone()))).collect() + } + + /// Given a list of free arguments, rename them to avoid name conflicts + /// + /// Specifically, the renaming is done in the following way: + /// + /// - If an argument is a variable + /// - If the variable only occurs once, then we do not rename it + /// - If the variable occurs more than once, then we rename it to `var#i` where `i` is the number of occurrences + /// - If an argument is a constant, then we rename it to `c#predicate#i` where `i` is the position of the argument + /// + /// In this case, all the arguments become distinct variables, where original variable names are preserved + fn rename_free_arguments( + &self, + predicate: &str, + arguments: &Vec<(usize, &Term)>, + occurred: &HashSet, + ) -> (Vec, Vec<(Variable, Variable)>) { + // First build a map from variable names to the number of occurrences + let mut var_occurrences: HashMap = occurred.iter().map(|v| (v.name.clone(), (v.clone(), 0))).collect(); + let mut var_equivalences: Vec<(Variable, Variable)> = Vec::new(); + let vars = arguments + .iter() + .map(|(i, a)| match a { + Term::Variable(v) => { + if let Some((old_var, occ)) = var_occurrences.get_mut(&v.name) { + // Create a new variable different than the previous argument + let new_var = Variable::new(format!("{}#{}", v.name, occ), v.ty.clone()); + + // Add to the list of equivalences + var_equivalences.push((old_var.clone(), new_var.clone())); + + // Update the information stored in the variable occurrences + *occ += 1; + *old_var = new_var.clone(); + + // Return the new variable + new_var + } else { + // Insert into variable occurrences + var_occurrences.insert(v.name.clone(), (v.clone(), 1)); + + // Return the variable + v.clone() } + }, + Term::Constant(c) => { + Variable::new(format!("c#{}#{}", predicate, i), ValueType::type_of(c)) } + }) + .collect::>(); + (vars, var_equivalences) + } - // Create projected left node - if new_projections.is_empty() { - break fringe; - } else { - bounded_vars.extend(new_projections.iter().map(|i| i.left.clone())); - fringe = Plan { - bounded_vars: bounded_vars.clone(), - ram_node: HighRamNode::Project(Box::new(fringe), new_projections), - }; + /// + fn compute_foreign_predicate_ground_atom( + &self, + atom: &Atom, + pred: &DynamicForeignPredicate, + free_arguments: &Vec<(usize, &Term)>, + ) -> Plan { + // The atom is grounded + let (all_vars, var_eq) = self.rename_free_arguments(&atom.predicate, free_arguments, &HashSet::new()); + + // Create a ground plan + let args = atom.args.iter().take(pred.num_bounded()).cloned().chain(all_vars.iter().cloned().map(Term::Variable)).collect(); + let ground_atom = Atom::new(atom.predicate.clone(), args); + let ground_plan = Plan { + bounded_vars: all_vars.iter().cloned().collect(), + ram_node: HighRamNode::ForeignPredicateGround(ground_atom), + }; + + // Get all the constraints + let const_constraints = self.foreign_predicate_constant_constraints(atom, &free_arguments); + let eq_constraints = self.foreign_predicate_equality_constraints(&var_eq); + let constraints = vec![const_constraints, eq_constraints].concat(); + + // Create a plan; + // if there are constraints, we need to create a filter plan on top of the ground plan + // Otherwise, we can just return the ground plan + if !constraints.is_empty() { + // Find the bounded vars + // Note that we do not use all the variables occurring in the constraints + let new_bounded_vars = free_arguments + .iter() + .filter_map(|(_, a)| match a { + Term::Variable(v) => Some(v.clone()), + _ => None + }) + .collect(); + + // Create a plan with filters on the constants + let filter_plan = Plan { + bounded_vars: new_bounded_vars, + ram_node: HighRamNode::filter(ground_plan, constraints) + }; + + filter_plan + } else { + ground_plan + } + } + + fn compute_foreign_predicate_join_atom( + &self, + left: Plan, + atom: &Atom, + pred: &DynamicForeignPredicate, + free_arguments: &Vec<(usize, &Term)>, + ) -> Plan { + // Get the arguments to the foreign predicate + let occurred_variables = left.bounded_vars.iter().cloned().collect(); + let (free_vars, var_eq) = self.rename_free_arguments(&atom.predicate, free_arguments, &occurred_variables); + + // Create an atom + let args = atom.args.iter().take(pred.num_bounded()).cloned().chain(free_vars.iter().cloned().map(Term::Variable)).collect(); + let to_join_atom = Atom::new(atom.predicate.clone(), args); + + // Create the join plan + let join_plan = Plan { + bounded_vars: left.bounded_vars.iter().cloned().chain(free_vars.iter().cloned()).collect(), + ram_node: HighRamNode::foreign_predicate_join(left.clone(), to_join_atom), + }; + + // Get all the constraints + let const_constraints = self.foreign_predicate_constant_constraints(atom, &free_arguments); + let eq_constraints = self.foreign_predicate_equality_constraints(&var_eq); + let constraints = vec![const_constraints, eq_constraints].concat(); + + // Create a plan; + // if there are constraints, we need to create a filter plan on top of the ground plan + // Otherwise, we can just return the ground plan + if !constraints.is_empty() { + // Find the bounded vars + // Note that we do not use all the variables occurring in the constraints + let right_bounded_vars: Vec<_> = free_arguments + .iter() + .filter_map(|(_, a)| match a { + Term::Variable(v) => Some(v.clone()), + _ => None + }) + .collect(); + + // Create a plan with filters on the constants + let filter_plan = Plan { + bounded_vars: left.bounded_vars.iter().chain(right_bounded_vars.iter()).cloned().collect(), + ram_node: HighRamNode::filter(join_plan, constraints) + }; + + filter_plan + } else { + join_plan + } + } + + /// Try to apply the foreign predicate atoms in the context. + /// Will scan all the foreign predicate atoms and see if there are any that can be applied. + /// + /// - `applied_foreign_predicate_atoms`: The set of already applied foreign predicate atoms, represented by their index + /// - `fringe`: The current Plan to apply the foreign predicate atoms + fn try_apply_foreign_predicate_atom(&self, applied_foreign_predicate_atoms: &mut HashSet, mut fringe: Plan) -> Plan { + // Find all the foreign predicate atoms + let bounded_vars = fringe.bounded_vars.clone(); + loop { + let mut applied = false; + + // Check if we can apply more foreign predicate atoms + for (i, atom) in self.foreign_predicate_pos_atoms.iter().enumerate() { + if !applied_foreign_predicate_atoms.contains(&i) { + // Get the foreign predicate information + let (pred, to_bound_arguments, free_arguments) = self.foreign_predicate_atom_info(atom); + + // Check if all the to-bound arguments are bounded; if so, it means that we can apply the atom + if to_bound_arguments.iter().all(|(_, a)| term_is_bounded(&bounded_vars, a)) { + // Mark the atom as applied + applied_foreign_predicate_atoms.insert(i); + + // There are 3 kinds of foreign predicate atoms: + // 1. There are no free arguments (i.e. all arguments are bounded) + // 2. Ground atom (i.e. all bounded arguments are constants) + // 3. Joining atom (i.e. some bounded arguments are variables) + // For each of these cases, we need to create a different plan + if free_arguments.is_empty() { + // The atom is completely bounded; we add a foreign predicate constraint plan + fringe = Plan { + bounded_vars: bounded_vars.clone(), + ram_node: HighRamNode::ForeignPredicateConstraint( + Box::new(fringe), + atom.clone(), + ), + }; + } else if to_bound_arguments.iter().all(|(_, a)| a.is_constant()) { + let plan = self.compute_foreign_predicate_ground_atom(atom, pred, &free_arguments); + + // Connect it with the existing plan + fringe = Plan { + bounded_vars: fringe.bounded_vars.union(&plan.bounded_vars).cloned().collect(), + ram_node: HighRamNode::join(fringe, plan), + }; + } else { + // The atom is bounded and new values can be generated + fringe = self.compute_foreign_predicate_join_atom(fringe, atom, pred, &free_arguments); + } + + // Found an atom that can be applied + applied = true; + } } } - }; - // Stage 2: Building the RAM tree bottom-up, starting with reduces + // Break the loop if no more foreign predicate atoms can be applied + if !applied { + break fringe; + } + } + } + + fn is_ground_foreign_atom(&self, atom: &Atom) -> bool { + let pred = self.foreign_predicate_registry.get(&atom.predicate).unwrap(); + atom.args.iter().take(pred.num_bounded()).all(|a| a.is_constant()) + } + + /// The main entry function that computes a query plan from a sequence of arcs + fn get_query_plan(&self, arcs: &Vec) -> Plan { + // ==== Stage 1: Helper Functions (Closures) ==== + + // Store the applied constraints + let mut applied_constraints = HashSet::new(); + let mut applied_assigns = HashSet::new(); + let mut applied_foreign_predicates = HashSet::new(); + + // ==== Stage 2: Building the RAM tree bottom-up, starting with reduces ==== // Build the first fringe - let (mut fringe, start_id) = if self.reduces.is_empty() { + let (mut fringe, start_arc_id) = if self.reduces.is_empty() { + // There is no reduce if arcs.is_empty() { - let node = Plan { - bounded_vars: HashSet::new(), - ram_node: HighRamNode::Unit, - }; - (node, 0) + // There is no arc + if self.foreign_predicate_pos_atoms.is_empty() { + // There is no reduce and there is no arc and there is no foreign predicate atom + let node = Plan::unit(); + (node, 0) + } else { + // Find the foreign predicate atom + if let Some((i, atom)) = self.foreign_predicate_pos_atoms.iter().enumerate().find(|(_, a)| self.is_ground_foreign_atom(a)) { + applied_foreign_predicates.insert(i); // Mark the atom as applied + let (pred, _, free_arguments) = self.foreign_predicate_atom_info(atom); + let plan = self.compute_foreign_predicate_ground_atom(atom, pred, &free_arguments); + (plan, 0) + } else { + panic!("[Internal Error] No foreign predicate atom is ground; should not happen"); + } + } } else { - // If there is no reduce, get it from the first arc + // If there is no reduce, find the first arc let first_arc = &arcs[0]; let node = Plan { bounded_vars: self.pos_atoms[first_arc.right].variable_args().cloned().collect(), @@ -209,7 +523,10 @@ impl<'a> QueryPlanContext<'a> { }; // Note: We always apply constraint first and then assigns - (try_apply_constraint(try_apply_assigns(node)), 1) + let node = self.try_apply_assigns(&mut applied_assigns, node); + let node = self.try_apply_constraint(&mut applied_constraints, node); + let node = self.try_apply_foreign_predicate_atom(&mut applied_foreign_predicates, node); + (node, 1) } } else { // If there is reduce, create a joined reduce @@ -229,16 +546,19 @@ impl<'a> QueryPlanContext<'a> { }; node = Plan { bounded_vars: left.bounded_vars.union(&right_bounded_vars).cloned().collect(), - ram_node: HighRamNode::Join(Box::new(left), Box::new(right)), + ram_node: HighRamNode::join(left, right), }; } // Note: We always apply constraint first and then assigns - (try_apply_constraint(try_apply_assigns(node)), 0) + let node = self.try_apply_assigns(&mut applied_assigns, node); + let node = self.try_apply_constraint(&mut applied_constraints, node); + let node = self.try_apply_foreign_predicate_atom(&mut applied_foreign_predicates, node); + (node, 0) }; - // Stage 3. Iterate through all the arcs, build the tree from bottom-up - for arc in &arcs[start_id..] { + // ==== Stage 3. Iterate through all the arcs, build the tree from bottom-up ==== + for arc in &arcs[start_arc_id..] { // Build the simple tree if arc.left.is_empty() { // A node that is not related to any of the node before; need product @@ -248,7 +568,7 @@ impl<'a> QueryPlanContext<'a> { ram_node: HighRamNode::Ground(self.pos_atoms[arc.right].clone()), }; let new_bounded_vars = left.bounded_vars.union(&right.bounded_vars).cloned().collect(); - let new_ram_node = HighRamNode::Join(Box::new(left), Box::new(right)); + let new_ram_node = HighRamNode::join(left, right); fringe = Plan { bounded_vars: new_bounded_vars, ram_node: new_ram_node, @@ -269,23 +589,25 @@ impl<'a> QueryPlanContext<'a> { // Create joined node fringe = Plan { bounded_vars: left.bounded_vars.union(&right.bounded_vars).cloned().collect(), - ram_node: HighRamNode::Join(Box::new(left), Box::new(right)), + ram_node: HighRamNode::join(left, right), }; } // Note: We always apply constraint first and then assigns - fringe = try_apply_constraint(try_apply_assigns(fringe)); + fringe = self.try_apply_assigns(&mut applied_assigns, fringe); + fringe = self.try_apply_constraint(&mut applied_constraints, fringe); + fringe = self.try_apply_foreign_predicate_atom(&mut applied_foreign_predicates, fringe); } - // Apply negative atoms + // ==== Stage 4: Apply negative atoms ==== for neg_atom in &self.neg_atoms { let neg_node = Plan { - bounded_vars: neg_atom.atom.variable_args().cloned().collect(), - ram_node: HighRamNode::Ground(neg_atom.atom.clone()), + bounded_vars: neg_atom.variable_args().cloned().collect(), + ram_node: HighRamNode::Ground(neg_atom.clone()), }; fringe = Plan { bounded_vars: fringe.bounded_vars.clone(), - ram_node: HighRamNode::Antijoin(Box::new(fringe), Box::new(neg_node)), + ram_node: HighRamNode::antijoin(fringe, neg_node), }; } @@ -338,10 +660,7 @@ impl State { vec![] } else { let mut next_states: Vec = vec![]; - for (id, atom) in (0..ctx.pos_atoms.len()) - .filter(|i| !self.visited_atoms.contains(i)) - .map(|i| (i, &ctx.pos_atoms[i])) - { + for (id, atom) in ctx.pos_atoms.iter().enumerate().filter(|(i, _)| !self.visited_atoms.contains(i)) { for set in self.visited_atoms.iter().powerset() { let all_bounded_args = ctx.bounded_args_from_pos_atoms_set(&set); let bounded_vars = atom @@ -395,6 +714,13 @@ pub struct Plan { } impl Plan { + pub fn unit() -> Self { + Self { + bounded_vars: HashSet::new(), + ram_node: HighRamNode::Unit, + } + } + pub fn pretty_print(&self) { self.pretty_print_helper(0); } @@ -437,6 +763,17 @@ impl Plan { x.pretty_print_helper(depth + 1); y.pretty_print_helper(depth + 1); } + HighRamNode::ForeignPredicateGround(a) => { + println!("{}ForeignPredicateGround {{{}}}", prefix, a); + } + HighRamNode::ForeignPredicateConstraint(x, a) => { + println!("{}ForeignPredicateConstraint {{{}}}", prefix, a); + x.pretty_print_helper(depth + 1); + } + HighRamNode::ForeignPredicateJoin(x, a) => { + println!("{}ForeignPredicateJoin {{{}}}", prefix, a); + x.pretty_print_helper(depth + 1); + } } } } @@ -450,9 +787,32 @@ pub enum HighRamNode { Filter(Box, Vec), Join(Box, Box), Antijoin(Box, Box), + ForeignPredicateGround(Atom), + ForeignPredicateConstraint(Box, Atom), + ForeignPredicateJoin(Box, Atom), } impl HighRamNode { + /// Create a new FILTER high level ram node + pub fn filter(p1: Plan, cs: Vec) -> Self { + Self::Filter(Box::new(p1), cs) + } + + /// Create a new JOIN high level ram node + pub fn join(p1: Plan, p2: Plan) -> Self { + Self::Join(Box::new(p1), Box::new(p2)) + } + + /// Create a new ANTIJOIN high level ram node + pub fn antijoin(p1: Plan, p2: Plan) -> Self { + Self::Antijoin(Box::new(p1), Box::new(p2)) + } + + /// Create a new Foreign Predicate Join high level ram node + pub fn foreign_predicate_join(p1: Plan, a: Atom) -> Self { + Self::ForeignPredicateJoin(Box::new(p1), a) + } + pub fn direct_atom(&self) -> Option<&Atom> { match self { Self::Ground(a) => Some(a), diff --git a/core/src/compiler/back/scc.rs b/core/src/compiler/back/scc.rs index 5521edb..153f003 100644 --- a/core/src/compiler/back/scc.rs +++ b/core/src/compiler/back/scc.rs @@ -8,6 +8,7 @@ use std::collections::*; use super::{ast::*, BackCompileError}; +/// The type of a dependency graph edge: positive, negative, or aggregation #[derive(Debug, Clone)] pub enum DependencyGraphEdge { Positive, @@ -26,6 +27,7 @@ impl std::fmt::Display for DependencyGraphEdge { } impl DependencyGraphEdge { + /// Check if the edge needs to be stratified pub fn needs_stratification(&self) -> bool { match self { Self::Positive => false, @@ -34,6 +36,7 @@ impl DependencyGraphEdge { } } +/// A graph storing the dependencies between predicates #[derive(Clone)] pub struct DependencyGraph { graph: Graph, @@ -42,6 +45,7 @@ pub struct DependencyGraph { } impl DependencyGraph { + /// Create an empty dependency graph pub fn new() -> Self { Self { graph: Graph::new(), @@ -50,6 +54,7 @@ impl DependencyGraph { } } + /// Add a predicate to the dependency graph pub fn add_predicate(&mut self, predicate: &String) -> NodeIndex { if let Some(ni) = self.predicate_to_node_id.get(predicate) { ni.clone() @@ -60,10 +65,12 @@ impl DependencyGraph { } } + /// Get the node id of a predicate pub fn predicate_node(&self, predicate: &String) -> NodeIndex { self.predicate_to_node_id.get(predicate).unwrap().clone() } + /// Add a dependency between two predicates pub fn add_dependency(&mut self, src: &String, dst: &String, edge: DependencyGraphEdge) -> EdgeIndex { let src_id = self.predicate_to_node_id[src]; let dst_id = self.predicate_to_node_id[dst]; @@ -242,11 +249,15 @@ impl Program { match atom { Literal::Atom(a) => { let atom_predicate = &a.predicate; - graph.add_dependency(head_predicate, atom_predicate, E::Positive); + if !self.predicate_registry.contains(atom_predicate) { + graph.add_dependency(head_predicate, atom_predicate, E::Positive); + } } Literal::NegAtom(a) => { let atom_predicate = &a.atom.predicate; - graph.add_dependency(head_predicate, atom_predicate, E::Negative); + if !self.predicate_registry.contains(atom_predicate) { + graph.add_dependency(head_predicate, atom_predicate, E::Negative); + } } Literal::Reduce(r) => { let reduce_predicate = &r.body_formula.predicate; diff --git a/core/src/compiler/front/analysis.rs b/core/src/compiler/front/analysis.rs index db42757..65e89e1 100644 --- a/core/src/compiler/front/analysis.rs +++ b/core/src/compiler/front/analysis.rs @@ -1,10 +1,13 @@ use crate::common::foreign_function::ForeignFunctionRegistry; +use crate::common::foreign_predicate::ForeignPredicateRegistry; use super::analyzers::*; use super::*; +/// The front analysis object that stores all the analysis results and errors #[derive(Clone, Debug)] pub struct Analysis { + pub invalid_constant: InvalidConstantAnalyzer, pub invalid_wildcard: InvalidWildcardAnalyzer, pub input_files_analysis: InputFilesAnalysis, pub output_files_analysis: OutputFilesAnalysis, @@ -19,8 +22,13 @@ pub struct Analysis { } impl Analysis { - pub fn new(function_registry: &ForeignFunctionRegistry) -> Self { + /// Create a new front IR analysis object + pub fn new( + function_registry: &ForeignFunctionRegistry, + predicate_registry: &ForeignPredicateRegistry, + ) -> Self { Self { + invalid_constant: InvalidConstantAnalyzer::new(), invalid_wildcard: InvalidWildcardAnalyzer::new(), input_files_analysis: InputFilesAnalysis::new(), output_files_analysis: OutputFilesAnalysis::new(), @@ -28,9 +36,9 @@ impl Analysis { aggregation_analysis: AggregationAnalysis::new(), character_literal_analysis: CharacterLiteralAnalysis::new(), constant_decl_analysis: ConstantDeclAnalysis::new(), - head_relation_analysis: HeadRelationAnalysis::new(), - type_inference: TypeInference::new(function_registry), - boundness_analysis: BoundnessAnalysis::new(), + head_relation_analysis: HeadRelationAnalysis::new(predicate_registry), + type_inference: TypeInference::new(function_registry, predicate_registry), + boundness_analysis: BoundnessAnalysis::new(predicate_registry), demand_attr_analysis: DemandAttributeAnalysis::new(), } } @@ -43,6 +51,7 @@ impl Analysis { &mut self.aggregation_analysis, &mut self.character_literal_analysis, &mut self.constant_decl_analysis, + &mut self.invalid_constant, &mut self.invalid_wildcard, ); analyzers.walk_items(items); @@ -71,6 +80,7 @@ impl Analysis { pub fn dump_errors(&mut self, error_ctx: &mut FrontCompileError) { error_ctx.extend(&mut self.input_files_analysis.errors); + error_ctx.extend(&mut self.invalid_constant.errors); error_ctx.extend(&mut self.invalid_wildcard.errors); error_ctx.extend(&mut self.aggregation_analysis.errors); error_ctx.extend(&mut self.character_literal_analysis.errors); diff --git a/core/src/compiler/front/analyzers/boundness/boundness_analysis.rs b/core/src/compiler/front/analyzers/boundness/boundness_analysis.rs index 3500300..70b1a70 100644 --- a/core/src/compiler/front/analyzers/boundness/boundness_analysis.rs +++ b/core/src/compiler/front/analyzers/boundness/boundness_analysis.rs @@ -2,22 +2,30 @@ use std::collections::*; use super::super::*; use super::*; + +use crate::common::foreign_predicate::*; use crate::compiler::front::*; #[derive(Clone, Debug)] pub struct BoundnessAnalysis { + pub predicate_bindings: ForeignPredicateBindings, pub rule_contexts: HashMap, pub errors: Vec, } impl BoundnessAnalysis { - pub fn new() -> Self { + pub fn new(registry: &ForeignPredicateRegistry) -> Self { Self { + predicate_bindings: registry.into(), rule_contexts: HashMap::new(), errors: Vec::new(), } } + pub fn add_foreign_predicate(&mut self, fp: &F) { + self.predicate_bindings.add(fp) + } + pub fn get_rule_context(&self, loc: &Loc) -> Option<&RuleContext> { self.rule_contexts.get(loc).map(|(_, ctx, _)| ctx) } @@ -42,7 +50,7 @@ impl BoundnessAnalysis { }; // Compute the boundness - match ctx.compute_boundness(&bounded_exprs) { + match ctx.compute_boundness(&self.predicate_bindings, &bounded_exprs) { Ok(_) => {} Err(errs) => { self.errors.extend(errs); diff --git a/core/src/compiler/front/analyzers/boundness/context.rs b/core/src/compiler/front/analyzers/boundness/context.rs index 1f516eb..61bcc33 100644 --- a/core/src/compiler/front/analyzers/boundness/context.rs +++ b/core/src/compiler/front/analyzers/boundness/context.rs @@ -32,8 +32,12 @@ impl RuleContext { Self { head_vars, body } } - pub fn compute_boundness(&self, bounded_exprs: &Vec) -> Result, Vec> { - let bounded_vars = self.body.compute_boundness(bounded_exprs)?; + pub fn compute_boundness( + &self, + predicate_bindings: &ForeignPredicateBindings, + bounded_exprs: &Vec, + ) -> Result, Vec> { + let bounded_vars = self.body.compute_boundness(predicate_bindings, bounded_exprs)?; for (var_name, var_loc) in &self.head_vars { if !bounded_vars.contains(var_name) { let err = BoundnessAnalysisError::HeadExprUnbound { loc: var_loc.clone() }; @@ -72,16 +76,20 @@ impl DisjunctionContext { Self { conjuncts } } - pub fn compute_boundness(&self, bounded_exprs: &Vec) -> Result, Vec> { + pub fn compute_boundness( + &self, + predicate_bindings: &ForeignPredicateBindings, + bounded_exprs: &Vec, + ) -> Result, Vec> { if self.conjuncts.is_empty() { Ok(BTreeSet::new()) } else if self.conjuncts.len() == 1 { - self.conjuncts[0].compute_boundness(bounded_exprs) + self.conjuncts[0].compute_boundness(predicate_bindings, bounded_exprs) } else { - let set1 = self.conjuncts[0].compute_boundness(bounded_exprs)?; + let set1 = self.conjuncts[0].compute_boundness(predicate_bindings, bounded_exprs)?; let other_sets = self.conjuncts[1..] .iter() - .map(|c| c.compute_boundness(bounded_exprs)) + .map(|c| c.compute_boundness(predicate_bindings, bounded_exprs)) .collect::>, _>>()?; Ok( set1 @@ -141,13 +149,17 @@ impl ConjunctionContext { } } - pub fn compute_boundness(&self, bounded_exprs: &Vec) -> Result, Vec> { - let mut local_ctx = LocalBoundnessAnalysisContext::new(); + pub fn compute_boundness( + &self, + predicate_bindings: &ForeignPredicateBindings, + bounded_exprs: &Vec, + ) -> Result, Vec> { + let mut local_ctx = LocalBoundnessAnalysisContext::new(predicate_bindings); // First check if the aggregation's boundness is okay for agg_context in &self.agg_contexts { // The bounded variables inside the aggregation is part of the bounded vars - let bounded_args = agg_context.compute_boundness(bounded_exprs)?; + let bounded_args = agg_context.compute_boundness(predicate_bindings, bounded_exprs)?; local_ctx.bounded_variables.extend(bounded_args); } @@ -240,17 +252,21 @@ impl AggregationContext { } } - pub fn compute_boundness(&self, bounded_exprs: &Vec) -> Result, Vec> { + pub fn compute_boundness( + &self, + predicate_bindings: &ForeignPredicateBindings, + bounded_exprs: &Vec, + ) -> Result, Vec> { // Construct the bounded let mut bounded = HashSet::new(); // If group_by is presented, check the gruop_by binding variables are properly bounded if let Some((group_by_ctx, _, _)) = &self.group_by { - group_by_ctx.compute_boundness(bounded_exprs)?; + group_by_ctx.compute_boundness(predicate_bindings, bounded_exprs)?; } // Add all the bounded variables in the aggregation body - bounded.extend(self.joined_body.compute_boundness(bounded_exprs)?); + bounded.extend(self.joined_body.compute_boundness(predicate_bindings, bounded_exprs)?); // Remove the qualified variable for binding_name in &self.binding_vars { diff --git a/core/src/compiler/front/analyzers/boundness/dependency.rs b/core/src/compiler/front/analyzers/boundness/dependency.rs index c4f53a7..5316330 100644 --- a/core/src/compiler/front/analyzers/boundness/dependency.rs +++ b/core/src/compiler/front/analyzers/boundness/dependency.rs @@ -5,6 +5,9 @@ pub enum BoundnessDependency { /// Argument to a relation RelationArg(Loc), + /// Foreign predicate arguments: predicate, bounded arguments, and to-bound arguments + ForeignPredicateArgs(Vec, Vec), + /// Constant loc, is bounded Constant(Loc), diff --git a/core/src/compiler/front/analyzers/boundness/foreign.rs b/core/src/compiler/front/analyzers/boundness/foreign.rs new file mode 100644 index 0000000..002389b --- /dev/null +++ b/core/src/compiler/front/analyzers/boundness/foreign.rs @@ -0,0 +1,32 @@ +use std::collections::*; + +use crate::common::foreign_predicate::*; + +#[derive(Clone, Debug)] +pub struct ForeignPredicateBindings { + bindings: HashMap, +} + +impl ForeignPredicateBindings { + pub fn contains(&self, name: &str) -> bool { + self.bindings.contains_key(name) + } + + pub fn add(&mut self, fp: &F) { + self.bindings.insert(fp.name(), fp.binding_pattern()); + } + + pub fn get(&self, name: &str) -> Option<&BindingPattern> { + self.bindings.get(name) + } +} + +impl From<&ForeignPredicateRegistry> for ForeignPredicateBindings { + fn from(registry: &ForeignPredicateRegistry) -> Self { + let bindings = registry + .iter() + .map(|(name, pred)| (name.clone(), pred.binding_pattern())) + .collect(); + Self { bindings } + } +} diff --git a/core/src/compiler/front/analyzers/boundness/local.rs b/core/src/compiler/front/analyzers/boundness/local.rs index a0f81eb..abcf5ad 100644 --- a/core/src/compiler/front/analyzers/boundness/local.rs +++ b/core/src/compiler/front/analyzers/boundness/local.rs @@ -1,11 +1,13 @@ use std::collections::*; use super::*; + use crate::compiler::front::ast::*; use crate::compiler::front::visitor::*; #[derive(Clone, Debug)] -pub struct LocalBoundnessAnalysisContext { +pub struct LocalBoundnessAnalysisContext<'a> { + pub foreign_predicate_bindings: &'a ForeignPredicateBindings, pub expr_boundness: HashMap, pub dependencies: Vec, pub variable_locations: HashMap>, @@ -14,12 +16,19 @@ pub struct LocalBoundnessAnalysisContext { pub errors: Vec, } -impl NodeVisitor for LocalBoundnessAnalysisContext { +impl<'a> NodeVisitor for LocalBoundnessAnalysisContext<'a> { fn visit_atom(&mut self, atom: &Atom) { - for arg in atom.iter_arguments() { - let loc = arg.location().clone(); - let dep = BoundnessDependency::RelationArg(loc); + if let Some(binding) = self.foreign_predicate_bindings.get(atom.predicate()) { + let bounded = atom.iter_arguments().enumerate().filter_map(|(i, a)| if binding[i].is_bound() { Some(a.location().clone()) } else { None } ).collect(); + let to_bound = atom.iter_arguments().enumerate().filter_map(|(i, a)| if binding[i].is_free() { Some(a.location().clone()) } else { None } ).collect(); + let dep = BoundnessDependency::ForeignPredicateArgs(bounded, to_bound); self.dependencies.push(dep); + } else { + for arg in atom.iter_arguments() { + let loc = arg.location().clone(); + let dep = BoundnessDependency::RelationArg(loc); + self.dependencies.push(dep); + } } } @@ -94,9 +103,10 @@ impl NodeVisitor for LocalBoundnessAnalysisContext { } } -impl LocalBoundnessAnalysisContext { - pub fn new() -> Self { +impl<'a> LocalBoundnessAnalysisContext<'a> { + pub fn new(foreign_predicate_bindings: &'a ForeignPredicateBindings) -> Self { Self { + foreign_predicate_bindings, expr_boundness: HashMap::new(), dependencies: Vec::new(), variable_locations: HashMap::new(), @@ -130,6 +140,13 @@ impl LocalBoundnessAnalysisContext { RelationArg(l) => { update(&mut self.expr_boundness, l, true); } + ForeignPredicateArgs(bounded_args, to_bound_args) => { + if bounded_args.iter().all(|l| get(&self.expr_boundness, l)) { + for to_bound_arg in to_bound_args { + update(&mut self.expr_boundness, to_bound_arg, true) + } + } + } Constant(l) => { update(&mut self.expr_boundness, l, true); } diff --git a/core/src/compiler/front/analyzers/boundness/mod.rs b/core/src/compiler/front/analyzers/boundness/mod.rs index b2da031..cb92477 100644 --- a/core/src/compiler/front/analyzers/boundness/mod.rs +++ b/core/src/compiler/front/analyzers/boundness/mod.rs @@ -2,6 +2,7 @@ mod boundness_analysis; mod context; mod dependency; mod error; +mod foreign; mod local; use super::super::utils::*; @@ -9,4 +10,5 @@ pub use boundness_analysis::*; pub use context::*; pub use dependency::*; pub use error::*; +pub use foreign::*; pub use local::*; diff --git a/core/src/compiler/front/analyzers/constant_decl.rs b/core/src/compiler/front/analyzers/constant_decl.rs index f426f1d..48d31d1 100644 --- a/core/src/compiler/front/analyzers/constant_decl.rs +++ b/core/src/compiler/front/analyzers/constant_decl.rs @@ -5,6 +5,12 @@ use super::super::error::*; use super::super::utils::*; use super::super::*; +/// Constant declaration analysis +/// +/// Analyzes the constant declarations coming from `ConstAssignment` and `EnumTypeDecl`. +/// After walking through AST, the analysis checks whether there is duplicated constant +/// declarations, unknown constants, and etc. +/// It stores the locations and other information where a constant is used and declared. #[derive(Clone, Debug)] pub struct ConstantDeclAnalysis { pub variables: HashMap, Constant)>, @@ -13,6 +19,7 @@ pub struct ConstantDeclAnalysis { } impl ConstantDeclAnalysis { + /// Create a new analysis pub fn new() -> Self { Self { variables: HashMap::new(), @@ -21,10 +28,15 @@ impl ConstantDeclAnalysis { } } + /// Get the variable information stored in the analysis, including + /// its declaration location, its type, and the constant it is associated with. + /// `None` is returned if such variable does not exist. pub fn get_variable(&self, var: &str) -> Option<&(Loc, Option, Constant)> { self.variables.get(var) } + /// Given a location where a constant variable is used, find the type of that variable. + /// `None` is returned if this location is not recorded or the variable is not annotated with a type. pub fn loc_of_const_type(&self, loc: &Loc) -> Option { self .variable_use @@ -46,6 +58,81 @@ impl ConstantDeclAnalysis { }) .collect() } + + pub fn process_enum_type_decl(&mut self, etd: &ast::EnumTypeDecl) -> Result<(), ConstantDeclError> { + let extract_value = |member: &ast::EnumTypeMember, prev_max: Option| -> Result { + // First check if there is an integer number assignment to the enum + match member.assigned_number() { + Some(c) => match &c.node { + // If there is, we check if the integer is greater than or equal to zero and greater than the previous maximum + ast::ConstantNode::Integer(i) if *i >= 0 => { + let i = *i; + // Check if we have a previous number already + if let Some(prev_max) = prev_max { + if i > prev_max { + // If the number is greater than previous number, then ok to directly assign the number + return Ok(i); + } else { + // If the number is not greater, then this enum value ID is invalid + return Err(ConstantDeclError::EnumIDAlreadyAssigned { + curr_name: member.name().to_string(), + id: i, + loc: member.location().clone(), + }); + } + } else { + // If there is no previous max, then directly give it `i`. + return Ok(i) + } + } + _ => { + // We don't care other cases + } + } + _ => {} + }; + + // If the assignment is not presented, we simply increment the previous maximum value + if let Some(prev_max) = prev_max { + return Ok(prev_max + 1); + } else { + return Ok(0); + } + }; + + let mut process_member = |member: &ast::EnumTypeMember, id: i64| -> Result<(), ConstantDeclError> { + if let Some((first_decl_loc, _, _)) = self.variables.get(member.name()) { + Err(ConstantDeclError::DuplicatedConstant { + name: member.name().to_string(), + first_decl: first_decl_loc.clone(), + second_decl: member.location().clone(), + }) + } else { + // Then store the variable into the storage + self.variables.insert( + member.name().to_string(), + (member.location().clone(), Some(Type::usize()), Constant::integer(id as i64)) + ); + Ok(()) + } + }; + + // Go through all the members + let mut members_iterator = etd.iter_members(); + + // First process the first member + let first_member = members_iterator.next().unwrap(); // Unwrap is ok since there has to be at least two components + let mut curr_id = extract_value(first_member, None)?; + process_member(first_member, curr_id)?; + + // Then process the rest + while let Some(curr_member) = members_iterator.next() { + curr_id = extract_value(curr_member, Some(curr_id))?; + process_member(curr_member, curr_id)?; + } + + Ok(()) + } } impl NodeVisitor for ConstantDeclAnalysis { @@ -66,6 +153,12 @@ impl NodeVisitor for ConstantDeclAnalysis { } } + fn visit_enum_type_decl(&mut self, etd: &ast::EnumTypeDecl) { + if let Err(e) = self.process_enum_type_decl(etd) { + self.errors.push(e); + } + } + fn visit_constant_set_tuple(&mut self, cst: &ConstantSetTuple) { for c in cst.iter_constants() { if let Some(v) = c.variable() { @@ -132,6 +225,11 @@ pub enum ConstantDeclError { name: String, loc: Loc, }, + EnumIDAlreadyAssigned { + curr_name: String, + id: i64, + loc: Loc, + }, } impl FrontCompileErrorTrait for ConstantDeclError { @@ -163,6 +261,9 @@ impl FrontCompileErrorTrait for ConstantDeclError { Self::UnknownConstantVariable { name, loc } => { format!("unknown variable `{}`:\n{}", name, loc.report(src)) } + Self::EnumIDAlreadyAssigned { curr_name, id, loc } => { + format!("the enum ID `{}` for variant `{}` has already been assigned\n{}", id, curr_name, loc.report(src)) + } } } } diff --git a/core/src/compiler/front/analyzers/head_relation.rs b/core/src/compiler/front/analyzers/head_relation.rs index 3c06a6c..520450a 100644 --- a/core/src/compiler/front/analyzers/head_relation.rs +++ b/core/src/compiler/front/analyzers/head_relation.rs @@ -1,5 +1,7 @@ use std::collections::*; +use crate::common::foreign_predicate::*; + use super::super::utils::*; use super::super::*; @@ -11,14 +13,19 @@ pub struct HeadRelationAnalysis { } impl HeadRelationAnalysis { - pub fn new() -> Self { + pub fn new(foreign_predicate_registry: &ForeignPredicateRegistry) -> Self { + let declared_relations = foreign_predicate_registry.iter().map(|(_, p)| p.name().to_string()).collect(); Self { errors: vec![], used_relations: HashMap::new(), - declared_relations: HashSet::new(), + declared_relations, } } + pub fn add_foreign_predicate(&mut self, fp: &F) { + self.declared_relations.insert(fp.name().to_string()); + } + pub fn compute_errors(&mut self) { let used_relations_set = self.used_relations.keys().cloned().collect::>(); for r in used_relations_set.difference(&self.declared_relations) { @@ -52,7 +59,7 @@ impl NodeVisitor for HeadRelationAnalysis { fn visit_query(&mut self, qd: &ast::Query) { self .used_relations - .insert(qd.relation_name().to_string(), qd.location().clone()); + .insert(qd.create_relation_name().to_string(), qd.location().clone()); } fn visit_atom(&mut self, a: &ast::Atom) { diff --git a/core/src/compiler/front/analyzers/invalid_constant.rs b/core/src/compiler/front/analyzers/invalid_constant.rs new file mode 100644 index 0000000..f2cd34f --- /dev/null +++ b/core/src/compiler/front/analyzers/invalid_constant.rs @@ -0,0 +1,48 @@ +use super::super::*; + +#[derive(Clone, Debug)] +pub struct InvalidConstantAnalyzer { + pub errors: Vec, +} + +impl InvalidConstantAnalyzer { + pub fn new() -> Self { + Self { errors: Vec::new() } + } +} + +impl NodeVisitor for InvalidConstantAnalyzer { + fn visit_constant(&mut self, constant: &Constant) { + match &constant.node { + ConstantNode::Invalid(message) => { + self.errors.push(InvalidConstantError::InvalidConstant { + loc: constant.location().clone(), + message: message.clone(), + }); + } + _ => {} + } + } +} + +#[derive(Clone, Debug)] +pub enum InvalidConstantError { + InvalidConstant { + loc: AstNodeLocation, + message: String, + }, +} + +impl FrontCompileErrorTrait for InvalidConstantError { + fn error_type(&self) -> FrontCompileErrorType { + FrontCompileErrorType::Error + } + + fn report(&self, src: &Sources) -> String { + match self { + Self::InvalidConstant { loc, message } => { + format!("Invalid constant: {}\n{}", message, loc.report(src)) + } + } + } +} diff --git a/core/src/compiler/front/analyzers/mod.rs b/core/src/compiler/front/analyzers/mod.rs index e07d550..96b4a35 100644 --- a/core/src/compiler/front/analyzers/mod.rs +++ b/core/src/compiler/front/analyzers/mod.rs @@ -6,6 +6,7 @@ pub mod demand_attr; pub mod head_relation; pub mod hidden_relation; pub mod input_files; +pub mod invalid_constant; pub mod invalid_wildcard; pub mod output_files; pub mod type_inference; @@ -18,6 +19,7 @@ pub use demand_attr::DemandAttributeAnalysis; pub use head_relation::HeadRelationAnalysis; pub use hidden_relation::HiddenRelationAnalysis; pub use input_files::InputFilesAnalysis; +pub use invalid_constant::InvalidConstantAnalyzer; pub use invalid_wildcard::InvalidWildcardAnalyzer; pub use output_files::OutputFilesAnalysis; pub use type_inference::TypeInference; @@ -29,6 +31,7 @@ pub mod errors { pub use super::demand_attr::DemandAttributeError; pub use super::head_relation::HeadRelationError; pub use super::input_files::InputFilesError; + pub use super::invalid_constant::InvalidConstantError; pub use super::invalid_wildcard::InvalidWildcardError; pub use super::output_files::OutputFilesError; pub use super::type_inference::TypeInferenceError; diff --git a/core/src/compiler/front/analyzers/output_files.rs b/core/src/compiler/front/analyzers/output_files.rs index 4db6dc7..bb1456e 100644 --- a/core/src/compiler/front/analyzers/output_files.rs +++ b/core/src/compiler/front/analyzers/output_files.rs @@ -98,7 +98,7 @@ impl OutputFilesAnalysis { impl NodeVisitor for OutputFilesAnalysis { fn visit_query_decl(&mut self, qd: &QueryDecl) { - self.process_attributes(qd.query().relation_name(), qd.attributes()); + self.process_attributes(qd.query().create_relation_name(), qd.attributes()); } } diff --git a/core/src/compiler/front/analyzers/type_inference/error.rs b/core/src/compiler/front/analyzers/type_inference/error.rs index 779ba17..c8f9ae7 100644 --- a/core/src/compiler/front/analyzers/type_inference/error.rs +++ b/core/src/compiler/front/analyzers/type_inference/error.rs @@ -53,6 +53,11 @@ pub enum TypeInferenceError { source_loc: AstNodeLocation, access_loc: AstNodeLocation, }, + InvalidForeignPredicateArgIndex { + predicate: String, + index: usize, + access_loc: AstNodeLocation, + }, ConstantSetArityMismatch { predicate: String, decl_loc: AstNodeLocation, @@ -62,11 +67,32 @@ pub enum TypeInferenceError { expected: ValueType, found: TypeSet, }, + BadEnumValueKind { + found: &'static str, + loc: AstNodeLocation, + }, + NegativeEnumValue { + found: i64, + loc: AstNodeLocation, + }, CannotUnifyTypes { t1: TypeSet, t2: TypeSet, loc: Option, }, + CannotUnifyForeignPredicateArgument { + pred: String, + i: usize, + expected_ty: TypeSet, + actual_ty: TypeSet, + loc: AstNodeLocation, + }, + NoMatchingTripletRule { + op1_ty: TypeSet, + op2_ty: TypeSet, + e_ty: TypeSet, + location: AstNodeLocation, + }, CannotUnifyVariables { v1: String, t1: TypeSet, @@ -100,6 +126,14 @@ pub enum TypeInferenceError { num_binding_vars: usize, loc: AstNodeLocation, }, + CannotRedefineForeignPredicate { + pred: String, + loc: AstNodeLocation, + }, + CannotQueryForeignPredicate { + pred: String, + loc: AstNodeLocation, + }, } impl TypeInferenceError { @@ -196,6 +230,16 @@ impl FrontCompileErrorTrait for TypeInferenceError { index, predicate, source_loc.report(src), access_loc.report(src) ) } + Self::InvalidForeignPredicateArgIndex { + predicate, + index, + access_loc, + } => { + format!( + "Invalid `{}`-th argument for foreign predicate `{}`:\n{}", + index, predicate, access_loc.report(src) + ) + } Self::ConstantSetArityMismatch { predicate, mismatch_tuple_loc, @@ -215,6 +259,20 @@ impl FrontCompileErrorTrait for TypeInferenceError { found.location().report(src) ) } + Self::BadEnumValueKind { found, loc } => { + format!( + "bad enum value. Expected unsigned integers, found `{}`\n{}", + found, + loc.report(src), + ) + } + Self::NegativeEnumValue { found, loc } => { + format!( + "enum value `{}` found to be negative. Expected unsigned integers\n{}", + found, + loc.report(src), + ) + } Self::CannotUnifyTypes { t1, t2, loc } => match loc { Some(l) => { format!("cannot unify types `{}` and `{}` in\n{}\nwhere the first is inferred here\n{}\nand the second is inferred here\n{}", t1, t2, l.report(src), t1.location().report(src), t2.location().report(src)) @@ -225,7 +283,17 @@ impl FrontCompileErrorTrait for TypeInferenceError { t1, t2, t1.location().report(src), t2.location().report(src) ) } - }, + } + Self::CannotUnifyForeignPredicateArgument { pred, i, expected_ty, actual_ty, loc } => { + format!( + "cannot unify the type of {}-th argument of foreign predicate `{}`, expected type `{}`, found `{}`:\n{}", + i, + pred, + expected_ty, + actual_ty, + loc.report(src), + ) + } Self::CannotUnifyVariables { v1, t1, v2, t2, loc } => { format!( "cannot unify variable types: `{}` has `{}` type, `{}` has `{}` type, but they should be unified\n{}", @@ -236,6 +304,15 @@ impl FrontCompileErrorTrait for TypeInferenceError { loc.report(src) ) } + Self::NoMatchingTripletRule { op1_ty, op2_ty, e_ty, location } => { + format!( + "no matching rule found; two operands have type `{}` and `{}`, while the expression has type `{}`:\n{}", + op1_ty, + op2_ty, + e_ty, + location.report(src) + ) + } Self::CannotTypeCast { t1, t2, loc } => { format!("cannot cast type from `{}` to `{}`\n{}", t1, t2, loc.report(src)) } @@ -284,6 +361,20 @@ impl FrontCompileErrorTrait for TypeInferenceError { num_output_vars, num_binding_vars, loc.report(src) ) } + Self::CannotRedefineForeignPredicate { pred, loc } => { + format!( + "the predicate `{}` is being defined here, but it is also a foreign predicate which cannot be populated\n{}", + pred, + loc.report(src), + ) + } + Self::CannotQueryForeignPredicate { pred, loc } => { + format!( + "the foreign predicate `{}` cannot be queried\n{}", + pred, + loc.report(src), + ) + } } } } diff --git a/core/src/compiler/front/analyzers/type_inference/foreign_function.rs b/core/src/compiler/front/analyzers/type_inference/foreign_function.rs new file mode 100644 index 0000000..7d94ce8 --- /dev/null +++ b/core/src/compiler/front/analyzers/type_inference/foreign_function.rs @@ -0,0 +1,172 @@ +use std::collections::*; + +use crate::common::foreign_function::*; +use crate::common::value_type::*; + +use super::*; + +/// Argument type of a function, which could be a generic type parameter or a type set (including base type) +#[derive(Clone, Debug)] +pub enum FunctionArgumentType { + Generic(usize), + TypeSet(TypeSet), +} + +impl From for FunctionArgumentType { + fn from(value: ForeignFunctionParameterType) -> Self { + match value { + ForeignFunctionParameterType::BaseType(ty) => Self::TypeSet(TypeSet::base(ty)), + ForeignFunctionParameterType::TypeFamily(ty) => Self::TypeSet(TypeSet::from(ty)), + ForeignFunctionParameterType::Generic(i) => Self::Generic(i), + } + } +} + +impl FunctionArgumentType { + pub fn is_generic(&self) -> bool { + match self { + Self::Generic(_) => false, + _ => true, + } + } +} + +/// Return type of a function, which need to be a generic type parameter or a base type (cannot be type set) +#[derive(Clone, Debug)] +pub enum FunctionReturnType { + Generic(usize), + BaseType(ValueType), +} + +impl From for FunctionReturnType { + fn from(value: ForeignFunctionParameterType) -> Self { + match value { + ForeignFunctionParameterType::BaseType(ty) => Self::BaseType(ty), + ForeignFunctionParameterType::Generic(i) => Self::Generic(i), + _ => panic!("Return type cannot be of type family"), + } + } +} + +/// The function type +#[derive(Clone, Debug)] +pub struct FunctionType { + /// Generic type parameters + pub generic_type_parameters: Vec, + + /// Static argument types + pub static_argument_types: Vec, + + /// Optional argument types + pub optional_argument_types: Vec, + + /// Variable argument type + pub variable_argument_type: Option, + + /// Return type + pub return_type: FunctionReturnType, +} + +impl From<&F> for FunctionType { + fn from(f: &F) -> Self { + Self { + generic_type_parameters: f.generic_type_parameters().into_iter().map(TypeSet::from).collect(), + static_argument_types: f + .static_argument_types() + .into_iter() + .map(FunctionArgumentType::from) + .collect(), + optional_argument_types: f + .optional_argument_types() + .into_iter() + .map(FunctionArgumentType::from) + .collect(), + variable_argument_type: f.optional_variable_argument_type().map(FunctionArgumentType::from), + return_type: FunctionReturnType::from(f.return_type()), + } + } +} + +impl FunctionType { + /// Get the number of static arguments + pub fn num_static_arguments(&self) -> usize { + self.static_argument_types.len() + } + + /// Get the number of optional arguments + pub fn num_optional_arguments(&self) -> usize { + self.optional_argument_types.len() + } + + /// Check whether this function has variable arguments + pub fn has_variable_arguments(&self) -> bool { + self.variable_argument_type.is_some() + } + + /// Check if the given `num_args` is acceptable by the function type + pub fn is_valid_num_args(&self, num_args: usize) -> bool { + // First, there should be at least `len(static_argument_types)` arguments + if num_args < self.num_static_arguments() { + return false; + } + + // Then, we compute if there is a right amount of optional arguments + let num_optional_args = num_args - self.num_static_arguments(); + if !self.has_variable_arguments() { + // If there is no variable arguments, then the #provided optional arguments should be <= #expected optional arguments + if num_optional_args > self.num_optional_arguments() { + return false; + } + } + + true + } + + /// Get the type of i-th argument + pub fn type_of_ith_argument(&self, i: usize) -> Option { + if i < self.num_static_arguments() { + Some(self.static_argument_types[i].clone()) + } else { + let optional_argument_id = i - self.num_static_arguments(); + if optional_argument_id < self.num_optional_arguments() { + Some(self.optional_argument_types[optional_argument_id].clone()) + } else if let Some(var_arg_type) = &self.variable_argument_type { + Some(var_arg_type.clone()) + } else { + None + } + } + } +} + +/// The registry holding all the foreign function types +#[derive(Clone, Debug)] +pub struct FunctionTypeRegistry { + pub function_types: HashMap, +} + +impl FunctionTypeRegistry { + pub fn empty() -> Self { + Self { + function_types: HashMap::new(), + } + } + + pub fn from_foreign_function_registry(foreign_function_registry: &ForeignFunctionRegistry) -> Self { + let mut type_registry = Self::empty(); + for (_, ff) in foreign_function_registry { + let name = ff.name(); + let func_type = FunctionType::from(ff); + type_registry.add_function_type(name, func_type); + } + type_registry + } + + pub fn add_function_type(&mut self, name: String, f: FunctionType) { + self.function_types.insert(name, f); + } + + pub fn get(&self, function_name: &str) -> Option<&FunctionType> { + self.function_types.get(function_name) + } +} diff --git a/core/src/compiler/front/analyzers/type_inference/foreign_predicate.rs b/core/src/compiler/front/analyzers/type_inference/foreign_predicate.rs new file mode 100644 index 0000000..932aad7 --- /dev/null +++ b/core/src/compiler/front/analyzers/type_inference/foreign_predicate.rs @@ -0,0 +1,77 @@ +use std::collections::*; + +use crate::common::value_type::*; +use crate::common::foreign_predicate::*; + +/// The type of a foreign predicate. +/// Essentially a list of basic types. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct PredicateType { + pub arguments: Vec, +} + +impl From<&P> for PredicateType { + fn from(p: &P) -> Self { + Self { + arguments: p.argument_types(), + } + } +} + +impl std::ops::Index for PredicateType { + type Output = ValueType; + + fn index(&self, index: usize) -> &Self::Output { + &self.arguments[index] + } +} + +impl PredicateType { + pub fn len(&self) -> usize { + self.arguments.len() + } + + pub fn iter<'a>(&'a self) -> std::slice::Iter<'a, ValueType> { + self.arguments.iter() + } +} + +/// Predicate type registry that stores information about foreign +/// predicates and their types +#[derive(Clone, Debug)] +pub struct PredicateTypeRegistry { + pub predicate_types: HashMap, +} + +impl PredicateTypeRegistry { + /// Create a new empty predicate type registry + pub fn empty() -> Self { + Self { + predicate_types: HashMap::new(), + } + } + + /// Create a new predicate type registry + pub fn from_foreign_predicate_registry(foreign_predicate_registry: &ForeignPredicateRegistry) -> Self { + let mut type_registry = Self::empty(); + for (_, fp) in foreign_predicate_registry { + type_registry.add_foreign_predicate(fp) + } + type_registry + } + + /// Add a new foreign predicate to the predicate type registry + pub fn add_foreign_predicate(&mut self, p: &P) { + self.predicate_types.insert(p.name(), PredicateType::from(p)); + } + + /// Check if the registry contains a predicate + pub fn contains_predicate(&self, p: &str) -> bool { + self.predicate_types.contains_key(p) + } + + /// Get a predicate type + pub fn get(&self, p: &str) -> Option<&PredicateType> { + self.predicate_types.get(p) + } +} diff --git a/core/src/compiler/front/analyzers/type_inference/local.rs b/core/src/compiler/front/analyzers/type_inference/local.rs index b3d272e..bc428fa 100644 --- a/core/src/compiler/front/analyzers/type_inference/local.rs +++ b/core/src/compiler/front/analyzers/type_inference/local.rs @@ -1,6 +1,7 @@ use std::collections::*; use super::*; +use crate::common::binary_op::BinaryOp; use crate::common::value_type::*; use crate::compiler::front::*; @@ -112,6 +113,7 @@ impl LocalTypeInferenceContext { constant_types: &HashMap, inferred_relation_types: &HashMap, Loc)>, function_type_registry: &FunctionTypeRegistry, + predicate_type_registry: &PredicateTypeRegistry, inferred_expr_types: &mut HashMap, ) -> Result<(), TypeInferenceError> { for unif in &self.unifications { @@ -120,6 +122,7 @@ impl LocalTypeInferenceContext { constant_types, inferred_relation_types, function_type_registry, + predicate_type_registry, inferred_expr_types, )?; } @@ -212,7 +215,7 @@ impl LocalTypeInferenceContext { if !tys.is_empty() { let ty = TypeSet::unify_type_sets(tys)?; let arg_types = &mut inferred_relation_types.get_mut(predicate).unwrap().0; - arg_types[(*i)] = ty; + arg_types[*i] = ty; } } @@ -425,30 +428,18 @@ impl NodeVisitor for LocalTypeInferenceContext { } fn visit_binary_expr(&mut self, b: &BinaryExpr) { - let unif = if b.op().is_arith() { - Unification::AddSubMulDivMod( - b.op1().location().clone(), - b.op2().location().clone(), - b.location().clone(), - ) - } else if b.op().is_logical() { - Unification::AndOrXor( - b.op1().location().clone(), - b.op2().location().clone(), - b.location().clone(), - ) - } else if b.op().is_eq_neq() { - Unification::EqNeq( - b.op1().location().clone(), - b.op2().location().clone(), - b.location().clone(), - ) - } else { - Unification::LtLeqGtGeq( - b.op1().location().clone(), - b.op2().location().clone(), - b.location().clone(), - ) + let op1 = b.op1().location().clone(); + let op2 = b.op2().location().clone(); + let loc = b.location().clone(); + let unif = match b.op().node { + BinaryOp::Add => Unification::Add(op1, op2, loc), + BinaryOp::Sub => Unification::Sub(op1, op2, loc), + BinaryOp::Mul => Unification::Mult(op1, op2, loc), + BinaryOp::Div => Unification::Div(op1, op2, loc), + BinaryOp::Mod => Unification::Mod(op1, op2, loc), + BinaryOp::And | BinaryOp::Or | BinaryOp::Xor => Unification::AndOrXor(op1, op2, loc), + BinaryOp::Eq | BinaryOp::Neq => Unification::EqNeq(op1, op2, loc), + BinaryOp::Lt | BinaryOp::Leq | BinaryOp::Gt | BinaryOp::Geq => Unification::LtLeqGtGeq(op1, op2, loc), }; self.unifications.push(unif); } diff --git a/core/src/compiler/front/analyzers/type_inference/mod.rs b/core/src/compiler/front/analyzers/type_inference/mod.rs index 41ee9d0..7f2c4d8 100644 --- a/core/src/compiler/front/analyzers/type_inference/mod.rs +++ b/core/src/compiler/front/analyzers/type_inference/mod.rs @@ -1,8 +1,10 @@ //! # Type inference analysis mod error; -mod function; +mod foreign_function; +mod foreign_predicate; mod local; +mod operator_rules; mod type_inference; mod type_set; mod unification; @@ -10,8 +12,10 @@ mod unification; use super::super::utils::*; pub use error::*; -pub use function::*; +pub use foreign_function::*; +pub use foreign_predicate::*; pub use local::*; +pub use operator_rules::*; pub use type_inference::*; pub use type_set::*; pub use unification::*; diff --git a/core/src/compiler/front/analyzers/type_inference/operator_rules.rs b/core/src/compiler/front/analyzers/type_inference/operator_rules.rs new file mode 100644 index 0000000..6d36125 --- /dev/null +++ b/core/src/compiler/front/analyzers/type_inference/operator_rules.rs @@ -0,0 +1,137 @@ +use lazy_static::lazy_static; + +use crate::common::value_type::*; + +lazy_static! { + pub static ref ADD_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { + use ValueType::*; + vec![ + (I8, I8, I8), + (I16, I16, I16), + (I32, I32, I32), + (I64, I64, I64), + (I128, I128, I128), + (ISize, ISize, ISize), + (U8, U8, U8), + (U16, U16, U16), + (U32, U32, U32), + (U64, U64, U64), + (U128, U128, U128), + (USize, USize, USize), + (F32, F32, F32), + (F64, F64, F64), + (String, String, String), + (DateTime, Duration, DateTime), + (Duration, DateTime, DateTime), + (Duration, Duration, Duration), + ] + }; + + pub static ref SUB_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { + use ValueType::*; + vec![ + (I8, I8, I8), + (I16, I16, I16), + (I32, I32, I32), + (I64, I64, I64), + (I128, I128, I128), + (ISize, ISize, ISize), + (U8, U8, U8), + (U16, U16, U16), + (U32, U32, U32), + (U64, U64, U64), + (U128, U128, U128), + (USize, USize, USize), + (F32, F32, F32), + (F64, F64, F64), + (DateTime, Duration, DateTime), + (DateTime, DateTime, Duration), + (Duration, Duration, Duration), + ] + }; + + pub static ref MULT_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { + use ValueType::*; + vec![ + (I8, I8, I8), + (I16, I16, I16), + (I32, I32, I32), + (I64, I64, I64), + (I128, I128, I128), + (ISize, ISize, ISize), + (U8, U8, U8), + (U16, U16, U16), + (U32, U32, U32), + (U64, U64, U64), + (U128, U128, U128), + (USize, USize, USize), + (F32, F32, F32), + (F64, F64, F64), + (Duration, I32, Duration), + (I32, Duration, Duration), + ] + }; + + pub static ref DIV_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { + use ValueType::*; + vec![ + (I8, I8, I8), + (I16, I16, I16), + (I32, I32, I32), + (I64, I64, I64), + (I128, I128, I128), + (ISize, ISize, ISize), + (U8, U8, U8), + (U16, U16, U16), + (U32, U32, U32), + (U64, U64, U64), + (U128, U128, U128), + (USize, USize, USize), + (F32, F32, F32), + (F64, F64, F64), + (Duration, I32, Duration), + ] + }; + + pub static ref MOD_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { + use ValueType::*; + vec![ + (I8, I8, I8), + (I16, I16, I16), + (I32, I32, I32), + (I64, I64, I64), + (I128, I128, I128), + (ISize, ISize, ISize), + (U8, U8, U8), + (U16, U16, U16), + (U32, U32, U32), + (U64, U64, U64), + (U128, U128, U128), + (USize, USize, USize), + (F32, F32, F32), + (F64, F64, F64), + ] + }; + + pub static ref COMPARE_TYPING_RULES: Vec<(ValueType, ValueType)> = { + use ValueType::*; + vec![ + (I8, I8), + (I16, I16), + (I32, I32), + (I64, I64), + (I128, I128), + (ISize, ISize), + (U8, U8), + (U16, U16), + (U32, U32), + (U64, U64), + (U128, U128), + (USize, USize), + (F32, F32), + (F64, F64), + (Duration, Duration), + (DateTime, DateTime), + ] + }; +} diff --git a/core/src/compiler/front/analyzers/type_inference/type_inference.rs b/core/src/compiler/front/analyzers/type_inference/type_inference.rs index 53890f4..f350fcc 100644 --- a/core/src/compiler/front/analyzers/type_inference/type_inference.rs +++ b/core/src/compiler/front/analyzers/type_inference/type_inference.rs @@ -1,6 +1,7 @@ use std::collections::*; -use crate::common::foreign_function::ForeignFunctionRegistry; +use crate::common::foreign_function::*; +use crate::common::foreign_predicate::*; use crate::common::tuple_type::*; use crate::common::value_type::*; use crate::compiler::front::*; @@ -11,7 +12,8 @@ use super::*; pub struct TypeInference { pub custom_types: HashMap, pub constant_types: HashMap, - pub function_type_registry: FunctionTypeRegistry, + pub foreign_function_type_registry: FunctionTypeRegistry, + pub foreign_predicate_type_registry: PredicateTypeRegistry, pub relation_type_decl_loc: HashMap, pub inferred_relation_types: HashMap, Loc)>, pub rule_variable_type: HashMap>, @@ -22,11 +24,15 @@ pub struct TypeInference { } impl TypeInference { - pub fn new(function_registry: &ForeignFunctionRegistry) -> Self { + pub fn new( + function_registry: &ForeignFunctionRegistry, + predicate_registry: &ForeignPredicateRegistry, + ) -> Self { Self { custom_types: HashMap::new(), constant_types: HashMap::new(), - function_type_registry: FunctionTypeRegistry::from_foreign_function_registry(function_registry), + foreign_function_type_registry: FunctionTypeRegistry::from_foreign_function_registry(function_registry), + foreign_predicate_type_registry: PredicateTypeRegistry::from_foreign_predicate_registry(predicate_registry), relation_type_decl_loc: HashMap::new(), inferred_relation_types: HashMap::new(), rule_variable_type: HashMap::new(), @@ -234,7 +240,8 @@ impl TypeInference { &self.custom_types, &self.constant_types, &self.inferred_relation_types, - &self.function_type_registry, + &self.foreign_function_type_registry, + &self.foreign_predicate_type_registry, &mut inferred_expr_types, )?; ctx.propagate_variable_types(&mut inferred_var_expr, &mut inferred_expr_types)?; @@ -278,6 +285,16 @@ impl NodeVisitor for TypeInference { } fn visit_relation_type(&mut self, relation_type: &RelationType) { + // Check if the relation is a foreign predicate + let predicate = relation_type.predicate(); + if self.foreign_predicate_type_registry.contains_predicate(predicate) { + self.errors.push(TypeInferenceError::CannotRedefineForeignPredicate { + pred: predicate.to_string(), + loc: relation_type.location().clone(), + }); + return; + } + self.check_and_add_relation_type( relation_type.predicate(), relation_type.arg_types(), @@ -285,6 +302,35 @@ impl NodeVisitor for TypeInference { ); } + fn visit_enum_type_decl(&mut self, enum_type_decl: &ast::EnumTypeDecl) { + // First add the enum type + let ty = Type::usize(); + self.check_and_add_custom_type(enum_type_decl.name(), &ty, enum_type_decl.location()); + + // And then declare all the constant types + for member in enum_type_decl.iter_members() { + match member.assigned_number() { + Some(c) => match &c.node { + ConstantNode::Integer(i) => { + if *i < 0 { + self.errors.push(TypeInferenceError::NegativeEnumValue { + found: *i, + loc: c.location().clone(), + }) + } + } + _ => { + self.errors.push(TypeInferenceError::BadEnumValueKind { + found: c.kind(), + loc: c.location().clone(), + }) + } + } + _ => {} + } + } + } + fn visit_const_assignment(&mut self, const_assign: &ConstAssignment) { if let Some(raw_type) = const_assign.ty() { let result = find_value_type(&self.custom_types, raw_type).and_then(|ty| { @@ -304,6 +350,15 @@ impl NodeVisitor for TypeInference { fn visit_constant_set_decl(&mut self, constant_set_decl: &ConstantSetDecl) { let pred = constant_set_decl.predicate(); + // Check if the relation is a foreign predicate + if self.foreign_predicate_type_registry.contains_predicate(pred) { + self.errors.push(TypeInferenceError::CannotRedefineForeignPredicate { + pred: pred.to_string(), + loc: constant_set_decl.location().clone(), + }); + return; + } + // There's nothing we can check if there is no tuple inside the set if constant_set_decl.num_tuples() == 0 { return; @@ -398,6 +453,16 @@ impl NodeVisitor for TypeInference { fn visit_fact_decl(&mut self, fact_decl: &FactDecl) { let pred = fact_decl.predicate(); + + // Check if the relation is a foreign predicate + if self.foreign_predicate_type_registry.contains_predicate(pred) { + self.errors.push(TypeInferenceError::CannotRedefineForeignPredicate { + pred: pred.to_string(), + loc: fact_decl.location().clone(), + }); + return; + } + let maybe_curr_type_sets = fact_decl .iter_arguments() .map(|arg| match arg { @@ -451,6 +516,18 @@ impl NodeVisitor for TypeInference { } fn visit_rule(&mut self, rule: &Rule) { + let pred = rule.head().predicate(); + + // Check if the relation is a foreign predicate + if self.foreign_predicate_type_registry.contains_predicate(pred) { + self.errors.push(TypeInferenceError::CannotRedefineForeignPredicate { + pred: pred.to_string(), + loc: rule.location().clone(), + }); + return; + } + + // Otherwise, create a rule inference context let ctx = LocalTypeInferenceContext::from_rule(rule); // Check if context has error already @@ -470,6 +547,17 @@ impl NodeVisitor for TypeInference { } fn visit_query(&mut self, query: &Query) { + // Check if the relation is a foreign predicate + let pred = query.predicate(); + if self.foreign_predicate_type_registry.contains_predicate(&pred) { + self.errors.push(TypeInferenceError::CannotQueryForeignPredicate { + pred: pred.to_string(), + loc: query.location().clone(), + }); + return; + } + + // Check the query match &query.node { QueryNode::Atom(atom) => { let ctx = LocalTypeInferenceContext::from_atom(atom); diff --git a/core/src/compiler/front/analyzers/type_inference/type_set.rs b/core/src/compiler/front/analyzers/type_inference/type_set.rs index d6bc72a..b79764a 100644 --- a/core/src/compiler/front/analyzers/type_inference/type_set.rs +++ b/core/src/compiler/front/analyzers/type_inference/type_set.rs @@ -6,13 +6,13 @@ use crate::compiler::front::*; #[derive(Clone, Debug)] pub enum TypeSet { BaseType(ValueType, AstNodeLocation), // Concrete base type - Numeric(AstNodeLocation), // Contains integer and float, default usize + Numeric(AstNodeLocation), // Contains integer and float, default i32 Arith(AstNodeLocation), // Numeric but with arithmetics, default integer i32 - Integer(AstNodeLocation), // Only integer, default i32 - SignedInteger(AstNodeLocation), // Only signed integer, default i32 - UnsignedInteger(AstNodeLocation), // Only unsigned integer, default u32 - Float(AstNodeLocation), // Only float, default f32 - String(AstNodeLocation), // Only string, default String + Integer(AstNodeLocation), // integer, default `i32` + SignedInteger(AstNodeLocation), // signed integer, default `i32` + UnsignedInteger(AstNodeLocation), // unsigned integer, default `u32` + Float(AstNodeLocation), // float, default `f32` + String(AstNodeLocation), // string, default `String` Any(AstNodeLocation), // Any type, default i32 } @@ -165,6 +165,10 @@ impl TypeSet { Self::Any(AstNodeLocation::default()) } + pub fn numeric() -> Self { + Self::Numeric(AstNodeLocation::default()) + } + pub fn arith() -> Self { Self::Arith(AstNodeLocation::default()) } @@ -186,6 +190,9 @@ impl TypeSet { ConstantNode::Char(_) => Self::BaseType(ValueType::Char, c.location().clone()), ConstantNode::Boolean(_) => Self::BaseType(ValueType::Bool, c.location().clone()), ConstantNode::String(_) => Self::String(c.location().clone()), + ConstantNode::DateTime(_) => Self::BaseType(ValueType::DateTime, c.location().clone()), + ConstantNode::Duration(_) => Self::BaseType(ValueType::Duration, c.location().clone()), + ConstantNode::Invalid(_) => panic!("[Internal Error] Should not be called with invalid constant"), } } @@ -205,7 +212,7 @@ impl TypeSet { (Self::UnsignedInteger(_), base_ty) => base_ty.is_numeric(), (Self::Float(_), base_ty) => base_ty.is_numeric(), (Self::String(_), base_ty) => base_ty.is_string() || base_ty.is_numeric(), - (Self::Any(_), _) => false, + (Self::Any(_), base_ty) => base_ty.is_numeric(), } } @@ -233,14 +240,14 @@ impl TypeSet { pub fn to_default_value_type(&self) -> ValueType { match self { Self::BaseType(b, _) => b.clone(), - Self::Numeric(_) => ValueType::USize, + Self::Numeric(_) => ValueType::I32, Self::Arith(_) => ValueType::I32, Self::Integer(_) => ValueType::I32, Self::SignedInteger(_) => ValueType::I32, Self::UnsignedInteger(_) => ValueType::U32, Self::Float(_) => ValueType::F32, Self::String(_) => ValueType::String, - Self::Any(_) => ValueType::USize, + Self::Any(_) => ValueType::I32, } } @@ -282,4 +289,18 @@ impl TypeSet { }), } } + + pub fn contains_value_type(&self, value_type: &ValueType) -> bool { + match self { + Self::BaseType(b, _) => b == value_type, + Self::Numeric(_) => value_type.is_numeric(), + Self::Arith(_) => value_type.is_numeric(), + Self::Integer(_) => value_type.is_integer(), + Self::SignedInteger(_) => value_type.is_signed_integer(), + Self::UnsignedInteger(_) => value_type.is_unsigned_integer(), + Self::Float(_) => value_type.is_float(), + Self::String(_) => value_type.is_string(), + Self::Any(_) => true, + } + } } diff --git a/core/src/compiler/front/analyzers/type_inference/unification.rs b/core/src/compiler/front/analyzers/type_inference/unification.rs index 25063d5..5a0c9a1 100644 --- a/core/src/compiler/front/analyzers/type_inference/unification.rs +++ b/core/src/compiler/front/analyzers/type_inference/unification.rs @@ -1,13 +1,14 @@ use std::collections::*; use super::*; + use crate::common::value_type::*; use crate::compiler::front::*; /// The structure storing unification relationships #[derive(Clone, Debug)] pub enum Unification { - /// The i-th element of a relation: arg, relation, argument ID + /// The i-th element of a relation: arg, relation name, argument ID IthArgOfRelation(Loc, String, usize), /// V, Variable Name @@ -16,8 +17,20 @@ pub enum Unification { /// C, Type Set of C OfConstant(Loc, TypeSet), - /// op1, op2, op1 X op2 - AddSubMulDivMod(Loc, Loc, Loc), + /// op1, op2, op1 + op2 + Add(Loc, Loc, Loc), + + /// op1, op2, op1 - op2 + Sub(Loc, Loc, Loc), + + /// op1, op2, op1 * op2 + Mult(Loc, Loc, Loc), + + /// op1, op2, op1 / op2 + Div(Loc, Loc, Loc), + + /// op1, op2, op1 % op2 + Mod(Loc, Loc, Loc), /// op1, op2, op1 == op2 EqNeq(Loc, Loc, Loc), @@ -52,22 +65,51 @@ impl Unification { constant_types: &HashMap, inferred_relation_types: &HashMap, Loc)>, function_type_registry: &FunctionTypeRegistry, + predicate_type_registry: &PredicateTypeRegistry, inferred_expr_types: &mut HashMap, ) -> Result<(), TypeInferenceError> { match self { Self::IthArgOfRelation(e, p, i) => { - let (tys, loc) = inferred_relation_types.get(p).unwrap(); - if i < &tys.len() { - let ty = tys[(*i)].clone(); - unify_ty(e, ty, inferred_expr_types)?; - Ok(()) + if let Some(tys) = predicate_type_registry.get(p) { + if i < &tys.len() { + // It is a foreign predicate in the registry; we get the i-th type + let ty = TypeSet::base(tys[*i].clone()); + + // Unify the type + match unify_ty(e, ty.clone(), inferred_expr_types) { + Ok(_) => Ok(()), + Err(_) => { + Err(TypeInferenceError::CannotUnifyForeignPredicateArgument { + pred: p.clone(), + i: *i, + expected_ty: ty, + actual_ty: inferred_expr_types.get(e).unwrap().clone(), + loc: e.clone(), + }) + } + } + } else { + Err(TypeInferenceError::InvalidForeignPredicateArgIndex { + predicate: p.clone(), + index: i.clone(), + access_loc: e.clone(), + }) + } } else { - Err(TypeInferenceError::InvalidArgIndex { - predicate: p.clone(), - index: i.clone(), - source_loc: loc.clone(), - access_loc: e.clone(), - }) + // It is user defined predicate + let (tys, loc) = inferred_relation_types.get(p).unwrap(); + if i < &tys.len() { + let ty = tys[*i].clone(); + unify_ty(e, ty, inferred_expr_types)?; + Ok(()) + } else { + Err(TypeInferenceError::InvalidArgIndex { + predicate: p.clone(), + index: i.clone(), + source_loc: loc.clone(), + access_loc: e.clone(), + }) + } } } Self::OfVariable(_, _) => Ok(()), @@ -98,7 +140,19 @@ impl Unification { Ok(()) } } - Self::AddSubMulDivMod(op1, op2, e) => { + Self::Add(op1, op2, e) => { + unify_polymorphic_binary_expression(op1, op2, e, inferred_expr_types, &ADD_TYPING_RULES) + } + Self::Sub(op1, op2, e) => { + unify_polymorphic_binary_expression(op1, op2, e, inferred_expr_types, &SUB_TYPING_RULES) + } + Self::Mult(op1, op2, e) => { + unify_polymorphic_binary_expression(op1, op2, e, inferred_expr_types, &MULT_TYPING_RULES) + } + Self::Div(op1, op2, e) => { + unify_polymorphic_binary_expression(op1, op2, e, inferred_expr_types, &DIV_TYPING_RULES) + } + Self::Mod(op1, op2, e) => { let e_ty = inferred_expr_types .entry(e.clone()) .or_insert(TypeSet::Arith(e.clone())) @@ -153,25 +207,7 @@ impl Unification { Ok(()) } Self::LtLeqGtGeq(op1, op2, e) => { - // e should be boolean - unify_boolean(e, inferred_expr_types)?; - - // op1 and op2 are numeric - let t1 = unify_arith(op1, inferred_expr_types)?; - let t2 = unify_arith(op2, inferred_expr_types)?; - - // op1 and op2 are of the same type - match t1.unify(&t2) { - Ok(new_ty) => { - inferred_expr_types.insert(op1.clone(), new_ty.clone()); - inferred_expr_types.insert(op2.clone(), new_ty.clone()); - Ok(()) - } - Err(mut err) => { - err.annotate_location(e); - Err(err) - } - } + unify_comparison_expression(op1, op2, e, inferred_expr_types, &COMPARE_TYPING_RULES) } Self::PosNeg(op1, e) => { let e_ty = inferred_expr_types @@ -327,10 +363,123 @@ impl Unification { } } +enum AppliedRules { + None, + One(T), + Multiple, +} + +impl AppliedRules { + fn new() -> Self { + Self::None + } + + fn add(self, rule: T) -> Self { + match self { + Self::None => Self::One(rule), + Self::One(_) => Self::Multiple, + Self::Multiple => Self::Multiple, + } + } +} + fn get_or_insert_ty(e: &Loc, ty: TypeSet, inferred_expr_types: &mut HashMap) -> TypeSet { inferred_expr_types.entry(e.clone()).or_insert(ty).clone() } +fn unify_polymorphic_binary_expression( + op1: &Loc, + op2: &Loc, + e: &Loc, + inferred_expr_types: &mut HashMap, + rules: &[(ValueType, ValueType, ValueType)], +) -> Result<(), TypeInferenceError> { + // First get the already inferred types of op1, op2, and e + let op1_ty = unify_any(op1, inferred_expr_types)?; + let op2_ty = unify_any(op2, inferred_expr_types)?; + let e_ty = unify_any(e, inferred_expr_types)?; + + // Then iterate through all the rules to see if any could be applied + let mut applied_rules = AppliedRules::new(); + for (t1, t2, te) in rules { + if op1_ty.contains_value_type(t1) && op2_ty.contains_value_type(t2) && e_ty.contains_value_type(te) { + applied_rules = applied_rules.add((t1.clone(), t2.clone(), te.clone())); + } + } + + // Finally, check if there is any rule applied + match applied_rules { + AppliedRules::None => { + // If no rule can be applied, then the type inference is failed + Err(TypeInferenceError::NoMatchingTripletRule { + op1_ty, + op2_ty, + e_ty, + location: e.clone(), + }) + }, + AppliedRules::One((t1, t2, te)) => { + // If there is exactly one rule that can be applied, then unify them with the exact types + unify_ty(op1, TypeSet::BaseType(t1, e.clone()), inferred_expr_types)?; + unify_ty(op2, TypeSet::BaseType(t2, e.clone()), inferred_expr_types)?; + unify_ty(e, TypeSet::BaseType(te, e.clone()), inferred_expr_types)?; + Ok(()) + } + AppliedRules::Multiple => { + // If ther are multiple rules that can be applied, we are not sure about the exact types, + // but the type inference is still successful + Ok(()) + }, + } +} + +fn unify_comparison_expression( + op1: &Loc, + op2: &Loc, + e: &Loc, + inferred_expr_types: &mut HashMap, + rules: &[(ValueType, ValueType)], +) -> Result<(), TypeInferenceError> { + // The result should be a boolean + let e_ty = unify_boolean(e, inferred_expr_types)?; + + // First get the already inferred types of op1, op2, and e + let op1_ty = unify_any(op1, inferred_expr_types)?; + let op2_ty = unify_any(op2, inferred_expr_types)?; + + // Then iterate through all the rules to see if any could be applied + let mut applied_rules = AppliedRules::new(); + for (t1, t2) in rules { + if op1_ty.contains_value_type(t1) && op2_ty.contains_value_type(t2) { + applied_rules = applied_rules.add((t1.clone(), t2.clone())); + } + } + + // Finally, check if there is any rule applied + match applied_rules { + AppliedRules::None => { + // If no rule can be applied, then the type inference is failed + Err(TypeInferenceError::NoMatchingTripletRule { + op1_ty, + op2_ty, + e_ty, + location: e.clone(), + }) + }, + AppliedRules::One((t1, t2)) => { + // If there is exactly one rule that can be applied, then unify them with the exact types + unify_ty(op1, TypeSet::BaseType(t1, e.clone()), inferred_expr_types)?; + unify_ty(op2, TypeSet::BaseType(t2, e.clone()), inferred_expr_types)?; + Ok(()) + } + AppliedRules::Multiple => { + // If ther are multiple rules that can be applied, we are not sure about the exact types, + // but the type inference is still successful + Ok(()) + }, + } +} + fn unify_ty( e: &Loc, ty: TypeSet, @@ -354,11 +503,6 @@ fn unify_any(e: &Loc, inferred_expr_types: &mut HashMap) -> Result unify_ty(e, e_ty, inferred_expr_types) } -fn unify_arith(e: &Loc, inferred_expr_types: &mut HashMap) -> Result { - let e_ty = TypeSet::Arith(e.clone()); - unify_ty(e, e_ty, inferred_expr_types) -} - fn unify_boolean(e: &Loc, inferred_expr_types: &mut HashMap) -> Result { let e_ty = TypeSet::BaseType(ValueType::Bool, e.clone()); unify_ty(e, e_ty, inferred_expr_types) diff --git a/core/src/compiler/front/ast/constant.rs b/core/src/compiler/front/ast/constant.rs index a3ef784..57574e0 100644 --- a/core/src/compiler/front/ast/constant.rs +++ b/core/src/compiler/front/ast/constant.rs @@ -1,21 +1,23 @@ // use std::rc::Rc; use super::*; -use crate::common::{input_tag::InputTag, value::Value, value_type::ValueType}; +use crate::common::input_tag::DynamicInputTag; +use crate::common::value::Value; +use crate::common::value_type::ValueType; #[derive(Clone, Debug, PartialEq)] #[doc(hidden)] -pub struct TagNode(pub InputTag); +pub struct TagNode(pub DynamicInputTag); /// A tag associated with a fact pub type Tag = AstNode; impl Tag { pub fn default_none() -> Self { - Self::default(TagNode(InputTag::None)) + Self::default(TagNode(DynamicInputTag::None)) } - pub fn input_tag(&self) -> &InputTag { + pub fn input_tag(&self) -> &DynamicInputTag { &self.node.0 } @@ -32,12 +34,22 @@ pub enum ConstantNode { Char(String), Boolean(bool), String(String), + DateTime(chrono::DateTime), + Duration(chrono::Duration), + + /// Invalid is used to represent a constant that could not be parsed; the string is the error message + Invalid(String), } /// A constant, which could be an integer, floating point, character, boolean, or string. pub type Constant = AstNode; impl Constant { + /// Create a new constant integer AST node + pub fn integer(i: i64) -> Self { + Self::default(ConstantNode::Integer(i)) + } + pub fn to_value(&self, ty: &ValueType) -> Value { use ConstantNode::*; match (&self.node, ty) { @@ -62,6 +74,8 @@ impl Constant { (String(_), ValueType::Str) => panic!("Cannot cast dynamic string into static string"), (String(s), ValueType::String) => Value::String(s.clone()), // (String(s), ValueType::RcString) => Value::RcString(Rc::new(s.clone())), + (DateTime(d), ValueType::DateTime) => Value::DateTime(d.clone()), + (Duration(d), ValueType::Duration) => Value::Duration(d.clone()), _ => panic!("Cannot convert front Constant `{:?}` to Type `{}`", self, ty), } } @@ -74,6 +88,9 @@ impl Constant { String(_) => "string", Char(_) => "char", Boolean(_) => "boolean", + DateTime(_) => "datetime", + Duration(_) => "duration", + Invalid(_) => "invalid", } } } diff --git a/core/src/compiler/front/ast/formula.rs b/core/src/compiler/front/ast/formula.rs index a946867..72bd2e2 100644 --- a/core/src/compiler/front/ast/formula.rs +++ b/core/src/compiler/front/ast/formula.rs @@ -20,8 +20,12 @@ impl Formula { pub fn negate(&self) -> Self { match self { - Self::Atom(a) => Self::NegAtom(NegAtom::new(a.location().clone(), NegAtomNode { atom: a.clone() })), - Self::NegAtom(n) => Self::Atom(n.atom().clone()), + Self::Atom(a) => { + Self::NegAtom(NegAtom::new(a.location().clone(), NegAtomNode { atom: a.clone() })) + } + Self::NegAtom(n) => { + Self::Atom(n.atom().clone()) + }, Self::Disjunction(d) => Self::Conjunction(Conjunction::new( d.location().clone(), ConjunctionNode { @@ -90,6 +94,10 @@ impl NegAtom { pub fn atom(&self) -> &Atom { &self.node.atom } + + pub fn predicate(&self) -> &String { + self.atom().predicate() + } } #[derive(Clone, Debug, PartialEq)] diff --git a/core/src/compiler/front/ast/query.rs b/core/src/compiler/front/ast/query.rs index 662e4fc..fa705b0 100644 --- a/core/src/compiler/front/ast/query.rs +++ b/core/src/compiler/front/ast/query.rs @@ -10,7 +10,21 @@ pub enum QueryNode { pub type Query = AstNode; impl Query { - pub fn relation_name(&self) -> String { + pub fn predicate(&self) -> String { + match &self.node { + QueryNode::Predicate(p) => { + let n = p.name(); + if let Some(id) = n.find("(") { + n[..id].to_string() + } else { + n.to_string() + } + }, + QueryNode::Atom(a) => a.predicate().to_string(), + } + } + + pub fn create_relation_name(&self) -> String { match &self.node { QueryNode::Predicate(p) => p.name().to_string(), QueryNode::Atom(a) => format!("{}", a), diff --git a/core/src/compiler/front/ast/type_decl.rs b/core/src/compiler/front/ast/type_decl.rs index d214db0..7a315ee 100644 --- a/core/src/compiler/front/ast/type_decl.rs +++ b/core/src/compiler/front/ast/type_decl.rs @@ -6,6 +6,7 @@ pub enum TypeDeclNode { Subtype(SubtypeDecl), Alias(AliasTypeDecl), Relation(RelationTypeDecl), + Enum(EnumTypeDecl), } pub type TypeDecl = AstNode; @@ -16,6 +17,7 @@ impl TypeDecl { TypeDeclNode::Subtype(s) => s.attributes(), TypeDeclNode::Alias(a) => a.attributes(), TypeDeclNode::Relation(r) => r.attributes(), + TypeDeclNode::Enum(e) => e.attributes(), } } @@ -24,6 +26,7 @@ impl TypeDecl { TypeDeclNode::Subtype(s) => s.attributes_mut(), TypeDeclNode::Alias(a) => a.attributes_mut(), TypeDeclNode::Relation(r) => r.attributes_mut(), + TypeDeclNode::Enum(e) => e.attributes_mut(), } } } @@ -167,3 +170,61 @@ impl RelationTypeDecl { &mut self.node.attrs } } + +#[derive(Clone, Debug, PartialEq)] +#[doc(hidden)] +pub struct EnumTypeDeclNode { + pub attrs: Attributes, + pub name: Identifier, + pub members: Vec, +} + +pub type EnumTypeDecl = AstNode; + +impl EnumTypeDecl { + pub fn attributes(&self) -> &Vec { + &self.node.attrs + } + + pub fn attributes_mut(&mut self) -> &mut Attributes { + &mut self.node.attrs + } + + pub fn name(&self) -> &str { + self.node.name.name() + } + + pub fn members(&self) -> &[EnumTypeMember] { + &self.node.members + } + + pub fn iter_members(&self) -> impl Iterator { + self.node.members.iter() + } + + pub fn iter_members_mut(&mut self) -> impl Iterator { + self.node.members.iter_mut() + } +} + +#[derive(Clone, Debug, PartialEq)] +pub struct EnumTypeMemberNode { + pub name: Identifier, + pub assigned_num: Option, +} + +pub type EnumTypeMember = AstNode; + +impl EnumTypeMember { + pub fn name(&self) -> &str { + self.node.name.name() + } + + pub fn assigned_number(&self) -> Option<&Constant> { + self.node.assigned_num.as_ref() + } + + pub fn assigned_number_mut(&mut self) -> Option<&mut Constant> { + self.node.assigned_num.as_mut() + } +} diff --git a/core/src/compiler/front/ast/types.rs b/core/src/compiler/front/ast/types.rs index 8447871..c147db7 100644 --- a/core/src/compiler/front/ast/types.rs +++ b/core/src/compiler/front/ast/types.rs @@ -23,6 +23,8 @@ pub enum TypeNode { Str, String, // RcString, + DateTime, + Duration, Named(Identifier), } @@ -48,6 +50,8 @@ impl std::fmt::Display for TypeNode { Self::Str => f.write_str("&str"), Self::String => f.write_str("String"), // Self::RcString => f.write_str("Rc"), + Self::DateTime => f.write_str("DateTime"), + Self::Duration => f.write_str("Duration"), Self::Named(i) => f.write_str(&i.node.name), } } @@ -56,6 +60,20 @@ impl std::fmt::Display for TypeNode { pub type Type = AstNode; impl Type { + /// Create a new `i8` type AST node + pub fn i8() -> Self { + Self::default(TypeNode::I8) + } + + /// Create a new `usize` type AST node + pub fn usize() -> Self { + Self::default(TypeNode::USize) + } + + /// Convert the type AST node to a value type + /// + /// Returns `Ok` if the node itself is a base type; + /// `Err` if the node is a `Named` type and not normalized to base type pub fn to_value_type(&self) -> Result { match &self.node { TypeNode::I8 => Ok(ValueType::I8), @@ -77,6 +95,8 @@ impl Type { TypeNode::Str => Ok(ValueType::Str), TypeNode::String => Ok(ValueType::String), // TypeNode::RcString => Ok(ValueType::RcString), + TypeNode::DateTime => Ok(ValueType::DateTime), + TypeNode::Duration => Ok(ValueType::Duration), TypeNode::Named(s) => Err(s.name().to_string()), } } diff --git a/core/src/compiler/front/compile.rs b/core/src/compiler/front/compile.rs index 68e7686..24179d2 100644 --- a/core/src/compiler/front/compile.rs +++ b/core/src/compiler/front/compile.rs @@ -7,6 +7,7 @@ use super::analyzers::*; use super::*; use crate::common::foreign_function::*; +use crate::common::foreign_predicate::*; use crate::common::tuple_type::*; use crate::common::value_type::*; use crate::utils::CopyOnWrite; @@ -28,6 +29,9 @@ pub struct FrontContext { /// Foreign function registry holding all foreign functions pub foreign_function_registry: ForeignFunctionRegistry, + /// Foreign predicate registry holding all foreign predicates + pub foreign_predicate_registry: ForeignPredicateRegistry, + /// Node ID annotator for giving AST node IDs. pub node_id_annotator: NodeIdAnnotator, @@ -38,11 +42,13 @@ pub struct FrontContext { impl FrontContext { pub fn new() -> Self { let function_registry = ForeignFunctionRegistry::std(); - let analysis = Analysis::new(&function_registry); + let predicate_registry = ForeignPredicateRegistry::std(); + let analysis = Analysis::new(&function_registry, &predicate_registry); Self { sources: Sources::new(), items: Vec::new(), foreign_function_registry: function_registry, + foreign_predicate_registry: predicate_registry, imported_files: HashSet::new(), node_id_annotator: NodeIdAnnotator::new(), analysis: CopyOnWrite::new(analysis), @@ -80,13 +86,43 @@ impl FrontContext { self.analysis.modify(|analysis| { analysis .type_inference - .function_type_registry + .foreign_function_type_registry .add_function_type(func_name, func_type) }); Ok(()) } + pub fn register_foreign_predicate(&mut self, f: F) -> Result<(), ForeignPredicateError> + where + F: ForeignPredicate + Send + Sync + Clone + 'static + { + // Check if the predicate name has already be defined before + if self.type_inference().has_relation(&f.name()) { + return Err(ForeignPredicateError::AlreadyExisted { id: f.name() }); + } + + // Add the predicate to the registry + self.foreign_predicate_registry.register(f.clone())?; + + // If succeeded, we add it to the type inference module + self.analysis.modify(|analysis| { + // Update the type inference module + analysis + .type_inference + .foreign_predicate_type_registry + .add_foreign_predicate(&f); + + // Update the head analysis module + analysis.head_relation_analysis.add_foreign_predicate(&f); + + // Update the boundness analysis module + analysis.boundness_analysis.add_foreign_predicate(&f); + }); + + Ok(()) + } + pub fn compile_source(&mut self, s: S) -> Result { self.compile_source_with_parser(s, parser::str_to_items) } diff --git a/core/src/compiler/front/f2b/f2b.rs b/core/src/compiler/front/f2b/f2b.rs index 4617db1..f3e0e97 100644 --- a/core/src/compiler/front/f2b/f2b.rs +++ b/core/src/compiler/front/f2b/f2b.rs @@ -1,6 +1,6 @@ use std::collections::*; -use super::super::analyzers::boundness::{AggregationContext, RuleContext}; +use super::super::analyzers::boundness::{AggregationContext, RuleContext, ForeignPredicateBindings}; use super::super::ast as front; use super::super::ast::{AstNodeLocation, WithLocation}; use super::super::compile::*; @@ -44,6 +44,7 @@ impl FrontContext { disjunctive_facts, rules, function_registry: self.foreign_function_registry.clone(), + predicate_registry: self.foreign_predicate_registry.clone(), } } @@ -53,7 +54,7 @@ impl FrontContext { .iter() .filter_map(|item| match item { front::Item::QueryDecl(q) => { - let name = q.node.query.relation_name().clone(); + let name = q.node.query.create_relation_name().clone(); if let Some(file) = self.analysis.borrow().output_files_analysis.output_file(&name) { Some((name, OutputOption::File(file.clone()))) } else { @@ -72,12 +73,16 @@ impl FrontContext { .type_inference .inferred_relation_types .iter() - .map(|(pred, (tys, _))| { - let arg_types = tys.iter().map(|type_set| type_set.to_default_value_type()).collect(); - back::Relation { - attributes: self.back_relation_attributes(pred), - predicate: pred.clone(), - arg_types, + .filter_map(|(pred, (tys, _))| { + if self.foreign_predicate_registry.contains(pred) { + None + } else { + let arg_types = tys.iter().map(|type_set| type_set.to_default_value_type()).collect(); + Some(back::Relation { + attributes: self.back_relation_attributes(pred), + predicate: pred.clone(), + arg_types, + }) } }) .collect::>() @@ -192,7 +197,7 @@ impl FrontContext { let attributes = back::Attributes::new(); // Collect information for flattening - let mut flatten_expr = FlattenExprContext::new(&analysis.type_inference); + let mut flatten_expr = FlattenExprContext::new(&analysis.type_inference, &self.foreign_predicate_registry); flatten_expr.walk_atom(src_rule.head()); // Create the flattened expression that the head needs @@ -310,9 +315,10 @@ impl FrontContext { temp_rules: &mut Vec, ) -> back::Literal { // unwrap is ok because the success of compute boundness is checked already - let body_bounded_vars = agg_ctx.body.compute_boundness(&vec![]).unwrap(); + let pred_bindings = ForeignPredicateBindings::from(&self.foreign_predicate_registry); + let body_bounded_vars = agg_ctx.body.compute_boundness(&pred_bindings, &vec![]).unwrap(); let group_by_bounded_vars = agg_ctx.group_by.as_ref().map_or(BTreeSet::new(), |(ctx, _, _)| { - ctx.compute_boundness(&vec![]).unwrap().into_iter().collect() + ctx.compute_boundness(&pred_bindings, &vec![]).unwrap().into_iter().collect() }); let all_bounded_vars = body_bounded_vars .union(&group_by_bounded_vars) diff --git a/core/src/compiler/front/f2b/flatten_expr.rs b/core/src/compiler/front/f2b/flatten_expr.rs index 66982f9..11ecb94 100644 --- a/core/src/compiler/front/f2b/flatten_expr.rs +++ b/core/src/compiler/front/f2b/flatten_expr.rs @@ -1,7 +1,8 @@ use std::collections::*; +use crate::common::foreign_predicate::*; use crate::compiler::back; -use crate::compiler::front::analyzers::TypeInference; +use crate::compiler::front::analyzers::*; use crate::compiler::front::utils::*; use crate::compiler::front::*; use crate::utils::IdAllocator; @@ -9,6 +10,7 @@ use crate::utils::IdAllocator; #[derive(Clone, Debug)] pub struct FlattenExprContext<'a> { pub type_inference: &'a TypeInference, + pub foreign_predicate_registry: &'a ForeignPredicateRegistry, pub id_allocator: IdAllocator, pub ignore_exprs: HashSet, pub internal: HashMap, @@ -56,9 +58,13 @@ impl FlattenedNode { pub type FlattenedLeaf = back::Term; impl<'a> FlattenExprContext<'a> { - pub fn new(type_inference: &'a TypeInference) -> Self { + pub fn new( + type_inference: &'a TypeInference, + foreign_predicate_registry: &'a ForeignPredicateRegistry, + ) -> Self { Self { type_inference, + foreign_predicate_registry, id_allocator: IdAllocator::default(), ignore_exprs: HashSet::new(), internal: HashMap::new(), @@ -225,11 +231,14 @@ impl<'a> FlattenExprContext<'a> { // First get the atom let back_atom_args = atom.iter_arguments().map(|a| self.get_expr_term(a)).collect(); - let back_atom = back::Literal::Atom(back::Atom { + let back_atom = back::Atom { predicate: atom.predicate().clone(), args: back_atom_args, - }); - literals.push(back_atom); + }; + + // Depending on whether the atom is foreign, add the literal differently + let back_literal = back::Literal::Atom(back_atom); + literals.push(back_literal); // Then collect all the intermediate variables for arg in atom.iter_arguments() { @@ -242,6 +251,34 @@ impl<'a> FlattenExprContext<'a> { pub fn neg_atom_to_back_literals(&self, neg_atom: &NegAtom) -> Vec { let mut literals = vec![]; + // First get the atom + let back_atom_args = neg_atom + .atom() + .iter_arguments() + .map(|a| self.get_expr_term(a)) + .collect(); + let back_atom = back::NegAtom { + atom: back::Atom { + predicate: neg_atom.predicate().clone(), + args: back_atom_args, + }, + }; + + // Then generate a literal + let back_literal = back::Literal::NegAtom(back_atom); + literals.push(back_literal); + + // Then collect all the intermediate variables + for arg in neg_atom.atom().iter_arguments() { + literals.extend(self.collect_flattened_literals(arg.location())); + } + + literals + } + + pub fn domestic_neg_atom_to_back_literals(&self, neg_atom: &NegAtom) -> Vec { + let mut literals = vec![]; + // First get the atom let back_atom_args = neg_atom .atom() diff --git a/core/src/compiler/front/grammar.lalrpop b/core/src/compiler/front/grammar.lalrpop index 54e0049..c9c39fc 100644 --- a/core/src/compiler/front/grammar.lalrpop +++ b/core/src/compiler/front/grammar.lalrpop @@ -1,7 +1,8 @@ use std::str::FromStr; use super::ast::*; -use crate::common::input_tag::InputTag; +use crate::common::input_tag::DynamicInputTag; +use crate::utils; grammar; @@ -85,6 +86,8 @@ match { "&str", "String", "Rc", + "DateTime", + "Duration", // Boolean keywords "true", @@ -106,6 +109,8 @@ match { r"-?[0-9]+" => int, r"-?\d+(\.\d+)(e-?\d+)?" => float, r#""[^"]*""# => string, + r#"t"[^"]*""# => date_time_string, + r#"d"[^"]*""# => duration_string, r#"'[^']*'"# => character, // Comments and Whitespaces @@ -172,6 +177,8 @@ TypeNode: TypeNode = { "&str" => TypeNode::Str, "String" => TypeNode::String, // "Rc" => TypeNode::RcString, + "DateTime" => TypeNode::DateTime, + "Duration" => TypeNode::Duration, => TypeNode::Named(n), } @@ -234,10 +241,40 @@ RelationTypeDeclNode: RelationTypeDeclNode = { RelationTypeDecl = Spanned; +EnumTypeDeclNode: EnumTypeDeclNode = { + "type" "=" > => { + EnumTypeDeclNode { + attrs, + name: n, + members: ms, + } + } +} + +EnumTypeDecl = Spanned; + +EnumTypeMemberNode: EnumTypeMemberNode = { + => { + EnumTypeMemberNode { + name: n, + assigned_num: None, + } + }, + "=" => { + EnumTypeMemberNode { + name: n, + assigned_num: Some(c), + } + } +} + +EnumTypeMember = Spanned; + TypeDeclNode: TypeDeclNode = { => TypeDeclNode::Subtype(s), => TypeDeclNode::Alias(a), => TypeDeclNode::Relation(r), + => TypeDeclNode::Enum(e), } TypeDecl: TypeDecl = Spanned; @@ -296,8 +333,8 @@ RelationDeclNode: RelationDeclNode = { RelationDecl: RelationDecl = Spanned; TagNode: TagNode = { - => TagNode(InputTag::Float(f)), - => TagNode(InputTag::Bool(b)), + => TagNode(DynamicInputTag::Float(f)), + => TagNode(DynamicInputTag::Bool(b)), } Tag: Tag = Spanned; @@ -307,6 +344,18 @@ ConstantNode: ConstantNode = { => ConstantNode::Integer(i), => ConstantNode::Float(f), => ConstantNode::String(s), + => { + match utils::parse_date_time_string(&s) { + Some(v) => ConstantNode::DateTime(v), + None => ConstantNode::Invalid(format!("Cannot parse date time `{}`", s)), + } + }, + => { + match utils::parse_duration_string(&s) { + Some(v) => ConstantNode::Duration(v), + None => ConstantNode::Invalid(format!("Cannot parse duration `{}`", s)), + } + }, => ConstantNode::Char(c), } @@ -393,7 +442,7 @@ ConjDisjFormula = { } CommaConjunctionNode: ConjunctionNode = { - > => ConjunctionNode { args } + > => ConjunctionNode { args } } CommaConjunction = Spanned; @@ -438,7 +487,7 @@ DisjunctionFormula = { ConjunctionKeyword = { "/\\", "and" } ConjunctionNode: ConjunctionNode = { - > => { + > => { ConjunctionNode { args } } } @@ -447,17 +496,18 @@ Conjunction = Spanned; ConjunctionFormula: Formula = { => Formula::Conjunction(c), - NegAtomFormula, + AnnotatedAtomFormula, } +NegateKeyword = { "~", "not" } + NegAtomNode: NegAtomNode = { - "~" => NegAtomNode { atom: a }, - "not" => NegAtomNode { atom: a }, + NegateKeyword => NegAtomNode { atom: a }, } NegAtom = Spanned; -NegAtomFormula: Formula = { +AnnotatedAtomFormula: Formula = { => Formula::NegAtom(n), UnitFormula, } @@ -876,6 +926,14 @@ StringLiteral: String = => { s[1..s.len() - 1].replace("\\t", "\t").replace("\\n", "\n").replace("\\\\", "\\").into() }; +DateTimeLiteral: String = => { + s[2..s.len() - 1].replace("\\t", "\t").replace("\\n", "\n").replace("\\\\", "\\").into() +}; + +DurationLiteral: String = => { + s[2..s.len() - 1].replace("\\t", "\t").replace("\\n", "\n").replace("\\\\", "\\").into() +}; + CharLiteral: String = => { s[1..s.len() - 1].replace("\\t", "\t").replace("\\n", "\n").replace("\\'", "'").replace("\\\\", "\\").into() }; diff --git a/core/src/compiler/front/pretty.rs b/core/src/compiler/front/pretty.rs index adffe90..1fb930f 100644 --- a/core/src/compiler/front/pretty.rs +++ b/core/src/compiler/front/pretty.rs @@ -29,6 +29,7 @@ impl Display for TypeDecl { TypeDeclNode::Subtype(s) => s.fmt(f), TypeDeclNode::Alias(s) => s.fmt(f), TypeDeclNode::Relation(s) => s.fmt(f), + TypeDeclNode::Enum(e) => e.fmt(f), } } } @@ -184,6 +185,32 @@ impl Display for RelationType { } } +impl Display for EnumTypeDecl { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + for attr in self.attributes() { + attr.fmt(f)?; + } + f.write_fmt(format_args!("type {} = ", self.name()))?; + for (i, member) in self.iter_members().enumerate() { + if i > 0 { + f.write_str(" | ")?; + } + member.fmt(f)?; + } + Ok(()) + } +} + +impl Display for EnumTypeMember { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + self.name().fmt(f)?; + if let Some(assigned_num) = self.assigned_number() { + f.write_fmt(format_args!(" = {}", assigned_num))?; + } + Ok(()) + } +} + impl Display for Identifier { fn fmt(&self, f: &mut Formatter<'_>) -> Result { f.write_str(self.name()) @@ -280,6 +307,9 @@ impl std::fmt::Display for Constant { ConstantNode::Char(c) => f.write_fmt(format_args!("'{}'", c)), ConstantNode::Boolean(b) => f.write_fmt(format_args!("{}", b)), ConstantNode::String(s) => f.write_fmt(format_args!("\"{}\"", s)), + ConstantNode::DateTime(s) => f.write_fmt(format_args!("\"{}\"", s)), + ConstantNode::Duration(s) => f.write_fmt(format_args!("\"{}\"", s)), + ConstantNode::Invalid(_) => f.write_str("invalid"), } } } @@ -307,7 +337,7 @@ impl Display for Formula { impl Display for NegAtom { fn fmt(&self, f: &mut Formatter<'_>) -> Result { - f.write_fmt(format_args!("~{}", self.atom())) + f.write_fmt(format_args!("not {}", self.atom())) } } @@ -315,7 +345,7 @@ impl Display for Disjunction { fn fmt(&self, f: &mut Formatter<'_>) -> Result { f.write_fmt(format_args!( "({})", - self.args().map(|a| format!("{}", a)).collect::>().join(" \\/ ") + self.args().map(|a| format!("{}", a)).collect::>().join(" or ") )) } } @@ -324,7 +354,7 @@ impl Display for Conjunction { fn fmt(&self, f: &mut Formatter<'_>) -> Result { f.write_fmt(format_args!( "({})", - self.args().map(|a| format!("{}", a)).collect::>().join(" /\\ ") + self.args().map(|a| format!("{}", a)).collect::>().join(" and ") )) } } diff --git a/core/src/compiler/front/transformations/tagged_rule.rs b/core/src/compiler/front/transformations/tagged_rule.rs index 754ebf3..da6b37c 100644 --- a/core/src/compiler/front/transformations/tagged_rule.rs +++ b/core/src/compiler/front/transformations/tagged_rule.rs @@ -1,8 +1,9 @@ -use crate::{common::input_tag::InputTag, compiler::front::*}; +use crate::common::input_tag::*; +use crate::compiler::front::*; #[derive(Clone, Debug)] pub struct TransformTaggedRule { - pub to_add_tags: Vec<(String, InputTag)>, + pub to_add_tags: Vec<(String, DynamicInputTag)>, } impl TransformTaggedRule { diff --git a/core/src/compiler/front/visitor.rs b/core/src/compiler/front/visitor.rs index 43458bb..667b245 100644 --- a/core/src/compiler/front/visitor.rs +++ b/core/src/compiler/front/visitor.rs @@ -19,6 +19,8 @@ pub trait NodeVisitor { node_visitor_func_def!(visit_subtype_decl, SubtypeDecl); node_visitor_func_def!(visit_relation_type_decl, RelationTypeDecl); node_visitor_func_def!(visit_relation_type, RelationType); + node_visitor_func_def!(visit_enum_type_decl, EnumTypeDecl); + node_visitor_func_def!(visit_enum_type_member, EnumTypeMember); node_visitor_func_def!(visit_const_decl, ConstDecl); node_visitor_func_def!(visit_const_assignment, ConstAssignment); node_visitor_func_def!(visit_relation_decl, RelationDecl); @@ -103,6 +105,7 @@ pub trait NodeVisitor { TypeDeclNode::Alias(a) => self.walk_alias_type_decl(a), TypeDeclNode::Subtype(s) => self.walk_subtype_decl(s), TypeDeclNode::Relation(r) => self.walk_relation_type_decl(r), + TypeDeclNode::Enum(e) => self.walk_enum_type_decl(e), } } @@ -140,6 +143,24 @@ pub trait NodeVisitor { } } + fn walk_enum_type_decl(&mut self, enum_type_decl: &EnumTypeDecl) { + self.visit_enum_type_decl(enum_type_decl); + self.visit_location(&enum_type_decl.loc); + self.walk_identifier(&enum_type_decl.node.name); + for member in enum_type_decl.members() { + self.walk_enum_type_member(member); + } + } + + fn walk_enum_type_member(&mut self, enum_type_member: &EnumTypeMember) { + self.visit_enum_type_member(enum_type_member); + self.visit_location(&enum_type_member.loc); + self.walk_identifier(&enum_type_member.node.name); + if let Some(assigned_num) = enum_type_member.assigned_number() { + self.walk_constant(assigned_num); + } + } + fn walk_const_decl(&mut self, const_decl: &ConstDecl) { self.visit_const_decl(const_decl); self.visit_location(const_decl.location()); @@ -491,6 +512,8 @@ macro_rules! impl_node_visitor_tuple { node_visitor_visit_node!(visit_subtype_decl, SubtypeDecl, ($($id),*)); node_visitor_visit_node!(visit_relation_type_decl, RelationTypeDecl, ($($id),*)); node_visitor_visit_node!(visit_relation_type, RelationType, ($($id),*)); + node_visitor_visit_node!(visit_enum_type_decl, EnumTypeDecl, ($($id),*)); + node_visitor_visit_node!(visit_enum_type_member, EnumTypeMember, ($($id),*)); node_visitor_visit_node!(visit_const_decl, ConstDecl, ($($id),*)); node_visitor_visit_node!(visit_const_assignment, ConstAssignment, ($($id),*)); node_visitor_visit_node!(visit_relation_decl, RelationDecl, ($($id),*)); diff --git a/core/src/compiler/front/visitor_mut.rs b/core/src/compiler/front/visitor_mut.rs index 665e7f7..62b6f29 100644 --- a/core/src/compiler/front/visitor_mut.rs +++ b/core/src/compiler/front/visitor_mut.rs @@ -19,6 +19,8 @@ pub trait NodeVisitorMut { node_visitor_mut_func_def!(visit_subtype_decl, SubtypeDecl); node_visitor_mut_func_def!(visit_relation_type_decl, RelationTypeDecl); node_visitor_mut_func_def!(visit_relation_type, RelationType); + node_visitor_mut_func_def!(visit_enum_type_decl, EnumTypeDecl); + node_visitor_mut_func_def!(visit_enum_type_member, EnumTypeMember); node_visitor_mut_func_def!(visit_const_decl, ConstDecl); node_visitor_mut_func_def!(visit_const_assignment, ConstAssignment); node_visitor_mut_func_def!(visit_relation_decl, RelationDecl); @@ -103,6 +105,7 @@ pub trait NodeVisitorMut { TypeDeclNode::Alias(a) => self.walk_alias_type_decl(a), TypeDeclNode::Subtype(s) => self.walk_subtype_decl(s), TypeDeclNode::Relation(r) => self.walk_relation_type_decl(r), + TypeDeclNode::Enum(e) => self.walk_enum_type_decl(e), } } @@ -140,6 +143,24 @@ pub trait NodeVisitorMut { } } + fn walk_enum_type_decl(&mut self, enum_type_decl: &mut EnumTypeDecl) { + self.visit_enum_type_decl(enum_type_decl); + self.visit_location(&mut enum_type_decl.loc); + self.walk_identifier(&mut enum_type_decl.node.name); + for member in enum_type_decl.iter_members_mut() { + self.walk_enum_type_member(member); + } + } + + fn walk_enum_type_member(&mut self, enum_type_member: &mut EnumTypeMember) { + self.visit_enum_type_member(enum_type_member); + self.visit_location(&mut enum_type_member.loc); + self.walk_identifier(&mut enum_type_member.node.name); + if let Some(assigned_num) = enum_type_member.assigned_number_mut() { + self.walk_constant(assigned_num); + } + } + fn walk_const_decl(&mut self, const_decl: &mut ConstDecl) { self.visit_const_decl(const_decl); self.visit_location(const_decl.location_mut()); @@ -489,6 +510,8 @@ macro_rules! impl_node_visitor_mut_tuple { node_visitor_mut_visit_node!(visit_subtype_decl, SubtypeDecl, ($($id),*)); node_visitor_mut_visit_node!(visit_relation_type_decl, RelationTypeDecl, ($($id),*)); node_visitor_mut_visit_node!(visit_relation_type, RelationType, ($($id),*)); + node_visitor_mut_visit_node!(visit_enum_type_decl, EnumTypeDecl, ($($id),*)); + node_visitor_mut_visit_node!(visit_enum_type_member, EnumTypeMember, ($($id),*)); node_visitor_mut_visit_node!(visit_const_decl, ConstDecl, ($($id),*)); node_visitor_mut_visit_node!(visit_const_assignment, ConstAssignment, ($($id),*)); node_visitor_mut_visit_node!(visit_relation_decl, RelationDecl, ($($id),*)); diff --git a/core/src/compiler/ram/ast.rs b/core/src/compiler/ram/ast.rs index cb03911..2665098 100644 --- a/core/src/compiler/ram/ast.rs +++ b/core/src/compiler/ram/ast.rs @@ -3,16 +3,19 @@ use std::collections::*; use crate::common::aggregate_op::AggregateOp; use crate::common::expr::*; use crate::common::foreign_function::*; +use crate::common::foreign_predicate::*; use crate::common::input_file::InputFile; -use crate::common::input_tag::InputTag; +use crate::common::input_tag::DynamicInputTag; use crate::common::output_option::OutputOption; use crate::common::tuple::{AsTuple, Tuple}; use crate::common::tuple_type::TupleType; +use crate::common::value::Value; #[derive(Debug, Clone)] pub struct Program { pub strata: Vec, pub function_registry: ForeignFunctionRegistry, + pub predicate_registry: ForeignPredicateRegistry, pub relation_to_stratum: HashMap, } @@ -21,14 +24,7 @@ impl Program { Self { strata: Vec::new(), function_registry: ForeignFunctionRegistry::new(), - relation_to_stratum: HashMap::new(), - } - } - - pub fn new_with_function_registry(function_registry: ForeignFunctionRegistry) -> Self { - Self { - strata: Vec::new(), - function_registry, + predicate_registry: ForeignPredicateRegistry::new(), relation_to_stratum: HashMap::new(), } } @@ -160,7 +156,7 @@ impl std::cmp::Ord for Relation { #[derive(Debug, Clone, PartialEq, PartialOrd)] pub struct Fact { - pub tag: InputTag, + pub tag: DynamicInputTag, pub tuple: Tuple, } @@ -194,23 +190,30 @@ pub enum Dataflow { Filter(Box, Expr), Find(Box, Tuple), OverwriteOne(Box), + ForeignPredicateGround(String, Vec), + ForeignPredicateConstraint(Box, String, Vec), + ForeignPredicateJoin(Box, String, Vec), Reduce(Reduce), Relation(String), } impl Dataflow { + /// Create a new unit dataflow given a tuple type pub fn unit(tuple_type: TupleType) -> Self { Self::Unit(tuple_type) } + /// Create a union dataflow pub fn union(self, d2: Dataflow) -> Self { Self::Union(Box::new(self), Box::new(d2)) } + /// Create a join-ed dataflow from two dataflows pub fn join(self, d2: Dataflow) -> Self { Self::Join(Box::new(self), Box::new(d2)) } + /// Create an intersection dataflow from two dataflows pub fn intersect(self, d2: Dataflow) -> Self { Self::Intersect(Box::new(self), Box::new(d2)) } @@ -243,6 +246,14 @@ impl Dataflow { Self::OverwriteOne(Box::new(self)) } + pub fn foreign_predicate_constraint(self, predicate: String, args: Vec) -> Self { + Self::ForeignPredicateConstraint(Box::new(self), predicate, args) + } + + pub fn foreign_predicate_join(self, predicate: String, args: Vec) -> Self { + Self::ForeignPredicateJoin(Box::new(self), predicate, args) + } + pub fn reduce(op: AggregateOp, predicate: S, group_by: ReduceGroupByType) -> Self { Self::Reduce(Reduce { op, @@ -264,9 +275,15 @@ impl Dataflow { | Self::Product(d1, d2) | Self::Antijoin(d1, d2) | Self::Difference(d1, d2) => d1.source_relations().union(&d2.source_relations()).cloned().collect(), - Self::Project(d, _) | Self::Filter(d, _) | Self::Find(d, _) | Self::OverwriteOne(d) => d.source_relations(), + Self::Project(d, _) + | Self::Filter(d, _) + | Self::Find(d, _) + | Self::OverwriteOne(d) + | Self::ForeignPredicateConstraint(d, _, _) + | Self::ForeignPredicateJoin(d, _, _) => d.source_relations(), Self::Reduce(r) => std::iter::once(r.source_relation()).collect(), Self::Relation(r) => std::iter::once(r).collect(), + Self::ForeignPredicateGround(_, _) => HashSet::new(), } } } diff --git a/core/src/compiler/ram/dependency.rs b/core/src/compiler/ram/dependency.rs index d3953c8..2fe8601 100644 --- a/core/src/compiler/ram/dependency.rs +++ b/core/src/compiler/ram/dependency.rs @@ -114,6 +114,13 @@ impl Dataflow { d1.collect_dependency(preds); d2.collect_dependency(preds); } + Self::ForeignPredicateGround(_, _) => {} + Self::ForeignPredicateConstraint(d, _, _) => { + d.collect_dependency(preds); + } + Self::ForeignPredicateJoin(d, _, _) => { + d.collect_dependency(preds); + } } } } diff --git a/core/src/compiler/ram/optimizations/project_cascade.rs b/core/src/compiler/ram/optimizations/project_cascade.rs index 92927ef..a057a41 100644 --- a/core/src/compiler/ram/optimizations/project_cascade.rs +++ b/core/src/compiler/ram/optimizations/project_cascade.rs @@ -57,6 +57,11 @@ fn project_cascade_on_dataflow(d0: &mut Dataflow) -> bool { Dataflow::Filter(d, _) => project_cascade_on_dataflow(&mut **d), Dataflow::Find(d, _) => project_cascade_on_dataflow(&mut **d), Dataflow::OverwriteOne(d) => project_cascade_on_dataflow(&mut **d), - Dataflow::Unit(_) | Dataflow::Relation(_) | Dataflow::Reduce(_) => false, + Dataflow::ForeignPredicateConstraint(d, _, _) => project_cascade_on_dataflow(&mut **d), + Dataflow::ForeignPredicateJoin(d, _, _) => project_cascade_on_dataflow(&mut **d), + Dataflow::ForeignPredicateGround(_, _) + | Dataflow::Unit(_) + | Dataflow::Relation(_) + | Dataflow::Reduce(_) => false, } } diff --git a/core/src/compiler/ram/pretty.rs b/core/src/compiler/ram/pretty.rs index 7c499ca..9253076 100644 --- a/core/src/compiler/ram/pretty.rs +++ b/core/src/compiler/ram/pretty.rs @@ -84,6 +84,20 @@ impl Dataflow { r.op, r.predicate, group_by_predicate )) } + Self::ForeignPredicateGround(pred, args) => { + let args = args.iter().map(|a| format!("{:?}", a)).collect::>(); + f.write_fmt(format_args!("ForeignPredicateGround[{}({})]", pred, args.join(", "))) + } + Self::ForeignPredicateConstraint(d, pred, args) => { + let args = args.iter().map(|a| format!("{:?}", a)).collect::>(); + f.write_fmt(format_args!("ForeignPredicateConstraint[{}({})]\n{}", pred, args.join(", "), padding))?; + d.pretty_print(f, next_indent, indent_size) + } + Self::ForeignPredicateJoin(d, pred, args) => { + let args = args.iter().map(|a| format!("{:?}", a)).collect::>(); + f.write_fmt(format_args!("ForeignPredicateJoin[{}({})]\n{}", pred, args.join(", "), padding))?; + d.pretty_print(f, next_indent, indent_size) + } Self::OverwriteOne(d) => { f.write_fmt(format_args!("OverwriteOne\n{}", padding))?; d.pretty_print(f, next_indent, indent_size) diff --git a/core/src/compiler/ram/ram2rs.rs b/core/src/compiler/ram/ram2rs.rs index b45e5b2..c150bd6 100644 --- a/core/src/compiler/ram/ram2rs.rs +++ b/core/src/compiler/ram/ram2rs.rs @@ -346,6 +346,9 @@ impl ast::Dataflow { let rs_d1 = d1.to_rs_dataflow(curr_strat_id, rel_to_strat_map); quote! { dataflow::overwrite_one(#rs_d1) } } + Self::ForeignPredicateGround(_, _) => unimplemented!(), + Self::ForeignPredicateConstraint(_, _, _) => unimplemented!(), + Self::ForeignPredicateJoin(_, _, _) => unimplemented!(), Self::Reduce(r) => { let get_col = |r| { let rel_ident = relation_name_to_rs_field_name(r); @@ -446,6 +449,8 @@ fn value_type_to_rs_type(ty: &ValueType) -> TokenStream { ValueType::Char => quote! { char }, ValueType::Str => quote! { &'static str }, ValueType::String => quote! { String }, + ValueType::DateTime => quote! { DateTime }, + ValueType::Duration => quote! { Duration }, // ValueType::RcString => quote! { Rc }, } } @@ -499,13 +504,13 @@ fn expr_to_rs_expr(expr: &Expr) -> TokenStream { } } -fn input_tag_to_rs_input_tag(tag: &InputTag) -> TokenStream { +fn input_tag_to_rs_input_tag(tag: &DynamicInputTag) -> TokenStream { match tag { - InputTag::None => quote! { InputTag::None }, - InputTag::Exclusive(i) => quote! { InputTag::Exclusive(#i) }, - InputTag::Bool(b) => quote! { InputTag::Bool(#b) }, - InputTag::Float(f) => quote! { InputTag::Float(#f) }, - InputTag::ExclusiveFloat(f, u) => quote! { InputTag::ExclusiveFloat(#f, #u) }, + DynamicInputTag::None => quote! { DynamicInputTag::None }, + DynamicInputTag::Exclusive(i) => quote! { DynamicInputTag::Exclusive(#i) }, + DynamicInputTag::Bool(b) => quote! { DynamicInputTag::Bool(#b) }, + DynamicInputTag::Float(f) => quote! { DynamicInputTag::Float(#f) }, + DynamicInputTag::ExclusiveFloat(f, u) => quote! { DynamicInputTag::ExclusiveFloat(#f, #u) }, } } @@ -562,6 +567,8 @@ fn value_to_rs_value(value: &Value) -> TokenStream { Str(s) => quote! { #s }, String(s) => quote! { String::from(#s) }, // RcString(s) => quote! { Rc::new(String::from(#s)) }, + DateTime(_) => unimplemented!(), + Duration(_) => unimplemented!(), } } diff --git a/core/src/compiler/ram/transform.rs b/core/src/compiler/ram/transform.rs index d5ccb3f..60afaf8 100644 --- a/core/src/compiler/ram/transform.rs +++ b/core/src/compiler/ram/transform.rs @@ -3,6 +3,7 @@ use crate::common::output_option::OutputOption; use super::*; impl Stratum { + /// Set the output option to `default` for all relations in the stratum pub fn output_all(&mut self) { self.relations.iter_mut().for_each(|(_, r)| { r.output = OutputOption::default(); diff --git a/core/src/integrate/context.rs b/core/src/integrate/context.rs index ec6d47e..a276182 100644 --- a/core/src/integrate/context.rs +++ b/core/src/integrate/context.rs @@ -1,9 +1,10 @@ use crate::common::foreign_function::*; -use crate::common::input_tag::*; +use crate::common::foreign_predicate::*; use crate::common::tuple::*; use crate::common::tuple_type::*; + use crate::compiler; -use crate::runtime::database::extensional::ExtensionalDatabase; +use crate::runtime::database::extensional::*; use crate::runtime::database::*; use crate::runtime::dynamic; use crate::runtime::env::*; @@ -270,11 +271,44 @@ impl IntegrateContext { Ok(()) } + /// Register a foreign predicate to the context + pub fn register_foreign_predicate(&mut self, fp: F) -> Result<(), IntegrateError> + where + F: ForeignPredicate + Send + Sync + Clone + 'static, + { + // Add the predicate to front compilation context + self + .front_ctx + .register_foreign_predicate(fp) + .map_err(|e| IntegrateError::Runtime(RuntimeError::ForeignPredicate(e)))?; + + // If goes through, then the front context has changed + self.front_has_changed = true; + + // Return Ok + Ok(()) + } + /// Set the context to be non-incremental anymore pub fn set_non_incremental(&mut self) { self.internal.exec_ctx.set_non_incremental(); } + /// Set whether to perform early discard + pub fn set_early_discard(&mut self, early_discard: bool) { + self.internal.runtime_env.set_early_discard(early_discard) + } + + /// Set the iteration limit + pub fn set_iter_limit(&mut self, k: usize) { + self.internal.runtime_env.set_iter_limit(k) + } + + /// Remove the iteration limit + pub fn remove_iter_limit(&mut self) { + self.internal.runtime_env.remove_iter_limit() + } + /// Get a mutable refernce to the Extensional Database (EDB) pub fn edb(&mut self) -> &mut ExtensionalDatabase { &mut self.internal.exec_ctx.edb @@ -331,7 +365,6 @@ impl IntegrateContext { /// Execute the program in its current state, with a limit set on iteration count pub fn run_with_monitor(&mut self, m: &M) -> Result<(), IntegrateError> where - Prov::InputTag: FromInputTag, M: Monitor, { // First compile the code @@ -342,10 +375,7 @@ impl IntegrateContext { } /// Execute the program in its current state, with a limit set on iteration count - pub fn run(&mut self) -> Result<(), IntegrateError> - where - Prov::InputTag: FromInputTag, - { + pub fn run(&mut self) -> Result<(), IntegrateError> { // First compile the code self.compile()?; @@ -394,7 +424,7 @@ impl IntegrateContext { } /// Get the relation output collection of a given relation - pub fn computed_relation(&mut self, relation: &str) -> Option>> { + pub fn computed_relation(&mut self, relation: &str) -> Option>> { self.internal.computed_relation(relation) } @@ -403,7 +433,7 @@ impl IntegrateContext { &mut self, relation: &str, m: &M, - ) -> Option>> + ) -> Option>> where M: Monitor, { @@ -487,11 +517,11 @@ impl InternalIntegrateContext { /// Execute the program in its current state, with a limit set on iteration count pub fn run_with_monitor(&mut self, m: &M) -> Result<(), IntegrateError> where - Prov::InputTag: FromInputTag, M: Monitor, { - // Populate the runtime foreign function registry + // Populate the runtime foreign function/predicate registry self.runtime_env.function_registry = self.ram_program.function_registry.clone(); + self.runtime_env.predicate_registry = self.ram_program.predicate_registry.clone(); // Finally execute the ram self @@ -504,12 +534,10 @@ impl InternalIntegrateContext { } /// Execute the program in its current state, with a limit set on iteration count - pub fn run(&mut self) -> Result<(), IntegrateError> - where - Prov::InputTag: FromInputTag, - { - // Populate the runtime foreign function registry + pub fn run(&mut self) -> Result<(), IntegrateError> { + // Populate the runtime foreign function/predicate registry self.runtime_env.function_registry = self.ram_program.function_registry.clone(); + self.runtime_env.predicate_registry = self.ram_program.predicate_registry.clone(); // Finally execute the ram self @@ -542,7 +570,7 @@ impl InternalIntegrateContext { } /// Get the RC'ed output collection of a given relation - pub fn computed_relation(&mut self, relation: &str) -> Option>> { + pub fn computed_relation(&mut self, relation: &str) -> Option>> { self.exec_ctx.recover(relation, &self.prov_ctx); self.exec_ctx.relation(relation) } @@ -552,7 +580,7 @@ impl InternalIntegrateContext { &mut self, relation: &str, m: &M, - ) -> Option>> { + ) -> Option>> { self.exec_ctx.recover_with_monitor(relation, &self.prov_ctx, m); self.exec_ctx.relation(relation) } diff --git a/core/src/runtime/database/extensional/database.rs b/core/src/runtime/database/extensional/database.rs index df7d710..7df63c8 100644 --- a/core/src/runtime/database/extensional/database.rs +++ b/core/src/runtime/database/extensional/database.rs @@ -1,6 +1,6 @@ use std::collections::*; -use crate::common::input_tag::InputTag; +use crate::common::input_tag::*; use crate::common::tuple::*; use crate::common::tuple_type::*; use crate::compiler::ram; @@ -89,7 +89,7 @@ impl ExtensionalDatabase { self.extensional_relations.contains_key(relation) } - pub fn add_dynamic_input_facts(&mut self, relation: &str, facts: Vec<(InputTag, T)>) -> Result<(), DatabaseError> + pub fn add_dynamic_input_facts(&mut self, relation: &str, facts: Vec<(DynamicInputTag, T)>) -> Result<(), DatabaseError> where T: Into, { diff --git a/core/src/runtime/database/extensional/relation.rs b/core/src/runtime/database/extensional/relation.rs index 21285b6..0f44b9b 100644 --- a/core/src/runtime/database/extensional/relation.rs +++ b/core/src/runtime/database/extensional/relation.rs @@ -7,14 +7,14 @@ use crate::runtime::provenance::*; #[derive(Clone, Debug)] pub struct ExtensionalRelation { /// The facts from the program - program_facts: Vec<(InputTag, Tuple)>, + program_facts: Vec<(DynamicInputTag, Tuple)>, /// Whether we have internalized the program facts; we only allow a single /// round of internalization of program facts pub internalized_program_facts: bool, /// Dynamically tagged input facts - dynamic_input: Vec<(InputTag, Tuple)>, + dynamic_input: Vec<(DynamicInputTag, Tuple)>, /// Statically tagged input facts static_input: Vec<(Option, Tuple)>, @@ -54,7 +54,7 @@ impl ExtensionalRelation { pub fn add_program_facts(&mut self, i: I) where - I: Iterator, + I: Iterator, { self.program_facts.extend(i) } @@ -67,7 +67,7 @@ impl ExtensionalRelation { self.static_input.extend(facts.into_iter().map(|tup| (None, tup))) } - pub fn add_dynamic_input_facts(&mut self, facts: Vec<(InputTag, Tuple)>) { + pub fn add_dynamic_input_facts(&mut self, facts: Vec<(DynamicInputTag, Tuple)>) { if !facts.is_empty() { self.internalized = false; } @@ -90,7 +90,7 @@ impl ExtensionalRelation { if !self.program_facts.is_empty() { // Iterate (not drain) the program facts elems.extend(self.program_facts.iter().map(|(tag, tup)| { - let maybe_input_tag: Option = FromInputTag::from_input_tag(&tag); + let maybe_input_tag = StaticInputTag::from_dynamic_input_tag(&tag); let tag = ctx.tagging_optional_fn(maybe_input_tag); DynamicElement::new(tup.clone(), tag) })); @@ -101,7 +101,7 @@ impl ExtensionalRelation { // First internalize dynamic input facts elems.extend(self.dynamic_input.drain(..).map(|(tag, tup)| { - let maybe_input_tag: Option = FromInputTag::from_input_tag(&tag); + let maybe_input_tag = StaticInputTag::from_dynamic_input_tag(&tag); let tag = ctx.tagging_optional_fn(maybe_input_tag); DynamicElement::new(tup, tag) })); @@ -127,7 +127,7 @@ impl ExtensionalRelation { if !self.program_facts.is_empty() { // Iterate (not drain) the program facts elems.extend(self.program_facts.iter().map(|(tag, tup)| { - let maybe_input_tag: Option = FromInputTag::from_input_tag(&tag); + let maybe_input_tag = StaticInputTag::from_dynamic_input_tag(&tag); let tag = ctx.tagging_optional_fn(maybe_input_tag.clone()); // !SPECIAL MONITORING! @@ -142,7 +142,7 @@ impl ExtensionalRelation { // First internalize dynamic input facts elems.extend(self.dynamic_input.drain(..).map(|(tag, tup)| { - let maybe_input_tag: Option = FromInputTag::from_input_tag(&tag); + let maybe_input_tag = StaticInputTag::from_dynamic_input_tag(&tag); let tag = ctx.tagging_optional_fn(maybe_input_tag.clone()); // !SPECIAL MONITORING! diff --git a/core/src/runtime/database/intentional/database.rs b/core/src/runtime/database/intentional/database.rs index 5dd76da..ade9371 100644 --- a/core/src/runtime/database/intentional/database.rs +++ b/core/src/runtime/database/intentional/database.rs @@ -74,7 +74,7 @@ impl IntentionalDatabase { IntentionalRelation { recovered: true, internal_facts: DynamicCollection::empty(), - recovered_facts: Ptr::new(DynamicOutputCollection::from( + recovered_facts: Ptr::new_rc(DynamicOutputCollection::from( edb_relation .internal .iter() @@ -109,17 +109,17 @@ impl IntentionalDatabase { pub fn get_output_collection_ref(&self, relation: &str) -> Option<&DynamicOutputCollection> { self.intentional_relations.get(relation).and_then(|r| { if r.recovered { - Some(Ptr::get(&r.recovered_facts)) + Some(Ptr::get_rc(&r.recovered_facts)) } else { None } }) } - pub fn get_output_collection(&self, relation: &str) -> Option>> { + pub fn get_output_collection(&self, relation: &str) -> Option>> { self.intentional_relations.get(relation).and_then(|r| { if r.recovered { - Some(Ptr::clone_ptr(&r.recovered_facts)) + Some(Ptr::clone_rc(&r.recovered_facts)) } else { None } diff --git a/core/src/runtime/database/intentional/relation.rs b/core/src/runtime/database/intentional/relation.rs index d51c788..4650a3b 100644 --- a/core/src/runtime/database/intentional/relation.rs +++ b/core/src/runtime/database/intentional/relation.rs @@ -11,7 +11,7 @@ pub struct IntentionalRelation { pub internal_facts: DynamicCollection, /// Recovered facts - pub recovered_facts: Ptr::Pointer>, + pub recovered_facts: Ptr::Rc>, } impl Default for IntentionalRelation { @@ -25,7 +25,7 @@ impl Clone for IntentionalRelation std::fmt::Debug for IntentionalRelati fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("IDBRelation") .field("internal", &self.internal_facts) - .field("recovered", &Ptr::get(&self.recovered_facts)) + .field("recovered", &Ptr::get_rc(&self.recovered_facts)) .finish() } } impl std::fmt::Display for IntentionalRelation { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Ptr::get(&self.recovered_facts).fmt(f) + Ptr::get_rc(&self.recovered_facts).fmt(f) } } @@ -50,7 +50,7 @@ impl IntentionalRelation { Self { recovered: false, internal_facts: DynamicCollection::empty(), - recovered_facts: Ptr::new(DynamicOutputCollection::empty()), + recovered_facts: Ptr::new_rc(DynamicOutputCollection::empty()), } } @@ -58,7 +58,7 @@ impl IntentionalRelation { Self { recovered: false, internal_facts: collection, - recovered_facts: Ptr::new(DynamicOutputCollection::empty()), + recovered_facts: Ptr::new_rc(DynamicOutputCollection::empty()), } } @@ -66,7 +66,7 @@ impl IntentionalRelation { Self { recovered: true, internal_facts: DynamicCollection::empty(), - recovered_facts: Ptr::new(collection), + recovered_facts: Ptr::new_rc(collection), } } @@ -75,14 +75,14 @@ impl IntentionalRelation { if !self.recovered && !self.internal_facts.is_empty() { if drain { // Add internal facts to recovered facts, and remove the internal facts - Ptr::get_mut(&mut self.recovered_facts).extend(self.internal_facts.drain().map(|elem| { + Ptr::get_rc_mut(&mut self.recovered_facts).extend(self.internal_facts.drain().map(|elem| { let output_tag = ctx.recover_fn(&elem.tag); m.observe_recover(&elem.tuple, &elem.tag, &output_tag); (output_tag, elem.tuple) })); } else { // Add internal facts to recover facts, do not remove the internal facts - Ptr::get_mut(&mut self.recovered_facts).extend(self.internal_facts.iter().map(|elem| { + Ptr::get_rc_mut(&mut self.recovered_facts).extend(self.internal_facts.iter().map(|elem| { let output_tag = ctx.recover_fn(&elem.tag); m.observe_recover(&elem.tuple, &elem.tag, &output_tag); (output_tag, elem.tuple.clone()) @@ -106,13 +106,13 @@ impl IntentionalRelation { // Check if we need to drain the internal facts if drain { // Add internal facts to recovered facts, and remove the internal facts - Ptr::get_mut(&mut self.recovered_facts).extend(self.internal_facts.drain().map(|elem| { + Ptr::get_rc_mut(&mut self.recovered_facts).extend(self.internal_facts.drain().map(|elem| { let output_tag = ctx.recover_fn(&elem.tag); (output_tag, elem.tuple) })); } else { // Add internal facts to recover facts, do not remove the internal facts - Ptr::get_mut(&mut self.recovered_facts).extend(self.internal_facts.iter().map(|elem| { + Ptr::get_rc_mut(&mut self.recovered_facts).extend(self.internal_facts.iter().map(|elem| { let output_tag = ctx.recover_fn(&elem.tag); (output_tag, elem.tuple.clone()) })); diff --git a/core/src/runtime/dynamic/dataflow/batching/batch.rs b/core/src/runtime/dynamic/dataflow/batching/batch.rs index 27669ec..11b22df 100644 --- a/core/src/runtime/dynamic/dataflow/batching/batch.rs +++ b/core/src/runtime/dynamic/dataflow/batching/batch.rs @@ -19,6 +19,8 @@ pub enum DynamicBatch<'a, Prov: Provenance> { Product(DynamicProductBatch<'a, Prov>), Difference(DynamicDifferenceBatch<'a, Prov>), Antijoin(DynamicAntijoinBatch<'a, Prov>), + ForeignPredicateConstraint(ForeignPredicateConstraintBatch<'a, Prov>), + ForeignPredicateJoin(ForeignPredicateJoinBatch<'a, Prov>), } impl<'a, Prov: Provenance> DynamicBatch<'a, Prov> { @@ -30,14 +32,6 @@ impl<'a, Prov: Provenance> DynamicBatch<'a, Prov> { Self::SourceVec(v.into_iter()) } - // pub fn aggregate( - // source: DynamicGroupsIterator, - // agg: DynamicAggregateOp, - // ctx: &'a Prov, - // ) -> Self { - // Self::Aggregation(DynamicAggregationBatch::new(source, agg, ctx)) - // } - pub fn step(&mut self, u: usize) { match self { Self::DynamicRelationStable(s) => s.elem_id += u, @@ -124,6 +118,8 @@ impl<'a, Prov: Provenance> Iterator for DynamicBatch<'a, Prov> { Self::Product(p) => p.next(), Self::Difference(d) => d.next(), Self::Antijoin(a) => a.next(), + Self::ForeignPredicateConstraint(b) => b.next(), + Self::ForeignPredicateJoin(b) => b.next(), } } } diff --git a/core/src/runtime/dynamic/dataflow/batching/batches.rs b/core/src/runtime/dynamic/dataflow/batching/batches.rs index ed3c982..735bbb8 100644 --- a/core/src/runtime/dynamic/dataflow/batching/batches.rs +++ b/core/src/runtime/dynamic/dataflow/batching/batches.rs @@ -16,6 +16,8 @@ pub enum DynamicBatches<'a, Prov: Provenance> { Find(DynamicFindBatches<'a, Prov>), OverwriteOne(DynamicOverwriteOneBatches<'a, Prov>), Binary(DynamicBatchesBinary<'a, Prov>), + ForeignPredicateConstraint(ForeignPredicateConstraintBatches<'a, Prov>), + ForeignPredicateJoin(ForeignPredicateJoinBatches<'a, Prov>), } impl<'a, Prov: Provenance> DynamicBatches<'a, Prov> { @@ -78,6 +80,8 @@ impl<'a, Prov: Provenance> Iterator for DynamicBatches<'a, Prov> { Self::Find(f) => f.next(), Self::OverwriteOne(o) => o.next(), Self::Binary(b) => b.next(), + Self::ForeignPredicateConstraint(b) => b.next(), + Self::ForeignPredicateJoin(b) => b.next(), } } } diff --git a/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs b/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs index 40bbe39..316e82c 100644 --- a/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs +++ b/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs @@ -1,6 +1,7 @@ -use crate::common::expr::Expr; -use crate::common::tuple::Tuple; -use crate::common::tuple_type::TupleType; +use crate::common::expr::*; +use crate::common::value::*; +use crate::common::tuple::*; +use crate::common::tuple_type::*; use super::*; @@ -23,6 +24,9 @@ pub enum DynamicDataflow<'a, Prov: Provenance> { Difference(DynamicDifferenceDataflow<'a, Prov>), Antijoin(DynamicAntijoinDataflow<'a, Prov>), Aggregate(DynamicAggregationDataflow<'a, Prov>), + ForeignPredicateGround(ForeignPredicateGroundDataflow<'a, Prov>), + ForeignPredicateConstraint(ForeignPredicateConstraintDataflow<'a, Prov>), + ForeignPredicateJoin(ForeignPredicateJoinDataflow<'a, Prov>), } impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { @@ -133,6 +137,48 @@ impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { }) } + pub fn foreign_predicate_ground( + pred: String, + bounded: Vec, + first_iter: bool, + ctx: &'a Prov, + ) -> Self { + Self::ForeignPredicateGround(ForeignPredicateGroundDataflow { + foreign_predicate: pred, + bounded_constants: bounded, + first_iteration: first_iter, + ctx, + }) + } + + pub fn foreign_predicate_constraint( + self, + pred: String, + args: Vec, + ctx: &'a Prov, + ) -> Self { + Self::ForeignPredicateConstraint(ForeignPredicateConstraintDataflow { + dataflow: Box::new(self), + foreign_predicate: pred, + args, + ctx, + }) + } + + pub fn foreign_predicate_join( + self, + pred: String, + args: Vec, + ctx: &'a Prov, + ) -> Self { + Self::ForeignPredicateJoin(ForeignPredicateJoinDataflow { + left: Box::new(self), + foreign_predicate: pred, + args, + ctx, + }) + } + pub fn iter_stable(&self, runtime: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { match self { Self::StableUnit(i) => i.iter_stable(runtime), @@ -152,6 +198,9 @@ impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { Self::Difference(d) => d.iter_stable(runtime), Self::Antijoin(a) => a.iter_stable(runtime), Self::Aggregate(a) => a.iter_stable(runtime), + Self::ForeignPredicateGround(d) => d.iter_stable(runtime), + Self::ForeignPredicateConstraint(d) => d.iter_stable(runtime), + Self::ForeignPredicateJoin(d) => d.iter_stable(runtime), } } @@ -174,6 +223,9 @@ impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { Self::Difference(d) => d.iter_recent(runtime), Self::Antijoin(a) => a.iter_recent(runtime), Self::Aggregate(a) => a.iter_recent(runtime), + Self::ForeignPredicateGround(d) => d.iter_recent(runtime), + Self::ForeignPredicateConstraint(d) => d.iter_recent(runtime), + Self::ForeignPredicateJoin(d) => d.iter_recent(runtime), } } } diff --git a/core/src/runtime/dynamic/dataflow/foreign_predicate/constraint.rs b/core/src/runtime/dynamic/dataflow/foreign_predicate/constraint.rs new file mode 100644 index 0000000..183a2e9 --- /dev/null +++ b/core/src/runtime/dynamic/dataflow/foreign_predicate/constraint.rs @@ -0,0 +1,107 @@ +use crate::common::expr::*; +use crate::common::foreign_predicate::*; +use crate::common::input_tag::*; +use crate::runtime::provenance::*; + +use super::*; + +#[derive(Clone)] +pub struct ForeignPredicateConstraintDataflow<'a, Prov: Provenance> { + /// Sub-dataflow + pub dataflow: Box>, + + /// The foreign predicate + pub foreign_predicate: String, + + /// The arguments to the foreign predicate + pub args: Vec, + + /// Provenance context + pub ctx: &'a Prov, +} + +impl<'a, Prov: Provenance> ForeignPredicateConstraintDataflow<'a, Prov> { + pub fn iter_stable(&self, runtime: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { + let fp = runtime.predicate_registry.get(&self.foreign_predicate).expect("Foreign predicate not found"); + DynamicBatches::ForeignPredicateConstraint(ForeignPredicateConstraintBatches { + batches: Box::new(self.dataflow.iter_stable(runtime)), + foreign_predicate: fp.clone(), + args: self.args.clone(), + ctx: self.ctx, + }) + } + + pub fn iter_recent(&self, runtime: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { + let fp = runtime.predicate_registry.get(&self.foreign_predicate).expect("Foreign predicate not found"); + DynamicBatches::ForeignPredicateConstraint(ForeignPredicateConstraintBatches { + batches: Box::new(self.dataflow.iter_recent(runtime)), + foreign_predicate: fp.clone(), + args: self.args.clone(), + ctx: self.ctx, + }) + } +} + +#[derive(Clone)] +pub struct ForeignPredicateConstraintBatches<'a, Prov: Provenance> { + pub batches: Box>, + pub foreign_predicate: DynamicForeignPredicate, + pub args: Vec, + pub ctx: &'a Prov, +} + +impl<'a, Prov: Provenance> Iterator for ForeignPredicateConstraintBatches<'a, Prov> { + type Item = DynamicBatch<'a, Prov>; + + fn next(&mut self) -> Option { + self.batches.next().map(|batch| { + DynamicBatch::ForeignPredicateConstraint(ForeignPredicateConstraintBatch { + batch: Box::new(batch), + foreign_predicate: self.foreign_predicate.clone(), + args: self.args.clone(), + ctx: self.ctx, + }) + }) + } +} + +#[derive(Clone)] +pub struct ForeignPredicateConstraintBatch<'a, Prov: Provenance> { + pub batch: Box>, + pub foreign_predicate: DynamicForeignPredicate, + pub args: Vec, + pub ctx: &'a Prov, +} + +impl<'a, Prov: Provenance> Iterator for ForeignPredicateConstraintBatch<'a, Prov> { + type Item = DynamicElement; + + fn next(&mut self) -> Option { + while let Some(elem) = self.batch.next() { + let Tagged { tuple, tag } = elem; + + // Try evaluate the arguments; if failed, continue to the next element in the batch + let values = self.args.iter().map(|arg| { + match arg { + Expr::Access(a) => tuple[a].as_value(), + Expr::Constant(c) => c.clone(), + _ => panic!("Invalid argument to bounded foreign predicate") + } + }).collect::>(); + + // Evaluate the foreign predicate to produce a list of output tags + // Note that there will be at most one output tag since the foreign predicate is bounded + let result = self.foreign_predicate.evaluate(&values); + + // Check if the foreign predicate returned a tag + if !result.is_empty() { + assert_eq!(result.len(), 1, "Bounded foreign predicate should return at most one element per evaluation"); + let input_tag = Prov::InputTag::from_dynamic_input_tag(&result[0].0); + let new_tag = self.ctx.tagging_optional_fn(input_tag); + let combined_tag = self.ctx.mult(&tag, &new_tag); + return Some(DynamicElement::new(tuple, combined_tag)); + } + } + None + } +} diff --git a/core/src/runtime/dynamic/dataflow/foreign_predicate/ground.rs b/core/src/runtime/dynamic/dataflow/foreign_predicate/ground.rs new file mode 100644 index 0000000..81816c1 --- /dev/null +++ b/core/src/runtime/dynamic/dataflow/foreign_predicate/ground.rs @@ -0,0 +1,65 @@ +use crate::common::foreign_predicate::*; +use crate::common::input_tag::*; +use crate::common::value::*; +use crate::common::tuple::*; +use crate::runtime::provenance::*; +use crate::runtime::env::*; + +use super::*; + +#[derive(Clone, Debug)] +pub struct ForeignPredicateGroundDataflow<'a, Prov: Provenance> { + /// The foreign predicate + pub foreign_predicate: String, + + /// The already bounded constants (in order to make this Dataflow free) + pub bounded_constants: Vec, + + /// Whether this Dataflow is running on first iteration + pub first_iteration: bool, + + /// Provenance context + pub ctx: &'a Prov, +} + +impl<'a, Prov: Provenance> ForeignPredicateGroundDataflow<'a, Prov> { + /// Generate a batch from the foreign predicate + fn generate_batch(&self, runtime: &RuntimeEnvironment) -> DynamicBatch<'a, Prov> { + // Fetch the foreign predicate + let foreign_predicate = runtime + .predicate_registry + .get(&self.foreign_predicate) + .expect("Foreign predicate not found"); + + // Evaluate the foreign predicate + let elements = foreign_predicate + .evaluate(&self.bounded_constants) + .into_iter() + .map(|(input_tag, values)| { + let input_tag = StaticInputTag::from_dynamic_input_tag(&input_tag); + let tag = self.ctx.tagging_optional_fn(input_tag); + let tuple = Tuple::from(values); + DynamicElement::new(tuple, tag) + }) + .collect::>(); + DynamicBatch::source_vec(elements) + } +} + +impl<'a, Prov: Provenance> ForeignPredicateGroundDataflow<'a, Prov> { + pub fn iter_stable(&self, runtime: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { + if self.first_iteration { + DynamicBatches::empty() + } else { + DynamicBatches::single(self.generate_batch(runtime)) + } + } + + pub fn iter_recent(&self, runtime: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { + if self.first_iteration { + DynamicBatches::single(self.generate_batch(runtime)) + } else { + DynamicBatches::empty() + } + } +} diff --git a/core/src/runtime/dynamic/dataflow/foreign_predicate/join.rs b/core/src/runtime/dynamic/dataflow/foreign_predicate/join.rs new file mode 100644 index 0000000..837083a --- /dev/null +++ b/core/src/runtime/dynamic/dataflow/foreign_predicate/join.rs @@ -0,0 +1,146 @@ +use crate::common::foreign_predicate::*; +use crate::common::input_tag::*; +use crate::common::expr::*; +use crate::common::tuple::*; +use crate::common::value::*; +use crate::runtime::provenance::*; +use crate::runtime::env::*; + +use super::*; + +pub struct ForeignPredicateJoinDataflow<'a, Prov: Provenance> { + pub left: Box>, + + /// The foreign predicate + pub foreign_predicate: String, + + /// The already bounded constants (in order to make this Dataflow free) + pub args: Vec, + + /// Provenance context + pub ctx: &'a Prov, +} + +impl<'a, Prov: Provenance> Clone for ForeignPredicateJoinDataflow<'a, Prov> { + fn clone(&self) -> Self { + Self { + left: self.left.clone(), + foreign_predicate: self.foreign_predicate.clone(), + args: self.args.clone(), + ctx: self.ctx, + } + } +} + +impl<'a, Prov: Provenance> ForeignPredicateJoinDataflow<'a, Prov> { + pub fn iter_stable(&self, runtime: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { + DynamicBatches::ForeignPredicateJoin(ForeignPredicateJoinBatches { + batches: Box::new(self.left.iter_stable(runtime)), + foreign_predicate: runtime.predicate_registry.get(&self.foreign_predicate).expect("Foreign predicate not found").clone(), + args: self.args.clone(), + ctx: self.ctx, + }) + } + + pub fn iter_recent(&self, runtime: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { + DynamicBatches::ForeignPredicateJoin(ForeignPredicateJoinBatches { + batches: Box::new(self.left.iter_recent(runtime)), + foreign_predicate: runtime.predicate_registry.get(&self.foreign_predicate).expect("Foreign predicate not found").clone(), + args: self.args.clone(), + ctx: self.ctx, + }) + } +} + +#[derive(Clone)] +pub struct ForeignPredicateJoinBatches<'a, Prov: Provenance> { + pub batches: Box>, + pub foreign_predicate: DynamicForeignPredicate, + pub args: Vec, + pub ctx: &'a Prov, +} + +impl<'a, Prov: Provenance> Iterator for ForeignPredicateJoinBatches<'a, Prov> { + type Item = DynamicBatch<'a, Prov>; + + fn next(&mut self) -> Option { + // First, try to get a batch from the set of batches + self.batches.next().map(|mut batch| { + // Then, try to get the first element inside of this batch; + // if there is an element, we need to evaluate the foreign predicate and produce a current output batch + let first_output_batch = batch.next().map(|elem| { + eval_foreign_predicate(elem, &self.foreign_predicate, &self.args, self.ctx) + }); + + // Generate a new batch + DynamicBatch::ForeignPredicateJoin(ForeignPredicateJoinBatch { + batch: Box::new(batch), + foreign_predicate: self.foreign_predicate.clone(), + args: self.args.clone(), + current_output_batch: first_output_batch, + ctx: self.ctx, + }) + }) + } +} + +#[derive(Clone)] +pub struct ForeignPredicateJoinBatch<'a, Prov: Provenance> { + pub batch: Box>, + pub foreign_predicate: DynamicForeignPredicate, + pub args: Vec, + pub current_output_batch: Option<(DynamicElement, std::vec::IntoIter>)>, + pub ctx: &'a Prov, +} + +impl<'a, Prov: Provenance> Iterator for ForeignPredicateJoinBatch<'a, Prov> { + type Item = DynamicElement; + + fn next(&mut self) -> Option { + while let Some((left_elem, current_output_batch)) = &mut self.current_output_batch { + if let Some(right_elem) = current_output_batch.next() { + let tuple = (left_elem.tuple.clone(), right_elem.tuple); + let new_tag = self.ctx.mult(&left_elem.tag, &right_elem.tag); + return Some(DynamicElement::new(tuple, new_tag)) + } else { + self.current_output_batch = self.batch.next().map(|elem| { + eval_foreign_predicate(elem, &self.foreign_predicate, &self.args, self.ctx) + }); + } + } + None + } +} + +/// Evaluate the foreign predicate on the given element +fn eval_foreign_predicate( + elem: DynamicElement, + fp: &DynamicForeignPredicate, + args: &Vec, + ctx: &Prov, +) -> (DynamicElement, std::vec::IntoIter>) { + // First get the arguments to pass to the foreign predicate + let args_to_fp: Vec = args.iter().map(|arg| { + match arg { + Expr::Access(a) => elem.tuple[a].as_value(), + Expr::Constant(c) => c.clone(), + _ => panic!("Foreign predicate join only supports constant and access arguments"), + } + }).collect(); + + // Then evaluate the foreign predicate on these arguments + let outputs: Vec<_> = fp.evaluate(&args_to_fp).into_iter().map(|(tag, values)| { + // Make sure to tag the output elements + let input_tag = Prov::InputTag::from_dynamic_input_tag(&tag); + let new_tag = ctx.tagging_optional_fn(input_tag); + + // Generate a tuple from the values produced by the foreign predicate + let tuple = Tuple::from(values); + + // Generate the output element + DynamicElement::new(tuple, new_tag) + }).collect(); + + // Return the input element and output elements pair + (elem, outputs.into_iter()) +} diff --git a/core/src/runtime/dynamic/dataflow/foreign_predicate/mod.rs b/core/src/runtime/dynamic/dataflow/foreign_predicate/mod.rs new file mode 100644 index 0000000..c3022a6 --- /dev/null +++ b/core/src/runtime/dynamic/dataflow/foreign_predicate/mod.rs @@ -0,0 +1,9 @@ +mod constraint; +mod ground; +mod join; + +pub use constraint::*; +pub use ground::*; +pub use join::*; + +use super::*; diff --git a/core/src/runtime/dynamic/dataflow/mod.rs b/core/src/runtime/dynamic/dataflow/mod.rs index c86a4fd..843458b 100644 --- a/core/src/runtime/dynamic/dataflow/mod.rs +++ b/core/src/runtime/dynamic/dataflow/mod.rs @@ -9,6 +9,7 @@ mod dynamic_dataflow; mod dynamic_relation; mod filter; mod find; +mod foreign_predicate; mod intersect; mod join; mod overwrite_one; @@ -35,6 +36,7 @@ pub use dynamic_dataflow::*; use dynamic_relation::*; use filter::*; use find::*; +use foreign_predicate::*; use intersect::*; use join::*; use overwrite_one::*; diff --git a/core/src/runtime/dynamic/incremental.rs b/core/src/runtime/dynamic/incremental.rs index 9734c39..63a5778 100644 --- a/core/src/runtime/dynamic/incremental.rs +++ b/core/src/runtime/dynamic/incremental.rs @@ -619,7 +619,7 @@ impl DynamicExecutionContext { self.idb.get_output_collection_ref(r) } - pub fn relation(&self, r: &str) -> Option>> { + pub fn relation(&self, r: &str) -> Option>> { self.idb.get_output_collection(r) } diff --git a/core/src/runtime/dynamic/io.rs b/core/src/runtime/dynamic/io.rs index faf51b0..a2e3f7b 100644 --- a/core/src/runtime/dynamic/io.rs +++ b/core/src/runtime/dynamic/io.rs @@ -3,7 +3,7 @@ use std::fs::File; use std::path::PathBuf; use crate::common::input_file::InputFile; -use crate::common::input_tag::InputTag; +use crate::common::input_tag::DynamicInputTag; use crate::common::output_option::OutputFile; use crate::common::tuple::Tuple; use crate::common::tuple_type::TupleType; @@ -11,7 +11,7 @@ use crate::common::value_type::ValueType; use crate::runtime::error::*; -pub fn load(input_file: &InputFile, types: &TupleType) -> Result, IOError> { +pub fn load(input_file: &InputFile, types: &TupleType) -> Result, IOError> { match input_file { InputFile::Csv { file_path, @@ -29,7 +29,7 @@ pub fn load_csv( has_header: bool, has_probability: bool, types: &TupleType, -) -> Result, IOError> { +) -> Result, IOError> { // First parse the value types let value_types = get_value_types(types)?; @@ -60,10 +60,10 @@ pub fn load_csv( let tag = if has_probability { let s = record.get(0).unwrap(); - s.parse::() + s.parse::() .map_err(|_| IOError::CannotParseProbability { value: s.to_string() })? } else { - InputTag::None + DynamicInputTag::None }; let values = record @@ -73,7 +73,7 @@ pub fn load_csv( .map(|(r, t)| t.parse(r).map_err(|e| IOError::ValueParseError { error: e })) .collect::, _>>()?; - let tuple = Tuple::from_primitives(values); + let tuple = Tuple::from(values); result.push((tag, tuple)); } diff --git a/core/src/runtime/dynamic/iteration.rs b/core/src/runtime/dynamic/iteration.rs index 67030b5..1d24cf2 100644 --- a/core/src/runtime/dynamic/iteration.rs +++ b/core/src/runtime/dynamic/iteration.rs @@ -241,6 +241,15 @@ impl<'a, Prov: Provenance> DynamicIteration<'a, Prov> { self.unsafe_get_dynamic_relation(c).into() } } + Dataflow::ForeignPredicateGround(p, a) => { + DynamicDataflow::foreign_predicate_ground(p.clone(), a.clone(), self.is_first_iteration(), ctx) + } + Dataflow::ForeignPredicateConstraint(d, p, a) => { + self.build_dynamic_dataflow(ctx, d).foreign_predicate_constraint(p.clone(), a.clone(), ctx) + } + Dataflow::ForeignPredicateJoin(d, p, a) => { + self.build_dynamic_dataflow(ctx, d).foreign_predicate_join(p.clone(), a.clone(), ctx) + } Dataflow::OverwriteOne(d) => self.build_dynamic_dataflow(ctx, d).overwrite_one(ctx), Dataflow::Filter(d, e) => self.build_dynamic_dataflow(ctx, d).filter(e.clone()), Dataflow::Find(d, k) => self.build_dynamic_dataflow(ctx, d).find(k.clone()), diff --git a/core/src/runtime/dynamic/relation.rs b/core/src/runtime/dynamic/relation.rs index 35614f0..9f7fc64 100644 --- a/core/src/runtime/dynamic/relation.rs +++ b/core/src/runtime/dynamic/relation.rs @@ -4,8 +4,7 @@ use std::rc::Rc; use super::dataflow::*; use super::*; -use crate::common::input_tag::FromInputTag; -use crate::common::input_tag::InputTag; +use crate::common::input_tag::*; use crate::common::tuple::Tuple; use crate::runtime::env::*; use crate::runtime::monitor::*; @@ -65,21 +64,21 @@ impl DynamicRelation { self.insert_tagged_with_monitor(ctx, vec![(input_tag, tuple)], m); } - pub fn insert_dynamically_tagged(&self, ctx: &mut Prov, data: Vec<(InputTag, Tup)>) + pub fn insert_dynamically_tagged(&self, ctx: &mut Prov, data: Vec<(DynamicInputTag, Tup)>) where Tup: Into, { let elements = data .into_iter() .map(|(tag, tup)| { - let input_tag = FromInputTag::from_input_tag(&tag); + let input_tag = StaticInputTag::from_dynamic_input_tag(&tag); (input_tag, tup) }) .collect(); self.insert_tagged(ctx, elements); } - pub fn insert_dynamically_tagged_with_monitor(&self, ctx: &mut Prov, data: Vec<(InputTag, Tup)>, m: &M) + pub fn insert_dynamically_tagged_with_monitor(&self, ctx: &mut Prov, data: Vec<(DynamicInputTag, Tup)>, m: &M) where Tup: Into, M: Monitor, @@ -87,7 +86,7 @@ impl DynamicRelation { let elements = data .into_iter() .map(|(tag, tup)| { - let input_tag = FromInputTag::from_input_tag(&tag); + let input_tag = StaticInputTag::from_dynamic_input_tag(&tag); (input_tag, tup) }) .collect(); diff --git a/core/src/runtime/env/environment.rs b/core/src/runtime/env/environment.rs index 1f5093e..afc2bc0 100644 --- a/core/src/runtime/env/environment.rs +++ b/core/src/runtime/env/environment.rs @@ -6,6 +6,7 @@ use rand::SeedableRng; use crate::common::constants::*; use crate::common::expr::*; use crate::common::foreign_function::*; +use crate::common::foreign_predicate::*; use crate::common::tuple::*; use crate::common::value_type::*; @@ -25,22 +26,26 @@ pub struct RuntimeEnvironment { /// Foreign function registry pub function_registry: ForeignFunctionRegistry, + + /// Foreign predicate registry + pub predicate_registry: ForeignPredicateRegistry, } impl Default for RuntimeEnvironment { fn default() -> Self { - Self::new() + Self::new_std() } } impl RuntimeEnvironment { - pub fn new() -> Self { + pub fn new_std() -> Self { Self { random_seed: DEFAULT_RANDOM_SEED, rng: Arc::new(Mutex::new(SmallRng::seed_from_u64(DEFAULT_RANDOM_SEED))), early_discard: true, iter_limit: None, function_registry: ForeignFunctionRegistry::std(), + predicate_registry: ForeignPredicateRegistry::std(), } } @@ -51,6 +56,21 @@ impl RuntimeEnvironment { early_discard: true, iter_limit: None, function_registry: ForeignFunctionRegistry::std(), + predicate_registry: ForeignPredicateRegistry::std(), + } + } + + pub fn new( + ffr: ForeignFunctionRegistry, + fpr: ForeignPredicateRegistry, + ) -> Self { + Self { + random_seed: DEFAULT_RANDOM_SEED, + rng: Arc::new(Mutex::new(SmallRng::seed_from_u64(DEFAULT_RANDOM_SEED))), + early_discard: true, + iter_limit: None, + function_registry: ffr, + predicate_registry: fpr, } } @@ -61,9 +81,22 @@ impl RuntimeEnvironment { early_discard: true, iter_limit: None, function_registry: ffr, + predicate_registry: ForeignPredicateRegistry::std(), } } + pub fn set_early_discard(&mut self, early_discard: bool) { + self.early_discard = early_discard + } + + pub fn set_iter_limit(&mut self, k: usize) { + self.iter_limit = Some(k); + } + + pub fn remove_iter_limit(&mut self) { + self.iter_limit = None; + } + pub fn eval(&self, expr: &Expr, tuple: &Tuple) -> Option { match expr { Expr::Tuple(t) => Some(Tuple::Tuple( @@ -103,6 +136,10 @@ impl RuntimeEnvironment { (Add, Tuple::Value(USize(i1)), Tuple::Value(USize(i2))) => Tuple::Value(USize(i1 + i2)), (Add, Tuple::Value(F32(i1)), Tuple::Value(F32(i2))) => Tuple::Value(F32(i1 + i2)), (Add, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => Tuple::Value(F64(i1 + i2)), + (Add, Tuple::Value(String(s1)), Tuple::Value(String(s2))) => Tuple::Value(String(format!("{}{}", s1, s2))), + (Add, Tuple::Value(DateTime(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(DateTime(i1 + i2)), + (Add, Tuple::Value(Duration(i1)), Tuple::Value(DateTime(i2))) => Tuple::Value(DateTime(i2 + i1)), + (Add, Tuple::Value(Duration(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(Duration(i1 + i2)), (Add, b1, b2) => panic!("Cannot perform ADD on {:?} and {:?}", b1, b2), // Subtraction @@ -120,6 +157,9 @@ impl RuntimeEnvironment { (Sub, Tuple::Value(USize(i1)), Tuple::Value(USize(i2))) => Tuple::Value(USize(i1 - i2)), (Sub, Tuple::Value(F32(i1)), Tuple::Value(F32(i2))) => Tuple::Value(F32(i1 - i2)), (Sub, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => Tuple::Value(F64(i1 - i2)), + (Sub, Tuple::Value(DateTime(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(DateTime(i1 - i2)), + (Sub, Tuple::Value(DateTime(i1)), Tuple::Value(DateTime(i2))) => Tuple::Value(Duration(i1 - i2)), + (Sub, Tuple::Value(Duration(i1)), Tuple::Value(Duration(i2))) =>Tuple::Value(Duration(i1 - i2)), (Sub, b1, b2) => panic!("Cannot perform SUB on {:?} and {:?}", b1, b2), // Multiplication @@ -137,6 +177,8 @@ impl RuntimeEnvironment { (Mul, Tuple::Value(USize(i1)), Tuple::Value(USize(i2))) => Tuple::Value(USize(i1 * i2)), (Mul, Tuple::Value(F32(i1)), Tuple::Value(F32(i2))) => Tuple::Value(F32(i1 * i2)), (Mul, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => Tuple::Value(F64(i1 * i2)), + (Mul, Tuple::Value(Duration(i1)), Tuple::Value(I32(i2))) => Tuple::Value(Duration(i1 * i2)), + (Mul, Tuple::Value(I32(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(Duration(i2 * i1)), (Mul, b1, b2) => panic!("Cannot perform MUL on {:?} and {:?}", b1, b2), // Division @@ -152,8 +194,23 @@ impl RuntimeEnvironment { (Div, Tuple::Value(U64(i1)), Tuple::Value(U64(i2))) => Tuple::Value(U64(i1 / i2)), (Div, Tuple::Value(U128(i1)), Tuple::Value(U128(i2))) => Tuple::Value(U128(i1 / i2)), (Div, Tuple::Value(USize(i1)), Tuple::Value(USize(i2))) => Tuple::Value(USize(i1 / i2)), - (Div, Tuple::Value(F32(i1)), Tuple::Value(F32(i2))) => Tuple::Value(F32(i1 / i2)), - (Div, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => Tuple::Value(F64(i1 / i2)), + (Div, Tuple::Value(F32(i1)), Tuple::Value(F32(i2))) => { + let r = i1 / i2; + if r.is_nan() { + return None; + } else { + Tuple::Value(F32(r)) + } + }, + (Div, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => { + let r = i1 / i2; + if r.is_nan() { + return None; + } else { + Tuple::Value(F64(r)) + } + }, + (Div, Tuple::Value(Duration(i1)), Tuple::Value(I32(i2))) => Tuple::Value(Duration(i1 / i2)), (Div, b1, b2) => panic!("Cannot perform DIV on {:?} and {:?}", b1, b2), // Mod @@ -199,6 +256,8 @@ impl RuntimeEnvironment { (Eq, Tuple::Value(Str(i1)), Tuple::Value(Str(i2))) => Tuple::Value(Bool(i1 == i2)), (Eq, Tuple::Value(String(i1)), Tuple::Value(String(i2))) => Tuple::Value(Bool(i1 == i2)), // (Eq, Tuple::Value(RcString(i1)), Tuple::Value(RcString(i2))) => Tuple::Value(Bool(i1 == i2)), + (Eq, Tuple::Value(DateTime(i1)), Tuple::Value(DateTime(i2))) => Tuple::Value(Bool(i1 == i2)), + (Eq, Tuple::Value(Duration(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(Bool(i1 == i2)), (Eq, b1, b2) => panic!("Cannot perform EQ on {:?} and {:?}", b1, b2), // Not equal to @@ -221,6 +280,8 @@ impl RuntimeEnvironment { (Neq, Tuple::Value(Str(i1)), Tuple::Value(Str(i2))) => Tuple::Value(Bool(i1 != i2)), (Neq, Tuple::Value(String(i1)), Tuple::Value(String(i2))) => Tuple::Value(Bool(i1 != i2)), // (Neq, Tuple::Value(RcString(i1)), Tuple::Value(RcString(i2))) => Tuple::Value(Bool(i1 != i2)), + (Neq, Tuple::Value(DateTime(i1)), Tuple::Value(DateTime(i2))) => Tuple::Value(Bool(i1 != i2)), + (Neq, Tuple::Value(Duration(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(Bool(i1 != i2)), (Neq, b1, b2) => panic!("Cannot perform NEQ on {:?} and {:?}", b1, b2), // Greater than @@ -238,6 +299,8 @@ impl RuntimeEnvironment { (Gt, Tuple::Value(USize(i1)), Tuple::Value(USize(i2))) => Tuple::Value(Bool(i1 > i2)), (Gt, Tuple::Value(F32(i1)), Tuple::Value(F32(i2))) => Tuple::Value(Bool(i1 > i2)), (Gt, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => Tuple::Value(Bool(i1 > i2)), + (Gt, Tuple::Value(DateTime(i1)), Tuple::Value(DateTime(i2))) => Tuple::Value(Bool(i1 > i2)), + (Gt, Tuple::Value(Duration(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(Bool(i1 > i2)), (Gt, b1, b2) => panic!("Cannot perform GT on {:?} and {:?}", b1, b2), // Greater than or equal to @@ -255,6 +318,8 @@ impl RuntimeEnvironment { (Geq, Tuple::Value(USize(i1)), Tuple::Value(USize(i2))) => Tuple::Value(Bool(i1 >= i2)), (Geq, Tuple::Value(F32(i1)), Tuple::Value(F32(i2))) => Tuple::Value(Bool(i1 >= i2)), (Geq, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => Tuple::Value(Bool(i1 >= i2)), + (Geq, Tuple::Value(DateTime(i1)), Tuple::Value(DateTime(i2))) => Tuple::Value(Bool(i1 >= i2)), + (Geq, Tuple::Value(Duration(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(Bool(i1 >= i2)), (Geq, b1, b2) => panic!("Cannot perform GEQ on {:?} and {:?}", b1, b2), // Less than @@ -272,6 +337,8 @@ impl RuntimeEnvironment { (Lt, Tuple::Value(USize(i1)), Tuple::Value(USize(i2))) => Tuple::Value(Bool(i1 < i2)), (Lt, Tuple::Value(F32(i1)), Tuple::Value(F32(i2))) => Tuple::Value(Bool(i1 < i2)), (Lt, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => Tuple::Value(Bool(i1 < i2)), + (Lt, Tuple::Value(DateTime(i1)), Tuple::Value(DateTime(i2))) => Tuple::Value(Bool(i1 < i2)), + (Lt, Tuple::Value(Duration(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(Bool(i1 < i2)), (Lt, b1, b2) => panic!("Cannot perform LT on {:?} and {:?}", b1, b2), // Less than or equal to @@ -289,6 +356,8 @@ impl RuntimeEnvironment { (Leq, Tuple::Value(USize(i1)), Tuple::Value(USize(i2))) => Tuple::Value(Bool(i1 <= i2)), (Leq, Tuple::Value(F32(i1)), Tuple::Value(F32(i2))) => Tuple::Value(Bool(i1 <= i2)), (Leq, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => Tuple::Value(Bool(i1 <= i2)), + (Leq, Tuple::Value(DateTime(i1)), Tuple::Value(DateTime(i2))) => Tuple::Value(Bool(i1 <= i2)), + (Leq, Tuple::Value(Duration(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(Bool(i1 <= i2)), (Leq, b1, b2) => panic!("Cannot perform LEQ on {:?} and {:?}", b1, b2), }; Some(result) diff --git a/core/src/runtime/env/options.rs b/core/src/runtime/env/options.rs index b0a657d..f59dd83 100644 --- a/core/src/runtime/env/options.rs +++ b/core/src/runtime/env/options.rs @@ -6,6 +6,7 @@ use rand::SeedableRng; use super::*; use crate::common::constants::*; use crate::common::foreign_function::*; +use crate::common::foreign_predicate::*; /// The options to create a runtime environment #[derive(Clone, Debug)] @@ -39,6 +40,7 @@ impl RuntimeEnvironmentOptions { early_discard: self.early_discard, iter_limit: self.iter_limit, function_registry: ForeignFunctionRegistry::std(), + predicate_registry: ForeignPredicateRegistry::std(), } } } diff --git a/core/src/runtime/error/error.rs b/core/src/runtime/error/error.rs index 473b52c..915fb11 100644 --- a/core/src/runtime/error/error.rs +++ b/core/src/runtime/error/error.rs @@ -1,11 +1,13 @@ use super::io::IOError; use crate::common::foreign_function::ForeignFunctionError; +use crate::common::foreign_predicate::ForeignPredicateError; use crate::runtime::database::DatabaseError; #[derive(Clone, Debug)] pub enum RuntimeError { IO(IOError), ForeignFunction(ForeignFunctionError), + ForeignPredicate(ForeignPredicateError), Database(DatabaseError), } @@ -14,6 +16,7 @@ impl std::fmt::Display for RuntimeError { match self { Self::IO(e) => e.fmt(f), Self::ForeignFunction(e) => e.fmt(f), + Self::ForeignPredicate(e) => e.fmt(f), Self::Database(e) => e.fmt(f), } } diff --git a/core/src/runtime/provenance/common/diff_prob_storage.rs b/core/src/runtime/provenance/common/diff_prob_storage.rs new file mode 100644 index 0000000..029d416 --- /dev/null +++ b/core/src/runtime/provenance/common/diff_prob_storage.rs @@ -0,0 +1,77 @@ +use crate::utils::*; + +/// The differentiable probability storage that offers interior mutability +pub struct DiffProbStorage { + pub storage: P::RcCell)>>, + pub num_requires_grad: P::Cell, +} + +impl DiffProbStorage { + pub fn new() -> Self { + Self { + storage: P::new_rc_cell(Vec::new()), + num_requires_grad: P::new_cell(0), + } + } + + /// Clone the internal storage + pub fn clone_internal(&self) -> Self { + Self { + storage: P::new_rc_cell(P::get_rc_cell(&self.storage, |s| s.clone())), + num_requires_grad: P::clone_cell(&self.num_requires_grad), + } + } + + /// Clone the reference counter + pub fn clone_rc(&self) -> Self { + Self { + storage: P::clone_rc_cell(&self.storage), + num_requires_grad: P::clone_cell(&self.num_requires_grad), + } + } + + pub fn add_prob(&self, prob: f64, external_tag: Option) -> usize { + // Store the fact id + let fact_id = P::get_rc_cell(&self.storage, |s| s.len()); + + // Increment the `num_requires_grad` if the external tag is provided + if external_tag.is_some() { + P::get_cell_mut(&self.num_requires_grad, |n| *n += 1); + } + + // Push this element into the storage + P::get_rc_cell_mut(&self.storage, |s| s.push((prob, external_tag))); + + // Return the id + fact_id + } + + pub fn get_diff_prob(&self, id: &usize) -> (f64, Option) { + P::get_rc_cell(&self.storage, |d| d[id.clone()].clone()) + } + + pub fn get_prob(&self, id: &usize) -> f64 { + P::get_rc_cell(&self.storage, |d| d[id.clone()].0) + } + + pub fn input_tags(&self) -> Vec { + P::get_rc_cell(&self.storage, |s| { + s.iter().filter_map(|(_, t)| t.clone()).collect() + }) + } + + pub fn num_input_tags(&self) -> usize { + P::get_cell(&self.num_requires_grad, |i| *i) + } + + pub fn fact_probability(&self, id: &usize) -> f64 { + P::get_rc_cell(&self.storage, |d| d[*id].0) + } +} + +impl Clone for DiffProbStorage { + /// Clone the reference counter of this storage (shallow copy) + fn clone(&self) -> Self { + self.clone_rc() + } +} diff --git a/core/src/runtime/provenance/common/dual_number.rs b/core/src/runtime/provenance/common/dual_number.rs index 20090ea..de0bd1a 100644 --- a/core/src/runtime/provenance/common/dual_number.rs +++ b/core/src/runtime/provenance/common/dual_number.rs @@ -22,6 +22,14 @@ impl DualNumberSemiring { deriv: CsVec::new(self.dim, vec![id], vec![1.0]), } } + + /// Create a constant dual number + pub fn constant(&self, real: f64) -> DualNumber { + DualNumber { + real, + deriv: CsVec::new(self.dim, vec![], vec![]), + } + } } impl sdd::Semiring for DualNumberSemiring { diff --git a/core/src/runtime/provenance/common/dual_number_2.rs b/core/src/runtime/provenance/common/dual_number_2.rs index 0e96639..5b64ba7 100644 --- a/core/src/runtime/provenance/common/dual_number_2.rs +++ b/core/src/runtime/provenance/common/dual_number_2.rs @@ -39,6 +39,13 @@ impl DualNumber2 { } } + pub fn constant(real: f64) -> Self { + Self { + real, + gradient: Gradient::empty(), + } + } + pub fn clamp_real(&mut self) { self.real = self.real.clamp(0.0, 1.0); } diff --git a/core/src/runtime/provenance/common/input_diff_prob.rs b/core/src/runtime/provenance/common/input_diff_prob.rs index 12d9a48..e5a37a6 100644 --- a/core/src/runtime/provenance/common/input_diff_prob.rs +++ b/core/src/runtime/provenance/common/input_diff_prob.rs @@ -1,5 +1,16 @@ +use crate::common::input_tag::*; + +/// An input differentiable probability. +/// +/// It contains two elements. +/// The first is an `f64` which represents the probability of the tag. +/// The second is an `Option` which is the original differentiable object. +/// Note that if the second element is provided as `None` then it means we +/// do not treat the object as differentiable and thus we do not need to +/// back-propagate gradients into it. +/// In such case the probability is treated as a constant. #[derive(Clone)] -pub struct InputDiffProb(pub f64, pub T); +pub struct InputDiffProb(pub f64, pub Option); impl std::fmt::Debug for InputDiffProb { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -7,8 +18,18 @@ impl std::fmt::Debug for InputDiffProb { } } -impl From<(f64, T)> for InputDiffProb { - fn from((p, t): (f64, T)) -> Self { +impl From<(f64, Option)> for InputDiffProb { + fn from((p, t): (f64, Option)) -> Self { Self(p, t) } } + +impl StaticInputTag for InputDiffProb { + fn from_dynamic_input_tag(t: &DynamicInputTag) -> Option { + match t { + DynamicInputTag::ExclusiveFloat(f, _) => Some(Self(f.clone(), None)), + DynamicInputTag::Float(f) => Some(Self(f.clone(), None)), + _ => None, + } + } +} diff --git a/core/src/runtime/provenance/common/input_exclusive_diff_prob.rs b/core/src/runtime/provenance/common/input_exclusive_diff_prob.rs index b1915ec..bdef489 100644 --- a/core/src/runtime/provenance/common/input_exclusive_diff_prob.rs +++ b/core/src/runtime/provenance/common/input_exclusive_diff_prob.rs @@ -1,10 +1,12 @@ +use crate::common::input_tag::*; + #[derive(Clone)] pub struct InputExclusiveDiffProb { /// The probability of the tag pub prob: f64, /// The external tag for differentiability - pub tag: T, + pub external_tag: Option, /// An optional identifier of the mutual exclusion pub exclusion: Option, @@ -12,7 +14,7 @@ pub struct InputExclusiveDiffProb { impl InputExclusiveDiffProb { pub fn new(prob: f64, tag: T, exclusion: Option) -> Self { - Self { prob, tag, exclusion } + Self { prob, external_tag: Some(tag), exclusion } } } @@ -24,6 +26,18 @@ impl std::fmt::Debug for InputExclusiveDiffProb { impl From<(f64, T, Option)> for InputExclusiveDiffProb { fn from((prob, tag, exclusion): (f64, T, Option)) -> Self { - Self { prob, tag, exclusion } + Self { prob, external_tag: Some(tag), exclusion } + } +} + +impl StaticInputTag for InputExclusiveDiffProb { + fn from_dynamic_input_tag(t: &DynamicInputTag) -> Option { + match t { + DynamicInputTag::None => None, + DynamicInputTag::Bool(b) => Some(Self { prob: if *b { 1.0 } else { 0.0 }, external_tag: None, exclusion: None }), + DynamicInputTag::Exclusive(i) => Some(Self { prob: 1.0, external_tag: None, exclusion: Some(i.clone()) }), + DynamicInputTag::Float(prob) => Some(Self { prob: prob.clone(), external_tag: None, exclusion: None }), + DynamicInputTag::ExclusiveFloat(prob, i) => Some(Self { prob: prob.clone(), external_tag: None, exclusion: Some(i.clone()) }), + } } } diff --git a/core/src/runtime/provenance/common/input_exclusive_prob.rs b/core/src/runtime/provenance/common/input_exclusive_prob.rs index bcb2737..33effb6 100644 --- a/core/src/runtime/provenance/common/input_exclusive_prob.rs +++ b/core/src/runtime/provenance/common/input_exclusive_prob.rs @@ -45,8 +45,8 @@ impl From<(f64, usize)> for InputExclusiveProb { } } -impl FromInputTag for InputExclusiveProb { - fn from_input_tag(t: &DynamicInputTag) -> Option { +impl StaticInputTag for InputExclusiveProb { + fn from_dynamic_input_tag(t: &DynamicInputTag) -> Option { match t { DynamicInputTag::Float(f) => Some(Self::new(f.clone(), None)), DynamicInputTag::ExclusiveFloat(f, id) => Some(Self::new(f.clone(), Some(id.clone()))), diff --git a/core/src/runtime/provenance/common/mod.rs b/core/src/runtime/provenance/common/mod.rs index a5d0845..7b5cfa7 100644 --- a/core/src/runtime/provenance/common/mod.rs +++ b/core/src/runtime/provenance/common/mod.rs @@ -3,6 +3,7 @@ mod chosen_elements; mod clause; mod cnf_dnf_context; mod cnf_dnf_formula; +mod diff_prob_storage; mod disjunction; mod dnf_context; mod dnf_formula; @@ -22,6 +23,7 @@ pub use chosen_elements::*; pub use clause::*; pub use cnf_dnf_context::*; pub use cnf_dnf_formula::*; +pub use diff_prob_storage::*; pub use disjunction::*; pub use dnf_context::*; pub use dnf_formula::*; diff --git a/core/src/runtime/provenance/common/output_diff_prob.rs b/core/src/runtime/provenance/common/output_diff_prob.rs index 7ca9403..c6c6b9e 100644 --- a/core/src/runtime/provenance/common/output_diff_prob.rs +++ b/core/src/runtime/provenance/common/output_diff_prob.rs @@ -1,20 +1,20 @@ #[derive(Clone)] -pub struct OutputDiffProb(pub f64, pub Vec<(usize, f64, T)>); +pub struct OutputDiffProb(pub f64, pub Vec<(usize, f64)>); -impl std::fmt::Debug for OutputDiffProb { +impl std::fmt::Debug for OutputDiffProb { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("") .field(&self.0) - .field(&self.1.iter().map(|(id, weight, _)| (id, weight)).collect::>()) + .field(&self.1.iter().map(|(id, weight)| (id, weight)).collect::>()) .finish() } } -impl std::fmt::Display for OutputDiffProb { +impl std::fmt::Display for OutputDiffProb { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("") .field(&self.0) - .field(&self.1.iter().map(|(id, weight, _)| (id, weight)).collect::>()) + .field(&self.1.iter().map(|(id, weight)| (id, weight)).collect::>()) .finish() } } diff --git a/core/src/runtime/provenance/differentiable/diff_add_mult_prob.rs b/core/src/runtime/provenance/differentiable/diff_add_mult_prob.rs index 29afc45..b0dc877 100644 --- a/core/src/runtime/provenance/differentiable/diff_add_mult_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_add_mult_prob.rs @@ -7,24 +7,22 @@ use crate::runtime::statics::*; use crate::utils::PointerFamily; pub struct DiffAddMultProbProvenance { - pub warned_disjunction: bool, pub valid_threshold: f64, - pub storage: P::Pointer>, + pub storage: P::RcCell>, } impl Clone for DiffAddMultProbProvenance { fn clone(&self) -> Self { Self { - warned_disjunction: self.warned_disjunction, valid_threshold: self.valid_threshold, - storage: P::new((&*self.storage).clone()), + storage: P::new_rc_cell(P::get_rc_cell(&self.storage, |s| s.clone())), } } } impl DiffAddMultProbProvenance { pub fn input_tags(&self) -> Vec { - self.storage.iter().cloned().collect() + P::get_rc_cell(&self.storage, |s| s.clone()) } pub fn tag_of_chosen_set(&self, all: &Vec, chosen_ids: &Vec) -> DualNumber2 @@ -48,9 +46,8 @@ impl DiffAddMultProbProvenance { impl Default for DiffAddMultProbProvenance { fn default() -> Self { Self { - warned_disjunction: false, valid_threshold: 0.0000, - storage: P::new(Vec::new()), + storage: P::new_rc_cell(Vec::new()), } } } @@ -60,17 +57,21 @@ impl Provenance for DiffAddMultProbProvena type InputTag = InputDiffProb; - type OutputTag = OutputDiffProb; + type OutputTag = OutputDiffProb; fn name() -> &'static str { "diffaddmultprob" } - fn tagging_fn(&mut self, input_tag: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { let InputDiffProb(p, t) = input_tag; - let pos_id = self.storage.len(); - P::get_mut(&mut self.storage).push(t); - DualNumber2::new(pos_id, p) + if let Some(external_input_tag) = t { + let pos_id = P::get_rc_cell(&self.storage, |s| s.len()); + P::get_rc_cell_mut(&self.storage, |s| s.push(external_input_tag)); + DualNumber2::new(pos_id, p) + } else { + DualNumber2::constant(p) + } } fn recover_fn(&self, p: &Self::Tag) -> Self::OutputTag { @@ -80,7 +81,7 @@ impl Provenance for DiffAddMultProbProvena .indices .iter() .zip(p.gradient.values.iter()) - .map(|(i, v)| (*i, *v, self.storage[*i].clone())) + .map(|(i, v)| (*i, *v)) .collect::>(); OutputDiffProb(prob, deriv) } diff --git a/core/src/runtime/provenance/differentiable/diff_max_mult_prob.rs b/core/src/runtime/provenance/differentiable/diff_max_mult_prob.rs index aac431e..a827dfc 100644 --- a/core/src/runtime/provenance/differentiable/diff_max_mult_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_max_mult_prob.rs @@ -7,24 +7,22 @@ use crate::runtime::statics::*; use crate::utils::PointerFamily; pub struct DiffMaxMultProbProvenance { - pub warned_disjunction: bool, pub valid_threshold: f64, - pub storage: P::Pointer>, + pub storage: P::RcCell>, } impl Clone for DiffMaxMultProbProvenance { fn clone(&self) -> Self { Self { - warned_disjunction: self.warned_disjunction, valid_threshold: self.valid_threshold, - storage: P::new((&*self.storage).clone()), + storage: P::new_rc_cell(P::get_rc_cell(&self.storage, |s| s.clone())), } } } impl DiffMaxMultProbProvenance { pub fn input_tags(&self) -> Vec { - self.storage.iter().cloned().collect() + P::get_rc_cell(&self.storage, |s| s.clone()) } pub fn tag_of_chosen_set(&self, all: &Vec, chosen_ids: &Vec) -> DualNumber2 @@ -48,9 +46,8 @@ impl DiffMaxMultProbProvenance { impl Default for DiffMaxMultProbProvenance { fn default() -> Self { Self { - warned_disjunction: false, valid_threshold: 0.0000, - storage: P::new(Vec::new()), + storage: P::new_rc_cell(Vec::new()), } } } @@ -60,17 +57,21 @@ impl Provenance for DiffMaxMultProbProvena type InputTag = InputDiffProb; - type OutputTag = OutputDiffProb; + type OutputTag = OutputDiffProb; fn name() -> &'static str { "diffmaxmultprob" } - fn tagging_fn(&mut self, input_tag: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { let InputDiffProb(p, t) = input_tag; - let pos_id = self.storage.len(); - P::get_mut(&mut self.storage).push(t); - DualNumber2::new(pos_id, p) + if let Some(external_input_tag) = t { + let pos_id = P::get_rc_cell(&self.storage, |s| s.len()); + P::get_rc_cell_mut(&self.storage, |s| s.push(external_input_tag)); + DualNumber2::new(pos_id, p) + } else { + DualNumber2::constant(p) + } } fn recover_fn(&self, p: &Self::Tag) -> Self::OutputTag { @@ -80,7 +81,7 @@ impl Provenance for DiffMaxMultProbProvena .indices .iter() .zip(p.gradient.values.iter()) - .map(|(i, v)| (*i, *v, self.storage[*i].clone())) + .map(|(i, v)| (*i, *v)) .collect::>(); OutputDiffProb(prob, deriv) } diff --git a/core/src/runtime/provenance/differentiable/diff_min_max_prob.rs b/core/src/runtime/provenance/differentiable/diff_min_max_prob.rs index c08d81d..53a2e63 100644 --- a/core/src/runtime/provenance/differentiable/diff_min_max_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_min_max_prob.rs @@ -5,19 +5,31 @@ use crate::common::element::*; use crate::common::value_type::*; use crate::runtime::dynamic::*; use crate::runtime::statics::*; -use crate::utils::PointerFamily; +use crate::utils::*; -pub struct Prob(pub usize); +#[derive(Clone)] +pub enum Derivative { + Pos(usize), + Zero, + Neg(usize), +} -impl Prob { - fn new(id: usize) -> Self { - Self(id) +impl Derivative { + pub fn negate(&self) -> Self { + match self { + Self::Pos(i) => Self::Neg(i.clone()), + Self::Zero => Self::Zero, + Self::Neg(i) => Self::Pos(i.clone()), + } } } -impl Clone for Prob { - fn clone(&self) -> Self { - Self(self.0) +#[derive(Clone)] +pub struct Prob(pub f64, pub Derivative); + +impl Prob { + pub fn new(p: f64, d: Derivative) -> Self { + Self(p, d) } } @@ -36,32 +48,20 @@ impl std::fmt::Debug for Prob { impl Tag for Prob {} pub struct DiffMinMaxProbProvenance { - pub warned_disjunction: bool, + pub storage: P::RcCell>, pub valid_threshold: f64, - pub zero_index: usize, - pub one_index: usize, - pub diff_probs: P::Pointer)>>, - pub negates: Vec, } impl Clone for DiffMinMaxProbProvenance { fn clone(&self) -> Self { Self { - warned_disjunction: self.warned_disjunction, valid_threshold: self.valid_threshold, - zero_index: self.zero_index, - one_index: self.one_index, - diff_probs: P::new((&*self.diff_probs).clone()), - negates: self.negates.clone(), + storage: P::new_rc_cell(P::get_rc_cell(&self.storage, |s| s.clone())), } } } impl DiffMinMaxProbProvenance { - pub fn probability(&self, id: usize) -> f64 { - self.diff_probs[id].0 - } - pub fn collect_chosen_elements<'a, E>(&self, all: &'a Vec, chosen_ids: &Vec) -> Vec<&'a E> where E: Element, @@ -105,30 +105,13 @@ impl DiffMinMaxProbProvenance { impl Default for DiffMinMaxProbProvenance { fn default() -> Self { - let mut diff_probs = Vec::new(); - diff_probs.push((0.0, Derivative::Zero)); - diff_probs.push((1.0, Derivative::Zero)); - let mut negates = Vec::new(); - negates.push(1); - negates.push(0); Self { - warned_disjunction: false, valid_threshold: -0.0001, - zero_index: 0, - one_index: 1, - diff_probs: P::new(diff_probs), - negates, + storage: P::new_rc_cell(Vec::new()), } } } -#[derive(Clone)] -pub enum Derivative { - Pos(T), - Zero, - Neg(T), -} - #[derive(Clone)] pub struct OutputDiffProb(pub f64, pub usize, pub i32, pub Option); @@ -155,42 +138,39 @@ impl Provenance for DiffMinMaxProbProvenan "diffminmaxprob" } - fn tagging_fn(&mut self, input_tag: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { let InputDiffProb(p, t) = input_tag; - let pos_id = self.diff_probs.len(); - let neg_id = pos_id + 1; - P::get_mut(&mut self.diff_probs).extend(vec![ - (p, Derivative::Pos(t.clone())), - (1.0 - p, Derivative::Neg(t.clone())), - ]); - self.negates.push(neg_id); - self.negates.push(pos_id); - Self::Tag::new(pos_id) + if let Some(external_tag) = t { + let fact_id = P::get_rc_cell(&self.storage, |s| s.len()); + P::get_rc_cell_mut(&self.storage, |s| s.push(external_tag)); + Self::Tag::new(p, Derivative::Pos(fact_id)) + } else { + Self::Tag::new(p, Derivative::Zero) + } } fn recover_fn(&self, t: &Self::Tag) -> Self::OutputTag { - let (p, der) = &self.diff_probs[t.0]; - match der { - Derivative::Pos(s) => OutputDiffProb(*p, t.0, 1, Some(s.clone())), - Derivative::Zero => OutputDiffProb(*p, 0, 0, None), - Derivative::Neg(s) => OutputDiffProb(*p, t.0, -1, Some(s.clone())), + match &t.1 { + Derivative::Pos(fact_id) => OutputDiffProb(t.0, *fact_id, 1, Some(P::get_rc_cell(&self.storage, |s| s[*fact_id].clone()))), + Derivative::Zero => OutputDiffProb(t.0, 0, 0, None), + Derivative::Neg(fact_id) => OutputDiffProb(t.0, *fact_id, -1, Some(P::get_rc_cell(&self.storage, |s| s[*fact_id].clone()))), } } fn discard(&self, p: &Self::Tag) -> bool { - self.probability(p.0) <= self.valid_threshold + p.0 <= self.valid_threshold } fn zero(&self) -> Self::Tag { - Self::Tag::new(self.zero_index) + Self::Tag::new(0.0, Derivative::Zero) } fn one(&self) -> Self::Tag { - Self::Tag::new(self.one_index) + Self::Tag::new(1.0, Derivative::Zero) } fn add(&self, t1: &Self::Tag, t2: &Self::Tag) -> Self::Tag { - if self.probability(t1.0) > self.probability(t2.0) { + if t1.0 > t2.0 { t1.clone() } else { t2.clone() @@ -198,11 +178,11 @@ impl Provenance for DiffMinMaxProbProvenan } fn saturated(&self, t_old: &Self::Tag, t_new: &Self::Tag) -> bool { - self.probability(t_old.0) == self.probability(t_new.0) + t_old.0 == t_new.0 } fn mult(&self, t1: &Self::Tag, t2: &Self::Tag) -> Self::Tag { - if self.probability(t1.0) > self.probability(t2.0) { + if t1.0 > t2.0 { t2.clone() } else { t1.clone() @@ -210,18 +190,18 @@ impl Provenance for DiffMinMaxProbProvenan } fn negate(&self, p: &Self::Tag) -> Option { - Some(Self::Tag::new(self.negates[p.0])) + Some(Self::Tag::new(1.0 - p.0, p.1.negate())) } fn weight(&self, t: &Self::Tag) -> f64 { - self.probability(t.0) + t.0 } fn dynamic_count(&self, mut batch: DynamicElements) -> DynamicElements { if batch.is_empty() { vec![DynamicElement::new(0usize, self.one())] } else { - batch.sort_by(|a, b| self.probability(b.tag.0).total_cmp(&self.probability(a.tag.0))); + batch.sort_by(|a, b| b.tag.0.total_cmp(&a.tag.0)); let mut elems = vec![]; for k in 0..=batch.len() { let prob = self.max_min_prob_of_k_count(&batch, k); @@ -282,18 +262,18 @@ impl Provenance for DiffMinMaxProbProvenan fn dynamic_exists(&self, batch: DynamicElements) -> DynamicElements { let mut max_prob = 0.0; - let mut max_id = None; + let mut max_deriv = None; for elem in batch { - let prob = self.probability(elem.tag.0); + let prob = elem.tag.0; if prob > max_prob { max_prob = prob; - max_id = Some(elem.tag.0); + max_deriv = Some(elem.tag.1); } } - if let Some(id) = max_id { - let t = DynamicElement::new(true, Self::Tag::new(id)); - let f = DynamicElement::new(false, Self::Tag::new(self.negates[id])); - vec![t, f] + if let Some(deriv) = max_deriv { + let f = DynamicElement::new(false, Self::Tag::new(1.0 - max_prob, deriv.negate())); + let t = DynamicElement::new(true, Self::Tag::new(max_prob, deriv)); + vec![f, t] } else { vec![DynamicElement::new(false, self.one())] } @@ -303,7 +283,7 @@ impl Provenance for DiffMinMaxProbProvenan if batch.is_empty() { vec![StaticElement::new(0usize, self.one())] } else { - batch.sort_by(|a, b| self.probability(b.tag.0).total_cmp(&self.probability(a.tag.0))); + batch.sort_by(|a, b| b.tag.0.total_cmp(&a.tag.0)); let mut elems = vec![]; for k in 0..=batch.len() { let prob = self.max_min_prob_of_k_count(&batch, k); @@ -367,18 +347,18 @@ impl Provenance for DiffMinMaxProbProvenan fn static_exists(&self, batch: StaticElements) -> StaticElements { let mut max_prob = 0.0; - let mut max_id = None; + let mut max_deriv = None; for elem in batch { - let prob = self.probability(elem.tag.0); + let prob = elem.tag.0; if prob > max_prob { max_prob = prob; - max_id = Some(elem.tag.0); + max_deriv = Some(elem.tag.1); } } - if let Some(id) = max_id { - let t = StaticElement::new(true, Self::Tag::new(id)); - let f = StaticElement::new(false, Self::Tag::new(self.negates[id])); - vec![t, f] + if let Some(deriv) = max_deriv { + let f = StaticElement::new(false, Self::Tag::new(1.0 - max_prob, deriv.negate())); + let t = StaticElement::new(true, Self::Tag::new(max_prob, deriv)); + vec![f, t] } else { vec![StaticElement::new(false, self.one())] } diff --git a/core/src/runtime/provenance/differentiable/diff_nand_min_prob.rs b/core/src/runtime/provenance/differentiable/diff_nand_min_prob.rs index cf7c966..9b7d6ea 100644 --- a/core/src/runtime/provenance/differentiable/diff_nand_min_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_nand_min_prob.rs @@ -7,24 +7,22 @@ use crate::runtime::statics::*; use crate::utils::PointerFamily; pub struct DiffNandMinProbProvenance { - pub warned_disjunction: bool, pub valid_threshold: f64, - pub storage: P::Pointer>, + pub storage: P::RcCell>, } impl Clone for DiffNandMinProbProvenance { fn clone(&self) -> Self { Self { - warned_disjunction: self.warned_disjunction, valid_threshold: self.valid_threshold, - storage: P::new((&*self.storage).clone()), + storage: P::new_rc_cell(P::get_rc_cell(&self.storage, |s| s.clone())), } } } impl DiffNandMinProbProvenance { pub fn input_tags(&self) -> Vec { - self.storage.iter().cloned().collect() + P::get_rc_cell(&self.storage, |s| s.clone()) } pub fn tag_of_chosen_set(&self, all: &Vec, chosen_ids: &Vec) -> DualNumber2 @@ -48,9 +46,8 @@ impl DiffNandMinProbProvenance { impl Default for DiffNandMinProbProvenance { fn default() -> Self { Self { - warned_disjunction: false, valid_threshold: 0.0000, - storage: P::new(Vec::new()), + storage: P::new_rc_cell(Vec::new()), } } } @@ -60,17 +57,21 @@ impl Provenance for DiffNandMinProbProvena type InputTag = InputDiffProb; - type OutputTag = OutputDiffProb; + type OutputTag = OutputDiffProb; fn name() -> &'static str { "diffnandminprob" } - fn tagging_fn(&mut self, input_tag: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { let InputDiffProb(p, t) = input_tag; - let pos_id = self.storage.len(); - P::get_mut(&mut self.storage).push(t); - DualNumber2::new(pos_id, p) + if let Some(external_input_tag) = t { + let pos_id = P::get_rc_cell(&self.storage, |s| s.len()); + P::get_rc_cell_mut(&self.storage, |s| s.push(external_input_tag)); + DualNumber2::new(pos_id, p) + } else { + DualNumber2::constant(p) + } } fn recover_fn(&self, p: &Self::Tag) -> Self::OutputTag { @@ -80,7 +81,7 @@ impl Provenance for DiffNandMinProbProvena .indices .iter() .zip(p.gradient.values.iter()) - .map(|(i, v)| (*i, *v, self.storage[*i].clone())) + .map(|(i, v)| (*i, *v)) .collect::>(); OutputDiffProb(prob, deriv) } diff --git a/core/src/runtime/provenance/differentiable/diff_nand_mult_prob.rs b/core/src/runtime/provenance/differentiable/diff_nand_mult_prob.rs index 58fc090..90e49bd 100644 --- a/core/src/runtime/provenance/differentiable/diff_nand_mult_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_nand_mult_prob.rs @@ -7,24 +7,22 @@ use crate::runtime::statics::*; use crate::utils::PointerFamily; pub struct DiffNandMultProbProvenance { - pub warned_disjunction: bool, pub valid_threshold: f64, - pub storage: P::Pointer>, + pub storage: P::RcCell>, } impl Clone for DiffNandMultProbProvenance { fn clone(&self) -> Self { Self { - warned_disjunction: self.warned_disjunction, valid_threshold: self.valid_threshold, - storage: P::new((&*self.storage).clone()), + storage: P::new_rc_cell(P::get_rc_cell(&self.storage, |s| s.clone())), } } } impl DiffNandMultProbProvenance { pub fn input_tags(&self) -> Vec { - self.storage.iter().cloned().collect() + P::get_rc_cell(&self.storage, |s| s.clone()) } pub fn tag_of_chosen_set(&self, all: &Vec, chosen_ids: &Vec) -> DualNumber2 @@ -48,9 +46,8 @@ impl DiffNandMultProbProvenance { impl Default for DiffNandMultProbProvenance { fn default() -> Self { Self { - warned_disjunction: false, valid_threshold: 0.0000, - storage: P::new(Vec::new()), + storage: P::new_rc_cell(Vec::new()), } } } @@ -60,17 +57,21 @@ impl Provenance for DiffNandMultProbProven type InputTag = InputDiffProb; - type OutputTag = OutputDiffProb; + type OutputTag = OutputDiffProb; fn name() -> &'static str { "diffnandmultprob" } - fn tagging_fn(&mut self, input_tag: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { let InputDiffProb(p, t) = input_tag; - let pos_id = self.storage.len(); - P::get_mut(&mut self.storage).push(t); - DualNumber2::new(pos_id, p) + if let Some(external_input_tag) = t { + let pos_id = P::get_rc_cell(&self.storage, |s| s.len()); + P::get_rc_cell_mut(&self.storage, |s| s.push(external_input_tag)); + DualNumber2::new(pos_id, p) + } else { + DualNumber2::constant(p) + } } fn recover_fn(&self, p: &Self::Tag) -> Self::OutputTag { @@ -80,7 +81,7 @@ impl Provenance for DiffNandMultProbProven .indices .iter() .zip(p.gradient.values.iter()) - .map(|(i, v)| (*i, *v, self.storage[*i].clone())) + .map(|(i, v)| (*i, *v)) .collect::>(); OutputDiffProb(prob, deriv) } diff --git a/core/src/runtime/provenance/differentiable/diff_sample_k_proofs.rs b/core/src/runtime/provenance/differentiable/diff_sample_k_proofs.rs index 738158f..a98eb47 100644 --- a/core/src/runtime/provenance/differentiable/diff_sample_k_proofs.rs +++ b/core/src/runtime/provenance/differentiable/diff_sample_k_proofs.rs @@ -7,8 +7,8 @@ use crate::utils::PointerFamily; pub struct DiffSampleKProofsProvenance { pub k: usize, pub sampler: P::Cell, - pub diff_probs: P::Pointer>, - pub disjunctions: Disjunctions, + pub storage: DiffProbStorage, + pub disjunctions: P::Cell, } impl Clone for DiffSampleKProofsProvenance { @@ -16,8 +16,8 @@ impl Clone for DiffSampleKProofsProvenance { Self { k: self.k, sampler: P::clone_cell(&self.sampler), - diff_probs: P::new((&*self.diff_probs).clone()), - disjunctions: self.disjunctions.clone(), + storage: self.storage.clone_internal(), + disjunctions: P::clone_cell(&self.disjunctions), } } } @@ -31,13 +31,13 @@ impl DiffSampleKProofsProvenance { Self { k, sampler: P::new_cell(StdRng::seed_from_u64(seed)), - diff_probs: P::new(Vec::new()), - disjunctions: Disjunctions::new(), + storage: DiffProbStorage::new(), + disjunctions: P::new_cell(Disjunctions::new()), } } pub fn input_tags(&self) -> Vec { - self.diff_probs.iter().map(|(_, t)| t.clone()).collect() + self.storage.input_tags() } pub fn set_k(&mut self, new_k: usize) { @@ -47,11 +47,11 @@ impl DiffSampleKProofsProvenance { impl DNFContextTrait for DiffSampleKProofsProvenance { fn fact_probability(&self, id: &usize) -> f64 { - self.diff_probs[id.clone()].0 + self.storage.fact_probability(id) } fn has_disjunction_conflict(&self, pos_facts: &std::collections::BTreeSet) -> bool { - self.disjunctions.has_conflict(pos_facts) + P::get_cell(&self.disjunctions, |d| d.has_conflict(pos_facts)) } } @@ -60,22 +60,21 @@ impl Provenance for DiffSampleKProofsProve type InputTag = InputExclusiveDiffProb; - type OutputTag = OutputDiffProb; + type OutputTag = OutputDiffProb; fn name() -> &'static str { "diff-sample-k-proofs" } - fn tagging_fn(&mut self, input_tag: Self::InputTag) -> Self::Tag { - let InputExclusiveDiffProb { prob, tag, exclusion } = input_tag; + fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { + let InputExclusiveDiffProb { prob, external_tag, exclusion } = input_tag; // First store the probability and generate the id - let fact_id = self.diff_probs.len(); - P::get_mut(&mut self.diff_probs).push((prob, tag)); + let fact_id = self.storage.add_prob(prob, external_tag); // Store the mutual exclusivity if let Some(disjunction_id) = exclusion { - self.disjunctions.add_disjunction(disjunction_id, fact_id); + P::get_cell_mut(&self.disjunctions, |d| d.add_disjunction(disjunction_id, fact_id)); } // Finally return the formula @@ -83,17 +82,25 @@ impl Provenance for DiffSampleKProofsProve } fn recover_fn(&self, t: &Self::Tag) -> Self::OutputTag { - let s = DualNumberSemiring::new(self.diff_probs.len()); + // Get the number of variables that requires grad + let num_var_requires_grad = self.storage.num_input_tags(); + let s = DualNumberSemiring::new(num_var_requires_grad); let v = |i: &usize| { - let (real, _) = &self.diff_probs[i.clone()]; - s.singleton(real.clone(), i.clone()) + let (real, external_tag) = self.storage.get_diff_prob(i); + + // Check if this variable `i` requires grad or not + if external_tag.is_some() { + s.singleton(real.clone(), i.clone()) + } else { + s.constant(real.clone()) + } }; let wmc_result = t.wmc(&s, &v); let prob = wmc_result.real; let deriv = wmc_result .deriv .iter() - .map(|(id, weight)| (id, *weight, self.diff_probs[id].1.clone())) + .map(|(id, weight)| (id, *weight)) .collect::>(); OutputDiffProb(prob, deriv) } @@ -126,7 +133,7 @@ impl Provenance for DiffSampleKProofsProve } fn weight(&self, t: &Self::Tag) -> f64 { - let v = |i: &usize| self.diff_probs[i.clone()].0; + let v = |i: &usize| self.storage.get_prob(i); t.wmc(&RealSemiring::new(), &v) } diff --git a/core/src/runtime/provenance/differentiable/diff_top_bottom_k_clauses.rs b/core/src/runtime/provenance/differentiable/diff_top_bottom_k_clauses.rs index f567523..355cfc2 100644 --- a/core/src/runtime/provenance/differentiable/diff_top_bottom_k_clauses.rs +++ b/core/src/runtime/provenance/differentiable/diff_top_bottom_k_clauses.rs @@ -5,21 +5,20 @@ use itertools::Itertools; use super::*; use crate::runtime::dynamic::*; use crate::runtime::statics::*; -use crate::utils::{PointerFamily, RcFamily}; +use crate::utils::*; -#[derive(Debug)] pub struct DiffTopBottomKClausesProvenance { pub k: usize, - pub diff_probs: P::Pointer>, - pub disjunctions: Disjunctions, + pub storage: DiffProbStorage, + pub disjunctions: P::Cell, } impl Clone for DiffTopBottomKClausesProvenance { fn clone(&self) -> Self { Self { k: self.k, - diff_probs: P::new((&*self.diff_probs).clone()), - disjunctions: self.disjunctions.clone(), + storage: self.storage.clone_internal(), + disjunctions: P::clone_cell(&self.disjunctions), } } } @@ -28,8 +27,8 @@ impl DiffTopBottomKClausesProvenance pub fn new(k: usize) -> Self { Self { k, - diff_probs: P::new(Vec::new()), - disjunctions: Disjunctions::new(), + storage: DiffProbStorage::new(), + disjunctions: P::new_cell(Disjunctions::new()), } } @@ -38,17 +37,17 @@ impl DiffTopBottomKClausesProvenance } pub fn input_tags(&self) -> Vec { - self.diff_probs.iter().map(|(_, t)| t.clone()).collect() + self.storage.input_tags() } } impl CNFDNFContextTrait for DiffTopBottomKClausesProvenance { fn fact_probability(&self, id: &usize) -> f64 { - self.diff_probs[*id].0 + self.storage.fact_probability(id) } fn has_disjunction_conflict(&self, pos_facts: &BTreeSet) -> bool { - self.disjunctions.has_conflict(pos_facts) + P::get_cell(&self.disjunctions, |d| d.has_conflict(pos_facts)) } } @@ -57,22 +56,21 @@ impl Provenance for DiffTopBottomKClausesP type InputTag = InputExclusiveDiffProb; - type OutputTag = OutputDiffProb; + type OutputTag = OutputDiffProb; fn name() -> &'static str { "diff-top-bottom-k-clauses" } - fn tagging_fn(&mut self, input_tag: Self::InputTag) -> Self::Tag { - let InputExclusiveDiffProb { prob, tag, exclusion } = input_tag; + fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { + let InputExclusiveDiffProb { prob, external_tag, exclusion } = input_tag; // First store the probability and generate the id - let fact_id = self.diff_probs.len(); - P::get_mut(&mut self.diff_probs).push((prob, tag)); + let fact_id = self.storage.add_prob(prob, external_tag); // Store the mutual exclusivity if let Some(disjunction_id) = exclusion { - self.disjunctions.add_disjunction(disjunction_id, fact_id); + P::get_cell_mut(&self.disjunctions, |d| d.add_disjunction(disjunction_id, fact_id)); } // Finally return the formula @@ -80,17 +78,25 @@ impl Provenance for DiffTopBottomKClausesP } fn recover_fn(&self, t: &Self::Tag) -> Self::OutputTag { - let s = DualNumberSemiring::new(self.diff_probs.len()); + // Get the number of variables that requires grad + let num_var_requires_grad = self.storage.num_input_tags(); + let s = DualNumberSemiring::new(num_var_requires_grad); let v = |i: &usize| { - let (real, _) = &self.diff_probs[i.clone()]; - s.singleton(real.clone(), i.clone()) + let (real, external_tag) = self.storage.get_diff_prob(i); + + // Check if this variable `i` requires grad or not + if external_tag.is_some() { + s.singleton(real.clone(), i.clone()) + } else { + s.constant(real.clone()) + } }; let wmc_result = t.wmc(&s, &v); let prob = wmc_result.real; let deriv = wmc_result .deriv .iter() - .map(|(id, weight)| (id, *weight, self.diff_probs[id].1.clone())) + .map(|(id, weight)| (id, *weight)) .collect::>(); OutputDiffProb(prob, deriv) } @@ -124,7 +130,7 @@ impl Provenance for DiffTopBottomKClausesP } fn weight(&self, t: &Self::Tag) -> f64 { - let v = |i: &usize| self.diff_probs[i.clone()].0; + let v = |i: &usize| self.storage.get_prob(i); t.wmc(&RealSemiring::new(), &v) } diff --git a/core/src/runtime/provenance/differentiable/diff_top_k_proofs.rs b/core/src/runtime/provenance/differentiable/diff_top_k_proofs.rs index d51531b..694716c 100644 --- a/core/src/runtime/provenance/differentiable/diff_top_k_proofs.rs +++ b/core/src/runtime/provenance/differentiable/diff_top_k_proofs.rs @@ -7,16 +7,16 @@ use crate::utils::*; pub struct DiffTopKProofsProvenance { pub k: usize, - pub diff_probs: P::Pointer>, - pub disjunctions: Disjunctions, + pub storage: DiffProbStorage, + pub disjunctions: P::Cell, } impl Clone for DiffTopKProofsProvenance { fn clone(&self) -> Self { Self { k: self.k, - diff_probs: P::new((&*self.diff_probs).clone()), - disjunctions: self.disjunctions.clone(), + storage: self.storage.clone_internal(), + disjunctions: P::clone_cell(&self.disjunctions), } } } @@ -25,13 +25,13 @@ impl DiffTopKProofsProvenance { pub fn new(k: usize) -> Self { Self { k, - diff_probs: P::new(Vec::new()), - disjunctions: Disjunctions::new(), + storage: DiffProbStorage::new(), + disjunctions: P::new_cell(Disjunctions::new()), } } pub fn input_tags(&self) -> Vec { - self.diff_probs.iter().map(|(_, t)| t.clone()).collect() + self.storage.input_tags() } pub fn set_k(&mut self, k: usize) { @@ -41,11 +41,11 @@ impl DiffTopKProofsProvenance { impl DNFContextTrait for DiffTopKProofsProvenance { fn fact_probability(&self, id: &usize) -> f64 { - self.diff_probs[*id].0 + self.storage.fact_probability(id) } fn has_disjunction_conflict(&self, pos_facts: &std::collections::BTreeSet) -> bool { - self.disjunctions.has_conflict(pos_facts) + P::get_cell(&self.disjunctions, |d| d.has_conflict(pos_facts)) } } @@ -54,22 +54,21 @@ impl Provenance for DiffTopKProofsProvenan type InputTag = InputExclusiveDiffProb; - type OutputTag = OutputDiffProb; + type OutputTag = OutputDiffProb; fn name() -> &'static str { "diff-top-k-proofs" } - fn tagging_fn(&mut self, input_tag: Self::InputTag) -> Self::Tag { - let InputExclusiveDiffProb { prob, tag, exclusion } = input_tag; + fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { + let InputExclusiveDiffProb { prob, external_tag, exclusion } = input_tag; // First store the probability and generate the id - let fact_id = self.diff_probs.len(); - P::get_mut(&mut self.diff_probs).push((prob, tag)); + let fact_id = self.storage.add_prob(prob, external_tag); // Store the mutual exclusivity if let Some(disjunction_id) = exclusion { - self.disjunctions.add_disjunction(disjunction_id, fact_id); + P::get_cell_mut(&self.disjunctions, |d| d.add_disjunction(disjunction_id, fact_id)); } // Finally return the formula @@ -77,17 +76,25 @@ impl Provenance for DiffTopKProofsProvenan } fn recover_fn(&self, t: &Self::Tag) -> Self::OutputTag { - let s = DualNumberSemiring::new(self.diff_probs.len()); + // Get the number of variables that requires grad + let num_var_requires_grad = self.storage.num_input_tags(); + let s = DualNumberSemiring::new(num_var_requires_grad); let v = |i: &usize| { - let (real, _) = &self.diff_probs[i.clone()]; - s.singleton(real.clone(), i.clone()) + let (real, external_tag) = self.storage.get_diff_prob(i); + + // Check if this variable `i` requires grad or not + if external_tag.is_some() { + s.singleton(real.clone(), i.clone()) + } else { + s.constant(real.clone()) + } }; let wmc_result = t.wmc(&s, &v); let prob = wmc_result.real; let deriv = wmc_result .deriv .iter() - .map(|(id, weight)| (id, *weight, self.diff_probs[id].1.clone())) + .map(|(id, weight)| (id, *weight)) .collect::>(); OutputDiffProb(prob, deriv) } @@ -121,7 +128,7 @@ impl Provenance for DiffTopKProofsProvenan } fn weight(&self, t: &Self::Tag) -> f64 { - let v = |i: &usize| self.diff_probs[i.clone()].0; + let v = |i: &usize| self.storage.get_prob(i); t.wmc(&RealSemiring::new(), &v) } diff --git a/core/src/runtime/provenance/differentiable/diff_top_k_proofs_indiv.rs b/core/src/runtime/provenance/differentiable/diff_top_k_proofs_indiv.rs index c1c78c1..e7836a6 100644 --- a/core/src/runtime/provenance/differentiable/diff_top_k_proofs_indiv.rs +++ b/core/src/runtime/provenance/differentiable/diff_top_k_proofs_indiv.rs @@ -8,7 +8,7 @@ use crate::utils::PointerFamily; #[derive(Clone)] pub struct OutputIndivDiffProb { pub k: usize, - pub proofs: Vec>, + pub proofs: Vec)>>, } impl std::fmt::Debug for OutputIndivDiffProb { @@ -61,16 +61,16 @@ impl std::fmt::Display for OutputIndivDiffProb { pub struct DiffTopKProofsIndivProvenance { pub k: usize, - pub diff_probs: P::Pointer>, - pub disjunctions: Disjunctions, + pub storage: DiffProbStorage, + pub disjunctions: P::Cell, } impl Clone for DiffTopKProofsIndivProvenance { fn clone(&self) -> Self { Self { k: self.k, - diff_probs: P::new((&*self.diff_probs).clone()), - disjunctions: self.disjunctions.clone(), + storage: self.storage.clone_internal(), + disjunctions: P::clone_cell(&self.disjunctions), } } } @@ -79,17 +79,17 @@ impl DiffTopKProofsIndivProvenance { pub fn new(k: usize) -> Self { Self { k, - diff_probs: P::new(Vec::new()), - disjunctions: Disjunctions::new(), + storage: DiffProbStorage::new(), + disjunctions: P::new_cell(Disjunctions::new()), } } pub fn input_tags(&self) -> Vec { - self.diff_probs.iter().map(|(_, t)| t.clone()).collect() + self.storage.input_tags() } - pub fn input_tag_of_fact_id(&self, i: usize) -> T { - self.diff_probs[i].1.clone() + pub fn input_tag_of_fact_id(&self, i: usize) -> Option { + self.storage.get_diff_prob(&i).1 } pub fn set_k(&mut self, k: usize) { @@ -99,11 +99,11 @@ impl DiffTopKProofsIndivProvenance { impl DNFContextTrait for DiffTopKProofsIndivProvenance { fn fact_probability(&self, id: &usize) -> f64 { - self.diff_probs[*id].0 + self.storage.get_prob(id) } fn has_disjunction_conflict(&self, pos_facts: &std::collections::BTreeSet) -> bool { - self.disjunctions.has_conflict(pos_facts) + P::get_cell(&self.disjunctions, |d| d.has_conflict(pos_facts)) } } @@ -118,16 +118,15 @@ impl Provenance for DiffTopKProofsIndivPro "diff-top-k-proofs-indiv" } - fn tagging_fn(&mut self, input_tag: Self::InputTag) -> Self::Tag { - let InputExclusiveDiffProb { prob, tag, exclusion } = input_tag; + fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { + let InputExclusiveDiffProb { prob, external_tag, exclusion } = input_tag; // First store the probability and generate the id - let fact_id = self.diff_probs.len(); - P::get_mut(&mut self.diff_probs).push((prob, tag)); + let fact_id = self.storage.add_prob(prob, external_tag); // Store the mutual exclusivity if let Some(disjunction_id) = exclusion { - self.disjunctions.add_disjunction(disjunction_id, fact_id); + P::get_cell_mut(&self.disjunctions, |d| d.add_disjunction(disjunction_id, fact_id)); } // Finally return the formula @@ -148,7 +147,7 @@ impl Provenance for DiffTopKProofsIndivPro ( self.fact_probability(&fact_id), literal.sign(), - self.input_tag_of_fact_id(fact_id), + self.input_tag_of_fact_id(fact_id) ) }) .collect::>() @@ -186,9 +185,8 @@ impl Provenance for DiffTopKProofsIndivPro } fn weight(&self, t: &Self::Tag) -> f64 { - let s = RealSemiring::new(); - let v = |i: &usize| self.diff_probs[i.clone()].0; - t.wmc(&s, &v) + let v = |i: &usize| self.storage.get_prob(i); + t.wmc(&RealSemiring::new(), &v) } fn dynamic_count(&self, batch: DynamicElements) -> DynamicElements { diff --git a/core/src/runtime/provenance/discrete/boolean.rs b/core/src/runtime/provenance/discrete/boolean.rs index 635e947..e9ed8e4 100644 --- a/core/src/runtime/provenance/discrete/boolean.rs +++ b/core/src/runtime/provenance/discrete/boolean.rs @@ -18,7 +18,7 @@ impl Provenance for BooleanProvenance { "boolean" } - fn tagging_fn(&mut self, ext_tag: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, ext_tag: Self::InputTag) -> Self::Tag { ext_tag } diff --git a/core/src/runtime/provenance/discrete/natural.rs b/core/src/runtime/provenance/discrete/natural.rs index e433e72..743cc59 100644 --- a/core/src/runtime/provenance/discrete/natural.rs +++ b/core/src/runtime/provenance/discrete/natural.rs @@ -1,9 +1,12 @@ use super::*; use crate::runtime::dynamic::*; use crate::runtime::statics::*; +use crate::common::input_tag::*; pub type Natural = usize; +impl StaticInputTag for Natural {} + #[derive(Clone, Debug, Default)] pub struct NaturalProvenance; @@ -18,7 +21,7 @@ impl Provenance for NaturalProvenance { "natural" } - fn tagging_fn(&mut self, t: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, t: Self::InputTag) -> Self::Tag { t } diff --git a/core/src/runtime/provenance/discrete/proofs.rs b/core/src/runtime/provenance/discrete/proofs.rs index de6bcd9..1517921 100644 --- a/core/src/runtime/provenance/discrete/proofs.rs +++ b/core/src/runtime/provenance/discrete/proofs.rs @@ -5,7 +5,7 @@ use itertools::iproduct; use crate::common::input_tag::*; use crate::runtime::dynamic::*; use crate::runtime::statics::*; -use crate::utils::IdAllocator; +use crate::utils::*; use super::*; @@ -129,8 +129,8 @@ pub enum ProofsInputTag { Exclusive(usize), } -impl FromInputTag for ProofsInputTag { - fn from_input_tag(t: &DynamicInputTag) -> Option { +impl StaticInputTag for ProofsInputTag { + fn from_dynamic_input_tag(t: &DynamicInputTag) -> Option { match t { DynamicInputTag::Exclusive(e) => Some(ProofsInputTag::Exclusive(e.clone())), DynamicInputTag::ExclusiveFloat(_, e) => Some(ProofsInputTag::Exclusive(e.clone())), @@ -139,13 +139,30 @@ impl FromInputTag for ProofsInputTag { } } -#[derive(Clone, Default)] -pub struct ProofsProvenance { - id_allocator: IdAllocator, - disjunctions: Disjunctions, +pub struct ProofsProvenance { + id_allocator: P::Cell, + disjunctions: P::Cell, } -impl Provenance for ProofsProvenance { +impl Default for ProofsProvenance

{ + fn default() -> Self { + Self { + id_allocator: P::new_cell(IdAllocator::default()), + disjunctions: P::new_cell(Disjunctions::default()), + } + } +} + +impl Clone for ProofsProvenance

{ + fn clone(&self) -> Self { + Self { + id_allocator: P::clone_cell(&self.id_allocator), + disjunctions: P::clone_cell(&self.disjunctions), + } + } +} + +impl Provenance for ProofsProvenance

{ type Tag = Proofs; type InputTag = ProofsInputTag; @@ -156,12 +173,12 @@ impl Provenance for ProofsProvenance { "proofs" } - fn tagging_fn(&mut self, exclusion: Self::InputTag) -> Self::Tag { - let fact_id = self.id_allocator.alloc(); + fn tagging_fn(&self, exclusion: Self::InputTag) -> Self::Tag { + let fact_id = P::get_cell_mut(&self.id_allocator, |a| a.alloc()); // Disjunction id if let ProofsInputTag::Exclusive(disjunction_id) = exclusion { - self.disjunctions.add_disjunction(disjunction_id, fact_id) + P::get_cell_mut(&self.disjunctions, |d| d.add_disjunction(disjunction_id, fact_id)); } // Return the proof @@ -192,7 +209,7 @@ impl Provenance for ProofsProvenance { let mut prod = Self::Tag::cartesian_product(t1, t2); prod .proofs - .retain(|proof| !self.disjunctions.has_conflict(&proof.facts)); + .retain(|proof| !P::get_cell(&self.disjunctions, |d| d.has_conflict(&proof.facts))); prod } diff --git a/core/src/runtime/provenance/discrete/unit.rs b/core/src/runtime/provenance/discrete/unit.rs index fd54fe1..1fd8d22 100644 --- a/core/src/runtime/provenance/discrete/unit.rs +++ b/core/src/runtime/provenance/discrete/unit.rs @@ -1,7 +1,9 @@ -use super::*; +use crate::common::input_tag::*; use crate::runtime::dynamic::*; use crate::runtime::statics::*; +use super::*; + #[derive(Clone, Debug, Default)] pub struct Unit; @@ -13,9 +15,17 @@ impl std::fmt::Display for Unit { impl Tag for Unit {} +impl StaticInputTag for () {} + #[derive(Clone, Debug, Default)] pub struct UnitProvenance; +impl UnitProvenance { + pub fn new() -> Self { + Self + } +} + impl Provenance for UnitProvenance { type Tag = Unit; @@ -27,7 +37,7 @@ impl Provenance for UnitProvenance { "unit" } - fn tagging_fn(&mut self, _: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, _: Self::InputTag) -> Self::Tag { Unit } diff --git a/core/src/runtime/provenance/probabilistic/add_mult_prob.rs b/core/src/runtime/provenance/probabilistic/add_mult_prob.rs index d1fb93b..f8be209 100644 --- a/core/src/runtime/provenance/probabilistic/add_mult_prob.rs +++ b/core/src/runtime/provenance/probabilistic/add_mult_prob.rs @@ -52,7 +52,7 @@ impl Provenance for AddMultProbProvenance { "addmultprob" } - fn tagging_fn(&mut self, p: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, p: Self::InputTag) -> Self::Tag { p.into() } diff --git a/core/src/runtime/provenance/probabilistic/min_max_prob.rs b/core/src/runtime/provenance/probabilistic/min_max_prob.rs index 323c14c..5f0a7a5 100644 --- a/core/src/runtime/provenance/probabilistic/min_max_prob.rs +++ b/core/src/runtime/provenance/probabilistic/min_max_prob.rs @@ -41,7 +41,7 @@ impl Provenance for MinMaxProbProvenance { "minmaxprob" } - fn tagging_fn(&mut self, p: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, p: Self::InputTag) -> Self::Tag { p.into() } diff --git a/core/src/runtime/provenance/probabilistic/prob_proofs.rs b/core/src/runtime/provenance/probabilistic/prob_proofs.rs index 30e20fa..2d568e5 100644 --- a/core/src/runtime/provenance/probabilistic/prob_proofs.rs +++ b/core/src/runtime/provenance/probabilistic/prob_proofs.rs @@ -2,6 +2,8 @@ use std::collections::*; use itertools::iproduct; +use crate::utils::*; + use super::*; #[derive(Clone, Debug, PartialEq, Eq)] @@ -99,13 +101,28 @@ impl std::fmt::Display for ProbProofs { impl Tag for ProbProofs {} -#[derive(Clone, Default)] -pub struct ProbProofsProvenance { - probabilities: Vec, - disjunctions: Disjunctions, +#[derive(Default)] +pub struct ProbProofsProvenance { + probs: P::Cell>, + disjunctions: P::Cell, +} + +impl Clone for ProbProofsProvenance

{ + fn clone(&self) -> Self { + Self { + probs: P::clone_cell(&self.probs), + disjunctions: P::clone_cell(&self.disjunctions), + } + } +} + +impl ProbProofsProvenance

{ + fn fact_probability(&self, i: &usize) -> f64 { + P::get_cell(&self.probs, |p| p[*i]) + } } -impl Provenance for ProbProofsProvenance { +impl Provenance for ProbProofsProvenance

{ type Tag = ProbProofs; type InputTag = InputExclusiveProb; @@ -116,14 +133,14 @@ impl Provenance for ProbProofsProvenance { "prob-proofs" } - fn tagging_fn(&mut self, input_tag: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { // First generate id and push the probability into the list - let fact_id = self.probabilities.len(); - self.probabilities.push(input_tag.prob); + let fact_id = P::get_cell(&self.probs, |p| p.len()); + P::get_cell_mut(&self.probs, |p| p.push(input_tag.prob)); // Add exlusion if needed if let Some(disj_id) = input_tag.exclusion { - self.disjunctions.add_disjunction(disj_id, fact_id); + P::get_cell_mut(&self.disjunctions, |d| d.add_disjunction(disj_id, fact_id)); } // Lastly return a tag @@ -132,7 +149,7 @@ impl Provenance for ProbProofsProvenance { fn recover_fn(&self, t: &Self::Tag) -> Self::OutputTag { let s = RealSemiring; - let v = |i: &usize| -> f64 { self.probabilities[*i] }; + let v = |i: &usize| -> f64 { self.fact_probability(i) }; AsBooleanFormula::wmc(t, &s, &v) } @@ -156,7 +173,9 @@ impl Provenance for ProbProofsProvenance { let mut prod = Self::Tag::cartesian_product(t1, t2); prod .proofs - .retain(|proof| !self.disjunctions.has_conflict(&proof.facts)); + .retain(|proof| { + P::get_cell(&self.disjunctions, |d| !d.has_conflict(&proof.facts)) + }); prod } @@ -174,7 +193,7 @@ impl Provenance for ProbProofsProvenance { fn weight(&self, t: &Self::Tag) -> f64 { let s = RealSemiring; - let v = |i: &usize| -> f64 { self.probabilities[*i] }; + let v = |i: &usize| -> f64 { self.fact_probability(i) }; AsBooleanFormula::wmc(t, &s, &v) } } diff --git a/core/src/runtime/provenance/probabilistic/sample_k_proofs.rs b/core/src/runtime/provenance/probabilistic/sample_k_proofs.rs index bdfaf72..a68f02f 100644 --- a/core/src/runtime/provenance/probabilistic/sample_k_proofs.rs +++ b/core/src/runtime/provenance/probabilistic/sample_k_proofs.rs @@ -1,31 +1,31 @@ -use std::cell::RefCell; use std::collections::*; -use std::rc::Rc; use rand::prelude::*; use rand::rngs::StdRng; +use crate::utils::*; + use super::*; -pub struct SampleKProofsProvenance { +pub struct SampleKProofsProvenance { pub k: usize, - pub sampler: Rc>, - pub probs: Rc>, - pub disjunctions: Disjunctions, + pub sampler: P::Cell, + pub probs: P::Cell>, + pub disjunctions: P::Cell, } -impl Clone for SampleKProofsProvenance { +impl Clone for SampleKProofsProvenance

{ fn clone(&self) -> Self { Self { k: self.k, - sampler: self.sampler.clone(), - probs: Rc::new((&*self.probs).clone()), - disjunctions: self.disjunctions.clone(), + sampler: P::clone_cell(&self.sampler), + probs: P::clone_cell(&self.probs), + disjunctions: P::clone_cell(&self.disjunctions), } } } -impl SampleKProofsProvenance { +impl SampleKProofsProvenance

{ pub fn new(k: usize) -> Self { Self::new_with_seed(k, 12345678) } @@ -33,9 +33,9 @@ impl SampleKProofsProvenance { pub fn new_with_seed(k: usize, seed: u64) -> Self { Self { k, - sampler: Rc::new(RefCell::new(StdRng::seed_from_u64(seed))), - probs: Rc::new(Vec::new()), - disjunctions: Disjunctions::new(), + sampler: P::new_cell(StdRng::seed_from_u64(seed)), + probs: P::new_cell(Vec::new()), + disjunctions: P::new_cell(Disjunctions::new()), } } @@ -44,17 +44,17 @@ impl SampleKProofsProvenance { } } -impl DNFContextTrait for SampleKProofsProvenance { +impl DNFContextTrait for SampleKProofsProvenance

{ fn fact_probability(&self, id: &usize) -> f64 { - self.probs[*id] + P::get_cell(&self.probs, |p| p[*id]) } fn has_disjunction_conflict(&self, pos_facts: &BTreeSet) -> bool { - self.disjunctions.has_conflict(pos_facts) + P::get_cell(&self.disjunctions, |d| d.has_conflict(pos_facts)) } } -impl Provenance for SampleKProofsProvenance { +impl Provenance for SampleKProofsProvenance

{ type Tag = DNFFormula; type InputTag = InputExclusiveProb; @@ -65,14 +65,14 @@ impl Provenance for SampleKProofsProvenance { "sample-k-proofs" } - fn tagging_fn(&mut self, input_tag: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { // First generate id and push the probability into the list - let fact_id = self.probs.len(); - Rc::get_mut(&mut self.probs).unwrap().push(input_tag.prob); + let fact_id = P::get_cell(&self.probs, |p| p.len()); + P::get_cell_mut(&self.probs, |p| p.push(input_tag.prob)); // Add exlusion if needed if let Some(disj_id) = input_tag.exclusion { - self.disjunctions.add_disjunction(disj_id, fact_id); + P::get_cell_mut(&self.disjunctions, |d| d.add_disjunction(disj_id, fact_id)); } // Lastly return a tag @@ -81,7 +81,7 @@ impl Provenance for SampleKProofsProvenance { fn recover_fn(&self, t: &Self::Tag) -> Self::OutputTag { let s = RealSemiring; - let v = |i: &usize| -> f64 { self.probs[*i] }; + let v = |i: &usize| -> f64 { self.fact_probability(i) }; t.wmc(&s, &v) } @@ -99,7 +99,7 @@ impl Provenance for SampleKProofsProvenance { fn add(&self, t1: &Self::Tag, t2: &Self::Tag) -> Self::Tag { let tag = t1.or(t2); - let sampled_clauses = self.sample_k_clauses(tag.clauses, self.k, &mut self.sampler.borrow_mut()); + let sampled_clauses = P::get_cell_mut(&self.sampler, |s| self.sample_k_clauses(tag.clauses, self.k, s)); DNFFormula { clauses: sampled_clauses, } @@ -112,7 +112,7 @@ impl Provenance for SampleKProofsProvenance { fn mult(&self, t1: &Self::Tag, t2: &Self::Tag) -> Self::Tag { let mut tag = t1.or(t2); self.retain_no_conflict(&mut tag.clauses); - let sampled_clauses = self.sample_k_clauses(tag.clauses, self.k, &mut self.sampler.borrow_mut()); + let sampled_clauses = P::get_cell_mut(&self.sampler, |s| self.sample_k_clauses(tag.clauses, self.k, s)); DNFFormula { clauses: sampled_clauses, } @@ -128,7 +128,7 @@ impl Provenance for SampleKProofsProvenance { fn weight(&self, t: &Self::Tag) -> f64 { let s = RealSemiring; - let v = |i: &usize| -> f64 { self.probs[*i] }; + let v = |i: &usize| -> f64 { self.fact_probability(i) }; t.wmc(&s, &v) } } diff --git a/core/src/runtime/provenance/probabilistic/top_bottom_k_clauses.rs b/core/src/runtime/provenance/probabilistic/top_bottom_k_clauses.rs index df3cb5b..32033e5 100644 --- a/core/src/runtime/provenance/probabilistic/top_bottom_k_clauses.rs +++ b/core/src/runtime/provenance/probabilistic/top_bottom_k_clauses.rs @@ -10,16 +10,16 @@ use crate::utils::{PointerFamily, RcFamily}; #[derive(Debug)] pub struct TopBottomKClausesProvenance { pub k: usize, - pub probs: P::Pointer>, - pub disjunctions: Disjunctions, + pub probs: P::Cell>, + pub disjunctions: P::Cell, } impl Clone for TopBottomKClausesProvenance

{ fn clone(&self) -> Self { Self { k: self.k, - probs: P::new((&*self.probs).clone()), - disjunctions: self.disjunctions.clone(), + probs: P::clone_cell(&self.probs), + disjunctions: P::clone_cell(&self.disjunctions), } } } @@ -28,8 +28,8 @@ impl TopBottomKClausesProvenance

{ pub fn new(k: usize) -> Self { Self { k, - probs: P::new(Vec::new()), - disjunctions: Disjunctions::new(), + probs: P::new_cell(Vec::new()), + disjunctions: P::new_cell(Disjunctions::new()), } } @@ -40,11 +40,11 @@ impl TopBottomKClausesProvenance

{ impl CNFDNFContextTrait for TopBottomKClausesProvenance

{ fn fact_probability(&self, id: &usize) -> f64 { - self.probs[*id] + P::get_cell(&self.probs, |p| p[*id]) } fn has_disjunction_conflict(&self, pos_facts: &BTreeSet) -> bool { - self.disjunctions.has_conflict(pos_facts) + P::get_cell(&self.disjunctions, |d| d.has_conflict(pos_facts)) } } @@ -59,14 +59,14 @@ impl Provenance for TopBottomKClausesProvenance

{ "top-bottom-k-clauses" } - fn tagging_fn(&mut self, input_tag: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { // First generate id and push the probability into the list - let fact_id = self.probs.len(); - P::get_mut(&mut self.probs).push(input_tag.prob); + let fact_id = P::get_cell(&self.probs, |p| p.len()); + P::get_cell_mut(&self.probs, |p| p.push(input_tag.prob)); // Add exlusion if needed if let Some(disj_id) = input_tag.exclusion { - self.disjunctions.add_disjunction(disj_id, fact_id); + P::get_cell_mut(&self.disjunctions, |d| d.add_disjunction(disj_id, fact_id)); } // Return the formula @@ -75,7 +75,7 @@ impl Provenance for TopBottomKClausesProvenance

{ fn recover_fn(&self, t: &Self::Tag) -> Self::OutputTag { let s = RealSemiring; - let v = |i: &usize| -> f64 { self.probs[*i] }; + let v = |i: &usize| -> f64 { self.fact_probability(i) }; t.wmc(&s, &v) } @@ -109,7 +109,7 @@ impl Provenance for TopBottomKClausesProvenance

{ fn weight(&self, t: &Self::Tag) -> f64 { let s = RealSemiring; - let v = |i: &usize| -> f64 { self.probs[*i] }; + let v = |i: &usize| -> f64 { self.fact_probability(i) }; t.wmc(&s, &v) } diff --git a/core/src/runtime/provenance/probabilistic/top_k_proofs.rs b/core/src/runtime/provenance/probabilistic/top_k_proofs.rs index 2826aeb..c3ffdc1 100644 --- a/core/src/runtime/provenance/probabilistic/top_k_proofs.rs +++ b/core/src/runtime/provenance/probabilistic/top_k_proofs.rs @@ -7,16 +7,16 @@ use crate::utils::*; pub struct TopKProofsProvenance { pub k: usize, - pub probs: P::Pointer>, - pub disjunctions: Disjunctions, + pub probs: P::Cell>, + pub disjunctions: P::Cell, } impl Default for TopKProofsProvenance

{ fn default() -> Self { Self { k: 3, - probs: P::new(Vec::new()), - disjunctions: Disjunctions::new(), + probs: P::new_cell(Vec::new()), + disjunctions: P::new_cell(Disjunctions::new()), } } } @@ -25,8 +25,8 @@ impl Clone for TopKProofsProvenance

{ fn clone(&self) -> Self { Self { k: self.k, - probs: P::new((&*self.probs).clone()), - disjunctions: self.disjunctions.clone(), + probs: P::clone_cell(&self.probs), + disjunctions: P::clone_cell(&self.disjunctions), } } } @@ -35,11 +35,15 @@ impl TopKProofsProvenance

{ pub fn new(k: usize) -> Self { Self { k, - probs: P::new(Vec::new()), - disjunctions: Disjunctions::new(), + probs: P::new_cell(Vec::new()), + disjunctions: P::new_cell(Disjunctions::new()), } } + pub fn num_facts(&self) -> usize { + P::get_cell(&self.probs, |p| p.len()) + } + pub fn set_k(&mut self, k: usize) { self.k = k; } @@ -47,11 +51,11 @@ impl TopKProofsProvenance

{ impl DNFContextTrait for TopKProofsProvenance

{ fn fact_probability(&self, id: &usize) -> f64 { - self.probs[*id] + P::get_cell(&self.probs, |p| p[*id]) } fn has_disjunction_conflict(&self, pos_facts: &std::collections::BTreeSet) -> bool { - self.disjunctions.has_conflict(pos_facts) + P::get_cell(&self.disjunctions, |d| d.has_conflict(pos_facts)) } } @@ -66,14 +70,14 @@ impl Provenance for TopKProofsProvenance

{ "top-k-proofs" } - fn tagging_fn(&mut self, input_tag: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { // First generate id and push the probability into the list - let fact_id = self.probs.len(); - P::get_mut(&mut self.probs).push(input_tag.prob); + let fact_id = self.num_facts(); + P::get_cell_mut(&self.probs, |p| p.push(input_tag.prob)); // Add exlusion if needed if let Some(disj_id) = input_tag.exclusion { - self.disjunctions.add_disjunction(disj_id, fact_id); + P::get_cell_mut(&self.disjunctions, |d| d.add_disjunction(disj_id, fact_id)); } // Lastly return a tag @@ -82,7 +86,7 @@ impl Provenance for TopKProofsProvenance

{ fn recover_fn(&self, t: &Self::Tag) -> Self::OutputTag { let s = RealSemiring; - let v = |i: &usize| -> f64 { self.probs[*i] }; + let v = |i: &usize| -> f64 { self.fact_probability(i) }; t.wmc(&s, &v) } @@ -116,7 +120,7 @@ impl Provenance for TopKProofsProvenance

{ fn weight(&self, t: &Self::Tag) -> f64 { let s = RealSemiring; - let v = |i: &usize| -> f64 { self.probs[*i] }; + let v = |i: &usize| -> f64 { self.fact_probability(i) }; t.wmc(&s, &v) } diff --git a/core/src/runtime/provenance/provenance.rs b/core/src/runtime/provenance/provenance.rs index 1b97fdd..3b9071c 100644 --- a/core/src/runtime/provenance/provenance.rs +++ b/core/src/runtime/provenance/provenance.rs @@ -7,6 +7,7 @@ use rand::prelude::*; use super::*; use crate::common::tuples::*; +use crate::common::input_tag::*; use crate::common::value_type::*; use crate::runtime::dynamic::*; use crate::runtime::env::*; @@ -15,15 +16,15 @@ use crate::runtime::statics::*; pub trait Provenance: Clone + 'static { type Tag: Tag; - type InputTag: Clone + Debug; + type InputTag: Clone + Debug + StaticInputTag; type OutputTag: Clone + Debug + Display; fn name() -> &'static str; - fn tagging_fn(&mut self, ext_tag: Self::InputTag) -> Self::Tag; + fn tagging_fn(&self, ext_tag: Self::InputTag) -> Self::Tag; - fn tagging_optional_fn(&mut self, ext_tag: Option) -> Self::Tag { + fn tagging_optional_fn(&self, ext_tag: Option) -> Self::Tag { match ext_tag { Some(et) => self.tagging_fn(et), None => self.one(), diff --git a/core/src/runtime/statics/relation.rs b/core/src/runtime/statics/relation.rs index fc3f1fc..467035f 100644 --- a/core/src/runtime/statics/relation.rs +++ b/core/src/runtime/statics/relation.rs @@ -59,22 +59,22 @@ impl StaticRelation { self.insert_tagged_with_monitor(ctx, vec![(info, tuple)], m); } - pub fn insert_dynamically_tagged(&self, ctx: &mut Prov, data: Vec<(InputTag, Tup)>) { + pub fn insert_dynamically_tagged(&self, ctx: &mut Prov, data: Vec<(DynamicInputTag, Tup)>) { let elements = data .into_iter() - .map(|(input_tag, tuple)| (FromInputTag::from_input_tag(&input_tag), tuple)) + .map(|(input_tag, tuple)| (StaticInputTag::from_dynamic_input_tag(&input_tag), tuple)) .collect::>(); self.insert_tagged(ctx, elements); } - pub fn insert_dynamically_tagged_with_monitor(&self, ctx: &mut Prov, data: Vec<(InputTag, Tup)>, m: &M) + pub fn insert_dynamically_tagged_with_monitor(&self, ctx: &mut Prov, data: Vec<(DynamicInputTag, Tup)>, m: &M) where Tuple: From, M: Monitor, { let elements = data .into_iter() - .map(|(input_tag, tuple)| (FromInputTag::from_input_tag(&input_tag), tuple)) + .map(|(input_tag, tuple)| (StaticInputTag::from_dynamic_input_tag(&input_tag), tuple)) .collect::>(); self.insert_tagged_with_monitor(ctx, elements, m); } diff --git a/core/src/testing/test_collection.rs b/core/src/testing/test_collection.rs index 671d5bf..5a4b7b0 100644 --- a/core/src/testing/test_collection.rs +++ b/core/src/testing/test_collection.rs @@ -42,9 +42,31 @@ impl> From, Tup)>> for pub fn test_equals(t1: &Tuple, t2: &Tuple) -> bool { match (t1, t2) { - (Tuple::Tuple(ts1), Tuple::Tuple(ts2)) => ts1.iter().zip(ts2.iter()).all(|(s1, s2)| test_equals(s1, s2)), - (Tuple::Value(Value::F32(f1)), Tuple::Value(Value::F32(f2))) => (f1 - f2).abs() < 0.001, - (Tuple::Value(Value::F64(f1)), Tuple::Value(Value::F64(f2))) => (f1 - f2).abs() < 0.001, + (Tuple::Tuple(ts1), Tuple::Tuple(ts2)) => { + ts1.iter().zip(ts2.iter()).all(|(s1, s2)| test_equals(s1, s2)) + }, + (Tuple::Value(Value::F32(t1)), Tuple::Value(Value::F32(t2))) => { + if t1.is_infinite() && t1.is_sign_positive() && t2.is_infinite() && t2.is_sign_positive() { + true + } else if t1.is_infinite() && t1.is_sign_negative() && t2.is_infinite() && t2.is_sign_negative() { + true + } else if t1.is_nan() || t2.is_nan() { + false + } else { + (t1 - t2).abs() < 0.001 + } + }, + (Tuple::Value(Value::F64(t1)), Tuple::Value(Value::F64(t2))) => { + if t1.is_infinite() && t1.is_sign_positive() && t2.is_infinite() && t2.is_sign_positive() { + true + } else if t1.is_infinite() && t1.is_sign_negative() && t2.is_infinite() && t2.is_sign_negative() { + true + } else if t1.is_nan() || t2.is_nan() { + false + } else { + (t1 - t2).abs() < 0.001 + } + }, _ => t1 == t2, } } @@ -78,8 +100,11 @@ where } } -pub fn expect_output_collection(actual: &DynamicOutputCollection, expected: C) -where +pub fn expect_output_collection( + name: &str, + actual: &DynamicOutputCollection, + expected: C +) where Prov: Provenance, Prov::Tag: std::fmt::Debug, C: Into, @@ -90,7 +115,7 @@ where for e in &expected.elements { let te = e.clone().into(); let pos = actual.iter().position(|(_, tuple)| test_equals(&tuple, &te)); - assert!(pos.is_some(), "Tuple {:?} not found in collection {:?}", te, actual) + assert!(pos.is_some(), "Tuple {:?} not found in `{}` collection {:?}", te, name, actual) } // Then check everything in actual is in expected @@ -101,14 +126,19 @@ where .position(|e| test_equals(&e.clone().into(), &elem.1)); assert!( pos.is_some(), - "Tuple {:?} is derived in collection but not found in expected set", - elem + "Tuple {:?} is derived in collection `{}` but not found in expected set", + elem, + name, ) } } -pub fn expect_output_collection_with_tag(actual: &DynamicOutputCollection, expected: C, cmp: F) -where +pub fn expect_output_collection_with_tag( + name: &str, + actual: &DynamicOutputCollection, + expected: C, + cmp: F, +) where Prov: Provenance, Prov::Tag: std::fmt::Debug, C: Into>, @@ -124,8 +154,9 @@ where .position(|(tag, tuple)| test_equals(&tuple, &te) && cmp(&tage, tag)); assert!( pos.is_some(), - "Tagged Tuple {:?} not found in collection {:?}", + "Tagged Tuple {:?} not found in `{}` collection {:?}", (tage, te), + name, actual ) } @@ -138,8 +169,9 @@ where .position(|(tag, tup)| test_equals(&tup.clone().into(), &elem.1) && cmp(tag, &elem.0)); assert!( pos.is_some(), - "Tagged Tuple {:?} is derived in collection but not found in expected set", - elem + "Tagged Tuple {:?} is derived in `{}` collection but not found in expected set", + elem, + name, ) } } @@ -233,13 +265,29 @@ impl_static_equals_for_value_type!(String); impl StaticEquals for f32 { fn test_static_equals(t1: &Self, t2: &Self) -> bool { - (t1 - t2).abs() < 0.001 + if t1.is_infinite() && t1.is_sign_positive() && t2.is_infinite() && t2.is_sign_positive() { + true + } else if t1.is_infinite() && t1.is_sign_negative() && t2.is_infinite() && t2.is_sign_negative() { + true + } else if t1.is_nan() || t2.is_nan() { + false + } else { + (t1 - t2).abs() < 0.001 + } } } impl StaticEquals for f64 { fn test_static_equals(t1: &Self, t2: &Self) -> bool { - (t1 - t2).abs() < 0.001 + if t1.is_infinite() && t1.is_sign_positive() && t2.is_infinite() && t2.is_sign_positive() { + true + } else if t1.is_infinite() && t1.is_sign_negative() && t2.is_infinite() && t2.is_sign_negative() { + true + } else if t1.is_nan() || t2.is_nan() { + false + } else { + (t1 - t2).abs() < 0.001 + } } } diff --git a/core/src/testing/test_interpret.rs b/core/src/testing/test_interpret.rs index ac7a796..5d69f20 100644 --- a/core/src/testing/test_interpret.rs +++ b/core/src/testing/test_interpret.rs @@ -9,7 +9,7 @@ use super::*; pub fn expect_interpret_result + Clone>(s: &str, (p, e): (&str, Vec)) { let actual = interpret_string(s.to_string()).expect("Compile Error"); - expect_output_collection(actual.get_output_collection_ref(p).unwrap(), e); + expect_output_collection(p, actual.get_output_collection_ref(p).unwrap(), e); } pub fn expect_interpret_result_with_setup(s: &str, f: F, (p, e): (&str, Vec)) @@ -22,7 +22,7 @@ where f(interpret_ctx.edb()); interpret_ctx.run().expect("Runtime error"); let idb = interpret_ctx.idb(); - expect_output_collection(idb.get_output_collection_ref(p).unwrap(), e); + expect_output_collection(p, idb.get_output_collection_ref(p).unwrap(), e); } pub fn expect_interpret_result_with_tag(s: &str, ctx: Prov, (p, e): (&str, Vec<(Prov::OutputTag, T)>), f: F) @@ -32,7 +32,7 @@ where F: Fn(&Prov::OutputTag, &Prov::OutputTag) -> bool, { let actual = interpret_string_with_ctx(s.to_string(), ctx).expect("Interpret Error"); - expect_output_collection_with_tag(actual.get_output_collection_ref(p).unwrap(), e, f); + expect_output_collection_with_tag(p, actual.get_output_collection_ref(p).unwrap(), e, f); } /// Expect the given program to produce an empty relation `p` @@ -55,7 +55,7 @@ pub fn expect_interpret_empty_result(s: &str, p: &str) { pub fn expect_interpret_multi_result(s: &str, expected: Vec<(&str, TestCollection)>) { let actual = interpret_string(s.to_string()).expect("Compile Error"); for (p, a) in expected { - expect_output_collection(actual.get_output_collection_ref(p).unwrap(), a); + expect_output_collection(p, actual.get_output_collection_ref(p).unwrap(), a); } } diff --git a/core/src/utils/chrono.rs b/core/src/utils/chrono.rs new file mode 100644 index 0000000..cbb22c9 --- /dev/null +++ b/core/src/utils/chrono.rs @@ -0,0 +1,10 @@ +/// Parse a string into a chrono DateTime +pub fn parse_date_time_string(d: &str) -> Option> { + dateparser::parse(d).ok() +} + +/// Parse a string into a chrono Duration +pub fn parse_duration_string(d: &str) -> Option { + let d1 = parse_duration::parse(d).ok()?; + chrono::Duration::from_std(d1).ok() +} diff --git a/core/src/utils/copy_on_write.rs b/core/src/utils/copy_on_write.rs index 564a82a..2af2c23 100644 --- a/core/src/utils/copy_on_write.rs +++ b/core/src/utils/copy_on_write.rs @@ -1,10 +1,10 @@ use super::{PointerFamily, RcFamily}; -pub struct CopyOnWrite(P::Pointer); +pub struct CopyOnWrite(P::Rc); impl CopyOnWrite { pub fn new(t: T) -> Self { - Self(P::new(t)) + Self(P::new_rc(t)) } pub fn borrow(&self) -> &T { @@ -17,13 +17,13 @@ impl CopyOnWrite { { let mut new_inner = (*self.0).clone(); f(&mut new_inner); - *self = Self(P::new(new_inner)); + *self = Self(P::new_rc(new_inner)); } } impl Clone for CopyOnWrite { fn clone(&self) -> Self { - Self(P::clone_ptr(&self.0)) + Self(P::clone_rc(&self.0)) } } diff --git a/core/src/utils/float.rs b/core/src/utils/float.rs new file mode 100644 index 0000000..e4e12e9 --- /dev/null +++ b/core/src/utils/float.rs @@ -0,0 +1,39 @@ +/// Floating Point trait (f32, f64) +pub trait Float: + Sized + + Copy + + Clone + + PartialEq + + PartialOrd + + std::fmt::Debug + + std::fmt::Display + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::ops::Div + + std::convert::TryInto +{ + fn zero() -> Self; + + fn one() -> Self; +} + +impl Float for f32 { + fn zero() -> Self { + 0.0 + } + + fn one() -> Self { + 1.0 + } +} + +impl Float for f64 { + fn zero() -> Self { + 0.0 + } + + fn one() -> Self { + 1.0 + } +} diff --git a/core/src/utils/integer.rs b/core/src/utils/integer.rs index c9dae4f..05b24a8 100644 --- a/core/src/utils/integer.rs +++ b/core/src/utils/integer.rs @@ -7,6 +7,8 @@ pub trait Integer: Eq + PartialOrd + Ord + + std::fmt::Debug + + std::fmt::Display + std::ops::Add + std::ops::Sub + std::ops::Mul + diff --git a/core/src/utils/mod.rs b/core/src/utils/mod.rs index 4c7e362..643f9d5 100644 --- a/core/src/utils/mod.rs +++ b/core/src/utils/mod.rs @@ -1,11 +1,15 @@ //! Utilities +mod chrono; mod copy_on_write; +mod float; mod id_allocator; mod integer; mod pointer_family; +pub use self::chrono::*; pub(crate) use copy_on_write::*; +pub use float::*; pub(crate) use id_allocator::*; pub use integer::*; pub use pointer_family::*; diff --git a/core/src/utils/pointer_family.rs b/core/src/utils/pointer_family.rs index be37107..3474993 100644 --- a/core/src/utils/pointer_family.rs +++ b/core/src/utils/pointer_family.rs @@ -5,36 +5,40 @@ use std::sync::{Arc, Mutex}; /// Pointer Family is a trait to generalize reference counted pointers /// such as `Rc` and `Arc`. -/// Each Pointer Family defines two types: `Pointer` and `Cell`, +/// Each Pointer Family defines three types: `Pointer` and `Cell`, /// where `Pointer` is the simple reference counted pointer and `Cell` /// contains a internally mutable pointer. pub trait PointerFamily: Clone + PartialEq + 'static { + + /* ==================== Ref Counted ==================== */ + /// Reference counted pointer - type Pointer: Deref; + type Rc: Deref; /// Create a new `Pointer` - fn new(value: T) -> Self::Pointer; + fn new_rc(value: T) -> Self::Rc; - /// Clone a `Pointer`. Only the reference counter will increase; + /// Clone a `Rc`. Only the reference counter will increase; /// the content will not be cloned - fn clone_ptr(ptr: &Self::Pointer) -> Self::Pointer; + fn clone_rc(ptr: &Self::Rc) -> Self::Rc; - /// Get an immutable reference to the content pointed by the `Pointer` - fn get(ptr: &Self::Pointer) -> &T; + /// Get an immutable reference to the content pointed by the `Rc` + fn get_rc(ptr: &Self::Rc) -> &T; - /// Get a mutable reference to the content pointed by the `Pointer`. - /// Note that the `Pointer` itself needs to be mutable here - fn get_mut(ptr: &mut Self::Pointer) -> &mut T; + /// Get a mutable reference to the content pointed by the `Rc`. + /// Note that the `Rc` itself needs to be mutable here + fn get_rc_mut(ptr: &mut Self::Rc) -> &mut T; - /// Reference counted Cell + /* ==================== Cell ==================== */ + + /// Cell type Cell; /// Create a new `Cell` fn new_cell(value: T) -> Self::Cell; - /// Clone a `Cell`. Only the reference counter will increase; - /// the content will not be cloned - fn clone_cell(ptr: &Self::Cell) -> Self::Cell; + /// Clone a `Cell` + fn clone_cell(ptr: &Self::Cell) -> Self::Cell; /// Apply function `f` to the immutable content in the cell fn get_cell(ptr: &Self::Cell, f: F) -> O @@ -45,48 +49,100 @@ pub trait PointerFamily: Clone + PartialEq + 'static { fn get_cell_mut(ptr: &Self::Cell, f: F) -> O where F: FnOnce(&mut T) -> O; + + /* ==================== Ref Counted Cell ==================== */ + + /// Reference counted Cell + type RcCell; + + /// Create a new `RcCell` + fn new_rc_cell(value: T) -> Self::RcCell; + + /// Clone a `RcCell`. Only the reference counter will increase; + /// the content will not be cloned + fn clone_rc_cell(ptr: &Self::RcCell) -> Self::RcCell; + + /// Clone the internal of the ref counted cell + fn clone_rc_cell_internal(ptr: &Self::RcCell) -> T { + Self::get_rc_cell(ptr, |x| x.clone()) + } + + /// Apply function `f` to the immutable content in the cell + fn get_rc_cell(ptr: &Self::RcCell, f: F) -> O + where + F: FnOnce(&T) -> O; + + /// Apply function `f` to the mutable content in the cell + fn get_rc_cell_mut(ptr: &Self::RcCell, f: F) -> O + where + F: FnOnce(&mut T) -> O; } +/// The Arc pointer family, mainly used for multi-threaded program #[derive(Clone, Debug, PartialEq)] pub struct ArcFamily; impl PointerFamily for ArcFamily { - type Pointer = Arc; + type Rc = Arc; - fn new(value: T) -> Self::Pointer { + fn new_rc(value: T) -> Self::Rc { Arc::new(value) } - fn clone_ptr(ptr: &Self::Pointer) -> Self::Pointer { + fn clone_rc(ptr: &Self::Rc) -> Self::Rc { Arc::clone(ptr) } - fn get(ptr: &Self::Pointer) -> &T { + fn get_rc(ptr: &Self::Rc) -> &T { &*ptr } - fn get_mut(ptr: &mut Self::Pointer) -> &mut T { + fn get_rc_mut(ptr: &mut Self::Rc) -> &mut T { Arc::get_mut(ptr).unwrap() } - type Cell = Arc>; + type Cell = Mutex; fn new_cell(value: T) -> Self::Cell { + Self::Cell::new(value) + } + + fn clone_cell(ptr: &Self::Cell) -> Self::Cell { + Self::Cell::new(ptr.lock().unwrap().clone()) + } + + fn get_cell(ptr: &Self::Cell, f: F) -> O + where + F: FnOnce(&T) -> O + { + f(&ptr.lock().unwrap()) + } + + fn get_cell_mut(ptr: &Self::Cell, f: F) -> O + where + F: FnOnce(&mut T) -> O + { + f(&mut ptr.lock().unwrap()) + } + + type RcCell = Arc>; + + fn new_rc_cell(value: T) -> Self::RcCell { Arc::new(Mutex::new(value)) } - fn clone_cell(ptr: &Self::Cell) -> Self::Cell { + fn clone_rc_cell(ptr: &Self::RcCell) -> Self::RcCell { ptr.clone() } - fn get_cell(ptr: &Self::Cell, f: F) -> O + fn get_rc_cell(ptr: &Self::RcCell, f: F) -> O where F: FnOnce(&T) -> O, { f(&*ptr.lock().unwrap()) } - fn get_cell_mut(ptr: &Self::Cell, f: F) -> O + fn get_rc_cell_mut(ptr: &Self::RcCell, f: F) -> O where F: FnOnce(&mut T) -> O, { @@ -98,32 +154,32 @@ impl PointerFamily for ArcFamily { pub struct RcFamily; impl PointerFamily for RcFamily { - type Pointer = Rc; + type Rc = Rc; - fn new(value: T) -> Self::Pointer { + fn new_rc(value: T) -> Self::Rc { Rc::new(value) } - fn clone_ptr(ptr: &Self::Pointer) -> Self::Pointer { + fn clone_rc(ptr: &Self::Rc) -> Self::Rc { Rc::clone(ptr) } - fn get(ptr: &Self::Pointer) -> &T { + fn get_rc(ptr: &Self::Rc) -> &T { &*ptr } - fn get_mut(ptr: &mut Self::Pointer) -> &mut T { + fn get_rc_mut(ptr: &mut Self::Rc) -> &mut T { Rc::get_mut(ptr).unwrap() } - type Cell = Rc>; + type Cell = RefCell; fn new_cell(value: T) -> Self::Cell { - Rc::new(RefCell::new(value)) + RefCell::new(value) } - fn clone_cell(ptr: &Self::Cell) -> Self::Cell { - ptr.clone() + fn clone_cell(ptr: &Self::Cell) -> Self::Cell { + RefCell::new(ptr.borrow().clone()) } fn get_cell(ptr: &Self::Cell, f: F) -> O @@ -139,4 +195,28 @@ impl PointerFamily for RcFamily { { f(&mut (*ptr.borrow_mut())) } + + type RcCell = Rc>; + + fn new_rc_cell(value: T) -> Self::RcCell { + Rc::new(RefCell::new(value)) + } + + fn clone_rc_cell(ptr: &Self::RcCell) -> Self::RcCell { + ptr.clone() + } + + fn get_rc_cell(ptr: &Self::RcCell, f: F) -> O + where + F: FnOnce(&T) -> O, + { + f(&*ptr.borrow()) + } + + fn get_rc_cell_mut(ptr: &Self::RcCell, f: F) -> O + where + F: FnOnce(&mut T) -> O, + { + f(&mut (*ptr.borrow_mut())) + } } diff --git a/core/tests/compiler/errors.rs b/core/tests/compiler/errors.rs index 2c9c780..7a4120d 100644 --- a/core/tests/compiler/errors.rs +++ b/core/tests/compiler/errors.rs @@ -118,3 +118,13 @@ fn conflicting_constant_decl_type_3() { |e| e.contains("cannot unify"), ) } + +#[test] +fn bad_enum_type_decl() { + expect_front_compile_failure( + r#" + type K = A = 3 | B | C = 4 | D + "#, + |e| e.contains("has already been assigned"), + ) +} diff --git a/core/tests/integrate/basic.rs b/core/tests/integrate/basic.rs index 83a7e9e..fbaf6d0 100644 --- a/core/tests/integrate/basic.rs +++ b/core/tests/integrate/basic.rs @@ -1,4 +1,5 @@ use scallop_core::runtime::provenance::*; +use scallop_core::utils::*; use scallop_core::testing::*; #[test] @@ -9,7 +10,7 @@ fn basic_edge_path_left_recursion() { rel path(a, b) = edge(a, b) \/ path(a, c) /\ edge(c, b) query path "#, - ("path", vec![(0usize, 2usize), (1, 2), (0, 3), (1, 3), (2, 3)]), + ("path", vec![(0, 2), (1, 2), (0, 3), (1, 3), (2, 3)]), ); } @@ -21,7 +22,7 @@ fn basic_edge_path_right_recursion() { rel path(a, b) = edge(a, b) \/ edge(a, c) /\ path(c, b) query path "#, - ("path", vec![(0usize, 2usize), (1, 2), (0, 3), (1, 3), (2, 3)]), + ("path", vec![(0, 2), (1, 2), (0, 3), (1, 3), (2, 3)]), ); } @@ -33,7 +34,7 @@ fn basic_edge_path_binary_recursion() { rel path(a, b) = edge(a, b) \/ path(a, c) /\ path(c, b) query path "#, - ("path", vec![(0usize, 2usize), (1, 2), (0, 3), (1, 3), (2, 3)]), + ("path", vec![(0, 2), (1, 2), (0, 3), (1, 3), (2, 3)]), ); } @@ -61,7 +62,7 @@ fn basic_difference_1() { rel s(x, y) = a(x, y), ~b(x, y) query s "#, - ("s", vec![(0usize, 1usize)]), + ("s", vec![(0, 1)]), ); } @@ -81,9 +82,9 @@ fn bmi_test_1() { rel bmi(id, w as f32 / ((h * h) as f32 / 10000.0)) = height(id, h), weight(id, w) "#, vec![ - ("height", vec![(1usize, 185i32), (2, 175), (3, 165)].into()), - ("weight", vec![(1usize, 80usize), (2, 70), (3, 55)].into()), - ("bmi", vec![(1usize, 23.374f32), (2, 22.857), (3, 20.202)].into()), + ("height", vec![(1, 185), (2, 175), (3, 165)].into()), + ("weight", vec![(1, 80), (2, 70), (3, 55)].into()), + ("bmi", vec![(1, 23.374f32), (2, 22.857), (3, 20.202)].into()), ], ) } @@ -120,7 +121,7 @@ fn const_fold_test_1() { rel E(1) rel R(s, a) = s == x + z, x == y + 1, y == z + 1, z == 1, E(a) "#, - ("R", vec![(4i32, 1usize)]), + ("R", vec![(4, 1)]), ); } @@ -130,7 +131,7 @@ fn const_fold_test_2() { r#" rel R(s) = s == x + z, x == y + 1, y == z + 1, z == 1 "#, - ("R", vec![(4i32,)]), + ("R", vec![(4,)]), ); } @@ -169,7 +170,7 @@ fn topk_test_1() { rel r1 = {(0, "x"), (0, "y"), (1, "y"), (1, "z")} rel r2(id, sym) :- sym = top<1>(s: r1(id, s)) "#, - ("r2", vec![(0usize, "x".to_string()), (1, "y".to_string())]), + ("r2", vec![(0, "x".to_string()), (1, "y".to_string())]), ); } @@ -186,7 +187,7 @@ fn digit_sum_test_1() { ( "sum_2", vec![ - (0usize, 1usize, 0i32), + (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 3), @@ -213,7 +214,7 @@ fn digit_sum_test_2() { ( "sum_2", vec![ - (0usize, 1usize, 0i32), + (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 3), @@ -330,7 +331,7 @@ fn simple_test_1() { rel edge = {(0, 1), (1, 2)} rel path(a, b) = edge(a, b) "#, - ("path", vec![(0usize, 1usize), (1, 2)]), + ("path", vec![(0, 1), (1, 2)]), ); } @@ -341,7 +342,7 @@ fn simple_test_2() { rel edge = {(0, 1), (1, 2), (2, 2)} rel self_edge(a, a) :- edge(a, a) "#, - ("self_edge", vec![(2usize, 2usize)]), + ("self_edge", vec![(2, 2)]), ); } @@ -352,7 +353,7 @@ fn simple_test_3() { rel edge = {(0, 1), (1, 2)} rel something(a, 2) :- edge(a, b) "#, - ("something", vec![(0usize, 2usize), (1, 2)]), + ("something", vec![(0, 2), (1, 2)]), ); } @@ -363,7 +364,7 @@ fn simple_test_4() { rel edge = {(0, 1), (1, 2)} rel something(a, 2) :- edge(a, b), b > 1 "#, - ("something", vec![(1usize, 2usize)]), + ("something", vec![(1, 2)]), ); } @@ -375,7 +376,7 @@ fn simple_test_5() { rel R = {(1, 2), (4, 3)} rel O(a, b) = S(b, a), R(a, b) "#, - ("O", vec![(4usize, 3usize)]), + ("O", vec![(4, 3)]), ); } @@ -387,7 +388,7 @@ fn simple_test_6() { rel R = {(1, 2), (3, 4), (4, 3)} rel O(a, b) = S(b, a), S(a, b), R(a, b), R(b, a) "#, - ("O", vec![(3usize, 4usize), (4, 3)]), + ("O", vec![(3, 4), (4, 3)]), ); } @@ -399,7 +400,7 @@ fn simple_test_7() { rel R = {(1), (2)} rel O(a, b) = S(a, b), ~R(b) "#, - ("O", vec![(2usize, 3usize)]), + ("O", vec![(2, 3)]), ); } @@ -411,7 +412,7 @@ fn simple_test_8() { rel R = {(1, 2), (2, 3)} rel O(a, b) = S(a, b), ~R(b, c) "#, - ("O", vec![(2usize, 3usize)]), + ("O", vec![(2, 3)]), ) } @@ -423,7 +424,7 @@ fn simple_test_9() { rel R = {(1, 2), (2, 3), (2, 2)} rel O(a, b) = S(a, b), ~R(a, a) "#, - ("O", vec![(0usize, 1usize), (1, 2)]), + ("O", vec![(0, 1), (1, 2)]), ) } @@ -479,7 +480,7 @@ fn class_student_grade_1() { "#, ( "class_top_student", - vec![(0usize, "jerry".to_string()), (1, "sherry".to_string())], + vec![(0, "jerry".to_string()), (1, "sherry".to_string())], ), ) } @@ -514,7 +515,7 @@ fn unused_relation_1() { rel S(b, 1) = B(b) query S "#, - ("S", vec![("haha".to_string(), 1usize), ("wow".to_string(), 1)]), + ("S", vec![("haha".to_string(), 1), ("wow".to_string(), 1)]), ) } @@ -527,8 +528,8 @@ fn atomic_query_1() { query S(_, 2) "#, vec![ - ("S(0, _)", vec![(0usize, 1usize), (0, 2)].into()), - ("S(_, 2)", vec![(0usize, 2usize), (1, 2)].into()), + ("S(0, _)", vec![(0, 1), (0, 2)].into()), + ("S(_, 2)", vec![(0, 2), (1, 2)].into()), ], ) } @@ -624,7 +625,7 @@ fn equal_v1_v2() { rel S = {(0, 1), (1, 2)} rel Q(a, b) = S(a, b), a == a "#, - vec![("Q", vec![(0usize, 1usize), (1, 2)].into())], + vec![("Q", vec![(0, 1), (1, 2)].into())], ) } @@ -654,7 +655,7 @@ fn test_count_with_where_clause() { "#, vec![( "count_enroll_cs_in_class", - vec![(0usize, 1usize), (1, 2), (2, 0)].into(), + vec![(0, 1usize), (1, 2), (2, 0)].into(), )], ) } @@ -741,7 +742,7 @@ fn implies_1() { // The object `o` such that `o` is cube implies that `o` is blue rel answer(o) = obj(o) and (shape(o, "cube") => color(o, "blue")) "#, - ("answer", vec![(1usize,)]), + ("answer", vec![(1,)]), ) } @@ -756,7 +757,7 @@ fn implies_2() { // The object `o` such that `o` is cube implies that `o` is blue rel answer(o) = obj(o) and (shape(o, "cube") => color(o, "blue")) "#, - ("answer", vec![(1usize,), (2,)]), + ("answer", vec![(1,), (2,)]), ) } @@ -923,7 +924,7 @@ fn ff_max_1() { r#" rel output($max(0, 1, 2, 3)) "#, - ("output", vec![(3usize,)]), + ("output", vec![(3,)]), ) } @@ -935,7 +936,7 @@ fn ff_max_2() { rel S = {3} rel output($max(a, b)) = R(a), S(b) "#, - ("output", vec![(3usize,)]), + ("output", vec![(3,)]), ) } @@ -947,7 +948,7 @@ fn const_variable_1() { const VAR2 = 246 rel r(VAR1, VAR2) "#, - ("r", vec![(135usize, 246usize)]), + ("r", vec![(135, 246)]), ) } @@ -994,7 +995,7 @@ fn const_variable_5() { const UP = 0, RIGHT = 1, DOWN = 2, LEFT = 3 rel r(UP, RIGHT, DOWN, LEFT) "#, - ("r", vec![(0usize, 1usize, 2usize, 3usize)]), + ("r", vec![(0, 1, 2, 3)]), ) } @@ -1009,9 +1010,53 @@ fn const_variable_6() { ) } +#[test] +fn const_variable_7() { + expect_interpret_result( + r#" + type Action = UP | RIGHT | DOWN | LEFT + rel r(UP, RIGHT, DOWN, LEFT) + "#, + ("r", vec![(0usize, 1usize, 2usize, 3usize)]), + ) +} + +#[test] +fn const_variable_8() { + expect_interpret_result( + r#" + type Action = UP = 10 | RIGHT | DOWN | LEFT + rel r(UP, RIGHT, DOWN, LEFT) + "#, + ("r", vec![(10usize, 11usize, 12usize, 13usize)]), + ) +} + +#[test] +fn const_variable_9() { + expect_interpret_result( + r#" + type Action = UP | RIGHT | DOWN = 10 | LEFT + rel r(UP, RIGHT, DOWN, LEFT) + "#, + ("r", vec![(0usize, 1usize, 10usize, 11usize)]), + ) +} + +#[test] +fn const_variable_10() { + expect_interpret_result( + r#" + type Action = UP = 3 | RIGHT | DOWN = 10 | LEFT + rel r(UP, RIGHT, DOWN, LEFT) + "#, + ("r", vec![(3usize, 4usize, 10usize, 11usize)]), + ) +} + #[test] fn sat_1() { - let ctx = proofs::ProofsProvenance::default(); + let ctx = proofs::ProofsProvenance::::default(); expect_interpret_result_with_tag( r#" type assign(String, bool) @@ -1046,7 +1091,7 @@ fn sat_1() { #[test] fn sat_2() { - let ctx = proofs::ProofsProvenance::default(); + let ctx = proofs::ProofsProvenance::::default(); expect_interpret_result_with_tag( r#" type assign(String, bool) @@ -1074,3 +1119,26 @@ fn sat_2() { |_, _| true, ) } + +#[test] +fn no_nan_1() { + expect_interpret_result( + r#" + rel R = {0.0, 3.0, 5.0} + rel P = {0.0, 1.0} + rel Q(a / b) = R(a) and P(b) + "#, + ("Q", vec![(0.0f32,), (3.0f32,), (5.0f32,), (std::f32::INFINITY,)]), + ) +} + +#[test] +fn string_plus_string_1() { + expect_interpret_result( + r#" + rel first_last_name = {("Alice", "Lee")} + rel full_name(first + " " + last) = first_last_name(first, last) + "#, + ("full_name", vec![("Alice Lee".to_string(),)]), + ) +} diff --git a/core/tests/integrate/dt.rs b/core/tests/integrate/dt.rs index 61ef0b2..d6f9730 100644 --- a/core/tests/integrate/dt.rs +++ b/core/tests/integrate/dt.rs @@ -35,7 +35,7 @@ fn dt_edge_path_1() { rel path(a, b) = edge(a, b) \/ path(a, c) /\ edge(c, b) query path(_, 3) "#, - ("path(_, 3)", vec![(0usize, 3usize), (1, 3), (2, 3)]), + ("path(_, 3)", vec![(0, 3), (1, 3), (2, 3)]), ); } @@ -48,6 +48,6 @@ fn dt_edge_path_2() { rel path(a, b) = edge(a, b) \/ path(a, c) /\ edge(c, b) query path(0, _) "#, - ("path(0, _)", vec![(0usize, 1usize), (0, 2), (0, 3)]), + ("path(0, _)", vec![(0, 1), (0, 2), (0, 3)]), ); } diff --git a/core/tests/integrate/edb.rs b/core/tests/integrate/edb.rs index e32158f..4e13d17 100644 --- a/core/tests/integrate/edb.rs +++ b/core/tests/integrate/edb.rs @@ -113,6 +113,7 @@ fn edb_edge_path_incremental_update() { // Check the result expect_output_collection( + "path", ctx.computed_relation_ref("path").unwrap(), vec![(0usize, 1usize), (0, 2), (1, 2)], ); @@ -128,6 +129,7 @@ fn edb_edge_path_incremental_update() { // Check the result expect_output_collection( + "path", ctx.computed_relation_ref("path").unwrap(), vec![(0usize, 1usize), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)], ); @@ -189,10 +191,12 @@ fn edb_edge_path_persistent_relation() { // Check the result expect_output_collection( + "path_1", ctx.computed_relation_ref("path_1").unwrap(), vec![(0usize, 1usize), (0, 2), (1, 2)], ); expect_output_collection( + "path_2", ctx.computed_relation_ref("path_2").unwrap(), vec![(0usize, 1usize), (0, 2), (1, 2)], ); @@ -213,10 +217,12 @@ fn edb_edge_path_persistent_relation() { // Check the result expect_output_collection( + "path_1", ctx.computed_relation_ref("path_1").unwrap(), vec![(0usize, 1usize), (0, 2), (1, 2)], ); expect_output_collection( + "path_2", ctx.computed_relation_ref("path_2").unwrap(), vec![(0usize, 1usize), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)], ); diff --git a/core/tests/integrate/ff.rs b/core/tests/integrate/ff.rs index a1c8599..0cb8b01 100644 --- a/core/tests/integrate/ff.rs +++ b/core/tests/integrate/ff.rs @@ -94,7 +94,114 @@ fn test_fib_ff() { // Result expect_output_collection( + "S", ctx.computed_relation_ref("S").unwrap(), vec![(0i32, 1i32), (3, 2), (5, 5), (8, 21)], ); } + +#[test] +fn ff_string_length_1() { + expect_interpret_result( + r#" + rel strings = {"hello", "world!"} + rel lengths(x, $string_length(x)) = strings(x) + "#, + ( + "lengths", + vec![("hello".to_string(), 5usize), ("world!".to_string(), 6)], + ), + ); +} + +#[test] +fn ff_string_length_2() { + expect_interpret_result( + r#" + rel strings = {"hello", "world!"} + rel lengths(x, y) = strings(x), y == $string_length(x) + "#, + ( + "lengths", + vec![("hello".to_string(), 5usize), ("world!".to_string(), 6)], + ), + ); +} + +#[test] +fn ff_string_concat_2() { + expect_interpret_result( + r#" + rel strings = {"hello", "world!"} + rel cat(x) = strings(a), strings(b), a != b, x == $string_concat(a, " ", b) + "#, + ( + "cat", + vec![("hello world!".to_string(),), ("world! hello".to_string(),)], + ), + ); +} + +#[test] +fn ff_hash_1() { + expect_interpret_result( + r#" + rel result(x) = x == $hash(1, 3) + "#, + ("result", vec![(7198375873285174811u64,)]), + ); +} + +#[test] +fn ff_hash_2() { + expect_interpret_result( + r#" + rel result($hash(1, 3)) + "#, + ("result", vec![(7198375873285174811u64,)]), + ); +} + +#[test] +fn ff_abs_1() { + expect_interpret_result( + r#" + rel my_rel = {-1, 3, 5, -6} + rel abs_result($abs(x)) = my_rel(x) + "#, + ("abs_result", vec![(1i32,), (3,), (5,), (6,)]), + ); +} + +#[test] +fn ff_abs_2() { + expect_interpret_result( + r#" + rel my_rel = {-1.5, 3.3, 5.0, -6.5} + rel abs_result($abs(x)) = my_rel(x) + "#, + ("abs_result", vec![(1.5f32,), (3.3,), (5.0,), (6.5,)]), + ); +} + +#[test] +fn ff_substring_1() { + expect_interpret_result( + r#" + rel my_rel = {"hello world!"} + rel result($substring(x, 0, 5)) = my_rel(x) + "#, + ("result", vec![("hello".to_string(),)]), + ); +} + +#[test] +fn ff_substring_2() { + expect_interpret_result( + r#" + rel my_rel = {"hello world!"} + rel result($substring(x, 6)) = my_rel(x) + "#, + ("result", vec![("world!".to_string(),)]), + ); +} diff --git a/core/tests/integrate/fp.rs b/core/tests/integrate/fp.rs new file mode 100644 index 0000000..4bd3cf5 --- /dev/null +++ b/core/tests/integrate/fp.rs @@ -0,0 +1,91 @@ +use scallop_core::testing::*; + +#[test] +fn range_free_1() { + expect_interpret_result( + r#" + rel result(y) = range_usize(0, 5, y) + "#, + ("result", vec![(0usize,), (1,), (2,), (3,), (4,)]), + ); +} + +#[test] +fn range_constraint_1() { + expect_interpret_result( + r#" + rel base = {("A", "B", 3.0)} + rel result(a, b) = base(a, b, x) and soft_eq_f32(x, 3.0) + "#, + ("result", vec![("A".to_string(), "B".to_string())]), + ); +} + +#[test] +fn range_join_1() { + expect_interpret_result( + r#" + rel base = {3} + rel result(y) = base(x) and range_usize(0, x, y) + "#, + ("result", vec![(0usize,), (1,), (2,)]), + ); +} + +#[test] +fn range_join_2() { + expect_interpret_result( + r#" + rel base = {3} + rel result() = base(x) and range_usize(0, x, 2) + "#, + ("result", vec![()]), + ); +} + +#[test] +fn range_join_3() { + expect_interpret_empty_result( + r#" + rel base = {3} + rel result() = base(x) and range_usize(0, x, 100) + "#, + "result", + ); +} + +#[test] +fn range_join_4() { + expect_interpret_result( + r#" + rel base = {3, 10} + rel result(x) = base(x) and range_usize(0, x, 5) + "#, + ("result", vec![(10usize,)]), + ); +} + +#[test] +fn string_chars_1() { + expect_interpret_result( + r#" + rel string = {"hello"} + rel result(i, c) = string(s), string_chars(s, i, c) + "#, + ("result", vec![(0usize, 'h'), (1, 'e'), (2, 'l'), (3, 'l'), (4, 'o')]), + ); +} + +#[test] +fn floating_point_eq_1() { + expect_interpret_multi_result( + r#" + rel result_1() = float_eq_f32(3.000001, 1.000001 + 2.000001) + rel result_2() = 3.000001 == 1.000001 + 2.000001 + "#, + vec![ + ("result_1", vec![()].into()), + ("result_2", TestCollection::empty()) + ], + ) +} diff --git a/core/tests/integrate/incr.rs b/core/tests/integrate/incr.rs index 090f8e2..b7969b9 100644 --- a/core/tests/integrate/incr.rs +++ b/core/tests/integrate/incr.rs @@ -28,6 +28,7 @@ fn incr_edge_path_left_recursion() { // Result expect_output_collection( + "path", ctx.computed_relation_ref("path").unwrap(), vec![(0usize, 1usize), (0, 2), (1, 2)], ); @@ -49,6 +50,7 @@ fn incr_edge_path_left_branching_1() { .unwrap(); ctx.run().unwrap(); expect_output_collection( + "edge", ctx.computed_relation_ref("edge").unwrap(), vec![(0usize, 1usize), (1, 2)], ); @@ -60,6 +62,7 @@ fn incr_edge_path_left_branching_1() { .unwrap(); first_branch.run().unwrap(); expect_output_collection( + "path", first_branch.computed_relation_ref("path").unwrap(), vec![(0usize, 1usize), (0, 2), (1, 2)], ); @@ -71,14 +74,16 @@ fn incr_edge_path_left_branching_1() { .unwrap(); second_branch.run().unwrap(); expect_output_collection( + "path", second_branch.computed_relation_ref("path").unwrap(), vec![(0usize, 1usize), (0, 2), (1, 2)], ); // Second branch, continuation - second_branch.add_rule(r#"result(1, y) = path(1, y)"#).unwrap(); + second_branch.add_rule(r#"result(x, y) = path(x, y) and x == 1"#).unwrap(); second_branch.run().unwrap(); expect_output_collection( + "result", second_branch.computed_relation_ref("result").unwrap(), vec![(1usize, 2usize)], ); @@ -101,6 +106,7 @@ fn incr_fib_test_0() { ctx.run().expect("Runtime error"); expect_output_collection( + "fib", ctx.computed_relation_ref("fib").unwrap(), vec![(0i32, 1i32), (1, 1), (2, 2), (3, 3), (4, 5), (5, 8)], ); @@ -123,6 +129,7 @@ fn incr_fib_test_1() { ctx.run().expect("Runtime error"); expect_output_collection( + "fib", ctx.computed_relation_ref("fib").unwrap(), vec![(0i32, 1i32), (1, 1), (2, 2), (3, 3), (4, 5), (5, 8)], ); diff --git a/core/tests/integrate/mod.rs b/core/tests/integrate/mod.rs index d4e9b05..17f35a7 100644 --- a/core/tests/integrate/mod.rs +++ b/core/tests/integrate/mod.rs @@ -3,7 +3,8 @@ mod bug; mod dt; mod edb; mod ff; +mod fp; mod incr; mod iter; -mod prim; mod prob; +mod time; diff --git a/core/tests/integrate/prob.rs b/core/tests/integrate/prob.rs index 050458d..da2789b 100644 --- a/core/tests/integrate/prob.rs +++ b/core/tests/integrate/prob.rs @@ -41,7 +41,7 @@ fn test_min_max_with_recursion() { query path(3, 1, 3, 3) "#, ctx, - ("path(3, 1, 3, 3)", vec![(0.9, (3usize, 1usize, 3usize, 3usize))]), + ("path(3, 1, 3, 3)", vec![(0.9, (3, 1, 3, 3))]), min_max_prob::MinMaxProbProvenance::cmp, ) } diff --git a/core/tests/integrate/time.rs b/core/tests/integrate/time.rs new file mode 100644 index 0000000..0d2f955 --- /dev/null +++ b/core/tests/integrate/time.rs @@ -0,0 +1,127 @@ +use chrono::*; + +use scallop_core::testing::*; + +#[test] +fn date_type_1() { + expect_compile( + r#" + type r(DateTime) + rel r = {t"2019-01-01T00:00:00Z"} + "#, + ) +} + +#[test] +fn duration_type_1() { + expect_compile( + r#" + type r(Duration) + rel r = {d"1y5d"} + "#, + ) +} + +#[test] +fn date_1() { + expect_interpret_result( + r#" + rel r = {t"2019-01-01T00:00:00Z"} + "#, + ("r", vec![(Utc.with_ymd_and_hms(2019, 01, 01, 0, 0, 0).unwrap(),)]) + ) +} + +#[test] +fn date_2() { + expect_interpret_result( + r#" + rel r = {(0, t"2019-01-01T00:00:00Z")} + "#, + ("r", vec![(0, Utc.with_ymd_and_hms(2019, 01, 01, 0, 0, 0).unwrap())]) + ) +} + +#[test] +fn bad_date_1() { + expect_front_compile_failure( + r#"rel r = {t"ABCDEF"}"#, + |e| e.contains("Cannot parse date time `ABCDEF`") + ) +} + +#[test] +fn bad_duration_1() { + expect_front_compile_failure( + r#"rel r = {d"ABCDEF"}"#, + |e| e.contains("Cannot parse duration `ABCDEF`") + ) +} + +#[test] +fn date_plus_duration_1() { + expect_interpret_result( + r#" + rel p = {t"2019-01-01T00:00:00Z"} + rel q = {d"3d"} + rel r(date + duration) = p(date) and q(duration) + "#, + ("r", vec![(Utc.with_ymd_and_hms(2019, 01, 04, 0, 0, 0).unwrap(),)]) + ) +} + +#[test] +fn date_minus_duration_1() { + expect_interpret_result( + r#" + rel p = {t"2019-01-04T00:00:00Z"} + rel q = {d"3d"} + rel r(date - duration) = p(date) and q(duration) + "#, + ("r", vec![(Utc.with_ymd_and_hms(2019, 01, 01, 0, 0, 0).unwrap(),)]) + ) +} + +#[test] +fn duration_plus_duration_1() { + expect_interpret_result( + r#" + rel p = {(d"3d", d"2d")} + rel r(d1 + d2) = p(d1, d2) + "#, + ("r", vec![(Duration::days(5),)]) + ) +} + +#[test] +fn get_year_1() { + expect_interpret_result( + r#" + rel p = {t"2019-01-04T00:00:00Z"} + rel r($datetime_year(d)) = p(d) + "#, + ("r", vec![(2019i32,)]) + ) +} + +#[test] +fn get_month_1() { + expect_interpret_result( + r#" + rel p = {t"2019-01-04T00:00:00Z"} + rel r($datetime_month(d)) = p(d) + "#, + ("r", vec![(1u32,)]) + ) +} + +#[test] +fn get_month0_1() { + expect_interpret_result( + r#" + rel p = {t"2019-01-04T00:00:00Z"} + rel r($datetime_month0(d)) = p(d) + "#, + ("r", vec![(0u32,)]) + ) +} diff --git a/core/tests/runtime/dataflow/dyn_foreign_predicate.rs b/core/tests/runtime/dataflow/dyn_foreign_predicate.rs new file mode 100644 index 0000000..8a86085 --- /dev/null +++ b/core/tests/runtime/dataflow/dyn_foreign_predicate.rs @@ -0,0 +1,77 @@ +use scallop_core::common::expr::*; +use scallop_core::common::value::*; +use scallop_core::common::tuple::*; +use scallop_core::runtime::dynamic::dataflow::*; +use scallop_core::runtime::dynamic::*; +use scallop_core::runtime::env::*; +use scallop_core::runtime::provenance::*; + +#[test] +fn test_dyn_dataflow_free_range() { + let runtime = RuntimeEnvironment::new_std(); + let ctx = unit::UnitProvenance::new(); + let df = DynamicDataflow::foreign_predicate_ground( + "range_usize".to_string(), + vec![Value::USize(1), Value::USize(5)], + true, + &ctx, + ); + let batch = df.iter_recent(&runtime).next().unwrap().collect::>(); + for i in 1..5 { + match &batch[i - 1].tuple { + Tuple::Tuple(vs) => { + assert_eq!(vs.len(), 1); + match vs[0] { + Tuple::Value(Value::USize(x)) if x == i => { + // Good + } + _ => assert!(false), + } + } + _ => assert!(false), + } + } +} + +#[test] +fn test_dyn_dataflow_soft_lt_1() { + let runtime = RuntimeEnvironment::new_std(); + let ctx = min_max_prob::MinMaxProbProvenance::new(); + let source_df = vec![ + DynamicElement::new((1.0, -1.0), 1.0), + DynamicElement::new((1.0, 1.0), 1.0), + DynamicElement::new((1.0, 5.0), 1.0), + ]; + let df = DynamicDataflow::vec(&source_df).foreign_predicate_constraint( + "soft_lt_f64".to_string(), + vec![Expr::access(0), Expr::access(1)], + &ctx, + ); + let batch = df.iter_recent(&runtime).next().unwrap().collect::>(); + for elem in batch { + let tup: (f64, f64) = elem.tuple.as_tuple(); + if tup.0 < tup.1 { + assert!(elem.tag > 0.5); + } else if tup.0 == tup.1 { + assert!(elem.tag == 0.5); + } else { + assert!(elem.tag < 0.5); + } + } +} + +#[test] +fn test_dyn_dataflow_join_range() { + let runtime = RuntimeEnvironment::new_std(); + let ctx = unit::UnitProvenance::new(); + let data = vec![ + DynamicElement::new((1usize, 3usize), unit::Unit), + DynamicElement::new((10usize, 10usize), unit::Unit), + DynamicElement::new((100usize, 101usize), unit::Unit), + ]; + let df = DynamicDataflow::vec(&data).foreign_predicate_join("range_usize".to_string(), vec![Expr::access(0), Expr::access(1)], &ctx); + let batch = df.iter_recent(&runtime).next().unwrap().collect::>(); + assert_eq!(AsTuple::<((usize, usize), (usize,))>::as_tuple(&batch[0].tuple), ((1usize, 3usize), (1usize,))); + assert_eq!(AsTuple::<((usize, usize), (usize,))>::as_tuple(&batch[1].tuple), ((1usize, 3usize), (2usize,))); + assert_eq!(AsTuple::<((usize, usize), (usize,))>::as_tuple(&batch[2].tuple), ((100usize, 101usize), (100usize,))); +} diff --git a/core/tests/runtime/dataflow/dyn_group_aggregate.rs b/core/tests/runtime/dataflow/dyn_group_aggregate.rs index f7851ec..4a6f319 100644 --- a/core/tests/runtime/dataflow/dyn_group_aggregate.rs +++ b/core/tests/runtime/dataflow/dyn_group_aggregate.rs @@ -9,7 +9,7 @@ use scallop_core::testing::*; #[test] fn test_dynamic_group_and_count_1() { let mut ctx = unit::UnitProvenance; - let mut rt = RuntimeEnvironment::new(); + let mut rt = RuntimeEnvironment::new_std(); // Relations let mut color = DynamicRelation::::new(); diff --git a/core/tests/runtime/dataflow/dyn_group_by_key.rs b/core/tests/runtime/dataflow/dyn_group_by_key.rs index 6194ef9..58728a6 100644 --- a/core/tests/runtime/dataflow/dyn_group_by_key.rs +++ b/core/tests/runtime/dataflow/dyn_group_by_key.rs @@ -11,7 +11,7 @@ where Prov: Provenance + Default, { let mut ctx = Prov::default(); - let mut rt = RuntimeEnvironment::new(); + let mut rt = RuntimeEnvironment::new_std(); let result_1 = { let mut strata_1 = DynamicIteration::::new(); diff --git a/core/tests/runtime/dataflow/dyn_intersect.rs b/core/tests/runtime/dataflow/dyn_intersect.rs index 687a105..4c9876c 100644 --- a/core/tests/runtime/dataflow/dyn_intersect.rs +++ b/core/tests/runtime/dataflow/dyn_intersect.rs @@ -7,7 +7,7 @@ use scallop_core::testing::*; #[test] fn test_dynamic_intersect_1() { let mut ctx = unit::UnitProvenance; - let mut rt = RuntimeEnvironment::new(); + let mut rt = RuntimeEnvironment::new_std(); // Relations let mut source_1 = DynamicRelation::::new(); diff --git a/core/tests/runtime/dataflow/dyn_join.rs b/core/tests/runtime/dataflow/dyn_join.rs index 827ed0e..73000f1 100644 --- a/core/tests/runtime/dataflow/dyn_join.rs +++ b/core/tests/runtime/dataflow/dyn_join.rs @@ -7,7 +7,7 @@ use scallop_core::testing::*; #[test] fn test_dynamic_join_1() { let mut ctx = unit::UnitProvenance; - let mut rt = RuntimeEnvironment::new(); + let mut rt = RuntimeEnvironment::new_std(); // Relations let mut source_1 = DynamicRelation::::new(); diff --git a/core/tests/runtime/dataflow/dyn_product.rs b/core/tests/runtime/dataflow/dyn_product.rs index 37d4959..28c346d 100644 --- a/core/tests/runtime/dataflow/dyn_product.rs +++ b/core/tests/runtime/dataflow/dyn_product.rs @@ -7,7 +7,7 @@ use scallop_core::testing::*; #[test] fn test_dynamic_product_1() { let mut ctx = unit::UnitProvenance; - let mut rt = RuntimeEnvironment::new(); + let mut rt = RuntimeEnvironment::new_std(); // Relations let mut source_1 = DynamicRelation::::new(); diff --git a/core/tests/runtime/dataflow/dyn_project.rs b/core/tests/runtime/dataflow/dyn_project.rs index 01d6fe6..8ca0b68 100644 --- a/core/tests/runtime/dataflow/dyn_project.rs +++ b/core/tests/runtime/dataflow/dyn_project.rs @@ -8,7 +8,7 @@ use scallop_core::testing::*; #[test] fn test_dyn_project_1() { let mut ctx = unit::UnitProvenance; - let mut rt = RuntimeEnvironment::new(); + let mut rt = RuntimeEnvironment::new_std(); // Relations let mut source = DynamicRelation::::new(); diff --git a/core/tests/runtime/dataflow/mod.rs b/core/tests/runtime/dataflow/mod.rs index 6480e43..3d415eb 100644 --- a/core/tests/runtime/dataflow/mod.rs +++ b/core/tests/runtime/dataflow/mod.rs @@ -2,6 +2,7 @@ mod dyn_aggregate; mod dyn_difference; mod dyn_filter; mod dyn_find; +mod dyn_foreign_predicate; mod dyn_group_aggregate; mod dyn_group_by_key; mod dyn_intersect; diff --git a/core/tests/runtime/provenance/top_bottom_k.rs b/core/tests/runtime/provenance/top_bottom_k.rs index b6a961c..5b265f1 100644 --- a/core/tests/runtime/provenance/top_bottom_k.rs +++ b/core/tests/runtime/provenance/top_bottom_k.rs @@ -7,7 +7,7 @@ mod diff { #[test] fn test_diff_top_bottom_k_clauses_1() { - let mut ctx = DiffTopBottomKClausesProvenance::<(), RcFamily>::new(1); + let ctx = DiffTopBottomKClausesProvenance::<(), RcFamily>::new(1); // Create a few tags let a = ctx.tagging_fn((0.9, (), None).into()); @@ -29,7 +29,7 @@ mod diff { #[test] fn test_diff_top_bottom_k_clauses_2() { - let mut ctx = DiffTopBottomKClausesProvenance::<(), RcFamily>::new(1); + let ctx = DiffTopBottomKClausesProvenance::<(), RcFamily>::new(1); // Create a few tags let a = ctx.tagging_fn((0.1, (), None).into()); diff --git a/core/tests/tests.rs b/core/tests/tests.rs index 0b008d3..6378d05 100644 --- a/core/tests/tests.rs +++ b/core/tests/tests.rs @@ -1,3 +1,4 @@ mod compiler; mod integrate; mod runtime; +mod utils; diff --git a/core/tests/utils/mod.rs b/core/tests/utils/mod.rs new file mode 100644 index 0000000..0a15cad --- /dev/null +++ b/core/tests/utils/mod.rs @@ -0,0 +1 @@ +mod value; diff --git a/core/tests/utils/value.rs b/core/tests/utils/value.rs new file mode 100644 index 0000000..059c328 --- /dev/null +++ b/core/tests/utils/value.rs @@ -0,0 +1,17 @@ +use std::convert::*; + +use scallop_core::common::value::*; + +#[test] +fn value_try_into_1() { + let v = Value::USize(10); + let p: usize = v.try_into().unwrap_or(0); + assert_eq!(p, 10); +} + +#[test] +fn value_try_into_2() { + let v = Value::I8(10); + let p: usize = v.try_into().unwrap_or(0); + assert_eq!(p, 0); +} diff --git a/doc/.gitignore b/doc/.gitignore new file mode 100644 index 0000000..7585238 --- /dev/null +++ b/doc/.gitignore @@ -0,0 +1 @@ +book diff --git a/doc/book.toml b/doc/book.toml new file mode 100644 index 0000000..58078f3 --- /dev/null +++ b/doc/book.toml @@ -0,0 +1,25 @@ +[book] +title = "Scallop Book" +description = "Use Scallop for Logic Programming and Neuro-symbolic Programming" +authors = ["Ziyang Li"] +language = "en" + +[rust] +edition = "2018" + +[output.html] +mathjax-support = true +site-url = "/scallop/" +additional-js = ["js/hljs_scallop.js"] + +[output.html.playground] +line-numbers = true + +[output.html.search] +limit-results = 20 +use-boolean-and = true +boost-title = 2 +boost-hierarchy = 2 +boost-paragraph = 1 +expand = true +heading-split-level = 2 diff --git a/doc/js/hljs_scallop.js b/doc/js/hljs_scallop.js new file mode 100644 index 0000000..67ba819 --- /dev/null +++ b/doc/js/hljs_scallop.js @@ -0,0 +1,32 @@ +hljs.registerLanguage("pen", (hljs) => ({ + name: "Scallop", + aliases: ["scl", "scallop"], + keywords: { + keyword: "import type rel relation query if then else where", + type: "i8 i16 i32 i64 i128 isize u8 u16 u32 u64 u128 usize f32 f64 bool char String", + literal: "true false", + built_in: "count sum prod min max exists forall unique top", + }, + contains: [ + hljs.C_LINE_COMMENT_MODE, + hljs.C_BLOCK_COMMENT_MODE, + { + className: "string", + variants: [ + hljs.QUOTE_STRING_MODE, + ] + }, + { + className: "number", + variants: [ + { + begin: hljs.C_NUMBER_RE + "[i]", + relevance: 1 + }, + hljs.C_NUMBER_MODE + ] + } + ], +})); + +hljs.initHighlightingOnLoad(); diff --git a/doc/readme.md b/doc/readme.md new file mode 100644 index 0000000..d13ce3d --- /dev/null +++ b/doc/readme.md @@ -0,0 +1,21 @@ +# The Scallop book + +Scallop book aims to provide a comprehensive guide on Scallop and its toolchain. +This book is mostly written in markdown and will be compiled and built by `mdbook`. + +## For contribution + +To develop the book, please first install the `mdbook` (assuming you have Rust and `cargo` properly installed): + +``` bash +$ cargo install mdbook +``` + +Then, type the following command to start serving the book at localhost: + +``` bash +$ make serve-book +``` + +While the server is running, go to your browser and type `localhost:3000` to start reading the book. +Note that when you edit the markdown files in the source, the page on your browser should automatically refresh. diff --git a/docs/readme.md b/doc/src/crash_course.md similarity index 100% rename from docs/readme.md rename to doc/src/crash_course.md diff --git a/doc/src/developer/binding.md b/doc/src/developer/binding.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/developer/index.md b/doc/src/developer/index.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/developer/language_construct.md b/doc/src/developer/language_construct.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/grammar.md b/doc/src/grammar.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/installation.md b/doc/src/installation.md new file mode 100644 index 0000000..9d72152 --- /dev/null +++ b/doc/src/installation.md @@ -0,0 +1,30 @@ +# Installation + +There are many ways in which you can use Scallop, forming a complete toolchain. +We specify how to installing the toolchain from source. +The following instructions assume you have access to the Scallop source code and have basic pre-requisites installed. + +## Requirements + +- Rust - nightly 2023-03-07 (please visit [here](https://rust-lang.github.io/rustup/concepts/channels.html) to learn more about Rust nightly and how to install them) +- Python 3.7+ (for connecting Scallop with Python and [PyTorch](https://pytorch.org)) + +## Scallop Interpreter + +The interpreter of Scallop is named `scli`. +To install it, please do + +``` bash +$ make install-scli +``` + +From here, you can use `scli` to test and run simple programs + +``` bash +$ scli examples/datalog/edge_path.scl +``` + +## Scallop Interactive Shell + + +## Scallop Python Interface diff --git a/doc/src/introduction.md b/doc/src/introduction.md new file mode 100644 index 0000000..765407d --- /dev/null +++ b/doc/src/introduction.md @@ -0,0 +1,3 @@ +``` scl +rel hello = {"world"} +``` diff --git a/doc/src/language/aggregation.md b/doc/src/language/aggregation.md new file mode 100644 index 0000000..758eaf1 --- /dev/null +++ b/doc/src/language/aggregation.md @@ -0,0 +1 @@ +# Rules with Aggregations diff --git a/doc/src/language/facts.md b/doc/src/language/facts.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/language/foreign_functions.md b/doc/src/language/foreign_functions.md new file mode 100644 index 0000000..92f76f9 --- /dev/null +++ b/doc/src/language/foreign_functions.md @@ -0,0 +1 @@ +# Foreign Functions diff --git a/doc/src/language/foreign_predicates.md b/doc/src/language/foreign_predicates.md new file mode 100644 index 0000000..fc58c10 --- /dev/null +++ b/doc/src/language/foreign_predicates.md @@ -0,0 +1,27 @@ +# Foreign Predicates + +Foreign predicates aim to provide programmers with extra capabilities with relational predicates. +Traditional Datalog program defines relational predicate using only horn-rules. +Given the assumption that the input database is finite, these derived relational predicates will also be finite. +However, there are many relational predicates that are infinite and could not be easily expressed by horn-rules. +One such example is the `range` relation. +Suppose it is defined as `range(begin, end, i)` where `i` could be between `begin` (inclusive) and `end` (exclusive). +There could be infinitely many triplets, and we cannot simply enumerate all of them. +But if the first two arguments `begin` and `end` are given, we can reasonably enumerate the `i`. + +In Scallop, `range` is available to be used as a **foreign predicate**. +Note that Scallop's foreign predicates are currently statically typed and does not support signature overload. +For example, to use `range` on `i32` data, we will need to invoke `range_i32`: + +``` scl +rel result(x) = range_i32(0, 5, x) +``` + +Here we enumerate the value of `x` from 0 (inclusive) to 5 (exclusive), meaning that we will obtain that `result = {0, 1, 2, 3, 4}`. +For the rest of this section, we describe in detail how foreign predicates are constructed in Scallop and why are they useful. + +## Bound and Free Pattern + +## Foreign Predicates are Statically Typed + +## Available Foreign Predicates in `std` diff --git a/doc/src/language/index.md b/doc/src/language/index.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/language/negation.md b/doc/src/language/negation.md new file mode 100644 index 0000000..d49ceb6 --- /dev/null +++ b/doc/src/language/negation.md @@ -0,0 +1 @@ +# Rules with Negations diff --git a/doc/src/language/provenance.md b/doc/src/language/provenance.md new file mode 100644 index 0000000..2216e2c --- /dev/null +++ b/doc/src/language/provenance.md @@ -0,0 +1 @@ +# Tags and Provenance diff --git a/doc/src/language/query.md b/doc/src/language/query.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/language/recursion.md b/doc/src/language/recursion.md new file mode 100644 index 0000000..30a25cc --- /dev/null +++ b/doc/src/language/recursion.md @@ -0,0 +1 @@ +# Recursive Rules diff --git a/doc/src/language/rules.md b/doc/src/language/rules.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/misc/contributors.md b/doc/src/misc/contributors.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/probabilistic/facts.md b/doc/src/probabilistic/facts.md new file mode 100644 index 0000000..51d3186 --- /dev/null +++ b/doc/src/probabilistic/facts.md @@ -0,0 +1 @@ +# Fact with Probability diff --git a/doc/src/probabilistic/index.md b/doc/src/probabilistic/index.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/probabilistic/logic.md b/doc/src/probabilistic/logic.md new file mode 100644 index 0000000..031dd41 --- /dev/null +++ b/doc/src/probabilistic/logic.md @@ -0,0 +1 @@ +# Logic and Probability diff --git a/doc/src/probabilistic/reasoning.md b/doc/src/probabilistic/reasoning.md new file mode 100644 index 0000000..c17b54a --- /dev/null +++ b/doc/src/probabilistic/reasoning.md @@ -0,0 +1 @@ +# Aggregation and Probability diff --git a/doc/src/probabilistic/sampling.md b/doc/src/probabilistic/sampling.md new file mode 100644 index 0000000..f40a645 --- /dev/null +++ b/doc/src/probabilistic/sampling.md @@ -0,0 +1 @@ +# Sampling with Probability diff --git a/doc/src/readme.md b/doc/src/readme.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/scallopy/branching.md b/doc/src/scallopy/branching.md new file mode 100644 index 0000000..930da15 --- /dev/null +++ b/doc/src/scallopy/branching.md @@ -0,0 +1 @@ +# Branching Executions diff --git a/doc/src/scallopy/context.md b/doc/src/scallopy/context.md new file mode 100644 index 0000000..36ef296 --- /dev/null +++ b/doc/src/scallopy/context.md @@ -0,0 +1 @@ +# Creating Context diff --git a/doc/src/scallopy/index.md b/doc/src/scallopy/index.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/scallopy/provenance.md b/doc/src/scallopy/provenance.md new file mode 100644 index 0000000..e323640 --- /dev/null +++ b/doc/src/scallopy/provenance.md @@ -0,0 +1 @@ +# Configuring Provenance diff --git a/doc/src/summary.md b/doc/src/summary.md new file mode 100644 index 0000000..34c6dd5 --- /dev/null +++ b/doc/src/summary.md @@ -0,0 +1,41 @@ +# Summary + +[Introduction](introduction.md) + +# Getting Started + +- [Installation](installation.md) +- [Crash Course](crash_course.md) + +# Reference Guide + +- [Scallop and Logic Programming](language/index.md) + - [Types, Relations, and Facts](language/facts.md) + - [Writing Simple Rules](language/rules.md) + - [Writing a Query](language/query.md) + - [Recursive Rules](language/recursion.md) + - [Rules with Negations](language/negation.md) + - [Rules with Aggregations](language/aggregation.md) + - [Foreign Functions](language/foreign_functions.md) + - [Foreign Predicates](language/foreign_predicates.md) +- [Scallop and Probabilistic Programming](probabilistic/index.md) + - [Fact with Probability](probabilistic/facts.md) + - [Logic and Probability](probabilistic/logic.md) + - [Aggregation and Probability](probabilistic/reasoning.md) + - [Sampling with Probability](probabilistic/sampling.md) + - [Tags and Provenance](language/provenance.md) +- [Python and Scallop](scallopy/index.md) + - [Creating Context](scallopy/context.md) + - [Branching Executions](scallopy/branching.md) + - [Configuring Provenance](scallopy/provenance.md) +- [For Developers](developer/index.md) + - [New Language Construct](developer/language_construct.md) + - [New Binding](developer/binding.md) + +# Resources + +- [Full Scallop Grammar](grammar.md) + +----------- + +[Contributors](misc/contributors.md) diff --git a/docs/design/grammar.md b/docs/design/grammar.md deleted file mode 100644 index 2e9c6e9..0000000 --- a/docs/design/grammar.md +++ /dev/null @@ -1,83 +0,0 @@ -# Grammar - -``` -SCALLOP_PROGRAM ::= ITEM* - -ITEM ::= TYPE_DECL - | RELATION_DECL - | QUERY_DECL - -TYPE ::= u8 | u16 | u32 | u64 | u128 | usize - | i8 | i16 | i32 | i64 | i128 | isize - | f32 | f64 | char | bool - | &str | String - | CUSTOM_TYPE_NAME - -TYPE_DECL ::= type CUSTOM_TYPE_NAME = TYPE - | type CUSTOM_TYPE_NAME <: TYPE - | type RELATION_NAME(TYPE*) - | type RELATION_NAME(VAR1: TYPE1, VAR2: TYPE2, ...) - -CONST_DECL ::= const CONSTANT_NAME : TYPE = CONSTANT - | const CONSTANT_NAME = CONSTANT - -RELATION_DECL ::= FACT_DECL - | FACTS_SET_DECL - | RULE_DECL - -CONSTANT ::= true | false | NUMBER_LITERAL | STRING_LITERAL - -CONST_TUPLE ::= CONSTANT | (CONSTANT1, CONSTANT2, ...) - -FOREIGN_FN ::= hash | string_length | string_concat | substring | abs - -BIN_OP ::= + | - | * | / | % | == | != | <= | < | >= | > | && | || | ^ - -UNARY_OP ::= ! | - - -CONST_EXPR ::= CONSTANT - | CONST_EXPR BIN_OP CONST_EXPR | UNARY_OP CONST_EXPR - | $ FOREIGN_FN(CONST_EXPR*) - | if CONST_EXPR then CONST_EXPR else CONST_EXPR - | ( CONST_EXPR ) - -TAG ::= true | false | NUMBER_LITERAL // true/false is for boolean tags; NUMBER_LITERAL is used for probabilities - -FACT_DECL ::= rel RELATION_NAME(CONST_EXPR*) // Untagged fact - | rel TAG :: RELATION_NAME(CONST_EXPR*) // Tagged fact - -FACTS_SET_DECL ::= rel RELATION_NAME = {CONST_TUPLE1, CONST_TUPLE2, ...} // Untagged tuples - | rel RELATION_NAME = {TAG1 :: CONST_TUPLE1, TAG2 :: CONST_TUPLE2, ...} // Tagged tuples - | rel RELATION_NAME = {TAG1 :: CONST_TUPLE1; TAG2 :: CONST_TUPLE2; ...} // Tagged tuples forming annotated disjunction - -EXPR ::= VARIABLE | CONSTANT - | EXPR BIN_OP EXPR | UNARY_OP EXPR - | $ FOREIGN_FN(EXPR*) - | if EXPR then EXPR else EXPR - | ( EXPR ) - -ATOM ::= RELATION_NAME(EXPR*) - -RULE_DECL ::= rel ATOM :- FORMULA | rel ATOM = FORMULA // Normal rule - | rel TAG :: ATOM :- FORMULA | rel TAG :: ATOM = FORMULA // Tagged rule - -FORMULA ::= ATOM - | not ATOM | ~ ATOM // negation - | FORMULA1, FORMULA2, ... | FORMULA and FORMULA | FORMULA /\ FORMULA // conjunction - | FORMULA or FORMULA | FORMULA \/ FORMULA // disjunction - | FORMULA implies FORMULA | FORMULA => FORMULA // implies - | CONSTRAINT | AGGREGATION - | ( FORMULA ) - -CONSTRAINT ::= EXPR // When expression returns a boolean value - -AGGREGATOR ::= count | sum | prod | min | max | exists | forall | unique - -AGGREGATION ::= VAR* = AGGREGATOR(VAR* : FORMULA) // Normal aggregation - | VAR* = AGGREGATOR(VAR* : FORMULA where VAR* : FORMULA) // Aggregation with group-by condition - | VAR* = AGGREGATOR[VAR*](VAR* : FORMULA) // Aggregation with arg (only applied to AGGREGATOR = min or max) - | VAR* = AGGREGATOR[VAR*](VAR* : FORMULA where VAR* : FORMULA) // Aggregation with arg and group-by condition (only applied to AGGREGATOR = min or max) - -QUERY_DECL ::= query RELATION_NAME - | query ATOM -``` diff --git a/docs/design/group_by.md b/docs/design/group_by.md deleted file mode 100644 index f0d6678..0000000 --- a/docs/design/group_by.md +++ /dev/null @@ -1,53 +0,0 @@ -# Design of Group By - -## Examples - -### Example 1 - -The following count does not have a group by variable - -``` scl -rel num_cars(n) :- n = count(o: is_a(o, "car")) -``` - -### Example 2 - -The following count does have a group by variable `c`. -The body of the rule `is_a(o, "car"), color(o, c)` bounds two variables: `o` and `c`. -We want the variables that occur in the head that is not "to-aggregate" or "argument" values. - -``` scl -rel num_cars_of_color(c, n) :- n = count(o: is_a(o, "car"), color(o, c)) -``` - -### Example 3 - -The following count does have a group by variable `c`, note that `s` is not a group-by variable: - -``` scl -rel num_cars_of_color(c, n) :- n = count(o: is_a(o, "car"), color(o, c), shape(o, s)) -``` - -Although we have `shape(o, s)`, but we are not storing `s` in the head. -Therefore we do not treat `s` as a group-by variable. - -### Example 4 - -``` scl -rel num_cars_of_color(c, n) :- n = count(o: is_a(o, "car"), color(o, c) where c: all_colors(c)) -``` - -body_bounded_vars: o, c -group_by_bounded_vars: c -group_by_vars: c - -### Example 5 - -``` scl -rel eval_yn(e, b) :- b = exists(o: eval_obj(f, o) where e: exists_expr(e, f)) -``` - -body_bounded_vars: f, o -group_by_bounded_vars: e, f -group_by_vars: e -to_agg_vars: o diff --git a/docs/icons/scallop-logo-transp-128.png b/docs/icons/scallop-logo-transp-128.png deleted file mode 100644 index 456127e5ea2162e81e5d7479b457c4175b7e3343..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6969 zcmV-98^+{`P)q$gGRCt{2oq23r$DQXt_0I0*eM_WBk>VkVl1$N3B1?&FDY2cI zj5o12-W)Rl;$X2GXLf;cfSrwI{z`ygk$AFLEwC8JL4bIXiN{Vn6HS7hIQAq~V#kM0 z$~vsW6e)@nZ?btbyZaqg`$v;Hc{aOAn(PlCsQF%Xys<5CO9BqNSOpb$X5CCIZi0*0V0O9#N0 zD(3j=ouH}x3vif>AeM|&2;i}}&yQ@KqI1JO7e+kJjd;X;g{BgXt^qoUzcoNpASVAV z3+j{xbt-cVYKtvp>q5j_*&Z;p|4l{=X=hOaa3wF88ecj-;PYCq!;m8r4X$oB_*?== z{C)FtNl<5Vu}y8UEpu#5b3FB$Wr+Vl;&PE8#Uw3&2~WMO3Gwe=x*8X!NBqf#R*M#Z zQ0hk6jrle^%7c{WTliA->Y6wRJfE7Mo&k~~m_pS=|4QY6v%=@5Vbbfw6#bne-Z z#pd@OMcSn+-RtD3P7kt-8q)du9_S+ zmIUy9^{%dqFLL`Bp&EEB;;APCy4b7>0@K+)Iq19?G_}Jd#u+UOJ_)OYWCTzRv^{${ z&V`6vQUOfuFUzsmU7Jfmj)~*S?^&kyO2ku#vJBA_c9gs_>TB&|KBJCL`?!yeLRX5A z2IT>bJWHc8$DlIblsUGhxxVT*b@AUxh+$26hhzk>A9($*TA~2UA^>f2`_>KFY~7fn z0C8zd@kYPP<;ey&RHD)K$z$tEp)1gpocx?PPDwzgXffB91NkP!N+ zCz)0Z`-*7YWD;k&iN<|?ccPc8UE_Sd+Tt%NEt1iu<4ewTj`NG7ee{nwWx(r>E5Dm^ zgI5?ybz@!d%mZbHUlmw*k4qQl+eUcjbU!=tHNI9GAfzW}TO{uD`q@EFTnZyJ@z<^= zkCTWdRwc>$@o#xbz42e_1F9;p{Fj0$`i5it@nkPS>GRjyLzD-Tx2?4&!jY4Gn(L`% z-&dEoJDgk@VGq)|Ykj+^^5rEcuklw$klA}TAg`PG#e zXNKa_Mgh81H=+jZ*6p`M5w-^Vno7yZHr=?SO+r)vBu?FM$Q1uVA~E%$%O|8{HgIZA z;OFfQ$A;Y%BLEH9RFzA8T>*dWjPYi7JShR_QetU>z4w+844Ir9^oh8M;}=?j>Ress z*_syjRjaLw@9-f?SEp<`IOg-+fX+8h4tV^mBfeq;pee<=(hxg#lybJu;g$ACQUcJG z!qNpFZ7u`LV}EtGhkR+#mJ%UE4T6Ju$?X^Qw`&BpWmDv=I95bScdqf@2eLUV|Thq$MC%+ptBqJw=dNN`E-HphA_B$IL7Z!_G!8<9+h%s zdV+Mf$|0b+D5&-R=f*76kehr=o=PGShcibzsV%bkL``6CdQqB=gOCp zwKvr)$H=VR*2Qk%+xsd5{6%9fx%To75M$$UuAI6~{rVt#>yi_}l~dQTb;0h6nFp4h zC%JT{7sJ$c0Z%7sZc16sGyg51IX|eietAcpY${#eaB;YMl=hbE)UFS)r*`49zrO3! zd8V5k6+y~o8SF+P4jmTuq7?%1rw)%f7s9HajWI}22O z(einN@bytTPId9o#zHDf7M?clJl%z52zG3kJ-lo?eI4D@A7+`Co~;N0Omf83MFa5M zCmV8js38wK$yGgzCEPVa{|5tnw6Ty43lxDH_2|FQM^jA>L2KqnXJR>8GvxR{!TdQmQr-EWJ1AIo*X9iL>wCGP1Mg zpURF7Mi^-uq`5vHL!bFH4_xX)`O=oIe=rsHE|irlfGKWxg9FE1i zUL;P_)*^Ic{>L6T*MmS&Uo~^8S-O&u%l(?r#M3N;7x*v;0B-RH<%M?tr+1gi+V#oL z+Uh)Pd*-zfBdz@?sVtTuIrL!>z!X!F!Vh=X z=kef-1r6@mQ9AGMjeanIGdw|4W6{iYJCd=By+XSFe@%t03*o~i z0J!N5nhSEQ)<-^4CTrF$jW=*cqxAgY1F~~%8nzbC`+LKuyC^6OlAAYkeq-Y55SpnM z0gt7?)Oqt^6u=ZMmLcktRDajJxqy8;ipdG3f4zY-5@V>Pll7H(RMyPzk{uryCwgs| z?b}LbEaOeMxWg0R%g?9E*7@*Z761U2sT~BG^Rms;2R^!vx~k-+br-@!$1nrOuClSA znBsNwTT3wVUKe)Aq-2)zv9o?$ot6;SjmlOa0bj^clj!cvl#7EyxnyfwWo#FvTHF6I((d!{z9zo4*f zp^UuOMdZo=TXt0t$eCMkW9PaNz`CmGFCG4bgEtg`CCm)2HUR*%n})UvXI}0ZegN)5LmB z*M8C5T*mf$%F^Nulq(r}{UVxF)ZMoc-JJax(aU`aFHk#uSm<;PBW$CZC0Z-6ngsw5 zwoL5_;6PDPuz%nERcif))DIGsD;at71MHAV`PQ;I%KA*4xJqel5oX}-sKxIeoBFsv z<%UkfY8Sv1M|4ALkcwv-8cW!*tAbE=Y6pp4ILhcBE>p0vfU?GQvzB$*29c3CrS;Q? zgd5yaoSMqW)3`W;@t>sz{Uz*-gnK)YdTj{uJsmgf#``$V0}FHg-NG4^g7 zYAi-fy#9lp5pe&by-Br_;@vrow#$x1dy*+Y}J);OyFCoo~L^1-no0ds$@Qu=j zlK7?rbuze!Orji)5P$PBSvAE3stTtq!C28?NX@?1D?UPg^Z$^CfUU6FxzxfeMaCD3p)B9MdQ zh@jQ9SmS3SBY^o6)=m8&&}?MeC)Pi_g}i$z5OYla+=_BMVy~S?%d!dWslIE!ccBNn zw&aF5{OdQy0cpDoWCXBqS~OGN4SY4bt|V6W=niF<&;OQkH6ADa%2~{+BCH)1x9xM! zc7XzG%eot7mE%d$I1tPTAbFk_y0%^j@nXUKb>tstKnu<9?jgT>g!mhmv6?E;H{22# zI38|G2jBK#M z9qVItXazUJO!vJ`gk_*lhKzpaFbJ_L9mYdO0IA}LX6kjow?o^?<9VNLl6vXfi$&wV zyMP**!1{RIjUKP!9`4yL%$=10$boT0z^qP};TZv>mT#eJTZCo&HtXR=0uR)q&9Uwv z`p>5jIYI1)w%*w09Bo6(4PrEu;&Xc7@QG6>S-5=+`$2(JikBiXp=OWHq2O9 zIUdoUA49Jwz}j1z*ftu&yVQeGUxso#U|RUacj$Ut`9quROM4!ohOx|ZN87I&v<~Pp&diN4|v?DE$t6| z`95XrxcjmY|IImMcmn(3Mz9Us<5w}aSD*$Xh@e${NAV|tANRiB%TV_iz9(tz80D2? zz1mpJul1#Rn|UOVG6Gmy4hvInK=F)ye?6fud<4Dx_RDtk#baoBK?0v@#5;E#6?4&U zMil!e|KEu(e=ycIz~JTn+rRJl{IX?`i~UZ(SIjAbi~yFNu&|ATK#j-^o(Mj%3;W@% zVBZV_C;s6bjD~fXdpF^{dj+$(3OO<{H4xm1FW+*mhwY&g*JeD(Z`;NhnD7HJpIM(u zni0VA(JpLbH}JU8P&SzLpB_+V^JMJbH%{!8v)B(bpvL2ffC;Xzsi<>6#oS-W@I-L@ zX#1Q8_{9ei0y?$Ka;sTJ0L#x|VHtHO^^E;cBcZSCMX$=oJ>G$HsssCzTXD}_2O&V$ z{!}?$RqWMsDA%8};t@x2s^25#Gvk9NG6Gl`ghjwS2sCT?p%bAm@5O#}JI)_2A%`RA zWjTl}8>BL0|L;8wXYR(I;{6eifX|o`r3oa?;b&92S5lW ze)2YQAhMtehaEZfv3k1w83C+Jo)@-J2RvilyAfgP2-`p}TlkXoxIcTICnJEB%g(9W z!~DKc@FjuI?7S;%oQqrD)zm)S{)_-t9@Q%Dy$ZZ8uZQZHLPqsGIiACCd;z-k3*mNRI||o|w&umGf@bwL`vi{ezk& zo>no}bXuQ>eIyQBwl&o5|f_&D0+<^MY?}_()~&{o`FX zzE@Qe;@BtA3#+Tm&?C81c!n@Fe7I(SxLHNQpu>WK`UTYs- z+W=b0jkk~7+L$du6Z^6XL+@6yHwT`5DFT&zIx8o#oqv7<{%@0rQu2G`rii% zCNA`7@k_mEaSvmLwHWp%x1t=6`0p+NU{n^O;@+UJjrKc=wg_GOVrf&vxBrCkP*P*$oE`%P}9ZJuH&26FZM9n8AelTK_^zPEX3Tgf$_Jmpu!WkiAJw1KzZ^_ zG0S_07DGk=$@8S|$?pq6!L`eS+CWDb;d=A#@l$BoHrD<1M6dMYwh!FZAo|84DJQH~Fq759*qD+ViA+>b*Y3!ZCCu7d4S$Zpg>px}LscS8>D9yLvx>QBeR&{nH{2 zGbb4V%+6!3uYQCO1!pf0Yn@$V=#X@epF#^+g!a}l(l$t>V|d!e7uhyiNe;kEi#*({ zWCSoh)sClLwhXcBN>_vrI!2JLzvzYs5T-%y{dG9w4*lnPW~@`+STq&>9ZA7qXCxzl z+sSu4b;pyA? zqM14sAFDynvN`d_RfIox?!c@qMmb)qu#B`!{@lr062MN!S1$z&(eTI481G(>E_l*r zp_FV}muURs*YP|#r|TCX8?&|);5%tF>o(Sk0G@Uw-yMo6`(N9}8HoC6|KaPJ%E_-T z$yM;-|Xc($vl>Fr;^nufX7_POP5A`>*bCZaX!oWk1WWS+H*G7UtTK-`E{@PK5MzGN3{OzF zqZ~6g5Er^O>n&XAVl@dMKkoDDnPK00p)LrtIXrb#Zj*IFRbV1Ruzz! zctglj%;TYs>0mVp;K>P375&q7XGsob2Ml)HQ-&R~IDB@5{?Yh?dVb!9eAeB&0pIbz zr&~+!`lnb;0>~eXD{Vq9e!<3ERFr2^S6@gtD*4s9QNk1Mya#-ch!i$7SD<|PkGiEF zOq-$eU^NMVLi^6!EK^H%mPKQIKE;JWj$e=SPInaPE-b>RYAmC4V?Ms?e{b>~M9a!* zdRTNd(Dut~E=L9vi#P-%?#h`=`6luJlYc9fNChi}uH@u460?Dm`GTqS0NJL-mZ}^! zR%FvL;_>_LI9+2dhEiBkVkm{F@K>(^#uVMHe^#)U^5bz{Bv8JQ$r3GYThp2whiiQ@~EiD z!E@yap@}D0#Jzl}W;F=_Tj%SpB=7Up)|Yk{8vL#&&Xw``KUC2&1naW{R2Ev)lm?}# z3oY(S31}aTDWY>ca_v^i3IkLW1=v`Tjqk}p;J+EB_WXi>RGL{$mw;Q@AMw2k7oq!I-A%l`^o!mfT_~>}*gk^|-AaUz;JMg3<`5Q~4 z0ib=vr+>o3^%UbVAJ6kC4e4k~QJ!y6lw(R$7nCv31VQc2D=yPxti95~{2yje$yv3IG6xtc=tr0DuKAVF6@B@Xw|D)IInI#X;t? z69D*umwo_D*jah=zzw-!Qo8N3Gvh^{l+I_7CWj>TuMqRrzhV94g!HN zJre`iI5_dthyeI0Jp=~8!BIgm2E?SK)B^v{SN+*8aQSZ@&q72wzbqqWNnAdiKEDpD zQC8%H`+skSlZU~Z{3n$YCW)RCMgq{#WZ`xL{t!q${8!}P{hVn^Qg`i{L7jc*|F0l^ zcxB)RXNiH0A0zl=Ng5yztx8 zt!ZBa?IZ$zQl9_U+#A?pxXjrP)9J$&{AUI^v;EUs5^sY^_{qgClSnO}ZOSsu3q{!f zdFE!Mxa9R@2K4>^&nLZ~=D2JJ-CL40ipN9mq<*P}yQOr(8hhNNL2jE-NIK%ONw-mS z`fwhTE!*i)CP$dO{r}is&JkH``!m#P^e9ojEDv=qMo;qimA^MWnH*6Fw#?YYc}}7E zg2<%xK+c5;Nv(JU0URrJ9HRMu=2RXq1swSYI2J>FkE562ku4H|Nnm z%(VRf5AFy?u*EQEuPss#6Xh3KWDoNM`R2dA4Mw{ZByZ0ICU{CFparFw=J)T^wMFN@ z(n|vSsVO$<;-wlw*(EEPPCJ2i7cDb2OWoVSd|9L=wxCb{cpS|X!B7=RW8Fbs+VV* z!s=IU1`Z#F?q)Y?RQ6dZ5D%-5_?x;jD&;Vx3H9oh9gjI1XDn@YThp(slf|Z{Fr~2u zEj0ww^=-)ND@ltlPqTFVABUpa;r{Ln^^=pkW{^w3RkHY$wgfNSpypG!4m3E$`sfSB zqE)4FLg>mEMa(@6ndz+*VdI1}?KUz<@gBAq4s-6&xdsjkC|LyfVnZQTQ}{A2qtU2F z64d_(w#<-M{;V?(yM?A@OW0zCPO&++lO(X%rus_PS_2KqnD+SzjWE3(yfriPL|Y8W65A;u0c*C1`;)A)-x*mznqKG#;yrwOW1!@ndf*Zp=oGuW zp2vg2`Ppl`62@u@rk__*Obz1bGn)arV%5}bz%ll5>tqyAwi%{h%N{bypXbRv zSO~)Z9B>c^vF;cPOsC)Wtpm(&91Y5?uwaM|fR@q7l;vvRAj>fd4 ztE2(%6p&5-4_AG!2Tkq*-9K8$Gd1^PuzXT0V-e6S#jrZz&Ntm{SD3ul#S!bFnZ$2f z^5R}z=UCKxn2?>8HUoV|5ocr8-m%%RV5VkAAXd4#udDeFzb#b)9bg(;Vz2x2=T8eB zV55;z)tFmL0c@OWP58l|y9+Ze0)3n4Kz=zfI*I)ai<6$@gcAm$2|2K{YjJu^+2nCJ zqlaMo8@ChizfKht;MUdAnf$DAQpo1kAb7-UT{3T6GTx{sm95~WbaU7g#bMxq39Mat zO80!3s;g)aXsM{>tpR_8uh*b)A3oAi{Ow`(e`5~Ki^Ww;p#0=CE7CyoF0^m*g)K@y z?>ON_7gFo=kZ@ZoLaY(w35i&$EKb`P5v^NN)c%I2nozfPvj-ij2O|JJTXhGce6 z@{0+~xTe)_>ZUL;(Nj-P_I*Q5Zl&U0hfK~3@(Vwv1a)<_?SmH1V_;bEI@LUF%X4Xx zp+EwWuILOjgIax>5F>x8oMn@ZWK4s1zqk$Eq|H-mTL+CJCKbmxBro4~f7XTN z_XeCqoP(Aft|9-xc{i9RI2^kxKtBJ}s8z}vs9Uhs?KHtoka{}cKMdfP=ZzZh3prYC zxh3x~(y&c|dUq4WAE8-(j*bU;qR;`+=Z~H1rj~w&qUuzNj;zxr^pa@% zQ|QW-Lxz=UaH`Yj%EItJnW(zJw1yAS)c3d=8I6ogj0jx&3z}1^fEb};e0lQep33jk zD^cX`Hum|rwVcT4}*qJZE_kg`7K=;y4s@W726OKz|ePlTeEE%6oCkVpsJavPtUL3 zm#p3wVTsg+?AX;O=c=mUsykp8R&bx}j0!(c!)Sj=48%3h>?AA@ z8F=omR*8t>5WxMR=4aPyxL9>@b$J>RecTj%q|&Ynp`4Tlew#kKH8@S@!~Hg$q@;Jd z8sLVDF;Pkr_1hVqcUs-l0!ODL%*-VEKHoeW?iu)A#6EA!S{l)l00PN&gMaAY${*%_ zZd{r6#iyodez;R*GSK%szFHM=TQ3l`*j1A18C-X$Zn<8te%@5w9m^PRWXp+!@yD#! zv)grRdsuF>v9x_EoUwgHGx!!%a#YX#V9da0+uGi8IpI@6q99;d{FEsF-2d@WUO0Vp z0ctArI8gFW>1_qa=O3c#3(tQy!PrZe*DQU+>3ayx$7kbUdpfX(UfV-wzeyS;gKPkI zafRTh&dWdlO&oNy>~bLS1jhDI!4n zLKn9enynX9{3}XQIi<}@Z!c%+R*kWLT71U^h;e;!Nn#&~KRwc=)Am0mnfeZ+`m!SV zDW;K=7=aV79s?Ps?QTZLZ{M|nC3?#X?_#wkbamWub=%R$UpFJB_*7sngW<5sIY$}R zKkIzz4sDlRzjb&t1F{E%bdC2ah(S4l{y)%_)<_tI*@;weu5O zwq~I}qHzw|h3deosp^q zsh{*ezU5LrX%d_Usq-F(9uHRaqw#Di_>jPxg+u#Cf!RAe9J5?7$|vH(64@=6BaOQO z>W~%pr-;)=-BOsBIji+6rV3t3=P5>_TZ7;KIk5A$RCt8rUVIivTZY{(O|fw_6f_HR z^OFvZg*O{QG-z<&cXW~;Vi_R)C?z#0{O_phd8B)*JHw-SCE{CG0mE8P47t28-~5da zYtLb5U+hN(rMIuge(oF=Y&9#2@&E#VXD1hZkGexozbs?YWEmUb!}nR2wVm&3*ZoNX zT@nDo%{wX<3V}>*gT&Ydtts=D20EM#Vi-VNZ zOKGV%GK#dc)sxfXCfDcnAWJ%>2e^yFt5^)AmwK_4cD`J!u&+RWjH#zhXsYzXJHQ-J zn_69*U0Gg!V*Ae!!|8tOcWn+$OT(4sp)meGz>ps1Zj+J{Ft63JPH%CKjxDarHUoYZ z@=Ji3A78J(`ma)LkDmglL+eG&YWX6?@jUbs1R|sF=5|wucz^Hp^Yb}X0|lV^M0XhS z=~yhq?m3qP6#I8!>C@u$@>TB4F*OY{z0$jXmu@w7lCM>G?n{+Q_^CO)_*BLM z*Rawgc?P{7EUhY?FS-jJa5C9U{H-doUArsa~z4KX840xU6k44esxPL(arg8WBc`2zl4vI7ijwZd6x_A0f zrvsU$mOc%jWtGS=e;!)IU|?kPe4|iuM3>cHjpe^gxBLiNx{{Os&27SrEjV45K-1~8^ z!`ZCmfY=vXJIl4km+=Sthh8Lut(!cu=T-O)uLZw4ZV8kuf@}H#0rJS&fL17G$)XA} zcPZR&iMv&2Ivwd`n`6I?bG2P_Z6w2+3KT$FCnZCES(k=pV&U)E**_DklV<~&Qz4Ph zBzMkMut$T{Wkc0Aa0z$!5HQCdn3SXfvbll<)Wvq;uojsb~5$aU(W zvJ+kG7kjFZTQ3vx($^o5eBdnjBPE>wo40T4(7}EM4UJqBXcCJLiWt_*nC~Ws?EI5Q zQ{!|HsOxB`r+gv1LLK$I-eBokQ(eN9n^3@2w!m z58yZC1R#A>V}TyHtEB1O&B6R@wWRC3q9T6}xI4_bf3Lizdm^V{#sEpWrSsU;&O=zF zWHC$n4i(r!rDPJ&JgJCzGa7CAyxez%fe}vQWH?UMwU-OY@rR&fK>X_$JNh!5Z;>u+ z&-ufZB1O}K$yWVFkcc+{><}Z%XO5snQ?c*UkOW6OD+-1w0fc&61#g`A9|6z+lMLOb zOP*|Kq|%IAzRDi^Jhnf(9B^ukNR6Sc9zP7^Z>tL8i#sS1pW{mtg843FsN+aODv!=>Bo zrl+S}oc{fDLJ8So_-WBtYamONG@JGuv+n$b1T!RP^Vw4StnXYj0^zJZ(WY;>laSEF zy`HjLbY-bMs?^@l=JjuvSP0X{k&)pohAtfv62(ZW9J3CpK0k6S8tjPEk}(dA6eyDj z>CMJz?7_#Z25!!vnS*9q1Rz2ym_$%3pn+J+F0SZrg_P|!VtxtkTtX zf$Yl$mqv~~QP2;J!P+lRwNyHJ5Sdg5k`rn)9qkk~P&nZnzU3wmj*`FY3>1y5k-;gx z8j6Mx_boS2D3slDO=(B2Fi0zk+D|;VlbbKFbUOU*uEcR~PL6?b&l>+ZFC+Wu0r~K{fcO!MqQWt>9yd3)iM91PUaUaY9?4W zzW3ptS;VWMG-q3yZLkbgW_KIyLf@#>acsuHFFNOEC@KXUFsKz3+haF{9oqZVVj}3Kw2YcO zR93%9%f8d#a$dz`d@Q0%QKq|Zlk#zST3y(XT1>4R#5eK9W9*Sr1d1rq*dIF5W6!%o zRq~k@qsPlh`r(pW`o0u$#K${-3P!zq(u#lI-+N#52T|$ibxdez9qwRLyrzs|*~i)% zWkY>L%nPR|dXv?w>-@~{h5*Op_GkssPIxYkybh(^t+B+tKwnf6w#oYyj%Zs;uSrbiH}q(ZEZR`}V?3Fn zBt=5XvfoURYw~kk19Zt*WDx+)-NWnZxFgj`&vPj7R`j9p&i0Zszs~b2QCQ8Wp@Nrqa_1?qf;&`s|p3U|+po$3M2Ufw|JuN`s;gk@0%0WtavZzIw@p+@E7> z08h25)w>ap6PRc9@wSB$$tF;wij!=GbLGi~L}yW#T_&5z^DZ|*-kOO7v;Wd3#5cW| zj2^)T_N=dk_p(Apc@$zyK6_ehY_2c3*g8qS*3^PI=M1iwzrFW%b|NTB5eI!kr-?)z zdt-6m@M~Ni5PQG%72B{!HU2*Jw+8Eph*n>OP*IEth9WKWNQ>l`vv&4t+=%tfl%k~c z-B02)CHwz$0)KSsP!J-*2-MbZ^zHItDqy}LIa_>Ee6#1yv~wMVG0e(|Og$0f+xCh( z(NaLV6t+L0PLTCBLL+5ngvrX9OkDTLcwZU&ETnIc^OM<3xtVm6j~|@LrtT{-50|N* z4wWYJ#l(MGjL`uZi_t0M+baz+a>2CU8tAUxINq20u+1w!w{B+}kN8yz#>@f%POf8t zH{s$^cOXVDHWVWKPhA$*wpQy?Q4s#z9^GH8Q~zF!)xQgt1}B`fsddOp12sHH%KtUNc9ZDMBQJ| zGYoVj43r48wFhq^GpwsSnTgLhCZV(|C22hbqJw$Ddj}4u6h+mfXH`AOB8%be#7N-~ zNl2d&K;6ChrIuvekm#=l`|oAjz{SD6n2BBypkW!!e-Aqacr?nO5KTu}KVZ1Q>PO%r z?V5XXAx0N-`sa)7kacpd+^*z83NK8$`mM)CN?sym^`N{R!Y{JO|LA7+nz&0&ta~}4 zuS2|L1hQ_07C9}7=?dBVVub(%!#%5Cq&r7hJcu;hoRQHgfmr#FM}XlfI?Zu#~LJ+-mVOca)dH*s<8hC~q@ zJ4M;I-;Y0N)1y$l5vHTg43I2$k`Z~JzZXXb0tDXhHUKfc_rK+&dk3y26HL9FpDkblJKG35GjEoc&<{%fcbHHI7Ju@ou4-)X$%< z*G~ed!BTQ<1nQ{eqW$)4^q9cL{?sA_oWnzlYfnxjRWaiE?Ci{WPc_y$6C)Nb{lrjR z>58i=bK_OIr#YU#D2d~qt7emNd1Z5uS6u>q-@y=kv%uQKHWX(33woU$D=qGP2J6Be z=SDobCeKppzx9mLGLx^dUN}FQCkn8BKZr-YB{oh(cs{y}v6N(tAVC98Yk$}<2UNv5OSVjC z0lz+en?m>58|r|~xfR@#-_}0Vh9?r&;9ju&W3+Pk*>@>(!b|On65l0L z2K@P&$Up1xVB?+-=~$z~=9_|BhI(ZA-0M+DLx20#VGI8qfBSkUhHXfI?ry7o%H*pX=a4}z-5hJu!)2f{cL+IR@Ed+jQ zXAxy{X1YOAyqzy(IluCd1sU?eq~6U(#qr27jNt0}sIAvdc&Qsd$p2HG>31y^rd+`P z$5PlauXhtqQ-y)x@UA>Ad=mQYP6evkJ}chZGOFPL^w-whVT)pZG8CYkV|s@ zOu_lbB<-mv+!ra_LeIIK`31PdK0?zt?HEyDTfY`dxF!4F6?}A`xw+v1YbE!Bu7bfD z`*(Bg2OpHgcRuR4%Np;t??xbWUj?l~viFi|Hmha{I8 zDO|caquAp_ZW`<7m4?^^qKJT6IeVSfxdBF8X6G<`r1s0O7C|p2 zMmj14?!?!VANh9hHyNS0r~!$NwFv8&=XJUc0$l5adhTG&d#+$>zC@hR-Yq2tX65_4 zS%tV)pldaoj@(;@V2Ut{<_Tz?6_Ia5(xTyEws~PVsEWSzWM{0NAI&jKK%ZP5ZSC+z z`7AW43=_wQI?@)43U$L|qKEGE@cr z_bfoM9A7e^629J1<1tW?r+r*F&81~{p@A{6W<*yeY z78hOhQR8t`S>WJEmp!SZ-Ie%H2013IzOa5t%DR6}D8}u(hoPH;RM^qNH~7HXU-lnT zJzyvjlz8U(uL_)W?rRs#e0=TRk*0h>y_V$54HH{ONPjG}x+5mp@HyUE`VY;Cu1=h2 zbiO~n7tsC+wi7K4vw7a2bP%`V6X6ny(UBOovbtTbW~63UaS%0iJw<2d6Tv{mu9seU z;S_HfBcvhp;~icoDtsGG0bO;&$J^1bM2#%my}m<)1yC$$$Ter;hN|hwheL#$OcqP6a&p)E`XT-JN7XI9o;-1gSZC zTCv`?Ovdk-@=}|mpy?}6Aj5j5kn^vqW~zi;zlX(PCD$T2bGB%(nq9qyuizZb2~fKQ zT`yaSisOm`H=bg3e%%`f zBV)-={*ZiZF5q{JYvDlfD!+aJL%exd;;#_|tXxd|jeR@kHD{%q4{Qs3UaF+2sdjEz z6PXjr{={9kA4^excY>9^T>6lCH(3|E@Z^D-ZLj<5dF@N%`^GoOdzOyBYd$O*D?}T@ zhnY;M!;D=Q@F{iF5wsr+Cs@tW4)TG=Ka^6fYBy_M)mfRD$`8ymBOOjo&^sz&sAtwF zE>as_t&_r{AOsc;D(qRS*{ipAKB-ILF_7aDfNdu!6stjn+v@?;R6C8bSSA9eIwBD0 zM2JOyKNab}Z&eu9&xjSK+Gh$;&H06Gtk%6j%!i)U32ss`h5oa}Z1!lzRUbxV;?AH$ z<`-4_elM5ufelt-@l`P$I)au0qVArmeb&bPFvFwH5{o}7795a;56W}5&c@CAHJ$QK zA9cr`t+Q$^o?J*s>$@XHP{GxJi`3(K%PvN|#v6@7O&U+=Nu|{+%^(q|D)(rO&{=ON zvpF1oQd0vpiC0}xM7?(*<~H+4rzSIDsL{zd=)$#7dwS5Z=S92q`_$~o5~jMfeN`9iH~9_NQ9ra=7%9p`g?_Jm z_>dAyHVmuWVF*ThCeN{=w@Ef>{4gT<>n#l9`yg4_zCCtuV6^AQg{rp=n5%=MEB^)^ zc?8u9F|fmaXtmZ*`_V*Wast$Y5-!OotHOFgo8aL6uh^WgFkp2F$q&nBvfmZCWt(}2 z2rhG`Ox+1yj@rH2KpV=C$>su0%gn#55?mHBm9JslO8mXKL(;9!cEpT$h8h{!BWTXM z-k`nDnAbR8zGVTBHyAE-Wy{hv0>rXozz}HlZ6xhW!jKu@sr&vNF3m*8i9|c);#o+i zt&L_t#`%qa%KoQ3UNOUEr00Xk$51I!@Ptx!r7DG)yGu4A!7w7-cG|LF911aI{s&1P z-$bom8k!k9IcqrbQ&oGSt_T%EBpM$R&?<21+veWK3K+r37&2DY#uPfLWHBa;k<(d? zf^8{CwqrJ(ZIvxIOL}(VI)4P8#{#W|&J;_F96?QAY${2cw(V{S3Un>S=l`;OPcDoU z=)c)#weGyb=TzxSO~&22cC9z2VK@A$ zzM(BTa^AFizAC~41GaO8S{Unw4~;t|e9y|Qz|nS`s$*yQJxGRc+Cktc1#73d?>kdi z4YFzTz1Q3kO%?khPYRz(Zs{+Z1MW4?k?M%E8c7(sh3M2!o9^P&p)vTph?Z7N@#jb@ z>^R4)Gv9W&Vj#rfdhX!QLTS31KW0c3Ttt4QBre8EL)oZGwWn&~Va!TnY!OHwDC>8N zA(b7_QOn)zo{Kp>44!8*Ho^&`k}X0=NYwdyWC|@ru!|~t)x5vzrLSiu=~5#ZZ_R@y z<_X5L3@s>gS7~~q(~w|z7?Ju2!}$D3I)?gpi z57XuXc{8JJ(I-zts%&66+4%Ywnms~=ujB#>Urqhr>7te3chV>&$L*qZH8)Fke3>PZ zX{Z50D0mp5mtb~^#?%`AAncWexQkR|#C{$R+dBWT*qfJ4U}oWS=tA)0l-gV)JqFm( zWj5qS`E0=V=5*kmRLbfFJ+y2tuFdY5I26psikkV4JSk_C^(K`3Dp7;KmY(&c{D{+q z9%^8Q&HpfyK(wU3WyHeBN`nXD_QQK@x@8RL(?wL!Qs^Iq#w2PytYH>oCorp?-QSU$ z9ZVgH4vzj&AE<24B_t}YJqi=}zEiiQL^Fs7hKQiLrdb!-F4^DAjt>H!73V#yutN_y zo*&I=BQxv+-uA^4m_!rTmdf+79Z;1H-8W_;MA4N~PvC+CLcx;~&OfdhI%GknoW6_O z;8}4!kZp?CGGHV1EOTNFf1AfgeJ|lP#!`6fpl&Lv!t;dSXLyz9rMx!7rmyMCvp!O4JO=UGdUI zCjDDW<)&RrSm1G#un(ys!l{}tX{CMi$!wtFB$|PrWx8N7Bm8+K%lWgRCc!t>A8_vpiGiQj za+_p9`TcLFgdFbX39Glx6t$F^>vqPD0F~&1(#rLT9*ydv>d=VvlIM*8I^=Tx;k|~xh$fi9k{i)XuqdAa z;LF~CcU?PjB3#UW;tJB! zErDgwG5-vF9g1@2FVvWX;Dd~()~(CVGMZ{Fq3NbRu814%2|MR~(@KY&|HV0T+ube@ zSPfW)T~suO3Ph8zBhek>@AARo$X2^)tSzaIysZfwCN7&*{X zL{}=4uLMd-8OcdX#UQ0NLm5(WwBX@Nxo;zOy`hs4#3K4wsyOCw5|f8>tqd5j3yO8^ zDmlJ?tNxvT+wzb!FsBcEj|gOEZaffDXVFgc#;teexW1{yUu}i$2w_21D3*(KNwp7} z&atb=-aEjgk$<4~8M01uP!k|9VBA3y-*Ob2KnBYITvKkcaUr#^O^aV^&f3*>k%g|M zinOW9Z3diZUqn|?O$fhx#h@&G{wZ8`I>w1s-rWKtL-wMl*83WYybN1Od5fHxk1EsG zNKtg(U57Ez7BY^%iT6;8=`JUj!lb@cEWd+m+RQ&?q!#RMzneS3;#&;| zD{Su0`BNo#`$=u4Z@u{R7|L>jIm|C^Ef0{itHFu^0tlzmxj#yvH3lF2Fdq25hhq3( z@h8S4oAt3jcCTKKXNx%64gDVEj`Ha#qUxY2_kD92whRN?s0MiBEa!iJ4x zyEs+iCTmNs&7^K=Z6<4IEgzt}b$h>(>eakXPe~)JUx*tifUtwb<|Bv_sn-qGLuR0)=4``E#4d;2zTW7Lg z57uaoF^0l4oi|ln1Bk*is6tjsHTjNDX8}gRh2mx`m{OQ<%IO3vZN@YcYLaemgELOW zJ#3Rfcd=69tb19g;ljt&sQuRdPEEaoMe9k8wr2BD6b&1kj0ac)8N{S^y1NIwpIJ+vJV4aWk@OkfJ5g$=m~1#wi+sbVBORz>lJyjay| zGffx8mpeeKGacqbmH$DNE$1(uXRGiBOb~w>{0oz0`jD?C)PuC3&9e(Z-O#-oh_*>= z!vB147rMb-MLB*~rzWWUq@bxW8&0l;vqq_C9|e@c|Mmt`oK&G#GlD(YLWW()eCg?N z@FosUhrvo>F?#uQ{GSUOot;@>exi2E?=6zP>|(;?+rq#g87L=?$E?O|?@5lz>fWlx zZZO(7MdX{PH=V%FQE@=ayC)QHr1y3*H^?%4XxKWSxnH`75XVy2a>{oG`yqc83S3k_ z;qLO{v{(-B~U3u*bJYXlGmg_RrS= z{^oJk_9q)Ybo1>dnqF54$cj=wY}{v~V}4JQoH<48bLGo^C&&~Sk&p{T5S#pyKXxpd z3|$th?M7d9WURdR7OKBl(%ttz)b;u>E0~19B*c(`fW6DdbI(dc4~mR7C$BO4;+2V| z4tI4=N{B3tbKl@R23T9ur_v7IphJ$~vCLqGKN3i{g|NUdVZDslDGVn4d!HjRXLZRV z%UQcGefxwIV*+l4Xe~r3u7k*{qV}*mEEB!2BI3vLcZ>odu^%{KBt98WIhAG~vi4W{ z*ne*rvC}|7K2l(;Rx)2%gPp&&f>!oW=RgnJN|i`%Lk<{3dO-xsXaedRzXq!9<@-59 zCP$MGF`azB5aRqil+Ux<`vb%}*?EYlYa2`U?xo~j4zn+F{X$TSsQ!H}|MdCvZ(i_O z9$ttBVqD8#te`*-w2c_)7z9qfdZR)i8np8A_|GPx6JSvXaUS}#FZSfr;*jxW)G8?5 zrl*T0(0(rDOqp`JcYbX3VL?pou&Z=slhu$X~G^>gJsO&k_nJ|eHu*4!s!(erDJphaM^){m zKBz9$H&kiQQ@`@?B);;Sr6nIJLPJ;Df5Er|FY$^n;8KiIV?gmupligl5(nxIyRqcE zYg3+N=#xu^DDTMvM}x#m!HxBX?p>4++$WGEsj~v<`AhPBGpSB28ZSgx!QB(D7$J|$ zL6pNuJmtX9KbZ92P3Ah2#IvI`w<1Ti zdY3m%i(|+5cuz|oVi-zpK$mS)i4#4rjR{(g8TCUIuzPN<*2?d`b?^DMXT}Hl$y?2F z*t?QI!mZ@Ui?(iut|LfbK%?XQC({5vcEJm$5D=)w;}IB;`v22KLYJL5u)@om!G=0Z zI1W3u;`I*xxp@l-$%5!+%zN3DCsAS{kT*lo8DT%!fyH4~mg!;X_XjdYqmzCg+4ZyC zNl0Yb(K|jZh4NNTYN6MhUx@*LH4l@B&{&*$kN||8T6MTNjG2;K=FxZG9aQ9~<~W;xwC;7`qB70POmI!glh)izsjfM&I+KDHU^X5JCEA)f1k62h z$$AT?5W+gCbiPx~a+|~A=%gW7YtPkjzo9%~ zEZV;L`_(!7s_!g2>PQtXSI#JwPEo)GgHKdn^iplpMJod`P6Y zwog^?@ZDMn-|{dP&zxfp`QHIsB!k7}J}$&FEKjfbc2t-tf6$y6*vT;{yrZQtGseM4LBt^1eNI zm)G`f=D$v=DEahLJKVUin0B2*7zHJ}SGW-czP1x!?^jBS8=Vl3b8QO#CUpso*&@$ufUQt>Nrb&`77t(1?e6}A@p#ZB5FWvU&Ghy~ zg9+2(*CEGbv{4op7Ro}@Cm65Jux9D{So0Q*NNs+Sqxo^MdDwZzAV3sLUPFgXC7{SB zdTElEvLPHFOp=nW*V3YhI2+mtka7Y+14c{i1bpE4J$#)M`I?FlGGhDcvdSyQWM&u1 zuy}q|grBzwPFuZ2FyzMLL3=!fr>P_(eRodN?Q@uEZ-|$TJ)EpE4eK+XL$`9L5FX=D>A_ zXd^=!{+tF|0e(`wdU=O=u(D42*N24Gr0*7rcQc(>v#__jDjX8V!kc6N^%tbFaf7De z55u}%RO$US)%1M)!$T7JahWK_B*7^S+tUhMo-h83nH}C0#7c5d#{7UazA8urui|{vat!L7s+XVWp-d3vh-1p+Q~-oY35!vjAKyo`hKOVfHd~__}1We$A!qlhrxi< zoh6LDP`G}(Y(c4XAz!IY4Tkvb^~Zg4F}MiM3|r}~NtE{)U@EM5>Y(jbKY4U(;pqb- z7VKi+s8D_$b?LloVc}#XA%w7FFP<$B1`w5STt?G7k?ttIwcJk^(l~+?N zG{+EqR`ZWQlgKvGe{y-|hRc1P=#JWa33~|w`7-5f993zI)HXkq84e7YT13nip5VQ$ga53i8umEnx7YH3z02!3 zF%!hw$)(`l;Gz8LzhPWM-`V+}m;Z56Em9F8KaLPVY)|ZmO^TgB0iGoKw+D1tQ5ud5 zqd~p3{OHR4FaFv02LW|lqgw8_3mblACdT=r-N27QpigUE=6;!&3{POwE=XOVTzW`k)WPi42DtYFYgr^3))Xggtu_q{J-#V-UpU7xA zwG?Dml7*HPuc7}t%bwatWay^LtuNR^NSo%I8rc(pQ46Vt!QrTg+-&2*V_8tzstB_R6d&csX}s)7f! zsA(o%QuEoaYLxXoQqJ3JTLaP8vQHtRDc`nrEe5SWdmb3hqf!gECdAfiXob5sirgu~ z6JWkri1Y_oV0d#y^j{_S;XPu;&DQB21s7DApElI*iQD+1@w;8R42{;(GWTL9&wpwe zF|e=dGUv|s)f?paZ}|VzC*oVZN3oI5XT5^GjoeygGOEuVb@rX!#R4T``SpK8zc|A# z4lNc|lK`ztHj*JJRnPw$crU`2lEk-W$TvE^Oho!69#7-N>}K9jQP(!ZZ6@kj`sq*K zxqe&m`v&2Q`@kr2wJwl;1>& zk$`^Ta-FG&9^qru3(O!VCrS02Ujb(RB2kq&t8;SLh zVapPu0O8}gE#Uvx)!l?zAn>!;DgIEEM*QqfI{qfeUoHfN=wT`UwTY%m z_^n?&M8*t6e)Jdvj?T@ibi2OiZ`;Xs$QURE&t@Dn844%why)9H>=Ku4B_EOrg>P-H z`~N90z}=?|a!qukA%>BFEJ716l_(5ZNtLweNl_`(`%>_2+oo^j-%F+D*ZS%m<17At zDUT97+dfiGUwiuf!;%`z@Jsd1PBY*4FH0Dege=fp_9K$I=bgk+w6Ud=hF}pjQNE>h z`$rmLf2y>^Q*X5PDjYlA%;x`HhSh+&E|BzjFibt z;(&iX_K;$jHH6Uq5rRtPmXaoU9^Rz)NshOCeh&RsTgWOZTbNVyY5Gvb=p{WjE$(XY+k}huoLV-9xds>chZFS+XNagCpb}HYu`-X zbJ+vyQN6M0EY(aLU>oH>a*;h!C^gw#p^-Ass&V$C$y`7gQq_wwSvOcg>I_kR2IvA{ zqe)+2N{?J@WAblDSV{>9Mo+vPs03Fm+x1(@w_%nL`Sdt@6)#w7oEKLD(0E(-4~aPX zivxNuSeUmX*1K-mt+=9)vKuD9qg0jth^Zr$^YXTLd^Ln4JYt;w)|}S}(pRNF#v%Qs zEZC#m^3+3^{QcgpiF8|D^5%0kNX=le+zo+D>!8hzs`g<4!Q3px9O!b9-^&zfPJ}#D z^i5c14`c5^MO9>RXlZyUON^Y1F_6OqK3-9o`ji?}ZYiaH4&iTQeh|T91U|iNG-bt9 zaMgsydxb%yrHD8BhxDG={}u2ek36_toTp(+P$^|!>1U=P`2SDMz2`yLtkqceEngJ& zPv^16H1(T$`ld%1v9ryu-CY;&zw<+t8eaNtKW0XMVTIjj;>?{D>31{gA%+2m@gf5i zYa^c_Lo3w_UFUxPvyO4@gd&@tuy)in!fm;jIxlg`eC6u`mAC`;rvR3RQis&xq zuLzhxx1_KLeLhlwAj(Q4Y3SS8RHJQJRKz2kSlfwLWMu^g2g zayU1Sx@B6GZj&@tE#LRV_jIAs2 zgyW2#QQyxVZ0v}}>GV(noX1LP@VSh2oIEw5T6DQhLeO{SFxR{HM&(C-CHD-o*r_Z1 zOj@$Vq%v`6;O*EdFggZUhBnR^>xOi@I+x}{P&Gz$vzeCN^^c;$e7ii>Rqy=^LJpr_ z*MT^SK7of-S!!+iaa=-;{&yC>&6M@&inHc;_Au(bRI#5(kd~C1@HrOnje-67&(fq9 z@RRjBRiZ~4_~siOX;+YRN>aw^%==dv&d3C=b68V%1`jr*DoWIw?cX<_pn8_UrWx>bCu7kxx${!G!3Dm|WP~KSZ&6YQ|k` zJR}gfM*V);*5m9T-mExr(HNh{78I9RACx7EhRPs-Cs?pmqLHPI-T9j$qad1=n~L0Jbpa_>NF3ePP?uJ!(5v++dZS#BFLeK4h5mfWQQdq3TdBydlbk_yKz>$GGVsoL3ZMgK_2=a$I;fkWirIaF94onLW(eUqN z{|faC_VN=z_33-?0VdfIhh4EO@+Q8+uDxO+<7$V=r_fYMjPlx(#9abfn%uHcn$k}m z#wPKx`+iSMp(4MNV8KpNKu!r~>AQ<|JV-k+#}deY*0E&UeJS*PJN6tjTNs52(=toq z;476ccwp`v=PwnxFtmEB>~MG&Wn5bRXG0=nYi{sXuy&{A^VqGDsg=cj#zhf$e2YFO z@Uupe48jc&T)oKl`)rwsrzWs(L9n}!6APX3)*{!_xEix8c5YX;Q$7C?XR?1u(Rix* z<0XKFKd**v9DE~US>Z%~A9JFlcBMeXiV_QbOsCE*E#TX?PqeI;y|Ts~cu65YzGr=1 zO1{~g`Twf;%CM-qsO>{{mq>~df^Tg4EDm zBPm_OyvOJHzV~|GxvrT%GiS|O_rCX9XYaG`HJ7;8h#E?}QW1R#Sq~om^!65)2ftE$ zbXk=H<@+0MvJyJro+gCnLg$B=Q3`4hUgJhX7^fW{i$tiF|n2b~F$Ff;B0d&YrcOC?VGOjJ-vB zuUUN`xP|x2!#$Cf@@bPxaEcG*VAPpvu{zG$F~z}DLlj}>IO3QgM`a3*#ETh_(hb=84rSG=G-Y$mtvJ z3KdB8Uivc*VY~eN^2FhKKS23?Q?X9wqMepS!$X9~TUoGHq>XdeVTH)=M}7Pub7PT^ zfv~BZs`|!2s^YmRdg6G^ZF`wlZ^(lMyKN6`-pJ!26V2t3e}8HtN68-pf&e}XP?yFy z-Fx#DnZX`glRb|n;ZGMzd*s1b=IF&1BCw}J@x9j6BjIt@6j#2bEdP; zs_X}gWU+D%47fJ!2CN=Cfq5#Ag6^^V^*xVL#N2;_D9?7%-zyUSX2DI z`Fm7#)D>DLVC>>`__WCOn4#EjTflASw9oDu97D=!8tU<@;hF64D*U)IXqa?1**dtD z%==J6t|2mgJtrH(TG4BIq5EPQ+RLEIrvG;dQ)n~Iy~4(Jxn$%lGpl3^nI7@YKg!FR z?fImZ^^omS$PCvsNa`Y~9`vaMH&lqBD=MxtC389N$M4#YXj|@xh;^td6IOg;<+}Ys zcwy4)O^Q6I`%2Dg&O;4CIgxV5Q`SPXsvn}9kST(1!g3E^aOOkhu9U3Bm?$7OtRA-{ z-K#d%kI3k;zI67qgwV63m=a5kjqYoJTHnR0y|q{( zMwangs0Z$w0Mvth&3o??wn7j%evNMXreF$A_;r#EKhd=lI@+@;PP6+k)@m%hKm0@? z)wIou9PA~b`tlb;_hRBz4P#=bX>21Jz)gk0i)XMD(jfv;Fix8-PB8~C9_l5#q)v-BXwJ9jYUtM=Vfr_Wc>TnO@$l3p_ zHMv(VY;(|ePw{RPcF7X1>Jl4Qc++M~+RNVd@~cumDQf93;Qajb_Lrl8O$s3KrUV2y-$}hDNNK3_^^`b2bV9-F+Dg_(P;~ZZQ*7~;;g@~s}^QE-biL6e=a`Vj`&p0 zm6WHR#tO1tb!EnkRw0+-4*NU5O~LngzBw+h%X{uWyg_Z@Zy2Kk^KenOozuaDi36f9 z!avRDU-^De-1h#@1Qw|zbtJTnvQtLW=9Xr4wJt~Jz5M- zLj+v&%s5IPU6-o#J23)0_rI3R4Bl|36rO9^PIA*le1?h~l z{k~Rc+V4Dex!J$Q%HJp@f1hnkCZG9~2p3!avlU1{4sUgDS+>I0qFtmNR2<96y4w_YAmqF}>d_#sMqO@LLB9*ZrH&Z$@9VG- zSDakgs-s1t1!E!;#|Hd7EoVKSv4`R!J()TxhGZZc@z$z}yWSIl$3P!=C-A!;lTLs) zw*_9;=Rdi=T+e7l`!*OA+>mBm>n~dao)HZvC)Ybv8fkBakt!jzjFBHBJkZ)8DKE=@ zO|7IbTg}F-^!iIuCAid)j!u}p!uPvDte*t+*1LzZqCjzIuR9EE_VEG`l2mAXFh*a{ zLKxeTMBi;Dscb-Aj2GVfTW>E#2hLgyh>Arwbkn-MQNlI<3 zhJ&{Qg4lXL)H2C&X5HIM7|_o(GLiAT4&Wc;Lbak}b-xYw_*b)zfHxRM!E|OxKEaDh zH3)<>-bLFZxNpeHs+Q~67O^NG@};rW)vWPVNxv+4_4)(6iY`osY=>U~h+E-e7zj8KTjTEq}sHhkuB?qOSq`Y49_So_Zyn89_ zudP+od>Ai-c_~`|7Lh-96?)`s2uB zYNqZH+`uM`o6$PP0sIqaRZhnJCo>@Kvo!OLv%>}nwP8DZg^*izyCRnU=x+mAKJUb>x^DlmNt}3b3QdXOxgs-4R#5LnL-UD2%bKmmnF~FfS z`X?8+G*iwKnHw;h>%QX%@VHNiHL?S5|I$X9H4eq(cAys^Folo&@tE4#6~X%G4wmhrIs%3-J}7RRBDj z-L`hhRE!!dKG7$CfG@UqnQTjJ(SZ`(_{4;H8UN+ilSAy-K*-nb*Sv953eP6P>=je{ zU>h1T(oD=HRa&ibdkqad3l_vLM47_@x1P&N2^xe9-^kox&UNhf(=TValWrrTt+;>& z7rw)X9#F9B#i8Ec%Ub1_1wORVZJV=Ga?GUi5RCABcCVy)H1?eflbP#o_4|~?fjBSg zmzv+gGAV5Mn4_p>=1xqSDz|ohKNZFLUw?P-f-dup?lX;+-nr`!PifN>oKRTh5c&*N5R9^)qY)sx_6(_@jjlbDrPI#pB`?Htn(;afL7 zRV@7ehG|=W?&xt1ptW{=?^dVuvyC4=n2ltTnNGHIqLFmleEB5qp7OiA zN!zQyb~@ElHc}^1>yB!2(+tR?C~{Xxim7|O#E6Qgyroo}j{50Sm(-fI_qZGQN8hSE zqsfeYqpm+gA+lQ+(i^;Yo=m|ezt5m}{zvXCnfvD8Aa{5^6z(Z|^k;ohHqYYu%=DD6 z;ut-<9zm9Wba)g~BxQoFy_-P@;sXOoCYjes<2Fb7JY{KiXz`MwPq@>}oo}x;f@&H| zEoJMAA2ByoJrWBH=<@ff4vdnTYB`Z4yz?#H(DDHs9=A5sqF&_|Jv@16J8|GICNUmE z;J;6Mwp+$k+dX9FIwYmVCmaGmH@?UR)oiwDSV78(Y@p4aZm4WDqVZB|ykV@U&qbrxih z8Mg6i=dvoAsMCE?1Y?O(^UloZ@eU~wj5J#YVD-cIUH?G%gSUD3OIR!OrlY_x^l)cL zp$QS}e?(=}jRkbZ^;Wg>ec78wRJWDQIpAehn8)n&C%(@toJM8B^Fh7g2onO1>r(E)xi`j8$K&X;Ma=txN&p6E zY3o3uD}x&b-5p?|8I-jzW z&fyxz1&o$P*TCZ%@5(r-k_Gsyhom)(Vjm;t0VG}T$?SmqZKkWLsyyivIV^xs3-pZ^ zG?E@W*@*s%HqG2t^|4GpH94DI`9Rz>b)o^60s5nXw56`CAKwlW?2T~&MOYD6@d7Cb zpbA003L3qiOAfk2X_Tq8VfNC~ntjVJ;;kSH%)|;Er>!hiV5-DEA2|D2@jJCFEnhp| zucRi?aNpo&^LKnL)JU2nLPepF3nSdh8%qeIxV*gl znan~7@Va{02up}aNHc;k^8O%)WmqO@)+-|8jz(!7&@=vcBKXiOS+mf9?=FKt%ib|R z093fG9cB72pl%p!+SJM|CG@e3nze(17)xQUzcA#j3WK$ZKuLTi3MFasl&}0zBH}`kkN`Ef5G5NE*`=^e z%1yE}2EfJ6ht-_fn|bJ5iPXD+4`RKM(%Rgt5244A8J3q!SOCxTP*!F+*fd~Kx6okV zsh6>|r2%u5^Y8xNW;cwxf^lKYS&BrH9((ee*Kbs4e~f* z{;_;sU~&1ChXD$Fsv8Opbi8di4*KZmX!D=%r^3s<*=W1*TMBmXyTl?)2r7-Psz}$qG)^Il#=CTirM>I;y zgn1r0Y`Bc<0x+Q`-&QSZI$xQB2pN%W@Y3Bmi~mi#j(9>V0WnGyT_aC+Is8cJ&lS2e zl5N7WC|)X@u46y!n|(3h5qkAJtl@7K00D5ZfkKyFZcgW^t478IBl?2MRw1QY^2XXL zW7->anH4537Mu#)H+|87;SUWL`-ZROSNlI`34!nYgToZ{{`I0U2#^bgda_iyr>Dm) zkGTP0{QOTY7r1EJx=>5XJMWTgyBI?H zjUy#V02t-UX%dE_kYi#dbUp}bU=c{jUd)56MEkP{OzEWsK5QdU2E^A!2ai2;^5oIH zG&oq+B)T{0Y6R@;rM|plzZU+OaiuoX`#uNQ)vwP9KA9O~z28#eEThA+G$vi6vq#wA zPY6Xgt9Qo@8A>E&UuBP)TbvP?{E^&nPsyr$@w$p*tl``Mh&FlPGWmoE(A?i9ee&E33Vb7nRL78%JY z2H)q^g)=fhBx^G7`rkpBezVeY4GrwR*zJY#)*t1vu4?e6?}n5gaa*DLMOUpZ4>Q^| z+W^4kblMTs*OuEBmaRF+8F}%U3;D(+fGjuE<#+lOY>)O**|zsHxwobhRsck zA!k{9pIaA$unR0x1`;YZ5_kCKxI=Rps40}o@k+eiJ*4-9ss2s_Hb^z9sYf#K?X ztQICLCv((SX=|ZwK?g-fZmv_dIX5Xna!@|&(h|x5;;78VY@VWcj8d`C5yL==;h(|2RBj5khnI$$12{>-T7craK0A155PFP4aQJ6-UbE~~%ms18GgE26 zGDt0wJ4Db`fMgw>iuTo#lSjsINBT|j^+z;)-Fz-)D=a1(;_ei1o=vG#FBejBGF9`U zefDrE+{!7idfk5tMi`x#5F%gq;aO{KT8psPi~E80zemT{!0GX>T)nr9+hzXtRm(~X z>hDa@3(ZCOKz2`j+}aO3<KhtPwsKp8*@CPw zIBHCO5@L&uj>A+NRKRi?7)NDQRaKEesjKVsjz?ARs)45DyOcjFw#&wJSrbo7?)U

wA3`3mm)zaBc{Benkmd1dyZE_ZGqa6;prfzzx5XUhFG0k*F{C;cm3&Fl+_ zuLdCfwcU9Yx4UjcIR09tWC1@pO6W7N5N}M^^xn=1KjkFxCL7Ki)x8wPwNBlbvL{q@ z>L8+K#MkkOhR*t~&f}aWE^K1{f2`>l>0&?Nxg!aK&pCAE zyBuca(crrsk_42xrZ(p5cr8FyJ`bj)b`BSZauAeAK<5nIIO58r5fKVy+*X3AfHn=j zk?w(I?zqw05pV6CF;%6P9KiSfv@*B983uHeN`(LbV1&fqDPe<~%Rsvieo5E9#QisD z_37=W$bm%%nn}rt35!iSo3Mz;A+pH^xU1AGmlk|Rc$jwOOvKzeHLk<3k=pgjmE+h- z5%;gH0vrHP#GSBZ1u~gmmyV@n!3`iOYt^a0Lm_$PK@=7NmgyNC`_<4=WdOQ!sV2oX zW$R;A+sC-XDHu~W=s?&ZZdhd|Y<)>O_7ZwAHDLThB9PvNnCdU8(d2(!rbQS5A_H_- zA_lx>V}TaEef2MLd0X~4PvGER^aX3#gnr~X8?~GWL@omvM^Gz*w76~U32|IE;?V8i zGo)DO&9)u&cRy zfLPr42?orl(HeQFA@GIj-d}$xxu+Vhpzp`PD!) z=nLf|-+w&%z=Se_3*&Z%=-N)NM%c&?yf{w*0$Y`f`Y?w!MoUh z2m2RzgpK_*ha7JHYWE0?E%mn{#GHF+x&nPR^Uk}40LEXq$Gp}1GaRbZQOBs`=@#y& zPkL5P`Oj5b|HMj)6=2B8hhMfFI&fkMj9(LOer?G>Nh9Y^%)2iq-MGiCnE%6oZ6-@pq|X2S`ifd&xDS?1B0b|nRdi!Kqb2I zU|IR@qjXp1O6X2-U+8}kxP?vl3Km6l{gfmOhcaSh@iWmM6YhVFP{D5N-KD|??tG_# zN`W`LAYT7Ja+O=k1b|({;ipJK5R3meBW(!*`)*2{CAO}k1E&8I`ai$wqV!Ix27_>(~DSXqiD~ diff --git a/etc/codegen/Cargo.toml b/etc/codegen/Cargo.toml index 4a74d53..f9dd385 100644 --- a/etc/codegen/Cargo.toml +++ b/etc/codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallop-codegen" -version = "0.1.7" +version = "0.1.8" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/codegen/tests/codegen_basic.rs b/etc/codegen/tests/codegen_basic.rs index 441db18..cb172c4 100644 --- a/etc/codegen/tests/codegen_basic.rs +++ b/etc/codegen/tests/codegen_basic.rs @@ -149,7 +149,7 @@ fn codegen_const_fold_test_1() { let mut ctx = unit::UnitProvenance::default(); let result = const_fold_test_1::run(&mut ctx); - expect_static_output_collection(&result.R, vec![(4i32, 1usize)]); + expect_static_output_collection(&result.R, vec![(4i32, 1i32)]); } #[test] diff --git a/etc/codegen/tests/codegen_edb.rs b/etc/codegen/tests/codegen_edb.rs index ee97da8..1acfbf6 100644 --- a/etc/codegen/tests/codegen_edb.rs +++ b/etc/codegen/tests/codegen_edb.rs @@ -17,11 +17,11 @@ fn codegen_edge_path_with_edb_1() { let edb = edge_path::create_edb::(); assert_eq!( edb.type_of("edge").unwrap(), - >::from_type() + >::from_type() ); assert_eq!( edb.type_of("path").unwrap(), - >::from_type() + >::from_type() ); } diff --git a/etc/scallop-wasm/Cargo.toml b/etc/scallop-wasm/Cargo.toml index 2ad1b5e..b8f4e7c 100644 --- a/etc/scallop-wasm/Cargo.toml +++ b/etc/scallop-wasm/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallop-wasm" -version = "0.1.7" +version = "0.1.8" authors = ["Ziyang Li"] edition = "2018" diff --git a/etc/scallopy/Cargo.toml b/etc/scallopy/Cargo.toml index ba62279..aad9f1b 100644 --- a/etc/scallopy/Cargo.toml +++ b/etc/scallopy/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallopy" -version = "0.1.7" +version = "0.1.8" edition = "2018" [lib] diff --git a/etc/scallopy/examples/foreign_predicate.py b/etc/scallopy/examples/foreign_predicate.py new file mode 100644 index 0000000..166a385 --- /dev/null +++ b/etc/scallopy/examples/foreign_predicate.py @@ -0,0 +1,26 @@ +from typing import * + +import scallopy +from scallopy import foreign_predicate, Generator +from scallopy.types import * + + +@foreign_predicate +def string_semantic_eq(s1: str, s2: str) -> Generator[float, Tuple]: + if s1 == "mom" and s2 == "mother": + yield (0.99, ()) + elif s1 == "mother" and s2 == "mother": + yield (1.0, ()) + + +ctx = scallopy.Context(provenance="minmaxprob") +ctx.register_foreign_predicate(string_semantic_eq) +ctx.add_relation("kinship", (str, str, str)) +ctx.add_facts("kinship", [ + (1.0, ("alice", "mom", "bob")), + (1.0, ("alice", "mother", "cassey")), +]) +ctx.add_rule("parent(a, b) = kinship(a, r, b) and string_semantic_eq(r, \"mother\")") +ctx.add_rule("sibling(a, b) = parent(c, a) and parent(c, b) and a != b") +ctx.run() +print(list(ctx.relation("sibling"))) diff --git a/etc/scallopy/scallopy/__init__.py b/etc/scallopy/scallopy/__init__.py index bc1ae99..8f5e957 100644 --- a/etc/scallopy/scallopy/__init__.py +++ b/etc/scallopy/scallopy/__init__.py @@ -2,11 +2,16 @@ from .forward import ScallopForwardFunction from .provenance import ScallopProvenance from .function import GenericTypeParameter, foreign_function +from .predicate import Generator, foreign_predicate from .types import * +from .input_mapping import InputMapping # Provide a few aliases Context = ScallopContext ForwardFunction = ScallopForwardFunction +Module = ScallopForwardFunction Provenance = ScallopProvenance Generic = GenericTypeParameter ff = foreign_function +fp = foreign_predicate +Map = InputMapping diff --git a/etc/scallopy/scallopy/context.py b/etc/scallopy/scallopy/context.py index dbc0ced..1694baf 100644 --- a/etc/scallopy/scallopy/context.py +++ b/etc/scallopy/scallopy/context.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import List, Union, Tuple, Optional, Any, Dict, Callable +from typing import * from copy import deepcopy # Try import torch; if not there, delegate to something else @@ -10,10 +10,12 @@ from .collection import ScallopCollection from .provenance import ScallopProvenance, DiffAddMultProb2Semiring, DiffNandMultProb2Semiring, DiffMaxMultProb2Semiring from .io import CSVFileOptions -from .utils import _mapping_tuple +from .input_mapping import InputMapping from .function import ForeignFunction +from .predicate import ForeignPredicate from .history import HistoryAction, record_history from .sample_type import SAMPLE_TYPE_TOP_K +from .utils import Counter # Main context class ScallopContext(Context): @@ -87,7 +89,7 @@ def __init__( self._input_retain_topk = {} self._input_non_probabilistic = {} self._input_is_singleton = {} - self._input_mutual_exclusions = 0 + self._mutual_exclusion_counter = Counter() self._sample_facts = {} self._k = k self._train_k = train_k @@ -102,7 +104,7 @@ def __init__( self._input_retain_topk = deepcopy(fork_from._input_retain_topk) self._input_non_probabilistic = deepcopy(fork_from._input_non_probabilistic) self._input_is_singleton = deepcopy(fork_from._input_is_singleton) - self._input_mutual_exclusions = deepcopy(fork_from._input_mutual_exclusions) + self._mutual_exclusion_counter = deepcopy(fork_from._mutual_exclusion_counter) self._sample_facts = deepcopy(fork_from._sample_facts) self._k = deepcopy(fork_from._k) self._train_k = deepcopy(fork_from._train_k) @@ -155,6 +157,21 @@ def clone(self) -> ScallopContext: """ return ScallopContext(fork_from=self) + def set_early_discard(self, early_discard: bool = True): + """ + Configure the current context to perform early discard (or not) + """ + self._internal.set_early_discard(early_discard) + + def set_iter_limit(self, iter_limit: Optional[int] = None): + """ + Configure the current context to have limit on iteration (or not) + """ + if iter_limit is None: + self._internal.remove_iter_limit() + else: + self._internal.set_iter_limit(iter_limit) + def run(self): """ Execute the code under the current context. This operation is incremental @@ -202,6 +219,13 @@ def register_foreign_function(self, foreign_function: ForeignFunction): else: raise Exception("Registering non-foreign-function. Consider decorating the function with @scallopy.foreign_function") + @record_history + def register_foreign_predicate(self, foreign_predicate: ForeignPredicate): + if type(foreign_predicate) == ForeignPredicate: + self._internal.register_foreign_predicate(foreign_predicate) + else: + raise Exception("Registering non-foreign-predicate. Consider decorating the function with @scallopy.foreign_predicate") + def forward_function( self, output: Optional[str] = None, @@ -260,7 +284,7 @@ def add_relation( self, relation_name: str, relation_types: Union[Tuple, type, str], - input_mapping: Optional[Union[List[Tuple], Tuple]] = None, + input_mapping: Any = None, retain_topk: Optional[int] = None, non_probabilistic: bool = False, load_csv: Optional[Union[CSVFileOptions, str]] = None, @@ -512,36 +536,46 @@ def _set_relation_non_probabilistic(self, relation: str, non_probabilistic: bool else: raise Exception(f"Unknown relation {relation}") - def set_input_mapping(self, relation: str, input_mapping: Union[List[Tuple], Tuple]): + def set_input_mapping( + self, + relation: str, + input_mapping: Any, + disjunctive: bool = False, + disjunctive_dim: Optional[int] = None, + retain_threshold: Optional[float] = None, + retain_k: Optional[int] = None, + sample_dim: Optional[int] = None, + sample_strategy: Optional[Literal["top", "categorical"]] = "top", + ): """ Set the input mapping for the given relation """ - if type(input_mapping) == list: - (preproc_input_mapping, is_singleton) = ([_mapping_tuple(t) for t in input_mapping], False) - elif type(input_mapping) == tuple: - (preproc_input_mapping, is_singleton) = ([_mapping_tuple(input_mapping)], True) - else: - raise Exception(f"Unknown input mapping type `{type(input_mapping)}`. Must be a list or tuple") + + # Try to create an input mapping, pass all configurations + mapping = InputMapping( + input_mapping, + disjunctive=disjunctive, + disjunctive_dim=disjunctive_dim, + retain_threshold=retain_threshold, + retain_k=retain_k, + sample_dim=sample_dim, + sample_strategy=sample_strategy, + supports_disjunctions=self.supports_disjunctions(), + ) # Check if the tuples in the input mapping matches - for tuple in preproc_input_mapping: - if not self._internal.check_tuple(relation, tuple): - raise Exception(f"The tuple {tuple} in the input mapping does not match the type of the relation `{relation}`") + for t in mapping.all_tuples(): + if not self._internal.check_tuple(relation, t): + raise Exception(f"The tuple {t} in the input mapping does not match the type of the relation `{relation}`") # If there is no problem, set the input_mapping property - self._input_mappings[relation] = (preproc_input_mapping, is_singleton) + self._input_mappings[relation] = mapping - def set_sample_topk_facts(self, relation: str, amount: int): - self._input_retain_topk[relation] = amount - - def set_sample_facts(self, relation: str, amount: int, sample_type: str = SAMPLE_TYPE_TOP_K): - """ - When forward function is called, sample from the given facts by the amount - - :param relation: the relation within which the facts need to be sampled - :param sample_type: can be chosen from "categorical" or "" - """ - self._sample_facts[relation] = (sample_type, amount) + def has_input_mapping(self, relation: str) -> bool: + if relation in self._input_mappings: + if self._input_mappings[relation] is not None and self._input_mappings[relation].kind is not None: + return True + return False def requires_tag(self) -> bool: """ @@ -571,6 +605,8 @@ def supports_disjunctions(self) -> bool: return self.provenance in PROVENANCE_SUPPORTING_DISJUNCTIONS def _process_disjunctive_elements(self, elems, disjunctions): + processed_elems = [e for e in elems] + # Check if we the provenance supports handling disjunctions if self.supports_disjunctions(): @@ -582,26 +618,25 @@ def _process_disjunctive_elements(self, elems, disjunctions): # Go through each disjunction for disjunction in disjunctions: # Assign a new disjunction id for this one - disjunction_id = self._input_mutual_exclusions - self._input_mutual_exclusions += 1 # Increment the mutual exclusion count + disjunction_id = self._mutual_exclusion_counter.get_and_increment() # Go through the facts and update the tag to include disjunction id for fact_id in disjunction: visited_fact_ids.add(fact_id) if self.requires_tag(): (tag, tup) = elems[fact_id] - elems[fact_id] = ((tag, disjunction_id), tup) + processed_elems[fact_id] = ((tag, disjunction_id), tup) else: - elems[fact_id] = (disjunction_id, elems[fact_id]) + processed_elems[fact_id] = (disjunction_id, elems[fact_id]) # Update facts who is not inside of any disjunction for (fact_id, fact) in enumerate(elems): if fact_id not in visited_fact_ids: if self.requires_tag(): (tag, tup) = fact - elems[fact_id] = ((tag, None), tup) + processed_elems[fact_id] = ((tag, None), tup) else: - elems[fact_id] = (None, fact) + processed_elems[fact_id] = (None, fact) # Return the processed elements - return elems + return processed_elems diff --git a/etc/scallopy/scallopy/forward.py b/etc/scallopy/scallopy/forward.py index 83425bc..441510d 100644 --- a/etc/scallopy/scallopy/forward.py +++ b/etc/scallopy/scallopy/forward.py @@ -21,13 +21,15 @@ def __init__( provenance: str = "difftopkproofs", custom_provenance: Optional[ScallopProvenance] = None, non_probabilistic: Optional[List[str]] = None, - input_mappings: Optional[Dict[str, List]] = None, + input_mappings: Optional[Dict[str, Any]] = None, output_relation: Optional[str] = None, output_mapping: Optional[List] = None, output_mappings: Optional[Dict[str, List]] = None, k: int = 3, train_k: Optional[int] = None, test_k: Optional[int] = None, + early_discard: Optional[bool] = None, + iter_limit: Optional[int] = None, retain_graph: bool = False, jit: bool = False, jit_name: str = "", @@ -71,6 +73,14 @@ def __init__( for (relation, elems) in facts.items(): self.ctx.add_facts(relation, elems) + # Configurations: iteration limit + if iter_limit is not None: + self.ctx.set_iter_limit(iter_limit) + + # Configurations: early discarding + if early_discard is not None: + self.ctx.set_early_discard(early_discard) + # Create the forward function self.forward_fn = self.ctx.forward_function( output=output_relation, @@ -381,6 +391,8 @@ def _process_input_facts(self, rela, rela_facts, disjunctions) -> List[Tuple]: # Process the facts ty = type(rela_facts) # The type of relation facts index_mapping = None # The index mapping of given facts and preprocessed facts if there is removal of facts + + # If the facts are directly provided as list if ty == list: facts = rela_facts if rela in self.ctx._input_retain_topk: @@ -391,32 +403,23 @@ def _process_input_facts(self, rela, rela_facts, disjunctions) -> List[Tuple]: facts = [f for (_, f) in indexed_facts] else: facts = sorted(facts, key=lambda x: x[0].item(), reverse=True)[:k] + + # Remap disjunction + remapped_disjs = [[index_mapping[i] for i in d if i in index_mapping] for d in disjunctions] if index_mapping is not None else disjunctions + + # Process elements with this disjunction + facts = self.ctx._process_disjunctive_elements(facts, remapped_disjs) + + # If the facts are provided as Tensor elif ty == Tensor: if rela not in self.ctx._input_mappings: raise Exception(f"scallopy.forward receives vectorized Tensor input. However there is no `input_mapping` provided for relation `{rela}`") - probs = rela_facts - single_element = self.ctx._input_mappings[rela][1] - if single_element: - fact = self.ctx._input_mappings[rela][0][0] - facts = [(probs, fact)] - else: - if rela in self.ctx._input_retain_topk and self.ctx._input_retain_topk[rela]: - k = min(self.ctx._input_retain_topk[rela], len(probs)) - (top_probs, top_prob_ids) = torch.topk(probs, k) - facts = [(p, self.ctx._input_mappings[rela][0][i]) for (p, i) in zip(top_probs, top_prob_ids)] - if disjunctions is not None: - index_mapping = {j: i for (i, j) in enumerate(top_prob_ids)} - else: - facts = list(zip(probs, self.ctx._input_mappings[rela][0])) + + # Use the input mapping to process + facts = self.ctx._input_mappings[rela].process_tensor(rela_facts) else: raise Exception(f"Unknown input facts type. Expected Tensor or List, found {ty}") - # Remap disjunction - remapped_disjs = [[index_mapping[i] for i in d if i in index_mapping] for d in disjunctions] if index_mapping is not None else disjunctions - - # Process elements with this disjunction - facts = self.ctx._process_disjunctive_elements(facts, remapped_disjs) - # Add the facts return facts @@ -436,7 +439,7 @@ def _run_single(self, task_id, all_inputs, output_relations): if not self.ctx.has_relation(rela): raise Exception(f"Unknown relation `{rela}`") facts = rela_inputs[task_id] - temp_ctx.add_facts(rela, facts) + temp_ctx._internal.add_facts(rela, facts) # Execute the context if self.debug_provenance: @@ -707,7 +710,7 @@ def pad_input(l): for (output_id, output_tag) in enumerate(task_result): # output_size if output_tag is not None: (_, deriv) = output_tag - for (input_id, weight, _) in deriv: + for (input_id, weight) in deriv: mat_w[batch_id, output_id, input_id] = weight # backward hook diff --git a/etc/scallopy/scallopy/function.py b/etc/scallopy/scallopy/function.py index 4b08e17..a5ba65d 100644 --- a/etc/scallopy/scallopy/function.py +++ b/etc/scallopy/scallopy/function.py @@ -1,10 +1,10 @@ from typing import * import inspect - +# Generic type parameters class GenericTypeParameter: """ - A generic type parameter used for Scallop foreign function + A generic type parameter used for Scallop foreign function and predicate """ COUNTER = 0 @@ -31,6 +31,7 @@ def __repr__(self): return f"T{self.id}({self.type_family})" +# Scallop Data Type class Type: def __init__(self, value): if type(value) == GenericTypeParameter: @@ -82,6 +83,9 @@ def __repr__(self): else: raise Exception(f"Unknown parameter kind {self.kind}") + def is_base(self): + return self.kind == "base" + def is_generic(self): return self.kind == "generic" @@ -168,7 +172,7 @@ def foreign_function(func): A decorator to create a Scallop foreign function, for example ``` python - @scallop_function + @scallopy.foreign_function def string_index_of(s1: str, s2: str) -> usize: return s1.index(s2) ``` diff --git a/etc/scallopy/scallopy/input_mapping.py b/etc/scallopy/scallopy/input_mapping.py new file mode 100644 index 0000000..df5aeed --- /dev/null +++ b/etc/scallopy/scallopy/input_mapping.py @@ -0,0 +1,398 @@ +from typing import * +from functools import reduce +import itertools +from copy import deepcopy + +from .torch_importer import * +from .utils import Counter + +class InputMapping: + """ + An input mapping for converting tensors into probabilistic symbols. + + Input mappings in Scallop can be constructed by many different ways, + such as through single values, tuples, lists, iterators, and + dictionaries. + """ + + def __init__( + self, + mapping, + disjunctive: bool = False, + disjunctive_dim: Optional[int] = None, + retain_threshold: Optional[float] = None, + retain_k: Optional[int] = None, + sample_dim: Optional[int] = None, + sample_strategy: Optional[Literal["top", "categorical"]] = "top", + supports_disjunctions: bool = False, + ): + """Create a new input mapping""" + + # Initialize the mapping information + ty = type(mapping) + + # First check if we are directly copying from another InputMapping + if ty == InputMapping: + self.__dict__ = deepcopy(mapping.__dict__) + self.supports_disjunctions = supports_disjunctions + return + + # Otherwise, process from beginning + if ty == list or ty == range: + self._initialize_list_mapping(mapping) + elif ty == dict: + self._initialize_dict_mapping(mapping) + elif ty == tuple: + self._initialize_tuple_mapping(mapping) + elif self._is_primitive(mapping): + self._initialize_value_mapping(mapping) + elif mapping == None: + self._kind = None + else: + raise Exception(f"Unknown input mapping type `{type(mapping)}`") + + # Initialize the properties + self.disjunctive = disjunctive or (disjunctive_dim is not None) + self.disjunctive_dim = disjunctive_dim + self.retain_threshold = retain_threshold + self.retain_k = retain_k + self.sample_dim = sample_dim + self.sample_strategy = sample_strategy + self.supports_disjunctions = supports_disjunctions + + # Validate the configurations + if self.disjunctive_dim is not None: + if not (0 <= self.disjunctive_dim < self.dimension): + raise Exception(f"Invalid disjunction dimension {self.disjunctive_dim}; total dimension is {self.dimension}") + if self.sample_dim is not None: + if not (0 <= self.sample_dim < self.dimension): + raise Exception(f"Invalid sampling dimension {self.sample_dim}; total dimension is {self.dimension}") + + def __getitem__(self, index) -> Tuple: + """Get the tuple of the input mapping from an index""" + if self._kind == "dict": + return tuple([self._mapping[i][j] for (i, j) in enumerate(index)]) + else: + return self._mapping[self._mult_dim_index_to_index(index)] + + def all_tuples(self) -> Iterator[Tuple]: + """Iterate over all the tuples in the input mapping""" + if self._kind == "list" or self._kind == "tuple" or self.kind == "value": + for element in self._mapping: + yield element + elif self.kind == "dict": + for element in itertools.product(*self._mapping): + yield element + + def all_indices(self) -> Iterator[Tuple]: + return itertools.product(*[list(range(x)) for x in self.shape]) + + def process_tensor(self, tensor: Tensor, batched=False, mutual_exclusion_counter=None) -> List: + """Process a tensor to produce a list of probabilistic symbols""" + + # Check the kind, if there is none + if self._kind == None: + raise Exception("Cannot apply None mapping to a tensor") + + # Create a new mutual exclusion counter if needed + if self.supports_disjunctions and mutual_exclusion_counter is None: + mutual_exclusion_counter = Counter() + + # Check the shape to decide whether to process a batched input + if tensor.shape == self._shape: + facts = self._process_one_tensor(tensor, mutual_exclusion_counter) + return [facts] if batched else facts + elif tensor.shape[1:] == self._shape: + return [self._process_one_tensor(item, mutual_exclusion_counter) for item in tensor] + else: + raise Exception(f"Tensor shape mismatch: expected {self._shape}, got {tensor.shape}") + + def _process_one_tensor(self, tensor: Tensor, mutual_exclusion_counter: Counter) -> List: + inc, exc = InclusiveSet(), ExclusiveSet() + + # Do sampling + import torch + if self.retain_k is not None: + if self.sample_dim is not None: + if self.sample_strategy == "categorical": + raise NotImplementedError() + elif self.sample_strategy == "top": + topk_result = torch.topk(tensor, self.retain_k, dim=self.sample_dim) + for index in self._convert_sampled_indices(topk_result.indices, self.sample_dim): + inc.add(index) + else: + raise Exception(f"Unknown sample strategy `{self.sample_strategy}`") + else: + flat_tensor = torch.flatten(tensor) + if self.sample_strategy == "categorical": + raise NotImplementedError() + elif self.sample_strategy == "top": + topk_result = torch.topk(flat_tensor, self.retain_k) + for index in topk_result.indices: + mult_dim_index = self._index_to_mult_dim_index(int(index)) + inc.add(mult_dim_index) + else: + raise Exception(f"Unknown sample strategy `{self.sample_strategy}`") + + # Do thresholding; if the probability is less than the threshold, add the index to the exclusive map + if self.retain_threshold is not None: + for index in self.all_indices(): + if tensor[index] < self.retain_threshold: + exc.exclude(index) + + # Get a set of filtered indices + filtered_indices = [index for index in self.all_indices() if inc.contains(index) and exc.contains(index)] + + # Add disjunctions + if self.supports_disjunctions: + if self.disjunctive: + if self.disjunctive_dim is not None: + partial_indices = itertools.product(*[range(d) for (i, d) in enumerate(self.shape) if i != self.disjunctive_dim]) + disj_map = {index: mutual_exclusion_counter.get_and_increment() for index in partial_indices} + get_partial_index = lambda index: index[:self.disjunctive_dim] + index[self.disjunctive_dim + 1:] + facts = [((tensor[index], disj_map[get_partial_index(index)]), self[index]) for index in filtered_indices] + else: + disj_id = mutual_exclusion_counter.get_and_increment() + facts = [((tensor[index], disj_id), self[index]) for index in filtered_indices] + else: + facts = [((tensor[index], None), self[index]) for index in filtered_indices] + else: + facts = [(tensor[index], self[index]) for index in filtered_indices] + + # Return facts + return facts + + def _initialize_list_mapping(self, mapping): + """ + Initialize an input mapping using a list + + The list can be nested, in which case it will be treated as a multi-dimensional mapping + """ + + # Need to make sure that a mapping is not empty + if len(mapping) == 0: + raise Exception("Invalid input mapping: a mapping list cannot be empty") + + # To process list of elements + curr_elements = [e for e in mapping] + shape = [len(curr_elements)] + is_singleton = None + while True: + first_element = curr_elements[0] + first_element_ty = type(first_element) + if self._is_primitive_or_tuple(first_element): + tuple_size = None if self._is_primitive(first_element) else len(first_element) + + # The first element is primitive or tuple; that means all elements need to be value or tuple + for (i, element) in enumerate(curr_elements): + if not self._is_primitive_or_tuple(element): + raise Exception(f"Invalid input mapping: expected terminal value/tuple at dimension {len(shape) + 1}, found {type(element)}") + + # Check the consistency of the tuple + if tuple_size == None: + if not self._is_primitive(element): + raise Exception(f"Invalid input mapping: expected singleton value, found tuple") + curr_elements[i] = (element,) + else: + if len(element) != tuple_size: + raise Exception(f"Invalid input mapping: expected tuple size {tuple_size}, found {len(element)}") + + # If all checks out, then set is_singleton value + is_singleton = tuple_size is None + + # Hit a good base case + break + + elif first_element_ty == list or first_element_ty == range: + # The first element is a list; first, the length of the list cannot be zero + next_size = len(first_element) + if next_size == 0: + raise Exception(f"Invalid input mapping: having potential empty dimension {len(shape) + 1}") + + # that means all elements need to be list of the same size + for element in curr_elements: + if type(element) != list: + if type(element) == range: + element = list(element) + else: + raise Exception(f"Invalid input mapping: expected a list or range, found {type(element)} at dimension {len(shape) + 1}") + if len(element) != next_size: + raise Exception(f"Invalid input mapping: shape mismatch at dimension {len(shape) + 1}, expected {next_size}, found {len(element)}") + + # Hit a good recursive case + shape.append(next_size) + curr_elements = [grand_child for child in curr_elements for grand_child in child] + + else: + # Not okay + raise Exception(f"Invalid input mapping: encountered element {type(first_element)}") + + # Upon success, set the basic information + self._kind = "list" + self._shape = tuple(shape) + self._is_singleton = is_singleton + self._source = mapping + self._mapping = curr_elements + + def _initialize_dict_mapping(self, mapping: Dict[int, List]): + """Initialize the input mapping with a dictionary""" + + # Mapping cannot be empty + if len(mapping) == 0: + raise Exception("Invalid input mapping: cannot be empty") + + # Check if the keys in the mapping are proper dimension integers + for dim_num in mapping.keys(): + if type(dim_num) != int: + raise Exception(f"Invalid input mapping: invalid dimension {dim_num}") + if dim_num < 0: + raise Exception("Invalid input mapping: cannot have negative dimension number") + + # Check the maximum dimensions and all dimensions are presented + max_dim_num = max(mapping.keys()) + if len(mapping) != max_dim_num + 1: + for i in range(max_dim_num + 1): + if i not in mapping: + raise Exception(f"Invalid input mapping: missing dimension {i}") + + # Get a shape + dimension = max_dim_num + 1 + + # Check that the values of the dictionary are all lists; create the shape list + processed_mapping = [] + shape = [] + for i in range(dimension): + to_check = mapping[i] + if type(to_check) != list: + if type(to_check) == range: + to_check = list(to_check) + else: + raise Exception("Invalid input mapping: value of dictionary must be a list or a range") + if len(to_check) == 0: + raise Exception(f"Invalid input mapping: empty dimension {i}") + for element in to_check: + if not self._is_primitive(element): + raise Exception("Invalid input mapping: element of dictionary value must be a primitive") + shape.append(len(to_check)) + processed_mapping.append(to_check) + + # Success! + self._kind = "dict" + self._shape = tuple(shape) + self._is_singleton = False + self._source = mapping + self._mapping = processed_mapping + + def _initialize_tuple_mapping(self, mapping): + """Initialize the input mapping with a tuple.""" + + # Check if all the elements inside of the tuple is value + for element in mapping: + if not self._is_primitive(element): + raise Exception("Invalid input mapping: elements of tuple must be a primitive") + + # Success! + self._kind = "tuple" + self._shape = tuple() + self._is_singleton = False + self._source = mapping + self._mapping = [mapping] + + def _initialize_value_mapping(self, mapping): + """Initialize the input mapping with a value""" + self._kind = "value" + self._shape = tuple() + self._is_singleton = True + self._source = mapping + self._mapping = [(mapping,)] + + def _is_primitive_or_tuple(self, e) -> bool: + ty = type(e) + if ty == tuple: + for ce in e: + if not self._is_primitive(ce): + return False + return True + else: + return self._is_primitive(e) + + def _mult_dim_index_to_index(self, mult_dim_index): + summed_index = 0 + for i in range(len(self._shape)): + summed_index += mult_dim_index[i] * reduce(lambda acc, i: acc * i, self._shape[i + 1:], 1) + return summed_index + + def _index_to_mult_dim_index(self, index): + acc_index = index + mult_dim_index = [] + for i in range(len(self._shape), 0, -1): + mult_dim_index.append(acc_index % self._shape[i - 1]) + acc_index = acc_index // self._shape[i - 1] + return tuple(reversed(mult_dim_index)) + + def _convert_sampled_indices(self, sampled_indices, sample_dim): + for partial_index in itertools.product(*[range(d) for (i, d) in enumerate(self._shape) if i != sample_dim]): + before, after = partial_index[:sample_dim], partial_index[sample_dim:] + indexing_slice = before + (slice(None),) + after + for sampled_k_index in sampled_indices[indexing_slice]: + yield before + (int(sampled_k_index),) + after + + def _is_primitive(self, e) -> bool: + ty = type(e) + return ty == int or ty == float or ty == str or ty == bool + + def _get_kind(self) -> str: + return self._kind + + def _set_kind(self): + raise Exception("Cannot set kind of an input mapping") + + kind = property(_get_kind, _set_kind) + + def _get_shape(self) -> Tuple[int]: + return self._shape + + def _set_shape(self): + raise Exception("Cannot set shape of an input mapping") + + shape = property(_get_shape, _set_shape) + + def _get_dimension(self) -> int: + return len(self._shape) + + def _set_dimension(self): + raise Exception("Cannot set dimension of an input mapping") + + dimension = property(_get_dimension, _set_dimension) + + def _get_is_singleton(self) -> int: + return self._is_singleton + + def _set_is_singleton(self): + raise Exception("Cannot set is_singleton of an input mapping") + + is_singleton = property(_get_is_singleton, _set_is_singleton) + + +class InclusiveSet: + def __init__(self, init_include_all=True): + self.include_all = init_include_all + self.included_set = set() + + def add(self, index: Tuple): + self.include_all = False + self.included_set.add(index) + + def contains(self, index: Tuple): + return self.include_all or (index in self.included_set) + + +class ExclusiveSet: + def __init__(self): + self.excluded_set = set() + + def exclude(self, index: Tuple): + self.excluded_set.add(index) + + def contains(self, index: Tuple) -> bool: + return index not in self.excluded_set diff --git a/etc/scallopy/scallopy/predicate.py b/etc/scallopy/scallopy/predicate.py new file mode 100644 index 0000000..27e82ed --- /dev/null +++ b/etc/scallopy/scallopy/predicate.py @@ -0,0 +1,174 @@ +from typing import * +import inspect + +# Predicate Data Type +class Type: + def __init__(self, value): + if isinstance(value, ForwardRef): + value = value.__forward_arg__ + if value == float: + self.type = "f32" + elif value == int: + self.type = "i32" + elif value == bool: + self.type = "bool" + elif value == str: + self.type = "String" + elif value == "i8" or value == "i16" or value == "i32" or value == "i64" or value == "i128" or value == "isize" or \ + value == "u8" or value == "u16" or value == "u32" or value == "u64" or value == "u128" or value == "usize" or \ + value == "f32" or value == "f64" or \ + value == "bool" or value == "char" or value == "String" or \ + value == "DateTime" or value == "Duration": + self.type = value + else: + raise Exception(f"Unknown scallop predicate type annotation `{value}`") + + def __repr__(self): + return self.type + + +class Generator(Generic[TypeVar("TagType"), TypeVar("TupleType")]): + pass + + +class ForeignPredicate: + """ + Scallop foreign predicate + """ + def __init__( + self, + func: Callable, + name: str, + input_arg_types: List[Type], + output_arg_types: List[Type], + tag_type: Any, + ): + self.func = func + self.name = name + self.input_arg_types = input_arg_types + self.output_arg_types = output_arg_types + self.tag_type = tag_type + + def __repr__(self): + r = f"extern pred {self.name}[{self.pattern()}](" + first = True + for arg in self.input_arg_types: + if first: + first = False + else: + r += ", " + r += f"{arg}" + for arg in self.output_arg_types: + if first: + first = False + else: + r += ", " + r += f"{arg}" + r += ")" + return r + + def __call__(self, *args): + if self.does_output_tag(): + return [f for f in self.func(*args)] + else: + return [(None, f) for f in self.func(*args)] + + def arity(self): + return len(self.input_arg_types) + len(self.output_arg_types) + + def num_bounded(self): + return len(self.input_arg_types) + + def all_argument_types(self): + return self.input_arg_types + self.output_arg_types + + def pattern(self): + return "b" * len(self.input_arg_types) + "f" * len(self.output_arg_types) + + def does_output_tag(self): + return self.tag_type is not None + + +def foreign_predicate(func: Callable): + """ + A decorator to create a Scallop foreign predicate, for example + + ``` python + @scallopy.foreign_function + def string_chars(s: str) -> scallopy.Generator[Tuple[int, char]]: + for (i, c) in enumerate(s): + yield (i, c) + ``` + """ + + # Get the function name + func_name = func.__name__ + + # Get the function signature + signature = inspect.signature(func) + + # Store all the argument types + argument_types = [] + + # Find argument types + for (arg_name, item) in signature.parameters.items(): + optional = item.default != inspect.Parameter.empty + if item.annotation is None: + raise Exception(f"Argument {arg_name} type annotation not provided") + if item.kind == inspect.Parameter.VAR_POSITIONAL: + raise Exception(f"Cannot have variable arguments in foreign predicate") + elif not optional: + ty = Type(item.annotation) + argument_types.append(ty) + else: + raise Exception(f"Cannot have optional argument in foreign predicate") + + # Find return type + if signature.return_annotation is None: + raise Exception(f"Return type annotation not provided") + elif signature.return_annotation.__dict__["__origin__"] != Generator: + raise Exception(f"Return type must be Generator") + else: + args = signature.return_annotation.__dict__["__args__"] + if len(args) != 2: + raise Exception(f"Generator must have 2 type arguments") + + # Produce return tuple type, and check that they are all base type + return_tuple_type = _extract_return_tuple_type(args[0]) + + # Produce return tag type + return_tag_type = _extract_return_tag_type(args[1]) + + # Create the foreign predicate + return ForeignPredicate( + func=func, + name=func_name, + input_arg_types=argument_types, + output_arg_types=return_tuple_type, + tag_type=return_tag_type, + ) + + +def _extract_return_tuple_type(tuple_type) -> List[Type]: + # First check if it is a None type (i.e. returning zero-tuple) + if tuple_type == None: + return [] + + # Then try to convert it to a base type + try: + ty = Type(tuple_type) + return [ty] + except: pass + + # If not, it must be a tuple of base types + if "__origin__" in tuple_type.__dict__ and tuple_type.__dict__["__origin__"] == tuple: + if "__args__" in tuple_type.__dict__: + return [Type(t) for t in tuple_type.__dict__["__args__"]] + else: + return [] + else: + raise Exception(f"Return tuple type must be a base type or a tuple of base types") + + +def _extract_return_tag_type(tag_type): + return tag_type diff --git a/etc/scallopy/scallopy/scallopy.pyi b/etc/scallopy/scallopy/scallopy.pyi index 4e1bb89..8064bf8 100644 --- a/etc/scallopy/scallopy/scallopy.pyi +++ b/etc/scallopy/scallopy/scallopy.pyi @@ -21,6 +21,12 @@ class InternalScallopContext: def set_k(self, k: int): ... + def set_early_discard(self, early_discard: bool = True): ... + + def set_iter_limit(self, k: int): ... + + def remove_iter_limit(self): ... + def run(self, iter_limit: Optional[int]) -> None: ... def run_with_debug_tag(self, iter_limit: Optional[int]) -> None: ... diff --git a/etc/scallopy/scallopy/utils.py b/etc/scallopy/scallopy/utils.py index cf0a9fa..0a7bf2f 100644 --- a/etc/scallopy/scallopy/utils.py +++ b/etc/scallopy/scallopy/utils.py @@ -3,3 +3,13 @@ # - Otherwise return the tuple directly def _mapping_tuple(t): return t if type(t) == tuple else (t,) + + +class Counter: + def __init__(self): + self.count = 0 + + def get_and_increment(self) -> int: + result = self.count + self.count += 1 + return result diff --git a/etc/scallopy/src/collection.rs b/etc/scallopy/src/collection.rs index e6ef526..9aa3615 100644 --- a/etc/scallopy/src/collection.rs +++ b/etc/scallopy/src/collection.rs @@ -1,79 +1,73 @@ -use std::sync::Arc; - use pyo3::class::iter::IterNextOutput; use pyo3::prelude::*; use scallop_core::common; use scallop_core::runtime::dynamic::*; use scallop_core::runtime::provenance::*; -use scallop_core::utils::ArcFamily; +use scallop_core::utils::*; use super::custom_tag; use super::provenance::*; use super::tuple::*; #[derive(Clone)] -pub enum CollectionEnum { +pub enum CollectionEnum { Unit { - collection: Arc>, + collection: P::Rc>, }, Proofs { - collection: Arc>, + collection: P::Rc>>, }, MinMaxProb { - collection: Arc>, + collection: P::Rc>, }, AddMultProb { - collection: Arc>, + collection: P::Rc>, }, TopKProofs { - collection: Arc>>, - tags: Arc>, + collection: P::Rc>>, }, TopBottomKClauses { - collection: Arc>>, - tags: Arc>, + collection: P::Rc>>, }, DiffMinMaxProb { - collection: Arc, ArcFamily>>>, - tags: Arc>)>>, + collection: P::Rc, P>>>, + tags: P::RcCell>>, }, DiffAddMultProb { - collection: Arc, ArcFamily>>>, - tags: Arc>>, + collection: P::Rc, P>>>, + tags: P::RcCell>>, }, DiffNandMultProb { - collection: Arc, ArcFamily>>>, - tags: Arc>>, + collection: P::Rc, P>>>, + tags: P::RcCell>>, }, DiffMaxMultProb { - collection: Arc, ArcFamily>>>, - tags: Arc>>, + collection: P::Rc, P>>>, + tags: P::RcCell>>, }, DiffNandMinProb { - collection: Arc, ArcFamily>>>, - tags: Arc>>, + collection: P::Rc, P>>>, + tags: P::RcCell>>, }, DiffSampleKProofs { - collection: Arc, ArcFamily>>>, - tags: Arc)>>, + collection: P::Rc, P>>>, + tags: DiffProbStorage, P>, }, DiffTopKProofs { - collection: Arc, ArcFamily>>>, - tags: Arc)>>, + collection: P::Rc, P>>>, + tags: DiffProbStorage, P>, }, DiffTopKProofsIndiv { - collection: - Arc, ArcFamily>>>, - tags: Arc)>>, + collection: P::Rc, P>>>, + tags: DiffProbStorage, P>, }, DiffTopBottomKClauses { - collection: - Arc, ArcFamily>>>, - tags: Arc)>>, + collection: P::Rc, P>>>, + tags: DiffProbStorage, P>, }, Custom { - collection: Arc>, + collection: P::Rc>, }, } @@ -100,24 +94,24 @@ macro_rules! match_collection { }; } -impl CollectionEnum { +impl CollectionEnum { pub fn num_input_facts(&self) -> Option { match self { Self::Unit { .. } => None, Self::Proofs { .. } => None, Self::MinMaxProb { .. } => None, Self::AddMultProb { .. } => None, - Self::TopKProofs { tags, .. } => Some(tags.len()), - Self::TopBottomKClauses { tags, .. } => Some(tags.len()), - Self::DiffMinMaxProb { tags, .. } => Some(tags.len()), - Self::DiffAddMultProb { tags, .. } => Some(tags.len()), - Self::DiffNandMultProb { tags, .. } => Some(tags.len()), - Self::DiffMaxMultProb { tags, .. } => Some(tags.len()), - Self::DiffNandMinProb { tags, .. } => Some(tags.len()), - Self::DiffSampleKProofs { tags, .. } => Some(tags.len()), - Self::DiffTopKProofs { tags, .. } => Some(tags.len()), - Self::DiffTopKProofsIndiv { tags, .. } => Some(tags.len()), - Self::DiffTopBottomKClauses { tags, .. } => Some(tags.len()), + Self::TopKProofs { .. } => None, + Self::TopBottomKClauses { .. } => None, + Self::DiffMinMaxProb { tags, .. } => Some(ArcFamily::get_rc_cell(tags, |t| t.len())), + Self::DiffAddMultProb { tags, .. } => Some(ArcFamily::get_rc_cell(tags, |t| t.len())), + Self::DiffNandMultProb { tags, .. } => Some(ArcFamily::get_rc_cell(tags, |t| t.len())), + Self::DiffMaxMultProb { tags, .. } => Some(ArcFamily::get_rc_cell(tags, |t| t.len())), + Self::DiffNandMinProb { tags, .. } => Some(ArcFamily::get_rc_cell(tags, |t| t.len())), + Self::DiffSampleKProofs { tags, .. } => Some(tags.num_input_tags()), + Self::DiffTopKProofs { tags, .. } => Some(tags.num_input_tags()), + Self::DiffTopKProofsIndiv { tags, .. } => Some(tags.num_input_tags()), + Self::DiffTopBottomKClauses { tags, .. } => Some(tags.num_input_tags()), Self::Custom { .. } => None, } } @@ -131,14 +125,14 @@ impl CollectionEnum { Self::TopKProofs { .. } => None, Self::TopBottomKClauses { .. } => None, Self::DiffMinMaxProb { .. } => None, - Self::DiffAddMultProb { tags, .. } => Some(tags.iter().cloned().collect()), - Self::DiffNandMultProb { tags, .. } => Some(tags.iter().cloned().collect()), - Self::DiffMaxMultProb { tags, .. } => Some(tags.iter().cloned().collect()), - Self::DiffNandMinProb { tags, .. } => Some(tags.iter().cloned().collect()), - Self::DiffSampleKProofs { tags, .. } => Some(tags.iter().map(|(_, t)| t.clone()).collect()), - Self::DiffTopKProofs { tags, .. } => Some(tags.iter().map(|(_, t)| t.clone()).collect()), - Self::DiffTopKProofsIndiv { tags, .. } => Some(tags.iter().map(|(_, t)| t.clone()).collect()), - Self::DiffTopBottomKClauses { tags, .. } => Some(tags.iter().map(|(_, t)| t.clone()).collect()), + Self::DiffAddMultProb { tags, .. } => Some(ArcFamily::get_rc_cell(tags, |t| t.clone())), + Self::DiffNandMultProb { tags, .. } => Some(ArcFamily::clone_rc_cell_internal(tags)), + Self::DiffMaxMultProb { tags, .. } => Some(ArcFamily::clone_rc_cell_internal(tags)), + Self::DiffNandMinProb { tags, .. } => Some(ArcFamily::clone_rc_cell_internal(tags)), + Self::DiffSampleKProofs { tags, .. } => Some(tags.input_tags()), + Self::DiffTopKProofs { tags, .. } => Some(tags.input_tags()), + Self::DiffTopKProofsIndiv { tags, .. } => Some(tags.input_tags()), + Self::DiffTopBottomKClauses { tags, .. } => Some(tags.input_tags()), Self::Custom { .. } => None, } } @@ -160,20 +154,20 @@ impl CollectionEnum { } fn ith_tag(&self, i: usize) -> Py { - fn ith_tag_helper(c: &Arc>, i: usize) -> Py + fn ith_tag_helper(c: &DynamicOutputCollection, i: usize) -> Py where Prov: PythonProvenance, { Prov::to_output_py_tag(c.ith_tag(i).unwrap()) } - match_collection!(self, c, ith_tag_helper(c, i)) + match_collection!(self, c, ith_tag_helper(ArcFamily::get_rc(c), i)) } } #[pyclass(unsendable, name = "InternalScallopCollection")] pub struct Collection { - pub collection: CollectionEnum, + pub collection: CollectionEnum, } #[pymethods] @@ -194,15 +188,15 @@ impl Collection { } } -impl From for Collection { - fn from(collection: CollectionEnum) -> Self { +impl From> for Collection { + fn from(collection: CollectionEnum) -> Self { Self { collection } } } #[pyclass(unsendable, name = "InternalScallopCollectionIterator")] pub struct CollectionIterator { - collection: CollectionEnum, + collection: CollectionEnum, current_index: usize, } diff --git a/etc/scallopy/src/context.rs b/etc/scallopy/src/context.rs index af92e49..c73abd8 100644 --- a/etc/scallopy/src/context.rs +++ b/etc/scallopy/src/context.rs @@ -18,6 +18,7 @@ use crate::custom_tag; use super::collection::*; use super::error::*; use super::foreign_function::*; +use super::foreign_predicate::*; use super::io::*; use super::provenance::*; use super::tuple::*; @@ -27,7 +28,7 @@ type AF = ArcFamily; #[derive(Clone)] pub enum ContextEnum { Unit(IntegrateContext), - Proofs(IntegrateContext), + Proofs(IntegrateContext, AF>), MinMaxProb(IntegrateContext), AddMultProb(IntegrateContext), TopKProofs(IntegrateContext, AF>), @@ -40,9 +41,7 @@ pub enum ContextEnum { DiffSampleKProofs(IntegrateContext, AF>, AF>), DiffTopKProofs(IntegrateContext, AF>, AF>), DiffTopKProofsIndiv(IntegrateContext, AF>, AF>), - DiffTopBottomKClauses( - IntegrateContext, AF>, AF>, - ), + DiffTopBottomKClauses(IntegrateContext, AF>, AF>), Custom(IntegrateContext), } @@ -182,6 +181,21 @@ impl Context { match_context!(&mut self.ctx, c, c.set_non_incremental()) } + /// Set early discard + fn set_early_discard(&mut self, early_discard: bool) { + match_context!(&mut self.ctx, c, c.set_early_discard(early_discard)) + } + + /// Set the iteration limit to be `k` + fn set_iter_limit(&mut self, k: usize) { + match_context!(&mut self.ctx, c, c.set_iter_limit(k)) + } + + /// Remove the iteration limit + fn remove_iter_limit(&mut self) { + match_context!(&mut self.ctx, c, c.remove_iter_limit()) + } + /// Compile the surface program stored in the scallopy context into the ram program. /// /// This function is usually used before creating a forward function. @@ -370,6 +384,13 @@ impl Context { Ok(()) } + /// Register a foreign predicate + fn register_foreign_predicate(&mut self, f: PyObject) -> Result<(), BindingError> { + let fp = PythonForeignPredicate::new(f); + match_context!(&mut self.ctx, c, c.register_foreign_predicate(fp)?); + Ok(()) + } + /// Execute the program /// /// If the context's ram program is already compiled, the program will be directly executed. @@ -651,14 +672,14 @@ fn is_all_equal + Clone>(i: T) -> bool { i.clone().min() == i.max() } -fn get_output_collection(c: &mut IntegrateContext, r: &str) -> Result +fn get_output_collection(c: &mut IntegrateContext, r: &str) -> Result, BindingError> where C: PythonProvenance, { if c.has_relation(r) { if let Some(collection) = c.computed_relation(r) { Ok(C::to_collection_enum( - ArcFamily::clone_ptr(&collection), + ArcFamily::clone_rc(&collection), c.provenance_context(), )) } else { @@ -673,7 +694,7 @@ fn get_output_collection_monitor( c: &mut IntegrateContext, m: &M, r: &str, -) -> Result +) -> Result, BindingError> where C: PythonProvenance, M: monitor::Monitor, @@ -681,7 +702,7 @@ where if c.has_relation(r) { if let Some(collection) = c.computed_relation_with_monitor(r, m) { Ok(C::to_collection_enum( - ArcFamily::clone_ptr(&collection), + ArcFamily::clone_rc(&collection), c.provenance_context(), )) } else { diff --git a/etc/scallopy/src/custom_tag.rs b/etc/scallopy/src/custom_tag.rs index 976008d..6f362e9 100644 --- a/etc/scallopy/src/custom_tag.rs +++ b/etc/scallopy/src/custom_tag.rs @@ -32,7 +32,7 @@ impl provenance::Provenance for CustomProvenance { "scallopy-custom" } - fn tagging_fn(&mut self, i: Self::InputTag) -> Self::Tag { + fn tagging_fn(&self, i: Self::InputTag) -> Self::Tag { Python::with_gil(|py| { let result = self.0.call_method(py, "tagging_fn", (i,), None).unwrap(); Self::Tag::new(result) diff --git a/etc/scallopy/src/error.rs b/etc/scallopy/src/error.rs index 64f875c..6f0c02f 100644 --- a/etc/scallopy/src/error.rs +++ b/etc/scallopy/src/error.rs @@ -14,6 +14,7 @@ pub enum BindingError { InvalidLoadCSVArg, InvalidBatchSize, EmptyBatchInput, + InvalidInputTag, PyErr(PyErr), } @@ -31,6 +32,7 @@ impl std::fmt::Display for BindingError { Self::InvalidLoadCSVArg => f.write_str("Invalid argument for `load_csv`"), Self::InvalidBatchSize => f.write_str("Invalid batch size"), Self::EmptyBatchInput => f.write_str("Empty batched input"), + Self::InvalidInputTag => f.write_str("Invalid input tag"), Self::PyErr(e) => std::fmt::Display::fmt(e, f), } } diff --git a/etc/scallopy/src/foreign_predicate.rs b/etc/scallopy/src/foreign_predicate.rs new file mode 100644 index 0000000..bc9658a --- /dev/null +++ b/etc/scallopy/src/foreign_predicate.rs @@ -0,0 +1,148 @@ +use pyo3::types::*; +use pyo3::*; + +use scallop_core::common::tuple_type::*; +use scallop_core::common::value_type::*; +use scallop_core::common::value::*; +use scallop_core::common::input_tag::*; +use scallop_core::common::foreign_predicate::*; + +use super::tuple::*; +use super::tag::*; + +#[derive(Clone)] +pub struct PythonForeignPredicate { + fp: PyObject, + name: String, + types: Vec, + num_bounded: usize, +} + +impl PythonForeignPredicate { + pub fn new(fp: PyObject) -> Self { + let name = Python::with_gil(|py| { + fp + .getattr(py, "name") + .expect("Cannot get foreign predicate name") + .extract(py) + .expect("Foreign predicate name cannot be extracted into String") + }); + + let types = Python::with_gil(|py| { + // Call `all_argument_types` function of the Python object + let func: PyObject = fp + .getattr(py, "all_argument_types") + .expect("Cannot get all_argument_types function") + .extract(py) + .expect("Cannot extract function into PyObject"); + + // Invoke the function + let py_types: Vec = func.call0(py).expect("Cannot call function").extract(py).expect("Cannot extract into PyList"); + + // Convert the Python types into Scallop types + py_types.into_iter().map(|py_type| py_param_type_to_fp_param_type(py_type, py)).collect() + }); + + let num_bounded: usize = Python::with_gil(|py| { + let func: PyObject = fp + .getattr(py, "num_bounded") + .expect("Cannot get num_bounded function") + .extract(py) + .expect("Cannot extract function into PyObject"); + + // Invoke the function + func.call0(py).expect("Cannot call function").extract(py).expect("Cannot extract into usize") + }); + + Self { + fp, + name, + types, + num_bounded, + } + } + + fn output_tuple_type(&self) -> TupleType { + self.types.iter().skip(self.num_bounded).cloned().collect() + } +} + +impl ForeignPredicate for PythonForeignPredicate { + fn name(&self) -> String { + self.name.clone() + } + + fn arity(&self) -> usize { + self.types.len() + } + + fn argument_type(&self, i: usize) -> ValueType { + self.types[i].clone() + } + + fn num_bounded(&self) -> usize { + self.num_bounded + } + + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + Python::with_gil(|py| { + // Construct the arguments + let args: Vec> = bounded.iter().map(to_python_value).collect(); + let args_tuple = PyTuple::new(py, args); + + // Invoke the function + let maybe_result = self.fp.call1(py, args_tuple).ok(); + + // Turn the result back to Scallop values + if let Some(result) = maybe_result { + let output_tuple_type = self.output_tuple_type(); + let elements: Vec<(&PyAny, &PyAny)> = result.extract(py).expect("Cannot extract into list of elements"); + let internal: Option> = elements + .into_iter() + .map(|(py_tag, py_tup)| { + let tag = from_python_input_tag(py_tag).ok()?; + let tuple = from_python_tuple(py_tup, &output_tuple_type).ok()?; + Some((tag, tuple.as_values())) + }) + .collect(); + if let Some(e) = internal { + e + } else { + vec![] + } + } else { + vec![] + } + }) + } +} + +fn py_param_type_to_fp_param_type(obj: PyObject, py: Python<'_>) -> ValueType { + let param_type: String = obj + .getattr(py, "type") + .expect("Cannot get param type") + .extract(py) + .expect("Cannot extract into String"); + match param_type.as_str() { + "i8" => ValueType::I8, + "i16" => ValueType::I16, + "i32" => ValueType::I32, + "i64" => ValueType::I64, + "i128" => ValueType::I128, + "isize" => ValueType::ISize, + "u8" => ValueType::U8, + "u16" => ValueType::U16, + "u32" => ValueType::U32, + "u64" => ValueType::U64, + "u128" => ValueType::U128, + "usize" => ValueType::USize, + "f32" => ValueType::F32, + "f64" => ValueType::F64, + "bool" => ValueType::Bool, + "char" => ValueType::Char, + "String" => ValueType::String, + "DateTime" => ValueType::DateTime, + "Duration" => ValueType::Duration, + _ => panic!("Unknown type {}", param_type), + } +} diff --git a/etc/scallopy/src/lib.rs b/etc/scallopy/src/lib.rs index d56f113..1deecc0 100644 --- a/etc/scallopy/src/lib.rs +++ b/etc/scallopy/src/lib.rs @@ -3,8 +3,10 @@ mod context; mod custom_tag; mod error; mod foreign_function; +mod foreign_predicate; mod io; mod provenance; +mod tag; mod tuple; use pyo3::prelude::*; diff --git a/etc/scallopy/src/provenance.rs b/etc/scallopy/src/provenance.rs index 596bee2..26416f5 100644 --- a/etc/scallopy/src/provenance.rs +++ b/etc/scallopy/src/provenance.rs @@ -50,7 +50,7 @@ pub trait PythonProvenance: Provenance { fn process_py_tag(tag: &PyAny) -> PyResult>; /// Convert an output collection into a python collection enum - fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum; + fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum; /// Convert an output tag into a python object fn to_output_py_tag(tag: &Self::OutputTag) -> Py; @@ -65,7 +65,7 @@ impl PythonProvenance for unit::UnitProvenance { Ok(None) } - fn to_collection_enum(col: Arc>, _: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, _: &Self) -> CollectionEnum { CollectionEnum::Unit { collection: col } } @@ -74,13 +74,13 @@ impl PythonProvenance for unit::UnitProvenance { } } -impl PythonProvenance for proofs::ProofsProvenance { +impl PythonProvenance for proofs::ProofsProvenance { fn process_py_tag(disj_id: &PyAny) -> PyResult> { let disj_id: usize = disj_id.extract()?; Ok(Some(proofs::ProofsInputTag::Exclusive(disj_id))) } - fn to_collection_enum(col: Arc>, _: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, _: &Self) -> CollectionEnum { CollectionEnum::Proofs { collection: col } } @@ -101,7 +101,7 @@ impl PythonProvenance for min_max_prob::MinMaxProbProvenance { tag.extract().map(Some) } - fn to_collection_enum(col: Arc>, _: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, _: &Self) -> CollectionEnum { CollectionEnum::MinMaxProb { collection: col } } @@ -115,7 +115,7 @@ impl PythonProvenance for add_mult_prob::AddMultProbProvenance { tag.extract().map(Some) } - fn to_collection_enum(col: Arc>, _: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, _: &Self) -> CollectionEnum { CollectionEnum::AddMultProb { collection: col } } @@ -134,10 +134,9 @@ impl PythonProvenance for top_k_proofs::TopKProofsProvenance { } } - fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, _: &Self) -> CollectionEnum { CollectionEnum::TopKProofs { collection: col.clone(), - tags: ctx.probs.clone(), } } @@ -156,10 +155,9 @@ impl PythonProvenance for top_bottom_k_clauses::TopBottomKClausesProvenance>, ctx: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, _: &Self) -> CollectionEnum { CollectionEnum::TopBottomKClauses { collection: col.clone(), - tags: ctx.probs.clone(), } } @@ -172,13 +170,13 @@ impl PythonProvenance for diff_min_max_prob::DiffMinMaxProbProvenance, fn process_py_tag(tag: &PyAny) -> PyResult> { let prob: f64 = tag.extract()?; let tag: Py = tag.into(); - Ok(Some((prob, tag).into())) + Ok(Some((prob, Some(tag)).into())) } - fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { CollectionEnum::DiffMinMaxProb { collection: col, - tags: ctx.diff_probs.clone(), + tags: ctx.storage.clone(), } } @@ -195,10 +193,10 @@ impl PythonProvenance for diff_add_mult_prob::DiffAddMultProbProvenance PyResult> { let prob: f64 = tag.extract()?; let tag: Py = tag.into(); - Ok(Some((prob, tag).into())) + Ok(Some((prob, Some(tag)).into())) } - fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { CollectionEnum::DiffAddMultProb { collection: col, tags: ctx.storage.clone(), @@ -214,10 +212,10 @@ impl PythonProvenance for diff_nand_mult_prob::DiffNandMultProbProvenance PyResult> { let prob: f64 = tag.extract()?; let tag: Py = tag.into(); - Ok(Some((prob, tag).into())) + Ok(Some((prob, Some(tag)).into())) } - fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { CollectionEnum::DiffNandMultProb { collection: col, tags: ctx.storage.clone(), @@ -233,10 +231,10 @@ impl PythonProvenance for diff_max_mult_prob::DiffMaxMultProbProvenance PyResult> { let prob: f64 = tag.extract()?; let tag: Py = tag.into(); - Ok(Some((prob, tag).into())) + Ok(Some((prob, Some(tag)).into())) } - fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { CollectionEnum::DiffMaxMultProb { collection: col, tags: ctx.storage.clone(), @@ -252,10 +250,10 @@ impl PythonProvenance for diff_nand_min_prob::DiffNandMinProbProvenance PyResult> { let prob: f64 = tag.extract()?; let tag: Py = tag.into(); - Ok(Some((prob, tag).into())) + Ok(Some((prob, Some(tag)).into())) } - fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { CollectionEnum::DiffNandMinProb { collection: col, tags: ctx.storage.clone(), @@ -278,10 +276,10 @@ impl PythonProvenance for diff_sample_k_proofs::DiffSampleKProofsProvenance>, ctx: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { CollectionEnum::DiffSampleKProofs { collection: col, - tags: ctx.diff_probs.clone(), + tags: ctx.storage.clone_rc(), } } @@ -301,10 +299,10 @@ impl PythonProvenance for diff_top_k_proofs::DiffTopKProofsProvenance, } } - fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { CollectionEnum::DiffTopKProofs { collection: col, - tags: ctx.diff_probs.clone(), + tags: ctx.storage.clone_rc(), } } @@ -324,10 +322,10 @@ impl PythonProvenance for diff_top_k_proofs_indiv::DiffTopKProofsIndivProvenance } } - fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { CollectionEnum::DiffTopKProofsIndiv { collection: col, - tags: ctx.diff_probs.clone(), + tags: ctx.storage.clone_rc(), } } @@ -347,10 +345,10 @@ impl PythonProvenance for diff_top_bottom_k_clauses::DiffTopBottomKClausesProven } } - fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { CollectionEnum::DiffTopBottomKClauses { collection: col, - tags: ctx.diff_probs.clone(), + tags: ctx.storage.clone_rc(), } } @@ -365,7 +363,7 @@ impl PythonProvenance for custom_tag::CustomProvenance { Ok(Some(tag)) } - fn to_collection_enum(col: Arc>, _: &Self) -> CollectionEnum { + fn to_collection_enum(col: Arc>, _: &Self) -> CollectionEnum { CollectionEnum::Custom { collection: col } } diff --git a/etc/scallopy/src/tag.rs b/etc/scallopy/src/tag.rs new file mode 100644 index 0000000..c15c804 --- /dev/null +++ b/etc/scallopy/src/tag.rs @@ -0,0 +1,36 @@ +use pyo3::types::*; +use pyo3::*; + +use scallop_core::common::input_tag::*; + +use super::error::*; + +#[derive(FromPyObject)] +enum PythonInputTag<'a> { + /// Boolean tag + Bool(bool), + + /// Exclusion id tag + Exclusive(usize), + + /// Float tag + Float(f64), + + /// Tuple of (f64, usize) where `usize` is the exclusion id + ExclusiveFloat(f64, usize), + + /// Catch all tag + #[pyo3(transparent)] + CatchAll(&'a PyAny), +} + +pub fn from_python_input_tag(tag: &PyAny) -> Result { + let py_input_tag: PythonInputTag = tag.extract()?; + match py_input_tag { + PythonInputTag::Bool(b) => Ok(DynamicInputTag::Bool(b)), + PythonInputTag::Exclusive(e) => Ok(DynamicInputTag::Exclusive(e)), + PythonInputTag::Float(f) => Ok(DynamicInputTag::Float(f)), + PythonInputTag::ExclusiveFloat(f, e) => Ok(DynamicInputTag::ExclusiveFloat(f, e)), + PythonInputTag::CatchAll(_) => Err(BindingError::InvalidInputTag), + } +} diff --git a/etc/scallopy/src/tuple.rs b/etc/scallopy/src/tuple.rs index 1c936c4..8cd205c 100644 --- a/etc/scallopy/src/tuple.rs +++ b/etc/scallopy/src/tuple.rs @@ -1,12 +1,13 @@ // use std::rc::Rc; -use pyo3::exceptions::PyIndexError; +use pyo3::exceptions::{PyIndexError, PyTypeError}; use pyo3::{prelude::*, types::PyTuple}; use scallop_core::common::tuple::Tuple; use scallop_core::common::tuple_type::TupleType; use scallop_core::common::value::Value; use scallop_core::common::value_type::ValueType; +use scallop_core::utils; pub fn from_python_tuple(v: &PyAny, ty: &TupleType) -> PyResult { match ty { @@ -61,6 +62,8 @@ pub fn to_python_value(val: &Value) -> Py { Str(s) => Python::with_gil(|py| s.to_object(py)), String(s) => Python::with_gil(|py| s.to_object(py)), // RcString(s) => Python::with_gil(|py| s.to_object(py)), + DateTime(d) => Python::with_gil(|py| d.to_string().to_object(py)), + Duration(d) => Python::with_gil(|py| d.to_string().to_object(py)), } } @@ -87,5 +90,13 @@ pub fn from_python_value(v: &PyAny, ty: &ValueType) -> PyResult { // ValueType::RcString => Ok(Tuple::Value(Value::RcString(Rc::new( // v.extract::()?, // )))), + ValueType::DateTime => { + let dt = utils::parse_date_time_string(v.extract()?).ok_or(PyTypeError::new_err("Cannot parse into DateTime"))?; + Ok(Value::DateTime(dt)) + } + ValueType::Duration => { + let dt = utils::parse_duration_string(v.extract()?).ok_or(PyTypeError::new_err("Cannot parse into Duration"))?; + Ok(Value::Duration(dt)) + } } } diff --git a/etc/scallopy/tests/basics.py b/etc/scallopy/tests/basics.py index 6b97e64..f833353 100644 --- a/etc/scallopy/tests/basics.py +++ b/etc/scallopy/tests/basics.py @@ -92,17 +92,6 @@ def test_top_k_proofs_disjunction(self): ctx.run() self.assertEqual(list(ctx.relation("result")), [(1.0, (False,))]) - def test_edge_path_prob_with_sugar(self): - from scallopy.sugar import Relation - ctx = scallopy.ScallopContext(provenance="minmaxprob") - edge = Relation(ctx, (int, int)) - path = Relation(ctx, (int, int)) - edge |= [(0.5, (0, 1)), (0.5, (1, 2))] - path["a", "c"] |= edge["a", "c"] - path["a", "c"] |= edge["a", "b"] & path["b", "c"] - ctx.run() - self.assertEqual(list(path), [(0.5, (0, 1)), (0.5, (0, 2)), (0.5, (1, 2))]) - if __name__ == "__main__": unittest.main() diff --git a/etc/scallopy/tests/configurations.py b/etc/scallopy/tests/configurations.py new file mode 100644 index 0000000..b2b7be0 --- /dev/null +++ b/etc/scallopy/tests/configurations.py @@ -0,0 +1,23 @@ +import unittest + +import scallopy + +class ConfigurationTests(unittest.TestCase): + def test_iter_limit_1(self): + ctx = scallopy.ScallopContext() + ctx.set_iter_limit(2) + ctx.add_relation("edge", (int, int)) + ctx.add_facts("edge", [(0, 1), (1, 2), (2, 3), (3, 4)]) + ctx.add_rule("path(a, c) = edge(a, c) or path(a, b) and edge(b, c)") + ctx.run() + assert list(ctx.relation("path")) == [(0, 1), (0, 2), (1, 2), (1, 3), (2, 3), (2, 4), (3, 4)] + + def test_iter_limit_2(self): + ctx = scallopy.ScallopContext() + ctx.set_iter_limit(1) + ctx.set_iter_limit(None) + ctx.add_relation("edge", (int, int)) + ctx.add_facts("edge", [(0, 1), (1, 2), (2, 3)]) + ctx.add_rule("path(a, c) = edge(a, c) or path(a, b) and edge(b, c)") + ctx.run() + assert list(ctx.relation("path")) == [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)] diff --git a/etc/scallopy/tests/forward.py b/etc/scallopy/tests/forward.py index 110f9ed..5b1f859 100644 --- a/etc/scallopy/tests/forward.py +++ b/etc/scallopy/tests/forward.py @@ -28,8 +28,8 @@ def test_top_k_sample(self): class TestDigitForward(unittest.TestCase): def setUp(self): self.ctx = scallopy.ScallopContext(provenance="diffminmaxprob") - self.ctx.add_relation("digit_1", int, list(range(10))) - self.ctx.add_relation("digit_2", int, list(range(10))) + self.ctx.add_relation("digit_1", int, range(10)) + self.ctx.add_relation("digit_2", int, range(10)) self.ctx.add_rule("sum_2(a + b) = digit_1(a) and digit_2(b)") self.ctx.add_rule("mult_2(a * b) = digit_1(a) and digit_2(b)") @@ -39,39 +39,39 @@ def test_unknown_relation_1(self): def test_normal(self): forward = self.ctx.forward_function("sum_2", list(range(19))) - digit_1 = torch.randn((16, 10)) - digit_2 = torch.randn((16, 10)) + digit_1 = torch.softmax(torch.randn((16, 10)), dim=1) + digit_2 = torch.softmax(torch.randn((16, 10)), dim=1) sum_2 = forward(digit_1=digit_1, digit_2=digit_2) self.assertEqual(sum_2.shape, (16, 19)) def test_no_output_mapping(self): forward = self.ctx.forward_function("sum_2") - digit_1 = torch.randn((16, 10)) - digit_2 = torch.randn((16, 10)) + digit_1 = torch.softmax(torch.randn((16, 10)), dim=1) + digit_2 = torch.softmax(torch.randn((16, 10)), dim=1) (result_mapping, result_tensor) = forward(digit_1=digit_1, digit_2=digit_2) self.assertEqual(set(result_mapping), set([(i,) for i in range(19)])) self.assertEqual(result_tensor.shape, (16, 19)) def test_single_dispatch(self): forward = self.ctx.forward_function("sum_2", dispatch="single") - digit_1 = torch.randn((16, 10)) - digit_2 = torch.randn((16, 10)) + digit_1 = torch.softmax(torch.randn((16, 10)), dim=1) + digit_2 = torch.softmax(torch.randn((16, 10)), dim=1) (result_mapping, result_tensor) = forward(digit_1=digit_1, digit_2=digit_2) self.assertEqual(set(result_mapping), set([(i,) for i in range(19)])) self.assertEqual(result_tensor.shape, (16, 19)) def test_serial_dispatch(self): forward = self.ctx.forward_function("sum_2", dispatch="serial") - digit_1 = torch.randn((16, 10)) - digit_2 = torch.randn((16, 10)) + digit_1 = torch.softmax(torch.randn((16, 10)), dim=1) + digit_2 = torch.softmax(torch.randn((16, 10)), dim=1) (result_mapping, result_tensor) = forward(digit_1=digit_1, digit_2=digit_2) self.assertEqual(set(result_mapping), set([(i,) for i in range(19)])) self.assertEqual(result_tensor.shape, (16, 19)) def test_multi_result(self): forward = self.ctx.forward_function(output_mappings={"sum_2": list(range(19)), "mult_2": list(range(100))}) - digit_1 = torch.randn((16, 10)) - digit_2 = torch.randn((16, 10)) + digit_1 = torch.softmax(torch.randn((16, 10)), dim=1) + digit_2 = torch.softmax(torch.randn((16, 10)), dim=1) result = forward(digit_1=digit_1, digit_2=digit_2) sum_2 = result["sum_2"] mult_2 = result["mult_2"] @@ -80,8 +80,8 @@ def test_multi_result(self): def test_multi_result_single_dispatch(self): forward = self.ctx.forward_function(output_mappings={"sum_2": list(range(19)), "mult_2": list(range(100))}, dispatch="single") - digit_1 = torch.randn((16, 10)) - digit_2 = torch.randn((16, 10)) + digit_1 = torch.softmax(torch.randn((16, 10)), dim=1) + digit_2 = torch.softmax(torch.randn((16, 10)), dim=1) result = forward(digit_1=digit_1, digit_2=digit_2) sum_2 = result["sum_2"] mult_2 = result["mult_2"] @@ -90,8 +90,8 @@ def test_multi_result_single_dispatch(self): def test_multi_result_non_parallel_dispatch(self): forward = self.ctx.forward_function(output_mappings={"sum_2": list(range(19)), "mult_2": list(range(100))}, dispatch="serial") - digit_1 = torch.randn((16, 10)) - digit_2 = torch.randn((16, 10)) + digit_1 = torch.softmax(torch.randn((16, 10)), dim=1) + digit_2 = torch.softmax(torch.randn((16, 10)), dim=1) result = forward(digit_1=digit_1, digit_2=digit_2) sum_2 = result["sum_2"] mult_2 = result["mult_2"] @@ -100,8 +100,8 @@ def test_multi_result_non_parallel_dispatch(self): def test_multi_result_maybe_with_output_mapping(self): forward = self.ctx.forward_function(output_mappings={"sum_2": None, "mult_2": list(range(100))}) - digit_1 = torch.randn((16, 10)) - digit_2 = torch.randn((16, 10)) + digit_1 = torch.softmax(torch.randn((16, 10)), dim=1) + digit_2 = torch.softmax(torch.randn((16, 10)), dim=1) result = forward(digit_1=digit_1, digit_2=digit_2) (sum_2_mapping, sum_2_tensor) = result["sum_2"] mult_2 = result["mult_2"] @@ -182,6 +182,28 @@ def test_forward(self): result = compute_sum_2(digit_a=digit_a, digit_b=digit_b) self.assertEqual(result.shape, (16, 19)) + def test_forward_with_probabilities(self): + def process_input(digit_tensor): + r = [] + (batch_size, _) = digit_tensor.shape + for task_id in range(batch_size): + r.append([(p, (i,)) for (i, p) in enumerate(digit_tensor[task_id])]) + return r + + sum_2_program = """ + type digit_a(usize), digit_b(usize) + rel sum_2(a + b) = digit_a(a), digit_b(b) + """ + compute_sum_2 = scallopy.ScallopForwardFunction( + program=sum_2_program, + provenance="difftopkproofs", + input_mappings={"digit_a": None, "digit_b": None}, + output_mappings={"sum_2": list(range(19))}, + ) + digit_a, digit_b = torch.randn((16, 10)), torch.randn((16, 10)) + result = compute_sum_2(digit_a=process_input(digit_a), digit_b=process_input(digit_b)) + self.assertEqual(result.shape, (16, 19)) + def test_forward_with_non_probabilistic(self): edge_path_program = """ type edge(usize, usize) diff --git a/etc/scallopy/tests/input_mapping.py b/etc/scallopy/tests/input_mapping.py new file mode 100644 index 0000000..d6dc401 --- /dev/null +++ b/etc/scallopy/tests/input_mapping.py @@ -0,0 +1,220 @@ +import unittest +import torch + +import scallopy + +class TestInputMapping(unittest.TestCase): + def test_construct_list_1(self): + im = scallopy.InputMapping(list(range(10))) + assert im.kind == "list" + assert im.shape == (10,) + assert im.dimension == 1 + assert im.is_singleton == True + + def test_construct_list_1_1(self): + im = scallopy.InputMapping(range(10)) + assert im.kind == "list" + assert im.shape == (10,) + assert im.dimension == 1 + assert im.is_singleton == True + + def test_construct_list_2(self): + im = scallopy.InputMapping([(i,) for i in range(10)]) + assert im.kind == "list" + assert im.shape == (10,) + assert im.dimension == 1 + assert im.is_singleton == False + + def test_construct_list_3(self): + im = scallopy.InputMapping([(i, j) for i in range(5) for j in range(5)]) + assert im.kind == "list" + assert im.shape == (25,) + assert im.dimension == 1 + assert im.is_singleton == False + + def test_construct_list_4(self): + im = scallopy.InputMapping([[(i, j) for i in range(5)] for j in range(5)]) + assert im.kind == "list" + assert im.shape == (5, 5) + assert im.dimension == 2 + assert im.is_singleton == False + + @unittest.expectedFailure + def test_construct_list_failure_1(self): + # Tuple size mismatch + _ = scallopy.InputMapping([(1, 2), (1,)]) + + @unittest.expectedFailure + def test_construct_list_failure_2(self): + # Empty mapping + _ = scallopy.InputMapping([]) + + @unittest.expectedFailure + def test_construct_list_failure_3(self): + # Empty mapping on other dimensions + _ = scallopy.InputMapping([[], [], []]) + + @unittest.expectedFailure + def test_construct_list_failure_4(self): + # Unmatched dimensions + _ = scallopy.InputMapping([[3, 5, 8], [3, 5, 9], [3, 5]]) + + @unittest.expectedFailure + def test_cannot_set_property(self): + im = scallopy.InputMapping(range(10)) + im.dimension = (1, 1, 1, 1) + + def test_construct_tuple_1(self): + im = scallopy.InputMapping(()) + assert im.kind == "tuple" + assert im.shape == () + assert im.dimension == 0 + assert im.is_singleton == False + + def test_construct_tuple_2(self): + im = scallopy.InputMapping((3, 5)) + assert im.kind == "tuple" + assert im.shape == () + assert im.dimension == 0 + assert im.is_singleton == False + + @unittest.expectedFailure + def test_construct_tuple_failure_1(self): + _ = scallopy.InputMapping((3, 5, [])) + + def test_construct_value_1(self): + im = scallopy.InputMapping(3) + assert im.kind == "value" + assert im.shape == () + assert im.dimension == 0 + assert im.is_singleton == True + + def test_construct_dict_1(self): + im = scallopy.InputMapping({0: range(5), 1: range(5)}) + assert im.kind == "dict" + assert im.shape == (5, 5) + assert im.dimension == 2 + assert im.is_singleton == False + + def test_construct_dict_2(self): + im = scallopy.InputMapping({0: range(5), 1: range(5), 2: range(2)}) + assert im.kind == "dict" + assert im.shape == (5, 5, 2) + assert im.dimension == 3 + assert im.is_singleton == False + + def test_construct_dict_3(self): + im = scallopy.InputMapping({0: range(3), 1: ["red", "green", "blue"]}) + assert im.kind == "dict" + assert im.shape == (3, 3) + assert im.dimension == 2 + assert im.is_singleton == False + + @unittest.expectedFailure + def test_construct_dict_failure_1(self): + _ = scallopy.InputMapping({}) + + @unittest.expectedFailure + def test_construct_dict_failure_2(self): + _ = scallopy.InputMapping({1: range(3)}) + + @unittest.expectedFailure + def test_construct_dict_failure_3(self): + _ = scallopy.InputMapping({0: range(3), 1: []}) + + @unittest.expectedFailure + def test_construct_dict_failure_4(self): + _ = scallopy.InputMapping({0: [(1, 3, 5)]}) + + @unittest.expectedFailure + def test_construct_dict_failure_5(self): + _ = scallopy.InputMapping({"1": [3]}) + + @unittest.expectedFailure + def test_construct_dict_failure_6(self): + _ = scallopy.InputMapping({-10: [3]}) + + def test_process_tensor_1(self): + im = scallopy.InputMapping({0: range(5), 1: range(5)}) + r = im.process_tensor(torch.zeros((5, 5))) + assert len(r) == 25 + + def test_process_tensor_2(self): + im = scallopy.InputMapping([[(i, j) for j in range(5)] for i in range(5)]) + r = im.process_tensor(torch.zeros((5, 5)), batched=True) + assert len(r) == 1 + assert len(r[0]) == 25 + + def test_process_tensor_3(self): + im = scallopy.InputMapping(range(10)) + r = im.process_tensor(torch.randn((16, 10))) + assert len(r) == 16 + assert len(r[0]) == 10 + + def test_retain_k_1(self): + im = scallopy.InputMapping(range(10), retain_k=3) + r = im.process_tensor(torch.randn((10,))) + assert len(r) == 3 + + @unittest.expectedFailure + def test_retain_k_2(self): + _ = scallopy.InputMapping(range(10), retain_k=3, sample_dim=1) + + @unittest.expectedFailure + def test_retain_k_3(self): + _ = scallopy.InputMapping(range(10), retain_k=3, sample_dim=-10) + + def test_mult_dim_retain_k_1(self): + im = scallopy.InputMapping({0: range(5), 1: range(5)}, retain_k=3) + r = im.process_tensor(torch.randn((5, 5))) + assert len(r) == 3 + + def test_mult_dim_retain_k_2(self): + im = scallopy.InputMapping({0: range(5), 1: range(3)}, retain_k=2, sample_dim=1) + r = im.process_tensor(torch.randn((5, 3))) + assert len(r) == 10 + + def test_mult_dim_retain_k_3(self): + im = scallopy.InputMapping({0: range(5), 1: range(3)}, retain_k=2, sample_dim=0) + r = im.process_tensor(torch.randn((5, 3))) + assert len(r) == 6 + + def test_retain_threshold_1(self): + im = scallopy.InputMapping(range(10), retain_threshold=0.5) + t = torch.randn(10) + r = im.process_tensor(t) + assert len(r) == len(t[t > 0.5]) + + def test_retain_threshold_2(self): + im = scallopy.InputMapping({0: range(5), 1: range(3)}, retain_threshold=0.5) + t = torch.randn((5, 3)) + r = im.process_tensor(t) + assert len(r) == len(t[t > 0.5]) + + def test_disjunction_1(self): + im = scallopy.InputMapping(range(10), disjunctive=True, supports_disjunctions=True) + t = torch.randn((10,)) + r = im.process_tensor(t) + for ((_, did), _) in r: + assert did == 0 + + def test_disjunction_2(self): + im = scallopy.InputMapping({0: range(5), 1: range(5)}, disjunctive=True, supports_disjunctions=True) + t = torch.randn((5, 5)) + r = im.process_tensor(t) + for ((_, did), _) in r: + assert did == 0 + + def test_disjunction_3(self): + im = scallopy.InputMapping({0: range(5), 1: range(5)}, disjunctive_dim=1, supports_disjunctions=True) + t = torch.randn((5, 5)) + r = im.process_tensor(t) + for ((_, did), (i, _)) in r: + assert did == i + + def test_disjunction_4(self): + im = scallopy.InputMapping({0: range(5), 1: range(5)}, disjunctive_dim=0, supports_disjunctions=True) + t = torch.randn((5, 5)) + r = im.process_tensor(t) + for ((_, did), (_, j)) in r: + assert did == j diff --git a/etc/scallopy/tests/test.py b/etc/scallopy/tests/test.py index be77a70..189b695 100644 --- a/etc/scallopy/tests/test.py +++ b/etc/scallopy/tests/test.py @@ -1,9 +1,11 @@ import unittest from basics import * +from configurations import * from failure import * from forward import * from foreign_function import * +from input_mapping import * if __name__ == '__main__': unittest.main() diff --git a/etc/sclc/Cargo.toml b/etc/sclc/Cargo.toml index 78c3e7c..68be7b9 100644 --- a/etc/sclc/Cargo.toml +++ b/etc/sclc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sclc-core" -version = "0.1.7" +version = "0.1.8" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/sclc/src/pylib.rs b/etc/sclc/src/pylib.rs index 168ae1c..71021a7 100644 --- a/etc/sclc/src/pylib.rs +++ b/etc/sclc/src/pylib.rs @@ -613,7 +613,7 @@ fn generate_helper_functions() -> TokenStream { } } - impl PythonProvenance for proofs::ProofsProvenance { + impl PythonProvenance for proofs::ProofsProvenance { fn process_py_tag(disj_id: &PyAny) -> Result, BindingError> { let disj_id: usize = disj_id.extract()?; Ok(Some(proofs::ProofsInputTag::Exclusive(disj_id))) @@ -685,7 +685,7 @@ fn generate_helper_functions() -> TokenStream { fn process_py_tag(tag: &PyAny) -> Result, BindingError> { let prob: f64 = tag.extract().map_err(BindingError::from)?; let tag: Py = tag.into(); - Ok(Some((prob, tag).into())) + Ok(Some((prob, Some(tag)).into())) } fn to_output_py_tag(tag: &Self::OutputTag) -> Py { @@ -701,7 +701,7 @@ fn generate_helper_functions() -> TokenStream { fn process_py_tag(tag: &PyAny) -> Result, BindingError> { let prob: f64 = tag.extract().map_err(BindingError::from)?; let tag: Py = tag.into(); - Ok(Some((prob, tag).into())) + Ok(Some((prob, Some(tag)).into())) } fn to_output_py_tag(tag: &Self::OutputTag) -> Py { @@ -717,7 +717,7 @@ fn generate_helper_functions() -> TokenStream { fn process_py_tag(tag: &PyAny) -> Result, BindingError> { let prob: f64 = tag.extract().map_err(BindingError::from)?; let tag: Py = tag.into(); - Ok(Some((prob, tag).into())) + Ok(Some((prob, Some(tag)).into())) } fn to_output_py_tag(tag: &Self::OutputTag) -> Py { @@ -733,7 +733,7 @@ fn generate_helper_functions() -> TokenStream { fn process_py_tag(tag: &PyAny) -> Result, BindingError> { let prob: f64 = tag.extract().map_err(BindingError::from)?; let tag: Py = tag.into(); - Ok(Some((prob, tag).into())) + Ok(Some((prob, Some(tag)).into())) } fn to_output_py_tag(tag: &Self::OutputTag) -> Py { @@ -749,7 +749,7 @@ fn generate_helper_functions() -> TokenStream { fn process_py_tag(tag: &PyAny) -> Result, BindingError> { let prob: f64 = tag.extract()?; let tag: Py = tag.into(); - Ok(Some((prob, tag).into())) + Ok(Some((prob, Some(tag)).into())) } fn to_output_py_tag(tag: &Self::OutputTag) -> Py { diff --git a/etc/scli/Cargo.toml b/etc/scli/Cargo.toml index f4ce5d8..0c966fe 100644 --- a/etc/scli/Cargo.toml +++ b/etc/scli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scli" -version = "0.1.7" +version = "0.1.8" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/scli/src/main.rs b/etc/scli/src/main.rs index 5009350..7ad4f0d 100644 --- a/etc/scli/src/main.rs +++ b/etc/scli/src/main.rs @@ -148,7 +148,7 @@ fn main() { interpret(ctx, &opt.input, integrate_opt, predicate_set, monitor_options); } "proofs" => { - let ctx = provenance::proofs::ProofsProvenance::default(); + let ctx = provenance::proofs::ProofsProvenance::::default(); interpret(ctx, &opt.input, integrate_opt, predicate_set, monitor_options); } "minmaxprob" => { @@ -181,8 +181,13 @@ fn interpret( predicate_set: PredicateSet, monitor_options: MonitorOptions, ) { - let mut interpret_ctx = integrate::InterpretContext::<_, RcFamily>::new_from_file_with_options(file_name, prov, opt) - .expect("Initialization Error"); + let mut interpret_ctx = match integrate::InterpretContext::<_, RcFamily>::new_from_file_with_options(file_name, prov, opt) { + Ok(ctx) => ctx, + Err(err) => { + println!("{}", err); + return; + } + }; // Check if we have any specified monitors, and run the program if !monitor_options.needs_monitor() { diff --git a/etc/sclrepl/Cargo.toml b/etc/sclrepl/Cargo.toml index 3674535..49f2938 100644 --- a/etc/sclrepl/Cargo.toml +++ b/etc/sclrepl/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sclrepl" -version = "0.1.7" +version = "0.1.8" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/sclrepl/src/main.rs b/etc/sclrepl/src/main.rs index 006f033..b8fe75a 100644 --- a/etc/sclrepl/src/main.rs +++ b/etc/sclrepl/src/main.rs @@ -113,7 +113,7 @@ where .iter() .filter_map(|item| { if let compiler::front::Item::QueryDecl(q) = item { - Some(q.query().relation_name()) + Some(q.query().create_relation_name()) } else { None } diff --git a/etc/vscode-scl/language-configuration.json b/etc/vscode-scl/language-configuration.json index 8f162a0..7301eff 100644 --- a/etc/vscode-scl/language-configuration.json +++ b/etc/vscode-scl/language-configuration.json @@ -17,6 +17,8 @@ ["[", "]"], ["(", ")"], ["\"", "\""], + ["d\"", "\""], + ["t\"", "\""], ["'", "'"] ], // symbols that can be used to surround a selection @@ -25,6 +27,8 @@ ["[", "]"], ["(", ")"], ["\"", "\""], + ["d\"", "\""], + ["t\"", "\""], ["'", "'"] ] -} \ No newline at end of file +} diff --git a/etc/vscode-scl/package.json b/etc/vscode-scl/package.json index 371b033..598bfa2 100644 --- a/etc/vscode-scl/package.json +++ b/etc/vscode-scl/package.json @@ -2,7 +2,7 @@ "name": "scallop", "displayName": "Scallop Language Syntax Highlight", "description": "Scallop Language Support", - "version": "0.0.5", + "version": "0.0.7", "repository": { "type": "git", "url": "https://github.com/liby99/scallop-v2" diff --git a/etc/vscode-scl/syntaxes/scallop.tmLanguage.json b/etc/vscode-scl/syntaxes/scallop.tmLanguage.json index fc29043..3d294cc 100644 --- a/etc/vscode-scl/syntaxes/scallop.tmLanguage.json +++ b/etc/vscode-scl/syntaxes/scallop.tmLanguage.json @@ -9,6 +9,12 @@ { "include": "#strings" }, + { + "include": "#datetime" + }, + { + "include": "#duration" + }, { "include": "#chars" }, @@ -234,6 +240,28 @@ } ] }, + "datetime": { + "name": "string.quoted.double.scallop", + "begin": "t\"", + "end": "\"", + "patterns": [ + { + "name": "constant.character.escape.scallop", + "match": "\\\\." + } + ] + }, + "duration": { + "name": "string.quoted.double.scallop", + "begin": "d\"", + "end": "\"", + "patterns": [ + { + "name": "constant.character.escape.scallop", + "match": "\\\\." + } + ] + }, "chars": { "name": "string.quoted.single.scallop", "begin": "'", diff --git a/lib/ram/Cargo.toml b/lib/ram/Cargo.toml new file mode 100644 index 0000000..da05f06 --- /dev/null +++ b/lib/ram/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "ram" +version = "0.1.0" +authors = ["Ziyang Li "] +edition = "2018" + +[dependencies] +egg = "0.9" diff --git a/lib/ram/src/generic_tuple.rs b/lib/ram/src/generic_tuple.rs new file mode 100644 index 0000000..ea6b005 --- /dev/null +++ b/lib/ram/src/generic_tuple.rs @@ -0,0 +1,132 @@ +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum GenericTuple { + Value(T), + Tuple(Box<[GenericTuple]>), +} + +impl GenericTuple { + pub fn unit() -> Self { + Self::Tuple(Box::new([])) + } + + pub fn empty() -> Self { + Self::Tuple(Box::new([])) + } + + pub fn is_empty(&self) -> bool { + if let Self::Tuple(ts) = self { + ts.is_empty() + } else { + false + } + } +} + +impl std::ops::Index for GenericTuple { + type Output = Self; + + fn index(&self, i: usize) -> &Self::Output { + match self { + Self::Tuple(t) => &t[i], + _ => panic!("Cannot access tuple value with `{:?}`", i), + } + } +} + +impl From<()> for GenericTuple { + fn from(_: ()) -> Self { + Self::Tuple(Box::new([])) + } +} + +impl From<(A,)> for GenericTuple +where + A: Into>, +{ + fn from((a,): (A,)) -> Self { + Self::Tuple(Box::new([a.into()])) + } +} + +impl From<(A, B)> for GenericTuple +where + A: Into>, + B: Into>, +{ + fn from((a, b): (A, B)) -> Self { + Self::Tuple(Box::new([a.into(), b.into()])) + } +} + +impl From<(A, B, C)> for GenericTuple +where + A: Into>, + B: Into>, + C: Into>, +{ + fn from((a, b, c): (A, B, C)) -> Self { + Self::Tuple(Box::new([a.into(), b.into(), c.into()])) + } +} + +impl From<(A, B, C, D)> for GenericTuple +where + A: Into>, + B: Into>, + C: Into>, + D: Into>, +{ + fn from((a, b, c, d): (A, B, C, D)) -> Self { + Self::Tuple(Box::new([a.into(), b.into(), c.into(), d.into()])) + } +} + +impl From<(A, B, C, D, E)> for GenericTuple +where + A: Into>, + B: Into>, + C: Into>, + D: Into>, + E: Into>, +{ + fn from((a, b, c, d, e): (A, B, C, D, E)) -> Self { + Self::Tuple(Box::new([a.into(), b.into(), c.into(), d.into(), e.into()])) + } +} + +impl From<(A, B, C, D, E, F)> for GenericTuple +where + A: Into>, + B: Into>, + C: Into>, + D: Into>, + E: Into>, + F: Into>, +{ + fn from((a, b, c, d, e, f): (A, B, C, D, E, F)) -> Self { + Self::Tuple(Box::new([a.into(), b.into(), c.into(), d.into(), e.into(), f.into()])) + } +} + +impl From<(A, B, C, D, E, F, G)> for GenericTuple +where + A: Into>, + B: Into>, + C: Into>, + D: Into>, + E: Into>, + F: Into>, + G: Into>, +{ + fn from((a, b, c, d, e, f, g): (A, B, C, D, E, F, G)) -> Self { + Self::Tuple(Box::new([ + a.into(), + b.into(), + c.into(), + d.into(), + e.into(), + f.into(), + g.into(), + ])) + } +} diff --git a/lib/ram/src/language.rs b/lib/ram/src/language.rs new file mode 100644 index 0000000..463753e --- /dev/null +++ b/lib/ram/src/language.rs @@ -0,0 +1,143 @@ +use egg::{rewrite as rw, *}; + +use super::tuple_type::TupleType; + +// Define the RAM language +define_language! { + pub enum Ram { + // Relational Predicate + Predicate(String, Id), + + // Relational Algebra Operations + "empty" = Empty, + "filter" = Filter([Id; 2]), + "project" = Project([Id; 2]), + "sorted" = Sorted(Id), + "product" = Product([Id; 2]), + "join" = Join([Id; 2]), + + // Tuple operations + "apply" = Apply([Id; 2]), + "cons" = Cons([Id; 2]), + "nil" = Nil, + + // Value operations + "+" = Add([Id; 2]), + "-" = Sub([Id; 2]), + "*" = Mult([Id; 2]), + "/" = Div([Id; 2]), + "&&" = And([Id; 2]), + "||" = Or([Id; 2]), + "!" = Not(Id), + + // Any symbol + Bool(bool), + Number(i32), + Symbol(Symbol), + } +} + +pub type EGraph = egg::EGraph; + +fn var(s: &str) -> Var { + s.parse().unwrap() +} + +fn is_constant(_v: Var) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { + move |_egraph, _, _subst| { + false + } +} + +/// All the rewrite rules for the language +pub fn ram_rewrite_rules() -> Vec> { + vec![ + // Relational level rewrites + rw!("filter-cascade"; "(filter (filter ?d ?a) ?b)" => "(filter ?d (&& ?a ?b))"), + rw!("filter-true"; "(filter ?d true)" => "?d"), + rw!("filter-false"; "(filter ?d false)" => "empty"), + rw!("project-cascade"; "(project (project ?d ?a) ?b)" => "(project ?d (apply ?b ?a))"), + + // Tuple level application rewrites + rw!("access-nil"; "(apply nil ?a)" => "?a"), + rw!("access-tuple-base"; "(apply (cons 0 ?x) (cons ?a ?b))" => "(apply ?x ?a)"), + rw!("access-tuple-ind"; "(apply (cons ?n ?x) (cons ?a ?b))" => "(apply (cons (- ?n 1) ?x) ?b)"), + rw!("apply-tuple-nil"; "(apply nil ?a)" => "nil"), + rw!("apply-tuple-cons"; "(apply (cons ?a ?b) ?c)" => "(cons (apply ?a ?c) (apply ?b ?c))"), + + // Expression level application rewrites + rw!("apply-add"; "(apply (+ ?a ?b) ?t)" => "(+ (apply ?a ?t) (apply ?b ?t))"), + rw!("apply-sub"; "(apply (- ?a ?b) ?t)" => "(- (apply ?a ?t) (apply ?b ?t))"), + rw!("apply-mult"; "(apply (* ?a ?b) ?t)" => "(* (apply ?a ?t) (apply ?b ?t))"), + rw!("apply-div"; "(apply (/ ?a ?b) ?t)" => "(/ (apply ?a ?t) (apply ?b ?t))"), + rw!("apply-and"; "(apply (&& ?a ?b) ?t)" => "(&& (apply ?a ?t) (apply ?b ?t))"), + rw!("apply-or"; "(apply (|| ?a ?b) ?t)" => "(|| (apply ?a ?t) (apply ?b ?t))"), + rw!("apply-not"; "(apply (! ?a) ?t)" => "(! (apply ?a ?t))"), + rw!("apply-const"; "(apply ?e ?t)" => "?e" if is_constant(var("?e"))), + + // Value level rewrites + rw!("add-comm"; "(+ ?a ?b)" => "(+ ?b ?a)"), + rw!("add-identity"; "(+ ?a 0)" => "?a"), + rw!("mult-comm"; "(* ?a ?b)" => "(* ?b ?a)"), + rw!("mult-identity"; "(* ?a 1)" => "?a"), + rw!("and-comm"; "(&& ?a ?b)" => "(&& ?b ?a)"), + rw!("and-identity"; "(&& ?a true)" => "?a"), + rw!("and-idempotent"; "(&& ?a ?a)" => "?a"), + rw!("or-comm"; "(|| ?a ?b)" => "(|| ?b ?a)"), + rw!("or-identity"; "(|| ?a false)" => "?a"), + rw!("or-idempotent"; "(|| ?a ?a)" => "?a"), + rw!("not-true"; "(! true)" => "false"), + rw!("not-false"; "(! false)" => "true"), + rw!("not-not"; "(! (! ?a))" => "?a"), + + // Simple arithmetic rewrites for index calculations + rw!("dec-1"; "(- 1 1)" => "0"), + rw!("dec-2"; "(- 2 1)" => "1"), + rw!("dec-3"; "(- 3 1)" => "2"), + rw!("dec-4"; "(- 4 1)" => "3"), + rw!("dec-5"; "(- 5 1)" => "4"), + ] +} + +struct RamCostFunction; + +impl CostFunction for RamCostFunction { + type Cost = i32; + + fn cost(&mut self, enode: &Ram, mut costs: C) -> Self::Cost + where + C: FnMut(Id) -> Self::Cost + { + let op_cost = match enode { + Ram::Filter(_) => 100, + Ram::Project(_) => 100, + Ram::Sorted(_) => 100, + Ram::Apply(_) => 10, + _ => 1, + }; + enode.fold(op_cost, |sum, id| sum + costs(id)) + } +} + +pub struct RamNodeData { + pub tuple_type: Option, +} + +/// parse an expression, simplify it using egg, and pretty print it back out +pub fn simplify(s: &str) -> String { + // parse the expression, the type annotation tells it which Language to use + let expr: RecExpr = s.parse().unwrap(); + + // simplify the expression using a Runner, which creates an e-graph with + // the given expression and runs the given rules over it + let rules = ram_rewrite_rules(); + let runner = Runner::default().with_expr(&expr).run(&rules); + + // the Runner knows which e-class the expression given with `with_expr` is in + let root = runner.roots[0]; + + // use an Extractor to pick the best element of the root eclass + let extractor = Extractor::new(&runner.egraph, RamCostFunction); + let (_, best) = extractor.find_best(root); + best.to_string() +} diff --git a/lib/ram/src/lib.rs b/lib/ram/src/lib.rs new file mode 100644 index 0000000..791daf7 --- /dev/null +++ b/lib/ram/src/lib.rs @@ -0,0 +1,6 @@ +pub mod generic_tuple; +pub mod tuple_type; +mod language; +pub mod value_type; + +pub use language::*; diff --git a/lib/ram/src/tuple_type.rs b/lib/ram/src/tuple_type.rs new file mode 100644 index 0000000..cac2512 --- /dev/null +++ b/lib/ram/src/tuple_type.rs @@ -0,0 +1,170 @@ +use super::generic_tuple::GenericTuple; +use super::value_type::{FromType, ValueType}; + +pub type TupleType = GenericTuple; + +impl std::fmt::Debug for TupleType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Tuple(ts) => f.write_fmt(format_args!( + "({})", + ts.iter().map(|t| format!("{t:?}")).collect::>().join(", ") + )), + Self::Value(v) => std::fmt::Debug::fmt(v, f), + } + } +} + +impl std::fmt::Display for TupleType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Tuple(ts) => f.write_fmt(format_args!( + "({})", + ts.iter().map(|t| format!("{}", t)).collect::>().join(", ") + )), + Self::Value(v) => std::fmt::Debug::fmt(v, f), + } + } +} + +impl TupleType { + pub fn from_types(types: &[ValueType], can_be_singleton: bool) -> Self { + if types.len() == 1 && can_be_singleton { + Self::Value(types[0].clone()) + } else { + Self::Tuple(types.iter().cloned().map(Self::Value).collect()) + } + } +} + +impl FromType for TupleType +where + ValueType: FromType, +{ + fn from_type() -> Self { + Self::Value(>::from_type()) + } +} + +impl FromType<()> for TupleType { + fn from_type() -> Self { + Self::Tuple(Box::new([])) + } +} + +impl FromType<(A,)> for TupleType +where + TupleType: FromType, +{ + fn from_type() -> Self { + Self::Tuple(Box::new([>::from_type()])) + } +} + +impl FromType<(A, B)> for TupleType +where + TupleType: FromType, + TupleType: FromType, +{ + fn from_type() -> Self { + Self::Tuple(Box::new([ + >::from_type(), + >::from_type(), + ])) + } +} + +impl FromType<(A, B, C)> for TupleType +where + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, +{ + fn from_type() -> Self { + Self::Tuple(Box::new([ + >::from_type(), + >::from_type(), + >::from_type(), + ])) + } +} + +impl FromType<(A, B, C, D)> for TupleType +where + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, +{ + fn from_type() -> Self { + Self::Tuple(Box::new([ + >::from_type(), + >::from_type(), + >::from_type(), + >::from_type(), + ])) + } +} + +impl FromType<(A, B, C, D, E)> for TupleType +where + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, +{ + fn from_type() -> Self { + Self::Tuple(Box::new([ + >::from_type(), + >::from_type(), + >::from_type(), + >::from_type(), + >::from_type(), + ])) + } +} + +impl FromType<(A, B, C, D, E, F)> for TupleType +where + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, +{ + fn from_type() -> Self { + Self::Tuple(Box::new([ + >::from_type(), + >::from_type(), + >::from_type(), + >::from_type(), + >::from_type(), + >::from_type(), + ])) + } +} + +impl FromType<(A, B, C, D, E, F, G)> for TupleType +where + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, + TupleType: FromType, +{ + fn from_type() -> Self { + Self::Tuple(Box::new([ + >::from_type(), + >::from_type(), + >::from_type(), + >::from_type(), + >::from_type(), + >::from_type(), + >::from_type(), + ])) + } +} diff --git a/lib/ram/src/value_type.rs b/lib/ram/src/value_type.rs new file mode 100644 index 0000000..e256f44 --- /dev/null +++ b/lib/ram/src/value_type.rs @@ -0,0 +1,218 @@ +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum ValueType { + I8, + I16, + I32, + I64, + I128, + ISize, + U8, + U16, + U32, + U64, + U128, + USize, + F32, + F64, + Char, + Bool, + Str, + String, +} + +impl ValueType { + pub fn is_numeric(&self) -> bool { + self.is_integer() || self.is_float() + } + + pub fn is_integer(&self) -> bool { + match self { + Self::I8 + | Self::I16 + | Self::I32 + | Self::I64 + | Self::I128 + | Self::ISize + | Self::U8 + | Self::U16 + | Self::U32 + | Self::U64 + | Self::U128 + | Self::USize => true, + _ => false, + } + } + + pub fn is_signed_integer(&self) -> bool { + match self { + Self::I8 | Self::I16 | Self::I32 | Self::I64 | Self::I128 | Self::ISize => true, + _ => false, + } + } + + pub fn is_float(&self) -> bool { + match self { + Self::F32 | Self::F64 => true, + _ => false, + } + } + + pub fn is_boolean(&self) -> bool { + match self { + Self::Bool => true, + _ => false, + } + } + + pub fn is_char(&self) -> bool { + match self { + Self::Char => true, + _ => false, + } + } + + pub fn is_string(&self) -> bool { + match self { + Self::Str | Self::String /* | Self::RcString */ => true, + _ => false, + } + } +} + +impl std::fmt::Display for ValueType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use ValueType::*; + match self { + I8 => f.write_str("i8"), + I16 => f.write_str("i16"), + I32 => f.write_str("i32"), + I64 => f.write_str("i64"), + I128 => f.write_str("i128"), + ISize => f.write_str("isize"), + U8 => f.write_str("u8"), + U16 => f.write_str("u16"), + U32 => f.write_str("u32"), + U64 => f.write_str("u64"), + U128 => f.write_str("u128"), + USize => f.write_str("usize"), + F32 => f.write_str("f32"), + F64 => f.write_str("f64"), + Char => f.write_str("char"), + Bool => f.write_str("bool"), + Str => f.write_str("&str"), + String => f.write_str("String"), + } + } +} + +pub trait FromType { + fn from_type() -> Self; +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::I8 + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::I16 + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::I32 + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::I64 + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::I128 + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::ISize + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::U8 + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::U16 + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::U32 + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::U64 + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::U128 + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::USize + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::F32 + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::F64 + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::Char + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::Bool + } +} + +impl FromType<&'static str> for ValueType { + fn from_type() -> Self { + Self::Str + } +} + +impl FromType for ValueType { + fn from_type() -> Self { + Self::String + } +} diff --git a/lib/ram/tests/test_optim.rs b/lib/ram/tests/test_optim.rs new file mode 100644 index 0000000..b4c99d4 --- /dev/null +++ b/lib/ram/tests/test_optim.rs @@ -0,0 +1,37 @@ +use ram::simplify; + +#[test] +fn test_filter_cascade_1() { + assert_eq!(simplify("(filter (filter ?d ?a) ?b)"), "(filter ?d (&& ?a ?b))") +} + +#[test] +fn test_filter_cascade_2() { + assert_eq!(simplify("(filter (filter (filter ?d ?a) ?b) ?c)"), "(filter ?d (&& ?c (&& ?a ?b)))") +} + +#[test] +fn test_project_cascade_1() { + assert_eq!(simplify("(project (project ?d ?a) ?b)"), "(project ?d (apply ?b ?a))") +} + +#[test] +fn test_project_cascade_2() { + assert_eq!(simplify(r#" + (project + (project + ?d + (tuple-cons + (index-cons 1 index-nil) + (tuple-cons + (index-cons 0 index-nil) + tuple-nil + ) + ) + ) + (- + (index-cons 0 index-nil) + (index-cons 1 index-nil) + ) + )"#), "(project ?d (- (index-cons 1 index-nil) (index-cons 0 index-nil)))") +} diff --git a/makefile b/makefile index d274300..bda871f 100644 --- a/makefile +++ b/makefile @@ -24,8 +24,13 @@ wasm-demo: run-wasm-demo: cd etc/scallop-wasm/demo; python3 -m http.server -py-venv: +init-venv: python3 -m venv .env + .env/bin/pip install --upgrade pip + .env/bin/pip install maturin torch torchvision transformers gym scikit-learn opencv-python tqdm matplotlib + +clear-venv: + rm -rf .env vscode-plugin: make -C etc/vscode-scl @@ -68,3 +73,10 @@ serve-doc: stop-serve-doc: @echo "Stopping documentation server on port 8192..." @lsof -t -i:8192 | xargs kill + +serve-book: + mdbook serve -p 8193 doc/ + +stop-serve-book: + @echo "Stopping book server on port 8193..." + @lsof -t -i:8193 | xargs kill diff --git a/readme.md b/readme.md index 00f338a..91e5354 100644 --- a/readme.md +++ b/readme.md @@ -1,7 +1,7 @@ # Scallop

- +

Scallop is a language based on DataLog that supports differentiable logical and relational reasoning.