Skip to content

Commit

Permalink
wip: short-circuiting boolean operators
Browse files Browse the repository at this point in the history
Even though the semantics of the stack machine are eager, what we actually care about is && and || being non-strict (rather than lazy).
The difference being non-strict and lazy is not observable from the outside since datalog cannot perform side-effects.
  • Loading branch information
divarvel committed Oct 11, 2023
1 parent 132b0c9 commit 1eeec57
Showing 1 changed file with 179 additions and 122 deletions.
301 changes: 179 additions & 122 deletions biscuit-auth/src/datalog/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,132 +86,160 @@ impl Binary {
fn evaluate(
&self,
left: Term,
right: Term,
right: Result<Term, error::Expression>,
symbols: &mut TemporarySymbolTable,
) -> Result<Term, error::Expression> {
match (self, left, right) {
// integer
(Binary::LessThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i < j)),
(Binary::GreaterThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i > j)),
(Binary::LessOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i <= j)),
(Binary::GreaterOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i >= j)),
(Binary::Equal, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i != j)),
(Binary::Add, Term::Integer(i), Term::Integer(j)) => i
.checked_add(j)
.map(Term::Integer)
.ok_or(error::Expression::Overflow),
(Binary::Sub, Term::Integer(i), Term::Integer(j)) => i
.checked_sub(j)
.map(Term::Integer)
.ok_or(error::Expression::Overflow),
(Binary::Mul, Term::Integer(i), Term::Integer(j)) => i
.checked_mul(j)
.map(Term::Integer)
.ok_or(error::Expression::Overflow),
(Binary::Div, Term::Integer(i), Term::Integer(j)) => i
.checked_div(j)
.map(Term::Integer)
.ok_or(error::Expression::DivideByZero),
(Binary::BitwiseAnd, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i & j)),
(Binary::BitwiseOr, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i | j)),
(Binary::BitwiseXor, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i ^ j)),

// string
(Binary::Prefix, Term::Str(s), Term::Str(pref)) => {
match (symbols.get_symbol(s), symbols.get_symbol(pref)) {
(Some(s), Some(pref)) => Ok(Term::Bool(s.starts_with(pref))),
(Some(_), None) => Err(error::Expression::UnknownSymbol(pref)),
_ => Err(error::Expression::UnknownSymbol(s)),
// && and || are short-circuiting operators: depending on the left-hand side, the right-hand side may not be evaluated
match (self, &left) {
(Binary::And, Term::Bool(i)) => Ok(Term::Bool(
*i && {
match right? {
Term::Bool(j) => Ok(j),
_ => Err(error::Expression::InvalidType),
}?
},
)),
(Binary::Or, Term::Bool(i)) => Ok(Term::Bool(
*i || {
match right? {
Term::Bool(j) => Ok(j),
_ => Err(error::Expression::InvalidType),
}?
},
)),

// (Binary::And, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i & j)),
// (Binary::Or, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i | j)),
_ => match (self, left, right?) {
// integer
(Binary::LessThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i < j)),
(Binary::GreaterThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i > j)),
(Binary::LessOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i <= j)),
(Binary::GreaterOrEqual, Term::Integer(i), Term::Integer(j)) => {
Ok(Term::Bool(i >= j))
}
}
(Binary::Suffix, Term::Str(s), Term::Str(suff)) => {
match (symbols.get_symbol(s), symbols.get_symbol(suff)) {
(Some(s), Some(suff)) => Ok(Term::Bool(s.ends_with(suff))),
(Some(_), None) => Err(error::Expression::UnknownSymbol(suff)),
_ => Err(error::Expression::UnknownSymbol(s)),
(Binary::Equal, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i != j)),
(Binary::Add, Term::Integer(i), Term::Integer(j)) => i
.checked_add(j)
.map(Term::Integer)
.ok_or(error::Expression::Overflow),
(Binary::Sub, Term::Integer(i), Term::Integer(j)) => i
.checked_sub(j)
.map(Term::Integer)
.ok_or(error::Expression::Overflow),
(Binary::Mul, Term::Integer(i), Term::Integer(j)) => i
.checked_mul(j)
.map(Term::Integer)
.ok_or(error::Expression::Overflow),
(Binary::Div, Term::Integer(i), Term::Integer(j)) => i
.checked_div(j)
.map(Term::Integer)
.ok_or(error::Expression::DivideByZero),
(Binary::BitwiseAnd, Term::Integer(i), Term::Integer(j)) => {
Ok(Term::Integer(i & j))
}
}
(Binary::Regex, Term::Str(s), Term::Str(r)) => {
match (symbols.get_symbol(s), symbols.get_symbol(r)) {
(Some(s), Some(r)) => Ok(Term::Bool(
Regex::new(r).map(|re| re.is_match(s)).unwrap_or(false),
)),
(Some(_), None) => Err(error::Expression::UnknownSymbol(r)),
_ => Err(error::Expression::UnknownSymbol(s)),
(Binary::BitwiseOr, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i | j)),
(Binary::BitwiseXor, Term::Integer(i), Term::Integer(j)) => {
Ok(Term::Integer(i ^ j))
}
}
(Binary::Contains, Term::Str(s), Term::Str(pattern)) => {
match (symbols.get_symbol(s), symbols.get_symbol(pattern)) {
(Some(s), Some(pattern)) => Ok(Term::Bool(s.contains(pattern))),
(Some(_), None) => Err(error::Expression::UnknownSymbol(pattern)),
_ => Err(error::Expression::UnknownSymbol(s)),

// string
(Binary::Prefix, Term::Str(s), Term::Str(pref)) => {
match (symbols.get_symbol(s), symbols.get_symbol(pref)) {
(Some(s), Some(pref)) => Ok(Term::Bool(s.starts_with(pref))),
(Some(_), None) => Err(error::Expression::UnknownSymbol(pref)),
_ => Err(error::Expression::UnknownSymbol(s)),
}
}
}
(Binary::Add, Term::Str(s1), Term::Str(s2)) => {
match (symbols.get_symbol(s1), symbols.get_symbol(s2)) {
(Some(s1), Some(s2)) => {
let s = format!("{}{}", s1, s2);
let sym = symbols.insert(&s);
Ok(Term::Str(sym))
(Binary::Suffix, Term::Str(s), Term::Str(suff)) => {
match (symbols.get_symbol(s), symbols.get_symbol(suff)) {
(Some(s), Some(suff)) => Ok(Term::Bool(s.ends_with(suff))),
(Some(_), None) => Err(error::Expression::UnknownSymbol(suff)),
_ => Err(error::Expression::UnknownSymbol(s)),
}
(Some(_), None) => Err(error::Expression::UnknownSymbol(s2)),
_ => Err(error::Expression::UnknownSymbol(s1)),
}
}
(Binary::Equal, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i != j)),

// date
(Binary::LessThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i < j)),
(Binary::GreaterThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i > j)),
(Binary::LessOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i <= j)),
(Binary::GreaterOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i >= j)),
(Binary::Equal, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i != j)),

// symbol

// byte array
(Binary::Equal, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i != j)),

// set
(Binary::Equal, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set == s)),
(Binary::NotEqual, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set != s)),
(Binary::Intersection, Term::Set(set), Term::Set(s)) => {
Ok(Term::Set(set.intersection(&s).cloned().collect()))
}
(Binary::Union, Term::Set(set), Term::Set(s)) => {
Ok(Term::Set(set.union(&s).cloned().collect()))
}
(Binary::Contains, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set.is_superset(&s))),
(Binary::Contains, Term::Set(set), Term::Integer(i)) => {
Ok(Term::Bool(set.contains(&Term::Integer(i))))
}
(Binary::Contains, Term::Set(set), Term::Date(i)) => {
Ok(Term::Bool(set.contains(&Term::Date(i))))
}
(Binary::Contains, Term::Set(set), Term::Bool(i)) => {
Ok(Term::Bool(set.contains(&Term::Bool(i))))
}
(Binary::Contains, Term::Set(set), Term::Str(i)) => {
Ok(Term::Bool(set.contains(&Term::Str(i))))
}
(Binary::Contains, Term::Set(set), Term::Bytes(i)) => {
Ok(Term::Bool(set.contains(&Term::Bytes(i))))
}
(Binary::Regex, Term::Str(s), Term::Str(r)) => {
match (symbols.get_symbol(s), symbols.get_symbol(r)) {
(Some(s), Some(r)) => Ok(Term::Bool(
Regex::new(r).map(|re| re.is_match(s)).unwrap_or(false),
)),
(Some(_), None) => Err(error::Expression::UnknownSymbol(r)),
_ => Err(error::Expression::UnknownSymbol(s)),
}
}
(Binary::Contains, Term::Str(s), Term::Str(pattern)) => {
match (symbols.get_symbol(s), symbols.get_symbol(pattern)) {
(Some(s), Some(pattern)) => Ok(Term::Bool(s.contains(pattern))),
(Some(_), None) => Err(error::Expression::UnknownSymbol(pattern)),
_ => Err(error::Expression::UnknownSymbol(s)),
}
}
(Binary::Add, Term::Str(s1), Term::Str(s2)) => {
match (symbols.get_symbol(s1), symbols.get_symbol(s2)) {
(Some(s1), Some(s2)) => {
let s = format!("{}{}", s1, s2);
let sym = symbols.insert(&s);
Ok(Term::Str(sym))
}
(Some(_), None) => Err(error::Expression::UnknownSymbol(s2)),
_ => Err(error::Expression::UnknownSymbol(s1)),
}
}
(Binary::Equal, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i != j)),

// date
(Binary::LessThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i < j)),
(Binary::GreaterThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i > j)),
(Binary::LessOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i <= j)),
(Binary::GreaterOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i >= j)),
(Binary::Equal, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i != j)),

// symbol

// byte array
(Binary::Equal, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i != j)),

// set
(Binary::Equal, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set == s)),
(Binary::NotEqual, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set != s)),
(Binary::Intersection, Term::Set(set), Term::Set(s)) => {
Ok(Term::Set(set.intersection(&s).cloned().collect()))
}
(Binary::Union, Term::Set(set), Term::Set(s)) => {
Ok(Term::Set(set.union(&s).cloned().collect()))
}
(Binary::Contains, Term::Set(set), Term::Set(s)) => {
Ok(Term::Bool(set.is_superset(&s)))
}
(Binary::Contains, Term::Set(set), Term::Integer(i)) => {
Ok(Term::Bool(set.contains(&Term::Integer(i))))
}
(Binary::Contains, Term::Set(set), Term::Date(i)) => {
Ok(Term::Bool(set.contains(&Term::Date(i))))
}
(Binary::Contains, Term::Set(set), Term::Bool(i)) => {
Ok(Term::Bool(set.contains(&Term::Bool(i))))
}
(Binary::Contains, Term::Set(set), Term::Str(i)) => {
Ok(Term::Bool(set.contains(&Term::Str(i))))
}
(Binary::Contains, Term::Set(set), Term::Bytes(i)) => {
Ok(Term::Bool(set.contains(&Term::Bytes(i))))
}

// boolean
(Binary::And, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i & j)),
(Binary::Or, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i | j)),
(Binary::Equal, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i != j)),
// boolean
(Binary::Equal, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i != j)),

_ => {
//println!("unexpected value type on the stack");
Err(error::Expression::InvalidType)
}
_ => {
//println!("unexpected value type on the stack");
Err(error::Expression::InvalidType)
}
},
}
}

Expand Down Expand Up @@ -248,29 +276,33 @@ impl Expression {
values: &HashMap<u32, Term>,
symbols: &mut TemporarySymbolTable,
) -> Result<Term, error::Expression> {
let mut stack: Vec<Term> = Vec::new();
let mut stack: Vec<Result<Term, error::Expression>> = Vec::new();

for op in self.ops.iter() {
//println!("op: {:?}\t| stack: {:?}", op, stack);
match op {
Op::Value(Term::Variable(i)) => match values.get(i) {
Some(term) => stack.push(term.clone()),
Some(term) => stack.push(Ok(term.clone())),
None => {
//println!("unknown variable {}", i);
return Err(error::Expression::UnknownVariable(*i));
}
},
Op::Value(term) => stack.push(term.clone()),
Op::Value(term) => stack.push(Ok(term.clone())),
Op::Unary(unary) => match stack.pop() {
None => {
//println!("expected a value on the stack");
return Err(error::Expression::InvalidStack);
}
Some(term) => stack.push(unary.evaluate(term, symbols)?),
// since && and || are short-circuiting, we don't abort computation as soon as there is an error. Instead we stack it, to allow further calls to && or || to discard it
Some(Err(e)) => stack.push(Err(e)),
Some(Ok(term)) => stack.push(unary.evaluate(term, symbols)),
},
Op::Binary(binary) => match (stack.pop(), stack.pop()) {
// since && and || are short-circuiting, we don't abort computation as soon as there is an error. Instead we stack it, to allow further calls to && or || to discard it.
// && and || are non-strict on the second argument only, so the left-term is always handled
(Some(right_term), Some(left_term)) => {
stack.push(binary.evaluate(left_term, right_term, symbols)?)
stack.push(binary.evaluate(left_term?, right_term, symbols))
}

_ => {
Expand All @@ -282,7 +314,7 @@ impl Expression {
}

if stack.len() == 1 {
Ok(stack.remove(0))
stack.remove(0)
} else {
Err(error::Expression::InvalidStack)
}
Expand Down Expand Up @@ -346,6 +378,31 @@ mod tests {
assert_eq!(res, Ok(Term::Bool(true)));
}

#[test]
fn boolean_short_circuit() {
let symbols = SymbolTable::new();
let mut tmp_symbols = TemporarySymbolTable::new(&symbols);

let ops = vec![
Op::Value(Term::Bool(true)),
Op::Value(Term::Bool(false)),
Op::Value(Term::Bool(false)),
Op::Binary(Binary::GreaterThan),
Op::Unary(Unary::Parens),
Op::Binary(Binary::Or),
];

let values: HashMap<u32, Term> = Default::default();

println!("ops: {:?}", ops);

let e = Expression { ops };
println!("print: {}", e.print(&symbols).unwrap());

let res = e.evaluate(&values, &mut tmp_symbols);
assert_eq!(res, Ok(Term::Bool(true)));
}

#[test]
fn bitwise() {
for (op, v1, v2, expected) in [
Expand Down

0 comments on commit 1eeec57

Please sign in to comment.