From 1eeec574476df4201e17465647136474ab9af428 Mon Sep 17 00:00:00 2001 From: Clement Delafargue Date: Wed, 11 Oct 2023 15:28:46 +0200 Subject: [PATCH] wip: short-circuiting boolean operators Even though the semantics of the stack machine are eager, what we actually care about is && and || being non-strict (rather than lazy). The difference being non-strict and lazy is not observable from the outside since datalog cannot perform side-effects. --- biscuit-auth/src/datalog/expression.rs | 301 +++++++++++++++---------- 1 file changed, 179 insertions(+), 122 deletions(-) diff --git a/biscuit-auth/src/datalog/expression.rs b/biscuit-auth/src/datalog/expression.rs index 8a5dd13a..49463601 100644 --- a/biscuit-auth/src/datalog/expression.rs +++ b/biscuit-auth/src/datalog/expression.rs @@ -86,132 +86,160 @@ impl Binary { fn evaluate( &self, left: Term, - right: Term, + right: Result, symbols: &mut TemporarySymbolTable, ) -> Result { - match (self, left, right) { - // integer - (Binary::LessThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i < j)), - (Binary::GreaterThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i > j)), - (Binary::LessOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i <= j)), - (Binary::GreaterOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i >= j)), - (Binary::Equal, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i == j)), - (Binary::NotEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i != j)), - (Binary::Add, Term::Integer(i), Term::Integer(j)) => i - .checked_add(j) - .map(Term::Integer) - .ok_or(error::Expression::Overflow), - (Binary::Sub, Term::Integer(i), Term::Integer(j)) => i - .checked_sub(j) - .map(Term::Integer) - .ok_or(error::Expression::Overflow), - (Binary::Mul, Term::Integer(i), Term::Integer(j)) => i - .checked_mul(j) - .map(Term::Integer) - .ok_or(error::Expression::Overflow), - (Binary::Div, Term::Integer(i), Term::Integer(j)) => i - .checked_div(j) - .map(Term::Integer) - .ok_or(error::Expression::DivideByZero), - (Binary::BitwiseAnd, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i & j)), - (Binary::BitwiseOr, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i | j)), - (Binary::BitwiseXor, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i ^ j)), - - // string - (Binary::Prefix, Term::Str(s), Term::Str(pref)) => { - match (symbols.get_symbol(s), symbols.get_symbol(pref)) { - (Some(s), Some(pref)) => Ok(Term::Bool(s.starts_with(pref))), - (Some(_), None) => Err(error::Expression::UnknownSymbol(pref)), - _ => Err(error::Expression::UnknownSymbol(s)), + // && and || are short-circuiting operators: depending on the left-hand side, the right-hand side may not be evaluated + match (self, &left) { + (Binary::And, Term::Bool(i)) => Ok(Term::Bool( + *i && { + match right? { + Term::Bool(j) => Ok(j), + _ => Err(error::Expression::InvalidType), + }? + }, + )), + (Binary::Or, Term::Bool(i)) => Ok(Term::Bool( + *i || { + match right? { + Term::Bool(j) => Ok(j), + _ => Err(error::Expression::InvalidType), + }? + }, + )), + + // (Binary::And, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i & j)), + // (Binary::Or, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i | j)), + _ => match (self, left, right?) { + // integer + (Binary::LessThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i < j)), + (Binary::GreaterThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i > j)), + (Binary::LessOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i <= j)), + (Binary::GreaterOrEqual, Term::Integer(i), Term::Integer(j)) => { + Ok(Term::Bool(i >= j)) } - } - (Binary::Suffix, Term::Str(s), Term::Str(suff)) => { - match (symbols.get_symbol(s), symbols.get_symbol(suff)) { - (Some(s), Some(suff)) => Ok(Term::Bool(s.ends_with(suff))), - (Some(_), None) => Err(error::Expression::UnknownSymbol(suff)), - _ => Err(error::Expression::UnknownSymbol(s)), + (Binary::Equal, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i == j)), + (Binary::NotEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i != j)), + (Binary::Add, Term::Integer(i), Term::Integer(j)) => i + .checked_add(j) + .map(Term::Integer) + .ok_or(error::Expression::Overflow), + (Binary::Sub, Term::Integer(i), Term::Integer(j)) => i + .checked_sub(j) + .map(Term::Integer) + .ok_or(error::Expression::Overflow), + (Binary::Mul, Term::Integer(i), Term::Integer(j)) => i + .checked_mul(j) + .map(Term::Integer) + .ok_or(error::Expression::Overflow), + (Binary::Div, Term::Integer(i), Term::Integer(j)) => i + .checked_div(j) + .map(Term::Integer) + .ok_or(error::Expression::DivideByZero), + (Binary::BitwiseAnd, Term::Integer(i), Term::Integer(j)) => { + Ok(Term::Integer(i & j)) } - } - (Binary::Regex, Term::Str(s), Term::Str(r)) => { - match (symbols.get_symbol(s), symbols.get_symbol(r)) { - (Some(s), Some(r)) => Ok(Term::Bool( - Regex::new(r).map(|re| re.is_match(s)).unwrap_or(false), - )), - (Some(_), None) => Err(error::Expression::UnknownSymbol(r)), - _ => Err(error::Expression::UnknownSymbol(s)), + (Binary::BitwiseOr, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i | j)), + (Binary::BitwiseXor, Term::Integer(i), Term::Integer(j)) => { + Ok(Term::Integer(i ^ j)) } - } - (Binary::Contains, Term::Str(s), Term::Str(pattern)) => { - match (symbols.get_symbol(s), symbols.get_symbol(pattern)) { - (Some(s), Some(pattern)) => Ok(Term::Bool(s.contains(pattern))), - (Some(_), None) => Err(error::Expression::UnknownSymbol(pattern)), - _ => Err(error::Expression::UnknownSymbol(s)), + + // string + (Binary::Prefix, Term::Str(s), Term::Str(pref)) => { + match (symbols.get_symbol(s), symbols.get_symbol(pref)) { + (Some(s), Some(pref)) => Ok(Term::Bool(s.starts_with(pref))), + (Some(_), None) => Err(error::Expression::UnknownSymbol(pref)), + _ => Err(error::Expression::UnknownSymbol(s)), + } } - } - (Binary::Add, Term::Str(s1), Term::Str(s2)) => { - match (symbols.get_symbol(s1), symbols.get_symbol(s2)) { - (Some(s1), Some(s2)) => { - let s = format!("{}{}", s1, s2); - let sym = symbols.insert(&s); - Ok(Term::Str(sym)) + (Binary::Suffix, Term::Str(s), Term::Str(suff)) => { + match (symbols.get_symbol(s), symbols.get_symbol(suff)) { + (Some(s), Some(suff)) => Ok(Term::Bool(s.ends_with(suff))), + (Some(_), None) => Err(error::Expression::UnknownSymbol(suff)), + _ => Err(error::Expression::UnknownSymbol(s)), } - (Some(_), None) => Err(error::Expression::UnknownSymbol(s2)), - _ => Err(error::Expression::UnknownSymbol(s1)), } - } - (Binary::Equal, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i == j)), - (Binary::NotEqual, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i != j)), - - // date - (Binary::LessThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i < j)), - (Binary::GreaterThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i > j)), - (Binary::LessOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i <= j)), - (Binary::GreaterOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i >= j)), - (Binary::Equal, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i == j)), - (Binary::NotEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i != j)), - - // symbol - - // byte array - (Binary::Equal, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i == j)), - (Binary::NotEqual, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i != j)), - - // set - (Binary::Equal, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set == s)), - (Binary::NotEqual, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set != s)), - (Binary::Intersection, Term::Set(set), Term::Set(s)) => { - Ok(Term::Set(set.intersection(&s).cloned().collect())) - } - (Binary::Union, Term::Set(set), Term::Set(s)) => { - Ok(Term::Set(set.union(&s).cloned().collect())) - } - (Binary::Contains, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set.is_superset(&s))), - (Binary::Contains, Term::Set(set), Term::Integer(i)) => { - Ok(Term::Bool(set.contains(&Term::Integer(i)))) - } - (Binary::Contains, Term::Set(set), Term::Date(i)) => { - Ok(Term::Bool(set.contains(&Term::Date(i)))) - } - (Binary::Contains, Term::Set(set), Term::Bool(i)) => { - Ok(Term::Bool(set.contains(&Term::Bool(i)))) - } - (Binary::Contains, Term::Set(set), Term::Str(i)) => { - Ok(Term::Bool(set.contains(&Term::Str(i)))) - } - (Binary::Contains, Term::Set(set), Term::Bytes(i)) => { - Ok(Term::Bool(set.contains(&Term::Bytes(i)))) - } + (Binary::Regex, Term::Str(s), Term::Str(r)) => { + match (symbols.get_symbol(s), symbols.get_symbol(r)) { + (Some(s), Some(r)) => Ok(Term::Bool( + Regex::new(r).map(|re| re.is_match(s)).unwrap_or(false), + )), + (Some(_), None) => Err(error::Expression::UnknownSymbol(r)), + _ => Err(error::Expression::UnknownSymbol(s)), + } + } + (Binary::Contains, Term::Str(s), Term::Str(pattern)) => { + match (symbols.get_symbol(s), symbols.get_symbol(pattern)) { + (Some(s), Some(pattern)) => Ok(Term::Bool(s.contains(pattern))), + (Some(_), None) => Err(error::Expression::UnknownSymbol(pattern)), + _ => Err(error::Expression::UnknownSymbol(s)), + } + } + (Binary::Add, Term::Str(s1), Term::Str(s2)) => { + match (symbols.get_symbol(s1), symbols.get_symbol(s2)) { + (Some(s1), Some(s2)) => { + let s = format!("{}{}", s1, s2); + let sym = symbols.insert(&s); + Ok(Term::Str(sym)) + } + (Some(_), None) => Err(error::Expression::UnknownSymbol(s2)), + _ => Err(error::Expression::UnknownSymbol(s1)), + } + } + (Binary::Equal, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i == j)), + (Binary::NotEqual, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i != j)), + + // date + (Binary::LessThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i < j)), + (Binary::GreaterThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i > j)), + (Binary::LessOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i <= j)), + (Binary::GreaterOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i >= j)), + (Binary::Equal, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i == j)), + (Binary::NotEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i != j)), + + // symbol + + // byte array + (Binary::Equal, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i == j)), + (Binary::NotEqual, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i != j)), + + // set + (Binary::Equal, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set == s)), + (Binary::NotEqual, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set != s)), + (Binary::Intersection, Term::Set(set), Term::Set(s)) => { + Ok(Term::Set(set.intersection(&s).cloned().collect())) + } + (Binary::Union, Term::Set(set), Term::Set(s)) => { + Ok(Term::Set(set.union(&s).cloned().collect())) + } + (Binary::Contains, Term::Set(set), Term::Set(s)) => { + Ok(Term::Bool(set.is_superset(&s))) + } + (Binary::Contains, Term::Set(set), Term::Integer(i)) => { + Ok(Term::Bool(set.contains(&Term::Integer(i)))) + } + (Binary::Contains, Term::Set(set), Term::Date(i)) => { + Ok(Term::Bool(set.contains(&Term::Date(i)))) + } + (Binary::Contains, Term::Set(set), Term::Bool(i)) => { + Ok(Term::Bool(set.contains(&Term::Bool(i)))) + } + (Binary::Contains, Term::Set(set), Term::Str(i)) => { + Ok(Term::Bool(set.contains(&Term::Str(i)))) + } + (Binary::Contains, Term::Set(set), Term::Bytes(i)) => { + Ok(Term::Bool(set.contains(&Term::Bytes(i)))) + } - // boolean - (Binary::And, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i & j)), - (Binary::Or, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i | j)), - (Binary::Equal, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i == j)), - (Binary::NotEqual, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i != j)), + // boolean + (Binary::Equal, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i == j)), + (Binary::NotEqual, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i != j)), - _ => { - //println!("unexpected value type on the stack"); - Err(error::Expression::InvalidType) - } + _ => { + //println!("unexpected value type on the stack"); + Err(error::Expression::InvalidType) + } + }, } } @@ -248,29 +276,33 @@ impl Expression { values: &HashMap, symbols: &mut TemporarySymbolTable, ) -> Result { - let mut stack: Vec = Vec::new(); + let mut stack: Vec> = Vec::new(); for op in self.ops.iter() { //println!("op: {:?}\t| stack: {:?}", op, stack); match op { Op::Value(Term::Variable(i)) => match values.get(i) { - Some(term) => stack.push(term.clone()), + Some(term) => stack.push(Ok(term.clone())), None => { //println!("unknown variable {}", i); return Err(error::Expression::UnknownVariable(*i)); } }, - Op::Value(term) => stack.push(term.clone()), + Op::Value(term) => stack.push(Ok(term.clone())), Op::Unary(unary) => match stack.pop() { None => { //println!("expected a value on the stack"); return Err(error::Expression::InvalidStack); } - Some(term) => stack.push(unary.evaluate(term, symbols)?), + // since && and || are short-circuiting, we don't abort computation as soon as there is an error. Instead we stack it, to allow further calls to && or || to discard it + Some(Err(e)) => stack.push(Err(e)), + Some(Ok(term)) => stack.push(unary.evaluate(term, symbols)), }, Op::Binary(binary) => match (stack.pop(), stack.pop()) { + // since && and || are short-circuiting, we don't abort computation as soon as there is an error. Instead we stack it, to allow further calls to && or || to discard it. + // && and || are non-strict on the second argument only, so the left-term is always handled (Some(right_term), Some(left_term)) => { - stack.push(binary.evaluate(left_term, right_term, symbols)?) + stack.push(binary.evaluate(left_term?, right_term, symbols)) } _ => { @@ -282,7 +314,7 @@ impl Expression { } if stack.len() == 1 { - Ok(stack.remove(0)) + stack.remove(0) } else { Err(error::Expression::InvalidStack) } @@ -346,6 +378,31 @@ mod tests { assert_eq!(res, Ok(Term::Bool(true))); } + #[test] + fn boolean_short_circuit() { + let symbols = SymbolTable::new(); + let mut tmp_symbols = TemporarySymbolTable::new(&symbols); + + let ops = vec![ + Op::Value(Term::Bool(true)), + Op::Value(Term::Bool(false)), + Op::Value(Term::Bool(false)), + Op::Binary(Binary::GreaterThan), + Op::Unary(Unary::Parens), + Op::Binary(Binary::Or), + ]; + + let values: HashMap = Default::default(); + + println!("ops: {:?}", ops); + + let e = Expression { ops }; + println!("print: {}", e.print(&symbols).unwrap()); + + let res = e.evaluate(&values, &mut tmp_symbols); + assert_eq!(res, Ok(Term::Bool(true))); + } + #[test] fn bitwise() { for (op, v1, v2, expected) in [