diff --git a/.github/workflows/scallop-core.yml b/.github/workflows/scallop-core.yml new file mode 100644 index 0000000..95591ba --- /dev/null +++ b/.github/workflows/scallop-core.yml @@ -0,0 +1,23 @@ +name: Scallop Core + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +env: + CARGO_TERM_COLOR: always + +jobs: + build-and-test: + name: Build and Test + runs-on: ubuntu-latest + strategy: + matrix: + toolchain: + - nightly + steps: + - uses: actions/checkout@v3 + - run: rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }} + - run: cargo test --verbose --workspace --release diff --git a/.github/workflows/scallopy-torch.yml b/.github/workflows/scallopy-torch.yml new file mode 100644 index 0000000..a479453 --- /dev/null +++ b/.github/workflows/scallopy-torch.yml @@ -0,0 +1,60 @@ +name: Scallopy with torch-tensor feature + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +env: + SCALLOPDIR: ${{ github.workspace }} + LD_LIBRARY_PATH: "/opt/intel/oneapi/mkl/latest/lib/intel64:${{ github.env.LD_LIBRARY_PATH }}" + +jobs: + build-and-test: + name: Build and Test + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + rust: [nightly] + python: ["3.10"] + steps: + - uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python }} + + - name: Add conda to system path + run: echo $CONDA/bin >> $GITHUB_PATH + + - name: Install PyTorch + run: conda install pytorch::pytorch=2.0.0 -c pytorch + + - name: Install other dependencies + run: conda install maturin tqdm + + - name: Install MKL + run: | + wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo echo "deb https://apt.repos.intel.com/oneapi all main" | sudo tee /etc/apt/sources.list.d/oneAPI.list + sudo apt-get update + sudo apt-get install -y intel-oneapi-mkl + python3 -m pip install mkl + ls /opt/intel/oneapi/mkl/latest/lib/intel64 # Investigating whether the library is successfully installed + + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: ${{ matrix.rust }} + override: true + + - name: Install scallopy with torch tensor features + run: make install-scallopy-torch + + - name: Test + run: python3 etc/scallopy/tests/test.py diff --git a/.github/workflows/scallopy.yml b/.github/workflows/scallopy.yml new file mode 100644 index 0000000..a7d2ae5 --- /dev/null +++ b/.github/workflows/scallopy.yml @@ -0,0 +1,46 @@ +name: Scallopy + +on: [push] + +env: + SCALLOPDIR: ${{ github.workspace }} + +jobs: + build-and-test: + name: Build and Test + runs-on: ubuntu-latest + strategy: + max-parallel: 5 + matrix: + python-version: + - "3.8" + - "3.9" + - "3.10" + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + + - name: Add conda to system path + run: echo $CONDA/bin >> $GITHUB_PATH + + - name: Install PyTorch + run: conda install pytorch::pytorch torchvision torchaudio -c pytorch + + - name: Install other dependencies + run: conda install maturin tqdm + + - name: Setup rustup toolchain + run: | + rustup update nightly + rustup default nightly + + - name: Install scallopy + run: make install-scallopy + + - name: Test + run: python3 etc/scallopy/tests/test.py diff --git a/.gitignore b/.gitignore index 3b56b33..80f9a64 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ +# Rust and compilation target +/.tmp /Cargo.lock # MacOS diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..3b1c0c6 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Ziyang Li + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/changelog.md b/changelog.md index 61f1667..8c7a682 100644 --- a/changelog.md +++ b/changelog.md @@ -1,4 +1,14 @@ -# (Latest) v0.1.9, Apr 24, 2023 +# (Latest) v0.2.0, Jun 11, 2023 + +- Fixing CSV loading and its performance; adding new modes to specify `keys` +- Adding `Symbol` type to the language +- Adding algebraic data types to the language and supports for entities +- Adding tensors to the language which can be accessed from `scallopy` +- Adding `scallop` CLI (command-line interface) with OpenAI plugins for invoking LLMs +- Adding more documentations +- Multiple bugs fixed + +# v0.1.9, Apr 24, 2023 - Supporting (partial) disjunctive Datalog - Fixed custom provenance's default implementations and dispatcher fallback diff --git a/core/Cargo.toml b/core/Cargo.toml index 629ad95..00226fc 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallop-core" -version = "0.1.9" +version = "0.2.0" authors = ["Ziyang Li "] edition = "2018" @@ -21,10 +21,17 @@ colored = "2.0" petgraph = "0.6" csv = "1.1" sprs = "0.11" -chrono = "0.4" +chrono = { version = "0.4", features = ["serde"] } dateparser = "0.1.6" parse_duration = "2.1.1" dyn-clone = "1.0.10" lazy_static = "1.4" +serde = { version = "1.0", features = ["derive"] } rand = { version = "0.8", features = ["std_rng", "small_rng", "alloc"] } sdd = { path = "../lib/sdd" } + +# Optional ones +tch = { version = "0.13.0", optional = true } + +[features] +torch-tensor = ["dep:tch"] diff --git a/core/res/.keep b/core/res/.keep new file mode 100644 index 0000000..e69de29 diff --git a/core/res/testing/csv/edge.csv b/core/res/testing/csv/edge.csv new file mode 100644 index 0000000..f947923 --- /dev/null +++ b/core/res/testing/csv/edge.csv @@ -0,0 +1,3 @@ +0,1 +1,2 +2,3 diff --git a/core/res/testing/csv/edge_with_deliminator.csv b/core/res/testing/csv/edge_with_deliminator.csv new file mode 100644 index 0000000..f0a6a16 --- /dev/null +++ b/core/res/testing/csv/edge_with_deliminator.csv @@ -0,0 +1,3 @@ +0 1 +1 2 +2 3 diff --git a/core/res/testing/csv/edge_with_deliminator_and_header.csv b/core/res/testing/csv/edge_with_deliminator_and_header.csv new file mode 100644 index 0000000..9f8289f --- /dev/null +++ b/core/res/testing/csv/edge_with_deliminator_and_header.csv @@ -0,0 +1,4 @@ +from to +0 1 +1 2 +2 3 diff --git a/core/res/testing/csv/edge_with_header.csv b/core/res/testing/csv/edge_with_header.csv new file mode 100644 index 0000000..f5a2306 --- /dev/null +++ b/core/res/testing/csv/edge_with_header.csv @@ -0,0 +1,4 @@ +from,to +0,1 +1,2 +2,3 diff --git a/core/res/testing/csv/edge_with_prob.csv b/core/res/testing/csv/edge_with_prob.csv new file mode 100644 index 0000000..f1d76ab --- /dev/null +++ b/core/res/testing/csv/edge_with_prob.csv @@ -0,0 +1,3 @@ +0.01,0,1 +0.50,1,2 +0.91,2,3 diff --git a/core/res/testing/csv/enrollment.csv b/core/res/testing/csv/enrollment.csv new file mode 100644 index 0000000..e984ba3 --- /dev/null +++ b/core/res/testing/csv/enrollment.csv @@ -0,0 +1,4 @@ +student_id,course_id,semester,year,grade +1,cse100,fa,2020,a +1,cse102,sp,2021,a +2,cse100,sp,2020,b diff --git a/core/res/testing/csv/student.csv b/core/res/testing/csv/student.csv new file mode 100644 index 0000000..99e4de2 --- /dev/null +++ b/core/res/testing/csv/student.csv @@ -0,0 +1,3 @@ +id,name,year,gender +1,alice,2022,female +2,bob,2023,male diff --git a/core/res/testing/json/.keep b/core/res/testing/json/.keep new file mode 100644 index 0000000..e69de29 diff --git a/core/src/common/binary_op.rs b/core/src/common/binary_op.rs index 4af6fee..1a6f056 100644 --- a/core/src/common/binary_op.rs +++ b/core/src/common/binary_op.rs @@ -1,6 +1,8 @@ //! # Binary Operations -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +use serde::*; + +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize)] pub enum BinaryOp { Add, Sub, diff --git a/core/src/common/entity.rs b/core/src/common/entity.rs new file mode 100644 index 0000000..dab93da --- /dev/null +++ b/core/src/common/entity.rs @@ -0,0 +1,11 @@ +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; + +pub fn encode_entity>(functor: &str, args: I) -> u64 { + let mut s = DefaultHasher::new(); + functor.hash(&mut s); + for arg in args { + arg.hash(&mut s); + } + s.finish() +} diff --git a/core/src/common/expr.rs b/core/src/common/expr.rs index 7ffd270..b73ca02 100644 --- a/core/src/common/expr.rs +++ b/core/src/common/expr.rs @@ -14,6 +14,7 @@ pub enum Expr { Unary(UnaryExpr), IfThenElse(IfThenElseExpr), Call(CallExpr), + New(NewExpr), } impl Expr { @@ -41,6 +42,26 @@ impl Expr { Self::Call(CallExpr { function, args }) } + pub fn new(functor: String, args: Vec) -> Self { + Self::New(NewExpr { functor, args }) + } + + pub fn eq(self, other: Expr) -> Self { + Self::Binary(BinaryExpr { + op: BinaryOp::Eq, + op1: Box::new(self), + op2: Box::new(other), + }) + } + + pub fn neq(self, other: Expr) -> Self { + Self::Binary(BinaryExpr { + op: BinaryOp::Neq, + op1: Box::new(self), + op2: Box::new(other), + }) + } + pub fn lt(self, other: Expr) -> Self { Self::Binary(BinaryExpr { op: BinaryOp::Lt, @@ -98,6 +119,7 @@ impl Expr { (Self::Unary(u), e) => Self::unary(u.op.clone(), u.op1.compose(e)), (Self::IfThenElse(i), e) => Self::ite(i.cond.compose(e), i.then_br.compose(e), i.else_br.compose(e)), (Self::Call(c), e) => Self::call(c.function.clone(), c.args.iter().map(|a| a.compose(e)).collect()), + (Self::New(n), e) => Self::new(n.functor.clone(), n.args.iter().map(|a| a.compose(e)).collect()), } } } @@ -291,3 +313,9 @@ pub struct CallExpr { pub function: String, pub args: Vec, } + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct NewExpr { + pub functor: String, + pub args: Vec, +} diff --git a/core/src/common/foreign_function.rs b/core/src/common/foreign_function.rs index a65a9f0..bcc222b 100644 --- a/core/src/common/foreign_function.rs +++ b/core/src/common/foreign_function.rs @@ -75,10 +75,12 @@ use std::collections::*; use dyn_clone::DynClone; +use super::foreign_functions as ffs; use super::type_family::*; use super::value::*; use super::value_type::*; -use super::foreign_functions as ffs; + +use crate::runtime::env::*; /// A type used for defining a foreign function. /// @@ -237,7 +239,19 @@ pub trait ForeignFunction: DynClone { /// /// We assume that the given arguments obey the type declaration. /// In case error happens, we return `None` as the result. - fn execute(&self, args: Vec) -> Option; + #[allow(unused_variables)] + fn execute(&self, args: Vec) -> Option { + panic!( + "[Internal Error] Missing execute function in the foreign function `{}`", + self.name() + ) + } + + /// Execute the function given arguments and a runtime environment + #[allow(unused_variables)] + fn execute_with_env(&self, env: &RuntimeEnvironment, args: Vec) -> Option { + self.execute(args) + } /// Get all the arguments fn arguments(&self) -> Vec<(ArgumentKind, ForeignFunctionParameterType)> { @@ -467,6 +481,10 @@ impl ForeignFunction for DynamicForeignFunction { fn execute(&self, args: Vec) -> Option { self.ff.execute(args) } + + fn execute_with_env(&self, env: &RuntimeEnvironment, args: Vec) -> Option { + self.ff.execute_with_env(env, args) + } } /// Dynamic foreign function registry @@ -497,9 +515,22 @@ impl ForeignFunctionRegistry { // Arithmetic registry.register(ffs::Abs).unwrap(); + registry.register(ffs::Floor).unwrap(); + registry.register(ffs::Ceil).unwrap(); + registry.register(ffs::Exp).unwrap(); + registry.register(ffs::Exp2).unwrap(); + registry.register(ffs::Log).unwrap(); + registry.register(ffs::Log2).unwrap(); + registry.register(ffs::Pow).unwrap(); + registry.register(ffs::Powf).unwrap(); registry.register(ffs::Sin).unwrap(); registry.register(ffs::Cos).unwrap(); registry.register(ffs::Tan).unwrap(); + registry.register(ffs::Asin).unwrap(); + registry.register(ffs::Acos).unwrap(); + registry.register(ffs::Atan).unwrap(); + registry.register(ffs::Atan2).unwrap(); + registry.register(ffs::Sign).unwrap(); // Min/Max registry.register(ffs::Max).unwrap(); @@ -510,6 +541,11 @@ impl ForeignFunctionRegistry { registry.register(ffs::StringLength).unwrap(); registry.register(ffs::StringCharAt).unwrap(); registry.register(ffs::Substring).unwrap(); + registry.register(ffs::Format).unwrap(); + registry.register(ffs::StringUpper).unwrap(); + registry.register(ffs::StringLower).unwrap(); + registry.register(ffs::StringIndexOf).unwrap(); + registry.register(ffs::StringTrim).unwrap(); // DateTime operations registry.register(ffs::DateTimeDay).unwrap(); @@ -520,6 +556,9 @@ impl ForeignFunctionRegistry { // Hashing operation registry.register(ffs::Hash).unwrap(); + // Tensor operation + registry.register(ffs::Dot).unwrap(); + registry } @@ -566,9 +605,23 @@ impl<'a> IntoIterator for &'a ForeignFunctionRegistry { pub trait UnaryFloatFunction: Clone { fn name(&self) -> String; - fn execute_f32(&self, arg: f32) -> f32; + #[allow(unused_variables)] + fn execute_f32(&self, arg: f32) -> f32 { + 0.0 + } + + fn execute_f32_partial(&self, arg: f32) -> Option { + Some(self.execute_f32(arg)) + } - fn execute_f64(&self, arg: f64) -> f64; + #[allow(unused_variables)] + fn execute_f64(&self, arg: f64) -> f64 { + 0.0 + } + + fn execute_f64_partial(&self, arg: f64) -> Option { + Some(self.execute_f64(arg)) + } } impl ForeignFunction for F { @@ -600,8 +653,8 @@ impl ForeignFunction for F { fn execute(&self, args: Vec) -> Option { match args[0] { - Value::F32(f) => Some(Value::F32(self.execute_f32(f))), - Value::F64(f) => Some(Value::F64(self.execute_f64(f))), + Value::F32(f) => self.execute_f32_partial(f).map(Value::F32), + Value::F64(f) => self.execute_f64_partial(f).map(Value::F64), _ => panic!("Expect floating point input"), } } diff --git a/core/src/common/foreign_functions/acos.rs b/core/src/common/foreign_functions/acos.rs new file mode 100644 index 0000000..8a05751 --- /dev/null +++ b/core/src/common/foreign_functions/acos.rs @@ -0,0 +1,31 @@ +use super::*; + +/// Arccosine foreign function +/// +/// ``` scl +/// extern fn $acos(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Acos; + +impl UnaryFloatFunction for Acos { + fn name(&self) -> String { + "acos".to_string() + } + + fn execute_f32_partial(&self, arg: f32) -> Option { + if arg >= -1.0 && arg <= 1.0 { + Some(arg.acos()) + } else { + None + } + } + + fn execute_f64_partial(&self, arg: f64) -> Option { + if arg >= -1.0 && arg <= 1.0 { + Some(arg.acos()) + } else { + None + } + } +} diff --git a/core/src/common/foreign_functions/asin.rs b/core/src/common/foreign_functions/asin.rs new file mode 100644 index 0000000..f533044 --- /dev/null +++ b/core/src/common/foreign_functions/asin.rs @@ -0,0 +1,31 @@ +use super::*; + +/// Arcsine foreign function +/// +/// ``` scl +/// extern fn $asin(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Asin; + +impl UnaryFloatFunction for Asin { + fn name(&self) -> String { + "asin".to_string() + } + + fn execute_f32_partial(&self, arg: f32) -> Option { + if arg >= -1.0 && arg <= 1.0 { + Some(arg.asin()) + } else { + None + } + } + + fn execute_f64_partial(&self, arg: f64) -> Option { + if arg >= -1.0 && arg <= 1.0 { + Some(arg.asin()) + } else { + None + } + } +} diff --git a/core/src/common/foreign_functions/atan.rs b/core/src/common/foreign_functions/atan.rs new file mode 100644 index 0000000..fd2d8d7 --- /dev/null +++ b/core/src/common/foreign_functions/atan.rs @@ -0,0 +1,23 @@ +use super::*; + +/// Arctangent foreign function +/// +/// ``` scl +/// extern fn $atan(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Atan; + +impl UnaryFloatFunction for Atan { + fn name(&self) -> String { + "atan".to_string() + } + + fn execute_f32(&self, arg: f32) -> f32 { + arg.atan() + } + + fn execute_f64(&self, arg: f64) -> f64 { + arg.atan() + } +} diff --git a/core/src/common/foreign_functions/atan2.rs b/core/src/common/foreign_functions/atan2.rs new file mode 100644 index 0000000..7d408bb --- /dev/null +++ b/core/src/common/foreign_functions/atan2.rs @@ -0,0 +1,55 @@ +use super::*; + +/// Arctangent2 foreign function +/// +/// ``` scl +/// extern fn $atan2(y: T, x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Atan2; + +impl ForeignFunction for Atan2 { + fn name(&self) -> String { + "atan2".to_string() + } + + fn num_generic_types(&self) -> usize { + 1 + } + + fn generic_type_family(&self, i: usize) -> TypeFamily { + match i { + 0 => TypeFamily::Float, + _ => panic!("No argument {}", i), + } + } + + fn num_static_arguments(&self) -> usize { + 2 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + match i { + 0 | 1 => ForeignFunctionParameterType::Generic(0), + _ => panic!("No argument {}", i), + } + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::Generic(0) + } + + fn execute(&self, args: Vec) -> Option { + match args[0] { + Value::F32(y) => match args[1] { + Value::F32(x) => Some(Value::F32(y.atan2(x))), + _ => panic!("Invalid arguments, should be floats of same bitsize"), + }, + Value::F64(y) => match args[1] { + Value::F64(x) => Some(Value::F64(y.atan2(x))), + _ => panic!("Invalid arguments, should be floats of same bitsize"), + }, + _ => panic!("Invalid arguments, should be floats of same bitsize"), + } + } +} diff --git a/core/src/common/foreign_functions/ceil.rs b/core/src/common/foreign_functions/ceil.rs new file mode 100644 index 0000000..89e9be6 --- /dev/null +++ b/core/src/common/foreign_functions/ceil.rs @@ -0,0 +1,60 @@ +use super::*; + +/// Ceiling foreign function +/// +/// ``` scl +/// extern fn $ceil(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Ceil; + +impl ForeignFunction for Ceil { + fn name(&self) -> String { + "ceil".to_string() + } + + fn num_generic_types(&self) -> usize { + 1 + } + + fn generic_type_family(&self, i: usize) -> TypeFamily { + assert_eq!(i, 0); + TypeFamily::Number + } + + fn num_static_arguments(&self) -> usize { + 1 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::Generic(0) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::Generic(0) + } + + fn execute(&self, args: Vec) -> Option { + match args[0] { + // Integers, directly return + Value::I8(f) => Some(Value::I8(f)), + Value::I16(f) => Some(Value::I16(f)), + Value::I32(f) => Some(Value::I32(f)), + Value::I64(f) => Some(Value::I64(f)), + Value::I128(f) => Some(Value::I128(f)), + Value::ISize(f) => Some(Value::ISize(f)), + Value::U8(f) => Some(Value::U8(f)), + Value::U16(f) => Some(Value::U16(f)), + Value::U32(f) => Some(Value::U32(f)), + Value::U64(f) => Some(Value::U64(f)), + Value::U128(f) => Some(Value::U128(f)), + Value::USize(f) => Some(Value::USize(f)), + + // Floating points, take ceiling + Value::F32(f) => Some(Value::F32(f.ceil())), + Value::F64(f) => Some(Value::F64(f.ceil())), + _ => panic!("should not happen; input variable to abs should be a number"), + } + } +} diff --git a/core/src/common/foreign_functions/datetime_month0.rs b/core/src/common/foreign_functions/datetime_month0.rs index 249b638..e6c6532 100644 --- a/core/src/common/foreign_functions/datetime_month0.rs +++ b/core/src/common/foreign_functions/datetime_month0.rs @@ -5,7 +5,7 @@ use super::*; /// Get the month of the year starting from 0 /// /// ``` scl -/// extern fn $datetime_month(d: DateTime) -> u32 +/// extern fn $datetime_month0(d: DateTime) -> u32 /// ``` #[derive(Clone)] pub struct DateTimeMonth0; diff --git a/core/src/common/foreign_functions/dot.rs b/core/src/common/foreign_functions/dot.rs new file mode 100644 index 0000000..7cba8bf --- /dev/null +++ b/core/src/common/foreign_functions/dot.rs @@ -0,0 +1,36 @@ +use super::*; + +/// Dot product of two tensors +/// +/// ``` scl +/// extern fn $dot(x: Tensor, y: Tensor) -> Tensor +/// ``` +#[derive(Clone)] +pub struct Dot; + +impl ForeignFunction for Dot { + fn name(&self) -> String { + "dot".to_string() + } + + fn num_static_arguments(&self) -> usize { + 2 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert!(i < 2); + ForeignFunctionParameterType::BaseType(ValueType::Tensor) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::Tensor) + } + + fn execute(&self, args: Vec) -> Option { + let mut iter = args.into_iter(); + match (iter.next().unwrap(), iter.next().unwrap()) { + (Value::TensorValue(t1), Value::TensorValue(t2)) => t1.dot(t2).map(Value::TensorValue), + _ => None, + } + } +} diff --git a/core/src/common/foreign_functions/exp.rs b/core/src/common/foreign_functions/exp.rs new file mode 100644 index 0000000..5d3fc5a --- /dev/null +++ b/core/src/common/foreign_functions/exp.rs @@ -0,0 +1,23 @@ +use super::*; + +/// Exponential foreign function: e^x +/// +/// ``` scl +/// extern fn $exp(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Exp; + +impl UnaryFloatFunction for Exp { + fn name(&self) -> String { + "exp".to_string() + } + + fn execute_f32(&self, arg: f32) -> f32 { + arg.exp() + } + + fn execute_f64(&self, arg: f64) -> f64 { + arg.exp() + } +} diff --git a/core/src/common/foreign_functions/exp2.rs b/core/src/common/foreign_functions/exp2.rs new file mode 100644 index 0000000..ad222eb --- /dev/null +++ b/core/src/common/foreign_functions/exp2.rs @@ -0,0 +1,23 @@ +use super::*; + +/// Exponential foreign function: 2^x +/// +/// ``` scl +/// extern fn $exp2(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Exp2; + +impl UnaryFloatFunction for Exp2 { + fn name(&self) -> String { + "exp2".to_string() + } + + fn execute_f32(&self, arg: f32) -> f32 { + arg.exp2() + } + + fn execute_f64(&self, arg: f64) -> f64 { + arg.exp2() + } +} diff --git a/core/src/common/foreign_functions/floor.rs b/core/src/common/foreign_functions/floor.rs new file mode 100644 index 0000000..93e5f36 --- /dev/null +++ b/core/src/common/foreign_functions/floor.rs @@ -0,0 +1,60 @@ +use super::*; + +/// Floor foreign function +/// +/// ``` scl +/// extern fn $floor(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Floor; + +impl ForeignFunction for Floor { + fn name(&self) -> String { + "floor".to_string() + } + + fn num_generic_types(&self) -> usize { + 1 + } + + fn generic_type_family(&self, i: usize) -> TypeFamily { + assert_eq!(i, 0); + TypeFamily::Number + } + + fn num_static_arguments(&self) -> usize { + 1 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::Generic(0) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::Generic(0) + } + + fn execute(&self, args: Vec) -> Option { + match args[0] { + // Integers, directly return + Value::I8(f) => Some(Value::I8(f)), + Value::I16(f) => Some(Value::I16(f)), + Value::I32(f) => Some(Value::I32(f)), + Value::I64(f) => Some(Value::I64(f)), + Value::I128(f) => Some(Value::I128(f)), + Value::ISize(f) => Some(Value::ISize(f)), + Value::U8(f) => Some(Value::U8(f)), + Value::U16(f) => Some(Value::U16(f)), + Value::U32(f) => Some(Value::U32(f)), + Value::U64(f) => Some(Value::U64(f)), + Value::U128(f) => Some(Value::U128(f)), + Value::USize(f) => Some(Value::USize(f)), + + // Floating points, take floor + Value::F32(f) => Some(Value::F32(f.floor())), + Value::F64(f) => Some(Value::F64(f.floor())), + _ => panic!("should not happen; input variable to abs should be a number"), + } + } +} diff --git a/core/src/common/foreign_functions/format.rs b/core/src/common/foreign_functions/format.rs new file mode 100644 index 0000000..4837a5a --- /dev/null +++ b/core/src/common/foreign_functions/format.rs @@ -0,0 +1,133 @@ +use super::*; + +use crate::runtime::env::*; + +/// Format foreign function +/// +/// ``` scl +/// extern fn $format(f: String, args: Any...) -> String +/// ``` +#[derive(Clone)] +pub struct Format; + +impl ForeignFunction for Format { + fn name(&self) -> String { + "format".to_string() + } + + fn num_generic_types(&self) -> usize { + 0 + } + + fn num_static_arguments(&self) -> usize { + 1 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::BaseType(ValueType::String) + } + + fn has_variable_arguments(&self) -> bool { + true + } + + fn variable_argument_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::TypeFamily(TypeFamily::Any) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::String) + } + + fn execute(&self, args: Vec) -> Option { + match &args[0] { + Value::String(f) => { + let split_str: Vec<&str> = f.split("{}").collect(); + let mut result = "".to_string(); + result += split_str[0]; + // boths lens should be # of braces + 1 + assert_eq!(args.len(), split_str.len()); + for i in 1..args.len() { + let s = match &args[i] { + Value::I8(i) => i.to_string(), + Value::I16(i) => i.to_string(), + Value::I32(i) => i.to_string(), + Value::I64(i) => i.to_string(), + Value::I128(i) => i.to_string(), + Value::ISize(i) => i.to_string(), + Value::U8(i) => i.to_string(), + Value::U16(i) => i.to_string(), + Value::U32(i) => i.to_string(), + Value::U64(i) => i.to_string(), + Value::U128(i) => i.to_string(), + Value::USize(i) => i.to_string(), + Value::F32(i) => i.to_string(), + Value::F64(i) => i.to_string(), + Value::Char(i) => i.to_string(), + Value::Bool(i) => i.to_string(), + Value::Str(i) => i.to_string(), + Value::String(i) => i.to_string(), + Value::Symbol(_) => panic!("[Internal Error] Symbol should not be processed"), + Value::SymbolString(s) => s.to_string(), + Value::DateTime(i) => i.to_string(), + Value::Duration(i) => i.to_string(), + Value::Entity(e) => format!("entity({e:#x})"), + Value::Tensor(_) | Value::TensorValue(_) => "tensor".to_string(), + }; + result += s.as_str(); + result += split_str[i]; + } + Some(Value::from(result)) + } + _ => panic!("Format argument not a string"), + } + } + + fn execute_with_env(&self, env: &RuntimeEnvironment, args: Vec) -> Option { + match &args[0] { + Value::String(f) => { + let split_str: Vec<&str> = f.split("{}").collect(); + let mut result = "".to_string(); + result += split_str[0]; + // boths lens should be # of braces + 1 + assert_eq!(args.len(), split_str.len()); + for i in 1..args.len() { + let s = match &args[i] { + Value::I8(i) => i.to_string(), + Value::I16(i) => i.to_string(), + Value::I32(i) => i.to_string(), + Value::I64(i) => i.to_string(), + Value::I128(i) => i.to_string(), + Value::ISize(i) => i.to_string(), + Value::U8(i) => i.to_string(), + Value::U16(i) => i.to_string(), + Value::U32(i) => i.to_string(), + Value::U64(i) => i.to_string(), + Value::U128(i) => i.to_string(), + Value::USize(i) => i.to_string(), + Value::F32(i) => i.to_string(), + Value::F64(i) => i.to_string(), + Value::Char(i) => i.to_string(), + Value::Bool(i) => i.to_string(), + Value::Str(i) => i.to_string(), + Value::String(i) => i.to_string(), + Value::Symbol(i) => env + .symbol_registry + .get_symbol(*i) + .expect("[Internal Error] Cannot find symbol"), + Value::SymbolString(_) => panic!("[Internal Error] SymbolString should not be processed"), + Value::DateTime(i) => i.to_string(), + Value::Duration(i) => i.to_string(), + Value::Entity(e) => format!("entity({e:#x})"), + Value::Tensor(_) | Value::TensorValue(_) => "tensor".to_string(), + }; + result += s.as_str(); + result += split_str[i]; + } + Some(Value::from(result)) + } + _ => panic!("Format argument not a string"), + } + } +} diff --git a/core/src/common/foreign_functions/log.rs b/core/src/common/foreign_functions/log.rs new file mode 100644 index 0000000..b1bc55f --- /dev/null +++ b/core/src/common/foreign_functions/log.rs @@ -0,0 +1,31 @@ +use super::*; + +/// Log (base e) foreign function +/// +/// ``` scl +/// extern fn $log(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Log; + +impl UnaryFloatFunction for Log { + fn name(&self) -> String { + "log".to_string() + } + + fn execute_f32_partial(&self, arg: f32) -> Option { + if arg > 0.0 { + Some(arg.ln()) + } else { + None + } + } + + fn execute_f64_partial(&self, arg: f64) -> Option { + if arg > 0.0 { + Some(arg.ln()) + } else { + None + } + } +} diff --git a/core/src/common/foreign_functions/log2.rs b/core/src/common/foreign_functions/log2.rs new file mode 100644 index 0000000..248ecc5 --- /dev/null +++ b/core/src/common/foreign_functions/log2.rs @@ -0,0 +1,31 @@ +use super::*; + +/// Log (base 2) foreign function +/// +/// ``` scl +/// extern fn $log2(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Log2; + +impl UnaryFloatFunction for Log2 { + fn name(&self) -> String { + "log2".to_string() + } + + fn execute_f32_partial(&self, arg: f32) -> Option { + if arg > 0.0 { + Some(arg.log2()) + } else { + None + } + } + + fn execute_f64_partial(&self, arg: f64) -> Option { + if arg > 0.0 { + Some(arg.log2()) + } else { + None + } + } +} diff --git a/core/src/common/foreign_functions/max.rs b/core/src/common/foreign_functions/max.rs index e39b911..75cf64a 100644 --- a/core/src/common/foreign_functions/max.rs +++ b/core/src/common/foreign_functions/max.rs @@ -9,7 +9,10 @@ use super::*; pub struct Max; impl Max { - fn dyn_max(args: Vec) -> Option where Value: TryInto { + fn dyn_max(args: Vec) -> Option + where + Value: TryInto, + { let mut iter = args.into_iter(); let mut curr_max: T = iter.next()?.try_into().ok()?; while let Some(next_elem) = iter.next() { diff --git a/core/src/common/foreign_functions/min.rs b/core/src/common/foreign_functions/min.rs index b3684e9..df5e5cf 100644 --- a/core/src/common/foreign_functions/min.rs +++ b/core/src/common/foreign_functions/min.rs @@ -9,7 +9,10 @@ use super::*; pub struct Min; impl Min { - fn dyn_min(args: Vec) -> Option where Value: TryInto { + fn dyn_min(args: Vec) -> Option + where + Value: TryInto, + { let mut iter = args.into_iter(); let mut curr_min: T = iter.next()?.try_into().ok()?; while let Some(next_elem) = iter.next() { diff --git a/core/src/common/foreign_functions/mod.rs b/core/src/common/foreign_functions/mod.rs index f897202..5d103b7 100644 --- a/core/src/common/foreign_functions/mod.rs +++ b/core/src/common/foreign_functions/mod.rs @@ -1,40 +1,78 @@ //! A library of foreign functions +use super::foreign_function::*; +use super::type_family::*; use super::value::*; use super::value_type::*; -use super::type_family::*; -use super::foreign_function::*; use std::convert::*; mod abs; +mod acos; +mod asin; +mod atan; +mod atan2; +mod ceil; mod cos; mod datetime_day; mod datetime_month; mod datetime_month0; mod datetime_year; +mod dot; +mod exp; +mod exp2; +mod floor; +mod format; mod hash; +mod log; +mod log2; mod max; mod min; +mod pow; +mod powf; +mod sign; mod sin; mod string_char_at; mod string_concat; +mod string_index_of; mod string_length; +mod string_lower; +mod string_trim; +mod string_upper; mod substring; mod tan; pub use abs::*; +pub use acos::*; +pub use asin::*; +pub use atan::*; +pub use atan2::*; +pub use ceil::*; pub use cos::*; pub use datetime_day::*; pub use datetime_month::*; pub use datetime_month0::*; pub use datetime_year::*; +pub use dot::*; +pub use exp::*; +pub use exp2::*; +pub use floor::*; +pub use format::*; pub use hash::*; +pub use log::*; +pub use log2::*; pub use max::*; pub use min::*; +pub use pow::*; +pub use powf::*; +pub use sign::*; pub use sin::*; pub use string_char_at::*; pub use string_concat::*; +pub use string_index_of::*; pub use string_length::*; +pub use string_lower::*; +pub use string_trim::*; +pub use string_upper::*; pub use substring::*; pub use tan::*; diff --git a/core/src/common/foreign_functions/pow.rs b/core/src/common/foreign_functions/pow.rs new file mode 100644 index 0000000..c540334 --- /dev/null +++ b/core/src/common/foreign_functions/pow.rs @@ -0,0 +1,58 @@ +use super::*; + +/// Power foreign function (x^y) +/// +/// ``` scl +/// extern fn $pow(x: T, y: u32) -> T +/// ``` +#[derive(Clone)] +pub struct Pow; + +impl ForeignFunction for Pow { + fn name(&self) -> String { + "pow".to_string() + } + + fn num_generic_types(&self) -> usize { + 1 + } + + fn generic_type_family(&self, i: usize) -> TypeFamily { + assert_eq!(i, 0); + TypeFamily::Integer + } + + fn num_static_arguments(&self) -> usize { + 2 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + match i { + 0 => ForeignFunctionParameterType::Generic(0), + 1 => ForeignFunctionParameterType::BaseType(ValueType::U32), + _ => panic!("No argument {}", i), + } + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::Generic(0) + } + + fn execute(&self, args: Vec) -> Option { + match (&args[0], &args[1]) { + (Value::I8(x), Value::U32(y)) => Some(Value::I8(x.pow(*y))), + (Value::I16(x), Value::U32(y)) => Some(Value::I16(x.pow(*y))), + (Value::I32(x), Value::U32(y)) => Some(Value::I32(x.pow(*y))), + (Value::I64(x), Value::U32(y)) => Some(Value::I64(x.pow(*y))), + (Value::I128(x), Value::U32(y)) => Some(Value::I128(x.pow(*y))), + (Value::ISize(x), Value::U32(y)) => Some(Value::ISize(x.pow(*y))), + (Value::U8(x), Value::U32(y)) => Some(Value::U8(x.pow(*y))), + (Value::U16(x), Value::U32(y)) => Some(Value::U16(x.pow(*y))), + (Value::U32(x), Value::U32(y)) => Some(Value::U32(x.pow(*y))), + (Value::U64(x), Value::U32(y)) => Some(Value::U64(x.pow(*y))), + (Value::U128(x), Value::U32(y)) => Some(Value::U128(x.pow(*y))), + (Value::USize(x), Value::U32(y)) => Some(Value::USize(x.pow(*y))), + _ => panic!("Invalid arguments"), + } + } +} diff --git a/core/src/common/foreign_functions/powf.rs b/core/src/common/foreign_functions/powf.rs new file mode 100644 index 0000000..0713b75 --- /dev/null +++ b/core/src/common/foreign_functions/powf.rs @@ -0,0 +1,55 @@ +use super::*; + +/// Powf foreign function (x^y) +/// +/// ``` scl +/// extern fn $powf(x: T, y: T) -> T +/// ``` +#[derive(Clone)] +pub struct Powf; + +impl ForeignFunction for Powf { + fn name(&self) -> String { + "powf".to_string() + } + + fn num_generic_types(&self) -> usize { + 1 + } + + fn generic_type_family(&self, i: usize) -> TypeFamily { + match i { + 0 => TypeFamily::Float, + _ => panic!("No argument {}", i), + } + } + + fn num_static_arguments(&self) -> usize { + 2 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + match i { + 0 | 1 => ForeignFunctionParameterType::Generic(0), + _ => panic!("No argument {}", i), + } + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::Generic(0) + } + + fn execute(&self, args: Vec) -> Option { + match args[0] { + Value::F32(x) => match args[1] { + Value::F32(y) => Some(Value::F32(x.powf(y))), + _ => panic!("Invalid arguments, should be floats of same bitsize"), + }, + Value::F64(x) => match args[1] { + Value::F64(y) => Some(Value::F64(x.powf(y))), + _ => panic!("Invalid arguments, should be floats of same bitsize"), + }, + _ => panic!("Invalid arguments, should be floats of same bitsize"), + } + } +} diff --git a/core/src/common/foreign_functions/sign.rs b/core/src/common/foreign_functions/sign.rs new file mode 100644 index 0000000..5402a65 --- /dev/null +++ b/core/src/common/foreign_functions/sign.rs @@ -0,0 +1,76 @@ +use super::*; + +/// Sign foreign function +/// +/// ``` scl +/// extern fn $sign(x: T) -> T +/// ``` +#[derive(Clone)] +pub struct Sign; + +impl ForeignFunction for Sign { + fn name(&self) -> String { + "sign".to_string() + } + + fn num_generic_types(&self) -> usize { + 1 + } + + fn generic_type_family(&self, i: usize) -> TypeFamily { + assert_eq!(i, 0); + TypeFamily::Number + } + + fn num_static_arguments(&self) -> usize { + 1 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::Generic(0) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::Generic(0) + } + + fn execute(&self, args: Vec) -> Option { + match args[0] { + Value::I8(f) => Some(Value::I8(f.signum())), + Value::I16(f) => Some(Value::I16(f.signum())), + Value::I32(f) => Some(Value::I32(f.signum())), + Value::I64(f) => Some(Value::I64(f.signum())), + Value::I128(f) => Some(Value::I128(f.signum())), + Value::ISize(f) => Some(Value::ISize(f.signum())), + + Value::U8(_) => Some(Value::U8(1)), + Value::U16(_) => Some(Value::U16(1)), + Value::U32(_) => Some(Value::U32(1)), + Value::U64(_) => Some(Value::U64(1)), + Value::U128(_) => Some(Value::U128(1)), + Value::USize(_) => Some(Value::USize(1)), + + Value::F32(f) => { + if f < 0.0 { + Some(Value::I32(-1)) + } else if f > 0.0 { + Some(Value::I32(1)) + } else { + Some(Value::I32(0)) + } + } + Value::F64(f) => { + if f < 0.0 { + Some(Value::I32(-1)) + } else if f > 0.0 { + Some(Value::I32(1)) + } else { + Some(Value::I32(0)) + } + } + + _ => panic!("should not happen; input variable to sign should be a number"), + } + } +} diff --git a/core/src/common/foreign_functions/string_index_of.rs b/core/src/common/foreign_functions/string_index_of.rs new file mode 100644 index 0000000..5f553c3 --- /dev/null +++ b/core/src/common/foreign_functions/string_index_of.rs @@ -0,0 +1,43 @@ +use super::*; + +/// String index_of foreign function +/// +/// ``` scl +/// extern fn $string_index_of(s: String, sub: String) -> usize +/// ``` +#[derive(Clone)] +pub struct StringIndexOf; + +impl ForeignFunction for StringIndexOf { + fn name(&self) -> String { + "string_index_of".to_string() + } + + fn num_static_arguments(&self) -> usize { + 2 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + match i { + 0 | 1 => ForeignFunctionParameterType::BaseType(ValueType::String), + _ => panic!("Invalid {}-th argument", i), + } + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::USize) + } + + fn execute(&self, args: Vec) -> Option { + assert_eq!(args.len(), 2); + match (&args[0], &args[1]) { + (Value::String(s), Value::String(sub)) => { + match s.find(sub) { + Some(index) => Some(Value::USize(index)), + None => None, // return None if no match found + } + } + _ => panic!("Invalid arguments, expected strings"), + } + } +} diff --git a/core/src/common/foreign_functions/string_lower.rs b/core/src/common/foreign_functions/string_lower.rs new file mode 100644 index 0000000..2011be9 --- /dev/null +++ b/core/src/common/foreign_functions/string_lower.rs @@ -0,0 +1,36 @@ +use super::*; + +/// String lower foreign function +/// +/// ``` scl +/// extern fn $string_lower(s: String) -> String +/// ``` +#[derive(Clone)] +pub struct StringLower; + +impl ForeignFunction for StringLower { + fn name(&self) -> String { + "string_lower".to_string() + } + + fn num_static_arguments(&self) -> usize { + 1 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::BaseType(ValueType::String) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::String) + } + + fn execute(&self, args: Vec) -> Option { + assert_eq!(args.len(), 1); + match &args[0] { + Value::String(s) => Some(Value::from(s.to_ascii_lowercase())), + _ => panic!("Invalid argument, expected string"), + } + } +} diff --git a/core/src/common/foreign_functions/string_trim.rs b/core/src/common/foreign_functions/string_trim.rs new file mode 100644 index 0000000..1c4b263 --- /dev/null +++ b/core/src/common/foreign_functions/string_trim.rs @@ -0,0 +1,36 @@ +use super::*; + +/// String trim foreign function +/// +/// ``` scl +/// extern fn $string_trim(s: String) -> String +/// ``` +#[derive(Clone)] +pub struct StringTrim; + +impl ForeignFunction for StringTrim { + fn name(&self) -> String { + "string_trim".to_string() + } + + fn num_static_arguments(&self) -> usize { + 1 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::BaseType(ValueType::String) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::String) + } + + fn execute(&self, args: Vec) -> Option { + assert_eq!(args.len(), 1); + match &args[0] { + Value::String(s) => Some(Value::from(s.trim().to_string())), + _ => panic!("Invalid argument, expected string"), + } + } +} diff --git a/core/src/common/foreign_functions/string_upper.rs b/core/src/common/foreign_functions/string_upper.rs new file mode 100644 index 0000000..610422b --- /dev/null +++ b/core/src/common/foreign_functions/string_upper.rs @@ -0,0 +1,36 @@ +use super::*; + +/// String upper foreign function +/// +/// ``` scl +/// extern fn $string_upper(s: String) -> String +/// ``` +#[derive(Clone)] +pub struct StringUpper; + +impl ForeignFunction for StringUpper { + fn name(&self) -> String { + "string_upper".to_string() + } + + fn num_static_arguments(&self) -> usize { + 1 + } + + fn static_argument_type(&self, i: usize) -> ForeignFunctionParameterType { + assert_eq!(i, 0); + ForeignFunctionParameterType::BaseType(ValueType::String) + } + + fn return_type(&self) -> ForeignFunctionParameterType { + ForeignFunctionParameterType::BaseType(ValueType::String) + } + + fn execute(&self, args: Vec) -> Option { + assert_eq!(args.len(), 1); + match &args[0] { + Value::String(s) => Some(Value::from(s.to_ascii_uppercase())), + _ => panic!("Invalid argument, expected string"), + } + } +} diff --git a/core/src/common/foreign_predicate.rs b/core/src/common/foreign_predicate.rs index 3101e12..c2f8607 100644 --- a/core/src/common/foreign_predicate.rs +++ b/core/src/common/foreign_predicate.rs @@ -2,10 +2,12 @@ use std::collections::*; use dyn_clone::*; +use crate::runtime::env::RuntimeEnvironment; + +use super::foreign_predicates as fps; +use super::input_tag::*; use super::value::*; use super::value_type::*; -use super::input_tag::*; -use super::foreign_predicates as fps; #[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] pub enum Binding { @@ -52,7 +54,12 @@ impl std::fmt::Display for ForeignPredicateIdentifier { "pred {}[{}]({})", self.identifier, self.binding_pattern, - self.types.iter().map(|t| format!("{}", t)).collect::>().join(", ") + self + .types + .iter() + .map(|t| format!("{}", t)) + .collect::>() + .join(", ") )) } } @@ -77,10 +84,9 @@ impl BindingPattern { pub fn new(arity: usize, num_bounded: usize) -> Self { assert!(num_bounded <= arity); Self { - pattern: (0..arity).map(|i| { - if i < num_bounded { Binding::Bound } - else { Binding::Free } - }).collect() + pattern: (0..arity) + .map(|i| if i < num_bounded { Binding::Bound } else { Binding::Free }) + .collect(), } } @@ -119,6 +125,29 @@ pub trait ForeignPredicate: DynClone { /// The name of the predicate fn name(&self) -> String; + /// Generic type parameters + fn generic_type_parameters(&self) -> Vec { + vec![] + } + + fn internal_name(&self) -> String { + let name = self.name(); + let type_params = self.generic_type_parameters(); + if type_params.len() > 0 { + format!( + "{}#{}", + name, + type_params + .into_iter() + .map(|t| t.to_string()) + .collect::>() + .join("#") + ) + } else { + name.to_string() + } + } + /// The arity of the predicate (i.e. number of arguments) fn arity(&self) -> usize; @@ -149,7 +178,19 @@ pub trait ForeignPredicate: DynClone { /// /// The `bounded` tuple (`Vec`) should have arity (length) `self.num_bounded()`. /// The function returns a sequence of (dynamically) tagged-tuples where the arity is `self.num_free()` - fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)>; + #[allow(unused_variables)] + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + panic!( + "[Internal Error] Missing evaluate function in the foreign predicate `{}`", + self.name() + ) + } + + /// Evaluate the foreign predicate given a tuple containing bounded variables and an environment + #[allow(unused_variables)] + fn evaluate_with_env(&self, env: &RuntimeEnvironment, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + self.evaluate(bounded) + } } /// The dynamic foreign predicate @@ -176,6 +217,10 @@ impl ForeignPredicate for DynamicForeignPredicate { self.fp.name() } + fn generic_type_parameters(&self) -> Vec { + self.fp.generic_type_parameters() + } + fn arity(&self) -> usize { self.fp.arity() } @@ -191,14 +236,20 @@ impl ForeignPredicate for DynamicForeignPredicate { fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { self.fp.evaluate(bounded) } + + fn evaluate_with_env(&self, env: &RuntimeEnvironment, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + self.fp.evaluate_with_env(env, bounded) + } } impl std::fmt::Debug for DynamicForeignPredicate { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f - .debug_struct("ForeignPredicate") + f.debug_struct("ForeignPredicate") .field("name", &self.name()) - .field("types", &(0..self.arity()).map(|i| self.argument_type(i)).collect::>()) + .field( + "types", + &(0..self.arity()).map(|i| self.argument_type(i)).collect::>(), + ) .field("num_bounded", &self.num_bounded()) .finish() } @@ -232,7 +283,7 @@ impl ForeignPredicateRegistry { /// Create an empty foreign predicate registry pub fn new() -> Self { Self { - registry: HashMap::new() + registry: HashMap::new(), } } @@ -256,15 +307,23 @@ impl ForeignPredicateRegistry { reg.register(fps::SoftNumberLt::new(value_type.clone())).unwrap(); } + // Register soft eq on tensors + reg.register(fps::SoftNumberEq::new(ValueType::Tensor)).unwrap(); + // String operations reg.register(fps::StringCharsBFF::new()).unwrap(); + reg.register(fps::StringFindBBFF::new()).unwrap(); + reg.register(fps::StringSplitBBF::new()).unwrap(); + + // DateTime + reg.register(fps::DateTimeYMD::new()).unwrap(); reg } /// Register a new foreign predicate in the registry pub fn register(&mut self, p: P) -> Result<(), ForeignPredicateError> { - let id = p.name(); + let id = p.internal_name(); if self.contains(&id) { Err(ForeignPredicateError::AlreadyExisted { id: format!("{}", id) }) } else { diff --git a/core/src/common/foreign_predicates/datetime_ymd.rs b/core/src/common/foreign_predicates/datetime_ymd.rs new file mode 100644 index 0000000..2bb3444 --- /dev/null +++ b/core/src/common/foreign_predicates/datetime_ymd.rs @@ -0,0 +1,58 @@ +use super::*; +use chrono::Datelike; + +/// DateTime YMD extraction foreign predicate +/// +/// ``` scl +/// extern pred datetime_ymd(d: DateTime, year: i32, month: u32, date: u32)[bfff] +/// ``` +#[derive(Clone)] +pub struct DateTimeYMD; + +impl Default for DateTimeYMD { + fn default() -> Self { + Self + } +} + +impl DateTimeYMD { + pub fn new() -> Self { + Self + } +} + +impl ForeignPredicate for DateTimeYMD { + fn name(&self) -> String { + "datetime_ymd".to_string() + } + + fn arity(&self) -> usize { + 4 + } + + fn argument_type(&self, i: usize) -> ValueType { + match i { + 0 => ValueType::DateTime, + 1 => ValueType::I32, + 2 | 3 => ValueType::U32, + _ => panic!("Invalid argument ID `{}`", i), + } + } + + fn num_bounded(&self) -> usize { + 1 + } + + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + assert_eq!(bounded.len(), 1); + match &bounded[0] { + Value::DateTime(dt) => { + vec![( + DynamicInputTag::None, + vec![Value::from(dt.year()), Value::from(dt.month()), Value::from(dt.day())], + )] + } + _ => panic!("Bounded argument is not a DateTime instance"), + } + } +} diff --git a/core/src/common/foreign_predicates/float_eq.rs b/core/src/common/foreign_predicates/float_eq.rs index 27dcc9f..f6710f3 100644 --- a/core/src/common/foreign_predicates/float_eq.rs +++ b/core/src/common/foreign_predicates/float_eq.rs @@ -14,23 +14,21 @@ pub struct FloatEq { impl FloatEq { pub fn new(ty: ValueType) -> Self { assert!(ty.is_float()); - Self { - ty, - threshold: 0.001, - } + Self { ty, threshold: 0.001 } } pub fn new_with_threshold(ty: ValueType, threshold: f64) -> Self { - Self { - ty, - threshold, - } + Self { ty, threshold } } } impl ForeignPredicate for FloatEq { fn name(&self) -> String { - format!("float_eq_{}", self.ty) + "float_eq".to_string() + } + + fn generic_type_parameters(&self) -> Vec { + vec![self.ty.clone()] } fn arity(&self) -> usize { @@ -53,10 +51,10 @@ impl ForeignPredicate for FloatEq { match (&self.ty, lhs, rhs) { (ValueType::F32, Value::F32(l), Value::F32(r)) if (l - r).abs() < (self.threshold as f32) => { vec![(DynamicInputTag::None, vec![])] - }, + } (ValueType::F64, Value::F64(l), Value::F64(r)) if (l - r).abs() < self.threshold => { vec![(DynamicInputTag::None, vec![])] - }, + } _ => vec![], } } diff --git a/core/src/common/foreign_predicates/mod.rs b/core/src/common/foreign_predicates/mod.rs index 4183c05..0482fba 100644 --- a/core/src/common/foreign_predicates/mod.rs +++ b/core/src/common/foreign_predicates/mod.rs @@ -4,11 +4,12 @@ use std::convert::*; use crate::utils::*; -use super::input_tag::*; use super::foreign_predicate::*; +use super::input_tag::*; use super::value::*; use super::value_type::*; +mod datetime_ymd; mod float_eq; mod range; mod soft_cmp; @@ -17,7 +18,10 @@ mod soft_gt; mod soft_lt; mod soft_neq; mod string_chars; +mod string_find; +mod string_split; +pub use datetime_ymd::*; pub use float_eq::*; pub use range::*; pub use soft_cmp::*; @@ -26,3 +30,5 @@ pub use soft_gt::*; pub use soft_lt::*; pub use soft_neq::*; pub use string_chars::*; +pub use string_find::*; +pub use string_split::*; diff --git a/core/src/common/foreign_predicates/range.rs b/core/src/common/foreign_predicates/range.rs index f890f6a..5341293 100644 --- a/core/src/common/foreign_predicates/range.rs +++ b/core/src/common/foreign_predicates/range.rs @@ -21,7 +21,10 @@ impl RangeBBF { } /// Compute the numbers between - fn range(begin: &Value, end: &Value) -> impl Iterator where Value: TryInto { + fn range(begin: &Value, end: &Value) -> impl Iterator + where + Value: TryInto, + { pub struct StepIterator { curr: T, end: T, @@ -49,14 +52,23 @@ impl RangeBBF { StepIterator { curr: begin, end } } - fn dyn_range>(begin: &Value, end: &Value) -> Vec<(DynamicInputTag, Vec)> where Value: TryInto { - Self::range::(begin, end).map(|i| (DynamicInputTag::None, vec![i.into()])).collect() + fn dyn_range>(begin: &Value, end: &Value) -> Vec<(DynamicInputTag, Vec)> + where + Value: TryInto, + { + Self::range::(begin, end) + .map(|i| (DynamicInputTag::None, vec![i.into()])) + .collect() } } impl ForeignPredicate for RangeBBF { fn name(&self) -> String { - format!("range_{}", self.ty) + "range".to_string() + } + + fn generic_type_parameters(&self) -> Vec { + vec![self.ty.clone()] } fn arity(&self) -> usize { diff --git a/core/src/common/foreign_predicates/soft_eq.rs b/core/src/common/foreign_predicates/soft_eq.rs index e4eb1da..43caa6a 100644 --- a/core/src/common/foreign_predicates/soft_eq.rs +++ b/core/src/common/foreign_predicates/soft_eq.rs @@ -1,5 +1,7 @@ //! The soft equality predicate +use crate::runtime::env::*; + use super::*; /// Soft EQ foreign predicate @@ -27,10 +29,7 @@ impl SoftNumberEq { } pub fn new_with_sigmoid_fn(ty: ValueType, sigmoid: SigmoidFunction) -> Self { - Self { - ty, - sigmoid, - } + Self { ty, sigmoid } } fn soft_eq(&self, lhs: &Value, rhs: &Value) -> Option @@ -55,11 +54,47 @@ impl SoftNumberEq { vec![] } } + + #[cfg(feature = "torch-tensor")] + fn soft_eq_tensor_wrapper( + &self, + env: &RuntimeEnvironment, + lhs: &Value, + rhs: &Value, + ) -> Vec<(DynamicInputTag, Vec)> { + use crate::common::tensors::*; + + match (lhs, rhs) { + (Value::TensorValue(tv1), Value::TensorValue(tv2)) => { + let t1 = env.tensor_registry.eval(tv1).tensor; + let t2 = env.tensor_registry.eval(tv2).tensor; + let r = t1.dot(&t2).sigmoid(); + let tag = DynamicInputTag::Tensor(Tensor::new(r)); + vec![(tag, vec![])] + } + _ => panic!("Input are not tensors"), + } + } + + #[allow(unused)] + #[cfg(not(feature = "torch-tensor"))] + fn soft_eq_tensor_wrapper( + &self, + env: &RuntimeEnvironment, + lhs: &Value, + rhs: &Value, + ) -> Vec<(DynamicInputTag, Vec)> { + vec![] + } } impl ForeignPredicate for SoftNumberEq { fn name(&self) -> String { - format!("soft_eq_{}", self.ty) + "soft_eq".to_string() + } + + fn generic_type_parameters(&self) -> Vec { + vec![self.ty.clone()] } fn arity(&self) -> usize { @@ -75,7 +110,7 @@ impl ForeignPredicate for SoftNumberEq { 2 } - fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + fn evaluate_with_env(&self, env: &RuntimeEnvironment, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { assert_eq!(bounded.len(), 2); let lhs = &bounded[0]; let rhs = &bounded[1]; @@ -83,17 +118,12 @@ impl ForeignPredicate for SoftNumberEq { ValueType::I8 => self.soft_eq_wrapper::(lhs, rhs), ValueType::I16 => self.soft_eq_wrapper::(lhs, rhs), ValueType::I32 => self.soft_eq_wrapper::(lhs, rhs), - // ValueType::I64 => self.soft_gt_wrapper::(lhs, rhs), - // ValueType::I128 => self.soft_gt_wrapper::(lhs, rhs), - // ValueType::ISize => self.soft_gt_wrapper::(lhs, rhs), ValueType::U8 => self.soft_eq_wrapper::(lhs, rhs), ValueType::U16 => self.soft_eq_wrapper::(lhs, rhs), ValueType::U32 => self.soft_eq_wrapper::(lhs, rhs), - // ValueType::U64 => self.soft_gt_wrapper::(lhs, rhs), - // ValueType::U128 => self.soft_gt_wrapper::(lhs, rhs), - // ValueType::USize => self.soft_gt_wrapper::(lhs, rhs), ValueType::F32 => self.soft_eq_wrapper::(lhs, rhs), ValueType::F64 => self.soft_eq_wrapper::(lhs, rhs), + ValueType::Tensor => self.soft_eq_tensor_wrapper(env, lhs, rhs), _ => vec![], } } diff --git a/core/src/common/foreign_predicates/soft_gt.rs b/core/src/common/foreign_predicates/soft_gt.rs index 9a72fcf..c7bb27e 100644 --- a/core/src/common/foreign_predicates/soft_gt.rs +++ b/core/src/common/foreign_predicates/soft_gt.rs @@ -25,10 +25,7 @@ impl SoftNumberGt { } pub fn new_with_sigmoid_fn(ty: ValueType, sigmoid: SigmoidFunction) -> Self { - Self { - ty, - sigmoid, - } + Self { ty, sigmoid } } fn soft_gt(&self, lhs: &Value, rhs: &Value) -> Option @@ -57,7 +54,11 @@ impl SoftNumberGt { impl ForeignPredicate for SoftNumberGt { fn name(&self) -> String { - format!("soft_gt_{}", self.ty) + "soft_gt".to_string() + } + + fn generic_type_parameters(&self) -> Vec { + vec![self.ty.clone()] } fn arity(&self) -> usize { diff --git a/core/src/common/foreign_predicates/soft_lt.rs b/core/src/common/foreign_predicates/soft_lt.rs index e2b9363..a9c85ed 100644 --- a/core/src/common/foreign_predicates/soft_lt.rs +++ b/core/src/common/foreign_predicates/soft_lt.rs @@ -20,15 +20,12 @@ impl SoftNumberLt { pub fn new(ty: ValueType) -> Self { Self { ty, - sigmoid: SigmoidFunction::default() + sigmoid: SigmoidFunction::default(), } } pub fn new_with_sigmoid_fn(ty: ValueType, sigmoid: SigmoidFunction) -> Self { - Self { - ty, - sigmoid, - } + Self { ty, sigmoid } } fn soft_lt(&self, lhs: &Value, rhs: &Value) -> Option @@ -57,7 +54,11 @@ impl SoftNumberLt { impl ForeignPredicate for SoftNumberLt { fn name(&self) -> String { - format!("soft_lt_{}", self.ty) + "soft_lt".to_string() + } + + fn generic_type_parameters(&self) -> Vec { + vec![self.ty.clone()] } fn arity(&self) -> usize { diff --git a/core/src/common/foreign_predicates/soft_neq.rs b/core/src/common/foreign_predicates/soft_neq.rs index 9593d45..a3b5390 100644 --- a/core/src/common/foreign_predicates/soft_neq.rs +++ b/core/src/common/foreign_predicates/soft_neq.rs @@ -27,10 +27,7 @@ impl SoftNumberNeq { } pub fn new_with_sigmoid_fn(ty: ValueType, sigmoid: SigmoidFunction) -> Self { - Self { - ty, - sigmoid, - } + Self { ty, sigmoid } } fn soft_neq(&self, lhs: &Value, rhs: &Value) -> Option @@ -59,7 +56,11 @@ impl SoftNumberNeq { impl ForeignPredicate for SoftNumberNeq { fn name(&self) -> String { - format!("soft_neq_{}", self.ty) + "soft_neq".to_string() + } + + fn generic_type_parameters(&self) -> Vec { + vec![self.ty.clone()] } fn arity(&self) -> usize { diff --git a/core/src/common/foreign_predicates/string_chars.rs b/core/src/common/foreign_predicates/string_chars.rs index 5f726b4..33431f9 100644 --- a/core/src/common/foreign_predicates/string_chars.rs +++ b/core/src/common/foreign_predicates/string_chars.rs @@ -1,6 +1,6 @@ use super::*; -/// Range foreign predicate +/// String chars foreign predicate /// /// ``` scl /// extern pred string_chars(s: String, id: usize, c: char)[bff] @@ -46,14 +46,12 @@ impl ForeignPredicate for StringCharsBFF { assert_eq!(bounded.len(), 1); let s = &bounded[0]; match s { - Value::String(s) => { - s - .chars() - .enumerate() - .map(|(i, c)| (DynamicInputTag::None, vec![Value::from(i), Value::from(c)])) - .collect() - } - _ => panic!("Bounded argument is not string") + Value::String(s) => s + .chars() + .enumerate() + .map(|(i, c)| (DynamicInputTag::None, vec![Value::from(i), Value::from(c)])) + .collect(), + _ => panic!("Bounded argument is not string"), } } } diff --git a/core/src/common/foreign_predicates/string_find.rs b/core/src/common/foreign_predicates/string_find.rs new file mode 100644 index 0000000..ee2de54 --- /dev/null +++ b/core/src/common/foreign_predicates/string_find.rs @@ -0,0 +1,56 @@ +use super::*; + +/// String find foreign predicate +/// +/// ``` scl +/// extern pred string_find(s: String, pattern: String, begin: usize, end: usize)[bbff] +/// ``` +#[derive(Clone)] +pub struct StringFindBBFF; + +impl Default for StringFindBBFF { + fn default() -> Self { + Self + } +} + +impl StringFindBBFF { + pub fn new() -> Self { + Self + } +} + +impl ForeignPredicate for StringFindBBFF { + fn name(&self) -> String { + "string_find".to_string() + } + + fn arity(&self) -> usize { + 4 + } + + fn argument_type(&self, i: usize) -> ValueType { + match i { + 0 | 1 => ValueType::String, + 2 | 3 => ValueType::USize, + _ => panic!("Invalid argument ID `{}`", i), + } + } + + fn num_bounded(&self) -> usize { + 2 + } + + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + assert_eq!(bounded.len(), 2); + match (&bounded[0], &bounded[1]) { + (Value::String(s), Value::String(pattern)) => { + let len: usize = pattern.chars().count(); // .len() doesn't work for non-ASCII str + s.match_indices(pattern.as_str()) + .map(|(i, _)| (DynamicInputTag::None, vec![Value::from(i), Value::from(i + len)])) + .collect() + } + _ => panic!("Bounded arguments are not strings"), + } + } +} diff --git a/core/src/common/foreign_predicates/string_split.rs b/core/src/common/foreign_predicates/string_split.rs new file mode 100644 index 0000000..bfa6cea --- /dev/null +++ b/core/src/common/foreign_predicates/string_split.rs @@ -0,0 +1,53 @@ +use super::*; + +/// String split foreign predicate +/// +/// ``` scl +/// extern pred string_split(s: String, pattern: String, output: String)[bbf] +/// ``` +#[derive(Clone)] +pub struct StringSplitBBF; + +impl Default for StringSplitBBF { + fn default() -> Self { + Self + } +} + +impl StringSplitBBF { + pub fn new() -> Self { + Self + } +} + +impl ForeignPredicate for StringSplitBBF { + fn name(&self) -> String { + "string_split".to_string() + } + + fn arity(&self) -> usize { + 3 + } + + fn argument_type(&self, i: usize) -> ValueType { + match i { + 0 | 1 | 2 => ValueType::String, + _ => panic!("Invalid argument ID `{}`", i), + } + } + + fn num_bounded(&self) -> usize { + 2 + } + + fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + assert_eq!(bounded.len(), 2); + match (&bounded[0], &bounded[1]) { + (Value::String(s), Value::String(pattern)) => s + .split(pattern.as_str()) + .map(|part| (DynamicInputTag::None, vec![Value::from(part.to_string())])) + .collect(), + _ => panic!("Bounded arguments are not strings"), + } + } +} diff --git a/core/src/common/generic_tuple.rs b/core/src/common/generic_tuple.rs index 37ec4c8..aec6563 100644 --- a/core/src/common/generic_tuple.rs +++ b/core/src/common/generic_tuple.rs @@ -22,6 +22,21 @@ impl GenericTuple { false } } + + /// Return the size of the tuple; if is a value, then `None` is returned + pub fn len(&self) -> Option { + match self { + Self::Value(_) => None, + Self::Tuple(ts) => Some(ts.len()), + } + } + + pub fn get_value(&self) -> Option<&T> { + match self { + Self::Value(v) => Some(v), + Self::Tuple(_) => None, + } + } } impl std::ops::Index for GenericTuple { diff --git a/core/src/common/input_file.rs b/core/src/common/input_file.rs index 7950176..d983a46 100644 --- a/core/src/common/input_file.rs +++ b/core/src/common/input_file.rs @@ -7,8 +7,9 @@ pub enum InputFile { deliminator: u8, has_header: bool, has_probability: bool, + keys: Option>, + fields: Option>, }, - Txt(PathBuf), } impl InputFile { @@ -18,6 +19,8 @@ impl InputFile { deliminator: b',', has_header: false, has_probability: false, + keys: None, + fields: None, } } @@ -26,12 +29,22 @@ impl InputFile { deliminator: Option, has_header: Option, has_probability: Option, + keys: Option>, + fields: Option>, ) -> Self { Self::Csv { file_path, deliminator: deliminator.unwrap_or(b','), - has_header: has_header.unwrap_or(false), + has_header: has_header.unwrap_or(false) || keys.is_some() || fields.is_some(), has_probability: has_probability.unwrap_or(false), + keys, + fields, + } + } + + pub fn file_path(&self) -> &PathBuf { + match self { + Self::Csv { file_path, .. } => file_path, } } } diff --git a/core/src/common/input_tag.rs b/core/src/common/input_tag.rs index 13f6c36..152b416 100644 --- a/core/src/common/input_tag.rs +++ b/core/src/common/input_tag.rs @@ -1,10 +1,15 @@ -#[derive(Clone, Debug, PartialEq, PartialOrd)] +use serde::*; + +use super::tensors::Tensor; + +#[derive(Clone, Debug, PartialEq, PartialOrd, Serialize)] pub enum DynamicInputTag { None, Exclusive(usize), Bool(bool), Float(f64), ExclusiveFloat(f64, usize), + Tensor(Tensor), } impl DynamicInputTag { @@ -40,6 +45,7 @@ impl std::fmt::Display for DynamicInputTag { Self::Bool(b) => b.fmt(f), Self::Float(n) => n.fmt(f), Self::ExclusiveFloat(n, i) => f.write_str(&format!("{} [ME({})]", n, i)), + Self::Tensor(t) => t.fmt(f), } } } diff --git a/core/src/common/mod.rs b/core/src/common/mod.rs index 1c97bb4..df6c2b1 100644 --- a/core/src/common/mod.rs +++ b/core/src/common/mod.rs @@ -4,6 +4,7 @@ pub mod aggregate_op; pub mod binary_op; pub mod constants; pub mod element; +pub mod entity; pub mod expr; pub mod foreign_function; pub mod foreign_functions; @@ -14,6 +15,8 @@ pub mod input_file; pub mod input_tag; pub mod output_option; pub mod predicate_set; +pub mod symbol_registry; +pub mod tensors; pub mod tuple; pub mod tuple_access; pub mod tuple_type; diff --git a/core/src/common/symbol_registry.rs b/core/src/common/symbol_registry.rs new file mode 100644 index 0000000..14c3892 --- /dev/null +++ b/core/src/common/symbol_registry.rs @@ -0,0 +1,44 @@ +use std::collections::*; + +#[derive(Clone, Debug)] +pub struct SymbolRegistry { + pub symbol_to_id_map: BTreeMap, + pub id_to_symbol_map: Vec, +} + +impl SymbolRegistry { + /// Create a new symbol registry + pub fn new() -> Self { + Self { + symbol_to_id_map: BTreeMap::new(), + id_to_symbol_map: Vec::new(), + } + } + + /// Check if the registry is empty + pub fn is_empty(&self) -> bool { + self.id_to_symbol_map.is_empty() + } + + /// Register a symbol into the registry and return an ID to represent the symbol + pub fn register(&mut self, symbol: String) -> usize { + if let Some(id) = self.symbol_to_id_map.get(&symbol) { + id.clone() + } else { + let id = self.id_to_symbol_map.len(); + self.symbol_to_id_map.insert(symbol.clone(), id); + self.id_to_symbol_map.push(symbol); + id + } + } + + /// Getting the symbol id using the string + pub fn get_id(&self, symbol: &str) -> Option { + self.symbol_to_id_map.get(symbol).cloned() + } + + /// Getting a symbol using its ID + pub fn get_symbol(&self, id: usize) -> Option<&String> { + self.id_to_symbol_map.get(id) + } +} diff --git a/core/src/common/tensors/convert.rs b/core/src/common/tensors/convert.rs new file mode 100644 index 0000000..a5cf75e --- /dev/null +++ b/core/src/common/tensors/convert.rs @@ -0,0 +1,15 @@ +use super::Tensor; + +/// The trait defining the piece of information can be converted from Tensor +/// +/// For Python, we want the external tag to be accessible from tensor +pub trait FromTensor: Clone + 'static { + fn from_tensor(tensor: Tensor) -> Option; +} + +impl FromTensor for () { + #[allow(unused)] + fn from_tensor(tensor: Tensor) -> Option { + None + } +} diff --git a/core/src/common/tensors/mod.rs b/core/src/common/tensors/mod.rs new file mode 100644 index 0000000..a18dabd --- /dev/null +++ b/core/src/common/tensors/mod.rs @@ -0,0 +1,17 @@ +mod convert; +mod msg; +mod registry; +mod shape; +mod symbol; +mod tensor; +mod torch; +mod value; + +pub use convert::*; +pub use msg::*; +pub use registry::*; +pub use shape::*; +pub use symbol::*; +pub use tensor::*; +pub use torch::*; +pub use value::*; diff --git a/core/src/common/tensors/msg.rs b/core/src/common/tensors/msg.rs new file mode 100644 index 0000000..547e3b7 --- /dev/null +++ b/core/src/common/tensors/msg.rs @@ -0,0 +1,2 @@ +pub const NO_TORCH_MSG: &'static str = + "This version of `scallop-core` is not compiled with `torch`. Consider adding feature flag `torch-tensor`"; diff --git a/core/src/common/tensors/registry.rs b/core/src/common/tensors/registry.rs new file mode 100644 index 0000000..d41be29 --- /dev/null +++ b/core/src/common/tensors/registry.rs @@ -0,0 +1,49 @@ +use std::collections::*; + +use super::*; + +pub struct TensorRegistry { + tensors: HashMap>, +} + +impl TensorRegistry { + pub fn new() -> Self { + Self { + tensors: HashMap::new(), + } + } + + pub fn register(&mut self, tensor: Tensor) -> TensorSymbol { + let shape = tensor.shape(); + let tensors_under_shape = self.tensors.entry(shape.clone()).or_default(); + let id = tensors_under_shape.len(); + tensors_under_shape.push(tensor); + TensorSymbol::new(shape, id) + } + + pub fn get(&self, symbol: &TensorSymbol) -> Option<&Tensor> { + self.tensors.get(&symbol.shape).and_then(|ts| ts.get(symbol.id)) + } + + #[cfg(feature = "torch-tensor")] + pub fn eval_expr(&self, value: &TensorExpr) -> Tensor { + match value { + TensorExpr::Symbol(s) => self.get(s).expect("Cannot find symbol").clone(), + TensorExpr::Float(f) => Tensor::new((*f).into()), + TensorExpr::Add(v1, v2) => Tensor::new(self.eval_expr(v1).tensor + self.eval_expr(v2).tensor), + TensorExpr::Sub(v1, v2) => Tensor::new(self.eval_expr(v1).tensor - self.eval_expr(v2).tensor), + TensorExpr::Mul(v1, v2) => Tensor::new(self.eval_expr(v1).tensor * self.eval_expr(v2).tensor), + TensorExpr::Dot(v1, v2) => Tensor::new(self.eval_expr(v1).tensor.dot(&self.eval_expr(v2).tensor)), + } + } + + #[cfg(not(feature = "torch-tensor"))] + #[allow(unused)] + pub fn eval_expr(&self, value: &TensorExpr) -> Tensor { + panic!("{}", NO_TORCH_MSG) + } + + pub fn eval(&self, value: &TensorValue) -> Tensor { + self.eval_expr(&value.expr) + } +} diff --git a/core/src/common/tensors/shape.rs b/core/src/common/tensors/shape.rs new file mode 100644 index 0000000..0d375bb --- /dev/null +++ b/core/src/common/tensors/shape.rs @@ -0,0 +1,41 @@ +/// The shape of a tensor +#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct TensorShape(Box<[i64]>); + +impl TensorShape { + pub fn scalar() -> Self { + Self(Box::new([])) + } + + /// Get the number of dimensions + pub fn dim(&self) -> usize { + self.0.len() + } +} + +impl From> for TensorShape { + fn from(shape: Vec) -> Self { + Self(shape.into_iter().collect()) + } +} + +impl std::ops::Index for TensorShape { + type Output = i64; + + fn index(&self, index: usize) -> &Self::Output { + &self.0[index] + } +} + +impl std::fmt::Display for TensorShape { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("<")?; + for i in 0..self.dim() { + if i > 0 { + f.write_str(", ")?; + } + self.0[i].fmt(f)?; + } + f.write_str(">") + } +} diff --git a/core/src/common/tensors/symbol.rs b/core/src/common/tensors/symbol.rs new file mode 100644 index 0000000..942bc65 --- /dev/null +++ b/core/src/common/tensors/symbol.rs @@ -0,0 +1,22 @@ +use super::*; + +/// A symbolic version of the tensor, storing its shape and ID under the shape. +/// Conceptually, this is a "pointer" to the actual tensor in the tensor registry. +#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub struct TensorSymbol { + pub shape: TensorShape, + pub id: usize, +} + +impl TensorSymbol { + /// Create a new tensor symbol + pub fn new(shape: TensorShape, id: usize) -> Self { + Self { shape, id } + } +} + +impl std::fmt::Display for TensorSymbol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("tensor{}(#{})", self.shape, self.id)) + } +} diff --git a/core/src/common/tensors/tensor.rs b/core/src/common/tensors/tensor.rs new file mode 100644 index 0000000..d1241cf --- /dev/null +++ b/core/src/common/tensors/tensor.rs @@ -0,0 +1,108 @@ +use serde::*; + +use super::*; + +/// An actual tensor containing +pub struct Tensor { + #[cfg(feature = "torch-tensor")] + pub tensor: TorchTensor, +} + +impl Clone for Tensor { + #[cfg(feature = "torch-tensor")] + fn clone(&self) -> Self { + Self { + tensor: self.tensor.shallow_clone(), + } + } + + #[cfg(not(feature = "torch-tensor"))] + fn clone(&self) -> Self { + Self {} + } +} + +impl std::fmt::Debug for Tensor { + #[cfg(feature = "torch-tensor")] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("Tensor(<{:?}>)", self.tensor.as_ptr())) + } + + #[cfg(not(feature = "torch-tensor"))] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("Tensor") + } +} + +impl std::fmt::Display for Tensor { + #[cfg(feature = "torch-tensor")] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("Tensor({})", self.tensor)) + } + + #[cfg(not(feature = "torch-tensor"))] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("Tensor") + } +} + +impl std::cmp::PartialEq for Tensor { + fn eq(&self, _: &Self) -> bool { + true + } +} + +impl std::cmp::Eq for Tensor {} + +impl std::cmp::PartialOrd for Tensor { + fn partial_cmp(&self, _: &Self) -> Option { + Some(std::cmp::Ordering::Equal) + } +} + +impl std::cmp::Ord for Tensor { + fn cmp(&self, _: &Self) -> std::cmp::Ordering { + std::cmp::Ordering::Equal + } +} + +unsafe impl Send for Tensor {} + +unsafe impl Sync for Tensor {} + +impl Tensor { + #[cfg(feature = "torch-tensor")] + pub fn new(tensor: TorchTensor) -> Self { + Self { tensor } + } + + #[cfg(feature = "torch-tensor")] + pub fn shape(&self) -> TensorShape { + TensorShape::from(self.tensor.size()) + } + + #[cfg(not(feature = "torch-tensor"))] + pub fn shape(&self) -> TensorShape { + TensorShape::scalar() + } + + #[cfg(feature = "torch-tensor")] + pub fn get_f64(&self) -> f64 { + self.tensor.double_value(&[]) + } + + #[cfg(not(feature = "torch-tensor"))] + pub fn get_f64(&self) -> f64 { + panic!("{}", NO_TORCH_MSG) + } +} + +impl Serialize for Tensor { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + use serde::ser::*; + serializer.serialize_struct("Tensor", 0)?.end() + } +} diff --git a/core/src/common/tensors/torch.rs b/core/src/common/tensors/torch.rs new file mode 100644 index 0000000..67d2aff --- /dev/null +++ b/core/src/common/tensors/torch.rs @@ -0,0 +1,2 @@ +#[cfg(feature = "torch-tensor")] +pub use tch::Tensor as TorchTensor; diff --git a/core/src/common/tensors/value.rs b/core/src/common/tensors/value.rs new file mode 100644 index 0000000..97aa5ab --- /dev/null +++ b/core/src/common/tensors/value.rs @@ -0,0 +1,186 @@ +use super::*; + +/// A symbolic value of the tensor +#[derive(Debug, Clone, PartialEq, PartialOrd, Hash)] +pub struct TensorValue { + pub shape: TensorShape, + pub expr: TensorExpr, +} + +impl TensorValue { + pub fn is_scalar(&self) -> bool { + self.shape.dim() == 0 + } + + pub fn add(self, v2: TensorValue) -> Option { + if self.shape == v2.shape { + Some(TensorValue { + shape: self.shape, + expr: self.expr + v2.expr, + }) + } else { + None + } + } + + pub fn sub(self, v2: TensorValue) -> Option { + if self.shape == v2.shape { + Some(TensorValue { + shape: self.shape, + expr: self.expr - v2.expr, + }) + } else { + None + } + } + + pub fn mul(self, v2: TensorValue) -> Option { + if self.shape == v2.shape || v2.is_scalar() { + Some(TensorValue { + shape: self.shape, + expr: self.expr * v2.expr, + }) + } else if v2.is_scalar() { + Some(TensorValue { + shape: v2.shape, + expr: self.expr * v2.expr, + }) + } else { + None + } + } + + pub fn dot(self, v2: TensorValue) -> Option { + if self.shape.dim() == 1 && self.shape == v2.shape { + Some(TensorValue { + shape: TensorShape::scalar(), + expr: self.expr.dot(v2.expr), + }) + } else { + None + } + } +} + +impl From for TensorValue { + fn from(value: f64) -> Self { + Self { + shape: TensorShape::scalar(), + expr: TensorExpr::Float(value), + } + } +} + +impl From for TensorValue { + fn from(value: TensorSymbol) -> Self { + Self { + shape: value.shape.clone(), + expr: TensorExpr::Symbol(value), + } + } +} + +impl std::fmt::Display for TensorValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.expr.fmt(f) + } +} + +#[derive(Debug, Clone, PartialEq, PartialOrd)] +pub enum TensorExpr { + /// A base tensor + Symbol(TensorSymbol), + + /// Constant floating point + Float(f64), + + /// A sum of two tensors + Add(Box, Box), + + /// A subtraction between two tensors + Sub(Box, Box), + + /// An element-wise multiplication of two tensors + Mul(Box, Box), + + /// A dot product between two tensors + Dot(Box, Box), +} + +impl TensorExpr { + pub fn dot(self, other: Self) -> Self { + Self::Dot(Box::new(self), Box::new(other)) + } +} + +impl std::fmt::Display for TensorExpr { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Symbol(s) => s.fmt(f), + Self::Float(n) => n.fmt(f), + Self::Add(v1, v2) => f.write_fmt(format_args!("{} + {}", v1, v2)), + Self::Sub(v1, v2) => f.write_fmt(format_args!("{} - {}", v1, v2)), + Self::Mul(v1, v2) => f.write_fmt(format_args!("{} * {}", v1, v2)), + Self::Dot(v1, v2) => f.write_fmt(format_args!("dot({}, {})", v1, v2)), + } + } +} + +impl std::hash::Hash for TensorExpr { + fn hash(&self, state: &mut H) { + match self { + Self::Symbol(s) => s.hash(state), + Self::Float(f) => i64::from_ne_bytes(f.to_ne_bytes()).hash(state), + Self::Add(v1, v2) => { + "add".hash(state); + v1.hash(state); + v2.hash(state); + } + Self::Sub(v1, v2) => { + "sub".hash(state); + v1.hash(state); + v2.hash(state); + } + Self::Mul(v1, v2) => { + "mul".hash(state); + v1.hash(state); + v2.hash(state); + } + Self::Dot(v1, v2) => { + "dot".hash(state); + v1.hash(state); + v2.hash(state); + } + } + } +} + +impl From for TensorExpr { + fn from(sym: TensorSymbol) -> Self { + Self::Symbol(sym) + } +} + +impl std::ops::Add for TensorExpr { + type Output = TensorExpr; + + fn add(self, rhs: TensorExpr) -> Self::Output { + TensorExpr::Add(Box::new(self), Box::new(rhs)) + } +} + +impl std::ops::Sub for TensorExpr { + type Output = TensorExpr; + + fn sub(self, rhs: TensorExpr) -> Self::Output { + TensorExpr::Sub(Box::new(self), Box::new(rhs)) + } +} + +impl std::ops::Mul for TensorExpr { + type Output = TensorExpr; + + fn mul(self, rhs: TensorExpr) -> Self::Output { + TensorExpr::Mul(Box::new(self), Box::new(rhs)) + } +} diff --git a/core/src/common/tuple.rs b/core/src/common/tuple.rs index a7049d6..22354a2 100644 --- a/core/src/common/tuple.rs +++ b/core/src/common/tuple.rs @@ -18,6 +18,14 @@ impl Tuple { } } + pub fn singleton(value: Value) -> Self { + Self::Tuple(Box::new([Self::Value(value)])) + } + + pub fn from_values>(values: I) -> Self { + Self::Tuple(values.into_iter().map(Self::Value).collect()) + } + pub fn as_values(&self) -> Vec { match self { Self::Value(_) => panic!("Not a tuple"), @@ -51,6 +59,13 @@ impl Tuple { } } + pub fn to_value(self) -> Value { + match self { + Self::Value(p) => p, + _ => panic!("Not a value"), + } + } + pub fn as_i8(&self) -> i8 { AsTuple::::as_tuple(self) } @@ -356,15 +371,6 @@ impl AsTuple for Tuple { } } -// impl AsTuple> for Tuple { -// fn as_tuple(&self) -> Rc { -// match self { -// Self::Value(Value::RcString(s)) => s.clone(), -// _ => panic!("Cannot perform as_tuple>"), -// } -// } -// } - impl AsTuple<()> for Tuple { fn as_tuple(&self) -> () { match self { diff --git a/core/src/common/unary_op.rs b/core/src/common/unary_op.rs index 6fb6e15..91f9d3b 100644 --- a/core/src/common/unary_op.rs +++ b/core/src/common/unary_op.rs @@ -1,6 +1,8 @@ +use serde::*; + use super::value_type::ValueType; -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize)] pub enum UnaryOp { Neg, Pos, diff --git a/core/src/common/value.rs b/core/src/common/value.rs index 981549b..687ef1f 100644 --- a/core/src/common/value.rs +++ b/core/src/common/value.rs @@ -1,10 +1,10 @@ -// use std::rc::Rc; - use std::convert::*; -use super::value_type::*; use chrono::{DateTime, Duration, Utc}; +use super::tensors::*; +use super::value_type::*; + #[derive(Debug, Clone, PartialEq, PartialOrd)] pub enum Value { I8(i8), @@ -25,9 +25,13 @@ pub enum Value { Bool(bool), Str(&'static str), String(String), + Symbol(usize), + SymbolString(String), DateTime(DateTime), Duration(Duration), - // RcString(Rc), + Entity(u64), + Tensor(Tensor), + TensorValue(TensorValue), } impl Value { @@ -38,14 +42,14 @@ impl Value { pub fn as_date_time(&self) -> DateTime { match self { Self::DateTime(d) => d.clone(), - _ => panic!("Not a DateTime") + _ => panic!("Not a DateTime"), } } pub fn as_duration(&self) -> Duration { match self { Self::Duration(d) => d.clone(), - _ => panic!("Not a Duration") + _ => panic!("Not a Duration"), } } @@ -63,6 +67,10 @@ impl Value { v => panic!("Cannot get string from value {}", v), } } + + pub fn symbol_str(s: &str) -> Self { + Self::SymbolString(s.to_string()) + } } impl Eq for Value {} @@ -97,8 +105,16 @@ impl std::hash::Hash for Value { Self::Bool(b) => b.hash(state), Self::Str(s) => s.hash(state), Self::String(s) => s.hash(state), + Self::Symbol(s) => s.hash(state), + Self::SymbolString(_) => panic!("[Internal Error] Hash should not happen for symbol string"), Self::DateTime(d) => d.hash(state), Self::Duration(d) => d.hash(state), + Self::Entity(e) => { + "entity".hash(state); + e.hash(state); + } + Self::Tensor(_) => panic!("[Internal Error] Hash should not happen for tensor"), + Self::TensorValue(v) => v.hash(state), } } } @@ -124,9 +140,13 @@ impl std::fmt::Display for Value { Self::Bool(i) => f.write_fmt(format_args!("{}", i)), Self::Str(i) => f.write_fmt(format_args!("{:?}", i)), Self::String(i) => f.write_fmt(format_args!("{:?}", i)), + Self::Symbol(_) => panic!("[Internal Error] Cannot display symbol"), + Self::SymbolString(s) => f.write_fmt(format_args!("s\"{}\"", s)), Self::DateTime(i) => f.write_fmt(format_args!("t\"{}\"", i)), Self::Duration(i) => f.write_fmt(format_args!("d\"{}\"", i)), - // Self::RcString(i) => f.write_fmt(format_args!("{:?}", i)), + Self::Entity(e) => f.write_fmt(format_args!("entity({e:#x})")), + Self::Tensor(t) => f.write_fmt(format_args!("{:?}", t)), + Self::TensorValue(v) => f.write_fmt(format_args!("`{}`", v)), } } } diff --git a/core/src/common/value_type.rs b/core/src/common/value_type.rs index 3ce1c86..91f7f43 100644 --- a/core/src/common/value_type.rs +++ b/core/src/common/value_type.rs @@ -1,11 +1,11 @@ -// use std::rc::Rc; +use serde::*; use crate::utils; use super::tuple::*; use super::value::*; -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] pub enum ValueType { I8, I16, @@ -25,9 +25,11 @@ pub enum ValueType { Bool, Str, String, + Symbol, DateTime, Duration, - // RcString, + Entity, + Tensor, } impl ValueType { @@ -52,9 +54,13 @@ impl ValueType { Bool(_) => Self::Bool, Str(_) => Self::Str, String(_) => Self::String, + Symbol(_) => Self::Symbol, + SymbolString(_) => Self::Symbol, DateTime(_) => Self::DateTime, Duration(_) => Self::Duration, - // RcString(_) => Self::RcString, + Entity(_) => Self::Entity, + Tensor(_) => Self::Tensor, + TensorValue(_) => Self::Tensor, } } @@ -180,12 +186,19 @@ impl ValueType { } } + pub fn is_entity(&self) -> bool { + match self { + Self::Entity => true, + _ => false, + } + } + pub fn can_type_cast(&self, target: &Self) -> bool { if self.is_numeric() && target.is_numeric() { true } else if self.is_boolean() && target.is_boolean() { true - } else if self.is_char() && (target.is_char() || target.is_string() || target.is_integer()) { + } else if self.is_char() && (target.is_char() || target.is_string() || target.is_integer() || target.is_float()) { true } else if self.is_string() && target.is_numeric() { true @@ -227,11 +240,21 @@ impl ValueType { // String Self::Str => panic!("Cannot parse into a static string"), Self::String => Ok(Value::String(s.to_string())), - // Self::RcString => Ok(Value::RcString(Rc::new(s.to_string()))), + Self::Symbol => panic!("Cannot parse into a symbol"), // DateTime and Duration - Self::DateTime => Ok(Value::DateTime(utils::parse_date_time_string(s).ok_or_else(|| ValueParseError::new(s, self))?)), - Self::Duration => Ok(Value::Duration(utils::parse_duration_string(s).ok_or_else(|| ValueParseError::new(s, self))?)), + Self::DateTime => Ok(Value::DateTime( + utils::parse_date_time_string(s).ok_or_else(|| ValueParseError::new(s, self))?, + )), + Self::Duration => Ok(Value::Duration( + utils::parse_duration_string(s).ok_or_else(|| ValueParseError::new(s, self))?, + )), + + // Entity + Self::Entity => panic!("Cannot parse into an entity"), + + // Tensor + Self::Tensor => panic!("Cannot parse into tensor"), } } @@ -331,10 +354,7 @@ impl ValueType { /// Get all floating point number types pub fn floats() -> &'static [ValueType] { - &[ - ValueType::F32, - ValueType::F64, - ] + &[ValueType::F32, ValueType::F64] } } @@ -381,9 +401,11 @@ impl std::fmt::Display for ValueType { Bool => f.write_str("bool"), Str => f.write_str("&str"), String => f.write_str("String"), - // RcString => f.write_str("Rc"), + Symbol => f.write_str("Symbol"), DateTime => f.write_str("DateTime"), Duration => f.write_str("Duration"), + Entity => f.write_str("Entity"), + Tensor => f.write_str("Tensor"), } } } @@ -510,9 +532,3 @@ impl FromType for ValueType { Self::Duration } } - -// impl FromType> for ValueType { -// fn from_type() -> Self { -// Self::RcString -// } -// } diff --git a/core/src/compiler/back/ast.rs b/core/src/compiler/back/ast.rs index b4e707b..50f8a0b 100644 --- a/core/src/compiler/back/ast.rs +++ b/core/src/compiler/back/ast.rs @@ -102,6 +102,16 @@ impl Rule { pub fn body_literals_mut(&mut self) -> impl Iterator { self.body.args.iter_mut() } + + pub fn collect_new_expr_functors(&self) -> impl Iterator { + self.body_literals().filter_map(|l| match l { + Literal::Assign(assign) => match &assign.right { + AssignExpr::New(new_expr) => Some(&new_expr.functor), + _ => None, + }, + _ => None, + }) + } } #[derive(Clone, Debug, PartialEq)] @@ -156,7 +166,7 @@ impl Head { Self::Atom(_) => { // Atomic head has only one pattern false - }, + } Self::Disjunction(disj) => { // Extract the pattern of the first atom in the disjunction let first_pattern = disj[0] @@ -170,19 +180,16 @@ impl Head { // Check if the first pattern is satisfied by all other atoms for a in disj.iter().skip(1) { - let satisfies_pattern = a.args - .iter() - .enumerate() - .all(|(i, t)| { - if let Some(p) = first_pattern.get(i) { - match t { - Term::Variable(v) => p == &v.name, - Term::Constant(_) => p.is_empty(), - } - } else { - false + let satisfies_pattern = a.args.iter().enumerate().all(|(i, t)| { + if let Some(p) = first_pattern.get(i) { + match t { + Term::Variable(v) => p == &v.name, + Term::Constant(_) => p.is_empty(), } - }); + } else { + false + } + }); // If not satisfied, then the head has multiple patterns if !satisfies_pattern { @@ -192,7 +199,7 @@ impl Head { // If all atoms satisfy the first pattern, then the head has only one pattern false - }, + } } } } @@ -319,6 +326,14 @@ impl Literal { }) } + /// Create a new assignment of call expression + pub fn new_expr(left: Variable, functor: String, args: Vec) -> Self { + Self::Assign(Assign { + left, + right: AssignExpr::New(NewExpr { functor, args }), + }) + } + /// Create a new assignment of if-then-else expression pub fn binary_constraint(op: BinaryConstraintOp, op1: Term, op2: Term) -> Self { Self::Constraint(Constraint::Binary(BinaryConstraint { op, op1, op2 })) @@ -383,6 +398,24 @@ impl Atom { self.args.iter().any(|a| a.is_constant()) } + /// Check if the constants in the atom's arguments are all in the front + pub fn constant_args_are_upfront(&self) -> bool { + let mut encountered_variable = false; + for arg in &self.args { + match arg { + Term::Variable(_) => { + encountered_variable = true; + } + Term::Constant(_) => { + if encountered_variable { + return false; + } + } + } + } + return true; + } + /// Get the constant arguments pub fn constant_args(&self) -> impl Iterator { self.args.iter().filter_map(|a| match a { @@ -452,6 +485,11 @@ impl Assign { args.extend(a.as_variable().iter()); } } + AssignExpr::New(n) => { + for a in &n.args { + args.extend(a.as_variable().iter()); + } + } } args } @@ -463,6 +501,16 @@ pub enum AssignExpr { Unary(UnaryAssignExpr), IfThenElse(IfThenElseAssignExpr), Call(CallExpr), + New(NewExpr), +} + +impl AssignExpr { + pub fn is_new_expr(&self) -> bool { + match self { + Self::New(_) => true, + _ => false, + } + } } pub type BinaryExprOp = crate::common::binary_op::BinaryOp; @@ -495,6 +543,12 @@ pub struct CallExpr { pub args: Vec, } +#[derive(Clone, Debug, PartialEq)] +pub struct NewExpr { + pub functor: String, + pub args: Vec, +} + /// A constraint literal which is either a binary or unary constraint #[derive(Clone, Debug, PartialEq)] pub enum Constraint { @@ -505,12 +559,20 @@ pub enum Constraint { impl Constraint { /// Create a new equality constraint using two terms pub fn eq(op1: Term, op2: Term) -> Self { - Self::Binary(BinaryConstraint { op: BinaryConstraintOp::Eq, op1, op2 }) + Self::Binary(BinaryConstraint { + op: BinaryConstraintOp::Eq, + op1, + op2, + }) } /// Create a new inequality constraint using two terms pub fn neq(op1: Term, op2: Term) -> Self { - Self::Binary(BinaryConstraint { op: BinaryConstraintOp::Neq, op1, op2 }) + Self::Binary(BinaryConstraint { + op: BinaryConstraintOp::Neq, + op1, + op2, + }) } /// Create a new binary constraint using an operator and two terms @@ -587,10 +649,10 @@ pub enum UnaryConstraintOp { Not, } -impl From<&front::UnaryOpNode> for Option { - fn from(op: &front::UnaryOpNode) -> Self { +impl From<&front::ast::UnaryOpNode> for Option { + fn from(op: &front::ast::UnaryOpNode) -> Self { match op { - front::UnaryOpNode::Not => Some(UnaryConstraintOp::Not), + front::ast::UnaryOpNode::Not => Some(UnaryConstraintOp::Not), _ => None, } } diff --git a/core/src/compiler/back/attr.rs b/core/src/compiler/back/attr.rs index 516d2d2..abfa306 100644 --- a/core/src/compiler/back/attr.rs +++ b/core/src/compiler/back/attr.rs @@ -1,3 +1,4 @@ +use crate::common::aggregate_op::AggregateOp; use crate::common::input_file::InputFile; #[derive(Clone, Debug, PartialEq)] @@ -95,8 +96,14 @@ pub enum Attribute { } impl Attribute { - pub fn aggregate_body(num_group_by_vars: usize, num_arg_vars: usize, num_key_vars: usize) -> Self { + pub fn aggregate_body( + aggregator: AggregateOp, + num_group_by_vars: usize, + num_arg_vars: usize, + num_key_vars: usize, + ) -> Self { Self::AggregateBody(AggregateBodyAttribute { + aggregator, num_group_by_vars, num_arg_vars, num_key_vars, @@ -117,14 +124,16 @@ impl Attribute { #[derive(Clone, Debug, PartialEq)] pub struct AggregateBodyAttribute { + pub aggregator: AggregateOp, pub num_group_by_vars: usize, pub num_arg_vars: usize, pub num_key_vars: usize, } impl AggregateBodyAttribute { - pub fn new(num_group_by_vars: usize, num_arg_vars: usize, num_key_vars: usize) -> Self { + pub fn new(aggregator: AggregateOp, num_group_by_vars: usize, num_arg_vars: usize, num_key_vars: usize) -> Self { Self { + aggregator, num_group_by_vars, num_arg_vars, num_key_vars, diff --git a/core/src/compiler/back/b2r.rs b/core/src/compiler/back/b2r.rs index a1d3911..b3d0c20 100644 --- a/core/src/compiler/back/b2r.rs +++ b/core/src/compiler/back/b2r.rs @@ -17,6 +17,7 @@ use crate::utils::IdAllocator; struct B2RContext<'a> { id_alloc: &'a mut IdAllocator, relations: &'a mut BTreeMap, + old_relations: &'a BTreeMap, temp_updates: &'a mut Vec, pred_permutations: &'a mut HashMap>, negative_dataflows: &'a mut Vec, @@ -26,6 +27,10 @@ impl<'a> B2RContext<'a> { pub fn add_permutation(&mut self, pred: String, perm: Permutation) { self.pred_permutations.entry(pred).or_default().insert(perm); } + + pub fn get_relation(&self, pred: &str) -> Option<&ram::Relation> { + self.relations.get(pred).or_else(|| self.old_relations.get(pred)) + } } struct NegativeDataflow { @@ -76,6 +81,7 @@ impl Program { let mut id_alloc = IdAllocator::default(); let mut pred_permutations = HashMap::>::new(); let mut negative_dataflows = Vec::::new(); + let mut all_relations = BTreeMap::new(); // For each stratum, generate a ram stratum // populate the EDB permutations @@ -84,8 +90,16 @@ impl Program { .enumerate() .map(|(_, s)| { // Compute the ram stratum - let ram_stratum = - self.stratum_to_ram_stratum(s, &mut id_alloc, &mut pred_permutations, &mut negative_dataflows); + let ram_stratum = self.stratum_to_ram_stratum( + s, + &mut id_alloc, + &all_relations, + &mut pred_permutations, + &mut negative_dataflows, + ); + + // Extend the list of relations + all_relations.extend(ram_stratum.relations.clone()); // Return the stratum ram_stratum @@ -123,7 +137,7 @@ impl Program { // Get the negative dataflows that can be computed at this stratum let curr_neg_dfs = negative_dataflows - .drain_filter(|ndf| ndf.sources.is_subset(&accumulated_sources)) + .extract_if(|ndf| ndf.sources.is_subset(&accumulated_sources)) .collect::>(); // Add the negative dataflow into the stratum @@ -145,6 +159,7 @@ impl Program { &self, stratum: &Stratum, id_alloc: &mut IdAllocator, + old_relations: &BTreeMap, pred_permutations: &mut HashMap>, negative_dataflows: &mut Vec, ) -> ram::Stratum { @@ -158,6 +173,7 @@ impl Program { // Compile context let mut b2r_context = B2RContext { id_alloc, + old_relations, relations: &mut relations, temp_updates: &mut temp_updates, pred_permutations, @@ -274,7 +290,17 @@ impl Program { let output = self.outputs.get(pred).cloned().unwrap_or(OutputOption::Hidden); // Check immutability, i.e., the relation is not updated by rules - let immutable = self.rules.iter().find_position(|r| r.head.predicate() == pred).is_none(); + let immutable = self + .rules + .iter() + .find_position(|r| { + if r.head.predicate() == pred { + true + } else { + r.collect_new_expr_functors().find_position(|f| *f == pred).is_some() + } + }) + .is_none(); // The Final Relation let ram_relation = ram::Relation { @@ -356,18 +382,16 @@ impl Program { let projection: Expr = head_atoms[0] .args .iter() - .map(|arg| { - match arg { - Term::Variable(_) => { - let result = Expr::access((0, var_counter)); - var_counter += 1; - result - } - Term::Constant(_) => { - let result = Expr::access((1, const_counter)); - const_counter += 1; - result - } + .map(|arg| match arg { + Term::Variable(_) => { + let result = Expr::access((0, var_counter)); + var_counter += 1; + result + } + Term::Constant(_) => { + let result = Expr::access((1, const_counter)); + const_counter += 1; + result } }) .collect(); @@ -408,117 +432,198 @@ impl Program { } } - fn ground_plan_to_ram_dataflow( + fn empty_ground_plan_to_ram_dataflow(&self, goal: &VariableTuple, atom: &Atom) -> ram::Dataflow { + let ground = ram::Dataflow::relation(atom.predicate.clone()); + if goal.matches(atom) { + ground + } else { + let atom_var_tuple = VariableTuple::empty(); + ram::Dataflow::project(ground, atom_var_tuple.projection(goal)) + } + } + + fn static_with_constant_ground_plan_to_ram_dataflow( &self, ctx: &mut B2RContext, goal: &VariableTuple, atom: &Atom, prop: DataflowProp, ) -> ram::Dataflow { - if atom.args.is_empty() { - let ground = ram::Dataflow::relation(atom.predicate.clone()); - if goal.matches(atom) { - ground - } else { - let atom_var_tuple = VariableTuple::empty(); - ram::Dataflow::project(ground, atom_var_tuple.projection(goal)) - } - } else if atom.has_constant_arg() { - // Use find - let (constants, variables) = atom.const_var_partition(); - let const_var = |i: usize, ty: Type| -> Variable { - Variable { - name: format!("const#{}", i), - ty, - } - }; + // Use filter + let sub_vars = atom + .args + .iter() + .enumerate() + .map(|(i, t)| match t { + Term::Constant(c) => const_var(i, c.value_type()), + Term::Variable(v) => v.clone(), + }) + .collect::>(); - // Temp atom - let sub_atom = Atom { - predicate: atom.predicate.clone(), - args: atom - .args + // Create an atom + let sub_atom = Atom { + predicate: atom.predicate.clone(), + args: sub_vars.iter().map(|v| Term::Variable(v.clone())).collect(), + }; + let sub_goal = VariableTuple::from_vars(sub_vars.iter().cloned(), true); + let sub_dataflow = self.ground_plan_to_ram_dataflow(ctx, &sub_goal, &sub_atom, prop.with_need_sorted(false)); + + // Get the filters + let mut filter_exprs = atom.args.iter().enumerate().filter_map(|(i, t)| match t { + Term::Constant(c) => Some(Expr::eq(Expr::access(i), Expr::constant(c.clone()))), + Term::Variable(_) => None, + }); + let mut filter_expr = filter_exprs + .next() + .expect("There should be at least one filter expression."); + for expr in filter_exprs { + filter_expr = filter_expr & expr; + } + + // Filter dataflow + let filter_dataflow = sub_dataflow.filter(filter_expr); + + // Projection dataflow + let project_dataflow = filter_dataflow.project(sub_goal.projection(goal)); + + // Check if we need to create temporary variable + Self::process_dataflow(ctx, goal, project_dataflow, prop) + } + + fn dynamic_with_constant_ground_plan_to_ram_dataflow( + &self, + ctx: &mut B2RContext, + goal: &VariableTuple, + atom: &Atom, + prop: DataflowProp, + ) -> ram::Dataflow { + // Use find + let (constants, variables) = atom.const_var_partition(); + + // Temp atom + let sub_atom = Atom { + predicate: atom.predicate.clone(), + args: atom + .args + .iter() + .enumerate() + .map(|(i, t)| match t { + Term::Constant(c) => Term::Variable(const_var(i, c.value_type())), + Term::Variable(v) => Term::Variable(v.clone()), + }) + .collect(), + }; + + // Subgoal + let constants_sub_goal = if constants.len() == 1 { + VariableTuple::Value(const_var(constants[0].0, constants[0].1.value_type())) + } else { + VariableTuple::Tuple( + constants .iter() - .enumerate() - .map(|(i, t)| match t { - Term::Constant(c) => Term::Variable(const_var(i, c.value_type())), - Term::Variable(v) => Term::Variable(v.clone()), - }) + .map(|(i, c)| VariableTuple::Value(const_var(i.clone(), c.value_type()))) .collect(), - }; + ) + }; + let variables_sub_goal = if variables.len() == 1 { + VariableTuple::Value(variables[0].1.clone()) + } else { + VariableTuple::Tuple( + variables + .iter() + .map(|(_, v)| VariableTuple::Value((*v).clone())) + .collect(), + ) + }; + let sub_goal = VariableTuple::from((constants_sub_goal, variables_sub_goal)); - // Subgoal - let constants_sub_goal = if constants.len() == 1 { - VariableTuple::Value(const_var(constants[0].0, constants[0].1.value_type())) - } else { - VariableTuple::Tuple( - constants - .iter() - .map(|(i, c)| VariableTuple::Value(const_var(i.clone(), c.value_type()))) - .collect(), - ) - }; - let variables_sub_goal = if variables.len() == 1 { - VariableTuple::Value(variables[0].1.clone()) - } else { - VariableTuple::Tuple( - variables - .iter() - .map(|(_, v)| VariableTuple::Value((*v).clone())) - .collect(), - ) - }; - let sub_goal = VariableTuple::from((constants_sub_goal, variables_sub_goal)); + // 1. Project it into (constants, variables) tuple + let project_1_dataflow = self.ground_plan_to_ram_dataflow(ctx, &sub_goal, &sub_atom, prop.with_need_sorted(true)); - // 1. Project it into (constants, variables) tuple - let project_1_dataflow = self.ground_plan_to_ram_dataflow(ctx, &sub_goal, &sub_atom, prop.with_need_sorted(true)); + // 2. Find using the constants + let find_tuple = if constants.len() == 1 { + Tuple::Value(constants[0].1.clone()) + } else { + Tuple::Tuple(constants.into_iter().map(|(_, t)| Tuple::Value(t.clone())).collect()) + }; + let find_dataflow = ram::Dataflow::find(project_1_dataflow, find_tuple); - // 2. Find using the constants - let find_tuple = if constants.len() == 1 { - Tuple::Value(constants[0].1.clone()) - } else { - Tuple::Tuple(constants.into_iter().map(|(_, t)| Tuple::Value(t.clone())).collect()) - }; - let find_dataflow = ram::Dataflow::find(project_1_dataflow, find_tuple); + // 3. Project into goal + let dataflow = ram::Dataflow::project(find_dataflow, sub_goal.projection(goal)); - // 3. Project into goal - let dataflow = ram::Dataflow::project(find_dataflow, sub_goal.projection(goal)); + // 4. Check if we need to create temporary variable + Self::process_dataflow(ctx, goal, dataflow, prop) + } - // 4. Check if we need to create temporary variable - Self::process_dataflow(ctx, goal, dataflow, prop) + fn with_constant_ground_plan_to_ram_dataflow( + &self, + ctx: &mut B2RContext, + goal: &VariableTuple, + atom: &Atom, + prop: DataflowProp, + ) -> ram::Dataflow { + let relation = ctx.get_relation(&atom.predicate).unwrap(); // NOTE: `unwrap` is ok here + if relation.immutable && relation.input_file.is_some() && !atom.constant_args_are_upfront() { + self.static_with_constant_ground_plan_to_ram_dataflow(ctx, goal, atom, prop) } else { - if goal.matches(atom) { - ram::Dataflow::Relation(atom.predicate.clone()) - } else { - let perm = goal.permutation(atom); - if let Some(filter) = Self::atom_filter(atom) { - let dataflow = ram::Dataflow::project( - ram::Dataflow::filter(ram::Dataflow::Relation(atom.predicate.clone()), filter), - perm.expr(), - ); - if prop.need_sorted && !perm.order_preserving() { - ctx.add_permutation(atom.predicate.clone(), perm); - if prop.is_negative { - Self::create_negative_temp_relation(ctx, goal, dataflow) - } else { - Self::create_temp_relation(ctx, goal, dataflow) - } + self.dynamic_with_constant_ground_plan_to_ram_dataflow(ctx, goal, atom, prop) + } + } + + fn no_constant_ground_plan_to_ram_dataflow( + &self, + ctx: &mut B2RContext, + goal: &VariableTuple, + atom: &Atom, + prop: DataflowProp, + ) -> ram::Dataflow { + if goal.matches(atom) { + ram::Dataflow::Relation(atom.predicate.clone()) + } else { + let perm = goal.permutation(atom); + if let Some(filter) = Self::atom_filter(atom) { + let dataflow = ram::Dataflow::project( + ram::Dataflow::filter(ram::Dataflow::Relation(atom.predicate.clone()), filter), + perm.expr(), + ); + if prop.need_sorted && !perm.order_preserving() { + ctx.add_permutation(atom.predicate.clone(), perm); + if prop.is_negative { + Self::create_negative_temp_relation(ctx, goal, dataflow) } else { - dataflow + Self::create_temp_relation(ctx, goal, dataflow) } } else { - if prop.need_sorted && !perm.order_preserving() { - let perm_name = Self::permutated_predicate_name(&atom.predicate, &perm); - ctx.add_permutation(atom.predicate.clone(), perm); - ram::Dataflow::Relation(perm_name) - } else { - ram::Dataflow::project(ram::Dataflow::Relation(atom.predicate.clone()), perm.expr()) - } + dataflow + } + } else { + if prop.need_sorted && !perm.order_preserving() { + let perm_name = Self::permutated_predicate_name(&atom.predicate, &perm); + ctx.add_permutation(atom.predicate.clone(), perm); + ram::Dataflow::Relation(perm_name) + } else { + ram::Dataflow::project(ram::Dataflow::Relation(atom.predicate.clone()), perm.expr()) } } } } + fn ground_plan_to_ram_dataflow( + &self, + ctx: &mut B2RContext, + goal: &VariableTuple, + atom: &Atom, + prop: DataflowProp, + ) -> ram::Dataflow { + if atom.args.is_empty() { + self.empty_ground_plan_to_ram_dataflow(goal, atom) + } else if atom.has_constant_arg() { + self.with_constant_ground_plan_to_ram_dataflow(ctx, goal, atom, prop) + } else { + self.no_constant_ground_plan_to_ram_dataflow(ctx, goal, atom, prop) + } + } + fn filter_plan_to_ram_dataflow( &self, ctx: &mut B2RContext, @@ -867,11 +972,22 @@ impl Program { // Get information from the atom let pred: String = atom.predicate.clone(); - let inputs: Vec = atom.args.iter().take(fp.num_bounded()).map(|arg| arg.as_constant().unwrap()).cloned().collect(); + let inputs: Vec = atom + .args + .iter() + .take(fp.num_bounded()) + .map(|arg| arg.as_constant().unwrap()) + .cloned() + .collect(); let ground_dataflow = ram::Dataflow::ForeignPredicateGround(pred, inputs); // Get the projection onto the variable tuple - let var_tuple = atom.args.iter().skip(fp.num_bounded()).map(|arg| arg.as_variable().unwrap()).cloned(); + let var_tuple = atom + .args + .iter() + .skip(fp.num_bounded()) + .map(|arg| arg.as_variable().unwrap()) + .cloned(); let project = VariableTuple::from_vars(var_tuple, false).projection(goal); // Project the dataflow @@ -893,12 +1009,14 @@ impl Program { // Generate information for foreign predicate constraint let pred: String = atom.predicate.clone(); - let exprs: Vec = atom.args.iter().map(|arg| { - match arg { + let exprs: Vec = atom + .args + .iter() + .map(|arg| match arg { Term::Constant(c) => Expr::Constant(c.clone()), Term::Variable(v) => Expr::Access(sub_goal.accessor_of(v).unwrap()), - } - }).collect(); + }) + .collect(); // Return a foreign predicate constraint dataflow let dataflow = sub_dataflow.foreign_predicate_constraint(pred, exprs); @@ -926,20 +1044,27 @@ impl Program { // Generate information for foreign predicate constraint let pred: String = atom.predicate.clone(); - let exprs: Vec = atom.args.iter().take(fp.num_bounded()).map(|arg| { - match arg { + let exprs: Vec = atom + .args + .iter() + .take(fp.num_bounded()) + .map(|arg| match arg { Term::Constant(c) => Expr::Constant(c.clone()), - Term::Variable(v) => { - Expr::Access(left_goal.accessor_of(v).unwrap()) - }, - } - }).collect(); + Term::Variable(v) => Expr::Access(left_goal.accessor_of(v).unwrap()), + }) + .collect(); // Generate the joint dataflow let join_dataflow = left_dataflow.foreign_predicate_join(pred, exprs); // Get the variable tuple of the joined output - let free_vars: Vec<_> = atom.args.iter().skip(fp.num_bounded()).map(|arg| arg.as_variable().unwrap()).cloned().collect(); + let free_vars: Vec<_> = atom + .args + .iter() + .skip(fp.num_bounded()) + .map(|arg| arg.as_variable().unwrap()) + .cloned() + .collect(); let right_tuple = VariableTuple::from_vars(free_vars.iter().cloned(), false); let var_tuple = VariableTuple::from((left_goal, right_tuple)); @@ -1015,9 +1140,7 @@ impl Program { .args .iter() .filter_map(|arg| match arg { - Term::Variable(v) => { - Some(VariableTuple::Value(v.clone())) - }, + Term::Variable(v) => Some(VariableTuple::Value(v.clone())), _ => { need_projection = true; None @@ -1166,3 +1289,10 @@ impl Program { format!("{}#perm#{}", pred, perm) } } + +fn const_var(i: usize, ty: Type) -> Variable { + Variable { + name: format!("const#{}", i), + ty, + } +} diff --git a/core/src/compiler/back/optimizations/constant_folding.rs b/core/src/compiler/back/optimizations/constant_folding.rs index 8956bb7..b7e421d 100644 --- a/core/src/compiler/back/optimizations/constant_folding.rs +++ b/core/src/compiler/back/optimizations/constant_folding.rs @@ -1,4 +1,5 @@ use crate::common::binary_op::BinaryOp; +use crate::common::entity; use crate::common::expr::{BinaryExpr, Expr, UnaryExpr}; use crate::common::foreign_function::*; use crate::common::unary_op::UnaryOp; @@ -86,6 +87,18 @@ pub fn constant_fold(rule: &mut Rule, function_registry: &ForeignFunctionRegistr } } } + AssignExpr::New(n) => { + let all_constant = n.args.iter().all(|a| a.is_constant()); + if all_constant { + let raw_id = entity::encode_entity(&n.functor, n.args.iter().map(|a| a.as_constant().unwrap())); + let id = Term::Constant(Constant::Entity(raw_id)); + *lit = Literal::Constraint(Constraint::Binary(BinaryConstraint { + op: BinaryConstraintOp::Eq, + op1: Term::Variable(a.left.clone()), + op2: id, + })) + } + } }, Literal::Constraint(c) => match c { Constraint::Binary(b) => match (&b.op, &b.op1, &b.op2) { diff --git a/core/src/compiler/back/optimizations/constant_propagation.rs b/core/src/compiler/back/optimizations/constant_propagation.rs index 6877f3a..0dbe4ee 100644 --- a/core/src/compiler/back/optimizations/constant_propagation.rs +++ b/core/src/compiler/back/optimizations/constant_propagation.rs @@ -129,6 +129,10 @@ pub fn constant_prop(rule: &mut Rule) { function: c.function.clone(), args: c.args.iter().map(substitute_term).collect(), }), + AssignExpr::New(n) => AssignExpr::New(NewExpr { + functor: n.functor.clone(), + args: n.args.iter().map(substitute_term).collect(), + }), }, }) } else { diff --git a/core/src/compiler/back/optimizations/equality_propagation.rs b/core/src/compiler/back/optimizations/equality_propagation.rs index 137e51c..792a503 100644 --- a/core/src/compiler/back/optimizations/equality_propagation.rs +++ b/core/src/compiler/back/optimizations/equality_propagation.rs @@ -99,6 +99,10 @@ pub fn propagate_equality(rule: &mut Rule) { function: c.function.clone(), args: c.args.iter().map(substitute_term).collect(), }), + AssignExpr::New(n) => AssignExpr::New(NewExpr { + functor: n.functor.clone(), + args: n.args.iter().map(substitute_term).collect(), + }), }, }), Literal::Constraint(Constraint::Binary(b)) => Literal::Constraint(Constraint::Binary(BinaryConstraint { diff --git a/core/src/compiler/back/pretty.rs b/core/src/compiler/back/pretty.rs index 747f3b8..95340f5 100644 --- a/core/src/compiler/back/pretty.rs +++ b/core/src/compiler/back/pretty.rs @@ -210,6 +210,7 @@ impl Display for AssignExpr { Self::Unary(u) => u.fmt(f), Self::IfThenElse(i) => i.fmt(f), Self::Call(c) => c.fmt(f), + Self::New(n) => n.fmt(f), } } } @@ -250,6 +251,16 @@ impl Display for CallExpr { } } +impl Display for NewExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + f.write_fmt(format_args!( + "new {}({})", + self.functor, + self.args.iter().map(|a| format!("{}", a)).collect::>().join("") + )) + } +} + impl Display for Constraint { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { match self { diff --git a/core/src/compiler/back/query_plan.rs b/core/src/compiler/back/query_plan.rs index 60042f9..5502bc6 100644 --- a/core/src/compiler/back/query_plan.rs +++ b/core/src/compiler/back/query_plan.rs @@ -51,7 +51,7 @@ impl<'a> QueryPlanContext<'a> { } else { ctx.pos_atoms.push(a.clone()) } - }, + } Literal::NegAtom(n) => ctx.neg_atoms.push(n.atom.clone()), Literal::Assign(a) => ctx.assigns.push(a.clone()), Literal::Constraint(c) => ctx.constraints.push(c.clone()), @@ -65,7 +65,7 @@ impl<'a> QueryPlanContext<'a> { } /// Find all the bounded arguments given the set of positive atoms - pub fn bounded_args_from_pos_atoms_set(&self, set: &Vec<&usize>) -> HashSet { + pub fn bounded_args_from_pos_atoms_set(&self, set: &Vec<&usize>, include_new: bool) -> HashSet { let mut base_bounded_args = HashSet::new(); // Add the base cases: all the arguments in the positive atoms form the base bounded args @@ -103,6 +103,10 @@ impl<'a> QueryPlanContext<'a> { cond_bounded && then_br_bounded && else_br_bounded } AssignExpr::Call(c) => c.args.iter().all(|a| term_is_bounded(&new_bounded_args, a)), + AssignExpr::New(n) => { + // If not include new, we do not bound the left variable + include_new && n.args.iter().all(|a| term_is_bounded(&new_bounded_args, a)) + } }; if can_bound { new_bounded_args.insert(assign.left.clone()); @@ -115,7 +119,11 @@ impl<'a> QueryPlanContext<'a> { let predicate = self.foreign_predicate_registry.get(&atom.predicate).unwrap(); // Then check if all the to-bound arguments are bounded - let can_bound = atom.args.iter().take(predicate.num_bounded()).all(|a| term_is_bounded(&new_bounded_args, a)); + let can_bound = atom + .args + .iter() + .take(predicate.num_bounded()) + .all(|a| term_is_bounded(&new_bounded_args, a)); // If it can be bounded, add the rest of the arguments to the bounded args if can_bound { @@ -205,7 +213,7 @@ impl<'a> QueryPlanContext<'a> { } /// Try applying as many assigns as possible - fn try_apply_assigns(&self, applied_assigns: &mut HashSet, mut fringe: Plan) -> Plan { + fn try_apply_non_new_assigns(&self, applied_assigns: &mut HashSet, mut fringe: Plan) -> Plan { // Find all the assigns that are needed to bound the need_projection_vars let mut bounded_vars = fringe.bounded_vars.clone(); loop { @@ -215,6 +223,7 @@ impl<'a> QueryPlanContext<'a> { for (i, assign) in self.assigns.iter().enumerate() { if !applied_assigns.contains(&i) && !bounded_vars.contains(&assign.left) + && !assign.right.is_new_expr() && assign.variable_args().into_iter().all(|v| bounded_vars.contains(v)) { applied_assigns.insert(i); @@ -236,31 +245,39 @@ impl<'a> QueryPlanContext<'a> { } /// Get the essential information for analyzing a foreign predicate atom - fn foreign_predicate_atom_info<'b, 'c>(&'b self, atom: &'c Atom) -> (&'b DynamicForeignPredicate, Vec<(usize, &'c Term)>, Vec<(usize, &'c Term)>) { + fn foreign_predicate_atom_info<'b, 'c>( + &'b self, + atom: &'c Atom, + ) -> ( + &'b DynamicForeignPredicate, + Vec<(usize, &'c Term)>, + Vec<(usize, &'c Term)>, + ) { let pred = self.foreign_predicate_registry.get(&atom.predicate).unwrap(); - let (to_bound_arguments, free_arguments): (Vec<_>, Vec<_>) = atom - .args - .iter() - .enumerate() - .partition(|(i, _)| *i < pred.num_bounded()); + let (to_bound_arguments, free_arguments): (Vec<_>, Vec<_>) = + atom.args.iter().enumerate().partition(|(i, _)| *i < pred.num_bounded()); (pred, to_bound_arguments, free_arguments) } fn foreign_predicate_constant_constraints(&self, atom: &Atom, arguments: &Vec<(usize, &Term)>) -> Vec { - arguments.iter().filter_map(|(i, a)| match a { - Term::Constant(c) => { - let op1 = Term::variable(format!("c#{}#{}", &atom.predicate, i), ValueType::type_of(c)); - let op2 = Term::Constant(c.clone()); - Some(Constraint::eq(op1, op2)) - } - Term::Variable(_) => { - None - } - }).collect() + arguments + .iter() + .filter_map(|(i, a)| match a { + Term::Constant(c) => { + let op1 = Term::variable(format!("c#{}#{}", &atom.predicate, i), ValueType::type_of(c)); + let op2 = Term::Constant(c.clone()); + Some(Constraint::eq(op1, op2)) + } + Term::Variable(_) => None, + }) + .collect() } fn foreign_predicate_equality_constraints(&self, var_eq: &Vec<(Variable, Variable)>) -> Vec { - var_eq.iter().map(|(v1, v2)| Constraint::eq(Term::Variable(v1.clone()), Term::Variable(v2.clone()))).collect() + var_eq + .iter() + .map(|(v1, v2)| Constraint::eq(Term::Variable(v1.clone()), Term::Variable(v2.clone()))) + .collect() } /// Given a list of free arguments, rename them to avoid name conflicts @@ -280,7 +297,8 @@ impl<'a> QueryPlanContext<'a> { occurred: &HashSet, ) -> (Vec, Vec<(Variable, Variable)>) { // First build a map from variable names to the number of occurrences - let mut var_occurrences: HashMap = occurred.iter().map(|v| (v.name.clone(), (v.clone(), 0))).collect(); + let mut var_occurrences: HashMap = + occurred.iter().map(|v| (v.name.clone(), (v.clone(), 0))).collect(); let mut var_equivalences: Vec<(Variable, Variable)> = Vec::new(); let vars = arguments .iter() @@ -306,10 +324,8 @@ impl<'a> QueryPlanContext<'a> { // Return the variable v.clone() } - }, - Term::Constant(c) => { - Variable::new(format!("c#{}#{}", predicate, i), ValueType::type_of(c)) } + Term::Constant(c) => Variable::new(format!("c#{}#{}", predicate, i), ValueType::type_of(c)), }) .collect::>(); (vars, var_equivalences) @@ -326,7 +342,13 @@ impl<'a> QueryPlanContext<'a> { let (all_vars, var_eq) = self.rename_free_arguments(&atom.predicate, free_arguments, &HashSet::new()); // Create a ground plan - let args = atom.args.iter().take(pred.num_bounded()).cloned().chain(all_vars.iter().cloned().map(Term::Variable)).collect(); + let args = atom + .args + .iter() + .take(pred.num_bounded()) + .cloned() + .chain(all_vars.iter().cloned().map(Term::Variable)) + .collect(); let ground_atom = Atom::new(atom.predicate.clone(), args); let ground_plan = Plan { bounded_vars: all_vars.iter().cloned().collect(), @@ -348,14 +370,14 @@ impl<'a> QueryPlanContext<'a> { .iter() .filter_map(|(_, a)| match a { Term::Variable(v) => Some(v.clone()), - _ => None + _ => None, }) .collect(); // Create a plan with filters on the constants let filter_plan = Plan { bounded_vars: new_bounded_vars, - ram_node: HighRamNode::filter(ground_plan, constraints) + ram_node: HighRamNode::filter(ground_plan, constraints), }; filter_plan @@ -376,12 +398,23 @@ impl<'a> QueryPlanContext<'a> { let (free_vars, var_eq) = self.rename_free_arguments(&atom.predicate, free_arguments, &occurred_variables); // Create an atom - let args = atom.args.iter().take(pred.num_bounded()).cloned().chain(free_vars.iter().cloned().map(Term::Variable)).collect(); + let args = atom + .args + .iter() + .take(pred.num_bounded()) + .cloned() + .chain(free_vars.iter().cloned().map(Term::Variable)) + .collect(); let to_join_atom = Atom::new(atom.predicate.clone(), args); // Create the join plan let join_plan = Plan { - bounded_vars: left.bounded_vars.iter().cloned().chain(free_vars.iter().cloned()).collect(), + bounded_vars: left + .bounded_vars + .iter() + .cloned() + .chain(free_vars.iter().cloned()) + .collect(), ram_node: HighRamNode::foreign_predicate_join(left.clone(), to_join_atom), }; @@ -400,14 +433,19 @@ impl<'a> QueryPlanContext<'a> { .iter() .filter_map(|(_, a)| match a { Term::Variable(v) => Some(v.clone()), - _ => None + _ => None, }) .collect(); // Create a plan with filters on the constants let filter_plan = Plan { - bounded_vars: left.bounded_vars.iter().chain(right_bounded_vars.iter()).cloned().collect(), - ram_node: HighRamNode::filter(join_plan, constraints) + bounded_vars: left + .bounded_vars + .iter() + .chain(right_bounded_vars.iter()) + .cloned() + .collect(), + ram_node: HighRamNode::filter(join_plan, constraints), }; filter_plan @@ -421,7 +459,11 @@ impl<'a> QueryPlanContext<'a> { /// /// - `applied_foreign_predicate_atoms`: The set of already applied foreign predicate atoms, represented by their index /// - `fringe`: The current Plan to apply the foreign predicate atoms - fn try_apply_foreign_predicate_atom(&self, applied_foreign_predicate_atoms: &mut HashSet, mut fringe: Plan) -> Plan { + fn try_apply_foreign_predicate_atom( + &self, + applied_foreign_predicate_atoms: &mut HashSet, + mut fringe: Plan, + ) -> Plan { // Find all the foreign predicate atoms let bounded_vars = fringe.bounded_vars.clone(); loop { @@ -434,7 +476,10 @@ impl<'a> QueryPlanContext<'a> { let (pred, to_bound_arguments, free_arguments) = self.foreign_predicate_atom_info(atom); // Check if all the to-bound arguments are bounded; if so, it means that we can apply the atom - if to_bound_arguments.iter().all(|(_, a)| term_is_bounded(&bounded_vars, a)) { + if to_bound_arguments + .iter() + .all(|(_, a)| term_is_bounded(&bounded_vars, a)) + { // Mark the atom as applied applied_foreign_predicate_atoms.insert(i); @@ -447,10 +492,7 @@ impl<'a> QueryPlanContext<'a> { // The atom is completely bounded; we add a foreign predicate constraint plan fringe = Plan { bounded_vars: bounded_vars.clone(), - ram_node: HighRamNode::ForeignPredicateConstraint( - Box::new(fringe), - atom.clone(), - ), + ram_node: HighRamNode::ForeignPredicateConstraint(Box::new(fringe), atom.clone()), }; } else if to_bound_arguments.iter().all(|(_, a)| a.is_constant()) { let plan = self.compute_foreign_predicate_ground_atom(atom, pred, &free_arguments); @@ -505,10 +547,17 @@ impl<'a> QueryPlanContext<'a> { (node, 0) } else { // Find the foreign predicate atom - if let Some((i, atom)) = self.foreign_predicate_pos_atoms.iter().enumerate().find(|(_, a)| self.is_ground_foreign_atom(a)) { + if let Some((i, atom)) = self + .foreign_predicate_pos_atoms + .iter() + .enumerate() + .find(|(_, a)| self.is_ground_foreign_atom(a)) + { applied_foreign_predicates.insert(i); // Mark the atom as applied let (pred, _, free_arguments) = self.foreign_predicate_atom_info(atom); let plan = self.compute_foreign_predicate_ground_atom(atom, pred, &free_arguments); + let plan = self.try_apply_non_new_assigns(&mut applied_assigns, plan); + let plan = self.try_apply_constraint(&mut applied_constraints, plan); (plan, 0) } else { panic!("[Internal Error] No foreign predicate atom is ground; should not happen"); @@ -523,7 +572,7 @@ impl<'a> QueryPlanContext<'a> { }; // Note: We always apply constraint first and then assigns - let node = self.try_apply_assigns(&mut applied_assigns, node); + let node = self.try_apply_non_new_assigns(&mut applied_assigns, node); let node = self.try_apply_constraint(&mut applied_constraints, node); let node = self.try_apply_foreign_predicate_atom(&mut applied_foreign_predicates, node); (node, 1) @@ -551,7 +600,7 @@ impl<'a> QueryPlanContext<'a> { } // Note: We always apply constraint first and then assigns - let node = self.try_apply_assigns(&mut applied_assigns, node); + let node = self.try_apply_non_new_assigns(&mut applied_assigns, node); let node = self.try_apply_constraint(&mut applied_constraints, node); let node = self.try_apply_foreign_predicate_atom(&mut applied_foreign_predicates, node); (node, 0) @@ -594,9 +643,28 @@ impl<'a> QueryPlanContext<'a> { } // Note: We always apply constraint first and then assigns - fringe = self.try_apply_assigns(&mut applied_assigns, fringe); - fringe = self.try_apply_constraint(&mut applied_constraints, fringe); - fringe = self.try_apply_foreign_predicate_atom(&mut applied_foreign_predicates, fringe); + let mut num_applied_assigns = applied_assigns.len(); + let mut num_applied_constraints = applied_constraints.len(); + let mut num_applied_fp = applied_foreign_predicates.len(); + loop { + fringe = self.try_apply_non_new_assigns(&mut applied_assigns, fringe); + fringe = self.try_apply_constraint(&mut applied_constraints, fringe); + fringe = self.try_apply_foreign_predicate_atom(&mut applied_foreign_predicates, fringe); + + // Check if anything is applied + let does_apply_assign = applied_assigns.len() != num_applied_assigns; + let does_apply_constraint = applied_constraints.len() != num_applied_constraints; + let does_apply_fp = applied_foreign_predicates.len() != num_applied_fp; + + // If so, we continue in a loop + if does_apply_assign || does_apply_constraint || does_apply_fp { + num_applied_assigns = applied_assigns.len(); + num_applied_constraints = applied_constraints.len(); + num_applied_fp = applied_foreign_predicates.len(); + } else { + break; + } + } } // ==== Stage 4: Apply negative atoms ==== @@ -611,6 +679,26 @@ impl<'a> QueryPlanContext<'a> { }; } + // ==== Stage 5: Apply new entity assigns ==== + let new_entity_assigns = self + .assigns + .iter() + .filter(|a| a.right.is_new_expr()) + .cloned() + .collect::>(); + if new_entity_assigns.len() > 0 { + let all_bounded_variables = new_entity_assigns + .iter() + .map(|a| &a.left) + .chain(fringe.bounded_vars.iter()) + .cloned() + .collect(); + fringe = Plan { + bounded_vars: all_bounded_variables, + ram_node: HighRamNode::project(fringe, new_entity_assigns), + }; + } + fringe } @@ -660,9 +748,14 @@ impl State { vec![] } else { let mut next_states: Vec = vec![]; - for (id, atom) in ctx.pos_atoms.iter().enumerate().filter(|(i, _)| !self.visited_atoms.contains(i)) { + for (id, atom) in ctx + .pos_atoms + .iter() + .enumerate() + .filter(|(i, _)| !self.visited_atoms.contains(i)) + { for set in self.visited_atoms.iter().powerset() { - let all_bounded_args = ctx.bounded_args_from_pos_atoms_set(&set); + let all_bounded_args = ctx.bounded_args_from_pos_atoms_set(&set, false); let bounded_vars = atom .variable_args() .filter(|a| all_bounded_args.contains(a)) @@ -798,6 +891,11 @@ impl HighRamNode { Self::Filter(Box::new(p1), cs) } + /// Create a new Project high level ram node + pub fn project(p1: Plan, assigns: Vec) -> Self { + Self::Project(Box::new(p1), assigns) + } + /// Create a new JOIN high level ram node pub fn join(p1: Plan, p2: Plan) -> Self { Self::Join(Box::new(p1), Box::new(p2)) diff --git a/core/src/compiler/back/scc.rs b/core/src/compiler/back/scc.rs index 153f003..0b5f6a6 100644 --- a/core/src/compiler/back/scc.rs +++ b/core/src/compiler/back/scc.rs @@ -244,29 +244,45 @@ impl Program { // Then add all the dependencies by going through rules for rule in &self.rules { + // Collect the head and functor predicates occurred in the rule let head_predicate = rule.head_predicate(); + let functor_predicates = rule.collect_new_expr_functors().collect::>(); + + // Functor predicates also depend on head + for functor_predicate in &functor_predicates { + graph.add_dependency(functor_predicate, head_predicate, E::Positive); + graph.add_dependency(head_predicate, functor_predicate, E::Positive); + } + + // A recording dependency + let mut record_dependency = |pred, edge_type: E| { + graph.add_dependency(head_predicate, pred, edge_type.clone()); + for functor_predicate in &functor_predicates { + graph.add_dependency(functor_predicate, pred, edge_type.clone()); + } + }; for atom in rule.body_literals() { match atom { Literal::Atom(a) => { let atom_predicate = &a.predicate; if !self.predicate_registry.contains(atom_predicate) { - graph.add_dependency(head_predicate, atom_predicate, E::Positive); + record_dependency(atom_predicate, E::Positive); } } Literal::NegAtom(a) => { let atom_predicate = &a.atom.predicate; if !self.predicate_registry.contains(atom_predicate) { - graph.add_dependency(head_predicate, atom_predicate, E::Negative); + record_dependency(atom_predicate, E::Negative); } } Literal::Reduce(r) => { let reduce_predicate = &r.body_formula.predicate; - graph.add_dependency(head_predicate, reduce_predicate, E::Aggregation); + record_dependency(reduce_predicate, E::Aggregation); // Add group by predicate also as an aggregation dependency if let Some(group_by_atom) = &r.group_by_formula { let group_by_predicate = &group_by_atom.predicate; - graph.add_dependency(head_predicate, group_by_predicate, E::Aggregation); + record_dependency(group_by_predicate, E::Aggregation); } } _ => {} diff --git a/core/src/compiler/back/var_tuple.rs b/core/src/compiler/back/var_tuple.rs index e92eeec..c80b913 100644 --- a/core/src/compiler/back/var_tuple.rs +++ b/core/src/compiler/back/var_tuple.rs @@ -185,6 +185,10 @@ impl VariableTuple { function: c.function.clone(), args: c.args.iter().map(|a| self.term_to_ram_expr(a).unwrap()).collect(), }), + AssignExpr::New(n) => Expr::New(crate::common::expr::NewExpr { + functor: n.functor.clone(), + args: n.args.iter().map(|a| self.term_to_ram_expr(a).unwrap()).collect(), + }), } } } diff --git a/core/src/compiler/front/analysis.rs b/core/src/compiler/front/analysis.rs index 65e89e1..6676023 100644 --- a/core/src/compiler/front/analysis.rs +++ b/core/src/compiler/front/analysis.rs @@ -15,6 +15,7 @@ pub struct Analysis { pub aggregation_analysis: AggregationAnalysis, pub character_literal_analysis: CharacterLiteralAnalysis, pub constant_decl_analysis: ConstantDeclAnalysis, + pub adt_analysis: AlgebraicDataTypeAnalysis, pub head_relation_analysis: HeadRelationAnalysis, pub type_inference: TypeInference, pub boundness_analysis: BoundnessAnalysis, @@ -23,10 +24,7 @@ pub struct Analysis { impl Analysis { /// Create a new front IR analysis object - pub fn new( - function_registry: &ForeignFunctionRegistry, - predicate_registry: &ForeignPredicateRegistry, - ) -> Self { + pub fn new(function_registry: &ForeignFunctionRegistry, predicate_registry: &ForeignPredicateRegistry) -> Self { Self { invalid_constant: InvalidConstantAnalyzer::new(), invalid_wildcard: InvalidWildcardAnalyzer::new(), @@ -36,6 +34,7 @@ impl Analysis { aggregation_analysis: AggregationAnalysis::new(), character_literal_analysis: CharacterLiteralAnalysis::new(), constant_decl_analysis: ConstantDeclAnalysis::new(), + adt_analysis: AlgebraicDataTypeAnalysis::new(), head_relation_analysis: HeadRelationAnalysis::new(predicate_registry), type_inference: TypeInference::new(function_registry, predicate_registry), boundness_analysis: BoundnessAnalysis::new(predicate_registry), @@ -51,6 +50,7 @@ impl Analysis { &mut self.aggregation_analysis, &mut self.character_literal_analysis, &mut self.constant_decl_analysis, + &mut self.adt_analysis, &mut self.invalid_constant, &mut self.invalid_wildcard, ); @@ -58,9 +58,12 @@ impl Analysis { } pub fn process_items(&mut self, items: &Vec) { + // Prepare the type inference module self .type_inference .extend_constant_types(self.constant_decl_analysis.compute_typed_constants()); + + // Create the analyzers and walk the items let mut analyzers = ( &mut self.head_relation_analysis, &mut self.type_inference, @@ -85,6 +88,7 @@ impl Analysis { error_ctx.extend(&mut self.aggregation_analysis.errors); error_ctx.extend(&mut self.character_literal_analysis.errors); error_ctx.extend(&mut self.constant_decl_analysis.errors); + error_ctx.extend(&mut self.adt_analysis.errors); error_ctx.extend(&mut self.head_relation_analysis.errors); error_ctx.extend(&mut self.type_inference.errors); error_ctx.extend(&mut self.boundness_analysis.errors); diff --git a/core/src/compiler/front/analyzers/aggregation.rs b/core/src/compiler/front/analyzers/aggregation.rs index 8a1637b..1fb1e4a 100644 --- a/core/src/compiler/front/analyzers/aggregation.rs +++ b/core/src/compiler/front/analyzers/aggregation.rs @@ -47,15 +47,11 @@ impl NodeVisitor for AggregationAnalysis { // Check the binding variables if reduce.bindings().is_empty() { match &reduce.operator().node { - ReduceOperatorNode::Exists - | ReduceOperatorNode::Forall - | ReduceOperatorNode::Unknown(_) => {} - r => { - self.errors.push(AggregationAnalysisError::EmptyBinding { - agg: r.to_string(), - loc: reduce.location().clone(), - }) - } + ReduceOperatorNode::Exists | ReduceOperatorNode::Forall | ReduceOperatorNode::Unknown(_) => {} + r => self.errors.push(AggregationAnalysisError::EmptyBinding { + agg: r.to_string(), + loc: reduce.location().clone(), + }), } } } diff --git a/core/src/compiler/front/analyzers/algebraic_data_type.rs b/core/src/compiler/front/analyzers/algebraic_data_type.rs new file mode 100644 index 0000000..9a0ab4a --- /dev/null +++ b/core/src/compiler/front/analyzers/algebraic_data_type.rs @@ -0,0 +1,98 @@ +use std::collections::*; + +use crate::compiler::front::*; + +#[derive(Clone, Debug)] +pub struct AlgebraicDataTypeAnalysis { + pub errors: Vec, + pub adt_variants: HashMap, + pub adt_types: HashSet, +} + +#[derive(Clone, Debug)] +pub struct VariantInfo { + pub belongs_to_type: Identifier, + pub name: Identifier, + pub location: AstNodeLocation, + pub args: Vec, +} + +impl AlgebraicDataTypeAnalysis { + pub fn new() -> Self { + Self { + errors: Vec::new(), + adt_variants: HashMap::new(), + adt_types: HashSet::new(), + } + } +} + +impl NodeVisitor for AlgebraicDataTypeAnalysis { + fn visit_algebraic_data_type_decl(&mut self, decl: &ast::AlgebraicDataTypeDecl) { + // Add the type to the set of adt types + self.adt_types.insert(decl.name().to_string()); + + // And then declare all the constant types + let mut visited_names: HashMap<&str, &AstNodeLocation> = HashMap::new(); + for variant in decl.iter_variants() { + // First check if the variant has already being declared + if let Some(loc) = visited_names.get(variant.name()) { + self.errors.push(ADTError::DuplicateADTVariant { + constructor: variant.name().to_string(), + first_declared: (*loc).clone(), + duplicated: variant.location().clone(), + }); + } else { + visited_names.insert(variant.name(), variant.location()); + } + + // Then check if the variant has occurred previously + if let Some(info) = self.adt_variants.get(variant.name()) { + self.errors.push(ADTError::DuplicateADTVariant { + constructor: variant.name().to_string(), + first_declared: info.location.clone(), + duplicated: variant.location().clone(), + }); + } + + // If everything is well, store the ADT variant + let info = VariantInfo { + belongs_to_type: decl.name_identifier().clone(), + name: variant.name_identifier().clone(), + location: variant.location().clone(), + args: variant.args().clone(), + }; + self.adt_variants.insert(variant.name().to_string(), info); + } + } +} + +#[derive(Clone, Debug)] +pub enum ADTError { + DuplicateADTVariant { + constructor: String, + first_declared: AstNodeLocation, + duplicated: AstNodeLocation, + }, +} + +impl FrontCompileErrorTrait for ADTError { + fn error_type(&self) -> FrontCompileErrorType { + FrontCompileErrorType::Error + } + + fn report(&self, src: &Sources) -> String { + match self { + Self::DuplicateADTVariant { + constructor, + first_declared, + duplicated, + } => { + format!( + "duplicated Algebraic Data Type variant `{}`. It is first declared here:\n{}\nwhile we find a duplicated declaration here:\n{}", + constructor, first_declared.report(src), duplicated.report(src) + ) + } + } + } +} diff --git a/core/src/compiler/front/analyzers/boundness/boundness_analysis.rs b/core/src/compiler/front/analyzers/boundness/boundness_analysis.rs index 3570815..a79f6b2 100644 --- a/core/src/compiler/front/analyzers/boundness/boundness_analysis.rs +++ b/core/src/compiler/front/analyzers/boundness/boundness_analysis.rs @@ -39,7 +39,7 @@ impl BoundnessAnalysis { // Make sure the demand attribute is affecting boundness analysis, // through some of the head expressions being bounded let bounded_exprs = if let Some(head_atom) = rule.head().atom() { - if let Some((pattern, _)) = demand_attrs.get(head_atom.predicate()) { + if let Some((pattern, _)) = demand_attrs.get(&head_atom.predicate()) { head_atom .iter_arguments() .zip(pattern.chars()) diff --git a/core/src/compiler/front/analyzers/boundness/context.rs b/core/src/compiler/front/analyzers/boundness/context.rs index 9573844..89ee8d9 100644 --- a/core/src/compiler/front/analyzers/boundness/context.rs +++ b/core/src/compiler/front/analyzers/boundness/context.rs @@ -67,6 +67,11 @@ impl DisjunctionContext { } Formula::Atom(a) => vec![ConjunctionContext::from_atom(a)], Formula::NegAtom(a) => vec![ConjunctionContext::from_neg_atom(a)], + Formula::Case(_) => { + panic!( + "Unexpected `case` visited during boundness analysis; case should be rewritten by previous transformations" + ) + } Formula::Constraint(a) => vec![ConjunctionContext::from_constraint(a)], Formula::Reduce(r) => vec![ConjunctionContext::from_reduce(r)], Formula::ForallExistsReduce(_) => { @@ -294,27 +299,19 @@ impl AggregationContext { fn collect_vars_in_head(head: &RuleHead) -> Vec<(String, Loc)> { match &head.node { RuleHeadNode::Atom(atom) => collect_vars_in_atom(atom), - RuleHeadNode::Disjunction(d) => d.iter().map(collect_vars_in_atom).flatten().collect(), + RuleHeadNode::Conjunction(c) => c.iter().flat_map(collect_vars_in_atom).collect(), + RuleHeadNode::Disjunction(d) => d.iter().flat_map(collect_vars_in_atom).collect(), } } fn collect_vars_in_atom(atom: &Atom) -> Vec<(String, Loc)> { - atom.iter_arguments().map(collect_vars_in_expr).flatten().collect() -} - -fn collect_vars_in_expr(expr: &Expr) -> Vec<(String, Loc)> { - match expr { - Expr::Binary(b) => vec![collect_vars_in_expr(b.op1()), collect_vars_in_expr(b.op2())].concat(), - Expr::Unary(u) => collect_vars_in_expr(u.op1()), - Expr::Variable(v) => vec![(v.name().to_string(), v.location().clone())], - Expr::Constant(_) => vec![], - Expr::Wildcard(_) => vec![], - Expr::IfThenElse(i) => vec![ - collect_vars_in_expr(i.cond()), - collect_vars_in_expr(i.then_br()), - collect_vars_in_expr(i.else_br()), - ] - .concat(), - Expr::Call(c) => c.iter_args().map(|a| collect_vars_in_expr(a)).concat(), - } + atom + .iter_arguments() + .flat_map(|arg| { + arg + .collect_used_variables() + .into_iter() + .map(|v| (v.name().to_string(), v.location().clone())) + }) + .collect() } diff --git a/core/src/compiler/front/analyzers/boundness/foreign.rs b/core/src/compiler/front/analyzers/boundness/foreign.rs index 002389b..5a4e374 100644 --- a/core/src/compiler/front/analyzers/boundness/foreign.rs +++ b/core/src/compiler/front/analyzers/boundness/foreign.rs @@ -13,7 +13,7 @@ impl ForeignPredicateBindings { } pub fn add(&mut self, fp: &F) { - self.bindings.insert(fp.name(), fp.binding_pattern()); + self.bindings.insert(fp.internal_name(), fp.binding_pattern()); } pub fn get(&self, name: &str) -> Option<&BindingPattern> { diff --git a/core/src/compiler/front/analyzers/boundness/local.rs b/core/src/compiler/front/analyzers/boundness/local.rs index abcf5ad..46dea4d 100644 --- a/core/src/compiler/front/analyzers/boundness/local.rs +++ b/core/src/compiler/front/analyzers/boundness/local.rs @@ -18,9 +18,29 @@ pub struct LocalBoundnessAnalysisContext<'a> { impl<'a> NodeVisitor for LocalBoundnessAnalysisContext<'a> { fn visit_atom(&mut self, atom: &Atom) { - if let Some(binding) = self.foreign_predicate_bindings.get(atom.predicate()) { - let bounded = atom.iter_arguments().enumerate().filter_map(|(i, a)| if binding[i].is_bound() { Some(a.location().clone()) } else { None } ).collect(); - let to_bound = atom.iter_arguments().enumerate().filter_map(|(i, a)| if binding[i].is_free() { Some(a.location().clone()) } else { None } ).collect(); + if let Some(binding) = self.foreign_predicate_bindings.get(&atom.predicate()) { + let bounded = atom + .iter_arguments() + .enumerate() + .filter_map(|(i, a)| { + if binding[i].is_bound() { + Some(a.location().clone()) + } else { + None + } + }) + .collect(); + let to_bound = atom + .iter_arguments() + .enumerate() + .filter_map(|(i, a)| { + if binding[i].is_free() { + Some(a.location().clone()) + } else { + None + } + }) + .collect(); let dep = BoundnessDependency::ForeignPredicateArgs(bounded, to_bound); self.dependencies.push(dep); } else { diff --git a/core/src/compiler/front/analyzers/character_literal.rs b/core/src/compiler/front/analyzers/character_literal.rs index 4773594..5c2b960 100644 --- a/core/src/compiler/front/analyzers/character_literal.rs +++ b/core/src/compiler/front/analyzers/character_literal.rs @@ -16,21 +16,17 @@ impl CharacterLiteralAnalysis { } impl NodeVisitor for CharacterLiteralAnalysis { - fn visit_constant(&mut self, c: &Constant) { - match &c.node { - ConstantNode::Char(s) => { - let loc = c.location().clone(); - if s.len() == 1 { - // OK - } else if s.len() == 0 { - self.errors.push(CharacterLiteralAnalysisError::EmptyCharacter { loc }) - } else { - self - .errors - .push(CharacterLiteralAnalysisError::InvalidCharacter { loc }) - } - } - _ => {} + fn visit_constant_char(&mut self, c: &ConstantChar) { + let s = c.character_string(); + let loc = c.location().clone(); + if s.len() == 1 { + // OK + } else if s.len() == 0 { + self.errors.push(CharacterLiteralAnalysisError::EmptyCharacter { loc }) + } else { + self + .errors + .push(CharacterLiteralAnalysisError::InvalidCharacter { loc }) } } } diff --git a/core/src/compiler/front/analyzers/constant_decl.rs b/core/src/compiler/front/analyzers/constant_decl.rs index 48d31d1..17b6eac 100644 --- a/core/src/compiler/front/analyzers/constant_decl.rs +++ b/core/src/compiler/front/analyzers/constant_decl.rs @@ -15,6 +15,7 @@ use super::super::*; pub struct ConstantDeclAnalysis { pub variables: HashMap, Constant)>, pub variable_use: HashMap, + pub entity_facts: Vec, pub errors: Vec, } @@ -24,6 +25,7 @@ impl ConstantDeclAnalysis { Self { variables: HashMap::new(), variable_use: HashMap::new(), + entity_facts: Vec::new(), errors: vec![], } } @@ -82,13 +84,13 @@ impl ConstantDeclAnalysis { } } else { // If there is no previous max, then directly give it `i`. - return Ok(i) + return Ok(i); } } _ => { // We don't care other cases } - } + }, _ => {} }; @@ -111,7 +113,11 @@ impl ConstantDeclAnalysis { // Then store the variable into the storage self.variables.insert( member.name().to_string(), - (member.location().clone(), Some(Type::usize()), Constant::integer(id as i64)) + ( + member.location().clone(), + Some(Type::usize()), + Constant::integer(id as i64), + ), ); Ok(()) } @@ -145,11 +151,35 @@ impl NodeVisitor for ConstantDeclAnalysis { second_decl: ca.location().clone(), }) } else { - // Then store the variable into the storage - self.variables.insert( - ca.name().to_string(), - (ca.location().clone(), ca.ty().cloned(), ca.value().clone()), - ); + let entity = ca.value(); + + // Then we make sure that the entity is indeed a constant + if let Some(var_loc) = entity.get_first_non_constant_location(&|v| self.variables.contains_key(v.name())) { + self.errors.push(ConstantDeclError::EntityContainsNonConstant { + const_decl_loc: ca.location().clone(), + var_loc: var_loc.clone(), + }) + } else { + // Annotate the type of the entity + let ty = if entity.is_constant() { + ca.ty().cloned() + } else { + Some(Type::entity()) + }; + + // Process the entity into a set of entity facts and one final constant value + let (entity_facts, constant) = + entity.to_facts_with_constant_variables(|v| self.variables.get(v.name()).map(|(_, _, c)| c.clone())); + + // Extend the entity facts with the storage + self.entity_facts.extend(entity_facts); + + // Store the variable + self.variables.insert( + ca.identifier().name().to_string(), + (ca.location().clone(), ty, constant), + ); + } } } @@ -230,6 +260,10 @@ pub enum ConstantDeclError { id: i64, loc: Loc, }, + EntityContainsNonConstant { + const_decl_loc: Loc, + var_loc: Loc, + }, } impl FrontCompileErrorTrait for ConstantDeclError { @@ -262,7 +296,18 @@ impl FrontCompileErrorTrait for ConstantDeclError { format!("unknown variable `{}`:\n{}", name, loc.report(src)) } Self::EnumIDAlreadyAssigned { curr_name, id, loc } => { - format!("the enum ID `{}` for variant `{}` has already been assigned\n{}", id, curr_name, loc.report(src)) + format!( + "the enum ID `{}` for variant `{}` has already been assigned\n{}", + id, + curr_name, + loc.report(src) + ) + } + Self::EntityContainsNonConstant { var_loc, .. } => { + format!( + "non-constant expression found in constant entity:\n{}", + var_loc.report(src) + ) } } } diff --git a/core/src/compiler/front/analyzers/demand_attr.rs b/core/src/compiler/front/analyzers/demand_attr.rs index ff652c5..1ab53af 100644 --- a/core/src/compiler/front/analyzers/demand_attr.rs +++ b/core/src/compiler/front/analyzers/demand_attr.rs @@ -39,10 +39,12 @@ impl DemandAttributeAnalysis { pub fn set_disjunctive(&mut self, pred: &String, loc: &AstNodeLocation) { if self.demand_attrs.contains_key(pred) { - self.errors.push(DemandAttributeError::DisjunctivePredicateWithDemandAttribute { - pred: pred.clone(), - loc: loc.clone(), - }); + self + .errors + .push(DemandAttributeError::DisjunctivePredicateWithDemandAttribute { + pred: pred.clone(), + loc: loc.clone(), + }); } else { self.disjunctive_predicates.insert(pred.clone()); } @@ -51,28 +53,31 @@ impl DemandAttributeAnalysis { pub fn process_attribute(&mut self, pred: &str, attr: &Attribute) { // Check if the predicate occurs in a disjunctive head if self.disjunctive_predicates.contains(pred) { - self.errors.push(DemandAttributeError::DisjunctivePredicateWithDemandAttribute { - pred: pred.to_string(), - loc: attr.location().clone(), - }); + self + .errors + .push(DemandAttributeError::DisjunctivePredicateWithDemandAttribute { + pred: pred.to_string(), + loc: attr.location().clone(), + }); } // Check the pattern if attr.name() == "demand" { if attr.num_pos_args() == 1 { let value = attr.pos_arg(0).unwrap(); - match &value.node { - ConstantNode::String(s) => { - if is_valid_demand_pattern(s) { + match value.as_constant().map(|v| &v.node) { + Some(ConstantNode::String(s)) => { + let string_content = s.string(); + if is_valid_demand_pattern(string_content) { if let Some((p, l)) = self.demand_attrs.get(pred) { - if p != s { + if p != string_content { self.errors.push(DemandAttributeError::ConflictingPattern { first_loc: l.clone(), second_loc: value.location().clone(), }); } } else { - let attr = (s.clone(), value.location().clone()); + let attr = (string_content.to_string(), value.location().clone()); self.demand_attrs.insert(pred.to_string(), attr); } } else { @@ -81,8 +86,12 @@ impl DemandAttributeAnalysis { }); } } - _ => self.errors.push(DemandAttributeError::InvalidArgumentType { - found: value.kind().to_string(), + Some(c) => self.errors.push(DemandAttributeError::InvalidArgumentType { + found: c.kind().to_string(), + loc: value.location().clone(), + }), + None => self.errors.push(DemandAttributeError::InvalidArgumentType { + found: "list".to_string(), loc: value.location().clone(), }), } @@ -113,14 +122,14 @@ impl NodeVisitor for DemandAttributeAnalysis { fn visit_rule_decl(&mut self, rule_decl: &ast::RuleDecl) { if rule_decl.rule().head().is_disjunction() { for predicate in rule_decl.rule().head().iter_predicates() { - self.set_disjunctive(predicate, rule_decl.rule().head().location()); + self.set_disjunctive(&predicate, rule_decl.rule().head().location()); return; // early stopping because this is an error } } // Otherwise, we add the demand attribute if let Some(atom) = rule_decl.rule().head().atom() { - self.process_attributes(atom.predicate(), rule_decl.attributes()); + self.process_attributes(&atom.predicate(), rule_decl.attributes()); } } } @@ -210,7 +219,11 @@ impl FrontCompileErrorTrait for DemandAttributeError { format!("Invalid demand pattern\n{}", loc.report(src)) } Self::DisjunctivePredicateWithDemandAttribute { pred, loc } => { - format!("The predicate `{}` being annotated by `demand` but occurs in a disjunctive rule head\n{}", pred, loc.report(src)) + format!( + "The predicate `{}` being annotated by `demand` but occurs in a disjunctive rule head\n{}", + pred, + loc.report(src) + ) } } } diff --git a/core/src/compiler/front/analyzers/head_relation.rs b/core/src/compiler/front/analyzers/head_relation.rs index de605a9..79fc2b3 100644 --- a/core/src/compiler/front/analyzers/head_relation.rs +++ b/core/src/compiler/front/analyzers/head_relation.rs @@ -14,7 +14,10 @@ pub struct HeadRelationAnalysis { impl HeadRelationAnalysis { pub fn new(foreign_predicate_registry: &ForeignPredicateRegistry) -> Self { - let declared_relations = foreign_predicate_registry.iter().map(|(_, p)| p.name().to_string()).collect(); + let declared_relations = foreign_predicate_registry + .iter() + .map(|(_, p)| p.name().to_string()) + .collect(); Self { errors: vec![], used_relations: HashMap::new(), diff --git a/core/src/compiler/front/analyzers/hidden_relation.rs b/core/src/compiler/front/analyzers/hidden_relation.rs index 7efe16e..3e9c872 100644 --- a/core/src/compiler/front/analyzers/hidden_relation.rs +++ b/core/src/compiler/front/analyzers/hidden_relation.rs @@ -33,16 +33,16 @@ impl NodeVisitor for HiddenRelationAnalysis { } fn visit_constant_set_decl(&mut self, decl: &ast::ConstantSetDecl) { - self.process_attributes(decl.predicate(), decl.attributes()) + self.process_attributes(&decl.predicate(), decl.attributes()) } fn visit_fact_decl(&mut self, decl: &ast::FactDecl) { - self.process_attributes(decl.predicate(), decl.attributes()) + self.process_attributes(&decl.predicate(), decl.attributes()) } fn visit_rule_decl(&mut self, rule_decl: &RuleDecl) { for predicate in rule_decl.rule().head().iter_predicates() { - self.process_attributes(predicate, rule_decl.attributes()) + self.process_attributes(&predicate, rule_decl.attributes()) } } } diff --git a/core/src/compiler/front/analyzers/input_files.rs b/core/src/compiler/front/analyzers/input_files.rs index f0124ba..2138801 100644 --- a/core/src/compiler/front/analyzers/input_files.rs +++ b/core/src/compiler/front/analyzers/input_files.rs @@ -27,12 +27,12 @@ impl InputFilesAnalysis { self.input_files.get(relation) } - pub fn process_deliminator(&self, attr_arg: Option<&Constant>) -> Result, InputFilesError> { + pub fn process_deliminator(&self, attr_arg: Option<&AttributeValue>) -> Result, InputFilesError> { match attr_arg { - Some(v) => match &v.node { - ConstantNode::String(s) => { + Some(v) => match v.as_string() { + Some(s) => { if s.len() == 1 { - let c = s.chars().next().unwrap(); + let c = s.chars().next().unwrap(); // NOTE: Unwrap is ok since we know the length is 1 if c.is_ascii() { Ok(Some(c as u8)) } else { @@ -54,10 +54,10 @@ impl InputFilesAnalysis { } } - pub fn process_has_header(&self, attr_arg: Option<&Constant>) -> Result, InputFilesError> { + pub fn process_has_header(&self, attr_arg: Option<&AttributeValue>) -> Result, InputFilesError> { match attr_arg { - Some(v) => match &v.node { - ConstantNode::Boolean(b) => Ok(Some(*b)), + Some(v) => match v.as_bool() { + Some(b) => Ok(Some(*b)), _ => Err(InputFilesError::HasHeaderNotBoolean { loc: v.location().clone(), }), @@ -66,10 +66,10 @@ impl InputFilesAnalysis { } } - pub fn process_has_probability(&self, attr_arg: Option<&Constant>) -> Result, InputFilesError> { + pub fn process_has_probability(&self, attr_arg: Option<&AttributeValue>) -> Result, InputFilesError> { match attr_arg { - Some(v) => match &v.node { - ConstantNode::Boolean(b) => Ok(Some(*b)), + Some(v) => match v.as_bool() { + Some(b) => Ok(Some(*b)), _ => Err(InputFilesError::HasProbabilityNotBoolean { loc: v.location().clone(), }), @@ -78,22 +78,90 @@ impl InputFilesAnalysis { } } + pub fn process_keys(&self, attr_arg: Option<&AttributeValue>) -> Result>, InputFilesError> { + match attr_arg { + Some(v) => match &v.node { + AttributeValueNode::Constant(c) => match c.as_string() { + Some(s) => Ok(Some(vec![s.clone()])), + None => Err(InputFilesError::KeysNotString { + loc: v.location().clone(), + }), + }, + AttributeValueNode::List(l) => { + // Make sure that the list is not empty + if l.is_empty() { + return Err(InputFilesError::KeysEmptyList { + loc: v.location().clone(), + }); + } + + // Extract each element of the key + let mut keys = Vec::new(); + for arg in l { + match arg.as_string() { + Some(s) => keys.push(s.clone()), + None => { + return Err(InputFilesError::KeysNotString { + loc: arg.location().clone(), + }) + } + } + } + Ok(Some(keys)) + } + AttributeValueNode::Tuple(_) => { + Err(InputFilesError::KeysNotString { + loc: v.location().clone(), + }) + }, + }, + None => Ok(None), + } + } + + fn process_fields(&self, attr_arg: Option<&AttributeValue>) -> Result>, InputFilesError> { + match attr_arg { + Some(v) => match &v.node { + AttributeValueNode::List(l) => { + // Extract each element of the key + let mut keys = Vec::new(); + for arg in l { + match arg.as_string() { + Some(s) => keys.push(s.clone()), + None => { + return Err(InputFilesError::FieldsNotListOfString { + loc: arg.location().clone(), + }) + } + } + } + Ok(Some(keys)) + } + _ => Err(InputFilesError::FieldsNotListOfString { + loc: v.location().clone(), + }), + }, + None => Ok(None), + } + } + /// Assumption: Assumes attr is of `file` pub fn process_attr(&self, attr: &Attribute) -> Result { - if attr.num_pos_args() > 0 { - let arg = attr.pos_arg(0).unwrap(); - match &arg.node { - ConstantNode::String(s) => { + if let Some(arg) = attr.pos_arg(0) { + match arg.as_string() { + Some(s) => { let path = PathBuf::from(s); match path.extension() { Some(s) if s == "csv" => { let deliminator = self.process_deliminator(attr.kw_arg("deliminator"))?; - let has_header = self.process_has_header(attr.kw_arg("has_header"))?; + let has_header = self.process_has_header(attr.kw_arg("has_header").or_else(|| attr.kw_arg("header")))?; let has_probability = self.process_has_probability(attr.kw_arg("has_probability"))?; - let input_file = InputFile::csv_with_options(path, deliminator, has_header, has_probability); + let keys = self.process_keys(attr.kw_arg("keys"))?; + let fields = self.process_fields(attr.kw_arg("fields"))?; + let input_file = + InputFile::csv_with_options(path, deliminator, has_header, has_probability, keys, fields); Ok(input_file) } - Some(s) if s == "txt" => Ok(InputFile::Txt(path)), Some(s) => Err(InputFilesError::UnknownExtension { ext: String::from(s.to_str().unwrap()), attr_arg_loc: arg.location().clone(), @@ -168,6 +236,15 @@ pub enum InputFilesError { DeliminatorNotASCII { loc: AstNodeLocation, }, + KeysEmptyList { + loc: AstNodeLocation, + }, + KeysNotString { + loc: AstNodeLocation, + }, + FieldsNotListOfString { + loc: AstNodeLocation, + }, } impl FrontCompileErrorTrait for InputFilesError { @@ -224,6 +301,18 @@ impl FrontCompileErrorTrait for InputFilesError { Self::DeliminatorNotASCII { loc } => { format!("`deliminator` attribute is not an ASCII character\n{}", loc.report(src)) } + Self::KeysEmptyList { loc } => { + format!("`keys` attribute is an empty list\n{}", loc.report(src)) + } + Self::KeysNotString { loc } => { + format!( + "`keys` attribute is not a string or list of strings\n{}", + loc.report(src) + ) + } + Self::FieldsNotListOfString { loc } => { + format!("`fields` attribute is not a list of strings\n{}", loc.report(src)) + } } } } diff --git a/core/src/compiler/front/analyzers/invalid_constant.rs b/core/src/compiler/front/analyzers/invalid_constant.rs index f2cd34f..31baafc 100644 --- a/core/src/compiler/front/analyzers/invalid_constant.rs +++ b/core/src/compiler/front/analyzers/invalid_constant.rs @@ -12,9 +12,21 @@ impl InvalidConstantAnalyzer { } impl NodeVisitor for InvalidConstantAnalyzer { - fn visit_constant(&mut self, constant: &Constant) { + fn visit_constant_datetime(&mut self, constant: &ConstantDateTime) { match &constant.node { - ConstantNode::Invalid(message) => { + Err(message) => { + self.errors.push(InvalidConstantError::InvalidConstant { + loc: constant.location().clone(), + message: message.clone(), + }); + } + _ => {} + } + } + + fn visit_constant_duration(&mut self, constant: &ConstantDuration) { + match &constant.node { + Err(message) => { self.errors.push(InvalidConstantError::InvalidConstant { loc: constant.location().clone(), message: message.clone(), @@ -27,10 +39,7 @@ impl NodeVisitor for InvalidConstantAnalyzer { #[derive(Clone, Debug)] pub enum InvalidConstantError { - InvalidConstant { - loc: AstNodeLocation, - message: String, - }, + InvalidConstant { loc: AstNodeLocation, message: String }, } impl FrontCompileErrorTrait for InvalidConstantError { diff --git a/core/src/compiler/front/analyzers/mod.rs b/core/src/compiler/front/analyzers/mod.rs index 96b4a35..72faf60 100644 --- a/core/src/compiler/front/analyzers/mod.rs +++ b/core/src/compiler/front/analyzers/mod.rs @@ -1,4 +1,5 @@ pub mod aggregation; +pub mod algebraic_data_type; pub mod boundness; pub mod character_literal; pub mod constant_decl; @@ -12,6 +13,7 @@ pub mod output_files; pub mod type_inference; pub use aggregation::AggregationAnalysis; +pub use algebraic_data_type::AlgebraicDataTypeAnalysis; pub use boundness::BoundnessAnalysis; pub use character_literal::CharacterLiteralAnalysis; pub use constant_decl::ConstantDeclAnalysis; @@ -26,6 +28,7 @@ pub use type_inference::TypeInference; pub mod errors { pub use super::aggregation::AggregationAnalysisError; + pub use super::algebraic_data_type::ADTError; pub use super::boundness::BoundnessAnalysisError; pub use super::constant_decl::ConstantDeclError; pub use super::demand_attr::DemandAttributeError; diff --git a/core/src/compiler/front/analyzers/output_files.rs b/core/src/compiler/front/analyzers/output_files.rs index bb1456e..cd3eca3 100644 --- a/core/src/compiler/front/analyzers/output_files.rs +++ b/core/src/compiler/front/analyzers/output_files.rs @@ -22,10 +22,10 @@ impl OutputFilesAnalysis { self.output_files.get(relation) } - pub fn process_deliminator(&self, attr_arg: Option<&Constant>) -> Result, OutputFilesError> { + pub fn process_deliminator(&self, attr_arg: Option<&AttributeValue>) -> Result, OutputFilesError> { match attr_arg { - Some(v) => match &v.node { - ConstantNode::String(s) => { + Some(v) => match v.as_string() { + Some(s) => { if s.len() == 1 { let c = s.chars().next().unwrap(); if c.is_ascii() { @@ -50,10 +50,9 @@ impl OutputFilesAnalysis { } pub fn process_attribute(&self, attr: &Attribute) -> Result { - if attr.num_pos_args() > 0 { - let arg = attr.pos_arg(0).unwrap(); - match &arg.node { - ConstantNode::String(s) => { + if let Some(arg) = attr.pos_arg(0) { + match arg.as_string() { + Some(s) => { let path = PathBuf::from(s); match path.extension() { Some(s) if s == "csv" => { diff --git a/core/src/compiler/front/analyzers/type_inference/error.rs b/core/src/compiler/front/analyzers/type_inference/error.rs index c8f9ae7..ab6ea06 100644 --- a/core/src/compiler/front/analyzers/type_inference/error.rs +++ b/core/src/compiler/front/analyzers/type_inference/error.rs @@ -5,6 +5,9 @@ use super::*; #[derive(Clone, Debug)] pub enum TypeInferenceError { + UnknownRelation { + relation: String, + }, DuplicateTypeDecl { type_name: String, source_decl_loc: AstNodeLocation, @@ -15,6 +18,10 @@ pub enum TypeInferenceError { source_decl_loc: AstNodeLocation, duplicate_decl_loc: AstNodeLocation, }, + UnknownADTVariant { + predicate: String, + loc: AstNodeLocation, + }, InvalidSubtype { source_type: TypeNode, source_type_loc: AstNodeLocation, @@ -47,6 +54,18 @@ pub enum TypeInferenceError { actual: usize, loc: AstNodeLocation, }, + ADTVariantArityMismatch { + variant: String, + expected: usize, + actual: usize, + loc: AstNodeLocation, + }, + EntityTupleArityMismatch { + predicate: String, + expected: usize, + actual: usize, + source_loc: AstNodeLocation, + }, InvalidArgIndex { predicate: String, index: usize, @@ -134,6 +153,9 @@ pub enum TypeInferenceError { pred: String, loc: AstNodeLocation, }, + Internal { + error_string: String, + }, } impl TypeInferenceError { @@ -154,6 +176,9 @@ impl FrontCompileErrorTrait for TypeInferenceError { fn report(&self, src: &Sources) -> String { match self { + Self::UnknownRelation { relation } => { + format!("unknown relation `{relation}`") + } Self::DuplicateTypeDecl { type_name, source_decl_loc, @@ -174,6 +199,12 @@ impl FrontCompileErrorTrait for TypeInferenceError { predicate, source_decl_loc.report(src), duplicate_decl_loc.report(src) ) } + Self::UnknownADTVariant { predicate, loc } => { + format!( + "unknown algebraic data type variant `{predicate}`:\n{}", + loc.report(src) + ) + } Self::InvalidSubtype { source_type, source_type_loc, @@ -204,10 +235,7 @@ impl FrontCompileErrorTrait for TypeInferenceError { .. } => { format!( - "arity mismatch for relation `{}`. Expected {}, found {}:\n{}", - predicate, - expected, - actual, + "arity mismatch for relation `{predicate}`. Expected {expected}, found {actual}:\n{}", mismatch_loc.report(src) ) } @@ -219,6 +247,28 @@ impl FrontCompileErrorTrait for TypeInferenceError { loc.report(src) ) } + Self::ADTVariantArityMismatch { + variant, + expected, + actual, + loc, + } => { + format!( + "arity mismatch for algebraic data type variant `{variant}`. Expected {expected}, found {actual}:\n{}", + loc.report(src) + ) + } + Self::EntityTupleArityMismatch { + predicate, + expected, + actual, + source_loc, + } => { + format!( + "incorrect number of arguments in entity tuple for `{predicate}`. Expected {expected}, found {actual}:\n{}", + source_loc.report(src) + ) + } Self::InvalidArgIndex { predicate, index, @@ -237,7 +287,9 @@ impl FrontCompileErrorTrait for TypeInferenceError { } => { format!( "Invalid `{}`-th argument for foreign predicate `{}`:\n{}", - index, predicate, access_loc.report(src) + index, + predicate, + access_loc.report(src) ) } Self::ConstantSetArityMismatch { @@ -283,8 +335,14 @@ impl FrontCompileErrorTrait for TypeInferenceError { t1, t2, t1.location().report(src), t2.location().report(src) ) } - } - Self::CannotUnifyForeignPredicateArgument { pred, i, expected_ty, actual_ty, loc } => { + }, + Self::CannotUnifyForeignPredicateArgument { + pred, + i, + expected_ty, + actual_ty, + loc, + } => { format!( "cannot unify the type of {}-th argument of foreign predicate `{}`, expected type `{}`, found `{}`:\n{}", i, @@ -304,7 +362,12 @@ impl FrontCompileErrorTrait for TypeInferenceError { loc.report(src) ) } - Self::NoMatchingTripletRule { op1_ty, op2_ty, e_ty, location } => { + Self::NoMatchingTripletRule { + op1_ty, + op2_ty, + e_ty, + location, + } => { format!( "no matching rule found; two operands have type `{}` and `{}`, while the expression has type `{}`:\n{}", op1_ty, @@ -375,6 +438,7 @@ impl FrontCompileErrorTrait for TypeInferenceError { loc.report(src), ) } + Self::Internal { error_string } => error_string.clone(), } } } diff --git a/core/src/compiler/front/analyzers/type_inference/foreign_predicate.rs b/core/src/compiler/front/analyzers/type_inference/foreign_predicate.rs index 932aad7..09fc5a4 100644 --- a/core/src/compiler/front/analyzers/type_inference/foreign_predicate.rs +++ b/core/src/compiler/front/analyzers/type_inference/foreign_predicate.rs @@ -1,7 +1,7 @@ use std::collections::*; -use crate::common::value_type::*; use crate::common::foreign_predicate::*; +use crate::common::value_type::*; /// The type of a foreign predicate. /// Essentially a list of basic types. @@ -62,7 +62,7 @@ impl PredicateTypeRegistry { /// Add a new foreign predicate to the predicate type registry pub fn add_foreign_predicate(&mut self, p: &P) { - self.predicate_types.insert(p.name(), PredicateType::from(p)); + self.predicate_types.insert(p.internal_name(), PredicateType::from(p)); } /// Check if the registry contains a predicate diff --git a/core/src/compiler/front/analyzers/type_inference/local.rs b/core/src/compiler/front/analyzers/type_inference/local.rs index bc428fa..89e8406 100644 --- a/core/src/compiler/front/analyzers/type_inference/local.rs +++ b/core/src/compiler/front/analyzers/type_inference/local.rs @@ -43,9 +43,15 @@ impl LocalTypeInferenceContext { pub fn unify_atom_arities( &self, + predicate_registry: &PredicateTypeRegistry, inferred_relation_types: &mut HashMap, Loc)>, ) -> Result<(), TypeInferenceError> { for (pred, arities) in &self.atom_arities { + // Skip foreign predicates + if predicate_registry.contains_predicate(pred) { + continue; + } + // Make sure we have inferred relation types for the predicate if !inferred_relation_types.contains_key(pred) { let (arity, atom_loc) = &arities[0]; @@ -214,8 +220,9 @@ impl LocalTypeInferenceContext { .collect::>(); if !tys.is_empty() { let ty = TypeSet::unify_type_sets(tys)?; - let arg_types = &mut inferred_relation_types.get_mut(predicate).unwrap().0; - arg_types[*i] = ty; + if let Some((arg_types, _)) = inferred_relation_types.get_mut(predicate) { + arg_types[*i] = ty; + } } } @@ -474,4 +481,13 @@ impl NodeVisitor for LocalTypeInferenceContext { ); self.unifications.push(unif) } + + fn visit_new_expr(&mut self, n: &NewExpr) { + let unif = Unification::New( + n.functor_name().to_string(), + n.iter_args().map(|a| a.location().clone()).collect(), + n.location().clone(), + ); + self.unifications.push(unif) + } } diff --git a/core/src/compiler/front/analyzers/type_inference/operator_rules.rs b/core/src/compiler/front/analyzers/type_inference/operator_rules.rs index 6d36125..170b137 100644 --- a/core/src/compiler/front/analyzers/type_inference/operator_rules.rs +++ b/core/src/compiler/front/analyzers/type_inference/operator_rules.rs @@ -24,9 +24,11 @@ lazy_static! { (DateTime, Duration, DateTime), (Duration, DateTime, DateTime), (Duration, Duration, Duration), + (Tensor, Tensor, Tensor), + (Tensor, F64, Tensor), + (F64, Tensor, Tensor), ] }; - pub static ref SUB_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { use ValueType::*; vec![ @@ -47,9 +49,11 @@ lazy_static! { (DateTime, Duration, DateTime), (DateTime, DateTime, Duration), (Duration, Duration, Duration), + (Tensor, Tensor, Tensor), + (Tensor, F64, Tensor), + (F64, Tensor, Tensor), ] }; - pub static ref MULT_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { use ValueType::*; vec![ @@ -69,9 +73,11 @@ lazy_static! { (F64, F64, F64), (Duration, I32, Duration), (I32, Duration, Duration), + (Tensor, Tensor, Tensor), + (Tensor, F64, Tensor), + (F64, Tensor, Tensor), ] }; - pub static ref DIV_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { use ValueType::*; vec![ @@ -92,7 +98,6 @@ lazy_static! { (Duration, I32, Duration), ] }; - pub static ref MOD_TYPING_RULES: Vec<(ValueType, ValueType, ValueType)> = { use ValueType::*; vec![ @@ -112,7 +117,6 @@ lazy_static! { (F64, F64, F64), ] }; - pub static ref COMPARE_TYPING_RULES: Vec<(ValueType, ValueType)> = { use ValueType::*; vec![ diff --git a/core/src/compiler/front/analyzers/type_inference/type_inference.rs b/core/src/compiler/front/analyzers/type_inference/type_inference.rs index b6f6c95..d1dfab9 100644 --- a/core/src/compiler/front/analyzers/type_inference/type_inference.rs +++ b/core/src/compiler/front/analyzers/type_inference/type_inference.rs @@ -3,6 +3,7 @@ use std::collections::*; use crate::common::foreign_function::*; use crate::common::foreign_predicate::*; use crate::common::tuple_type::*; +use crate::common::value::*; use crate::common::value_type::*; use crate::compiler::front::*; @@ -10,30 +11,52 @@ use super::*; #[derive(Clone, Debug)] pub struct TypeInference { + /// A mapping from the custom type name to its inferred value type and decl location pub custom_types: HashMap, + + /// A mapping from referred constant variables' location to its declared type, if the type is specified in the constant declaration pub constant_types: HashMap, + + /// Foreign function types pub foreign_function_type_registry: FunctionTypeRegistry, + + /// Foreign predicate types pub foreign_predicate_type_registry: PredicateTypeRegistry, + + /// A mapping from relation name to its type declaration location pub relation_type_decl_loc: HashMap, + + /// A mapping from internal relation name to ADT variant name, e.g. `adt#Node` -> `Node` + pub adt_relations: HashMap)>, + + /// A mapping from relation name to its argument types `Vec` and the location `Loc` where such type is inferred pub inferred_relation_types: HashMap, Loc)>, + + /// A mapping { Rule Location: { Variable Name: Inferred Type } } pub rule_variable_type: HashMap>, + + /// The local inference contexts of a rule pub rule_local_contexts: Vec, + + /// A mapping from a relation name that is queried to the location where it is queried pub query_relations: HashMap, + + /// A mapping from expression ID to its inferred types pub expr_types: HashMap, + + /// A list of errors obtained from the type inference process pub errors: Vec, } impl TypeInference { - pub fn new( - function_registry: &ForeignFunctionRegistry, - predicate_registry: &ForeignPredicateRegistry, - ) -> Self { + pub fn new(function_registry: &ForeignFunctionRegistry, predicate_registry: &ForeignPredicateRegistry) -> Self { Self { custom_types: HashMap::new(), constant_types: HashMap::new(), foreign_function_type_registry: FunctionTypeRegistry::from_foreign_function_registry(function_registry), foreign_predicate_type_registry: PredicateTypeRegistry::from_foreign_predicate_registry(predicate_registry), relation_type_decl_loc: HashMap::new(), + adt_relations: HashMap::new(), inferred_relation_types: HashMap::new(), rule_variable_type: HashMap::new(), rule_local_contexts: Vec::new(), @@ -63,11 +86,15 @@ impl TypeInference { } pub fn relations(&self) -> Vec { - self - .inferred_relation_types - .iter() - .filter_map(|(n, _)| if !n.contains("#") { Some(n.clone()) } else { None }) - .collect() + if self.query_relations.is_empty() { + self + .inferred_relation_types + .iter() + .filter_map(|(n, _)| if !n.contains("#") { Some(n.clone()) } else { None }) + .collect() + } else { + self.query_relations.iter().map(|(n, _)| n.clone()).collect() + } } pub fn has_relation(&self, relation: &str) -> bool { @@ -203,6 +230,62 @@ impl TypeInference { } } + pub fn parse_entity_facts( + &self, + facts: &Vec, + ) -> Result)>, TypeInferenceError> { + facts.iter().map(|fact| self.parse_entity_fact(fact)).collect() + } + + pub fn parse_entity_fact(&self, fact: &EntityFact) -> Result<(String, Value, Vec), TypeInferenceError> { + let relation_name = format!("adt#{}", fact.functor.name()); + if self.adt_relations.contains_key(&relation_name) { + // First get the value type + let value_types = self + .relation_arg_types(&relation_name) + .expect("[Internal Error] adt variant not found in storage; this is probably a bug"); + + // Make sure that the arity matches + if value_types.len() != fact.args.len() + 1 { + return Err(TypeInferenceError::ADTVariantArityMismatch { + variant: fact.functor.name().to_string(), + expected: value_types.len() - 1, + actual: fact.args.len(), + loc: fact.loc.clone(), + }); + } + + // Get the id value + let id_value = fact.id.to_value(&value_types[0]); + + // Get the arg values + let arg_values = fact + .args + .iter() + .zip(value_types.iter().skip(1)) + .map(|(arg, ty)| { + if arg.can_unify(ty) { + Ok(arg.to_value(ty)) + } else { + Err(TypeInferenceError::CannotUnifyTypes { + t1: TypeSet::from_constant(arg), + t2: TypeSet::base(ty.clone()), + loc: None, + }) + } + }) + .collect::, _>>()?; + + // The returned fact contains relation name, id, and args + Ok((relation_name, id_value, arg_values)) + } else { + Err(TypeInferenceError::UnknownADTVariant { + predicate: fact.functor.name().to_string(), + loc: fact.functor.location().clone(), + }) + } + } + pub fn infer_types(&mut self) { if let Err(err) = self.infer_types_helper() { self.errors.push(err); @@ -211,8 +294,9 @@ impl TypeInference { fn infer_types_helper(&mut self) -> Result<(), TypeInferenceError> { // Mapping from variable to set of expressions - // Mapping from relation argument to set of expressions let mut inferred_var_expr = HashMap::>>::new(); + + // Mapping from relation argument to set of expressions let mut inferred_relation_expr = HashMap::<(String, usize), BTreeSet>::new(); // Fixpoint states @@ -284,6 +368,39 @@ impl NodeVisitor for TypeInference { ); } + fn visit_relation_type_decl(&mut self, relation_type_decl: &RelationTypeDecl) { + if let Some(attr) = relation_type_decl.attributes().find("adt") { + // Get the variant name string + let adt_variant_name = attr + .pos_arg_to_string(0) + .expect("[Internal Error] internally annotated adt attribute does not have a string as the argument 0"); + + // Get the adt variant + let adt_variant_relation_name = format!("adt#{adt_variant_name}"); + + // Get the is_entity list + let adt_is_entity_list: Vec = attr + .pos_arg_to_list(1) + .expect("[Internal Error] internally annotated adt attribute does not have a list of boolean as the argument 1") + .iter() + .map(|arg| { + arg + .as_bool() + .expect( + "[Internal Error] internally annotated adt attribute does not have a list of boolean as the argument 1", + ) + .clone() + }) + .collect(); + + // Add the adt annotation into the `adt_relations` mapping + self.adt_relations.insert( + adt_variant_relation_name, + (adt_variant_name.clone(), adt_is_entity_list), + ); + } + } + fn visit_relation_type(&mut self, relation_type: &RelationType) { // Check if the relation is a foreign predicate let predicate = relation_type.predicate(); @@ -308,6 +425,7 @@ impl NodeVisitor for TypeInference { self.check_and_add_custom_type(enum_type_decl.name(), &ty, enum_type_decl.location()); // And then declare all the constant types + // Note: we do not check for duplicated names here, as they are handled by `ConstantDeclAnalysis`. for member in enum_type_decl.iter_members() { match member.assigned_number() { Some(c) => match &c.node { @@ -319,13 +437,11 @@ impl NodeVisitor for TypeInference { }) } } - _ => { - self.errors.push(TypeInferenceError::BadEnumValueKind { - found: c.kind(), - loc: c.location().clone(), - }) - } - } + _ => self.errors.push(TypeInferenceError::BadEnumValueKind { + found: c.kind(), + loc: c.location().clone(), + }), + }, _ => {} } } @@ -334,7 +450,7 @@ impl NodeVisitor for TypeInference { fn visit_const_assignment(&mut self, const_assign: &ConstAssignment) { if let Some(raw_type) = const_assign.ty() { let result = find_value_type(&self.custom_types, raw_type).and_then(|ty| { - let ts = TypeSet::from_constant(const_assign.value()); + let ts = TypeSet::from_constant(const_assign.value().get_constant().expect("[Internal Error] During type inference, all entities should be normalized to constant. This is probably an internal error.")); ts.unify(&TypeSet::BaseType(ty, raw_type.location().clone())) }); match result { @@ -455,7 +571,7 @@ impl NodeVisitor for TypeInference { let pred = fact_decl.predicate(); // Check if the relation is a foreign predicate - if self.foreign_predicate_type_registry.contains_predicate(pred) { + if self.foreign_predicate_type_registry.contains_predicate(&pred) { self.errors.push(TypeInferenceError::CannotRedefineForeignPredicate { pred: pred.to_string(), loc: fact_decl.location().clone(), @@ -463,6 +579,17 @@ impl NodeVisitor for TypeInference { return; } + // Check if the relation is an ADT + if pred.contains("adt#") { + // Make sure that the predicate is an existing ADT relation + if !self.adt_relations.contains_key(&pred) { + self.errors.push(TypeInferenceError::UnknownADTVariant { + predicate: pred[4..].to_string(), + loc: fact_decl.atom().predicate_identifier().location().clone(), + }) + } + } + let maybe_curr_type_sets = fact_decl .iter_arguments() .map(|arg| match arg { @@ -481,18 +608,27 @@ impl NodeVisitor for TypeInference { }; // Check the type - if self.inferred_relation_types.contains_key(pred) { - let (original_type_sets, original_type_def_loc) = &self.inferred_relation_types[pred]; + if self.inferred_relation_types.contains_key(&pred) { + let (original_type_sets, original_type_def_loc) = &self.inferred_relation_types[&pred]; // First check if the arity matches if curr_type_sets.len() != original_type_sets.len() { - self.errors.push(TypeInferenceError::ArityMismatch { - predicate: pred.clone(), - expected: original_type_sets.len(), - actual: curr_type_sets.len(), - source_loc: original_type_def_loc.clone(), - mismatch_loc: fact_decl.atom().location().clone(), - }); + if let Some((variant_name, _)) = self.adt_relations.get(&pred) { + self.errors.push(TypeInferenceError::ADTVariantArityMismatch { + variant: variant_name.clone(), + expected: original_type_sets.len() - 1, + actual: curr_type_sets.len() - 1, + loc: fact_decl.atom().location().clone(), + }); + } else { + self.errors.push(TypeInferenceError::ArityMismatch { + predicate: pred.clone(), + expected: original_type_sets.len(), + actual: curr_type_sets.len(), + source_loc: original_type_def_loc.clone(), + mismatch_loc: fact_decl.atom().location().clone(), + }); + } return; } @@ -504,7 +640,7 @@ impl NodeVisitor for TypeInference { .collect::, _>>(); match maybe_new_type_sets { Ok(new_type_sets) => { - self.inferred_relation_types.get_mut(pred).unwrap().0 = new_type_sets; + self.inferred_relation_types.get_mut(&pred).unwrap().0 = new_type_sets; } Err(err) => self.errors.push(err), } @@ -518,7 +654,7 @@ impl NodeVisitor for TypeInference { fn visit_rule(&mut self, rule: &Rule) { for pred in rule.head().iter_predicates() { // Check if a head predicate is a foreign predicate - if self.foreign_predicate_type_registry.contains_predicate(pred) { + if self.foreign_predicate_type_registry.contains_predicate(&pred) { self.errors.push(TypeInferenceError::CannotRedefineForeignPredicate { pred: pred.to_string(), loc: rule.location().clone(), @@ -537,7 +673,7 @@ impl NodeVisitor for TypeInference { } // First unify atom arity - if let Err(err) = ctx.unify_atom_arities(&mut self.inferred_relation_types) { + if let Err(err) = ctx.unify_atom_arities(&self.foreign_predicate_type_registry, &mut self.inferred_relation_types) { self.errors.push(err); return; } @@ -569,7 +705,9 @@ impl NodeVisitor for TypeInference { } // Unify atom arity - if let Err(err) = ctx.unify_atom_arities(&mut self.inferred_relation_types) { + if let Err(err) = + ctx.unify_atom_arities(&self.foreign_predicate_type_registry, &mut self.inferred_relation_types) + { self.errors.push(err); return; } diff --git a/core/src/compiler/front/analyzers/type_inference/type_set.rs b/core/src/compiler/front/analyzers/type_inference/type_set.rs index b79764a..f5a586a 100644 --- a/core/src/compiler/front/analyzers/type_inference/type_set.rs +++ b/core/src/compiler/front/analyzers/type_inference/type_set.rs @@ -190,9 +190,10 @@ impl TypeSet { ConstantNode::Char(_) => Self::BaseType(ValueType::Char, c.location().clone()), ConstantNode::Boolean(_) => Self::BaseType(ValueType::Bool, c.location().clone()), ConstantNode::String(_) => Self::String(c.location().clone()), + ConstantNode::Symbol(_) => Self::BaseType(ValueType::Symbol, c.location().clone()), ConstantNode::DateTime(_) => Self::BaseType(ValueType::DateTime, c.location().clone()), ConstantNode::Duration(_) => Self::BaseType(ValueType::Duration, c.location().clone()), - ConstantNode::Invalid(_) => panic!("[Internal Error] Should not be called with invalid constant"), + ConstantNode::Entity(_) => Self::BaseType(ValueType::Entity, c.location().clone()), } } diff --git a/core/src/compiler/front/analyzers/type_inference/unification.rs b/core/src/compiler/front/analyzers/type_inference/unification.rs index 5a0c9a1..31ab6f6 100644 --- a/core/src/compiler/front/analyzers/type_inference/unification.rs +++ b/core/src/compiler/front/analyzers/type_inference/unification.rs @@ -55,6 +55,9 @@ pub enum Unification { /// f, ops*, $f(ops*) Call(String, Vec, Loc), + + /// C, ops*, new C(ops*) + New(String, Vec, Loc), } impl Unification { @@ -78,15 +81,13 @@ impl Unification { // Unify the type match unify_ty(e, ty.clone(), inferred_expr_types) { Ok(_) => Ok(()), - Err(_) => { - Err(TypeInferenceError::CannotUnifyForeignPredicateArgument { - pred: p.clone(), - i: *i, - expected_ty: ty, - actual_ty: inferred_expr_types.get(e).unwrap().clone(), - loc: e.clone(), - }) - } + Err(_) => Err(TypeInferenceError::CannotUnifyForeignPredicateArgument { + pred: p.clone(), + i: *i, + expected_ty: ty, + actual_ty: inferred_expr_types.get(e).unwrap().clone(), + loc: e.clone(), + }), } } else { Err(TypeInferenceError::InvalidForeignPredicateArgIndex { @@ -359,6 +360,43 @@ impl Unification { }) } } + Self::New(functor, args, e) => { + let adt_variant_relation_name = format!("adt#{functor}"); + + // cond should be boolean + unify_entity(e, inferred_expr_types)?; + + // Get the functor/relation + if let Some((types, _)) = inferred_relation_types.get(&adt_variant_relation_name) { + if args.len() == types.len() - 1 { + for (arg, ty) in args.iter().zip(types.iter().skip(1)) { + let arg_ty = get_or_insert_ty(arg, TypeSet::Any(arg.clone()), inferred_expr_types); + match arg_ty.unify(ty) { + Ok(new_ty) => { + inferred_expr_types.insert(arg.clone(), new_ty); + } + Err(mut err) => { + err.annotate_location(arg); + return Err(err); + } + } + } + Ok(()) + } else { + Err(TypeInferenceError::ADTVariantArityMismatch { + variant: functor.clone(), + expected: types.len() - 1, + actual: args.len(), + loc: e.clone(), + }) + } + } else { + Err(TypeInferenceError::UnknownADTVariant { + predicate: functor.clone(), + loc: e.clone(), + }) + } + } } } } @@ -417,7 +455,7 @@ fn unify_polymorphic_binary_expression( e_ty, location: e.clone(), }) - }, + } AppliedRules::One((t1, t2, te)) => { // If there is exactly one rule that can be applied, then unify them with the exact types unify_ty(op1, TypeSet::BaseType(t1, e.clone()), inferred_expr_types)?; @@ -429,7 +467,7 @@ fn unify_polymorphic_binary_expression( // If ther are multiple rules that can be applied, we are not sure about the exact types, // but the type inference is still successful Ok(()) - }, + } } } @@ -465,7 +503,7 @@ fn unify_comparison_expression( e_ty, location: e.clone(), }) - }, + } AppliedRules::One((t1, t2)) => { // If there is exactly one rule that can be applied, then unify them with the exact types unify_ty(op1, TypeSet::BaseType(t1, e.clone()), inferred_expr_types)?; @@ -476,7 +514,7 @@ fn unify_comparison_expression( // If ther are multiple rules that can be applied, we are not sure about the exact types, // but the type inference is still successful Ok(()) - }, + } } } @@ -507,3 +545,8 @@ fn unify_boolean(e: &Loc, inferred_expr_types: &mut HashMap) -> Re let e_ty = TypeSet::BaseType(ValueType::Bool, e.clone()); unify_ty(e, e_ty, inferred_expr_types) } + +fn unify_entity(e: &Loc, inferred_expr_types: &mut HashMap) -> Result { + let e_ty = TypeSet::BaseType(ValueType::Entity, e.clone()); + unify_ty(e, e_ty, inferred_expr_types) +} diff --git a/core/src/compiler/front/ast/attr.rs b/core/src/compiler/front/ast/attr.rs index b3f1522..ae2ae97 100644 --- a/core/src/compiler/front/ast/attr.rs +++ b/core/src/compiler/front/ast/attr.rs @@ -1,17 +1,106 @@ +use std::iter::FromIterator; + +use serde::*; + use super::*; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] +#[doc(hidden)] +pub enum AttributeValueNode { + Constant(Constant), + List(Vec), + Tuple(Vec), +} + +/// The value of an attribute; it could be a list or a constant +pub type AttributeValue = AstNode; + +impl AttributeValue { + pub fn constant(c: Constant) -> Self { + Self::default(AttributeValueNode::Constant(c)) + } + + pub fn is_constant(&self) -> bool { + match &self.node { + AttributeValueNode::Constant(_) => true, + _ => false, + } + } + + pub fn is_list(&self) -> bool { + match &self.node { + AttributeValueNode::List(_) => true, + _ => false, + } + } + + pub fn is_tuple(&self) -> bool { + match &self.node { + AttributeValueNode::Tuple(_) => true, + _ => false, + } + } + + pub fn as_constant(&self) -> Option<&Constant> { + match &self.node { + AttributeValueNode::Constant(c) => Some(c), + _ => None, + } + } + + pub fn as_bool(&self) -> Option<&bool> { + self.as_constant().and_then(Constant::as_bool) + } + + pub fn as_integer(&self) -> Option<&i64> { + self.as_constant().and_then(Constant::as_integer) + } + + pub fn as_string(&self) -> Option<&String> { + self.as_constant().and_then(Constant::as_string) + } + + pub fn as_list(&self) -> Option<&Vec> { + match &self.node { + AttributeValueNode::List(l) => Some(l), + _ => None, + } + } +} + +impl From for AttributeValue { + fn from(c: Constant) -> Self { + Self::new(c.location().clone_without_id(), AttributeValueNode::Constant(c)) + } +} + +impl FromIterator for AttributeValue { + fn from_iter>(iter: T) -> Self { + Self::default(AttributeValueNode::List(iter.into_iter().collect())) + } +} + +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct AttributeNode { pub name: Identifier, - pub pos_args: Vec, - pub kw_args: Vec<(Identifier, Constant)>, + pub pos_args: Vec, + pub kw_args: Vec<(Identifier, AttributeValue)>, } /// An attribute of the form `@attr(args...)` pub type Attribute = AstNode; impl Attribute { + pub fn default_with_name(n: String) -> Self { + AttributeNode { + name: Identifier::default_with_name(n), + pos_args: Vec::new(), + kw_args: Vec::new(), + } + .into() + } + pub fn name(&self) -> &String { &self.node.name.node.name } @@ -20,11 +109,42 @@ impl Attribute { self.node.pos_args.len() } - pub fn pos_arg(&self, i: usize) -> Option<&Constant> { + pub fn pos_arg(&self, i: usize) -> Option<&AttributeValue> { self.node.pos_args.get(i) } - pub fn iter_pos_args(&self) -> impl Iterator { + pub fn pos_arg_to_bool(&self, i: usize) -> Option<&bool> { + self + .node + .pos_args + .get(i) + .and_then(AttributeValue::as_constant) + .and_then(Constant::as_bool) + } + + pub fn pos_arg_to_integer(&self, i: usize) -> Option<&i64> { + self + .node + .pos_args + .get(i) + .and_then(AttributeValue::as_constant) + .and_then(Constant::as_integer) + } + + pub fn pos_arg_to_string(&self, i: usize) -> Option<&String> { + self + .node + .pos_args + .get(i) + .and_then(AttributeValue::as_constant) + .and_then(Constant::as_string) + } + + pub fn pos_arg_to_list(&self, i: usize) -> Option<&Vec> { + self.node.pos_args.get(i).and_then(AttributeValue::as_list) + } + + pub fn iter_pos_args(&self) -> impl Iterator { self.node.pos_args.iter() } @@ -32,7 +152,7 @@ impl Attribute { self.node.kw_args.len() } - pub fn kw_arg(&self, kw: &str) -> Option<&Constant> { + pub fn kw_arg(&self, kw: &str) -> Option<&AttributeValue> { for (name, arg) in &self.node.kw_args { if name.name() == kw { return Some(arg); @@ -44,3 +164,18 @@ impl Attribute { /// A list of attributes pub type Attributes = Vec; + +pub trait AttributesTrait { + fn find(&self, name: &str) -> Option<&Attribute>; +} + +impl AttributesTrait for Attributes { + fn find(&self, name: &str) -> Option<&Attribute> { + for attr in self { + if attr.name() == name { + return Some(attr); + } + } + None + } +} diff --git a/core/src/compiler/front/ast/const_decl.rs b/core/src/compiler/front/ast/const_decl.rs index 2003860..1943376 100644 --- a/core/src/compiler/front/ast/const_decl.rs +++ b/core/src/compiler/front/ast/const_decl.rs @@ -1,11 +1,13 @@ +use serde::*; + use super::*; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ConstAssignmentNode { pub name: Identifier, pub ty: Option, - pub value: Constant, + pub value: Entity, } /// A single constant assignment, e.g. `X = 42` @@ -32,16 +34,16 @@ impl ConstAssignment { self.node.ty.as_mut() } - pub fn value(&self) -> &Constant { + pub fn value(&self) -> &Entity { &self.node.value } - pub fn value_mut(&mut self) -> &mut Constant { + pub fn value_mut(&mut self) -> &mut Entity { &mut self.node.value } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ConstDeclNode { pub attrs: Attributes, diff --git a/core/src/compiler/front/ast/constant.rs b/core/src/compiler/front/ast/constant.rs index 57574e0..34f34a1 100644 --- a/core/src/compiler/front/ast/constant.rs +++ b/core/src/compiler/front/ast/constant.rs @@ -1,11 +1,12 @@ -// use std::rc::Rc; +use chrono::serde::ts_seconds; +use serde::Serialize; use super::*; use crate::common::input_tag::DynamicInputTag; use crate::common::value::Value; use crate::common::value_type::ValueType; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct TagNode(pub DynamicInputTag); @@ -26,19 +27,144 @@ impl Tag { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Hash, Serialize)] +pub struct ConstantCharNode { + pub character: String, +} + +pub type ConstantChar = AstNode; + +impl ConstantChar { + pub fn character(&self) -> char { + // Unwrap is ok since during parsing + self.node.character.chars().next().unwrap() + } + + pub fn character_string(&self) -> &String { + &self.node.character + } +} + +#[derive(Clone, Debug, PartialEq, Hash, Serialize)] +pub struct ConstantStringNode { + pub string: String, +} + +impl ConstantStringNode { + pub fn new(string: String) -> Self { + Self { string } + } +} + +pub type ConstantString = AstNode; + +impl ConstantString { + pub fn string(&self) -> &String { + &self.node.string + } + + pub fn string_mut(&mut self) -> &mut String { + &mut self.node.string + } +} + +#[derive(Clone, Debug, PartialEq, Hash, Serialize)] +pub struct ConstantSymbolNode { + pub symbol: String, +} + +pub type ConstantSymbol = AstNode; + +impl ConstantSymbol { + pub fn symbol(&self) -> &String { + &self.node.symbol + } +} + +#[derive(Clone, Debug, PartialEq, Hash, Serialize)] +pub struct ConstantDateTimeNode { + #[serde(with = "ts_seconds")] + pub datetime: chrono::DateTime, +} + +pub type ConstantDateTime = AstNode>; + +impl ConstantDateTime { + pub fn datetime(&self) -> Option<&chrono::DateTime> { + self.node.as_ref().ok().map(|n| &n.datetime) + } +} + +#[derive(Clone, Debug, PartialEq, Hash)] +pub struct ConstantDurationNode { + pub duration: chrono::Duration, +} + +impl Serialize for ConstantDurationNode { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::*; + let mut state = serializer.serialize_struct("ConstantDurationNode", 1)?; + state.serialize_field("duration", &self.duration.num_seconds())?; + state.end() + } +} + +pub type ConstantDuration = AstNode>; + +impl ConstantDuration { + pub fn duration(&self) -> Option<&chrono::Duration> { + self.node.as_ref().ok().map(|n| &n.duration) + } +} + +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub enum ConstantNode { Integer(i64), Float(f64), - Char(String), + Char(ConstantChar), Boolean(bool), - String(String), - DateTime(chrono::DateTime), - Duration(chrono::Duration), + String(ConstantString), + Symbol(ConstantSymbol), + DateTime(ConstantDateTime), + Duration(ConstantDuration), + Entity(u64), +} - /// Invalid is used to represent a constant that could not be parsed; the string is the error message - Invalid(String), +impl std::hash::Hash for ConstantNode { + fn hash(&self, state: &mut H) { + match self { + Self::Integer(i) => i.hash(state), + Self::Float(f) => i64::from_ne_bytes(f.to_ne_bytes()).hash(state), + Self::Char(c) => c.hash(state), + Self::Boolean(b) => b.hash(state), + Self::String(s) => s.hash(state), + Self::Symbol(s) => s.hash(state), + Self::DateTime(d) => d.hash(state), + Self::Duration(d) => d.hash(state), + Self::Entity(u) => u.hash(state), + } + } +} + +impl ConstantNode { + pub fn kind(&self) -> &'static str { + use ConstantNode::*; + match self { + Integer(_) => "integer", + Float(_) => "float", + String(_) => "string", + Symbol(_) => "symbol", + Char(_) => "char", + Boolean(_) => "boolean", + DateTime(_) => "datetime", + Duration(_) => "duration", + Entity(_) => "entity", + } + } } /// A constant, which could be an integer, floating point, character, boolean, or string. @@ -50,6 +176,48 @@ impl Constant { Self::default(ConstantNode::Integer(i)) } + pub fn float(f: f64) -> Self { + Self::default(ConstantNode::Float(f)) + } + + pub fn boolean(b: bool) -> Self { + Self::default(ConstantNode::Boolean(b)) + } + + pub fn string(s: String) -> Self { + Self::default(ConstantNode::String(ConstantStringNode::new(s).into())) + } + + pub fn can_unify(&self, ty: &ValueType) -> bool { + use ConstantNode::*; + match (&self.node, ty) { + (Integer(_), ValueType::I8) + | (Integer(_), ValueType::I16) + | (Integer(_), ValueType::I32) + | (Integer(_), ValueType::I64) + | (Integer(_), ValueType::I128) + | (Integer(_), ValueType::ISize) + | (Integer(_), ValueType::U8) + | (Integer(_), ValueType::U16) + | (Integer(_), ValueType::U32) + | (Integer(_), ValueType::U64) + | (Integer(_), ValueType::U128) + | (Integer(_), ValueType::USize) + | (Integer(_), ValueType::F32) + | (Integer(_), ValueType::F64) + | (Float(_), ValueType::F32) + | (Float(_), ValueType::F64) + | (Char(_), ValueType::Char) + | (Boolean(_), ValueType::Bool) + | (String(_), ValueType::String) + | (Symbol(_), ValueType::Symbol) + | (DateTime(_), ValueType::DateTime) + | (Duration(_), ValueType::Duration) + | (Entity(_), ValueType::Entity) => true, + _ => false, + } + } + pub fn to_value(&self, ty: &ValueType) -> Value { use ConstantNode::*; match (&self.node, ty) { @@ -69,34 +237,78 @@ impl Constant { (Integer(i), ValueType::F64) => Value::F64(*i as f64), (Float(f), ValueType::F32) => Value::F32(*f as f32), (Float(f), ValueType::F64) => Value::F64(*f as f64), - (Char(c), ValueType::Char) => Value::Char(c.chars().next().unwrap()), + (Char(c), ValueType::Char) => Value::Char(c.character()), (Boolean(b), ValueType::Bool) => Value::Bool(*b), (String(_), ValueType::Str) => panic!("Cannot cast dynamic string into static string"), - (String(s), ValueType::String) => Value::String(s.clone()), - // (String(s), ValueType::RcString) => Value::RcString(Rc::new(s.clone())), - (DateTime(d), ValueType::DateTime) => Value::DateTime(d.clone()), - (Duration(d), ValueType::Duration) => Value::Duration(d.clone()), + (String(s), ValueType::String) => Value::String(s.string().clone()), + (Symbol(s), ValueType::Symbol) => Value::SymbolString(s.symbol().clone()), + (DateTime(d), ValueType::DateTime) => { + Value::DateTime(d.datetime().expect("Cannot have invalid datetime").clone()) + } + (Duration(d), ValueType::Duration) => { + Value::Duration(d.duration().expect("Cannot have invalid duration").clone()) + } + (Entity(u), ValueType::Entity) => Value::Entity(*u), _ => panic!("Cannot convert front Constant `{:?}` to Type `{}`", self, ty), } } pub fn kind(&self) -> &'static str { - use ConstantNode::*; + self.node.kind() + } + + pub fn as_bool(&self) -> Option<&bool> { match &self.node { - Integer(_) => "integer", - Float(_) => "float", - String(_) => "string", - Char(_) => "char", - Boolean(_) => "boolean", - DateTime(_) => "datetime", - Duration(_) => "duration", - Invalid(_) => "invalid", + ConstantNode::Boolean(b) => Some(b), + _ => None, + } + } + + pub fn as_integer(&self) -> Option<&i64> { + match &self.node { + ConstantNode::Integer(i) => Some(i), + _ => None, + } + } + + pub fn as_float(&self) -> Option<&f64> { + match &self.node { + ConstantNode::Float(f) => Some(f), + _ => None, + } + } + + pub fn as_string(&self) -> Option<&String> { + match &self.node { + ConstantNode::String(s) => Some(s.string()), + _ => None, + } + } + + pub fn as_char(&self) -> Option<&String> { + match &self.node { + ConstantNode::Char(c) => Some(c.character_string()), + _ => None, + } + } + + pub fn as_datetime(&self) -> Option<&chrono::DateTime> { + match &self.node { + ConstantNode::DateTime(d) => d.datetime(), + _ => None, + } + } + + pub fn as_duration(&self) -> Option<&chrono::Duration> { + match &self.node { + ConstantNode::Duration(d) => d.duration(), + _ => None, } } } /// A constant or a variable -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] pub enum ConstantOrVariable { Constant(Constant), Variable(Variable), @@ -132,7 +344,7 @@ impl ConstantOrVariable { } } -#[derive(Clone, PartialEq)] +#[derive(Clone, PartialEq, Serialize)] #[doc(hidden)] pub struct IdentifierNode { pub name: String, @@ -148,9 +360,22 @@ impl IdentifierNode { pub type Identifier = AstNode; impl Identifier { + pub fn default_with_name(name: String) -> Self { + IdentifierNode::new(name).into() + } + pub fn name(&self) -> &str { &self.node.name } + + pub fn map String>(&self, f: F) -> Self { + Self { + loc: self.loc.clone(), + node: IdentifierNode { + name: f(&self.node.name), + }, + } + } } impl std::fmt::Debug for IdentifierNode { diff --git a/core/src/compiler/front/ast/entity.rs b/core/src/compiler/front/ast/entity.rs new file mode 100644 index 0000000..88d025d --- /dev/null +++ b/core/src/compiler/front/ast/entity.rs @@ -0,0 +1,169 @@ +use serde::*; + +use crate::common::entity; + +use super::*; + +#[derive(Clone, Debug, PartialEq, Serialize)] +pub enum EntityNode { + Expr(Expr), + Object(Object), +} + +pub type Entity = AstNode; + +impl Entity { + /// Create a new constant entity + pub fn constant(c: Constant) -> Self { + Self::default(EntityNode::Expr(Expr::Constant(c))) + } + + /// Checks if the entity is a simple constant + pub fn is_constant(&self) -> bool { + match &self.node { + EntityNode::Expr(e) => e.is_constant(), + _ => false, + } + } + + /// Get the constant if the entity is a simple constant + pub fn get_constant(&self) -> Option<&Constant> { + match &self.node { + EntityNode::Expr(e) => e.get_constant(), + _ => None, + } + } + + /// Checks if the entity has variable inside + pub fn has_variable(&self) -> bool { + match &self.node { + EntityNode::Expr(e) => e.has_variable(), + EntityNode::Object(o) => o.has_variable(), + } + } + + /// Get the location of the first non-constant in the entity + pub fn get_first_non_constant_location(&self, is_constant: &F) -> Option<&AstNodeLocation> + where + F: Fn(&Variable) -> bool, + { + match &self.node { + EntityNode::Expr(e) => e.get_first_non_constant_location(is_constant), + EntityNode::Object(o) => o.get_first_non_constant_location(is_constant), + } + } + + pub fn to_facts(&self) -> (Vec, Constant) { + self.to_facts_with_constant_variables(|_| None) + } + + pub fn to_facts_with_constant_variables(&self, f: F) -> (Vec, Constant) + where + F: Fn(&Variable) -> Option, + { + fn helper(entity: &Entity, facts: &mut Vec, f: &F) -> Constant + where + F: Fn(&Variable) -> Option, + { + // Check whether we need to recurse + match &entity.node { + EntityNode::Expr(e) => { + if let Some(c) = e.get_constant() { + c.clone() + } else if let Some(v) = e.get_variable() { + if let Some(c) = f(v) { + c + } else { + panic!("[Internal Error] Found non-constant variable in ") + } + } else { + panic!("[Internal Error] Should contain only constant or constant variables") + } + } + EntityNode::Object(obj) => { + let functor = obj.functor().clone_without_location_id(); + let args = obj.iter_args().map(|a| helper(a, facts, f)).collect::>(); + + // Create a hash value + let raw_id = entity::encode_entity(functor.name(), args.iter().map(|a| &a.node)); + + // Create a constant ID of the hash value + let id = Constant { + loc: obj.location().clone(), + node: ConstantNode::Entity(raw_id), + }; + + // Create the entity fact and store it inside the storage + let entity_fact = EntityFact { + functor, + id: id.clone(), + args, + loc: obj.location().clone(), + }; + facts.push(entity_fact); + + // Return the ID + id + } + } + } + + let mut facts = Vec::new(); + let constant = helper(self, &mut facts, &f); + (facts, constant) + } +} + +#[derive(Clone, Debug, Serialize)] +pub struct EntityFact { + pub functor: Identifier, + pub id: Constant, + pub args: Vec, + pub loc: AstNodeLocation, +} + +#[derive(Clone, Debug, PartialEq, Serialize)] +pub struct ObjectNode { + pub functor: Identifier, + pub args: Vec, +} + +pub type Object = AstNode; + +impl Object { + pub fn has_variable(&self) -> bool { + self.node.args.iter().any(|a| a.has_variable()) + } + + pub fn functor(&self) -> &Identifier { + &self.node.functor + } + + pub fn functor_mut(&mut self) -> &mut Identifier { + &mut self.node.functor + } + + pub fn functor_name(&self) -> &str { + self.node.functor.name() + } + + pub fn iter_args(&self) -> impl Iterator { + self.node.args.iter() + } + + pub fn iter_args_mut(&mut self) -> impl Iterator { + self.node.args.iter_mut() + } + + pub fn get_first_non_constant_location(&self, is_constant: &F) -> Option<&AstNodeLocation> + where + F: Fn(&Variable) -> bool, + { + for arg in self.iter_args() { + if let Some(loc) = arg.get_first_non_constant_location(is_constant) { + return Some(loc); + } + } + None + } +} diff --git a/core/src/compiler/front/ast/expr.rs b/core/src/compiler/front/ast/expr.rs index e6296ab..4ab0d81 100644 --- a/core/src/compiler/front/ast/expr.rs +++ b/core/src/compiler/front/ast/expr.rs @@ -1,6 +1,8 @@ +use serde::*; + use super::*; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] pub enum Expr { Constant(Constant), Variable(Variable), @@ -9,6 +11,7 @@ pub enum Expr { Unary(UnaryExpr), IfThenElse(IfThenElseExpr), Call(CallExpr), + New(NewExpr), } impl Expr { @@ -53,6 +56,7 @@ impl Expr { Self::Unary(u) => u.location(), Self::IfThenElse(i) => i.location(), Self::Call(c) => c.location(), + Self::New(n) => n.location(), } } @@ -84,6 +88,34 @@ impl Expr { } } + pub fn get_constant(&self) -> Option<&Constant> { + match self { + Self::Constant(c) => Some(c), + _ => None, + } + } + + pub fn get_variable(&self) -> Option<&Variable> { + match self { + Self::Variable(v) => Some(v), + _ => None, + } + } + + /// Checks if the expression has variables (or wildcards) inside of it + pub fn has_variable(&self) -> bool { + match self { + Self::Constant(_) => false, + Self::Variable(_) => true, + Self::Wildcard(_) => true, + Self::Binary(b) => b.op1().has_variable() || b.op2().has_variable(), + Self::Unary(b) => b.op1().has_variable(), + Self::IfThenElse(i) => i.cond().has_variable() || i.then_br().has_variable() || i.else_br().has_variable(), + Self::Call(c) => c.iter_args().any(|a| a.has_variable()), + Self::New(n) => n.iter_args().any(|a| a.has_variable()), + } + } + pub fn collect_used_variables(&self) -> Vec { let mut vars = vec![]; self.collect_used_variables_helper(&mut vars); @@ -114,11 +146,64 @@ impl Expr { Self::Variable(v) => { vars.push(v.clone()); } + Self::New(n) => { + for a in n.iter_args() { + a.collect_used_variables_helper(vars); + } + } + } + } + + pub fn get_first_variable_location(&self) -> Option<&AstNodeLocation> { + match self { + Expr::Constant(_) => None, + Expr::Variable(v) => Some(v.location()), + Expr::Wildcard(w) => Some(w.location()), + Expr::Binary(b) => Some(b.location()), + Expr::Unary(u) => u.op1().get_first_variable_location(), + Expr::IfThenElse(i) => i + .cond() + .get_first_variable_location() + .or_else(|| i.then_br().get_first_variable_location()) + .or_else(|| i.else_br().get_first_variable_location()), + Expr::Call(c) => { + for arg in c.iter_args() { + if let Some(loc) = arg.get_first_variable_location() { + return Some(loc); + } + } + None + } + Expr::New(n) => { + for arg in n.iter_args() { + if let Some(loc) = arg.get_first_variable_location() { + return Some(loc); + } + } + None + } + } + } + + pub fn get_first_non_constant_location(&self, is_constant: &F) -> Option<&AstNodeLocation> + where + F: Fn(&Variable) -> bool, + { + match self { + Expr::Constant(_) => None, + Expr::Variable(v) => { + if is_constant(v) { + None + } else { + Some(self.location()) + } + } + _ => Some(self.location()), } } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct VariableNode { pub name: Identifier, @@ -148,7 +233,7 @@ impl std::fmt::Display for Variable { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct VariableBindingNode { pub name: Identifier, @@ -170,7 +255,7 @@ impl VariableBinding { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct WildcardNode; @@ -211,7 +296,7 @@ impl BinaryOp { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct BinaryExprNode { pub op: BinaryOp, @@ -235,7 +320,7 @@ impl BinaryExpr { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub enum UnaryOpNode { Neg, @@ -273,7 +358,7 @@ impl UnaryOp { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct UnaryExprNode { pub op: UnaryOp, @@ -292,7 +377,7 @@ impl UnaryExpr { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct IfThenElseExprNode { pub cond: Box, @@ -316,7 +401,7 @@ impl IfThenElseExpr { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct CallExprNode { pub function_identifier: FunctionIdentifier, @@ -356,7 +441,7 @@ impl CallExpr { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct FunctionIdentifierNode { pub id: Identifier, @@ -370,3 +455,38 @@ impl FunctionIdentifier { self.node.id.name() } } + +#[derive(Clone, Debug, PartialEq, Serialize)] +#[doc(hidden)] +pub struct NewExprNode { + pub functor: Identifier, + pub args: Vec, +} + +pub type NewExpr = AstNode; + +impl NewExpr { + pub fn num_args(&self) -> usize { + self.node.args.len() + } + + pub fn iter_args(&self) -> impl Iterator { + self.node.args.iter() + } + + pub fn iter_args_mut(&mut self) -> impl Iterator { + self.node.args.iter_mut() + } + + pub fn functor_identifier(&self) -> &Identifier { + &self.node.functor + } + + pub fn functor_identifier_mut(&mut self) -> &mut Identifier { + &mut self.node.functor + } + + pub fn functor_name(&self) -> &str { + self.node.functor.name() + } +} diff --git a/core/src/compiler/front/ast/formula.rs b/core/src/compiler/front/ast/formula.rs index 7b829a5..170a23e 100644 --- a/core/src/compiler/front/ast/formula.rs +++ b/core/src/compiler/front/ast/formula.rs @@ -1,10 +1,13 @@ +use serde::*; + use super::*; /// A formula -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] pub enum Formula { Atom(Atom), NegAtom(NegAtom), + Case(Case), Disjunction(Disjunction), Conjunction(Conjunction), Implies(Implies), @@ -20,12 +23,12 @@ impl Formula { pub fn negate(&self) -> Self { match self { - Self::Atom(a) => { - Self::NegAtom(NegAtom::new(a.location().clone(), NegAtomNode { atom: a.clone() })) + Self::Atom(a) => Self::NegAtom(NegAtom::new(a.location().clone(), NegAtomNode { atom: a.clone() })), + Self::NegAtom(n) => Self::Atom(n.atom().clone()), + Self::Case(_) => { + // TODO + panic!("Cannot have case inside negation") } - Self::NegAtom(n) => { - Self::Atom(n.atom().clone()) - }, Self::Disjunction(d) => Self::Conjunction(Conjunction::new( d.location().clone(), ConjunctionNode { @@ -57,10 +60,11 @@ impl Formula { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct AtomNode { pub predicate: Identifier, + pub type_args: Vec, pub args: Vec, } @@ -68,8 +72,41 @@ pub struct AtomNode { pub type Atom = AstNode; impl Atom { - pub fn predicate(&self) -> &String { - &self.node.predicate.node.name + pub fn predicate(&self) -> String { + if self.has_type_arg() { + let args = self + .iter_type_arguments() + .map(|a| format!("{}", a)) + .collect::>() + .join("#"); + format!("{}#{}", self.node.predicate.name(), args) + } else { + self.node.predicate.name().to_string() + } + } + + pub fn predicate_identifier(&self) -> &Identifier { + &self.node.predicate + } + + pub fn predicate_identifier_mut(&mut self) -> &mut Identifier { + &mut self.node.predicate + } + + pub fn has_type_arg(&self) -> bool { + !self.node.type_args.is_empty() + } + + pub fn num_type_args(&self) -> usize { + self.node.type_args.len() + } + + pub fn iter_type_arguments(&self) -> impl Iterator { + self.node.type_args.iter() + } + + pub fn iter_type_arguments_mut(&mut self) -> impl Iterator { + self.node.type_args.iter_mut() } pub fn arity(&self) -> usize { @@ -79,9 +116,13 @@ impl Atom { pub fn iter_arguments(&self) -> impl Iterator { self.node.args.iter() } + + pub fn iter_arguments_mut(&mut self) -> impl Iterator { + self.node.args.iter_mut() + } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct NegAtomNode { pub atom: Atom, @@ -95,12 +136,39 @@ impl NegAtom { &self.node.atom } - pub fn predicate(&self) -> &String { + pub fn predicate(&self) -> String { self.atom().predicate() } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] +#[doc(hidden)] +pub struct CaseNode { + pub variable: Variable, + pub entity: Entity, +} + +pub type Case = AstNode; + +impl Case { + pub fn variable(&self) -> &Variable { + &self.node.variable + } + + pub fn variable_name(&self) -> &str { + self.variable().name() + } + + pub fn entity(&self) -> &Entity { + &self.node.entity + } + + pub fn entity_mut(&mut self) -> &mut Entity { + &mut self.node.entity + } +} + +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ConjunctionNode { pub args: Vec, @@ -115,7 +183,7 @@ impl Conjunction { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct DisjunctionNode { pub args: Vec, @@ -130,7 +198,7 @@ impl Disjunction { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ImpliesNode { pub left: Box, @@ -150,7 +218,7 @@ impl Implies { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ConstraintNode { pub expr: Expr, @@ -180,7 +248,7 @@ impl Constraint { } /// A variable or a wildcard, e.g. `x` or `_` -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] pub enum VariableOrWildcard { Variable(Variable), Wildcard(Wildcard), @@ -202,7 +270,7 @@ impl VariableOrWildcard { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ReduceNode { pub left: Vec, @@ -253,7 +321,7 @@ impl Reduce { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub enum ReduceOperatorNode { Count, @@ -334,7 +402,7 @@ impl ReduceOperator { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ForallExistsReduceNode { pub negate: bool, diff --git a/core/src/compiler/front/ast/import_decl.rs b/core/src/compiler/front/ast/import_decl.rs index 9fead11..6fa92aa 100644 --- a/core/src/compiler/front/ast/import_decl.rs +++ b/core/src/compiler/front/ast/import_decl.rs @@ -1,6 +1,8 @@ +use serde::*; + use super::*; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ImportFileNode { pub file_path: String, @@ -8,7 +10,7 @@ pub struct ImportFileNode { pub type ImportFile = AstNode; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ImportDeclNode { pub attrs: Attributes, diff --git a/core/src/compiler/front/ast/item.rs b/core/src/compiler/front/ast/item.rs index b600ae3..0a78939 100644 --- a/core/src/compiler/front/ast/item.rs +++ b/core/src/compiler/front/ast/item.rs @@ -1,6 +1,8 @@ +use serde::*; + use super::*; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] pub enum Item { ImportDecl(ImportDecl), TypeDecl(TypeDecl), diff --git a/core/src/compiler/front/ast/mod.rs b/core/src/compiler/front/ast/mod.rs index f8860c1..ef78330 100644 --- a/core/src/compiler/front/ast/mod.rs +++ b/core/src/compiler/front/ast/mod.rs @@ -1,6 +1,7 @@ mod attr; mod const_decl; mod constant; +mod entity; mod expr; mod formula; mod import_decl; @@ -15,6 +16,7 @@ mod utils; pub use attr::*; pub use const_decl::*; pub use constant::*; +pub use entity::*; pub use expr::*; pub use formula::*; pub use import_decl::*; diff --git a/core/src/compiler/front/ast/query.rs b/core/src/compiler/front/ast/query.rs index fa705b0..d3f6058 100644 --- a/core/src/compiler/front/ast/query.rs +++ b/core/src/compiler/front/ast/query.rs @@ -1,6 +1,8 @@ +use serde::*; + use super::*; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub enum QueryNode { Predicate(Identifier), @@ -19,7 +21,7 @@ impl Query { } else { n.to_string() } - }, + } QueryNode::Atom(a) => a.predicate().to_string(), } } @@ -44,7 +46,7 @@ impl Into> for Query { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct QueryDeclNode { pub attrs: Attributes, diff --git a/core/src/compiler/front/ast/relation_decl.rs b/core/src/compiler/front/ast/relation_decl.rs index d573bb7..73756dd 100644 --- a/core/src/compiler/front/ast/relation_decl.rs +++ b/core/src/compiler/front/ast/relation_decl.rs @@ -1,6 +1,8 @@ +use serde::*; + use super::*; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ConstantTupleNode { pub elems: Vec, @@ -14,7 +16,7 @@ impl ConstantTuple { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ConstantSetTupleNode { pub tag: Tag, @@ -37,7 +39,7 @@ impl ConstantSetTuple { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ConstantSetNode { pub tuples: Vec, @@ -52,7 +54,7 @@ impl ConstantSet { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ConstantSetDeclNode { pub attrs: Attributes, @@ -88,7 +90,7 @@ impl ConstantSetDecl { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct FactDeclNode { pub attrs: Attributes, @@ -107,7 +109,7 @@ impl FactDecl { &mut self.node.attrs } - pub fn predicate(&self) -> &String { + pub fn predicate(&self) -> String { self.node.atom.predicate() } @@ -139,7 +141,7 @@ impl FactDecl { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct RuleDeclNode { pub attrs: Attributes, @@ -181,7 +183,7 @@ impl RuleDecl { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub enum RelationDeclNode { Set(ConstantSetDecl), @@ -207,6 +209,13 @@ impl RelationDecl { RelationDeclNode::Rule(r) => r.attributes_mut(), } } + + pub fn rule(&self) -> Option<&Rule> { + match &self.node { + RelationDeclNode::Rule(r) => Some(&r.node.rule), + _ => None, + } + } } impl From for Item { diff --git a/core/src/compiler/front/ast/rule.rs b/core/src/compiler/front/ast/rule.rs index fc9ea53..d640cf8 100644 --- a/core/src/compiler/front/ast/rule.rs +++ b/core/src/compiler/front/ast/rule.rs @@ -1,6 +1,8 @@ +use serde::*; + use super::*; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct RuleNode { pub head: RuleHead, @@ -41,10 +43,11 @@ impl Into> for Rule { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub enum RuleHeadNode { Atom(Atom), + Conjunction(Vec), Disjunction(Vec), } @@ -54,6 +57,7 @@ impl RuleHead { pub fn is_atomic(&self) -> bool { match &self.node { RuleHeadNode::Atom(_) => true, + RuleHeadNode::Conjunction(_) => false, RuleHeadNode::Disjunction(_) => false, } } @@ -61,20 +65,31 @@ impl RuleHead { pub fn is_disjunction(&self) -> bool { match &self.node { RuleHeadNode::Atom(_) => false, + RuleHeadNode::Conjunction(_) => false, RuleHeadNode::Disjunction(_) => true, } } + pub fn is_conjunction(&self) -> bool { + match &self.node { + RuleHeadNode::Atom(_) => false, + RuleHeadNode::Conjunction(_) => true, + RuleHeadNode::Disjunction(_) => false, + } + } + pub fn atom(&self) -> Option<&Atom> { match &self.node { RuleHeadNode::Atom(atom) => Some(atom), + RuleHeadNode::Conjunction(_) => None, RuleHeadNode::Disjunction(_) => None, } } - pub fn iter_predicates(&self) -> Vec<&String> { + pub fn iter_predicates(&self) -> Vec { match &self.node { RuleHeadNode::Atom(atom) => vec![atom.predicate()], + RuleHeadNode::Conjunction(atoms) => atoms.iter().map(|atom| atom.predicate()).collect(), RuleHeadNode::Disjunction(atoms) => atoms.iter().map(|atom| atom.predicate()).collect(), } } @@ -82,10 +97,8 @@ impl RuleHead { pub fn iter_arguments(&self) -> Vec<&Expr> { match &self.node { RuleHeadNode::Atom(atom) => atom.iter_arguments().collect(), - RuleHeadNode::Disjunction(atoms) => atoms - .iter() - .flat_map(|atom| atom.iter_arguments()) - .collect(), + RuleHeadNode::Conjunction(atoms) => atoms.iter().flat_map(|atom| atom.iter_arguments()).collect(), + RuleHeadNode::Disjunction(atoms) => atoms.iter().flat_map(|atom| atom.iter_arguments()).collect(), } } } diff --git a/core/src/compiler/front/ast/type_decl.rs b/core/src/compiler/front/ast/type_decl.rs index 7a315ee..d2758a0 100644 --- a/core/src/compiler/front/ast/type_decl.rs +++ b/core/src/compiler/front/ast/type_decl.rs @@ -1,23 +1,57 @@ +use serde::*; + use super::*; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub enum TypeDeclNode { Subtype(SubtypeDecl), Alias(AliasTypeDecl), Relation(RelationTypeDecl), Enum(EnumTypeDecl), + Algebraic(AlgebraicDataTypeDecl), } pub type TypeDecl = AstNode; impl TypeDecl { + pub fn alias(name: Identifier, alias_of: Type) -> Self { + TypeDeclNode::Alias( + AliasTypeDeclNode { + attrs: Attributes::default(), + name, + alias_of, + } + .into(), + ) + .into() + } + + pub fn relation(name: Identifier, args: Vec) -> Self { + TypeDeclNode::Relation( + RelationTypeDeclNode { + attrs: Attributes::default(), + rel_types: vec![RelationTypeNode { + name, + arg_types: args + .into_iter() + .map(|a| ArgTypeBindingNode { name: None, ty: a }.into()) + .collect(), + } + .into()], + } + .into(), + ) + .into() + } + pub fn attributes(&self) -> &Attributes { match &self.node { TypeDeclNode::Subtype(s) => s.attributes(), TypeDeclNode::Alias(a) => a.attributes(), TypeDeclNode::Relation(r) => r.attributes(), TypeDeclNode::Enum(e) => e.attributes(), + TypeDeclNode::Algebraic(a) => a.attributes(), } } @@ -27,11 +61,12 @@ impl TypeDecl { TypeDeclNode::Alias(a) => a.attributes_mut(), TypeDeclNode::Relation(r) => r.attributes_mut(), TypeDeclNode::Enum(e) => e.attributes_mut(), + TypeDeclNode::Algebraic(a) => a.attributes_mut(), } } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct SubtypeDeclNode { pub attrs: Attributes, @@ -59,7 +94,7 @@ impl SubtypeDecl { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct AliasTypeDeclNode { pub attrs: Attributes, @@ -87,7 +122,7 @@ impl AliasTypeDecl { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct ArgTypeBindingNode { pub name: Option, @@ -106,7 +141,7 @@ impl ArgTypeBinding { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct RelationTypeNode { pub name: Identifier, @@ -140,7 +175,7 @@ impl RelationType { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct RelationTypeDeclNode { pub attrs: Attributes, @@ -171,7 +206,7 @@ impl RelationTypeDecl { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] #[doc(hidden)] pub struct EnumTypeDeclNode { pub attrs: Attributes, @@ -207,7 +242,7 @@ impl EnumTypeDecl { } } -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Serialize)] pub struct EnumTypeMemberNode { pub name: Identifier, pub assigned_num: Option, @@ -228,3 +263,68 @@ impl EnumTypeMember { self.node.assigned_num.as_mut() } } + +#[derive(Clone, Debug, PartialEq, Serialize)] +pub struct AlgebraicDataTypeDeclNode { + pub attrs: Attributes, + pub name: Identifier, + pub variants: Vec, +} + +pub type AlgebraicDataTypeDecl = AstNode; + +impl AlgebraicDataTypeDecl { + pub fn attributes(&self) -> &Vec { + &self.node.attrs + } + + pub fn attributes_mut(&mut self) -> &mut Attributes { + &mut self.node.attrs + } + + pub fn name(&self) -> &str { + self.node.name.name() + } + + pub fn name_identifier(&self) -> &Identifier { + &self.node.name + } + + pub fn iter_variants(&self) -> impl Iterator { + self.node.variants.iter() + } + + pub fn iter_variants_mut(&mut self) -> impl Iterator { + self.node.variants.iter_mut() + } +} + +#[derive(Clone, Debug, PartialEq, Serialize)] +pub struct AlgebraicDataTypeVariantNode { + pub constructor: Identifier, + pub args: Vec, +} + +pub type AlgebraicDataTypeVariant = AstNode; + +impl AlgebraicDataTypeVariant { + pub fn name(&self) -> &str { + self.node.constructor.name() + } + + pub fn name_identifier(&self) -> &Identifier { + &self.node.constructor + } + + pub fn iter_arg_types(&self) -> impl Iterator { + self.node.args.iter() + } + + pub fn iter_arg_types_mut(&mut self) -> impl Iterator { + self.node.args.iter_mut() + } + + pub fn args(&self) -> &Vec { + &self.node.args + } +} diff --git a/core/src/compiler/front/ast/types.rs b/core/src/compiler/front/ast/types.rs index c147db7..ecb0f0c 100644 --- a/core/src/compiler/front/ast/types.rs +++ b/core/src/compiler/front/ast/types.rs @@ -1,7 +1,10 @@ -use super::*; +use serde::*; + use crate::common::value_type::*; -#[derive(Debug, Clone, PartialEq)] +use super::*; + +#[derive(Debug, Clone, PartialEq, Serialize)] #[doc(hidden)] pub enum TypeNode { I8, @@ -22,10 +25,12 @@ pub enum TypeNode { Bool, Str, String, - // RcString, + Symbol, DateTime, Duration, - Named(Identifier), + Entity, + Tensor, + Named(String), } impl std::fmt::Display for TypeNode { @@ -49,10 +54,46 @@ impl std::fmt::Display for TypeNode { Self::Bool => f.write_str("bool"), Self::Str => f.write_str("&str"), Self::String => f.write_str("String"), - // Self::RcString => f.write_str("Rc"), + Self::Symbol => f.write_str("Symbol"), Self::DateTime => f.write_str("DateTime"), Self::Duration => f.write_str("Duration"), - Self::Named(i) => f.write_str(&i.node.name), + Self::Entity => f.write_str("Entity"), + Self::Tensor => f.write_str("Tensor"), + Self::Named(i) => f.write_str(i), + } + } +} + +impl std::str::FromStr for TypeNode { + // There will be no error + type Err = (); + + fn from_str(s: &str) -> Result { + match s { + "i8" => Ok(Self::I8), + "i16" => Ok(Self::I16), + "i32" => Ok(Self::I32), + "i64" => Ok(Self::I64), + "i128" => Ok(Self::I128), + "isize" => Ok(Self::ISize), + "u8" => Ok(Self::U8), + "u16" => Ok(Self::U16), + "u32" => Ok(Self::U32), + "u64" => Ok(Self::U64), + "u128" => Ok(Self::U128), + "usize" => Ok(Self::USize), + "f32" => Ok(Self::F32), + "f64" => Ok(Self::F64), + "bool" => Ok(Self::Bool), + "char" => Ok(Self::Char), + "&str" => Ok(Self::Str), + "String" => Ok(Self::String), + "Symbol" => Ok(Self::Symbol), + "DateTime" => Ok(Self::DateTime), + "Duration" => Ok(Self::Duration), + "Entity" => Ok(Self::Entity), + "Tensor" => Ok(Self::Tensor), + s => Ok(Self::Named(s.to_string())), } } } @@ -65,11 +106,29 @@ impl Type { Self::default(TypeNode::I8) } + /// Create a new `u64` type AST node + pub fn u64() -> Self { + Self::default(TypeNode::U64) + } + /// Create a new `usize` type AST node pub fn usize() -> Self { Self::default(TypeNode::USize) } + /// Create a new `entity` type AST node + pub fn entity() -> Self { + Self::default(TypeNode::Entity) + } + + /// Get the name of the type if the type node contains a custom named type + pub fn get_name(&self) -> Option<&str> { + match &self.node { + TypeNode::Named(n) => Some(&n), + _ => None, + } + } + /// Convert the type AST node to a value type /// /// Returns `Ok` if the node itself is a base type; @@ -94,10 +153,25 @@ impl Type { TypeNode::Bool => Ok(ValueType::Bool), TypeNode::Str => Ok(ValueType::Str), TypeNode::String => Ok(ValueType::String), - // TypeNode::RcString => Ok(ValueType::RcString), + TypeNode::Symbol => Ok(ValueType::Symbol), TypeNode::DateTime => Ok(ValueType::DateTime), TypeNode::Duration => Ok(ValueType::Duration), - TypeNode::Named(s) => Err(s.name().to_string()), + TypeNode::Entity => Ok(ValueType::Entity), + TypeNode::Tensor => Ok(ValueType::Tensor), + TypeNode::Named(s) => Err(s.to_string()), + } + } +} + +impl From for Type { + fn from(value: Identifier) -> Self { + let type_node = value + .name() + .parse() + .expect("[Internal Error] Casting `Identifier` to `TypeNode` should not fail"); + Self { + loc: value.loc, + node: type_node, } } } diff --git a/core/src/compiler/front/ast/utils.rs b/core/src/compiler/front/ast/utils.rs index 4d4b54c..4deb5ab 100644 --- a/core/src/compiler/front/ast/utils.rs +++ b/core/src/compiler/front/ast/utils.rs @@ -1,14 +1,15 @@ use colored::*; +use serde::*; use super::super::*; -#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize)] pub struct Location { pub row: usize, pub col: usize, } -#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize)] pub struct Span { pub start: T, pub end: T, @@ -30,7 +31,7 @@ impl Span { } } -#[derive(Clone, Default, PartialEq, Eq, PartialOrd, Ord)] +#[derive(Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize)] pub struct AstNodeLocation { pub offset_span: Span, pub loc_span: Option>, @@ -183,7 +184,7 @@ impl std::fmt::Debug for AstNodeLocation { } } -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Serialize, Hash)] #[doc(hidden)] pub struct AstNode { pub loc: AstNodeLocation, @@ -232,6 +233,26 @@ impl AstNode { node: self.node.clone(), } } + + pub fn clone_without_location_id(&self) -> Self + where + N: Clone, + { + Self { + loc: self.loc.clone_without_id(), + node: self.node.clone(), + } + } + + pub fn with_span(mut self, loc: &AstNodeLocation) -> Self { + self.loc = loc.clone_without_id(); + self + } + + pub fn with_location(mut self, loc: AstNodeLocation) -> Self { + self.loc = loc; + self + } } impl From for AstNode { diff --git a/core/src/compiler/front/attribute/action.rs b/core/src/compiler/front/attribute/action.rs new file mode 100644 index 0000000..f14c891 --- /dev/null +++ b/core/src/compiler/front/attribute/action.rs @@ -0,0 +1,25 @@ +use crate::compiler::front::*; + +/// An action to perform after analyzing an attribute +pub enum AttributeAction { + /// Remove the current item + RemoveItem, + + /// Adding a new item + AddItem(Item), + + /// Replace the current item with a new item + ReplaceItem(Item), + + /// Multiple actions + Multiple(Vec), + + /// Context process + Context(Box), + + /// Error with a message + Error(String), + + /// Do nothing + Nothing, +} diff --git a/core/src/compiler/front/attribute/error.rs b/core/src/compiler/front/attribute/error.rs new file mode 100644 index 0000000..a98e0ae --- /dev/null +++ b/core/src/compiler/front/attribute/error.rs @@ -0,0 +1,38 @@ +use super::super::*; + +#[derive(Clone, Debug)] +pub enum AttributeError { + DuplicatedAttributeProcessor { name: String }, + ReservedAttribute { name: String }, + Custom { msg: String }, +} + +impl AttributeError { + pub fn new_custom(msg: String) -> Self { + Self::Custom { msg } + } +} + +impl std::fmt::Display for AttributeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::DuplicatedAttributeProcessor { name } => { + f.write_fmt(format_args!("Duplicated attribute processor `{}`", name)) + } + Self::ReservedAttribute { name } => { + f.write_fmt(format_args!("Attribute process `{}` is reserved in Scallop", name)) + } + Self::Custom { msg } => f.write_str(msg), + } + } +} + +impl FrontCompileErrorTrait for AttributeError { + fn error_type(&self) -> FrontCompileErrorType { + FrontCompileErrorType::Error + } + + fn report(&self, _: &Sources) -> String { + format!("{}", self) + } +} diff --git a/core/src/compiler/front/attribute/mod.rs b/core/src/compiler/front/attribute/mod.rs new file mode 100644 index 0000000..09e76a2 --- /dev/null +++ b/core/src/compiler/front/attribute/mod.rs @@ -0,0 +1,11 @@ +mod action; +mod error; +mod post; +mod processor; +mod registry; + +pub use action::*; +pub use error::*; +pub use post::*; +pub use processor::*; +pub use registry::*; diff --git a/core/src/compiler/front/attribute/post.rs b/core/src/compiler/front/attribute/post.rs new file mode 100644 index 0000000..3b9aa5a --- /dev/null +++ b/core/src/compiler/front/attribute/post.rs @@ -0,0 +1,84 @@ +use std::collections::*; + +use crate::utils::*; + +use super::super::*; +use super::*; + +pub enum PostProcessingAction { + RemoveItem { item_index: usize }, + ReplaceItem { item_index: usize, item: Item }, + AddItem { item: Item }, + Context { func: Box }, + Error { msg: String }, +} + +pub struct PostProcessingContext { + actions: Vec, +} + +impl PostProcessingContext { + pub fn new() -> Self { + Self { actions: Vec::new() } + } + + pub fn add_action(&mut self, action: AttributeAction, item_index: usize) { + // Helper function + fn add_action(post_proc_actions: &mut Vec, action: AttributeAction, item_index: usize) { + match action { + AttributeAction::AddItem(item) => { + post_proc_actions.push(PostProcessingAction::AddItem { item }); + } + AttributeAction::Context(func) => { + post_proc_actions.push(PostProcessingAction::Context { func }); + } + AttributeAction::Error(msg) => { + post_proc_actions.push(PostProcessingAction::Error { msg }); + } + AttributeAction::Multiple(acts) => { + for act in acts { + add_action(post_proc_actions, act, item_index); + } + } + AttributeAction::Nothing => {} + AttributeAction::RemoveItem => { + post_proc_actions.push(PostProcessingAction::RemoveItem { item_index }); + } + AttributeAction::ReplaceItem(item) => { + post_proc_actions.push(PostProcessingAction::ReplaceItem { item, item_index }); + } + } + } + + // Invoke the helper function + add_action(&mut self.actions, action, item_index); + } + + pub fn process(self, ctx: &mut FrontContext, items: &mut Vec) -> Result<(), AttributeError> { + let mut to_remove_item_index = HashSet::new(); + let mut new_items = Vec::new(); + + for action in self.actions { + match action { + PostProcessingAction::AddItem { item } => { + new_items.push(item); + } + PostProcessingAction::Context { func } => { + func(ctx); + } + PostProcessingAction::RemoveItem { item_index } => { + to_remove_item_index.insert(item_index); + } + PostProcessingAction::ReplaceItem { item_index, item } => { + items[item_index] = item; + } + PostProcessingAction::Error { msg } => return Err(AttributeError::Custom { msg }), + } + } + + items.retain_with_index(|id, _| !to_remove_item_index.contains(&id)); + items.extend(new_items.into_iter()); + + Ok(()) + } +} diff --git a/core/src/compiler/front/attribute/processor.rs b/core/src/compiler/front/attribute/processor.rs new file mode 100644 index 0000000..331b062 --- /dev/null +++ b/core/src/compiler/front/attribute/processor.rs @@ -0,0 +1,46 @@ +use dyn_clone::DynClone; + +use super::super::ast; + +use super::*; + +pub trait AttributeProcessor: DynClone + 'static { + fn name(&self) -> String; + + fn apply(&self, item: &ast::Item, attr: &ast::Attribute) -> Result; +} + +impl std::fmt::Debug for dyn AttributeProcessor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.name()) + } +} + +#[derive(Debug)] +pub struct DynamicAttributeProcessor { + proc: Box, +} + +impl DynamicAttributeProcessor { + pub fn new(p: P) -> Self { + Self { proc: Box::new(p) } + } +} + +impl AttributeProcessor for DynamicAttributeProcessor { + fn name(&self) -> String { + self.proc.name() + } + + fn apply(&self, item: &ast::Item, attr: &ast::Attribute) -> Result { + self.proc.apply(item, attr) + } +} + +impl Clone for DynamicAttributeProcessor { + fn clone(&self) -> Self { + Self { + proc: dyn_clone::clone_box(&*self.proc), + } + } +} diff --git a/core/src/compiler/front/attribute/registry.rs b/core/src/compiler/front/attribute/registry.rs new file mode 100644 index 0000000..f717408 --- /dev/null +++ b/core/src/compiler/front/attribute/registry.rs @@ -0,0 +1,63 @@ +use std::collections::*; + +use crate::compiler::front::FrontContext; + +use super::super::ast; +use super::*; + +const RESERVED_ATTRIBUTES: [&'static str; 3] = ["hidden", "file", "demand"]; + +#[derive(Clone, Debug)] +pub struct AttributeProcessorRegistry { + pub registry: HashMap, +} + +impl AttributeProcessorRegistry { + pub fn new() -> Self { + Self { + registry: HashMap::new(), + } + } + + pub fn has_attribute_processor(&self, name: &str) -> bool { + self.registry.contains_key(name) + } + + pub fn get_attribute_processor(&self, name: &str) -> Option<&DynamicAttributeProcessor> { + self.registry.get(name) + } + + pub fn register

(&mut self, p: P) -> Result<(), AttributeError> + where + P: AttributeProcessor, + { + let name = p.name(); + if RESERVED_ATTRIBUTES.contains(&name.as_str()) { + Err(AttributeError::ReservedAttribute { name }) + } else if self.registry.contains_key(&name) { + Err(AttributeError::DuplicatedAttributeProcessor { name }) + } else { + let dyn_p = DynamicAttributeProcessor::new(p); + self.registry.insert(name.to_string(), dyn_p); + Ok(()) + } + } + + pub fn analyze(&self, items: &Vec) -> Result { + let mut post_proc_ctx = PostProcessingContext::new(); + for (item_index, item) in items.iter().enumerate() { + for attr in item.attributes() { + if let Some(proc) = self.get_attribute_processor(attr.name()) { + let action = proc.apply(item, attr)?; + post_proc_ctx.add_action(action, item_index); + } + } + } + Ok(post_proc_ctx) + } + + pub fn analyze_and_process(&self, ctx: &mut FrontContext, items: &mut Vec) -> Result<(), AttributeError> { + let attr_pos_proc = self.analyze(items)?; + attr_pos_proc.process(ctx, items) + } +} diff --git a/core/src/compiler/front/compile.rs b/core/src/compiler/front/compile.rs index 24179d2..8115ef6 100644 --- a/core/src/compiler/front/compile.rs +++ b/core/src/compiler/front/compile.rs @@ -4,10 +4,12 @@ use std::path::PathBuf; use super::analysis::*; use super::analyzers::*; +use super::attribute::*; use super::*; use crate::common::foreign_function::*; use crate::common::foreign_predicate::*; +use crate::common::tuple::*; use crate::common::tuple_type::*; use crate::common::value_type::*; use crate::utils::CopyOnWrite; @@ -32,6 +34,9 @@ pub struct FrontContext { /// Foreign predicate registry holding all foreign predicates pub foreign_predicate_registry: ForeignPredicateRegistry, + /// Attribute processor registry holding all attribute processors + pub attribute_processor_registry: AttributeProcessorRegistry, + /// Node ID annotator for giving AST node IDs. pub node_id_annotator: NodeIdAnnotator, @@ -43,12 +48,14 @@ impl FrontContext { pub fn new() -> Self { let function_registry = ForeignFunctionRegistry::std(); let predicate_registry = ForeignPredicateRegistry::std(); + let attribute_registry = AttributeProcessorRegistry::new(); let analysis = Analysis::new(&function_registry, &predicate_registry); Self { sources: Sources::new(), items: Vec::new(), foreign_function_registry: function_registry, foreign_predicate_registry: predicate_registry, + attribute_processor_registry: attribute_registry, imported_files: HashSet::new(), node_id_annotator: NodeIdAnnotator::new(), analysis: CopyOnWrite::new(analysis), @@ -95,11 +102,11 @@ impl FrontContext { pub fn register_foreign_predicate(&mut self, f: F) -> Result<(), ForeignPredicateError> where - F: ForeignPredicate + Send + Sync + Clone + 'static + F: ForeignPredicate + Send + Sync + Clone + 'static, { // Check if the predicate name has already be defined before - if self.type_inference().has_relation(&f.name()) { - return Err(ForeignPredicateError::AlreadyExisted { id: f.name() }); + if self.type_inference().has_relation(&f.internal_name()) { + return Err(ForeignPredicateError::AlreadyExisted { id: f.internal_name() }); } // Add the predicate to the registry @@ -123,6 +130,19 @@ impl FrontContext { Ok(()) } + pub fn register_attribute_processor

(&mut self, p: P) -> Result<(), AttributeError> + where + P: AttributeProcessor + Send + Sync + Clone, + { + self.attribute_processor_registry.register(p)?; + Ok(()) + } + + pub fn compile_string(&mut self, s: String) -> Result { + let source = StringSource::new(s); + self.compile_source_with_parser(source, parser::str_to_items) + } + pub fn compile_source(&mut self, s: S) -> Result { self.compile_source_with_parser(s, parser::str_to_items) } @@ -212,7 +232,7 @@ impl FrontContext { err.set_source_name(name.to_string()); } error_ctx.set_sources(&dup_ctx.sources); - let source_id = error_ctx.sources.add(source); + let source_id = error_ctx.add_source(source); err.set_source_id(source_id); error_ctx.add(err); return Err(error_ctx); @@ -243,6 +263,18 @@ impl FrontContext { ast.iter_mut().for_each(|item| annotate(item)); } + // Use foreign attribute registry to annotate item + match self + .attribute_processor_registry + .analyze_and_process(&mut dup_ctx, &mut ast) + { + Ok(_) => {} + Err(err) => { + error_ctx.add(err); + return Err(error_ctx); + } + }; + // Front pre-transformaion analysis dup_ctx.analysis.modify(|analysis| { analysis.perform_pre_transformation_analysis(&ast); @@ -253,7 +285,9 @@ impl FrontContext { } // Front transformation; add new items into ast and re-annotate the ast - apply_transformations(&mut ast, dup_ctx.analysis.borrow()); + dup_ctx.analysis.modify_without_copy(|analysis| { + apply_transformations(&mut ast, analysis); + }); dup_ctx.node_id_annotator.walk_items(&mut ast); // Front analysis @@ -317,6 +351,102 @@ impl FrontContext { } } + /// Compile an entity (written in string) into a set of entity facts + pub fn compile_entity_to_facts( + &self, + relation: &str, + entity_tuple: Vec, + ) -> Result>, FrontCompileError> { + use crate::compiler::front::analyzers::type_inference::*; + + // Create an error context + let mut error_ctx = FrontCompileError::with_sources(&self.sources); + + // Check relation types + if let Some(arg_types) = self.type_inference().relation_arg_types(relation) { + // First parse the entities from the strings + let mut all_entity_facts = Vec::new(); + let mut final_entity_constants = Vec::new(); + + // Iterate through + for (entity_str, expected_ty) in entity_tuple.into_iter().zip(arg_types.into_iter()) { + let value = if expected_ty.is_entity() { + // Create a source containing the entity string + let entity_source = StringSource::new(entity_str); + + // Try to parse the entity + let mut entity = match parser::str_to_entity(&entity_source.string) { + Ok(entity) => entity, + Err(mut err) => { + let source_id = error_ctx.add_source(entity_source); + err.set_source_id(source_id); + error_ctx.add(err); + return Err(error_ctx); + } + }; + + // Annotate locations on the entity for error reporting + let source_id = error_ctx.add_source(entity_source); + let mut source_id_annotator = SourceIdAnnotator::new(source_id); + let mut node_id_annotator = self.node_id_annotator.clone(); + NodeVisitorMut::walk_entity(&mut (&mut source_id_annotator, &mut node_id_annotator), &mut entity); + + // Then parse the entity into a set of entity facts + let (entity_facts, constant) = entity.to_facts(); + + // Check the types of the entity facts + let entity_facts = match self.type_inference().parse_entity_facts(&entity_facts) { + Ok(entity_facts) => entity_facts, + Err(err) => { + error_ctx.add(err); + return Err(error_ctx); + } + }; + + // Add things into the aggregated list + all_entity_facts.extend(entity_facts); + + // Add the value + if constant.can_unify(&expected_ty) { + constant.to_value(&expected_ty) + } else { + error_ctx.add(TypeInferenceError::CannotUnifyTypes { + t1: TypeSet::from_constant(&constant), + t2: TypeSet::base(expected_ty), + loc: None, + }); + return Err(error_ctx); + } + } else { + match expected_ty.parse(&entity_str) { + Ok(value) => value, + Err(err) => { + error_ctx.add(err); + return Err(error_ctx); + } + } + }; + final_entity_constants.push(value) + } + + // Use the type inference engine + let tuple = Tuple::from_values(final_entity_constants.into_iter()); + + // Post-process the facts into a storage + let mut facts: HashMap<_, Vec<_>> = std::iter::once((relation.to_string(), vec![tuple])).collect(); + for (functor, id, args) in all_entity_facts { + let tuple = Tuple::from_values(std::iter::once(id).chain(args.into_iter())); + facts.entry(functor).or_default().push(tuple); + } + Ok(facts) + } else { + error_ctx.add(TypeInferenceError::UnknownRelation { + relation: relation.to_string(), + }); + Err(error_ctx) + } + } + pub fn num_relations(&self) -> usize { self.type_inference().num_relations() } diff --git a/core/src/compiler/front/error.rs b/core/src/compiler/front/error.rs index c124938..2167890 100644 --- a/core/src/compiler/front/error.rs +++ b/core/src/compiler/front/error.rs @@ -1,6 +1,8 @@ use colored::*; use dyn_clone::DynClone; +use crate::common::value_type::ValueParseError; + use super::*; pub enum FrontCompileErrorType { @@ -31,14 +33,6 @@ impl FrontCompileErrorType { } } -pub trait FrontCompileErrorTrait: DynClone + std::fmt::Debug { - /// Get the error type of this error (warning/error) - fn error_type(&self) -> FrontCompileErrorType; - - /// Report the error showing source into string - fn report(&self, src: &Sources) -> String; -} - #[derive(Debug)] pub struct FrontCompileError { pub sources: Sources, @@ -75,6 +69,13 @@ impl FrontCompileError { } } + pub fn with_sources(sources: &Sources) -> Self { + Self { + sources: sources.clone(), + errors: Vec::new(), + } + } + pub fn singleton(e: E) -> Self { Self { sources: Sources::new(), @@ -86,6 +87,10 @@ impl FrontCompileError { self.sources = sources.clone(); } + pub fn add_source(&mut self, source: S) -> usize { + self.sources.add(source) + } + pub fn add(&mut self, error: E) { self.errors.push(Box::new(error)); } @@ -105,13 +110,13 @@ impl FrontCompileError { } pub fn report_errors(&self) { - println!("{}", self) + eprintln!("{}", self) } pub fn report_warnings(&self) { for error in &self.errors { if error.error_type().is_warning() { - println!("{} {}\n", error.error_type().marker(), error.report(&self.sources)) + eprintln!("{} {}\n", error.error_type().marker(), error.report(&self.sources)) } } } @@ -120,3 +125,21 @@ impl FrontCompileError { self.errors.clear(); } } + +pub trait FrontCompileErrorTrait: DynClone + std::fmt::Debug { + /// Get the error type of this error (warning/error) + fn error_type(&self) -> FrontCompileErrorType; + + /// Report the error showing source into string + fn report(&self, src: &Sources) -> String; +} + +impl FrontCompileErrorTrait for ValueParseError { + fn error_type(&self) -> FrontCompileErrorType { + FrontCompileErrorType::Error + } + + fn report(&self, _: &Sources) -> String { + format!("cannot parse value `{}` into type `{}`", self.source, self.ty) + } +} diff --git a/core/src/compiler/front/f2b/f2b.rs b/core/src/compiler/front/f2b/f2b.rs index cd3b05a..ec43cac 100644 --- a/core/src/compiler/front/f2b/f2b.rs +++ b/core/src/compiler/front/f2b/f2b.rs @@ -1,6 +1,6 @@ use std::collections::*; -use super::super::analyzers::boundness::{AggregationContext, RuleContext, ForeignPredicateBindings}; +use super::super::analyzers::boundness::{AggregationContext, ForeignPredicateBindings, RuleContext}; use super::super::ast as front; use super::super::ast::{AstNodeLocation, WithLocation}; use super::super::compile::*; @@ -13,6 +13,7 @@ use crate::compiler::back; impl FrontContext { pub fn to_back_program(&self) -> back::Program { + // Generate relations and facts let base_relations = self.to_back_relations(); let facts = self.to_back_facts(); let disjunctive_facts = self.to_back_disjunctive_facts(); @@ -126,7 +127,7 @@ impl FrontContext { } front::RelationDeclNode::Fact(f) => { let pred = f.predicate(); - let tys = self.relation_arg_types(pred).unwrap(); + let tys = self.relation_arg_types(&pred).unwrap(); let args = f.iter_constants().zip(tys.iter()).map(|(c, t)| c.to_value(t)).collect(); let back_fact = back::Fact { tag: f.tag().input_tag().clone(), @@ -174,10 +175,6 @@ impl FrontContext { } fn to_back_rules(&self, temp_relations: &mut Vec) -> Vec { - self.rules_to_back_rules(temp_relations) - } - - fn rules_to_back_rules(&self, temp_relations: &mut Vec) -> Vec { self .iter_relation_decls() .filter_map(|rd| match &rd.node { @@ -191,8 +188,9 @@ impl FrontContext { fn rule_decl_to_back_rules(&self, rd: &front::RuleDecl, temp_relations: &mut Vec) -> Vec { let rule_loc = rd.rule().location(); match &rd.rule().head().node { - front::RuleHeadNode::Atom(head) => { - self.atomic_rule_decl_to_back_rules(rule_loc, head, temp_relations) + front::RuleHeadNode::Atom(head) => self.atomic_rule_decl_to_back_rules(rule_loc, head, temp_relations), + front::RuleHeadNode::Conjunction(_) => { + panic!("[Internal Error] Conjunction should be flattened and de-sugared. This is probably a bug.") } front::RuleHeadNode::Disjunction(head_atoms) => { self.disjunctive_rule_decl_to_back_rules(rule_loc, head_atoms, temp_relations) @@ -224,10 +222,7 @@ impl FrontContext { .collect::>(); // Create the head that will be shared across all back rules - let args = head - .iter_arguments() - .map(|a| flatten_expr.get_expr_term(a)) - .collect(); + let args = head.iter_arguments().map(|a| flatten_expr.get_expr_term(a)).collect(); let head = back::Head::atom(pred.clone(), args); // Get the back rules @@ -274,10 +269,7 @@ impl FrontContext { let back_head_atoms = head_atoms .iter() .map(|a| { - let args = a - .iter_arguments() - .map(|a| flatten_expr.get_expr_term(a)) - .collect(); + let args = a.iter_arguments().map(|a| flatten_expr.get_expr_term(a)).collect(); back::Atom::new(a.predicate().clone(), args) }) .collect(); @@ -377,7 +369,11 @@ impl FrontContext { let pred_bindings = ForeignPredicateBindings::from(&self.foreign_predicate_registry); let body_bounded_vars = agg_ctx.body.compute_boundness(&pred_bindings, &vec![]).unwrap(); let group_by_bounded_vars = agg_ctx.group_by.as_ref().map_or(BTreeSet::new(), |(ctx, _, _)| { - ctx.compute_boundness(&pred_bindings, &vec![]).unwrap().into_iter().collect() + ctx + .compute_boundness(&pred_bindings, &vec![]) + .unwrap() + .into_iter() + .collect() }); let all_bounded_vars = body_bounded_vars .union(&group_by_bounded_vars) @@ -472,43 +468,30 @@ impl FrontContext { let body_tys = self.type_inference().variable_types(src_rule_loc, body_args.iter()); let body_terms = self.back_terms_with_types(body_args.clone(), body_tys.clone()); - // Get the body to-aggregate relation - let body_attr = back::AggregateBodyAttribute::new(group_by_vars.len(), arg_var_names.len(), to_agg_var_names.len()); - let body_attrs = back::Attributes::singleton(body_attr); - let body_relation = back::Relation::new_with_attrs(body_attrs, body_predicate.clone(), body_tys.clone()); - temp_relations.push(body_relation); - - // Get the rules for body - let body_head = back::Head::atom(body_predicate.clone(), body_terms.clone()); - let body_rules = self.formula_to_back_rules( - flatten_expr, - src_rule_loc, - back::Attributes::new(), - body_predicate.clone(), - &agg_ctx.body, - body_head, - vec![], - temp_relations, - ); - temp_rules.extend(body_rules); - // Get the reduce literal - let body_atom = back::Atom::new(body_predicate.clone(), body_terms); - let left_vars = self.back_vars(src_rule_loc, agg_ctx.left_variable_names().into_iter().collect()); let group_by_vars = self.back_vars(src_rule_loc, group_by_vars.into_iter().collect()); let other_group_by_vars = self.back_vars(src_rule_loc, other_group_by_vars); - let arg_vars = self.back_vars(src_rule_loc, arg_var_names.into_iter().collect()); let to_agg_vars = self.back_vars(src_rule_loc, to_agg_var_names.into_iter().collect()); + let left_vars = self.back_vars(src_rule_loc, agg_ctx.left_variable_names().into_iter().collect()); + let arg_vars = self.back_vars(src_rule_loc, arg_var_names.into_iter().collect()); // Generate the internal aggregate operator let op = match &agg_ctx.aggregate_op { front::ReduceOperatorNode::Count => AggregateOp::Count, front::ReduceOperatorNode::Sum => { - assert_eq!(left_vars.len(), 1, "There should be only one var for summation"); + assert_eq!( + left_vars.len(), + 1, + "[Internal Error] There should be only one var for summation" + ); AggregateOp::Sum(left_vars[0].ty.clone()) } front::ReduceOperatorNode::Prod => { - assert_eq!(left_vars.len(), 1, "There should be only one var for production"); + assert_eq!( + left_vars.len(), + 1, + "[Internal Error] There should be only one var for production" + ); AggregateOp::Prod(left_vars[0].ty.clone()) } front::ReduceOperatorNode::Min => AggregateOp::min(!arg_vars.is_empty()), @@ -518,14 +501,36 @@ impl FrontContext { front::ReduceOperatorNode::TopK(k) => AggregateOp::top_k(k.clone()), front::ReduceOperatorNode::CategoricalK(k) => AggregateOp::categorical_k(k.clone()), front::ReduceOperatorNode::Forall => { - panic!("There should be no forall aggregator op. This is a bug"); + panic!("[Internal Error] There should be no forall aggregator op. This is a bug"); } front::ReduceOperatorNode::Unknown(_) => { - panic!("There should be no unknown aggregator op. This is a bug"); + panic!("[Internal Error] There should be no unknown aggregator op. This is a bug"); } }; + // Get the body to-aggregate relation + let body_attr = + back::AggregateBodyAttribute::new(op.clone(), group_by_vars.len(), arg_vars.len(), to_agg_vars.len()); + let body_attrs = back::Attributes::singleton(body_attr); + let body_relation = back::Relation::new_with_attrs(body_attrs, body_predicate.clone(), body_tys.clone()); + temp_relations.push(body_relation); + + // Get the rules for body + let body_head = back::Head::atom(body_predicate.clone(), body_terms.clone()); + let body_rules = self.formula_to_back_rules( + flatten_expr, + src_rule_loc, + back::Attributes::new(), + body_predicate.clone(), + &agg_ctx.body, + body_head, + vec![], + temp_relations, + ); + temp_rules.extend(body_rules); + // Get the literal + let body_atom = back::Atom::new(body_predicate.clone(), body_terms); let reduce_literal = back::Reduce::new( op, left_vars, diff --git a/core/src/compiler/front/f2b/flatten_expr.rs b/core/src/compiler/front/f2b/flatten_expr.rs index 11ecb94..b305d95 100644 --- a/core/src/compiler/front/f2b/flatten_expr.rs +++ b/core/src/compiler/front/f2b/flatten_expr.rs @@ -7,7 +7,6 @@ use crate::compiler::front::utils::*; use crate::compiler::front::*; use crate::utils::IdAllocator; -#[derive(Clone, Debug)] pub struct FlattenExprContext<'a> { pub type_inference: &'a TypeInference, pub foreign_predicate_registry: &'a ForeignPredicateRegistry, @@ -41,15 +40,21 @@ pub enum FlattenedNode { function: String, args: Vec, }, + New { + left: back::Variable, + functor: String, + args: Vec, + }, } impl FlattenedNode { pub fn back_var(&self) -> back::Variable { match self { - Self::Binary { left, .. } => left.clone(), - Self::Unary { left, .. } => left.clone(), - Self::IfThenElse { left, .. } => left.clone(), - Self::Call { left, .. } => left.clone(), + Self::Binary { left, .. } + | Self::Unary { left, .. } + | Self::IfThenElse { left, .. } + | Self::Call { left, .. } + | Self::New { left, .. } => left.clone(), } } } @@ -58,10 +63,7 @@ impl FlattenedNode { pub type FlattenedLeaf = back::Term; impl<'a> FlattenExprContext<'a> { - pub fn new( - type_inference: &'a TypeInference, - foreign_predicate_registry: &'a ForeignPredicateRegistry, - ) -> Self { + pub fn new(type_inference: &'a TypeInference, foreign_predicate_registry: &'a ForeignPredicateRegistry) -> Self { Self { type_inference, foreign_predicate_registry, @@ -111,6 +113,7 @@ impl<'a> FlattenExprContext<'a> { FlattenedNode::Call { left, function, args } => { self.collect_flattened_literals_of_call_op(left, function, args) } + FlattenedNode::New { left, functor, args } => self.collect_flattened_literals_of_new_op(left, functor, args), } } else { vec![] @@ -226,6 +229,28 @@ impl<'a> FlattenExprContext<'a> { curr_literals } + pub fn collect_flattened_literals_of_new_op( + &self, + left: &back::Variable, + functor: &String, + args: &Vec, + ) -> Vec { + let mut curr_literals = vec![]; + + // The call expression literal + let arg_terms = args.iter().map(|a| self.get_loc_term(a)).collect::>(); + let literal = back::Literal::new_expr(left.clone(), functor.clone(), arg_terms); + curr_literals.push(literal); + + // Collect flattened literals from args + for arg in args { + curr_literals.extend(self.collect_flattened_literals(arg)); + } + + // Return all of them + curr_literals + } + pub fn atom_to_back_literals(&self, atom: &Atom) -> Vec { let mut literals = vec![]; @@ -354,7 +379,8 @@ impl<'a> FlattenExprContext<'a> { Formula::Atom(atom) => self.atom_to_back_literals(atom), Formula::NegAtom(neg_atom) => self.neg_atom_to_back_literals(neg_atom), Formula::Constraint(c) => self.constraint_to_back_literal(c), - Formula::Conjunction(_) + Formula::Case(_) + | Formula::Conjunction(_) | Formula::Disjunction(_) | Formula::Implies(_) | Formula::Reduce(_) @@ -451,6 +477,22 @@ impl<'a> NodeVisitor for FlattenExprContext<'a> { ); } + fn visit_new_expr(&mut self, n: &ast::NewExpr) { + let tmp_var_name = self.allocate_tmp_var(); + let functor = format!("adt#{}", n.functor_name()); + self.internal.insert( + n.location().clone(), + FlattenedNode::New { + left: back::Variable { + name: tmp_var_name, + ty: self.type_inference.expr_value_type(n).unwrap(), + }, + functor, + args: n.iter_args().map(|a| a.location().clone()).collect(), + }, + ); + } + fn visit_variable(&mut self, v: &Variable) { let back_var = back::Variable { name: v.name().to_string(), diff --git a/core/src/compiler/front/grammar.lalrpop b/core/src/compiler/front/grammar.lalrpop index 0e17d0e..2bb20b5 100644 --- a/core/src/compiler/front/grammar.lalrpop +++ b/core/src/compiler/front/grammar.lalrpop @@ -41,9 +41,9 @@ match { "^", "==", "!=", + "<", ">", ">=", - "<", "<=", "+", "-", @@ -65,6 +65,9 @@ match { "else", "exists", "forall", + "case", + "is", + "new", // Type "i8", @@ -85,9 +88,11 @@ match { "bool", "&str", "String", - "Rc", + "Symbol", "DateTime", "Duration", + "Entity", + "Tensor", // Boolean keywords "true", @@ -106,9 +111,11 @@ match { // Literals r"[a-zA-Z][a-zA-Z_0-9]*" => name, + r"([a-zA-Z][a-zA-Z_0-9]*)\s*<\s*([a-zA-Z][a-zA-Z_0-9]*)(\s*,\s*([a-zA-Z][a-zA-Z_0-9]*))*\s*>" => specialized_name, r"-?[0-9]+" => int, r"-?[0-9]+(\.[0-9]+)(e-?[0-9]+)?" => float, - r#""[^"]*""# => string, + r#""""((?:[^"])*)"""|"""((?:[\s\S]|\\.)*)"""|"((?:[^"\\]|\\.)*)""# => string, + r#"s"[^"]*""# => symbol_string, r#"t"[^"]*""# => date_time_string, r#"d"[^"]*""# => duration_string, r#"'[^']*'"# => character, @@ -116,16 +123,24 @@ match { // Comments and Whitespaces r"[ \n\r]*" => { }, r"//[^\n\r]*[\n\r]*" => { }, - r"/\*([^\*]*\*+[^\*/])*([^\*]*\*+|[^\*])*\*/" => { }, + r"/\*[^*]*\*+(?:[^/*][^*]*\*+)*/" => { }, } /// ============================== /// /// ========= Attributes ========= /// /// ============================== /// -AttributeArg: Result = { - => Ok(c), - "=" => Err((n, c)), +AttributeValueNode: AttributeValueNode = { + => AttributeValueNode::Constant(c), + "[" > "]" => AttributeValueNode::List(vs), + "(" > ")" => AttributeValueNode::Tuple(ts), +} + +AttributeValue = Spanned; + +AttributeArg: Result = { + => Ok(c), + "=" => Err((n, c)), } AttributeNode: AttributeNode = { @@ -149,9 +164,9 @@ AttributeNode: AttributeNode = { } } -Attribute: Attribute = Spanned; +Attribute = Spanned; -Attributes: Attributes = ; +Attributes = ; /// ==================================== /// /// ========= Type Declaration ========= /// @@ -176,13 +191,15 @@ TypeNode: TypeNode = { "bool" => TypeNode::Bool, "&str" => TypeNode::Str, "String" => TypeNode::String, - // "Rc" => TypeNode::RcString, + "Symbol" => TypeNode::Symbol, "DateTime" => TypeNode::DateTime, "Duration" => TypeNode::Duration, - => TypeNode::Named(n), + "Entity" => TypeNode::Entity, + "Tensor" => TypeNode::Tensor, + => TypeNode::Named(n), } -Type: Type = Spanned; +Type = Spanned; SubtypeDeclNode: SubtypeDeclNode = { "type" "<:" => { @@ -194,7 +211,7 @@ SubtypeDeclNode: SubtypeDeclNode = { } } -SubtypeDecl: SubtypeDecl = Spanned; +SubtypeDecl = Spanned; AliasTypeDeclNode: AliasTypeDeclNode = { "type" "=" => { @@ -206,7 +223,7 @@ AliasTypeDeclNode: AliasTypeDeclNode = { } } -AliasTypeDecl: AliasTypeDecl = Spanned; +AliasTypeDecl = Spanned; ArgTypeBindingNode: ArgTypeBindingNode = { ":" => { @@ -270,37 +287,81 @@ EnumTypeMemberNode: EnumTypeMemberNode = { EnumTypeMember = Spanned; +ADTSeparator = "|"; + +AlgebraicDataTypeDeclNode: AlgebraicDataTypeDeclNode = { + "type" "=" ADTSeparator? > => { + AlgebraicDataTypeDeclNode { + attrs, + name, + variants: vs, + } + } +} + +AlgebraicDataTypeDecl = Spanned; + +AlgebraicDataTypeVariantNode: AlgebraicDataTypeVariantNode = { + "(" > ")" => { + AlgebraicDataTypeVariantNode { + constructor, + args, + } + } +} + +AlgebraicDataTypeVariant = Spanned; + TypeDeclNode: TypeDeclNode = { => TypeDeclNode::Subtype(s), => TypeDeclNode::Alias(a), => TypeDeclNode::Relation(r), => TypeDeclNode::Enum(e), + => TypeDeclNode::Algebraic(a), } -TypeDecl: TypeDecl = Spanned; +TypeDecl = Spanned; /// ======================================== /// -/// ========= Relation Declaration ========= /// +/// ========= Constant Declaration ========= /// /// ======================================== /// +EntityNode: EntityNode = { + => EntityNode::Expr(e), + => EntityNode::Object(o), +} + +pub Entity = Spanned; + +ObjectNode: ObjectNode = { + "(" > ")" => { + ObjectNode { + functor, + args, + } + } +} + +Object = Spanned; + ConstAssignmentNode: ConstAssignmentNode = { - ":" "=" => { + ":" "=" => { ConstAssignmentNode { name: n, ty: Some(t), - value: c, + value: e, } }, - "=" => { + "=" => { ConstAssignmentNode { name: n, ty: None, - value: c, + value: e, } } } -ConstAssignment: ConstAssignment = Spanned; +ConstAssignment = Spanned; ConstDeclNode: ConstDeclNode = { "const" > => { @@ -311,7 +372,7 @@ ConstDeclNode: ConstDeclNode = { } } -ConstDecl: ConstDecl = Spanned; +ConstDecl = Spanned; /// ======================================== /// /// ========= Relation Declaration ========= /// @@ -340,23 +401,14 @@ TagNode: TagNode = { Tag: Tag = Spanned; ConstantNode: ConstantNode = { - => ConstantNode::Boolean(b), => ConstantNode::Integer(i), => ConstantNode::Float(f), - => ConstantNode::String(s), - => { - match utils::parse_date_time_string(&s) { - Some(v) => ConstantNode::DateTime(v), - None => ConstantNode::Invalid(format!("Cannot parse date time `{}`", s)), - } - }, - => { - match utils::parse_duration_string(&s) { - Some(v) => ConstantNode::Duration(v), - None => ConstantNode::Invalid(format!("Cannot parse duration `{}`", s)), - } - }, - => ConstantNode::Char(c), + => ConstantNode::Boolean(b), + => ConstantNode::Char(c), + => ConstantNode::String(s), + => ConstantNode::Symbol(s), + => ConstantNode::DateTime(s), + => ConstantNode::Duration(s), } Constant: Constant = Spanned; @@ -428,7 +480,12 @@ Variable = Spanned; AtomNode: AtomNode = { "(" > ")" => { - AtomNode { predicate, args } + AtomNode { predicate, type_args: vec![], args } + }, + "(" > ")" => { + let (predicate, type_arg_ids) = n; + let type_args = type_arg_ids.into_iter().map(Type::from).collect(); + AtomNode { predicate, type_args, args } } } @@ -641,6 +698,7 @@ UnitFormula: Formula = { "(" ")" => f, => Formula::Constraint(c), => Formula::Atom(a), + => Formula::Case(c), => Formula::Reduce(r), => Formula::ForallExistsReduce(r), } @@ -656,6 +714,17 @@ ConstraintNode: ConstraintNode = { Constraint = Spanned; +CaseNode: CaseNode = { + "case" "is" => { + CaseNode { + variable: v, + entity: e, + } + } +} + +Case = Spanned; + VariableBindingNode: VariableBindingNode = { => VariableBindingNode { name, ty: None }, "(" ":" ")" => VariableBindingNode { name, ty: Some(ty) }, @@ -830,7 +899,18 @@ CallExprNode: CallExprNode = { } } -CallExpr: CallExpr = Spanned; +CallExpr = Spanned; + +NewExprNode: NewExprNode = { + "new" "(" > ")" => { + NewExprNode { + functor, + args, + } + } +} + +NewExpr = Spanned; UnitExpr: Expr = { "(" ")" => e, @@ -838,18 +918,25 @@ UnitExpr: Expr = { => Expr::Constant(c), => Expr::Variable(v), => Expr::Call(c), + => Expr::New(n), } RuleHeadNode: RuleHeadNode = { => { RuleHeadNode::Atom(a) }, + > => { + RuleHeadNode::Conjunction(atoms) + }, + "{" > "}" => { + RuleHeadNode::Conjunction(atoms) + }, "{" > "}" => { RuleHeadNode::Disjunction(atoms) }, } -RuleHead: RuleHead = Spanned; +RuleHead = Spanned; RuleNode: RuleNode = { DefineSymbol => { @@ -868,7 +955,7 @@ RuleDeclNode: RuleDeclNode = { }, } -RuleDecl: RuleDecl = Spanned; +RuleDecl = Spanned; /// ====================================== /// /// ========= Import Declaration ========= /// @@ -882,7 +969,7 @@ ImportFileNode: ImportFileNode = { } } -ImportFile: ImportFile = Spanned; +ImportFile = Spanned; ImportDeclNode: ImportDeclNode = { "import" => { @@ -893,7 +980,7 @@ ImportDeclNode: ImportDeclNode = { } } -ImportDecl: ImportDecl = Spanned; +ImportDecl = Spanned; /// ===================================== /// /// ========= Query Declaration ========= /// @@ -908,7 +995,7 @@ QueryNode: QueryNode = { } } -pub Query: Query = Spanned; +pub Query = Spanned; QueryKeyword = { "query", @@ -923,7 +1010,7 @@ QueryDeclNode: QueryDeclNode = { }, } -QueryDecl: QueryDecl = Spanned; +QueryDecl = Spanned; /// ==================================== /// /// ========= Item Declaration ========= /// @@ -945,35 +1032,135 @@ pub Items: Items = ; Name: String = name => (<>).to_string(); +pub SpecializedPredicate: (Identifier, Vec) = { + => { + // First get the name part + let angle_id = s.find("<").unwrap(); // Split the string using `<` + let name = s[0..angle_id].trim_end(); // The first part is the name + let name_id = Identifier::from_span(start, start + name.len(), IdentifierNode { name: name.to_string() }); // Generate the identifier for name + + // Then get the args part + let all_args_str = &s[angle_id + 1..]; + let arg_local_start_positions = std::iter::once(0) + .chain(all_args_str.match_indices(",").map(|(i, _)| i + 1)) + .chain(std::iter::once(all_args_str.len())) + .collect::>(); + let num_args = arg_local_start_positions.len() - 1; + let arg_ids = (0..num_args) + .map(|i| { + let (curr_begin, curr_end) = (arg_local_start_positions[i], arg_local_start_positions[i + 1] - 1); + let curr_total = curr_end - curr_begin; + let local_arg_str = &all_args_str[curr_begin..curr_end]; + let local_arg_start = curr_begin + (curr_total - local_arg_str.trim_start().len()); + let local_arg_end = curr_begin + local_arg_str.trim_end().len(); + let global_arg_start = start + (angle_id + 1) + local_arg_start; + let global_arg_end = start + (angle_id + 1) + local_arg_end; + Identifier::from_span(global_arg_start, global_arg_end, IdentifierNode { name: local_arg_str.trim().to_string() }) + }) + .collect(); + + // Return + (name_id, arg_ids) + } +} + IdentifierNode: IdentifierNode = => IdentifierNode { name }; -Identifier: Identifier = Spanned; +Identifier = Spanned; Int: i64 = int => i64::from_str(<>).unwrap(); Float: f64 = float => f64::from_str(<>).unwrap(); -StringLiteral: String = => { - s[1..s.len() - 1].replace("\\t", "\t").replace("\\n", "\n").replace("\\\\", "\\").into() -}; +StringLiteral: String = { + => { + s[1..s.len() - 1].replace("\"\"", "") + .replace("\\t", "\t") + .replace("\\r", "\r") + .replace("\\n", "\n") + .replace("\\0", "\0") + .replace("\\\"", "\"") + .replace("\\\'", "\'") + .replace("\\\\", "\\") + .into() + } +} -DateTimeLiteral: String = => { - s[2..s.len() - 1].replace("\\t", "\t").replace("\\n", "\n").replace("\\\\", "\\").into() -}; +SymbolLiteral: String = { + => { + s[2..s.len() - 1].into() + } +} -DurationLiteral: String = => { - s[2..s.len() - 1].replace("\\t", "\t").replace("\\n", "\n").replace("\\\\", "\\").into() -}; +DateTimeLiteral: String = { + => { + s[2..s.len() - 1].into() + } +} -CharLiteral: String = => { - s[1..s.len() - 1].replace("\\t", "\t").replace("\\n", "\n").replace("\\'", "'").replace("\\\\", "\\").into() -}; +DurationLiteral: String = { + => { + s[2..s.len() - 1].into() + } +} + +CharLiteral: String = { + => { + s[1..s.len() - 1].replace("\\t", "\t").replace("\\n", "\n").replace("\\'", "'").replace("\\\\", "\\").into() + } +} Bool: bool = { "true" => true, "false" => false, } +ConstantCharNode: ConstantCharNode = { + => { + ConstantCharNode { character: c } + } +} + +ConstantChar = Spanned; + +ConstantStringNode: ConstantStringNode = { + => { + ConstantStringNode { string: s } + } +} + +ConstantString = Spanned; + +ConstantSymbolNode: ConstantSymbolNode = { + => { + ConstantSymbolNode { symbol: s } + } +} + +ConstantSymbol = Spanned; + +ConstantDateTimeNode: Result = { + => { + match utils::parse_date_time_string(&s) { + Some(v) => Ok(ConstantDateTimeNode { datetime: v }), + None => Err(format!("Cannot parse date time `{}`", s)), + } + } +} + +ConstantDateTime = Spanned; + +ConstantDurationNode: Result = { + => { + match utils::parse_duration_string(&s) { + Some(v) => Ok(ConstantDurationNode { duration: v }), + None => Err(format!("Cannot parse duration `{}`", s)), + } + } +} + +ConstantDuration = Spanned; + /// =========================== /// /// ========= Helpers ========= /// /// =========================== /// diff --git a/core/src/compiler/front/mod.rs b/core/src/compiler/front/mod.rs index 70fc596..2f47cff 100644 --- a/core/src/compiler/front/mod.rs +++ b/core/src/compiler/front/mod.rs @@ -1,7 +1,8 @@ mod analysis; pub mod analyzers; mod annotation; -mod ast; +pub mod ast; +pub mod attribute; mod compile; mod error; mod f2b; @@ -17,7 +18,7 @@ mod visitor_mut; pub use analysis::*; pub use annotation::*; -pub use ast::*; +use ast::*; pub use compile::*; pub use error::*; pub use f2b::*; diff --git a/core/src/compiler/front/parser.rs b/core/src/compiler/front/parser.rs index 07b2818..5096764 100644 --- a/core/src/compiler/front/parser.rs +++ b/core/src/compiler/front/parser.rs @@ -1,10 +1,12 @@ use super::*; +/// Parse a string into a single `Item` pub fn str_to_item(s: &str) -> Result { let parser = grammar::ItemParser::new(); parser.parse(s).map_err(ParserError::from) } +/// Parse a string into multiple `Items` pub fn str_to_items(s: &str) -> Result, ParserError> { let parser = grammar::ItemsParser::new(); parser.parse(s).map_err(ParserError::from) @@ -25,6 +27,16 @@ pub fn str_to_query(s: &str) -> Result, ParserError> { parser.parse(s).map(|q| q.into()).map_err(ParserError::from) } +pub fn str_to_entity(s: &str) -> Result { + let parser = grammar::EntityParser::new(); + parser.parse(s).map(|q| q.into()).map_err(ParserError::from) +} + +pub fn str_to_specialized_predicate(s: &str) -> Result<(Identifier, Vec), ParserError> { + let parser = grammar::SpecializedPredicateParser::new(); + parser.parse(s).map(|q| q.into()).map_err(ParserError::from) +} + pub type RawParserError<'a> = lalrpop_util::ParseError, &'static str>; #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/core/src/compiler/front/pretty.rs b/core/src/compiler/front/pretty.rs index 9429671..6ef97e8 100644 --- a/core/src/compiler/front/pretty.rs +++ b/core/src/compiler/front/pretty.rs @@ -30,6 +30,7 @@ impl Display for TypeDecl { TypeDeclNode::Alias(s) => s.fmt(f), TypeDeclNode::Relation(s) => s.fmt(f), TypeDeclNode::Enum(e) => e.fmt(f), + TypeDeclNode::Algebraic(a) => a.fmt(f), } } } @@ -60,6 +61,29 @@ impl Display for ConstAssignment { } } +impl Display for Entity { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match &self.node { + EntityNode::Expr(e) => e.fmt(f), + EntityNode::Object(o) => o.fmt(f), + } + } +} + +impl Display for Object { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + f.write_str(self.functor_name())?; + f.write_str("(")?; + for (i, arg) in self.iter_args().enumerate() { + if i > 0 { + f.write_str(", ")?; + } + arg.fmt(f)?; + } + f.write_str(")") + } +} + impl Display for RelationDecl { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match &self.node { @@ -88,6 +112,34 @@ impl Display for Query { } } +impl Display for AttributeValue { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match &self.node { + AttributeValueNode::Constant(c) => c.fmt(f), + AttributeValueNode::List(l) => { + f.write_str("[")?; + for (i, v) in l.iter().enumerate() { + if i > 0 { + f.write_str(", ")?; + } + v.fmt(f)?; + } + f.write_str("]") + } + AttributeValueNode::Tuple(l) => { + f.write_str("(")?; + for (i, v) in l.iter().enumerate() { + if i > 0 { + f.write_str(", ")?; + } + v.fmt(f)?; + } + f.write_str(")") + } + } + } +} + impl Display for Attribute { fn fmt(&self, f: &mut Formatter<'_>) -> Result { f.write_fmt(format_args!("@{}", self.name()))?; @@ -140,7 +192,7 @@ impl Display for Type { impl Display for SubtypeDecl { fn fmt(&self, f: &mut Formatter<'_>) -> Result { for attr in self.attributes() { - attr.fmt(f)?; + f.write_fmt(format_args!("{} ", attr))?; } f.write_fmt(format_args!("type {} <: {}", self.name(), self.subtype_of())) } @@ -149,7 +201,7 @@ impl Display for SubtypeDecl { impl Display for AliasTypeDecl { fn fmt(&self, f: &mut Formatter<'_>) -> Result { for attr in self.attributes() { - attr.fmt(f)?; + f.write_fmt(format_args!("{} ", attr))?; } f.write_fmt(format_args!("type {} = {}", self.name(), self.alias_of())) } @@ -158,7 +210,7 @@ impl Display for AliasTypeDecl { impl Display for RelationTypeDecl { fn fmt(&self, f: &mut Formatter<'_>) -> Result { for attr in self.attributes() { - attr.fmt(f)?; + f.write_fmt(format_args!("{} ", attr))?; } f.write_str("type ")?; for (i, relation_type) in self.relation_types().enumerate() { @@ -188,7 +240,7 @@ impl Display for RelationType { impl Display for EnumTypeDecl { fn fmt(&self, f: &mut Formatter<'_>) -> Result { for attr in self.attributes() { - attr.fmt(f)?; + f.write_fmt(format_args!("{} ", attr))?; } f.write_fmt(format_args!("type {} = ", self.name()))?; for (i, member) in self.iter_members().enumerate() { @@ -211,6 +263,35 @@ impl Display for EnumTypeMember { } } +impl Display for AlgebraicDataTypeDecl { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + for attr in self.attributes() { + f.write_fmt(format_args!("{} ", attr))?; + } + let name = self.name(); + f.write_fmt(format_args!("type {name} = "))?; + for (i, member) in self.iter_variants().enumerate() { + if i > 0 { + f.write_str(" | ")?; + } + member.fmt(f)?; + } + Ok(()) + } +} + +impl Display for AlgebraicDataTypeVariant { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let name = self.name(); + let args = self + .iter_arg_types() + .map(|t| format!("{t}")) + .collect::>() + .join(", "); + f.write_fmt(format_args!("{name}({args})")) + } +} + impl Display for Identifier { fn fmt(&self, f: &mut Formatter<'_>) -> Result { f.write_str(self.name()) @@ -220,7 +301,7 @@ impl Display for Identifier { impl Display for ConstantSetDecl { fn fmt(&self, f: &mut Formatter<'_>) -> Result { for attr in self.attributes() { - attr.fmt(f)?; + f.write_fmt(format_args!("{} ", attr))?; } f.write_fmt(format_args!( "rel {} = {{{}}}", @@ -237,7 +318,7 @@ impl Display for ConstantSetDecl { impl Display for FactDecl { fn fmt(&self, f: &mut Formatter<'_>) -> Result { for attr in self.attributes() { - attr.fmt(f)?; + f.write_fmt(format_args!("{} ", attr))?; } f.write_fmt(format_args!( "rel {}({})", @@ -302,14 +383,51 @@ impl Display for ConstantOrVariable { impl std::fmt::Display for Constant { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.node { - ConstantNode::Integer(i) => f.write_fmt(format_args!("{}", i)), - ConstantNode::Float(n) => f.write_fmt(format_args!("{}", n)), - ConstantNode::Char(c) => f.write_fmt(format_args!("'{}'", c)), - ConstantNode::Boolean(b) => f.write_fmt(format_args!("{}", b)), - ConstantNode::String(s) => f.write_fmt(format_args!("\"{}\"", s)), - ConstantNode::DateTime(s) => f.write_fmt(format_args!("\"{}\"", s)), - ConstantNode::Duration(s) => f.write_fmt(format_args!("\"{}\"", s)), - ConstantNode::Invalid(_) => f.write_str("invalid"), + ConstantNode::Integer(i) => f.write_fmt(format_args!("{i}")), + ConstantNode::Entity(e) => f.write_fmt(format_args!("{e:#x}")), + ConstantNode::Float(n) => f.write_fmt(format_args!("{n}")), + ConstantNode::Char(c) => c.fmt(f), + ConstantNode::Boolean(b) => f.write_fmt(format_args!("{b}")), + ConstantNode::String(s) => s.fmt(f), + ConstantNode::Symbol(s) => s.fmt(f), + ConstantNode::DateTime(d) => d.fmt(f), + ConstantNode::Duration(d) => d.fmt(f), + } + } +} + +impl std::fmt::Display for ConstantChar { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("'{}'", self.character_string())) + } +} + +impl std::fmt::Display for ConstantString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("\"{}\"", self.string())) + } +} + +impl std::fmt::Display for ConstantSymbol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("s\"{}\"", self.symbol())) + } +} + +impl std::fmt::Display for ConstantDateTime { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.node { + Ok(n) => f.write_fmt(format_args!("t\"{}\"", n.datetime)), + Err(msg) => f.write_fmt(format_args!("[Invalid DateTime: \"{}\"]", msg)), + } + } +} + +impl std::fmt::Display for ConstantDuration { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.node { + Ok(n) => f.write_fmt(format_args!("d\"{}\"", n.duration)), + Err(msg) => f.write_fmt(format_args!("[Invalid Duration: \"{}\"]", msg)), } } } @@ -324,16 +442,25 @@ impl Display for RuleHead { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match &self.node { RuleHeadNode::Atom(a) => a.fmt(f), + RuleHeadNode::Conjunction(c) => { + for (i, a) in c.iter().enumerate() { + if i > 0 { + f.write_str(", ")?; + } + a.fmt(f)?; + } + Ok(()) + } RuleHeadNode::Disjunction(d) => { f.write_str("{")?; for (i, a) in d.iter().enumerate() { if i > 0 { - f.write_str(", ")?; + f.write_str("; ")?; } a.fmt(f)?; } f.write_str("}") - }, + } } } } @@ -343,6 +470,7 @@ impl Display for Formula { match self { Self::Atom(a) => a.fmt(f), Self::NegAtom(a) => a.fmt(f), + Self::Case(c) => c.fmt(f), Self::Disjunction(a) => a.fmt(f), Self::Conjunction(a) => a.fmt(f), Self::Implies(i) => i.fmt(f), @@ -359,6 +487,12 @@ impl Display for NegAtom { } } +impl Display for Case { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + f.write_fmt(format_args!("case {} is {}", self.variable_name(), self.entity())) + } +} + impl Display for Disjunction { fn fmt(&self, f: &mut Formatter<'_>) -> Result { f.write_fmt(format_args!( @@ -493,6 +627,11 @@ impl std::fmt::Display for Expr { c.function_identifier(), c.iter_args().map(|a| format!("{}", a)).collect::>().join(", ") )), + Self::New(n) => f.write_fmt(format_args!( + "new {}({})", + n.functor_identifier(), + n.iter_args().map(|a| format!("{}", a)).collect::>().join(", ") + )), } } } diff --git a/core/src/compiler/front/source/cycle_import.rs b/core/src/compiler/front/source/cycle_import.rs index 8be7e13..d33503c 100644 --- a/core/src/compiler/front/source/cycle_import.rs +++ b/core/src/compiler/front/source/cycle_import.rs @@ -10,7 +10,7 @@ pub struct CycleImportError { impl CycleImportError { pub fn report(&self, _: &Sources) { - println!("File `{}` is already imported", self.path.to_str().unwrap()); + eprintln!("File `{}` is already imported", self.path.to_str().unwrap()); } } diff --git a/core/src/compiler/front/source/rust_macro.rs b/core/src/compiler/front/source/rust_macro.rs index ee46fd2..33bdee0 100644 --- a/core/src/compiler/front/source/rust_macro.rs +++ b/core/src/compiler/front/source/rust_macro.rs @@ -103,7 +103,7 @@ fn token_stream_to_src_lines(tokens: TokenStream) -> (usize, Vec) { (first_line_num, src_lines) } -fn populate_str_tokens(str_tokens: &mut Vec<(String, proc_macro2::LineColumn, i32)>, tokens: TokenStream) { +fn populate_str_tokens(str_tokens: &mut Vec<(String, LineColumn, i32)>, tokens: TokenStream) { let mut tokens_iter = tokens.into_iter(); loop { if let Some(token) = tokens_iter.next() { @@ -116,15 +116,15 @@ fn populate_str_tokens(str_tokens: &mut Vec<(String, proc_macro2::LineColumn, i3 proc_macro2::Delimiter::Bracket => ("[".to_string(), -1, "]".to_string()), proc_macro2::Delimiter::None => ("".to_string(), 0, "".to_string()), }; - str_tokens.push((open, span.start(), 0)); + str_tokens.push((open, span_start_line_column(&span), 0)); populate_str_tokens(str_tokens, g.stream()); - str_tokens.push((close, span.end(), offset)); + str_tokens.push((close, span_end_line_column(&span), offset)); } proc_macro2::TokenTree::Ident(i) => { let span = i.span(); - str_tokens.push((format!("{}", i), span.start(), 0)); + str_tokens.push((format!("{}", i), span_start_line_column(&span), 0)); } proc_macro2::TokenTree::Punct(p) => { let mut curr_p = p.clone(); @@ -140,11 +140,11 @@ fn populate_str_tokens(str_tokens: &mut Vec<(String, proc_macro2::LineColumn, i3 panic!("Should not happen"); } } - str_tokens.push((op, span.start(), 0)); + str_tokens.push((op, span_start_line_column(&span), 0)); } proc_macro2::TokenTree::Literal(l) => { let span = l.span(); - str_tokens.push((format!("{}", l), span.start(), 0)); + str_tokens.push((format!("{}", l), span_start_line_column(&span), 0)); } } } else { @@ -152,3 +152,25 @@ fn populate_str_tokens(str_tokens: &mut Vec<(String, proc_macro2::LineColumn, i3 } } } + +#[derive(Debug, Clone, Copy)] +struct LineColumn { + line: usize, + column: usize, +} + +fn span_start_line_column(span: &proc_macro2::Span) -> LineColumn { + let s = span.unwrap().start(); + LineColumn { + line: s.line, + column: s.column, + } +} + +fn span_end_line_column(span: &proc_macro2::Span) -> LineColumn { + let s = span.unwrap().end(); + LineColumn { + line: s.line, + column: s.column, + } +} diff --git a/core/src/compiler/front/transform.rs b/core/src/compiler/front/transform.rs index 08f5f7f..b9ba8bc 100644 --- a/core/src/compiler/front/transform.rs +++ b/core/src/compiler/front/transform.rs @@ -1,29 +1,41 @@ use super::transformations::*; use super::*; -pub fn apply_transformations(ast: &mut Vec, analysis: &Analysis) { +pub fn apply_transformations(ast: &mut Vec, analysis: &mut Analysis) { + let mut transform_adt = TransformAlgebraicDataType::new(&mut analysis.adt_analysis); let mut transform_const_var_to_const = TransformConstVarToConst::new(&analysis.constant_decl_analysis); let mut transform_atomic_query = TransformAtomicQuery::new(); + let mut transform_conjunctive_head = TransformConjunctiveHead::new(); let mut transform_tagged_rule = TransformTaggedRule::new(); let mut transform_non_const_fact = TransformNonConstantFactToRule; + let mut desugar_case_is = DesugarCaseIs::new(); let mut desugar_forall_exists = DesugarForallExists::new(); let mut forall_to_not_exists = TransformForall; let mut implies_to_disjunction = TransformImplies; let mut visitors = ( + &mut transform_adt, &mut transform_atomic_query, + &mut transform_conjunctive_head, &mut transform_const_var_to_const, &mut transform_tagged_rule, &mut transform_non_const_fact, + &mut desugar_case_is, &mut desugar_forall_exists, &mut forall_to_not_exists, // Note: forall needs to go before implies transformation &mut implies_to_disjunction, ); visitors.walk_items(ast); + // Post-transformation; remove items + ast.retain(|item| transform_conjunctive_head.retain(item) && transform_adt.retain(item)); + // Post-transformation; annotate node ids afterwards let mut new_items = vec![]; - new_items.extend(transform_atomic_query.generate_items()); - new_items.extend(transform_tagged_rule.generate_items()); + new_items.extend(transform_adt.generate_items()); + new_items.extend(transform_const_var_to_const.generate_items()); + new_items.extend(transform_atomic_query.drain_items()); + new_items.extend(transform_conjunctive_head.generate_items()); + new_items.extend(transform_tagged_rule.drain_items()); // Some of the transformations need to be applied to new items as well transform_const_var_to_const.walk_items(&mut new_items); @@ -31,7 +43,3 @@ pub fn apply_transformations(ast: &mut Vec, analysis: &Analysis) { // Extend the ast to incorporate these new items ast.extend(new_items) } - -pub trait Transformation { - fn generate_items(self) -> Vec; -} diff --git a/core/src/compiler/front/transformations/adt_to_relation.rs b/core/src/compiler/front/transformations/adt_to_relation.rs new file mode 100644 index 0000000..283c490 --- /dev/null +++ b/core/src/compiler/front/transformations/adt_to_relation.rs @@ -0,0 +1,114 @@ +use crate::compiler::front::analyzers::*; +use crate::compiler::front::visitor_mut::*; +use crate::compiler::front::*; + +#[derive(Debug)] +pub struct TransformAlgebraicDataType<'a> { + analysis: &'a mut AlgebraicDataTypeAnalysis, +} + +impl<'a> NodeVisitorMut for TransformAlgebraicDataType<'a> { + fn visit_type_decl(&mut self, type_decl: &mut TypeDecl) { + match &type_decl.node { + TypeDeclNode::Algebraic(adt_decl) => { + let location = adt_decl.location().clone(); + let name_identifier = adt_decl.name_identifier().clone(); + let new_decl = TypeDecl::alias(name_identifier.clone(), Type::entity()).with_location(location); + *type_decl = new_decl; + } + _ => {} + } + } +} + +impl<'a> TransformAlgebraicDataType<'a> { + pub fn new(analysis: &'a mut AlgebraicDataTypeAnalysis) -> Self { + Self { analysis } + } + + pub fn generate_items(self) -> Vec { + let result = self + .analysis + .adt_variants + .iter() + .map(|(variant_name, variant_info)| { + let rel_name = variant_info + .name + .clone_without_location_id() + .map(|n| format!("adt#{n}")); + + // Generate the args including the first ID type + let first_arg: Type = TypeNode::Named(variant_info.belongs_to_type.name().to_string()).into(); + let arg_types: Vec = std::iter::once(first_arg) + .chain(variant_info.args.iter().cloned()) + .map(|arg| ArgTypeBindingNode { name: None, ty: arg }.into()) + .collect(); + + // Generate an attribute `@adt("VARIANT_NAME", [IS_ARG_0_ENTITY, ...])` + let is_entity: AttributeValue = variant_info + .args + .iter() + .map(|arg| { + let arg_is_entity = if let Some(name) = arg.get_name() { + self.analysis.adt_types.contains(name) + } else { + false + }; + let constant = Constant::boolean(arg_is_entity); + let attr_arg = AttributeValue::constant(constant); + attr_arg + }) + .collect(); + let adt_attr: Attribute = AttributeNode { + name: Identifier::default_with_name("adt".to_string()), + pos_args: vec![ + AttributeValue::constant(Constant::string(variant_name.clone())), + is_entity, + ], + kw_args: vec![], + } + .into(); + + // Generate another attribute `@hidden` + let hidden_attr: Attribute = AttributeNode { + name: Identifier::default_with_name("hidden".to_string()), + pos_args: vec![], + kw_args: vec![], + } + .into(); + + // Generate a type declaration item + Item::TypeDecl( + TypeDeclNode::Relation( + RelationTypeDeclNode { + attrs: vec![adt_attr, hidden_attr], + rel_types: vec![RelationTypeNode { + name: rel_name, + arg_types, + } + .into()], + } + .into(), + ) + .into(), + ) + }) + .collect(); + + // Clear the variants + self.analysis.adt_variants.clear(); + + // Return the results + result + } + + pub fn retain(&self, item: &Item) -> bool { + match item { + Item::TypeDecl(td) => match td.node { + TypeDeclNode::Algebraic(_) => false, + _ => true, + }, + _ => true, + } + } +} diff --git a/core/src/compiler/front/transformations/atomic_query.rs b/core/src/compiler/front/transformations/atomic_query.rs index b8feee2..661d558 100644 --- a/core/src/compiler/front/transformations/atomic_query.rs +++ b/core/src/compiler/front/transformations/atomic_query.rs @@ -9,6 +9,23 @@ impl TransformAtomicQuery { pub fn new() -> Self { Self { to_add_rules: vec![] } } + + pub fn drain_items(self) -> Vec { + self + .to_add_rules + .iter() + .map(|rule| { + let rule_decl = RuleDeclNode { + attrs: vec![], + tag: Tag::default_none(), + rule: rule.clone(), + }; + let rel_decl = RelationDeclNode::Rule(rule_decl.into()); + let item = Item::RelationDecl(rel_decl.into()); + item + }) + .collect() + } } impl NodeVisitorMut for TransformAtomicQuery { @@ -32,10 +49,12 @@ impl NodeVisitorMut for TransformAtomicQuery { .collect::>(); let head_atom = AtomNode { predicate: IdentifierNode::new(query_name.clone()).into(), + type_args: Vec::new(), args: args.clone(), }; let body_atom = AtomNode { predicate: IdentifierNode::new(a.predicate().clone()).into(), + type_args: Vec::new(), args: args.clone(), }; let eq_constraints = a @@ -67,22 +86,3 @@ impl NodeVisitorMut for TransformAtomicQuery { } } } - -impl Transformation for TransformAtomicQuery { - fn generate_items(self) -> Vec { - self - .to_add_rules - .iter() - .map(|rule| { - let rule_decl = RuleDeclNode { - attrs: vec![], - tag: Tag::default_none(), - rule: rule.clone(), - }; - let rel_decl = RelationDeclNode::Rule(rule_decl.into()); - let item = Item::RelationDecl(rel_decl.into()); - item - }) - .collect() - } -} diff --git a/core/src/compiler/front/transformations/conjunctive_head.rs b/core/src/compiler/front/transformations/conjunctive_head.rs new file mode 100644 index 0000000..7812ad7 --- /dev/null +++ b/core/src/compiler/front/transformations/conjunctive_head.rs @@ -0,0 +1,58 @@ +use crate::compiler::front::*; + +#[derive(Clone, Debug)] +pub struct TransformConjunctiveHead { + to_add_items: Vec, +} + +impl TransformConjunctiveHead { + pub fn new() -> Self { + Self { to_add_items: vec![] } + } + + pub fn retain(&self, item: &Item) -> bool { + match item { + Item::RelationDecl(r) => { + if let Some(rule) = r.rule() { + !rule.head().is_conjunction() + } else { + true + } + } + _ => true, + } + } + + pub fn generate_items(self) -> Vec { + self.to_add_items + } +} + +impl NodeVisitorMut for TransformConjunctiveHead { + fn visit_rule(&mut self, rule: &mut Rule) { + match &rule.head().node { + RuleHeadNode::Conjunction(c) => { + for atom in c { + self.to_add_items.push(Item::RelationDecl( + RelationDeclNode::Rule( + RuleDeclNode { + attrs: Attributes::new(), + tag: Tag::default_none(), + rule: Rule::new( + rule.location().clone_without_id(), + RuleNode { + head: RuleHead::new(rule.location().clone(), RuleHeadNode::Atom(atom.clone())), + body: rule.body().clone(), + }, + ), + } + .into(), + ) + .into(), + )); + } + } + _ => {} + } + } +} diff --git a/core/src/compiler/front/transformations/const_var_to_const.rs b/core/src/compiler/front/transformations/const_var_to_const.rs index 2922e11..959bf65 100644 --- a/core/src/compiler/front/transformations/const_var_to_const.rs +++ b/core/src/compiler/front/transformations/const_var_to_const.rs @@ -1,4 +1,4 @@ -use crate::compiler::front::analyzers::ConstantDeclAnalysis; +use crate::{common::input_tag::DynamicInputTag, compiler::front::analyzers::ConstantDeclAnalysis}; use super::super::*; @@ -11,6 +11,45 @@ impl<'a> TransformConstVarToConst<'a> { pub fn new(const_decl_analysis: &'a ConstantDeclAnalysis) -> Self { Self { const_decl_analysis } } + + pub fn generate_items(&self) -> Vec { + self + .const_decl_analysis + .entity_facts + .iter() + .map(|entity_fact| { + Item::RelationDecl( + RelationDeclNode::Fact( + FactDeclNode { + attrs: Attributes::new(), + tag: TagNode(DynamicInputTag::None).into(), + atom: Atom { + loc: entity_fact.loc.clone(), + node: AtomNode { + predicate: { + entity_fact + .functor + .clone_without_location_id() + .map(|n| format!("adt#{n}")) + }, + type_args: vec![], + args: { + std::iter::once(&entity_fact.id) + .chain(entity_fact.args.iter()) + .cloned() + .map(Expr::Constant) + .collect() + }, + }, + }, + } + .into(), + ) + .into(), + ) + }) + .collect() + } } impl<'a> NodeVisitorMut for TransformConstVarToConst<'a> { diff --git a/core/src/compiler/front/transformations/desugar_case_is.rs b/core/src/compiler/front/transformations/desugar_case_is.rs new file mode 100644 index 0000000..b773b15 --- /dev/null +++ b/core/src/compiler/front/transformations/desugar_case_is.rs @@ -0,0 +1,95 @@ +use crate::compiler::front::*; +use crate::utils::IdAllocator; + +#[derive(Clone, Debug)] +pub struct DesugarCaseIs; + +impl DesugarCaseIs { + pub fn new() -> Self { + Self + } + + pub fn transform_case_is_to_formula(&self, case: &Case) -> Formula { + match &case.entity().node { + EntityNode::Expr(e) => { + // If the entity is directly an expression, the formula is a constraint + Formula::Constraint( + Constraint::default_with_expr(Expr::binary( + BinaryOp::default_eq(), + Expr::Variable(case.variable().clone()), + e.clone(), + )) + .with_span(&case.loc), + ) + } + EntityNode::Object(o) => { + // If the entity is an object, the formula is a conjunction of atoms + let parent_id = case.id(); + let variable = case.variable().clone(); + let mut variable_counter = IdAllocator::new(); + let mut formulas = vec![]; + + // Recurse through the entity to create formulas + self.transform_object_to_formula_helper(variable, o, parent_id, &mut variable_counter, &mut formulas); + + // Return the conjunction of formulas + Formula::conjunction(formulas) + } + } + } + + fn transform_object_to_formula_helper( + &self, + variable: Variable, + object: &Object, + parent_id: usize, + variable_counter: &mut IdAllocator, + formulas: &mut Vec, + ) { + // Obtain the predicate of the atom that we are going to generate + let predicate = object.functor().clone_without_location_id().map(|n| format!("adt#{n}")); + + // Obtain the second-to-last arguments in the atom + let sub_args = object.iter_args().map(|arg| { + match &arg.node { + EntityNode::Expr(e) => e.clone(), + EntityNode::Object(o) => { + // Obtain a variable id + let variable_id = variable_counter.alloc(); + + // Create a variable from the variable id + let current_variable = Variable::default_with_name(format!("adt#var#{parent_id}#{variable_id}")); + + // Recurse on the object + self.transform_object_to_formula_helper(current_variable.clone(), o, parent_id, variable_counter, formulas); + + // Return the variable as the result + Expr::Variable(current_variable) + } + } + }); + + // Create all arguments including the variable + let args = std::iter::once(Expr::Variable(variable)).chain(sub_args).collect(); + + // Add a formula to the formulas + let formula = Formula::Atom( + AtomNode { + predicate, + type_args: vec![], + args, + } + .into(), + ); + formulas.push(formula); + } +} + +impl NodeVisitorMut for DesugarCaseIs { + fn visit_formula(&mut self, formula: &mut Formula) { + match formula { + Formula::Case(c) => *formula = self.transform_case_is_to_formula(c), + _ => {} + } + } +} diff --git a/core/src/compiler/front/transformations/mod.rs b/core/src/compiler/front/transformations/mod.rs index a6859da..58498c9 100644 --- a/core/src/compiler/front/transformations/mod.rs +++ b/core/src/compiler/front/transformations/mod.rs @@ -1,13 +1,19 @@ +mod adt_to_relation; mod atomic_query; +mod conjunctive_head; mod const_var_to_const; +mod desugar_case_is; mod desugar_forall_exists; mod forall_to_not_exists; mod implies_to_disjunction; mod non_constant_fact_to_rule; mod tagged_rule; +pub use adt_to_relation::*; pub use atomic_query::*; +pub use conjunctive_head::*; pub use const_var_to_const::*; +pub use desugar_case_is::*; pub use desugar_forall_exists::*; pub use forall_to_not_exists::*; pub use implies_to_disjunction::*; diff --git a/core/src/compiler/front/transformations/non_constant_fact_to_rule.rs b/core/src/compiler/front/transformations/non_constant_fact_to_rule.rs index ffc71ea..0bb578b 100644 --- a/core/src/compiler/front/transformations/non_constant_fact_to_rule.rs +++ b/core/src/compiler/front/transformations/non_constant_fact_to_rule.rs @@ -30,6 +30,7 @@ impl NodeVisitorMut for TransformNonConstantFactToRule { // all the non-constant arguments will be replaced by a variable let head_atom: Atom = AtomNode { predicate: head.node.predicate.clone(), + type_args: vec![], args: head .iter_arguments() .enumerate() diff --git a/core/src/compiler/front/transformations/tagged_rule.rs b/core/src/compiler/front/transformations/tagged_rule.rs index da6b37c..6b415b5 100644 --- a/core/src/compiler/front/transformations/tagged_rule.rs +++ b/core/src/compiler/front/transformations/tagged_rule.rs @@ -26,6 +26,7 @@ impl TransformTaggedRule { // 2. Append the atom to the end let new_atom = AtomNode { predicate: IdentifierNode { name: pred.clone() }.into(), + type_args: vec![], args: vec![], }; let new_atom_form = Formula::Atom(new_atom.into()); @@ -38,34 +39,15 @@ impl TransformTaggedRule { // Return the predicate pred } -} - -impl NodeVisitorMut for TransformTaggedRule { - fn visit_rule_decl(&mut self, rule_decl: &mut RuleDecl) { - // If rule is directly declared with probability - if rule_decl.tag().is_some() { - // Transform the rule - let pred = Self::transform(rule_decl); - - // Store this probability for later - self - .to_add_tags - .push((pred.clone(), rule_decl.tag().input_tag().clone())); - } else if Self::has_prob_attr(rule_decl) { - // If the rule is annotated with `@probabilistic` - Self::transform(rule_decl); - } - } -} -impl Transformation for TransformTaggedRule { - fn generate_items(self) -> Vec { + pub fn drain_items(self) -> Vec { self .to_add_tags .into_iter() .map(|(pred, tag)| { let fact = AtomNode { predicate: IdentifierNode { name: pred.clone() }.into(), + type_args: vec![], args: vec![], }; let fact_decl = FactDeclNode { @@ -80,3 +62,21 @@ impl Transformation for TransformTaggedRule { .collect() } } + +impl NodeVisitorMut for TransformTaggedRule { + fn visit_rule_decl(&mut self, rule_decl: &mut RuleDecl) { + // If rule is directly declared with probability + if rule_decl.tag().is_some() { + // Transform the rule + let pred = Self::transform(rule_decl); + + // Store this probability for later + self + .to_add_tags + .push((pred.clone(), rule_decl.tag().input_tag().clone())); + } else if Self::has_prob_attr(rule_decl) { + // If the rule is annotated with `@probabilistic` + Self::transform(rule_decl); + } + } +} diff --git a/core/src/compiler/front/visitor.rs b/core/src/compiler/front/visitor.rs index dceeec3..9eaca7b 100644 --- a/core/src/compiler/front/visitor.rs +++ b/core/src/compiler/front/visitor.rs @@ -15,14 +15,19 @@ pub trait NodeVisitor { node_visitor_func_def!(visit_import_file, ImportFile); node_visitor_func_def!(visit_arg_type_binding, ArgTypeBinding); node_visitor_func_def!(visit_type, Type); + node_visitor_func_def!(visit_type_decl, TypeDecl); node_visitor_func_def!(visit_alias_type_decl, AliasTypeDecl); node_visitor_func_def!(visit_subtype_decl, SubtypeDecl); node_visitor_func_def!(visit_relation_type_decl, RelationTypeDecl); node_visitor_func_def!(visit_relation_type, RelationType); node_visitor_func_def!(visit_enum_type_decl, EnumTypeDecl); node_visitor_func_def!(visit_enum_type_member, EnumTypeMember); + node_visitor_func_def!(visit_algebraic_data_type_decl, AlgebraicDataTypeDecl); + node_visitor_func_def!(visit_algebraic_data_type_variant, AlgebraicDataTypeVariant); node_visitor_func_def!(visit_const_decl, ConstDecl); node_visitor_func_def!(visit_const_assignment, ConstAssignment); + node_visitor_func_def!(visit_entity, Entity); + node_visitor_func_def!(visit_object, Object); node_visitor_func_def!(visit_relation_decl, RelationDecl); node_visitor_func_def!(visit_constant_set_decl, ConstantSetDecl); node_visitor_func_def!(visit_constant_set, ConstantSet); @@ -38,12 +43,12 @@ pub trait NodeVisitor { node_visitor_func_def!(visit_rule_head, RuleHead); node_visitor_func_def!(visit_atom, Atom); node_visitor_func_def!(visit_neg_atom, NegAtom); - node_visitor_func_def!(visit_attribute, Attribute); node_visitor_func_def!(visit_formula, Formula); node_visitor_func_def!(visit_conjunction, Conjunction); node_visitor_func_def!(visit_disjunction, Disjunction); node_visitor_func_def!(visit_implies, Implies); node_visitor_func_def!(visit_constraint, Constraint); + node_visitor_func_def!(visit_case, Case); node_visitor_func_def!(visit_reduce, Reduce); node_visitor_func_def!(visit_forall_exists_reduce, ForallExistsReduce); node_visitor_func_def!(visit_variable_binding, VariableBinding); @@ -52,11 +57,19 @@ pub trait NodeVisitor { node_visitor_func_def!(visit_unary_expr, UnaryExpr); node_visitor_func_def!(visit_if_then_else_expr, IfThenElseExpr); node_visitor_func_def!(visit_call_expr, CallExpr); + node_visitor_func_def!(visit_new_expr, NewExpr); node_visitor_func_def!(visit_constant, Constant); + node_visitor_func_def!(visit_constant_char, ConstantChar); + node_visitor_func_def!(visit_constant_string, ConstantString); + node_visitor_func_def!(visit_constant_symbol, ConstantSymbol); + node_visitor_func_def!(visit_constant_duration, ConstantDuration); + node_visitor_func_def!(visit_constant_datetime, ConstantDateTime); node_visitor_func_def!(visit_variable, Variable); node_visitor_func_def!(visit_wildcard, Wildcard); node_visitor_func_def!(visit_identifier, Identifier); node_visitor_func_def!(visit_function_identifier, FunctionIdentifier); + node_visitor_func_def!(visit_attribute, Attribute); + node_visitor_func_def!(visit_attribute_value, AttributeValue); fn walk_items(&mut self, items: &Vec) { for item in items { @@ -101,12 +114,14 @@ pub trait NodeVisitor { } fn walk_type_decl(&mut self, type_decl: &TypeDecl) { + self.visit_type_decl(type_decl); self.visit_location(&type_decl.loc); match &type_decl.node { TypeDeclNode::Alias(a) => self.walk_alias_type_decl(a), TypeDeclNode::Subtype(s) => self.walk_subtype_decl(s), TypeDeclNode::Relation(r) => self.walk_relation_type_decl(r), TypeDeclNode::Enum(e) => self.walk_enum_type_decl(e), + TypeDeclNode::Algebraic(a) => self.walk_algebraic_data_type_decl(a), } } @@ -162,6 +177,24 @@ pub trait NodeVisitor { } } + fn walk_algebraic_data_type_decl(&mut self, algebraic_data_type_decl: &AlgebraicDataTypeDecl) { + self.visit_algebraic_data_type_decl(algebraic_data_type_decl); + self.visit_location(&algebraic_data_type_decl.loc); + self.walk_identifier(&algebraic_data_type_decl.node.name); + for variant in algebraic_data_type_decl.iter_variants() { + self.walk_algebraic_data_type_variant(variant); + } + } + + fn walk_algebraic_data_type_variant(&mut self, variant: &AlgebraicDataTypeVariant) { + self.visit_algebraic_data_type_variant(variant); + self.visit_location(&variant.loc); + self.walk_identifier(&variant.node.constructor); + for arg_type in variant.iter_arg_types() { + self.walk_type(arg_type) + } + } + fn walk_const_decl(&mut self, const_decl: &ConstDecl) { self.visit_const_decl(const_decl); self.visit_location(const_decl.location()); @@ -178,7 +211,25 @@ pub trait NodeVisitor { if let Some(ty) = const_assign.ty() { self.walk_type(ty); } - self.walk_constant(const_assign.value()) + self.walk_entity(const_assign.value()) + } + + fn walk_entity(&mut self, entity: &Entity) { + self.visit_entity(entity); + self.visit_location(entity.location()); + match &entity.node { + EntityNode::Expr(e) => self.walk_expr(e), + EntityNode::Object(o) => self.walk_object(o), + } + } + + fn walk_object(&mut self, object: &Object) { + self.visit_object(object); + self.visit_location(object.location()); + self.walk_identifier(object.functor()); + for arg in object.iter_args() { + self.walk_entity(arg); + } } fn walk_relation_decl(&mut self, relation_decl: &RelationDecl) { @@ -273,11 +324,8 @@ pub trait NodeVisitor { self.visit_location(rule_head.location()); match &rule_head.node { RuleHeadNode::Atom(a) => self.walk_atom(a), - RuleHeadNode::Disjunction(d) => { - for atom in d { - self.walk_atom(atom); - } - }, + RuleHeadNode::Conjunction(c) => c.iter().for_each(|a| self.walk_atom(a)), + RuleHeadNode::Disjunction(d) => d.iter().for_each(|a| self.walk_atom(a)), } } @@ -290,6 +338,7 @@ pub trait NodeVisitor { Formula::Constraint(c) => self.walk_constraint(c), Formula::Atom(a) => self.walk_atom(a), Formula::NegAtom(n) => self.walk_neg_atom(n), + Formula::Case(c) => self.walk_case(c), Formula::Reduce(r) => self.walk_reduce(r), Formula::ForallExistsReduce(r) => self.walk_forall_exists_reduce(r), } @@ -324,6 +373,13 @@ pub trait NodeVisitor { self.walk_expr(&cons.node.expr); } + fn walk_case(&mut self, case: &Case) { + self.visit_case(case); + self.visit_location(&case.loc); + self.walk_variable(&case.node.variable); + self.walk_entity(&case.node.entity); + } + fn walk_reduce_op(&mut self, reduce_op: &ReduceOperator) { self.visit_location(&reduce_op.loc); } @@ -368,9 +424,12 @@ pub trait NodeVisitor { fn walk_atom(&mut self, atom: &Atom) { self.visit_atom(atom); - self.visit_location(&atom.loc); - self.walk_identifier(&atom.node.predicate); - for arg in &atom.node.args { + self.visit_location(atom.location()); + self.walk_identifier(atom.predicate_identifier()); + for type_arg in atom.iter_type_arguments() { + self.walk_type(type_arg) + } + for arg in atom.iter_arguments() { self.walk_expr(arg); } } @@ -409,6 +468,7 @@ pub trait NodeVisitor { Expr::Unary(u) => self.walk_unary_expr(u), Expr::IfThenElse(i) => self.walk_if_then_else_expr(i), Expr::Call(c) => self.walk_call_expr(c), + Expr::New(n) => self.walk_new_expr(n), } } @@ -452,6 +512,15 @@ pub trait NodeVisitor { } } + fn walk_new_expr(&mut self, n: &NewExpr) { + self.visit_new_expr(n); + self.visit_location(&n.loc); + self.walk_identifier(n.functor_identifier()); + for arg in n.iter_args() { + self.walk_expr(arg); + } + } + fn walk_variable(&mut self, variable: &Variable) { self.visit_variable(variable); self.visit_location(&variable.loc); @@ -466,6 +535,39 @@ pub trait NodeVisitor { fn walk_constant(&mut self, constant: &Constant) { self.visit_constant(constant); self.visit_location(&constant.loc); + match &constant.node { + ConstantNode::Char(c) => self.walk_constant_char(c), + ConstantNode::String(s) => self.walk_constant_string(s), + ConstantNode::Symbol(s) => self.walk_constant_symbol(s), + ConstantNode::DateTime(d) => self.walk_constant_datetime(d), + ConstantNode::Duration(d) => self.walk_constant_duration(d), + _ => {} + } + } + + fn walk_constant_char(&mut self, constant_char: &ConstantChar) { + self.visit_constant_char(constant_char); + self.visit_location(constant_char.location()); + } + + fn walk_constant_string(&mut self, constant_string: &ConstantString) { + self.visit_constant_string(constant_string); + self.visit_location(constant_string.location()); + } + + fn walk_constant_symbol(&mut self, constant_symbol: &ConstantSymbol) { + self.visit_constant_symbol(constant_symbol); + self.visit_location(constant_symbol.location()); + } + + fn walk_constant_datetime(&mut self, constant_datetime: &ConstantDateTime) { + self.visit_constant_datetime(constant_datetime); + self.visit_location(constant_datetime.location()); + } + + fn walk_constant_duration(&mut self, constant_duration: &ConstantDuration) { + self.visit_constant_duration(constant_duration); + self.visit_location(constant_duration.location()); } fn walk_tag(&mut self, tag: &Tag) { @@ -488,12 +590,30 @@ pub trait NodeVisitor { self.visit_attribute(attr); self.visit_location(&attr.loc); self.walk_identifier(&attr.node.name); - for c in &attr.node.pos_args { - self.walk_constant(c); + for v in &attr.node.pos_args { + self.walk_attribute_value(v); } - for (n, c) in &attr.node.kw_args { + for (n, v) in &attr.node.kw_args { self.walk_identifier(n); - self.walk_constant(c); + self.walk_attribute_value(v); + } + } + } + + fn walk_attribute_value(&mut self, attribute_value: &AttributeValue) { + self.visit_attribute_value(attribute_value); + self.visit_location(&attribute_value.loc); + match &attribute_value.node { + AttributeValueNode::Constant(c) => self.walk_constant(c), + AttributeValueNode::List(l) => { + for v in l { + self.walk_attribute_value(v); + } + } + AttributeValueNode::Tuple(t) => { + for v in t { + self.walk_attribute_value(v); + } } } } @@ -522,14 +642,19 @@ macro_rules! impl_node_visitor_tuple { node_visitor_visit_node!(visit_import_file, ImportFile, ($($id),*)); node_visitor_visit_node!(visit_arg_type_binding, ArgTypeBinding, ($($id),*)); node_visitor_visit_node!(visit_type, Type, ($($id),*)); + node_visitor_visit_node!(visit_type_decl, TypeDecl, ($($id),*)); node_visitor_visit_node!(visit_alias_type_decl, AliasTypeDecl, ($($id),*)); node_visitor_visit_node!(visit_subtype_decl, SubtypeDecl, ($($id),*)); node_visitor_visit_node!(visit_relation_type_decl, RelationTypeDecl, ($($id),*)); node_visitor_visit_node!(visit_relation_type, RelationType, ($($id),*)); node_visitor_visit_node!(visit_enum_type_decl, EnumTypeDecl, ($($id),*)); node_visitor_visit_node!(visit_enum_type_member, EnumTypeMember, ($($id),*)); + node_visitor_visit_node!(visit_algebraic_data_type_decl, AlgebraicDataTypeDecl, ($($id),*)); + node_visitor_visit_node!(visit_algebraic_data_type_variant, AlgebraicDataTypeVariant, ($($id),*)); node_visitor_visit_node!(visit_const_decl, ConstDecl, ($($id),*)); node_visitor_visit_node!(visit_const_assignment, ConstAssignment, ($($id),*)); + node_visitor_visit_node!(visit_entity, Entity, ($($id),*)); + node_visitor_visit_node!(visit_object, Object, ($($id),*)); node_visitor_visit_node!(visit_relation_decl, RelationDecl, ($($id),*)); node_visitor_visit_node!(visit_constant_set_decl, ConstantSetDecl, ($($id),*)); node_visitor_visit_node!(visit_constant_set, ConstantSet, ($($id),*)); @@ -545,12 +670,12 @@ macro_rules! impl_node_visitor_tuple { node_visitor_visit_node!(visit_rule_head, RuleHead, ($($id),*)); node_visitor_visit_node!(visit_atom, Atom, ($($id),*)); node_visitor_visit_node!(visit_neg_atom, NegAtom, ($($id),*)); - node_visitor_visit_node!(visit_attribute, Attribute, ($($id),*)); node_visitor_visit_node!(visit_formula, Formula, ($($id),*)); node_visitor_visit_node!(visit_conjunction, Conjunction, ($($id),*)); node_visitor_visit_node!(visit_disjunction, Disjunction, ($($id),*)); node_visitor_visit_node!(visit_implies, Implies, ($($id),*)); node_visitor_visit_node!(visit_constraint, Constraint, ($($id),*)); + node_visitor_visit_node!(visit_case, Case, ($($id),*)); node_visitor_visit_node!(visit_reduce, Reduce, ($($id),*)); node_visitor_visit_node!(visit_forall_exists_reduce, ForallExistsReduce, ($($id),*)); node_visitor_visit_node!(visit_variable_binding, VariableBinding, ($($id),*)); @@ -559,11 +684,19 @@ macro_rules! impl_node_visitor_tuple { node_visitor_visit_node!(visit_unary_expr, UnaryExpr, ($($id),*)); node_visitor_visit_node!(visit_if_then_else_expr, IfThenElseExpr, ($($id),*)); node_visitor_visit_node!(visit_call_expr, CallExpr, ($($id),*)); + node_visitor_visit_node!(visit_new_expr, NewExpr, ($($id),*)); node_visitor_visit_node!(visit_constant, Constant, ($($id),*)); + node_visitor_visit_node!(visit_constant_char, ConstantChar, ($($id),*)); + node_visitor_visit_node!(visit_constant_string, ConstantString, ($($id),*)); + node_visitor_visit_node!(visit_constant_symbol, ConstantSymbol, ($($id),*)); + node_visitor_visit_node!(visit_constant_datetime, ConstantDateTime, ($($id),*)); + node_visitor_visit_node!(visit_constant_duration, ConstantDuration, ($($id),*)); node_visitor_visit_node!(visit_variable, Variable, ($($id),*)); node_visitor_visit_node!(visit_wildcard, Wildcard, ($($id),*)); node_visitor_visit_node!(visit_identifier, Identifier, ($($id),*)); node_visitor_visit_node!(visit_function_identifier, FunctionIdentifier, ($($id),*)); + node_visitor_visit_node!(visit_attribute, Attribute, ($($id),*)); + node_visitor_visit_node!(visit_attribute_value, AttributeValue, ($($id),*)); } } } diff --git a/core/src/compiler/front/visitor_mut.rs b/core/src/compiler/front/visitor_mut.rs index b3498fd..eb23e50 100644 --- a/core/src/compiler/front/visitor_mut.rs +++ b/core/src/compiler/front/visitor_mut.rs @@ -15,14 +15,19 @@ pub trait NodeVisitorMut { node_visitor_mut_func_def!(visit_import_file, ImportFile); node_visitor_mut_func_def!(visit_arg_type_binding, ArgTypeBinding); node_visitor_mut_func_def!(visit_type, Type); + node_visitor_mut_func_def!(visit_type_decl, TypeDecl); node_visitor_mut_func_def!(visit_alias_type_decl, AliasTypeDecl); node_visitor_mut_func_def!(visit_subtype_decl, SubtypeDecl); node_visitor_mut_func_def!(visit_relation_type_decl, RelationTypeDecl); node_visitor_mut_func_def!(visit_relation_type, RelationType); node_visitor_mut_func_def!(visit_enum_type_decl, EnumTypeDecl); node_visitor_mut_func_def!(visit_enum_type_member, EnumTypeMember); + node_visitor_mut_func_def!(visit_algebraic_data_type_decl, AlgebraicDataTypeDecl); + node_visitor_mut_func_def!(visit_algebraic_data_type_variant, AlgebraicDataTypeVariant); node_visitor_mut_func_def!(visit_const_decl, ConstDecl); node_visitor_mut_func_def!(visit_const_assignment, ConstAssignment); + node_visitor_mut_func_def!(visit_entity, Entity); + node_visitor_mut_func_def!(visit_object, Object); node_visitor_mut_func_def!(visit_relation_decl, RelationDecl); node_visitor_mut_func_def!(visit_constant_set_decl, ConstantSetDecl); node_visitor_mut_func_def!(visit_constant_set, ConstantSet); @@ -38,12 +43,12 @@ pub trait NodeVisitorMut { node_visitor_mut_func_def!(visit_rule_head, RuleHead); node_visitor_mut_func_def!(visit_atom, Atom); node_visitor_mut_func_def!(visit_neg_atom, NegAtom); - node_visitor_mut_func_def!(visit_attribute, Attribute); node_visitor_mut_func_def!(visit_formula, Formula); node_visitor_mut_func_def!(visit_conjunction, Conjunction); node_visitor_mut_func_def!(visit_disjunction, Disjunction); node_visitor_mut_func_def!(visit_implies, Implies); node_visitor_mut_func_def!(visit_constraint, Constraint); + node_visitor_mut_func_def!(visit_case, Case); node_visitor_mut_func_def!(visit_reduce, Reduce); node_visitor_mut_func_def!(visit_forall_exists_reduce, ForallExistsReduce); node_visitor_mut_func_def!(visit_variable_binding, VariableBinding); @@ -52,11 +57,19 @@ pub trait NodeVisitorMut { node_visitor_mut_func_def!(visit_unary_expr, UnaryExpr); node_visitor_mut_func_def!(visit_if_then_else_expr, IfThenElseExpr); node_visitor_mut_func_def!(visit_call_expr, CallExpr); + node_visitor_mut_func_def!(visit_new_expr, NewExpr); node_visitor_mut_func_def!(visit_constant, Constant); + node_visitor_mut_func_def!(visit_constant_char, ConstantChar); + node_visitor_mut_func_def!(visit_constant_string, ConstantString); + node_visitor_mut_func_def!(visit_constant_symbol, ConstantSymbol); + node_visitor_mut_func_def!(visit_constant_duration, ConstantDuration); + node_visitor_mut_func_def!(visit_constant_datetime, ConstantDateTime); node_visitor_mut_func_def!(visit_variable, Variable); node_visitor_mut_func_def!(visit_wildcard, Wildcard); node_visitor_mut_func_def!(visit_identifier, Identifier); node_visitor_mut_func_def!(visit_function_identifier, FunctionIdentifier); + node_visitor_mut_func_def!(visit_attribute, Attribute); + node_visitor_mut_func_def!(visit_attribute_value, AttributeValue); fn walk_items(&mut self, items: &mut Vec) { for item in items { @@ -101,12 +114,14 @@ pub trait NodeVisitorMut { } fn walk_type_decl(&mut self, type_decl: &mut TypeDecl) { + self.visit_type_decl(type_decl); self.visit_location(&mut type_decl.loc); match &mut type_decl.node { TypeDeclNode::Alias(a) => self.walk_alias_type_decl(a), TypeDeclNode::Subtype(s) => self.walk_subtype_decl(s), TypeDeclNode::Relation(r) => self.walk_relation_type_decl(r), TypeDeclNode::Enum(e) => self.walk_enum_type_decl(e), + TypeDeclNode::Algebraic(a) => self.walk_algebraic_data_type_decl(a), } } @@ -162,6 +177,24 @@ pub trait NodeVisitorMut { } } + fn walk_algebraic_data_type_decl(&mut self, algebraic_data_type_decl: &mut AlgebraicDataTypeDecl) { + self.visit_algebraic_data_type_decl(algebraic_data_type_decl); + self.visit_location(&mut algebraic_data_type_decl.loc); + self.walk_identifier(&mut algebraic_data_type_decl.node.name); + for variant in algebraic_data_type_decl.iter_variants_mut() { + self.walk_algebraic_data_type_variant(variant); + } + } + + fn walk_algebraic_data_type_variant(&mut self, variant: &mut AlgebraicDataTypeVariant) { + self.visit_algebraic_data_type_variant(variant); + self.visit_location(&mut variant.loc); + self.walk_identifier(&mut variant.node.constructor); + for arg_type in variant.iter_arg_types_mut() { + self.walk_type(arg_type) + } + } + fn walk_const_decl(&mut self, const_decl: &mut ConstDecl) { self.visit_const_decl(const_decl); self.visit_location(const_decl.location_mut()); @@ -178,7 +211,25 @@ pub trait NodeVisitorMut { if let Some(ty) = const_assign.ty_mut() { self.walk_type(ty); } - self.walk_constant(const_assign.value_mut()) + self.walk_entity(const_assign.value_mut()) + } + + fn walk_entity(&mut self, entity: &mut Entity) { + self.visit_entity(entity); + self.visit_location(entity.location_mut()); + match &mut entity.node { + EntityNode::Expr(e) => self.walk_expr(e), + EntityNode::Object(o) => self.walk_object(o), + } + } + + fn walk_object(&mut self, object: &mut Object) { + self.visit_object(object); + self.visit_location(object.location_mut()); + self.walk_identifier(object.functor_mut()); + for arg in object.iter_args_mut() { + self.walk_entity(arg); + } } fn walk_relation_decl(&mut self, relation_decl: &mut RelationDecl) { @@ -273,11 +324,8 @@ pub trait NodeVisitorMut { self.visit_location(rule_head.location_mut()); match &mut rule_head.node { RuleHeadNode::Atom(a) => self.walk_atom(a), - RuleHeadNode::Disjunction(d) => { - for atom in d { - self.walk_atom(atom); - } - }, + RuleHeadNode::Conjunction(c) => c.iter_mut().for_each(|a| self.walk_atom(a)), + RuleHeadNode::Disjunction(d) => d.iter_mut().for_each(|a| self.walk_atom(a)), } } @@ -290,6 +338,7 @@ pub trait NodeVisitorMut { Formula::Constraint(c) => self.walk_constraint(c), Formula::Atom(a) => self.walk_atom(a), Formula::NegAtom(n) => self.walk_neg_atom(n), + Formula::Case(c) => self.walk_case(c), Formula::Reduce(r) => self.walk_reduce(r), Formula::ForallExistsReduce(r) => self.walk_forall_exists_reduce(r), } @@ -324,6 +373,13 @@ pub trait NodeVisitorMut { self.walk_expr(&mut cons.node.expr); } + fn walk_case(&mut self, case: &mut Case) { + self.visit_case(case); + self.visit_location(&mut case.loc); + self.walk_variable(&mut case.node.variable); + self.walk_entity(&mut case.node.entity); + } + fn walk_reduce_op(&mut self, reduce_op: &mut ReduceOperator) { self.visit_location(&mut reduce_op.loc); } @@ -368,9 +424,12 @@ pub trait NodeVisitorMut { fn walk_atom(&mut self, atom: &mut Atom) { self.visit_atom(atom); - self.visit_location(&mut atom.loc); - self.walk_identifier(&mut atom.node.predicate); - for arg in &mut atom.node.args { + self.visit_location(atom.location_mut()); + self.walk_identifier(atom.predicate_identifier_mut()); + for type_arg in atom.iter_type_arguments_mut() { + self.walk_type(type_arg); + } + for arg in atom.iter_arguments_mut() { self.walk_expr(arg); } } @@ -409,6 +468,7 @@ pub trait NodeVisitorMut { Expr::Unary(u) => self.walk_unary_expr(u), Expr::IfThenElse(i) => self.walk_if_then_else_expr(i), Expr::Call(c) => self.walk_call_expr(c), + Expr::New(n) => self.walk_new_expr(n), } } @@ -452,6 +512,15 @@ pub trait NodeVisitorMut { } } + fn walk_new_expr(&mut self, n: &mut NewExpr) { + self.visit_new_expr(n); + self.visit_location(&mut n.loc); + self.walk_identifier(n.functor_identifier_mut()); + for arg in n.iter_args_mut() { + self.walk_expr(arg); + } + } + fn walk_variable(&mut self, variable: &mut Variable) { self.visit_variable(variable); self.visit_location(&mut variable.loc); @@ -466,6 +535,39 @@ pub trait NodeVisitorMut { fn walk_constant(&mut self, constant: &mut Constant) { self.visit_constant(constant); self.visit_location(&mut constant.loc); + match &mut constant.node { + ConstantNode::Char(c) => self.walk_constant_char(c), + ConstantNode::String(s) => self.walk_constant_string(s), + ConstantNode::Symbol(s) => self.walk_constant_symbol(s), + ConstantNode::DateTime(d) => self.walk_constant_datetime(d), + ConstantNode::Duration(d) => self.walk_constant_duration(d), + _ => {} + } + } + + fn walk_constant_char(&mut self, constant_char: &mut ConstantChar) { + self.visit_constant_char(constant_char); + self.visit_location(constant_char.location_mut()); + } + + fn walk_constant_string(&mut self, constant_string: &mut ConstantString) { + self.visit_constant_string(constant_string); + self.visit_location(constant_string.location_mut()); + } + + fn walk_constant_symbol(&mut self, constant_symbol: &mut ConstantSymbol) { + self.visit_constant_symbol(constant_symbol); + self.visit_location(constant_symbol.location_mut()); + } + + fn walk_constant_datetime(&mut self, constant_datetime: &mut ConstantDateTime) { + self.visit_constant_datetime(constant_datetime); + self.visit_location(constant_datetime.location_mut()); + } + + fn walk_constant_duration(&mut self, constant_duration: &mut ConstantDuration) { + self.visit_constant_duration(constant_duration); + self.visit_location(constant_duration.location_mut()); } fn walk_tag(&mut self, tag: &mut Tag) { @@ -489,11 +591,29 @@ pub trait NodeVisitorMut { self.visit_location(&mut attr.loc); self.walk_identifier(&mut attr.node.name); for c in &mut attr.node.pos_args { - self.walk_constant(c); + self.walk_attribute_value(c); } for (n, c) in &mut attr.node.kw_args { self.walk_identifier(n); - self.walk_constant(c); + self.walk_attribute_value(c); + } + } + } + + fn walk_attribute_value(&mut self, attribute_value: &mut AttributeValue) { + self.visit_attribute_value(attribute_value); + self.visit_location(&mut attribute_value.loc); + match &mut attribute_value.node { + AttributeValueNode::Constant(c) => self.walk_constant(c), + AttributeValueNode::List(l) => { + for v in l { + self.walk_attribute_value(v); + } + } + AttributeValueNode::Tuple(t) => { + for v in t { + self.walk_attribute_value(v); + } } } } @@ -520,14 +640,19 @@ macro_rules! impl_node_visitor_mut_tuple { node_visitor_mut_visit_node!(visit_item, Item, ($($id),*)); node_visitor_mut_visit_node!(visit_arg_type_binding, ArgTypeBinding, ($($id),*)); node_visitor_mut_visit_node!(visit_type, Type, ($($id),*)); + node_visitor_mut_visit_node!(visit_type_decl, TypeDecl, ($($id),*)); node_visitor_mut_visit_node!(visit_alias_type_decl, AliasTypeDecl, ($($id),*)); node_visitor_mut_visit_node!(visit_subtype_decl, SubtypeDecl, ($($id),*)); node_visitor_mut_visit_node!(visit_relation_type_decl, RelationTypeDecl, ($($id),*)); node_visitor_mut_visit_node!(visit_relation_type, RelationType, ($($id),*)); node_visitor_mut_visit_node!(visit_enum_type_decl, EnumTypeDecl, ($($id),*)); node_visitor_mut_visit_node!(visit_enum_type_member, EnumTypeMember, ($($id),*)); + node_visitor_mut_visit_node!(visit_algebraic_data_type_decl, AlgebraicDataTypeDecl, ($($id),*)); + node_visitor_mut_visit_node!(visit_algebraic_data_type_variant, AlgebraicDataTypeVariant, ($($id),*)); node_visitor_mut_visit_node!(visit_const_decl, ConstDecl, ($($id),*)); node_visitor_mut_visit_node!(visit_const_assignment, ConstAssignment, ($($id),*)); + node_visitor_mut_visit_node!(visit_entity, Entity, ($($id),*)); + node_visitor_mut_visit_node!(visit_object, Object, ($($id),*)); node_visitor_mut_visit_node!(visit_relation_decl, RelationDecl, ($($id),*)); node_visitor_mut_visit_node!(visit_constant_set_decl, ConstantSetDecl, ($($id),*)); node_visitor_mut_visit_node!(visit_constant_set, ConstantSet, ($($id),*)); @@ -543,12 +668,12 @@ macro_rules! impl_node_visitor_mut_tuple { node_visitor_mut_visit_node!(visit_rule_head, RuleHead, ($($id),*)); node_visitor_mut_visit_node!(visit_atom, Atom, ($($id),*)); node_visitor_mut_visit_node!(visit_neg_atom, NegAtom, ($($id),*)); - node_visitor_mut_visit_node!(visit_attribute, Attribute, ($($id),*)); node_visitor_mut_visit_node!(visit_formula, Formula, ($($id),*)); node_visitor_mut_visit_node!(visit_conjunction, Conjunction, ($($id),*)); node_visitor_mut_visit_node!(visit_disjunction, Disjunction, ($($id),*)); node_visitor_mut_visit_node!(visit_implies, Implies, ($($id),*)); node_visitor_mut_visit_node!(visit_constraint, Constraint, ($($id),*)); + node_visitor_mut_visit_node!(visit_case, Case, ($($id),*)); node_visitor_mut_visit_node!(visit_reduce, Reduce, ($($id),*)); node_visitor_mut_visit_node!(visit_forall_exists_reduce, ForallExistsReduce, ($($id),*)); node_visitor_mut_visit_node!(visit_variable_binding, VariableBinding, ($($id),*)); @@ -556,11 +681,20 @@ macro_rules! impl_node_visitor_mut_tuple { node_visitor_mut_visit_node!(visit_binary_expr, BinaryExpr, ($($id),*)); node_visitor_mut_visit_node!(visit_unary_expr, UnaryExpr, ($($id),*)); node_visitor_mut_visit_node!(visit_if_then_else_expr, IfThenElseExpr, ($($id),*)); + node_visitor_mut_visit_node!(visit_call_expr, CallExpr, ($($id),*)); + node_visitor_mut_visit_node!(visit_new_expr, NewExpr, ($($id),*)); node_visitor_mut_visit_node!(visit_constant, Constant, ($($id),*)); + node_visitor_mut_visit_node!(visit_constant_char, ConstantChar, ($($id),*)); + node_visitor_mut_visit_node!(visit_constant_string, ConstantString, ($($id),*)); + node_visitor_mut_visit_node!(visit_constant_symbol, ConstantSymbol, ($($id),*)); + node_visitor_mut_visit_node!(visit_constant_datetime, ConstantDateTime, ($($id),*)); + node_visitor_mut_visit_node!(visit_constant_duration, ConstantDuration, ($($id),*)); node_visitor_mut_visit_node!(visit_variable, Variable, ($($id),*)); node_visitor_mut_visit_node!(visit_wildcard, Wildcard, ($($id),*)); node_visitor_mut_visit_node!(visit_identifier, Identifier, ($($id),*)); node_visitor_mut_visit_node!(visit_function_identifier, FunctionIdentifier, ($($id),*)); + node_visitor_mut_visit_node!(visit_attribute, Attribute, ($($id),*)); + node_visitor_mut_visit_node!(visit_attribute_value, AttributeValue, ($($id),*)); } } } diff --git a/core/src/compiler/ram/ast.rs b/core/src/compiler/ram/ast.rs index 3eec9b9..6b63705 100644 --- a/core/src/compiler/ram/ast.rs +++ b/core/src/compiler/ram/ast.rs @@ -29,6 +29,7 @@ impl Program { } } + /// Get a relation by its name; returns `None` if the relation does not exist. pub fn relation(&self, name: &str) -> Option<&Relation> { self .relation_to_stratum @@ -36,14 +37,17 @@ impl Program { .and_then(|stratum_id| self.strata[*stratum_id].relations.get(name)) } + /// Get a relation by its name; panics if the relation does not exist. pub fn relation_unchecked(&self, name: &str) -> &Relation { &self.strata[self.relation_to_stratum[name]].relations[name] } + /// Iterate over all relations in the program. pub fn relations(&self) -> impl Iterator { self.strata.iter().flat_map(|s| s.relations.values()) } + /// Iterate over all relation (name, type) pairs in the program pub fn relation_types<'a>(&'a self) -> impl 'a + Iterator { self .strata @@ -51,6 +55,7 @@ impl Program { .flat_map(|s| s.relations.iter().map(|(p, r)| (p.clone(), r.tuple_type.clone()))) } + /// Get the tuple type of a relation by its name; returns `None` if the relation does not exist. pub fn relation_tuple_type(&self, predicate: &str) -> Option { if let Some(stratum_id) = self.relation_to_stratum.get(predicate) { Some(self.strata[*stratum_id].relations[predicate].tuple_type.clone()) @@ -59,6 +64,23 @@ impl Program { } } + /// Get the input file option of a relation by its name, returns `None` if the relation does not exist. + pub fn input_file(&self, predicate: &str) -> Option<&InputFile> { + if let Some(stratum_id) = self.relation_to_stratum.get(predicate) { + self.strata[*stratum_id].relations[predicate].input_file.as_ref() + } else { + None + } + } + + /// Iterate through all the input files in the program. + pub fn input_files<'a>(&'a self) -> impl 'a + Iterator { + self + .relations() + .filter_map(move |relation| relation.input_file.as_ref()) + } + + /// Set the target relation(s) to be output relations. pub fn set_output_relations(&mut self, target: Vec<&str>) { self.strata.iter_mut().for_each(|stratum| { stratum.relations.iter_mut().for_each(|(_, relation)| { @@ -71,11 +93,12 @@ impl Program { }) } - pub fn output_option(&self, relation: &str) -> Option { + /// Get the output option of a relation by its name; returns `None` if the relation does not exist. + pub fn output_option(&self, relation: &str) -> Option<&OutputOption> { self .relation_to_stratum .get(relation) - .map(|stratum_id| self.strata[*stratum_id].relations[relation].output.clone()) + .map(|stratum_id| &self.strata[*stratum_id].relations[relation].output) } } @@ -301,8 +324,7 @@ impl Dataflow { | Self::Exclusion(d, _) => d.source_relations(), Self::Reduce(r) => std::iter::once(r.source_relation()).collect(), Self::Relation(r) => std::iter::once(r).collect(), - Self::ForeignPredicateGround(_, _) - | Self::UntaggedVec(_) => HashSet::new(), + Self::ForeignPredicateGround(_, _) | Self::UntaggedVec(_) => HashSet::new(), } } } diff --git a/core/src/compiler/ram/dependency.rs b/core/src/compiler/ram/dependency.rs index ccb4f07..fa54b97 100644 --- a/core/src/compiler/ram/dependency.rs +++ b/core/src/compiler/ram/dependency.rs @@ -68,8 +68,7 @@ impl Update { impl Dataflow { fn collect_dependency(&self, preds: &mut HashSet) { match self { - Self::Unit(_) - | Self::UntaggedVec(_) => {} + Self::Unit(_) | Self::UntaggedVec(_) => {} Self::Relation(r) => { preds.insert(r.clone()); } diff --git a/core/src/compiler/ram/pretty.rs b/core/src/compiler/ram/pretty.rs index 142ece8..dc3217b 100644 --- a/core/src/compiler/ram/pretty.rs +++ b/core/src/compiler/ram/pretty.rs @@ -73,7 +73,10 @@ impl Dataflow { match self { // Base relations Self::Unit(t) => f.write_fmt(format_args!("Unit({})", t)), - Self::UntaggedVec(v) => f.write_fmt(format_args!("Vec([{}])", v.iter().map(|t| format!("{}", t)).collect::>().join(", "))), + Self::UntaggedVec(v) => f.write_fmt(format_args!( + "Vec([{}])", + v.iter().map(|t| format!("{}", t)).collect::>().join(", ") + )), Self::Relation(r) => f.write_fmt(format_args!("Relation {}", r)), // Unary operations @@ -158,12 +161,22 @@ impl Dataflow { } Self::ForeignPredicateConstraint(d, pred, args) => { let args = args.iter().map(|a| format!("{:?}", a)).collect::>(); - f.write_fmt(format_args!("ForeignPredicateConstraint[{}({})]\n{}", pred, args.join(", "), padding))?; + f.write_fmt(format_args!( + "ForeignPredicateConstraint[{}({})]\n{}", + pred, + args.join(", "), + padding + ))?; d.pretty_print(f, next_indent, indent_size) } Self::ForeignPredicateJoin(d, pred, args) => { let args = args.iter().map(|a| format!("{:?}", a)).collect::>(); - f.write_fmt(format_args!("ForeignPredicateJoin[{}({})]\n{}", pred, args.join(", "), padding))?; + f.write_fmt(format_args!( + "ForeignPredicateJoin[{}({})]\n{}", + pred, + args.join(", "), + padding + ))?; d.pretty_print(f, next_indent, indent_size) } } diff --git a/core/src/compiler/ram/ram2rs.rs b/core/src/compiler/ram/ram2rs.rs index 9aab626..52d95cc 100644 --- a/core/src/compiler/ram/ram2rs.rs +++ b/core/src/compiler/ram/ram2rs.rs @@ -84,9 +84,9 @@ impl ast::Program { // Composite quote! { - // use std::rc::Rc; use scallop_core::common::input_tag::*; use scallop_core::runtime::provenance::*; + use scallop_core::runtime::env::*; use scallop_core::runtime::statics::*; use scallop_core::runtime::database::extensional::*; #(#strata_rs)* @@ -95,7 +95,8 @@ impl ast::Program { run_with_edb(ctx, ExtensionalDatabase::new()) } pub fn run_with_edb(ctx: &mut C, mut edb: ExtensionalDatabase) -> OutputRelations { - edb.internalize(ctx); + let runtime_env = RuntimeEnvironment::default(); + edb.internalize(&runtime_env, ctx); #(#exec_strata)* #output_relations } @@ -451,9 +452,11 @@ fn value_type_to_rs_type(ty: &ValueType) -> TokenStream { ValueType::Char => quote! { char }, ValueType::Str => quote! { &'static str }, ValueType::String => quote! { String }, + ValueType::Symbol => unimplemented!(), ValueType::DateTime => quote! { DateTime }, ValueType::Duration => quote! { Duration }, - // ValueType::RcString => quote! { Rc }, + ValueType::Entity => quote! { u64 }, + ValueType::Tensor => unimplemented!(), } } @@ -503,6 +506,9 @@ fn expr_to_rs_expr(expr: &Expr) -> TokenStream { Expr::Call(_) => { unimplemented!() } + Expr::New(_) => { + unimplemented!() + } } } @@ -513,6 +519,7 @@ fn input_tag_to_rs_input_tag(tag: &DynamicInputTag) -> TokenStream { DynamicInputTag::Bool(b) => quote! { DynamicInputTag::Bool(#b) }, DynamicInputTag::Float(f) => quote! { DynamicInputTag::Float(#f) }, DynamicInputTag::ExclusiveFloat(f, u) => quote! { DynamicInputTag::ExclusiveFloat(#f, #u) }, + DynamicInputTag::Tensor(_) => unimplemented!(), } } @@ -568,9 +575,13 @@ fn value_to_rs_value(value: &Value) -> TokenStream { Bool(b) => quote! { #b }, Str(s) => quote! { #s }, String(s) => quote! { String::from(#s) }, - // RcString(s) => quote! { Rc::new(String::from(#s)) }, + Symbol(_) => unimplemented!(), + SymbolString(_) => unimplemented!(), DateTime(_) => unimplemented!(), Duration(_) => unimplemented!(), + Entity(e) => quote! { #e }, + Tensor(_) => panic!("[Internal Error] Should not have raw tensor during compilation"), + TensorValue(_) => panic!("[Internal Error] Should not have tensor value during compilation"), } } diff --git a/core/src/integrate/attribute.rs b/core/src/integrate/attribute.rs index e706dac..ce27eac 100644 --- a/core/src/integrate/attribute.rs +++ b/core/src/integrate/attribute.rs @@ -8,17 +8,21 @@ pub enum AttributeArgument { Integer(i64), Boolean(bool), String(String), + List(Vec), } impl AttributeArgument { - pub fn to_front(&self) -> front::Constant { - let c = match self { - Self::Float(f) => front::ConstantNode::Float(f.clone()), - Self::Integer(i) => front::ConstantNode::Integer(i.clone()), - Self::Boolean(b) => front::ConstantNode::Boolean(b.clone()), - Self::String(s) => front::ConstantNode::String(s.clone()), - }; - front::Constant::default(c) + pub fn to_front(&self) -> front::ast::AttributeValue { + match self { + Self::Float(f) => front::ast::Constant::float(f.clone()).into(), + Self::Integer(i) => front::ast::Constant::integer(i.clone()).into(), + Self::Boolean(b) => front::ast::Constant::boolean(b.clone()).into(), + Self::String(s) => front::ast::Constant::string(s.clone()).into(), + Self::List(l) => { + let l = l.iter().map(AttributeArgument::to_front).collect(); + front::ast::AttributeValue::default(front::ast::AttributeValueNode::List(l)) + } + } } } @@ -40,6 +44,15 @@ impl From for AttributeArgument { } } +impl From> for AttributeArgument +where + T: Into, +{ + fn from(v: Vec) -> Self { + Self::List(v.into_iter().map(|t| t.into()).collect()) + } +} + #[derive(Clone, Debug, PartialEq)] pub struct Attribute { pub name: String, @@ -56,8 +69,8 @@ impl Attribute { } } - pub fn to_front(&self) -> front::Attribute { - front::Attribute::default(front::AttributeNode { + pub fn to_front(&self) -> front::ast::Attribute { + front::ast::Attribute::default(front::ast::AttributeNode { name: string_to_front_identifier(&self.name), pos_args: self .positional_arguments @@ -77,6 +90,6 @@ impl Attribute { } } -fn string_to_front_identifier(s: &str) -> front::Identifier { - front::Identifier::default(front::IdentifierNode { name: s.to_string() }) +fn string_to_front_identifier(s: &str) -> front::ast::Identifier { + front::ast::Identifier::default(front::ast::IdentifierNode { name: s.to_string() }) } diff --git a/core/src/integrate/context.rs b/core/src/integrate/context.rs index d3509fe7..040598c 100644 --- a/core/src/integrate/context.rs +++ b/core/src/integrate/context.rs @@ -1,9 +1,12 @@ +use std::collections::*; + use crate::common::foreign_function::*; use crate::common::foreign_predicate::*; use crate::common::tuple::*; use crate::common::tuple_type::*; use crate::compiler; +use crate::compiler::front::attribute::AttributeProcessor; use crate::runtime::database::extensional::*; use crate::runtime::database::*; use crate::runtime::dynamic; @@ -112,6 +115,10 @@ impl IntegrateContext { &mut self.internal.prov_ctx } + pub fn runtime_environment(&self) -> &RuntimeEnvironment { + &self.internal.runtime_env + } + pub fn internal_context(&self) -> &InternalIntegrateContext { &self.internal } @@ -148,7 +155,7 @@ impl IntegrateContext { } /// Compile a relation declaration - pub fn add_relation(&mut self, string: &str) -> Result<&compiler::front::RelationTypeDecl, IntegrateError> { + pub fn add_relation(&mut self, string: &str) -> Result<&compiler::front::ast::RelationTypeDecl, IntegrateError> { self.front_has_changed = true; let source = compiler::front::StringSource::new(string.to_string()); self @@ -163,7 +170,7 @@ impl IntegrateContext { &mut self, string: &str, attrs: Vec, - ) -> Result<&compiler::front::RelationTypeDecl, IntegrateError> { + ) -> Result<&compiler::front::ast::RelationTypeDecl, IntegrateError> { self.front_has_changed = true; let source = compiler::front::StringSource::new(string.to_string()); self @@ -224,7 +231,7 @@ impl IntegrateContext { let source = compiler::front::StringSource::new(string.to_string()); self .front_ctx - .compile_rule_with_annotator(source, |item: &mut compiler::front::Item| { + .compile_rule_with_annotator(source, |item: &mut compiler::front::ast::Item| { item.attributes_mut().extend(attrs.iter().map(Attribute::to_front)) }) .map_err(IntegrateError::front) @@ -270,6 +277,35 @@ impl IntegrateContext { Ok(()) } + pub fn add_entity(&mut self, relation: &str, entity_tuple: Vec) -> Result<(), IntegrateError> { + // First obtain the facts from the entity str + let facts = self + .front_ctx + .compile_entity_to_facts(relation, entity_tuple) + .map_err(|e| IntegrateError::Compile(vec![compiler::CompileError::Front(e)]))?; + + // Then add the facts to the EDB + for (relation, tuples) in facts { + self + .edb() + .add_facts(&relation, tuples) + .map_err(|e| IntegrateError::Runtime(RuntimeError::Database(e)))?; + } + + Ok(()) + } + + pub fn compile_entity( + &self, + relation: &str, + entity_tuple: Vec, + ) -> Result>, IntegrateError> { + self + .front_ctx + .compile_entity_to_facts(relation, entity_tuple) + .map_err(|e| IntegrateError::Compile(vec![compiler::CompileError::Front(e)])) + } + /// Register a foreign function to the context pub fn register_foreign_function(&mut self, ff: F) -> Result<(), IntegrateError> where @@ -306,6 +342,23 @@ impl IntegrateContext { Ok(()) } + pub fn register_foreign_attribute(&mut self, p: A) -> Result<(), IntegrateError> + where + A: AttributeProcessor + Send + Sync + Clone, + { + // Add the predicate to front compilation context + self + .front_ctx + .register_attribute_processor(p) + .map_err(|e| IntegrateError::Runtime(RuntimeError::ForeignAttribute(e)))?; + + // If goes through, then the front context has changed + self.front_has_changed = true; + + // Return ok + Ok(()) + } + /// Set the context to be non-incremental anymore pub fn set_non_incremental(&mut self) { self.internal.exec_ctx.set_non_incremental(); @@ -582,13 +635,13 @@ impl InternalIntegrateContext { } pub fn computed_relation_ref(&mut self, relation: &str) -> Option<&dynamic::DynamicOutputCollection> { - self.exec_ctx.recover(relation, &self.prov_ctx); + self.exec_ctx.recover(relation, &self.runtime_env, &self.prov_ctx); self.exec_ctx.relation_ref(relation) } /// Get the RC'ed output collection of a given relation pub fn computed_relation(&mut self, relation: &str) -> Option>> { - self.exec_ctx.recover(relation, &self.prov_ctx); + self.exec_ctx.recover(relation, &self.runtime_env, &self.prov_ctx); self.exec_ctx.relation(relation) } @@ -598,7 +651,9 @@ impl InternalIntegrateContext { relation: &str, m: &M, ) -> Option>> { - self.exec_ctx.recover_with_monitor(relation, &self.prov_ctx, m); + self + .exec_ctx + .recover_with_monitor(relation, &self.runtime_env, &self.prov_ctx, m); self.exec_ctx.relation(relation) } } diff --git a/core/src/integrate/error.rs b/core/src/integrate/error.rs index eb8c1fa..4c70486 100644 --- a/core/src/integrate/error.rs +++ b/core/src/integrate/error.rs @@ -11,6 +11,17 @@ impl IntegrateError { pub fn front(e: compiler::front::FrontCompileError) -> Self { Self::Compile(vec![compiler::CompileError::Front(e)]) } + + pub fn io(e: IOError) -> Self { + Self::Runtime(RuntimeError::IO(e)) + } + + pub fn kind(&self) -> &'static str { + match self { + Self::Compile(_) => "Compile error occurred; aborted", + Self::Runtime(_) => "Runtime error occurred; aborted", + } + } } impl std::fmt::Display for IntegrateError { diff --git a/core/src/integrate/interpret.rs b/core/src/integrate/interpret.rs index 3959f99..96d3020 100644 --- a/core/src/integrate/interpret.rs +++ b/core/src/integrate/interpret.rs @@ -69,7 +69,7 @@ impl InterpretContext InterpretContext {} OutputOption::Default => { - relation.recover(&self.provenance, true); + // Recover + relation.recover(&self.runtime_env, &self.provenance, true); } - OutputOption::File(_) => { - unimplemented!("Cannot output into file for now") + OutputOption::File(f) => { + // Recover and export the file + relation.recover(&self.runtime_env, &self.provenance, true); + database::io::store_file(f, relation).map_err(IntegrateError::io)?; } } } @@ -105,7 +108,7 @@ impl InterpretContext InterpretContext {} OutputOption::Default => { - relation.recover_with_monitor(&self.provenance, m, true); + // Recover + relation.recover_with_monitor(&self.runtime_env, &self.provenance, m, true); } - OutputOption::File(_) => { - unimplemented!("Cannot output into file for now") + OutputOption::File(f) => { + // Recover and export the file + relation.recover_with_monitor(&self.runtime_env, &self.provenance, m, true); + database::io::store_file(f, relation).map_err(IntegrateError::io)?; } } } diff --git a/core/src/lib.rs b/core/src/lib.rs index 64c6740..af6f3b0 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,5 +1,6 @@ #![feature(min_specialization)] -#![feature(drain_filter)] +#![feature(extract_if)] +#![feature(proc_macro_span)] pub mod common; pub mod compiler; diff --git a/core/src/runtime/database/error.rs b/core/src/runtime/database/error.rs index b26b80c..31aaf01 100644 --- a/core/src/runtime/database/error.rs +++ b/core/src/runtime/database/error.rs @@ -1,5 +1,6 @@ use crate::common::tuple::*; use crate::common::tuple_type::*; +use crate::runtime::error::*; #[derive(Clone, Debug)] pub enum DatabaseError { @@ -14,6 +15,7 @@ pub enum DatabaseError { NewProgramFacts { relation: String, }, + IO(IOError), } impl std::fmt::Display for DatabaseError { @@ -32,6 +34,7 @@ impl std::fmt::Display for DatabaseError { "New facts in program declared for relation `{}`; cannot incrementally compute", relation )), + Self::IO(error) => error.fmt(f), } } } diff --git a/core/src/runtime/database/extensional/database.rs b/core/src/runtime/database/extensional/database.rs index 3cfbbc7..824a148 100644 --- a/core/src/runtime/database/extensional/database.rs +++ b/core/src/runtime/database/extensional/database.rs @@ -5,6 +5,7 @@ use crate::common::tuple::*; use crate::common::tuple_type::*; use crate::compiler::ram; use crate::runtime::dynamic::*; +use crate::runtime::env::*; use crate::runtime::monitor::*; use crate::runtime::provenance::*; use crate::runtime::statics::*; @@ -22,6 +23,9 @@ pub struct ExtensionalDatabase { /// Types of relations pub relation_types: HashMap, + /// Loaded Files + pub input_file_registry: io::InputFileRegistry, + /// Extensional relations pub extensional_relations: HashMap>, @@ -36,6 +40,7 @@ impl ExtensionalDatabase { type_check: true, disjunction_count: 0, relation_types: HashMap::new(), + input_file_registry: io::InputFileRegistry::new(), extensional_relations: HashMap::new(), internalized: false, } @@ -46,6 +51,7 @@ impl ExtensionalDatabase { type_check, disjunction_count: 0, relation_types: HashMap::new(), + input_file_registry: io::InputFileRegistry::new(), extensional_relations: HashMap::new(), internalized: false, } @@ -59,10 +65,15 @@ impl ExtensionalDatabase { type_check: self.type_check, disjunction_count: self.disjunction_count, relation_types: self.relation_types.clone(), - extensional_relations: self.extensional_relations.iter().map(|(pred, rel)| { - let new_rel = rel.clone_with_new_provenance(); - (pred.clone(), new_rel) - }).collect(), + input_file_registry: self.input_file_registry.clone(), + extensional_relations: self + .extensional_relations + .iter() + .map(|(pred, rel)| { + let new_rel = rel.clone_with_new_provenance(); + (pred.clone(), new_rel) + }) + .collect(), internalized: false, } } @@ -75,6 +86,7 @@ impl ExtensionalDatabase { type_check: true, disjunction_count: 0, relation_types: types.collect(), + input_file_registry: io::InputFileRegistry::new(), extensional_relations: HashMap::new(), internalized: false, } @@ -88,6 +100,7 @@ impl ExtensionalDatabase { type_check, disjunction_count: 0, relation_types: types.collect(), + input_file_registry: io::InputFileRegistry::new(), extensional_relations: HashMap::new(), internalized: false, } @@ -105,7 +118,11 @@ impl ExtensionalDatabase { self.extensional_relations.contains_key(relation) } - pub fn add_dynamic_input_facts(&mut self, relation: &str, facts: Vec<(DynamicInputTag, T)>) -> Result<(), DatabaseError> + pub fn add_dynamic_input_facts( + &mut self, + relation: &str, + facts: Vec<(DynamicInputTag, T)>, + ) -> Result<(), DatabaseError> where T: Into, { @@ -182,18 +199,23 @@ impl ExtensionalDatabase { } } - pub fn populate_program_facts(&mut self, program: &ram::Program) -> Result<(), DatabaseError> { + pub fn populate_program_facts( + &mut self, + env: &RuntimeEnvironment, + program: &ram::Program, + ) -> Result<(), DatabaseError> { + // Load and cache all the input files + self.input_file_registry.load(program).map_err(DatabaseError::IO)?; + // Iterate through all relations declared in the program for relation in program.relations() { // Check if we need to load the relation facts if !relation.facts.is_empty() { - let edb_relation = self - .extensional_relations - .entry(relation.predicate.clone()) - .or_default(); + let edb_relation = self.get_or_insert_relation(&relation.predicate); if edb_relation.internalized_program_facts && edb_relation.has_program_facts() && edb_relation.num_program_facts() < relation.facts.len() + // TODO: why? { return Err(DatabaseError::NewProgramFacts { relation: relation.predicate.clone(), @@ -203,10 +225,17 @@ impl ExtensionalDatabase { } // Check if we need to load external facts (from files or databases) - if relation.input_file.is_some() { - unimplemented!("Cannot load external file yet"); + if let Some(input_file_config) = &relation.input_file { + let edb_relation = self + .extensional_relations + .entry(relation.predicate.to_string()) + .or_default(); + let loaded_file_content = self.input_file_registry.get(input_file_config.file_path()).unwrap(); // unwrap since this has to be populated before + edb_relation.load_from_file(env, input_file_config, loaded_file_content, &relation.tuple_type)?; } } + + // Return Ok(()) } @@ -224,16 +253,16 @@ impl ExtensionalDatabase { } } - pub fn internalize(&mut self, ctx: &mut Prov) { + pub fn internalize(&mut self, env: &RuntimeEnvironment, ctx: &Prov) { for (_, relation) in &mut self.extensional_relations { - relation.internalize(ctx); + relation.internalize(env, ctx); } self.internalized = true } - pub fn internalize_with_monitor>(&mut self, ctx: &mut Prov, m: &M) { + pub fn internalize_with_monitor>(&mut self, env: &RuntimeEnvironment, ctx: &Prov, m: &M) { for (_, relation) in &mut self.extensional_relations { - relation.internalize_with_monitor(ctx, m); + relation.internalize_with_monitor(env, ctx, m); } self.internalized = true } @@ -266,6 +295,10 @@ impl ExtensionalDatabase { Ok(()) } } + + fn get_or_insert_relation(&mut self, relation: &str) -> &mut ExtensionalRelation { + self.extensional_relations.entry(relation.to_string()).or_default() + } } impl ExtensionalDatabase diff --git a/core/src/runtime/database/extensional/relation.rs b/core/src/runtime/database/extensional/relation.rs index 0d119f9..41b4b9f 100644 --- a/core/src/runtime/database/extensional/relation.rs +++ b/core/src/runtime/database/extensional/relation.rs @@ -1,9 +1,20 @@ +use std::collections::*; +use std::path::*; + +use crate::common::input_file::*; use crate::common::input_tag::*; use crate::common::tuple::*; +use crate::common::tuple_type::*; +use crate::common::value::*; +use crate::common::value_type::*; use crate::runtime::dynamic::*; +use crate::runtime::env::*; +use crate::runtime::error::*; use crate::runtime::monitor::*; use crate::runtime::provenance::*; +use super::*; + #[derive(Clone, Debug)] pub struct ExtensionalRelation { /// The facts from the program @@ -13,6 +24,9 @@ pub struct ExtensionalRelation { /// round of internalization of program facts pub internalized_program_facts: bool, + /// Loaded files + pub loaded_files: BTreeSet, + /// Dynamically tagged input facts dynamic_input: Vec<(DynamicInputTag, Tuple)>, @@ -37,6 +51,7 @@ impl ExtensionalRelation { Self { program_facts: vec![], internalized_program_facts: false, + loaded_files: BTreeSet::new(), dynamic_input: vec![], static_input: vec![], internal: DynamicCollection::empty(), @@ -51,11 +66,18 @@ impl ExtensionalRelation { ExtensionalRelation { program_facts: self.program_facts.clone(), internalized_program_facts: false, + loaded_files: BTreeSet::new(), dynamic_input: self.dynamic_input.clone(), - static_input: self.static_input.iter().map(|(tag, tuple)| { - let new_tag = tag.as_ref().and_then(|tag| ConvertFromInputTag::from_input_tag(tag.clone())); - (new_tag, tuple.clone()) - }).collect(), + static_input: self + .static_input + .iter() + .map(|(tag, tuple)| { + let new_tag = tag + .as_ref() + .and_then(|tag| ConvertFromInputTag::from_input_tag(tag.clone())); + (new_tag, tuple.clone()) + }) + .collect(), internal: DynamicCollection::empty(), internalized: false, } @@ -100,16 +122,40 @@ impl ExtensionalRelation { self.static_input.extend(facts) } - pub fn internalize(&mut self, ctx: &mut Prov) { + pub fn load_from_file( + &mut self, + env: &RuntimeEnvironment, + file: &InputFile, + loaded_file_content: &io::InputFileContent, + types: &TupleType, + ) -> Result<(), DatabaseError> { + // Do not load again + if self.loaded_files.contains(file.file_path()) { + return Ok(()); + } + + // Load the file into the dynamic input + let result = load_from_file(env, file, loaded_file_content, types).map_err(DatabaseError::IO)?; + self.dynamic_input.extend(result); + + // If succeeded, add the file into the file_path + self.loaded_files.insert(file.file_path().clone()); + + // Return ok + Ok(()) + } + + pub fn internalize(&mut self, env: &RuntimeEnvironment, ctx: &Prov) { let mut elems: Vec> = Vec::new(); // First internalize program facts, only if there is program facts if !self.program_facts.is_empty() { // Iterate (not drain) the program facts elems.extend(self.program_facts.iter().map(|(tag, tup)| { + let int_tup = env.internalize_tuple(tup); let maybe_input_tag = StaticInputTag::from_dynamic_input_tag(&tag); let tag = ctx.tagging_optional_fn(maybe_input_tag); - DynamicElement::new(tup.clone(), tag) + DynamicElement::new(int_tup, tag) })); // Set the internalization to `true` @@ -118,15 +164,17 @@ impl ExtensionalRelation { // First internalize dynamic input facts elems.extend(self.dynamic_input.drain(..).map(|(tag, tup)| { + let int_tup = env.internalize_tuple(&tup); let maybe_input_tag = StaticInputTag::from_dynamic_input_tag(&tag); let tag = ctx.tagging_optional_fn(maybe_input_tag); - DynamicElement::new(tup, tag) + DynamicElement::new(int_tup, tag) })); // Then internalize static input facts elems.extend(self.static_input.drain(..).map(|(tag, tup)| { + let int_tup = env.internalize_tuple(&tup); let tag = ctx.tagging_optional_fn(tag); - DynamicElement::new(tup, tag) + DynamicElement::new(int_tup, tag) })); // Add existed facts @@ -137,20 +185,21 @@ impl ExtensionalRelation { self.internalized = true; } - pub fn internalize_with_monitor>(&mut self, ctx: &mut Prov, m: &M) { + pub fn internalize_with_monitor>(&mut self, env: &RuntimeEnvironment, ctx: &Prov, m: &M) { let mut elems: Vec> = Vec::new(); // First internalize program facts, only if there is program facts if !self.program_facts.is_empty() { // Iterate (not drain) the program facts elems.extend(self.program_facts.iter().map(|(tag, tup)| { + let int_tup = env.internalize_tuple(tup); let maybe_input_tag = StaticInputTag::from_dynamic_input_tag(&tag); let tag = ctx.tagging_optional_fn(maybe_input_tag.clone()); // !SPECIAL MONITORING! - m.observe_tagging(tup, &maybe_input_tag, &tag); + m.observe_tagging(&int_tup, &maybe_input_tag, &tag); - DynamicElement::new(tup.clone(), tag) + DynamicElement::new(int_tup, tag) })); // Set the internalization to `true` @@ -159,23 +208,25 @@ impl ExtensionalRelation { // First internalize dynamic input facts elems.extend(self.dynamic_input.drain(..).map(|(tag, tup)| { + let int_tup = env.internalize_tuple(&tup); let maybe_input_tag = StaticInputTag::from_dynamic_input_tag(&tag); let tag = ctx.tagging_optional_fn(maybe_input_tag.clone()); // !SPECIAL MONITORING! m.observe_tagging(&tup, &maybe_input_tag, &tag); - DynamicElement::new(tup, tag) + DynamicElement::new(int_tup, tag) })); // Then internalize static input facts elems.extend(self.static_input.drain(..).map(|(input_tag, tup)| { + let int_tup = env.internalize_tuple(&tup); let tag = ctx.tagging_optional_fn(input_tag.clone()); // !SPECIAL MONITORING! m.observe_tagging(&tup, &input_tag, &tag); - DynamicElement::new(tup, tag) + DynamicElement::new(int_tup, tag) })); // Add existed facts @@ -186,3 +237,390 @@ impl ExtensionalRelation { self.internalized = true; } } + +fn load_from_file( + env: &RuntimeEnvironment, + input_file: &InputFile, + loaded_file_content: &io::InputFileContent, + types: &TupleType, +) -> Result, IOError> { + match (input_file, loaded_file_content) { + ( + InputFile::Csv { + keys, + fields, + has_probability, + .. + }, + io::InputFileContent::CSV(content), + ) => load_from_csv(env, keys, fields, *has_probability, content, types), + } +} + +fn load_from_csv( + env: &RuntimeEnvironment, + keys: &Option>, + fields: &Option>, + has_probability: bool, + loaded_file_content: &io::CSVFileContent, + types: &TupleType, +) -> Result, IOError> { + match (keys, fields) { + (Some(keys), Some(fields)) => { + load_from_csv_with_keys_and_fields(env, keys, fields, has_probability, loaded_file_content, types) + } + (Some(keys), None) => load_from_csv_with_keys(env, keys, has_probability, loaded_file_content, types), + (None, Some(fields)) => load_from_csv_with_fields(fields, has_probability, loaded_file_content, types), + (None, None) => load_from_csv_raw(has_probability, loaded_file_content, types), + } +} + +fn load_from_csv_with_keys_and_fields( + env: &RuntimeEnvironment, + keys: &Vec, + fields: &Vec, + has_probability: bool, + loaded_file_content: &io::CSVFileContent, + types: &TupleType, +) -> Result, IOError> { + // Get the value types and the probability offset + let value_types = get_value_types(types)?; + let probability_offset = if has_probability { 1 } else { 0 }; + + // Check arity first + if value_types.len() != keys.len() + 2 { + return Err(IOError::ArityMismatch { + expected: keys.len() + 2, + found: value_types.len(), + }); + } + + // Check types + match (value_types[value_types.len() - 2], value_types[value_types.len() - 1]) { + (ValueType::Symbol, ValueType::String) => { /* GOOD */ } + (ValueType::Symbol, value_type) => { + return Err(IOError::ExpectStringType { + actual: value_type.clone(), + }) + } + (field_type, _) => { + return Err(IOError::ExpectSymbolType { + actual: field_type.clone(), + }) + } + } + + // Get the indices of the keys + let key_value_ids = keys + .iter() + .zip(value_types.iter()) + .map(|(k, t)| { + let id = loaded_file_content + .get_header_id(k) + .ok_or(IOError::CannotFindField { field: k.clone() })?; + Ok((id, *t)) + }) + .collect::, _>>()?; + + // Get the set of fields to include + let fields_id_set = fields + .iter() + .map(|f| { + loaded_file_content + .get_header_id(f) + .ok_or(IOError::CannotFindField { field: f.clone() }) + }) + .collect::, _>>()?; + + // Field values + let field_symbols: Vec = loaded_file_content + .iter_headers() + .map(|h| Value::Symbol(env.symbol_registry.register(h.clone()))) + .collect(); + + // Cache the results and process each row + let mut result = vec![]; + for row in loaded_file_content.get_rows() { + // TODO: this could generate multiple facts with the duplicated same tag + // Get the tag + let tag = if has_probability { + let s = row.get(0).ok_or(IOError::IndexOutOfBounds { index: 0 })?; + s.parse::() + .map_err(|_| IOError::CannotParseProbability { value: s.to_string() })? + } else { + DynamicInputTag::None + }; + + // Get the keys + let keys = key_value_ids + .iter() + .map(|(i, t)| { + let s = row.get(*i).ok_or(IOError::IndexOutOfBounds { index: *i })?; + t.parse(s).map_err(|e| IOError::ValueParseError { error: e }) + }) + .collect::, _>>()?; + + // Get all other values + let field_values = row + .iter() + .enumerate() + .skip(probability_offset) // we skip the tag in the front + .filter(|(i, _)| { + // We want to skip the fields that are keys + key_value_ids.iter().find(|(j, _)| i == j).is_none() && fields_id_set.contains(i) + }) + .map(|(i, s)| { + let field = field_symbols + .get(i) + .ok_or(IOError::IndexOutOfBounds { index: i })? + .clone(); + let value = Value::String(s.to_string()); + Ok((field, value)) + }) + .collect::, _>>()?; + + // Create the results + let curr_results = field_values + .into_iter() + .map(|(field, value)| { + let tuple = Tuple::from_values(keys.iter().cloned().chain(vec![field, value])); + let element = (tag.clone(), tuple); + element + }) + .collect::>(); + + // Extend the results + result.extend(curr_results); + } + + // Return + Ok(result) +} + +fn load_from_csv_with_keys( + env: &RuntimeEnvironment, + keys: &Vec, + has_probability: bool, + loaded_file_content: &io::CSVFileContent, + types: &TupleType, +) -> Result, IOError> { + // Get the value types and the probability offset + let value_types = get_value_types(types)?; + let probability_offset = if has_probability { 1 } else { 0 }; + + // Check arity first + if value_types.len() != keys.len() + 2 { + return Err(IOError::ArityMismatch { + expected: keys.len() + 2, + found: value_types.len(), + }); + } + + // Check types + match (value_types[value_types.len() - 2], value_types[value_types.len() - 1]) { + (ValueType::Symbol, ValueType::String) => { /* GOOD */ } + (ValueType::Symbol, value_type) => { + return Err(IOError::ExpectStringType { + actual: value_type.clone(), + }) + } + (field_type, _) => { + return Err(IOError::ExpectSymbolType { + actual: field_type.clone(), + }) + } + } + + // Get the indices of the keys + let key_value_ids = keys + .iter() + .zip(value_types.iter()) + .map(|(k, t)| { + let id = loaded_file_content + .get_header_id(k) + .ok_or(IOError::CannotFindField { field: k.clone() })?; + Ok((id, *t)) + }) + .collect::, _>>()?; + + // Field values + let field_symbols: Vec = loaded_file_content + .iter_headers() + .map(|h| Value::Symbol(env.symbol_registry.register(h.clone()))) + .collect(); + + // Cache the results and process each row + let mut result = vec![]; + for row in loaded_file_content.get_rows() { + // TODO: this could generate multiple facts with the duplicated same tag + // Get the tag + let tag = if has_probability { + let s = row.get(0).ok_or(IOError::IndexOutOfBounds { index: 0 })?; + s.parse::() + .map_err(|_| IOError::CannotParseProbability { value: s.to_string() })? + } else { + DynamicInputTag::None + }; + + // Get the keys + let keys = key_value_ids + .iter() + .map(|(i, t)| { + let s = row.get(*i).ok_or(IOError::IndexOutOfBounds { index: *i })?; + t.parse(s).map_err(|e| IOError::ValueParseError { error: e }) + }) + .collect::, _>>()?; + + // Get all other values + let field_values = row + .iter() + .enumerate() + .skip(probability_offset) // we skip the tag in the front + .filter(|(i, _)| { + // We want to skip the fields that are keys + key_value_ids.iter().find(|(j, _)| i == j).is_none() + }) + .map(|(i, s)| { + let field = field_symbols + .get(i) + .ok_or(IOError::IndexOutOfBounds { index: i })? + .clone(); + let value = Value::String(s.to_string()); + Ok((field, value)) + }) + .collect::, _>>()?; + + // Create the results + let curr_results = field_values + .into_iter() + .map(|(field, value)| { + let tuple = Tuple::from_values(keys.iter().cloned().chain(vec![field, value])); + let element = (tag.clone(), tuple); + element + }) + .collect::>(); + + // Extend the results + result.extend(curr_results); + } + + // Return + Ok(result) +} + +fn load_from_csv_with_fields( + fields: &Vec, + has_probability: bool, + loaded_file_content: &io::CSVFileContent, + types: &TupleType, +) -> Result, IOError> { + // Get the value types and the probability offset + let value_types = get_value_types(types)?; + + // Check arity first + if value_types.len() != fields.len() { + return Err(IOError::ArityMismatch { + expected: fields.len(), + found: value_types.len(), + }); + } + + // Get the indices of the keys + let id_type_pairs = fields + .iter() + .zip(value_types.iter()) + .map(|(f, t)| { + let id = loaded_file_content + .get_header_id(f) + .ok_or(IOError::CannotFindField { field: f.clone() })?; + Ok((id, *t)) + }) + .collect::, _>>()?; + + // Cache the results and process each row + let mut result = vec![]; + for row in loaded_file_content.get_rows() { + // Get the tag + let tag = if has_probability { + let s = row.get(0).ok_or(IOError::IndexOutOfBounds { index: 0 })?; + s.parse::() + .map_err(|_| IOError::CannotParseProbability { value: s.to_string() })? + } else { + DynamicInputTag::None + }; + + // Get the tuple + let values = id_type_pairs + .iter() + .map(|(id, t)| { + let value = row.get(*id).ok_or(IOError::IndexOutOfBounds { index: *id })?; + t.parse(value).map_err(|e| IOError::ValueParseError { error: e }) + }) + .collect::, _>>()?; + + // Create the tagged-tuple + let tagged_tuple = (tag, Tuple::from(values)); + result.push(tagged_tuple); + } + + Ok(result) +} + +fn load_from_csv_raw( + has_probability: bool, + loaded_file_content: &io::CSVFileContent, + types: &TupleType, +) -> Result, IOError> { + // Get the value types and the probability offset + let value_types = get_value_types(types)?; + let probability_offset = if has_probability { 1 } else { 0 }; + + // Cache the results and process each row + let mut result = vec![]; + for row in loaded_file_content.get_rows() { + // Check arity + if row.len() != value_types.len() + probability_offset { + return Err(IOError::ArityMismatch { + expected: row.len(), + found: value_types.len() + probability_offset, + }); + } + + // Get the tag + let tag = if has_probability { + let s = row.get(0).ok_or(IOError::IndexOutOfBounds { index: 0 })?; + s.parse::() + .map_err(|_| IOError::CannotParseProbability { value: s.to_string() })? + } else { + DynamicInputTag::None + }; + + // Get the tuple + let values = row + .iter() + .skip(probability_offset) + .zip(value_types.iter()) + .map(|(r, t)| t.parse(r).map_err(|e| IOError::ValueParseError { error: e })) + .collect::, _>>()?; + + // Create the tagged-tuple + let tagged_tuple = (tag, Tuple::from(values)); + result.push(tagged_tuple); + } + + Ok(result) +} + +fn get_value_types(types: &TupleType) -> Result, IOError> { + match types { + TupleType::Tuple(ts) => ts + .iter() + .map(|t| match t { + TupleType::Value(v) => Some(v), + _ => None, + }) + .collect::>>() + .ok_or(IOError::InvalidType { types: types.clone() }), + TupleType::Value(_) => Err(IOError::InvalidType { types: types.clone() }), + } +} diff --git a/core/src/runtime/database/intentional/database.rs b/core/src/runtime/database/intentional/database.rs index 26cd1d3..54699cc 100644 --- a/core/src/runtime/database/intentional/database.rs +++ b/core/src/runtime/database/intentional/database.rs @@ -1,7 +1,8 @@ use std::collections::*; -use crate::runtime::database::extensional::ExtensionalRelation; +use crate::runtime::database::extensional::*; use crate::runtime::dynamic::*; +use crate::runtime::env::*; use crate::runtime::monitor::*; use crate::runtime::provenance::*; use crate::utils::*; @@ -74,35 +75,49 @@ impl IntentionalDatabase { self.intentional_relations.contains_key(relation) } - pub fn recover_from_edb(&mut self, relation: &str, ctx: &Prov, edb_relation: &ExtensionalRelation) { + pub fn recover_from_edb( + &mut self, + relation: &str, + env: &RuntimeEnvironment, + ctx: &Prov, + edb_relation: &ExtensionalRelation, + ) { self.intentional_relations.insert( relation.to_string(), IntentionalRelation { recovered: true, internal_facts: DynamicCollection::empty(), - recovered_facts: Ptr::new_rc(DynamicOutputCollection::from( - edb_relation - .internal - .iter() - .map(|elem| (ctx.recover_fn(&elem.tag), elem.tuple.clone())), - )), + recovered_facts: Ptr::new_rc(DynamicOutputCollection::from(edb_relation.internal.iter().map( + |elem| { + let tag = ctx.recover_fn(&elem.tag); + let tup = env.externalize_tuple(&elem.tuple); + (tag, tup) + }, + ))), }, ); } /// Recover the output collection for a relation - pub fn recover(&mut self, relation: &str, ctx: &Prov, drain: bool) { + pub fn recover(&mut self, relation: &str, env: &RuntimeEnvironment, ctx: &Prov, drain: bool) { if let Some(r) = self.intentional_relations.get_mut(relation) { - r.recover(ctx, drain); + r.recover(env, ctx, drain); } } /// Recover the output collection for a relation, with a monitor - pub fn recover_with_monitor>(&mut self, relation: &str, ctx: &Prov, m: &M, drain: bool) { + pub fn recover_with_monitor>( + &mut self, + relation: &str, + env: &RuntimeEnvironment, + ctx: &Prov, + m: &M, + drain: bool, + ) { if let Some(r) = self.intentional_relations.get_mut(relation) { // !SPECIAL MONITORING! m.observe_recovering_relation(relation); - r.recover_with_monitor(ctx, m, drain); + r.recover_with_monitor(env, ctx, m, drain); } } diff --git a/core/src/runtime/database/intentional/relation.rs b/core/src/runtime/database/intentional/relation.rs index 4650a3b..724894f 100644 --- a/core/src/runtime/database/intentional/relation.rs +++ b/core/src/runtime/database/intentional/relation.rs @@ -1,5 +1,6 @@ -use crate::runtime::dynamic::{DynamicCollection, DynamicOutputCollection}; -use crate::runtime::monitor::Monitor; +use crate::runtime::dynamic::*; +use crate::runtime::env::*; +use crate::runtime::monitor::*; use crate::runtime::provenance::*; use crate::utils::PointerFamily; @@ -70,22 +71,29 @@ impl IntentionalRelation { } } - pub fn recover_with_monitor>(&mut self, ctx: &Prov, m: &M, drain: bool) { + pub fn recover(&mut self, env: &RuntimeEnvironment, ctx: &Prov, drain: bool) { // Only recover if it is not recovered - if !self.recovered && !self.internal_facts.is_empty() { + if !self.recovered { + // Shortcut: if there is no internal facts, then there is nothing to recover + if self.internal_facts.is_empty() { + self.recovered = true; + return; + } + + // Check if we need to drain the internal facts if drain { // Add internal facts to recovered facts, and remove the internal facts Ptr::get_rc_mut(&mut self.recovered_facts).extend(self.internal_facts.drain().map(|elem| { + let output_tup = env.externalize_tuple(&elem.tuple); let output_tag = ctx.recover_fn(&elem.tag); - m.observe_recover(&elem.tuple, &elem.tag, &output_tag); - (output_tag, elem.tuple) + (output_tag, output_tup) })); } else { // Add internal facts to recover facts, do not remove the internal facts Ptr::get_rc_mut(&mut self.recovered_facts).extend(self.internal_facts.iter().map(|elem| { + let output_tup = env.externalize_tuple(&elem.tuple); let output_tag = ctx.recover_fn(&elem.tag); - m.observe_recover(&elem.tuple, &elem.tag, &output_tag); - (output_tag, elem.tuple.clone()) + (output_tag, output_tup) })); } @@ -94,27 +102,24 @@ impl IntentionalRelation { } } - pub fn recover(&mut self, ctx: &Prov, drain: bool) { + pub fn recover_with_monitor>(&mut self, env: &RuntimeEnvironment, ctx: &Prov, m: &M, drain: bool) { // Only recover if it is not recovered - if !self.recovered { - // Shortcut: if there is no internal facts, then there is nothing to recover - if self.internal_facts.is_empty() { - self.recovered = true; - return; - } - - // Check if we need to drain the internal facts + if !self.recovered && !self.internal_facts.is_empty() { if drain { // Add internal facts to recovered facts, and remove the internal facts Ptr::get_rc_mut(&mut self.recovered_facts).extend(self.internal_facts.drain().map(|elem| { + let output_tup = env.externalize_tuple(&elem.tuple); let output_tag = ctx.recover_fn(&elem.tag); - (output_tag, elem.tuple) + m.observe_recover(&output_tup, &elem.tag, &output_tag); + (output_tag, output_tup) })); } else { // Add internal facts to recover facts, do not remove the internal facts Ptr::get_rc_mut(&mut self.recovered_facts).extend(self.internal_facts.iter().map(|elem| { + let output_tup = env.externalize_tuple(&elem.tuple); let output_tag = ctx.recover_fn(&elem.tag); - (output_tag, elem.tuple.clone()) + m.observe_recover(&output_tup, &elem.tag, &output_tag); + (output_tag, output_tup) })); } diff --git a/core/src/runtime/database/io/input_file_content.rs b/core/src/runtime/database/io/input_file_content.rs new file mode 100644 index 0000000..955095a --- /dev/null +++ b/core/src/runtime/database/io/input_file_content.rs @@ -0,0 +1,102 @@ +use std::collections::*; +use std::path::*; + +use crate::common::input_file::InputFile; +use crate::runtime::error::IOError; + +#[derive(Debug, Clone)] +pub enum InputFileContent { + CSV(CSVFileContent), +} + +impl InputFileContent { + pub fn load(input_file: &InputFile) -> Result { + match input_file { + InputFile::Csv { + file_path, + deliminator, + has_header, + .. + } => CSVFileContent::from_file(file_path.clone(), *deliminator, *has_header).map(InputFileContent::CSV), + } + } +} + +#[derive(Debug, Clone)] +pub struct CSVFileContent { + fields: Vec, + field_id_map: BTreeMap, + rows: Vec>, +} + +impl CSVFileContent { + pub fn from_file(file_path: PathBuf, deliminator: u8, has_header: bool) -> Result { + let mut rdr = csv::ReaderBuilder::new() + .delimiter(deliminator) + .has_headers(has_header) + .from_path(file_path.clone()) + .map_err(|e| IOError::CannotOpenFile { + file_path, + error: format!("{}", e), + })?; + + // Generate fields + let (fields, field_id_map) = if has_header { + let fields = rdr + .headers() + .map_err(|e| IOError::CannotReadHeader { + error: format!("{}", e), + })? + .iter() + .map(|s| s.to_string()) + .collect::>(); + let field_id_map = fields.iter().enumerate().map(|(i, s)| (s.clone(), i)).collect(); + (fields, field_id_map) + } else { + (vec![], BTreeMap::new()) + }; + + // Generate rows + let mut rows = Vec::new(); + for result in rdr.records() { + let record = result.map_err(|e| IOError::CannotParseCSV { + error: format!("{}", e), + })?; + rows.push(record.iter().map(|s| s.to_string()).collect()); + } + + Ok(Self { + fields, + field_id_map, + rows, + }) + } + + pub fn num_columns(&self) -> usize { + self.fields.len() + } + + pub fn num_rows(&self) -> usize { + self.rows.len() + } + + pub fn get_header_id(&self, header: &str) -> Option { + self.field_id_map.get(header).copied() + } + + pub fn get_ith_header(&self, i: usize) -> Option<&String> { + self.fields.get(i) + } + + pub fn headers(&self) -> &Vec { + &self.fields + } + + pub fn iter_headers(&self) -> impl Iterator { + self.fields.iter() + } + + pub fn get_rows(&self) -> impl Iterator> { + self.rows.iter() + } +} diff --git a/core/src/runtime/database/io/input_file_registry.rs b/core/src/runtime/database/io/input_file_registry.rs new file mode 100644 index 0000000..ffd9419 --- /dev/null +++ b/core/src/runtime/database/io/input_file_registry.rs @@ -0,0 +1,60 @@ +use std::collections::*; +use std::path::*; + +use crate::compiler::ram; +use crate::runtime::error::*; +use crate::utils::*; + +use super::*; + +#[derive(Debug)] +pub struct InputFileRegistry { + pub input_files: Ptr::Rc>, +} + +impl Clone for InputFileRegistry { + fn clone(&self) -> Self { + Self { + input_files: ArcFamily::clone_rc(&self.input_files), + } + } +} + +impl Clone for InputFileRegistry { + fn clone(&self) -> Self { + Self { + input_files: RcFamily::clone_rc(&self.input_files), + } + } +} + +impl InputFileRegistry { + pub fn new() -> Self { + Self { + input_files: Ptr::new_rc(HashMap::new()), + } + } + + pub fn load(&mut self, program: &ram::Program) -> Result<(), IOError> { + // Iterate through all the input files in the program + program.input_files().try_for_each(|input_file| { + if Ptr::get_rc(&self.input_files).contains_key(input_file.file_path()) { + // Do nothing; the file is already loaded + Ok(()) + } else { + // Load the file first; will fail if error happens + let input_file_content = InputFileContent::load(input_file)?; + + // Insert into the registry + Ptr::get_rc_mut(&mut self.input_files).insert(input_file.file_path().to_path_buf(), input_file_content); + + // Success + Ok(()) + } + }) + } + + pub fn get(&self, file_path: &PathBuf) -> Option<&InputFileContent> { + Ptr::get_rc(&self.input_files).get(file_path) + } +} diff --git a/core/src/runtime/database/io/mod.rs b/core/src/runtime/database/io/mod.rs new file mode 100644 index 0000000..bb1feba --- /dev/null +++ b/core/src/runtime/database/io/mod.rs @@ -0,0 +1,7 @@ +mod input_file_content; +mod input_file_registry; +mod output_file; + +pub use input_file_content::*; +pub use input_file_registry::*; +pub use output_file::*; diff --git a/core/src/runtime/database/io/output_file.rs b/core/src/runtime/database/io/output_file.rs new file mode 100644 index 0000000..a622352 --- /dev/null +++ b/core/src/runtime/database/io/output_file.rs @@ -0,0 +1,43 @@ +use std::fs::File; +use std::path::*; + +use csv::WriterBuilder; + +use crate::common::output_option::*; +use crate::runtime::error::*; +use crate::runtime::provenance::*; +use crate::utils::*; + +use super::super::*; + +pub fn store_file( + output_file: &OutputFile, + idb_relation: &intentional::IntentionalRelation, +) -> Result<(), IOError> { + match output_file { + OutputFile::CSV(f) => store_csv_file(&f.file_path, f.deliminator, idb_relation), + } +} + +pub fn store_csv_file( + file_path: &PathBuf, + deliminator: u8, + idb_relation: &intentional::IntentionalRelation, +) -> Result<(), IOError> { + // Then load the file + let file = File::create(file_path).map_err(|e| IOError::CannotOpenFile { + file_path: file_path.clone(), + error: format!("{}", e), + })?; + + // Write the tuples to the file + let mut wtr = WriterBuilder::new().delimiter(deliminator).from_writer(file); + for (_, tuple) in Ptr::get_rc(&idb_relation.recovered_facts).iter() { + let record = tuple.as_ref_values().into_iter().map(|v| format!("{}", v)); + wtr + .write_record(record) + .map_err(|e| IOError::CannotWriteRecord { error: e.to_string() })?; + } + + Ok(()) +} diff --git a/core/src/runtime/database/mod.rs b/core/src/runtime/database/mod.rs index 8e46b60..bd06d77 100644 --- a/core/src/runtime/database/mod.rs +++ b/core/src/runtime/database/mod.rs @@ -1,5 +1,6 @@ mod error; pub mod extensional; pub mod intentional; +pub mod io; pub use error::*; diff --git a/core/src/runtime/dynamic/dataflow/batching/batch.rs b/core/src/runtime/dynamic/dataflow/batching/batch.rs index 4cab6b2..8a9ef60 100644 --- a/core/src/runtime/dynamic/dataflow/batching/batch.rs +++ b/core/src/runtime/dynamic/dataflow/batching/batch.rs @@ -7,7 +7,7 @@ use super::super::*; #[derive(Clone)] pub enum DynamicBatch<'a, Prov: Provenance> { Vec(std::slice::Iter<'a, DynamicElement>), - UntaggedVec(&'a Prov, std::slice::Iter<'a, Tuple>), + UntaggedVec(&'a Prov, std::vec::IntoIter), SourceVec(std::vec::IntoIter>), DynamicRelationStable(DynamicRelationStableBatch<'a, Prov>), DynamicRelationRecent(DynamicRelationRecentBatch<'a, Prov>), @@ -30,7 +30,7 @@ impl<'a, Prov: Provenance> DynamicBatch<'a, Prov> { Self::Vec(v.iter()) } - pub fn untagged_vec(ctx: &'a Prov, v: std::slice::Iter<'a, Tuple>) -> Self { + pub fn untagged_vec(ctx: &'a Prov, v: std::vec::IntoIter) -> Self { Self::UntaggedVec(ctx, v) } @@ -112,7 +112,7 @@ impl<'a, Prov: Provenance> Iterator for DynamicBatch<'a, Prov> { fn next(&mut self) -> Option { match self { Self::Vec(iter) => iter.next().map(Clone::clone), - Self::UntaggedVec(ctx, iter) => iter.next().map(|t| DynamicElement::new(t.clone(), ctx.one())), + Self::UntaggedVec(ctx, iter) => iter.next().map(|t| DynamicElement::new(t, ctx.one())), Self::SourceVec(iter) => iter.next(), Self::DynamicRelationStable(b) => b.next(), Self::DynamicRelationRecent(b) => b.next(), diff --git a/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs b/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs index 7d37597..3123e90 100644 --- a/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs +++ b/core/src/runtime/dynamic/dataflow/dynamic_dataflow.rs @@ -1,7 +1,7 @@ use crate::common::expr::*; -use crate::common::value::*; use crate::common::tuple::*; use crate::common::tuple_type::*; +use crate::common::value::*; use super::*; @@ -36,7 +36,7 @@ impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { Self::Vec(vec) } - pub fn untagged_vec(ctx: &'a Prov, vec: &'a Vec) -> Self { + pub fn untagged_vec(ctx: &'a Prov, vec: Vec) -> Self { Self::UntaggedVec(DynamicUntaggedVec::new(ctx, vec)) } @@ -143,12 +143,7 @@ impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { }) } - pub fn foreign_predicate_ground( - pred: String, - bounded: Vec, - first_iter: bool, - ctx: &'a Prov, - ) -> Self { + pub fn foreign_predicate_ground(pred: String, bounded: Vec, first_iter: bool, ctx: &'a Prov) -> Self { Self::ForeignPredicateGround(ForeignPredicateGroundDataflow { foreign_predicate: pred, bounded_constants: bounded, @@ -157,12 +152,7 @@ impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { }) } - pub fn foreign_predicate_constraint( - self, - pred: String, - args: Vec, - ctx: &'a Prov, - ) -> Self { + pub fn foreign_predicate_constraint(self, pred: String, args: Vec, ctx: &'a Prov) -> Self { Self::ForeignPredicateConstraint(ForeignPredicateConstraintDataflow { dataflow: Box::new(self), foreign_predicate: pred, @@ -171,12 +161,7 @@ impl<'a, Prov: Provenance> DynamicDataflow<'a, Prov> { }) } - pub fn foreign_predicate_join( - self, - pred: String, - args: Vec, - ctx: &'a Prov, - ) -> Self { + pub fn foreign_predicate_join(self, pred: String, args: Vec, ctx: &'a Prov) -> Self { Self::ForeignPredicateJoin(ForeignPredicateJoinDataflow { left: Box::new(self), foreign_predicate: pred, diff --git a/core/src/runtime/dynamic/dataflow/dynamic_exclusion.rs b/core/src/runtime/dynamic/dataflow/dynamic_exclusion.rs index 1235891..9b13dec 100644 --- a/core/src/runtime/dynamic/dataflow/dynamic_exclusion.rs +++ b/core/src/runtime/dynamic/dataflow/dynamic_exclusion.rs @@ -1,7 +1,7 @@ use std::collections::*; -use crate::common::tuple::*; use crate::common::input_tag::*; +use crate::common::tuple::*; use crate::runtime::provenance::*; use crate::utils::*; @@ -59,7 +59,11 @@ pub struct ExclusionOp<'a, Prov: Provenance> { impl<'a, Prov: Provenance> ExclusionOp<'a, Prov> { pub fn new(runtime: &'a RuntimeEnvironment, ctx: &'a Prov, visited_exclusion_map: VisitedExclusionMap) -> Self { - Self { runtime, ctx, visited_exclusion_map } + Self { + runtime, + ctx, + visited_exclusion_map, + } } pub fn apply(&self, left: DynamicBatch<'a, Prov>, right: DynamicBatch<'a, Prov>) -> DynamicBatch<'a, Prov> { @@ -118,18 +122,19 @@ impl<'a, Prov: Provenance> Iterator for DynamicExclusionBatch<'a, Prov> { loop { if let Some(left) = &self.left_curr { // First get an exclusion ID - let exc_id = if let Some(id) = RcFamily::get_rc_cell(&self.visited_exclusion_map, |m| m.get(&left.tuple).cloned()) { - // If the left tuple has been visited, directly pull the exclusion id - id - } else if let Some(id) = self.curr_exclusion_id { - // Or we have already generated a new ID for this tuple - id - } else { - // Otherwise, generate a new ID - let id = self.runtime.allocate_new_exclusion_id(); - RcFamily::get_rc_cell_mut(&self.visited_exclusion_map, |m| m.insert(left.tuple.clone(), id)); - id - }; + let exc_id = + if let Some(id) = RcFamily::get_rc_cell(&self.visited_exclusion_map, |m| m.get(&left.tuple).cloned()) { + // If the left tuple has been visited, directly pull the exclusion id + id + } else if let Some(id) = self.curr_exclusion_id { + // Or we have already generated a new ID for this tuple + id + } else { + // Otherwise, generate a new ID + let id = self.runtime.allocate_new_exclusion_id(); + RcFamily::get_rc_cell_mut(&self.visited_exclusion_map, |m| m.insert(left.tuple.clone(), id)); + id + }; // Then, iterate through the right if let Some(right) = self.right_clone.next() { diff --git a/core/src/runtime/dynamic/dataflow/foreign_predicate/constraint.rs b/core/src/runtime/dynamic/dataflow/foreign_predicate/constraint.rs index dfdfc44..f9f0202 100644 --- a/core/src/runtime/dynamic/dataflow/foreign_predicate/constraint.rs +++ b/core/src/runtime/dynamic/dataflow/foreign_predicate/constraint.rs @@ -21,21 +21,29 @@ pub struct ForeignPredicateConstraintDataflow<'a, Prov: Provenance> { impl<'a, Prov: Provenance> ForeignPredicateConstraintDataflow<'a, Prov> { pub fn iter_stable(&self, runtime: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { - let fp = runtime.predicate_registry.get(&self.foreign_predicate).expect("Foreign predicate not found"); + let fp = runtime + .predicate_registry + .get(&self.foreign_predicate) + .expect("Foreign predicate not found"); DynamicBatches::ForeignPredicateConstraint(ForeignPredicateConstraintBatches { batches: Box::new(self.dataflow.iter_stable(runtime)), foreign_predicate: fp.clone(), args: self.args.clone(), + env: runtime, ctx: self.ctx, }) } pub fn iter_recent(&self, runtime: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { - let fp = runtime.predicate_registry.get(&self.foreign_predicate).expect("Foreign predicate not found"); + let fp = runtime + .predicate_registry + .get(&self.foreign_predicate) + .expect("Foreign predicate not found"); DynamicBatches::ForeignPredicateConstraint(ForeignPredicateConstraintBatches { batches: Box::new(self.dataflow.iter_recent(runtime)), foreign_predicate: fp.clone(), args: self.args.clone(), + env: runtime, ctx: self.ctx, }) } @@ -46,6 +54,7 @@ pub struct ForeignPredicateConstraintBatches<'a, Prov: Provenance> { pub batches: Box>, pub foreign_predicate: DynamicForeignPredicate, pub args: Vec, + pub env: &'a RuntimeEnvironment, pub ctx: &'a Prov, } @@ -58,6 +67,7 @@ impl<'a, Prov: Provenance> Iterator for ForeignPredicateConstraintBatches<'a, Pr batch: Box::new(batch), foreign_predicate: self.foreign_predicate.clone(), args: self.args.clone(), + env: self.env, ctx: self.ctx, }) }) @@ -69,6 +79,7 @@ pub struct ForeignPredicateConstraintBatch<'a, Prov: Provenance> { pub batch: Box>, pub foreign_predicate: DynamicForeignPredicate, pub args: Vec, + pub env: &'a RuntimeEnvironment, pub ctx: &'a Prov, } @@ -80,21 +91,27 @@ impl<'a, Prov: Provenance> Iterator for ForeignPredicateConstraintBatch<'a, Prov let Tagged { tuple, tag } = elem; // Try evaluate the arguments; if failed, continue to the next element in the batch - let values = self.args.iter().map(|arg| { - match arg { + let values = self + .args + .iter() + .map(|arg| match arg { Expr::Access(a) => tuple[a].as_value(), Expr::Constant(c) => c.clone(), - _ => panic!("Invalid argument to bounded foreign predicate") - } - }).collect::>(); + _ => panic!("Invalid argument to bounded foreign predicate"), + }) + .collect::>(); // Evaluate the foreign predicate to produce a list of output tags // Note that there will be at most one output tag since the foreign predicate is bounded - let result = self.foreign_predicate.evaluate(&values); + let result = self.foreign_predicate.evaluate_with_env(self.env, &values); // Check if the foreign predicate returned a tag if !result.is_empty() { - assert_eq!(result.len(), 1, "Bounded foreign predicate should return at most one element per evaluation"); + assert_eq!( + result.len(), + 1, + "Bounded foreign predicate should return at most one element per evaluation" + ); let input_tag = Prov::InputTag::from_dynamic_input_tag(&result[0].0); let new_tag = self.ctx.tagging_optional_fn(input_tag); let combined_tag = self.ctx.mult(&tag, &new_tag); diff --git a/core/src/runtime/dynamic/dataflow/foreign_predicate/ground.rs b/core/src/runtime/dynamic/dataflow/foreign_predicate/ground.rs index 6e6efc6..d4b1281 100644 --- a/core/src/runtime/dynamic/dataflow/foreign_predicate/ground.rs +++ b/core/src/runtime/dynamic/dataflow/foreign_predicate/ground.rs @@ -1,8 +1,8 @@ use crate::common::foreign_predicate::*; -use crate::common::value::*; use crate::common::tuple::*; -use crate::runtime::provenance::*; +use crate::common::value::*; use crate::runtime::env::*; +use crate::runtime::provenance::*; use super::*; @@ -32,7 +32,7 @@ impl<'a, Prov: Provenance> ForeignPredicateGroundDataflow<'a, Prov> { // Evaluate the foreign predicate let elements = foreign_predicate - .evaluate(&self.bounded_constants) + .evaluate_with_env(runtime, &self.bounded_constants) .into_iter() .map(|(input_tag, values)| { let input_tag = StaticInputTag::from_dynamic_input_tag(&input_tag); diff --git a/core/src/runtime/dynamic/dataflow/foreign_predicate/join.rs b/core/src/runtime/dynamic/dataflow/foreign_predicate/join.rs index d5a2409..e2bfbef 100644 --- a/core/src/runtime/dynamic/dataflow/foreign_predicate/join.rs +++ b/core/src/runtime/dynamic/dataflow/foreign_predicate/join.rs @@ -1,9 +1,9 @@ -use crate::common::foreign_predicate::*; use crate::common::expr::*; +use crate::common::foreign_predicate::*; use crate::common::tuple::*; use crate::common::value::*; -use crate::runtime::provenance::*; use crate::runtime::env::*; +use crate::runtime::provenance::*; use super::*; @@ -35,8 +35,13 @@ impl<'a, Prov: Provenance> ForeignPredicateJoinDataflow<'a, Prov> { pub fn iter_stable(&self, runtime: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { DynamicBatches::ForeignPredicateJoin(ForeignPredicateJoinBatches { batches: Box::new(self.left.iter_stable(runtime)), - foreign_predicate: runtime.predicate_registry.get(&self.foreign_predicate).expect("Foreign predicate not found").clone(), + foreign_predicate: runtime + .predicate_registry + .get(&self.foreign_predicate) + .expect("Foreign predicate not found") + .clone(), args: self.args.clone(), + env: runtime, ctx: self.ctx, }) } @@ -44,8 +49,13 @@ impl<'a, Prov: Provenance> ForeignPredicateJoinDataflow<'a, Prov> { pub fn iter_recent(&self, runtime: &'a RuntimeEnvironment) -> DynamicBatches<'a, Prov> { DynamicBatches::ForeignPredicateJoin(ForeignPredicateJoinBatches { batches: Box::new(self.left.iter_recent(runtime)), - foreign_predicate: runtime.predicate_registry.get(&self.foreign_predicate).expect("Foreign predicate not found").clone(), + foreign_predicate: runtime + .predicate_registry + .get(&self.foreign_predicate) + .expect("Foreign predicate not found") + .clone(), args: self.args.clone(), + env: runtime, ctx: self.ctx, }) } @@ -56,6 +66,7 @@ pub struct ForeignPredicateJoinBatches<'a, Prov: Provenance> { pub batches: Box>, pub foreign_predicate: DynamicForeignPredicate, pub args: Vec, + pub env: &'a RuntimeEnvironment, pub ctx: &'a Prov, } @@ -67,9 +78,9 @@ impl<'a, Prov: Provenance> Iterator for ForeignPredicateJoinBatches<'a, Prov> { self.batches.next().map(|mut batch| { // Then, try to get the first element inside of this batch; // if there is an element, we need to evaluate the foreign predicate and produce a current output batch - let first_output_batch = batch.next().map(|elem| { - eval_foreign_predicate(elem, &self.foreign_predicate, &self.args, self.ctx) - }); + let first_output_batch = batch + .next() + .map(|elem| eval_foreign_predicate(elem, &self.foreign_predicate, &self.args, self.env, self.ctx)); // Generate a new batch DynamicBatch::ForeignPredicateJoin(ForeignPredicateJoinBatch { @@ -77,6 +88,7 @@ impl<'a, Prov: Provenance> Iterator for ForeignPredicateJoinBatches<'a, Prov> { foreign_predicate: self.foreign_predicate.clone(), args: self.args.clone(), current_output_batch: first_output_batch, + env: self.env, ctx: self.ctx, }) }) @@ -89,6 +101,7 @@ pub struct ForeignPredicateJoinBatch<'a, Prov: Provenance> { pub foreign_predicate: DynamicForeignPredicate, pub args: Vec, pub current_output_batch: Option<(DynamicElement, std::vec::IntoIter>)>, + pub env: &'a RuntimeEnvironment, pub ctx: &'a Prov, } @@ -100,11 +113,12 @@ impl<'a, Prov: Provenance> Iterator for ForeignPredicateJoinBatch<'a, Prov> { if let Some(right_elem) = current_output_batch.next() { let tuple = (left_elem.tuple.clone(), right_elem.tuple); let new_tag = self.ctx.mult(&left_elem.tag, &right_elem.tag); - return Some(DynamicElement::new(tuple, new_tag)) + return Some(DynamicElement::new(tuple, new_tag)); } else { - self.current_output_batch = self.batch.next().map(|elem| { - eval_foreign_predicate(elem, &self.foreign_predicate, &self.args, self.ctx) - }); + self.current_output_batch = self + .batch + .next() + .map(|elem| eval_foreign_predicate(elem, &self.foreign_predicate, &self.args, self.env, self.ctx)); } } None @@ -116,29 +130,35 @@ fn eval_foreign_predicate( elem: DynamicElement, fp: &DynamicForeignPredicate, args: &Vec, + env: &RuntimeEnvironment, ctx: &Prov, ) -> (DynamicElement, std::vec::IntoIter>) { // First get the arguments to pass to the foreign predicate - let args_to_fp: Vec = args.iter().map(|arg| { - match arg { + let args_to_fp: Vec = args + .iter() + .map(|arg| match arg { Expr::Access(a) => elem.tuple[a].as_value(), Expr::Constant(c) => c.clone(), _ => panic!("Foreign predicate join only supports constant and access arguments"), - } - }).collect(); + }) + .collect(); // Then evaluate the foreign predicate on these arguments - let outputs: Vec<_> = fp.evaluate(&args_to_fp).into_iter().map(|(tag, values)| { - // Make sure to tag the output elements - let input_tag = Prov::InputTag::from_dynamic_input_tag(&tag); - let new_tag = ctx.tagging_optional_fn(input_tag); - - // Generate a tuple from the values produced by the foreign predicate - let tuple = Tuple::from(values); - - // Generate the output element - DynamicElement::new(tuple, new_tag) - }).collect(); + let outputs: Vec<_> = fp + .evaluate_with_env(env, &args_to_fp) + .into_iter() + .map(|(tag, values)| { + // Make sure to tag the output elements + let input_tag = Prov::InputTag::from_dynamic_input_tag(&tag); + let new_tag = ctx.tagging_optional_fn(input_tag); + + // Generate a tuple from the values produced by the foreign predicate + let tuple = Tuple::from(values); + + // Generate the output element + DynamicElement::new(tuple, new_tag) + }) + .collect(); // Return the input element and output elements pair (elem, outputs.into_iter()) diff --git a/core/src/runtime/dynamic/dataflow/mod.rs b/core/src/runtime/dynamic/dataflow/mod.rs index 91cd7ec..d15839a 100644 --- a/core/src/runtime/dynamic/dataflow/mod.rs +++ b/core/src/runtime/dynamic/dataflow/mod.rs @@ -5,8 +5,8 @@ mod utils; mod antijoin; mod difference; mod dynamic_collection; -mod dynamic_exclusion; mod dynamic_dataflow; +mod dynamic_exclusion; mod dynamic_relation; mod filter; mod find; diff --git a/core/src/runtime/dynamic/dataflow/untagged_vec.rs b/core/src/runtime/dynamic/dataflow/untagged_vec.rs index bafdc69..7a538c4 100644 --- a/core/src/runtime/dynamic/dataflow/untagged_vec.rs +++ b/core/src/runtime/dynamic/dataflow/untagged_vec.rs @@ -7,16 +7,16 @@ use super::*; #[derive(Clone)] pub struct DynamicUntaggedVec<'a, Prov: Provenance> { pub ctx: &'a Prov, - pub tuples: &'a Vec, + pub tuples: Vec, } impl<'a, Prov: Provenance> DynamicUntaggedVec<'a, Prov> { - pub fn new(ctx: &'a Prov, tuples: &'a Vec) -> Self { + pub fn new(ctx: &'a Prov, tuples: Vec) -> Self { Self { ctx, tuples } } pub fn iter_recent(&self, _: &RuntimeEnvironment) -> DynamicBatches<'a, Prov> { - DynamicBatches::single(DynamicBatch::untagged_vec(self.ctx, self.tuples.iter())) + DynamicBatches::single(DynamicBatch::untagged_vec(self.ctx, self.tuples.clone().into_iter())) } pub fn iter_stable(&self, _: &RuntimeEnvironment) -> DynamicBatches<'a, Prov> { diff --git a/core/src/runtime/dynamic/incremental.rs b/core/src/runtime/dynamic/incremental.rs index f3153b4..298189d 100644 --- a/core/src/runtime/dynamic/incremental.rs +++ b/core/src/runtime/dynamic/incremental.rs @@ -102,7 +102,7 @@ impl DynamicExecutionContext { &mut self, program: ram::Program, runtime: &RuntimeEnvironment, - ctx: &mut Prov, + ctx: &Prov, ) -> Result<(), RuntimeError> { self.incremental_execute_helper(Some(program), runtime, ctx) } @@ -111,7 +111,7 @@ impl DynamicExecutionContext { &mut self, maybe_new_program: Option, runtime: &RuntimeEnvironment, - ctx: &mut Prov, + ctx: &Prov, ) -> Result<(), RuntimeError> { // Pull the IDB let mut incremental_result = IntentionalDatabase::default(); @@ -122,7 +122,7 @@ impl DynamicExecutionContext { std::mem::swap(&mut self.program, &mut temp_program); let program_ref = if let Some(new_program) = &maybe_new_program { // Process the EDB; populate using program facts - self.edb.populate_program_facts(new_program)?; + self.edb.populate_program_facts(runtime, new_program)?; // If need to incrementalize, remove such computed results let edb_need_update_relations = self.edb.need_update_relations(); @@ -142,17 +142,14 @@ impl DynamicExecutionContext { // Return this new program &new_program } else { - self - .edb - .populate_program_facts(&temp_program) - .expect("Since there is no new program, no error should be raised during program facts population"); + self.edb.populate_program_facts(runtime, &temp_program)?; // If there is no new program, we directly take our current program &temp_program }; // Internalize EDB relations - self.edb.internalize(ctx); + self.edb.internalize(runtime, ctx); // Generate stratum information let strata_info = stratum_inputs_outputs(program_ref); @@ -224,7 +221,7 @@ impl DynamicExecutionContext { ram_program: &ram::Program, strata_info: &StrataInformation, runtime: &RuntimeEnvironment, - ctx: &mut Prov, + ctx: &Prov, ) -> Result, RuntimeError> { let dyn_relas = stratum .relations @@ -285,7 +282,10 @@ impl DynamicExecutionContext { // Check if we need it to be output if self.options.incremental_maintain - || strata_info.stratum_outputs[&stratum_id].contains(rela) + || strata_info + .stratum_outputs + .get(&stratum_id) + .map_or(false, |o| o.contains(rela)) || ram_program.relation_unchecked(rela).output.is_not_hidden() { iter.add_output_relation(rela); @@ -316,12 +316,7 @@ impl DynamicExecutionContext { } /// Directly execute the program stored in the file - pub fn execute_with_monitor( - &mut self, - runtime: &RuntimeEnvironment, - ctx: &mut Prov, - m: &M, - ) -> Result<(), RuntimeError> + pub fn execute_with_monitor(&mut self, runtime: &RuntimeEnvironment, ctx: &Prov, m: &M) -> Result<(), RuntimeError> where M: Monitor, { @@ -332,7 +327,7 @@ impl DynamicExecutionContext { &mut self, program: ram::Program, runtime: &RuntimeEnvironment, - ctx: &mut Prov, + ctx: &Prov, m: &M, ) -> Result<(), RuntimeError> where @@ -345,7 +340,7 @@ impl DynamicExecutionContext { &mut self, maybe_new_program: Option, runtime: &RuntimeEnvironment, - ctx: &mut Prov, + ctx: &Prov, m: &M, ) -> Result<(), RuntimeError> where @@ -360,7 +355,7 @@ impl DynamicExecutionContext { std::mem::swap(&mut self.program, &mut temp_program); let program_ref = if let Some(new_program) = &maybe_new_program { // Process the EDB; populate using program facts - self.edb.populate_program_facts(new_program)?; + self.edb.populate_program_facts(runtime, new_program)?; // If need to incrementalize, remove such computed results let edb_need_update_relations = self.edb.need_update_relations(); @@ -382,7 +377,7 @@ impl DynamicExecutionContext { } else { self .edb - .populate_program_facts(&temp_program) + .populate_program_facts(runtime, &temp_program) .expect("Since there is no new program, no error should be raised during program facts population"); // If there is no new program, we directly take our current program @@ -391,7 +386,7 @@ impl DynamicExecutionContext { // Internalize EDB relations // !SPECIAL MONITORING! - self.edb.internalize_with_monitor(ctx, m); + self.edb.internalize_with_monitor(runtime, ctx, m); // Generate stratum information let strata_info = stratum_inputs_outputs(program_ref); @@ -472,7 +467,7 @@ impl DynamicExecutionContext { ram_program: &ram::Program, strata_info: &StrataInformation, runtime: &RuntimeEnvironment, - ctx: &mut Prov, + ctx: &Prov, m: &M, ) -> Result, RuntimeError> where @@ -613,18 +608,22 @@ impl DynamicExecutionContext { self.idb.get_internal_collection(r) } - pub fn recover(&mut self, r: &str, ctx: &Prov) { + pub fn recover(&mut self, r: &str, runtime: &RuntimeEnvironment, ctx: &Prov) { if self.idb.has_relation(r) { - self.idb.recover(r, ctx, !self.options.retain_internal_when_recover); + self + .idb + .recover(r, runtime, ctx, !self.options.retain_internal_when_recover); } else if self.edb.has_relation(r) { - self.idb.recover_from_edb(r, ctx, &self.edb.extensional_relations[r]); + self + .idb + .recover_from_edb(r, runtime, ctx, &self.edb.extensional_relations[r]); } } - pub fn recover_with_monitor>(&mut self, r: &str, ctx: &Prov, m: &M) { + pub fn recover_with_monitor>(&mut self, r: &str, runtime: &RuntimeEnvironment, ctx: &Prov, m: &M) { self .idb - .recover_with_monitor(r, ctx, m, !self.options.retain_internal_when_recover) + .recover_with_monitor(r, runtime, ctx, m, !self.options.retain_internal_when_recover) } pub fn relation_ref(&self, r: &str) -> Option<&DynamicOutputCollection> { diff --git a/core/src/runtime/dynamic/iteration.rs b/core/src/runtime/dynamic/iteration.rs index de2a212..55cb99f 100644 --- a/core/src/runtime/dynamic/iteration.rs +++ b/core/src/runtime/dynamic/iteration.rs @@ -77,12 +77,15 @@ impl<'a, Prov: Provenance> DynamicIteration<'a, Prov> { while self.need_to_iterate(ctx, &runtime.iter_limit) { // Perform updates for update in &self.updates { - let dyn_update = self.build_dynamic_update(ctx, update); + let dyn_update = self.build_dynamic_update(runtime, ctx, update); dyn_update .target .insert_dataflow_recent(ctx, &dyn_update.dataflow, runtime); } + // Drain from dynamically generated entities in the runtime + self.drain_dynamic_entities(ctx, runtime); + // Update iteration number self.step(); } @@ -96,6 +99,19 @@ impl<'a, Prov: Provenance> DynamicIteration<'a, Prov> { result } + fn drain_dynamic_entities(&mut self, ctx: &Prov, runtime: &RuntimeEnvironment) { + for (relation, tuples) in runtime.drain_new_entities() { + let update = Update { + target: relation, + dataflow: Dataflow::UntaggedVec(tuples), + }; + let dyn_update = self.build_dynamic_update(runtime, ctx, &update); + dyn_update + .target + .insert_dataflow_recent(ctx, &dyn_update.dataflow, runtime); + } + } + fn need_to_iterate(&mut self, ctx: &Prov, iter_limit: &Option) -> bool { // Check if it has been changed if self.changed(ctx) || self.is_first_iteration() { @@ -138,7 +154,7 @@ impl<'a, Prov: Provenance> DynamicIteration<'a, Prov> { // Perform updates for update in &self.updates { - let dyn_update = self.build_dynamic_update(ctx, update); + let dyn_update = self.build_dynamic_update(runtime, ctx, update); dyn_update .target .insert_dataflow_recent(ctx, &dyn_update.dataflow, runtime); @@ -218,14 +234,24 @@ impl<'a, Prov: Provenance> DynamicIteration<'a, Prov> { } } - fn build_dynamic_update(&'a self, ctx: &'a Prov, update: &'a Update) -> DynamicUpdate<'a, Prov> { + fn build_dynamic_update( + &'a self, + env: &RuntimeEnvironment, + ctx: &'a Prov, + update: &'a Update, + ) -> DynamicUpdate<'a, Prov> { DynamicUpdate { target: self.unsafe_get_dynamic_relation(&update.target), - dataflow: self.build_dynamic_dataflow(ctx, &update.dataflow), + dataflow: self.build_dynamic_dataflow(env, ctx, &update.dataflow), } } - fn build_dynamic_dataflow(&'a self, ctx: &'a Prov, dataflow: &'a Dataflow) -> DynamicDataflow<'a, Prov> { + fn build_dynamic_dataflow( + &'a self, + env: &RuntimeEnvironment, + ctx: &'a Prov, + dataflow: &'a Dataflow, + ) -> DynamicDataflow<'a, Prov> { match dataflow { Dataflow::Unit(t) => { if self.is_first_iteration() { @@ -235,7 +261,8 @@ impl<'a, Prov: Provenance> DynamicIteration<'a, Prov> { } } Dataflow::UntaggedVec(v) => { - DynamicDataflow::untagged_vec(ctx, v) + let internal_tuple = v.iter().map(|t| env.internalize_tuple(t)).collect(); + DynamicDataflow::untagged_vec(ctx, internal_tuple) } Dataflow::Relation(c) => { if self.input_dynamic_collections.contains_key(c) { @@ -245,47 +272,65 @@ impl<'a, Prov: Provenance> DynamicIteration<'a, Prov> { } } Dataflow::ForeignPredicateGround(p, a) => { - DynamicDataflow::foreign_predicate_ground(p.clone(), a.clone(), self.is_first_iteration(), ctx) + let internal_values = a.iter().map(|v| env.internalize_value(v)).collect(); + DynamicDataflow::foreign_predicate_ground(p.clone(), internal_values, self.is_first_iteration(), ctx) } Dataflow::ForeignPredicateConstraint(d, p, a) => { - self.build_dynamic_dataflow(ctx, d).foreign_predicate_constraint(p.clone(), a.clone(), ctx) + // NOTE: `a` contains accessors which do not need to be internalized + self + .build_dynamic_dataflow(env, ctx, d) + .foreign_predicate_constraint(p.clone(), a.clone(), ctx) } Dataflow::ForeignPredicateJoin(d, p, a) => { - self.build_dynamic_dataflow(ctx, d).foreign_predicate_join(p.clone(), a.clone(), ctx) + // NOTE: `a` contains accessors which do not need to be internalized + self + .build_dynamic_dataflow(env, ctx, d) + .foreign_predicate_join(p.clone(), a.clone(), ctx) + } + Dataflow::OverwriteOne(d) => self.build_dynamic_dataflow(env, ctx, d).overwrite_one(ctx), + Dataflow::Exclusion(d1, d2) => self + .build_dynamic_dataflow(env, ctx, d1) + .dynamic_exclusion(self.build_dynamic_dataflow(env, ctx, d2), ctx), + Dataflow::Filter(d, e) => { + let internal_filter = env.internalize_expr(e); + self.build_dynamic_dataflow(env, ctx, d).filter(internal_filter) + } + Dataflow::Find(d, k) => { + let internal_key = env.internalize_tuple(k); + self.build_dynamic_dataflow(env, ctx, d).find(internal_key) + } + Dataflow::Project(d, e) => { + let internal_expr = env.internalize_expr(e); + self.build_dynamic_dataflow(env, ctx, d).project(internal_expr) } - Dataflow::OverwriteOne(d) => self.build_dynamic_dataflow(ctx, d).overwrite_one(ctx), - Dataflow::Exclusion(d1, d2) => self.build_dynamic_dataflow(ctx, d1).dynamic_exclusion(self.build_dynamic_dataflow(ctx, d2), ctx), - Dataflow::Filter(d, e) => self.build_dynamic_dataflow(ctx, d).filter(e.clone()), - Dataflow::Find(d, k) => self.build_dynamic_dataflow(ctx, d).find(k.clone()), - Dataflow::Project(d, e) => self.build_dynamic_dataflow(ctx, d).project(e.clone()), Dataflow::Intersect(d1, d2) => { - let r1 = self.build_dynamic_dataflow(ctx, d1); - let r2 = self.build_dynamic_dataflow(ctx, d2); + let r1 = self.build_dynamic_dataflow(env, ctx, d1); + let r2 = self.build_dynamic_dataflow(env, ctx, d2); r1.intersect(r2, ctx) } Dataflow::Join(d1, d2) => { - let r1 = self.build_dynamic_dataflow(ctx, d1); - let r2 = self.build_dynamic_dataflow(ctx, d2); + let r1 = self.build_dynamic_dataflow(env, ctx, d1); + let r2 = self.build_dynamic_dataflow(env, ctx, d2); r1.join(r2, ctx) } Dataflow::Product(d1, d2) => { - let r1 = self.build_dynamic_dataflow(ctx, d1); - let r2 = self.build_dynamic_dataflow(ctx, d2); + let r1 = self.build_dynamic_dataflow(env, ctx, d1); + let r2 = self.build_dynamic_dataflow(env, ctx, d2); r1.product(r2, ctx) } Dataflow::Union(d1, d2) => { - let r1 = self.build_dynamic_dataflow(ctx, d1); - let r2 = self.build_dynamic_dataflow(ctx, d2); + let r1 = self.build_dynamic_dataflow(env, ctx, d1); + let r2 = self.build_dynamic_dataflow(env, ctx, d2); r1.union(r2) } Dataflow::Difference(d1, d2) => { - let r1 = self.build_dynamic_dataflow(ctx, d1); - let r2 = self.build_dynamic_dataflow(ctx, d2); + let r1 = self.build_dynamic_dataflow(env, ctx, d1); + let r2 = self.build_dynamic_dataflow(env, ctx, d2); r1.difference(r2, ctx) } Dataflow::Antijoin(d1, d2) => { - let r1 = self.build_dynamic_dataflow(ctx, d1); - let r2 = self.build_dynamic_dataflow(ctx, d2); + let r1 = self.build_dynamic_dataflow(env, ctx, d1); + let r2 = self.build_dynamic_dataflow(env, ctx, d2); r1.antijoin(r2, ctx) } Dataflow::Reduce(a) => { diff --git a/core/src/runtime/dynamic/mod.rs b/core/src/runtime/dynamic/mod.rs index 31d3a7a..0123942 100644 --- a/core/src/runtime/dynamic/mod.rs +++ b/core/src/runtime/dynamic/mod.rs @@ -3,7 +3,6 @@ mod collection; pub mod dataflow; mod element; mod incremental; -pub mod io; mod iteration; mod output_collection; mod relation; diff --git a/core/src/runtime/env/environment.rs b/core/src/runtime/env/environment.rs index cbd225a..0f276da 100644 --- a/core/src/runtime/env/environment.rs +++ b/core/src/runtime/env/environment.rs @@ -1,23 +1,25 @@ -use std::sync::*; - -use rand::rngs::SmallRng; -use rand::SeedableRng; +use std::collections::*; use crate::common::constants::*; +use crate::common::entity; use crate::common::expr::*; use crate::common::foreign_function::*; use crate::common::foreign_predicate::*; +use crate::common::tensors; use crate::common::tuple::*; +use crate::common::value::*; use crate::common::value_type::*; use crate::utils::*; -#[derive(Clone, Debug)] +use super::*; + +#[derive(Clone)] pub struct RuntimeEnvironment { /// Random seed for reference pub random_seed: u64, /// Random number generater initialized from the random seed - pub rng: Arc>, + pub random: Random, /// Whether we want to early discard 0-tagged facts pub early_discard: bool, @@ -32,7 +34,16 @@ pub struct RuntimeEnvironment { pub predicate_registry: ForeignPredicateRegistry, /// Mutual exclusion ID allocator - pub exclusion_id_allocator: Arc>, + pub exclusion_id_allocator: IdAllocator2, + + /// Symbol registry + pub symbol_registry: SymbolRegistry2, + + /// New Entities + pub new_entities: NewEntitiesStorage2, + + /// Tensor registry + pub tensor_registry: TensorRegistry2, } impl Default for RuntimeEnvironment { @@ -45,51 +56,75 @@ impl RuntimeEnvironment { pub fn new_std() -> Self { Self { random_seed: DEFAULT_RANDOM_SEED, - rng: Arc::new(Mutex::new(SmallRng::seed_from_u64(DEFAULT_RANDOM_SEED))), + random: Random::new(DEFAULT_RANDOM_SEED), early_discard: true, iter_limit: None, function_registry: ForeignFunctionRegistry::std(), predicate_registry: ForeignPredicateRegistry::std(), - exclusion_id_allocator: Arc::new(Mutex::new(IdAllocator::new())), + exclusion_id_allocator: IdAllocator2::new(), + symbol_registry: SymbolRegistry2::new(), + new_entities: NewEntitiesStorage2::new(), + tensor_registry: TensorRegistry2::new(), } } pub fn new_with_random_seed(seed: u64) -> Self { Self { random_seed: seed, - rng: Arc::new(Mutex::new(SmallRng::seed_from_u64(seed))), + random: Random::new(seed), early_discard: true, iter_limit: None, function_registry: ForeignFunctionRegistry::std(), predicate_registry: ForeignPredicateRegistry::std(), - exclusion_id_allocator: Arc::new(Mutex::new(IdAllocator::new())), + exclusion_id_allocator: IdAllocator2::new(), + symbol_registry: SymbolRegistry2::new(), + new_entities: NewEntitiesStorage2::new(), + tensor_registry: TensorRegistry2::new(), } } - pub fn new( - ffr: ForeignFunctionRegistry, - fpr: ForeignPredicateRegistry, - ) -> Self { + pub fn new(ffr: ForeignFunctionRegistry, fpr: ForeignPredicateRegistry) -> Self { Self { random_seed: DEFAULT_RANDOM_SEED, - rng: Arc::new(Mutex::new(SmallRng::seed_from_u64(DEFAULT_RANDOM_SEED))), + random: Random::new(DEFAULT_RANDOM_SEED), early_discard: true, iter_limit: None, function_registry: ffr, predicate_registry: fpr, - exclusion_id_allocator: Arc::new(Mutex::new(IdAllocator::new())), + exclusion_id_allocator: IdAllocator2::new(), + symbol_registry: SymbolRegistry2::new(), + new_entities: NewEntitiesStorage2::new(), + tensor_registry: TensorRegistry2::new(), } } pub fn new_with_function_registry(ffr: ForeignFunctionRegistry) -> Self { Self { random_seed: DEFAULT_RANDOM_SEED, - rng: Arc::new(Mutex::new(SmallRng::seed_from_u64(DEFAULT_RANDOM_SEED))), + random: Random::new(DEFAULT_RANDOM_SEED), early_discard: true, iter_limit: None, function_registry: ffr, predicate_registry: ForeignPredicateRegistry::std(), - exclusion_id_allocator: Arc::new(Mutex::new(IdAllocator::new())), + exclusion_id_allocator: IdAllocator2::new(), + symbol_registry: SymbolRegistry2::new(), + new_entities: NewEntitiesStorage2::new(), + tensor_registry: TensorRegistry2::new(), + } + } + + pub fn new_from_options(options: RuntimeEnvironmentOptions) -> Self { + Self { + random_seed: options.random_seed, + random: Random::new(options.random_seed), + early_discard: options.early_discard, + iter_limit: options.iter_limit, + function_registry: ForeignFunctionRegistry::std(), + predicate_registry: ForeignPredicateRegistry::std(), + exclusion_id_allocator: IdAllocator2::new(), + symbol_registry: SymbolRegistry2::new(), + new_entities: NewEntitiesStorage2::new(), + tensor_registry: TensorRegistry2::new(), } } @@ -106,7 +141,80 @@ impl RuntimeEnvironment { } pub fn allocate_new_exclusion_id(&self) -> usize { - self.exclusion_id_allocator.lock().unwrap().alloc() + self.exclusion_id_allocator.alloc() + } + + pub fn internalize_tuple(&self, tup: &Tuple) -> Tuple { + match tup { + Tuple::Tuple(ts) => Tuple::Tuple(ts.iter().map(|t| self.internalize_tuple(t)).collect()), + Tuple::Value(v) => Tuple::Value(self.internalize_value(v)), + } + } + + pub fn internalize_value(&self, val: &Value) -> Value { + match val { + Value::SymbolString(s) => { + let symbol_id = self.symbol_registry.register(s.clone()); + Value::Symbol(symbol_id) + } + Value::Tensor(t) => { + let tensor_symbol = self.tensor_registry.register(t.clone()); + Value::TensorValue(tensor_symbol.into()) + } + other => other.clone(), + } + } + + pub fn internalize_expr(&self, expr: &Expr) -> Expr { + match expr { + Expr::Access(a) => Expr::Access(a.clone()), + Expr::Tuple(t) => Expr::Tuple(t.iter().map(|e| self.internalize_expr(e)).collect()), + Expr::Binary(b) => Expr::binary( + b.op.clone(), + self.internalize_expr(&b.op1), + self.internalize_expr(&b.op2), + ), + Expr::Unary(u) => Expr::unary(u.op.clone(), self.internalize_expr(&u.op1)), + Expr::Call(c) => Expr::call( + c.function.clone(), + c.args.iter().map(|e| self.internalize_expr(e)).collect(), + ), + Expr::Constant(c) => Expr::Constant(self.internalize_value(c)), + Expr::IfThenElse(ite) => Expr::ite( + self.internalize_expr(&ite.cond), + self.internalize_expr(&ite.then_br), + self.internalize_expr(&ite.else_br), + ), + Expr::New(n) => Expr::new( + n.functor.clone(), + n.args.iter().map(|e| self.internalize_expr(e)).collect(), + ), + } + } + + pub fn externalize_tuple(&self, tup: &Tuple) -> Tuple { + match tup { + Tuple::Tuple(ts) => Tuple::Tuple(ts.iter().map(|t| self.externalize_tuple(t)).collect()), + Tuple::Value(v) => Tuple::Value(self.externalize_value(v)), + } + } + + pub fn externalize_value(&self, val: &Value) -> Value { + match val { + Value::Symbol(s) => { + let symbol = self.symbol_registry.get_symbol(*s).expect("Cannot find symbol"); + Value::SymbolString(symbol) + } + Value::TensorValue(t) => { + let tensor = self.tensor_registry.eval(t); + Value::Tensor(tensor) + } + other => other.clone(), + } + } + + pub fn drain_new_entities(&self) -> HashMap> { + self.new_entities.drain_entities() } pub fn eval(&self, expr: &Expr, tuple: &Tuple) -> Option { @@ -120,12 +228,13 @@ impl RuntimeEnvironment { Expr::Unary(u) => self.eval_unary(u, tuple), Expr::IfThenElse(i) => self.eval_if_then_else(i, tuple), Expr::Call(c) => self.eval_call(c, tuple), + Expr::New(n) => self.eval_new(n, tuple), } } pub fn eval_binary(&self, expr: &BinaryExpr, v: &Tuple) -> Option { use crate::common::binary_op::BinaryOp::*; - use crate::common::value::Value::*; + use Value::*; // Recursively evaluate sub-expressions let lhs_v = self.eval(&expr.op1, v)?; @@ -152,6 +261,15 @@ impl RuntimeEnvironment { (Add, Tuple::Value(DateTime(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(DateTime(i1 + i2)), (Add, Tuple::Value(Duration(i1)), Tuple::Value(DateTime(i2))) => Tuple::Value(DateTime(i2 + i1)), (Add, Tuple::Value(Duration(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(Duration(i1 + i2)), + (Add, Tuple::Value(TensorValue(v1)), Tuple::Value(TensorValue(v2))) => { + v1.add(v2).map(TensorValue).map(Tuple::Value)? + } + (Add, Tuple::Value(TensorValue(v1)), Tuple::Value(F64(f2))) => { + v1.add(f2.into()).map(TensorValue).map(Tuple::Value)? + } + (Add, Tuple::Value(F64(f1)), Tuple::Value(TensorValue(v2))) => { + v2.add(f1.into()).map(TensorValue).map(Tuple::Value)? + } (Add, b1, b2) => panic!("Cannot perform ADD on {:?} and {:?}", b1, b2), // Subtraction @@ -171,7 +289,17 @@ impl RuntimeEnvironment { (Sub, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => Tuple::Value(F64(i1 - i2)), (Sub, Tuple::Value(DateTime(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(DateTime(i1 - i2)), (Sub, Tuple::Value(DateTime(i1)), Tuple::Value(DateTime(i2))) => Tuple::Value(Duration(i1 - i2)), - (Sub, Tuple::Value(Duration(i1)), Tuple::Value(Duration(i2))) =>Tuple::Value(Duration(i1 - i2)), + (Sub, Tuple::Value(Duration(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(Duration(i1 - i2)), + (Sub, Tuple::Value(TensorValue(v1)), Tuple::Value(TensorValue(v2))) => { + v1.sub(v2).map(TensorValue).map(Tuple::Value)? + } + (Sub, Tuple::Value(TensorValue(v1)), Tuple::Value(F64(f2))) => { + v1.sub(f2.into()).map(TensorValue).map(Tuple::Value)? + } + (Sub, Tuple::Value(F64(f1)), Tuple::Value(TensorValue(v2))) => tensors::TensorValue::from(f1) + .sub(v2) + .map(TensorValue) + .map(Tuple::Value)?, (Sub, b1, b2) => panic!("Cannot perform SUB on {:?} and {:?}", b1, b2), // Multiplication @@ -191,6 +319,16 @@ impl RuntimeEnvironment { (Mul, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => Tuple::Value(F64(i1 * i2)), (Mul, Tuple::Value(Duration(i1)), Tuple::Value(I32(i2))) => Tuple::Value(Duration(i1 * i2)), (Mul, Tuple::Value(I32(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(Duration(i2 * i1)), + (Mul, Tuple::Value(TensorValue(v1)), Tuple::Value(TensorValue(v2))) => { + v1.mul(v2).map(TensorValue).map(Tuple::Value)? + } + (Mul, Tuple::Value(TensorValue(v1)), Tuple::Value(F64(f2))) => { + v1.mul(f2.into()).map(TensorValue).map(Tuple::Value)? + } + (Mul, Tuple::Value(F64(f1)), Tuple::Value(TensorValue(v2))) => tensors::TensorValue::from(f1) + .mul(v2) + .map(TensorValue) + .map(Tuple::Value)?, (Mul, b1, b2) => panic!("Cannot perform MUL on {:?} and {:?}", b1, b2), // Division @@ -213,7 +351,7 @@ impl RuntimeEnvironment { } else { Tuple::Value(F32(r)) } - }, + } (Div, Tuple::Value(F64(i1)), Tuple::Value(F64(i2))) => { let r = i1 / i2; if r.is_nan() { @@ -221,7 +359,7 @@ impl RuntimeEnvironment { } else { Tuple::Value(F64(r)) } - }, + } (Div, Tuple::Value(Duration(i1)), Tuple::Value(I32(i2))) => Tuple::Value(Duration(i1 / i2)), (Div, b1, b2) => panic!("Cannot perform DIV on {:?} and {:?}", b1, b2), @@ -267,9 +405,10 @@ impl RuntimeEnvironment { (Eq, Tuple::Value(Bool(i1)), Tuple::Value(Bool(i2))) => Tuple::Value(Bool(i1 == i2)), (Eq, Tuple::Value(Str(i1)), Tuple::Value(Str(i2))) => Tuple::Value(Bool(i1 == i2)), (Eq, Tuple::Value(String(i1)), Tuple::Value(String(i2))) => Tuple::Value(Bool(i1 == i2)), - // (Eq, Tuple::Value(RcString(i1)), Tuple::Value(RcString(i2))) => Tuple::Value(Bool(i1 == i2)), + (Eq, Tuple::Value(Symbol(i1)), Tuple::Value(Symbol(i2))) => Tuple::Value(Bool(i1 == i2)), (Eq, Tuple::Value(DateTime(i1)), Tuple::Value(DateTime(i2))) => Tuple::Value(Bool(i1 == i2)), (Eq, Tuple::Value(Duration(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(Bool(i1 == i2)), + (Eq, Tuple::Value(Entity(i1)), Tuple::Value(Entity(i2))) => Tuple::Value(Bool(i1 == i2)), (Eq, b1, b2) => panic!("Cannot perform EQ on {:?} and {:?}", b1, b2), // Not equal to @@ -291,9 +430,10 @@ impl RuntimeEnvironment { (Neq, Tuple::Value(Bool(i1)), Tuple::Value(Bool(i2))) => Tuple::Value(Bool(i1 != i2)), (Neq, Tuple::Value(Str(i1)), Tuple::Value(Str(i2))) => Tuple::Value(Bool(i1 != i2)), (Neq, Tuple::Value(String(i1)), Tuple::Value(String(i2))) => Tuple::Value(Bool(i1 != i2)), - // (Neq, Tuple::Value(RcString(i1)), Tuple::Value(RcString(i2))) => Tuple::Value(Bool(i1 != i2)), + (Neq, Tuple::Value(Symbol(i1)), Tuple::Value(Symbol(i2))) => Tuple::Value(Bool(i1 != i2)), (Neq, Tuple::Value(DateTime(i1)), Tuple::Value(DateTime(i2))) => Tuple::Value(Bool(i1 != i2)), (Neq, Tuple::Value(Duration(i1)), Tuple::Value(Duration(i2))) => Tuple::Value(Bool(i1 != i2)), + (Neq, Tuple::Value(Entity(i1)), Tuple::Value(Entity(i2))) => Tuple::Value(Bool(i1 != i2)), (Neq, b1, b2) => panic!("Cannot perform NEQ on {:?} and {:?}", b1, b2), // Greater than @@ -460,6 +600,8 @@ impl RuntimeEnvironment { (Tuple::Value(Char(s)), T::U64) => s.to_digit(10).map(|i| Tuple::Value(U64(i as u64))), (Tuple::Value(Char(s)), T::U128) => s.to_digit(10).map(|i| Tuple::Value(U128(i as u128))), (Tuple::Value(Char(s)), T::USize) => s.to_digit(10).map(|i| Tuple::Value(USize(i as usize))), + (Tuple::Value(Char(s)), T::F32) => s.to_string().parse().ok().map(|f| Tuple::Value(F32(f))), + (Tuple::Value(Char(s)), T::F64) => s.to_string().parse().ok().map(|f| Tuple::Value(F64(f))), (Tuple::Value(String(s)), T::I8) => s.parse().ok().map(|i| Tuple::Value(I8(i))), (Tuple::Value(String(s)), T::I16) => s.parse().ok().map(|i| Tuple::Value(I16(i))), @@ -494,6 +636,12 @@ impl RuntimeEnvironment { (Tuple::Value(Char(c)), T::String) => Some(Tuple::Value(String(c.to_string()))), (Tuple::Value(Str(s)), T::String) => Some(Tuple::Value(String(s.to_string()))), (Tuple::Value(String(s)), T::String) => Some(Tuple::Value(String(s.clone()))), + (Tuple::Value(Symbol(id)), T::String) => Some(Tuple::Value(String( + self + .symbol_registry + .get_symbol(id) + .expect("[Internal Error] Cannot find symbol"), + ))), // Not implemented (v, t) => unimplemented!("Unimplemented type cast from `{:?}` to `{}`", v.tuple_type(), t), @@ -521,10 +669,30 @@ impl RuntimeEnvironment { .collect::>>()?; // Run the function - let result = f.execute(args)?; + let result = f.execute_with_env(self, args)?; // Turn result into tuple Some(Tuple::Value(result)) }) } + + pub fn eval_new(&self, expr: &NewExpr, v: &Tuple) -> Option { + // Evaluate the arguments + let args = expr + .args + .iter() + .map(|a| self.eval(a, v).map(|t| t.as_value())) + .collect::>>()?; + + // Hash all the arguments to get a new entity + let raw_id = entity::encode_entity(&expr.functor, args.iter()); + let id = Value::Entity(raw_id); + + // Combine them to form a tuple for later insertion to new entities list + let tuple = Tuple::from_values(args.into_iter()); + self.new_entities.add(&expr.functor, id.clone(), tuple); + + // Return the value + Some(Tuple::Value(id)) + } } diff --git a/core/src/runtime/env/mod.rs b/core/src/runtime/env/mod.rs index da3cfdd..ed3f4cc 100644 --- a/core/src/runtime/env/mod.rs +++ b/core/src/runtime/env/mod.rs @@ -1,5 +1,13 @@ mod environment; +mod new_entities; mod options; +mod random; +mod symbol_registry; +mod tensor_registry; pub use environment::*; +pub use new_entities::*; pub use options::*; +pub use random::*; +pub use symbol_registry::*; +pub use tensor_registry::*; diff --git a/core/src/runtime/env/new_entities.rs b/core/src/runtime/env/new_entities.rs new file mode 100644 index 0000000..7610425 --- /dev/null +++ b/core/src/runtime/env/new_entities.rs @@ -0,0 +1,79 @@ +use std::collections::*; +use std::sync::*; + +use crate::common::tuple::*; +use crate::common::value::*; + +#[derive(Clone, Debug)] +pub struct NewEntitiesStorage { + pub entities: HashMap>>, +} + +impl NewEntitiesStorage { + pub fn new() -> Self { + Self { + entities: HashMap::new(), + } + } + + pub fn add(&mut self, functor: &str, value: Value, tuple: Tuple) { + self + .entities + .entry(functor.to_string()) + .or_default() + .entry(value) + .or_default() + .push(tuple); + } + + pub fn drain_entities(&mut self) -> HashMap> { + // Create an empty dictionary and swap it with the internal one + let mut entities = HashMap::new(); + std::mem::swap(&mut self.entities, &mut entities); + + // Post-process the entities into vector of tuples + entities + .into_iter() + .map(|(relation_name, id_tuple_map)| { + let tuples = id_tuple_map + .into_iter() + .flat_map(|(id, tuples)| { + tuples + .into_iter() + .map(move |tuple| Tuple::from_values(std::iter::once(id.clone()).chain(tuple.as_values()))) + }) + .collect::>(); + (relation_name, tuples) + }) + .collect() + } +} + +#[derive(Debug)] +pub struct NewEntitiesStorage2 { + storage: Mutex, +} + +impl Clone for NewEntitiesStorage2 { + fn clone(&self) -> Self { + Self { + storage: Mutex::new(self.storage.lock().unwrap().clone()), + } + } +} + +impl NewEntitiesStorage2 { + pub fn new() -> Self { + Self { + storage: Mutex::new(NewEntitiesStorage::new()), + } + } + + pub fn add(&self, functor: &str, id: Value, tuple: Tuple) { + self.storage.lock().unwrap().add(functor, id, tuple) + } + + pub fn drain_entities(&self) -> HashMap> { + self.storage.lock().unwrap().drain_entities() + } +} diff --git a/core/src/runtime/env/options.rs b/core/src/runtime/env/options.rs index cb848c4..486c0e6 100644 --- a/core/src/runtime/env/options.rs +++ b/core/src/runtime/env/options.rs @@ -1,12 +1,4 @@ -use std::sync::*; - -use rand::rngs::SmallRng; -use rand::SeedableRng; - use crate::common::constants::*; -use crate::common::foreign_function::*; -use crate::common::foreign_predicate::*; -use crate::utils::*; use super::*; @@ -33,17 +25,7 @@ impl RuntimeEnvironmentOptions { } } - /// Build a runtime environment from this options pub fn build(self) -> RuntimeEnvironment { - let rng = SmallRng::seed_from_u64(self.random_seed); - RuntimeEnvironment { - random_seed: self.random_seed, - rng: Arc::new(Mutex::new(rng)), - early_discard: self.early_discard, - iter_limit: self.iter_limit, - function_registry: ForeignFunctionRegistry::std(), - predicate_registry: ForeignPredicateRegistry::std(), - exclusion_id_allocator: Arc::new(Mutex::new(IdAllocator::new())), - } + RuntimeEnvironment::new_from_options(self) } } diff --git a/core/src/runtime/env/random.rs b/core/src/runtime/env/random.rs new file mode 100644 index 0000000..a6027e8 --- /dev/null +++ b/core/src/runtime/env/random.rs @@ -0,0 +1,23 @@ +use std::sync::*; + +use rand::rngs::SmallRng; +use rand::SeedableRng; + +#[derive(Clone, Debug)] +pub struct Random { + pub rng: Arc>, +} + +impl Random { + /// Create a new random module + pub fn new(seed: u64) -> Self { + Self { + rng: Arc::new(Mutex::new(SmallRng::seed_from_u64(seed))), + } + } + + /// Sample an element from a distribution using the rng + pub fn sample_from>(&self, dist: &D) -> T { + dist.sample(&mut *self.rng.lock().unwrap()) + } +} diff --git a/core/src/runtime/env/symbol_registry.rs b/core/src/runtime/env/symbol_registry.rs new file mode 100644 index 0000000..b0e2312 --- /dev/null +++ b/core/src/runtime/env/symbol_registry.rs @@ -0,0 +1,33 @@ +use std::sync::*; + +use crate::common::symbol_registry::*; + +/// A symbol registry with shared internal mutability +#[derive(Clone, Debug)] +pub struct SymbolRegistry2 { + pub registry: Arc>, +} + +impl SymbolRegistry2 { + pub fn new() -> Self { + Self { + registry: Arc::new(Mutex::new(SymbolRegistry::new())), + } + } + + pub fn is_empty(&self) -> bool { + self.registry.lock().unwrap().is_empty() + } + + pub fn register(&self, symbol: String) -> usize { + self.registry.lock().unwrap().register(symbol) + } + + pub fn get_id(&self, symbol: &str) -> Option { + self.registry.lock().unwrap().get_id(symbol) + } + + pub fn get_symbol(&self, id: usize) -> Option { + self.registry.lock().unwrap().get_symbol(id).cloned() + } +} diff --git a/core/src/runtime/env/tensor_registry.rs b/core/src/runtime/env/tensor_registry.rs new file mode 100644 index 0000000..1e78877 --- /dev/null +++ b/core/src/runtime/env/tensor_registry.rs @@ -0,0 +1,29 @@ +use std::sync::*; + +use crate::common::tensors::*; + +/// A reference counted +#[derive(Clone)] +pub struct TensorRegistry2 { + registry: Arc>, +} + +impl TensorRegistry2 { + pub fn new() -> Self { + Self { + registry: Arc::new(Mutex::new(TensorRegistry::new())), + } + } + + pub fn register(&self, tensor: Tensor) -> TensorSymbol { + self.registry.lock().unwrap().register(tensor) + } + + pub fn get(&self, symbol: &TensorSymbol) -> Option { + self.registry.lock().unwrap().get(symbol).cloned() + } + + pub fn eval(&self, value: &TensorValue) -> Tensor { + self.registry.lock().unwrap().eval(value) + } +} diff --git a/core/src/runtime/error/error.rs b/core/src/runtime/error/error.rs index 915fb11..721cb21 100644 --- a/core/src/runtime/error/error.rs +++ b/core/src/runtime/error/error.rs @@ -1,6 +1,7 @@ use super::io::IOError; use crate::common::foreign_function::ForeignFunctionError; use crate::common::foreign_predicate::ForeignPredicateError; +use crate::compiler::front::attribute::AttributeError; use crate::runtime::database::DatabaseError; #[derive(Clone, Debug)] @@ -8,6 +9,7 @@ pub enum RuntimeError { IO(IOError), ForeignFunction(ForeignFunctionError), ForeignPredicate(ForeignPredicateError), + ForeignAttribute(AttributeError), Database(DatabaseError), } @@ -17,6 +19,7 @@ impl std::fmt::Display for RuntimeError { Self::IO(e) => e.fmt(f), Self::ForeignFunction(e) => e.fmt(f), Self::ForeignPredicate(e) => e.fmt(f), + Self::ForeignAttribute(a) => a.fmt(f), Self::Database(e) => e.fmt(f), } } diff --git a/core/src/runtime/error/io.rs b/core/src/runtime/error/io.rs index e487027..ad9ae7b 100644 --- a/core/src/runtime/error/io.rs +++ b/core/src/runtime/error/io.rs @@ -1,18 +1,24 @@ use std::path::PathBuf; -use crate::common::tuple_type::TupleType; -use crate::common::value_type::ValueParseError; +use crate::common::tuple_type::*; +use crate::common::value_type::*; #[derive(Clone, Debug)] pub enum IOError { CannotOpenFile { file_path: PathBuf, error: String }, CannotReadFile { error: String }, CannotParseCSV { error: String }, + CannotReadHeader { error: String }, + CannotFindField { field: String }, + IndexOutOfBounds { index: usize }, InvalidType { types: TupleType }, + ExpectSymbolType { actual: ValueType }, + ExpectStringType { actual: ValueType }, ValueParseError { error: ValueParseError }, CannotParseProbability { value: String }, ArityMismatch { expected: usize, found: usize }, CannotWriteRecord { error: String }, + InvalidFileFormat {}, } impl std::fmt::Display for IOError { @@ -25,7 +31,16 @@ impl std::fmt::Display for IOError { )), Self::CannotReadFile { error } => f.write_fmt(format_args!("IO: Cannot read file: {}", error)), Self::CannotParseCSV { error } => f.write_fmt(format_args!("IO: Cannot parse CSV: {}", error)), + Self::CannotReadHeader { error } => f.write_fmt(format_args!("IO: Cannot read CSV header: {}", error)), + Self::CannotFindField { field } => f.write_fmt(format_args!("IO: Cannot find field `{}`", field)), + Self::IndexOutOfBounds { index } => f.write_fmt(format_args!("IO: Index out of bounds: {}", index)), Self::InvalidType { types } => f.write_fmt(format_args!("IO: Invalid tuple type: `{}`", types)), + Self::ExpectSymbolType { actual } => { + f.write_fmt(format_args!("IO: Expect `Symbol` type for field; found `{}`", actual)) + } + Self::ExpectStringType { actual } => { + f.write_fmt(format_args!("IO: Expect `String` type for value; found `{}`", actual)) + } Self::ValueParseError { error } => std::fmt::Display::fmt(error, f), Self::CannotParseProbability { value } => f.write_fmt(format_args!("IO: Cannot parse probability `{}`", value)), Self::ArityMismatch { expected, found } => f.write_fmt(format_args!( @@ -33,6 +48,7 @@ impl std::fmt::Display for IOError { expected, found )), Self::CannotWriteRecord { error } => f.write_fmt(format_args!("IO: Cannot write record: {}", error)), + Self::InvalidFileFormat {} => f.write_fmt(format_args!("IO: Invalid file format")), } } } diff --git a/core/src/runtime/monitor/logging.rs b/core/src/runtime/monitor/logging.rs index 13c2e7e..8b2f8d1 100644 --- a/core/src/runtime/monitor/logging.rs +++ b/core/src/runtime/monitor/logging.rs @@ -12,11 +12,11 @@ impl LoggingMonitor { } pub fn warning(&self, s: &str) { - println!("[Warn] {}", s.color(Color::Yellow)); + eprintln!("[Warn] {}", s.color(Color::Yellow)); } pub fn error(&self, s: &str) { - println!("[Error] {}", s.color(Color::Red)); + eprintln!("[Error] {}", s.color(Color::Red)); } } diff --git a/core/src/runtime/provenance/common/diff_prob_storage.rs b/core/src/runtime/provenance/common/diff_prob_storage.rs index 029d416..05b7fc5 100644 --- a/core/src/runtime/provenance/common/diff_prob_storage.rs +++ b/core/src/runtime/provenance/common/diff_prob_storage.rs @@ -55,9 +55,7 @@ impl DiffProbStorage { } pub fn input_tags(&self) -> Vec { - P::get_rc_cell(&self.storage, |s| { - s.iter().filter_map(|(_, t)| t.clone()).collect() - }) + P::get_rc_cell(&self.storage, |s| s.iter().filter_map(|(_, t)| t.clone()).collect()) } pub fn num_input_tags(&self) -> usize { diff --git a/core/src/runtime/provenance/common/input_tags/boolean.rs b/core/src/runtime/provenance/common/input_tags/boolean.rs index 101f727..c8a18ff 100644 --- a/core/src/runtime/provenance/common/input_tags/boolean.rs +++ b/core/src/runtime/provenance/common/input_tags/boolean.rs @@ -1,4 +1,5 @@ use crate::common::input_tag::*; +use crate::common::tensors::*; use super::*; @@ -12,25 +13,37 @@ impl StaticInputTag for bool { } impl ConvertFromInputTag<()> for bool { - fn from_input_tag(_: ()) -> Option { None } + fn from_input_tag(_: ()) -> Option { + None + } } impl ConvertFromInputTag for bool { - fn from_input_tag(t: bool) -> Option { Some(t) } + fn from_input_tag(t: bool) -> Option { + Some(t) + } } impl ConvertFromInputTag for bool { - fn from_input_tag(t: usize) -> Option { Some(t > 0) } + fn from_input_tag(t: usize) -> Option { + Some(t > 0) + } } impl ConvertFromInputTag for bool { - fn from_input_tag(t: f32) -> Option { Some(t > 0.0) } + fn from_input_tag(t: f32) -> Option { + Some(t > 0.0) + } } impl ConvertFromInputTag for bool { - fn from_input_tag(t: f64) -> Option { Some(t > 0.0) } + fn from_input_tag(t: f64) -> Option { + Some(t > 0.0) + } } -impl ConvertFromInputTag> for bool { - fn from_input_tag(t: InputDiffProb) -> Option { Some(t.0 > 0.0) } +impl ConvertFromInputTag> for bool { + fn from_input_tag(t: InputDiffProb) -> Option { + Some(t.0 > 0.0) + } } diff --git a/core/src/runtime/provenance/common/input_tags/float.rs b/core/src/runtime/provenance/common/input_tags/float.rs index 98aff4e..0f9b584 100644 --- a/core/src/runtime/provenance/common/input_tags/float.rs +++ b/core/src/runtime/provenance/common/input_tags/float.rs @@ -1,4 +1,5 @@ use crate::common::input_tag::*; +use crate::common::tensors::*; use super::*; @@ -13,33 +14,49 @@ impl StaticInputTag for f64 { } impl ConvertFromInputTag<()> for f64 { - fn from_input_tag(_: ()) -> Option { None } + fn from_input_tag(_: ()) -> Option { + None + } } impl ConvertFromInputTag for f64 { - fn from_input_tag(t: bool) -> Option { Some(if t { 1.0 } else { 0.0 }) } + fn from_input_tag(t: bool) -> Option { + Some(if t { 1.0 } else { 0.0 }) + } } impl ConvertFromInputTag for f64 { - fn from_input_tag(t: usize) -> Option { Some(if t > 0 { 1.0 } else { 0.0 }) } + fn from_input_tag(t: usize) -> Option { + Some(if t > 0 { 1.0 } else { 0.0 }) + } } impl ConvertFromInputTag for f64 { - fn from_input_tag(_: Exclusion) -> Option { None } + fn from_input_tag(_: Exclusion) -> Option { + None + } } impl ConvertFromInputTag for f64 { - fn from_input_tag(t: f64) -> Option { Some(t) } + fn from_input_tag(t: f64) -> Option { + Some(t) + } } impl ConvertFromInputTag for f64 { - fn from_input_tag(t: InputExclusiveProb) -> Option { Some(t.prob) } + fn from_input_tag(t: InputExclusiveProb) -> Option { + Some(t.prob) + } } -impl ConvertFromInputTag> for f64 { - fn from_input_tag(t: InputDiffProb) -> Option { Some(t.0) } +impl ConvertFromInputTag> for f64 { + fn from_input_tag(t: InputDiffProb) -> Option { + Some(t.0) + } } -impl ConvertFromInputTag> for f64 { - fn from_input_tag(t: InputExclusiveDiffProb) -> Option { Some(t.prob) } +impl ConvertFromInputTag> for f64 { + fn from_input_tag(t: InputExclusiveDiffProb) -> Option { + Some(t.prob) + } } diff --git a/core/src/runtime/provenance/common/input_tags/input_diff_prob.rs b/core/src/runtime/provenance/common/input_tags/input_diff_prob.rs index 4203feb..1c1e6eb 100644 --- a/core/src/runtime/provenance/common/input_tags/input_diff_prob.rs +++ b/core/src/runtime/provenance/common/input_tags/input_diff_prob.rs @@ -1,4 +1,5 @@ use crate::common::input_tag::*; +use crate::common::tensors::*; use super::*; @@ -12,37 +13,40 @@ use super::*; /// back-propagate gradients into it. /// In such case the probability is treated as a constant. #[derive(Clone)] -pub struct InputDiffProb(pub f64, pub Option); +pub struct InputDiffProb(pub f64, pub Option); -impl std::fmt::Debug for InputDiffProb { +impl std::fmt::Debug for InputDiffProb { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } -impl From<(f64, Option)> for InputDiffProb { +impl From<(f64, Option)> for InputDiffProb { fn from((p, t): (f64, Option)) -> Self { Self(p, t) } } -impl StaticInputTag for InputDiffProb { +impl StaticInputTag for InputDiffProb { fn from_dynamic_input_tag(t: &DynamicInputTag) -> Option { match t { - DynamicInputTag::ExclusiveFloat(f, _) => Some(Self(f.clone(), None)), + DynamicInputTag::None => None, + DynamicInputTag::Bool(b) => Some(Self(if *b { 1.0 } else { 0.0 }, None)), + DynamicInputTag::Exclusive(_) => None, DynamicInputTag::Float(f) => Some(Self(f.clone(), None)), - _ => None, + DynamicInputTag::ExclusiveFloat(f, _) => Some(Self(f.clone(), None)), + DynamicInputTag::Tensor(t) => Some(Self(t.get_f64(), T::from_tensor(t.clone()))), } } } -impl ConvertFromInputTag<()> for InputDiffProb { +impl ConvertFromInputTag<()> for InputDiffProb { fn from_input_tag(_: ()) -> Option { None } } -impl ConvertFromInputTag for InputDiffProb { +impl ConvertFromInputTag for InputDiffProb { fn from_input_tag(b: bool) -> Option { if b { None @@ -52,7 +56,7 @@ impl ConvertFromInputTag for InputDiffProb { } } -impl ConvertFromInputTag for InputDiffProb { +impl ConvertFromInputTag for InputDiffProb { fn from_input_tag(u: usize) -> Option { if u > 0 { None @@ -62,31 +66,31 @@ impl ConvertFromInputTag for InputDiffProb { } } -impl ConvertFromInputTag for InputDiffProb { +impl ConvertFromInputTag for InputDiffProb { fn from_input_tag(_: Exclusion) -> Option { None } } -impl ConvertFromInputTag for InputDiffProb { +impl ConvertFromInputTag for InputDiffProb { fn from_input_tag(t: f64) -> Option { Some(Self(t, None)) } } -impl ConvertFromInputTag for InputDiffProb { +impl ConvertFromInputTag for InputDiffProb { fn from_input_tag(t: InputExclusiveProb) -> Option { Some(Self(t.prob, None)) } } -impl ConvertFromInputTag> for InputDiffProb { +impl ConvertFromInputTag> for InputDiffProb { fn from_input_tag(t: InputDiffProb) -> Option { Some(t.clone()) } } -impl ConvertFromInputTag> for InputDiffProb { +impl ConvertFromInputTag> for InputDiffProb { fn from_input_tag(t: InputExclusiveDiffProb) -> Option { Some(Self(t.prob, None)) } diff --git a/core/src/runtime/provenance/common/input_tags/input_exclusion.rs b/core/src/runtime/provenance/common/input_tags/input_exclusion.rs index de5ada1..698edcb 100644 --- a/core/src/runtime/provenance/common/input_tags/input_exclusion.rs +++ b/core/src/runtime/provenance/common/input_tags/input_exclusion.rs @@ -1,4 +1,5 @@ use crate::common::input_tag::*; +use crate::common::tensors::*; use super::*; @@ -57,13 +58,13 @@ impl ConvertFromInputTag for Exclusion { } } -impl ConvertFromInputTag> for Exclusion { +impl ConvertFromInputTag> for Exclusion { fn from_input_tag(_: InputDiffProb) -> Option { None } } -impl ConvertFromInputTag> for Exclusion { +impl ConvertFromInputTag> for Exclusion { fn from_input_tag(t: InputExclusiveDiffProb) -> Option { match &t.exclusion { Some(e) => Some(Self::Exclusive(e.clone())), diff --git a/core/src/runtime/provenance/common/input_tags/input_exclusive_diff_prob.rs b/core/src/runtime/provenance/common/input_tags/input_exclusive_diff_prob.rs index b85559e..fdf3127 100644 --- a/core/src/runtime/provenance/common/input_tags/input_exclusive_diff_prob.rs +++ b/core/src/runtime/provenance/common/input_tags/input_exclusive_diff_prob.rs @@ -1,9 +1,10 @@ use crate::common::input_tag::*; +use crate::common::tensors::*; use super::*; #[derive(Clone)] -pub struct InputExclusiveDiffProb { +pub struct InputExclusiveDiffProb { /// The probability of the tag pub prob: f64, @@ -14,47 +15,80 @@ pub struct InputExclusiveDiffProb { pub exclusion: Option, } -impl InputExclusiveDiffProb { +impl InputExclusiveDiffProb { pub fn new(prob: f64, tag: T, exclusion: Option) -> Self { - Self { prob, external_tag: Some(tag), exclusion } + Self { + prob, + external_tag: Some(tag), + exclusion, + } } pub fn new_without_gradient(prob: f64, exclusion: Option) -> Self { - Self { prob, external_tag: None, exclusion } + Self { + prob, + external_tag: None, + exclusion, + } } } -impl std::fmt::Debug for InputExclusiveDiffProb { +impl std::fmt::Debug for InputExclusiveDiffProb { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.prob.fmt(f) } } -impl From<(f64, T, Option)> for InputExclusiveDiffProb { +impl From<(f64, T, Option)> for InputExclusiveDiffProb { fn from((prob, tag, exclusion): (f64, T, Option)) -> Self { - Self { prob, external_tag: Some(tag), exclusion } + Self { + prob, + external_tag: Some(tag), + exclusion, + } } } -impl StaticInputTag for InputExclusiveDiffProb { +impl StaticInputTag for InputExclusiveDiffProb { fn from_dynamic_input_tag(t: &DynamicInputTag) -> Option { match t { DynamicInputTag::None => None, - DynamicInputTag::Bool(b) => Some(Self { prob: if *b { 1.0 } else { 0.0 }, external_tag: None, exclusion: None }), - DynamicInputTag::Exclusive(i) => Some(Self { prob: 1.0, external_tag: None, exclusion: Some(i.clone()) }), - DynamicInputTag::Float(prob) => Some(Self { prob: prob.clone(), external_tag: None, exclusion: None }), - DynamicInputTag::ExclusiveFloat(prob, i) => Some(Self { prob: prob.clone(), external_tag: None, exclusion: Some(i.clone()) }), + DynamicInputTag::Bool(b) => Some(Self { + prob: if *b { 1.0 } else { 0.0 }, + external_tag: None, + exclusion: None, + }), + DynamicInputTag::Exclusive(i) => Some(Self { + prob: 1.0, + external_tag: None, + exclusion: Some(i.clone()), + }), + DynamicInputTag::Float(prob) => Some(Self { + prob: prob.clone(), + external_tag: None, + exclusion: None, + }), + DynamicInputTag::ExclusiveFloat(prob, i) => Some(Self { + prob: prob.clone(), + external_tag: None, + exclusion: Some(i.clone()), + }), + DynamicInputTag::Tensor(t) => Some(Self { + prob: t.get_f64(), + external_tag: T::from_tensor(t.clone()), + exclusion: None, + }), } } } -impl ConvertFromInputTag<()> for InputExclusiveDiffProb { +impl ConvertFromInputTag<()> for InputExclusiveDiffProb { fn from_input_tag(_: ()) -> Option { None } } -impl ConvertFromInputTag for InputExclusiveDiffProb { +impl ConvertFromInputTag for InputExclusiveDiffProb { fn from_input_tag(b: bool) -> Option { if b { None @@ -64,7 +98,7 @@ impl ConvertFromInputTag for InputExclusiveDiffProb } } -impl ConvertFromInputTag for InputExclusiveDiffProb { +impl ConvertFromInputTag for InputExclusiveDiffProb { fn from_input_tag(u: usize) -> Option { if u > 0 { None @@ -74,7 +108,7 @@ impl ConvertFromInputTag for InputExclusiveDiffProb ConvertFromInputTag for InputExclusiveDiffProb { +impl ConvertFromInputTag for InputExclusiveDiffProb { fn from_input_tag(e: Exclusion) -> Option { match e { Exclusion::Independent => None, @@ -83,25 +117,25 @@ impl ConvertFromInputTag for InputExclusiveDiffPr } } -impl ConvertFromInputTag for InputExclusiveDiffProb { +impl ConvertFromInputTag for InputExclusiveDiffProb { fn from_input_tag(t: f64) -> Option { Some(Self::new_without_gradient(t, None)) } } -impl ConvertFromInputTag for InputExclusiveDiffProb { +impl ConvertFromInputTag for InputExclusiveDiffProb { fn from_input_tag(t: InputExclusiveProb) -> Option { Some(Self::new_without_gradient(t.prob.clone(), t.exclusion.clone())) } } -impl ConvertFromInputTag> for InputExclusiveDiffProb { +impl ConvertFromInputTag> for InputExclusiveDiffProb { fn from_input_tag(t: InputDiffProb) -> Option { Some(Self::new_without_gradient(t.0, None)) } } -impl ConvertFromInputTag> for InputExclusiveDiffProb { +impl ConvertFromInputTag> for InputExclusiveDiffProb { fn from_input_tag(t: InputExclusiveDiffProb) -> Option { Some(t.clone()) } diff --git a/core/src/runtime/provenance/common/input_tags/input_exclusive_prob.rs b/core/src/runtime/provenance/common/input_tags/input_exclusive_prob.rs index 3a47156..e98f544 100644 --- a/core/src/runtime/provenance/common/input_tags/input_exclusive_prob.rs +++ b/core/src/runtime/provenance/common/input_tags/input_exclusive_prob.rs @@ -1,4 +1,5 @@ use crate::common::input_tag::*; +use crate::common::tensors::*; use super::*; @@ -104,13 +105,13 @@ impl ConvertFromInputTag for InputExclusiveProb { } } -impl ConvertFromInputTag> for InputExclusiveProb { +impl ConvertFromInputTag> for InputExclusiveProb { fn from_input_tag(t: InputDiffProb) -> Option { Some(Self::new(t.0, None)) } } -impl ConvertFromInputTag> for InputExclusiveProb { +impl ConvertFromInputTag> for InputExclusiveProb { fn from_input_tag(t: InputExclusiveDiffProb) -> Option { Some(Self::new(t.prob.clone(), t.exclusion.clone())) } diff --git a/core/src/runtime/provenance/common/input_tags/natural.rs b/core/src/runtime/provenance/common/input_tags/natural.rs index dfd6c30..7fc80e3 100644 --- a/core/src/runtime/provenance/common/input_tags/natural.rs +++ b/core/src/runtime/provenance/common/input_tags/natural.rs @@ -1,4 +1,5 @@ use crate::common::input_tag::*; +use crate::common::tensors::*; use super::*; @@ -10,30 +11,43 @@ impl StaticInputTag for usize { DynamicInputTag::Bool(b) => Some(if *b { 1 } else { 0 }), DynamicInputTag::Float(f) => Some(if *f > 0.0 { 1 } else { 0 }), DynamicInputTag::ExclusiveFloat(_, _) => Some(1), + DynamicInputTag::Tensor(_) => Some(1), } } } impl ConvertFromInputTag<()> for usize { - fn from_input_tag(_: ()) -> Option { None } + fn from_input_tag(_: ()) -> Option { + None + } } impl ConvertFromInputTag for usize { - fn from_input_tag(t: bool) -> Option { Some(if t { 1 } else { 0 }) } + fn from_input_tag(t: bool) -> Option { + Some(if t { 1 } else { 0 }) + } } impl ConvertFromInputTag for usize { - fn from_input_tag(t: usize) -> Option { Some(t) } + fn from_input_tag(t: usize) -> Option { + Some(t) + } } impl ConvertFromInputTag for usize { - fn from_input_tag(t: f32) -> Option { Some(if t > 0.0 { 1 } else { 0 }) } + fn from_input_tag(t: f32) -> Option { + Some(if t > 0.0 { 1 } else { 0 }) + } } impl ConvertFromInputTag for usize { - fn from_input_tag(t: f64) -> Option { Some(if t > 0.0 { 1 } else { 0 }) } + fn from_input_tag(t: f64) -> Option { + Some(if t > 0.0 { 1 } else { 0 }) + } } -impl ConvertFromInputTag> for usize { - fn from_input_tag(t: InputDiffProb) -> Option { Some(if t.0 > 0.0 { 1 } else { 0 }) } +impl ConvertFromInputTag> for usize { + fn from_input_tag(t: InputDiffProb) -> Option { + Some(if t.0 > 0.0 { 1 } else { 0 }) + } } diff --git a/core/src/runtime/provenance/common/input_tags/unit.rs b/core/src/runtime/provenance/common/input_tags/unit.rs index 3104e21..5812669 100644 --- a/core/src/runtime/provenance/common/input_tags/unit.rs +++ b/core/src/runtime/provenance/common/input_tags/unit.rs @@ -1,4 +1,5 @@ use crate::common::input_tag::*; +use crate::common::tensors::*; use super::*; @@ -9,33 +10,49 @@ impl StaticInputTag for () { } impl ConvertFromInputTag<()> for () { - fn from_input_tag(_: ()) -> Option { Some(()) } + fn from_input_tag(_: ()) -> Option { + Some(()) + } } impl ConvertFromInputTag for () { - fn from_input_tag(_: bool) -> Option { Some(()) } + fn from_input_tag(_: bool) -> Option { + Some(()) + } } impl ConvertFromInputTag for () { - fn from_input_tag(_: usize) -> Option { Some(()) } + fn from_input_tag(_: usize) -> Option { + Some(()) + } } impl ConvertFromInputTag for () { - fn from_input_tag(_: f64) -> Option { Some(()) } + fn from_input_tag(_: f64) -> Option { + Some(()) + } } impl ConvertFromInputTag for () { - fn from_input_tag(_: Exclusion) -> Option { Some(()) } + fn from_input_tag(_: Exclusion) -> Option { + Some(()) + } } impl ConvertFromInputTag for () { - fn from_input_tag(_: InputExclusiveProb) -> Option { Some(()) } + fn from_input_tag(_: InputExclusiveProb) -> Option { + Some(()) + } } -impl ConvertFromInputTag> for () { - fn from_input_tag(_: InputDiffProb) -> Option { Some(()) } +impl ConvertFromInputTag> for () { + fn from_input_tag(_: InputDiffProb) -> Option { + Some(()) + } } -impl ConvertFromInputTag> for () { - fn from_input_tag(_: InputExclusiveDiffProb) -> Option { Some(()) } +impl ConvertFromInputTag> for () { + fn from_input_tag(_: InputExclusiveDiffProb) -> Option { + Some(()) + } } diff --git a/core/src/runtime/provenance/differentiable/diff_add_mult_prob.rs b/core/src/runtime/provenance/differentiable/diff_add_mult_prob.rs index b0dc877..3ff5887 100644 --- a/core/src/runtime/provenance/differentiable/diff_add_mult_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_add_mult_prob.rs @@ -2,16 +2,17 @@ use itertools::Itertools; use super::*; use crate::common::element::*; +use crate::common::tensors::*; use crate::runtime::dynamic::*; use crate::runtime::statics::*; use crate::utils::PointerFamily; -pub struct DiffAddMultProbProvenance { +pub struct DiffAddMultProbProvenance { pub valid_threshold: f64, pub storage: P::RcCell>, } -impl Clone for DiffAddMultProbProvenance { +impl Clone for DiffAddMultProbProvenance { fn clone(&self) -> Self { Self { valid_threshold: self.valid_threshold, @@ -20,7 +21,7 @@ impl Clone for DiffAddMultProbProvenance { } } -impl DiffAddMultProbProvenance { +impl DiffAddMultProbProvenance { pub fn input_tags(&self) -> Vec { P::get_rc_cell(&self.storage, |s| s.clone()) } @@ -43,7 +44,7 @@ impl DiffAddMultProbProvenance { } } -impl Default for DiffAddMultProbProvenance { +impl Default for DiffAddMultProbProvenance { fn default() -> Self { Self { valid_threshold: 0.0000, @@ -52,7 +53,7 @@ impl Default for DiffAddMultProbProvenance { } } -impl Provenance for DiffAddMultProbProvenance { +impl Provenance for DiffAddMultProbProvenance { type Tag = DualNumber2; type InputTag = InputDiffProb; diff --git a/core/src/runtime/provenance/differentiable/diff_max_mult_prob.rs b/core/src/runtime/provenance/differentiable/diff_max_mult_prob.rs index a827dfc..cb47230 100644 --- a/core/src/runtime/provenance/differentiable/diff_max_mult_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_max_mult_prob.rs @@ -2,16 +2,17 @@ use itertools::Itertools; use super::*; use crate::common::element::*; +use crate::common::tensors::*; use crate::runtime::dynamic::*; use crate::runtime::statics::*; use crate::utils::PointerFamily; -pub struct DiffMaxMultProbProvenance { +pub struct DiffMaxMultProbProvenance { pub valid_threshold: f64, pub storage: P::RcCell>, } -impl Clone for DiffMaxMultProbProvenance { +impl Clone for DiffMaxMultProbProvenance { fn clone(&self) -> Self { Self { valid_threshold: self.valid_threshold, @@ -20,7 +21,7 @@ impl Clone for DiffMaxMultProbProvenance { } } -impl DiffMaxMultProbProvenance { +impl DiffMaxMultProbProvenance { pub fn input_tags(&self) -> Vec { P::get_rc_cell(&self.storage, |s| s.clone()) } @@ -43,7 +44,7 @@ impl DiffMaxMultProbProvenance { } } -impl Default for DiffMaxMultProbProvenance { +impl Default for DiffMaxMultProbProvenance { fn default() -> Self { Self { valid_threshold: 0.0000, @@ -52,7 +53,7 @@ impl Default for DiffMaxMultProbProvenance { } } -impl Provenance for DiffMaxMultProbProvenance { +impl Provenance for DiffMaxMultProbProvenance { type Tag = DualNumber2; type InputTag = InputDiffProb; diff --git a/core/src/runtime/provenance/differentiable/diff_min_max_prob.rs b/core/src/runtime/provenance/differentiable/diff_min_max_prob.rs index 53a2e63..2c00bc3 100644 --- a/core/src/runtime/provenance/differentiable/diff_min_max_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_min_max_prob.rs @@ -2,6 +2,7 @@ use itertools::Itertools; use super::*; use crate::common::element::*; +use crate::common::tensors::*; use crate::common::value_type::*; use crate::runtime::dynamic::*; use crate::runtime::statics::*; @@ -47,12 +48,12 @@ impl std::fmt::Debug for Prob { impl Tag for Prob {} -pub struct DiffMinMaxProbProvenance { +pub struct DiffMinMaxProbProvenance { pub storage: P::RcCell>, pub valid_threshold: f64, } -impl Clone for DiffMinMaxProbProvenance { +impl Clone for DiffMinMaxProbProvenance { fn clone(&self) -> Self { Self { valid_threshold: self.valid_threshold, @@ -61,7 +62,7 @@ impl Clone for DiffMinMaxProbProvenance { } } -impl DiffMinMaxProbProvenance { +impl DiffMinMaxProbProvenance { pub fn collect_chosen_elements<'a, E>(&self, all: &'a Vec, chosen_ids: &Vec) -> Vec<&'a E> where E: Element, @@ -103,7 +104,7 @@ impl DiffMinMaxProbProvenance { } } -impl Default for DiffMinMaxProbProvenance { +impl Default for DiffMinMaxProbProvenance { fn default() -> Self { Self { valid_threshold: -0.0001, @@ -113,21 +114,21 @@ impl Default for DiffMinMaxProbProvenance { } #[derive(Clone)] -pub struct OutputDiffProb(pub f64, pub usize, pub i32, pub Option); +pub struct OutputDiffProb(pub f64, pub usize, pub i32, pub Option); -impl std::fmt::Debug for OutputDiffProb { +impl std::fmt::Debug for OutputDiffProb { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("").field(&self.0).field(&self.1).field(&self.2).finish() } } -impl std::fmt::Display for OutputDiffProb { +impl std::fmt::Display for OutputDiffProb { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_tuple("").field(&self.0).field(&self.1).field(&self.2).finish() } } -impl Provenance for DiffMinMaxProbProvenance { +impl Provenance for DiffMinMaxProbProvenance { type Tag = Prob; type InputTag = InputDiffProb; @@ -151,9 +152,19 @@ impl Provenance for DiffMinMaxProbProvenan fn recover_fn(&self, t: &Self::Tag) -> Self::OutputTag { match &t.1 { - Derivative::Pos(fact_id) => OutputDiffProb(t.0, *fact_id, 1, Some(P::get_rc_cell(&self.storage, |s| s[*fact_id].clone()))), + Derivative::Pos(fact_id) => OutputDiffProb( + t.0, + *fact_id, + 1, + Some(P::get_rc_cell(&self.storage, |s| s[*fact_id].clone())), + ), Derivative::Zero => OutputDiffProb(t.0, 0, 0, None), - Derivative::Neg(fact_id) => OutputDiffProb(t.0, *fact_id, -1, Some(P::get_rc_cell(&self.storage, |s| s[*fact_id].clone()))), + Derivative::Neg(fact_id) => OutputDiffProb( + t.0, + *fact_id, + -1, + Some(P::get_rc_cell(&self.storage, |s| s[*fact_id].clone())), + ), } } diff --git a/core/src/runtime/provenance/differentiable/diff_nand_min_prob.rs b/core/src/runtime/provenance/differentiable/diff_nand_min_prob.rs index 9b7d6ea..13110de 100644 --- a/core/src/runtime/provenance/differentiable/diff_nand_min_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_nand_min_prob.rs @@ -2,16 +2,17 @@ use itertools::Itertools; use super::*; use crate::common::element::*; +use crate::common::tensors::*; use crate::runtime::dynamic::*; use crate::runtime::statics::*; use crate::utils::PointerFamily; -pub struct DiffNandMinProbProvenance { +pub struct DiffNandMinProbProvenance { pub valid_threshold: f64, pub storage: P::RcCell>, } -impl Clone for DiffNandMinProbProvenance { +impl Clone for DiffNandMinProbProvenance { fn clone(&self) -> Self { Self { valid_threshold: self.valid_threshold, @@ -20,7 +21,7 @@ impl Clone for DiffNandMinProbProvenance { } } -impl DiffNandMinProbProvenance { +impl DiffNandMinProbProvenance { pub fn input_tags(&self) -> Vec { P::get_rc_cell(&self.storage, |s| s.clone()) } @@ -43,7 +44,7 @@ impl DiffNandMinProbProvenance { } } -impl Default for DiffNandMinProbProvenance { +impl Default for DiffNandMinProbProvenance { fn default() -> Self { Self { valid_threshold: 0.0000, @@ -52,7 +53,7 @@ impl Default for DiffNandMinProbProvenance { } } -impl Provenance for DiffNandMinProbProvenance { +impl Provenance for DiffNandMinProbProvenance { type Tag = DualNumber2; type InputTag = InputDiffProb; diff --git a/core/src/runtime/provenance/differentiable/diff_nand_mult_prob.rs b/core/src/runtime/provenance/differentiable/diff_nand_mult_prob.rs index 90e49bd..3711840 100644 --- a/core/src/runtime/provenance/differentiable/diff_nand_mult_prob.rs +++ b/core/src/runtime/provenance/differentiable/diff_nand_mult_prob.rs @@ -2,16 +2,17 @@ use itertools::Itertools; use super::*; use crate::common::element::*; +use crate::common::tensors::*; use crate::runtime::dynamic::*; use crate::runtime::statics::*; use crate::utils::PointerFamily; -pub struct DiffNandMultProbProvenance { +pub struct DiffNandMultProbProvenance { pub valid_threshold: f64, pub storage: P::RcCell>, } -impl Clone for DiffNandMultProbProvenance { +impl Clone for DiffNandMultProbProvenance { fn clone(&self) -> Self { Self { valid_threshold: self.valid_threshold, @@ -20,7 +21,7 @@ impl Clone for DiffNandMultProbProvenance { } } -impl DiffNandMultProbProvenance { +impl DiffNandMultProbProvenance { pub fn input_tags(&self) -> Vec { P::get_rc_cell(&self.storage, |s| s.clone()) } @@ -43,7 +44,7 @@ impl DiffNandMultProbProvenance { } } -impl Default for DiffNandMultProbProvenance { +impl Default for DiffNandMultProbProvenance { fn default() -> Self { Self { valid_threshold: 0.0000, @@ -52,7 +53,7 @@ impl Default for DiffNandMultProbProvenance { } } -impl Provenance for DiffNandMultProbProvenance { +impl Provenance for DiffNandMultProbProvenance { type Tag = DualNumber2; type InputTag = InputDiffProb; diff --git a/core/src/runtime/provenance/differentiable/diff_sample_k_proofs.rs b/core/src/runtime/provenance/differentiable/diff_sample_k_proofs.rs index a98eb47..2efa5c2 100644 --- a/core/src/runtime/provenance/differentiable/diff_sample_k_proofs.rs +++ b/core/src/runtime/provenance/differentiable/diff_sample_k_proofs.rs @@ -1,17 +1,19 @@ use rand::prelude::*; use rand::rngs::StdRng; -use super::*; +use crate::common::tensors::*; use crate::utils::PointerFamily; -pub struct DiffSampleKProofsProvenance { +use super::*; + +pub struct DiffSampleKProofsProvenance { pub k: usize, pub sampler: P::Cell, pub storage: DiffProbStorage, pub disjunctions: P::Cell, } -impl Clone for DiffSampleKProofsProvenance { +impl Clone for DiffSampleKProofsProvenance { fn clone(&self) -> Self { Self { k: self.k, @@ -22,7 +24,7 @@ impl Clone for DiffSampleKProofsProvenance { } } -impl DiffSampleKProofsProvenance { +impl DiffSampleKProofsProvenance { pub fn new(k: usize) -> Self { Self::new_with_seed(k, 12345678) } @@ -45,7 +47,7 @@ impl DiffSampleKProofsProvenance { } } -impl DNFContextTrait for DiffSampleKProofsProvenance { +impl DNFContextTrait for DiffSampleKProofsProvenance { fn fact_probability(&self, id: &usize) -> f64 { self.storage.fact_probability(id) } @@ -55,7 +57,7 @@ impl DNFContextTrait for DiffSampleKProofsProvenance } } -impl Provenance for DiffSampleKProofsProvenance { +impl Provenance for DiffSampleKProofsProvenance { type Tag = DNFFormula; type InputTag = InputExclusiveDiffProb; @@ -67,7 +69,11 @@ impl Provenance for DiffSampleKProofsProve } fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { - let InputExclusiveDiffProb { prob, external_tag, exclusion } = input_tag; + let InputExclusiveDiffProb { + prob, + external_tag, + exclusion, + } = input_tag; // First store the probability and generate the id let fact_id = self.storage.add_prob(prob, external_tag); diff --git a/core/src/runtime/provenance/differentiable/diff_top_bottom_k_clauses.rs b/core/src/runtime/provenance/differentiable/diff_top_bottom_k_clauses.rs index 355cfc2..33fdbf7 100644 --- a/core/src/runtime/provenance/differentiable/diff_top_bottom_k_clauses.rs +++ b/core/src/runtime/provenance/differentiable/diff_top_bottom_k_clauses.rs @@ -2,18 +2,20 @@ use std::collections::*; use itertools::Itertools; -use super::*; +use crate::common::tensors::*; use crate::runtime::dynamic::*; use crate::runtime::statics::*; use crate::utils::*; -pub struct DiffTopBottomKClausesProvenance { +use super::*; + +pub struct DiffTopBottomKClausesProvenance { pub k: usize, pub storage: DiffProbStorage, pub disjunctions: P::Cell, } -impl Clone for DiffTopBottomKClausesProvenance { +impl Clone for DiffTopBottomKClausesProvenance { fn clone(&self) -> Self { Self { k: self.k, @@ -23,7 +25,7 @@ impl Clone for DiffTopBottomKClausesProven } } -impl DiffTopBottomKClausesProvenance { +impl DiffTopBottomKClausesProvenance { pub fn new(k: usize) -> Self { Self { k, @@ -41,7 +43,7 @@ impl DiffTopBottomKClausesProvenance } } -impl CNFDNFContextTrait for DiffTopBottomKClausesProvenance { +impl CNFDNFContextTrait for DiffTopBottomKClausesProvenance { fn fact_probability(&self, id: &usize) -> f64 { self.storage.fact_probability(id) } @@ -51,7 +53,7 @@ impl CNFDNFContextTrait for DiffTopBottomK } } -impl Provenance for DiffTopBottomKClausesProvenance { +impl Provenance for DiffTopBottomKClausesProvenance { type Tag = CNFDNFFormula; type InputTag = InputExclusiveDiffProb; @@ -63,7 +65,11 @@ impl Provenance for DiffTopBottomKClausesP } fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { - let InputExclusiveDiffProb { prob, external_tag, exclusion } = input_tag; + let InputExclusiveDiffProb { + prob, + external_tag, + exclusion, + } = input_tag; // First store the probability and generate the id let fact_id = self.storage.add_prob(prob, external_tag); diff --git a/core/src/runtime/provenance/differentiable/diff_top_k_proofs.rs b/core/src/runtime/provenance/differentiable/diff_top_k_proofs.rs index 694716c..52e0942 100644 --- a/core/src/runtime/provenance/differentiable/diff_top_k_proofs.rs +++ b/core/src/runtime/provenance/differentiable/diff_top_k_proofs.rs @@ -1,17 +1,19 @@ use itertools::Itertools; -use super::*; +use crate::common::tensors::*; use crate::runtime::dynamic::*; use crate::runtime::statics::*; use crate::utils::*; -pub struct DiffTopKProofsProvenance { +use super::*; + +pub struct DiffTopKProofsProvenance { pub k: usize, pub storage: DiffProbStorage, pub disjunctions: P::Cell, } -impl Clone for DiffTopKProofsProvenance { +impl Clone for DiffTopKProofsProvenance { fn clone(&self) -> Self { Self { k: self.k, @@ -21,7 +23,7 @@ impl Clone for DiffTopKProofsProvenance { } } -impl DiffTopKProofsProvenance { +impl DiffTopKProofsProvenance { pub fn new(k: usize) -> Self { Self { k, @@ -39,7 +41,7 @@ impl DiffTopKProofsProvenance { } } -impl DNFContextTrait for DiffTopKProofsProvenance { +impl DNFContextTrait for DiffTopKProofsProvenance { fn fact_probability(&self, id: &usize) -> f64 { self.storage.fact_probability(id) } @@ -49,7 +51,7 @@ impl DNFContextTrait for DiffTopKProofsProvenance Provenance for DiffTopKProofsProvenance { +impl Provenance for DiffTopKProofsProvenance { type Tag = DNFFormula; type InputTag = InputExclusiveDiffProb; @@ -61,7 +63,11 @@ impl Provenance for DiffTopKProofsProvenan } fn tagging_fn(&self, input_tag: Self::InputTag) -> Self::Tag { - let InputExclusiveDiffProb { prob, external_tag, exclusion } = input_tag; + let InputExclusiveDiffProb { + prob, + external_tag, + exclusion, + } = input_tag; // First store the probability and generate the id let fact_id = self.storage.add_prob(prob, external_tag); diff --git a/core/src/runtime/provenance/differentiable/mod.rs b/core/src/runtime/provenance/differentiable/mod.rs index 2148d98..8de7483 100644 --- a/core/src/runtime/provenance/differentiable/mod.rs +++ b/core/src/runtime/provenance/differentiable/mod.rs @@ -6,6 +6,5 @@ pub mod diff_nand_mult_prob; pub mod diff_sample_k_proofs; pub mod diff_top_bottom_k_clauses; pub mod diff_top_k_proofs; -pub mod diff_top_k_proofs_indiv; use super::*; diff --git a/core/src/runtime/provenance/probabilistic/prob_proofs.rs b/core/src/runtime/provenance/probabilistic/prob_proofs.rs index 2d568e5..ec5c6f5 100644 --- a/core/src/runtime/provenance/probabilistic/prob_proofs.rs +++ b/core/src/runtime/provenance/probabilistic/prob_proofs.rs @@ -173,9 +173,7 @@ impl Provenance for ProbProofsProvenance

{ let mut prod = Self::Tag::cartesian_product(t1, t2); prod .proofs - .retain(|proof| { - P::get_cell(&self.disjunctions, |d| !d.has_conflict(&proof.facts)) - }); + .retain(|proof| P::get_cell(&self.disjunctions, |d| !d.has_conflict(&proof.facts))); prod } diff --git a/core/src/runtime/provenance/provenance.rs b/core/src/runtime/provenance/provenance.rs index a321fdb..6978cd0 100644 --- a/core/src/runtime/provenance/provenance.rs +++ b/core/src/runtime/provenance/provenance.rs @@ -2,7 +2,6 @@ use std::collections::HashSet; use std::fmt::{Debug, Display}; use rand::distributions::WeightedIndex; -use rand::prelude::*; use super::*; @@ -12,17 +11,28 @@ use crate::runtime::dynamic::*; use crate::runtime::env::*; use crate::runtime::statics::*; +/// A provenance pub trait Provenance: Clone + 'static { - type Tag: Tag; - + /// The input tag space of the provenance type InputTag: Clone + Debug + StaticInputTag; + /// The (internal) tag space of the provenance + type Tag: Tag; + + /// The output tag space of the provenance type OutputTag: Clone + Debug + Display; + /// The name of the provenance fn name() -> &'static str; + /// Converting input tag to internal tag fn tagging_fn(&self, ext_tag: Self::InputTag) -> Self::Tag; + /// Converting a maybe input tag to internal tag; + /// if the input tag does not exist, we use the `one` tag. + /// + /// Custom provenance may overwrite this to get special behavior when + /// there is no input tag fn tagging_optional_fn(&self, ext_tag: Option) -> Self::Tag { match ext_tag { Some(et) => self.tagging_fn(et), @@ -30,21 +40,30 @@ pub trait Provenance: Clone + 'static { } } + /// Convert the internal tag to the output tag fn recover_fn(&self, t: &Self::Tag) -> Self::OutputTag; + /// Check if we want to discard a fact with the given tag fn discard(&self, t: &Self::Tag) -> bool; + /// The `zero` element in the internal tag space fn zero(&self) -> Self::Tag; + /// The `one` element in the internal tag space fn one(&self) -> Self::Tag; + /// Adding two tags fn add(&self, t1: &Self::Tag, t2: &Self::Tag) -> Self::Tag; - fn saturated(&self, t_old: &Self::Tag, t_new: &Self::Tag) -> bool; - + /// Multiply two tags fn mult(&self, t1: &Self::Tag, t2: &Self::Tag) -> Self::Tag; - fn negate(&self, _: &Self::Tag) -> Option { + /// Negate a tag. + /// + /// If `None` is returned, the tuple will be discarded/removed. + /// By default (if not implemented), negating a tag results in `None`, + #[allow(unused)] + fn negate(&self, t: &Self::Tag) -> Option { None } @@ -52,7 +71,14 @@ pub trait Provenance: Clone + 'static { self.negate(t2).map(|neg_t2| self.mult(t1, &neg_t2)) } - fn weight(&self, _: &Self::Tag) -> f64 { + /// Check if a tag has saturated given its old and new versions + fn saturated(&self, t_old: &Self::Tag, t_new: &Self::Tag) -> bool; + + /// Get the weight of a tag + /// + /// By default (if not implemented), every tag are weighted equally by having a weight of 1 + #[allow(unused)] + fn weight(&self, tag: &Self::Tag) -> f64 { 1.0 } @@ -116,9 +142,7 @@ pub trait Provenance: Clone + 'static { } else { let weights = batch.iter().map(|e| self.weight(&e.tag)).collect::>(); let dist = WeightedIndex::new(&weights).unwrap(); - let sampled_ids = (0..k) - .map(|_| dist.sample(&mut *rt.rng.lock().unwrap())) - .collect::>(); + let sampled_ids = (0..k).map(|_| rt.random.sample_from(&dist)).collect::>(); batch .into_iter() .enumerate() @@ -193,9 +217,7 @@ pub trait Provenance: Clone + 'static { } else { let weights = batch.iter().map(|e| self.weight(&e.tag)).collect::>(); let dist = WeightedIndex::new(&weights).unwrap(); - let sampled_ids = (0..k) - .map(|_| dist.sample(&mut *rt.rng.lock().unwrap())) - .collect::>(); + let sampled_ids = (0..k).map(|_| rt.random.sample_from(&dist)).collect::>(); batch .into_iter() .enumerate() diff --git a/core/src/testing/test_collection.rs b/core/src/testing/test_collection.rs index 5a4b7b0..1a9101d 100644 --- a/core/src/testing/test_collection.rs +++ b/core/src/testing/test_collection.rs @@ -42,9 +42,7 @@ impl> From, Tup)>> for pub fn test_equals(t1: &Tuple, t2: &Tuple) -> bool { match (t1, t2) { - (Tuple::Tuple(ts1), Tuple::Tuple(ts2)) => { - ts1.iter().zip(ts2.iter()).all(|(s1, s2)| test_equals(s1, s2)) - }, + (Tuple::Tuple(ts1), Tuple::Tuple(ts2)) => ts1.iter().zip(ts2.iter()).all(|(s1, s2)| test_equals(s1, s2)), (Tuple::Value(Value::F32(t1)), Tuple::Value(Value::F32(t2))) => { if t1.is_infinite() && t1.is_sign_positive() && t2.is_infinite() && t2.is_sign_positive() { true @@ -55,7 +53,7 @@ pub fn test_equals(t1: &Tuple, t2: &Tuple) -> bool { } else { (t1 - t2).abs() < 0.001 } - }, + } (Tuple::Value(Value::F64(t1)), Tuple::Value(Value::F64(t2))) => { if t1.is_infinite() && t1.is_sign_positive() && t2.is_infinite() && t2.is_sign_positive() { true @@ -66,7 +64,7 @@ pub fn test_equals(t1: &Tuple, t2: &Tuple) -> bool { } else { (t1 - t2).abs() < 0.001 } - }, + } _ => t1 == t2, } } @@ -100,11 +98,8 @@ where } } -pub fn expect_output_collection( - name: &str, - actual: &DynamicOutputCollection, - expected: C -) where +pub fn expect_output_collection(name: &str, actual: &DynamicOutputCollection, expected: C) +where Prov: Provenance, Prov::Tag: std::fmt::Debug, C: Into, @@ -115,7 +110,13 @@ pub fn expect_output_collection( for e in &expected.elements { let te = e.clone().into(); let pos = actual.iter().position(|(_, tuple)| test_equals(&tuple, &te)); - assert!(pos.is_some(), "Tuple {:?} not found in `{}` collection {:?}", te, name, actual) + assert!( + pos.is_some(), + "Tuple {:?} not found in `{}` collection {:?}", + te, + name, + actual + ) } // Then check everything in actual is in expected diff --git a/core/src/testing/test_compile.rs b/core/src/testing/test_compile.rs index f96b4d8..259a6cb 100644 --- a/core/src/testing/test_compile.rs +++ b/core/src/testing/test_compile.rs @@ -49,3 +49,27 @@ where } } } + +/// Expect the given program fails to compile in the FRONT compilation stage +/// +/// The given `f` takes in an error `String` and returns whether that string +/// represents a particular error that the user expected. +pub fn expect_front_compile_failure_with_modifier(s: &str, m: M, f: F) +where + M: Fn(&mut compiler::front::FrontContext), + F: Fn(String) -> bool, +{ + let mut ctx = compiler::front::FrontContext::new(); + m(&mut ctx); + match ctx.compile_string(s.to_string()) { + Ok(_) => { + panic!("Compilation passed; expected failure") + } + Err(err) => { + let err = format!("{}", err); + if !f(err) { + panic!("Expected failure not found") + } + } + } +} diff --git a/core/src/testing/test_interpret.rs b/core/src/testing/test_interpret.rs index 5d69f20..698f5bc 100644 --- a/core/src/testing/test_interpret.rs +++ b/core/src/testing/test_interpret.rs @@ -1,6 +1,7 @@ use crate::common::tuple::Tuple; use crate::integrate::*; use crate::runtime::database::*; +use crate::runtime::error::*; use crate::runtime::monitor; use crate::runtime::provenance::*; use crate::utils::*; @@ -66,3 +67,34 @@ pub fn expect_interpret_within_iter_limit(s: &str, iter_limit: usize) { let monitor = monitor::IterationCheckingMonitor::new(iter_limit); interpret_string_with_ctx_and_monitor(s.to_string(), prov, &monitor).expect("Interpret Error"); } + +pub fn expect_interpret_failure(s: &str) { + let result = interpret_string(s.to_string()); + match result { + Ok(_) => panic!("Interpreting succeeded instead of expected failure"), + Err(err) => match err { + IntegrateError::Compile(_) => panic!("Expecting runtime error but got compile error instead"), + IntegrateError::Runtime(_) => { /* GOOD */ } + }, + } +} + +pub fn expect_interpret_specific_failure(s: &str, f: F) +where + F: Fn(RuntimeError) -> bool, +{ + let result = interpret_string(s.to_string()); + match result { + Ok(_) => panic!("Interpreting succeeded instead of expected failure"), + Err(err) => match err { + IntegrateError::Compile(_) => panic!("Expecting runtime error but got compile error instead"), + IntegrateError::Runtime(r) => { + if f(r) { + /* GOOD */ + } else { + panic!("Did not capture expected runtime failure") + } + } + }, + } +} diff --git a/core/src/utils/chrono.rs b/core/src/utils/chrono.rs index cbb22c9..22cfee6 100644 --- a/core/src/utils/chrono.rs +++ b/core/src/utils/chrono.rs @@ -1,6 +1,10 @@ /// Parse a string into a chrono DateTime +/// +/// If the time portion is not supplied in the input string, the time will be +/// default to 12:00:00am UTC time pub fn parse_date_time_string(d: &str) -> Option> { - dateparser::parse(d).ok() + let midnight_naive = chrono::NaiveTime::from_hms_opt(0, 0, 0).unwrap(); + dateparser::parse_with(d, &chrono::Utc, midnight_naive).ok() } /// Parse a string into a chrono Duration diff --git a/core/src/utils/copy_on_write.rs b/core/src/utils/copy_on_write.rs index 2af2c23..b2056c2 100644 --- a/core/src/utils/copy_on_write.rs +++ b/core/src/utils/copy_on_write.rs @@ -19,6 +19,10 @@ impl CopyOnWrite { f(&mut new_inner); *self = Self(P::new_rc(new_inner)); } + + pub fn modify_without_copy(&mut self, f: F) { + f(P::get_rc_mut(&mut self.0)); + } } impl Clone for CopyOnWrite { diff --git a/core/src/utils/float.rs b/core/src/utils/float.rs index e4e12e9..07ced4f 100644 --- a/core/src/utils/float.rs +++ b/core/src/utils/float.rs @@ -1,17 +1,17 @@ /// Floating Point trait (f32, f64) pub trait Float: - Sized + - Copy + - Clone + - PartialEq + - PartialOrd + - std::fmt::Debug + - std::fmt::Display + - std::ops::Add + - std::ops::Sub + - std::ops::Mul + - std::ops::Div + - std::convert::TryInto + Sized + + Copy + + Clone + + PartialEq + + PartialOrd + + std::fmt::Debug + + std::fmt::Display + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::ops::Div + + std::convert::TryInto { fn zero() -> Self; diff --git a/core/src/utils/id_allocator.rs b/core/src/utils/id_allocator.rs index 879ab19..2ce9849 100644 --- a/core/src/utils/id_allocator.rs +++ b/core/src/utils/id_allocator.rs @@ -1,3 +1,6 @@ +use std::sync::*; + +/// An ID allocator #[derive(Debug, Clone, Default)] pub struct IdAllocator { id: usize, @@ -18,3 +21,27 @@ impl IdAllocator { result } } + +/// An alternative ID allocator which has internal mutability +#[derive(Debug, Clone, Default)] +pub struct IdAllocator2 { + pub id_allocator: Arc>, +} + +impl IdAllocator2 { + pub fn new() -> Self { + Self { + id_allocator: Arc::new(Mutex::new(IdAllocator::new())), + } + } + + pub fn new_with_start(start: usize) -> Self { + Self { + id_allocator: Arc::new(Mutex::new(IdAllocator::new_with_start(start))), + } + } + + pub fn alloc(&self) -> usize { + self.id_allocator.lock().unwrap().alloc() + } +} diff --git a/core/src/utils/indexed_retain.rs b/core/src/utils/indexed_retain.rs new file mode 100644 index 0000000..a7977d7 --- /dev/null +++ b/core/src/utils/indexed_retain.rs @@ -0,0 +1,30 @@ +pub trait IndexedRetain { + fn retain_with_index(&mut self, f: F) + where + F: FnMut(usize, &T) -> bool; +} + +impl IndexedRetain for Vec { + fn retain_with_index(&mut self, mut f: F) + where + F: FnMut(usize, &T) -> bool, // the signature of the callback changes + { + let len = self.len(); + let mut del = 0; + { + let v = &mut **self; + + for i in 0..len { + // only implementation change here + if !f(i, &v[i]) { + del += 1; + } else if del > 0 { + v.swap(i - del, i); + } + } + } + if del > 0 { + self.truncate(len - del); + } + } +} diff --git a/core/src/utils/integer.rs b/core/src/utils/integer.rs index 05b24a8..cca8f25 100644 --- a/core/src/utils/integer.rs +++ b/core/src/utils/integer.rs @@ -1,20 +1,20 @@ /// Integer trait (i8 - i128, u8 - u128, isize, usize) pub trait Integer: - Sized + - Copy + - Clone + - PartialEq + - Eq + - PartialOrd + - Ord + - std::fmt::Debug + - std::fmt::Display + - std::ops::Add + - std::ops::Sub + - std::ops::Mul + - std::ops::Div + - std::convert::TryInto + - std::convert::TryInto + Sized + + Copy + + Clone + + PartialEq + + Eq + + PartialOrd + + Ord + + std::fmt::Debug + + std::fmt::Display + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::ops::Div + + std::convert::TryInto + + std::convert::TryInto { fn zero() -> Self; @@ -24,8 +24,12 @@ pub trait Integer: macro_rules! impl_integer { ($type:ty) => { impl Integer for $type { - fn zero() -> Self { 0 } - fn one() -> Self { 1 } + fn zero() -> Self { + 0 + } + fn one() -> Self { + 1 + } } }; } diff --git a/core/src/utils/mod.rs b/core/src/utils/mod.rs index 643f9d5..7454e33 100644 --- a/core/src/utils/mod.rs +++ b/core/src/utils/mod.rs @@ -4,6 +4,7 @@ mod chrono; mod copy_on_write; mod float; mod id_allocator; +mod indexed_retain; mod integer; mod pointer_family; @@ -11,5 +12,6 @@ pub use self::chrono::*; pub(crate) use copy_on_write::*; pub use float::*; pub(crate) use id_allocator::*; +pub(crate) use indexed_retain::*; pub use integer::*; pub use pointer_family::*; diff --git a/core/src/utils/pointer_family.rs b/core/src/utils/pointer_family.rs index 3474993..9211bba 100644 --- a/core/src/utils/pointer_family.rs +++ b/core/src/utils/pointer_family.rs @@ -9,7 +9,6 @@ use std::sync::{Arc, Mutex}; /// where `Pointer` is the simple reference counted pointer and `Cell` /// contains a internally mutable pointer. pub trait PointerFamily: Clone + PartialEq + 'static { - /* ==================== Ref Counted ==================== */ /// Reference counted pointer @@ -113,14 +112,14 @@ impl PointerFamily for ArcFamily { fn get_cell(ptr: &Self::Cell, f: F) -> O where - F: FnOnce(&T) -> O + F: FnOnce(&T) -> O, { f(&ptr.lock().unwrap()) } fn get_cell_mut(ptr: &Self::Cell, f: F) -> O where - F: FnOnce(&mut T) -> O + F: FnOnce(&mut T) -> O, { f(&mut ptr.lock().unwrap()) } diff --git a/core/tests/compiler/adt.rs b/core/tests/compiler/adt.rs new file mode 100644 index 0000000..5553d46 --- /dev/null +++ b/core/tests/compiler/adt.rs @@ -0,0 +1,153 @@ +use scallop_core::integrate::*; +use scallop_core::runtime::provenance::*; +use scallop_core::testing::*; +use scallop_core::utils::*; + +#[test] +fn adt_duplicated_name_1() { + expect_front_compile_failure( + r#" + type Expr = Const(i32) | Add(Expr, Expr) | Minus(Expr, Expr) | Add(Expr, Expr) + "#, + |e| e.contains("duplicate"), + ) +} + +#[test] +fn adt_duplicated_name_2() { + expect_front_compile_failure( + r#" + type Expr = Const(i32) | Add(Expr, Expr) + type Expr2 = Const(i32) | Sub(Expr, Expr) + "#, + |e| e.contains("duplicate"), + ) +} + +#[test] +fn adt_entity() { + expect_compile( + r#" + type Expr = Const(i32) | Add(Expr, Expr) + const MY_EXPR = Add(Const(5), Const(3)) + "#, + ) +} + +#[test] +fn adt_entity_fail_1() { + expect_front_compile_failure( + r#" + type Expr = Const(i32) | Add(Expr, Expr) + const MY_EXPR = Add(Const(5 + 5), Const(3)) + "#, + |e| e.contains("non-constant"), + ) +} + +#[test] +fn adt_entity_fail_2() { + expect_front_compile_failure( + r#" + type Expr = Const(i32) | Add(Expr, Expr) + const MY_EXPR = Sub(Const(5), Const(3)) + "#, + |e| e.contains("unknown algebraic data type variant"), + ) +} + +#[test] +fn adt_entity_arity_mismatch_1() { + expect_front_compile_failure( + r#" + type Expr = Const(i32) | Add(Expr, Expr) + const MY_EXPR = Add(Const(5), Const(3), Const(7)) + "#, + |e| e.contains("arity mismatch"), + ) +} + +#[test] +fn adt_entity_type_error_1() { + expect_front_compile_failure( + r#" + type Expr = Const(i32) | Add(Expr, Expr) + const MY_EXPR = Add(Const(5), Const("this is a string")) + "#, + |e| e.contains("cannot unify types"), + ) +} + +#[test] +fn adt_add_dynamic_entity_1() { + let prov = unit::UnitProvenance::new(); + let mut ctx = IntegrateContext::::new(prov); + + // Compile a program containing ADT definitions + ctx + .add_program( + r#" + type Expr = Const(i32) | Add(Expr, Expr) + type root(e: Expr) + rel eval(e, y) = case e is Const(y) + rel eval(e, y1 + y2) = case e is Add(e1, e2) and eval(e1, y1) and eval(e2, y2) + rel result(y) = root(e) and eval(e, y) + "#, + ) + .expect("Compile error"); + + // Dynamically add an entity to the context + ctx + .add_entity("root", vec!["Add(Const(5), Add(Const(2), Const(3)))".to_string()]) + .expect("Cannot add entity"); + + // Run the context + ctx.run().expect("Runtime error"); + + // Check the results + expect_output_collection( + "result", + ctx.computed_relation_ref("result").expect("Cannot get result"), + vec![(10i32,)], + ); +} + +#[test] +fn adt_add_dynamic_entity_2() { + let prov = unit::UnitProvenance::new(); + let mut ctx = IntegrateContext::::new(prov); + + // Compile a program containing ADT definitions + ctx + .add_program( + r#" + type Expr = Const(i32) | Add(Expr, Expr) + type root(id: i32, e: Expr) + rel eval(e, y) = case e is Const(y) + rel eval(e, y1 + y2) = case e is Add(e1, e2) and eval(e1, y1) and eval(e2, y2) + rel result(id, y) = root(id, e) and eval(e, y) + "#, + ) + .expect("Compile error"); + + // Dynamically add an entity to the context + ctx + .add_entity( + "root", + vec!["1".to_string(), "Add(Const(5), Add(Const(2), Const(3)))".to_string()], + ) + .expect("Cannot add entity"); + ctx + .add_entity("root", vec!["2".to_string(), "Const(3)".to_string()]) + .expect("Cannot add entity"); + + // Run the context + ctx.run().expect("Runtime error"); + + // Check the results + expect_output_collection( + "result", + ctx.computed_relation_ref("result").expect("Cannot get result"), + vec![(1i32, 10i32), (2, 3)], + ); +} diff --git a/core/tests/compiler/mod.rs b/core/tests/compiler/mod.rs index 38ed184..6988fc7 100644 --- a/core/tests/compiler/mod.rs +++ b/core/tests/compiler/mod.rs @@ -1,3 +1,4 @@ +mod adt; mod errors; mod incremental; mod parse; diff --git a/core/tests/compiler/parse.rs b/core/tests/compiler/parse.rs index ac59a81..d33dc00 100644 --- a/core/tests/compiler/parse.rs +++ b/core/tests/compiler/parse.rs @@ -29,3 +29,85 @@ fn parse_rule() { assert!(str_to_item(r#"rel path(a, b) :- path(a, c) /\ edge(c, b)"#).is_ok()); assert!(str_to_item(r#"rel path(a, b) :- edge(a, b) \/ path(a, c) /\ edge(c, b)"#).is_ok()); } + +#[test] +fn ignore_comment_1() { + assert!(str_to_item(r#"rel relate = { /* this is a comment */ }"#).is_ok()); + assert!(str_to_item(r#"rel relate = /* this is a comment */ { }"#).is_ok()); + assert!(str_to_item(r#"rel relate /* this is a comment */ = { }"#).is_ok()); +} + +#[test] +fn ignore_comment_2() { + let items = str_to_items(r#" + rel relate = { /* this is a comment */ } + rel another_relate() = // this is another comment + some_atom() /* this is another comment */ + "#).expect("Compile failure"); + assert_eq!(items.len(), 2); +} + +#[test] +fn ignore_comment_3() { + let items = str_to_items(r#" + rel relate = { (3, 5 /* , 4, pretending to be commented out */) } + "#).expect("Compile failure"); + assert_eq!(items.len(), 1); +} + +#[test] +fn test_parse_specialized_predicate_1() { + let (id, args) = str_to_specialized_predicate("range").expect("Cannot parse"); + assert_eq!(id.name(), "range"); + assert_eq!(id.loc.offset_span.start, 0); + assert_eq!(id.loc.offset_span.end, 5); + assert_eq!(args.len(), 1); + assert_eq!(args[0].name(), "usize"); + assert_eq!(args[0].loc.offset_span.start, 6); + assert_eq!(args[0].loc.offset_span.end, 11); +} + +#[test] +fn test_parse_specialized_predicate_2() { + let (id, args) = str_to_specialized_predicate("range< usize,usize >").expect("Cannot parse"); + assert_eq!(id.name(), "range"); + assert_eq!(id.loc.offset_span.start, 0); + assert_eq!(id.loc.offset_span.end, 5); + assert_eq!(args.len(), 2); + assert_eq!(args[0].name(), "usize"); + assert_eq!(args[0].loc.offset_span.start, 9); + assert_eq!(args[0].loc.offset_span.end, 14); + assert_eq!(args[1].name(), "usize"); + assert_eq!(args[1].loc.offset_span.start, 15); + assert_eq!(args[1].loc.offset_span.end, 20); +} + +#[test] +fn test_parse_specialized_predicate_3() { + let (id, args) = str_to_specialized_predicate("dasdf").expect("Cannot parse"); + assert_eq!(id.name(), "dasdf"); + assert_eq!(id.loc.offset_span.start, 0); + assert_eq!(id.loc.offset_span.end, 5); + assert_eq!(args.len(), 2); + assert_eq!(args[0].name(), "usize"); + assert_eq!(args[0].loc.offset_span.start, 6); + assert_eq!(args[0].loc.offset_span.end, 11); + assert_eq!(args[1].name(), "f32"); + assert_eq!(args[1].loc.offset_span.start, 13); + assert_eq!(args[1].loc.offset_span.end, 16); +} + +#[test] +fn test_parse_specialized_predicate_4() { + let (id, args) = str_to_specialized_predicate("dasdf < usize , f32 >").expect("Cannot parse"); + assert_eq!(id.name(), "dasdf"); + assert_eq!(id.loc.offset_span.start, 0); + assert_eq!(id.loc.offset_span.end, 5); + assert_eq!(args.len(), 2); + assert_eq!(args[0].name(), "usize"); + assert_eq!(args[0].loc.offset_span.start, 10); + assert_eq!(args[0].loc.offset_span.end, 15); + assert_eq!(args[1].name(), "f32"); + assert_eq!(args[1].loc.offset_span.start, 21); + assert_eq!(args[1].loc.offset_span.end, 24); +} diff --git a/core/tests/integrate/adt.rs b/core/tests/integrate/adt.rs new file mode 100644 index 0000000..32f4d25 --- /dev/null +++ b/core/tests/integrate/adt.rs @@ -0,0 +1,308 @@ +use scallop_core::testing::*; + +#[test] +fn adt_arith_formula_eval_1() { + expect_interpret_result( + r#" + type Expr = Const(i32) | Add(Expr, Expr) + + rel eval(x, y) = case x is Const(y) + rel eval(x, y1 + y2) = case x is Add(x1, x2) and + eval(x1, y1) and eval(x2, y2) + + const MY_EXPR = Add(Const(5), Add(Const(3), Const(6))) + + rel result(y) = eval(MY_EXPR, y) + "#, + ("result", vec![(14i32,)]), + ) +} + +#[test] +fn adt_list_1() { + expect_interpret_result( + r#" + type List = Nil() | Cons(i32, List) + + rel list_sum(l, 0) = case l is Nil() + rel list_sum(l, hd + s) = case l is Cons(hd, tl) and list_sum(tl, s) + + const MY_LIST = Cons(1, Cons(2, Cons(3, Nil()))) + + rel result(y) = list_sum(MY_LIST, y) + "#, + ("result", vec![(6i32,)]), + ) +} + +#[test] +fn adt_binary_tree_1() { + expect_interpret_result( + r#" + type Tree = Nil() | Node(i32, Tree, Tree) + + rel tree_depth(t, 0) = case t is Nil() + rel tree_depth(t, $max(ld, rd) + 1) = case t is Node(_, lt, rt) and + tree_depth(lt, ld) and tree_depth(rt, rd) + + const MY_TREE = Node(1, Node(2, Nil(), Node(3, Nil(), Nil())), Node(4, Nil(), Nil())) + + rel result(y) = tree_depth(MY_TREE, y) + "#, + ("result", vec![(3i32,)]), + ) +} + +const RE_PROGRAM: &'static str = r#" + type RE = Char(char) | Nil() | Con(RE, RE) | Or(RE, RE) | Star(RE) + + rel match(r, i, i) = case r is Nil(), string_chars(s, i, _), input_string(s) + rel match(r, i, i + 1) = case r is Char(c), input_string(s), string_chars(s, i, c) + rel match(r, s, e) = case r is Con(r1, r2), match(r1, s, m), match(r2, m, e) + rel match(r, s, e) = case r is Or(r1, r2), match(r1, s, e) + rel match(r, s, e) = case r is Or(r1, r2), match(r2, s, e) + rel match(r, i, i) = case r is Star(r1), string_chars(s, i, _), input_string(s) + rel match(r, s, e) = case r is Star(r1), match(r1, s, e) + rel match(r, s, e) = case r is Star(r1), match(r1, s, m), match(r, m, e) +"#; + +#[test] +fn adt_regex_1() { + expect_interpret_result( + &format!( + "{RE_PROGRAM}\n{}", + r#" + const MY_RE = Con(Char('a'), Char('b')) + rel input_string("ab") + rel result() = match(MY_RE, 0, 2) + "#, + ), + ("result", vec![()]), + ) +} + +#[test] +fn adt_regex_2() { + expect_interpret_result( + &format!( + "{RE_PROGRAM}\n{}", + r#" + const MY_RE = Con(Star(Char('a')), Char('b')) + rel input_string("aaaaaaaab") + rel result() = match(MY_RE, 0, 9) + "#, + ), + ("result", vec![()]), + ) +} + +const CLEVR_PROGRAM: &'static str = r#" + type Color = RED | GREEN | BLUE + type Size = LARGE | SMALL + type SpatialRela = LEFT | RIGHT + type Expr = Scene() | Color(Color, Expr) | Size(Size, Expr) | Rela(SpatialRela, Expr, Expr) | RelaInv(SpatialRela, Expr, Expr) + + rel eval(e, output_obj) = case e is Scene(), input_obj_ids(output_obj) + rel eval(e, output_obj) = case e is Color(c, e1), eval(e1, output_obj), input_obj_color(output_obj, c) + rel eval(e, output_obj) = case e is Size(s, e1), eval(e1, output_obj), input_obj_size(output_obj, s) + rel eval(e, o2) = case e is Rela(r, e1, e2), eval(e1, o1), eval(e2, o2), input_obj_rela(r, o1, o2) + rel eval(e, o1) = case e is RelaInv(r, e1, e2), eval(e1, o1), eval(e2, o2), input_obj_rela(r, o1, o2) +"#; + +#[test] +fn adt_clevr_1() { + expect_interpret_result( + &format!( + "{CLEVR_PROGRAM}\n{}", + r#" + rel input_obj_ids = {0, 1} + rel input_obj_color = {(0, RED), (1, GREEN)} + rel input_obj_size = {(0, LARGE), (1, SMALL)} + rel input_obj_rela = {(0, 1, LEFT), (1, 0, RIGHT)} + + const MY_EXPR = Color(RED, Scene()) + + rel result(o) = eval(MY_EXPR, o) + "#, + ), + ("result", vec![(0usize,)]), + ) +} + +const EQSAT_1_PROGRAM: &'static str = r#" + // The language for simple symbolic arithmetic expression + type Expr = Const(i32) + | Var(String) + | Add(Expr, Expr) + + // A relation `to_string` for visualizing + rel to_string(p, i as String) = case p is Const(i) + rel to_string(p, v) = case p is Var(v) + rel to_string(p, $format("({} + {})", s1, s2)) = case p is Add(p1, p2) and to_string(p1, s1) and to_string(p2, s2) + + // Relation for expression + rel expr(p) = case p is Const(_) or case p is Var(_) or case p is Add(_, _) + + // Definition of rewrite rules suggesting equivalence + rel equivalent(p, p) = expr(p) + rel equivalent(p1, p3) = equivalent(p1, p2) and equivalent(p2, p3) + rel equivalent(p, new Add(b, a)) = case p is Add(a, b) + rel equivalent(p1, new Add(a2, b2)) = case p1 is Add(a1, b1) and equivalent(a1, a2) and equivalent(b1, b2) + rel equivalent(p, new Const(a + b)) = case p is Add(Const(a), Const(b)) + rel equivalent(p, p1) = case p is Add(p1, Const(0)) + + // Definition of weight + rel weight(p, 1) = case p is Const(_) + rel weight(p, 1) = case p is Var(_) + rel weight(p, w1 + w2 + 1) = case p is Add(p1, p2) and weight(p1, w1) and weight(p2, w2) + + // Compute equivalent programs + rel equiv_programs(sp) = input_program(p) and equivalent(p, sp) + + // Find the best program (minimum weight) among all programs equivalent to p + rel best_program(p) = w := min[p](w: equiv_programs(p) and weight(p, w)) + rel best_program_str(s) = best_program(best_prog) and to_string(best_prog, s) + query best_program_str +"#; + +#[test] +fn equality_saturation_1() { + expect_interpret_result( + &format!( + "{EQSAT_1_PROGRAM}\n{}", + r#" + const PROGRAM = Add(Add(Const(3), Const(-3)), Var("a")) + rel input_program(PROGRAM) + "#, + ), + ("best_program_str", vec![("a".to_string(),)]), + ) +} + +#[test] +fn equality_saturation_2() { + expect_interpret_result( + &format!( + "{EQSAT_1_PROGRAM}\n{}", + r#" + const PROGRAM = Add(Add(Const(3), Const(-3)), Const(5)) + rel input_program(PROGRAM) + "#, + ), + ("best_program_str", vec![("5".to_string(),)]), + ) +} + +const TYPE_INF_1_PROGRAM: &'static str = r#" + type Op = EQ | NEQ | GEQ | LEQ | GT | LT | AND | OR | XOR | ADD | SUB | MUL | DIV | NEG | NOT + + type Expr = Number(i32) + | Boolean(bool) + | Variable(String) + | Binary(Op, Expr, Expr) + | Unary(Op, Expr) + | Let(String, Expr, Expr) + | Ite(Expr, Expr, Expr) + + type Type = BOOL | INT + + type input_program(expr: Expr) + + // ================= + + // Pretty printing of operators + rel op_to_string = { + (EQ, "=="), (NEQ, "!="), + (GEQ, ">="), (LEQ, "<="), (GT, ">"), (LT, "<"), + (AND, "&&"), (OR, "||"), (XOR, "^"), + (ADD, "+"), (SUB, "-"), (MUL, "*"), (DIV, "/"), + (NEG, "-"), (NOT, "!") + } + + // Pretty printing of type + rel ty_to_string = {(BOOL, "bool"), (INT, "int")} + + // Pretty printing of expressions + rel expr_to_string(e, x as String) = case e is Number(x) + rel expr_to_string(e, x as String) = case e is Boolean(x) + rel expr_to_string(e, x) = case e is Variable(x) + rel expr_to_string(e, $format("({} {} {})", op1_str, op_str, op2_str)) = case e is Binary(op, op1, op2) and expr_to_string(op1, op1_str) and expr_to_string(op2, op2_str) and op_to_string(op, op_str) + rel expr_to_string(e, $format("({}{})", op_str, op1_str)) = case e is Unary(op, op1) and expr_to_string(op1, op1_str) and op_to_string(op, op_str) + rel expr_to_string(e, $format("let {} = {} in {}", x, b_str, i_str)) = case e is Let(x, b, i) and expr_to_string(b, b_str) and expr_to_string(i, i_str) + rel expr_to_string(e, $format("if {} then {} else {}", cs, ts, es)) = case e is Ite(c, t, e) and expr_to_string(c, cs) and expr_to_string(t, ts) and expr_to_string(e, es) + + // ================= + + // Basic types of operators + rel eq_op = {EQ, NEQ} + rel comp_op = {GEQ, LEQ, GT, LT} + rel logical_op = {AND, OR, XOR} + rel arith_op = {ADD, SUB, MUL, DIV} + rel unary_arith_op = {NEG} + rel unary_logical_op = {NOT} + + // Typing environment + type Env = Empty() | Cons(String, Type, Env) + const EMPTY_ENV = Empty() + + // Find a variable stored in the typing environment + @demand("bbf") + rel find_type(env, var, ty) = case e is Cons(var, ty, _) + rel find_type(env, var, ty) = case e is Cons(vp, _, tl) and vp != var and find_type(tl, var, ty) + + // The type (`ty`) of an expression (`expr`) under an environment (`env`) + type type_of(env: Env, expr: Expr, ty: Type) + + // Typing rules + @demand("bbf") + rel type_of(env, e, BOOL) = case e is Boolean(_) + rel type_of(env, e, INT) = case e is Number(_) + rel type_of(env, e, ty) = case e is Variable(x) and find_type(env, x, ty) + rel type_of(env, e, BOOL) = case e is Binary(op, op1, op2) and eq_op(op) and type_of(env, op1, ty) and type_of(env, op2, ty) + rel type_of(env, e, BOOL) = case e is Binary(op, op1, op2) and comp_op(op) and type_of(env, op1, INT) and type_of(env, op2, INT) + rel type_of(env, e, BOOL) = case e is Binary(op, op1, op2) and logical_op(op) and type_of(env, op1, BOOL) and type_of(env, op2, BOOL) + rel type_of(env, e, INT) = case e is Binary(op, op1, op2) and arith_op(op) and type_of(env, op1, INT) and type_of(env, op2, INT) + rel type_of(env, e, BOOL) = case e is Unary(op, op1) and unary_logical_op(op) and type_of(env, op1, BOOL) + rel type_of(env, e, INT) = case e is Unary(op, op1) and unary_arith_op(op) and type_of(env, op1, INT) + rel type_of(env, e, ty_i) = to_infer_let_cons(env, e, sub_env, i) and type_of(sub_env, i, ty_i) + rel type_of(env, e, ty) = case e is Ite(c, t, e) and type_of(env, c, BOOL) and type_of(env, t, ty) and type_of(env, e, ty) + + // Helpers + @demand("bbff") + rel to_infer_let_cons(env, e, new Cons(x, ty_b, env), i) = case e is Let(x, b, i) and type_of(env, b, ty_b) + + // The result if the type of the input program + rel result(expr_str, ty_str) = input_program(p) and expr_to_string(p, expr_str) and type_of(EMPTY_ENV, p, ty) and ty_to_string(ty, ty_str) + query result +"#; + +#[test] +fn type_inf_1() { + expect_interpret_result( + &format!( + "{TYPE_INF_1_PROGRAM}\n{}", + r#" + const PROGRAM = Let("x", Number(3), Binary(EQ, Variable("x"), Number(4))) + rel input_program(PROGRAM) + "#, + ), + ( + "result", + vec![("let x = 3 in (x == 4)".to_string(), "bool".to_string())], + ), + ) +} + +#[test] +fn type_inf_2() { + expect_interpret_empty_result( + &format!( + "{TYPE_INF_1_PROGRAM}\n{}", + r#" + const PROGRAM = Let("x", Number(3), Binary(ADD, Variable("x"), Boolean(false))) + rel input_program(PROGRAM) + "#, + ), + "result", + ) +} diff --git a/core/tests/integrate/attr.rs b/core/tests/integrate/attr.rs new file mode 100644 index 0000000..71406a5 --- /dev/null +++ b/core/tests/integrate/attr.rs @@ -0,0 +1,67 @@ +use scallop_core::testing::*; +use scallop_core::compiler::front::FrontContext; +use scallop_core::compiler::front::ast; +use scallop_core::compiler::front::attribute::*; + +mod attr_1 { + use super::*; + + #[derive(Clone)] + struct Foo; + + impl AttributeProcessor for Foo { + fn name(&self) -> String { + "foo".to_string() + } + + fn apply(&self, _: &ast::Item, attr: &ast::Attribute) -> Result { + if attr.num_pos_args() != 3 { + Err(AttributeError::new_custom("foo attribute requires 3 arguments".to_string())) + } else { + if attr.pos_arg(0).and_then(|arg| Some(arg.is_tuple())) == Some(true) { + Ok(AttributeAction::Nothing) + } else { + Err(AttributeError::new_custom("foo attribute requires a tuple as the first argument".to_string())) + } + } + } + } + + #[test] + fn attr_1_test_1() { + expect_compile( + r#" + @foo((1, 2), 3, 4) + type my_relation(a: i32, b: i32) + "#, + ); + } + + #[test] + fn attr_1_test_2() { + expect_front_compile_failure_with_modifier( + r#" + @foo((1, 2), 3) + type my_relation(a: i32, b: i32) + "#, + |ctx: &mut FrontContext| { + ctx.register_attribute_processor(Foo).expect("Cannot register attribute"); + }, + |s| s.contains("foo attribute requires 3 arguments"), + ); + } + + #[test] + fn attr_1_test_3() { + expect_front_compile_failure_with_modifier( + r#" + @foo("asdfasdf", 3, 5) + type my_relation(a: i32, b: i32) + "#, + |ctx: &mut FrontContext| { + ctx.register_attribute_processor(Foo).expect("Cannot register attribute"); + }, + |s| s.contains("foo attribute requires a tuple as the first argument"), + ); + } +} diff --git a/core/tests/integrate/basic.rs b/core/tests/integrate/basic.rs index d18a99f..d52e7cf 100644 --- a/core/tests/integrate/basic.rs +++ b/core/tests/integrate/basic.rs @@ -1,6 +1,6 @@ use scallop_core::runtime::provenance::*; -use scallop_core::utils::*; use scallop_core::testing::*; +use scallop_core::utils::*; #[test] fn basic_edge_path_left_recursion() { @@ -653,10 +653,7 @@ fn test_count_with_where_clause() { // Count how many student enrolls in CS class in each class rel count_enroll_cs_in_class(c, n) :- n = count(s: student(c, s), enroll(s, "CS") where c: classes(c)) "#, - vec![( - "count_enroll_cs_in_class", - vec![(0, 1usize), (1, 2), (2, 0)].into(), - )], + vec![("count_enroll_cs_in_class", vec![(0, 1usize), (1, 2), (2, 0)].into())], ) } @@ -669,10 +666,7 @@ fn test_exists_path_1() { rel result1(b) = b := exists(path(0, 2)) rel result2(b) = b := exists(path(0, 3)) "#, - vec![ - ("result1", vec![(true,)].into()), - ("result2", vec![(false,)].into()), - ], + vec![("result1", vec![(true,)].into()), ("result2", vec![(false,)].into())], ) } @@ -1175,14 +1169,16 @@ fn disjunctive_1() { let prov = proofs::ProofsProvenance::::default(); // Pre-generate true tags and false tags - let true_tag = proofs::Proofs::from_proofs(vec![ - proofs::Proof::from_facts(vec![0, 1, 2, 4].into_iter()), - proofs::Proof::from_facts(vec![0, 1, 2, 5].into_iter()), - proofs::Proof::from_facts(vec![0, 1, 3, 4].into_iter()), - ].into_iter()); - let false_tag = proofs::Proofs::from_proofs(vec![ - proofs::Proof::from_facts(vec![0, 1, 3, 5].into_iter()), - ].into_iter()); + let true_tag = proofs::Proofs::from_proofs( + vec![ + proofs::Proof::from_facts(vec![0, 1, 2, 4].into_iter()), + proofs::Proof::from_facts(vec![0, 1, 2, 5].into_iter()), + proofs::Proof::from_facts(vec![0, 1, 3, 4].into_iter()), + ] + .into_iter(), + ); + let false_tag = + proofs::Proofs::from_proofs(vec![proofs::Proof::from_facts(vec![0, 1, 3, 5].into_iter())].into_iter()); // Test expect_interpret_result_with_tag( @@ -1196,3 +1192,392 @@ fn disjunctive_1() { proofs::Proofs::eq, ) } + +#[test] +fn escape_single_newline_char() { + expect_interpret_result( + r#" + rel str = {"Hello\nWorld"} + "#, + ("str", vec![("Hello\nWorld".to_string(),)]), + ); +} + +#[test] +fn escape_multiple_newline_char() { + expect_interpret_result( + r#" + rel str = {"Hello\n\n\nWorld"} + "#, + ("str", vec![("Hello\n\n\nWorld".to_string(),)]), + ); +} + +#[test] +fn escape_newline_char_end() { + expect_interpret_result( + r#" + rel str = {"Scallop\n"} + "#, + ("str", vec![("Scallop\n".to_string(),)]), + ); +} + +#[test] +fn escape_newline_char_beginning() { + expect_interpret_result( + r#" + rel str = {"\nScallop"} + "#, + ("str", vec![("\nScallop".to_string(),)]), + ); +} + +#[test] +fn escape_tab_char() { + expect_interpret_result( + r#" + rel str = {"Hello\tWorld"} + "#, + ("str", vec![("Hello\tWorld".to_string(),)]), + ); +} + +#[test] +fn escape_null_char() { + expect_interpret_result( + r#" + rel str = {"Null\0"} + "#, + ("str", vec![("Null\0".to_string(),)]), + ); +} + +#[test] +fn escape_single_quote() { + expect_interpret_result( + r#" + rel str = {"Here is a quote: \'Hi\'"} + "#, + ("str", vec![("Here is a quote: \'Hi\'".to_string(),)]), + ); +} + +#[test] +fn escape_double_quote() { + expect_interpret_result( + r#" + rel str = {"Here is a quote: \"Hi\""} + "#, + ("str", vec![("Here is a quote: \"Hi\"".to_string(),)]), + ); +} + +#[test] +fn escape_backslash() { + expect_interpret_result( + r#" + rel str = {"Back \\ Slash"} + "#, + ("str", vec![("Back \\ Slash".to_string(),)]), + ); +} + +#[test] +fn escape_carriage_return() { + expect_interpret_result( + r#" + rel str = {"Carriage Return\r"} + "#, + ("str", vec![("Carriage Return\r".to_string(),)]), + ); +} + +// #[test] +// fn escape_unicode() { +// expect_interpret_result( +// r#" +// rel str = {"Thumbs up: \u{1F44D}"} +// "#, +// ("str", vec![("Thumbs up: \u{1F44D}".to_string(),)]), +// ); +// } + +#[test] +fn escape_emoji_unicode() { + expect_interpret_result( + r#" + rel str = {"Thumbs up: 👍"} + "#, + ("str", vec![("Thumbs up: \u{1F44D}".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string() { + expect_interpret_result( + r#" + rel str = {"""This +is +a +multiline +string"""} + "#, + ("str", vec![("This\nis\na\nmultiline\nstring".to_string(),)]), + ); +} + +#[test] +fn escape_indented_multiline_string() { + expect_interpret_result( + r#" + rel str = {"""This + is + a + multiline + string"""} + "#, + ("str", vec![("This\n is\n a\n multiline\n string".to_string(),)]), + ); +} + +#[test] +fn escape_mix_multiline_string_before() { + expect_interpret_result( + r#" + rel str = {"""This +is +a +multiline +string""", "A regular string"} + "#, + ("str", vec![("This\nis\na\nmultiline\nstring".to_string(),), ("A regular string".to_string(),)]), + ); +} + +#[test] +fn escape_mix_multiline_string_after() { + expect_interpret_result( + r#" + rel str = {"First string", """This +is +a +multiline +string"""} + "#, + ("str", vec![("First string".to_string(),), ("This\nis\na\nmultiline\nstring".to_string(),)]), + ); +} + +#[test] +fn escape_multiple_multiline_string() { + expect_interpret_result( + r#" + rel str = {"""This +is +a +multiline +string""", +"""A +second +multiline +string"""} + "#, + ("str", vec![("This\nis\na\nmultiline\nstring".to_string(),), ("A\nsecond\nmultiline\nstring".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_with_regular_string() { + expect_interpret_result( + r#" + rel str = {"""Here is a multiline string with quote: +"This is a test" +By John Doe"""} + "#, + ("str", vec![("Here is a multiline string with quote:\n\"This is a test\"\nBy John Doe".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_with_multiple_regular_string() { + expect_interpret_result( + r#" + rel str = {"""Here is a multiline string with quote: +"This is a test" +By +"Anonymous" +"""} + "#, + ("str", vec![("Here is a multiline string with quote:\n\"This is a test\"\nBy\n\"Anonymous\"\n".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_with_escaped_double_quote_string() { + expect_interpret_result( + r#" + rel str = {"""Here is a multiline string with quote: +\"\"This is a test\"\" +By Jane Doe"""} + "#, + ("str", vec![("Here is a multiline string with quote:\n\"\"This is a test\"\"\nBy Jane Doe".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_with_double_quote_string() { + expect_interpret_result( + r#" + rel str = {"""Here is a multiline string with quote: +""This is a test"" +By Jane Doe"""} + "#, + ("str", vec![("Here is a multiline string with quote:\nThis is a test\nBy Jane Doe".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_with_triple_quote() { + expect_interpret_result( + r#" + rel str = {"""This is not the end +\"\"\" +But this is"""} + "#, + ("str", vec![("This is not the end\n\"\"\"\nBut this is".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_as_single_line() { + expect_interpret_result( + r#" + rel str = {"""This is only one line"""} + "#, + ("str", vec![("This is only one line".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_single_newline_char() { + expect_interpret_result( + r#" + rel str = {"""Hello +\n +World"""} + "#, + ("str", vec![("Hello\n\n\nWorld".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_multiple_newline_char() { + expect_interpret_result( + r#" + rel str = {"""Hello +\n\n\n +World"""} + "#, + ("str", vec![("Hello\n\n\n\n\nWorld".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_newline_char_end() { + expect_interpret_result( + r#" + rel str = {"""Scallop +\n"""} + "#, + ("str", vec![("Scallop\n\n".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_newline_char_beginning() { + expect_interpret_result( + r#" + rel str = {"""\n +Scallop"""} + "#, + ("str", vec![("\n\nScallop".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_tab_char() { + expect_interpret_result( + r#" + rel str = {"""Hello +\tWorld"""} + "#, + ("str", vec![("Hello\n\tWorld".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_null_char() { + expect_interpret_result( + r#" + rel str = {"""Null +\0"""} + "#, + ("str", vec![("Null\n\0".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_single_quote() { + expect_interpret_result( + r#" + rel str = {"""Here is a quote: +\'Hi\'"""} + "#, + ("str", vec![("Here is a quote:\n\'Hi\'".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_double_quote() { + expect_interpret_result( + r#" + rel str = {"""Here is a quote: +\"Hi\""""} + "#, + ("str", vec![("Here is a quote:\n\"Hi\"".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_backslash() { + expect_interpret_result( + r#" + rel str = {"""Back +\\ Slash"""} + "#, + ("str", vec![("Back\n\\ Slash".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_carriage_return() { + expect_interpret_result( + r#" + rel str = {"""Carriage +Return\r"""} + "#, + ("str", vec![("Carriage\nReturn\r".to_string(),)]), + ); +} + +#[test] +fn escape_multiline_string_emoji_unicode() { + expect_interpret_result( + r#" + rel str = {"""Thumbs up: +👍"""} + "#, + ("str", vec![("Thumbs up:\n\u{1F44D}".to_string(),)]), + ); +} diff --git a/core/tests/integrate/ff.rs b/core/tests/integrate/ff.rs index 0cb8b01..f40a24f 100644 --- a/core/tests/integrate/ff.rs +++ b/core/tests/integrate/ff.rs @@ -1,12 +1,12 @@ use std::convert::*; -use scallop_core::utils::*; -use scallop_core::common::value::*; use scallop_core::common::foreign_function::*; use scallop_core::common::type_family::*; -use scallop_core::runtime::provenance; +use scallop_core::common::value::*; use scallop_core::integrate; +use scallop_core::runtime::provenance; use scallop_core::testing::*; +use scallop_core::utils::*; #[derive(Clone)] pub struct Fib; @@ -87,7 +87,10 @@ fn test_fib_ff() { ctx.add_rule(r#"S(x, $fib(x)) = R(x)"#).unwrap(); // Facts - ctx.edb().add_facts("R", vec![(-10i32,), (0,), (3,), (5,), (8,)]).unwrap(); + ctx + .edb() + .add_facts("R", vec![(-10i32,), (0,), (3,), (5,), (8,)]) + .unwrap(); // Execution ctx.run().unwrap(); @@ -205,3 +208,294 @@ fn ff_substring_2() { ("result", vec![("world!".to_string(),)]), ); } + +#[test] +fn ff_floor_1() { + expect_interpret_result( + r#" + rel my_rel = {-1.5, 3.8, 5.0, -6.0} + rel result($floor(x)) = my_rel(x) + "#, + ("result", vec![(-2.0f32,), (3.0,), (5.0,), (-6.0,)]), + ); +} + +#[test] +fn ff_floor_2() { + expect_interpret_result( + r#" + rel my_rel = {-1, 50, 12345, 0} + rel result($floor(x)) = my_rel(x) + "#, + ("result", vec![(-1i32,), (50,), (12345,), (0,)]), + ); +} + +#[test] +fn ff_ceil_1() { + expect_interpret_result( + r#" + rel my_rel = {-1.5, 3.8, 5.0, -6.0} + rel result($ceil(x)) = my_rel(x) + "#, + ("result", vec![(-1.0f32,), (4.0,), (5.0,), (-6.0,)]), + ); +} + +#[test] +fn ff_ceil_2() { + expect_interpret_result( + r#" + rel my_rel = {-1, 50, 12345, 0} + rel result($ceil(x)) = my_rel(x) + "#, + ("result", vec![(-1i32,), (50,), (12345,), (0,)]), + ); +} + +#[test] +fn ff_exp_1() { + expect_interpret_result( + r#" + rel my_rel = {-1.5, 3.8, 0.0} + rel result($exp(x)) = my_rel(x) + "#, + ("result", vec![((-1.5f32).exp(),), (3.8f32.exp(),), (1.0,)]), + ); +} + +#[test] +fn ff_exp2_1() { + expect_interpret_result( + r#" + rel my_rel = {-1.5, 3.8, 0.0} + rel result($exp2(x)) = my_rel(x) + "#, + ("result", vec![((-1.5f32).exp2(),), (3.8f32.exp2(),), (1.0,)]), + ); +} + +#[test] +fn ff_log_1() { + expect_interpret_result( + r#" + rel my_rel = {12.5, 345.89, 1.0, -2.7} + rel result($log(x)) = my_rel(x) + "#, + ("result", vec![(12.5f32.ln(),), (345.89f32.ln(),), (0.0,)]), + ); +} + +#[test] +fn ff_log2_1() { + expect_interpret_result( + r#" + rel my_rel = {12.5, 345.89, 1.0, -2.7} + rel result($log2(x)) = my_rel(x) + "#, + ("result", vec![(12.5f32.log2(),), (345.89f32.log2(),), (0.0,)]), + ); +} + +#[test] +fn ff_pow_1() { + expect_interpret_result( + r#" + rel base = {2, 7} + rel exp = {3, 10, 0} + rel result(n) = base(x), exp(y), n == $pow(x, y) + "#, + ( + "result", + vec![ + (2i32.pow(3),), + (2i32.pow(10),), + (1,), + (7i32.pow(3),), + (7i32.pow(10),), + (1,), + ], + ), + ); +} + +#[test] +fn ff_powf_1() { + expect_interpret_result( + r#" + rel base = {1.5, 3.8} + rel exp = {-2.5, 4.8, 0.0} + rel result(n) = base(x), exp(y), n == $powf(x, y) + "#, + ( + "result", + vec![ + (1.5f32.powf(-2.5),), + (1.5f32.powf(4.8),), + (1.0,), + (3.8f32.powf(-2.5),), + (3.8f32.powf(4.8),), + (1.0,), + ], + ), + ); +} + +#[test] +fn ff_acos_1() { + expect_interpret_result( + r#" + rel my_rel = {-1.0, 1.0, 0.0, 1.5, -1.5} + rel result($acos(x)) = my_rel(x) + "#, + ("result", vec![((-1.0f32).acos(),), (1.0f32.acos(),), (0.0f32.acos(),)]), + ); +} + +#[test] +fn ff_asin_1() { + expect_interpret_result( + r#" + rel my_rel = {-1.0, 1.0, 0.0, 1.5, -1.5} + rel result($asin(x)) = my_rel(x) + "#, + ("result", vec![((-1.0f32).asin(),), (1.0f32.asin(),), (0.0f32.asin(),)]), + ); +} + +#[test] +fn ff_atan_1() { + expect_interpret_result( + r#" + rel my_rel = {-1.0, 1.0, 0.0} + rel result($atan(x)) = my_rel(x) + "#, + ("result", vec![((-1.0f32).atan(),), (1.0f32.atan(),), (0.0f32.atan(),)]), + ); +} + +#[test] +fn ff_atan2_1() { + expect_interpret_result( + r#" + rel first = {1.5, -3.8, 0.0} + rel second = {-2.5, 4.8, 0.0} + rel result(n) = first(y), second(x), n == $atan2(y, x) + "#, + ( + "result", + vec![ + (1.5f32.atan2(-2.5),), + (1.5f32.atan2(4.8),), + (1.5f32.atan2(0.0),), + ((-3.8f32).atan2(-2.5),), + ((-3.8f32).atan2(4.8),), + ((-3.8f32).atan2(0.0),), + (0.0f32.atan2(-2.5),), + (0.0f32.atan2(4.8),), + (0.0f32.atan2(0.0),), + ], + ), + ); +} + +#[test] +fn ff_sign_1() { + expect_interpret_result( + r#" + rel my_rel = {-12.5, 34.6, 0.0} + rel result($sign(x)) = my_rel(x) + "#, + ("result", vec![(-1i32,), (1,), (0,)]), + ); +} + +#[test] +fn ff_sign_2() { + expect_interpret_result( + r#" + rel my_rel = {-12, 34, 0} + rel result($sign(x)) = my_rel(x) + "#, + ("result", vec![(-1i32,), (1,), (0,)]), + ); +} + +#[test] +fn ff_format_1() { + expect_interpret_result( + r#" + rel strings = {"hello", "world!"} + rel result(x) = strings(a), strings(b), a != b, x == $format("{} {}", a, b) + "#, + ( + "result", + vec![("hello world!".to_string(),), ("world! hello".to_string(),)], + ), + ); +} + +#[test] +fn ff_format_2() { + expect_interpret_result( + r#" + rel numbers = {1.2, -3.4} + rel result(x) = numbers(a), numbers(b), a != b, x == $format("{} {}", a, b) + "#, + ("result", vec![("1.2 -3.4".to_string(),), ("-3.4 1.2".to_string(),)]), + ); +} + +#[test] +fn ff_string_lower_1() { + expect_interpret_result( + r#" + rel string = {"Hello World!", "aBcDeF1234."} + rel result($string_lower(s)) = string(s) + "#, + ( + "result", + vec![("hello world!".to_string(),), ("abcdef1234.".to_string(),)], + ), + ); +} + +#[test] +fn ff_string_upper_1() { + expect_interpret_result( + r#" + rel string = {"Hello World!", "aBcDeF1234."} + rel result($string_upper(s)) = string(s) + "#, + ( + "result", + vec![("HELLO WORLD!".to_string(),), ("ABCDEF1234.".to_string(),)], + ), + ); +} + +#[test] +fn ff_string_index_of_1() { + expect_interpret_result( + r#" + rel string = {"Scallop is cool!"} + rel substring = {"o", "is", "Scallop"} + rel result(i) = string(s), substring(t), i == $string_index_of(s, t) + "#, + ("result", vec![(5usize,), (8,), (0,)]), + ); +} + +#[test] +fn ff_string_trim_1() { + expect_interpret_result( + r#" + rel string = {"Hello World!", " \t ABC def 123 \n"} + rel result($string_trim(s)) = string(s) + "#, + ( + "result", + vec![("Hello World!".to_string(),), ("ABC def 123".to_string(),)], + ), + ); +} diff --git a/core/tests/integrate/fp.rs b/core/tests/integrate/fp.rs index 4bd3cf5..5c3b574 100644 --- a/core/tests/integrate/fp.rs +++ b/core/tests/integrate/fp.rs @@ -4,7 +4,7 @@ use scallop_core::testing::*; fn range_free_1() { expect_interpret_result( r#" - rel result(y) = range_usize(0, 5, y) + rel result(y) = range(0, 5, y) "#, ("result", vec![(0usize,), (1,), (2,), (3,), (4,)]), ); @@ -15,7 +15,7 @@ fn range_constraint_1() { expect_interpret_result( r#" rel base = {("A", "B", 3.0)} - rel result(a, b) = base(a, b, x) and soft_eq_f32(x, 3.0) + rel result(a, b) = base(a, b, x) and soft_eq(x, 3.0) "#, ("result", vec![("A".to_string(), "B".to_string())]), ); @@ -26,7 +26,7 @@ fn range_join_1() { expect_interpret_result( r#" rel base = {3} - rel result(y) = base(x) and range_usize(0, x, y) + rel result(y) = base(x) and range(0, x, y) "#, ("result", vec![(0usize,), (1,), (2,)]), ); @@ -37,7 +37,7 @@ fn range_join_2() { expect_interpret_result( r#" rel base = {3} - rel result() = base(x) and range_usize(0, x, 2) + rel result() = base(x) and range(0, x, 2) "#, ("result", vec![()]), ); @@ -48,7 +48,7 @@ fn range_join_3() { expect_interpret_empty_result( r#" rel base = {3} - rel result() = base(x) and range_usize(0, x, 100) + rel result() = base(x) and range(0, x, 100) "#, "result", ); @@ -59,7 +59,7 @@ fn range_join_4() { expect_interpret_result( r#" rel base = {3, 10} - rel result(x) = base(x) and range_usize(0, x, 5) + rel result(x) = base(x) and range(0, x, 5) "#, ("result", vec![(10usize,)]), ); @@ -80,12 +80,81 @@ fn string_chars_1() { fn floating_point_eq_1() { expect_interpret_multi_result( r#" - rel result_1() = float_eq_f32(3.000001, 1.000001 + 2.000001) + rel result_1() = float_eq(3.000001, 1.000001 + 2.000001) rel result_2() = 3.000001 == 1.000001 + 2.000001 "#, + vec![("result_1", vec![()].into()), ("result_2", TestCollection::empty())], + ) +} + +#[test] +fn string_split_1() { + expect_interpret_multi_result( + r#" + rel string = {"abcde ab cde abcde"} + rel pattern1 = {" "} + rel pattern2 = {"ab"} + rel pattern3 = {"abcde"} + rel result1(out) = string(s), pattern1(p), string_split(s, p, out) + rel result2(out) = string(s), pattern2(p), string_split(s, p, out) + rel result3(out) = string(s), pattern3(p), string_split(s, p, out) + "#, vec![ - ("result_1", vec![()].into()), - ("result_2", TestCollection::empty()) + ( + "result1", + vec![ + ("abcde".to_string(),), + ("ab".to_string(),), + ("cde".to_string(),), + ("abcde".to_string(),), + ] + .into(), + ), + ( + "result2", + vec![ + ("".to_string(),), + ("cde ".to_string(),), + (" cde ".to_string(),), + ("cde".to_string(),), + ] + .into(), + ), + ( + "result3", + vec![("".to_string(),), (" ab cde ".to_string(),), ("".to_string(),)].into(), + ), ], - ) + ); +} + +#[test] +fn string_find_1() { + expect_interpret_multi_result( + r#" + rel string = {"abcde ab cde abcde"} + rel pattern1 = {" "} + rel pattern2 = {"ab"} + rel pattern3 = {"cde"} + rel result1(i, j) = string(s), pattern1(p), string_find(s, p, i, j) + rel result2(i, j) = string(s), pattern2(p), string_find(s, p, i, j) + rel result3(i, j) = string(s), pattern3(p), string_find(s, p, i, j) + "#, + vec![ + ("result1", vec![(5usize, 6usize), (8, 9), (12, 13)].into()), + ("result2", vec![(0usize, 2usize), (6, 8), (13, 15)].into()), + ("result3", vec![(2usize, 5usize), (9, 12), (15, 18)].into()), + ], + ); +} + +#[test] +fn datetime_ymd_1() { + expect_interpret_result( + r#" + rel datetime = {t"2023-04-17T00:00:00Z"} + rel result(y, m, d) = datetime(dt), datetime_ymd(dt, y, m, d) + "#, + ("result", vec![(2023i32, 4u32, 17u32)]), + ); } diff --git a/core/tests/integrate/incr.rs b/core/tests/integrate/incr.rs index b7969b9..9e7ff69 100644 --- a/core/tests/integrate/incr.rs +++ b/core/tests/integrate/incr.rs @@ -80,7 +80,9 @@ fn incr_edge_path_left_branching_1() { ); // Second branch, continuation - second_branch.add_rule(r#"result(x, y) = path(x, y) and x == 1"#).unwrap(); + second_branch + .add_rule(r#"result(x, y) = path(x, y) and x == 1"#) + .unwrap(); second_branch.run().unwrap(); expect_output_collection( "result", diff --git a/core/tests/integrate/io.rs b/core/tests/integrate/io.rs new file mode 100644 index 0000000..01870ca --- /dev/null +++ b/core/tests/integrate/io.rs @@ -0,0 +1,266 @@ +use scallop_core::common::value::*; +use scallop_core::runtime::provenance::*; +use scallop_core::testing::*; + +static CARGO_MANIFEST_DIR: &'static str = env!("CARGO_MANIFEST_DIR"); + +#[test] +fn io_edge() { + expect_interpret_result( + &format!( + r#" + @file("{}/res/testing/csv/edge.csv") + type edge(a: i32, b: i32) + rel path(a, b) = edge(a, b) \/ path(a, c) /\ edge(c, b) + query path + "#, + CARGO_MANIFEST_DIR, + ), + ("path", vec![(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]), + ); +} + +#[test] +fn io_edge_with_header() { + expect_interpret_result( + &format!( + r#" + @file("{}/res/testing/csv/edge_with_header.csv", header=true) + type edge(a: i32, b: i32) + rel path(a, b) = edge(a, b) \/ path(a, c) /\ edge(c, b) + query path + "#, + CARGO_MANIFEST_DIR, + ), + ("path", vec![(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]), + ); +} + +#[test] +fn io_edge_with_deliminator() { + expect_interpret_result( + &format!( + r#" + @file("{}/res/testing/csv/edge_with_deliminator.csv", deliminator="\t") + type edge(a: i32, b: i32) + rel path(a, b) = edge(a, b) \/ path(a, c) /\ edge(c, b) + query path + "#, + CARGO_MANIFEST_DIR, + ), + ("path", vec![(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]), + ); +} + +#[test] +fn io_edge_with_deliminator_and_header() { + expect_interpret_result( + &format!( + r#" + @file("{}/res/testing/csv/edge_with_deliminator_and_header.csv", deliminator="\t", header=true) + type edge(a: i32, b: i32) + rel path(a, b) = edge(a, b) \/ path(a, c) /\ edge(c, b) + query path + "#, + CARGO_MANIFEST_DIR, + ), + ("path", vec![(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]), + ); +} + +#[test] +fn io_edge_with_prob() { + let ctx = min_max_prob::MinMaxProbProvenance::new(); + expect_interpret_result_with_tag( + &format!( + r#" + @file("{}/res/testing/csv/edge_with_prob.csv", has_probability=true) + type edge(a: i32, b: i32) + "#, + CARGO_MANIFEST_DIR, + ), + ctx, + ("edge", vec![(0.01, (0, 1)), (0.5, (1, 2)), (0.91, (2, 3))]), + f64::eq, + ); +} + +#[test] +fn io_student() { + expect_interpret_result( + &format!( + r#" + @file("{}/res/testing/csv/student.csv", keys="id") + type student(id: String, field: Symbol, value: String) + "#, + CARGO_MANIFEST_DIR, + ), + ( + "student", + vec![ + ("1".to_string(), Value::symbol_str("name"), "alice".to_string()), + ("1".to_string(), Value::symbol_str("year"), "2022".to_string()), + ("1".to_string(), Value::symbol_str("gender"), "female".to_string()), + ("2".to_string(), Value::symbol_str("name"), "bob".to_string()), + ("2".to_string(), Value::symbol_str("year"), "2023".to_string()), + ("2".to_string(), Value::symbol_str("gender"), "male".to_string()), + ], + ), + ); +} + +#[test] +fn io_student_with_fields() { + expect_interpret_result( + &format!( + r#" + @file("{}/res/testing/csv/student.csv", fields=["id", "name", "year"]) + type student(id: String, name: String, year: i32) + "#, + CARGO_MANIFEST_DIR, + ), + ( + "student", + vec![ + ("1".to_string(), "alice".to_string(), 2022i32), + ("2".to_string(), "bob".to_string(), 2023i32), + ], + ), + ); +} + +#[test] +fn io_enrollment() { + expect_interpret_result( + &format!( + r#" + @file("{}/res/testing/csv/enrollment.csv", keys=["student_id", "course_id"]) + type enrollment(student_id: String, course_id: String, field: Symbol, value: String) + "#, + CARGO_MANIFEST_DIR, + ), + ( + "enrollment", + vec![ + ( + "1".to_string(), + "cse100".to_string(), + Value::symbol_str("semester"), + "fa".to_string(), + ), + ( + "1".to_string(), + "cse100".to_string(), + Value::symbol_str("year"), + "2020".to_string(), + ), + ( + "1".to_string(), + "cse100".to_string(), + Value::symbol_str("grade"), + "a".to_string(), + ), + ( + "1".to_string(), + "cse102".to_string(), + Value::symbol_str("semester"), + "sp".to_string(), + ), + ( + "1".to_string(), + "cse102".to_string(), + Value::symbol_str("year"), + "2021".to_string(), + ), + ( + "1".to_string(), + "cse102".to_string(), + Value::symbol_str("grade"), + "a".to_string(), + ), + ( + "2".to_string(), + "cse100".to_string(), + Value::symbol_str("semester"), + "sp".to_string(), + ), + ( + "2".to_string(), + "cse100".to_string(), + Value::symbol_str("year"), + "2020".to_string(), + ), + ( + "2".to_string(), + "cse100".to_string(), + Value::symbol_str("grade"), + "b".to_string(), + ), + ], + ), + ); +} + +#[test] +fn io_enrollment_with_keys_and_fields() { + expect_interpret_result( + &format!( + r#" + @file("{}/res/testing/csv/enrollment.csv", keys=["student_id", "course_id"], fields=["grade"]) + type enrollment(student_id: String, course_id: String, field: Symbol, value: String) + "#, + CARGO_MANIFEST_DIR, + ), + ( + "enrollment", + vec![ + ( + "1".to_string(), + "cse100".to_string(), + Value::symbol_str("grade"), + "a".to_string(), + ), + ( + "1".to_string(), + "cse102".to_string(), + Value::symbol_str("grade"), + "a".to_string(), + ), + ( + "2".to_string(), + "cse100".to_string(), + Value::symbol_str("grade"), + "b".to_string(), + ), + ], + ), + ); +} + +#[test] +fn io_enrollment_arity_error() { + expect_interpret_specific_failure( + &format!( + r#" + @file("{}/res/testing/csv/enrollment.csv", keys=["student_id", "course_id"], fields=["grade"]) + type enrollment(student_id: String, course_id: String) + "#, + CARGO_MANIFEST_DIR, + ), + |err| format!("{}", err).contains("IO: Arity mismatch"), + ); +} + +#[test] +fn io_enrollment_type_error() { + expect_interpret_specific_failure( + &format!( + r#" + @file("{}/res/testing/csv/enrollment.csv", keys=["student_id", "course_id"], fields=["grade"]) + type enrollment(student_id: String, course_id: String, field: String, value: String) + "#, + CARGO_MANIFEST_DIR, + ), + |err| format!("{}", err).contains("IO: Expect `Symbol` type for field; found `String`"), + ); +} diff --git a/core/tests/integrate/mod.rs b/core/tests/integrate/mod.rs index 17f35a7..94e91cb 100644 --- a/core/tests/integrate/mod.rs +++ b/core/tests/integrate/mod.rs @@ -1,3 +1,5 @@ +mod adt; +mod attr; mod basic; mod bug; mod dt; @@ -5,6 +7,7 @@ mod edb; mod ff; mod fp; mod incr; +mod io; mod iter; mod prob; mod time; diff --git a/core/tests/integrate/time.rs b/core/tests/integrate/time.rs index 0d2f955..0c6e5b6 100644 --- a/core/tests/integrate/time.rs +++ b/core/tests/integrate/time.rs @@ -28,7 +28,7 @@ fn date_1() { r#" rel r = {t"2019-01-01T00:00:00Z"} "#, - ("r", vec![(Utc.with_ymd_and_hms(2019, 01, 01, 0, 0, 0).unwrap(),)]) + ("r", vec![(Utc.with_ymd_and_hms(2019, 01, 01, 0, 0, 0).unwrap(),)]), ) } @@ -38,24 +38,22 @@ fn date_2() { r#" rel r = {(0, t"2019-01-01T00:00:00Z")} "#, - ("r", vec![(0, Utc.with_ymd_and_hms(2019, 01, 01, 0, 0, 0).unwrap())]) + ("r", vec![(0, Utc.with_ymd_and_hms(2019, 01, 01, 0, 0, 0).unwrap())]), ) } #[test] fn bad_date_1() { - expect_front_compile_failure( - r#"rel r = {t"ABCDEF"}"#, - |e| e.contains("Cannot parse date time `ABCDEF`") - ) + expect_front_compile_failure(r#"rel r = {t"ABCDEF"}"#, |e| { + e.contains("Cannot parse date time `ABCDEF`") + }) } #[test] fn bad_duration_1() { - expect_front_compile_failure( - r#"rel r = {d"ABCDEF"}"#, - |e| e.contains("Cannot parse duration `ABCDEF`") - ) + expect_front_compile_failure(r#"rel r = {d"ABCDEF"}"#, |e| { + e.contains("Cannot parse duration `ABCDEF`") + }) } #[test] @@ -66,7 +64,7 @@ fn date_plus_duration_1() { rel q = {d"3d"} rel r(date + duration) = p(date) and q(duration) "#, - ("r", vec![(Utc.with_ymd_and_hms(2019, 01, 04, 0, 0, 0).unwrap(),)]) + ("r", vec![(Utc.with_ymd_and_hms(2019, 01, 04, 0, 0, 0).unwrap(),)]), ) } @@ -78,7 +76,7 @@ fn date_minus_duration_1() { rel q = {d"3d"} rel r(date - duration) = p(date) and q(duration) "#, - ("r", vec![(Utc.with_ymd_and_hms(2019, 01, 01, 0, 0, 0).unwrap(),)]) + ("r", vec![(Utc.with_ymd_and_hms(2019, 01, 01, 0, 0, 0).unwrap(),)]), ) } @@ -89,7 +87,7 @@ fn duration_plus_duration_1() { rel p = {(d"3d", d"2d")} rel r(d1 + d2) = p(d1, d2) "#, - ("r", vec![(Duration::days(5),)]) + ("r", vec![(Duration::days(5),)]), ) } @@ -100,7 +98,7 @@ fn get_year_1() { rel p = {t"2019-01-04T00:00:00Z"} rel r($datetime_year(d)) = p(d) "#, - ("r", vec![(2019i32,)]) + ("r", vec![(2019i32,)]), ) } @@ -111,7 +109,7 @@ fn get_month_1() { rel p = {t"2019-01-04T00:00:00Z"} rel r($datetime_month(d)) = p(d) "#, - ("r", vec![(1u32,)]) + ("r", vec![(1u32,)]), ) } @@ -122,6 +120,6 @@ fn get_month0_1() { rel p = {t"2019-01-04T00:00:00Z"} rel r($datetime_month0(d)) = p(d) "#, - ("r", vec![(0u32,)]) + ("r", vec![(0u32,)]), ) } diff --git a/core/tests/runtime/dataflow/dyn_exclusion.rs b/core/tests/runtime/dataflow/dyn_exclusion.rs index 6fc12ed..b29ccb6 100644 --- a/core/tests/runtime/dataflow/dyn_exclusion.rs +++ b/core/tests/runtime/dataflow/dyn_exclusion.rs @@ -1,6 +1,6 @@ use scallop_core::common::expr::*; -use scallop_core::runtime::env::*; use scallop_core::runtime::dynamic::*; +use scallop_core::runtime::env::*; use scallop_core::runtime::provenance::*; use scallop_core::testing::*; use scallop_core::utils::*; @@ -24,7 +24,7 @@ fn test_dynamic_exclusion_1() { target.insert_dataflow_recent( &ctx, &dataflow::DynamicDataflow::from(&source) - .dynamic_exclusion(dataflow::DynamicDataflow::untagged_vec(&ctx, &exc), &ctx) + .dynamic_exclusion(dataflow::DynamicDataflow::untagged_vec(&ctx, exc.clone()), &ctx) .project((Expr::access((0, 0)), Expr::access((1, 0))).into()), &rt, ); @@ -32,10 +32,13 @@ fn test_dynamic_exclusion_1() { } // Inspect the result - expect_collection(&target.complete(&ctx), vec![ - (0, "red".to_string()), - (0, "blue".to_string()), - (1, "red".to_string()), - (1, "blue".to_string()), - ]); + expect_collection( + &target.complete(&ctx), + vec![ + (0, "red".to_string()), + (0, "blue".to_string()), + (1, "red".to_string()), + (1, "blue".to_string()), + ], + ); } diff --git a/core/tests/runtime/dataflow/dyn_foreign_predicate.rs b/core/tests/runtime/dataflow/dyn_foreign_predicate.rs index 8a86085..3e20e05 100644 --- a/core/tests/runtime/dataflow/dyn_foreign_predicate.rs +++ b/core/tests/runtime/dataflow/dyn_foreign_predicate.rs @@ -1,6 +1,6 @@ use scallop_core::common::expr::*; -use scallop_core::common::value::*; use scallop_core::common::tuple::*; +use scallop_core::common::value::*; use scallop_core::runtime::dynamic::dataflow::*; use scallop_core::runtime::dynamic::*; use scallop_core::runtime::env::*; @@ -11,7 +11,7 @@ fn test_dyn_dataflow_free_range() { let runtime = RuntimeEnvironment::new_std(); let ctx = unit::UnitProvenance::new(); let df = DynamicDataflow::foreign_predicate_ground( - "range_usize".to_string(), + "range#usize".to_string(), vec![Value::USize(1), Value::USize(5)], true, &ctx, @@ -43,7 +43,7 @@ fn test_dyn_dataflow_soft_lt_1() { DynamicElement::new((1.0, 5.0), 1.0), ]; let df = DynamicDataflow::vec(&source_df).foreign_predicate_constraint( - "soft_lt_f64".to_string(), + "soft_lt#f64".to_string(), vec![Expr::access(0), Expr::access(1)], &ctx, ); @@ -69,9 +69,22 @@ fn test_dyn_dataflow_join_range() { DynamicElement::new((10usize, 10usize), unit::Unit), DynamicElement::new((100usize, 101usize), unit::Unit), ]; - let df = DynamicDataflow::vec(&data).foreign_predicate_join("range_usize".to_string(), vec![Expr::access(0), Expr::access(1)], &ctx); + let df = DynamicDataflow::vec(&data).foreign_predicate_join( + "range#usize".to_string(), + vec![Expr::access(0), Expr::access(1)], + &ctx, + ); let batch = df.iter_recent(&runtime).next().unwrap().collect::>(); - assert_eq!(AsTuple::<((usize, usize), (usize,))>::as_tuple(&batch[0].tuple), ((1usize, 3usize), (1usize,))); - assert_eq!(AsTuple::<((usize, usize), (usize,))>::as_tuple(&batch[1].tuple), ((1usize, 3usize), (2usize,))); - assert_eq!(AsTuple::<((usize, usize), (usize,))>::as_tuple(&batch[2].tuple), ((100usize, 101usize), (100usize,))); + assert_eq!( + AsTuple::<((usize, usize), (usize,))>::as_tuple(&batch[0].tuple), + ((1usize, 3usize), (1usize,)) + ); + assert_eq!( + AsTuple::<((usize, usize), (usize,))>::as_tuple(&batch[1].tuple), + ((1usize, 3usize), (2usize,)) + ); + assert_eq!( + AsTuple::<((usize, usize), (usize,))>::as_tuple(&batch[2].tuple), + ((100usize, 101usize), (100usize,)) + ); } diff --git a/doc/js/hljs_scallop.js b/doc/js/hljs_scallop.js index 67ba819..fd9671e 100644 --- a/doc/js/hljs_scallop.js +++ b/doc/js/hljs_scallop.js @@ -2,20 +2,41 @@ hljs.registerLanguage("pen", (hljs) => ({ name: "Scallop", aliases: ["scl", "scallop"], keywords: { - keyword: "import type rel relation query if then else where", - type: "i8 i16 i32 i64 i128 isize u8 u16 u32 u64 u128 usize f32 f64 bool char String", + keyword: "import type const rel query if then else where case is in and or not implies new", + operator: "&& || ^ + - * / % /\\ \\/ => := = :-", + type: "i8 i16 i32 i64 i128 isize u8 u16 u32 u64 u128 usize f32 f64 bool char String DateTime Duration Entity Tensor", literal: "true false", - built_in: "count sum prod min max exists forall unique top", + built_in: "count sum prod min max exists forall unique top categorical uniform", }, contains: [ hljs.C_LINE_COMMENT_MODE, hljs.C_BLOCK_COMMENT_MODE, + { + className: 'string', + begin: "'", + end: "'" + }, { className: "string", variants: [ hljs.QUOTE_STRING_MODE, ] }, + { + className: "string", + begin: 's"', + end: '"' + }, + { + className: "string", + begin: 'd"', + end: '"' + }, + { + className: "string", + begin: 't"', + end: '"' + }, { className: "number", variants: [ diff --git a/doc/src/introduction.md b/doc/src/introduction.md index 765407d..eed3f13 100644 --- a/doc/src/introduction.md +++ b/doc/src/introduction.md @@ -1,3 +1,87 @@ +# Scallop, a Language for Neurosymbolic Programming + +

+ +
+ +Scallop is a language based on DataLog that supports differentiable logical and relational reasoning. +Scallop program can be easily integrated in Python and even with a PyTorch learning module. +You can also use it as another DataLog solver. +This book aims to give both high-level overview of the language usage and also low-level documentation on how each language feature is used. + +The following example shows how knowledge base facts, rules, and probabilistic facts recognized from images can operate together. + ``` scl -rel hello = {"world"} +// Knowledge base facts +rel is_a("giraffe", "mammal") +rel is_a("tiger", "mammal") +rel is_a("mammal", "animal") + +// Knowledge base rules +rel name(a, b) :- name(a, c), is_a(c, b) + +// Recognized from an image, maybe probabilistic +rel name = { + 0.3::(1, "giraffe"), + 0.7::(1, "tiger"), + 0.9::(2, "giraffe"), + 0.1::(2, "tiger"), +} + +// Count the animals +rel num_animals(n) :- n = count(o: name(o, "animal")) ``` + +## Table of Content + +Please refer to the side-bar for a detailed table of content. +At a high-level, we organize this book into the following 5 sections: + +### Installation and Crash Course + +[Installation](installation.md) gives instructions on how to install the Scallop on your machine. +[Crash Course](crash_course.md) gives a quick introduction to what the language is and how it is used. +Both sections are designed so that you can start quickly with Scallop. + +### Scallop and Logic Programming + +[Scallop and Logic Programming](language/index.md) aims to give you a detailed introduction on the language. +It introduces language features such as relational programming, negation and aggregation, queries, foreign constructs, and etc. +Reading through all of these you should be well-versed in Scallop's core functionality and you will be able to use Scallop as a Datalog engine. + +``` scl +@demand("bf") +rel fib = {(0, 1), (1, 1)} +rel fib(x, y1 + y2) = fib(x - 1, y1) and fib(x - 2, y2) and x > 1 +query fib(10, y) +``` + +### Scallop and Probabilistic Programming + +[Scallop and Probabilistic Programming](probabilistic/index.md) introduces the probabilistic side of Scallop. +You will learn to tag facts with probabilities, its underlying algorithms and frameworks, and additional programming constructs for probabilistic semantics. +By the end of this section, you will be familiar with using Scallop as a probabilistic programming language. + +``` scl +rel attr = { 0.99::(OBJECT_A, "blue"), 0.01::(OBJECT_B, "red"), ... } +rel relate = { 0.01::(OBJECT_A, "holds", OBJECT_B), ... } +``` + +### Scallopy and Neurosymbolic Programming + +[Scallopy and Neurosymbolic Programming](scallopy/index.md) goes into the heart of Scallop to introduce applying Scallop to write Neurosymbolic applications. +Neurosymbolic methods are for methods that have both neural and logical components. +For this, we are going to use the Python binding of Scallop, `scallopy`, to integrate with machine learning libraries such as PyTorch. +This section will be describing the API of `scallopy`. + +``` py +sum_2 = scallopy.Module( + program="""type digit_1(i32), digit_2(i32) + rel sum_2(a + b) = digit_1(a) and digit_2(b)""", + input_mappings={"digit_1": range(10), "digit_2": range(10)}, + output_mapping=("sum_2", range(19))) +``` + +### For Developers + +[For Developers](developer/index.md) discusses how developers and researchers who are interested in extending Scallop can step into the source code of Scallop and program extensions. diff --git a/doc/src/language/adt_and_entity.md b/doc/src/language/adt_and_entity.md new file mode 100644 index 0000000..8905e13 --- /dev/null +++ b/doc/src/language/adt_and_entity.md @@ -0,0 +1,456 @@ +# Algebraic Data Type and Entities + +Algebraic data types are powerful programming constructs that allows user to define custom data structures and variants. +Consider a traditional functional definition of a `List`: + +``` scl +type IntList = Nil() + | Cons(i32, List) +``` + +We are saying that a `IntList` can be one of two variants, `Nil` and `Cons`: +- `Nil` denotes the end of a list; +- `Cons` contains the current `i32` integer and a continuation of the list. + +In this representation, we can represent a list like `[1, 2, 3]` with `Cons(1, Cons(2, Cons(3, Nil())))`. +This is indeed what we can write in Scallop. +We can declare such a list as a constant: + +``` scl +const MY_LIST = Cons(1, Cons(2, Cons(3, Nil()))) +``` + +In general, we call the type definition of such data structure *Algebraic Data Type* definitions, or *ADT* definitions. +The name *Entity* is used to refer to objects of such data types. +In the example above, the constant `MY_LIST` is an *entity* of the *ADT* named `IntList`. + +In this section, we describe in detail the definition and use of ADT and Entities. +We also touch on the internals. + +## Defining Algebraic Data Types (ADT) + +We use the following syntax to define ADTs: + +``` scl +type TYPE_NAME = VARIANT_NAME(ARG_TYPE_1, ARG_TYPE_2, ...) | ... +``` + +An ADT named `TYPE_NAME` is defined to have multiple (at least 2) named variants with `VARIANT_NAME`. +Each variant holds a tuple of values typed by `ARG_TYPE_1`, `ARG_TYPE_2`, etc. +We call variants that have no argument *terminal variant*s. +Parenthesis are still needed for those variants. + +Please note that there cannot be duplicated variant names, either within the same ADT or different ADTs. +For example, the following code would result in compilation failure: + +``` scl +type IntList = Cons(i32, IntList) | Nil() +type BoolList = Cons(bool, BoolList) | Nil() // Failure: Cons and Nil are already defined +``` + +Currently, ADTs do not support generics. +In the above case, the `IntList` and `BoolList` needs to be defined separately with differently named variants. + +### Using ADT to represent arithmetic expressions + +Common data that can be expressed through ADT could be structured expressions. +The following definition describes the abstract syntax tree (AST) of simple arithmetic expressions: + +``` scl +type Expr = Int(i32) // An expression could be a simple integer, + | Add(Expr, Expr) // a summation of two expressions + | Sub(Expr, Expr) // a substraction of two expressions +``` + +The following code encodes a simple expression + +``` scl +// The expression (1 + 3) - 5 +const MY_EXPR = Sub(Add(Int(1), Int(3)), Int(5)) +``` + +### Using ADT to represent data structures + +Data structures such as binary trees can also be represented: + +``` scl +type Tree = Node(i32, Tree, Tree) | Nil() +``` + +Here, `Node(i32, Tree, Tree)` represents a node in a tree holding three things: +an integer (`i32`), a left sub-tree `Tree`, and a right sub-tree `Tree`. +The other variant `Nil` represents an empty sub-tree. +In this encoding, `Node(5, Nil(), Nil())` would be representing a leaf-node holding a number 5. + +The following code encodes a balanced binary search tree: + +``` scl +// 3 +// / \ +// 1 5 +// / \ / \ +// 0 2 4 6 +const MY_TREE = + Node(3, + Node(1, + Node(0, Nil(), Nil()), + Node(2, Nil(), Nil()), + ), + Node(5, + Node(4, Nil(), Nil()), + Node(6, Nil(), Nil()), + ) + ) +``` + +## Working with Entities + +Entities are most commonly created as constants using the `const` keyword. +Let us revisit the `List` example and see how we can use the defined constant in our analysis. + +``` scl +type List = Cons(i32, List) | Nil() + +const MY_LIST = Cons(1, Cons(2, Cons(3, Nil()))) // [1, 2, 3] +``` + +### Using Entities in Relations + +We can include the constant entities as part of a fact: + +``` scl +rel target(MY_LIST) +query target +``` + +As a result of the above program, we are going to get the value of the entity `MY_LIST`: + +``` +target: {(entity(0xff08d5d60a201f17))} +``` + +The value is going to be a 64-bit integer encoded in hex. +It is a unique identifier for the created entity. + +Note that, identical entities are going to have the same identifier. +In the following example, `MY_LIST_1` and `MY_LIST_2` are identical, and therefore their hex identifier are the same. + +``` scl +const MY_LIST_1 = Cons(1, Nil()), + MY_LIST_2 = Cons(1, Nil()), + MY_LIST_3 = Cons(2, Nil()) + +rel lists = { + (1, MY_LIST_1), + (2, MY_LIST_2), + (3, MY_LIST_3), +} + +query lists +// lists: { +// (1, entity(0x678defa0a65c83ab)), // Notice that the entity 1 and 2 are the same +// (2, entity(0x678defa0a65c83ab)), +// (3, entity(0x3734567c3d9f8d3f)), // This one is different than above +// } +``` + +### Decomposing Entities in Rules + +To peek into the content of an Entity, we can *destruct* it using the `case`-`is` operator. +We look at an example of computing the length of a list: + +``` scl +type length(list: List, len: i32) +rel length(list, 0) = case list is Nil() +rel length(list, l + 1) = case list is Cons(_, tl) and length(tl, l) +``` + +We define a recursive relation `length` to compute the length of a list. +There are two cases. +When the list is `Nil()`, this means the list has already ended. +Therefore the list has a length of `0` +For the second case, the list is `Cons(_, tl)`. +Here, the length of list is the length of `tl` plus 1. + +We can then compute the length of a list by `query`ing the `length` relationship on a constant list. + +``` scl +query length(MY_LIST, l) // l = 3 +``` + +### Case Study: Decomposing Entities for Pretty-Printing + +We can look at more examples of using the `case`-`is` operators. +The following set of rules pretty-prints expressions: + +``` scl +type Expr = Int(i32) | Add(Expr, Expr) | Sub(Expr, Expr) + +type to_string(expr: Expr, str: String) +rel to_string(e, $format("{}", i)) = case e is Int(i) +rel to_string(e, $format("({} + {})", a, b)) = case e is Add(e1, e2) and to_string(e1, a) and to_string(e2, b) +rel to_string(e, $format("({} - {})", a, b)) = case e is Sub(e1, e2) and to_string(e1, a) and to_string(e2, b) +``` + +Shown in the example, we have written three `to_string` rules for pretty-printing the `Expr` data structure. +Each rule correspond to handling exactly one of the variants. +For the inductive cases `Add` and `Sub`, we have the `to_string` rule defined recursively so that the sub-expressions are also converted to strings. +For pretty-printing, we have used the `$format` foreign function. + +At the end, running the following snippet + +``` scl +const MY_EXPR = Sub(Add(Int(3), Int(5)), Int(1)) +query to_string(MY_EXPR, s) +``` + +would give the following result, suggesting that the pretty-printed expression is `((3 + 5) - 1)` + +``` +to_string(MY_EXPR, s): {(entity(0xa97605c2703c6249), "((3 + 5) - 1)")} +``` + +### Case Study: Checking Regular Expressions + +With ADT, we can specify the language of regular expressions (regex) with ease. +Let's consider a very simple regex with union (`|`) and star (`*`), while phrases can be grouped together. +For example, the regex `"a*b"` expresses that character `a` can be repeated arbitrary amount of time (including 0-times), followed by a single `b`. +This regex can be used to match strings like `"aaaab"` and `"b"`, but not `"ba"`. + +Let's try to define this regex language in Scallop! + +``` scl +type Regex = Char(char) // a single character + | Star(Regex) // the star of a regex + | Union(Regex, Regex) // a union of two regexes + | Concat(Regex, Regex) // concatenation of two regexes +``` + +As can be seen, we have defined 4 variants of this regex language. +With this, our regex `"a*b"` can be expressed as follows: + +``` scl +// a*b +const A_STAR_B = Concat(Star(Char('a')), Char('b')) +``` + +Now, let's define the actual semantics of this regex language and write a relation `matches` to check if the regex matches with a given sub-string. +We first setup the types of such relations. +- `input_regex` is a unary-relation for holding the regex to be checked against; +- `input_string` is a unary-relation for holding the string to be checked against; +- `matches_substr` is for checking if a sub-regex `r` can be matched with the input string between `begin` and `end` indices, where `end` is exclusive; +- `matches` is a boolean relation telling whether the `A_STAR_B` regex matches with the input string or not. + +``` scl +type input_regex(r: Regex) +type input_string(s: String) +type matches_substr(r: Regex, begin: usize, end: usize) +type matches() +``` + +The main bulk of the code will then be dedicated to define the `matches_substr` relation. +At a high level, we decompose on each type of regex, and match on sub-strings. +The first rule that we are going to write would be for the `Char` variant. + +``` scl +rel matches_substr(r, i, i + 1) = case r is Char(c) and input_string(s) and string_chars(s, i, c) +``` + +The rule suggests that if the regex `r` is a single character `c`, then we go into the input string `s` and find all the index `i` such that its corresponding character is `c`. +The matched sub-string would start at index `i` and end at index `i + 1`. +Note that the `string_chars` relation is a foreign predicate that decomposes the string into characters. + +Similarly, we can write the rules for other variants: + +``` scl +// For star; it matches empty sub-strings [i, i) and recursively on sub-regex +rel matches_substr(r, i, i) = case r is Star(_) and input_string(s) and string_chars(s, i, _) +rel matches_substr(r, b, e) = case r is Star(r1) and matches_substr(r, b, c) and matches_substr(r1, c, e) + +// For union; any string that matches left or right sub-regex would match the union +rel matches_substr(r, b, e) = case r is Union(r1, r2) and matches_substr(r1, b, e) +rel matches_substr(r, b, e) = case r is Union(r1, r2) and matches_substr(r2, b, e) + +// For concat; we need strings to match in a consecutive matter +rel matches_substr(r, b, e) = case r is Concat(r1, r2) and matches_substr(r1, b, c) and matches_substr(r2, c, e) +``` + +Lastly, we add the rule to derive the final `matches` relation. +Basically, it checks if the regex matches the start-to-end of the input string + +``` scl +rel matches() = input_regex(r) and input_string(s) and matches_substr(r, 0, $string_length(s)) +``` + +Let us test the result! + +``` scl +rel input_regex(A_STAR_B) +rel input_string("aaaab") +query matches // {()} +``` + +## Dynamically Creating Entities + +There are cases where we want to create new entities during the deductive process. +This is done through the `new` keyword followed by the entity to create. +Suppose we have the definition of `List` and some pretty-printing code for it: + +``` scl +type List = Cons(i32, List) | Nil() + +rel to_string_2(l, "]") = case l is Nil() +rel to_string_2(l, $format("{}]", i)) = case l is Cons(i, Nil()) +rel to_string_2(l, $format("{}, {}", i, ts)) = case l is Cons(i, tl) and case tl is Cons(_, _) and to_string_2(tl, ts) +rel to_string(l, $format("[{}", tl)) = to_string_2(l, tl) +``` + +The following example shows that, given an input list `l`, we generate a result list `Cons(1, l)`. + +``` scl +type input_list(List) +rel result_list(new Cons(1, l)) = input_list(l) +``` + +Given an actual list defined as a constant, we will be able to specify that the constant is the input list: + +``` scl +const MY_INPUT_LIST = Cons(2, Cons(3, Nil())) +rel input_list(MY_INPUT_LIST) +``` + +Now, let's visualize the results! + +``` scl +rel input_list_str(s) = to_string(MY_INPUT_LIST, s) +rel result_list_str(s) = result_list(l) and to_string(l, s) + +query input_list_str // [2, 3] +query result_list_str // [1, 2, 3] +``` + +As can be seen, through the `new` operator, we have essentially created a new list containing the element `1`. +We note that the rule for `result_list` is *not* recursive. +In general, extra care needs to be taken to ensure that the program does not go into infinite loop.` + +### Case Study: Creating Entities for Equality Saturation + +In this case study we look at the problem of equality saturation. +Given an symbolic expression, there might be ways to simplify it, which are defined through *rewrite rules*. +Notice that after simplification, the program should be equivalent to the input. +The problem is challenging as there might be multiple ways to apply the rewrite rules. +How do we then systematically derive the simplest equivalent program? + +A simple example here is the symbolic arithmetic expression language, with constant, variables, and summation rule: + +``` scl +type Expr = Const(i32) | Var(String) | Add(Expr, Expr) +``` + +One example expression that we can express in this language would be + +``` scl +const MY_EXPR = Add(Add(Const(-3), Var("a")), Const(3)) // (-3 + a) + 3 +``` + +For visualization, we write a `to_string` function + +``` scl +rel to_string(p, i as String) = case p is Const(i) +rel to_string(p, v) = case p is Var(v) +rel to_string(p, $format("({} + {})", s1, s2)) = + case p is Add(p1, p2) and to_string(p1, s1) and to_string(p2, s2) +``` + +If we query on `to_string` for `MY_EXPR`, we would get + +``` scl +query to_string(MY_EXPR, s) // s = "((-3 + a) + 3)" +``` + +Now let us deal with the actual simplification. +The expression `(-3 + a) + 3` could be simplified to just `a`, as the `-3` and `3` cancels out. +The way to do the simplification is to write two things: + +1. rewrite rules in the form of equivalence relation; +2. the weight function giving each expression a weight to tell which expression is *simpler*. + +For this, the following set of relations needs to be defined. + +``` scl +type input_expr(expr: Expr) +type equivalent(expr_1: Expr, expr_2: Expr) +type weight(expr: Expr, w: i32) +type simplest(expr: Expr) +``` + +Note that we need set a prior knowledge on `equivalent`: the `expr_1` is always *more complex* than the `expr_2`. +This is to prevent the simplification to go to arbitrary direction and result in infinite-loop. +In such case, `equivalent` would not be commutative. +Let us start with `equivalent` and define its basic property of identity and transitivity: + +``` scl +// Identity +rel equivalent(e, e) = case e is Const(_) or case e is Var(_) or case e is Add(_, _) + +// Transitivity +rel equivalent(e1, e3) = equivalent(e1, e2) and equivalent(e2, e3) +``` + +Now, we can write the rewrite rules. +The first one we are going to write states that, if `e1` and `e1p` are equivalent and `e2` and `e2p` are equivalent, +their additions (`Add(e1, e2)` and `Add(e1p, e2p)`) are equivalent too. + +``` scl +// e1 == e1p, e2 == e2p ==> (e1 + e2) == (e1p + e2p) +rel equivalent(e, new Add(e1p, e2p)) = case e is Add(e1, e2) and equivalent(e1, e1p) and equivalent(e2, e2p) +``` + +The next rule states that Addition is commutative, such that `Add(a, b)` is equivalent to `Add(b, a)`: + +``` scl +// (a + b) == (b + a) +rel equivalent(e, new Add(b, a)) = case e is Add(a, b) +``` + +We also have a rule for associativity: + +``` scl +// (a + (b + c)) == ((a + b) + c) +rel equivalent(e, new Add(new Add(a, b), c)) = case e is Add(a, Add(b, c)) +``` + +A rule for simplifying adding summation identity 0: + +``` scl +// a + 0 = a +rel equivalent(e, a) = case e is Add(a, Const(0)) +``` + +A rule for reducing two constants addition: + +``` scl +rel equivalent(e, Const(a + b)) = case e is Add(Const(a), Const(b)) +``` + +Now we have 5 rewrite-rules in place, let us define how to compute the weight of each expression. +The leaf nodes (`Var` and `Const`) have weight of `1`, and the addition have the weight from left and right sub-expr added together plus 1. + +``` scl +rel weight(e, 1) = case e is Var(_) or case e is Const(_) +rel weight(e, l + r + 1) = case e is Add(a, b) and weight(a, l) and weight(b, r) +``` + +Lastly, we use the aggregation to find the equivalent programs with the minimum weight, which is our definition of the "simplest" program. +Note that we have used an `argmax` aggregation denoted by `min[p]` here: + +``` scl +rel best_program(p) = _ := min[p](w: input_expr(e) and equivalent(e, p) and weight(p, w)) +``` + +If we query for the best program and turn it into string, we will get our expected output, a single variable `"a"`! + +``` scl +rel best_program_str(s) = best_program(p) and to_string(p, s) +query best_program_str // {("a")} +``` diff --git a/doc/src/language/aggregation.md b/doc/src/language/aggregation.md index 758eaf1..53cbe2e 100644 --- a/doc/src/language/aggregation.md +++ b/doc/src/language/aggregation.md @@ -1 +1,244 @@ -# Rules with Aggregations +# Aggregations + +Aggregations in Scallop can be viewed as operations that aggregates over multiple facts. +Such operations include counting, summation and product, finding min and max, and logical quantifiers such as exists and forall. +Aggregations appear in the body of a rule, and can be nested for abbrevity. + +As a concrete example, we look at a program which counts over a set of people: + +``` scl +rel person = {"alice", "bob", "christine"} +rel num_people(n) = n := count(p: person(p)) // n = 3 +``` + + +In general, we use the following syntax for aggregation formulas. + +``` scl +R1, R2, ... := AGGREGATOR(V1, V2, ...: FORMULA (where U1, U2, ...: FORMULA)?) +``` + +We name `R1, ...` to be the aggregation *result* variable, `V1, ...` to be the *binding* variable, and the formula inside of the aggregation the *body*. +When the `where` keyword is used, we have the aggregation associated with *explicit group-by* clause. +Here, we call the set of variables `U1, ...` as *group-by variables*. +The formula under the `where` clause is named the *group-by body*. +The binding variables need to be fully grounded by the body formula, and the group-by variables (if presented) need to also be fully grounded by the group-by body. +For different types of aggregation, the `AGGREGATOR` might also change and annotated with different information. +The number of result variables, the number of binding variables, and their types differ for each aggregation. + +Here is a high-level overview of each supported aggregator and their configurations. +In the table, `...` is used to denote an arbitrary amount of variables. + +| Aggregator | Binding Variables | Result Variables | +|------------|-------------------|------------------| +| `count` | `Any...` | `usize` | +| `sum` | `Number` | the same as the binding variable | +| `prod` | `Number` | the same as the binding variable | +| `min` | `Any` | the same as the binding variables | +| `max` | `Any` | the same as the binding variables | +| `exists` | `Any...` | `bool` | +| `forall` | `Any...` | `bool` | + +Below, we elaborate on each aggregators and describe their usages. + +## Count + +To count the number of facts, we can use the `count` aggregator. +Just repeating the examples shown in the beginning: + +``` scl +rel person = {"alice", "bob", "christine"} +rel num_people(n) = n := count(p: person(p)) // n = 3 +``` + +We are counting the number of persons appear in the `person` relation. +To be more concrete, let's read out the aggregation formula: + +> We count the number of `p` such that `p` is a `person`, and assign the result to the variable `n`. + +For `count`, there could be arbitrary (> 0) number of binding variables which can be typed arbitrarily. +It will only have a single result variable which is typed `usize`. +For example, you may count the number of `edge`s: + +``` scl +rel num_edges(n) = n := count(a, b: edge(a, b)) +``` + +Here, we have two binding variables `a` and `b`, meaning that we are counting the number of *distinct* pairs of `a` and `b`. + +### Implicit Group-By + +With `group-by`, we may count the number of facts under a pre-defined group. +Consider the example where there is a scene with differet colored objects, + +``` scl +rel obj_color = {(0, "red"), (1, "red"), (2, "blue"), (3, "red")} +rel num_obj_per_color(col, num) = num := count(obj: obj_color(obj, col)) +``` + +As suggested by the facts inside of `obj_color`, there are `4` objects indexed using `0, 1, 2, 3`, each associated with a different color. +The object #0, #1, and #3 are `red` and the object #2 is `blue`. +Therefore, we will get 3 red objects and 1 blue object, as computed in the result of `num_obj_per_color`: + +``` +num_obj_per_color: {("blue", 1), ("red", 3)} +``` + +Let's analyze the rule in detail. +We find that we are counting over `obj` such that the object `obj` has a certain color `col`. +But `col` is also a variable occurring in the head of the rule. +This is an *implicit group-by*, in that the variable `col` is being used as an implicit group-by variable. +That is, we are conditioning the counting procedure under each *group* that is defined by the `col` variable. +Since there are two colors appearing in the `obj_color` relation, we are performing count for each of the two groups. + +In general, if a variable is positively grounded in the body and appear in the head of a parent rule, we call the variable an *implicit group-by variable*. + +### Explicit Group-By + +In the above example, there is no green colored object. +However, how do we know that the number of green object is 0? +The result does not seem to address this problem. + +The missing piece is a *domain* of the possible groups. +Without explicitly setting the domain, Scallop could only search inside of the database on possible groups. +However, we can explicitly tell Scallop about what are the groups. +Consider the following rewrite of the above program: + +``` scl +rel colors = {"red", "green", "blue"} +rel obj_color = {(0, "red"), (1, "red"), (2, "blue"), (3, "red")} +rel num_obj_per_color(col, num) = num := count(obj: obj_color(obj, col) where col: colors(col)) +``` + +With the `where` clause, we have explicitly declared that `col` is a *group-by variable* which is grounded by the `colors` relation. +If we look into the `colors` relation, we find that there are three possible colors that we care about, red, green, and blue. +In this case, we will consider `"green"` as the third group and try to count the number of green objects -- which there are 0: + +``` +num_obj_per_color: {("blue", 1), ("green", 0), ("red", 3)} +``` + +## Sum and Product + +We can use the aggregator of sum and product to aggregate multiple numerical values. +Consider the following example of sales: + +``` scl +rel sales = {("alice", 1000.0), ("bob", 1200.0), ("christine", 1000.0)} +``` + +We can compute the sum of all the sales: + +``` scl +rel total_sales(s) = s := sum(sp: sales_1(p, sp)) // 3700.0 +``` + +Notice that the result type of `s` is the same as the type of the binding variable `sp`, which is `f32` as indicated by the decimals in the definition of `sales`. + +The product aggregator `prod` can be used in a similar manner as `sum`. + +## Min, Max, Argmin, and Argmax + +Scallop can compute the minimum or maximum among a set of values. +In the following example, we find the maximum grade of an exam: + +``` scl +rel exam_grades = {("a", 95.2), ("b", 87.3), ("c", 99.9)} +rel min_score(m) = m := max(s: exam_grades(_, s)) // 99.9 +``` + +The number (and types) of binding variables can be arbitrary, but the result variables must match the binding variables. +In the above case, since `s` is of type `f32`, `m` will be of type `f32` as well. + +It is also possible to get argmax/argmin. +Suppose we want to get the person (along with their grade) who scored the best, we write: + +``` scl +rel best_student(n, s) = s := max[n](s: exam_grades(n, s)) +``` + +Here, we are still finding the maximum score `s`, but along with `max` we have specified the "arg" (`[n]`) which associates with the maximum score. +We call `n` an arg variable for `min`/`max` aggregator. +The arg variable is grounded by the aggregation body, and can be directly used in the head of the rule. + +If we do not care about the grade and just want to know who has the best grade, we can use wildcard `_` to ignore the result variable, like + +``` +rel best_student(n) = _ := max[n](s: exam_grades(n, s)) +``` + +## Exists and Forall + +Logical quantifier such as exists and forall can also be encoded as aggregations. +They will return value of boolean as the aggregation result. + +### Existential Quantifier + +Let us start with discussing the easier of the two, `exists`. +Technically, all variables in the body of Scallop rule are existentially quantified. +We can use `exists` aggregation to make it explicit. +For example, we can check if there exists an object that is blue: + +``` scl +rel obj_color = {(0, "red"), (1, "green")} +rel has_blue(b) = b := exists(o: obj_color(o, "blue")) +``` + +Specifically, we are checking "if there exists an object `o` such that its color is `blue`". +The result is being assigned to a variable `b`. +Since there is no blue object, we will get a result of `has_blue(false)`. + +In case when we just want the result boolean to be `true` or `false`, we can omit the result variables. +For example, we can rewrite the recursive case of edge-path transitive closure as + +``` scl +rel path(a, c) = exists(b: path(a, b) and edge(b, c)) +``` + +We note that this is just a syntax sugar equivalent to the following: + +``` scl +rel path(a, c) = r := exists(b: path(a, b) and edge(b, c)) and r == true +``` + +When we want to know the inexistence of something, we can do + +``` scl +rel no_red() = not exists(o: obj_color(o, "red")) +``` + +Note that there can be arbitrary amount of binding variables. + +### Universal Quantifier + +We can also have universal quantifier `forall`. +For this, there is a special requirement for universal quantification, that the body formula has to be an `implies` formula. +In the following example, we check if all the objects are spherical: + +``` scl +type Shape = CUBE | SPHERE | CONE | CYLINDER +rel object = {0, 1, 2} +rel obj_shape = {(0, CUBE), (1, SPHERE), (2, SPHERE)} +rel target(b) = b := forall(o: object(o) implies obj_shape(o, SPHERE)) +``` + +Notice that we have a relation which defines the domain of `object`, suggesting that there are just 3 objects for us to work with. +In the aggregation, we are checking "for all `o` such that `o` is an object, is the object a sphere?" +The result is stored in the variable `b` and propagated to the `target` relation. + +The reason we need to have an *implies* formula is that we need to use the left-hand-side of `implies` to give bounds to the universally quantified variables. +Scallop cannot reason about open domain variables. + +Note that similar to `exists`, we can also remove the result variable. +The following program derives a boolean (arity-0) relation `target` denoting whether all the red objects are cubes: + +``` scl +type Shape = CUBE | SPHERE | CONE | CYLINDER +type Color = RED | GREEN | BLUE +rel obj_shape = {(0, CUBE), (1, SPHERE), (2, SPHERE)} +rel obj_color = {(0, RED), (1, GREEN), (2, GREEN)} +rel target() = forall(o: obj_color(o, RED) implies obj_shape(o, CUBE)) // {()} +``` + +Here, we directly use `obj_color` to serve as the left-hand-side of the `implies`. +There will be one empty tuple being derived, suggesting that the statement is true. diff --git a/doc/src/language/constants.md b/doc/src/language/constants.md new file mode 100644 index 0000000..44f4ab5 --- /dev/null +++ b/doc/src/language/constants.md @@ -0,0 +1,75 @@ +# Declaring Constants + +We can declare constants and give it names. +The general syntax is the following: + +``` scl +const NAME (: TYPE)? = CONSTANT +``` + +For example, we can define the value of `PI`: + +``` scl +const PI = 3.1415926 +``` + +Notice that here we have not specified the type of `PI`. +By default, a float value would resort to the place where the constant is used. +If we want to specify a non-default type, we can do + +``` scl +const PI: f64 = 3.1415926 +``` + +We can also declare multiple constants at a time: + +``` scl +const LEFT = 0, UP = 1, RIGHT = 2, DOWN = 3 +``` + +## Enum Types + +We sometimes want to define enum types which contain constant variables. +Common examples include `RED`, `GREEN`, and `BLUE` under the `Color` type, and `LEFT`, `RIGHT`, `UP` under the `Action` type. +These can be achieved by defining enum types: + +``` scl +type Color = RED | GREEN | BLUE +type Action = LEFT | UP | RIGHT | DOWN +``` + +Internally, the values such as `RED` and `UP` are unsigned integer constants. +If not specified, the values start from 0 and goes up 1 at a time. + +For example, given the type definition above, `RED = 0`, `GREEN = 1`, and `BLUE = 2`. +For `Action`s, `LEFT = 0`, `UP = 1`, and etc. +Notice that even when `Color` and `Action` are different types, their values can overlap. + +One can specify the values of these enum variants by attaching actual numbers to them. +In the following example, we have explicitly assigned three values to the colors. + +``` scl +type Color = RED = 3 | GREEN = 5 | BLUE = 7 +``` + +We can also just set a few of those: + +``` scl +type Color = RED | GREEN = 10 | BLUE +``` + +In this case, `RED = 0`, `GREEN = 10`, and `BLUE = 11`. +Notice how blue's value is incremented from `GREEN`. + +## Displaying Constants + +Constants are just values and many of them are integer values. +They are not explicitly associated with any symbols. +If you want to display them correctly, we advise you create auxilliary relations storing the mapping from each constant to its string form. +For example, we can have + +``` scl +rel color_to_string = {(RED, "red"), (GREEN, "green"), (BLUE, "blue")} +``` + +In this case, following the result with `color_to_string` relation will display their desired meanings properly. diff --git a/doc/src/language/custom_type.md b/doc/src/language/custom_type.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/language/disj_conj_head.md b/doc/src/language/disj_conj_head.md new file mode 100644 index 0000000..753af93 --- /dev/null +++ b/doc/src/language/disj_conj_head.md @@ -0,0 +1 @@ +# Disjunctive and Conjunctive Head diff --git a/doc/src/language/foreign_functions.md b/doc/src/language/foreign_functions.md index 92f76f9..da41d7c 100644 --- a/doc/src/language/foreign_functions.md +++ b/doc/src/language/foreign_functions.md @@ -1 +1,101 @@ # Foreign Functions + +Foreign functions allows for complex value manipulation in Scallop. +Conceptually, they are pure and partial functions that operate on value(s) and return one single value only. +Functions with states, such as `random`, are not allowed as foreign functions. + +## Function Types + +In Scallop, foreign functions are generically typed with optional and variable arguments. +All the functions have a dollar sign (`$`) associated with the function name. +We use the following syntax to denote a function signature + +``` +$FUNC_NAME( + POS_ARG: POS_ARG_TYPE, ..., + OPT_ARG: OPT_ARG_TYPE?, ..., + VAR_ARG_TYPE... +) -> RETURN_TYPE +``` + +The generic arguments are specified in the `<...>` after the function name, and can be annotated by optional type family. +For the arguments of the function, optional arguments have to appear after all positional arguments, and the variable arg type must appear after all positional and optional arguments. +Functions must have a return type. + +For example, the function `$string_char_at(s: String, i: usize) -> char` takes in a string `s` and an index `i`, and returns the character at that location. +The two arguments `s` and `i` are both positional arguments. + +In the function `$substring(s: String, b: usize, e: usize?)`, we have 2 positional arguments (`s` and `b`) and 1 optional argument (`e`). +This means that this substring function can be invoked with 2 or 3 arguments. +Invoking `$substring("hello", 3)` would give us `"lo"`, and invoking `$substring("hello", 1, 3)` would give us `"el"`. + +For functions like `$abs(T) -> T`, we have absolute value function taking in value of any type that is a number (including integers and floating points). +The function also returns a type the same as the input. + +For a function like `$format(f: String, Any...)`, it looks at the format string and fill all the `{}` symbol in the string with the latter arguments. +Notice how there can be arbitrary number of arguments (variable arg) of `Any` type. +For example, we can have `$format("{} + {}", 3, "a") ==> "3 + a"` and `$format("{}", true) ==> "true"`. + +## Function Failures + +Foreign functions may fail with errors such as divide-by-zero, index-out-of-bounds. +When error happens, values will not be propagated along the computation, and will be dropped silently. + +For example, the following program makes use of the foreign function `$string_char_at`. +It walks through the indices 1, 3, and 5, and get the character on those indices of the string `"hello"`. + +``` scl +rel indices = {1, 3, 5} +rel output(i, $string_char_at("hello", i)) = indices(i) +``` + +However, there are only 5 characters in the string, meaning that getting the `5`-th character would result in an index-out-of-bounds error. +Scallop will drop this invokation silently, resulting in only two facts being derived: + +``` +output: {(1, 'e'), (3, 'l')} +``` + +Similar things happen when `nan` is derived from floating point operations, or that the foreign function fails. + +## Library of Foreign Functions + +We hereby list all the available foreign functions, their signatures, descriptions, and an example of how to invoke them. +The functions here are ordered alphabetically. +For some of the functions that are slightly more complicated (e.g. `$format`), please refer to the section below for more information. + +| Function | Description | Example | +|:---------|----------------------------------|-----------------| +| `$abs(x: T) -> T` | Absolute value function \\(\lvert x \rvert\\) | `$abs(-1)` => `1` | +| `$acos(x: T) -> T` | Arc cosine function \\(\text{acos}(x)\\) | `$acos(0.0)` => `1.5708` | +| `$atan(x: T) -> T` | Arc tangent function \\(\text{atan}(x)\\) | `$atan(0.0)` => `0.0` | +| `$atan2(y: T, x: T) -> T` | 2-argument arc tangent function \\( \text{atan}(y, x) \\) | `$atan2(0.0, 1.0)` => `0.0` | +| `$ceil(x: T) -> T` | Round *up* to closest integer \\( \lceil x \rceil \\) | `$ceil(0.5)` => `1.0` | +| `$cos(x: T) -> T` | Cosine function \\(\text{cos}(x)\\) | `$cos(0.0)` => `1.0` | +| `$datetime_day(d: DateTime) -> u32` | Get the day component of a `DateTime`, starting from 1 | `$datetime_day(t"2023-01-01")` => `1` | +| `$datetime_month(d: DateTime) -> u32` | Get the month component of a `DateTime`, starting from 1 | `$datetime_month(t"2023-01-01")` => `1` | +| `$datetime_month0(d: DateTime) -> u32` | Get the month component of a `DateTime`, starting from 0 | `$datetime_month0(t"2023-01-01")` => `0` | +| `$datetime_year(d: DateTime) -> i32` | Get the year component of a `DateTime` | `$datetime_month0(t"2023-01-01")` => `2023` | +| `$dot(a: Tensor, b: Tensor) -> Tensor` | Dot product of two tensors \\(a \cdot b\\); only available when compiled with `torch-tensor` | | +| `$exp(x: T) -> T` | Exponential function \\(e^x\\) | `$exp(0.0)` => `1.0` | +| `$exp2(x: T) -> T` | Exponential function \\(2^x\\) (base 2) | `$exp2(2.0)` => `4.0` | +| `$floor(x: T) -> T` | Round *down* to closest integer \\(\lfloor x \rfloor\\) | `$exp2(2)` => `4` | +| `$format(String, Any...) -> String` | Formatting string | `$format("{} + {}", 3, "a")` => `"3 + a"` | +| `$hash(Any...) -> u64` | Hash the given values | `$hash("a", 3, 5.5)` => `5862532063111067262` | +| `$log(x: T) -> T` | Natural logarithm function \\(\text{log}_e(x)\\) | `$log(1.0)` => `0.0` | +| `$log2(x: T) -> T` | Logarithm function \\(\text{log}_2(x)\\) (base 2) | `$log2(4.0)` => `2.0` | +| `$max(T...) -> T` | Maximum \\(\text{max}(x_1, x_2, \dots)\\) | `$max(4.0, 1.0, 9.5)` => `9.5` | +| `$min(T...) -> T` | Minimum \\(\text{min}(x_1, x_2, \dots)\\) | `$max(4.0, 1.0, 9.5)` => `1.0` | +| `$pow(x: T, y: u32) -> T` | Integer power function \\(x^y\\) | `$pow(2.2, 2)` => `4.84` | +| `$powf(x: T, y: T) -> T` | Float power function \\(x^y\\) | `$powf(4.0, 0.5)` => `2.0` | +| `$sign(x: T) -> T` | Sign function that returns \\(\{-1, 0, 1\}\\) in respective types | `$sign(-3.0)` => `-1.0` | +| `$sin(x: T) -> T` | Sine function \\(\text{sin}(x)\\) | `$sin(0.0)` => `0.0` | +| `$string_char_at(s: String, i: usize) -> char` | Get the `i`-th character of string `s` | `$string_char_at("hello", 2)` => `'l'` | +| `$string_concat(String...) -> String` | Concatenate multiple strings | `$string_concat("hello", " ", "world")` => `"hello world"` | +| `$string_index_of(s: String, pat: String) -> usize` | Find the index of the first occurrence of the pattern `pat` in string `s` | `$string_index_of("hello world", "world")` => `6` | +| `$string_length(s: String) -> usize` | Get the length of the string | `$string_length("hello")` => `5` | +| `$string_lower(s: String) -> String` | To lower-case | `$string_lower("LisA")` => `"lisa"` | +| `$string_trim(s: String) -> String` | Trim a string | `$string_trim(" hello ")` => `"hello"` | +| `$string_upper(s: String) -> String` | To upper-case | `$string_upper("LisA")` => `"LISA"` | +| `$substring(s: String, b: usize, e: usize?) -> String` | Find the substring given begin index and optional the end index | `$substring("hello world", 6)` => `"world"` | +| `$tan(x: T) -> T` | Tangent function \\(\text{tan}(x)\\) | `$tan(0.0)` => `0.0` | diff --git a/doc/src/language/foreign_predicates.md b/doc/src/language/foreign_predicates.md index fc58c10..64463b2 100644 --- a/doc/src/language/foreign_predicates.md +++ b/doc/src/language/foreign_predicates.md @@ -10,18 +10,44 @@ There could be infinitely many triplets, and we cannot simply enumerate all of t But if the first two arguments `begin` and `end` are given, we can reasonably enumerate the `i`. In Scallop, `range` is available to be used as a **foreign predicate**. -Note that Scallop's foreign predicates are currently statically typed and does not support signature overload. -For example, to use `range` on `i32` data, we will need to invoke `range_i32`: +Notice that `range` can be applied on any integer data, making it a generic predicate. +For example, to use `range` on `i32` data, we will need to invoke `range`: ``` scl -rel result(x) = range_i32(0, 5, x) +rel result(x) = range(0, 5, x) ``` Here we enumerate the value of `x` from 0 (inclusive) to 5 (exclusive), meaning that we will obtain that `result = {0, 1, 2, 3, 4}`. For the rest of this section, we describe in detail how foreign predicates are constructed in Scallop and why are they useful. -## Bound and Free Pattern +## Foreign Predicate Types -## Foreign Predicates are Statically Typed +Foreign predicates can be generic and are statically typed. +In addition to just providing the argument types, we also need to provide a boundness pattern. -## Available Foreign Predicates in `std` +A boundness pattern is a string of length equal to the relation arity and consisting of `b` and `f`. +The character `b` means *bounded*, reflecting the variable on that position is taken as input to the predicate. +The character `f` means *free*, suggesting that the variable on that position will be generated as output by the predicate. + +For example, the full definition of `range` is + +``` +range(begin: T, end: T, i: T)[bbf] +``` + +Notice that at the end of the definition we have `[bbf]`. +Here, `bbf` is a boundness pattern for range, suggesting that `begin` and `end` will be provided as input, and `i` will be generated as output. + +> In the future, we plan to allow the definition of multiple boundness patterns + +## Standard Library of Foreign Predicates (Part A) + +In this part, we give an overview to the foreign predicates that are discrete only. + +| Foreign Predicate | Description | +|-------------------|-------------| +| `datetime_ymd(d: DateTime, y: i32, m: u32, d: u32)[bfff]` | Get the `y`ear, `m`onth, and `day` from a `DateTime` value | +| `range(begin: T, end: T, i: T)[bbf]` | Generate all the integers `i` starting from `begin` and end with `end - 1` | +| `string_chars(s: String, i: usize, c: char)[bff]` | Generate all the index-character tuples inside of string `s` | +| `string_find(s: String, pat: String, begin: usize, end: usize)[bbff]` | Generate all the begin-end ranges of the pattern `pat`'s occurrence in the string `s` | +| `string_split(s: String, pat: String, out: String)[bbf]` | Split the string `s` using the pattern `pat` and generate the `out` strings | diff --git a/doc/src/language/index.md b/doc/src/language/index.md index e69de29..74dfe6d 100644 --- a/doc/src/language/index.md +++ b/doc/src/language/index.md @@ -0,0 +1,5 @@ +# Scallop and Logic Programming + +In this part of the book we introduce Scallop as a relational and logical programming language. +Scallop is a Datalog based language extended with various language features such as negation, aggregation, disjunctive head, algebraic data type, foreign functions, and foreign predicates. +We will explain all of these concepts in detail, aiming to provide a comprehensive introduction to the core language constructs. diff --git a/doc/src/language/loading_csv.md b/doc/src/language/loading_csv.md new file mode 100644 index 0000000..83d5c82 --- /dev/null +++ b/doc/src/language/loading_csv.md @@ -0,0 +1,153 @@ +# Loading from CSV + +Scallop can be used along with existing datasets loaded from CSVs. +This is usually achieved with annotating on specific relations. +For example, assuming we have a file `edge.csv`, + +``` csv +0,1 +1,2 +``` + +we can load the content of it into a relation `edge` in Scallop using the following syntax + +``` scl +@file("edge.csv") +type edge(from: usize, to: usize) + +rel path(a, c) = edge(a, c) or path(a, b) and edge(b, c) + +query path +``` + +In particular, we annotate the `@file(...)` attribute onto the relation type declaration `type edge(...)`. +The file name is written inside the `@file` attribute. +We require the relation to be declared with types in order for it to be loaded with CSV file content. +Depending on the type declaration, the file content will be parsed into values of certain types. + +From here, the `edge` relation will be loaded with the content `(0, 1)` and `(1, 2)`. +After executing the Scallop program above, we would obtain the result `path` being `(0, 1)`, `(0, 2)`, and `(1, 2)`. + +Certainly, there are many ways to load CSV. +In this section, we introduce the various ways to configure the CSV loading. + +## Headers + +There are CSV files with headers. +Suppose we have the following CSV file + +``` csv +from,to +0,1 +1,2 +``` + +To load this file, we would need to add an additional argument `header=true` to the `@file` attribute: + +``` scl +@file("edge.csv", header=true) +type edge(from: usize, to: usize) +``` + +Note that by default we assume that CSV files don't have headers. + +## Deliminators + +By default, we assume the values inside of the CSV file are deliminated by commas `','`. +In case where CSV files have values deliminated by other characters, such as tabs `'\t'`, we would need to specify that in the `@file` attribute: + +``` scl +@file("edge.csv", deliminator="\t") +type edge(from: usize, to: usize) +``` + +Note that deliminators cannot be of multiple characters. + +## Parsing Field-Value Pairs + +There are many CSV tables which have a lot of columns. +One way is to specify all the fields and their types, like the following. + +``` scl +type table(field1: type1, field2: type2, ..., fieldn: typen) +``` + +However, this might be very hard to encode. +Therefore, we provide another way of parsing CSV files into relations, by using primary keys and field-value pairs. +Let's assume we have the following CSV file: + +``` csv +student_id,name,year,gender +0001,alice,2020,female +0002,bob,2021,male +``` + +We see that `student_id` can serve as the primary key of this table. +With this, it can be loaded into the following relation + +``` scl +@file("student.csv", keys="student_id") +type table(student_id: usize, field: String, value: String) +``` + +By specifying `keys="student"`, we tell Scallop that `student_id` should be viewed as unique primary keys. +The rest of the two elements are `field` and `value`, both need to be typed `String`s. +As a result, it produces the following 6 facts in the `table` relation: + +``` +(1, "name", "alice"), (1, "year", "2020"), (1, "gender", "female"), +(2, "name", "bob"), (2, "year", "2021"), (2, "gender", "male") +``` + +Note that there could be more than one keys. +Consider the following table + +``` +student_id,course_id,enroll_time,grade +0001,cse100,fa2020,a +0001,cse101,sp2021,a +0002,cse120,sp2021,b +``` + +We see that the combination of `student_id` and `course_id` form the unique primary keys. +In this case, they can be loaded using the following syntax: + +``` scl +@file("enrollment.csv", keys=["student_id", "course_id"]) +type enrollment(student_id: usize, course_id: String, field: String, value: String) +``` + +By setting `keys` to be a list `["student_id", "course_id"]`, the `student_id` field is the first primary key and `course_id` is the second. +There are still two additional arguments for the `enrollment` relation. +In general, the arity of the relation will be the number of primary keys plus 2. + +## Specifying Fields to Load + +In case not all fields are desired when loading, one can use the `fields` argument to specify what to load. +Consider the same enrollment table encoded in CSV: + +``` csv +student_id,course_id,enroll_time,grade +0001,cse100,fa2020,a +0001,cse101,sp2021,a +0002,cse120,sp2021,b +``` + +If we only want to get everything but omit the `enroll_time` column, we can do + +``` scl +@file("enrollment.csv", fields=["student_id", "course_id", "grade"]) +type enrollment(student_id: usize, course_id: String, grade: String) +``` + +This can also work in conjunction with the `keys` argument. +In this case, we do not need to specify the primary keys. + +``` scl +@file("enrollment.csv", keys=["student_id", "course_id"], fields=["grade"]) +type enrollment(student_id: usize, course_id: String, field: String, value: String) +// The following facts will be obtained +// enrollment(1, "cse100", "grade", "a") +// enrollment(1, "cse101", "grade", "a") +// enrollment(2, "cse120", "grade", "b") +``` diff --git a/doc/src/language/magic_set.md b/doc/src/language/magic_set.md new file mode 100644 index 0000000..5523f61 --- /dev/null +++ b/doc/src/language/magic_set.md @@ -0,0 +1 @@ +# Magic-Set Transformation diff --git a/doc/src/language/negation.md b/doc/src/language/negation.md index d49ceb6..47471de 100644 --- a/doc/src/language/negation.md +++ b/doc/src/language/negation.md @@ -1 +1,103 @@ -# Rules with Negations +# Negations + +Scallop supports negation to be attached to atoms to form negations. +In the following example, we are trying to obtain the set of people with no children: + +``` scl +rel person = {"bob", "alice", "christine"} // There are three persons of interest +rel father = {("bob", "alice")} // Bob is Alice's father +rel mother = {("alice", "christine")} // Alice is Christine's mother + +rel has_no_child(n) = person(n) and not father(n, _) and not mother(n, _) +``` + +The last rule basically says that if there is a person `n` who is neither anyone's father nor anyone's mother then the person `n` has no child. +This is indeed what we are going to obtain: + +``` +has_no_child: {("christine",)} +``` + +It is clear that negations are very helpful in writing such kind of the rules. +However, there are many restrictions on negations. +We explain in detail such restrictions. + +## Negation and Variable Grounding + +If we look closely to the rule of `has_no_child` above, we will find that there is an atom `person(n)` being used in the body. +Why can't we remove it and just say "if one is neither father nor mother then the one has no child"? + +``` scl +rel has_no_child(n) = not father(n, _) and not mother(n, _) // Error: variable `n` is not grounded +``` + +The problem is with variable grounding. +For the variable `n` to be appeared in the head, there is **no positive atom** that grounds it. +All we are saying are what `n` is not, but not what `n` is. +With only "what it is not", it could be literally anything else in the world. + +Therefore, we need to ground it with a positive atom such as `person(n)`. +With this rule, we have basically + +## Stratified Negation + +Expanding upon [our definition of dependency graph](recursion.md#relation-dependency), +if a predicate occurs in a negative atom in a body, +we say that the predicate of the rule head *negatively depends* on this predicate. +For example, the above `has_no_child` example has the following dependency graph. +Notice that we have marked the *positive* (`pos`) and *negative* (`neg`) on each edge: + +``` +person <--pos-- has_no_child --neg--> father + | + +-----neg-----> mother +``` + +Scallop supports *stratified negation*, which states that there is never a loop in the dependency graph which involves a negative dependency edge. +In other words, if there exists such a loop, the program will be rejected by the Scallop compiler. +Consider the following example: + +``` scl +rel is_true() = not is_true() // Rejected +``` + +The relation `is_true` negatively depends on the relation `is_true` itself, making it a loop containing a negative dependency edge. +The error message would show that this program "cannot be stratified". +If we draw the dependency graph of this program, it look like the following: + +``` scl +is_true <---+ + | | + +--neg---+ +``` + +Since there is a loop (`is_true -> is_true`) and the loop contains a negative edge, this program cannot be stratified. + +The reason that stratified negation is named such way is that, if there is no negative dependency edge in a loop, the whole dependency can be decomposed in to [*strongly connected components*](https://en.wikipedia.org/wiki/Strongly_connected_component), where inside of each strongly connected component (SCC), there is no negative dependency. +In other words, the negation has been *stratified*, so that the negative edge can only happen between SCCs. +We call each SCC a *stratum*, and the collection of them a *strata*. +Any non-recursive program has a dependency graph forming a *Directed Acyclic Graph* (DAG), and is therefore always stratifiable. + +The following program, although containing both negation and recursion, can be stratified: + +``` scl +rel path(a, b) = edge(a, b) and not sanitized(b) +rel path(a, c) = path(a, b) and edge(b, c) and not sanitized(b) +``` + +For it, the following dependency graph can be drawn: + +``` +sanitized <--neg-- path <----+ + | | | + edge <--pos---+ +--pos-+ +``` + +In this program, we have three SCCs (or strata): + +- Stratum 1: `{edge}` +- Stratum 2: `{sanitized}` +- Stratum 3: `{path}` + +Negative dependency only occurs between stratum 2 and 3. +Therefore, the program can be accepted. diff --git a/doc/src/language/query.md b/doc/src/language/query.md index e69de29..a58ea39 100644 --- a/doc/src/language/query.md +++ b/doc/src/language/query.md @@ -0,0 +1,53 @@ +# Writing Queries + +Consider the following example of classes, students, and enrollments, and that we want to compute the number of students who have enrolled in at least one CS class. + +``` scl +// There are three classes +rel classes = {0, 1, 2} + +// Each student is enrolled in a course (Math or CS) +rel enroll = { + ("tom", "CS"), ("jenny", "Math"), // Class 0 + ("alice", "CS"), ("bob", "CS"), // Class 1 + ("jerry", "Math"), ("john", "Math"), // Class 2 +} + +// Count how many student enrolls in CS course +rel num_enroll_cs(n) = n := count(s: enroll(s, "CS")) +``` + +Normally, executing a program would result in `scli` outputting every single relation. + +``` +classes: {(0), (1), (2)} +num_enroll_cs: {(3)} +enroll: {("alice", "CS"), ("bob", "CS"), ("jenny", "Math"), ...} +``` + +However, we might only be interested in the relation named `num_enroll_cs`. +In this case, we write a *query* using the `query` keyword: + +``` scl +query num_enroll_cs +``` + +In this case, only the relation `num_enroll_cs` will be output: + +``` +num_enroll_cs: {(3)} +``` + +## Atomic Query + +One can also write atomic query if we just want to get a part of the relation. +For instance, consider the fibonacci example: + +``` scl +type fib(x: i32, y: i32) +rel fib = {(0, 1), (1, 1)} +rel fib(x, y1 + y2) = fib(x - 1, y1) and fib(x - 2, y2) and x <= 10 +query fib(8, y) // fib(8, y): {(8, 34)} +``` + +In this case, we are just looking at the 8-th fibonacci number, which is 34. diff --git a/doc/src/language/recursion.md b/doc/src/language/recursion.md index 30a25cc..a19abb6 100644 --- a/doc/src/language/recursion.md +++ b/doc/src/language/recursion.md @@ -1 +1,156 @@ # Recursive Rules + +One very powerful programming construct with Scallop is to declaratively define recursion. +Inside of a rule, if a relational predicate appearing in the head appears in the body, the predicate is recursive. +For example, the definition of fibonacci number is recursive: + +\\[ \text{fib}(x) = \left\\{ \begin{array}{ll} \text{fib}(x - 1) + \text{fib}(x - 2), & \text{if}~ x > 1 \\\\ 1, & \text{otherwise} \end{array} \right. \\] + +Written in Scallop, we encode the function `fib` as a binary relation between the integer input and output: + +``` scl +type fib(x: i32, y: i32) +``` + +We can define the base cases for \\(\text{fib}(0)\\) and \\(\text{fib}(1)\\): + +``` scl +rel fib = {(0, 1), (1, 1)} +``` + +Now it comes to the definition of recursive cases, which peeks into \\(\text{fib}(x - 1)\\) and \\(\text{fib}(x - 2)\\) and sums them. + +``` scl +rel fib(x, y1 + y2) = fib(x - 1, y1) and fib(x - 2, y2) // infinite-loop +``` + +However, when actually executing this, it would not terminate as we are attempting to compute all fibonacci numbers, and there are infinite amount of them. +In order to stop it, we can temporarily add a constraint to limit the value of `x`, so that we only compute the fibonacci number up to 10: + +``` scl +rel fib(x, y1 + y2) = fib(x - 1, y1) and fib(x - 2, y2) and x <= 10 +``` + +At the end, we would get a the `fib` relation to contain the following facts: + +``` +fib: {(0, 1), (1, 1), (2, 2), (3, 3), (4, 5), (5, 8), (6, 13), (7, 21), (8, 34), (9, 55), (10, 89)} +``` + +As suggested by the result, the 10-th fibonacci number is 89. + +## Case Study: Graphs and Transitive Closure + +Following is one of the most widely known Datalog program: computing the `path`s inside of a graph. +By definition, an edge or a sequence of edges constitute a path. +This is reflected by the following two rules: + +``` scl +type edge(i32, i32) + +rel path(a, b) = edge(a, b) +rel path(a, c) = path(a, b) and edge(b, c) +``` + +The first line states that an edge can form a path. +The second line states that a path, connected to a new edge, forms a new path. +As can be seen from the second line, the relation `path` appears in both the body and the head, making it a *recursive relation*. + +In this example, suppose we have + +``` scl +rel edge = {(0, 1), (1, 2)} +``` + +we would get the set of paths to be + +``` +path: {(0, 1), (0, 2), (1, 2)} +``` + +Notice that the path `(0, 2)` is a compound path obtained from joining the two edges `(0, 1)` and `(1, 2)`. + +## Relation Dependency + +Given a rule with head and body, we say that the predicate appearing in the head *depends* on the predicates of the atoms appearing in the body. +This forms a dependency graph. +The above edge-path example would have the following dependency graph: + +``` +edge <--- path <---+ + | | + +------+ +``` + +The relation `edge` depends on nothing, while `path` depends on `edge` and also `path` itself. +This forms a loop in the dependency graph. +In general, if a program has a dependency graph with a loop, then the program requires *recursion*. +Any relation that is involved in a loop would be a *recursive relation*. + +Notice that we are mostly talking about *positive dependency* here, as the atoms in the body of the rule are *positive atoms* (i.e., without annotation of negation or aggregation). +In more complex scenarios, there will be negation or aggregation in a rule, which we explain in detail in future sections. + +## Fixed-point Iteration + +The recursion in Scallop happens in *fixed-point iteration*. +In plain terms, the recursion will continue until there is no new fact being derived in an iteration. +In hind-sight, the whole Scallop program is executed in a loop. +Within one iteration, all of the rules inside of the program are executed. +Let us digest the actual execution happens when executing the above edge-path program: + +``` scl +rel edge = {(0, 1), (1, 2), (2, 3)} +rel path(a, b) = edge(a, b) // rule 1 +rel path(a, c) = path(a, b) and edge(b, c) // rule 2 +``` + +Before the first iteration, the `edge` has already been filled with 3 facts, namely `(0, 1)`, `(1, 2)`, and `(2, 3)`. +But the `path` is empty. +Let's now go through all the iterations: + +``` +Iter 0: path = {} +Iter 1: path = {(0, 1), (1, 2), (2, 3)} + Δpath = {(0, 1), (1, 2), (2, 3)} // through applying rule 1 +Iter 2: path = {(0, 1), (1, 2), (2, 3), (0, 2), (1, 3)} + Δpath = {(0, 2), (1, 3)} // through applying rule 2 +Iter 3: path = {(0, 1), (1, 2), (2, 3), (0, 2), (1, 3), (0, 3)} + Δpath = {(0, 3)} // through applying rule 2 +Iter 4: path = {(0, 1), (1, 2), (2, 3), (0, 2), (1, 3), (0, 3)} + Δpath = {} +``` + +In the above note, we also include `Δpath`, which contains the new paths derived during the current iteration. +As can be seen, during iteration 1, paths of length 1 are derived; during iteration 2, paths of length 2 are derived. +During iteration 4, there is no more path to be derived, and therefore the `Δpath` is empty. +This tells us that no new facts are derived and the whole fixed-point iteration is stopped, giving us the final result. + +## Infinite Relations + +As we have described in the *fixed-point iteration*, the recursion will continue until no more fact is derived. +However, we are capable of writing rules that are infinite. +As shown in the first example: + +``` scl +rel fib(x, y1 + y2) = fib(x - 1, y1) and fib(x - 2, y2) +``` + +gives you an infinite relation as there can always be a new `x` to be derived. +In this case, the fixed-point iteration never stops. + +The root cause of this is Scallop's support for *value creationg*, i.e., the creation of new values. +Typically, database systems work in closed-world assumption, that is, all the items being reasoned about are already there. +No computation is done on arbitrarily created values. +But in the above example, we have derived `x` from the grounded expression `x - 1`, hence creating a new value. + +Typically, the way to resolve this is to create bounds on the created values. +For example, the rule + +``` scl +rel fib(x, y1 + y2) = fib(x - 1, y1) and fib(x - 2, y2) and x <= 10 +``` + +restricts that `x` cannot be greater than 10. +This makes the fixed-point iteration to stop after around 10 iterations. + +Other way of getting around with this involve the use of [*Magic-Set Transformations*](https://dl.acm.org/doi/pdf/10.1145/6012.15399), which we describe its equivalent in Scallop in [a later section](magic_set.md). diff --git a/doc/src/language/reference_guide.md b/doc/src/language/reference_guide.md new file mode 100644 index 0000000..c6124e0 --- /dev/null +++ b/doc/src/language/reference_guide.md @@ -0,0 +1,111 @@ +# Reference Guide + +We list all the language features supported by Scallop. + +## Import Files + +``` scl +import "path/to/other/file.scl" +``` + +## Type Definition + +### Type Alias Definition + +``` scl +type ObjectId = usize +``` + +### Sub-Type Definition + +``` scl +type Name <: String +``` + +### Enum Type Definition + +``` scl +type Action = LEFT | RIGHT | UP | DOWN +``` + +### Algebraic Data Type Definition + +``` scl +type Expr = Const(i32) | Add(Expr, Expr) | Sub(Expr, Expr) +``` + +### Relation Type Definition + +``` scl +type edge(x: i32, y: i32) +``` + +## Constant Definition + +``` scl +const PI: f32 = 3.1415 +``` + +## Relation Definition + +### Fact Definition + +``` scl +rel edge(1, 2) +``` + +### Set-of-Tuples Definition + +``` scl +rel edge = {(1, 2), (2, 3), (3, 4)} +``` + +### Rule Definition + +``` scl +rel path(a, b) = edge(a, b) or path(a, c) and edge(c, b) +``` + +#### Disjunctive Head + +``` scl +rel { assign(v, false); assign(v, true) } = variable(v) +``` + +#### Atom + +``` scl +fib(x - 1, y) +``` + +#### Negation + +``` scl +rel has_no_child(p) = person(p) and not father(p, _) and not mother(p, _) +``` + +#### Constraint + +``` scl +rel number(0) +rel number(i + 1) = number(i) and i < 10 +``` + +#### Aggregation + +``` scl +rel person = {"alice", "bob", "christine"} +rel num_people(n) = n := count(p: person(p)) +``` + +#### Foreign Predicate + +``` scl +rel grid(x, y) = range(0, 5, x) and range(0, 5, y) +``` + +## Query Definition + +``` scl +query path +``` diff --git a/doc/src/language/relation.md b/doc/src/language/relation.md new file mode 100644 index 0000000..bc62bea --- /dev/null +++ b/doc/src/language/relation.md @@ -0,0 +1,178 @@ +# Relations and Facts + +Scallop is a relational and logical programming language. +As described in the [Wikipedia](https://en.wikipedia.org/wiki/Logic_programming): + +> Logic programming is a programming paradigm which is largely based on formal logic. +> Any program written in a logic programming language is a set of sentences in logical form, +> expressing facts and rules about some problem domain. + +In Scallop, relations are the most fundamental building blocks of program. +In the following example, we have declared the type of a relation called `edge`, using the `type` keyword: + +``` scl +type edge(a: i32, b: i32) +``` + +We say that the name `edge` is a *predicate* or a *relation*. +Inside of the parenthesis, we have two `arguments`, `a: i32` and `b: i32`. +Therefore, we have `edge` being an *arity-2* relation, due to it having 2 arguments. +For the argument `a: i32`, we give a name of the field (`a`) and a type of that argument `i32`. +Here, both of the arguments are of the `i32` type, which means signed-*i*nteger, *32*-bit. +For more information on value types, refer to [the Value Types section](#value-types). + +The above line only declares the type of the relation but not the *content* of the relation. +The actual information stored in the relations are called *facts*. +Here we define a single fact under the relation `edge`: + +``` scl +rel edge(0, 1) +``` + +Assuming `0` and `1` each denote an ID of a node, this fact declares that there is an edge going from node `0` to node `1`. +There are two arguments in this fact, matching the arity of this relation. +Regardless of the predicate `edge`, one also simply consider the `(0, 1)` as a *tuple*, more specifically, a *2-tuple*. + +To declare multiple facts, one can simply write multiple single fact declaration using the `rel` keyword, like + +``` scl +rel edge(0, 1) +rel edge(1, 2) +``` + +One can also use the *set* syntax to declare multiple facts of a relation. +The following line reads: "the relation `edge` contains a set of tuples, including `(0, 1)` and `(1, 2)`": + +``` scl +rel edge = {(0, 1), (1, 2)} +``` + +Note that it is possible to declare multiple fact sets for the same relation. + +``` scl +rel edge = {(0, 1), (1, 2)} +rel edge = {(2, 3)} +``` + +With the above two lines the edge relation now contains 3 facts, `(0, 1)`, `(1, 2)`, and `(2, 3)`. + +## Examples of Relations + +### Boolean and 0-arity Relation + +Many things can be represented as relations. +We start with the most basic programming construct, boolean. +While Scallop allows value to have the boolean type, relations themselves can encode boolean values. +The following example contains an *arity-0* relation named `is_target`: + +``` scl +type is_target() +``` + +There is only one possible tuple that could form a fact in this relation, that is the *empty tuple* `()`. +Assuming that we are treating the relation `is_target` as a set, then if the set contains no element (i.e., empty), then it encodes boolean "false". +Otherwise, if the set contains exactly one (note: it can contain at most one) tuple, then it encodes boolean "true". +Declaring only the type of `is_target` as above, would assume that the relation is empty. +To declare the fact, we can do: + +``` scl +rel is_target() +// or +rel is_target = {()} +``` + +### Unary Relations + +Unary relations are relations of arity 1. +We can define unary relations for "variables" as we see in other programming languages. +The following example declares a relation named `greeting` containing one single string of `"hello world!"`. + +``` scl +rel greeting("hello world!") +// or +rel greeting = {("hello world!",)} +``` + +Note that for the second way of expressing the fact, we may omit the parenthesis and make it cleaner: + +``` scl +rel greeting = {"hello world!"} +``` + +In light of this, we may write the following rule: + +``` scl +rel possible_digit = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} +``` + +### Integer Arithmetics as Relations + +Integer arithmetics can be represented as relations as well. +Consider a simple summation in algebra, `a + b = c` encodes the sum relationship among two operands (`a` and `b`) and their summation (`c`). +Encoded in Scallop, they form arity-3 relations: + +``` scl +type add(op1: i32, op2: i32, result: i32) +``` + +Note that, in Scallop, relations are *not* polymorphic. +That is, every relation, no matter declared or inferred, only has one type annotation. + +> We are working on an update in the future to relax this restriction. + +To declare facts of this `add` relation, such as `3 + 4 = 7`, we write + +``` scl +rel add(3, 4, 7) // 3 + 4 = 7 +``` + +However, you might notice that the `add` relation is theoretically *infinite*. +That is, there are infinite amount of facts that can satisfy the `add` relation. +There is no way that we can possibly enumerate or declare all the facts. +In such case, we resort to declaring rules using foreign functions or predicates, which we will discuss later. +For now, let's use `add` as an example relation that encodes integer arithmetics. + +### Terminologies + +We have the following terminologies for describing relations. + +- Boolean Relation: arity-0 relation +- Unary Relation: arity-1 relation +- Binary Relation: arity-2 relation +- Ternary Relation: arity-3 relation + +## Type Inference + +Scallop supports *Type Inference*. +One does not need to fully annotate every relation on their types. +Types are inferred during the compilation process. + +For example, given the following code, + +``` scl +rel edge = {(0, 1), (1, 2)} +``` + +we can infer that the relation `edge` is a binary-relation where both arguments are integers. +Note that when integers are specified, they are set default to the type of `i32`. + +Type inference will fail if conflicts are detected. +In the following snippet, we have the second argument being `1` as integer and also `"1"` as string. + +``` scl +rel edge = {(0, 1), (0, "1")} +``` + +Having this code will raise the following compile error, suggesting that the types cannot be unified. +Note that the following response is generated in `sclrepl` command line interface. + +``` +[Error] cannot unify types `numeric` and `string`, where the first is declared here + REPL:0 | rel edge = {(0, 1), (0, "1")} + | ^ +and the second is declared here + REPL:0 | rel edge = {(0, 1), (0, "1")} + | ^^^ +``` + +For more information on values and types, please refer to the [next section](value_type.md) diff --git a/doc/src/language/rules.md b/doc/src/language/rules.md index e69de29..d08dfab 100644 --- a/doc/src/language/rules.md +++ b/doc/src/language/rules.md @@ -0,0 +1,200 @@ +# Rules + +*Rules* are the fundamental to computation in Scallop. +Each rule defines the value and data flowing from some relation to another relation. +In the following program, we have defined a few facts for the `edge` relation. +On the second line, we have defined that, for each edge `(a, b)`, there is also a path `(a, b)`. +We note that here, `a` and `b` are variables instead of constants as we have with defining facts. +During computation, the two facts in `edge` will populate the `path` relation. +This way, we have defined a rule for the `path`, which is executed during computation. + +``` scl +rel edge = {(0, 1), (1, 2)} +rel path(a, b) = edge(a, b) // (0, 1), (1, 2) +``` + +In this section, we talk about how we write rules in Scallop and how intricate computation can be done through it. + +## Syntax + +In general, the basic rules in Scallop are of the form + +``` +RULE ::= rel ATOM = FORMULA +FORMULA ::= ATOM + | not ATOM + | CONSTRAINT + | AGGREGATION + | FORMULA and FORMULA + | FORMULA or FORMULA + | ( FORMULA ) +``` + +For each rule, we name the atom on the left to be the *head* of the rule, and the formula on the right to be the *body*. +We read it from right to left: when the body formula holds, the head also holds. +The formula might contain atoms, negated atoms, aggregations, conjunction, disjunction, and a few more constructs. +For this section, we focus on simple (positive) atom, constraints, and their conjunctions and disjunctions. +We will leave the discussion of negation and aggregation to the next sections. + +## Atom + +Simple atoms are of the form `RELATION(ARG_1, ARG_2, ...)`. +Similar to facts, we have the relation name followed by a tuple of numerous arguments. +Now, the arguments can be of richer forms, involving variables, constants, expressions, function calls, and many more. + +Considering the most basic example from above: + +``` scl +rel path(a, b) = edge(a, b) +``` + +We have two variables `a` and `b` *grounded* by the `edge` relation. +This means we are treating the variables `a` and `b` as source of information, which can be propagated to the head. +In this example, the head also contains two variables, both being grounded by the body. +Therefore the whole rule is well formed. + +In case the head variables are not grounded by the body, such as the following, + +``` scl +rel path(a, c) = edge(a, b) +``` + +we would get an error that looks like the following: + +``` +[Error] Argument of the head of a rule is ungrounded + REPL:1 | rel path(a, c) = edge(a, b) + | ^ +``` + +The error message points us to the variable `c` that has not being grounded in the body. + +For basic atoms, such as the ones that the user has defined, can be used to directly ground variables which are directly arguments of the atoms. +They can be used to ground other variables or expressions. +In the following example, although the rule itself might not make any sense, the variable `a` is used to ground the expression `a + 1`. +Therefore, the rule is completely valid. + +``` scl +rel output_relation(a, a + 1) = input_relation(a) +``` + +In certain cases, expressions can be used to bound variables as well! + +``` scl +rel output_relation(a, b) = input_relation(a, b + 1) +``` + +In the above example, the expression `b + 1` can be used to derive `b`, and thus making the variable `b` grounded. +However, this might not be true for other expressions: + +``` scl +rel output_relation(b, c) = input_relation(b + c) // FAILURE +``` + +The `input_relation` can ground the expression `b + c` directly, however, the two arguments `b` and `c` cannot be derived from their sum, as there are (theoretically) infinite amount of combinations. +In this case, we will get a compilation failure. + +There can be constraints present in atoms as well. +For example, consider the following rule: + +``` scl +rel self_edge(a) = edge(a, a) +``` + +The atom `edge(a, a)` in the body grounds only one variable `a`. +But the pattern is used to match any edge that goes from `a` and to `a` itself. +Therefore, instead of grounding two values representing the "from" and "to" of an `edge`, we are additionally posing constraint on the type of edge that we are matching. +Conceptually, we can view the above rule as the following equivalent rule: + +``` scl +rel self_edge(a) = edge(a, b) and a == b +``` + +where there is an additional constraint posed on the equality of `a` and `b`. +We are going to touch on `and` and constraints in the upcoming sections. + +## Disjunction (Or) + +The body formula can contain logical connectives such as `and`, `or`, `not`, and `implies`, used to connect basic formulas such as *Atom*. +In the following example, we are defining that if `a` is `b`'s *father* or *mother*, then `a` is `b`'s parent: + +``` scl +rel parent(a, b) = father(a, b) +rel parent(a, b) = mother(a, b) +``` + +In this program, we have divided the derivation of `parent` into two separate rules, one processing the `father` relationship and the other processing the `mother` relationship. +This natually form a disjunction (or), as the derivation of `parent` can come from 2 disjunctive sources. +Note that in Scallop (or Datalog in general), the ordering of the two rules does not matter. + +Therefore, given that + +``` scl +rel father = {("Bob", "Alice")} +rel mother = {("Christine", "Alice")} +``` + +we can derive that the `parent` relation holding two tuples, `("Bob", "Alice")` and `("Christine", "Alice")`. + +The above program can be rewritten into a more compact form that looks like the following: + +``` scl +rel parent(a, b) = father(a, b) or mother(a, b) +// or +rel parent(a, b) = father(a, b) \/ mother(a, b) +``` + +We have used an explicit `or` (`\/`) keyword to connect the two atoms, `father(a, b)` and `mother(a, b)`. +The `\/` symbol, which is commonly seen in the formal logics as the symbol vee (\\(\vee\\)), is also supported. +Notice that written in this way, each branch of the disjunction need to fully bound the variables/expressions in the head. + +## Conjunction (And) + +To demonstrate the use of `and`, let's look at the following example computing the relation of `grandmother` based on `father` and `mother`: + +``` scl +rel grandmother(a, c) = mother(a, b) and father(b, c) +// or +rel grandmother(a, c) = mother(a, b) /\ father(b, c) +``` + +Notice that the symbol `/\` is a replacement for the `and` operator, which resembles the wedge (\\(\wedge\\)) symbol seen in formal logics. + +As can be seen from the rule, the body grounds three variables `a`, `b`, and `c`. +The variables `a` and `b` comes from `mother` and the variables `b` and `c` comes from `father`. +Notice that there is one variable, `b`, in common. +In this case, we are *joining* the relation of `mother` and `father` on the variable `b`. + +## Constraints + +Rule body can have boolean constraints. +For example, the conjunctive rule above can be re-written as + +``` scl +rel grandmother(a, c) = mother(a, b) and father(bp, c) and b == bp +``` + +Here, we are posing an equality (`==`) constraint on `b` and `bp`. +Normally, constraints are such kind of binary expressions involving predicates such as + +- equality and inequality (`==` and `!=`) +- numerical comparisons (`<`, `>`, `<=`, and `>=`) + +## Other constructs + +There are other constructs available for defining rules, which we continue to discuss in detail in other sections: + +- [Disjunctive head](disj_conj_head.md) +- [Recursive Rules](recursion.md) +- [Negation](negation.md) +- [Aggregation](aggregation.md) +- [Foreign Predicates](foreign_predicates.md) + +## Traditional Datalog Syntax + +If you are familiar with traditional Datalog, you can have it by swapping the `=` with `:-`, and the `and` to `,` +For example, the rule for defining `grandmother` can be rewritten as + +``` scl +rel grandmother(a, c) :- mother(a, b), father(b, c) +``` diff --git a/doc/src/language/value_type.md b/doc/src/language/value_type.md new file mode 100644 index 0000000..2c2e967 --- /dev/null +++ b/doc/src/language/value_type.md @@ -0,0 +1,238 @@ +# Values and Types + +Scallop has a built-in set of basic value types, following Rust's naming convention. +From there, we have types such as `Symbol`, `DateTime`, `Entity`, and `Tensor`, which are special types to Scallop. + +| Type | Description | +|------|-------------| +| `i8` | Signed-integer, 8-bit | +| `i16` | Signed-integer, 16-bit | +| `i32` | Signed-integer, 32-bit | +| `i64` | Signed-integer, 64-bit | +| `i128` | Signed-integer, 128-bit | +| `isize` | Signed size; its size is dependent on the system | +| `u8` | Unsigned-integer, 8-bit | +| `u16` | Unsigned-integer, 16-bit | +| `u32` | Unsigned-integer, 32-bit | +| `u64` | Unsigned-integer, 64-bit | +| `u128` | Unsigned-integer, 128-bit | +| `usize` | Unsigned size; its size is dependent on the system | +| `f32` | Floating-point number, 32-bit | +| `f64` | Floating-point number, 64-bit | +| `bool` | Boolean | +| `char` | Character | +| `String` | Variable-length string | +| `Symbol` | Symbol | +| `DateTime` | Date and time | +| `Duration` | Duration | +| `Entity` | Entity | +| `Tensor` | Tensor | + +### Integers + +Integers are the most basic data-type in Scallop. +If not specified, the default integer type that the system will pick is the `i32` (signed integer 32-bit) type: + +``` scl +rel edge = {(0, 1), (1, 2)} // (i32, i32) +``` + +If an unsigned integer type is specified but a negative number is used in the declared facts, a type inference error will be raised. +We demonstrate this in the `sclrepl` environment: + +``` +scl> type my_edge(usize, usize) +scl> rel my_edge = {(-1, -5), (0, 3)} +[Error] cannot unify types `usize` and `signed integer`, where the first is declared here + REPL:0 | type my_edge(usize, usize) + | ^^^^^ +and the second is declared here + REPL:1 | rel my_edge = {(-1, -5), (0, 3)} + | ^^ +``` + +Primitive operations that can be used along with integers are + +- Comparators: + - `==` (equality) + - `!=` (inequality) + - `>` (greater-than) + - `>=` (greater-than-or-equal-to) + - `<` (less-than) + - `<=` (less-than-or-equal-to) +- Arithmetic operators: + - `+` (plus) + - `-` (minus/negate) + - `*` (mult) + - `/` (div) + - `%` (mod) + +All of the above operations need to operate on two integers of the same type. +For instance, you cannot compare an `i32` value with a `usize` value. + +### Floating Point Numbers + +Floating point numbers are supported in Scallop as well. +The following example shows the definition of student and their class grades: + +``` scl +type student_grade(name: String, class: String, grade: f32) + +rel student_grade = { + ("alice", "cse 100", 95.2), + ("bob", "cse 100", 90.8), +} +``` + +It is possible derive special floating points such as `inf` and `-inf`, though we cannot declare such values directly. +For the floating point that is `nan` (not-a-number), we will omit the whole fact from the database to maintain sanity. +Specifically, the derivation of `nan` is treated as a failure of foreign functions, which we explain in detail [here](foreign_functions.md). + +All the basic operations that can work on integers would be able to work for floating point numbers as well. + +### Boolean + +Scallop allows the use of boolean values (`true` and `false`). + +``` scl +type variable_assign(String, bool) +rel variable_assign = {("a", true), ("b", false)} +``` + +We support the following boolean operations: + +- Comparisons + - `==` (equality) + - `!=` (inequality) +- Logical operations + - `!` (unary negate) + - `&&` (binary and) + - `||` (binary or) + - `^` (binary xor) + +For example, we can have the following code + +``` scl +rel result(a ^ b) = variable_assign("a", a) and variable_assign("b", b) // true +``` + +### Character + +Scallop allows definition of characters such as `'a'`, `'*'`. +They are single-quoted, and can contain escaped characters such as `'\n'` (new-line) and `'\t'` (tab). + +``` scl +type my_chars = {(0, 'h'), (1, 'e'), (2, 'l'), (3, 'l'), (4, 'o')} +``` + +Comparisons operations `==` and `!=` are available for characters. + +### String + +Scallop support variable length strings of the type `String`. +Strings are declared using the double quote (`"`), and can contain escaped characters such as `\n` and `\t`. + +``` scl +rel greeting = {"Hello World"} +``` + +Strings can certainly be compared using `==` and `!=`. +The main ways for interacting with strings are through foreign functions such as `$string_length`, `$substring`, `$string_concat`, and etc. +Please refer to the [foreign functions section](foreign_functions.md) for more information. + +### Symbols + +Symbols are internally registered strings. +They are most commonly created through [loading from external files](loading_csv.md). +But they can still be specified using the `s`-quoted-string notation: + +``` scl +rel symbols = {s"NAME", s"AGE", s"GENDER"} +``` + +### DateTime and Duration + +`DateTime` and `Duration` are natively supported data structures by Scallop. +We commonly specify `DateTime` and `Duration` using their string form. +In the following example, we specify the `DateTime` values using the `t`-quoted-string notation (`t` represents time): + +``` scl +rel event_dates = {("enroll", t"2020-01-01"), ("finish", t"2020-03-01")} +``` + +The dates will be all transformed into UTC time-zone. +When the date part is specified and the time is not specified, we will fill the time `00:00:00 UTC`. +When the time is specified but the date is not, we will use the current date when the program is invoked. +Any reasonable date-time format are acceptable, common ones include + +- `t"2019-11-29 08:08:05-08"` +- `t"4/8/2014 22:05"` +- `t"September 17, 2012 10:09am"` +- `t"2014/04/2 03:00:51"` +- `t"2014年04月08日"` + +`Duration`s can be specified using the `d`-quoted-string notation (`d` represents duration): + +``` scl +rel event_durations = {("e1", d"12 days"), ("e2", d"15 days 20 seconds")} +``` + +The string can contain numbers followed by their units. +When specifying durations, the following units are accepted: + +- nanoseconds (`n`) +- microseconds (`usecs`) +- milliseconds (`msecs`) +- seconds (`secs`) +- minutes (`m`) +- hours (`h`) +- days (`d`) +- weeks (`w`) +- months (`M`) +- years (`y`) + +We can operate between `Duration` and `DateTime` using simple operations such as `+` and `-`: +- `DateTime + Duration ==> DateTime` +- `Duration + Duration ==> Duration` +- `DateTime - DateTime ==> Duration` +- `DateTime - Duration ==> DateTime` +- `Duration - Duration ==> Duration` + +### Entity + +Entity values are 64-bit unsigned integers created through hashing. +They are used to represent pointers of created entities. +They cannot be directly created. +Rather, they are managed by Scallop through the creation of entities. +For example, + +``` scl +type List = Nil() | Cons(i32, List) +const MY_LIST = Cons(1, Cons(2, Nil())) +rel input_list(MY_LIST) +query input_list +``` + +The result is then + +``` +input_list: {(entity(0x4cd0d9e6652cdfc7))} +``` + +Please refer to [this section](adt_and_entity.md) for more informaiton on algebraic data types and entities. + +## Type Conversions + +In Scallop, types can be converted using the `as` operator. +For example, we can have + +``` scl +rel numbers = {1, 2, 3, 4, 5} +rel num_str(n as String) = numbers(n) +``` + +to derive the `numbers` to be `{"1", "2", "3", "4", "5"}`. +In general, we can have all numbers castable to each other. +We also have every type being castable to `String`. +For converting `String` to other types, it undergoes a parsing process. +When the parsing does not go through, no result will be returned. diff --git a/doc/src/probabilistic/disjunctive.md b/doc/src/probabilistic/disjunctive.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/probabilistic/index.md b/doc/src/probabilistic/index.md index e69de29..9151a7f 100644 --- a/doc/src/probabilistic/index.md +++ b/doc/src/probabilistic/index.md @@ -0,0 +1,19 @@ +# Scallop and Probabilistic Programming + +One fundamental concept in machine learning is *probability*. +Scallop, being a neurosymbolic programming language, supports probability and probabilistic programming natively. +For example, one can write the following program: + +``` scl +type Action = UP | DOWN | LEFT | RIGHT +rel predicted_action = {0.05::UP, 0.09::DOWN, 0.82::LEFT, 0.04::RIGHT} +``` + +where the `predicted_action` relation encodes a distribution of actions and their probabilities. +In particular, the `UP` action is predicted to have a \\(0.05\\) probability. +Here, the `::` symbol is used to suggest that probabilities (such as 0.05) are used to *tag* the facts (such as `UP`). + +Since we can define probability on user declared facts, the derivation of new facts will be associated with probabilities too. +This means that Scallop is doing *probabilistic reasoning*. +The whole probabilistic reasoning semantics of Scallop is defined with the theory of *provenance semiring*. +In this chapter, we give detailed explanation to the probabilities appeared in Scallop. diff --git a/doc/src/probabilistic/library.md b/doc/src/probabilistic/library.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/probabilistic/proofs.md b/doc/src/probabilistic/proofs.md new file mode 100644 index 0000000..2671236 --- /dev/null +++ b/doc/src/probabilistic/proofs.md @@ -0,0 +1 @@ +# Proofs Provenance diff --git a/doc/src/probabilistic/provenance.md b/doc/src/probabilistic/provenance.md new file mode 100644 index 0000000..11a1d74 --- /dev/null +++ b/doc/src/probabilistic/provenance.md @@ -0,0 +1,51 @@ +# Tags and Provenance + +Scallop's probabilistic semantics is realized by the *Provenance Semiring* framework. +Inside of this framework, each fact can be *tagged* by an extra piece of information, which we call *tag*. +Such information is propagated throughout the execution of Scallop program according to the *provenance*, which is the mathematical object defining how tags propagate. + +## Motivating Probabilistic Example + +The following example shows a fact `earthquake()` being tagged by a probability `0.03` (earthquake could happen with a 0.03 probability): + +``` scl +rel 0.03::earthquake() +``` + +Concretely, we have an *(external) tag space* of \\([0, 1]\\), which contains real numbers between 0 and 1, which is the space of probabilities. +Similarly, we define another tagged fact `burglary()`: + +``` scl +rel 0.20::burglary() +``` + +We can declare a rule saying that, "when earthquake or burglary happens, an alarm will go off". + +``` scl +rel alarm() = earthquake() or burglary() +query alarm +``` + +Remember that the facts `earthquake()` and `burglary()` are probabilistic. +Intuitively, the derived fact `alarm()` will also be associated with a derived probability. +Based on probability theory, we have + +\\[ +\begin{align} +\Pr(\text{alarm}) +&= \Pr(\text{earthquake} \vee \text{burglary}) \\\\ +&= 1 - \Pr(\neg \text{earthquake} \wedge \neg \text{burglary}) \\\\ +&= 1 - \Pr(\neg \text{earthquake}) \cdot \Pr(\neg \text{burglary}) \\\\ +&= 1 - (1 - \Pr(\text{earthquake})) \cdot (1 - \Pr(\text{burglary})) \\\\ +&= 1 - (1 - 0.03) (1 - 0.20) \\\\ +&= 1 - 0.97 \times 0.8 \\\\ +&= 0.224 +\end{align} +\\] + +This is indeed what we get if we use the `topkproofs` provenance (which we discuss later in the chapter) with the `scli` Scallop interpreter: + +``` +> scli alarm.scl +alarm: {0.224::()} +``` diff --git a/doc/src/res/examples/sum_2.scl b/doc/src/res/examples/sum_2.scl new file mode 100644 index 0000000..a345c47 --- /dev/null +++ b/doc/src/res/examples/sum_2.scl @@ -0,0 +1,2 @@ +type digit_a(i32), digit_b(i32) +rel sum_2(a + b) = digit_a(a) and digit_b(b) diff --git a/doc/src/res/img/scallop-logo-transp-128.png b/doc/src/res/img/scallop-logo-transp-128.png new file mode 100644 index 0000000..456127e Binary files /dev/null and b/doc/src/res/img/scallop-logo-transp-128.png differ diff --git a/doc/src/res/img/scallop-logo-ws-512.png b/doc/src/res/img/scallop-logo-ws-512.png new file mode 100644 index 0000000..1406f9c Binary files /dev/null and b/doc/src/res/img/scallop-logo-ws-512.png differ diff --git a/doc/src/scallopy/branching.md b/doc/src/scallopy/branching.md index 930da15..d056fcd 100644 --- a/doc/src/scallopy/branching.md +++ b/doc/src/scallopy/branching.md @@ -1 +1,24 @@ # Branching Executions + +One cool feature that `scallopy` supports is *branching execution*. +People can create a context, clone it to form new contexts, and modify the new context without touching the old ones. +This is particularly useful when incremental computation is desired. + +``` py +# Create the first version of the context +ctx = scallopy.ScallopContext() +ctx.add_relation(...) +ctx.add_facts(...) + +# Branch it into another context +ctx1 = ctx.clone() +ctx1.add_relation(...) +ctx1.add_facts(...) +ctx1.run() # Running the first context + +# Branch it into one more context; `ctx1` and `ctx2` are completely disjoint +ctx2 = ctx.clone() +ctx2.add_relation(...) +ctx2.add_facts(...) +ctx2.run() # Running the second context +``` diff --git a/doc/src/scallopy/context.md b/doc/src/scallopy/context.md index 36ef296..3525cc8 100644 --- a/doc/src/scallopy/context.md +++ b/doc/src/scallopy/context.md @@ -1 +1,174 @@ -# Creating Context +# Scallop Context + +The most fundamental point of interaction of `scallopy` is `ScallopContext`. +The following is a very simple example setting up a `ScallopContext` to compute the `edge-path` program: + +``` py +import scallopy + +# Creating a new context +ctx = scallopy.ScallopContext() + +# Add relation of `edge` +ctx.add_relation("edge", (int, int)) +ctx.add_facts("edge", [(0, 1), (1, 2)]) + +# Add rule of `path` +ctx.add_rule("path(a, c) = edge(a, c) or path(a, b) and edge(b, c)") + +# Run! +ctx.run() + +# Check the result! +print(list(ctx.relation("path"))) # [(0, 1), (0, 2), (1, 2)] +``` + +Roughly, the program above can be divided into three phases: + +1. Setup the context: this involves defining relations, adding facts to relations, and adding rules that do the computation +2. Running the program inside of context +3. Fetch the results + +While the 2nd and 3rd steps are the place where the computation really happens, it's more important for the programmers to correctly setup the full context for computation. +We now elaborate on what are the high-level things to do when setting up the context + +## Configurations + +When creating a new `ScallopContext`, one should configure it with intended provenance. +If no argument is supplied, as shown in the above example, the context will be initialized with the default provenance, `unit`, which resembles untagged semantics (a.k.a. discrete Datalog). +To explicitly specify this, you can do + +``` py +ctx = scallopy.ScallopContext(provenance="unit") +``` + +Of course, Scallop can be used to perform reasoning on probabilistic and differentiable inputs. +For instance, you can write the following + +``` py +ctx = scallopy.ScallopContext(provenance="minmaxprob") # Probabilistic +# or +ctx = scallopy.ScallopContext(provenance="diffminmaxprob") # Differentiable +``` + +For more information on possible provenance information, please refer to the [provenance](scallopy/provenance.md) section. +It it worth noting that some provenance, such as `topkproofs`, accept additional parameters such as `k`. +In this case, you can supply this as additional arguments when creating the context: + +``` py +ctx = scallopy.ScallopContext(provenance="topkproofs", k=5) # top-k-proofs provenance with k = 5 +``` + +## Adding Program + +Given that a context has been configured and initialized, we can set it up the quickest by loading a program into the context. +One can either load an external `.scl` file, or directly inserting a program written as Python string. +To directly add a full program string to the context, one can do + +``` py +ctx.add_program(""" + rel edge = {(0, 1), (1, 2)} + rel path(a, c) = edge(a, c) or path(a, b) and edge(b, c) +""") +``` + +On the other hand, assuming that there is a file `edge_path.scl` that contains the same content as the above string, one can do + +``` py +ctx.import_file("edge_path.scl") +``` + +## Adding Relations + +Instead of adding program as a whole, one can also add relations one-at-a-time. +When adding new relations, one would need to supply the name as well as the type of the relation. +For example, the `edge` relation can be defined as follows + +``` py +ctx.add_relation("edge", (int, int)) +``` + +Here, we are saying that `edge` is an arity-2 relation storing pairs of integers. +Note that we are specifying the type using Python's `int` type. +This is equivalent to the `i32` type inside Scallop. +Therefore, the above instruction tranlates to the following Scallop code: + +``` scl +rel edge(i32, i32) +``` + +Many existing Python types can directly translate to Scallop type. +In particular, we have the mapping listed as follows: + +| Python Type | Scallop Type | +|-------------|--------------| +| `int` | `i32` | +| `bool` | `bool` | +| `float` | `f32` | +| `str` | `String` | + +In case one want to use types other than the listed ones (e.g., `usize`), they can be accessed directly using the string `"usize"`, or they can be accessed through predefined types such as `scallopy.usize`. +The example below defines a relation of type `(usize, f64, i32)`: + +``` py +ctx.add_relation("my_relation", (scallopy.usize, "f64", int)) +``` + +Specifically for arity-1 relations, users don't need to use a tuple to specify the type. +For instance, + +``` py +ctx.add_relation("digit", int) +``` + +### Configuring Relations + +#### `non_probabilistic` + +## Adding Facts + +The most basic version of adding facts into an existing relation inside of an existing context. +We are assuming that the context has a provenance of `"unit"`. + +``` py +ctx.add_facts("edge", [(1, 2), (2, 3)]) +``` + +## Adding Rules + +### Tagged Rules + +## Running + +## Additional Features + +There are more features provided by the `ScallopContext` interface. +We hereby list them for reference. + +### Cloning + +One can copy a context to create a new context. +The resulting context will contain all the program, configurations, and provenance information. + +``` py +new_ctx = ctx.clone() +``` + +The cloning feature relates to pseudo-incremental computation and branching computation. +We elaborate on this in the [Branching Computation](scallopy/branching.md) section. + +### Compiling + +### Iteration Count Limit + +One can configure the + +### Early Discarding + +### Obtaining Context Information + +### Foreign Functions and Predicates + +### Saving and Loading + +Please refer to the [Saving and Loading](scallopy/save_and_load.md) section for more information. diff --git a/doc/src/scallopy/foreign_function.md b/doc/src/scallopy/foreign_function.md new file mode 100644 index 0000000..038c5bb --- /dev/null +++ b/doc/src/scallopy/foreign_function.md @@ -0,0 +1,61 @@ +# Foreign Functions + +While there are existing [foreign functions](../language/foreign_functions.md) such as `$hash` and `$abs`, people sometimes want more functions to be included for specialized computation. +`scallopy` provides such interface and allows user to define foreign functions in Python. +Here is an example defining a custom `$sum` function in Python which is later used in Scallop: + +``` py +# Create a new foreign function by annotating an existing function with `@scallopy.foreign_function` +# Note that this function has variable arguments! +@scallopy.foreign_function +def my_sum(*args: int) -> int: + s = 0 + for x in args: + s += x + return s + +# Create a context +ctx = scallopy.ScallopContext() + +# Register the declared foreign function (`my_sum`) +# Note that the function needs to be registered before it is used +ctx.register_foreign_function(my_sum) + +# Add some relations +ctx.add_relation("I", (int, int)) +ctx.add_facts("I", [(1, 2), (2, 3), (3, 4)]) + +# Add a rule which uses the registered function! +ctx.add_rule("R($my_sum(a, b)) = I(a, b)") + +# Run the context +ctx.run() + +# See the result, should be [(3,), (5,), (7,)] +print(list(ctx.relation("R"))) +``` + +Now we elaborate on how we define new foreign functions in Python. + +## Function Signature + +The annotator `@scallopy.foreign_function` performs analysis of the annotated Python function and makes sure that it is accepted as a Scallop foreign function. +We require that types are annotated on all arguments and the return value. +For simplicity, Python types such as `int`, `bool`, and `str` are mapped to Scallop types (and type families) as following: + +| Python type | Scallop type | Scallop base types | +|-------------|--------------|--------------------| +| `int` | `Integer` family | `i8`, `i16`, ..., `u8`, `u16`, ..., `usize` | +| `float` | `Float` family | `f32`, `f64` | +| `bool` | `bool` | `bool` | +| `str` | `String` | `String` | + +If one desires to use a more fine-grained type + +## Argument Types + +## Optional Arguments + +## Variable Arguments + +## Error Handling diff --git a/doc/src/scallopy/foreign_predicate.md b/doc/src/scallopy/foreign_predicate.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/scallopy/getting_started.md b/doc/src/scallopy/getting_started.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/scallopy/index.md b/doc/src/scallopy/index.md index e69de29..a0f55e4 100644 --- a/doc/src/scallopy/index.md +++ b/doc/src/scallopy/index.md @@ -0,0 +1,14 @@ +# The Scallop Python Binding `scallopy` + +`scallopy` is the Python binding for Scallop, offering an interface for computationg with Scallop in Python. +In addition, it can be integrated with [PyTorch](https://pytorch.org), allowing users to write Neuro-Symbolic applications that can be connected to PyTorch. +In this section, we elaborate on how to install, configure, and use the `scallopy` library. + +For an example, please look at [Getting Started](getting_started.md). +To start reading the documentation, proceed to [Scallopy Context](context.md) + +## Installation + +### TODO: Installation with venv + +### TODO: Installation with Conda diff --git a/doc/src/scallopy/module.md b/doc/src/scallopy/module.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/scallopy/module_input.md b/doc/src/scallopy/module_input.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/scallopy/module_output.md b/doc/src/scallopy/module_output.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/scallopy/save_and_load.md b/doc/src/scallopy/save_and_load.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/scallopy/types.md b/doc/src/scallopy/types.md new file mode 100644 index 0000000..e69de29 diff --git a/doc/src/summary.md b/doc/src/summary.md index 34c6dd5..f39844a 100644 --- a/doc/src/summary.md +++ b/doc/src/summary.md @@ -7,30 +7,55 @@ - [Installation](installation.md) - [Crash Course](crash_course.md) -# Reference Guide +# Language Reference Guide - [Scallop and Logic Programming](language/index.md) - - [Types, Relations, and Facts](language/facts.md) - - [Writing Simple Rules](language/rules.md) + - [Relations and Facts](language/relation.md) + - [Writing Rules](language/rules.md) + - [Values and Types](language/value_type.md) - [Writing a Query](language/query.md) - [Recursive Rules](language/recursion.md) - - [Rules with Negations](language/negation.md) - - [Rules with Aggregations](language/aggregation.md) + - [Negations](language/negation.md) + - [Aggregations](language/aggregation.md) + - [Declaring Constants](language/constants.md) + - [Algebraic Data Type and Entities](language/adt_and_entity.md) + - [Loading from CSV](language/loading_csv.md) - [Foreign Functions](language/foreign_functions.md) - [Foreign Predicates](language/foreign_predicates.md) -- [Scallop and Probabilistic Programming](probabilistic/index.md) + - [Magic-Set Transformation](language/magic_set.md) + - [Reference Guide](language/reference_guide.md) +- [Provenance and Probabilistic Programming](probabilistic/index.md) + - [Provenance](probabilistic/provenance.md) + - [Proofs Provenance](probabilistic/proofs.md) - [Fact with Probability](probabilistic/facts.md) - [Logic and Probability](probabilistic/logic.md) + - [Provenance Library](probabilistic/library.md) - [Aggregation and Probability](probabilistic/reasoning.md) - [Sampling with Probability](probabilistic/sampling.md) - - [Tags and Provenance](language/provenance.md) -- [Python and Scallop](scallopy/index.md) - - [Creating Context](scallopy/context.md) +- [`scallopy`](scallopy/index.md) + - [Getting Started](scallopy/getting_started.md) + - [Scallop Context](scallopy/context.md) - [Branching Executions](scallopy/branching.md) - [Configuring Provenance](scallopy/provenance.md) + - [Creating Module](scallopy/module.md) + - [Configuring Input Relations](scallopy/module_input.md) + - [Configuring Output Relations](scallopy/module_output.md) + - [Foreign Functions](scallopy/foreign_function.md) + - [Foreign Predicate](scallopy/foreign_predicate.md) + - [Saving and Loading](scallopy/save_and_load.md) + +# Toolchain + +- [Scallop CLI](toolchain/cli.md) +- [Scallop Interpreter](toolchain/scli.md) +- [Scallop REPL](toolchain/sclrepl.md) +- [Scallop Compiler](toolchain/sclc.md) + +# For Developers + - [For Developers](developer/index.md) - - [New Language Construct](developer/language_construct.md) - - [New Binding](developer/binding.md) +- [New Language Construct](developer/language_construct.md) +- [New Binding](developer/binding.md) # Resources diff --git a/doc/src/toolchain/cli.md b/doc/src/toolchain/cli.md new file mode 100644 index 0000000..602a4a0 --- /dev/null +++ b/doc/src/toolchain/cli.md @@ -0,0 +1 @@ +# Scallop CLI diff --git a/doc/src/toolchain/repl.md b/doc/src/toolchain/repl.md new file mode 100644 index 0000000..679e93c --- /dev/null +++ b/doc/src/toolchain/repl.md @@ -0,0 +1 @@ +# Scallop REPL diff --git a/doc/src/toolchain/sclc.md b/doc/src/toolchain/sclc.md new file mode 100644 index 0000000..8109f64 --- /dev/null +++ b/doc/src/toolchain/sclc.md @@ -0,0 +1 @@ +# Scallop Compiler diff --git a/doc/src/toolchain/scli.md b/doc/src/toolchain/scli.md new file mode 100644 index 0000000..6683aeb --- /dev/null +++ b/doc/src/toolchain/scli.md @@ -0,0 +1 @@ +# Scallop Interpreter diff --git a/doc/src/toolchain/sclrepl.md b/doc/src/toolchain/sclrepl.md new file mode 100644 index 0000000..679e93c --- /dev/null +++ b/doc/src/toolchain/sclrepl.md @@ -0,0 +1 @@ +# Scallop REPL diff --git a/etc/codegen/Cargo.toml b/etc/codegen/Cargo.toml index fe08147..5851a0a 100644 --- a/etc/codegen/Cargo.toml +++ b/etc/codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallop-codegen" -version = "0.1.9" +version = "0.2.0" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/codegen/src/lib.rs b/etc/codegen/src/lib.rs index be7ddac..15cb635 100644 --- a/etc/codegen/src/lib.rs +++ b/etc/codegen/src/lib.rs @@ -13,10 +13,8 @@ pub fn scallop(tokens: TokenStream) -> TokenStream { let ram = match compiler::compile_source_to_ram(src) { Ok(ram) => ram, Err(errs) => { - for err in errs { - println!("{}", err); - } - return quote! {}.into(); + let all_errs = errs.iter().map(|err| err.to_string()).collect::>().join("\n"); + return quote! { compile_error!(#all_errs); }.into(); } }; diff --git a/etc/scallop-cli/.gitignore b/etc/scallop-cli/.gitignore new file mode 100644 index 0000000..2fc8ce3 --- /dev/null +++ b/etc/scallop-cli/.gitignore @@ -0,0 +1,3 @@ +build +dist +*.egg-info diff --git a/etc/scallop-cli/examples/llm/generate_name.scl b/etc/scallop-cli/examples/llm/generate_name.scl new file mode 100644 index 0000000..435b126 --- /dev/null +++ b/etc/scallop-cli/examples/llm/generate_name.scl @@ -0,0 +1,6 @@ +@gpt_chat_complete(prompt="A typical {lang} name is {n}", pattern="bf") +type generate_name(lang: String, n: String) + +rel languages = {"english", "chinese", "japanese", "french", "germany"} + +rel sampled_name(l, n) = languages(l) and generate_name(l, n) diff --git a/etc/scallop-cli/examples/llm/kinship_composition.scl b/etc/scallop-cli/examples/llm/kinship_composition.scl new file mode 100644 index 0000000..83f9e31 --- /dev/null +++ b/etc/scallop-cli/examples/llm/kinship_composition.scl @@ -0,0 +1,4 @@ +@gpt_complete(prompt="{x}'s {y} is {z}", pattern="bbf") +type composition(x: String, y: String, z: String) + +rel result($string_lower(z)) = composition("father", "mother", z) diff --git a/etc/scallop-cli/examples/llm/mountain_height.scl b/etc/scallop-cli/examples/llm/mountain_height.scl new file mode 100644 index 0000000..97a2fc3 --- /dev/null +++ b/etc/scallop-cli/examples/llm/mountain_height.scl @@ -0,0 +1,8 @@ +@gpt_complete(prompt="The mountain {x}'s height is {y} meters", pattern="bf") +type mountain_height(x: String, y: String) + +rel mountains = {"Mount Everest", "K2"} + +rel result(x, y) = mountains(x), mountain_height(x, y) + +query result diff --git a/etc/scallop-cli/examples/llm/mountain_height_multi.scl b/etc/scallop-cli/examples/llm/mountain_height_multi.scl new file mode 100644 index 0000000..57512f4 --- /dev/null +++ b/etc/scallop-cli/examples/llm/mountain_height_multi.scl @@ -0,0 +1,8 @@ +@gpt_complete(prompt="The mountain {x}'s height is {y} meters or {z} feet?", pattern="bff") +type mountain_height(x: String, y: f32, z: f32) + +rel mountains = {"Mount Everest", "K2"} + +rel result(x, y, z) = mountains(x), mountain_height(x, y, z) + +query result diff --git a/etc/scallop-cli/examples/llm/qa_with_function.scl b/etc/scallop-cli/examples/llm/qa_with_function.scl new file mode 100644 index 0000000..5bb1d7d --- /dev/null +++ b/etc/scallop-cli/examples/llm/qa_with_function.scl @@ -0,0 +1,8 @@ +rel questions = { + (1, "what is the height of highest mountain in the world?"), + (2, "are cats larger than dogs?"), +} + +rel answer(id, $gpt_complete(x)) = questions(id, x) + +query answer diff --git a/etc/scallop-cli/examples/llm/qa_with_predicate.scl b/etc/scallop-cli/examples/llm/qa_with_predicate.scl new file mode 100644 index 0000000..5bf2b77 --- /dev/null +++ b/etc/scallop-cli/examples/llm/qa_with_predicate.scl @@ -0,0 +1,8 @@ +rel questions = { + (1, "what is the height of highest mountain in the world?"), + (2, "are cats larger than dogs?"), +} + +rel answer(id, y) = questions(id, x), gpt_complete(x, y) + +query answer diff --git a/etc/scallop-cli/readme.md b/etc/scallop-cli/readme.md new file mode 100644 index 0000000..e69de29 diff --git a/etc/scallop-cli/scallop/__init__.py b/etc/scallop-cli/scallop/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/etc/scallop-cli/scallop/cli.py b/etc/scallop-cli/scallop/cli.py new file mode 100644 index 0000000..8ed3a04 --- /dev/null +++ b/etc/scallop-cli/scallop/cli.py @@ -0,0 +1,57 @@ +# Argument parser +import argparse + +# Scallop imports +import scallopy + +# Project imports +from . import config +from . import stdlib + + +def argument_parser(): + parser = argparse.ArgumentParser("scallop", description="Scallop language command line interface") + parser.add_argument("file", nargs="?", default=None, help="The file to execute") + parser.add_argument("-p", "--provenance", type=str, default="unit", help="The provenance to pick") + parser.add_argument("-m", "--module", type=str, default=None, help="Load module in interactive mode") + parser.add_argument("--iter-limit", type=int, default=10, help="Iteration limit") + parser.add_argument("--num-allowed-openai-request", type=int, default=100, help="Limit on the number of openai calls") + parser.add_argument("--openai-gpt-model", type=str, default="text-davinci-003", help="The GPT model we use") + parser.add_argument("--openai-gpt-temperature", type=float, default=0, help="The temperature for the GPT model") + return parser + + +def cmd_args(): + parser = argument_parser() + return parser.parse_args() + + +def main(): + # Parse command line arguments + args = cmd_args() + + # Configure environments + config.configure(args) + + # Create a scallopy context + ctx = scallopy.ScallopContext(provenance=args.provenance) + ctx.set_iter_limit(args.iter_limit) + + # Load the stdlib + stdlib.load_stdlib(ctx) + + # Check if the user has provided a file + if args.file is not None: + # When file is available, Import the target file into the context + ctx.import_file(args.file) + + # Run the context + ctx.run() + + # Print the results + for relation in ctx.relations(): + if ctx.has_relation(relation): + print(f"{relation}\t:\t{list(ctx.relation(relation))}") + + else: + raise NotImplementedError() diff --git a/etc/scallop-cli/scallop/config/__init__.py b/etc/scallop-cli/scallop/config/__init__.py new file mode 100644 index 0000000..9fdac37 --- /dev/null +++ b/etc/scallop-cli/scallop/config/__init__.py @@ -0,0 +1 @@ +from .config import configure diff --git a/etc/scallop-cli/scallop/config/config.py b/etc/scallop-cli/scallop/config/config.py new file mode 100644 index 0000000..63b2b2a --- /dev/null +++ b/etc/scallop-cli/scallop/config/config.py @@ -0,0 +1,5 @@ +from . import openai + + +def configure(args): + openai.configure_openai(args) diff --git a/etc/scallop-cli/scallop/config/openai.py b/etc/scallop-cli/scallop/config/openai.py new file mode 100644 index 0000000..b0c7ef1 --- /dev/null +++ b/etc/scallop-cli/scallop/config/openai.py @@ -0,0 +1,45 @@ +import openai +import os + +# Whether the openai plugin has been configured +CONFIGURED = False + +# Number of allowed requests +NUM_ALLOWED_REQUESTS = 0 + +# Number of already performed requests +NUM_PERFORMED_REQUESTS = 0 + +# Temprature of GPT model +TEMPERATURE = 0.0 + +# The GPT model to use +MODEL = None + +def configure_openai(args): + global CONFIGURED + global NUM_ALLOWED_REQUESTS + global NUM_PERFORMED_REQUESTS + global TEMPERATURE + global MODEL + + # Open API + api_key = os.getenv("OPENAI_API_KEY") + if api_key is None: + return + + # Is configured + CONFIGURED = True + + # Set the API Key + openai.api_key = api_key + + # Set request limit + NUM_ALLOWED_REQUESTS = args.num_allowed_openai_request + NUM_PERFORMED_REQUESTS = 0 + + # Set model + MODEL = args.openai_gpt_model + + # Set temperature + TEMPERATURE = args.openai_gpt_temperature diff --git a/etc/scallop-cli/scallop/stdlib/__init__.py b/etc/scallop-cli/scallop/stdlib/__init__.py new file mode 100644 index 0000000..6405b84 --- /dev/null +++ b/etc/scallop-cli/scallop/stdlib/__init__.py @@ -0,0 +1,16 @@ +from scallopy import ScallopContext + +from . import ff +from . import fp +from . import attr + +def load_stdlib(scallop_ctx: ScallopContext): + # Register foreign functions + scallop_ctx.register_foreign_function(ff.gpt_complete) + + # Register foreign predicates + scallop_ctx.register_foreign_predicate(fp.gpt_complete) + + # Register foreign attributes + scallop_ctx.register_foreign_attribute(attr.gpt_complete) + scallop_ctx.register_foreign_attribute(attr.gpt_chat_complete) diff --git a/etc/scallop-cli/scallop/stdlib/attr/__init__.py b/etc/scallop-cli/scallop/stdlib/attr/__init__.py new file mode 100644 index 0000000..4f2352d --- /dev/null +++ b/etc/scallop-cli/scallop/stdlib/attr/__init__.py @@ -0,0 +1,2 @@ +from .gpt_complete import gpt_complete +from .gpt_chat import gpt_chat_complete diff --git a/etc/scallop-cli/scallop/stdlib/attr/gpt_chat.py b/etc/scallop-cli/scallop/stdlib/attr/gpt_chat.py new file mode 100644 index 0000000..e9fbcdb --- /dev/null +++ b/etc/scallop-cli/scallop/stdlib/attr/gpt_chat.py @@ -0,0 +1,179 @@ +from typing import * +import re + +import openai + +import scallopy + +from ...config import openai as openai_config + +@scallopy.foreign_attribute +def gpt_chat_complete( + item, + *, + prompt: str, + pattern: str, + examples: List[List[str]] = [], + debug: bool = False, +): + # Check if the annotation is on relation type decl + if not ("TypeDecl" in item and "Relation" in item["TypeDecl"]["node"]): + raise Exception("`gpt` has to be an attribute of a relation type declaration") + + # Get the type decl + relation_type_decls = item["TypeDecl"]["node"]["Relation"]["node"]["rel_types"] + if len(relation_type_decls) > 1: + raise Exception("`gpt` cannot be an attribute on multiple relations") + relation_type_decl = relation_type_decls[0]["node"] + + # Get the relation name and argument types + name = relation_type_decl["name"]["node"]["name"] + arg_types = [(arg_type["node"]["name"]["node"]["name"], arg_type["node"]["ty"]["node"]) for arg_type in relation_type_decl["arg_types"]] + + # Get the Pattern + regex_match = re.match("^(b*)(f+)$", pattern) + if regex_match is None: + raise Exception("`gpt` pattern must start with b (optional) and ending with f (required)") + + # Check if the pattern and the arg types match + if len(arg_types) != len(pattern): + raise Exception("`gpt` pattern must have the same length as the number of arguments") + + # Compute the number of `b`ounds and the number of `f`rees. + num_bounded = len(regex_match[1]) + num_free = len(regex_match[2]) + + # Make sure that the types are good + if not all([ty == "String" for (_, ty) in arg_types[:num_bounded]]): + raise Exception("`gpt` annotation requires all input arguments to have `String` type") + + # The storage is special per foreign predicate + STORAGE = {} + + # Main function to Invoke the gpt + def invoke_gpt(*args): + assert len(args) == num_bounded + + # Deal with STORAGE; check if the response is memoized + storage_key = tuple(args) + if storage_key in STORAGE: + response = STORAGE[storage_key] + + # Check if openai API is configured + elif not openai_config.CONFIGURED: + raise Exception("Open AI Plugin not configured; consider setting OPENAI_API_KEY") + + # Check openai request + elif openai_config.NUM_PERFORMED_REQUESTS > openai_config.NUM_ALLOWED_REQUESTS: + raise Exception("Exceeding allowed number of requests") + + # Need to do a new request + else: + # Fill the prompt with the inputs; replace free variables with the BLANK + filled_prompt = fill_prompt(prompt, args, arg_types, num_bounded) + + # Create a request to openai gpt + system_ctx, messages = fill_template([filled_prompt]) + _, current_conversation = query_gpt_completion(system_ctx, messages, openai_config.MODEL) + responses = extract_responses(current_conversation) + + # Debug print + if debug: + print(f"Prompt: {messages}") + print(f"Responses: {responses}") + + # Store the response + STORAGE[storage_key] = responses + + # Return choices + for response in responses: + tup = parse_choice_text(response.strip(), arg_types, num_bounded) + yield tup + + # Generate the foreign predicate + foreign_predicate = scallopy.ForeignPredicate( + invoke_gpt, + name, + input_arg_types=[scallopy.predicate.Type(arg_ty) for (_, arg_ty) in arg_types[:num_bounded]], + output_arg_types=[scallopy.predicate.Type(arg_ty) for (_, arg_ty) in arg_types[num_bounded:]], + tag_type=None, + ) + + # Remove the item and register a foreign predicate + return scallopy.attribute.MultipleActions([ + scallopy.attribute.RegisterForeignPredicateAction(foreign_predicate), + scallopy.attribute.RemoveItemAction(), + ]) + + +def fill_prompt(prompt: str, args: List[str], arg_types: List[Tuple[str, str]], num_bounded: int): + arg_patterns = ["{" + arg_name + "}" for (arg_name, _) in arg_types] + free_args = ["" for (arg_name, _) in arg_types[num_bounded:]] + filled_prompt = prompt + for (arg_pattern, fill) in zip(arg_patterns[:num_bounded], args): + filled_prompt = filled_prompt.replace(arg_pattern, str(fill)) + for (arg_pattern, fill) in zip(arg_patterns[num_bounded:], free_args): + filled_prompt = filled_prompt.replace(arg_pattern, str(fill)) + return filled_prompt + + +def parse_choice_text(text: str, arg_types: List[Tuple[str, str]], num_bounded: int): + ret_arg_regexes = [re.compile(f"{arg_name}\s*=\s*(.+)$\n?", re.MULTILINE) for (arg_name, _) in arg_types[num_bounded:]] + matches = [next(iter(ret_arg_regex.finditer(text))) for ret_arg_regex in ret_arg_regexes] + answers = ["" if match is None else parse_value(match[1], arg_ty) for (match, (_, arg_ty)) in zip(matches, arg_types[num_bounded:])] + return tuple(answers) + + +def parse_value(text: str, arg_type: str) -> Any: + if arg_type == "F32" or arg_type == "F64": + return float(text) + elif arg_type == "I8" or arg_type == "I16" or arg_type == "I32" or arg_type == "I64" or arg_type == "ISize" or \ + arg_type == "U8" or arg_type == "U16" or arg_type == "U32" or arg_type == "U64" or arg_type == "USize": + return int(float(text)) + elif arg_type == "Bool": + if text == "true" or text == "True": return True + elif text == "false" or text == "False": return False + else: return False + elif arg_type == "String": + return text + else: + raise NotImplemented() + +def query_gpt_completion(system_ctx, messages, model): + current_conversation = [system_ctx] + responses = [] + for message in messages: + current_conversation.append(message) + response = openai.ChatCompletion.create( + model=model, + messages=current_conversation, + temperature=0) + responses.append(response) + current_conversation.append(response['choices'][0]['message']) + response_pairs = (responses, current_conversation) + return response_pairs + +def fill_template(methodology_sub_texts): + template_text = ''' Answer in the format of new-line separated \"BLANK_VAR = ANSWER\". + Here is one example: + Father's mother is . + BLANK_n = grandmother + + Another example is, + The mountain Everest is of meters or feets. + BLANK_n = 8848 + BLANK_m = 29032 + + Please fill in the blank(s): + ''' + system_ctx = {"role": "system", "content": f"You are a knowledgable assistant. "} + messages = [{"role": "user", "content": template_text + text} for text in methodology_sub_texts] + + return system_ctx, messages + +def extract_responses(conversation): + gpt_responses = [] + for dialogue in conversation: + if dialogue['role'] == 'assistant': + gpt_responses.append(dialogue['content']) + return gpt_responses diff --git a/etc/scallop-cli/scallop/stdlib/attr/gpt_complete.py b/etc/scallop-cli/scallop/stdlib/attr/gpt_complete.py new file mode 100644 index 0000000..a10a42b --- /dev/null +++ b/etc/scallop-cli/scallop/stdlib/attr/gpt_complete.py @@ -0,0 +1,158 @@ +from typing import * +import re + +import openai + +import scallopy + +from ...config import openai as openai_config + +@scallopy.foreign_attribute +def gpt_complete( + item, + *, + prompt: str, + pattern: str, + examples: List[List[str]] = [], + debug: bool = False, +): + # Check if the annotation is on relation type decl + if not ("TypeDecl" in item and "Relation" in item["TypeDecl"]["node"]): + raise Exception("`gpt` has to be an attribute of a relation type declaration") + + # Get the type decl + relation_type_decls = item["TypeDecl"]["node"]["Relation"]["node"]["rel_types"] + if len(relation_type_decls) > 1: + raise Exception("`gpt` cannot be an attribute on multiple relations") + relation_type_decl = relation_type_decls[0]["node"] + + # Get the relation name and argument types + name = relation_type_decl["name"]["node"]["name"] + arg_types = [(arg_type["node"]["name"]["node"]["name"], arg_type["node"]["ty"]["node"]) for arg_type in relation_type_decl["arg_types"]] + + # Get the Pattern + regex_match = re.match("^(b*)(f+)$", pattern) + if regex_match is None: + raise Exception("`gpt` pattern must start with b (optional) and ending with f (required)") + + # Check if the pattern and the arg types match + if len(arg_types) != len(pattern): + raise Exception("`gpt` pattern must have the same length as the number of arguments") + + # Compute the number of `b`ounds and the number of `f`rees. + num_bounded = len(regex_match[1]) + num_free = len(regex_match[2]) + + # Make sure that the types are good + if not all([ty == "String" for (_, ty) in arg_types[:num_bounded]]): + raise Exception("`gpt` annotation requires all input arguments to have `String` type") + + # The storage is special per foreign predicate + STORAGE = {} + + # Main function to Invoke the gpt + def invoke_gpt(*args): + assert len(args) == num_bounded + + # Deal with STORAGE; check if the response is memoized + storage_key = tuple(args) + if storage_key in STORAGE: + response = STORAGE[storage_key] + + # Check if openai API is configured + elif not openai_config.CONFIGURED: + raise Exception("Open AI Plugin not configured; consider setting OPENAI_API_KEY") + + # Check openai request + elif openai_config.NUM_PERFORMED_REQUESTS > openai_config.NUM_ALLOWED_REQUESTS: + raise Exception("Exceeding allowed number of requests") + + # Need to do a new request + else: + # Fill the prompt with the inputs; replace free variables with the BLANK + filled_prompt = fill_prompt(prompt, args, arg_types, num_bounded) + + prompt_header = ''' Answer in the format of new-line separated \"BLANK_VAR = ANSWER\". + Here is one example: + Father's mother is . + BLANK_n = grandmother + + Another example is, + The mountain Everest is of meters or feets. + BLANK_n = 8848 + BLANK_m = 29032 + + Please fill in the blank(s): + ''' + + # Create a full prompt which is "fill in the blank" + full_prompt = f"{prompt_header}\n{filled_prompt}." + + # Create a request to openai gpt + response = openai.Completion.create( + model=openai_config.MODEL, + prompt=full_prompt, + temperature=openai_config.TEMPERATURE) + + # Debug print + if debug: + print(f"Prompt: {full_prompt}") + print(f"Response: {response}") + + # Store the response + STORAGE[storage_key] = response + + # Return choices + for choice in response["choices"]: + choice_text = choice["text"].strip() + tup = parse_choice_text(choice_text, arg_types, num_bounded) + yield tup + + # Generate the foreign predicate + foreign_predicate = scallopy.ForeignPredicate( + invoke_gpt, + name, + input_arg_types=[scallopy.predicate.Type(arg_ty) for (_, arg_ty) in arg_types[:num_bounded]], + output_arg_types=[scallopy.predicate.Type(arg_ty) for (_, arg_ty) in arg_types[num_bounded:]], + tag_type=None, + ) + + # Remove the item and register a foreign predicate + return scallopy.attribute.MultipleActions([ + scallopy.attribute.RegisterForeignPredicateAction(foreign_predicate), + scallopy.attribute.RemoveItemAction(), + ]) + + +def fill_prompt(prompt: str, args: List[str], arg_types: List[Tuple[str, str]], num_bounded: int): + arg_patterns = ["{" + arg_name + "}" for (arg_name, _) in arg_types] + free_args = ["" for (arg_name, _) in arg_types[num_bounded:]] + filled_prompt = prompt + for (arg_pattern, fill) in zip(arg_patterns[:num_bounded], args): + filled_prompt = filled_prompt.replace(arg_pattern, str(fill)) + for (arg_pattern, fill) in zip(arg_patterns[num_bounded:], free_args): + filled_prompt = filled_prompt.replace(arg_pattern, str(fill)) + return filled_prompt + + +def parse_choice_text(text: str, arg_types: List[Tuple[str, str]], num_bounded: int): + ret_arg_regexes = [re.compile(f"{arg_name}\s*=\s*(.+)$\n?", re.MULTILINE) for (arg_name, _) in arg_types[num_bounded:]] + matches = [next(iter(ret_arg_regex.finditer(text))) for ret_arg_regex in ret_arg_regexes] + answers = ["" if match is None else parse_value(match[1], arg_ty) for (match, (_, arg_ty)) in zip(matches, arg_types[num_bounded:])] + return tuple(answers) + + +def parse_value(text: str, arg_type: str) -> Any: + if arg_type == "F32" or arg_type == "F64": + return float(text) + elif arg_type == "I8" or arg_type == "I16" or arg_type == "I32" or arg_type == "I64" or arg_type == "ISize" or \ + arg_type == "U8" or arg_type == "U16" or arg_type == "U32" or arg_type == "U64" or arg_type == "USize": + return int(float(text)) + elif arg_type == "Bool": + if text == "true" or text == "True": return True + elif text == "false" or text == "False": return False + else: return False + elif arg_type == "String": + return text + else: + raise NotImplemented() diff --git a/etc/scallop-cli/scallop/stdlib/ff/__init__.py b/etc/scallop-cli/scallop/stdlib/ff/__init__.py new file mode 100644 index 0000000..9a178ed --- /dev/null +++ b/etc/scallop-cli/scallop/stdlib/ff/__init__.py @@ -0,0 +1 @@ +from .gpt_complete import gpt_complete diff --git a/etc/scallop-cli/scallop/stdlib/ff/gpt_complete.py b/etc/scallop-cli/scallop/stdlib/ff/gpt_complete.py new file mode 100644 index 0000000..2a08085 --- /dev/null +++ b/etc/scallop-cli/scallop/stdlib/ff/gpt_complete.py @@ -0,0 +1,27 @@ +import openai +import scallopy + +from ...config import openai as openai_config + +# For memoization +STORAGE = {} + + +@scallopy.foreign_function +def gpt_complete(s: str) -> str: + if s in STORAGE: + return STORAGE[s] + elif not openai_config.CONFIGURED: + raise Exception("Open AI Plugin not configured; consider setting OPENAI_API_KEY") + elif openai_config.NUM_PERFORMED_REQUESTS > openai_config.NUM_ALLOWED_REQUESTS: + raise Exception("Exceeding allowed number of requests") + else: + openai_config.NUM_PERFORMED_REQUESTS += 1 + response = openai.Completion.create( + model=openai_config.MODEL, + prompt=s, + temperature=openai_config.TEMPERATURE) + choice = response["choices"][0] + result = choice["text"].strip() + STORAGE[s] = result + return result diff --git a/etc/scallop-cli/scallop/stdlib/fp/__init__.py b/etc/scallop-cli/scallop/stdlib/fp/__init__.py new file mode 100644 index 0000000..9a178ed --- /dev/null +++ b/etc/scallop-cli/scallop/stdlib/fp/__init__.py @@ -0,0 +1 @@ +from .gpt_complete import gpt_complete diff --git a/etc/scallop-cli/scallop/stdlib/fp/gpt_complete.py b/etc/scallop-cli/scallop/stdlib/fp/gpt_complete.py new file mode 100644 index 0000000..70b682d --- /dev/null +++ b/etc/scallop-cli/scallop/stdlib/fp/gpt_complete.py @@ -0,0 +1,32 @@ +from typing import Tuple + +import openai +import scallopy + +from ...config import openai as openai_config + +STORAGE = {} + +@scallopy.foreign_predicate +def gpt_complete(s: str) -> scallopy.Generator[None, str]: + # Check if the storage already contains the response + if s in STORAGE: + response = STORAGE[s] + else: + if not openai_config.CONFIGURED: + raise Exception("Open AI Plugin not configured; consider setting OPENAI_API_KEY") + elif openai_config.NUM_PERFORMED_REQUESTS > openai_config.NUM_ALLOWED_REQUESTS: + raise Exception("Exceeding allowed number of requests") + else: + # Memoize the response + openai_config.NUM_PERFORMED_REQUESTS += 1 + response = openai.Completion.create( + model=openai_config.MODEL, + prompt=s, + temperature=openai_config.TEMPERATURE) + STORAGE[s] = response + + # Iterate through all the choices + for choice in response["choices"]: + result = choice["text"].strip() + yield (result,) diff --git a/etc/scallop-cli/setup.py b/etc/scallop-cli/setup.py new file mode 100644 index 0000000..0086768 --- /dev/null +++ b/etc/scallop-cli/setup.py @@ -0,0 +1,22 @@ +from distutils.core import setup + +setup( + name='scallop', + version='0.2.0', + packages=[ + 'scallop', + 'scallop.config', + 'scallop.stdlib', + 'scallop.stdlib.ff', + 'scallop.stdlib.fp', + 'scallop.stdlib.attr', + ], + requires=[ + 'scallopy' + ], + entry_points = { + 'console_scripts': [ + 'scallop=scallop.cli:main' + ], + } +) diff --git a/etc/scallop-wasm/Cargo.toml b/etc/scallop-wasm/Cargo.toml index 5e09a0b..59cae89 100644 --- a/etc/scallop-wasm/Cargo.toml +++ b/etc/scallop-wasm/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallop-wasm" -version = "0.1.9" +version = "0.2.0" authors = ["Ziyang Li"] edition = "2018" diff --git a/etc/scallopy/Cargo.toml b/etc/scallopy/Cargo.toml index bdc9f42..69c4f76 100644 --- a/etc/scallopy/Cargo.toml +++ b/etc/scallopy/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scallopy" -version = "0.1.9" +version = "0.2.0" edition = "2018" [lib] @@ -11,7 +11,13 @@ crate-type = ["cdylib"] scallop-core = { path = "../../core" } sclc-core = { path = "../sclc" } rayon = "1.5" +tch = { features = ["python-extension"], version = "0.13.0", optional = true } +serde = { version = "1.0", features = ["derive"] } +pythonize = "0.18" [dependencies.pyo3] version = "0.18.2" features = ["extension-module"] + +[features] +torch-tensor = ["scallop-core/torch-tensor", "dep:tch"] diff --git a/etc/scallopy/examples/foreign_predicate.py b/etc/scallopy/examples/foreign_predicate.py index 166a385..15c3974 100644 --- a/etc/scallopy/examples/foreign_predicate.py +++ b/etc/scallopy/examples/foreign_predicate.py @@ -23,4 +23,6 @@ def string_semantic_eq(s1: str, s2: str) -> Generator[float, Tuple]: ctx.add_rule("parent(a, b) = kinship(a, r, b) and string_semantic_eq(r, \"mother\")") ctx.add_rule("sibling(a, b) = parent(c, a) and parent(c, b) and a != b") ctx.run() -print(list(ctx.relation("sibling"))) +print("kinship", list(ctx.relation("kinship"))) +print("sibling", list(ctx.relation("sibling"))) +print("parent", list(ctx.relation("parent"))) diff --git a/etc/scallopy/makefile b/etc/scallopy/makefile index 01202c3..685c7a5 100644 --- a/etc/scallopy/makefile +++ b/etc/scallopy/makefile @@ -1,6 +1,9 @@ -all: +develop-scallopy: maturin develop --release +develop-scallopy-plus: + maturin develop --release --features "torch-tensor" + test: python3 tests/test.py diff --git a/etc/scallopy/scallopy/__init__.py b/etc/scallopy/scallopy/__init__.py index 8f5e957..f74ca49 100644 --- a/etc/scallopy/scallopy/__init__.py +++ b/etc/scallopy/scallopy/__init__.py @@ -1,10 +1,12 @@ from .context import ScallopContext from .forward import ScallopForwardFunction from .provenance import ScallopProvenance -from .function import GenericTypeParameter, foreign_function -from .predicate import Generator, foreign_predicate +from .function import GenericTypeParameter, foreign_function, ForeignFunction +from .predicate import Generator, foreign_predicate, ForeignPredicate from .types import * from .input_mapping import InputMapping +from .scallopy import torch_tensor_enabled +from .attribute import foreign_attribute, ForeignAttributeProcessor # Provide a few aliases Context = ScallopContext diff --git a/etc/scallopy/scallopy/attribute.py b/etc/scallopy/scallopy/attribute.py new file mode 100644 index 0000000..b387c37 --- /dev/null +++ b/etc/scallopy/scallopy/attribute.py @@ -0,0 +1,100 @@ +from typing import * + +from . import predicate + + +class AttributeAction: + def __init__(self): + self.name = None + + +class MultipleActions(AttributeAction): + def __init__(self, actions: List[AttributeAction]): + self.name = "multiple" + self.actions = actions + + +class RemoveItemAction(AttributeAction): + def __init__(self): + self.name = "remove_item" + + +class NoAction(AttributeAction): + def __init__(self): + self.name = "no_action" + + +class ErrorAction(AttributeAction): + def __init__(self, msg): + self.name = "error" + self.msg = msg + + +class RegisterForeignPredicateAction(AttributeAction): + def __init__(self, fp): + self.name = "register_foreign_predicate" + self.foreign_predicate = fp + + +class ForeignAttributeProcessor: + def __init__(self, name: str, processor: Callable): + self.name = name + self.processor = processor + + def process_value(self, value): + if "Constant" in value["node"]: + return self.process_constant(value["node"]["Constant"]) + elif "List" in value["node"]: + return [self.process_value(v) for v in value["node"]["List"]] + else: + raise NotImplemented() + + def process_constant(self, constant): + if "String" in constant["node"]: + return constant["node"]["String"]["node"]["string"] + elif "Integer" in constant["node"]: + return constant["node"]["Integer"] + elif "Boolean" in constant["node"]: + return constant["node"]["Boolean"] + elif "Float" in constant["node"]: + return constant["node"]["Float"] + else: + print(constant["node"]) + raise NotImplemented() + + def process_kw_arg_name(self, kw_arg): + return kw_arg[0]["node"]["name"] + + def process_kw_arg_value(self, kw_arg): + return self.process_value(kw_arg[1]) + + def process_attribute(self, attr): + # attr_name = attr["node"]["name"]["node"]["name"] + pos_args = [self.process_value(pos_arg) for pos_arg in attr["node"]["pos_args"]] + kw_args = {self.process_kw_arg_name(kw_arg): self.process_kw_arg_value(kw_arg) for kw_arg in attr["node"]["kw_args"]} + return (pos_args, kw_args) + + def apply(self, item, attr): + (pos_args, kw_args) = self.process_attribute(attr) + try: + result = self.processor(item, *pos_args, **kw_args) + if result is None: + return NoAction() + elif isinstance(result, AttributeAction): + return result + else: + raise Exception("Invalid return value of foreign attribute") + except Exception as err: + return ErrorAction(str(err)) + + +def foreign_attribute(func) -> ForeignAttributeProcessor: + """ + A decorator + """ + + # Get the function name + func_name = func.__name__ + + # Get the attribute processor + return ForeignAttributeProcessor(func_name, func) diff --git a/etc/scallopy/scallopy/collection.py b/etc/scallopy/scallopy/collection.py index 70972ba..e7b567d 100644 --- a/etc/scallopy/scallopy/collection.py +++ b/etc/scallopy/scallopy/collection.py @@ -23,19 +23,24 @@ def __iter__(self): else: yield (1 - p, t) elif self.provenance == "diffaddmultprob" or \ + self.provenance == "diffnandmultprob" or \ + self.provenance == "diffmaxmultprob" or \ + self.provenance == "diffnandminprob" or \ self.provenance == "difftopkproofs" or \ self.provenance == "diffsamplekproofs" or \ self.provenance == "difftopbottomkclauses": + input_tags = self._internal.input_tags() for ((p, deriv), t) in self._internal: - diff_prob = diff_proofs_prob(p, deriv) + diff_prob = diff_proofs_prob(p, deriv, input_tags) yield (diff_prob, t) else: for t in self._internal: yield t -def diff_proofs_prob(p: float, deriv: List[Tuple[int, float, Tensor]]): +def diff_proofs_prob(p: float, deriv: List[Tuple[int, float, Tensor]], input_tags: List[Tensor]): def hook(grad): - for (_, weight, source_tensor) in deriv: + for (tag_id, weight) in deriv: + source_tensor = input_tags[tag_id] if source_tensor.requires_grad: source_tensor.backward(weight * grad, retain_graph=True) v = torch.tensor(p, requires_grad=True) diff --git a/etc/scallopy/scallopy/context.py b/etc/scallopy/scallopy/context.py index 9a75dfc..c9e8b81 100644 --- a/etc/scallopy/scallopy/context.py +++ b/etc/scallopy/scallopy/context.py @@ -13,9 +13,10 @@ from .input_mapping import InputMapping from .function import ForeignFunction from .predicate import ForeignPredicate +from .attribute import ForeignAttributeProcessor from .history import HistoryAction, record_history from .sample_type import SAMPLE_TYPE_TOP_K -from .utils import Counter +from .utils import Counter, _map_entity_tuple_to_str_tuple # Main context class ScallopContext(Context): @@ -80,7 +81,10 @@ def __init__( elif provenance == "diffmaxmultprob2": provenance = "custom" custom_provenance = DiffMaxMultProb2Semiring() - else: pass + elif custom_provenance is not None: + provenance = "custom" + else: + pass # Setup self.provenance = provenance @@ -242,6 +246,13 @@ def register_foreign_predicate(self, foreign_predicate: ForeignPredicate): else: raise Exception("Registering non-foreign-predicate. Consider decorating the function with @scallopy.foreign_predicate") + @record_history + def register_foreign_attribute(self, foreign_attribute: ForeignAttributeProcessor): + if type(foreign_attribute) == ForeignAttributeProcessor: + self._internal.register_foreign_attribute(foreign_attribute) + else: + raise Exception("Registering non-foreign-attribute. Consider decorating the function with @scallopy.attribute") + def forward_function( self, output: Optional[str] = None, @@ -484,6 +495,24 @@ def add_rule(self, rule: str, tag: Optional[Any] = None, demand: Optional[str] = """ self._internal.add_rule(rule, tag=tag, demand=demand) + @record_history + def add_entity(self, relation: str, entity: Union[str, Tuple]): + """ + Add an entity to the context. + + ``` python + ctx.add_program("type Expr = Const(i32) | Add(Expr, Expr)") + ctx.add_rule("eval(e, y) = case e is Const(y)") + ctx.add_rule("eval(e, y1 + y2) = case e is Add(e1, e2) and eval(e1, y1) and eval(e2, y2)") + ctx.add_rule("result(y) = root(e) and eval(e, y)") + ctx.add_entity("root", "Add(Const(5), Add(Const(3), Const(4)))") + ctx.run() + print(ctx.relation("result")) # [(12,)], where 12 is derived from 5 + (3 + 4) + ``` + """ + entity = _map_entity_tuple_to_str_tuple(entity) + self._internal.add_entity(relation, entity) + @record_history def compile(self): self._internal.compile() @@ -621,7 +650,6 @@ def supports_disjunctions(self) -> bool: "topbottomkclauses", "diffsamplekproofs", "difftopkproofs", - "difftopkproofsindiv", "difftopbottomkclauses", ]) return self.provenance in PROVENANCE_SUPPORTING_DISJUNCTIONS diff --git a/etc/scallopy/scallopy/forward.py b/etc/scallopy/scallopy/forward.py index 4975cf4..d473901 100644 --- a/etc/scallopy/scallopy/forward.py +++ b/etc/scallopy/scallopy/forward.py @@ -5,13 +5,13 @@ import zipfile import shutil -from .torch_importer import * +from .torch_importer import Context, Tensor, torch from .sample_type import * from .context import ScallopContext from .provenance import ScallopProvenance -from .utils import _mapping_tuple +from .utils import _mapping_tuple, _map_entity_tuple_to_str_tuple -class ScallopForwardFunction(torch.nn.Module): +class ScallopForwardFunction(Context): def __init__( self, file: Optional[str] = None, @@ -104,7 +104,7 @@ def __call__(self, *pos_args, **kw_args): return self.forward_fn(*pos_args, **kw_args) -class InternalScallopForwardFunction(torch.nn.Module): +class InternalScallopForwardFunction(Context): FORWARD_FN_COUNTER = 1 """ @@ -261,7 +261,8 @@ def _check_compiled(self, so_dir): def __call__( self, - disjunctions: Optional[Dict[str, List[List[List[int]]]]] = None, + disjunctions: Dict[str, List[List[List[int]]]] = {}, + entities: Dict[str, List[List[str]]] = {}, output_relations: Optional[List[Union[str, List[str]]]] = None, **input_facts: Dict[str, Union[Tensor, List]], ) -> Union[Tensor, Tuple[List[Tuple], Tensor]]: @@ -274,14 +275,20 @@ def __call__( - None, if outputs are provided """ if self.jit: - return self._call_with_static_ctx(disjunctions=disjunctions, input_facts=input_facts) + return self._call_with_static_ctx( + disjunctions=disjunctions, + input_facts=input_facts) else: - return self._call_with_dynamic_ctx(disjunctions=disjunctions, output_relations=output_relations, input_facts=input_facts) + return self._call_with_dynamic_ctx( + disjunctions=disjunctions, + entities=entities, + output_relations=output_relations, + input_facts=input_facts) def _call_with_static_ctx( self, - disjunctions: Optional[Dict] = None, - input_facts: Dict[str, Union[Tensor, List]] = None, + disjunctions: Dict, + input_facts: Dict[str, Union[Tensor, List]], ): # First make sure all facts share the same batch size batch_size = self._compute_and_check_batch_size(input_facts) @@ -315,16 +322,21 @@ def _get_k(self): def _call_with_dynamic_ctx( self, - disjunctions: Optional[Dict] = None, - output_relations: Optional[List[Union[str, List[str]]]] = None, - input_facts: Dict[str, Union[Tensor, List]] = None, + disjunctions: Optional[Dict], + entities: Optional[Dict[str, List[List[str]]]], + output_relations: Optional[List[Union[str, List[str]]]], + input_facts: Dict[str, Union[Tensor, List]], ): self.ctx._refresh_training_eval_state() # Set train/eval self.ctx._internal.set_non_incremental() self.ctx._internal.compile() # Compile into back IR # First make sure that all facts share the same batch size - batch_size = self._compute_and_check_batch_size(input_facts) + batch_size = self._compute_and_check_batch_size(input_facts, entities) + + # Get all the entity facts + for (entity_relation, entity_facts) in self._compute_entity_facts(batch_size, entities).items(): + input_facts[entity_relation] = entity_facts # Process the input into a unified form all_inputs = self._process_all_input_facts(batch_size, input_facts, disjunctions) @@ -358,7 +370,7 @@ def _call_with_dynamic_ctx( # Process the output return self._process_output(batch_size, input_tags, output_results) - def _compute_and_check_batch_size(self, inputs: Dict[str, Union[Tensor, List]]) -> int: + def _compute_and_check_batch_size(self, inputs: Dict[str, Union[Tensor, List]], entities: Dict[str, List[List[str]]] = {}) -> int: """ Given the inputs, check if the batch size is consistent over all relations. If so, return the batch size. @@ -370,11 +382,29 @@ def _compute_and_check_batch_size(self, inputs: Dict[str, Union[Tensor, List]]) batch_size = len(rela_facts) elif batch_size != len(rela_facts): raise Exception(f"Inconsistency in batch size: expected {batch_size}, found {len(rela_facts)} for relation `{rela}`") + for (rela, rela_entities) in entities.items(): + if batch_size is None: + batch_size = len(rela_entities) + elif batch_size != len(rela_entities): + raise Exception(f"Inconsistency in entity batch size: expected {batch_size}, found {len(rela_facts)} for relation `{rela}`") if batch_size is None: raise Exception("There is no input to the forward function") else: return batch_size + def _compute_entity_facts(self, batch_size, entities): + all_entity_facts = {} + for (relation, batch_of_entities) in entities.items(): + for (i, entities) in enumerate(batch_of_entities): + for entity in entities: + entity = _map_entity_tuple_to_str_tuple(entity) + entity_raw_facts = self.ctx._internal.compile_entity(relation, entity) + for (obj_relation, obj_facts) in entity_raw_facts.items(): + if obj_relation not in all_entity_facts: + all_entity_facts[obj_relation] = [[] for _ in range(batch_size)] + all_entity_facts[obj_relation][i] += [(None, fact) for fact in obj_facts] + return all_entity_facts + def _process_all_input_facts(self, batch_size, all_input_facts, all_disjunctions): """ Given all the input facts where facts may be lists or tensors, @@ -633,30 +663,6 @@ def _batched_prob( tensor_results.append(torch.stack(task_tensor_results)) return self._torch_tensor_apply(torch.stack(tensor_results)) - # Provenance diff topkproofs indiv - elif self.ctx.provenance == "difftopkproofsindiv": - tensor_results = [] - for task_results in tasks: - task_tensor_results = [] - for (k, proofs) in task_results: - task_tensor_proofs = [] - for i in range(k): - if i < len(proofs): - proof = proofs[i] - if len(proof) == 0: - task_tensor_proofs.append(torch.tensor(1.0)) - else: - get_literal_prob = lambda i: proof[i][2] if proof[i][1] else 1 - proof[i][2] - agg_prob = get_literal_prob(0) - for j in range(1, len(proof)): - agg_prob *= get_literal_prob(j) - task_tensor_proofs.append(agg_prob) - else: - task_tensor_proofs.append(torch.tensor(0.0)) - task_tensor_results.append(torch.stack(task_tensor_proofs)) - tensor_results.append(torch.stack(task_tensor_results)) - return self._torch_tensor_apply(torch.stack(tensor_results)) - # If we have a custom provenance, use its collate function elif self.ctx.provenance == "custom": return self.ctx._custom_provenance.collate(tasks) @@ -672,7 +678,6 @@ def _has_output_hook(self): elif self.ctx.provenance == "diffnandminprob": return True elif self.ctx.provenance == "diffsamplekproofs": return True elif self.ctx.provenance == "difftopkproofs": return True - elif self.ctx.provenance == "difftopkproofsindiv": return False elif self.ctx.provenance == "difftopbottomkclauses": return True else: return False @@ -691,8 +696,6 @@ def _batched_output_hook( self.ctx.provenance == "diffsamplekproofs" or \ self.ctx.provenance == "difftopbottomkclauses": return self._diff_proofs_batched_output_hook(input_tags, tasks) - elif self.ctx.provenance == "difftopkproofsindiv": - raise Exception("[Internal Error] Should not happen; difftopkproofsindiv does not need output hook") else: raise Exception("[Internal Error] Should not happen") diff --git a/etc/scallopy/scallopy/input_mapping.py b/etc/scallopy/scallopy/input_mapping.py index 814dbea..05282c3 100644 --- a/etc/scallopy/scallopy/input_mapping.py +++ b/etc/scallopy/scallopy/input_mapping.py @@ -120,22 +120,28 @@ def _process_one_tensor(self, tensor: Tensor, mutual_exclusion_counter: Counter) if self.retain_k is not None: if self.sample_dim is not None: if self.sample_strategy == "categorical": - raise NotImplementedError() + transposed_tensor = tensor.transpose(self.sample_dim, tensor.dim() - 1) + distributions = torch.distributions.Categorical(transposed_tensor) + sampled_indices = distributions.sample((self.retain_k,)).transpose(0, transposed_tensor.dim() - 1) + for index in self._convert_categorical_sampled_indices(sampled_indices, self.sample_dim): + inc.add(index) elif self.sample_strategy == "top": topk_result = torch.topk(tensor, self.retain_k, dim=self.sample_dim) - for index in self._convert_sampled_indices(topk_result.indices, self.sample_dim): + for index in self._convert_topk_sampled_indices(topk_result.indices, self.sample_dim): inc.add(index) else: raise Exception(f"Unknown sample strategy `{self.sample_strategy}`") else: flat_tensor = torch.flatten(tensor) if self.sample_strategy == "categorical": - raise NotImplementedError() + categorical_distr = torch.distributions.Categorical(probs=flat_tensor) + sampled_indices = categorical_distr.sample((self.retain_k,)) + for index in sampled_indices: + inc.add(self._index_to_mult_dim_index(int(index))) elif self.sample_strategy == "top": topk_result = torch.topk(flat_tensor, self.retain_k) for index in topk_result.indices: - mult_dim_index = self._index_to_mult_dim_index(int(index)) - inc.add(mult_dim_index) + inc.add(self._index_to_mult_dim_index(int(index))) else: raise Exception(f"Unknown sample strategy `{self.sample_strategy}`") @@ -335,13 +341,20 @@ def _index_to_mult_dim_index(self, index): acc_index = acc_index // self._shape[i - 1] return tuple(reversed(mult_dim_index)) - def _convert_sampled_indices(self, sampled_indices, sample_dim): + def _convert_topk_sampled_indices(self, sampled_indices, sample_dim): for partial_index in itertools.product(*[range(d) for (i, d) in enumerate(self._shape) if i != sample_dim]): before, after = partial_index[:sample_dim], partial_index[sample_dim:] indexing_slice = before + (slice(None),) + after for sampled_k_index in sampled_indices[indexing_slice]: yield before + (int(sampled_k_index),) + after + def _convert_categorical_sampled_indices(self, sampled_indices, sample_dim): + for partial_index in itertools.product(*[range(d) for (i, d) in enumerate(self._shape) if i != sample_dim]): + before, after = partial_index[:sample_dim], partial_index[sample_dim:] + indexing_slice = before + after + for sampled_k_index in sampled_indices[indexing_slice]: + yield before + (int(sampled_k_index),) + after + def _is_primitive(self, e) -> bool: ty = type(e) return ty == int or ty == float or ty == str or ty == bool diff --git a/etc/scallopy/scallopy/io.py b/etc/scallopy/scallopy/io.py index 668b158..84403cd 100644 --- a/etc/scallopy/scallopy/io.py +++ b/etc/scallopy/scallopy/io.py @@ -1,4 +1,6 @@ -from typing import Optional +from typing import Optional, List, Union +from copy import deepcopy + class CSVFileOptions: """ @@ -15,8 +17,45 @@ def __init__( deliminator: Optional[str] = None, has_header: bool = False, has_probability: bool = False, + keys: Optional[Union[List[str], str]] = None, + fields: Optional[List[str]] = None, ): + # Basic properties self.path = path self.deliminator = deliminator - self.has_header = has_header self.has_probability = has_probability + + # Sanitize the keys and fields + self.keys = _sanitize_keys(keys) + self.fields = fields + + # If there is key or field, then there has to be header + self.has_header = has_header or (self.keys is not None) or (self.fields is not None) + + def with_deliminator(self, deliminator: str): + copied = deepcopy(self) + copied.deliminator = deliminator + return copied + + def with_fields(self, fields: Optional[List[str]]): + copied = deepcopy(self) + copied.fields = fields + return copied + + def with_keys(self, keys: Optional[Union[str, List[str]]]): + copied = deepcopy(self) + copied.keys = keys + return copied + + +def _sanitize_keys(keys: Optional[Union[str, List[str]]]) -> Optional[List[str]]: + if keys is not None: + if type(keys) is str: + return [keys] + else: + assert type(keys) is list, "`keys` should be a string or a list of strings" + for elem in keys: + assert type(elem) is str, "an element in `keys` should be a string" + return keys + else: + return None diff --git a/etc/scallopy/scallopy/predicate.py b/etc/scallopy/scallopy/predicate.py index 27e82ed..b4dd703 100644 --- a/etc/scallopy/scallopy/predicate.py +++ b/etc/scallopy/scallopy/predicate.py @@ -1,6 +1,27 @@ from typing import * import inspect + +ALIASES = { + "I8": "i8", + "I16": "i16", + "I32": "i32", + "I64": "i64", + "I128": "i128", + "ISize": "isize", + "U8": "u8", + "U16": "u16", + "U32": "u32", + "U64": "u64", + "U128": "u128", + "USize": "usize", + "F32": "f32", + "F64": "f64", + "Char": "char", + "Bool": "bool", +} + + # Predicate Data Type class Type: def __init__(self, value): @@ -20,6 +41,8 @@ def __init__(self, value): value == "bool" or value == "char" or value == "String" or \ value == "DateTime" or value == "Duration": self.type = value + elif value in ALIASES: + self.type = ALIASES[value] else: raise Exception(f"Unknown scallop predicate type annotation `{value}`") @@ -86,7 +109,7 @@ def pattern(self): return "b" * len(self.input_arg_types) + "f" * len(self.output_arg_types) def does_output_tag(self): - return self.tag_type is not None + return self.tag_type is not None and self.tag_type is not type(None) def foreign_predicate(func: Callable): @@ -133,11 +156,12 @@ def string_chars(s: str) -> scallopy.Generator[Tuple[int, char]]: if len(args) != 2: raise Exception(f"Generator must have 2 type arguments") + # Produce return tag type + return_tag_type = _extract_return_tag_type(args[0]) + # Produce return tuple type, and check that they are all base type - return_tuple_type = _extract_return_tuple_type(args[0]) + return_tuple_type = _extract_return_tuple_type(args[1]) - # Produce return tag type - return_tag_type = _extract_return_tag_type(args[1]) # Create the foreign predicate return ForeignPredicate( diff --git a/etc/scallopy/scallopy/scallopy.pyi b/etc/scallopy/scallopy/scallopy.pyi index 537db95..7f6aab5 100644 --- a/etc/scallopy/scallopy/scallopy.pyi +++ b/etc/scallopy/scallopy/scallopy.pyi @@ -3,6 +3,14 @@ from typing import Dict, List, Union, Tuple, Optional, Any from .provenance import ScallopProvenance from .io import CSVFileOptions + +def torch_tensor_enabled() -> bool: + """ + Returns a boolean indicating whether this version of `scallopy` + is compiled with torch tensor enabled. + """ + + class InternalScallopContext: def __init__( self, @@ -64,8 +72,16 @@ class InternalScallopContext: demand: Optional[str] = None, ) -> None: ... + def add_entity(self, relation: str, entity_tuple: Tuple[str]) -> None: ... + + def compile_entity(self, relation: str, entity_tuple: Tuple[str]) -> Dict[str, List[Tuple]]: ... + def register_foreign_function(self, ff: Any) -> None: ... + def register_foreign_predicate(self, fp: Any) -> None: ... + + def register_foreign_attribute(self, ff: Any) -> None: ... + def dump_front_ir(self): ... def relation(self, relation: str) -> InternalScallopCollection: ... @@ -90,6 +106,16 @@ class InternalScallopCollection: Get the number of input facts for a valid provenance semiring """ + def input_tags(self) -> Optional[List[Any]]: + """ + Get all the input tags + """ + + def len(self) -> int: + """ + Get the number of elements in the collection + """ + def __iter__(self) -> InternalScallopCollectionIterator: """ Iterate through the tuples of the collection diff --git a/etc/scallopy/scallopy/torch_importer.py b/etc/scallopy/scallopy/torch_importer.py index a3057e1..8fa0e2c 100644 --- a/etc/scallopy/scallopy/torch_importer.py +++ b/etc/scallopy/scallopy/torch_importer.py @@ -17,3 +17,4 @@ class _Tensor: has_pytorch = False Context = _Context Tensor = _Tensor + torch = None diff --git a/etc/scallopy/scallopy/types.py b/etc/scallopy/scallopy/types.py index 86020f0..e7a7bcd 100644 --- a/etc/scallopy/scallopy/types.py +++ b/etc/scallopy/scallopy/types.py @@ -15,8 +15,12 @@ f64 = "f64" bool = "bool" char = "char" -string = "String" -# rc_string = "Rc" +String = "String" +Symbol = "Symbol" +DateTime = "DateTime" +Duration = "Duration" +Entity = "Entity" +Tensor = "Tensor" # Type families Any = "Any" diff --git a/etc/scallopy/scallopy/utils.py b/etc/scallopy/scallopy/utils.py index 0a7bf2f..c36fff5 100644 --- a/etc/scallopy/scallopy/utils.py +++ b/etc/scallopy/scallopy/utils.py @@ -5,6 +5,25 @@ def _mapping_tuple(t): return t if type(t) == tuple else (t,) +def _map_entity_to_str(entity): + if type(entity) == bool: + return "true" if entity else "false" + elif type(entity) == int or type(entity) == float: + return str(entity) + elif type(entity) == str: + return entity + else: + raise Exception(f"Unknown entity type {type(entity)}") + + +def _map_entity_tuple_to_str_tuple(entity): + if type(entity) is tuple or type(entity) is list: + return tuple([_map_entity_to_str(element) for element in entity]) + else: + entity_element = _map_entity_to_str(entity) + return (entity_element,) + + class Counter: def __init__(self): self.count = 0 diff --git a/etc/scallopy/src/collection.rs b/etc/scallopy/src/collection.rs index 9aa3615..a5ab769 100644 --- a/etc/scallopy/src/collection.rs +++ b/etc/scallopy/src/collection.rs @@ -3,10 +3,14 @@ use pyo3::prelude::*; use scallop_core::common; use scallop_core::runtime::dynamic::*; +use scallop_core::runtime::env::*; use scallop_core::runtime::provenance::*; use scallop_core::utils::*; +use crate::runtime::PythonRuntimeEnvironment; + use super::custom_tag; +use super::external_tag::*; use super::provenance::*; use super::tuple::*; @@ -31,40 +35,36 @@ pub enum CollectionEnum { collection: P::Rc>>, }, DiffMinMaxProb { - collection: P::Rc, P>>>, - tags: P::RcCell>>, + collection: P::Rc>>, + tags: P::RcCell>, }, DiffAddMultProb { - collection: P::Rc, P>>>, - tags: P::RcCell>>, + collection: P::Rc>>, + tags: P::RcCell>, }, DiffNandMultProb { - collection: P::Rc, P>>>, - tags: P::RcCell>>, + collection: P::Rc>>, + tags: P::RcCell>, }, DiffMaxMultProb { - collection: P::Rc, P>>>, - tags: P::RcCell>>, + collection: P::Rc>>, + tags: P::RcCell>, }, DiffNandMinProb { - collection: P::Rc, P>>>, - tags: P::RcCell>>, + collection: P::Rc>>, + tags: P::RcCell>, }, DiffSampleKProofs { - collection: P::Rc, P>>>, - tags: DiffProbStorage, P>, + collection: P::Rc>>, + tags: DiffProbStorage, }, DiffTopKProofs { - collection: P::Rc, P>>>, - tags: DiffProbStorage, P>, - }, - DiffTopKProofsIndiv { - collection: P::Rc, P>>>, - tags: DiffProbStorage, P>, + collection: P::Rc>>, + tags: DiffProbStorage, }, DiffTopBottomKClauses { - collection: P::Rc, P>>>, - tags: DiffProbStorage, P>, + collection: P::Rc>>, + tags: DiffProbStorage, }, Custom { collection: P::Rc>, @@ -87,7 +87,6 @@ macro_rules! match_collection { CollectionEnum::DiffNandMinProb { collection: $v, .. } => $e, CollectionEnum::DiffSampleKProofs { collection: $v, .. } => $e, CollectionEnum::DiffTopKProofs { collection: $v, .. } => $e, - CollectionEnum::DiffTopKProofsIndiv { collection: $v, .. } => $e, CollectionEnum::DiffTopBottomKClauses { collection: $v, .. } => $e, CollectionEnum::Custom { collection: $v } => $e, } @@ -110,7 +109,6 @@ impl CollectionEnum { Self::DiffNandMinProb { tags, .. } => Some(ArcFamily::get_rc_cell(tags, |t| t.len())), Self::DiffSampleKProofs { tags, .. } => Some(tags.num_input_tags()), Self::DiffTopKProofs { tags, .. } => Some(tags.num_input_tags()), - Self::DiffTopKProofsIndiv { tags, .. } => Some(tags.num_input_tags()), Self::DiffTopBottomKClauses { tags, .. } => Some(tags.num_input_tags()), Self::Custom { .. } => None, } @@ -125,14 +123,13 @@ impl CollectionEnum { Self::TopKProofs { .. } => None, Self::TopBottomKClauses { .. } => None, Self::DiffMinMaxProb { .. } => None, - Self::DiffAddMultProb { tags, .. } => Some(ArcFamily::get_rc_cell(tags, |t| t.clone())), - Self::DiffNandMultProb { tags, .. } => Some(ArcFamily::clone_rc_cell_internal(tags)), - Self::DiffMaxMultProb { tags, .. } => Some(ArcFamily::clone_rc_cell_internal(tags)), - Self::DiffNandMinProb { tags, .. } => Some(ArcFamily::clone_rc_cell_internal(tags)), - Self::DiffSampleKProofs { tags, .. } => Some(tags.input_tags()), - Self::DiffTopKProofs { tags, .. } => Some(tags.input_tags()), - Self::DiffTopKProofsIndiv { tags, .. } => Some(tags.input_tags()), - Self::DiffTopBottomKClauses { tags, .. } => Some(tags.input_tags()), + Self::DiffAddMultProb { tags, .. } => Some(ArcFamily::get_rc_cell(tags, |t| t.clone()).into_vec()), + Self::DiffNandMultProb { tags, .. } => Some(ArcFamily::clone_rc_cell_internal(tags).into_vec()), + Self::DiffMaxMultProb { tags, .. } => Some(ArcFamily::clone_rc_cell_internal(tags).into_vec()), + Self::DiffNandMinProb { tags, .. } => Some(ArcFamily::clone_rc_cell_internal(tags).into_vec()), + Self::DiffSampleKProofs { tags, .. } => Some(tags.input_tags().into_vec()), + Self::DiffTopKProofs { tags, .. } => Some(tags.input_tags().into_vec()), + Self::DiffTopBottomKClauses { tags, .. } => Some(tags.input_tags().into_vec()), Self::Custom { .. } => None, } } @@ -163,10 +160,18 @@ impl CollectionEnum { match_collection!(self, c, ith_tag_helper(ArcFamily::get_rc(c), i)) } + + pub fn to_collection(self, env: &RuntimeEnvironment) -> Collection { + Collection { + env: env.into(), + collection: self, + } + } } #[pyclass(unsendable, name = "InternalScallopCollection")] pub struct Collection { + pub env: PythonRuntimeEnvironment, pub collection: CollectionEnum, } @@ -180,22 +185,22 @@ impl Collection { self.collection.input_tags() } + fn len(&self) -> usize { + self.collection.len() + } + fn __iter__(slf: PyRef) -> CollectionIterator { CollectionIterator { + env: slf.env.clone(), collection: slf.collection.clone(), current_index: 0, } } } -impl From> for Collection { - fn from(collection: CollectionEnum) -> Self { - Self { collection } - } -} - #[pyclass(unsendable, name = "InternalScallopCollectionIterator")] pub struct CollectionIterator { + env: PythonRuntimeEnvironment, collection: CollectionEnum, current_index: usize, } @@ -207,10 +212,10 @@ impl CollectionIterator { let i = slf.current_index; slf.current_index += 1; if slf.collection.has_empty_tag() { - let tuple = to_python_tuple(slf.collection.ith_tuple(i)); + let tuple = to_python_tuple(slf.collection.ith_tuple(i), &slf.env); IterNextOutput::Yield(tuple) } else { - let tuple = to_python_tuple(slf.collection.ith_tuple(i)); + let tuple = to_python_tuple(slf.collection.ith_tuple(i), &slf.env); let tag = slf.collection.ith_tag(i); let elem = Python::with_gil(|py| (tag, tuple).to_object(py)); IterNextOutput::Yield(elem) diff --git a/etc/scallopy/src/config.rs b/etc/scallopy/src/config.rs new file mode 100644 index 0000000..342e49f --- /dev/null +++ b/etc/scallopy/src/config.rs @@ -0,0 +1,10 @@ +use pyo3::prelude::*; + +#[pyfunction] +pub fn torch_tensor_enabled() -> bool { + if cfg!(feature = "torch-tensor") { + true + } else { + false + } +} diff --git a/etc/scallopy/src/context.rs b/etc/scallopy/src/context.rs index b2f69ad..394a8dc 100644 --- a/etc/scallopy/src/context.rs +++ b/etc/scallopy/src/context.rs @@ -1,7 +1,7 @@ use std::collections::*; use pyo3::prelude::*; -use pyo3::types::PyList; +use pyo3::types::*; use rayon::prelude::*; @@ -9,6 +9,7 @@ use scallop_core::common::tuple::*; use scallop_core::common::tuple_type::*; use scallop_core::compiler; use scallop_core::integrate::*; +use scallop_core::runtime::env::*; use scallop_core::runtime::monitor; use scallop_core::runtime::provenance::*; use scallop_core::utils::*; @@ -17,6 +18,8 @@ use crate::custom_tag; use super::collection::*; use super::error::*; +use super::external_tag::*; +use super::foreign_attribute::*; use super::foreign_function::*; use super::foreign_predicate::*; use super::io::*; @@ -33,15 +36,14 @@ pub enum ContextEnum { AddMultProb(IntegrateContext), TopKProofs(IntegrateContext, AF>), TopBottomKClauses(IntegrateContext, AF>), - DiffMinMaxProb(IntegrateContext, AF>, AF>), - DiffAddMultProb(IntegrateContext, AF>, AF>), - DiffNandMultProb(IntegrateContext, AF>, AF>), - DiffMaxMultProb(IntegrateContext, AF>, AF>), - DiffNandMinProb(IntegrateContext, AF>, AF>), - DiffSampleKProofs(IntegrateContext, AF>, AF>), - DiffTopKProofs(IntegrateContext, AF>, AF>), - DiffTopKProofsIndiv(IntegrateContext, AF>, AF>), - DiffTopBottomKClauses(IntegrateContext, AF>, AF>), + DiffMinMaxProb(IntegrateContext, AF>), + DiffAddMultProb(IntegrateContext, AF>), + DiffNandMultProb(IntegrateContext, AF>), + DiffMaxMultProb(IntegrateContext, AF>), + DiffNandMinProb(IntegrateContext, AF>), + DiffSampleKProofs(IntegrateContext, AF>), + DiffTopKProofs(IntegrateContext, AF>), + DiffTopBottomKClauses(IntegrateContext, AF>), Custom(IntegrateContext), } @@ -61,7 +63,6 @@ macro_rules! match_context { ContextEnum::DiffNandMinProb($v) => $e, ContextEnum::DiffSampleKProofs($v) => $e, ContextEnum::DiffTopKProofs($v) => $e, - ContextEnum::DiffTopKProofsIndiv($v) => $e, ContextEnum::DiffTopBottomKClauses($v) => $e, ContextEnum::Custom($v) => $e, } @@ -84,7 +85,6 @@ macro_rules! match_context_except_custom { ContextEnum::DiffNandMinProb($v) => Ok($e), ContextEnum::DiffSampleKProofs($v) => Ok($e), ContextEnum::DiffTopKProofs($v) => Ok($e), - ContextEnum::DiffTopKProofsIndiv($v) => Ok($e), ContextEnum::DiffTopBottomKClauses($v) => Ok($e), ContextEnum::Custom(_) => Err(BindingError::CustomProvenanceUnsupported), } @@ -171,11 +171,6 @@ impl Context { diff_top_k_proofs::DiffTopKProofsProvenance::new(k), )), }), - "difftopkproofsindiv" => Ok(Self { - ctx: ContextEnum::DiffTopKProofsIndiv(IntegrateContext::new_incremental( - diff_top_k_proofs_indiv::DiffTopKProofsIndivProvenance::new(k), - )), - }), "difftopbottomkclauses" => Ok(Self { ctx: ContextEnum::DiffTopBottomKClauses(IntegrateContext::new_incremental( diff_top_bottom_k_clauses::DiffTopBottomKClausesProvenance::new(k), @@ -200,77 +195,106 @@ impl Context { } /// Create a new scallop context with a different provenance as the current context - fn clone_with_new_provenance( - &self, - provenance: &str, - k: usize, - ) -> Result { + fn clone_with_new_provenance(&self, provenance: &str, k: usize) -> Result { // Check provenance type match provenance { "unit" => Ok(Self { - ctx: ContextEnum::Unit(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance(unit::UnitProvenance::default()))?), + ctx: ContextEnum::Unit(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(unit::UnitProvenance::default()) + )?), }), "proofs" => Ok(Self { - ctx: ContextEnum::Proofs(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance(proofs::ProofsProvenance::default()))?), + ctx: ContextEnum::Proofs(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(proofs::ProofsProvenance::default()) + )?), }), "minmaxprob" => Ok(Self { - ctx: ContextEnum::MinMaxProb(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance(min_max_prob::MinMaxProbProvenance::default()))?), + ctx: ContextEnum::MinMaxProb(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(min_max_prob::MinMaxProbProvenance::default()) + )?), }), "addmultprob" => Ok(Self { - ctx: ContextEnum::AddMultProb(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance(add_mult_prob::AddMultProbProvenance::default()))?), + ctx: ContextEnum::AddMultProb(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(add_mult_prob::AddMultProbProvenance::default()) + )?), }), "topkproofs" => Ok(Self { - ctx: ContextEnum::TopKProofs(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance(top_k_proofs::TopKProofsProvenance::new(k)))?), + ctx: ContextEnum::TopKProofs(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(top_k_proofs::TopKProofsProvenance::new(k)) + )?), }), "topbottomkclauses" => Ok(Self { - ctx: ContextEnum::TopBottomKClauses(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance( - top_bottom_k_clauses::TopBottomKClausesProvenance::new(k), - ))?), + ctx: ContextEnum::TopBottomKClauses(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(top_bottom_k_clauses::TopBottomKClausesProvenance::new(k),) + )?), }), "diffminmaxprob" => Ok(Self { - ctx: ContextEnum::DiffMinMaxProb(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance( - diff_min_max_prob::DiffMinMaxProbProvenance::default(), - ))?), + ctx: ContextEnum::DiffMinMaxProb(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(diff_min_max_prob::DiffMinMaxProbProvenance::default(),) + )?), }), "diffaddmultprob" => Ok(Self { - ctx: ContextEnum::DiffAddMultProb(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance( - diff_add_mult_prob::DiffAddMultProbProvenance::default(), - ))?), + ctx: ContextEnum::DiffAddMultProb(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(diff_add_mult_prob::DiffAddMultProbProvenance::default(),) + )?), }), "diffnandmultprob" => Ok(Self { - ctx: ContextEnum::DiffNandMultProb(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance( - diff_nand_mult_prob::DiffNandMultProbProvenance::default(), - ))?), + ctx: ContextEnum::DiffNandMultProb(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(diff_nand_mult_prob::DiffNandMultProbProvenance::default(),) + )?), }), "diffmaxmultprob" => Ok(Self { - ctx: ContextEnum::DiffMaxMultProb(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance( - diff_max_mult_prob::DiffMaxMultProbProvenance::default(), - ))?), + ctx: ContextEnum::DiffMaxMultProb(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(diff_max_mult_prob::DiffMaxMultProbProvenance::default(),) + )?), }), "diffnandminprob" => Ok(Self { - ctx: ContextEnum::DiffNandMinProb(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance( - diff_nand_min_prob::DiffNandMinProbProvenance::default(), - ))?), + ctx: ContextEnum::DiffNandMinProb(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(diff_nand_min_prob::DiffNandMinProbProvenance::default(),) + )?), }), "diffsamplekproofs" => Ok(Self { - ctx: ContextEnum::DiffSampleKProofs(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance( - diff_sample_k_proofs::DiffSampleKProofsProvenance::new(k), - ))?), + ctx: ContextEnum::DiffSampleKProofs(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(diff_sample_k_proofs::DiffSampleKProofsProvenance::new(k),) + )?), }), "difftopkproofs" => Ok(Self { - ctx: ContextEnum::DiffTopKProofs(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance( - diff_top_k_proofs::DiffTopKProofsProvenance::new(k), - ))?), - }), - "difftopkproofsindiv" => Ok(Self { - ctx: ContextEnum::DiffTopKProofsIndiv(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance( - diff_top_k_proofs_indiv::DiffTopKProofsIndivProvenance::new(k), - ))?), + ctx: ContextEnum::DiffTopKProofs(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(diff_top_k_proofs::DiffTopKProofsProvenance::new(k),) + )?), }), "difftopbottomkclauses" => Ok(Self { - ctx: ContextEnum::DiffTopBottomKClauses(match_context_except_custom!(&self.ctx, c, c.clone_with_new_provenance( - diff_top_bottom_k_clauses::DiffTopBottomKClausesProvenance::new(k), - ))?), + ctx: ContextEnum::DiffTopBottomKClauses(match_context_except_custom!( + &self.ctx, + c, + c.clone_with_new_provenance(diff_top_bottom_k_clauses::DiffTopBottomKClausesProvenance::new(k),) + )?), }), "custom" => Err(BindingError::CustomProvenanceUnsupported), p => Err(BindingError::UnknownProvenance(p.to_string())), @@ -363,14 +387,13 @@ impl Context { ContextEnum::TopKProofs(_) => None, ContextEnum::TopBottomKClauses(_) => None, ContextEnum::DiffMinMaxProb(_) => None, - ContextEnum::DiffAddMultProb(c) => Some(c.provenance_context().input_tags()), - ContextEnum::DiffNandMultProb(c) => Some(c.provenance_context().input_tags()), - ContextEnum::DiffMaxMultProb(c) => Some(c.provenance_context().input_tags()), - ContextEnum::DiffNandMinProb(c) => Some(c.provenance_context().input_tags()), - ContextEnum::DiffSampleKProofs(c) => Some(c.provenance_context().input_tags()), - ContextEnum::DiffTopKProofs(c) => Some(c.provenance_context().input_tags()), - ContextEnum::DiffTopKProofsIndiv(c) => Some(c.provenance_context().input_tags()), - ContextEnum::DiffTopBottomKClauses(c) => Some(c.provenance_context().input_tags()), + ContextEnum::DiffAddMultProb(c) => Some(c.provenance_context().input_tags().into_vec()), + ContextEnum::DiffNandMultProb(c) => Some(c.provenance_context().input_tags().into_vec()), + ContextEnum::DiffMaxMultProb(c) => Some(c.provenance_context().input_tags().into_vec()), + ContextEnum::DiffNandMinProb(c) => Some(c.provenance_context().input_tags().into_vec()), + ContextEnum::DiffSampleKProofs(c) => Some(c.provenance_context().input_tags().into_vec()), + ContextEnum::DiffTopKProofs(c) => Some(c.provenance_context().input_tags().into_vec()), + ContextEnum::DiffTopBottomKClauses(c) => Some(c.provenance_context().input_tags().into_vec()), ContextEnum::Custom(_) => None, } } @@ -399,9 +422,6 @@ impl Context { ContextEnum::DiffTopKProofs(c) => { c.provenance_context_mut().set_k(k); } - ContextEnum::DiffTopKProofsIndiv(c) => { - c.provenance_context_mut().set_k(k); - } ContextEnum::DiffTopBottomKClauses(c) => { c.provenance_context_mut().set_k(k); } @@ -478,6 +498,30 @@ impl Context { match_context!(&mut self.ctx, c, add_py_rule(c, rule, tag, attrs)) } + fn add_entity(&mut self, relation: &str, entity_tuple: Vec) -> Result<(), BindingError> { + match_context!(&mut self.ctx, c, { c.add_entity(relation, entity_tuple)? }); + Ok(()) + } + + fn compile_entity( + &mut self, + relation: &str, + entity_tuple: Vec, + ) -> Result>>, BindingError> { + match_context!(&mut self.ctx, c, { + let curr_env = c.runtime_environment().into(); + let facts = c + .compile_entity(relation, entity_tuple)? + .into_iter() + .map(|(relation_name, tuples)| { + let py_objs = tuples.iter().map(|tuple| to_python_tuple(tuple, &curr_env)).collect(); + (relation_name, py_objs) + }) + .collect(); + Ok(facts) + }) + } + /// Register a foreign function fn register_foreign_function(&mut self, f: PyObject) -> Result<(), BindingError> { let ff = PythonForeignFunction::new(f); @@ -492,6 +536,13 @@ impl Context { Ok(()) } + /// Register a foreign attribute + fn register_foreign_attribute(&mut self, attr: PyObject) -> Result<(), BindingError> { + let py_attr = PythonForeignAttribute::new(attr); + match_context!(&mut self.ctx, c, c.register_foreign_attribute(py_attr)?); + Ok(()) + } + /// Execute the program /// /// If the context's ram program is already compiled, the program will be directly executed. @@ -566,8 +617,10 @@ impl Context { /// Has to be called after `ctx.run()`, otherwise the output collection will not be computed. /// Error is returned if the relation is not computed or does not exist. fn relation(&mut self, r: &str) -> Result { - let maybe_coll_enum = match_context!(&mut self.ctx, c, get_output_collection(c, r)); - maybe_coll_enum.map(|c| c.into()) + let (maybe_coll_enum, env) = match_context!(&mut self.ctx, c, { + (get_output_collection(c, r), c.runtime_environment()) + }); + maybe_coll_enum.map(|c| c.to_collection(env)) } /// Get the output collection of the relation, while monitoring the provenance and tags @@ -576,8 +629,10 @@ impl Context { /// Error is returned if the relation is not computed or does not exist. fn relation_with_debug_tag(&mut self, r: &str) -> Result { let m = monitor::DebugTagsMonitor; - let maybe_coll_enum = match_context!(&mut self.ctx, c, get_output_collection_monitor(c, &m, r)); - maybe_coll_enum.map(|c| c.into()) + let (maybe_coll_enum, env) = match_context!(&mut self.ctx, c, { + (get_output_collection_monitor(c, &m, r), c.runtime_environment()) + }); + maybe_coll_enum.map(|c| c.to_collection(env)) } /// Check if the context contains a relation @@ -632,7 +687,7 @@ impl Context { ::Tag: std::marker::Sync + std::marker::Send, { let batch_size = inputs.iter().next().unwrap().1.len(); - let inputs = process_batched_inputs::(inputs, |r| c.relation_type(r))?; + let inputs = process_batched_inputs::(inputs, c.runtime_environment(), |r| c.relation_type(r))?; run_batch_parallel(c, batch_size, inputs, output_relations) } @@ -659,26 +714,26 @@ impl Context { } } -fn add_py_rule( - c: &mut IntegrateContext, +fn add_py_rule

( + c: &mut IntegrateContext, rule: &str, tag: Option<&PyAny>, attrs: Vec, ) -> Result<(), BindingError> where - C: PythonProvenance, + P: PythonProvenance, { - let tag: Option = C::process_optional_py_tag(tag)?; + let tag: Option = P::process_optional_py_tag(tag)?; c.add_rule_with_options(rule, tag, attrs)?; Ok(()) } -fn add_py_facts(c: &mut IntegrateContext, relation: &str, elems: &PyList) -> Result<(), BindingError> +fn add_py_facts

(c: &mut IntegrateContext, relation: &str, elems: &PyList) -> Result<(), BindingError> where - C: PythonProvenance, + P: PythonProvenance, { if let Some(tuple_type) = c.relation_type(relation) { - let tuples = C::process_typed_py_facts(elems, &tuple_type)?; + let tuples = P::process_typed_py_facts(elems, &tuple_type, c.runtime_environment())?; c.add_facts(relation, tuples, false)?; Ok(()) } else { @@ -686,12 +741,12 @@ where } } -fn check_py_tuple(c: &IntegrateContext, relation: &str, py_tup: &PyAny) -> Result +fn check_py_tuple

(c: &IntegrateContext, relation: &str, py_tup: &PyAny) -> Result where - C: PythonProvenance, + P: PythonProvenance, { if let Some(tuple_type) = c.relation_type(relation) { - Ok(from_python_tuple(py_tup, &tuple_type).is_ok()) + Ok(from_python_tuple(py_tup, &tuple_type, &c.runtime_environment().into()).is_ok()) } else { Err(BindingError::UnknownRelation(relation.to_string()).into()) } @@ -703,7 +758,7 @@ where { if let Some(tuple_type) = c.relation_type(relation) { for py_tup in py_tups { - if from_python_tuple(py_tup, &tuple_type).is_err() { + if from_python_tuple(py_tup, &tuple_type, &c.runtime_environment().into()).is_err() { return Ok(false); } } @@ -713,12 +768,13 @@ where } } -fn process_batched_inputs( +fn process_batched_inputs( inputs: HashMap>, + env: &RuntimeEnvironment, get_relation_type: G, -) -> PyResult, Tuple)>>)>> +) -> PyResult, Tuple)>>)>> where - C: PythonProvenance, + P: PythonProvenance, G: Fn(&str) -> Option, { inputs @@ -727,7 +783,7 @@ where let tuple_type = get_relation_type(&r).ok_or(BindingError::UnknownRelation(r.clone()))?; let batch = b .into_iter() - .map(|elems| C::process_typed_py_facts(elems, &tuple_type)) + .map(|elems| P::process_typed_py_facts(elems, &tuple_type, env)) .collect::>>()?; Ok((r, batch)) }) @@ -762,7 +818,9 @@ where let computed = temp_ctx .computed_relation(r) .ok_or(BindingError::RelationNotComputed(r.to_string()))?; - Ok(C::to_collection_enum(computed, temp_ctx.provenance_context()).into()) + let collection_enum = C::to_collection_enum(computed, temp_ctx.provenance_context()); + let collection = collection_enum.to_collection(&temp_ctx.runtime_env); + Ok(collection) }) .collect::, _>>() }) @@ -773,7 +831,10 @@ fn is_all_equal + Clone>(i: T) -> bool { i.clone().min() == i.max() } -fn get_output_collection(c: &mut IntegrateContext, r: &str) -> Result, BindingError> +fn get_output_collection( + c: &mut IntegrateContext, + r: &str, +) -> Result, BindingError> where C: PythonProvenance, { diff --git a/etc/scallopy/src/custom_tag.rs b/etc/scallopy/src/custom_tag.rs index c9e5743..236d9a5 100644 --- a/etc/scallopy/src/custom_tag.rs +++ b/etc/scallopy/src/custom_tag.rs @@ -38,9 +38,7 @@ impl provenance::Provenance for CustomProvenance { /// Invoking the provenance's tagging function on the input tag fn tagging_fn(&self, i: Self::InputTag) -> Self::Tag { - Python::with_gil(|py| { - Self::Tag::new(self.0.call_method(py, "tagging_fn", (i,), None).unwrap()) - }) + Python::with_gil(|py| Self::Tag::new(self.0.call_method(py, "tagging_fn", (i,), None).unwrap())) } /// Invoking the provenance's recover function on an internal tag @@ -68,17 +66,42 @@ impl provenance::Provenance for CustomProvenance { } fn zero(&self) -> Self::Tag { - Python::with_gil(|py| Self::Tag::new(self.0.call_method(py, "zero", (), None).unwrap().extract(py).unwrap())) + Python::with_gil(|py| { + Self::Tag::new( + self + .0 + .call_method(py, "zero", (), None) + .expect("Python error in `zero`") + .extract(py) + .expect("Python error in `zero`"), + ) + }) } fn one(&self) -> Self::Tag { - Python::with_gil(|py| Self::Tag::new(self.0.call_method(py, "one", (), None).unwrap().extract(py).unwrap())) + Python::with_gil(|py| { + Self::Tag::new( + self + .0 + .call_method(py, "one", (), None) + .expect("Python error in `one`") + .extract(py) + .expect("Python error in `one`"), + ) + }) } fn add(&self, t1: &Self::Tag, t2: &Self::Tag) -> Self::Tag { Python::with_gil(|py| { let input = (t1.0.clone(), t2.0.clone()); - Self::Tag::new(self.0.call_method(py, "add", input, None).unwrap().extract(py).unwrap()) + Self::Tag::new( + self + .0 + .call_method(py, "add", input, None) + .expect("Python error in `add`") + .extract(py) + .expect("Python error in `add`"), + ) }) } @@ -89,9 +112,9 @@ impl provenance::Provenance for CustomProvenance { self .0 .call_method(py, "mult", input, None) - .unwrap() + .expect("Python error in `mult`") .extract(py) - .unwrap(), + .expect("Python error in `mult`"), ) }) } @@ -103,9 +126,9 @@ impl provenance::Provenance for CustomProvenance { self .0 .call_method(py, "negate", input, None) - .unwrap() + .expect("Python error in `negate`") .extract(py) - .unwrap(), + .expect("Python error in `negate`"), )) }) } diff --git a/etc/scallopy/src/external_tag.rs b/etc/scallopy/src/external_tag.rs new file mode 100644 index 0000000..e14f8c1 --- /dev/null +++ b/etc/scallopy/src/external_tag.rs @@ -0,0 +1,71 @@ +use pyo3::prelude::*; + +use scallop_core::common::tensors::*; + +#[derive(Clone)] +pub struct ExtTag { + pub tag: Py, +} + +impl FromTensor for ExtTag { + #[allow(unused)] + #[cfg(not(feature = "torch-tensor"))] + fn from_tensor(tensor: Tensor) -> Option { + None + } + + #[cfg(feature = "torch-tensor")] + fn from_tensor(tensor: Tensor) -> Option { + use super::torch::*; + Python::with_gil(|py| { + let py_tensor = PyTensor(tensor.tensor); + let py_obj: Py = py_tensor.into_py(py); + let ext_tag: ExtTag = py_obj.into(); + Some(ext_tag) + }) + } +} + +impl From> for ExtTag { + fn from(tag: Py) -> Self { + Self { tag } + } +} + +impl From<&PyAny> for ExtTag { + fn from(tag: &PyAny) -> Self { + Self { tag: tag.into() } + } +} + +impl Into> for ExtTag { + fn into(self) -> Py { + self.tag + } +} + +pub trait ExtTagVec { + fn into_vec(self) -> Vec>; +} + +impl ExtTagVec for Vec { + fn into_vec(self) -> Vec> { + self.into_iter().map(|v| v.tag).collect() + } +} + +pub trait ExtTagOption { + fn into_option(self) -> Option>; +} + +impl ExtTagOption for Option { + fn into_option(self) -> Option> { + self.map(|v| v.tag) + } +} + +impl ExtTagOption for Option<&ExtTag> { + fn into_option(self) -> Option> { + self.map(|v| v.tag.clone()) + } +} diff --git a/etc/scallopy/src/foreign_attribute.rs b/etc/scallopy/src/foreign_attribute.rs new file mode 100644 index 0000000..f560fdd --- /dev/null +++ b/etc/scallopy/src/foreign_attribute.rs @@ -0,0 +1,73 @@ +use pyo3::types::*; +use pyo3::*; + +use scallop_core::compiler::front::attribute::*; +use scallop_core::compiler::front::*; + +use crate::foreign_predicate::PythonForeignPredicate; + +#[derive(Clone)] +pub struct PythonForeignAttribute { + py_attr: PyObject, + name: String, +} + +impl PythonForeignAttribute { + pub fn new(py_attr: PyObject) -> Self { + let name = Python::with_gil(|py| { + py_attr + .getattr(py, "name") + .expect("Cannot get foreign predicate name") + .extract(py) + .expect("Foreign predicate name cannot be extracted into String") + }); + + Self { py_attr, name } + } + + pub fn process_action(&self, py: Python, result: Py) -> AttributeAction { + let name: String = result.getattr(py, "name").unwrap().extract(py).unwrap(); + match name.as_str() { + "multiple" => { + let py_actions: Vec> = result.getattr(py, "actions").unwrap().extract(py).unwrap(); + let actions = py_actions + .into_iter() + .map(|py_action| self.process_action(py.clone(), py_action)) + .collect(); + AttributeAction::Multiple(actions) + } + "remove_item" => AttributeAction::RemoveItem, + "no_action" => AttributeAction::Nothing, + "error" => { + let msg: String = result.getattr(py, "msg").unwrap().extract(py).unwrap(); + AttributeAction::Error(msg) + } + "register_foreign_predicate" => { + let py_fp: Py = result.getattr(py, "foreign_predicate").unwrap().extract(py).unwrap(); + let fp = PythonForeignPredicate::new(py_fp); + AttributeAction::Context(Box::new(move |ctx| { + ctx + .register_foreign_predicate(fp) + .expect("Cannot register foreign predicate"); + })) + } + n => panic!("Unknown action `{}`", n), + } + } +} + +impl AttributeProcessor for PythonForeignAttribute { + fn name(&self) -> String { + self.name.clone() + } + + fn apply(&self, item: &ast::Item, attr: &ast::Attribute) -> Result { + Python::with_gil(|py| { + let item_py = pythonize::pythonize(py, item).unwrap(); + let attr_py = pythonize::pythonize(py, attr).unwrap(); + let args = PyTuple::new(py, vec![item_py, attr_py]); + let result = self.py_attr.call_method(py, "apply", args, None).unwrap(); + Ok(self.process_action(py, result)) + }) + } +} diff --git a/etc/scallopy/src/foreign_function.rs b/etc/scallopy/src/foreign_function.rs index 2020f5f..5427350 100644 --- a/etc/scallopy/src/foreign_function.rs +++ b/etc/scallopy/src/foreign_function.rs @@ -5,6 +5,7 @@ use scallop_core::common::foreign_function::*; use scallop_core::common::type_family::*; use scallop_core::common::value::*; use scallop_core::common::value_type::*; +use scallop_core::runtime::env::*; use super::tuple::*; @@ -80,7 +81,9 @@ impl ForeignFunction for PythonForeignFunction { .ff .getattr(py, "static_arg_types") .expect("Cannot get foreign function static arg types"); - let static_arg_types: &PyList = static_arg_types.downcast::(py).expect("Cannot cast into PyList"); + let static_arg_types: &PyList = static_arg_types + .downcast::(py) + .expect("Cannot cast into PyList"); static_arg_types.len() }) } @@ -91,7 +94,9 @@ impl ForeignFunction for PythonForeignFunction { .ff .getattr(py, "static_arg_types") .expect("Cannot get foreign function static arg types"); - let static_arg_types: &PyList = static_arg_types.downcast::(py).expect("Cannot cast into PyList"); + let static_arg_types: &PyList = static_arg_types + .downcast::(py) + .expect("Cannot cast into PyList"); let param_type: PyObject = static_arg_types .get_item(i) .expect("Cannot get i-th param") @@ -162,7 +167,7 @@ impl ForeignFunction for PythonForeignFunction { }) } - fn execute(&self, args: Vec) -> Option { + fn execute_with_env(&self, env: &RuntimeEnvironment, args: Vec) -> Option { let ty = self.infer_return_type(&args); // Actually run the function @@ -176,16 +181,22 @@ impl ForeignFunction for PythonForeignFunction { .expect("Cannot extract function"); // Construct the arguments - let args: Vec> = args.iter().map(to_python_value).collect(); + let args: Vec> = args.iter().map(|a| to_python_value(a, &env.into())).collect(); let args_tuple = PyTuple::new(py, args); // Invoke the function - let maybe_result = func.call1(py, args_tuple).ok(); + let maybe_result = match func.call1(py, args_tuple) { + Ok(result) => Some(result), + Err(err) => { + eprintln!("{}", err); + None + } + }; // Turn the result back to Scallop value if let Some(result) = maybe_result { let result: &PyAny = result.extract(py).expect(""); - from_python_value(result, &ty).ok() + from_python_value(result, &ty, &env.into()).ok() } else { None } diff --git a/etc/scallopy/src/foreign_predicate.rs b/etc/scallopy/src/foreign_predicate.rs index bc9658a..fc2440f 100644 --- a/etc/scallopy/src/foreign_predicate.rs +++ b/etc/scallopy/src/foreign_predicate.rs @@ -1,14 +1,15 @@ use pyo3::types::*; use pyo3::*; +use scallop_core::common::foreign_predicate::*; +use scallop_core::common::input_tag::*; use scallop_core::common::tuple_type::*; -use scallop_core::common::value_type::*; use scallop_core::common::value::*; -use scallop_core::common::input_tag::*; -use scallop_core::common::foreign_predicate::*; +use scallop_core::common::value_type::*; +use scallop_core::runtime::env::*; -use super::tuple::*; use super::tag::*; +use super::tuple::*; #[derive(Clone)] pub struct PythonForeignPredicate { @@ -21,8 +22,7 @@ pub struct PythonForeignPredicate { impl PythonForeignPredicate { pub fn new(fp: PyObject) -> Self { let name = Python::with_gil(|py| { - fp - .getattr(py, "name") + fp.getattr(py, "name") .expect("Cannot get foreign predicate name") .extract(py) .expect("Foreign predicate name cannot be extracted into String") @@ -37,10 +37,17 @@ impl PythonForeignPredicate { .expect("Cannot extract function into PyObject"); // Invoke the function - let py_types: Vec = func.call0(py).expect("Cannot call function").extract(py).expect("Cannot extract into PyList"); + let py_types: Vec = func + .call0(py) + .expect("Cannot call function") + .extract(py) + .expect("Cannot extract into PyList"); // Convert the Python types into Scallop types - py_types.into_iter().map(|py_type| py_param_type_to_fp_param_type(py_type, py)).collect() + py_types + .into_iter() + .map(|py_type| py_param_type_to_fp_param_type(py_type, py)) + .collect() }); let num_bounded: usize = Python::with_gil(|py| { @@ -51,7 +58,11 @@ impl PythonForeignPredicate { .expect("Cannot extract function into PyObject"); // Invoke the function - func.call0(py).expect("Cannot call function").extract(py).expect("Cannot extract into usize") + func + .call0(py) + .expect("Cannot call function") + .extract(py) + .expect("Cannot extract into usize") }); Self { @@ -68,53 +79,71 @@ impl PythonForeignPredicate { } impl ForeignPredicate for PythonForeignPredicate { - fn name(&self) -> String { - self.name.clone() - } + fn name(&self) -> String { + self.name.clone() + } - fn arity(&self) -> usize { - self.types.len() - } + fn arity(&self) -> usize { + self.types.len() + } - fn argument_type(&self, i: usize) -> ValueType { - self.types[i].clone() - } + fn argument_type(&self, i: usize) -> ValueType { + self.types[i].clone() + } - fn num_bounded(&self) -> usize { - self.num_bounded - } + fn num_bounded(&self) -> usize { + self.num_bounded + } - fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { - Python::with_gil(|py| { - // Construct the arguments - let args: Vec> = bounded.iter().map(to_python_value).collect(); - let args_tuple = PyTuple::new(py, args); - - // Invoke the function - let maybe_result = self.fp.call1(py, args_tuple).ok(); - - // Turn the result back to Scallop values - if let Some(result) = maybe_result { - let output_tuple_type = self.output_tuple_type(); - let elements: Vec<(&PyAny, &PyAny)> = result.extract(py).expect("Cannot extract into list of elements"); - let internal: Option> = elements - .into_iter() - .map(|(py_tag, py_tup)| { - let tag = from_python_input_tag(py_tag).ok()?; - let tuple = from_python_tuple(py_tup, &output_tuple_type).ok()?; - Some((tag, tuple.as_values())) - }) - .collect(); - if let Some(e) = internal { - e - } else { - vec![] - } + fn evaluate_with_env(&self, env: &RuntimeEnvironment, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec)> { + Python::with_gil(|py| { + // Construct the arguments + let args: Vec> = bounded.iter().map(|v| to_python_value(v, &env.into())).collect(); + let args_tuple = PyTuple::new(py, args); + + // Invoke the function + let maybe_result = match self.fp.call1(py, args_tuple) { + Ok(result) => Some(result), + Err(err) => { + eprintln!("{}", err); + None + } + }; + + // Turn the result back to Scallop values + if let Some(result) = maybe_result { + let output_tuple_type = self.output_tuple_type(); + let elements: Vec<(&PyAny, &PyAny)> = result.extract(py).expect("Cannot extract into list of elements"); + let internal: Option> = elements + .into_iter() + .map(|(py_tag, py_tup)| { + let tag = match from_python_input_tag(py_tag) { + Ok(tag) => tag, + Err(err) => { + eprintln!("Error when parsing tag: {}", err); + return None; + } + }; + let tuple = match from_python_tuple(py_tup, &output_tuple_type, &env.into()) { + Ok(tuple) => tuple, + Err(err) => { + eprintln!("Error when parsing tuple: {}", err); + return None; + } + }; + Some((tag, tuple.as_values())) + }) + .collect(); + if let Some(e) = internal { + e } else { vec![] } - }) - } + } else { + vec![] + } + }) + } } fn py_param_type_to_fp_param_type(obj: PyObject, py: Python<'_>) -> ValueType { diff --git a/etc/scallopy/src/io.rs b/etc/scallopy/src/io.rs index 7c76243..9ba6f06 100644 --- a/etc/scallopy/src/io.rs +++ b/etc/scallopy/src/io.rs @@ -8,6 +8,8 @@ pub struct CSVFileOptions { pub deliminator: Option, pub has_header: bool, pub has_probability: bool, + pub keys: Option>, + pub fields: Option>, } impl CSVFileOptions { @@ -17,6 +19,8 @@ impl CSVFileOptions { deliminator: None, has_header: false, has_probability: false, + keys: None, + fields: None, } } } @@ -29,6 +33,12 @@ impl Into for CSVFileOptions { } kw_args.insert("has_header".to_string(), self.has_header.into()); kw_args.insert("has_probability".to_string(), self.has_probability.into()); + if let Some(keys) = self.keys { + kw_args.insert("keys".to_string(), keys.into()); + } + if let Some(fields) = self.fields { + kw_args.insert("fields".to_string(), fields.into()); + } // Get attribute Attribute { diff --git a/etc/scallopy/src/lib.rs b/etc/scallopy/src/lib.rs index 1deecc0..1c5556e 100644 --- a/etc/scallopy/src/lib.rs +++ b/etc/scallopy/src/lib.rs @@ -1,14 +1,21 @@ mod collection; +mod config; mod context; mod custom_tag; mod error; +mod external_tag; +mod foreign_attribute; mod foreign_function; mod foreign_predicate; mod io; mod provenance; +mod runtime; mod tag; mod tuple; +#[cfg(feature = "torch-tensor")] +mod torch; + use pyo3::prelude::*; use collection::*; @@ -16,8 +23,14 @@ use context::*; #[pymodule] fn scallopy(_py: Python, m: &PyModule) -> PyResult<()> { + // Configurations + m.add_function(wrap_pyfunction!(config::torch_tensor_enabled, m).unwrap())?; + + // Add classes m.add_class::()?; m.add_class::()?; m.add_class::()?; + + // Ok Ok(()) } diff --git a/etc/scallopy/src/provenance.rs b/etc/scallopy/src/provenance.rs index 27de2fd..d9b0cf3 100644 --- a/etc/scallopy/src/provenance.rs +++ b/etc/scallopy/src/provenance.rs @@ -6,11 +6,13 @@ use pyo3::types::*; use scallop_core::common::tuple::*; use scallop_core::common::tuple_type::*; use scallop_core::runtime::dynamic::*; +use scallop_core::runtime::env::*; use scallop_core::runtime::provenance::*; use scallop_core::utils::*; use super::collection::*; use super::custom_tag; +use super::external_tag::*; use super::tuple::*; /// The trait which all provenance contexts used in `scallopy` should implement @@ -18,14 +20,18 @@ use super::tuple::*; /// Contains the functions and default implementations for type conversion from and to python objects pub trait PythonProvenance: Provenance { /// Process a list of python facts while the tuples are typed with `tuple_type` - fn process_typed_py_facts(facts: &PyList, tuple_type: &TupleType) -> PyResult, Tuple)>> { + fn process_typed_py_facts( + facts: &PyList, + tuple_type: &TupleType, + env: &RuntimeEnvironment, + ) -> PyResult, Tuple)>> { let facts: Vec<&PyAny> = facts.extract()?; facts .into_iter() .map(|fact| { let (maybe_py_tag, py_tup) = Self::split_py_fact(fact)?; let tag = Self::process_optional_py_tag(maybe_py_tag)?; - let tup = from_python_tuple(py_tup, tuple_type)?; + let tup = from_python_tuple(py_tup, tuple_type, &env.into())?; Ok((tag, tup)) }) .collect::>>() @@ -166,10 +172,10 @@ impl PythonProvenance for top_bottom_k_clauses::TopBottomKClausesProvenance, ArcFamily> { +impl PythonProvenance for diff_min_max_prob::DiffMinMaxProbProvenance { fn process_py_tag(tag: &PyAny) -> PyResult> { let prob: f64 = tag.extract()?; - let tag: Py = tag.into(); + let tag: ExtTag = tag.into(); Ok(Some((prob, Some(tag)).into())) } @@ -182,17 +188,17 @@ impl PythonProvenance for diff_min_max_prob::DiffMinMaxProbProvenance, fn to_output_py_tag(tag: &Self::OutputTag) -> Py { match tag.2 { - 1 => Python::with_gil(|py| (1, tag.3.clone().unwrap()).to_object(py)), + 1 => Python::with_gil(|py| (1, tag.3.as_ref().into_option().unwrap()).to_object(py)), 0 => Python::with_gil(|py| (0, tag.0).to_object(py)), - _ => Python::with_gil(|py| (-1, tag.3.clone().unwrap()).to_object(py)), + _ => Python::with_gil(|py| (-1, tag.3.as_ref().into_option().unwrap()).to_object(py)), } } } -impl PythonProvenance for diff_add_mult_prob::DiffAddMultProbProvenance, ArcFamily> { +impl PythonProvenance for diff_add_mult_prob::DiffAddMultProbProvenance { fn process_py_tag(tag: &PyAny) -> PyResult> { let prob: f64 = tag.extract()?; - let tag: Py = tag.into(); + let tag: ExtTag = tag.into(); Ok(Some((prob, Some(tag)).into())) } @@ -208,10 +214,10 @@ impl PythonProvenance for diff_add_mult_prob::DiffAddMultProbProvenance, ArcFamily> { +impl PythonProvenance for diff_nand_mult_prob::DiffNandMultProbProvenance { fn process_py_tag(tag: &PyAny) -> PyResult> { let prob: f64 = tag.extract()?; - let tag: Py = tag.into(); + let tag: ExtTag = tag.into(); Ok(Some((prob, Some(tag)).into())) } @@ -227,10 +233,10 @@ impl PythonProvenance for diff_nand_mult_prob::DiffNandMultProbProvenance, ArcFamily> { +impl PythonProvenance for diff_max_mult_prob::DiffMaxMultProbProvenance { fn process_py_tag(tag: &PyAny) -> PyResult> { let prob: f64 = tag.extract()?; - let tag: Py = tag.into(); + let tag: ExtTag = tag.into(); Ok(Some((prob, Some(tag)).into())) } @@ -246,10 +252,10 @@ impl PythonProvenance for diff_max_mult_prob::DiffMaxMultProbProvenance, ArcFamily> { +impl PythonProvenance for diff_nand_min_prob::DiffNandMinProbProvenance { fn process_py_tag(tag: &PyAny) -> PyResult> { let prob: f64 = tag.extract()?; - let tag: Py = tag.into(); + let tag: ExtTag = tag.into(); Ok(Some((prob, Some(tag)).into())) } @@ -265,11 +271,11 @@ impl PythonProvenance for diff_nand_min_prob::DiffNandMinProbProvenance, ArcFamily> { +impl PythonProvenance for diff_sample_k_proofs::DiffSampleKProofsProvenance { fn process_py_tag(tag: &PyAny) -> PyResult> { let tag_disj_id: (&PyAny, Option) = tag.extract()?; if let Some(prob) = tag_disj_id.0.extract()? { - let tag: Py = tag_disj_id.0.into(); + let tag: ExtTag = tag_disj_id.0.into(); Ok(Some((prob, tag, tag_disj_id.1).into())) } else { Ok(None) @@ -288,11 +294,11 @@ impl PythonProvenance for diff_sample_k_proofs::DiffSampleKProofsProvenance, ArcFamily> { +impl PythonProvenance for diff_top_k_proofs::DiffTopKProofsProvenance { fn process_py_tag(tag: &PyAny) -> PyResult> { let tag_disj_id: (&PyAny, Option) = tag.extract()?; if let Some(prob) = tag_disj_id.0.extract()? { - let tag: Py = tag_disj_id.0.into(); + let tag: ExtTag = tag_disj_id.0.into(); Ok(Some((prob, tag, tag_disj_id.1).into())) } else { Ok(None) @@ -311,34 +317,11 @@ impl PythonProvenance for diff_top_k_proofs::DiffTopKProofsProvenance, } } -impl PythonProvenance for diff_top_k_proofs_indiv::DiffTopKProofsIndivProvenance, ArcFamily> { - fn process_py_tag(tag: &PyAny) -> PyResult> { - let tag_disj_id: (&PyAny, Option) = tag.extract()?; - if let Some(prob) = tag_disj_id.0.extract()? { - let tag: Py = tag_disj_id.0.into(); - Ok(Some((prob, tag, tag_disj_id.1).into())) - } else { - Ok(None) - } - } - - fn to_collection_enum(col: Arc>, ctx: &Self) -> CollectionEnum { - CollectionEnum::DiffTopKProofsIndiv { - collection: col, - tags: ctx.storage.clone_rc(), - } - } - - fn to_output_py_tag(tag: &Self::OutputTag) -> Py { - Python::with_gil(|py| (tag.k, tag.proofs.clone()).to_object(py)) - } -} - -impl PythonProvenance for diff_top_bottom_k_clauses::DiffTopBottomKClausesProvenance, ArcFamily> { +impl PythonProvenance for diff_top_bottom_k_clauses::DiffTopBottomKClausesProvenance { fn process_py_tag(tag: &PyAny) -> PyResult> { let tag_disj_id: (&PyAny, Option) = tag.extract()?; if let Some(prob) = tag_disj_id.0.extract()? { - let tag: Py = tag_disj_id.0.into(); + let tag: ExtTag = tag_disj_id.0.into(); Ok(Some((prob, tag, tag_disj_id.1).into())) } else { Ok(None) diff --git a/etc/scallopy/src/runtime.rs b/etc/scallopy/src/runtime.rs new file mode 100644 index 0000000..13b1069 --- /dev/null +++ b/etc/scallopy/src/runtime.rs @@ -0,0 +1,16 @@ +use scallop_core::runtime::env::*; + +#[derive(Clone)] +pub struct PythonRuntimeEnvironment { + pub symbol_registry: SymbolRegistry2, + pub tensor_registry: TensorRegistry2, +} + +impl<'a> From<&'a RuntimeEnvironment> for PythonRuntimeEnvironment { + fn from(env: &'a RuntimeEnvironment) -> Self { + Self { + symbol_registry: env.symbol_registry.clone(), + tensor_registry: env.tensor_registry.clone(), + } + } +} diff --git a/etc/scallopy/src/tag.rs b/etc/scallopy/src/tag.rs index c15c804..40ee8c7 100644 --- a/etc/scallopy/src/tag.rs +++ b/etc/scallopy/src/tag.rs @@ -25,12 +25,16 @@ enum PythonInputTag<'a> { } pub fn from_python_input_tag(tag: &PyAny) -> Result { - let py_input_tag: PythonInputTag = tag.extract()?; - match py_input_tag { - PythonInputTag::Bool(b) => Ok(DynamicInputTag::Bool(b)), - PythonInputTag::Exclusive(e) => Ok(DynamicInputTag::Exclusive(e)), - PythonInputTag::Float(f) => Ok(DynamicInputTag::Float(f)), - PythonInputTag::ExclusiveFloat(f, e) => Ok(DynamicInputTag::ExclusiveFloat(f, e)), - PythonInputTag::CatchAll(_) => Err(BindingError::InvalidInputTag), + let py_input_tag: Option = tag.extract()?; + if let Some(py_input_tag) = py_input_tag { + match py_input_tag { + PythonInputTag::Bool(b) => Ok(DynamicInputTag::Bool(b)), + PythonInputTag::Exclusive(e) => Ok(DynamicInputTag::Exclusive(e)), + PythonInputTag::Float(f) => Ok(DynamicInputTag::Float(f)), + PythonInputTag::ExclusiveFloat(f, e) => Ok(DynamicInputTag::ExclusiveFloat(f, e)), + PythonInputTag::CatchAll(_) => Err(BindingError::InvalidInputTag), + } + } else { + Ok(DynamicInputTag::None) } } diff --git a/etc/scallopy/src/torch.rs b/etc/scallopy/src/torch.rs new file mode 100644 index 0000000..de8b2fb --- /dev/null +++ b/etc/scallopy/src/torch.rs @@ -0,0 +1,45 @@ +use pyo3::prelude::*; +use pyo3::{ + exceptions::{PyTypeError, PyValueError}, + AsPyPointer, +}; + +use tch; + +/// A wrapper of pytorch tensor +pub struct PyTensor(pub tch::Tensor); + +impl std::ops::Deref for PyTensor { + type Target = tch::Tensor; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +pub fn wrap_tch_err(err: tch::TchError) -> PyErr { + PyErr::new::(format!("{err:?}")) +} + +impl<'source> FromPyObject<'source> for PyTensor { + fn extract(ob: &'source PyAny) -> PyResult { + let ptr = ob.as_ptr() as *mut tch::python::CPyObject; + let tensor = unsafe { tch::Tensor::pyobject_unpack(ptr) }; + tensor + .map_err(wrap_tch_err)? + .ok_or_else(|| { + let type_ = ob.get_type(); + PyErr::new::(format!("expected a torch.Tensor, got {type_}")) + }) + .map(PyTensor) + } +} + +impl IntoPy for PyTensor { + fn into_py(self, py: Python<'_>) -> PyObject { + self.0.pyobject_wrap().map_or_else( + |_| py.None(), + |ptr| unsafe { PyObject::from_owned_ptr(py, ptr as *mut pyo3::ffi::PyObject) }, + ) + } +} diff --git a/etc/scallopy/src/tuple.rs b/etc/scallopy/src/tuple.rs index ad3cb0e..804c664 100644 --- a/etc/scallopy/src/tuple.rs +++ b/etc/scallopy/src/tuple.rs @@ -1,15 +1,58 @@ -// use std::rc::Rc; - -use pyo3::exceptions::{PyIndexError, PyTypeError}; -use pyo3::{prelude::*, types::PyTuple}; +use pyo3::exceptions::*; +use pyo3::prelude::*; +use pyo3::types::*; +use scallop_core::common::tensors::*; use scallop_core::common::tuple::Tuple; use scallop_core::common::tuple_type::TupleType; use scallop_core::common::value::Value; use scallop_core::common::value_type::ValueType; use scallop_core::utils; -pub fn from_python_tuple(v: &PyAny, ty: &TupleType) -> PyResult { +use super::runtime::*; + +pub fn to_python_tuple(tup: &Tuple, env: &PythonRuntimeEnvironment) -> Py { + match tup { + Tuple::Tuple(t) => Python::with_gil(|py| { + let values = t.iter().map(|t| to_python_tuple(t, env)).collect::>(); + PyTuple::new(py, values).into() + }), + Tuple::Value(v) => to_python_value(v, env), + } +} + +pub fn to_python_value(val: &Value, env: &PythonRuntimeEnvironment) -> Py { + use Value::*; + Python::with_gil(|py| match val { + I8(i) => i.to_object(py), + I16(i) => i.to_object(py), + I32(i) => i.to_object(py), + I64(i) => i.to_object(py), + I128(i) => i.to_object(py), + ISize(i) => i.to_object(py), + U8(i) => i.to_object(py), + U16(i) => i.to_object(py), + U32(i) => i.to_object(py), + U64(i) => i.to_object(py), + U128(i) => i.to_object(py), + USize(i) => i.to_object(py), + F32(f) => f.to_object(py), + F64(f) => f.to_object(py), + Char(c) => c.to_object(py), + Bool(b) => b.to_object(py), + Str(s) => s.to_object(py), + String(s) => s.to_object(py), + Symbol(s) => env.symbol_registry.get_symbol(*s).to_object(py), + SymbolString(s) => s.to_object(py), + DateTime(d) => d.to_string().to_object(py), + Duration(d) => d.to_string().to_object(py), + Entity(e) => e.to_object(py), + Tensor(t) => tensor_to_py_object(t.clone(), py), + TensorValue(v) => tensor_to_py_object(env.tensor_registry.eval(v), py), + }) +} + +pub fn from_python_tuple(v: &PyAny, ty: &TupleType, env: &PythonRuntimeEnvironment) -> PyResult { match ty { TupleType::Tuple(ts) => { let tup: &PyTuple = v.downcast()?; @@ -19,7 +62,7 @@ pub fn from_python_tuple(v: &PyAny, ty: &TupleType) -> PyResult { .enumerate() .map(|(i, t)| { let e = tup.get_item(i)?; - from_python_tuple(e, t) + from_python_tuple(e, t, env) }) .collect::>>()?; Ok(Tuple::Tuple(elems)) @@ -27,47 +70,11 @@ pub fn from_python_tuple(v: &PyAny, ty: &TupleType) -> PyResult { Err(PyIndexError::new_err("Invalid tuple size")) } } - TupleType::Value(t) => from_python_value(v, t).map(Tuple::Value), - } -} - -pub fn to_python_tuple(tup: &Tuple) -> Py { - match tup { - Tuple::Tuple(t) => { - Python::with_gil(|py| PyTuple::new(py, t.iter().map(to_python_tuple).collect::>()).into()) - } - Tuple::Value(v) => to_python_value(v), + TupleType::Value(t) => from_python_value(v, t, env).map(Tuple::Value), } } -pub fn to_python_value(val: &Value) -> Py { - use Value::*; - match val { - I8(i) => Python::with_gil(|py| i.to_object(py)), - I16(i) => Python::with_gil(|py| i.to_object(py)), - I32(i) => Python::with_gil(|py| i.to_object(py)), - I64(i) => Python::with_gil(|py| i.to_object(py)), - I128(i) => Python::with_gil(|py| i.to_object(py)), - ISize(i) => Python::with_gil(|py| i.to_object(py)), - U8(i) => Python::with_gil(|py| i.to_object(py)), - U16(i) => Python::with_gil(|py| i.to_object(py)), - U32(i) => Python::with_gil(|py| i.to_object(py)), - U64(i) => Python::with_gil(|py| i.to_object(py)), - U128(i) => Python::with_gil(|py| i.to_object(py)), - USize(i) => Python::with_gil(|py| i.to_object(py)), - F32(f) => Python::with_gil(|py| f.to_object(py)), - F64(f) => Python::with_gil(|py| f.to_object(py)), - Char(c) => Python::with_gil(|py| c.to_object(py)), - Bool(b) => Python::with_gil(|py| b.to_object(py)), - Str(s) => Python::with_gil(|py| s.to_object(py)), - String(s) => Python::with_gil(|py| s.to_object(py)), - // RcString(s) => Python::with_gil(|py| s.to_object(py)), - DateTime(d) => Python::with_gil(|py| d.to_string().to_object(py)), - Duration(d) => Python::with_gil(|py| d.to_string().to_object(py)), - } -} - -pub fn from_python_value(v: &PyAny, ty: &ValueType) -> PyResult { +pub fn from_python_value(v: &PyAny, ty: &ValueType, env: &PythonRuntimeEnvironment) -> PyResult { match ty { ValueType::I8 => Ok(Value::I8(v.extract()?)), ValueType::I16 => Ok(Value::I16(v.extract()?)), @@ -85,11 +92,13 @@ pub fn from_python_value(v: &PyAny, ty: &ValueType) -> PyResult { ValueType::F64 => Ok(Value::F64(v.extract()?)), ValueType::Char => Ok(Value::Char(v.extract()?)), ValueType::Bool => Ok(Value::Bool(v.extract()?)), - ValueType::Str => panic!(""), + ValueType::Str => panic!("[Internal Error] Cannot convert python value into static string"), ValueType::String => Ok(Value::String(v.extract()?)), - // ValueType::RcString => Ok(Tuple::Value(Value::RcString(Rc::new( - // v.extract::()?, - // )))), + ValueType::Symbol => { + let symbol_str: String = v.extract()?; + let id = env.symbol_registry.register(symbol_str); + Ok(Value::Symbol(id)) + } ValueType::DateTime => { let dt = utils::parse_date_time_string(v.extract()?).ok_or(PyTypeError::new_err("Cannot parse into DateTime"))?; Ok(Value::DateTime(dt)) @@ -98,5 +107,39 @@ pub fn from_python_value(v: &PyAny, ty: &ValueType) -> PyResult { let dt = utils::parse_duration_string(v.extract()?).ok_or(PyTypeError::new_err("Cannot parse into Duration"))?; Ok(Value::Duration(dt)) } + ValueType::Entity => Ok(Value::Entity(v.extract()?)), + ValueType::Tensor => tensor_from_py_object(v, env), } } + +#[cfg(feature = "torch-tensor")] +fn tensor_from_py_object(pyobj: &PyAny, env: &PythonRuntimeEnvironment) -> PyResult { + use super::torch::PyTensor; + let py_tensor: PyTensor = pyobj.extract()?; + let scl_tensor: Tensor = Tensor::new(py_tensor.0); + let symbol: TensorSymbol = env.tensor_registry.register(scl_tensor); + Ok(Value::TensorValue(symbol.into())) +} + +#[cfg(not(feature = "torch-tensor"))] +#[allow(unused)] +fn tensor_from_py_object(pyobj: &PyAny, env: &PythonRuntimeEnvironment) -> PyResult { + panic!( + "This `scallopy` version is not compiled with tensor support; consider adding `torch-tensor` flag when compiling" + ) +} + +#[cfg(feature = "torch-tensor")] +fn tensor_to_py_object(tensor: Tensor, py: Python<'_>) -> PyObject { + use super::torch::PyTensor; + let py_tensor = PyTensor(tensor.tensor); + py_tensor.into_py(py) +} + +#[cfg(not(feature = "torch-tensor"))] +#[allow(unused)] +fn tensor_to_py_object(tensor: Tensor, py: Python<'_>) -> PyObject { + panic!( + "This `scallopy` version is not compiled with tensor support; consider adding `torch-tensor` flag when compiling" + ) +} diff --git a/etc/scallopy/tests/basics.py b/etc/scallopy/tests/basics.py index f833353..186cf11 100644 --- a/etc/scallopy/tests/basics.py +++ b/etc/scallopy/tests/basics.py @@ -2,7 +2,7 @@ import scallopy -class TestBasics(unittest.TestCase): +class BasicTests(unittest.TestCase): def test_edge_path(self): ctx = scallopy.ScallopContext() ctx.add_relation("edge", (int, int)) diff --git a/etc/scallopy/tests/dbio.py b/etc/scallopy/tests/dbio.py new file mode 100644 index 0000000..541cdbf --- /dev/null +++ b/etc/scallopy/tests/dbio.py @@ -0,0 +1,30 @@ +import unittest +import torch + +import scallopy + +class TestIO(unittest.TestCase): + def test_load_csv(self): + ctx = scallopy.Context() + ctx.add_relation("edge", (int, int), load_csv=scallopy.io.CSVFileOptions("core/res/testing/csv/edge.csv")) + ctx.run() + assert list(ctx.relation("edge")) == [(0, 1), (1, 2), (2, 3)] + + def test_load_csv_with_field(self): + ctx = scallopy.Context() + csv_file = scallopy.io.CSVFileOptions("core/res/testing/csv/student.csv", fields=["id", "name", "year"]) + ctx.add_relation("student", (int, str, int), load_csv=csv_file) + ctx.run() + assert list(ctx.relation("student")) == [(1, "alice", 2022), (2, "bob", 2023)] + + def test_load_csv_with_key_and_field(self): + ctx = scallopy.Context() + csv_file = scallopy.io.CSVFileOptions("core/res/testing/csv/student.csv", keys="id", fields=["name", "year"]) + ctx.add_relation("student", (int, scallopy.Symbol, str), load_csv=csv_file) + ctx.run() + assert list(ctx.relation("student")) == [ + (1, "name", "alice"), + (1, "year", "2022"), + (2, "name", "bob"), + (2, "year", "2023"), + ] diff --git a/etc/scallopy/tests/entity.py b/etc/scallopy/tests/entity.py new file mode 100644 index 0000000..4ac56c2 --- /dev/null +++ b/etc/scallopy/tests/entity.py @@ -0,0 +1,126 @@ +import unittest + +import scallopy + +class TestEntity(unittest.TestCase): + def test_entity_1(self): + ctx = scallopy.Context() + ctx.add_program("type Expr = Const(i32) | Add(Expr, Expr)") + ctx.add_relation("root", scallopy.Entity) + ctx.add_rule("eval(e, y) = case e is Const(y)") + ctx.add_rule("eval(e, y1 + y2) = case e is Add(e1, e2) and eval(e1, y1) and eval(e2, y2)") + ctx.add_rule("result(y) = root(e) and eval(e, y)") + ctx.add_entity("root", "Add(Const(5), Add(Const(3), Const(4)))") + ctx.run() + assert list(ctx.relation("result")) == [(12,)] + + def test_entity_2(self): + ctx = scallopy.Context() + ctx.add_program(""" + type Expr = Const(i32) | Add(Expr, Expr) + type root(expr: Expr) + rel eval(e, y) = case e is Const(y) + rel eval(e, y1 + y2) = case e is Add(e1, e2) and eval(e1, y1) and eval(e2, y2) + rel result(y) = root(e) and eval(e, y) + """) + ctx.add_entity("root", "Add(Const(5), Add(Const(3), Const(4)))") + ctx.run() + assert list(ctx.relation("result")) == [(12,)] + + def test_entity_constant_1(self): + ctx = scallopy.Context() + ctx.add_program("""type root(b: bool)""") + ctx.add_entity("root", True) + ctx.run() + assert list(ctx.relation("root")) == [(True,)] + + def test_entity_constant_2(self): + ctx = scallopy.Context() + ctx.add_program("type root(b: String)") + ctx.add_entity("root", "hello world") + ctx.run() + assert list(ctx.relation("root")) == [("hello world",)] + + def test_entity_constant_3(self): + ctx = scallopy.Context() + ctx.add_program("type root(b: i32)") + ctx.add_entity("root", 3) + ctx.run() + assert list(ctx.relation("root")) == [(3,)] + + def test_entity_tuple_1(self): + ctx = scallopy.Context() + ctx.add_program(""" + type Expr = Const(i32) | Add(Expr, Expr) + type root(id: i32, b: Expr) + rel eval(e, y) = case e is Const(y) + rel eval(e, y1 + y2) = case e is Add(e1, e2) and eval(e1, y1) and eval(e2, y2) + rel result(id, y) = root(id, e) and eval(e, y) + """) + ctx.add_entity("root", (1, "Add(Const(3), Const(5))")) + ctx.add_entity("root", (2, "Add(Const(6), Const(5))")) + ctx.run() + assert list(ctx.relation("result")) == [(1, 8), (2, 11)] + + @unittest.expectedFailure + def test_entity_compile_failure_1(self): + # Unexisted variant + ctx = scallopy.Context() + ctx.add_relation("root", scallopy.Entity) + ctx.add_entity("root", "Unexisted(5)") + + @unittest.expectedFailure + def test_entity_compile_failure_2(self): + # Arity mismatch + ctx = scallopy.Context() + ctx.add_program("type Expr = Const(i32) | Add(Expr, Expr)") + ctx.add_relation("root", scallopy.Entity) + ctx.add_entity("root", "Add(Const(5), Const(3), Const(7))") + + def test_entity_forward(self): + forward = scallopy.Module( + program=""" + type Expr = Const(i32) | Add(Expr, Expr) + type root(expr: Expr) + rel eval(e, y) = case e is Const(y) + rel eval(e, y1 + y2) = case e is Add(e1, e2) and eval(e1, y1) and eval(e2, y2) + rel result(y) = root(e) and eval(e, y) + """, + provenance="diffminmaxprob", + output_relation="result") + + results, _ = forward( + entities={ + "root": [ + ["Add(Const(5), Const(3))"], + ["Const(10)"], + ["Add(Add(Const(10), Const(2)), Const(2))"], + ] + } + ) + + assert set(results) == set([(8,), (10,), (14,)]) + + def test_entity_forward_2(self): + forward = scallopy.Module( + program=""" + type Expr = Const(i32) | Add(Expr, Expr) + type root(id: i32, expr: Expr) + rel eval(e, y) = case e is Const(y) + rel eval(e, y1 + y2) = case e is Add(e1, e2) and eval(e1, y1) and eval(e2, y2) + rel result(id, y) = root(id, e) and eval(e, y) + """, + provenance="diffminmaxprob", + output_relation="result") + + results, _ = forward( + entities={ + "root": [ + [(1, "Add(Const(5), Const(3))"), (2, "Const(3)")], + [(1, "Const(10)"), (2, "Add(Add(Const(1), Const(2)), Const(3))")], + [(1, "Add(Add(Const(10), Const(2)), Const(2))")], + ] + } + ) + + assert set(results) == set([(1, 8), (1, 10), (1, 14), (2, 3), (2, 6)]) diff --git a/etc/scallopy/tests/foreign_function.py b/etc/scallopy/tests/foreign_function.py index 61b2bdb..82da5e8 100644 --- a/etc/scallopy/tests/foreign_function.py +++ b/etc/scallopy/tests/foreign_function.py @@ -36,15 +36,15 @@ def test_foreign_string_index_of(self): # First create the foreign function @scallopy.foreign_function - def string_index_of(s1: str, s2: str) -> scallopy.usize: + def my_string_index_of(s1: str, s2: str) -> scallopy.usize: return s1.index(s2) # Then add the context ctx = scallopy.ScallopContext() - ctx.register_foreign_function(string_index_of) + ctx.register_foreign_function(my_string_index_of) ctx.add_relation("S", (str, str)) ctx.add_facts("S", [("hello world", "hello"), ("hello world", "world"), ("hello world", "42")]) - ctx.add_rule("R($string_index_of(a, b)) = S(a, b)") + ctx.add_rule("R($my_string_index_of(a, b)) = S(a, b)") ctx.run() self.assertEqual(list(ctx.relation("R")), [(0,), (6,)]) diff --git a/etc/scallopy/tests/input_mapping.py b/etc/scallopy/tests/input_mapping.py index d6dc401..cee17e2 100644 --- a/etc/scallopy/tests/input_mapping.py +++ b/etc/scallopy/tests/input_mapping.py @@ -164,6 +164,16 @@ def test_retain_k_2(self): def test_retain_k_3(self): _ = scallopy.InputMapping(range(10), retain_k=3, sample_dim=-10) + def test_categorical_retain_k_1(self): + im = scallopy.InputMapping(range(10), retain_k=3, sample_strategy="categorical") + r = im.process_tensor(torch.softmax(torch.randn((10,)), dim=0)) + assert len(r) <= 3 # less than or equal since this is categorical sampling + + def test_categorical_retain_k_2(self): + im = scallopy.InputMapping({0: range(10), 1: range(10)}, retain_k=3, sample_strategy="categorical") + r = im.process_tensor(torch.softmax(torch.randn((10, 10)), dim=1)) + assert len(r) <= 3 # less than or equal since this is categorical sampling + def test_mult_dim_retain_k_1(self): im = scallopy.InputMapping({0: range(5), 1: range(5)}, retain_k=3) r = im.process_tensor(torch.randn((5, 5))) @@ -179,6 +189,18 @@ def test_mult_dim_retain_k_3(self): r = im.process_tensor(torch.randn((5, 3))) assert len(r) == 6 + def test_mult_dim_categorical_retain_k_1(self): + im = scallopy.InputMapping({0: range(5), 1: range(3)}, retain_k=2, sample_dim=0, sample_strategy="categorical") + i = torch.nn.functional.softmax(torch.randn((5, 3)), dim=0) + r = im.process_tensor(i) + assert len(r) <= 2 * 3 # less than or equal since this is categorical sampling + + def test_mult_dim_categorical_retain_k_2(self): + im = scallopy.InputMapping({0: range(5), 1: range(3), 2: range(7)}, retain_k=2, sample_dim=0, sample_strategy="categorical") + i = torch.nn.functional.softmax(torch.randn((5, 3, 7)), dim=0) + r = im.process_tensor(i) + assert len(r) <= 2 * 3 * 7 # less than or equal since this is categorical sampling + def test_retain_threshold_1(self): im = scallopy.InputMapping(range(10), retain_threshold=0.5) t = torch.randn(10) diff --git a/etc/scallopy/tests/tensors.py b/etc/scallopy/tests/tensors.py new file mode 100644 index 0000000..4335e1b --- /dev/null +++ b/etc/scallopy/tests/tensors.py @@ -0,0 +1,108 @@ +import unittest + +import torch +import scallopy + +class TensorTests(unittest.TestCase): + @unittest.skipIf(not scallopy.torch_tensor_enabled(), "not supported in this scallopy version") + def test_tensor_1(self): + x, y = torch.randn(5), torch.randn(5) + s = x.dot(y) + + ctx = scallopy.Context() + ctx.add_relation("r", (int, scallopy.Tensor)) + ctx.add_facts("r", [(1, x), (2, y)]) + ctx.add_rule("y($dot(ta, tb)) = r(1, ta) and r(2, tb)") + + ctx.run() + + result = list(ctx.relation("y"))[0][0] + assert s == result + + @unittest.skipIf(not scallopy.torch_tensor_enabled(), "not supported in this scallopy version") + def test_tensor_2(self): + x, y = torch.randn(5), torch.randn(5) + gt_sum = x + y + + ctx = scallopy.Context() + ctx.add_relation("r", (int, scallopy.Tensor)) + ctx.add_facts("r", [(1, x), (2, y)]) + ctx.add_rule("y(ta + tb) = r(1, ta) and r(2, tb)") + ctx.run() + my_sum = list(ctx.relation("y"))[0][0] + + assert all(gt_sum == my_sum) + + @unittest.skipIf(not scallopy.torch_tensor_enabled(), "not supported in this scallopy version") + def test_tensor_3(self): + x = torch.randn(10) + y = torch.randn(10) + gt_sim = x.dot(y).sigmoid() + + ctx = scallopy.Context(provenance="difftopkproofs") + ctx.add_relation("embed", (int, scallopy.Tensor), non_probabilistic=True) + ctx.add_facts("embed", [(1, x), (2, y)]) + ctx.add_rule("similar(x, y) = embed(x, tx) and embed(y, ty) and x != y and soft_eq(tx, ty)") + ctx.run() + my_sim = list(ctx.relation("similar"))[0][0] + + assert gt_sim.item() == my_sim.item() + + @unittest.skipIf(not scallopy.torch_tensor_enabled(), "not supported in this scallopy version") + def test_tensor_backprop_4(self): + x = torch.randn(10, requires_grad=True) + y = torch.randn(10) + opt = torch.optim.Adam(params=[x], lr=0.1) + gt_initial_sim = x.dot(y).sigmoid() + + ctx = scallopy.Context(provenance="difftopkproofs") + ctx.add_relation("embed", (int, scallopy.Tensor), non_probabilistic=True) + ctx.add_facts("embed", [(1, x), (2, y)]) + ctx.add_rule("similar(x, y) = embed(x, tx) and embed(y, ty) and x != y and soft_eq(tx, ty)") + ctx.run() + my_initial_sim = list(ctx.relation("similar"))[0][0] + + assert gt_initial_sim.item() == my_initial_sim.item() + + # Derive a loss, backward, and step + l = torch.nn.functional.mse_loss(my_initial_sim, torch.tensor(1.0)) + l.backward() + opt.step() + + # New similarity + new_sim = x.dot(y).sigmoid() + assert new_sim > my_initial_sim + + @unittest.skipIf(not scallopy.torch_tensor_enabled(), "not supported in this scallopy version") + def test_tensor_forward_backprop_1(self): + batch_size = 16 + + x = torch.randn((batch_size, 10), requires_grad=True) + y = torch.randn((batch_size, 10), requires_grad=True) + opt = torch.optim.Adam(params=[x, y], lr=0.1) + + scl_module = scallopy.Module( + program=""" + type embedding_1(embed: Tensor) + type embedding_2(embed: Tensor) + rel similar() = embedding_1(t1) and embedding_2(t2) and soft_eq(t1, t2) + query similar + """, + non_probabilistic=["embedding_1", "embedding_2"], + output_relation="similar", + output_mapping=(), + dispatch="serial") + + def step() -> float: + result = scl_module(embedding_1=[[(x[i],)] for i in range(batch_size)], embedding_2=[[(y[i],)] for i in range(batch_size)]) + gt = torch.ones(batch_size) + l = torch.nn.functional.mse_loss(result, gt) + l.backward() + opt.step() + return l.item() + + curr_loss = step() + for i in range(4): + next_loss = step() + assert next_loss < curr_loss + curr_loss = next_loss diff --git a/etc/scallopy/tests/test.py b/etc/scallopy/tests/test.py index 3f3f25e..f67360e 100644 --- a/etc/scallopy/tests/test.py +++ b/etc/scallopy/tests/test.py @@ -3,10 +3,13 @@ from basics import * from configurations import * from convert import * +from dbio import * +from entity import * from failure import * from forward import * from foreign_function import * from input_mapping import * +from tensors import * if __name__ == '__main__': unittest.main() diff --git a/etc/sclc/Cargo.toml b/etc/sclc/Cargo.toml index 3f5d13b..19b661e 100644 --- a/etc/sclc/Cargo.toml +++ b/etc/sclc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sclc-core" -version = "0.1.9" +version = "0.2.0" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/sclc/src/bin/sclc.rs b/etc/sclc/src/bin/sclc.rs index f160fb3..a619d71 100644 --- a/etc/sclc/src/bin/sclc.rs +++ b/etc/sclc/src/bin/sclc.rs @@ -13,7 +13,7 @@ fn main() { Ok(ram) => ram, Err(errs) => { for err in errs { - println!("{}", err); + eprintln!("{}", err); } return; } diff --git a/etc/sclc/src/pylib.rs b/etc/sclc/src/pylib.rs index 71021a7..1d226a2 100644 --- a/etc/sclc/src/pylib.rs +++ b/etc/sclc/src/pylib.rs @@ -56,9 +56,9 @@ pub fn create_pylib( Ok(()) } else { - println!("[Compile Error]"); - println!("stdout: {}", std::str::from_utf8(&output.stdout).unwrap()); - println!("{}", std::str::from_utf8(&output.stderr).unwrap()); + eprintln!("[Compile Error]"); + eprintln!("stdout: {}", std::str::from_utf8(&output.stdout).unwrap()); + eprintln!("{}", std::str::from_utf8(&output.stderr).unwrap()); Ok(()) } } @@ -145,6 +145,7 @@ fn generate_pylib_code( use pyo3::exceptions::*; use pyo3::types::*; use rayon::prelude::*; + use scallop_core::common::tensors::*; use scallop_core::common::tuple::*; use scallop_core::common::tuple_type::*; use scallop_core::common::value::*; @@ -281,9 +282,9 @@ fn generate_context_code() -> TokenStream { Unit(StaticContext), MinMaxProb(StaticContext), AddMultProb(StaticContext), - DiffMinMaxProb(StaticContext, ArcFamily>>), - DiffTopKProofs(StaticContext, ArcFamily>>), - DiffTopBottomKClauses(StaticContext, ArcFamily>>), + DiffMinMaxProb(StaticContext>), + DiffTopKProofs(StaticContext>), + DiffTopBottomKClauses(StaticContext>), } #[pyclass(unsendable, name = "StaticContext")] @@ -312,8 +313,8 @@ fn generate_context_code() -> TokenStream { ContextEnum::MinMaxProb(_) => None, ContextEnum::AddMultProb(_) => None, ContextEnum::DiffMinMaxProb(_) => None, - ContextEnum::DiffTopKProofs(c) => Some(c.prov_ctx.input_tags()), - ContextEnum::DiffTopBottomKClauses(c) => Some(c.prov_ctx.input_tags()), + ContextEnum::DiffTopKProofs(c) => Some(c.prov_ctx.input_tags().into_vec()), + ContextEnum::DiffTopBottomKClauses(c) => Some(c.prov_ctx.input_tags().into_vec()), } } @@ -502,6 +503,13 @@ fn generate_helper_functions() -> TokenStream { Bool(b) => Python::with_gil(|py| b.to_object(py)), Str(s) => Python::with_gil(|py| s.to_object(py)), String(s) => Python::with_gil(|py| s.to_object(py)), + Symbol(_) => unimplemented!(), + SymbolString(_) => unimplemented!(), + DateTime(_) => unimplemented!(), + Duration(_) => unimplemented!(), + Entity(i) => Python::with_gil(|py| i.to_object(py)), + Tensor(_) => unimplemented!(), + TensorValue(_) => unimplemented!(), } } } @@ -544,10 +552,83 @@ fn generate_helper_functions() -> TokenStream { ValueType::Bool => Ok(Tuple::Value(Value::Bool(v.extract()?))), ValueType::Str => panic!("Static reference string cannot be used for Python binding"), ValueType::String => Ok(Tuple::Value(Value::String(v.extract()?))), + ValueType::Symbol => unimplemented!(), + ValueType::DateTime => unimplemented!(), + ValueType::Duration => unimplemented!(), + ValueType::Entity => Ok(Tuple::Value(Value::Entity(v.extract()?))), + ValueType::Tensor => unimplemented!(), }, } } + #[derive(Clone)] + pub struct ExtTag { + pub tag: Py, + } + + impl FromTensor for ExtTag { + #[allow(unused)] + #[cfg(not(feature = "torch-tensor"))] + fn from_tensor(tensor: Tensor) -> Option { + None + } + + #[cfg(feature = "torch-tensor")] + fn from_tensor(tensor: Tensor) -> Option { + use super::torch::*; + Python::with_gil(|py| { + let py_tensor = PyTensor(tensor.tensor); + let py_obj: Py = py_tensor.into_py(py); + let ext_tag: ExtTag = py_obj.into(); + Some(ext_tag) + }) + } + } + + impl From> for ExtTag { + fn from(tag: Py) -> Self { + Self { tag } + } + } + + impl From<&PyAny> for ExtTag { + fn from(tag: &PyAny) -> Self { + Self { tag: tag.into() } + } + } + + impl Into> for ExtTag { + fn into(self) -> Py { + self.tag + } + } + + pub trait ExtTagVec { + fn into_vec(self) -> Vec>; + } + + impl ExtTagVec for Vec { + fn into_vec(self) -> Vec> { + self.into_iter().map(|v| v.tag).collect() + } + } + + pub trait ExtTagOption { + fn into_option(self) -> Option>; + } + + impl ExtTagOption for Option { + fn into_option(self) -> Option> { + self.map(|v| v.tag) + } + } + + impl ExtTagOption for Option<&ExtTag> { + fn into_option(self) -> Option> { + self.map(|v| v.tag.clone()) + } + } + trait PythonProvenance: Provenance { /// Process a list of python facts while the tuples are typed with `tuple_type` fn process_typed_py_facts(facts: &PyList, tuple_type: &TupleType) -> Result, Tuple)>, BindingError> { @@ -616,7 +697,7 @@ fn generate_helper_functions() -> TokenStream { impl PythonProvenance for proofs::ProofsProvenance { fn process_py_tag(disj_id: &PyAny) -> Result, BindingError> { let disj_id: usize = disj_id.extract()?; - Ok(Some(proofs::ProofsInputTag::Exclusive(disj_id))) + Ok(Some(Exclusion::Exclusive(disj_id))) } fn to_output_py_tag(proofs: &Self::OutputTag) -> Py { @@ -681,26 +762,26 @@ fn generate_helper_functions() -> TokenStream { } } - impl PythonProvenance for diff_min_max_prob::DiffMinMaxProbProvenance, ArcFamily> { + impl PythonProvenance for diff_min_max_prob::DiffMinMaxProbProvenance { fn process_py_tag(tag: &PyAny) -> Result, BindingError> { let prob: f64 = tag.extract().map_err(BindingError::from)?; - let tag: Py = tag.into(); + let tag: ExtTag = tag.into(); Ok(Some((prob, Some(tag)).into())) } fn to_output_py_tag(tag: &Self::OutputTag) -> Py { match tag.2 { - 1 => Python::with_gil(|py| (1, tag.3.clone().unwrap()).to_object(py)), + 1 => Python::with_gil(|py| (1, tag.3.as_ref().into_option().unwrap()).to_object(py)), 0 => Python::with_gil(|py| (0, tag.0).to_object(py)), - _ => Python::with_gil(|py| (-1, tag.3.clone().unwrap()).to_object(py)), + _ => Python::with_gil(|py| (-1, tag.3.as_ref().into_option().unwrap()).to_object(py)), } } } - impl PythonProvenance for diff_add_mult_prob::DiffAddMultProbProvenance, ArcFamily> { + impl PythonProvenance for diff_add_mult_prob::DiffAddMultProbProvenance { fn process_py_tag(tag: &PyAny) -> Result, BindingError> { let prob: f64 = tag.extract().map_err(BindingError::from)?; - let tag: Py = tag.into(); + let tag: ExtTag = tag.into(); Ok(Some((prob, Some(tag)).into())) } @@ -709,14 +790,14 @@ fn generate_helper_functions() -> TokenStream { } fn get_input_tags(&self) -> Option>> { - Some(self.input_tags()) + Some(self.input_tags().into_vec()) } } - impl PythonProvenance for diff_nand_mult_prob::DiffNandMultProbProvenance, ArcFamily> { + impl PythonProvenance for diff_nand_mult_prob::DiffNandMultProbProvenance { fn process_py_tag(tag: &PyAny) -> Result, BindingError> { let prob: f64 = tag.extract().map_err(BindingError::from)?; - let tag: Py = tag.into(); + let tag: ExtTag = tag.into(); Ok(Some((prob, Some(tag)).into())) } @@ -725,14 +806,14 @@ fn generate_helper_functions() -> TokenStream { } fn get_input_tags(&self) -> Option>> { - Some(self.input_tags()) + Some(self.input_tags().into_vec()) } } - impl PythonProvenance for diff_max_mult_prob::DiffMaxMultProbProvenance, ArcFamily> { + impl PythonProvenance for diff_max_mult_prob::DiffMaxMultProbProvenance { fn process_py_tag(tag: &PyAny) -> Result, BindingError> { let prob: f64 = tag.extract().map_err(BindingError::from)?; - let tag: Py = tag.into(); + let tag: ExtTag = tag.into(); Ok(Some((prob, Some(tag)).into())) } @@ -741,14 +822,14 @@ fn generate_helper_functions() -> TokenStream { } fn get_input_tags(&self) -> Option>> { - Some(self.input_tags()) + Some(self.input_tags().into_vec()) } } - impl PythonProvenance for diff_nand_min_prob::DiffNandMinProbProvenance, ArcFamily> { + impl PythonProvenance for diff_nand_min_prob::DiffNandMinProbProvenance { fn process_py_tag(tag: &PyAny) -> Result, BindingError> { let prob: f64 = tag.extract()?; - let tag: Py = tag.into(); + let tag: ExtTag = tag.into(); Ok(Some((prob, Some(tag)).into())) } @@ -757,15 +838,15 @@ fn generate_helper_functions() -> TokenStream { } fn get_input_tags(&self) -> Option>> { - Some(self.input_tags()) + Some(self.input_tags().into_vec()) } } - impl PythonProvenance for diff_sample_k_proofs::DiffSampleKProofsProvenance, ArcFamily> { + impl PythonProvenance for diff_sample_k_proofs::DiffSampleKProofsProvenance { fn process_py_tag(tag: &PyAny) -> Result, BindingError> { let tag_disj_id: (&PyAny, Option) = tag.extract()?; if let Some(prob) = tag_disj_id.0.extract()? { - let tag: Py = tag_disj_id.0.into(); + let tag: ExtTag = tag_disj_id.0.into(); Ok(Some((prob, tag, tag_disj_id.1).into())) } else { Ok(None) @@ -777,15 +858,15 @@ fn generate_helper_functions() -> TokenStream { } fn get_input_tags(&self) -> Option>> { - Some(self.input_tags()) + Some(self.input_tags().into_vec()) } } - impl PythonProvenance for diff_top_k_proofs::DiffTopKProofsProvenance, ArcFamily> { + impl PythonProvenance for diff_top_k_proofs::DiffTopKProofsProvenance { fn process_py_tag(tag: &PyAny) -> Result, BindingError> { let tag_disj_id: (&PyAny, Option) = tag.extract()?; if let Some(prob) = tag_disj_id.0.extract()? { - let tag: Py = tag_disj_id.0.into(); + let tag: ExtTag = tag_disj_id.0.into(); Ok(Some((prob, tag, tag_disj_id.1).into())) } else { Ok(None) @@ -797,35 +878,15 @@ fn generate_helper_functions() -> TokenStream { } fn get_input_tags(&self) -> Option>> { - Some(self.input_tags()) - } - } - - impl PythonProvenance for diff_top_k_proofs_indiv::DiffTopKProofsIndivProvenance, ArcFamily> { - fn process_py_tag(tag: &PyAny) -> Result, BindingError> { - let tag_disj_id: (&PyAny, Option) = tag.extract()?; - if let Some(prob) = tag_disj_id.0.extract()? { - let tag: Py = tag_disj_id.0.into(); - Ok(Some((prob, tag, tag_disj_id.1).into())) - } else { - Ok(None) - } - } - - fn to_output_py_tag(tag: &Self::OutputTag) -> Py { - Python::with_gil(|py| (tag.k, tag.proofs.clone()).to_object(py)) - } - - fn get_input_tags(&self) -> Option>> { - Some(self.input_tags()) + Some(self.input_tags().into_vec()) } } - impl PythonProvenance for diff_top_bottom_k_clauses::DiffTopBottomKClausesProvenance, ArcFamily> { + impl PythonProvenance for diff_top_bottom_k_clauses::DiffTopBottomKClausesProvenance { fn process_py_tag(tag: &PyAny) -> Result, BindingError> { let tag_disj_id: (&PyAny, Option) = tag.extract()?; if let Some(prob) = tag_disj_id.0.extract()? { - let tag: Py = tag_disj_id.0.into(); + let tag: ExtTag = tag_disj_id.0.into(); Ok(Some((prob, tag, tag_disj_id.1).into())) } else { Ok(None) @@ -837,7 +898,7 @@ fn generate_helper_functions() -> TokenStream { } fn get_input_tags(&self) -> Option>> { - Some(self.input_tags()) + Some(self.input_tags().into_vec()) } } } diff --git a/etc/sclc/tests/common/check_exec.rs b/etc/sclc/tests/common/check_exec.rs index 6bbe1a6..c66287e 100644 --- a/etc/sclc/tests/common/check_exec.rs +++ b/etc/sclc/tests/common/check_exec.rs @@ -26,7 +26,7 @@ pub fn check_compile_exec_from_program_string(program_name: &str, program_string Ok(ram) => ram, Err(errs) => { for err in errs { - println!("{}", err); + eprintln!("{}", err); } return; } diff --git a/etc/sclc/tests/common/check_pylib.rs b/etc/sclc/tests/common/check_pylib.rs index 2a529e4..dceb88f 100644 --- a/etc/sclc/tests/common/check_pylib.rs +++ b/etc/sclc/tests/common/check_pylib.rs @@ -26,7 +26,7 @@ pub fn check_compile_pylib_from_program_string(program_name: &str, program_strin Ok(ram) => ram, Err(errs) => { for err in errs { - println!("{}", err); + eprintln!("{}", err); } return; } diff --git a/etc/scli/Cargo.toml b/etc/scli/Cargo.toml index afd5a58..5a84284 100644 --- a/etc/scli/Cargo.toml +++ b/etc/scli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "scli" -version = "0.1.9" +version = "0.2.0" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/scli/src/main.rs b/etc/scli/src/main.rs index fb70f7c..1867338 100644 --- a/etc/scli/src/main.rs +++ b/etc/scli/src/main.rs @@ -167,9 +167,7 @@ fn main() -> Result<(), String> { let ctx = provenance::top_bottom_k_clauses::TopBottomKClausesProvenance::::new(opt.top_k); interpret(ctx, &opt.input, integrate_opt, predicate_set, monitor_options) } - _ => { - Err(format!("Unknown provenance semiring `{}`", opt.provenance)) - } + _ => Err(format!("Unknown provenance semiring `{}`", opt.provenance)), } } @@ -180,12 +178,14 @@ fn interpret( predicate_set: PredicateSet, monitor_options: MonitorOptions, ) -> Result<(), String> { - let mut interpret_ctx = match integrate::InterpretContext::<_, RcFamily>::new_from_file_with_options(file_name, prov, opt) { - Ok(ctx) => ctx, - Err(err) => { - return Err(format!("{}", err)); - } - }; + let mut interpret_ctx = + match integrate::InterpretContext::<_, RcFamily>::new_from_file_with_options(file_name, prov, opt) { + Ok(ctx) => ctx, + Err(err) => { + eprintln!("{}", err); + return Err(err.kind().to_string()); + } + }; // Check if we have any specified monitors, and run the program if !monitor_options.needs_monitor() { diff --git a/etc/scli/tests/integration.rs b/etc/scli/tests/integration.rs index 0d89819..7646f47 100644 --- a/etc/scli/tests/integration.rs +++ b/etc/scli/tests/integration.rs @@ -7,7 +7,8 @@ fn file_doesnt_exist_1() -> Result<(), Box> { let mut cmd = Command::cargo_bin("scli")?; cmd.arg("test/file/doesnt/exist"); - cmd.assert() + cmd + .assert() .failure() .stderr(predicate::str::contains("Cannot open file test/file/doesnt/exist")); diff --git a/etc/sclrepl/Cargo.toml b/etc/sclrepl/Cargo.toml index 2b1a778..9c74027 100644 --- a/etc/sclrepl/Cargo.toml +++ b/etc/sclrepl/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sclrepl" -version = "0.1.9" +version = "0.2.0" authors = ["Ziyang Li "] edition = "2018" diff --git a/etc/sclrepl/src/main.rs b/etc/sclrepl/src/main.rs index b8fe75a..0868acd 100644 --- a/etc/sclrepl/src/main.rs +++ b/etc/sclrepl/src/main.rs @@ -67,7 +67,7 @@ fn main() -> std::io::Result<()> { } } -fn run(cmd_args: Options, mut ctx: C) -> std::io::Result<()> +fn run(cmd_args: Options, ctx: C) -> std::io::Result<()> where C: runtime::provenance::Provenance, { @@ -78,7 +78,7 @@ where // Compile context let options = compiler::CompileOptions::from(&cmd_args); let mut front_context = compiler::front::FrontContext::new(); - let mut runtime_env = runtime::env::RuntimeEnvironment::default(); + let runtime_env = runtime::env::RuntimeEnvironment::default(); let mut exec_context = runtime::dynamic::DynamicExecutionContext::<_, RcFamily>::new_with_options(runtime::dynamic::ExecutionOptions { incremental_maintain: true, @@ -112,7 +112,7 @@ where let queries = items .iter() .filter_map(|item| { - if let compiler::front::Item::QueryDecl(q) = item { + if let compiler::front::ast::Item::QueryDecl(q) = item { Some(q.query().create_relation_name()) } else { None @@ -160,7 +160,7 @@ where } // Interpret the ram - match exec_context.incremental_execute(ram, &mut runtime_env, &mut ctx) { + match exec_context.incremental_execute(ram, &runtime_env, &ctx) { Ok(()) => {} Err(e) => { println!("{:?}", e); @@ -169,7 +169,7 @@ where // Print the result for q in &queries { - exec_context.recover(q, &ctx); + exec_context.recover(q, &runtime_env, &ctx); println!("{}: {}", q, exec_context.relation(q).unwrap()); } } diff --git a/etc/vscode-scl/examples/syntax_test.scl b/etc/vscode-scl/examples/syntax_test.scl index a587934..65022b5 100644 --- a/etc/vscode-scl/examples/syntax_test.scl +++ b/etc/vscode-scl/examples/syntax_test.scl @@ -4,9 +4,7 @@ import "other_file.scl" type MyID <: i8 // Alias type and builtin type `&str` -type MyString <: &str type MyString2 <: String -type MyString3 <: Rc // Relation type declaration type digit(x: usize, y: i8) diff --git a/etc/vscode-scl/examples/syntax_test_2.scl b/etc/vscode-scl/examples/syntax_test_2.scl new file mode 100644 index 0000000..604c6a7 --- /dev/null +++ b/etc/vscode-scl/examples/syntax_test_2.scl @@ -0,0 +1,124 @@ +// This file contains syntax for demonstration of VSCode plugin +// The file itself does not compile + +// Import item +import "file.scl" + +// Alias type definition +type MyType = AliasType // This is an inline comment + +// Subtype definition +type SubType <: SuperType + +// Enum type definition; variables are separated by `|` +type EnumType = VAR_A | VAR_B | VAR_C + +// Enum type definition with assigned number +type EnumType2 = VAR_A = 0 | VAR_B = 3 | VAR_C + +// Algebraic data type definition +type AlgebraicDataType = Number(i32) | OtherVariant(MyType, String) + +// ADT definition with new lines +type Expr = Number(i32) + | Var(String) + | Add(Expr, Expr) + +// Relation type definition; arguments are named +type relation(arg1: Type1, arg2: Type2) + +// Relation type definition; arguments are not named +type relation2(i32, i32, String) + +// Attribute +@demand("bf") +type fib(x: i32, y: i32) + +// Multiple attributes; attribute without argument; attribute +@hidden +@my_attr("pos_arg_1", 3, kw_arg_name=5) +@file("my_file.csv", keys=["key_1", "key_2"]) +type my_table(i32, i32, i32) + +// Multiple relation types +type edge(i32, i32), path(a: i32, b: i32) + +// Simple constant definition +const MY_CONST = 3 + +// Entity definition +const MY_CONST_2 = Add(Const(3), Add(Const(5), Const(8))) + +// Multiple constant definitions +const MY_CONST_3 = 1, MY_CONST_4 = Add(Const(3), Add(Const(5))) + +// Multiple constant definitions +const MY_CONST_5 = 1, + MY_CONST_6 = Add(Const(3), Add(Const(5))) + +// Set of facts +rel relation = {(0, 1), (1, 2), (2, 3)} + +// Set of facts with tag +rel relation = {0.01::(0, 1), 0.05::(1, 2), 0.9::(2, 3)} + +// Set of facts with DateTime, Duration, and Symbol +rel relation = {d"1d2h", t"2023/01/01", s"column_name"} + +// One single fact +rel 0.05::relation(0, 1) + +// Multiple facts; entitys may be created inside +// Tags may be floating point or boolean or other constants +rel relation(new Const(), 1), + 0.55::relation(0, 2), + true::relation(0, 3), + relation(MY_CONST, 3, "135135") + +// Rules +rel path(a, b) = edge(a, b) or path(a, c) and edge(c, b) +rel path(a, b) :- edge(a, b) \/ path(a, c) /\ edge(c, b) + +// Rules with tag in the front +rel 0.09::path(a, b) = edge(a, b) + +// Rules with arithmetics inside +rel fib(x, a + b) = fib(x - 1, a) and fib(x - 2, b) + +// Rules with type conversion inside +rel to_string(x, a as String) = case x is Const(a) + +// Rules with if-then-else inside +@demand("bbf") +rel max(a, b, if a > b then a else b) + +// Rules with new entity inside +rel expr_to_string(e, $format("({} {} {})", op1_str, op_str, op2_str)) = + case e is Binary(op, op1, op2) and + expr_to_string(op1, op1_str) and expr_to_string(op2, op2_str) and op_to_string(op, op_str) + +// Rules with aggregation inside +rel num_students(n) = n := count(a: student(p)) +rel num_paths(n) = n := count(a, b: path(a, b)) +rel my_rule(x) = _ := count(a, b: path(a, b) where c: color(c)) +rel my_rule_2() = forall(a, b: path(a, b) and path(a, b) implies edge(a, b)) +rel sample_rule(x) = x := top<3>(a, b: path(a, b)) +rel sample_rule(x) = x := categorical<3>(a, b: path(a, b)) +rel sample_rule(x) = x := uniform<3>(a, b: path(a, b)) +rel sample_rule(x) = x := min[a](a, b: path(a, b)) +rel sample_rule(x) = _ := min[x](s: score(x, s)) + +// Nested aggregation +rel nested_agg(x) = x := count(x: m := max(x: relation(x))) + +// Disjunctive datalog +rel { assign(x, true); assign(x, false) } = var(x) + +// Relations with generic arguments +rel grid(x, y) = range(0, 5, x) and range(0, 5, y) + +// Query a relation +query relation + +// Query using an atom +query my_relation(a, a + b, $hash(1, 3, 5)) diff --git a/etc/vscode-scl/package.json b/etc/vscode-scl/package.json index 598bfa2..a59993f 100644 --- a/etc/vscode-scl/package.json +++ b/etc/vscode-scl/package.json @@ -2,7 +2,7 @@ "name": "scallop", "displayName": "Scallop Language Syntax Highlight", "description": "Scallop Language Support", - "version": "0.0.7", + "version": "0.0.8", "repository": { "type": "git", "url": "https://github.com/liby99/scallop-v2" diff --git a/etc/vscode-scl/syntaxes/scallop.tmLanguage.json b/etc/vscode-scl/syntaxes/scallop.tmLanguage.json index 3d294cc..515070d 100644 --- a/etc/vscode-scl/syntaxes/scallop.tmLanguage.json +++ b/etc/vscode-scl/syntaxes/scallop.tmLanguage.json @@ -1,37 +1,187 @@ { - "$schema": "https://raw.githubusercontent.com/martinring/tmlanguage/master/tmlanguage.json", - "name": "Scallop", + "$schema": "https://raw.githubusercontent.com/martinring/tmlanguage/master/tmlanguage.json", + "name": "Scallop", "scopeName": "source.scl", - "patterns": [ - { - "include": "#comment" + "patterns": [ + { "include": "#comment" }, + { "include": "#attribute" }, + { "include": "#item" }, + { "include": "#aggregation" }, + { "include": "#case_is" }, + { "include": "#keyword" }, + { "include": "#formula_operator" }, + { "include": "#expr_operator" }, + { "include": "#atom" }, + { "include": "#tag" }, + { "include": "#constants" }, + { "include": "#expr" } + ], + "repository": { + "attribute": { + "patterns": [ + { "include": "#attribute_with_arg" }, + { "include": "#simple_attribute" } + ] }, - { - "include": "#strings" + "simple_attribute": { + "match": "(@)([a-zA-Z][a-zA-Z0-9_]*)", + "captures": { + "1": { + "name": "punctuation.definition.decorator.scallop" + }, + "2": { + "name": "support.function.scallop" + } + } }, - { - "include": "#datetime" + "attribute_with_arg": { + "begin": "(@)([a-zA-Z][a-zA-Z0-9_]*)\\(", + "beginCaptures": { + "1": { + "name": "punctuation.definition.decorator.scallop" + }, + "2": { + "name": "support.function.scallop" + } + }, + "end": "\\)", + "patterns": [ + { "include": "#comment" }, + { "include": "#named_argument" }, + { "include": "#constants" } + ] }, - { - "include": "#duration" + "named_argument": { + "match": "([a-zA-Z][a-zA-Z0-9_]*)\\s*(=)", + "captures": { + "1": { + "name": "variable.parameter.scallop" + }, + "2": { + "name": "keyword.operator.scallop" + } + } }, - { - "include": "#chars" + "item": { + "patterns": [ + { "include": "#import" }, + { "include": "#type_decl" }, + { "include": "#constant_decl" }, + { "include": "#relation_decl" }, + { "include": "#query_decl" } + ] + }, + "import": { + "match": "\\b(import)\\b", + "captures": { + "1": { + "name": "keyword.control.scallop" + } + } + }, + "type_decl": { + "patterns": [ + { "include": "#adt_type_decl" }, + { "include": "#adt_variant" }, + { "include": "#enum_type_decl" }, + { "include": "#alias_type_decl" }, + { "include": "#sub_type_decl" }, + { "include": "#relation_type_decl" } + ] }, - { - "include": "#booleans" + "adt_type_decl": { + "begin": "(type)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s*(=)\\s*([a-zA-Z][a-zA-Z0-9_]*)\\s*\\(", + "end": "\\)", + "beginCaptures": { + "1": { + "name": "keyword.control.scallop" + }, + "2": { + "name": "storage.type.scallop" + }, + "3": { + "name": "keyword.operator.scallop" + }, + "4": { + "name": "entity.name.class.scallop" + } + }, + "patterns": [ + { "include": "#comment" }, + { "include": "#type" } + ] }, - { - "include": "#float" + "adt_variant": { + "begin": "(\\|)\\s*([a-zA-Z][a-zA-Z0-9_]*)\\s*\\(", + "end": "\\)", + "beginCaptures": { + "1": { + "name": "keyword.operator.scallop" + }, + "2": { + "name": "entity.name.class.scallop" + } + }, + "patterns": [ + { "include": "#comment" }, + { "include": "#type" } + ] }, - { - "include": "#integer" + "enum_type_decl": { + "match": "(type)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s*(=)\\s*([A-Z][A-Z0-9_]*)\\s*(=|\\|)", + "captures": { + "1": { + "name": "keyword.control.scallop" + }, + "2": { + "name": "storage.type.scallop" + }, + "3": { + "name": "keyword.operator.scallop" + }, + "4": { + "name": "constant.other.caps.scallop" + }, + "5": { + "name": "keyword.operator.scallop" + } + } }, - { - "include": "#builtin_types" + "alias_type_decl": { + "match": "(type)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s*(=)\\s*([a-zA-Z][a-zA-Z0-9_]*)", + "captures": { + "1": { + "name": "keyword.control.scallop" + }, + "2": { + "name": "storage.type.scallop" + }, + "3": { + "name": "keyword.operator.scallop" + }, + "4": { + "name": "storage.type.scallop" + } + } + }, + "sub_type_decl": { + "match": "(type)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s*(<:)\\s*([a-zA-Z][a-zA-Z0-9_]*)", + "captures": { + "1": { + "name": "keyword.control.scallop" + }, + "2": { + "name": "storage.type.scallop" + }, + "3": { + "name": "keyword.operator.scallop" + }, + "4": { + "name": "storage.type.scallop" + } + } }, - { - "comment": "Relation Type Declaration", + "relation_type_decl": { "begin": "(type)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s*\\(", "end": "\\)", "beginCaptures": { @@ -43,52 +193,71 @@ } }, "patterns": [ - { - "comment": "Relation Type Member Binding", - "match": "([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*([a-zA-Z][a-zA-Z0-9_]*)", - "captures": { - "1": { - "name": "variable.parameter.scallop" - }, - "2": { - "name": "storage.type.core.scallop" - } - } - }, - { - "comment": "Relation Type Member", - "name": "storage.type.core.scallop", - "match": "(\\b[a-zA-Z][a-zA-Z0-9_]*\\b)" - } + { "include": "#comment" }, + { "include": "#relation_type_binding" }, + { "include": "#type" } ] }, - { - "comment": "Type Alias Declaration", - "match": "(type)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s+=", + "relation_type_binding": { + "match": "([a-zA-Z][a-zA-Z0-9_]*)\\s*:\\s*([a-zA-Z][a-zA-Z0-9_]*)", "captures": { "1": { - "name": "keyword.control.scallop" + "name": "variable.parameter.scallop" }, "2": { - "name": "storage.type.core.scallop" + "name": "storage.type.scallop" } } }, - { - "comment": "Sub Type Declaration", - "match": "(type)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s+<:", - "captures": { + "type": { + "patterns": [ + { "include": "#basic_type" }, + { "include": "#custom_type" } + ] + }, + "basic_type": { + "match": "(i8|i16|i32|i64|i128|isize|u8|u16|u32|u64|u128|usize|f32|f64|bool|char|&str|String|Symbol|DateTime|Duration|Entity)", + "name": "storage.type.scallop" + }, + "custom_type": { + "match": "([A-Z][a-zA-Z0-9_]*)", + "name": "storage.type.scallop" + }, + "constant_decl": { + "patterns": [ + { "include": "#constant_entity_decl" }, + { "include": "#basic_constant_decl" }, + { "include": "#continue_constant_entity_decl" } + ] + }, + "constant_entity_decl": { + "begin": "(const)\\s+([A-Z][A-Z0-9_]*)(\\s*:\\s*([A-Z][A-Z0-9_]*))?\\s*(=)\\s*([a-zA-Z][a-zA-Z0-9_]*)\\s*\\(", + "end": "\\)", + "beginCaptures": { "1": { "name": "keyword.control.scallop" }, "2": { - "name": "storage.type.core.scallop" + "name": "constant.other.caps.scallop" + }, + "4": { + "name": "storage.type.scallop" + }, + "5": { + "name": "keyword.operator.scallop" + }, + "6": { + "name": "entity.name.class.scallop" } - } + }, + "patterns": [ + { "include": "#comment" }, + { "include": "#entity" }, + { "include": "#constants" } + ] }, - { - "comment": "Constant Declaration with Type", - "match": "(const)\\s+([A-Z][A-Z0-9_]*)\\s*:\\s*([a-zA-Z][a-zA-Z0-9_]*)", + "basic_constant_decl": { + "match": "(const)\\s+([A-Z][A-Z0-9_]*)\\s*(=)", "captures": { "1": { "name": "keyword.control.scallop" @@ -97,142 +266,442 @@ "name": "constant.other.caps.scallop" }, "3": { - "name": "storage.type.core.scallop" + "name": "keyword.operator.scallop" } } }, - { - "include": "#keywords" + "continue_constant_entity_decl": { + "begin": "([A-Z][A-Z0-9_]*)\\s*(=)\\s*([a-zA-Z][a-zA-Z0-9_]*)\\s*\\(", + "end": "\\)", + "beginCaptures": { + "1": { + "name": "constant.other.caps.scallop" + }, + "2": { + "name": "keyword.operator.scallop" + }, + "3": { + "name": "entity.name.class.scallop" + } + }, + "patterns": [ + { "include": "#comment" }, + { "include": "#entity" }, + { "include": "#expr" } + ] + }, + "entity": { + "begin": "([a-zA-Z][a-zA-Z0-9_]*)\\s*\\(", + "end": "\\)", + "beginCaptures": { + "1": { + "name": "entity.name.class.scallop" + } + }, + "patterns": [ + { "include": "#comment" }, + { "include": "#entity" }, + { "include": "#expr" } + ] }, - { - "comment": "As Operator", - "match": "(\\b(as)\\b)", - "name": "keyword.other.scallop" + "relation_decl": { + "patterns": [ + { "include": "#fact_set_decl" } + ] }, - { - "comment": "Aggregator with Argument", - "match": "\\b(min|max)\\s*\\[", + "fact_set_decl": { + "begin": "(rel)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s*(=)\\s*\\{", + "beginCaptures": { + "1": { + "name": "keyword.control.scallop" + }, + "2": { + "name": "entity.name.function.scallop" + }, + "3": { + "name": "keyword.operator.scallop" + } + }, + "end": "\\}", + "patterns": [ + { "include": "#comment" }, + { "include": "#constants" }, + { "include": "#tag" } + ] + }, + "tag": { + "match": "(\\S+)(::)", "captures": { "1": { - "name": "keyword.other.scallop" + "name": "constant.numeric.scallop" + }, + "2": { + "name": "keyword.operator.scallop" } } }, - { - "comment": "Aggregator with sample count", - "match": "\\b(top)\\s*(<)\\s*(\\d+)\\s*(>)\\s*\\(", + "query_decl": { + "match": "(query)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\n", "captures": { "1": { - "name": "keyword.other.scallop" + "name": "keyword.control.scallop" }, "2": { - "name": "punctuation.brackets.angle.scallop" + "name": "entity.name.function.scallop" + } + } + }, + "aggregation": { + "patterns": [ + { "include": "#simple_aggregation" }, + { "include": "#sample_aggregation" }, + { "include": "#argminmax_aggregation" }, + { "include": "#forall_exists_aggregation" } + ] + }, + "simple_aggregation": { + "begin": "([a-zA-Z_][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z_][a-zA-Z0-9_]*))*\\s*(:=|=)\\s*(count|sum|prod|min|max|exists|forall|unique)\\(([a-zA-Z][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z][a-zA-Z0-9_]*))*\\s*:", + "end": "\\)", + "beginCaptures": { + "1": { + "name": "variable.parameter.scallop" }, "3": { - "name": "constant.numeric.integer.decimal.scallop" + "name": "variable.parameter.scallop" }, "4": { - "name": "punctuation.brackets.angle.scallop" + "name": "keyword.operator.scallop" + }, + "5": { + "name": "keyword.other.scallop" + }, + "6": { + "name": "variable.parameter.scallop" + }, + "8": { + "name": "variable.parameter.scallop" } - } + }, + "patterns": [ + { "include": "#comment" }, + { "include": "#where_clause" }, + { "include": "#formula" } + ] }, - { - "comment": "Aggregator", - "match": "\\b(count|sum|prod|min|max|exists|forall|unique)\\s*\\(", - "captures": { + "sample_aggregation": { + "begin": "([a-zA-Z_][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z_][a-zA-Z0-9_]*))*\\s*(:=|=)\\s*(top|categorical|uniform)(<)(\\d+)(>)\\(([a-zA-Z][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z][a-zA-Z0-9_]*))*\\s*:", + "end": "\\)", + "beginCaptures": { "1": { + "name": "variable.parameter.scallop" + }, + "3": { + "name": "variable.parameter.scallop" + }, + "4": { + "name": "keyword.operator.scallop" + }, + "5": { "name": "keyword.other.scallop" + }, + "6": { + "name": "keyword.operator.scallop" + }, + "7": { + "name": "constant.numeric.scallop" + }, + "8": { + "name": "keyword.operator.scallop" + }, + "9": { + "name": "variable.parameter.scallop" + }, + "11": { + "name": "variable.parameter.scallop" } - } + }, + "patterns": [ + { "include": "#comment" }, + { "include": "#where_clause" }, + { "include": "#formula" } + ] }, - { - "comment": "Foreign Function", - "match": "(\\$[a-zA-Z][a-zA-Z0-9_]*)\\s*\\(", - "captures": { + "argminmax_aggregation": { + "begin": "([a-zA-Z_][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z_][a-zA-Z0-9_]*))*\\s*(:=|=)\\s*(min|max)(\\[)([a-zA-Z][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z][a-zA-Z0-9_]*))*(\\])\\(([a-zA-Z][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z][a-zA-Z0-9_]*))*\\s*:", + "end": "\\)", + "beginCaptures": { "1": { - "name": "entity.name.tag.scallop" + "name": "variable.parameter.scallop" + }, + "3": { + "name": "variable.parameter.scallop" + }, + "4": { + "name": "keyword.operator.scallop" + }, + "5": { + "name": "keyword.other.scallop" + }, + "6": { + "name": "keyword.operator.scallop" + }, + "7": { + "name": "variable.parameter.scallop" + }, + "9": { + "name": "variable.parameter.scallop" + }, + "10": { + "name": "keyword.operator.scallop" + }, + "11": { + "name": "variable.parameter.scallop" + }, + "13": { + "name": "variable.parameter.scallop" } - } + }, + "patterns": [ + { "include": "#comment" }, + { "include": "#where_clause" }, + { "include": "#formula" } + ] + }, + "forall_exists_aggregation": { + "begin": "(forall|exists)\\(([a-zA-Z][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z][a-zA-Z0-9_]*))*\\s*:", + "end": "\\)", + "beginCaptures": { + "1": { + "name": "keyword.other.scallop" + }, + "2": { + "name": "variable.parameter.scallop" + }, + "4": { + "name": "variable.parameter.scallop" + } + }, + "patterns": [ + { "include": "#comment" }, + { "include": "#where_clause" }, + { "include": "#formula" } + ] }, - { - "comment": "Atomic Relation", - "match": "[^\\@]\\b((?!((count|sum|prod|min|max|exists|forall|unique|and|or|not)\\b))[a-zA-Z][a-zA-Z0-9_]*)\\s*\\(", + "where_clause": { + "match": "(where)\\s+([a-zA-Z][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z][a-zA-Z0-9_]*))*\\s*:", "captures": { "1": { - "name": "entity.name.function.scallop" + "name": "keyword.control.scallop" + }, + "2": { + "name": "variable.parameter.scallop" + }, + "4": { + "name": "variable.parameter.scallop" } } }, - { - "include": "#constant_set_relation" + "case_is": { + "begin": "(case)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s+(is)\\s+([a-zA-Z][a-zA-Z0-9_]*)\\s*\\(", + "end": "\\)", + "beginCaptures": { + "1": { + "name": "keyword.control.scallop" + }, + "2": { + "name": "variable.parameter.scallop" + }, + "3": { + "name": "keyword.control.scallop" + }, + "4": { + "name": "entity.name.class.scallop" + } + }, + "patterns": [ + { "include": "#comment" }, + { "include": "#entity" }, + { "include": "#expr" } + ] }, - { - "include": "#query_relation" + "formula": { + "patterns": [ + { "include": "#aggregation" }, + { "include": "#case_is" }, + { "include": "#atom" }, + { "include": "#tag" }, + { "include": "#formula_operator" }, + { "include": "#expr_operator" }, + { "include": "#constants" }, + { "include": "#expr" } + ] }, - { - "comment": "Attribute", - "name": "meta.attribute.scallop", - "begin": "\\@([a-zA-Z][a-zA-Z0-9_]*)\\(", + "atom": { + "patterns": [ + { "include": "#specialized_atom" }, + { "include": "#simple_atom" } + ] + }, + "simple_atom": { + "begin": "([a-zA-Z][a-zA-Z0-9_]*)\\s*\\(", "end": "\\)", + "beginCaptures": { + "1": { + "name": "entity.name.function.scallop" + } + }, "patterns": [ - { - "include": "#float" + { "include": "#comment" }, + { "include": "#relation_type_binding" }, + { "include": "#basic_type" }, + { "include": "#expr" } + ] + }, + "specialized_atom": { + "begin": "([a-zA-Z][a-zA-Z0-9_]*)\\s*(<)\\s*([a-zA-Z][a-zA-Z0-9_]*)(\\s*,\\s*([a-zA-Z][a-zA-Z0-9_]*))*\\s*(>)\\s*\\(", + "end": "\\)", + "beginCaptures": { + "1": { + "name": "entity.name.function.scallop" }, - { - "include": "#integer" + "2": { + "name": "punctuation.brackets.angle.scallop" }, - { - "include": "#booleans" + "3": { + "name": "storage.type.scallop" }, - { - "include": "#strings" + "5": { + "name": "storage.type.scallop" }, - { - "include": "#chars" + "6": { + "name": "punctuation.brackets.angle.scallop" + } + }, + "patterns": [ + { "include": "#comment" }, + { "include": "#relation_type_binding" }, + { "include": "#basic_type" }, + { "include": "#expr" } + ] + }, + "expr": { + "patterns": [ + { "include": "#foreign_function_expr" }, + { "include": "#new_entity_expr" }, + { "include": "#constants" }, + { "include": "#expr_operator" }, + { "include": "#variable" } + ] + }, + "foreign_function_expr": { + "begin": "(\\$[a-zA-Z][a-zA-Z0-9_]*)\\s*\\(", + "end": "\\)", + "beginCaptures": { + "1": { + "name": "entity.name.tag.scallop" + } + }, + "patterns": [ + { "include": "#comment" }, + { "include": "#expr" } + ] + }, + "new_entity_expr": { + "begin": "(new)\\s+([a-zA-Z][a-zA-Z0-9]*)\\s*\\(", + "end": "\\)", + "beginCaptures": { + "1": { + "name": "keyword.operator.new" + }, + "2": { + "name": "entity.name.class.scallop" + } + }, + "patterns": [ + { "include": "#comment" }, + { "include": "#expr" } + ] + }, + "variable": { + "match": "([a-zA-Z][a-zA-Z0-9_]*)", + "name": "variable.parameter.scallop" + }, + "keyword": { + "match": "\\b(import|type|const|rel|query)\\b", + "captures": { + "1": { + "name": "keyword.control.scallop" } + } + }, + "formula_operator": { + "patterns": [ + { "include": "#formula_operator_symbol" }, + { "include": "#formula_operator_keywords" } ] }, - { - "comment": "Logical operator", - "name": "keyword.control.scallop", - "match": "(\\b(and|or|not|implies)\\b)" + "formula_operator_symbol": { + "match": "(\\\\/|/\\\\|=>|:-|=|:=|\\|)", + "name": "keyword.operator.scallop" }, - { - "comment": "Logical operator", - "name": "keyword.operator.comparison.scallop", - "match": "(&&|\\|\\||==|!=|\\\\/|/\\\\)" + "formula_operator_keywords": { + "match": "\\b(and|or|not|implies|case|is|where)\\b", + "name": "keyword.control.scallop" }, - { - "comment": "Assignment operator", - "name": "keyword.operator.assignment.scallop", - "match": "(:-|=)" + "expr_operator": { + "patterns": [ + { "include": "#expr_operator_symbol" }, + { "include": "#expr_operator_keywords" }, + { "include": "#as_symbol" } + ] }, - { - "comment": "Arithmetic operator", - "name": "keyword.operator.arithmetic.scallop", - "match": "(!|\\+|-|/|\\*|%)" + "expr_operator_symbol": { + "match": "(&&|\\|\\||==|!=|!|\\+|-|/|\\*|%|<=|>=|<|>)", + "name": "keyword.operator.scallop" }, - { - "comment": "Other comparison operators", - "name": "keyword.operator.comparison.scallop", - "match": "(<=|>=|<|>)" + "expr_operator_keywords": { + "match": "\\b(if|then|else)\\b", + "name": "keyword.control.scallop" }, - { - "comment": "Constant Variable", - "match": "(\\b[A-Z][A-Z0-9_]*\\b)", - "name": "constant.other.caps.scallop" + "as_symbol": { + "match": "\\b(as)\\b", + "name": "keyword.other.scallop" }, - { - "comment": "Parameters", - "name": "variable.parameter.scallop", - "match": "(\\b[a-zA-Z][a-zA-Z0-9_]*\\b)" - } - ], - "repository": { - "strings": { - "name": "string.quoted.double.scallop", - "begin": "\"", - "end": "\"", + "constants": { + "patterns": [ + { "include": "#integer" }, + { "include": "#float" }, + { "include": "#boolean" }, + { "include": "#string" }, + { "include": "#char" }, + { "include": "#datetime" }, + { "include": "#duration" }, + { "include": "#symbol" }, + { "include": "#constant_var" } + ] + }, + "integer": { + "comment": "Integer literal (decimal)", + "name": "constant.numeric.integer.decimal.scallop", + "match": "\\b[0-9][0-9_]*\\b" + }, + "float": { + "comment": "Floating point literal (fraction)", + "name": "constant.numeric.float.scallop", + "match": "\\b[0-9][0-9_]*\\.[0-9][0-9_]*([eE][+-]?[0-9_]+)?\\b" + }, + "boolean": { + "patterns": [ + { + "name": "constant.language.scallop", + "match": "(\\b(true|false)\\b)" + } + ] + }, + "char": { + "name": "string.quoted.single.scallop", + "begin": "'", + "end": "'", "patterns": [ { "name": "constant.character.escape.scallop", @@ -240,9 +709,9 @@ } ] }, - "datetime": { + "string": { "name": "string.quoted.double.scallop", - "begin": "t\"", + "begin": "\"", "end": "\"", "patterns": [ { @@ -251,9 +720,14 @@ } ] }, - "duration": { + "datetime": { "name": "string.quoted.double.scallop", - "begin": "d\"", + "begin": "(t)\"", + "beginCaptures": { + "1": { + "name": "keyword.control.scallop" + } + }, "end": "\"", "patterns": [ { @@ -262,10 +736,15 @@ } ] }, - "chars": { - "name": "string.quoted.single.scallop", - "begin": "'", - "end": "'", + "duration": { + "name": "string.quoted.double.scallop", + "begin": "(d)\"", + "end": "\"", + "beginCaptures": { + "1": { + "name": "keyword.control.scallop" + } + }, "patterns": [ { "name": "constant.character.escape.scallop", @@ -273,60 +752,29 @@ } ] }, - "constant_set_relation": { - "comment": "Constant Set Relation", - "match": "\\b([a-zA-Z][a-zA-Z0-9_]*)\\s*(=|:-)\\s*\\{", - "captures": { - "1": { - "name": "entity.name.function.scallop" - }, - "2": { - "name": "keyword.operator.assignment.scallop" - } - } - }, - "query_relation": { - "comment": "Query relation", - "match": "\\b(query)\\s+([a-zA-Z][a-zA-Z0-9_]*)", - "captures": { + "symbol": { + "name": "string.quoted.double.scallop", + "begin": "(s)\"", + "end": "\"", + "beginCaptures": { "1": { "name": "keyword.control.scallop" - }, - "2": { - "name": "entity.name.function.scallop" } - } - }, - "builtin_types": { - "comment": "Built-in/core type", - "name": "storage.type.core.scallop", - "match": "\\b(i8|i16|i32|i64|i128|isize|u8|u16|u32|u64|u128|usize|f32|f64|char|bool|String)\\b|(\\&str|Rc\\)" - }, - "float": { - "comment": "Floating point literal (fraction)", - "name": "constant.numeric.float.scallop", - "match": "\\b[0-9][0-9_]*\\.[0-9][0-9_]*([eE][+-]?[0-9_]+)?\\b" - }, - "integer": { - "comment": "Integer literal (decimal)", - "name": "constant.numeric.integer.decimal.scallop", - "match": "\\b[0-9][0-9_]*\\b" - }, - "keywords": { + }, "patterns": [ { - "name": "keyword.control.scallop", - "match": "(\\b(import|type|const|rel|relation|if|then|else|where)\\b)" + "name": "constant.character.escape.scallop", + "match": "\\\\." } ] }, - "booleans": { - "patterns": [ - { - "name": "constant.language.scallop", - "match": "(\\b(true|false)\\b)" + "constant_var": { + "match": "\\b([A-Z][A-Z0-9_]*)\\b", + "captures": { + "1": { + "name": "constant.other.caps.scallop" } - ] + } }, "comment": { "patterns": [ @@ -388,5 +836,5 @@ } ] } - } + } } diff --git a/examples/datalog/count.scl b/examples/datalog/count.scl new file mode 100644 index 0000000..2333138 --- /dev/null +++ b/examples/datalog/count.scl @@ -0,0 +1,12 @@ +// There are three classes +rel classes = {0, 1, 2} + +// Each student is enrolled in a course (Math or CS) +rel enroll = { + ("tom", "CS"), ("jenny", "Math"), // Class 0 + ("alice", "CS"), ("bob", "CS"), // Class 1 + ("jerry", "Math"), ("john", "Math"), // Class 2 +} + +// Count how many student enrolls in CS course +rel count_enroll_cs(n) = n := count(s: enroll(s, "CS")) diff --git a/examples/datalog/equality_saturation.scl b/examples/datalog/equality_saturation.scl new file mode 100644 index 0000000..7372265 --- /dev/null +++ b/examples/datalog/equality_saturation.scl @@ -0,0 +1,42 @@ +// The language for simple symbolic arithmetic expression +type Expr = Const(i32) | Var(String) | Add(Expr, Expr) + +// The input to this module is a program +type input_program(program: Expr) + +// A relation `to_string` for visualizing +rel to_string(p, i as String) = case p is Const(i) +rel to_string(p, v) = case p is Var(v) +rel to_string(p, $format("({} + {})", s1, s2)) = case p is Add(p1, p2) and to_string(p1, s1) and to_string(p2, s2) + +// Relation for expression +rel expr(p) = case p is Const(_) or case p is Var(_) or case p is Add(_, _) + +// Basic definition of equivalency: it is identity and transitive +rel equivalent(p, p) = expr(p) +rel equivalent(p1, p3) = equivalent(p1, p2) and equivalent(p2, p3) + +// Definition of rewrite rules suggesting equivalence +rel equivalent(p, new Add(b, a)) = case p is Add(a, b) +rel equivalent(p1, new Add(a2, b2)) = case p1 is Add(a1, b1) and equivalent(a1, a2) and equivalent(b1, b2) +rel equivalent(p, new Add(a, new Add(b, c))) = case p is Add(Add(a, b), c) +rel equivalent(p, new Const(a + b)) = case p is Add(Const(a), Const(b)) +rel equivalent(p, p1) = case p is Add(p1, Const(0)) + +// Definition of weight on each type of construct +rel weight(p, 1) = case p is Const(_) +rel weight(p, 1) = case p is Var(_) +rel weight(p, w1 + w2 + 1) = case p is Add(p1, p2) and weight(p1, w1) and weight(p2, w2) + +// Equivalent program strings +rel equiv_programs(sp) = input_program(p) and equivalent(p, sp) + +// Find the best program (minimum weight) among all programs equivalent to p +rel best_program(p) = _ := min[p](w: equiv_programs(p) and weight(p, w)) +rel best_program_str(s) = best_program(best_prog) and to_string(best_prog, s) + +// ======================================== + +const MY_PROGRAM = Add(Add(Const(3), Var("a")), Const(-3)) +rel input_program(MY_PROGRAM) +query best_program_str diff --git a/examples/datalog/evaluate_formula.scl b/examples/datalog/evaluate_formula.scl index 122c38b..d90fb7f 100644 --- a/examples/datalog/evaluate_formula.scl +++ b/examples/datalog/evaluate_formula.scl @@ -1,22 +1,23 @@ // Inputs -type symbol(usize, String) +type symbol(usize, char) type length(usize) // Facts for lexing -rel digit = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"} +rel digit = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'} type term(value: f32, begin: usize, end: usize) rel term(x as f32, b, b + 1) = symbol(b, x) and digit(x) +rel term(n * 10 + x as f32, s, e + 1) = term(n, s, e) and symbol(e, x) and digit(x) type mult_div(value: f32, begin: usize, end: usize) rel mult_div(x, b, r) = term(x, b, r) -rel mult_div(x * y, b, e) = mult_div(x, b, m) and symbol(m, "*") and term(y, m + 1, e) -rel mult_div(x / y, b, e) = mult_div(x, b, m) and symbol(m, "/") and term(y, m + 1, e) +rel mult_div(x * y, b, e) = mult_div(x, b, m) and symbol(m, '*') and term(y, m + 1, e) +rel mult_div(x / y, b, e) = mult_div(x, b, m) and symbol(m, '/') and term(y, m + 1, e) type add_minus(value: f32, begin: usize, end: usize) rel add_minus(x, b, r) = mult_div(x, b, r) -rel add_minus(x + y, b, e) = add_minus(x, b, m) and symbol(m, "+") and mult_div(y, m + 1, e) -rel add_minus(x - y, b, e) = add_minus(x, b, m) and symbol(m, "-") and mult_div(y, m + 1, e) +rel add_minus(x + y, b, e) = add_minus(x, b, m) and symbol(m, '+') and mult_div(y, m + 1, e) +rel add_minus(x - y, b, e) = add_minus(x, b, m) and symbol(m, '-') and mult_div(y, m + 1, e) type result(value: f32) rel result(y) = add_minus(y, 0, l) and length(l) @@ -26,8 +27,7 @@ rel result(y) = add_minus(y, 0, l) and length(l) // Testing related type test_string(String) rel length($string_length(s)) = test_string(s) -rel symbol(0, $string_char_at(s, 0) as String) = test_string(s), $string_length(s) > 0 -rel symbol(i, $string_char_at(s, i) as String) = symbol(i - 1, _), test_string(s), $string_length(s) > i +rel symbol(i, c) = test_string(s), string_chars(s, i, c) rel test_string("123/24+1") query result diff --git a/examples/datalog/exists_blue_obj.scl b/examples/datalog/exists_blue_obj.scl index 44b76f5..a272652 100644 --- a/examples/datalog/exists_blue_obj.scl +++ b/examples/datalog/exists_blue_obj.scl @@ -6,8 +6,8 @@ rel color = {(0, "red"), (1, "green"), (2, "blue"), (3, "blue")} rel shape = {(0, "cube"), (1, "cylinder"), (2, "sphere"), (3, "cube")} // Is there a blue object? -rel exists_blue_obj(b) = b = exists(o: color(o, "blue")) +rel exists_blue_obj(b) = b := exists(o: color(o, "blue")) // For each shape, is there a blue object of that shape? -rel exists_blue_obj_of_shape(s, b) :- - b = exists(o: color(o, "blue"), shape(o, s) where s: all_shapes(s)) +rel exists_blue_obj_of_shape(s, b) = + b := exists(o: color(o, "blue"), shape(o, s) where s: all_shapes(s)) diff --git a/examples/datalog/lambda.scl b/examples/datalog/lambda.scl new file mode 100644 index 0000000..0720899 --- /dev/null +++ b/examples/datalog/lambda.scl @@ -0,0 +1,26 @@ +// Language of simply typed lambda calculus (STLC) +type Expr = ConstInt(i32) + | ConstBool(bool) + | Var(String) + | Let(String, Expr, Expr) + | IfThenElse(Expr, Expr, Expr) + | App(Expr, Expr) + | Lambda(String, Expr) + +// Pretty printing +type to_string(e: Expr, s: String) +rel to_string(e, i as String) = case e is ConstInt(i) +rel to_string(e, b as String) = case e is ConstBool(b) +rel to_string(e, s) = case e is Var(s) +rel to_string(e, $format("let {} = ({}) in ({})", x, s1, s2)) = case e is Let(x, e1, e2) and to_string(e1, s1) and to_string(e2, s2) +rel to_string(e, $format("if ({}) then ({}) else ({})", s1, s2, s3)) = case e is IfThenElse(e1, e2, e3) and to_string(e1, s1) and to_string(e2, s2) and to_string(e3, s3) +rel to_string(e, $format("({}) ({})", s1, s2)) = case e is App(e1, e2) and to_string(e1, s1) and to_string(e2, s2) +rel to_string(e, $format("(λ{}.{})", x, s1)) = case e is Lambda(x, e1) and to_string(e1, s1) + +// Define my program +const MY_PROGRAM = App(Lambda("a", Var("a")), ConstInt(1)) + +// Print the program +rel result(s) = to_string(MY_PROGRAM, s) + +query result diff --git a/examples/datalog/pacman_maze_example.scl b/examples/datalog/pacman_maze_example.scl index 241f464..b011117 100644 --- a/examples/datalog/pacman_maze_example.scl +++ b/examples/datalog/pacman_maze_example.scl @@ -5,10 +5,7 @@ type goal_position(x: usize, y: usize) type is_enemy(x: usize, y: usize) // Constants -const UP = 0 -const RIGHT = 1 -const DOWN = 2 -const LEFT = 3 +type Action = UP | RIGHT | DOWN | LEFT // Basic connectivity rel node(x, y) = grid_node(x, y), not is_enemy(x, y) diff --git a/examples/datalog/regex.scl b/examples/datalog/regex.scl index 3de7797..3fb65bf 100644 --- a/examples/datalog/regex.scl +++ b/examples/datalog/regex.scl @@ -1,47 +1,28 @@ // =========== REGEX =========== -type regex_char(id: usize, c: char) -type regex_concat(id: usize, left: usize, right: usize) -type regex_union(id: usize, left: usize, right: usize) -type regex_star(id: usize, child: usize) -type regex_root(id: usize) +type Regex = Char(char) | Concat(Regex, Regex) | Union(Regex, Regex) | Star(Regex) +type regex_root(regex: Regex) // Match a single char -rel matches_substr(expr, start, start + 1) :- regex_char(expr, c), char_at(start, c) +rel matches_substr(expr, i, i + 1) :- case expr is Char(c), input_string(s), string_chars(s, i, c) // Match a concatenation -rel matches_substr(expr, l, r) :- regex_concat(expr, le, re), matches_substr(le, l, m), matches_substr(re, m, r) +rel matches_substr(expr, l, r) :- case expr is Concat(le, re), matches_substr(le, l, m), matches_substr(re, m, r) // Match a union -rel matches_substr(expr, l, r) :- regex_union(expr, a, b), matches_substr(a, l, r) -rel matches_substr(expr, l, r) :- regex_union(expr, a, b), matches_substr(b, l, r) +rel matches_substr(expr, l, r) :- case expr is Union(a, _), matches_substr(a, l, r) +rel matches_substr(expr, l, r) :- case expr is Union(_, b), matches_substr(b, l, r) // Match a star -rel matches_substr(expr, i, i) :- regex_star(expr, _), range(0, l + 1, i), input_string(s), strlen(s, l) -rel matches_substr(expr, l, r) :- regex_star(expr, c), matches_substr(c, l, r) -rel matches_substr(expr, l, r) :- regex_star(expr, c), matches_substr(expr, l, m), matches_substr(c, m, r) +rel matches_substr(expr, i, i) :- case expr is Star(_), input_string(s), string_chars(s, i, _) +rel matches_substr(expr, l, r) :- case expr is Star(c), matches_substr(c, l, r) +rel matches_substr(expr, l, r) :- case expr is Star(c), matches_substr(expr, l, m), matches_substr(c, m, r) // Matches the whole string -rel matches(true) :- input_string(s), strlen(s, l), regex_root(e), matches_substr(e, 0, l) - -// =========== STRING =========== -type input_string(s: String) -rel char_at(i, $string_char_at(s, i)) :- input_string(s), strlen(s, l), range(0, l, i) - -// =========== HELPER =========== -@demand("bbf") -rel range(a, b, i) :- i == a -rel range(a, b, i) :- range(a, b, i - 1), i < b -@demand("bf") -rel strlen(s, i) :- i == $string_length(s) +rel matches() :- regex_root(e), input_string(s), matches_substr(e, 0, $string_length(s)) // =========== EXAMPLE =========== -rel regex_char(0, 'a') -rel regex_char(1, 'b') -rel regex_concat(2, 0, 1) -rel regex_concat(3, 2, 0) -rel regex_star(4, 1) -rel regex_concat(5, 3, 4) -rel regex_root(5) +const MY_REGEX = Concat(Concat(Concat(Char('a'), Char('b')), Char('a')), Star(Char('b'))) +rel regex_root(MY_REGEX) rel input_string("ababbbb") diff --git a/examples/datalog/sat.scl b/examples/datalog/sat.scl new file mode 100644 index 0000000..aa84223 --- /dev/null +++ b/examples/datalog/sat.scl @@ -0,0 +1,22 @@ +// Boolean formula language +type Formula = Var(String) + | Not(Formula) + | And(Formula, Formula) + | Or(Formula, Formula) + +// Each variable could be assigned either true or false, but not both +rel { assign(v, true); assign(v, false) } = case bf is Var(v) + +// Evaluation the formula to see if it is satisfiable +rel eval(bf, r) = case bf is Var(v) and assign(v, r) +rel eval(bf, !r) = case bf is Not(c) and eval(c, r) +rel eval(bf, lr && rr) = case bf is And(lbf, rbf) and eval(lbf, lr) and eval(rbf, rr) +rel eval(bf, lr || rr) = case bf is Or(lbf, rbf) and eval(lbf, lr) and eval(rbf, rr) + +// =============== + +// (A /\ ~A) \/ (B /\ ~B) +const MY_FORMULA = Or(And(Var("A"), Not(Var("A"))), And(Var("B"), Not(Var("B")))) + +// Query the evaluated result +query eval(MY_FORMULA, r) diff --git a/examples/datalog/type_inference.scl b/examples/datalog/type_inference.scl index af9970e..8053df5 100644 --- a/examples/datalog/type_inference.scl +++ b/examples/datalog/type_inference.scl @@ -1,52 +1,89 @@ -// EXP ::= let V = EXP in EXP -// | if EXP then EXP else EXP -// | X + Y | X - Y -// | X and Y | X or Y | not X -// | X == Y | X != Y | X < Y | X <= Y | X > Y | X >= Y - -// Basic syntax constructs -type number(usize, i32) -type boolean(usize, bool) -type variable(usize, String) -type bexp(usize, String, usize, usize) -type aexp(usize, String, usize, usize) -type let_in(usize, String, usize, usize) -type if_then_else(usize, usize, usize, usize) - -// Comparison operations -rel comparison_op = {"==", "!=", ">=", "<=", ">", "<"} -rel logical_op = {"&&", "||", "^"} -rel arith_op = {"+", "-", "*", "/"} - -// A program with each number 0-4 denoting their index -// let x = 3 in x == 4 -// -------------------0 -// -1 ------2 -// -3 -4 -rel let_in = {(0, "x", 1, 2)} -rel number = {(1, 3), (4, 4)} -rel bexp = {(2, "==", 3, 4)} -rel variable = {(3, "x")} - -// Type Inference: - -// - Base case -rel type_of(x, "bool") = boolean(x, _) -rel type_of(x, "int") = number(x, _) -rel type_of(x, t) = variable(x, v), env_type(x, v, t) -rel type_of(e, "bool") = bexp(e, op, x, y), comparison_op(op), type_of(x, "int"), type_of(y, "int") -rel type_of(e, "bool") = bexp(e, op, x, y), logical_op(op), type_of(x, "bool"), type_of(y, "bool") -rel type_of(e, "int") = aexp(e, op, x, y), arith_op(op), type_of(x, "int"), type_of(y, "int") -rel type_of(e, t) = let_in(e, v, b, c), env_type(c, v, tv), type_of(b, tv), type_of(c, t) -rel type_of(e, t) = if_then_else(e, x, y, z), type_of(x, "bool"), type_of(y, t), type_of(z, t) - -// - Environment variable type -rel env_type(x, v, t) = bexp(e, _, x, _), env_type(e, v, t) -rel env_type(y, v, t) = bexp(e, _, _, y), env_type(e, v, t) -rel env_type(x, v, t) = aexp(e, _, x, _), env_type(e, v, t) -rel env_type(y, v, t) = aexp(e, _, _, y), env_type(e, v, t) -rel env_type(z, v, t) = let_in(_, v, y, z), type_of(y, t) -rel env_type(z, v2, t) = let_in(x, v1, _, z), env_type(x, v2, t), v1 != v2 -rel env_type(x, v, t) = env_type(e, v, t), if_then_else(e, x, _, _) -rel env_type(y, v, t) = env_type(e, v, t), if_then_else(e, _, y, _) -rel env_type(z, v, t) = env_type(e, v, t), if_then_else(e, _, _, z) +type Op = EQ | NEQ | GEQ | LEQ | GT | LT | AND | OR | XOR | ADD | SUB | MUL | DIV | NEG | NOT + +type Expr = Number(i32) + | Boolean(bool) + | Variable(String) + | Binary(Op, Expr, Expr) + | Unary(Op, Expr) + | Let(String, Expr, Expr) + | Ite(Expr, Expr, Expr) + +type Type = BOOL | INT + +type input_program(expr: Expr) + +// ================= + +// Pretty printing of operators +rel op_to_string = { + (EQ, "=="), (NEQ, "!="), + (GEQ, ">="), (LEQ, "<="), (GT, ">"), (LT, "<"), + (AND, "&&"), (OR, "||"), (XOR, "^"), + (ADD, "+"), (SUB, "-"), (MUL, "*"), (DIV, "/"), + (NEG, "-"), (NOT, "!") +} + +// Pretty printing of type +rel ty_to_string = {(BOOL, "bool"), (INT, "int")} + +// Pretty printing of expressions +rel expr_to_string(e, x as String) = case e is Number(x) +rel expr_to_string(e, x as String) = case e is Boolean(x) +rel expr_to_string(e, x) = case e is Variable(x) +rel expr_to_string(e, $format("({} {} {})", op1_str, op_str, op2_str)) = case e is Binary(op, op1, op2) and expr_to_string(op1, op1_str) and expr_to_string(op2, op2_str) and op_to_string(op, op_str) +rel expr_to_string(e, $format("({}{})", op_str, op1_str)) = case e is Unary(op, op1) and expr_to_string(op1, op1_str) and op_to_string(op, op_str) +rel expr_to_string(e, $format("let {} = {} in {}", x, b_str, i_str)) = case e is Let(x, b, i) and expr_to_string(b, b_str) and expr_to_string(i, i_str) +rel expr_to_string(e, $format("if {} then {} else {}", cs, ts, es)) = case e is Ite(c, t, e) and expr_to_string(c, cs) and expr_to_string(t, ts) and expr_to_string(e, es) + +// ================= + +// Basic types of operators +rel eq_op = {EQ, NEQ} +rel comp_op = {GEQ, LEQ, GT, LT} +rel logical_op = {AND, OR, XOR} +rel arith_op = {ADD, SUB, MUL, DIV} +rel unary_arith_op = {NEG} +rel unary_logical_op = {NOT} + +// Typing environment +type Env = Empty() | Cons(String, Type, Env) +const EMPTY_ENV = Empty() + +// Find a variable stored in the typing environment +@demand("bbf") +rel find_type(env, var, ty) = case e is Cons(var, ty, _) +rel find_type(env, var, ty) = case e is Cons(vp, _, tl) and vp != var and find_type(tl, var, ty) + +// The type (`ty`) of an expression (`expr`) under an environment (`env`) +type type_of(env: Env, expr: Expr, ty: Type) + +// Typing rules +@demand("bbf") +rel type_of(env, e, BOOL) = case e is Boolean(_) +rel type_of(env, e, INT) = case e is Number(_) +rel type_of(env, e, ty) = case e is Variable(x) and find_type(env, x, ty) +rel type_of(env, e, BOOL) = case e is Binary(op, op1, op2) and eq_op(op) and type_of(env, op1, ty) and type_of(env, op2, ty) +rel type_of(env, e, BOOL) = case e is Binary(op, op1, op2) and comp_op(op) and type_of(env, op1, INT) and type_of(env, op2, INT) +rel type_of(env, e, BOOL) = case e is Binary(op, op1, op2) and logical_op(op) and type_of(env, op1, BOOL) and type_of(env, op2, BOOL) +rel type_of(env, e, INT) = case e is Binary(op, op1, op2) and arith_op(op) and type_of(env, op1, INT) and type_of(env, op2, INT) +rel type_of(env, e, BOOL) = case e is Unary(op, op1) and unary_logical_op(op) and type_of(env, op1, BOOL) +rel type_of(env, e, INT) = case e is Unary(op, op1) and unary_arith_op(op) and type_of(env, op1, INT) +rel type_of(env, e, ty_i) = to_infer_let_cons(env, e, sub_env, i) and type_of(sub_env, i, ty_i) +rel type_of(env, e, ty) = case e is Ite(c, t, e) and type_of(env, c, BOOL) and type_of(env, t, ty) and type_of(env, e, ty) + +// Helpers +@demand("bbff") +rel to_infer_let_cons(env, e, new Cons(x, ty_b, env), i) = case e is Let(x, b, i) and type_of(env, b, ty_b) + +// The result if the type of the input program +rel result(expr_str, ty_str) = input_program(p) and expr_to_string(p, expr_str) and type_of(EMPTY_ENV, p, ty) and ty_to_string(ty, ty_str) + +// ================= + +// let x = 3 in x == 4 +const PROGRAM = Let("x", Number(3), Binary(EQ, Variable("x"), Number(4))) + +// Input program is the `PROGRAM` +rel input_program(PROGRAM) + +query result diff --git a/lib/ram/src/language.rs b/lib/ram/src/language.rs index c401186..77429a9 100644 --- a/lib/ram/src/language.rs +++ b/lib/ram/src/language.rs @@ -18,8 +18,13 @@ define_language! { // Tuple operations "apply" = Apply([Id; 2]), - "cons" = Cons([Id; 2]), - "nil" = Nil, + "tuple-cons" = TupleCons([Id; 2]), + "tuple-nil" = TupleNil, + + // Indexing operations + "index" = Index(Id), + "index-cons" = IndexCons([Id; 2]), + "index-nil" = IndexNil, // Value operations "+" = Add([Id; 2]), @@ -44,9 +49,7 @@ fn var(s: &str) -> Var { } fn is_constant(_v: Var) -> impl Fn(&mut EGraph, Id, &Subst) -> bool { - move |_egraph, _, _subst| { - false - } + move |_egraph, _, _subst| false } /// All the rewrite rules for the language @@ -57,16 +60,14 @@ pub fn ram_rewrite_rules() -> Vec> { rw!("filter-true"; "(filter ?d true)" => "?d"), rw!("filter-false"; "(filter ?d false)" => "empty"), rw!("project-cascade"; "(project (project ?d ?a) ?b)" => "(project ?d (apply ?b ?a))"), - rw!("product-transpose"; "(product ?a ?b)" => "(project (product ?b ?a) (cons (cons 1 nil) (cons (cons 0 nil) nil))))"), - rw!("join-transpose"; "(join ?a ?b)" => "(project (join ?b ?a) (cons (cons 2 nil) (cons (cons 1 nil) nil))))"), - + rw!("product-transpose"; "(product ?a ?b)" => "(project (product ?b ?a) (tuple-cons (index 1) (tuple-cons (index-cons 0 index-nil) tuple-nil))))"), + rw!("join-transpose"; "(join ?a ?b)" => "(project (join ?b ?a) (tuple-cons (index 0) (tuple-cons (index 2) (tuple-cons (index 1) tuple-nil))))"), // Tuple level application rewrites - rw!("access-nil"; "(apply nil ?a)" => "?a"), - rw!("access-tuple-base"; "(apply (cons 0 ?x) (cons ?a ?b))" => "(apply ?x ?a)"), - rw!("access-tuple-ind"; "(apply (cons ?n ?x) (cons ?a ?b))" => "(apply (cons (- ?n 1) ?x) ?b)"), - rw!("apply-tuple-nil"; "(apply nil ?a)" => "nil"), - rw!("apply-tuple-cons"; "(apply (cons ?a ?b) ?c)" => "(cons (apply ?a ?c) (apply ?b ?c))"), - + rw!("access-nil"; "(apply index-nil ?a)" => "?a"), + rw!("access-tuple-base"; "(apply (index-cons 0 ?x) (tuple-cons ?a ?b))" => "(apply ?x ?a)"), + rw!("access-tuple-ind"; "(apply (index-cons ?n ?x) (tuple-cons ?a ?b))" => "(apply (index-cons (- ?n 1) ?x) ?b)"), + rw!("apply-tuple-nil"; "(apply tuple-nil ?a)" => "tuple-nil"), + rw!("apply-tuple-cons"; "(apply (tuple-cons ?a ?b) ?c)" => "(tuple-cons (apply ?a ?c) (apply ?b ?c))"), // Expression level application rewrites rw!("apply-add"; "(apply (+ ?a ?b) ?t)" => "(+ (apply ?a ?t) (apply ?b ?t))"), rw!("apply-sub"; "(apply (- ?a ?b) ?t)" => "(- (apply ?a ?t) (apply ?b ?t))"), @@ -76,7 +77,6 @@ pub fn ram_rewrite_rules() -> Vec> { rw!("apply-or"; "(apply (|| ?a ?b) ?t)" => "(|| (apply ?a ?t) (apply ?b ?t))"), rw!("apply-not"; "(apply (! ?a) ?t)" => "(! (apply ?a ?t))"), rw!("apply-const"; "(apply ?e ?t)" => "?e" if is_constant(var("?e"))), - // Value level rewrites rw!("add-comm"; "(+ ?a ?b)" => "(+ ?b ?a)"), rw!("add-identity"; "(+ ?a 0)" => "?a"), @@ -91,8 +91,8 @@ pub fn ram_rewrite_rules() -> Vec> { rw!("not-true"; "(! true)" => "false"), rw!("not-false"; "(! false)" => "true"), rw!("not-not"; "(! (! ?a))" => "?a"), - // Simple arithmetic rewrites for index calculations + rw!("index-desugar"; "(index ?x)" => "(index-cons ?x index-nil)"), rw!("dec-1"; "(- 1 1)" => "0"), rw!("dec-2"; "(- 2 1)" => "1"), rw!("dec-3"; "(- 3 1)" => "2"), @@ -108,7 +108,7 @@ impl CostFunction for RamCostFunction { fn cost(&mut self, enode: &Ram, mut costs: C) -> Self::Cost where - C: FnMut(Id) -> Self::Cost + C: FnMut(Id) -> Self::Cost, { let op_cost = match enode { Ram::Empty => 0, @@ -118,8 +118,11 @@ impl CostFunction for RamCostFunction { Ram::Join(_) => 100, Ram::Sorted(_) => 500, Ram::Apply(_) => 10, - Ram::Cons(_) => 0, - Ram::Nil => 0, + Ram::Index(_) => 10, + Ram::TupleCons(_) => 0, + Ram::TupleNil => 0, + Ram::IndexCons(_) => 0, + Ram::IndexNil => 0, _ => 1, }; enode.fold(op_cost, |sum, id| sum + costs(id)) diff --git a/lib/ram/src/lib.rs b/lib/ram/src/lib.rs index 791daf7..5094986 100644 --- a/lib/ram/src/lib.rs +++ b/lib/ram/src/lib.rs @@ -1,6 +1,6 @@ pub mod generic_tuple; -pub mod tuple_type; mod language; +pub mod tuple_type; pub mod value_type; pub use language::*; diff --git a/lib/ram/tests/test_optim.rs b/lib/ram/tests/test_optim.rs index b4c99d4..d9cb059 100644 --- a/lib/ram/tests/test_optim.rs +++ b/lib/ram/tests/test_optim.rs @@ -7,7 +7,10 @@ fn test_filter_cascade_1() { #[test] fn test_filter_cascade_2() { - assert_eq!(simplify("(filter (filter (filter ?d ?a) ?b) ?c)"), "(filter ?d (&& ?c (&& ?a ?b)))") + assert_eq!( + simplify("(filter (filter (filter ?d ?a) ?b) ?c)"), + "(filter ?d (&& ?c (&& ?a ?b)))" + ) } #[test] @@ -17,21 +20,26 @@ fn test_project_cascade_1() { #[test] fn test_project_cascade_2() { - assert_eq!(simplify(r#" + assert_eq!( + simplify( + r#" (project (project ?d (tuple-cons - (index-cons 1 index-nil) + (index 1) (tuple-cons - (index-cons 0 index-nil) + (index 0) tuple-nil ) ) ) (- - (index-cons 0 index-nil) - (index-cons 1 index-nil) + (index 0) + (index 1) ) - )"#), "(project ?d (- (index-cons 1 index-nil) (index-cons 0 index-nil)))") + )"# + ), + simplify("(project ?d (- (index 1) (index 0)))") + ) } diff --git a/makefile b/makefile index bda871f..b219c56 100644 --- a/makefile +++ b/makefile @@ -10,6 +10,12 @@ install-sclc: install-sclrepl: cargo install --path etc/sclrepl +install-scallop-cli: install-scallopy + cd etc/scallop-cli; python setup.py install + +develop-scallop-cli: develop-scallopy + cd etc/scallop-cli; python setup.py install + install-scallopy: maturin build --release --manifest-path etc/scallopy/Cargo.toml --out target/wheels/current find target/wheels/current -name "*.whl" -print | xargs pip install --force-reinstall @@ -18,6 +24,45 @@ install-scallopy: develop-scallopy: cd etc/scallopy; maturin develop --release +# ============================================ +# === Scallopy with Torch on normal Device === + +install-scallopy-torch: + maturin build --release \ + --features "torch-tensor" \ + --manifest-path etc/scallopy/Cargo.toml \ + --out target/wheels/current \ + --config 'env.LIBTORCH_USE_PYTORCH = "1"' + find target/wheels/current -name "*.whl" -print | xargs pip install --force-reinstall + rm -rf target/wheels/current + +develop-scallopy-torch: + cd etc/scallopy; maturin develop --release --features "torch-tensor" --config 'env.LIBTORCH_USE_PYTORCH = "1"' + +# ================================================= +# === Scallopy with Torch on Apple M1/M2 Device === + +install-scallopy-torch-apple: + python3 scripts/link_torch_lib.py + maturin build --release \ + --features "torch-tensor" \ + --manifest-path etc/scallopy/Cargo.toml \ + --out target/wheels/current \ + --config 'env.LIBTORCH = "$(shell pwd)/.tmp/torch"' \ + --config 'env.DYLD_LIBRARY_PATH = "$(shell pwd)/.tmp/torch/lib"' + find target/wheels/current -name "*.whl" -print | xargs pip install --force-reinstall + rm -rf target/wheels/current + +develop-scallopy-torch-apple: + python3 scripts/link_torch_lib.py + cd etc/scallopy; maturin develop --release \ + --features "torch-tensor" \ + --config 'env.LIBTORCH = "$(shell pwd)/.tmp/torch"' \ + --config 'env.DYLD_LIBRARY_PATH = "$(shell pwd)/.tmp/torch/lib"' + +# ================================================= +# === Scallop WASM for Web Demo and Node === + wasm-demo: make -C etc/scallop-wasm @@ -40,6 +85,12 @@ clean: make -C etc/scallopy clean make -C etc/scallop-wasm clean +check: + cargo check --workspace + +check-plus: + cargo check --workspace --features "torch-tensor" + test: @echo "[Info] Performing cargo test..." @make test-cargo @@ -60,7 +111,10 @@ test-cargo: test-cargo-ignored: cargo test --workspace -- --ignored -test-scallopy: +test-scallopy: develop-scallopy + python3 etc/scallopy/tests/test.py + +test-scallopy-torch: develop-scallopy-torch python3 etc/scallopy/tests/test.py doc: diff --git a/scripts/get_current_version.py b/scripts/get_current_version.py new file mode 100644 index 0000000..6f05a56 --- /dev/null +++ b/scripts/get_current_version.py @@ -0,0 +1,4 @@ +import toml +file = open("core/Cargo.toml", "r") +parsed_toml = toml.loads(file.read()) +print(parsed_toml["package"]["version"]) diff --git a/scripts/link_torch_lib.py b/scripts/link_torch_lib.py new file mode 100644 index 0000000..ca21ab7 --- /dev/null +++ b/scripts/link_torch_lib.py @@ -0,0 +1,38 @@ +import os +import subprocess +import re +import argparse + +# Get an argument +parser = argparse.ArgumentParser() +parser.add_argument("--verbose", action="store_true") +args = parser.parse_args() + +# First get the python directory +output = subprocess.run(["which", "python"], stdout=subprocess.PIPE).stdout.decode() +python_dir = output.split("\n")[0] +if args.verbose: + print(f"Python Directory: {python_dir}") + +# Check python version +python_version = subprocess.run([python_dir, "--version"], stdout=subprocess.PIPE).stdout.decode() +capture = re.search('Python (\\d+).(\\d+).(\\d+)', python_version) +version = f"{capture[1]}.{capture[2]}" +if args.verbose: + print(f"Python Version: {version}") + +# Get the lib +lib_dir = os.path.abspath(os.path.join(python_dir, "..", "..", "lib", f"python{version}", "site-packages", "torch")) +if not os.path.exists(lib_dir): + print(f"[Error] Torch lib `{lib_dir}` does not exist") +if args.verbose: + print(f"Torch Lib: {lib_dir}") + +# Link the library to the current directory +tmp_dir = os.path.join(os.path.dirname(__file__), "..", ".tmp") +if not os.path.exists(tmp_dir): + os.mkdir(tmp_dir) +ln_dir = os.path.abspath(os.path.join(tmp_dir, "torch")) +if args.verbose: + print(f"To be linked location: {ln_dir}") +subprocess.run(["ln", "-sf", lib_dir, ln_dir])