Skip to content

Commit

Permalink
Updating Scallop Plugin System
Browse files Browse the repository at this point in the history
  • Loading branch information
Liby99 committed Sep 14, 2024
1 parent 7b1a8cc commit e48b151
Show file tree
Hide file tree
Showing 176 changed files with 4,275 additions and 2,215 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ __pycache__

# Scallop
*.sclcmpl
/.scllog
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- Rule tags can now be expressions with potential reference to local variables: `rel 1/n::head() = body(n)`
- Allowing for sparse gradient computation inside Scallopy to minimize memory footprint
- Allowing users to specify per-datapoint output mapping inside Scallopy
- Adding `topkproofsdebug` provenance for obtaining the top-k proofs in Scallopy
- Adding destructor syntax so that ADTs can be used in a more idiomatic way
- Unifying the behavior of integer overflow inside Scallop
- Multiple bugs fixed
Expand Down
3 changes: 0 additions & 3 deletions core/src/common/foreign_aggregates/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,9 @@ impl MinMaxAggregator {
let i = if self.is_min { i } else { batch.len() - 1 - i };
let Tagged { tuple, tag } = &batch[i];
let and_true_tag = prov.mult(&accumulated_false_tag, tag);
println!("and true tag: {and_true_tag:?}");
result.push(DynamicElement::<P>::new(tuple.clone(), and_true_tag));
if let Some(f) = prov.negate(tag).map(|neg| prov.mult(&accumulated_false_tag, &neg)) {
println!("?????");
accumulated_false_tag = f;
println!("{accumulated_false_tag:?}");
}
}

Expand Down
3 changes: 3 additions & 0 deletions core/src/common/foreign_predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ impl ForeignPredicateRegistry {
// Tensor
reg.register(fps::TensorShape::new()).unwrap();

// Provenance
reg.register(fps::NewTagVariable::new()).unwrap();

reg
}

Expand Down
2 changes: 2 additions & 0 deletions core/src/common/foreign_predicates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use super::value_type::*;

mod datetime_ymd;
mod float_eq;
mod new_tag_variable;
mod range;
mod soft_cmp;
mod soft_eq;
Expand All @@ -25,6 +26,7 @@ mod tensor_shape;

pub use datetime_ymd::*;
pub use float_eq::*;
pub use new_tag_variable::*;
pub use range::*;
pub use soft_cmp::*;
pub use soft_eq::*;
Expand Down
42 changes: 42 additions & 0 deletions core/src/common/foreign_predicates/new_tag_variable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
//! The floating point equality predicate
use super::*;

#[derive(Clone, Debug)]
pub struct NewTagVariable;

impl NewTagVariable {
pub fn new() -> Self {
Self
}
}

impl ForeignPredicate for NewTagVariable {
fn name(&self) -> String {
"new_tag_variable".to_string()
}

fn generic_type_parameters(&self) -> Vec<ValueType> {
vec![]
}

fn arity(&self) -> usize {
0
}

#[allow(unused)]
fn argument_type(&self, i: usize) -> ValueType {
unreachable!("Shouldn't be called as there is no argument")
}

fn num_bounded(&self) -> usize {
0
}

#[allow(unused)]
fn evaluate(&self, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec<Value>)> {
vec![
(DynamicInputTag::NewVariable, vec![]),
]
}
}
4 changes: 3 additions & 1 deletion core/src/common/input_tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use super::foreign_tensor::DynamicExternalTensor;
#[derive(Clone, Debug, PartialEq, PartialOrd, Serialize)]
pub enum DynamicInputTag {
None,
NewVariable,
Exclusive(usize),
Bool(bool),
Natural(usize),
Expand Down Expand Up @@ -44,7 +45,8 @@ impl std::fmt::Display for DynamicInputTag {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::None => Ok(()),
Self::Exclusive(i) => f.write_str(&format!("[ME({})]", i)),
Self::NewVariable => f.write_str("new Var"),
Self::Exclusive(i) => f.write_fmt(format_args!("[ME({})]", i)),
Self::Bool(b) => b.fmt(f),
Self::Natural(n) => n.fmt(f),
Self::Float(n) => n.fmt(f),
Expand Down
30 changes: 30 additions & 0 deletions core/src/compiler/back/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,26 @@ impl Attributes {
}
None
}

pub fn goal_attr(&self) -> Option<&GoalAttribute> {
for attr in &self.attrs {
match attr {
Attribute::Goal(g) => return Some(g),
_ => {}
}
}
None
}

pub fn scheduler_attr(&self) -> Option<&SchedulerAttribute> {
for attr in &self.attrs {
match attr {
Attribute::Scheduler(s) => return Some(s),
_ => {}
}
}
None
}
}

impl<I> From<I> for Attributes
Expand All @@ -92,6 +112,8 @@ pub enum Attribute {
Demand(DemandAttribute),
MagicSet(MagicSetAttribute),
InputFile(InputFileAttribute),
Goal(GoalAttribute),
Scheduler(SchedulerAttribute),
}

impl Attribute {
Expand Down Expand Up @@ -182,3 +204,11 @@ pub struct MagicSetAttribute;
pub struct InputFileAttribute {
pub input_file: InputFile,
}

#[derive(Clone, Debug, PartialEq)]
pub struct GoalAttribute;

#[derive(Clone, Debug, PartialEq)]
pub struct SchedulerAttribute {
pub scheduler: crate::runtime::env::Scheduler,
}
2 changes: 2 additions & 0 deletions core/src/compiler/back/b2r.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ impl Program {
facts: vec![facts, disjunctive_facts].concat(),
input_file,
output,
is_goal: rel.attributes.goal_attr().is_some(),
scheduler: rel.attributes.scheduler_attr().map(|sa| sa.scheduler.clone()),
immutable,
};

Expand Down
23 changes: 23 additions & 0 deletions core/src/compiler/back/pretty.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use crate::runtime::env::Scheduler;

use super::*;
use std::fmt::{Display, Formatter, Result as FmtResult};

Expand Down Expand Up @@ -57,6 +59,8 @@ impl Display for Attribute {
Self::Demand(d) => d.fmt(f),
Self::MagicSet(d) => d.fmt(f),
Self::InputFile(i) => i.fmt(f),
Self::Goal(g) => g.fmt(f),
Self::Scheduler(s) => s.fmt(f),
}
}
}
Expand Down Expand Up @@ -97,6 +101,25 @@ impl Display for InputFileAttribute {
}
}

impl Display for GoalAttribute {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.write_str("@goal")
}
}

impl Display for SchedulerAttribute {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.write_str("@scheduler(")?;
match &self.scheduler {
Scheduler::LFP => f.write_str("\"lfp\"")?,
Scheduler::AStar => f.write_str("\"a-star\"")?,
Scheduler::DFS => f.write_str("\"dfs\"")?,
Scheduler::Beam { beam_size } => f.write_fmt(format_args!("\"beam\", beam_size = {beam_size}"))?,
}
f.write_str(")")
}
}

impl Display for Fact {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
if self.tag.is_some() {
Expand Down
21 changes: 16 additions & 5 deletions core/src/compiler/back/query_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,16 @@ impl<'a> QueryPlanContext<'a> {
}

fn pos_atom_arcs(&self, beam_size: usize) -> State {
let atom_relations = self.pos_atoms.iter().enumerate().map(|(i, atom)| (i, atom.predicate.clone())).collect();

// If there is no positive atom, return an empty state
if self.pos_atoms.is_empty() {
return State::new();
return State::new(atom_relations);
}

// Maintain a priority queue of searching states
let mut priority_queue = BinaryHeap::new();
priority_queue.push(State::new());
priority_queue.push(State::new(atom_relations));

// Maintain a set of final states
let mut final_states = BinaryHeap::new();
Expand Down Expand Up @@ -750,6 +752,7 @@ fn term_is_bounded(bounded_vars: &HashSet<Variable>, term: &Term) -> bool {

#[derive(Clone, Debug, Default, PartialEq, Eq)]
struct State {
atom_relations: HashMap<usize, String>,
visited_atoms: Vec<usize>,
arcs: Vec<Arc>,
}
Expand All @@ -767,8 +770,12 @@ impl std::cmp::Ord for State {
}

impl State {
pub fn new() -> Self {
Self::default()
pub fn new(atom_relations: HashMap<usize, String>) -> Self {
Self {
atom_relations,
visited_atoms: vec![],
arcs: vec![],
}
}

pub fn aggregated_weight(&self) -> i32 {
Expand Down Expand Up @@ -797,10 +804,12 @@ impl State {
let arc = Arc {
left: set.iter().map(|i| **i).collect::<Vec<_>>(),
right: id,
left_relations: set.iter().map(|id| self.atom_relations[id].clone()).collect(),
bounded_vars,
is_edb,
};
next_states.push(State {
atom_relations: self.atom_relations.clone(),
visited_atoms: vec![self.visited_atoms.clone(), vec![id]].concat(),
arcs: vec![self.arcs.clone(), vec![arc]].concat(),
});
Expand All @@ -820,15 +829,17 @@ impl State {
struct Arc {
left: Vec<usize>,
right: usize,
left_relations: Vec<String>,
bounded_vars: HashSet<Variable>,
is_edb: bool,
}

impl Arc {
pub fn weight(&self) -> i32 {
let demand_weight = self.left_relations.iter().filter(|r| r.starts_with("d#")).count() as i32;
let num_bounded_vars = self.bounded_vars.len() as i32;
let edb_weight = if self.left.is_empty() && self.is_edb { 1 } else { 0 };
num_bounded_vars + edb_weight
demand_weight + num_bounded_vars + edb_weight
}
}

Expand Down
29 changes: 19 additions & 10 deletions core/src/compiler/back/scc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,21 @@ impl Program {
// Collect the head predicate
let head_predicate = rule.head_predicate();

// Step 1. Deal with the dependencies between goal predicate and its dependencies
if let Some(head_relation) = self.relation_of_predicate(head_predicate) {
if head_relation.attributes.goal_attr().is_some() {
for atom in rule.body_literals() {
match atom {
Literal::Atom(a) if !self.predicate_registry.contains(&a.predicate) => {
graph.add_dependency(&a.predicate, head_predicate, E::Positive);
}
_ => {}
}
}
}
}

// Step 2. Deal with dependencies related to Entity and Functors
// Collect all the related functor predicates
let does_create_dyn_ent = rule.needs_dynamically_parse_entity(&self.function_registry, &self.predicate_registry);
let functor_predicates: Vec<_> = if does_create_dyn_ent {
Expand Down Expand Up @@ -278,17 +293,11 @@ impl Program {
};
for atom in rule.body_literals() {
match atom {
Literal::Atom(a) => {
let atom_predicate = &a.predicate;
if !self.predicate_registry.contains(atom_predicate) {
record_dependency(atom_predicate, E::Positive);
}
Literal::Atom(a) if !self.predicate_registry.contains(&a.predicate) => {
record_dependency(&a.predicate, E::Positive);
}
Literal::NegAtom(a) => {
let atom_predicate = &a.atom.predicate;
if !self.predicate_registry.contains(atom_predicate) {
record_dependency(atom_predicate, E::Negative);
}
Literal::NegAtom(a) if !self.predicate_registry.contains(&a.atom.predicate) => {
record_dependency(&a.atom.predicate, E::Negative);
}
Literal::Reduce(r) => {
let reduce_predicate = &r.body_formula.predicate;
Expand Down
9 changes: 9 additions & 0 deletions core/src/compiler/front/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ pub struct Analysis {
pub character_literal_analysis: CharacterLiteralAnalysis,
pub constant_decl_analysis: ConstantDeclAnalysis,
pub adt_analysis: AlgebraicDataTypeAnalysis,
pub goal_relation_analysis: GoalRelationAnalysis,
pub head_relation_analysis: HeadRelationAnalysis,
pub scheduler_attr_analysis: SchedulerAttributeAnalysis,
pub tagged_rule_analysis: TaggedRuleAnalysis,
pub type_inference: TypeInference,
pub boundness_analysis: BoundnessAnalysis,
Expand All @@ -42,6 +44,8 @@ impl Analysis {
constant_decl_analysis: ConstantDeclAnalysis::new(),
adt_analysis: AlgebraicDataTypeAnalysis::new(),
head_relation_analysis: HeadRelationAnalysis::new(predicate_registry),
goal_relation_analysis: GoalRelationAnalysis::new(),
scheduler_attr_analysis: SchedulerAttributeAnalysis::new(),
tagged_rule_analysis: TaggedRuleAnalysis::new(),
type_inference: TypeInference::new(function_registry, predicate_registry, aggregate_registry),
boundness_analysis: BoundnessAnalysis::new(predicate_registry),
Expand All @@ -60,6 +64,7 @@ impl Analysis {
&mut self.adt_analysis,
&mut self.invalid_constant,
&mut self.invalid_wildcard,
&mut self.scheduler_attr_analysis,
);
items.walk(&mut analyzers);
}
Expand All @@ -72,6 +77,7 @@ impl Analysis {

// Create the analyzers and walk the items
let mut analyzers = (
&mut self.goal_relation_analysis,
&mut self.head_relation_analysis,
&mut self.type_inference,
&mut self.demand_attr_analysis,
Expand All @@ -81,6 +87,7 @@ impl Analysis {
}

pub fn post_analysis(&mut self, foreign_predicate_registry: &mut ForeignPredicateRegistry) {
self.goal_relation_analysis.compute_errors();
self.head_relation_analysis.compute_errors();
self.type_inference.check_query_predicates();
self.type_inference.infer_types();
Expand All @@ -99,7 +106,9 @@ impl Analysis {
error_ctx.extend(&mut self.character_literal_analysis.errors);
error_ctx.extend(&mut self.constant_decl_analysis.errors);
error_ctx.extend(&mut self.adt_analysis.errors);
error_ctx.extend(&mut self.goal_relation_analysis.errors);
error_ctx.extend(&mut self.head_relation_analysis.errors);
error_ctx.extend(&mut self.scheduler_attr_analysis.errors);
error_ctx.extend(&mut self.type_inference.errors);
error_ctx.extend(&mut self.boundness_analysis.errors);
error_ctx.extend(&mut self.demand_attr_analysis.errors);
Expand Down
Loading

0 comments on commit e48b151

Please sign in to comment.