Skip to content

Commit

Permalink
Bumping version to 0.2.4
Browse files Browse the repository at this point in the history
  • Loading branch information
Liby99 committed Aug 30, 2024
1 parent 847d68f commit da21ec8
Show file tree
Hide file tree
Showing 78 changed files with 2,078 additions and 304 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/scallopy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ jobs:
max-parallel: 5
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"

steps:
Expand Down
11 changes: 11 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# v0.2.4, Aug 30, 2024

- Rule tags can now be expressions with potential reference to local variables: `rel 1/n::head() = body(n)`
- Allowing for sparse gradient computation inside Scallopy to minimize memory footprint
- Allowing users to specify per-datapoint output mapping inside Scallopy
- Adding destructor syntax so that ADTs can be used in a more idiomatic way
- Unifying the behavior of integer overflow inside Scallop
- Multiple bugs fixed

# v0.2.3, Jun 23, 2024

# v0.2.2, Oct 25, 2023

- Adding `wmc_with_disjunctions` option for provenances that deal with boolean formulas for more accurate probability estimation
Expand Down
2 changes: 1 addition & 1 deletion core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "scallop-core"
version = "0.2.2"
version = "0.2.4"
authors = ["Ziyang Li <[email protected]>"]
edition = "2018"

Expand Down
46 changes: 23 additions & 23 deletions core/src/common/foreign_predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,29 +40,29 @@ impl Binding {
}
}

/// The identifier of a foreign predicate in a registry
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub struct ForeignPredicateIdentifier {
identifier: String,
types: Box<[ValueType]>,
binding_pattern: BindingPattern,
}

impl std::fmt::Display for ForeignPredicateIdentifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"pred {}[{}]({})",
self.identifier,
self.binding_pattern,
self
.types
.iter()
.map(|t| format!("{}", t))
.collect::<Vec<_>>()
.join(", ")
))
}
}
// /// The identifier of a foreign predicate in a registry
// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
// pub struct ForeignPredicateIdentifier {
// identifier: String,
// types: Box<[ValueType]>,
// binding_pattern: BindingPattern,
// }

// impl std::fmt::Display for ForeignPredicateIdentifier {
// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// f.write_fmt(format_args!(
// "pred {}[{}]({})",
// self.identifier,
// self.binding_pattern,
// self
// .types
// .iter()
// .map(|t| format!("{}", t))
// .collect::<Vec<_>>()
// .join(", ")
// ))
// }
// }

/// A binding pattern for a predicate, e.g. bbf
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
Expand Down
2 changes: 1 addition & 1 deletion core/src/compiler/back/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ impl Program {
// Perform rule level optimizations
for rule in &mut self.rules {
// First propagate equality
optimizations::propagate_equality(rule);
optimizations::propagate_equality(rule, &self.predicate_registry);

// Enter the loop of constant folding/propagation
loop {
Expand Down
24 changes: 19 additions & 5 deletions core/src/compiler/back/optimizations/equality_propagation.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use std::collections::*;

use crate::common::foreign_predicate::*;

use super::super::*;

pub fn propagate_equality(rule: &mut Rule) {
pub fn propagate_equality(rule: &mut Rule, foreign_predicate_registry: &ForeignPredicateRegistry) {
let mut substitutions = HashMap::<_, Variable>::new();
let mut ignore_literals = HashSet::new();
let mut cannot_substitute = HashSet::<Variable>::new();
Expand All @@ -18,7 +20,7 @@ pub fn propagate_equality(rule: &mut Rule) {
}

// Find all the bounded variables by atom and assign
let bounded = bounded_by_atom_and_assign(rule);
let bounded = bounded_by_atom_and_assign(rule, foreign_predicate_registry);

// Collect all substitutions
for (i, literal) in rule.body_literals().enumerate() {
Expand Down Expand Up @@ -136,14 +138,26 @@ pub fn propagate_equality(rule: &mut Rule) {
attributes: rule.attributes.clone(),
head: new_head,
body: Conjunction { args: new_literals },
}
};
}

fn bounded_by_atom_and_assign(rule: &Rule) -> HashSet<Variable> {
fn bounded_by_atom_and_assign(rule: &Rule, foreign_predicate_registry: &ForeignPredicateRegistry) -> HashSet<Variable> {
let mut bounded = rule
.body_literals()
.flat_map(|l| match l {
Literal::Atom(a) => a.variable_args().cloned().collect::<Vec<_>>(),
Literal::Atom(atom) => {
if let Some(fp) = foreign_predicate_registry.get(&atom.predicate) {
// If atom is on foreign predicate, only the variables that are free will be bounded
atom.args[fp.num_bounded()..fp.arity()]
.iter()
.filter_map(|term| term.as_variable())
.cloned()
.collect::<Vec<_>>()
} else {
// If atom is on a normal relation, all the variables will be bounded
atom.variable_args().cloned().collect::<Vec<_>>()
}
}
_ => vec![],
})
.collect::<HashSet<_>>();
Expand Down
8 changes: 7 additions & 1 deletion core/src/compiler/front/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct Analysis {
pub constant_decl_analysis: ConstantDeclAnalysis,
pub adt_analysis: AlgebraicDataTypeAnalysis,
pub head_relation_analysis: HeadRelationAnalysis,
pub tagged_rule_analysis: TaggedRuleAnalysis,
pub type_inference: TypeInference,
pub boundness_analysis: BoundnessAnalysis,
pub demand_attr_analysis: DemandAttributeAnalysis,
Expand All @@ -41,6 +42,7 @@ impl Analysis {
constant_decl_analysis: ConstantDeclAnalysis::new(),
adt_analysis: AlgebraicDataTypeAnalysis::new(),
head_relation_analysis: HeadRelationAnalysis::new(predicate_registry),
tagged_rule_analysis: TaggedRuleAnalysis::new(),
type_inference: TypeInference::new(function_registry, predicate_registry, aggregate_registry),
boundness_analysis: BoundnessAnalysis::new(predicate_registry),
demand_attr_analysis: DemandAttributeAnalysis::new(),
Expand Down Expand Up @@ -78,12 +80,15 @@ impl Analysis {
items.walk(&mut analyzers);
}

pub fn post_analysis(&mut self) {
pub fn post_analysis(&mut self, foreign_predicate_registry: &mut ForeignPredicateRegistry) {
self.head_relation_analysis.compute_errors();
self.type_inference.check_query_predicates();
self.type_inference.infer_types();
self.demand_attr_analysis.check_arity(&self.type_inference);
self.boundness_analysis.check_boundness(&self.demand_attr_analysis);
self
.tagged_rule_analysis
.register_predicates(&self.type_inference, foreign_predicate_registry);
}

pub fn dump_errors(&mut self, error_ctx: &mut FrontCompileError) {
Expand All @@ -98,5 +103,6 @@ impl Analysis {
error_ctx.extend(&mut self.type_inference.errors);
error_ctx.extend(&mut self.boundness_analysis.errors);
error_ctx.extend(&mut self.demand_attr_analysis.errors);
error_ctx.extend(&mut self.tagged_rule_analysis.errors);
}
}
2 changes: 1 addition & 1 deletion core/src/compiler/front/analyzers/constant_decl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ impl NodeVisitor<FactDecl> for ConstantDeclAnalysis {
for v in vars {
if self.variables.contains_key(v.variable_name()) {
self.variable_use.insert(v.location().clone(), v.name().to_string());
} else {
} else if !fact_decl.atom().iter_args().any(|arg| arg.is_destruct()) {
self.errors.push(ConstantDeclError::UnknownConstantVariable {
name: v.name().to_string(),
loc: v.location().clone(),
Expand Down
2 changes: 2 additions & 0 deletions core/src/compiler/front/analyzers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod input_files;
pub mod invalid_constant;
pub mod invalid_wildcard;
pub mod output_files;
pub mod tagged_rule;
pub mod type_inference;

pub use aggregation::AggregationAnalysis;
Expand All @@ -24,6 +25,7 @@ pub use input_files::InputFilesAnalysis;
pub use invalid_constant::InvalidConstantAnalyzer;
pub use invalid_wildcard::InvalidWildcardAnalyzer;
pub use output_files::OutputFilesAnalysis;
pub use tagged_rule::TaggedRuleAnalysis;
pub use type_inference::TypeInference;

pub mod errors {
Expand Down
185 changes: 185 additions & 0 deletions core/src/compiler/front/analyzers/tagged_rule.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
use lazy_static::lazy_static;
use std::collections::*;

use crate::common::expr;
use crate::common::foreign_predicate::*;
use crate::common::input_tag::*;
use crate::common::tuple::*;
use crate::common::unary_op;
use crate::common::value::*;
use crate::common::value_type::*;

use crate::compiler::front::*;
use crate::runtime::env::*;

lazy_static! {
pub static ref TAG_TYPE: Vec<ValueType> = {
use ValueType::*;
vec![F64, F32, Bool]
};
}

#[derive(Clone, Debug)]
pub struct TaggedRuleAnalysis {
pub to_add_tag_predicates: HashMap<ast::NodeLocation, ToAddTagPredicate>,
pub errors: Vec<FrontCompileErrorMessage>,
}

impl TaggedRuleAnalysis {
pub fn new() -> Self {
Self {
to_add_tag_predicates: HashMap::new(),
errors: Vec::new(),
}
}

pub fn add_tag_predicate(
&mut self,
rule_id: ast::NodeLocation,
name: String,
arg_name: String,
tag_loc: ast::NodeLocation,
) {
let pred = ToAddTagPredicate::new(name, arg_name, tag_loc);
self.to_add_tag_predicates.insert(rule_id, pred);
}

pub fn register_predicates(
&mut self,
type_inference: &super::TypeInference,
foreign_predicate_registry: &mut ForeignPredicateRegistry,
) {
for (rule_id, tag_predicate) in self.to_add_tag_predicates.drain() {
if let Some(rule_variable_type) = type_inference.rule_variable_type.get(&rule_id) {
if let Some(var_ty) = rule_variable_type.get(&tag_predicate.arg_name) {
match get_target_tag_type(var_ty, &tag_predicate.tag_loc) {
Ok(target_tag_ty) => {
// This means that we have an okay tag that is type checked
// Create a foreign predicate and register it
let fp = TagPredicate::new(tag_predicate.name.clone(), target_tag_ty);
if let Err(err) = foreign_predicate_registry.register(fp) {
self.errors.push(FrontCompileErrorMessage::error().msg(err.to_string()));
}
}
Err(err) => {
self.errors.push(err);
}
}
}
}
}
}
}

fn get_target_tag_type(
var_ty: &analyzers::type_inference::TypeSet,
loc: &ast::NodeLocation,
) -> Result<ValueType, FrontCompileErrorMessage> {
// Top priority: if var_ty is a base type, directly check if it is among some expected type
if let Some(base_ty) = var_ty.get_base_type() {
if TAG_TYPE.contains(&base_ty) {
return Ok(base_ty);
}
}

// Then we check if the value can be casted into certain types
for tag_ty in TAG_TYPE.iter() {
if var_ty.can_type_cast(tag_ty) {
return Ok(var_ty.to_default_value_type());
}
}

// If not, then
return Err(
FrontCompileErrorMessage::error()
.msg(format!(
"A value of type `{var_ty}` cannot be casted into a dynamic tag"
))
.src(loc.clone()),
);
}

/// The information of a helper tag predicate
///
/// Suppose we have a rule
/// ``` ignore
/// rel 1/p :: head() = body(p)
/// ```
///
/// This rule will be transformed into
/// ``` ignore
/// rel head() = body(p) and tag#head#1#var == 1 / p and tag#head#1(tag#head#1#var)
/// ```
#[derive(Clone, Debug)]
pub struct ToAddTagPredicate {
/// The name of the predicate
pub name: String,

/// The main tag expression
pub arg_name: String,

/// Tag location
pub tag_loc: ast::NodeLocation,
}

impl ToAddTagPredicate {
pub fn new(name: String, arg_name: String, tag_loc: ast::NodeLocation) -> Self {
Self {
name,
arg_name,
tag_loc,
}
}
}

/// An actual predicate
#[derive(Clone, Debug)]
pub struct TagPredicate {
/// The name of he predicate
pub name: String,

/// args
pub arg_ty: ValueType,
}

impl TagPredicate {
pub fn new(name: String, arg_ty: ValueType) -> Self {
Self { name, arg_ty }
}
}

impl ForeignPredicate for TagPredicate {
fn name(&self) -> String {
self.name.clone()
}

fn arity(&self) -> usize {
1
}

fn argument_type(&self, i: usize) -> ValueType {
assert_eq!(i, 0);
self.arg_ty.clone()
}

fn num_bounded(&self) -> usize {
1
}

fn evaluate_with_env(&self, env: &RuntimeEnvironment, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec<Value>)> {
// Result tuple
let tup = vec![];

// Create a type cast expression and evaluate it on the given values
let tuple = Tuple::from_values(bounded.iter().cloned());
let cast_expr = expr::Expr::unary(unary_op::UnaryOp::TypeCast(ValueType::F64), expr::Expr::access(0));
let maybe_computed_tag = env.eval(&cast_expr, &tuple);

// Return the value
if let Some(Tuple::Value(Value::F64(f))) = maybe_computed_tag {
vec![(DynamicInputTag::Float(f), tup)]
} else {
vec![]
}
}
}
Loading

0 comments on commit da21ec8

Please sign in to comment.