diff --git a/core/src/compiler/front/analyzers/aggregation.rs b/core/src/compiler/front/analyzers/aggregation.rs index fd30528..8a1637b 100644 --- a/core/src/compiler/front/analyzers/aggregation.rs +++ b/core/src/compiler/front/analyzers/aggregation.rs @@ -43,6 +43,21 @@ impl NodeVisitor for AggregationAnalysis { } } } + + // Check the binding variables + if reduce.bindings().is_empty() { + match &reduce.operator().node { + ReduceOperatorNode::Exists + | ReduceOperatorNode::Forall + | ReduceOperatorNode::Unknown(_) => {} + r => { + self.errors.push(AggregationAnalysisError::EmptyBinding { + agg: r.to_string(), + loc: reduce.location().clone(), + }) + } + } + } } } @@ -51,6 +66,7 @@ pub enum AggregationAnalysisError { NonMinMaxAggregationHasArgument { op: ReduceOperator }, UnknownAggregator { agg: String, loc: Loc }, ForallBodyNotImplies { loc: Loc }, + EmptyBinding { agg: String, loc: Loc }, } impl FrontCompileErrorTrait for AggregationAnalysisError { @@ -76,6 +92,13 @@ impl FrontCompileErrorTrait for AggregationAnalysisError { loc.report(src) ) } + Self::EmptyBinding { agg, loc } => { + format!( + "the binding variables of `{}` aggregation cannot be empty\n{}", + agg, + loc.report(src), + ) + } } } } diff --git a/core/src/compiler/front/ast/formula.rs b/core/src/compiler/front/ast/formula.rs index 72bd2e2..0bba8ad 100644 --- a/core/src/compiler/front/ast/formula.rs +++ b/core/src/compiler/front/ast/formula.rs @@ -269,6 +269,24 @@ pub enum ReduceOperatorNode { Unknown(String), } +impl ReduceOperatorNode { + pub fn to_string(&self) -> String { + match self { + Self::Count => "count".to_string(), + Self::Sum => "sum".to_string(), + Self::Prod => "prod".to_string(), + Self::Min => "min".to_string(), + Self::Max => "max".to_string(), + Self::Exists => "exists".to_string(), + Self::Forall => "forall".to_string(), + Self::Unique => "unique".to_string(), + Self::TopK(k) => format!("top<{}>", k), + Self::CategoricalK(k) => format!("categorical<{}>", k), + Self::Unknown(_) => "unknown".to_string(), + } + } +} + /// A reduce opeartor, e.g. `count` pub type ReduceOperator = AstNode; @@ -312,19 +330,7 @@ impl ReduceOperator { } pub fn to_string(&self) -> String { - match &self.node { - ReduceOperatorNode::Count => "count".to_string(), - ReduceOperatorNode::Sum => "sum".to_string(), - ReduceOperatorNode::Prod => "prod".to_string(), - ReduceOperatorNode::Min => "min".to_string(), - ReduceOperatorNode::Max => "max".to_string(), - ReduceOperatorNode::Exists => "exists".to_string(), - ReduceOperatorNode::Forall => "forall".to_string(), - ReduceOperatorNode::Unique => "unique".to_string(), - ReduceOperatorNode::TopK(k) => format!("top<{}>", k), - ReduceOperatorNode::CategoricalK(k) => format!("categorical<{}>", k), - ReduceOperatorNode::Unknown(_) => "unknown".to_string(), - } + self.node.to_string() } } diff --git a/core/src/compiler/front/grammar.lalrpop b/core/src/compiler/front/grammar.lalrpop index c9c39fc..88dbaa6 100644 --- a/core/src/compiler/front/grammar.lalrpop +++ b/core/src/compiler/front/grammar.lalrpop @@ -583,6 +583,16 @@ ReduceGroupBy: (Vec, Box) = { ReduceAssignmentSymbol = { "=", ":=" } ReduceNode: ReduceNode = { + ReduceAssignmentSymbol "(" ")" => { + ReduceNode { + left: vs, + operator: op, + args: args, + bindings: vec![], + body: Box::new(f), + group_by: g, + } + }, ReduceAssignmentSymbol "(" > ":" ")" => { ReduceNode { left: vs, @@ -605,6 +615,14 @@ ForallExistsReduceOpNode: ReduceOperatorNode = { ForallExistsReduceOp = Spanned; ForallExistsReduceNode: ForallExistsReduceNode = { + "(" ")" => { + ForallExistsReduceNode { + operator: op, + bindings: vec![], + body: Box::new(f), + group_by: g, + } + }, "(" > ":" ")" => { ForallExistsReduceNode { operator: op, diff --git a/core/tests/compiler/errors.rs b/core/tests/compiler/errors.rs index 7a4120d..c9cc24e 100644 --- a/core/tests/compiler/errors.rs +++ b/core/tests/compiler/errors.rs @@ -128,3 +128,13 @@ fn bad_enum_type_decl() { |e| e.contains("has already been assigned"), ) } + +#[test] +fn bad_no_binding_agg_1() { + expect_front_compile_failure( + r#" + rel r() = x := count(edge(1, 3)) + "#, + |e| e.contains("binding variables of `count` aggregation cannot be empty"), + ) +} diff --git a/core/tests/integrate/basic.rs b/core/tests/integrate/basic.rs index fbaf6d0..0da2220 100644 --- a/core/tests/integrate/basic.rs +++ b/core/tests/integrate/basic.rs @@ -660,6 +660,22 @@ fn test_count_with_where_clause() { ) } +#[test] +fn test_exists_path_1() { + expect_interpret_multi_result( + r#" + rel edge = {(0, 1), (1, 2)} + rel path(x, y) = edge(x, y) or (path(x, z) and edge(z, y)) + rel result1(b) = b := exists(path(0, 2)) + rel result2(b) = b := exists(path(0, 3)) + "#, + vec![ + ("result1", vec![(true,)].into()), + ("result2", vec![(false,)].into()), + ], + ) +} + #[test] fn test_exists_with_where_clause() { expect_interpret_multi_result( diff --git a/etc/scallopy/Cargo.toml b/etc/scallopy/Cargo.toml index aad9f1b..7c1abcd 100644 --- a/etc/scallopy/Cargo.toml +++ b/etc/scallopy/Cargo.toml @@ -13,5 +13,5 @@ sclc-core = { path = "../sclc" } rayon = "1.5" [dependencies.pyo3] -version = "0.16.5" +version = "0.18.2" features = ["extension-module"] diff --git a/etc/scallopy/examples/sum_2_forward.py b/etc/scallopy/examples/sum_2_forward.py index fb9f56c..845d5c4 100644 --- a/etc/scallopy/examples/sum_2_forward.py +++ b/etc/scallopy/examples/sum_2_forward.py @@ -8,7 +8,7 @@ compute_sum_2 = scallopy.ScallopForwardFunction( program=sum_2_program, - provenance="diffminmaxprob", + provenance="diffaddmultprob2", input_mappings={"digit_a": list(range(10)), "digit_b": list(range(10))}, output_mappings={"sum_2": list(range(19))}) diff --git a/etc/scallopy/scallopy/context.py b/etc/scallopy/scallopy/context.py index 1694baf..81ff03d 100644 --- a/etc/scallopy/scallopy/context.py +++ b/etc/scallopy/scallopy/context.py @@ -577,6 +577,12 @@ def has_input_mapping(self, relation: str) -> bool: return True return False + def set_sample_topk_facts(self, relation: str, amount: int): + if relation in self._input_mappings: + self._input_mappings[relation].set_sample_topk_facts(amount) + else: + raise Exception(f"Unknown relation {relation}") + def requires_tag(self) -> bool: """ Returns whether the context requires facts to be associated with tags diff --git a/etc/scallopy/scallopy/forward.py b/etc/scallopy/scallopy/forward.py index 441510d..4975cf4 100644 --- a/etc/scallopy/scallopy/forward.py +++ b/etc/scallopy/scallopy/forward.py @@ -1,4 +1,5 @@ from typing import Dict, Union, List, Optional, Tuple, Any, Callable +import logging import os import sys import zipfile @@ -31,6 +32,7 @@ def __init__( early_discard: Optional[bool] = None, iter_limit: Optional[int] = None, retain_graph: bool = False, + retain_topk: Optional[Dict[str, int]] = None, jit: bool = False, jit_name: str = "", jit_recompile: bool = False, @@ -68,6 +70,11 @@ def __init__( for (relation, mapping) in input_mappings.items(): self.ctx.set_input_mapping(relation, mapping) + # Set the retain top-k + if retain_topk is not None: + for (relation, k) in retain_topk.items(): + self.ctx.set_sample_topk_facts(relation, k) + # Add input facts if specified if facts is not None: for (relation, elems) in facts.items(): @@ -130,6 +137,11 @@ def __init__( self.recompile = recompile self.fn_counter = self.FORWARD_FN_COUNTER + # Preprocess the dispatch + if self.ctx.provenance == "custom" and self.dispatch == "parallel": + logging.warning("custom provenance does not support parallel dispatch; falling back to serial dispatch. Consider creating the forward function using `dispatch=\"serial\"`.") + self.dispatch = "serial" + # Populate the output and output mapping fields self._process_output_mapping(output, output_mapping, output_mappings) for output_relation in self.outputs: diff --git a/etc/scallopy/scallopy/input_mapping.py b/etc/scallopy/scallopy/input_mapping.py index df5aeed..814dbea 100644 --- a/etc/scallopy/scallopy/input_mapping.py +++ b/etc/scallopy/scallopy/input_mapping.py @@ -68,6 +68,11 @@ def __init__( if not (0 <= self.sample_dim < self.dimension): raise Exception(f"Invalid sampling dimension {self.sample_dim}; total dimension is {self.dimension}") + def set_sample_topk_facts(self, amount: int): + self.retain_k = amount + self.sample_dim = None + self.sample_strategy = "top" + def __getitem__(self, index) -> Tuple: """Get the tuple of the input mapping from an index""" if self._kind == "dict": diff --git a/etc/scallopy/scallopy/provenance.py b/etc/scallopy/scallopy/provenance.py index 829d59c..4e8dff8 100644 --- a/etc/scallopy/scallopy/provenance.py +++ b/etc/scallopy/scallopy/provenance.py @@ -6,42 +6,48 @@ class ScallopProvenance: """ Base class for a provenance context. Any class implementing `ScallopProvenance` must override the following functions: - - `base` - `zero` - `one` - `add` - `mult` + - `negate` + - `saturate` """ - def base(self, info): + def tagging_fn(self, input_tag): """ - Given the base information, generate a tag for the tuple. - Base information is specified as `I1`, `I2`, ... in the following example: + Given the input tags, generate an internal tag for the tuple. + For example, in the following code, internal tags are the I1, I2, ...: ``` python ctx.add_facts("RELA", [(I1, T1), (I2, T2), ...]) ``` - This `base` function should take in info like `I1` and return the base tag. + This `tagging_fn` function should take in input tags like `I1` and return an internal tag + + If not implemented, we assume the input tags are the internal tags. + If the input tag is not provided (`None`), we use the `one()` as the internal tag. """ - return info + if input_tag is None: return self.one() + else: return input_tag - def disjunction_base(self, infos): + def recover_fn(self, internal_tag): """ - Given a set of base informations associated with a set of tuples forming a - disjunction, return the list of tags associated with each of them. + Given an internal tag, recover the output tag. + + If not implemented, we assume the internal tags are the output tags. """ - return [self.base(i) for i in infos] + return internal_tag - def is_valid(self, tag): + def discard(self, tag): """ - Check if a given tag is valid. + Check if a given tag needs to be discarded. When a tag is invalid, the tuple associated will be removed during reasoning. The default implementation assumes every tag is valid. An example of an invalid tag: a probability tag of probability 0.0 """ - return True + return False def zero(self): """ @@ -67,6 +73,12 @@ def mult(self, t1, t2): """ raise Exception("Not implemented") + def negate(self, t): + """ + Perform semiring negation on a tag (`t`) + """ + raise Exception("Not implemented") + def aggregate_count(self, elems): """ Aggregate a count of the given elements @@ -111,12 +123,6 @@ def __init__(self): if not torch_importer.has_pytorch: raise Exception("PyTorch unavailable. You can use this semiring only with PyTorch") - def base(self, info: torch_importer.Tensor): - """ - If a torch tensor is provided then keep that tensor as the tag; otherwise we give it 1.0 - """ - return info if info is not None else self.one() - def zero(self): """ Zero tag is a floating point 0.0 (i.e. 0.0 probability being true) @@ -165,12 +171,6 @@ def __init__(self): if not torch_importer.has_pytorch: raise Exception("PyTorch unavailable. You can use this semiring only with PyTorch") - def base(self, info: torch_importer.Tensor): - """ - If a torch tensor is provided then keep that tensor as the tag; otherwise we give it 1.0 - """ - return info if info is not None else self.one() - def zero(self): """ Zero tag is a floating point 0.0 (i.e. 0.0 probability being true) @@ -219,12 +219,6 @@ def __init__(self): if not torch_importer.has_pytorch: raise Exception("PyTorch unavailable. You can use this semiring only with PyTorch") - def base(self, info: torch_importer.Tensor): - """ - If a torch tensor is provided then keep that tensor as the tag; otherwise we give it 1.0 - """ - return info if info is not None else self.one() - def zero(self): """ Zero tag is a floating point 0.0 (i.e. 0.0 probability being true) diff --git a/etc/scallopy/src/context.rs b/etc/scallopy/src/context.rs index c73abd8..f93930b 100644 --- a/etc/scallopy/src/context.rs +++ b/etc/scallopy/src/context.rs @@ -83,7 +83,7 @@ impl Context { /// * `k` - an unsigned integer serving as the hyper-parameter for provenance such as `"topkproofs"` /// * `custom_provenance` - an optional python object serving as the provenance context #[new] - #[args(provenance = "\"unit\"", k = "3", custom_provenance = "None")] + #[pyo3(signature=(provenance="unit", k=3, custom_provenance=None))] fn new(provenance: &str, k: usize, custom_provenance: Option>) -> Result { // Check provenance type match provenance { @@ -316,7 +316,7 @@ impl Context { /// # ctx = Context::new("unit", 3, None).unwrap(); /// ctx.add_relation("atom(usize, usize)", None, None).unwrap(); /// ``` - #[args(load_csv = "None", demand = "None")] + #[pyo3(signature=(relation, load_csv=None, demand=None))] fn add_relation( &mut self, relation: &str, @@ -360,7 +360,7 @@ impl Context { } /// Add a rule - #[args(tag = "None", demand = "None")] + #[pyo3(signature=(rule, tag=None, demand=None))] fn add_rule(&mut self, rule: &str, tag: Option<&PyAny>, demand: Option) -> Result<(), BindingError> { // Attributes let mut attrs = Vec::new(); @@ -491,7 +491,7 @@ impl Context { /// Get the number of relations in this context. /// If `include_hidden` is set `true`, the result will also count the hidden relations - #[args(include_hidden = false)] + #[pyo3(signature=(include_hidden = false))] fn num_relations(&self, include_hidden: bool) -> usize { if include_hidden { match_context!(&self.ctx, c, c.num_all_relations()) @@ -502,7 +502,7 @@ impl Context { /// Get a list of relations in this context. /// If `include_hidden` is set `true`, the result will also include the hidden relations - #[args(include_hidden = false)] + #[pyo3(signature=(include_hidden = false))] fn relations(&self, include_hidden: bool) -> Vec { if include_hidden { match_context!(&self.ctx, c, c.all_relations()) diff --git a/etc/scallopy/src/custom_tag.rs b/etc/scallopy/src/custom_tag.rs index 6f362e9..c9e5743 100644 --- a/etc/scallopy/src/custom_tag.rs +++ b/etc/scallopy/src/custom_tag.rs @@ -1,10 +1,12 @@ use pyo3::{Py, PyAny, Python}; use scallop_core::runtime::provenance; +/// The custom tag which holds an arbitrary python object. #[derive(Clone, Debug)] pub struct CustomTag(pub Py); impl CustomTag { + /// Create a new custom tag with a python object. pub fn new(tag: Py) -> Self { Self(tag) } @@ -16,8 +18,10 @@ impl std::fmt::Display for CustomTag { } } +/// Custom tag is a tag impl provenance::Tag for CustomTag {} +/// The custom provenance which is a wrapper of a python class #[derive(Clone, Debug)] pub struct CustomProvenance(pub Py); @@ -32,13 +36,14 @@ impl provenance::Provenance for CustomProvenance { "scallopy-custom" } + /// Invoking the provenance's tagging function on the input tag fn tagging_fn(&self, i: Self::InputTag) -> Self::Tag { Python::with_gil(|py| { - let result = self.0.call_method(py, "tagging_fn", (i,), None).unwrap(); - Self::Tag::new(result) + Self::Tag::new(self.0.call_method(py, "tagging_fn", (i,), None).unwrap()) }) } + /// Invoking the provenance's recover function on an internal tag fn recover_fn(&self, t: &Self::Tag) -> Self::OutputTag { Python::with_gil(|py| { self @@ -50,6 +55,7 @@ impl provenance::Provenance for CustomProvenance { }) } + /// Invoking the provenance's discard function on an internal tag fn discard(&self, t: &Self::Tag) -> bool { Python::with_gil(|py| { self diff --git a/etc/scallopy/src/foreign_function.rs b/etc/scallopy/src/foreign_function.rs index 8cc7954..2020f5f 100644 --- a/etc/scallopy/src/foreign_function.rs +++ b/etc/scallopy/src/foreign_function.rs @@ -42,7 +42,7 @@ impl ForeignFunction for PythonForeignFunction { .getattr(py, "generic_type_params") .expect("Cannot get foreign function generic type parameters"); let generic_type_params: &PyList = generic_type_params - .cast_as::(py) + .downcast::(py) .expect("Cannot cast into PyList"); generic_type_params.len() }) @@ -55,7 +55,7 @@ impl ForeignFunction for PythonForeignFunction { .getattr(py, "generic_type_params") .expect("Cannot get foreign function generic type parameters"); let generic_type_params: &PyList = generic_type_params - .cast_as::(py) + .downcast::(py) .expect("Cannot cast into PyList"); let param: String = generic_type_params .get_item(i) @@ -80,7 +80,7 @@ impl ForeignFunction for PythonForeignFunction { .ff .getattr(py, "static_arg_types") .expect("Cannot get foreign function static arg types"); - let static_arg_types: &PyList = static_arg_types.cast_as::(py).expect("Cannot cast into PyList"); + let static_arg_types: &PyList = static_arg_types.downcast::(py).expect("Cannot cast into PyList"); static_arg_types.len() }) } @@ -91,7 +91,7 @@ impl ForeignFunction for PythonForeignFunction { .ff .getattr(py, "static_arg_types") .expect("Cannot get foreign function static arg types"); - let static_arg_types: &PyList = static_arg_types.cast_as::(py).expect("Cannot cast into PyList"); + let static_arg_types: &PyList = static_arg_types.downcast::(py).expect("Cannot cast into PyList"); let param_type: PyObject = static_arg_types .get_item(i) .expect("Cannot get i-th param") @@ -108,7 +108,7 @@ impl ForeignFunction for PythonForeignFunction { .getattr(py, "optional_arg_types") .expect("Cannot get foreign function optional arg types"); let optional_arg_types: &PyList = optional_arg_types - .cast_as::(py) + .downcast::(py) .expect("Cannot cast into PyList"); optional_arg_types.len() }) @@ -121,7 +121,7 @@ impl ForeignFunction for PythonForeignFunction { .getattr(py, "optional_arg_types") .expect("Cannot get foreign function optional arg types"); let optional_arg_types: &PyList = optional_arg_types - .cast_as::(py) + .downcast::(py) .expect("Cannot cast into PyList"); let param_type: PyObject = optional_arg_types .get_item(i) diff --git a/etc/scallopy/src/tuple.rs b/etc/scallopy/src/tuple.rs index 8cd205c..ad3cb0e 100644 --- a/etc/scallopy/src/tuple.rs +++ b/etc/scallopy/src/tuple.rs @@ -12,7 +12,7 @@ use scallop_core::utils; pub fn from_python_tuple(v: &PyAny, ty: &TupleType) -> PyResult { match ty { TupleType::Tuple(ts) => { - let tup: &PyTuple = v.cast_as()?; + let tup: &PyTuple = v.downcast()?; if tup.len() == ts.len() { let elems = ts .iter() diff --git a/etc/sclc/src/main.rs b/etc/sclc/src/main.rs deleted file mode 100644 index a191f37..0000000 --- a/etc/sclc/src/main.rs +++ /dev/null @@ -1,40 +0,0 @@ -#![feature(path_file_prefix)] - -mod exec; -mod options; -mod py; - -use structopt::StructOpt; - -use scallop_core::compiler; - -fn main() { - // Command line arguments - let opt = options::Options::from_args(); - - // Compile - let compile_opt = compiler::CompileOptions::from(&opt); - let ram = match compiler::compile_file_to_ram_with_options(&opt.input, &compile_opt) { - Ok(ram) => ram, - Err(errs) => { - for err in errs { - println!("{}", err); - } - return; - } - }; - - // Turn the ram module into a sequence of rust tokens - let module = ram.to_rs_module(&compile_opt); - - // Print the module string if debugging - if opt.debug_rs { - println!("{}", module); - } - - // Depending on the mode, create artifacts - match opt.mode.as_str() { - "executable" => exec::create_executable(&opt, &ram, module), - m => panic!("Unknown compilation mode --mode `{}`", m), - }; -} diff --git a/examples/.gitignore b/examples/.gitignore index df54770..fafdea3 100644 --- a/examples/.gitignore +++ b/examples/.gitignore @@ -1,2 +1,3 @@ output_csv temp_scl +playground diff --git a/examples/datalog/all_cube_is_blue.scl b/examples/datalog/all_cube_is_blue.scl new file mode 100644 index 0000000..590dfb0 --- /dev/null +++ b/examples/datalog/all_cube_is_blue.scl @@ -0,0 +1,5 @@ +rel obj = {0, 1, 2} +rel shape = {(0, "cube"), (1, "sphere"), (2, "cube")} +rel color = {(0, "blue"), (1, "red"), (2, "blue")} + +rel result(b) :- b = forall(o: shape(o, "cube") => color(o, "blue")) diff --git a/examples/datalog/count_where.scl b/examples/datalog/count_where.scl new file mode 100644 index 0000000..71d47c2 --- /dev/null +++ b/examples/datalog/count_where.scl @@ -0,0 +1,24 @@ +// There are three classes +rel classes = {0, 1, 2} + +// There are 6 students, 2 in each class +rel student = { + (0, "tom"), (0, "jenny"), // Class 0 + (1, "alice"), (1, "bob"), // Class 1 + (2, "jerry"), (2, "john"), // Class 2 +} + +// Each student is enrolled in a course (Math or CS) +rel enroll = { + ("tom", "CS"), ("jenny", "Math"), // Class 0 + ("alice", "CS"), ("bob", "CS"), // Class 1 + ("jerry", "Math"), ("john", "Math"), // Class 2 +} + +// Count how many student enrolls in CS course in each class +rel count_enroll_cs_in_class(c, n) :- + n = count(s: student(c, s), enroll(s, "CS") where c: classes(c)) + +// Expected: {(0, 1), (1, 2), (2, 0)} +// Interpretation: class 0 has 1 student enroll in CS, class 1 has 2, class 2 has 0 +query count_enroll_cs_in_class diff --git a/examples/datalog/edge_path.scl b/examples/datalog/edge_path.scl new file mode 100644 index 0000000..0d977dc --- /dev/null +++ b/examples/datalog/edge_path.scl @@ -0,0 +1,4 @@ +rel edge = {(0, 1), (1, 2), (2, 3)} +rel path(a, b) = edge(a, b) +rel path(a, c) = path(a, b) and edge(b, c) +query path diff --git a/examples/datalog/edge_path_undir.scl b/examples/datalog/edge_path_undir.scl new file mode 100644 index 0000000..e91a048 --- /dev/null +++ b/examples/datalog/edge_path_undir.scl @@ -0,0 +1,4 @@ +rel edge = {(0, 1), (1, 2), (2, 3)} +rel path(a, b) = edge(a, b) or (edge(a, c) and path(c, b)) +rel undir_path(a, b) = path(a, b) \/ path(b, a) +query undir_path diff --git a/examples/datalog/evaluate_formula.scl b/examples/datalog/evaluate_formula.scl new file mode 100644 index 0000000..122c38b --- /dev/null +++ b/examples/datalog/evaluate_formula.scl @@ -0,0 +1,33 @@ +// Inputs +type symbol(usize, String) +type length(usize) + +// Facts for lexing +rel digit = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"} + +type term(value: f32, begin: usize, end: usize) +rel term(x as f32, b, b + 1) = symbol(b, x) and digit(x) + +type mult_div(value: f32, begin: usize, end: usize) +rel mult_div(x, b, r) = term(x, b, r) +rel mult_div(x * y, b, e) = mult_div(x, b, m) and symbol(m, "*") and term(y, m + 1, e) +rel mult_div(x / y, b, e) = mult_div(x, b, m) and symbol(m, "/") and term(y, m + 1, e) + +type add_minus(value: f32, begin: usize, end: usize) +rel add_minus(x, b, r) = mult_div(x, b, r) +rel add_minus(x + y, b, e) = add_minus(x, b, m) and symbol(m, "+") and mult_div(y, m + 1, e) +rel add_minus(x - y, b, e) = add_minus(x, b, m) and symbol(m, "-") and mult_div(y, m + 1, e) + +type result(value: f32) +rel result(y) = add_minus(y, 0, l) and length(l) + +// =============================================== // + +// Testing related +type test_string(String) +rel length($string_length(s)) = test_string(s) +rel symbol(0, $string_char_at(s, 0) as String) = test_string(s), $string_length(s) > 0 +rel symbol(i, $string_char_at(s, i) as String) = symbol(i - 1, _), test_string(s), $string_length(s) > i + +rel test_string("123/24+1") +query result diff --git a/examples/demo_scl/exists_blue_obj.scl b/examples/datalog/exists_blue_obj.scl similarity index 100% rename from examples/demo_scl/exists_blue_obj.scl rename to examples/datalog/exists_blue_obj.scl diff --git a/examples/demo_scl/fib_dt_1.scl b/examples/datalog/fibonacci.scl similarity index 100% rename from examples/demo_scl/fib_dt_1.scl rename to examples/datalog/fibonacci.scl diff --git a/examples/datalog/kinship.scl b/examples/datalog/kinship.scl new file mode 100644 index 0000000..398e276 --- /dev/null +++ b/examples/datalog/kinship.scl @@ -0,0 +1,21 @@ +const MOTHER = 1 +const FATHER = 2 +const GRANDMOTHER = 3 +const GRANDFATHER = 4 + +rel transitive = { + (MOTHER, MOTHER, GRANDMOTHER), + (FATHER, FATHER, GRANDFATHER), + (FATHER, MOTHER, GRANDMOTHER), + (MOTHER, FATHER, GRANDFATHER), +} + +rel context = { + (FATHER, "bob", "john"), + (MOTHER, "john", "alice"), +} + +rel derived(r, a, b) :- context(r, a, b) +rel derived(r3, a, b) :- derived(r1, a, c), derived(r2, c, b), transitive(r1, r2, r3) + +query derived(r, "bob", "alice") diff --git a/examples/datalog/pacman_maze_example.scl b/examples/datalog/pacman_maze_example.scl new file mode 100644 index 0000000..241f464 --- /dev/null +++ b/examples/datalog/pacman_maze_example.scl @@ -0,0 +1,53 @@ +// Input from neural networks +type grid_node(x: usize, y: usize) +type curr_position(x: usize, y: usize) +type goal_position(x: usize, y: usize) +type is_enemy(x: usize, y: usize) + +// Constants +const UP = 0 +const RIGHT = 1 +const DOWN = 2 +const LEFT = 3 + +// Basic connectivity +rel node(x, y) = grid_node(x, y), not is_enemy(x, y) +rel edge(x, y, x, yp, UP) = node(x, y), node(x, yp), yp == y + 1 +rel edge(x, y, xp, y, RIGHT) = node(x, y), node(xp, y), xp == x + 1 +rel edge(x, y, x, yp, DOWN) = node(x, y), node(x, yp), yp == y - 1 +rel edge(x, y, xp, y, LEFT) = node(x, y), node(xp, y), xp == x - 1 + +// Path for connectivity; will condition on no enemy on the path +rel path(x, y, x, y) = node(x, y) +rel path(x, y, xp, yp) = edge(x, y, xp, yp, _) +rel path(x, y, xpp, ypp) = path(x, y, xp, yp), edge(xp, yp, xpp, ypp, _) + +// Get the next position +rel next_position(a, xp, yp) = curr_position(x, y), edge(x, y, xp, yp, a) +rel action_score(a) = next_position(a, x, y), goal_position(gx, gy), path(x, y, gx, gy) + +// ============ EXAMPLE ============ + +// The following example denotes the following arena +// +// * - - +// E E - +// x - - +// +// where "*" is the goal and "x" is the current position. +// The agent needs to avoid the enemies ("E") +// and therefore the best action is to go RIGHT (represented by integer 1) + +rel grid_node = { + (0, 2), (1, 2), (2, 2), + (0, 1), (1, 1), (2, 1), + (0, 0), (1, 0), (2, 0), +} + +rel is_enemy = {(0, 1), (1, 1)} + +rel goal_position = {(0, 2)} + +rel curr_position = {(0, 0)} + +query action_score diff --git a/examples/datalog/regex.scl b/examples/datalog/regex.scl new file mode 100644 index 0000000..3de7797 --- /dev/null +++ b/examples/datalog/regex.scl @@ -0,0 +1,48 @@ +// =========== REGEX =========== +type regex_char(id: usize, c: char) +type regex_concat(id: usize, left: usize, right: usize) +type regex_union(id: usize, left: usize, right: usize) +type regex_star(id: usize, child: usize) +type regex_root(id: usize) + +// Match a single char +rel matches_substr(expr, start, start + 1) :- regex_char(expr, c), char_at(start, c) + +// Match a concatenation +rel matches_substr(expr, l, r) :- regex_concat(expr, le, re), matches_substr(le, l, m), matches_substr(re, m, r) + +// Match a union +rel matches_substr(expr, l, r) :- regex_union(expr, a, b), matches_substr(a, l, r) +rel matches_substr(expr, l, r) :- regex_union(expr, a, b), matches_substr(b, l, r) + +// Match a star +rel matches_substr(expr, i, i) :- regex_star(expr, _), range(0, l + 1, i), input_string(s), strlen(s, l) +rel matches_substr(expr, l, r) :- regex_star(expr, c), matches_substr(c, l, r) +rel matches_substr(expr, l, r) :- regex_star(expr, c), matches_substr(expr, l, m), matches_substr(c, m, r) + +// Matches the whole string +rel matches(true) :- input_string(s), strlen(s, l), regex_root(e), matches_substr(e, 0, l) + +// =========== STRING =========== +type input_string(s: String) +rel char_at(i, $string_char_at(s, i)) :- input_string(s), strlen(s, l), range(0, l, i) + +// =========== HELPER =========== +@demand("bbf") +rel range(a, b, i) :- i == a +rel range(a, b, i) :- range(a, b, i - 1), i < b +@demand("bf") +rel strlen(s, i) :- i == $string_length(s) + +// =========== EXAMPLE =========== +rel regex_char(0, 'a') +rel regex_char(1, 'b') +rel regex_concat(2, 0, 1) +rel regex_concat(3, 2, 0) +rel regex_star(4, 1) +rel regex_concat(5, 3, 4) +rel regex_root(5) + +rel input_string("ababbbb") + +query matches diff --git a/examples/demo_scl/type_inference.scl b/examples/datalog/type_inference.scl similarity index 100% rename from examples/demo_scl/type_inference.scl rename to examples/datalog/type_inference.scl diff --git a/examples/demo_scl/all_cube_is_blue.scl b/examples/demo_scl/all_cube_is_blue.scl deleted file mode 100644 index e69de29..0000000 diff --git a/examples/demo_scl/linked_list.scl b/examples/demo_scl/linked_list.scl deleted file mode 100644 index e69de29..0000000 diff --git a/examples/bug_scl/arity-mismatch-1/1.scl b/examples/legacy/bug_scl/arity-mismatch-1/1.scl similarity index 100% rename from examples/bug_scl/arity-mismatch-1/1.scl rename to examples/legacy/bug_scl/arity-mismatch-1/1.scl diff --git a/examples/bug_scl/arity-mismatch-1/2.scl b/examples/legacy/bug_scl/arity-mismatch-1/2.scl similarity index 100% rename from examples/bug_scl/arity-mismatch-1/2.scl rename to examples/legacy/bug_scl/arity-mismatch-1/2.scl diff --git a/examples/legacy/bug_scl/char_at_error/bug.scl b/examples/legacy/bug_scl/char_at_error/bug.scl new file mode 100644 index 0000000..4d876d2 --- /dev/null +++ b/examples/legacy/bug_scl/char_at_error/bug.scl @@ -0,0 +1,3 @@ +rel input("1357") +rel string_char_at(0, $string_char_at(s, 0)) :- input(s), 0 < $string_length(s) +rel string_char_at(i, $string_char_at(s, i)) :- input(s), i < $string_length(s), string_char_at(i - 1, _) diff --git a/examples/legacy/bug_scl/expr_test_1/bug.scl b/examples/legacy/bug_scl/expr_test_1/bug.scl new file mode 100644 index 0000000..b8a0315 --- /dev/null +++ b/examples/legacy/bug_scl/expr_test_1/bug.scl @@ -0,0 +1,9 @@ +rel eval(e, c) = constant(e, c) +rel eval(e, a + b) = binary(e, "+", l, r), eval(l, a), eval(r, b) +rel eval(e, a - b) = binary(e, "-", l, r), eval(l, a), eval(r, b) +rel result(y) = eval(e, y), goal(e) + +rel constant = { (0, 1), (1, 2), (2, 3) } +rel binary = { (3, "+", 0, 1), (4, "-", 3, 2) } +rel goal(4) +query result \ No newline at end of file diff --git a/examples/bug_scl/io-issue-18/bug.scl b/examples/legacy/bug_scl/io-issue-18/bug.scl similarity index 100% rename from examples/bug_scl/io-issue-18/bug.scl rename to examples/legacy/bug_scl/io-issue-18/bug.scl diff --git a/examples/bug_scl/io-issue-23/ok.scl b/examples/legacy/bug_scl/io-issue-23/ok.scl similarity index 100% rename from examples/bug_scl/io-issue-23/ok.scl rename to examples/legacy/bug_scl/io-issue-23/ok.scl diff --git a/examples/bug_scl/io-issue-3/bad-1.scl b/examples/legacy/bug_scl/io-issue-3/bad-1.scl similarity index 100% rename from examples/bug_scl/io-issue-3/bad-1.scl rename to examples/legacy/bug_scl/io-issue-3/bad-1.scl diff --git a/examples/bug_scl/io-issue-3/good.scl b/examples/legacy/bug_scl/io-issue-3/good.scl similarity index 100% rename from examples/bug_scl/io-issue-3/good.scl rename to examples/legacy/bug_scl/io-issue-3/good.scl diff --git a/examples/bug_scl/io-issue-4/bug.scl b/examples/legacy/bug_scl/io-issue-4/bug.scl similarity index 100% rename from examples/bug_scl/io-issue-4/bug.scl rename to examples/legacy/bug_scl/io-issue-4/bug.scl diff --git a/examples/bug_scl/io-issue-7/bug.scl b/examples/legacy/bug_scl/io-issue-7/bug.scl similarity index 100% rename from examples/bug_scl/io-issue-7/bug.scl rename to examples/legacy/bug_scl/io-issue-7/bug.scl diff --git a/examples/bug_scl/io-issue-7/ok.scl b/examples/legacy/bug_scl/io-issue-7/ok.scl similarity index 100% rename from examples/bug_scl/io-issue-7/ok.scl rename to examples/legacy/bug_scl/io-issue-7/ok.scl diff --git a/examples/bug_scl/readme.md b/examples/legacy/bug_scl/readme.md similarity index 100% rename from examples/bug_scl/readme.md rename to examples/legacy/bug_scl/readme.md diff --git a/examples/demo_scl/agent_pathfinding.scl b/examples/legacy/demo_scl/agent_pathfinding.scl similarity index 89% rename from examples/demo_scl/agent_pathfinding.scl rename to examples/legacy/demo_scl/agent_pathfinding.scl index 357aa3d..0d117da 100644 --- a/examples/demo_scl/agent_pathfinding.scl +++ b/examples/legacy/demo_scl/agent_pathfinding.scl @@ -20,7 +20,7 @@ rel next_state(b) = curr_state(a), edge(a, b) type goal(usize) // Whether we can reach the goal from next_state -rel can_reach_goal(x) = next_state(x), goal(y), reach(x, y), not enemy(x) +rel can_reach_goal() = curr_state(x), goal(y), reach(x, y), not enemy(x) ///////////// Example 1 ///////////// @@ -43,7 +43,7 @@ rel edge = { // There are enemies in 4, 5, 6 rel enemy = { 0.1::1, 0.1::2, 0.1::3, - 0.2::4, 0.9::5, 0.9::6, + 0.1::4, 0.9::5, 0.9::6, 0.1::7, 0.1::8, 0.1::9, } @@ -51,6 +51,6 @@ rel enemy = { rel goal(3) // We want to start from node 5, 7, or 9 -rel curr_state(8) +rel curr_state(9) query can_reach_goal diff --git a/examples/legacy/demo_scl/all_cube_is_blue.scl b/examples/legacy/demo_scl/all_cube_is_blue.scl new file mode 100644 index 0000000..590dfb0 --- /dev/null +++ b/examples/legacy/demo_scl/all_cube_is_blue.scl @@ -0,0 +1,5 @@ +rel obj = {0, 1, 2} +rel shape = {(0, "cube"), (1, "sphere"), (2, "cube")} +rel color = {(0, "blue"), (1, "red"), (2, "blue")} + +rel result(b) :- b = forall(o: shape(o, "cube") => color(o, "blue")) diff --git a/examples/legacy/demo_scl/avoiding_enemy_arena.scl b/examples/legacy/demo_scl/avoiding_enemy_arena.scl new file mode 100644 index 0000000..241f464 --- /dev/null +++ b/examples/legacy/demo_scl/avoiding_enemy_arena.scl @@ -0,0 +1,53 @@ +// Input from neural networks +type grid_node(x: usize, y: usize) +type curr_position(x: usize, y: usize) +type goal_position(x: usize, y: usize) +type is_enemy(x: usize, y: usize) + +// Constants +const UP = 0 +const RIGHT = 1 +const DOWN = 2 +const LEFT = 3 + +// Basic connectivity +rel node(x, y) = grid_node(x, y), not is_enemy(x, y) +rel edge(x, y, x, yp, UP) = node(x, y), node(x, yp), yp == y + 1 +rel edge(x, y, xp, y, RIGHT) = node(x, y), node(xp, y), xp == x + 1 +rel edge(x, y, x, yp, DOWN) = node(x, y), node(x, yp), yp == y - 1 +rel edge(x, y, xp, y, LEFT) = node(x, y), node(xp, y), xp == x - 1 + +// Path for connectivity; will condition on no enemy on the path +rel path(x, y, x, y) = node(x, y) +rel path(x, y, xp, yp) = edge(x, y, xp, yp, _) +rel path(x, y, xpp, ypp) = path(x, y, xp, yp), edge(xp, yp, xpp, ypp, _) + +// Get the next position +rel next_position(a, xp, yp) = curr_position(x, y), edge(x, y, xp, yp, a) +rel action_score(a) = next_position(a, x, y), goal_position(gx, gy), path(x, y, gx, gy) + +// ============ EXAMPLE ============ + +// The following example denotes the following arena +// +// * - - +// E E - +// x - - +// +// where "*" is the goal and "x" is the current position. +// The agent needs to avoid the enemies ("E") +// and therefore the best action is to go RIGHT (represented by integer 1) + +rel grid_node = { + (0, 2), (1, 2), (2, 2), + (0, 1), (1, 1), (2, 1), + (0, 0), (1, 0), (2, 0), +} + +rel is_enemy = {(0, 1), (1, 1)} + +rel goal_position = {(0, 2)} + +rel curr_position = {(0, 0)} + +query action_score diff --git a/examples/legacy/demo_scl/avoiding_enemy_arena_prob.scl b/examples/legacy/demo_scl/avoiding_enemy_arena_prob.scl new file mode 100644 index 0000000..767f832 --- /dev/null +++ b/examples/legacy/demo_scl/avoiding_enemy_arena_prob.scl @@ -0,0 +1,50 @@ +// Input from neural networks +type grid_node(x: i8, y: i8) +type is_agent(x: i8, y: i8) +type is_goal(x: i8, y: i8) +type is_enemy(x: i8, y: i8) + +// Constants +const UP = 0, RIGHT = 1, DOWN = 2, LEFT = 3 + +// Basic connectivity +rel node(x, y) = grid_node(x, y), not is_enemy(x, y) +rel edge(x, y, x, yp, UP) = node(x, y), node(x, yp), yp == y + 1 +rel edge(x, y, xp, y, RIGHT) = node(x, y), node(xp, y), xp == x + 1 +rel edge(x, y, x, yp, DOWN) = node(x, y), node(x, yp), yp == y - 1 +rel edge(x, y, xp, y, LEFT) = node(x, y), node(xp, y), xp == x - 1 + +// Path for connectivity; will condition on no enemy on the path +rel path(x, y, x, y) = node(x, y) +rel path(x, y, xp, yp) = edge(x, y, xp, yp, _) +rel path(x, y, xpp, ypp) = path(x, y, xp, yp), edge(xp, yp, xpp, ypp, _) + +// Get the next position +rel next_position(a, xp, yp) = is_agent(x, y), edge(x, y, xp, yp, a) +rel expected_reward(a) = next_position(a, x, y), is_goal(gx, gy), path(x, y, gx, gy) + +// ============ EXAMPLE ============ + +// The following example denotes the following arena +// +// g - - +// E E - +// a - - +// +// where "g" is the goal and "a" is the current position. +// The agent needs to avoid the enemies ("E") +// and therefore the best action is to go RIGHT (represented by integer 1) + +rel grid_node = { + 0.95::(0, 2), 0.95::(1, 2), 0.95::(2, 2), + 0.95::(0, 1), 0.95::(1, 1), 0.95::(2, 1), + 0.95::(0, 0), 0.95::(1, 0), 0.95::(2, 0), +} + +rel is_enemy = {0.99::(0, 1), 0.99::(1, 1)} + +rel is_agent = {0.99::(0, 0)} + +rel is_goal = {0.98::(0, 2)} + +query expected_reward diff --git a/examples/demo_scl/count_where.scl b/examples/legacy/demo_scl/count_where.scl similarity index 100% rename from examples/demo_scl/count_where.scl rename to examples/legacy/demo_scl/count_where.scl diff --git a/examples/demo_scl/datalog.scl b/examples/legacy/demo_scl/datalog.scl similarity index 100% rename from examples/demo_scl/datalog.scl rename to examples/legacy/demo_scl/datalog.scl diff --git a/examples/demo_scl/disjunction.scl b/examples/legacy/demo_scl/disjunction.scl similarity index 100% rename from examples/demo_scl/disjunction.scl rename to examples/legacy/demo_scl/disjunction.scl diff --git a/examples/demo_scl/edge_path_csv_1.scl b/examples/legacy/demo_scl/edge_path_csv_1.scl similarity index 83% rename from examples/demo_scl/edge_path_csv_1.scl rename to examples/legacy/demo_scl/edge_path_csv_1.scl index 4ab9566..849b6e5 100644 --- a/examples/demo_scl/edge_path_csv_1.scl +++ b/examples/legacy/demo_scl/edge_path_csv_1.scl @@ -1,5 +1,5 @@ @file("examples/input_csv/edge.csv") -input edge(usize, usize) +type edge(usize, usize) @demand("bf") rel path(a, c) = edge(a, c) \/ path(a, b) /\ edge(b, c) diff --git a/examples/demo_scl/edge_path_csv_2.scl b/examples/legacy/demo_scl/edge_path_csv_2.scl similarity index 88% rename from examples/demo_scl/edge_path_csv_2.scl rename to examples/legacy/demo_scl/edge_path_csv_2.scl index 5feef22..59d3a63 100644 --- a/examples/demo_scl/edge_path_csv_2.scl +++ b/examples/legacy/demo_scl/edge_path_csv_2.scl @@ -1,5 +1,5 @@ @file("examples/input_csv/edge_prob.csv", deliminator = "\t", has_header = true, has_probability = true) -input edge(usize, usize) +type edge(usize, usize) @demand("bf") rel path(a, c) = edge(a, c) \/ path(a, b) /\ edge(b, c) diff --git a/examples/demo_scl/edge_path_dt_1.scl b/examples/legacy/demo_scl/edge_path_dt_1.scl similarity index 100% rename from examples/demo_scl/edge_path_dt_1.scl rename to examples/legacy/demo_scl/edge_path_dt_1.scl diff --git a/examples/legacy/demo_scl/exists_blue_obj.scl b/examples/legacy/demo_scl/exists_blue_obj.scl new file mode 100644 index 0000000..44b76f5 --- /dev/null +++ b/examples/legacy/demo_scl/exists_blue_obj.scl @@ -0,0 +1,13 @@ +// A set of all the shapes +rel all_shapes = {"cube", "cylinder", "sphere"} + +// Each object has two attributes: color and shape +rel color = {(0, "red"), (1, "green"), (2, "blue"), (3, "blue")} +rel shape = {(0, "cube"), (1, "cylinder"), (2, "sphere"), (3, "cube")} + +// Is there a blue object? +rel exists_blue_obj(b) = b = exists(o: color(o, "blue")) + +// For each shape, is there a blue object of that shape? +rel exists_blue_obj_of_shape(s, b) :- + b = exists(o: color(o, "blue"), shape(o, s) where s: all_shapes(s)) diff --git a/examples/legacy/demo_scl/fib_dt_1.scl b/examples/legacy/demo_scl/fib_dt_1.scl new file mode 100644 index 0000000..0081095 --- /dev/null +++ b/examples/legacy/demo_scl/fib_dt_1.scl @@ -0,0 +1,4 @@ +@demand("bf") +rel fib(x, a + b) = fib(x - 1, a), fib(x - 2, b), x > 1 +rel fib = {(0, 1), (1, 1)} +query fib(10, y) diff --git a/examples/legacy/demo_scl/kinship.scl b/examples/legacy/demo_scl/kinship.scl new file mode 100644 index 0000000..398e276 --- /dev/null +++ b/examples/legacy/demo_scl/kinship.scl @@ -0,0 +1,21 @@ +const MOTHER = 1 +const FATHER = 2 +const GRANDMOTHER = 3 +const GRANDFATHER = 4 + +rel transitive = { + (MOTHER, MOTHER, GRANDMOTHER), + (FATHER, FATHER, GRANDFATHER), + (FATHER, MOTHER, GRANDMOTHER), + (MOTHER, FATHER, GRANDFATHER), +} + +rel context = { + (FATHER, "bob", "john"), + (MOTHER, "john", "alice"), +} + +rel derived(r, a, b) :- context(r, a, b) +rel derived(r3, a, b) :- derived(r1, a, c), derived(r2, c, b), transitive(r1, r2, r3) + +query derived(r, "bob", "alice") diff --git a/examples/demo_scl/lambda.scl b/examples/legacy/demo_scl/lambda.scl similarity index 100% rename from examples/demo_scl/lambda.scl rename to examples/legacy/demo_scl/lambda.scl diff --git a/examples/legacy/demo_scl/linked_list.scl b/examples/legacy/demo_scl/linked_list.scl new file mode 100644 index 0000000..04e43e2 --- /dev/null +++ b/examples/legacy/demo_scl/linked_list.scl @@ -0,0 +1,16 @@ +type LinkedList <: usize +type cons(LinkedList, i32, LinkedList) +type nil(LinkedList) + +type length(LinkedList, usize) +rel length(list, 0) :- nil(list) +rel length(list, l + 1) :- cons(list, _, tail), length(tail, l) + +// ==== Example ==== +const L1 = 0 +const L2 = 1 +const L3 = 2 +const L4 = 3 +rel nil = {L1} +rel cons = {(L2, 10, L1), (L3, 20, L2), (L4, 30, L3)} +query length(L4, l) diff --git a/examples/demo_scl/mnist_add_sub.scl b/examples/legacy/demo_scl/mnist_add_sub.scl similarity index 100% rename from examples/demo_scl/mnist_add_sub.scl rename to examples/legacy/demo_scl/mnist_add_sub.scl diff --git a/examples/legacy/demo_scl/multi_digit_hwf.scl b/examples/legacy/demo_scl/multi_digit_hwf.scl new file mode 100644 index 0000000..e0e5a77 --- /dev/null +++ b/examples/legacy/demo_scl/multi_digit_hwf.scl @@ -0,0 +1,49 @@ +// Inputs +type symbol(usize, String) +type length(usize) + +// Facts for lexing +rel digit = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9"} +rel mult_div = {"*", "/"} +rel plus_minus = {"+", "-"} + +// Parsing +type value_node(id: u64, string: String, begin: usize, end: usize) +rel value_node($hash(x, d), d, x, x + 1) = symbol(x, d), digit(d) +rel value_node($hash(joint, b - 1, e), joint, b - 1, e) = + symbol(b - 1, dh), digit(dh), value_node(x, dr, b, e), joint == $string_concat(dh, dr) + +type mult_div_node(id: u64, string: String, left_node: u64, right_node: u64, begin: usize, end: usize) +rel mult_div_node(id, string, 0, 0, b, e) = value_node(id, string, b, e) +rel mult_div_node($hash(id, s, l, r), s, l, r, b, e) = + symbol(id, s), mult_div(s), mult_div_node(l, _, _, _, b, id), value_node(r, _, id + 1, e) + +type plus_minus_node(id: u64, string: String, left_node: u64, right_node: u64, begin: usize, end: usize) +rel plus_minus_node(id, string, l, r, b, e) = mult_div_node(id, string, l, r, b, e) +rel plus_minus_node($hash(id, s, l, r), s, l, r, b, e) = + symbol(id, s), plus_minus(s), plus_minus_node(l, _, _, _, b, id), mult_div_node(r, _, _, _, id + 1, e) + +type root_node(id: u64) +rel root_node(id) = plus_minus_node(id, _, _, _, 0, l), length(l) + +// Evaluate AST +@demand("bf") +rel eval(x, s as f64) = value_node(x, s, _, _) +rel eval(x, y1 + y2) = plus_minus_node(x, "+", l, r, _, _), eval(l, y1), eval(r, y2) +rel eval(x, y1 - y2) = plus_minus_node(x, "-", l, r, _, _), eval(l, y1), eval(r, y2) +rel eval(x, y1 * y2) = mult_div_node(x, "*", l, r, _, _), eval(l, y1), eval(r, y2) +rel eval(x, y1 / y2) = mult_div_node(x, "/", l, r, _, _), eval(l, y1), eval(r, y2), y2 != 0.0 + +// Compute result +rel result(y) = eval(e, y), root_node(e) + +// =============================================== // + +// Testing related +type test_string(String) +rel length($string_length(s)) = test_string(s) +rel symbol(0, $string_char_at(s, 0) as String) = test_string(s), $string_length(s) > 0 +rel symbol(i, $string_char_at(s, i) as String) = symbol(i - 1, _), test_string(s), $string_length(s) > i + +rel test_string("123/24+1") +query result diff --git a/examples/demo_scl/music_theory.scl b/examples/legacy/demo_scl/music_theory.scl similarity index 100% rename from examples/demo_scl/music_theory.scl rename to examples/legacy/demo_scl/music_theory.scl diff --git a/examples/demo_scl/negate_1.scl b/examples/legacy/demo_scl/negate_1.scl similarity index 100% rename from examples/demo_scl/negate_1.scl rename to examples/legacy/demo_scl/negate_1.scl diff --git a/examples/demo_scl/no_green_between.scl b/examples/legacy/demo_scl/no_green_between.scl similarity index 100% rename from examples/demo_scl/no_green_between.scl rename to examples/legacy/demo_scl/no_green_between.scl diff --git a/examples/demo_scl/output_1.scl b/examples/legacy/demo_scl/output_1.scl similarity index 92% rename from examples/demo_scl/output_1.scl rename to examples/legacy/demo_scl/output_1.scl index fd36d8c..aa84821 100644 --- a/examples/demo_scl/output_1.scl +++ b/examples/legacy/demo_scl/output_1.scl @@ -2,4 +2,4 @@ rel edge = {(0, 1), (1, 2), (2, 3)} rel path(a, b) = edge(a, b) \/ path(a, c) /\ edge(c, b) @file("examples/output_csv/output_1_path.csv") -output path +query path diff --git a/examples/demo_scl/prob_rule_1.scl b/examples/legacy/demo_scl/prob_rule_1.scl similarity index 100% rename from examples/demo_scl/prob_rule_1.scl rename to examples/legacy/demo_scl/prob_rule_1.scl diff --git a/examples/demo_scl/query_atom_2.scl b/examples/legacy/demo_scl/query_atom_2.scl similarity index 100% rename from examples/demo_scl/query_atom_2.scl rename to examples/legacy/demo_scl/query_atom_2.scl diff --git a/examples/demo_scl/range_dt_1.scl b/examples/legacy/demo_scl/range_dt_1.scl similarity index 100% rename from examples/demo_scl/range_dt_1.scl rename to examples/legacy/demo_scl/range_dt_1.scl diff --git a/examples/legacy/demo_scl/regex.scl b/examples/legacy/demo_scl/regex.scl new file mode 100644 index 0000000..da12b7e --- /dev/null +++ b/examples/legacy/demo_scl/regex.scl @@ -0,0 +1,47 @@ +// =========== REGEX =========== +type regex_char(id: usize, c: char) +type regex_concat(id: usize, left: usize, right: usize) +type regex_star(id: usize, child: usize) +type regex_root(id: usize) + +// Match a single char +rel matches_substr(expr, start, start + 1) :- regex_char(expr, c), char_at(start, c) + +// Match a concatenation +rel matches_substr(expr, l, r) :- regex_concat(expr, le, re), matches_substr(le, l, m), matches_substr(re, m, r) + +// Match a union +rel matches_substr(expr, l, r) :- regex_union(expr, a, b), matches_substr(a, l, r) +rel matches_substr(expr, l, r) :- regex_union(expr, a, b), matches_substr(b, l, r) + +// Match a star +rel matches_substr(expr, i, i) :- regex_star(expr, _), range(0, l + 1, i), input_string(s), strlen(s, l) +rel matches_substr(expr, l, r) :- regex_star(expr, c), matches_substr(c, l, r) +rel matches_substr(expr, l, r) :- regex_star(expr, c), matches_substr(expr, l, m), matches_substr(c, m, r) + +// Matches the whole string +rel matches() :- input_string(s), strlen(s, l), regex_root(e), matches_substr(e, 0, l) + +// =========== STRING =========== +type input_string(s: String) +rel char_at(i, $string_char_at(s, i)) :- input_string(s), strlen(s, l), range(0, l, i) + +// =========== HELPER =========== +@demand("bbf") +rel range(a, b, i) :- i == a +rel range(a, b, i) :- range(a, b, i - 1), i < b +@demand("bf") +rel strlen(s, i) :- i == $string_length(s) + +// =========== EXAMPLE =========== +rel regex_char(0, 'a') +rel regex_char(1, 'b') +rel regex_concat(2, 0, 1) +rel regex_concat(3, 2, 0) +rel regex_star(4, 1) +rel regex_concat(5, 3, 4) +rel regex_root(5) + +rel input_string("ababbbb") + +query matches diff --git a/examples/legacy/demo_scl/sat_1.scl b/examples/legacy/demo_scl/sat_1.scl new file mode 100644 index 0000000..f55edb5 --- /dev/null +++ b/examples/legacy/demo_scl/sat_1.scl @@ -0,0 +1,22 @@ +type assign(String, bool) + +// Assignments to variables A, B, and C +rel assign = {1.0::("A", true); 1.0::("A", false)} +rel assign = {1.0::("B", true); 1.0::("B", false)} +rel assign = {1.0::("C", true); 1.0::("C", false)} + +// Boolean formula (A and !B) or (B and !C) +rel bf_var = {(1, "A"), (2, "B"), (3, "B"), (4, "C")} +rel bf_not = {(5, 2), (6, 4)} +rel bf_and = {(7, 1, 5), (8, 3, 6)} +rel bf_or = {(9, 7, 8)} +rel bf_root = {9} + +// Evaluation +rel eval_bf(bf, r) :- bf_var(bf, v), assign(v, r) +rel eval_bf(bf, !r) :- bf_not(bf, c), eval_bf(c, r) +rel eval_bf(bf, lr && rr) :- bf_and(bf, lbf, rbf), eval_bf(lbf, lr), eval_bf(rbf, rr) +rel eval_bf(bf, lr || rr) :- bf_or(bf, lbf, rbf), eval_bf(lbf, lr), eval_bf(rbf, rr) +rel eval(r) :- bf_root(bf), eval_bf(bf, r) + +query eval diff --git a/examples/legacy/demo_scl/sat_2.scl b/examples/legacy/demo_scl/sat_2.scl new file mode 100644 index 0000000..6516c81 --- /dev/null +++ b/examples/legacy/demo_scl/sat_2.scl @@ -0,0 +1,21 @@ +type assign(String, bool) + +// Assignments to variables A and B +rel assign = {1.0::("A", true); 1.0::("A", false)} +rel assign = {1.0::("B", true); 1.0::("B", false)} + +// Boolean formula (A and !A) or (B and !B) +rel bf_var = {(1, "A"), (2, "B")} +rel bf_not = {(3, 1), (4, 2)} +rel bf_and = {(5, 1, 3), (6, 2, 4)} +rel bf_or = {(7, 5, 6)} +rel bf_root = {7} + +// Evaluation +rel eval_bf(bf, r) :- bf_var(bf, v), assign(v, r) +rel eval_bf(bf, !r) :- bf_not(bf, c), eval_bf(c, r) +rel eval_bf(bf, lr && rr) :- bf_and(bf, lbf, rbf), eval_bf(lbf, lr), eval_bf(rbf, rr) +rel eval_bf(bf, lr || rr) :- bf_or(bf, lbf, rbf), eval_bf(lbf, lr), eval_bf(rbf, rr) +rel eval(r) :- bf_root(bf), eval_bf(bf, r) + +query eval diff --git a/examples/demo_scl/stdlib.scl b/examples/legacy/demo_scl/stdlib.scl similarity index 100% rename from examples/demo_scl/stdlib.scl rename to examples/legacy/demo_scl/stdlib.scl diff --git a/examples/demo_scl/stdlib_usage_1.scl b/examples/legacy/demo_scl/stdlib_usage_1.scl similarity index 100% rename from examples/demo_scl/stdlib_usage_1.scl rename to examples/legacy/demo_scl/stdlib_usage_1.scl diff --git a/examples/legacy/demo_scl/type_inference.scl b/examples/legacy/demo_scl/type_inference.scl new file mode 100644 index 0000000..af9970e --- /dev/null +++ b/examples/legacy/demo_scl/type_inference.scl @@ -0,0 +1,52 @@ +// EXP ::= let V = EXP in EXP +// | if EXP then EXP else EXP +// | X + Y | X - Y +// | X and Y | X or Y | not X +// | X == Y | X != Y | X < Y | X <= Y | X > Y | X >= Y + +// Basic syntax constructs +type number(usize, i32) +type boolean(usize, bool) +type variable(usize, String) +type bexp(usize, String, usize, usize) +type aexp(usize, String, usize, usize) +type let_in(usize, String, usize, usize) +type if_then_else(usize, usize, usize, usize) + +// Comparison operations +rel comparison_op = {"==", "!=", ">=", "<=", ">", "<"} +rel logical_op = {"&&", "||", "^"} +rel arith_op = {"+", "-", "*", "/"} + +// A program with each number 0-4 denoting their index +// let x = 3 in x == 4 +// -------------------0 +// -1 ------2 +// -3 -4 +rel let_in = {(0, "x", 1, 2)} +rel number = {(1, 3), (4, 4)} +rel bexp = {(2, "==", 3, 4)} +rel variable = {(3, "x")} + +// Type Inference: + +// - Base case +rel type_of(x, "bool") = boolean(x, _) +rel type_of(x, "int") = number(x, _) +rel type_of(x, t) = variable(x, v), env_type(x, v, t) +rel type_of(e, "bool") = bexp(e, op, x, y), comparison_op(op), type_of(x, "int"), type_of(y, "int") +rel type_of(e, "bool") = bexp(e, op, x, y), logical_op(op), type_of(x, "bool"), type_of(y, "bool") +rel type_of(e, "int") = aexp(e, op, x, y), arith_op(op), type_of(x, "int"), type_of(y, "int") +rel type_of(e, t) = let_in(e, v, b, c), env_type(c, v, tv), type_of(b, tv), type_of(c, t) +rel type_of(e, t) = if_then_else(e, x, y, z), type_of(x, "bool"), type_of(y, t), type_of(z, t) + +// - Environment variable type +rel env_type(x, v, t) = bexp(e, _, x, _), env_type(e, v, t) +rel env_type(y, v, t) = bexp(e, _, _, y), env_type(e, v, t) +rel env_type(x, v, t) = aexp(e, _, x, _), env_type(e, v, t) +rel env_type(y, v, t) = aexp(e, _, _, y), env_type(e, v, t) +rel env_type(z, v, t) = let_in(_, v, y, z), type_of(y, t) +rel env_type(z, v2, t) = let_in(x, v1, _, z), env_type(x, v2, t), v1 != v2 +rel env_type(x, v, t) = env_type(e, v, t), if_then_else(e, x, _, _) +rel env_type(y, v, t) = env_type(e, v, t), if_then_else(e, _, y, _) +rel env_type(z, v, t) = env_type(e, v, t), if_then_else(e, _, _, z) diff --git a/examples/demo_scl/undir_edge_path.scl b/examples/legacy/demo_scl/undir_edge_path.scl similarity index 100% rename from examples/demo_scl/undir_edge_path.scl rename to examples/legacy/demo_scl/undir_edge_path.scl diff --git a/examples/good_scl/animal.scl b/examples/legacy/good_scl/animal.scl similarity index 100% rename from examples/good_scl/animal.scl rename to examples/legacy/good_scl/animal.scl diff --git a/examples/good_scl/bmi.scl b/examples/legacy/good_scl/bmi.scl similarity index 100% rename from examples/good_scl/bmi.scl rename to examples/legacy/good_scl/bmi.scl diff --git a/examples/good_scl/bmi_2.scl b/examples/legacy/good_scl/bmi_2.scl similarity index 100% rename from examples/good_scl/bmi_2.scl rename to examples/legacy/good_scl/bmi_2.scl diff --git a/examples/legacy/good_scl/categorical_sample.scl b/examples/legacy/good_scl/categorical_sample.scl new file mode 100644 index 0000000..371866e --- /dev/null +++ b/examples/legacy/good_scl/categorical_sample.scl @@ -0,0 +1,7 @@ +const A = 0, B = 1, C = 2 + +rel obj_color = {0.4::(A, "red"); 0.3::(A, "blue"); 0.3::(A, "green")} +rel obj_color = {0.3::(B, "red"); 0.5::(B, "blue"); 0.2::(B, "green")} +rel obj_color = {0.05::(C, "red"); 0.05::(C, "blue"); 0.9::(C, "green")} + +rel sampled_obj_color(obj, c) = c := categorical<1>(c: obj_color(obj, c)) diff --git a/examples/good_scl/count_digit_3_or_4.scl b/examples/legacy/good_scl/count_digit_3_or_4.scl similarity index 100% rename from examples/good_scl/count_digit_3_or_4.scl rename to examples/legacy/good_scl/count_digit_3_or_4.scl diff --git a/examples/legacy/good_scl/date_time_1.scl b/examples/legacy/good_scl/date_time_1.scl new file mode 100644 index 0000000..6e7a4ef --- /dev/null +++ b/examples/legacy/good_scl/date_time_1.scl @@ -0,0 +1,91 @@ +rel event = {(1, t"2021-05-01T01:17:02.604456Z"), (2, t"2019-11-29 08:15:47.624504-08"), (3, t"Wed, 02 Jun 2021 06:31:39 GMT"), + (4, t"2017-07-19 03:21:51+00:00"), (5, t"2014-04-26 05:24:37 PM"), (6, t"2012-08-03 18:31:59.257000000"), + (7, t"2014-12-16 06:20:00 GMT"), (8, t"2021-02-21 PST"), (9, t"2012-08-03 18:31:59.257000000 +0000"), + (10, t"September 17, 2012, 10:10:09"), (11, t"May 6 at 9:24 PM"), (12, t"4:00pm"), + (13, t"14 May 2019 19:11:40.164"), (14, t"oct. 7, 1970"), (15, t"May 26, 2021, 12:49 AM PDT"), + (16, t"03/19/2012 10:11:59.318636"), (17, t"8/8/1965 01:00 PM"), (18, t"7 oct 70"), + (19, t"171113 14:14:20"), (20, t"03.31.2014"), (21, t"2012/03/19 10:11:59"), + (22, t"2014年04月08日11时25分18秒") + } + +rel duration = {(1, d"15 days 20 seconds 100 milliseconds"), (2, d"14 days seconds"), (3, d".:++++]][][[][15[]][seconds][]:}}}}")} + +//duration and datetime + +rel adding_duration_and_datetime_one(dr, dte, x + y) = event(dte, x), duration(dr, y) + +rel adding_duration_and_datetime_two(dr, dte, y + x) = event(dte, x), duration(dr, y) + +rel subtracting_duration_from_datetime(dr, dte, x - y) = event(dte, x), duration(dr, y) + +//duration and duration + +rel adding_durations(dr_one, dr_two, x + y) = duration(dr_one, x), duration(dr_two, y) + +rel subtracting_durations(dr_one, dr_two, x - y) = duration(dr_one, x), duration(dr_two, y) + +rel eq_durations(dr_one, dr_two) = duration(dr_one, x), duration(dr_two, y), x == y + +rel neq_durations(dr_one, dr_two) = duration(dr_one, x), duration(dr_two, y), x != y + +rel gt_durations(dr_one, dr_two) = duration(dr_one, x), duration(dr_two, y), x > y + +rel gte_durations(dr_one, dr_two) = duration(dr_one, x), duration(dr_two, y), x >= y + +rel lt_durations(dr_one, dr_two) = duration(dr_one, x), duration(dr_two, y), x < y + +rel lte_durations(dr_one, dr_two) = duration(dr_one, x), duration(dr_two, y), x <= y + +//duration and int + +rel div_duration(dr, x/3) = duration(dr, x) + +rel mul_duration_three(dr, x*3) = duration(dr, x) + +rel mul_duration_four(dr, 4*x) = duration(dr, x) + +rel mul_duration_neg(dr, -1*x) = duration(dr, x) + + + +//datetime and datetime + +rel subtracting_datetimes(dt_one, dt_two, x - y) = event(dt_one, x), event(dt_two, y), x >= y + +rel eq_datetimes(dt_one, dt_two) = event(dt_one, x), event(dt_two, y), x == y + +rel neq_datetimes(dt_one, dt_two) = event(dt_one, x), event(dt_two, y), x != y + +rel gt_datetimes(dt_one, dt_two) = event(dt_one, x), event(dt_two, y), x > y + +rel gte_datetimes(dt_one, dt_two) = event(dt_one, x), event(dt_two, y), x >= y + +rel lt_datetimes(dt_one, dt_two) = event(dt_one, x), event(dt_two, y), x < y + +rel lte_datetimes(dt_one, dt_two) = event(dt_one, x), event(dt_two, y), x <= y + +//aggregators for datetime + +rel how_many_datetimes(x) = x = count(a: event(_,a)) + +rel how_many_datetimes_less_than(x, y) = y = count(i: event(i,a), event(x,f), a < f) + +rel exists_a_datetime_less_than(x, y) = y = exists(i: event(i,a), event(x,f), a < f) + +rel closest_later_event(x,y) = y = min[i](dte: event(i, dte), event(x,f), dte > f) + + + +//aggregators for duration + +rel how_many_durations(x) = x = count(a: duration(_,a)) + +rel how_many_durations_less_than(x, y) = y = count(i: duration(i,a), duration(x,f), a < f) + +rel exists_durations_less_than(x, y) = y = exists(i: duration(i,a), duration(x,f), a < f) + +rel sum_durations(x) = x = sum(a: duration(_,a)) + +rel closest_later_duration(x,y) = y = min[i](dte: duration(i, dte), duration(x,f), dte > f) + + diff --git a/examples/good_scl/digit_sum.scl b/examples/legacy/good_scl/digit_sum.scl similarity index 100% rename from examples/good_scl/digit_sum.scl rename to examples/legacy/good_scl/digit_sum.scl diff --git a/examples/good_scl/digit_sum_2.scl b/examples/legacy/good_scl/digit_sum_2.scl similarity index 100% rename from examples/good_scl/digit_sum_2.scl rename to examples/legacy/good_scl/digit_sum_2.scl diff --git a/examples/good_scl/digit_sum_prob.scl b/examples/legacy/good_scl/digit_sum_prob.scl similarity index 100% rename from examples/good_scl/digit_sum_prob.scl rename to examples/legacy/good_scl/digit_sum_prob.scl diff --git a/examples/good_scl/double_dice.scl b/examples/legacy/good_scl/double_dice.scl similarity index 100% rename from examples/good_scl/double_dice.scl rename to examples/legacy/good_scl/double_dice.scl diff --git a/examples/good_scl/edge_path.scl b/examples/legacy/good_scl/edge_path.scl similarity index 100% rename from examples/good_scl/edge_path.scl rename to examples/legacy/good_scl/edge_path.scl diff --git a/examples/good_scl/edge_path_2.scl b/examples/legacy/good_scl/edge_path_2.scl similarity index 100% rename from examples/good_scl/edge_path_2.scl rename to examples/legacy/good_scl/edge_path_2.scl diff --git a/examples/good_scl/expr.scl b/examples/legacy/good_scl/expr.scl similarity index 100% rename from examples/good_scl/expr.scl rename to examples/legacy/good_scl/expr.scl diff --git a/examples/good_scl/expr_parse.scl b/examples/legacy/good_scl/expr_parse.scl similarity index 100% rename from examples/good_scl/expr_parse.scl rename to examples/legacy/good_scl/expr_parse.scl diff --git a/examples/good_scl/expr_prob.scl b/examples/legacy/good_scl/expr_prob.scl similarity index 100% rename from examples/good_scl/expr_prob.scl rename to examples/legacy/good_scl/expr_prob.scl diff --git a/examples/good_scl/fib.scl b/examples/legacy/good_scl/fib.scl similarity index 100% rename from examples/good_scl/fib.scl rename to examples/legacy/good_scl/fib.scl diff --git a/examples/good_scl/fib_dt.scl b/examples/legacy/good_scl/fib_dt.scl similarity index 100% rename from examples/good_scl/fib_dt.scl rename to examples/legacy/good_scl/fib_dt.scl diff --git a/examples/good_scl/forall_1.scl b/examples/legacy/good_scl/forall_1.scl similarity index 100% rename from examples/good_scl/forall_1.scl rename to examples/legacy/good_scl/forall_1.scl diff --git a/examples/good_scl/forall_3.scl b/examples/legacy/good_scl/forall_3.scl similarity index 100% rename from examples/good_scl/forall_3.scl rename to examples/legacy/good_scl/forall_3.scl diff --git a/examples/good_scl/has_three_objs.scl b/examples/legacy/good_scl/has_three_objs.scl similarity index 100% rename from examples/good_scl/has_three_objs.scl rename to examples/legacy/good_scl/has_three_objs.scl diff --git a/examples/legacy/good_scl/hashing_1.scl b/examples/legacy/good_scl/hashing_1.scl new file mode 100644 index 0000000..0ed0901 --- /dev/null +++ b/examples/legacy/good_scl/hashing_1.scl @@ -0,0 +1,2 @@ +rel edge = {(0, 1), (1, 2)} +rel result(c) :- c == $hash(a, b), edge(a, b) diff --git a/examples/legacy/good_scl/hashing_3.scl b/examples/legacy/good_scl/hashing_3.scl new file mode 100644 index 0000000..38be42d --- /dev/null +++ b/examples/legacy/good_scl/hashing_3.scl @@ -0,0 +1 @@ +rel result(x) = x == $hash(1, 3) diff --git a/examples/legacy/good_scl/how_many_3.scl b/examples/legacy/good_scl/how_many_3.scl new file mode 100644 index 0000000..2fb4b52 --- /dev/null +++ b/examples/legacy/good_scl/how_many_3.scl @@ -0,0 +1,2 @@ +rel digit = {0.91::(0, 0), 0.01::(0, 1), 0.01::(0, 2), 0.01::(0, 3)} +rel result(n) :- n = count(o: digit(o, 3)) diff --git a/examples/legacy/good_scl/how_many_3_with_disj.scl b/examples/legacy/good_scl/how_many_3_with_disj.scl new file mode 100644 index 0000000..2bfd158 --- /dev/null +++ b/examples/legacy/good_scl/how_many_3_with_disj.scl @@ -0,0 +1,4 @@ +rel digit = {0.91::(0, 0); 0.03::(0, 1); 0.03::(0, 2); 0.03::(0, 3)} +rel digit = {0.02::(1, 0); 0.03::(1, 1); 0.02::(1, 2); 0.93::(1, 3)} +rel digit = {0.01::(2, 0); 0.01::(2, 1); 0.01::(2, 2); 0.96::(2, 3)} +rel result(n) :- n = count(o: digit(o, 3)) diff --git a/examples/good_scl/hwf.scl b/examples/legacy/good_scl/hwf.scl similarity index 100% rename from examples/good_scl/hwf.scl rename to examples/legacy/good_scl/hwf.scl diff --git a/examples/good_scl/hwf_parsing.scl b/examples/legacy/good_scl/hwf_parsing.scl similarity index 100% rename from examples/good_scl/hwf_parsing.scl rename to examples/legacy/good_scl/hwf_parsing.scl diff --git a/examples/good_scl/implies.scl b/examples/legacy/good_scl/implies.scl similarity index 100% rename from examples/good_scl/implies.scl rename to examples/legacy/good_scl/implies.scl diff --git a/examples/legacy/good_scl/ite_1.scl b/examples/legacy/good_scl/ite_1.scl new file mode 100644 index 0000000..88d0764 --- /dev/null +++ b/examples/legacy/good_scl/ite_1.scl @@ -0,0 +1,2 @@ +rel a = {true} +rel b(x) = x == (if y then 3 else 4), a(y) diff --git a/examples/legacy/good_scl/kinship_ic_1.scl b/examples/legacy/good_scl/kinship_ic_1.scl new file mode 100644 index 0000000..c396e92 --- /dev/null +++ b/examples/legacy/good_scl/kinship_ic_1.scl @@ -0,0 +1,4 @@ +rel context = {0.9::("father", "A", "B"); 0.01::("mother", "A", "B")} +rel context = {0.7::("uncle", "B", "A"); 0.01::("son", "B", "A"); 0.01::("daughter", "B", "A")} +rel ic(r) = r = forall(a, b: context("father", a, b) => (context("son", b, a) or context("daughter", b, a))) +query ic diff --git a/examples/legacy/good_scl/negate_query_2.scl b/examples/legacy/good_scl/negate_query_2.scl new file mode 100644 index 0000000..3744742 --- /dev/null +++ b/examples/legacy/good_scl/negate_query_2.scl @@ -0,0 +1,3 @@ +rel B("Alice") +rel A() :- ~B(_) +query A diff --git a/examples/good_scl/no_incoming_edge.scl b/examples/legacy/good_scl/no_incoming_edge.scl similarity index 100% rename from examples/good_scl/no_incoming_edge.scl rename to examples/legacy/good_scl/no_incoming_edge.scl diff --git a/examples/good_scl/obj_color.scl b/examples/legacy/good_scl/obj_color.scl similarity index 100% rename from examples/good_scl/obj_color.scl rename to examples/legacy/good_scl/obj_color.scl diff --git a/examples/good_scl/obj_color_2.scl b/examples/legacy/good_scl/obj_color_2.scl similarity index 100% rename from examples/good_scl/obj_color_2.scl rename to examples/legacy/good_scl/obj_color_2.scl diff --git a/examples/good_scl/obj_color_3.scl b/examples/legacy/good_scl/obj_color_3.scl similarity index 100% rename from examples/good_scl/obj_color_3.scl rename to examples/legacy/good_scl/obj_color_3.scl diff --git a/examples/good_scl/odd_even.scl b/examples/legacy/good_scl/odd_even.scl similarity index 100% rename from examples/good_scl/odd_even.scl rename to examples/legacy/good_scl/odd_even.scl diff --git a/examples/good_scl/odd_even_2.scl b/examples/legacy/good_scl/odd_even_2.scl similarity index 100% rename from examples/good_scl/odd_even_2.scl rename to examples/legacy/good_scl/odd_even_2.scl diff --git a/examples/good_scl/odd_even_3.scl b/examples/legacy/good_scl/odd_even_3.scl similarity index 100% rename from examples/good_scl/odd_even_3.scl rename to examples/legacy/good_scl/odd_even_3.scl diff --git a/examples/legacy/good_scl/path_top2proofs.scl b/examples/legacy/good_scl/path_top2proofs.scl new file mode 100644 index 0000000..291b04f --- /dev/null +++ b/examples/legacy/good_scl/path_top2proofs.scl @@ -0,0 +1,20 @@ +// B -- C -- D +// | | | +// A E +// | | +// G -- H + +// Edges +const A = 0, B = 1, C = 2, D = 3, E = 4, F = 5, G = 6, H = 7 +rel is_enemy = {0.01::B, 0.01::C, 0.01::D, 0.99::F} +rel raw_edge = {(A, B), (B, C), (C, D), (D, E)} +rel raw_edge = {/*(A, F), (F, E)*/} +rel raw_edge = {(C, F)} +rel raw_edge = {(G, A), (G, H), (H, F)} + +// Recursive rules +rel edge(a, b) = edge(b, a) or (raw_edge(a, b) and not is_enemy(a) and not is_enemy(b)) +rel path(a, c) = edge(a, c) or (path(a, b) and edge(b, c)) + +// Query +query path(H, E) diff --git a/examples/good_scl/prov_fixpoint.scl b/examples/legacy/good_scl/prov_fixpoint.scl similarity index 100% rename from examples/good_scl/prov_fixpoint.scl rename to examples/legacy/good_scl/prov_fixpoint.scl diff --git a/examples/good_scl/query_only.scl b/examples/legacy/good_scl/query_only.scl similarity index 100% rename from examples/good_scl/query_only.scl rename to examples/legacy/good_scl/query_only.scl diff --git a/examples/legacy/good_scl/sample_top_1.scl b/examples/legacy/good_scl/sample_top_1.scl new file mode 100644 index 0000000..517a699 --- /dev/null +++ b/examples/legacy/good_scl/sample_top_1.scl @@ -0,0 +1,2 @@ +rel digit_a = {0.01::0, 0.01::1, 0.2::2, 0.3::3, 0.4::4, 0.01::5, 0.01::6, 0.01::7, 0.01::8, 0.01::9} +rel sampled_digit_a(x) :- x = top<3>(x: digit_a(x)) diff --git a/examples/legacy/good_scl/sample_top_2.scl b/examples/legacy/good_scl/sample_top_2.scl new file mode 100644 index 0000000..8c94f66 --- /dev/null +++ b/examples/legacy/good_scl/sample_top_2.scl @@ -0,0 +1,7 @@ +rel symbol = { + (0, "0"), (0, "1"), (0, "2"), (0, "3"), (0, "4"), (0, "5"), (0, "+"), (0, "-"), (0, "*"), (0, "/"), + (1, "0"), (1, "1"), (1, "2"), (1, "3"), (1, "4"), (1, "5"), (1, "+"), (1, "-"), (1, "*"), (1, "/"), + (2, "0"), (2, "1"), (2, "2"), (2, "3"), (2, "4"), (2, "5"), (2, "+"), (2, "-"), (2, "*"), (2, "/"), +} + +rel sampled_symbols(id, sym) :- sym = top<3>(s: symbol(id, s)) diff --git a/examples/good_scl/spectrl.scl b/examples/legacy/good_scl/spectrl.scl similarity index 100% rename from examples/good_scl/spectrl.scl rename to examples/legacy/good_scl/spectrl.scl diff --git a/examples/good_scl/srl_1.scl b/examples/legacy/good_scl/srl_1.scl similarity index 100% rename from examples/good_scl/srl_1.scl rename to examples/legacy/good_scl/srl_1.scl diff --git a/examples/good_scl/student_grade_1.scl b/examples/legacy/good_scl/student_grade_1.scl similarity index 100% rename from examples/good_scl/student_grade_1.scl rename to examples/legacy/good_scl/student_grade_1.scl diff --git a/examples/good_scl/student_grade_2.scl b/examples/legacy/good_scl/student_grade_2.scl similarity index 100% rename from examples/good_scl/student_grade_2.scl rename to examples/legacy/good_scl/student_grade_2.scl diff --git a/examples/legacy/good_scl/sum_1.scl b/examples/legacy/good_scl/sum_1.scl new file mode 100644 index 0000000..28b44cc --- /dev/null +++ b/examples/legacy/good_scl/sum_1.scl @@ -0,0 +1,2 @@ +rel color_num_obj :- {("blue", 1), ("red", 3), ("yellow", 6)} +rel num_obj(n) :- n = sum(y: color_num_obj(_, y)) diff --git a/examples/good_scl/temporal_1.scl b/examples/legacy/good_scl/temporal_1.scl similarity index 100% rename from examples/good_scl/temporal_1.scl rename to examples/legacy/good_scl/temporal_1.scl diff --git a/examples/good_scl/temporal_2.scl b/examples/legacy/good_scl/temporal_2.scl similarity index 100% rename from examples/good_scl/temporal_2.scl rename to examples/legacy/good_scl/temporal_2.scl diff --git a/examples/input_csv/edge.csv b/examples/legacy/input_csv/edge.csv similarity index 100% rename from examples/input_csv/edge.csv rename to examples/legacy/input_csv/edge.csv diff --git a/examples/input_csv/edge_prob.csv b/examples/legacy/input_csv/edge_prob.csv similarity index 100% rename from examples/input_csv/edge_prob.csv rename to examples/legacy/input_csv/edge_prob.csv diff --git a/examples/invalid_scl/arity_mismatch.scl b/examples/legacy/invalid_scl/arity_mismatch.scl similarity index 56% rename from examples/invalid_scl/arity_mismatch.scl rename to examples/legacy/invalid_scl/arity_mismatch.scl index acc62ac..b90ac9f 100644 --- a/examples/invalid_scl/arity_mismatch.scl +++ b/examples/legacy/invalid_scl/arity_mismatch.scl @@ -1,2 +1,2 @@ -input edge(a: usize, b: usize) +type edge(a: usize, b: usize) rel path(a, b) :- edge(a, b, c), edge(a) diff --git a/examples/legacy/invalid_scl/conflicting_constant_decl_type.scl b/examples/legacy/invalid_scl/conflicting_constant_decl_type.scl new file mode 100644 index 0000000..309e65d --- /dev/null +++ b/examples/legacy/invalid_scl/conflicting_constant_decl_type.scl @@ -0,0 +1,3 @@ +const V: usize = 5 +type r(u32, usize) +rel r = {(V, 3), (3, V)} diff --git a/examples/invalid_scl/dup_input.scl b/examples/legacy/invalid_scl/dup_input.scl similarity index 100% rename from examples/invalid_scl/dup_input.scl rename to examples/legacy/invalid_scl/dup_input.scl diff --git a/examples/invalid_scl/dup_type_decl.scl b/examples/legacy/invalid_scl/dup_type_decl.scl similarity index 100% rename from examples/invalid_scl/dup_type_decl.scl rename to examples/legacy/invalid_scl/dup_type_decl.scl diff --git a/examples/invalid_scl/dup_type_decl_2.scl b/examples/legacy/invalid_scl/dup_type_decl_2.scl similarity index 100% rename from examples/invalid_scl/dup_type_decl_2.scl rename to examples/legacy/invalid_scl/dup_type_decl_2.scl diff --git a/examples/legacy/invalid_scl/hashing_2.scl b/examples/legacy/invalid_scl/hashing_2.scl new file mode 100644 index 0000000..fd19564 --- /dev/null +++ b/examples/legacy/invalid_scl/hashing_2.scl @@ -0,0 +1,2 @@ +rel edge = {(0, 1), (1, 2)} +rel result(c) :- c == $hash(a, x), edge(a, b) diff --git a/examples/invalid_scl/invalid_input_ext.scl b/examples/legacy/invalid_scl/invalid_input_ext.scl similarity index 100% rename from examples/invalid_scl/invalid_input_ext.scl rename to examples/legacy/invalid_scl/invalid_input_ext.scl diff --git a/examples/invalid_scl/invalid_query.scl b/examples/legacy/invalid_scl/invalid_query.scl similarity index 100% rename from examples/invalid_scl/invalid_query.scl rename to examples/legacy/invalid_scl/invalid_query.scl diff --git a/examples/invalid_scl/invalid_query_type.scl b/examples/legacy/invalid_scl/invalid_query_type.scl similarity index 100% rename from examples/invalid_scl/invalid_query_type.scl rename to examples/legacy/invalid_scl/invalid_query_type.scl diff --git a/examples/invalid_scl/invalid_wildcard_1.scl b/examples/legacy/invalid_scl/invalid_wildcard_1.scl similarity index 100% rename from examples/invalid_scl/invalid_wildcard_1.scl rename to examples/legacy/invalid_scl/invalid_wildcard_1.scl diff --git a/examples/invalid_scl/odd_even_non_stratified.scl b/examples/legacy/invalid_scl/odd_even_non_stratified.scl similarity index 100% rename from examples/invalid_scl/odd_even_non_stratified.scl rename to examples/legacy/invalid_scl/odd_even_non_stratified.scl diff --git a/examples/invalid_scl/type_error_arith.scl b/examples/legacy/invalid_scl/type_error_arith.scl similarity index 100% rename from examples/invalid_scl/type_error_arith.scl rename to examples/legacy/invalid_scl/type_error_arith.scl diff --git a/examples/invalid_scl/type_error_bad_cmp.scl b/examples/legacy/invalid_scl/type_error_bad_cmp.scl similarity index 100% rename from examples/invalid_scl/type_error_bad_cmp.scl rename to examples/legacy/invalid_scl/type_error_bad_cmp.scl diff --git a/examples/invalid_scl/type_error_bad_const.scl b/examples/legacy/invalid_scl/type_error_bad_const.scl similarity index 100% rename from examples/invalid_scl/type_error_bad_const.scl rename to examples/legacy/invalid_scl/type_error_bad_const.scl diff --git a/examples/invalid_scl/type_error_cast.scl b/examples/legacy/invalid_scl/type_error_cast.scl similarity index 100% rename from examples/invalid_scl/type_error_cast.scl rename to examples/legacy/invalid_scl/type_error_cast.scl diff --git a/examples/invalid_scl/type_error_constraint.scl b/examples/legacy/invalid_scl/type_error_constraint.scl similarity index 100% rename from examples/invalid_scl/type_error_constraint.scl rename to examples/legacy/invalid_scl/type_error_constraint.scl diff --git a/examples/invalid_scl/type_error_count.scl b/examples/legacy/invalid_scl/type_error_count.scl similarity index 100% rename from examples/invalid_scl/type_error_count.scl rename to examples/legacy/invalid_scl/type_error_count.scl diff --git a/examples/invalid_scl/type_error_not.scl b/examples/legacy/invalid_scl/type_error_not.scl similarity index 100% rename from examples/invalid_scl/type_error_not.scl rename to examples/legacy/invalid_scl/type_error_not.scl diff --git a/examples/invalid_scl/type_error_num_cmp.scl b/examples/legacy/invalid_scl/type_error_num_cmp.scl similarity index 100% rename from examples/invalid_scl/type_error_num_cmp.scl rename to examples/legacy/invalid_scl/type_error_num_cmp.scl diff --git a/examples/invalid_scl/type_error_rela.scl b/examples/legacy/invalid_scl/type_error_rela.scl similarity index 100% rename from examples/invalid_scl/type_error_rela.scl rename to examples/legacy/invalid_scl/type_error_rela.scl diff --git a/examples/invalid_scl/unbound_1.scl b/examples/legacy/invalid_scl/unbound_1.scl similarity index 100% rename from examples/invalid_scl/unbound_1.scl rename to examples/legacy/invalid_scl/unbound_1.scl diff --git a/examples/invalid_scl/unbound_2.scl b/examples/legacy/invalid_scl/unbound_2.scl similarity index 100% rename from examples/invalid_scl/unbound_2.scl rename to examples/legacy/invalid_scl/unbound_2.scl diff --git a/examples/invalid_scl/unbound_3.scl b/examples/legacy/invalid_scl/unbound_3.scl similarity index 100% rename from examples/invalid_scl/unbound_3.scl rename to examples/legacy/invalid_scl/unbound_3.scl diff --git a/examples/legacy/invalid_scl/undeclared_relation.scl b/examples/legacy/invalid_scl/undeclared_relation.scl new file mode 100644 index 0000000..f81c515 --- /dev/null +++ b/examples/legacy/invalid_scl/undeclared_relation.scl @@ -0,0 +1,2 @@ +rel path(a, b) = edge(a, b) +rel path(a, c) = path(a, b), edge(b, c) diff --git a/examples/invalid_scl/unknown_type.scl b/examples/legacy/invalid_scl/unknown_type.scl similarity index 100% rename from examples/invalid_scl/unknown_type.scl rename to examples/legacy/invalid_scl/unknown_type.scl diff --git a/examples/tutorial_scl/graph_algo.scl b/examples/legacy/tutorial_scl/graph_algo.scl similarity index 100% rename from examples/tutorial_scl/graph_algo.scl rename to examples/legacy/tutorial_scl/graph_algo.scl diff --git a/examples/tutorial_scl/graph_algo_autograder.py b/examples/legacy/tutorial_scl/graph_algo_autograder.py similarity index 100% rename from examples/tutorial_scl/graph_algo_autograder.py rename to examples/legacy/tutorial_scl/graph_algo_autograder.py diff --git a/examples/tutorial_scl/graph_algo_example.scl b/examples/legacy/tutorial_scl/graph_algo_example.scl similarity index 100% rename from examples/tutorial_scl/graph_algo_example.scl rename to examples/legacy/tutorial_scl/graph_algo_example.scl diff --git a/examples/tutorial_scl/relations.scl b/examples/legacy/tutorial_scl/relations.scl similarity index 100% rename from examples/tutorial_scl/relations.scl rename to examples/legacy/tutorial_scl/relations.scl diff --git a/examples/tutorial_scl/scene_graph.scl b/examples/legacy/tutorial_scl/scene_graph.scl similarity index 100% rename from examples/tutorial_scl/scene_graph.scl rename to examples/legacy/tutorial_scl/scene_graph.scl diff --git a/examples/tutorial_scl/scene_graph_example_1.scl b/examples/legacy/tutorial_scl/scene_graph_example_1.scl similarity index 100% rename from examples/tutorial_scl/scene_graph_example_1.scl rename to examples/legacy/tutorial_scl/scene_graph_example_1.scl diff --git a/examples/tutorial_scl/visual_question_answering.scl b/examples/legacy/tutorial_scl/visual_question_answering.scl similarity index 100% rename from examples/tutorial_scl/visual_question_answering.scl rename to examples/legacy/tutorial_scl/visual_question_answering.scl diff --git a/examples/probabilistic/alarm.scl b/examples/probabilistic/alarm.scl new file mode 100644 index 0000000..db9f8c2 --- /dev/null +++ b/examples/probabilistic/alarm.scl @@ -0,0 +1,4 @@ +rel 0.05::earthquake() +rel 0.90::burglary() +rel 0.95::alarm() :- earthquake() or burglary() +query alarm() diff --git a/examples/probabilistic/digit_less_than.scl b/examples/probabilistic/digit_less_than.scl new file mode 100644 index 0000000..a7dbf85 --- /dev/null +++ b/examples/probabilistic/digit_less_than.scl @@ -0,0 +1,29 @@ +rel digit_1 = { + 0.01::0, + 0.01::1, + 0.01::2, + 0.91::3, + 0.01::4, + 0.01::5, + 0.01::6, + 0.01::7, + 0.01::8, + 0.01::9, +} + +rel digit_2 = { + 0.01::0, + 0.01::1, + 0.01::2, + 0.02::3, + 0.01::4, + 0.01::5, + 0.01::6, + 0.90::7, + 0.01::8, + 0.01::9, +} + +rel less_than(a < b) = digit_1(a) and digit_2(b) + +query less_than diff --git a/examples/probabilistic/digit_sum_2.scl b/examples/probabilistic/digit_sum_2.scl new file mode 100644 index 0000000..750ca0d --- /dev/null +++ b/examples/probabilistic/digit_sum_2.scl @@ -0,0 +1,29 @@ +rel digit_1 = { + 0.01::0, + 0.01::1, + 0.01::2, + 0.91::3, + 0.01::4, + 0.01::5, + 0.01::6, + 0.01::7, + 0.01::8, + 0.01::9, +} + +rel digit_2 = { + 0.01::0, + 0.01::1, + 0.01::2, + 0.02::3, + 0.01::4, + 0.01::5, + 0.01::6, + 0.90::7, + 0.01::8, + 0.01::9, +} + +rel sum_2(a + b) = digit_1(a) and digit_2(b) + +query sum_2 diff --git a/experiments/mnist/docs/+_964.jpg b/experiments/mnist/docs/+_964.jpg deleted file mode 100644 index d0b80f6..0000000 Binary files a/experiments/mnist/docs/+_964.jpg and /dev/null differ diff --git a/experiments/mnist/docs/3_226.jpg b/experiments/mnist/docs/3_226.jpg deleted file mode 100644 index 7953e93..0000000 Binary files a/experiments/mnist/docs/3_226.jpg and /dev/null differ diff --git a/experiments/mnist/docs/5_42754.jpg b/experiments/mnist/docs/5_42754.jpg deleted file mode 100644 index 25dcb28..0000000 Binary files a/experiments/mnist/docs/5_42754.jpg and /dev/null differ diff --git a/experiments/mnist/how_many_3.py b/experiments/mnist/how_many_3.py index 124f5ef..f92eb79 100644 --- a/experiments/mnist/how_many_3.py +++ b/experiments/mnist/how_many_3.py @@ -128,18 +128,18 @@ def __init__(self, num_elements, provenance, k): # Scallop Context self.scl_ctx = scallopy.ScallopContext(provenance=provenance, k=k) - self.scl_ctx.add_relation("digit", (int, int), input_mapping=[(i, d) for i in range(num_elements) for d in range(10)]) + self.scl_ctx.add_relation("digit", (int, int), input_mapping={0: range(num_elements), 1: range(10)}) self.scl_ctx.add_rule("how_many_3(x) :- x = count(o: digit(o, 3))") # The `how_many_3` logical reasoning module - self.how_many_3 = self.scl_ctx.forward_function("how_many_3", list(range(num_elements + 1))) + if args.debug_tag: + self.how_many_3 = self.scl_ctx.forward_function("how_many_3", list(range(num_elements + 1)), debug_provenance=True, dispatch="single") + else: + self.how_many_3 = self.scl_ctx.forward_function("how_many_3", list(range(num_elements + 1))) def forward(self, x: List[torch.Tensor]): - # Apply mnist net on each image - digit_distrs = [self.mnist_net(imgs) for imgs in x] - - # Concatenate them into the same big tensor - digit = torch.cat(tuple(digit_distrs), dim=1) + # Apply mnist net on each image and stack them into a big tensor + digit = torch.stack([self.mnist_net(imgs) for imgs in x], dim=1) # Then execute the reasoning module; the result is a size `num_elements + 1` tensor return self.how_many_3(digit=digit) @@ -214,6 +214,8 @@ def train(self, n_epochs): parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--provenance", type=str, default="difftopbottomkclauses") parser.add_argument("--top-k", type=int, default=3) + parser.add_argument("--debug-tag", action="store_true") + parser.add_argument("--softmax", action="store_true") args = parser.parse_args() # Parameters diff --git a/experiments/mnist/how_many_3_or_4.py b/experiments/mnist/how_many_3_or_4.py index a03d8e6..9a16919 100644 --- a/experiments/mnist/how_many_3_or_4.py +++ b/experiments/mnist/how_many_3_or_4.py @@ -129,18 +129,15 @@ def __init__(self, num_elements, provenance, k): # Scallop Context self.scl_ctx = scallopy.ScallopContext(provenance=provenance, k=k) - self.scl_ctx.add_relation("digit", (int, int), input_mapping=[(i, d) for i in range(num_elements) for d in range(10)]) + self.scl_ctx.add_relation("digit", (int, int), input_mapping={0: range(num_elements), 1: range(10)}) self.scl_ctx.add_rule("how_many(x) :- x = count(o: digit(o, 3) or digit(o, 4))") # The `how_many` logical reasoning module - self.how_many = self.scl_ctx.forward_function("how_many", list(range(num_elements + 1))) + self.how_many = self.scl_ctx.forward_function("how_many", list(range(num_elements + 1)), jit=args.jit) def forward(self, x: List[torch.Tensor]): - # Apply mnist net on each image - digit_distrs = [self.mnist_net(imgs) for imgs in x] - - # Concatenate them into the same big tensor - digit = torch.cat(tuple(digit_distrs), dim=1) + # Apply mnist net on each image and stack them into the same big tensor + digit = torch.stack([self.mnist_net(imgs) for imgs in x], dim=1) # Then execute the reasoning module; the result is a size `num_elements + 1` tensor return self.how_many(digit=digit) @@ -197,7 +194,6 @@ def test_epoch(self, epoch): def train(self, n_epochs): self.test_epoch(0) - exit() for epoch in range(1, n_epochs + 1): self.train_epoch(epoch) self.test_epoch(epoch) @@ -227,6 +223,7 @@ def print_random_distribution(num_elements): parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--provenance", type=str, default="difftopbottomkclauses") parser.add_argument("--top-k", type=int, default=2) + parser.add_argument("--jit", action="store_true") args = parser.parse_args() # Parameters @@ -238,8 +235,6 @@ def print_random_distribution(num_elements): # Dataloaders train_loader, test_loader = mnist_how_many_3_or_4_loader(data_dir, args.num_elements, args.dataset_up_scale, args.batch_size_train, args.batch_size_test) - - print_random_distribution(args.num_elements) # Create trainer and train diff --git a/experiments/mnist/plot_confusion_matrix.py b/experiments/mnist/plot_confusion_matrix.py new file mode 100644 index 0000000..c4214d5 --- /dev/null +++ b/experiments/mnist/plot_confusion_matrix.py @@ -0,0 +1,62 @@ +from argparse import ArgumentParser +import os + +# Computation +import torch +import torchvision +import numpy +from sklearn.metrics import confusion_matrix + +# Plotting +import matplotlib.pyplot as plt +import seaborn as sn +import pandas as pd + +if __name__ == "__main__": + # Argument parser + parser = ArgumentParser("plot_confusion_matrix") + parser.add_argument("--model-file", default="mnist_sort_2/mnist_sort_2_net.pkl") + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--plot-image", action="store_true") + parser.add_argument("--image-file", default="confusion.png") + parser.add_argument("--task", type=str, default="sum_2") + args = parser.parse_args() + + if args.task == "sum_2": + import sum_2 as module + from sum_2 import MNISTSum2Net, MNISTNet + + # Directories + data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data")) + model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../model")) + + # Load mnist dataset + mnist_dataset = torchvision.datasets.MNIST(data_dir, train=False, download=True, transform=module.mnist_img_transform) + mnist_loader = torch.utils.data.DataLoader(mnist_dataset, batch_size=args.batch_size) + + # Load model + mnist_subtask_net = torch.load(open(os.path.join(model_dir, args.model_file), "rb")) + mnist_net = mnist_subtask_net.mnist_net + mnist_net.eval() + + # Get prediction result + y_true, y_pred = [], [] + with torch.no_grad(): + for (imgs, digits) in mnist_loader: + pred_digits = numpy.argmax(mnist_net(imgs), axis=1) + y_true += [d.item() for d in digits] + y_pred += [d.item() for d in pred_digits] + + # Compute confusion matrix + cm = confusion_matrix(y_true, y_pred) + + # Plot image or print + if args.plot_image: + df_cm = pd.DataFrame(cm, index=list(range(10)), columns=list(range(10))) + plt.figure(figsize=(10,7)) + sn.heatmap(df_cm, annot=True, cmap=plt.cm.Blues) + plt.ylabel("Actual") + plt.xlabel("Predicted") + plt.savefig(args.image_file) + else: + print(cm) diff --git a/experiments/mnist/sort_2/run.py b/experiments/mnist/sort_2/run.py index d03e883..15a760f 100644 --- a/experiments/mnist/sort_2/run.py +++ b/experiments/mnist/sort_2/run.py @@ -110,7 +110,7 @@ def forward(self, x): class MNISTSort2Net(nn.Module): - def __init__(self, provenance, train_k, test_k, wmc_type, max_digit=9): + def __init__(self, provenance, train_k, test_k, max_digit=9): super(MNISTSort2Net, self).__init__() self.max_digit = max_digit self.num_classes = max_digit + 1 @@ -119,7 +119,7 @@ def __init__(self, provenance, train_k, test_k, wmc_type, max_digit=9): self.mnist_net = MNISTNet(num_classes=self.num_classes) # Scallop Context - self.scl_ctx = scallopy.ScallopContext(provenance=provenance, train_k=train_k, test_k=test_k, wmc_type=wmc_type) + self.scl_ctx = scallopy.ScallopContext(provenance=provenance, train_k=train_k, test_k=test_k) self.scl_ctx.add_relation("digit_1", int, input_mapping=list(range(self.num_classes))) self.scl_ctx.add_relation("digit_2", int, input_mapping=list(range(self.num_classes))) self.scl_ctx.add_rule("less_than(a < b) = digit_1(a), digit_2(b)") @@ -149,8 +149,8 @@ def nll_loss(output, ground_truth): class Trainer(): - def __init__(self, train_loader, test_loader, model_dir, learning_rate, loss, train_k, test_k, provenance, wmc_type, max_digit=9): - self.network = MNISTSort2Net(provenance, train_k=train_k, test_k=test_k, wmc_type=wmc_type, max_digit=max_digit) + def __init__(self, train_loader, test_loader, model_dir, learning_rate, loss, train_k, test_k, provenance, max_digit=9): + self.network = MNISTSort2Net(provenance, train_k=train_k, test_k=test_k, max_digit=max_digit) self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate) self.train_loader = train_loader self.test_loader = test_loader @@ -207,7 +207,6 @@ def run(self, n_epochs): parser.add_argument("--provenance", type=str, default="difftopkproofs") parser.add_argument("--train-k", type=int, default=3) parser.add_argument("--test-k", type=int, default=3) - parser.add_argument("--wmc-type", type=str, default="bottom-up") parser.add_argument("--max-digit", type=int, default=9) args = parser.parse_args() @@ -224,5 +223,5 @@ def run(self, n_epochs): train_loader, test_loader = mnist_sort_2_loader(data_dir, args.batch_size, max_digit=args.max_digit) # Create trainer and train - trainer = Trainer(train_loader, test_loader, model_dir, args.learning_rate, args.loss_fn, args.train_k, args.test_k, args.provenance, args.wmc_type, max_digit=args.max_digit) + trainer = Trainer(train_loader, test_loader, model_dir, args.learning_rate, args.loss_fn, args.train_k, args.test_k, args.provenance, max_digit=args.max_digit) trainer.run(args.n_epochs) diff --git a/experiments/mnist/sum_2.py b/experiments/mnist/sum_2.py index c0cc6a2..26b60e7 100644 --- a/experiments/mnist/sum_2.py +++ b/experiments/mnist/sum_2.py @@ -119,7 +119,7 @@ def __init__(self, provenance, k): self.scl_ctx.add_rule("sum_2(a + b) :- digit_1(a), digit_2(b)") # The `sum_2` logical reasoning module - self.sum_2 = self.scl_ctx.forward_function("sum_2", output_mapping=[(i,) for i in range(19)]) + self.sum_2 = self.scl_ctx.forward_function("sum_2", output_mapping=[(i,) for i in range(19)], jit=args.jit, dispatch=args.dispatch) def forward(self, x: Tuple[torch.Tensor, torch.Tensor]): (a_imgs, b_imgs) = x @@ -143,11 +143,13 @@ def nll_loss(output, ground_truth): class Trainer(): - def __init__(self, train_loader, test_loader, learning_rate, loss, k, provenance): + def __init__(self, train_loader, test_loader, model_dir, learning_rate, loss, k, provenance): + self.model_dir = model_dir self.network = MNISTSum2Net(provenance, k) self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate) self.train_loader = train_loader self.test_loader = test_loader + self.best_loss = 10000000000 if loss == "nll": self.loss = nll_loss elif loss == "bce": @@ -166,7 +168,7 @@ def train_epoch(self, epoch): self.optimizer.step() iter.set_description(f"[Train Epoch {epoch}] Loss: {loss.item():.4f}") - def test(self, epoch): + def test_epoch(self, epoch): self.network.eval() num_items = len(self.test_loader.dataset) test_loss = 0 @@ -180,12 +182,15 @@ def test(self, epoch): correct += pred.eq(target.data.view_as(pred)).sum() perc = 100. * correct / num_items iter.set_description(f"[Test Epoch {epoch}] Total loss: {test_loss:.4f}, Accuracy: {correct}/{num_items} ({perc:.2f}%)") + if test_loss < self.best_loss: + self.best_loss = test_loss + torch.save(self.network, os.path.join(model_dir, "sum_2_best.pt")) def train(self, n_epochs): - self.test(0) + self.test_epoch(0) for epoch in range(1, n_epochs + 1): self.train_epoch(epoch) - self.test(epoch) + self.test_epoch(epoch) if __name__ == "__main__": @@ -199,6 +204,8 @@ def train(self, n_epochs): parser.add_argument("--seed", type=int, default=1234) parser.add_argument("--provenance", type=str, default="difftopkproofs") parser.add_argument("--top-k", type=int, default=3) + parser.add_argument("--jit", action="store_true") + parser.add_argument("--dispatch", type=str, default="parallel") args = parser.parse_args() # Parameters @@ -214,10 +221,12 @@ def train(self, n_epochs): # Data data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data")) + model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../model/mnist_sum_2")) + os.makedirs(model_dir, exist_ok=True) # Dataloaders train_loader, test_loader = mnist_sum_2_loader(data_dir, batch_size_train, batch_size_test) # Create trainer and train - trainer = Trainer(train_loader, test_loader, learning_rate, loss_fn, k, provenance) + trainer = Trainer(train_loader, test_loader, model_dir, learning_rate, loss_fn, k, provenance) trainer.train(n_epochs) diff --git a/experiments/pacman_maze/arena.py b/experiments/pacman_maze/arena.py new file mode 100644 index 0000000..b15a5fe --- /dev/null +++ b/experiments/pacman_maze/arena.py @@ -0,0 +1,216 @@ +from typing import * +import gym +import numpy +import cv2 +import random +import torch +import os + +RES_DIR = os.path.abspath(os.path.join(os.path.abspath(__file__), "../res")) + +class AvoidingArena(gym.Env): + def __init__( + self, + grid_dim: Tuple[int, int] =(5, 5), + cell_size: float = 0.5, + dpi: int = 80, + num_enemies: int = 5, + easy: bool = False, + default_reward: float = 0.00, + on_success_reward: float = 1.0, + on_failure_reward: float = 0.0, + remain_unchanged_reward: float = 0.0, + ): + """ + :param grid_dim, (int, int), a tuple of two integers for (grid_x, grid_y) + :param cell_size, float, the side length of each cell, in inches + :param dpi, int, dimension (number of pixels) per inch + :param num_enemies, int, maximum number of enemies on the arena + """ + self.grid_x, self.grid_y = grid_dim + self.cell_size = cell_size + self.dpi = dpi + self.num_enemies = num_enemies + self.image_w, self.image_h = self.grid_x * self.cell_size, self.grid_y * self.cell_size + self.easy = easy + self.default_reward = default_reward + self.on_success_reward = on_success_reward + self.on_failure_reward = on_failure_reward + self.remain_unchanged_reward = remain_unchanged_reward + + # Initialize environment states + self.curr_pos = None + self.start_pos = None + self.goal_pos = None + self.enemies = None + + # Load background and enemy images + self.background_image = cv2.imread(os.path.join(RES_DIR, "back.webp")) + enemy_image_1 = cv2.imread(os.path.join(RES_DIR, "enemy1.webp"), cv2.IMREAD_UNCHANGED) + enemy_image_2 = cv2.imread(os.path.join(RES_DIR, "enemy2.webp"), cv2.IMREAD_UNCHANGED) + self.enemy_images = [enemy_image_1, enemy_image_2] + self.goal_image = cv2.imread(os.path.join(RES_DIR, "flag.png"), cv2.IMREAD_UNCHANGED) + self.agent_image = cv2.imread(os.path.join(RES_DIR, "agent.png"), cv2.IMREAD_UNCHANGED) + + def reset(self): + # Generate start position + self.start_pos = self.sample_point() + self.curr_pos = self.start_pos + + if self.easy: + possible_positions = [] + for (off_x, off_y) in [(0, 1), (1, 0), (0, -1), (1, 0)]: + goal_x = self.start_pos[0] + off_x + goal_y = self.start_pos[1] + off_y + if 0 <= goal_x < self.grid_x and 0 <= goal_y < self.grid_y: + possible_positions.append((goal_x, goal_y)) + self.goal_pos = possible_positions[random.randrange(0, len(possible_positions))] + else: + # Generate end position + self.goal_pos = self.sample_point() + while self.goal_pos == self.start_pos: + self.goal_pos = self.sample_point() + + # Generate enemy positions + self.enemies = [] + self.enemy_types = [] + num_tries = 0 + while len(self.enemies) < self.num_enemies and num_tries < 100: + num_tries += 1 + try_pos = self.sample_point() + if self.ok_enemy_position(try_pos): + self.enemies.append(try_pos) + self.enemy_types.append(random.randint(0, 1)) + + # Return + return () + + def step(self, action): + prev_pos = self.curr_pos + + # Step action + if action == 0 and self.curr_pos[1] < self.grid_y - 1: # If can move up + self.curr_pos = (self.curr_pos[0], self.curr_pos[1] + 1) + elif action == 1 and self.curr_pos[0] < self.grid_x - 1: # If can move right + self.curr_pos = (self.curr_pos[0] + 1, self.curr_pos[1]) + elif action == 2 and self.curr_pos[1] > 0: # If can move down + self.curr_pos = (self.curr_pos[0], self.curr_pos[1] - 1) + elif action == 3 and self.curr_pos[0] > 0: # If can move left + self.curr_pos = (self.curr_pos[0] - 1, self.curr_pos[1]) + + # Check if reached goal position + done, reward = False, self.default_reward + if self.curr_pos in self.enemies: done, reward = True, self.on_failure_reward # Hitting enemy + elif self.curr_pos == self.goal_pos: done, reward = True, self.on_success_reward # Reaching goal + elif self.curr_pos == prev_pos: done, reward = False, self.remain_unchanged_reward # Stay in same position + + # Return + return ((), done, reward, ()) + + def hidden_state(self): + """ + Return a tuple (current_position, goal_position, enemy_positions) + where enemy positions is a list of enemy positions + + This hidden_state should not be used by model that desires to solve the game + """ + return (self.curr_pos, self.goal_pos, self.enemies) + + def render(self): + w, h = int(self.image_w * self.dpi), int(self.image_h * self.dpi) + # image = numpy.zeros((w, h, 3), dtype=numpy.uint8) + image = numpy.zeros((h, w, 3), dtype=numpy.uint8) + + # Setup the background + image[0:h, 0:w] = self.background_image[0:h, 0:w] + + # Draw the current position + self._paint_spirit(image, self.agent_image, self.curr_pos) + + # Draw the goal position + self._paint_spirit(image, self.goal_image, self.goal_pos) + + # Draw the enemy position + for (i, enemy_pos) in enumerate(self.enemies): + self._paint_spirit(image, self.enemy_images[self.enemy_types[i]], enemy_pos) + + return image + + def render_torch_tensor(self, image=None): + image = self.render() if image is None else image + image = numpy.ascontiguousarray(image, dtype=numpy.float32) / 255 + torch_image = torch.tensor(image).permute(2, 0, 1).float() + return torch.stack([torch_image]) + + def _paint_spirit(self, background, spirit, orig_cell_pos): + cell_pos = (orig_cell_pos[0], self.grid_y - orig_cell_pos[1] - 1) + cell_w, cell_h = self.cell_pixel_size() + agent_image = cv2.resize(spirit, (cell_w, cell_h), interpolation=cv2.INTER_AREA) + agent_offset_x, agent_offset_y = cell_pos[0] * cell_w, cell_pos[1] * cell_h + agent_end_x, agent_end_y = agent_offset_x + cell_w, agent_offset_y + cell_h + agent_img_gray = agent_image[:, :, 3] + _, mask = cv2.threshold(agent_img_gray, 120, 255, cv2.THRESH_BINARY) + mask_inv = cv2.bitwise_not(agent_img_gray) + source = background[agent_offset_y:agent_end_y, agent_offset_x:agent_end_x] + bg = cv2.bitwise_or(source, source, mask=mask_inv) + fg = cv2.bitwise_and(agent_image, agent_image, mask=mask) + background[agent_offset_y:agent_end_y, agent_offset_x:agent_end_x] = cv2.add(bg, fg[:, :, 0:3]) + + def paint_color(self, background, colors, cell_pos): + size_x, size_y = 10, 10 + cell_pos = (cell_pos[0], self.grid_y - cell_pos[1] - 1) + cell_w, cell_h = self.cell_pixel_size() + agent_offset_x, agent_offset_y = cell_pos[0] * cell_w, cell_pos[1] * cell_h + agent_end_x, agent_end_y = agent_offset_x + size_x, agent_offset_y + size_y + get_channel = lambda c: numpy.ones((size_y, size_x), dtype=numpy.uint8) * int(255 * colors[c]) + color = numpy.transpose(numpy.stack([get_channel(i) for i in range(3)]), (1, 2, 0)) + background[agent_offset_y:agent_end_y, agent_offset_x:agent_end_x] = color + + def print_state(self): + print("┌" + ("─" * ((self.grid_x + 2) * 2 - 3)) + "┐") + for j in range(self.grid_y - 1, -1, -1): + print("│", end=" ") + for i in range(self.grid_x): + print(self.pos_char((i, j)), end=" ") + print("│") + print("└" + ("─" * ((self.grid_x + 2) * 2 - 3)) + "┘") + + def pos_char(self, pos): + if pos == self.curr_pos: return 'C' + elif pos == self.start_pos: return 'S' + elif pos == self.goal_pos: return 'G' + elif pos in self.enemies: return '▒' + else: return ' ' + + def string_of_action(self, action): + if action == 0: return "up" + elif action == 1: return "right" + elif action == 2: return "down" + elif action == 3: return "left" + else: raise Exception(f"Unknown action `{action}`") + + def sample_point(self): + return (random.randint(0, self.grid_x - 1), random.randint(0, self.grid_y - 1)) + + def ok_enemy_position(self, pos): + return self.manhatten_distance(pos, self.start_pos) > 1 and self.manhatten_distance(pos, self.goal_pos) > 1 + + def cell_pixel_size(self): + return (int(self.cell_size * self.dpi), int(self.cell_size * self.dpi)) + + def manhatten_distance(self, p1, p2): + return abs(p1[0] - p2[0]) + abs(p1[1] - p2[1]) + +def crop_cell_image(image, grid_dim, cell_pixel_size, orig_cell_pos): + cell_pos = (orig_cell_pos[0], grid_dim[1] - orig_cell_pos[1] - 1) + cell_w, cell_h = cell_pixel_size + agent_offset_x, agent_offset_y = cell_pos[0] * cell_w, cell_pos[1] * cell_h + agent_end_x, agent_end_y = agent_offset_x + cell_w, agent_offset_y + cell_h + return image[agent_offset_y:agent_end_y, agent_offset_x:agent_end_x] + +def crop_cell_image_torch(image, grid_dim, cell_pixel_size, orig_cell_pos): + cell_pos = (orig_cell_pos[0], grid_dim[1] - orig_cell_pos[1] - 1) + cell_w, cell_h = cell_pixel_size + agent_offset_x, agent_offset_y = cell_pos[0] * cell_w, cell_pos[1] * cell_h + agent_end_x, agent_end_y = agent_offset_x + cell_w, agent_offset_y + cell_h + return image[:, agent_offset_y:agent_end_y, agent_offset_x:agent_end_x] diff --git a/experiments/pacman_maze/examples/arena_1.scl b/experiments/pacman_maze/examples/arena_1.scl new file mode 100644 index 0000000..cc941bb --- /dev/null +++ b/experiments/pacman_maze/examples/arena_1.scl @@ -0,0 +1,25 @@ +import "../scl/arena.scl" +import "../scl/grid_node.scl" + +// This example replicates the following 4x4 grid +// +// * O * G +// * * X * +// * O O O +// * * * * + +rel grid_size(4, 4) + +rel curr_position = {0.98::(2, 2)} + +rel goal_position = {0.98::(3, 3)} + +rel is_enemy = { + 0.01::(0, 3), 0.99::(1, 3), 0.01::(2, 3), 0.01::(3, 3), + 0.01::(0, 2), 0.01::(1, 2), 0.01::(2, 2), 0.01::(3, 2), + 0.01::(0, 1), 0.99::(1, 1), 0.99::(2, 1), 0.99::(3, 1), + 0.01::(0, 0), 0.01::(1, 0), 0.01::(2, 0), 0.01::(3, 0), +} + +query action_score +query next_position diff --git a/experiments/pacman_maze/examples/arena_2.scl b/experiments/pacman_maze/examples/arena_2.scl new file mode 100644 index 0000000..2f8f32a --- /dev/null +++ b/experiments/pacman_maze/examples/arena_2.scl @@ -0,0 +1,25 @@ +import "../scl/arena.scl" +import "../scl/grid_node.scl" + +// This example replicates the following 4x4 grid +// +// * O * G +// * * * * +// * O O O +// * * X * + +rel grid_size(4, 4) + +rel curr_position = {0.98::(2, 0)} + +rel goal_position = {0.99::(3, 3)} + +rel is_enemy = { + 0.01::(0, 3), 0.99::(1, 3), 0.01::(2, 3), 0.01::(3, 3), + 0.01::(0, 2), 0.01::(1, 2), 0.01::(2, 2), 0.01::(3, 2), + 0.01::(0, 1), 0.99::(1, 1), 0.99::(2, 1), 0.99::(3, 1), + 0.01::(0, 0), 0.01::(1, 0), 0.01::(2, 0), 0.01::(3, 0), +} + +query action_score +query next_position diff --git a/experiments/pacman_maze/examples/arena_3.scl b/experiments/pacman_maze/examples/arena_3.scl new file mode 100644 index 0000000..48627ce --- /dev/null +++ b/experiments/pacman_maze/examples/arena_3.scl @@ -0,0 +1,26 @@ +import "../scl/arena.scl" +import "../scl/grid_node.scl" + +// This example replicates the following 4x4 grid +// +// * O * G +// * * * * +// * O O O +// * * X * + +rel grid_size(4, 4) + +rel curr_position = {0.98::(3, 2)} + +rel goal_position = {0.99::(3, 3)} + +rel is_enemy = { + 0.01::(0, 3), 0.99::(1, 3), 0.01::(2, 3), 0.01::(3, 3), + 0.01::(0, 2), 0.01::(1, 2), 0.01::(2, 2), 0.01::(3, 2), + 0.01::(0, 1), 0.99::(1, 1), 0.99::(2, 1), 0.99::(3, 1), + 0.01::(0, 0), 0.01::(1, 0), 0.01::(2, 0), 0.01::(3, 0), +} + +query action_score +// query next_position +// query node diff --git a/experiments/pacman_maze/examples/arena_4.scl b/experiments/pacman_maze/examples/arena_4.scl new file mode 100644 index 0000000..57ce773 --- /dev/null +++ b/experiments/pacman_maze/examples/arena_4.scl @@ -0,0 +1,31 @@ +import "../scl/arena.scl" +import "../scl/grid_node.scl" + +// This example replicates the following 4x4 grid +// +// * * C * * +// * E * E * +// * * * * * +// G * E * E +// * * * * * + +rel grid_node = { + 0.95::(0, 4), 0.95::(1, 4), 0.95::(2, 4), 0.95::(3, 4), 0.95::(4, 4), + 0.95::(0, 3), 0.95::(1, 3), 0.95::(2, 3), 0.95::(3, 3), 0.95::(4, 3), + 0.95::(0, 2), 0.95::(1, 2), 0.95::(2, 2), 0.95::(3, 2), 0.95::(4, 2), + 0.95::(0, 1), 0.95::(1, 1), 0.95::(2, 1), 0.95::(3, 1), 0.95::(4, 1), + 0.95::(0, 0), 0.95::(1, 0), 0.95::(2, 0), 0.95::(3, 0), 0.95::(4, 0), +} + +rel curr_position = {(2, 4)} + +rel goal_position = {(0, 1)} + +rel is_enemy = {(1, 3), (3, 3), (2, 1), (4, 1)} + +query action_score + +// R: {Pos(2), Pos(3), Pos(7), Pos(11), Pos(12), Pos(15), Pos(16)}, +// {Pos(2), Pos(3), Pos(7), Pos(10), Pos(11), Pos(12), Pos(15)} +// {Pos(0), Pos(1), Pos(2), Pos(3), Pos(5), Pos(10), Pos(15)} +// L: {Pos(0), Pos(1), Pos(2), Pos(5), Pos(10), Pos(15)} diff --git a/experiments/pacman_maze/res/agent.png b/experiments/pacman_maze/res/agent.png new file mode 100644 index 0000000..54a098d Binary files /dev/null and b/experiments/pacman_maze/res/agent.png differ diff --git a/experiments/pacman_maze/res/back.webp b/experiments/pacman_maze/res/back.webp new file mode 100644 index 0000000..b466bc3 Binary files /dev/null and b/experiments/pacman_maze/res/back.webp differ diff --git a/experiments/pacman_maze/res/enemy1.webp b/experiments/pacman_maze/res/enemy1.webp new file mode 100644 index 0000000..d76767e Binary files /dev/null and b/experiments/pacman_maze/res/enemy1.webp differ diff --git a/experiments/pacman_maze/res/enemy2.webp b/experiments/pacman_maze/res/enemy2.webp new file mode 100644 index 0000000..95a5174 Binary files /dev/null and b/experiments/pacman_maze/res/enemy2.webp differ diff --git a/experiments/pacman_maze/res/flag.png b/experiments/pacman_maze/res/flag.png new file mode 100644 index 0000000..6616cf6 Binary files /dev/null and b/experiments/pacman_maze/res/flag.png differ diff --git a/experiments/pacman_maze/run.py b/experiments/pacman_maze/run.py new file mode 100644 index 0000000..e2f0e83 --- /dev/null +++ b/experiments/pacman_maze/run.py @@ -0,0 +1,366 @@ +from argparse import ArgumentParser +import random +import cv2 +import torch +from torch import nn +from torch import optim +from tqdm import tqdm +import scallopy +from collections import namedtuple, deque +import os + +from arena import AvoidingArena, crop_cell_image_torch + +FILE_DIR = os.path.abspath(os.path.join(os.path.abspath(__file__), "../")) + +Transition = namedtuple("Transition", ("state", "action", "next_state", "reward")) + +class CellClassifier(nn.Module): + """ + Classifies each cell (in image format) into one of 4 classes: agent, goal, enemy, [empty] + """ + def __init__(self): + super(CellClassifier, self).__init__() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=8, stride=4, padding=2) + self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=4, padding=1) + self.fc1 = nn.Linear(in_features=288, out_features=256) + self.fc2 = nn.Linear(in_features=256, out_features=4) + self.relu = nn.ReLU() + + def forward(self, x): + batch_size, _, _, _ = x.shape + x = self.relu(self.conv1(x)) # In: (80, 80, 4) Out: (20, 20, 16) + x = self.relu(self.conv2(x)) # In: (20, 20, 16) Out: (10, 10, 32) + x = x.view(batch_size, -1) + x = self.relu(self.fc1(x)) # In: (3200,) Out: (256,) + x = self.fc2(x) # In: (256,) Out: (4,) + x = torch.softmax(x, dim=1) + return x + + +class EntityExtractor(nn.Module): + """ + Divide the whole image into grid cells, and pass the grid cells into the CellFeatureNet. + The output of this network is 3 separate vectors: is_actor, is_goal, is_enemy, + Each vector is of length #cells, mapping each cell to respective property (actor, goal, enemy) + """ + def __init__(self, grid_x, grid_y, cell_pixel_size): + super(EntityExtractor, self).__init__() + self.grid_x = grid_x + self.grid_y = grid_y + self.cell_pixel_size = cell_pixel_size + self.cell_dim = (self.grid_x, self.grid_y) + self.cells = [(i, j) for i in range(grid_x) for j in range(grid_y)] + self.cell_feature_net = CellClassifier() + + def forward(self, x): + batch_size, _, _, _ = x.shape + num_cells = len(self.cells) + cells = torch.stack([torch.stack([crop_cell_image_torch(x[i], self.cell_dim, self.cell_pixel_size, c) for c in self.cells]) for i in range(batch_size)]) + cells = cells.reshape(batch_size * num_cells, 3, self.cell_pixel_size[0], self.cell_pixel_size[1]) + features = self.cell_feature_net(cells) + batched_features = features.reshape(batch_size, num_cells, 4) + is_actor = batched_features[:, :, 0] + is_goal = batched_features[:, :, 1] + is_enemy = batched_features[:, :, 2] + return (is_actor, is_goal, is_enemy) + + +class PolicyNet(nn.Module): + """ + A policy net that takes in an image and return the action scores as [UP, RIGHT, BOTTOM, LEFT] + """ + def __init__(self, grid_x, grid_y, cell_pixel_size): + super(PolicyNet, self).__init__() + self.grid_x, self.grid_y = grid_x, grid_y + self.cells = [(x, y) for x in range(grid_x) for y in range(grid_y)] + + # Setup CNNs that process the image and extract features + self.extract_entity = EntityExtractor(grid_x, grid_y, cell_pixel_size) + + # Setup scallop context and scallop forward functions + self.path_planner = scallopy.Module( + program=""" + // Static input facts + type grid_node(x: usize, y: usize) + + // Input from neural networks + type actor(x: usize, y: usize) + type goal(x: usize, y: usize) + type enemy(x: usize, y: usize) + + // Possible actions to take + type Action = Up | Right | Down | Left + + // Output relation + type action_to_take(Action) + + // Basic connectivity + rel node(x, y) = grid_node(x, y), not enemy(x, y) + rel edge(x, y, x, yp, Up) = node(x, y), node(x, yp), yp == y + 1 // Up + rel edge(x, y, xp, y, Right) = node(x, y), node(xp, y), xp == x + 1 // Right + rel edge(x, y, x, yp, Down) = node(x, y), node(x, yp), yp == y - 1 // Down + rel edge(x, y, xp, y, Left) = node(x, y), node(xp, y), xp == x - 1 // Left + + // Path for connectivity; will condition on no enemy on the path + rel path(x, y, x, y) = node(x, y) + rel path(x, y, xp, yp) = edge(x, y, xp, yp, _) + rel path(x, y, xpp, ypp) = path(x, y, xp, yp), edge(xp, yp, xpp, ypp, _) + + // Get the next position + rel next_position(a, xp, yp) = actor(x, y), edge(x, y, xp, yp, a) + rel next_action(a) = next_position(a, x, y), goal(gx, gy), path(x, y, gx, gy) + + // Constraint violation + rel too_many_goal() = n := count(x, y: goal(x, y)), n > 1 + rel too_many_enemy() = n := count(x, y: enemy(x, y)), n > 5 + rel violation() = too_many_goal() or too_many_enemy() + """, + provenance="difftopkproofs", + k=1, + facts={"grid_node": [(torch.tensor(args.attenuation, requires_grad=False), c) for c in self.cells]}, + input_mappings={"actor": self.cells, "goal": self.cells, "enemy": self.cells}, + retain_topk={"actor": 3, "goal": 3, "enemy": 7}, + output_mappings={"next_action": list(range(4)), "violation": ()}, + retain_graph=True) + + def forward(self, x): + actor, goal, enemy = self.extract_entity(x) + result = self.path_planner(actor=actor, goal=goal, enemy=enemy) + next_action = torch.softmax(result["next_action"], dim=1) + violation = result["violation"] * 0.2 + return next_action, violation + + def visualize(self, arena, raw_image, torch_image): + curr_position, goal_position, is_enemy = self.extract_entity(torch_image) + for (i, c) in enumerate(self.cells): + blue, green, red = curr_position[0, i], goal_position[0, i], is_enemy[0, i] + arena.paint_color(raw_image, (blue, green, red), c) + return + + +class ReplayMemory: + def __init__(self, capacity): + self.memory = deque([], maxlen=capacity) + + def push(self, transition): + self.memory.append(transition) + + def sample(self, batch_size): + return random.sample(self.memory, batch_size) + + def __len__(self): + return len(self.memory) + + +class DQN: + def __init__(self, grid_x, grid_y, cell_pixel_size): + self.policy_net = PolicyNet(grid_x, grid_y, cell_pixel_size) + + # Create another network + self.target_net = PolicyNet(grid_x, grid_y, cell_pixel_size) + self.target_net.load_state_dict(self.policy_net.state_dict()) + self.target_net.eval() + + # Store replay memory + self.memory = ReplayMemory(args.replay_memory_capacity) + + # Loss function and optimizer + self.criterion = nn.HuberLoss() + self.violation_loss = nn.SmoothL1Loss() + self.optimizer = optim.RMSprop(self.policy_net.parameters(), args.learning_rate) + + def predict_action(self, state_image): + action_scores, _ = self.policy_net(state_image) # [0.25, 0.24, 0.26, 0.25] + action = torch.argmax(action_scores, dim=1) # 2 + return action + + def observe_transition(self, transition): + self.memory.push(transition) + + def optimize_model(self): + if len(self.memory) < args.batch_size: return 0.0 + + # Pull out a batch and its relevant features + batch = self.memory.sample(args.batch_size) + non_final_mask = torch.tensor([transition.next_state != None for transition in batch], dtype=torch.bool) + non_final_next_states = torch.stack([transition.next_state[0] for transition in batch if transition.next_state is not None]) + action_batch = torch.stack([transition.action for transition in batch]) + state_batch = torch.stack([transition.state[0] for transition in batch]) + reward_batch = torch.stack([torch.tensor(transition.reward) for transition in batch]) + + # Prepare the loss function + state_action_values_raw, violations = self.policy_net(state_batch) + state_action_values = state_action_values_raw.gather(1, action_batch)[:, 0] + next_state_values = torch.zeros(args.batch_size) + next_state_values[non_final_mask] = self.target_net(non_final_next_states)[0].max(1)[0].detach() + expected_state_action_values = (next_state_values * args.gamma) + reward_batch + + # Compute the loss: + loss1 = self.criterion(state_action_values, expected_state_action_values) # loss1 is the criterion + loss2 = self.violation_loss(violations, torch.zeros(args.batch_size)) # loss2 is the violation + loss = loss1 + loss2 + + self.optimizer.zero_grad() + loss.backward(retain_graph=True) + for param in self.policy_net.parameters(): + param.grad.data.clamp_(-1, 1) + self.optimizer.step() + + # Return loss + return loss.detach() + + def update_target(self): + self.target_net.load_state_dict(self.policy_net.state_dict()) + + +class Trainer: + def __init__(self, grid_x, grid_y, cell_size, dpi, num_enemies, epsilon): + self.arena = AvoidingArena((grid_x, grid_y), cell_size, dpi, num_enemies, easy=args.easy) + self.dqn = DQN(grid_x, grid_y, self.arena.cell_pixel_size()) + self.epsilon = epsilon + + def show_image(self, raw_image, torch_image): + if args.overlay_prediction: + self.dqn.policy_net.visualize(self.arena, raw_image, torch_image) + cv2.namedWindow("Current Arena", cv2.WINDOW_NORMAL) + cv2.resizeWindow("Current Arena", args.window_size, args.window_size) + cv2.imshow("Current Arena", raw_image) + cv2.waitKey(int(args.show_run_interval * 1000)) + + def train_epoch(self, epoch): + self.epsilon = self.epsilon * args.epsilon_falloff + success, failure, optimize_count, sum_loss = 0, 0, 0, 0.0 + iterator = tqdm(range(args.num_train_episodes)) + for episode_i in iterator: + _ = self.arena.reset() + curr_raw_image = self.arena.render() + curr_state_image = self.arena.render_torch_tensor(image=curr_raw_image) + for _ in range(args.num_steps): + # Render + if args.show_train_run: + self.show_image(curr_raw_image, curr_state_image) + + # Pick an action + if random.random() < self.epsilon: action = torch.tensor([random.randint(0, 3)]) + else: action = self.dqn.predict_action(curr_state_image) + + # Step the environment + _, done, reward, _ = self.arena.step(action[0]) + + # Get the next state + if done: + next_raw_image = None + next_state_image = None + else: + next_raw_image = self.arena.render() + next_state_image = self.arena.render_torch_tensor(image=next_raw_image) + + # Record the transition in memory buffer + transition = Transition(curr_state_image, action, next_state_image, reward) + self.dqn.observe_transition(transition) + + # Update the model + loss = self.dqn.optimize_model() + sum_loss += loss + optimize_count += 1 + + # Update the next state + if done: + if reward > 0: success += 1 + else: failure += 1 + break + else: + curr_raw_image = next_raw_image + curr_state_image = next_state_image + + # Update the target net + if episode_i % args.target_update == 0: + self.dqn.update_target() + + # Print information + success_rate = (success / (episode_i + 1)) * 100.0 + avg_loss = sum_loss / optimize_count + iterator.set_description(f"[Train Epoch {epoch}] Avg Loss: {avg_loss}, Success: {success}/{episode_i + 1} ({success_rate:.2f}%)") + + def test_epoch(self, epoch): + success, failure = 0, 0 + iterator = tqdm(range(args.num_test_episodes)) + for episode_i in iterator: + _ = self.arena.reset() + raw_image = self.arena.render() + state_image = self.arena.render_torch_tensor(image=raw_image) + for _ in range(args.num_steps): + # Show image + if args.show_test_run: + self.show_image(raw_image, state_image) + + # Pick an action + action = self.dqn.predict_action(state_image) + _, done, reward, _ = self.arena.step(action[0]) + raw_image = self.arena.render() + state_image = self.arena.render_torch_tensor(image=raw_image) + + # Update the next state + if done: + if reward > 0: success += 1 + else: failure += 1 + break + + # Print information + success_rate = (success / (episode_i + 1)) * 100.0 + iterator.set_description(f"[Test Epoch {epoch}] Success {success}/{episode_i + 1} ({success_rate:.2f}%)") + + def run(self): + # self.test_epoch(0) + for i in range(1, args.num_epochs + 1): + self.train_epoch(i) + self.test_epoch(i) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--grid-x", type=int, default=5) + parser.add_argument("--grid-y", type=int, default=5) + parser.add_argument("--cell-size", type=float, default=0.5) + parser.add_argument("--dpi", type=int, default=80) + parser.add_argument("--batch-size", type=int, default=24) + parser.add_argument("--num-enemies", type=int, default=5) + parser.add_argument("--num-epochs", type=int, default=100) + parser.add_argument("--num-train-episodes", type=int, default=100) + parser.add_argument("--num-test-episodes", type=int, default=100) + parser.add_argument("--num-steps", type=int, default=30) + parser.add_argument("--target-update", type=int, default=10) + parser.add_argument("--epsilon", type=float, default=0.9) + parser.add_argument("--epsilon-falloff", type=float, default=0.98) + parser.add_argument("--gamma", type=float, default=0.999) + parser.add_argument("--learning-rate", type=float, default=0.0001) + parser.add_argument("--replay-memory-capacity", type=int, default=3000) + parser.add_argument("--seed", type=int, default=1357) + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--attenuation", type=float, default=0.95) + parser.add_argument("--show-run", action="store_true") + parser.add_argument("--show-train-run", action="store_true") + parser.add_argument("--show-test-run", action="store_true") + parser.add_argument("--show-run-interval", type=int, default=0.001) + parser.add_argument("--window-size", type=int, default=200) + parser.add_argument("--overlay-prediction", action="store_true") + parser.add_argument("--easy", action="store_true") + args = parser.parse_args() + + # Set parameters + args.show_run_interval = max(0.001, args.show_run_interval) # Minimum 1ms + if args.show_run: + args.show_train_run = True + args.show_test_run = True + torch.manual_seed(args.seed) + random.seed(args.seed) + if args.cuda: + if torch.cuda.is_available(): device = torch.device(f"cuda:{args.gpu}") + else: raise Exception("No cuda available") + else: device = torch.device("cpu") + + # Train + trainer = Trainer(args.grid_x, args.grid_y, args.cell_size, args.dpi, args.num_enemies, args.epsilon) + trainer.run() diff --git a/experiments/pacman_maze/run_demo.py b/experiments/pacman_maze/run_demo.py new file mode 100644 index 0000000..5937ad9 --- /dev/null +++ b/experiments/pacman_maze/run_demo.py @@ -0,0 +1,93 @@ +import random +import argparse +import cv2 + +from arena import AvoidingArena + +ESC_KEY = 27 +LEFT_KEY = 63234 +UP_KEY = 63232 +DOWN_KEY = 63233 +RIGHT_KEY = 63235 + +def show_image(env: AvoidingArena): + # Render an image + image = env.render() + + # Show the image + cv2.imshow("Avoid Enemy", image) + + # Show the cells + if args.show_cells: + cells = [(i, j) for i in range(args.grid_x) for j in range(args.grid_y)] + for cell_pos in cells: + cell_image = env.get_cell_image(image, cell_pos) + cv2.imshow(f"Cell {cell_pos}", cell_image) + + # Wait for next frame + if args.auto: + cv2.waitKey(int(args.interval * 1000)) + elif args.manual: + key = cv2.waitKeyEx() + if key == ESC_KEY: + cv2.destroyAllWindows() + exit() + elif key == UP_KEY: + return env.step(0) + elif key == RIGHT_KEY: + return env.step(1) + elif key == DOWN_KEY: + return env.step(2) + elif key == LEFT_KEY: + return env.step(3) + else: + key = cv2.waitKey() + if key == ESC_KEY: + cv2.destroyAllWindows() + exit() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--show-image", action="store_true") + parser.add_argument("--show-cells", action="store_true") + parser.add_argument("--print-arena", action="store_true") + parser.add_argument("--auto", action="store_true") + parser.add_argument("--manual", action="store_true") + parser.add_argument("--interval", type=float, default=0.3) + parser.add_argument("--grid-x", type=int, default=5) + parser.add_argument("--grid-y", type=int, default=5) + parser.add_argument("--cell-size", type=float, default=0.5) + parser.add_argument("--num-enemies", type=int, default=5) + parser.add_argument("--dpi", type=int, default=80) + parser.add_argument("--seed", type=int, default=1358) + args = parser.parse_args() + + # Manual + if args.manual: + args.show_image = True + random.seed(args.seed) + + # Start environment + env = AvoidingArena((args.grid_x, args.grid_y), args.cell_size, args.dpi, args.num_enemies) + done, reward = False, 0 + env.reset() + + # Print or show image + if args.print_arena: env.print_state() + if args.show_image: result = show_image(env) + + # Enter loop + while not done: + if not args.manual: + action = random.randint(0, 3) + print(f"Taking action: {env.string_of_action(action)}") + _, done, reward, _ = env.step(action) + else: + _, done, reward, _ = result + + # If finished + if done: print("Success!" if reward > 0 else "Failed!") + + # Print or show image + if args.print_arena: env.print_state() + if args.show_image: result = show_image(env) diff --git a/experiments/pacman_maze/run_dqn.py b/experiments/pacman_maze/run_dqn.py new file mode 100644 index 0000000..584806e --- /dev/null +++ b/experiments/pacman_maze/run_dqn.py @@ -0,0 +1,255 @@ +from argparse import ArgumentParser + +import random +import torch +from torch import nn +from torch import optim +import cv2 +from tqdm import tqdm +from collections import namedtuple, deque + +from arena import AvoidingArena + +Transition = namedtuple("Transition", ("state", "action", "next_state", "reward")) + +class PolicyNet(nn.Module): + def __init__(self): + super(PolicyNet, self).__init__() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=8, stride=4, padding=2) + self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=4, padding=1) + self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=4, padding=1) + self.fc1 = nn.Linear(in_features=576, out_features=256) + self.fc2 = nn.Linear(in_features=256, out_features=4) + self.relu = nn.ReLU() + + def forward(self, x): + batch_size, _, _, _ = x.shape + x = self.relu(self.conv1(x)) # In: (80, 80, 4) Out: (20, 20, 16) + x = self.relu(self.conv2(x)) # In: (20, 20, 16) Out: (10, 10, 32) + x = self.relu(self.conv3(x)) # In: (20, 20, 32) Out: (10, 10, 64) + x = x.view(batch_size, -1) # In: (10, 10, 32) Out: (3200,) + x = self.relu(self.fc1(x)) # In: (3200,) Out: (256,) + x = self.fc2(x) # In: (256,) Out: (4,) + return x + + +class ReplayMemory: + def __init__(self, capacity): + self.memory = deque([], maxlen=capacity) + + def push(self, transition): + self.memory.append(transition) + + def sample(self, batch_size): + return random.sample(self.memory, batch_size) + + def __len__(self): + return len(self.memory) + + +class DQN: + def __init__(self): + self.policy_net = PolicyNet() + + # Create another network + self.target_net = PolicyNet() + self.target_net.load_state_dict(self.policy_net.state_dict()) + self.target_net.eval() + + # Store replay memory + self.memory = ReplayMemory(args.replay_memory_capacity) + + # Loss function and optimizer + self.criterion = nn.SmoothL1Loss() + self.optimizer = optim.RMSprop(self.policy_net.parameters(), lr=args.learning_rate) + + def predict_action(self, state_image): + action_scores = self.policy_net(state_image) + action = torch.argmax(action_scores, dim=1) + return action + + def observe_transition(self, transition): + self.memory.push(transition) + + def optimize_model(self): + if len(self.memory) < args.batch_size: return 0.0 + + # Pull out a batch and its relevant features + batch = self.memory.sample(args.batch_size) + non_final_mask = torch.tensor([transition.next_state != None for transition in batch], dtype=torch.bool) + non_final_next_states = torch.stack([transition.next_state[0] for transition in batch if transition.next_state is not None]) + action_batch = torch.stack([transition.action for transition in batch]) + state_batch = torch.stack([transition.state[0] for transition in batch]) + reward_batch = torch.stack([torch.tensor(transition.reward) for transition in batch]) + + # Prepare the loss function + state_action_values = self.policy_net(state_batch).gather(1, action_batch)[:, 0] + next_state_values = torch.zeros(args.batch_size) + next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach() + expected_state_action_values = (next_state_values * args.gamma) + reward_batch + + # Compute the loss + loss = self.criterion(state_action_values, expected_state_action_values) + self.optimizer.zero_grad() + loss.backward() + for param in self.policy_net.parameters(): + param.grad.data.clamp_(-1, 1) + self.optimizer.step() + + # Return loss + return loss.detach() + + def update_target(self): + self.target_net.load_state_dict(self.policy_net.state_dict()) + + +class Trainer: + def __init__(self, grid_x, grid_y, cell_size, dpi, num_enemies, epsilon): + self.dqn = DQN() + self.epsilon = epsilon + self.arena = AvoidingArena( + (grid_x, grid_y), + cell_size, + dpi, + num_enemies, + default_reward=0.0, + on_success_reward=1.0, + on_failure_reward=0.0, + remain_unchanged_reward=0.0, + ) + + def show_image(self, image): + cv2.imshow("Current Arena", image) + cv2.waitKey(int(args.show_run_interval * 1000)) + + def train_epoch(self, epoch): + self.epsilon = self.epsilon * args.epsilon_falloff + success, failure, optimize_count, sum_loss = 0, 0, 0, 0.0 + iterator = tqdm(range(args.num_train_episodes)) + for episode_i in iterator: + _ = self.arena.reset() + curr_raw_image = self.arena.render() + curr_state_image = self.arena.render_torch_tensor(image=curr_raw_image) + sum_loss = 0.0 + for _ in range(args.num_steps): + # Render + if args.show_train_run: + self.show_image(curr_raw_image) + + # Pick an action + if random.random() < self.epsilon: action = torch.tensor([random.randint(0, 3)]) + else: action = self.dqn.predict_action(curr_state_image) + + # Step the environment + _, done, reward, _ = self.arena.step(action[0]) + + # Record the transition in memory buffer + if done: + next_raw_image = None + next_state_image = None + else: + next_raw_image = self.arena.render() + next_state_image = self.arena.render_torch_tensor(image=next_raw_image) + transition = Transition(curr_state_image, action, next_state_image, reward) + self.dqn.observe_transition(transition) + + # Update the model + loss = self.dqn.optimize_model() + sum_loss += loss + optimize_count += 1 + + # Update the next state + if done: + if reward > 0: success += 1 + else: failure += 1 + break + else: + curr_raw_image = next_raw_image + curr_state_image = next_state_image + + # Update the target net + if episode_i % args.target_update == 0: + self.dqn.update_target() + + # Print information + success_rate = (success / (episode_i + 1)) * 100.0 + avg_loss = sum_loss / optimize_count + iterator.set_description(f"[Train Epoch {epoch}] Avg Loss: {avg_loss}, Success: {success}/{episode_i + 1} ({success_rate:.2f}%)") + + def test_epoch(self, epoch): + success, failure = 0, 0 + iterator = tqdm(range(args.num_test_episodes)) + for episode_i in iterator: + _ = self.arena.reset() + raw_image = self.arena.render() + state_image = self.arena.render_torch_tensor(image=raw_image) + for _ in range(args.num_steps): + # Show image + if args.show_test_run: + self.show_image(raw_image) + + # Pick an action + action = self.dqn.predict_action(state_image) + _, done, reward, _ = self.arena.step(action[0]) + raw_image = self.arena.render() + state_image = self.arena.render_torch_tensor(image=raw_image) + + # Update the next state + if done: + if reward > 0: success += 1 + else: failure += 1 + break + + # Print information + success_rate = (success / (episode_i + 1)) * 100.0 + iterator.set_description(f"[Test Epoch {epoch}] Success {success}/{episode_i + 1} ({success_rate:.2f}%)") + + def run(self): + # self.test_epoch(0) + for i in range(1, args.num_epochs + 1): + self.train_epoch(i) + self.test_epoch(i) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--grid-x", type=int, default=5) + parser.add_argument("--grid-y", type=int, default=5) + parser.add_argument("--cell-size", type=float, default=0.5) + parser.add_argument("--dpi", type=int, default=80) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--panelize-staying", action="store_true") + parser.add_argument("--num-enemies", type=int, default=5) + parser.add_argument("--num-epochs", type=int, default=100) + parser.add_argument("--num-train-episodes", type=int, default=500) + parser.add_argument("--num-test-episodes", type=int, default=50) + parser.add_argument("--num-steps", type=int, default=30) + parser.add_argument("--learning-rate", type=float, default=0.0001) + parser.add_argument("--target-update", type=int, default=10) + parser.add_argument("--epsilon", type=float, default=0.9) + parser.add_argument("--epsilon-falloff", type=float, default=0.98) + parser.add_argument("--gamma", type=float, default=0.999) + parser.add_argument("--replay-memory-capacity", type=int, default=1000) + parser.add_argument("--seed", type=int, default=1357) + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--show-run", action="store_true") + parser.add_argument("--show-train-run", action="store_true") + parser.add_argument("--show-test-run", action="store_true") + parser.add_argument("--show-run-interval", type=int, default=0.001) + args = parser.parse_args() + + # Set parameters + torch.manual_seed(args.seed) + random.seed(args.seed) + if args.show_run: + args.show_train_run = True + args.show_test_run = True + if args.cuda: + if torch.cuda.is_available(): device = torch.device(f"cuda:{args.gpu}") + else: raise Exception("No cuda available") + else: device = torch.device("cpu") + + # Train + trainer = Trainer(args.grid_x, args.grid_y, args.cell_size, args.dpi, args.num_enemies, args.epsilon) + trainer.run() diff --git a/experiments/pacman_maze/run_old.py b/experiments/pacman_maze/run_old.py new file mode 100644 index 0000000..b51ca58 --- /dev/null +++ b/experiments/pacman_maze/run_old.py @@ -0,0 +1,322 @@ +from argparse import ArgumentParser +import random +import cv2 +import torch +from torch import nn +from torch import optim +from tqdm import tqdm +import scallopy +from collections import namedtuple, deque +import os + +from arena import AvoidingArena, crop_cell_image_torch + +FILE_DIR = os.path.abspath(os.path.join(os.path.abspath(__file__), "../")) + +Transition = namedtuple("Transition", ("state", "action", "next_state", "reward")) + +class CellClassifier(nn.Module): + """ + Classifies each cell (in image format) into one of 4 classes: agent, goal, enemy, [empty] + """ + def __init__(self): + super(CellClassifier, self).__init__() + self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=8, stride=4, padding=2) + self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=4, padding=1) + self.fc1 = nn.Linear(in_features=288, out_features=256) + self.fc2 = nn.Linear(in_features=256, out_features=4) + self.relu = nn.ReLU() + + def forward(self, x): + batch_size, _, _, _ = x.shape + x = self.relu(self.conv1(x)) # In: (80, 80, 4) Out: (20, 20, 16) + x = self.relu(self.conv2(x)) # In: (20, 20, 16) Out: (10, 10, 32) + x = x.view(batch_size, -1) + x = self.relu(self.fc1(x)) # In: (3200,) Out: (256,) + x = self.fc2(x) # In: (256,) Out: (4,) + x = torch.softmax(x, dim=1) + return x + + +class EntityExtractor(nn.Module): + """ + Divide the whole image into grid cells, and pass the grid cells into the CellFeatureNet. + The output of this network is 3 separate vectors: is_agent, is_goal, is_enemy, + Each vector is of length #cells, mapping each cell to respective property (agent, goal, enemy) + """ + def __init__(self, grid_x, grid_y, cell_pixel_size): + super(EntityExtractor, self).__init__() + self.grid_x = grid_x + self.grid_y = grid_y + self.cell_pixel_size = cell_pixel_size + self.cell_dim = (self.grid_x, self.grid_y) + self.cells = [(i, j) for i in range(grid_x) for j in range(grid_y)] + self.cell_classifier = CellClassifier() + + def forward(self, x): + batch_size, _, _, _ = x.shape + num_cells = len(self.cells) + cells = torch.stack([torch.stack([crop_cell_image_torch(x[i], self.cell_dim, self.cell_pixel_size, c) for c in self.cells]) for i in range(batch_size)]) + cells = cells.reshape(batch_size * num_cells, 3, self.cell_pixel_size[0], self.cell_pixel_size[1]) + features = self.cell_classifier(cells) + batched_features = features.reshape(batch_size, num_cells, 4) + is_agent = batched_features[:, :, 0] + is_goal = batched_features[:, :, 1] + is_enemy = batched_features[:, :, 2] + return (is_agent, is_goal, is_enemy) + + +class PolicyNet(nn.Module): + """ + A policy net that takes in an image and return the action scores as [UP, RIGHT, BOTTOM, LEFT] + """ + def __init__(self, grid_x, grid_y, cell_pixel_size): + super(PolicyNet, self).__init__() + self.cells = [(x, y) for x in range(grid_x) for y in range(grid_y)] + self.grid_x = grid_x + self.grid_y = grid_y + + # Setup CNNs that process the image and extract features + self.extract_entity = EntityExtractor(grid_x, grid_y, cell_pixel_size) + + # Setup scallop context and scallop forward functions + self.ctx = scallopy.ScallopContext(provenance="difftopbottomkclauses", k=1) + self.ctx.import_file(os.path.join(FILE_DIR, "scl", "arena.scl")) + self.ctx.add_facts("grid_node", [(torch.tensor(args.attenuation, requires_grad=False), c) for c in self.cells]) + self.ctx.set_input_mapping("curr_position", self.cells, retain_k=3) + self.ctx.set_input_mapping("goal_position", self.cells, retain_k=3) + self.ctx.set_input_mapping("is_enemy", self.cells, retain_k=7) + self.predict_action = self.ctx.forward_function("action_score", list(range(4)), jit=args.jit, recompile=args.recompile) + + def forward(self, x): + curr_position, goal_position, is_enemy = self.extract_entity(x) + exp_reward = self.predict_action(curr_position=curr_position, goal_position=goal_position, is_enemy=is_enemy) + return exp_reward + + def visualize(self, arena, raw_image, torch_image): + curr_position, goal_position, is_enemy = self.extract_entity(torch_image) + for (i, c) in enumerate(self.cells): + blue, green, red = curr_position[0, i], goal_position[0, i], is_enemy[0, i] + arena.paint_color(raw_image, (blue, green, red), c) + return + + +class ReplayMemory: + def __init__(self, capacity): + self.memory = deque([], maxlen=capacity) + + def push(self, transition): + self.memory.append(transition) + + def sample(self, batch_size): + return random.sample(self.memory, batch_size) + + def __len__(self): + return len(self.memory) + + +class DQN: + def __init__(self, grid_x, grid_y, cell_pixel_size): + self.policy_net = PolicyNet(grid_x, grid_y, cell_pixel_size) + + # Create another network + self.target_net = PolicyNet(grid_x, grid_y, cell_pixel_size) + self.target_net.load_state_dict(self.policy_net.state_dict()) + self.target_net.eval() + + # Store replay memory + self.memory = ReplayMemory(args.replay_memory_capacity) + + # Loss function and optimizer + self.criterion = nn.HuberLoss() + self.optimizer = optim.RMSprop(self.policy_net.parameters(), args.learning_rate) + + def predict_action(self, state_image): + action_scores = self.policy_net(state_image) # [0.25, 0.24, 0.26, 0.25] + action = torch.argmax(action_scores, dim=1) # 2 + return action + + def observe_transition(self, transition): + self.memory.push(transition) + + def optimize_model(self): + if len(self.memory) < args.batch_size: return 0.0 + + # Pull out a batch and its relevant features + batch = self.memory.sample(args.batch_size) + non_final_mask = torch.tensor([transition.next_state != None for transition in batch], dtype=torch.bool) + non_final_next_states = torch.stack([transition.next_state[0] for transition in batch if transition.next_state is not None]) + action_batch = torch.stack([transition.action for transition in batch]) + state_batch = torch.stack([transition.state[0] for transition in batch]) + reward_batch = torch.stack([torch.tensor(transition.reward) for transition in batch]) + + # Prepare the loss function + state_action_values = self.policy_net(state_batch).gather(1, action_batch)[:, 0] + next_state_values = torch.zeros(args.batch_size) + next_state_values[non_final_mask] = self.target_net(non_final_next_states).max(1)[0].detach() + expected_state_action_values = (next_state_values * args.gamma) + reward_batch + + # Compute the loss + loss = self.criterion(state_action_values, expected_state_action_values) + self.optimizer.zero_grad() + loss.backward() + for param in self.policy_net.parameters(): + param.grad.data.clamp_(-1, 1) + self.optimizer.step() + + # Return loss + return loss.detach() + + def update_target(self): + self.target_net.load_state_dict(self.policy_net.state_dict()) + + +class Trainer: + def __init__(self, grid_x, grid_y, cell_size, dpi, num_enemies, epsilon): + self.arena = AvoidingArena((grid_x, grid_y), cell_size, dpi, num_enemies, easy=args.easy) + self.dqn = DQN(grid_x, grid_y, self.arena.cell_pixel_size()) + self.epsilon = epsilon + + def show_image(self, raw_image, torch_image): + if args.overlay_prediction: + self.dqn.policy_net.visualize(self.arena, raw_image, torch_image) + cv2.imshow("Current Arena", raw_image) + cv2.waitKey(int(args.show_run_interval * 1000)) + + def train_epoch(self, epoch): + self.epsilon = self.epsilon * args.epsilon_falloff + success, failure, optimize_count, sum_loss = 0, 0, 0, 0.0 + iterator = tqdm(range(args.num_train_episodes)) + for episode_i in iterator: + _ = self.arena.reset() + curr_raw_image = self.arena.render() + curr_state_image = self.arena.render_torch_tensor(image=curr_raw_image) + for _ in range(args.num_steps): + # Render + if args.show_train_run: + self.show_image(curr_raw_image, curr_state_image) + + # Pick an action + if random.random() < self.epsilon: action = torch.tensor([random.randint(0, 3)]) + else: action = self.dqn.predict_action(curr_state_image) + + # Step the environment + _, done, reward, _ = self.arena.step(action[0]) + + # Get the next state + if done: + next_raw_image = None + next_state_image = None + else: + next_raw_image = self.arena.render() + next_state_image = self.arena.render_torch_tensor(image=next_raw_image) + + # Record the transition in memory buffer + transition = Transition(curr_state_image, action, next_state_image, reward) + self.dqn.observe_transition(transition) + + # Update the model + loss = self.dqn.optimize_model() + sum_loss += loss + optimize_count += 1 + + # Update the next state + if done: + if reward > 0: success += 1 + else: failure += 1 + break + else: + curr_raw_image = next_raw_image + curr_state_image = next_state_image + + # Update the target net + if episode_i % args.target_update == 0: + self.dqn.update_target() + + # Print information + success_rate = (success / (episode_i + 1)) * 100.0 + avg_loss = sum_loss / optimize_count + iterator.set_description(f"[Train Epoch {epoch}] Avg Loss: {avg_loss}, Success: {success}/{episode_i + 1} ({success_rate:.2f}%)") + + def test_epoch(self, epoch): + success, failure = 0, 0 + iterator = tqdm(range(args.num_test_episodes)) + for episode_i in iterator: + _ = self.arena.reset() + raw_image = self.arena.render() + state_image = self.arena.render_torch_tensor(image=raw_image) + for _ in range(args.num_steps): + # Show image + if args.show_test_run: + self.show_image(raw_image, state_image) + + # Pick an action + action = self.dqn.predict_action(state_image) + _, done, reward, _ = self.arena.step(action[0]) + raw_image = self.arena.render() + state_image = self.arena.render_torch_tensor(image=raw_image) + + # Update the next state + if done: + if reward > 0: success += 1 + else: failure += 1 + break + + # Print information + success_rate = (success / (episode_i + 1)) * 100.0 + iterator.set_description(f"[Test Epoch {epoch}] Success {success}/{episode_i + 1} ({success_rate:.2f}%)") + + def run(self): + # self.test_epoch(0) + for i in range(1, args.num_epochs + 1): + self.train_epoch(i) + self.test_epoch(i) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--grid-x", type=int, default=5) + parser.add_argument("--grid-y", type=int, default=5) + parser.add_argument("--cell-size", type=float, default=0.5) + parser.add_argument("--dpi", type=int, default=80) + parser.add_argument("--batch-size", type=int, default=24) + parser.add_argument("--num-enemies", type=int, default=5) + parser.add_argument("--num-epochs", type=int, default=100) + parser.add_argument("--num-train-episodes", type=int, default=100) + parser.add_argument("--num-test-episodes", type=int, default=100) + parser.add_argument("--num-steps", type=int, default=30) + parser.add_argument("--target-update", type=int, default=10) + parser.add_argument("--epsilon", type=float, default=0.9) + parser.add_argument("--epsilon-falloff", type=float, default=0.98) + parser.add_argument("--gamma", type=float, default=0.999) + parser.add_argument("--learning-rate", type=float, default=0.0001) + parser.add_argument("--replay-memory-capacity", type=int, default=3000) + parser.add_argument("--seed", type=int, default=1357) + parser.add_argument("--cuda", action="store_true") + parser.add_argument("--gpu", type=int, default=0) + parser.add_argument("--jit", action="store_true") + parser.add_argument("--recompile", action="store_true") + parser.add_argument("--attenuation", type=float, default=0.95) + parser.add_argument("--show-run", action="store_true") + parser.add_argument("--show-train-run", action="store_true") + parser.add_argument("--show-test-run", action="store_true") + parser.add_argument("--show-run-interval", type=int, default=0.001) + parser.add_argument("--overlay-prediction", action="store_true") + parser.add_argument("--easy", action="store_true") + args = parser.parse_args() + + # Set parameters + args.show_run_interval = max(0.001, args.show_run_interval) # Minimum 1ms + if args.show_run: + args.show_train_run = True + args.show_test_run = True + torch.manual_seed(args.seed) + random.seed(args.seed) + if args.cuda: + if torch.cuda.is_available(): device = torch.device(f"cuda:{args.gpu}") + else: raise Exception("No cuda available") + else: device = torch.device("cpu") + + # Train + trainer = Trainer(args.grid_x, args.grid_y, args.cell_size, args.dpi, args.num_enemies, args.epsilon) + trainer.run() diff --git a/experiments/pacman_maze/run_random.py b/experiments/pacman_maze/run_random.py new file mode 100644 index 0000000..45d75b7 --- /dev/null +++ b/experiments/pacman_maze/run_random.py @@ -0,0 +1,45 @@ +from argparse import ArgumentParser +from tqdm import tqdm +import random + +from arena import AvoidingArena + +class RandomPolicy: + def __init__(self): pass + + def __call__(self, _): + return random.randint(0, 3) + +def test_random_model(): + # Initialize + arena = AvoidingArena((args.grid_x, args.grid_y), args.cell_size, args.dpi, args.num_enemies) + model = RandomPolicy() + success, failure = 0, 0 + iterator = tqdm(range(args.num_episodes)) + for episode_i in iterator: + state = arena.reset() + for _ in range(args.num_steps): + action = model(state) + _, done, reward, _ = arena.step(action) + if done: + if reward > 0: success += 1 + else: failure += 1 + break + + # Print + success_rate = (success / (episode_i + 1)) * 100.0 + iterator.set_description(f"[Test] {success}/{episode_i + 1} ({success_rate:.2f}%)") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--grid-x", type=int, default=5) + parser.add_argument("--grid-y", type=int, default=5) + parser.add_argument("--cell-size", type=float, default=0.5) + parser.add_argument("--dpi", type=int, default=80) + parser.add_argument("--num-enemies", type=int, default=5) + parser.add_argument("--num-episodes", type=int, default=1000) + parser.add_argument("--num-steps", type=int, default=30) + args = parser.parse_args() + + test_random_model() diff --git a/experiments/pacman_maze/run_scallop_sanity.py b/experiments/pacman_maze/run_scallop_sanity.py new file mode 100644 index 0000000..268ad23 --- /dev/null +++ b/experiments/pacman_maze/run_scallop_sanity.py @@ -0,0 +1,87 @@ +from argparse import ArgumentParser +from tqdm import tqdm +import numpy +import random +import scallopy +import os +import cv2 + +from arena import AvoidingArena + +FILE_DIR = os.path.abspath(os.path.join(os.path.abspath(__file__), "../")) + +class ScallopPolicy: + def __init__(self, grid_x, grid_y): + self.grid_x, self.grid_y = grid_x, grid_y + self.cells = [(x, y) for x in range(self.grid_x) for y in range(self.grid_y)] + self.ctx = scallopy.ScallopContext(provenance="topkproofs", k=1) + self.ctx.import_file(os.path.join(FILE_DIR, "scl", "arena.scl")) + self.ctx.add_facts("grid_node", [(args.attenuation, c) for c in self.cells]) + + def __call__(self, hidden_state): + curr_pos, goal_pos, enemies = hidden_state + temp_ctx = self.ctx.clone() + temp_ctx.add_facts("curr_position", [(None, curr_pos)]) + temp_ctx.add_facts("goal_position", [(None, goal_pos)]) + temp_ctx.add_facts("is_enemy", [(None, p) for p in enemies]) + temp_ctx.run() + output = list(temp_ctx.relation("action_score")) + if len(output) > 0: + action_id = numpy.argmax(numpy.array([p for (p, _) in output])) + action = output[action_id][1][0] + return action + else: + return random.randrange(0, 4) + + +def show_image(raw_image): + cv2.imshow("Current Arena", raw_image) + cv2.waitKey(int(args.show_run_interval * 1000)) + + +def test_scallop_model(): + # Initialize + arena = AvoidingArena((args.grid_x, args.grid_y), args.cell_size, args.dpi, args.num_enemies) + model = ScallopPolicy(args.grid_x, args.grid_y) + success, failure = 0, 0 + iterator = tqdm(range(args.num_episodes)) + for episode_i in iterator: + _ = arena.reset() + if args.show_run: + show_image(arena.render()) + for _ in range(args.num_steps): + action = model(arena.hidden_state()) + _, done, reward, _ = arena.step(action) + if args.show_run: + show_image(arena.render()) + if done: + if reward > 0: success += 1 + else: failure += 1 + break + + # Print + success_rate = (success / (episode_i + 1)) * 100.0 + iterator.set_description(f"[Test] {success}/{episode_i + 1} ({success_rate:.2f}%)") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--grid-x", type=int, default=5) + parser.add_argument("--grid-y", type=int, default=5) + parser.add_argument("--cell-size", type=float, default=0.5) + parser.add_argument("--dpi", type=int, default=80) + parser.add_argument("--num-enemies", type=int, default=5) + parser.add_argument("--num-episodes", type=int, default=1000) + parser.add_argument("--num-steps", type=int, default=30) + parser.add_argument("--attenuation", type=float, default=0.9) + parser.add_argument("--seed", type=int, default=12345) + parser.add_argument("--show-run", action="store_true") + parser.add_argument("--show-run-interval", type=float, default=0.001) + args = parser.parse_args() + + # Other arguments + args.show_run = True + random.seed(args.seed) + + # Test scallop model + test_scallop_model() diff --git a/experiments/pacman_maze/scl/arena.scl b/experiments/pacman_maze/scl/arena.scl new file mode 100644 index 0000000..0accd5f --- /dev/null +++ b/experiments/pacman_maze/scl/arena.scl @@ -0,0 +1,21 @@ +// Input from neural networks +type grid_node(x: usize, y: usize) +type curr_position(x: usize, y: usize) +type goal_position(x: usize, y: usize) +type is_enemy(x: usize, y: usize) + +// Basic connectivity +rel node(x, y) = grid_node(x, y), not is_enemy(x, y) +rel edge(x, y, x, yp, 0) = node(x, y), node(x, yp), yp == y + 1 // Up +rel edge(x, y, xp, y, 1) = node(x, y), node(xp, y), xp == x + 1 // Right +rel edge(x, y, x, yp, 2) = node(x, y), node(x, yp), yp == y - 1 // Down +rel edge(x, y, xp, y, 3) = node(x, y), node(xp, y), xp == x - 1 // Left + +// Path for connectivity; will condition on no enemy on the path +rel path(x, y, x, y) = node(x, y) +rel path(x, y, xp, yp) = edge(x, y, xp, yp, _) +rel path(x, y, xpp, ypp) = path(x, y, xp, yp), edge(xp, yp, xpp, ypp, _) + +// Get the next position +rel next_position(a, xp, yp) = curr_position(x, y), edge(x, y, xp, yp, a) +rel action_score(a) = next_position(a, x, y), goal_position(gx, gy), path(x, y, gx, gy) diff --git a/experiments/pacman_maze/scl/arena_w_constraint.scl b/experiments/pacman_maze/scl/arena_w_constraint.scl new file mode 100644 index 0000000..42e72b7 --- /dev/null +++ b/experiments/pacman_maze/scl/arena_w_constraint.scl @@ -0,0 +1,31 @@ +// Input from neural networks +type grid_node(x: usize, y: usize) +type curr_position(x: usize, y: usize) +type goal_position(x: usize, y: usize) +type is_enemy(x: usize, y: usize) + +// Basic connectivity +rel node(x, y) = grid_node(x, y), not is_enemy(x, y) +rel edge(x, y, x, yp, 0) = node(x, y), node(x, yp), yp == y + 1 // Up +rel edge(x, y, xp, y, 1) = node(x, y), node(xp, y), xp == x + 1 // Right +rel edge(x, y, x, yp, 2) = node(x, y), node(x, yp), yp == y - 1 // Down +rel edge(x, y, xp, y, 3) = node(x, y), node(xp, y), xp == x - 1 // Left + +// Path for connectivity; will condition on no enemy on the path +rel path(x, y, x, y) = node(x, y) +rel path(x, y, xp, yp) = edge(x, y, xp, yp, _) +rel path(x, y, xpp, ypp) = path(x, y, xp, yp), edge(xp, yp, xpp, ypp, _) + +// Get the next position +rel next_position(a, xp, yp) = curr_position(x, y), edge(x, y, xp, yp, a) +rel action_score(a) = next_position(a, x, y), goal_position(gx, gy), path(x, y, gx, gy) + +// Constraint violation +type too_many_goal() +type too_many_agent() +type too_many_enemy() +// Comment the following out if we don't want a particular constraint +rel too_many_goal() :- n = count(x, y: goal_position(x, y)), n > 1 +// rel too_many_agent() :- n = count(x, y: curr_position(x, y)), n > 1 +rel too_many_enemy() :- n = count(x, y: is_enemy(x, y)), n > 5 +rel violation() = too_many_goal() or too_many_enemy() or too_many_agent() diff --git a/experiments/pacman_maze/scl/grid_node.scl b/experiments/pacman_maze/scl/grid_node.scl new file mode 100644 index 0000000..379c5c4 --- /dev/null +++ b/experiments/pacman_maze/scl/grid_node.scl @@ -0,0 +1,5 @@ +type grid_size(x: usize, y: usize) + +rel grid_node(0, 0) +rel grid_node(x, yp) = grid_node(x, y), grid_size(_, gy), yp == y + 1, yp < gy +rel grid_node(xp, y) = grid_node(x, y), grid_size(gx, _), xp == x + 1, xp < gx