Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

wip: short-circuiting boolean operators #188

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
301 changes: 179 additions & 122 deletions biscuit-auth/src/datalog/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,132 +86,160 @@ impl Binary {
fn evaluate(
&self,
left: Term,
right: Term,
right: Result<Term, error::Expression>,
symbols: &mut TemporarySymbolTable,
) -> Result<Term, error::Expression> {
match (self, left, right) {
// integer
(Binary::LessThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i < j)),
(Binary::GreaterThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i > j)),
(Binary::LessOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i <= j)),
(Binary::GreaterOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i >= j)),
(Binary::Equal, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i != j)),
(Binary::Add, Term::Integer(i), Term::Integer(j)) => i
.checked_add(j)
.map(Term::Integer)
.ok_or(error::Expression::Overflow),
(Binary::Sub, Term::Integer(i), Term::Integer(j)) => i
.checked_sub(j)
.map(Term::Integer)
.ok_or(error::Expression::Overflow),
(Binary::Mul, Term::Integer(i), Term::Integer(j)) => i
.checked_mul(j)
.map(Term::Integer)
.ok_or(error::Expression::Overflow),
(Binary::Div, Term::Integer(i), Term::Integer(j)) => i
.checked_div(j)
.map(Term::Integer)
.ok_or(error::Expression::DivideByZero),
(Binary::BitwiseAnd, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i & j)),
(Binary::BitwiseOr, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i | j)),
(Binary::BitwiseXor, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i ^ j)),

// string
(Binary::Prefix, Term::Str(s), Term::Str(pref)) => {
match (symbols.get_symbol(s), symbols.get_symbol(pref)) {
(Some(s), Some(pref)) => Ok(Term::Bool(s.starts_with(pref))),
(Some(_), None) => Err(error::Expression::UnknownSymbol(pref)),
_ => Err(error::Expression::UnknownSymbol(s)),
// && and || are short-circuiting operators: depending on the left-hand side, the right-hand side may not be evaluated
match (self, &left) {
(Binary::And, Term::Bool(i)) => Ok(Term::Bool(
*i && {
match right? {
Term::Bool(j) => Ok(j),
_ => Err(error::Expression::InvalidType),
}?
},
)),
(Binary::Or, Term::Bool(i)) => Ok(Term::Bool(
*i || {
match right? {
Term::Bool(j) => Ok(j),
_ => Err(error::Expression::InvalidType),
}?
},
)),

// (Binary::And, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i & j)),
// (Binary::Or, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i | j)),
_ => match (self, left, right?) {
// integer
(Binary::LessThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i < j)),
(Binary::GreaterThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i > j)),
(Binary::LessOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i <= j)),
(Binary::GreaterOrEqual, Term::Integer(i), Term::Integer(j)) => {
Ok(Term::Bool(i >= j))
}
}
(Binary::Suffix, Term::Str(s), Term::Str(suff)) => {
match (symbols.get_symbol(s), symbols.get_symbol(suff)) {
(Some(s), Some(suff)) => Ok(Term::Bool(s.ends_with(suff))),
(Some(_), None) => Err(error::Expression::UnknownSymbol(suff)),
_ => Err(error::Expression::UnknownSymbol(s)),
(Binary::Equal, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i != j)),
(Binary::Add, Term::Integer(i), Term::Integer(j)) => i
.checked_add(j)
.map(Term::Integer)
.ok_or(error::Expression::Overflow),
(Binary::Sub, Term::Integer(i), Term::Integer(j)) => i
.checked_sub(j)
.map(Term::Integer)
.ok_or(error::Expression::Overflow),
(Binary::Mul, Term::Integer(i), Term::Integer(j)) => i
.checked_mul(j)
.map(Term::Integer)
.ok_or(error::Expression::Overflow),
(Binary::Div, Term::Integer(i), Term::Integer(j)) => i
.checked_div(j)
.map(Term::Integer)
.ok_or(error::Expression::DivideByZero),
(Binary::BitwiseAnd, Term::Integer(i), Term::Integer(j)) => {
Ok(Term::Integer(i & j))
}
}
(Binary::Regex, Term::Str(s), Term::Str(r)) => {
match (symbols.get_symbol(s), symbols.get_symbol(r)) {
(Some(s), Some(r)) => Ok(Term::Bool(
Regex::new(r).map(|re| re.is_match(s)).unwrap_or(false),
)),
(Some(_), None) => Err(error::Expression::UnknownSymbol(r)),
_ => Err(error::Expression::UnknownSymbol(s)),
(Binary::BitwiseOr, Term::Integer(i), Term::Integer(j)) => Ok(Term::Integer(i | j)),
(Binary::BitwiseXor, Term::Integer(i), Term::Integer(j)) => {
Ok(Term::Integer(i ^ j))
}
}
(Binary::Contains, Term::Str(s), Term::Str(pattern)) => {
match (symbols.get_symbol(s), symbols.get_symbol(pattern)) {
(Some(s), Some(pattern)) => Ok(Term::Bool(s.contains(pattern))),
(Some(_), None) => Err(error::Expression::UnknownSymbol(pattern)),
_ => Err(error::Expression::UnknownSymbol(s)),

// string
(Binary::Prefix, Term::Str(s), Term::Str(pref)) => {
match (symbols.get_symbol(s), symbols.get_symbol(pref)) {
(Some(s), Some(pref)) => Ok(Term::Bool(s.starts_with(pref))),
(Some(_), None) => Err(error::Expression::UnknownSymbol(pref)),
_ => Err(error::Expression::UnknownSymbol(s)),
}
}
}
(Binary::Add, Term::Str(s1), Term::Str(s2)) => {
match (symbols.get_symbol(s1), symbols.get_symbol(s2)) {
(Some(s1), Some(s2)) => {
let s = format!("{}{}", s1, s2);
let sym = symbols.insert(&s);
Ok(Term::Str(sym))
(Binary::Suffix, Term::Str(s), Term::Str(suff)) => {
match (symbols.get_symbol(s), symbols.get_symbol(suff)) {
(Some(s), Some(suff)) => Ok(Term::Bool(s.ends_with(suff))),
(Some(_), None) => Err(error::Expression::UnknownSymbol(suff)),
_ => Err(error::Expression::UnknownSymbol(s)),
}
(Some(_), None) => Err(error::Expression::UnknownSymbol(s2)),
_ => Err(error::Expression::UnknownSymbol(s1)),
}
}
(Binary::Equal, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i != j)),

// date
(Binary::LessThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i < j)),
(Binary::GreaterThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i > j)),
(Binary::LessOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i <= j)),
(Binary::GreaterOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i >= j)),
(Binary::Equal, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i != j)),

// symbol

// byte array
(Binary::Equal, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i != j)),

// set
(Binary::Equal, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set == s)),
(Binary::NotEqual, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set != s)),
(Binary::Intersection, Term::Set(set), Term::Set(s)) => {
Ok(Term::Set(set.intersection(&s).cloned().collect()))
}
(Binary::Union, Term::Set(set), Term::Set(s)) => {
Ok(Term::Set(set.union(&s).cloned().collect()))
}
(Binary::Contains, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set.is_superset(&s))),
(Binary::Contains, Term::Set(set), Term::Integer(i)) => {
Ok(Term::Bool(set.contains(&Term::Integer(i))))
}
(Binary::Contains, Term::Set(set), Term::Date(i)) => {
Ok(Term::Bool(set.contains(&Term::Date(i))))
}
(Binary::Contains, Term::Set(set), Term::Bool(i)) => {
Ok(Term::Bool(set.contains(&Term::Bool(i))))
}
(Binary::Contains, Term::Set(set), Term::Str(i)) => {
Ok(Term::Bool(set.contains(&Term::Str(i))))
}
(Binary::Contains, Term::Set(set), Term::Bytes(i)) => {
Ok(Term::Bool(set.contains(&Term::Bytes(i))))
}
(Binary::Regex, Term::Str(s), Term::Str(r)) => {
match (symbols.get_symbol(s), symbols.get_symbol(r)) {
(Some(s), Some(r)) => Ok(Term::Bool(
Regex::new(r).map(|re| re.is_match(s)).unwrap_or(false),
)),
(Some(_), None) => Err(error::Expression::UnknownSymbol(r)),
_ => Err(error::Expression::UnknownSymbol(s)),
}
}
(Binary::Contains, Term::Str(s), Term::Str(pattern)) => {
match (symbols.get_symbol(s), symbols.get_symbol(pattern)) {
(Some(s), Some(pattern)) => Ok(Term::Bool(s.contains(pattern))),
(Some(_), None) => Err(error::Expression::UnknownSymbol(pattern)),
_ => Err(error::Expression::UnknownSymbol(s)),
}
}
(Binary::Add, Term::Str(s1), Term::Str(s2)) => {
match (symbols.get_symbol(s1), symbols.get_symbol(s2)) {
(Some(s1), Some(s2)) => {
let s = format!("{}{}", s1, s2);
let sym = symbols.insert(&s);
Ok(Term::Str(sym))
}
(Some(_), None) => Err(error::Expression::UnknownSymbol(s2)),
_ => Err(error::Expression::UnknownSymbol(s1)),
}
}
(Binary::Equal, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i != j)),

// date
(Binary::LessThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i < j)),
(Binary::GreaterThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i > j)),
(Binary::LessOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i <= j)),
(Binary::GreaterOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i >= j)),
(Binary::Equal, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i != j)),

// symbol

// byte array
(Binary::Equal, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i != j)),

// set
(Binary::Equal, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set == s)),
(Binary::NotEqual, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set != s)),
(Binary::Intersection, Term::Set(set), Term::Set(s)) => {
Ok(Term::Set(set.intersection(&s).cloned().collect()))
}
(Binary::Union, Term::Set(set), Term::Set(s)) => {
Ok(Term::Set(set.union(&s).cloned().collect()))
}
(Binary::Contains, Term::Set(set), Term::Set(s)) => {
Ok(Term::Bool(set.is_superset(&s)))
}
(Binary::Contains, Term::Set(set), Term::Integer(i)) => {
Ok(Term::Bool(set.contains(&Term::Integer(i))))
}
(Binary::Contains, Term::Set(set), Term::Date(i)) => {
Ok(Term::Bool(set.contains(&Term::Date(i))))
}
(Binary::Contains, Term::Set(set), Term::Bool(i)) => {
Ok(Term::Bool(set.contains(&Term::Bool(i))))
}
(Binary::Contains, Term::Set(set), Term::Str(i)) => {
Ok(Term::Bool(set.contains(&Term::Str(i))))
}
(Binary::Contains, Term::Set(set), Term::Bytes(i)) => {
Ok(Term::Bool(set.contains(&Term::Bytes(i))))
}

// boolean
(Binary::And, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i & j)),
(Binary::Or, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i | j)),
(Binary::Equal, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i != j)),
// boolean
(Binary::Equal, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i != j)),

_ => {
//println!("unexpected value type on the stack");
Err(error::Expression::InvalidType)
}
_ => {
//println!("unexpected value type on the stack");
Err(error::Expression::InvalidType)
}
},
}
}

Expand Down Expand Up @@ -248,29 +276,33 @@ impl Expression {
values: &HashMap<u32, Term>,
symbols: &mut TemporarySymbolTable,
) -> Result<Term, error::Expression> {
let mut stack: Vec<Term> = Vec::new();
let mut stack: Vec<Result<Term, error::Expression>> = Vec::new();

for op in self.ops.iter() {
//println!("op: {:?}\t| stack: {:?}", op, stack);
match op {
Op::Value(Term::Variable(i)) => match values.get(i) {
Some(term) => stack.push(term.clone()),
Some(term) => stack.push(Ok(term.clone())),
None => {
//println!("unknown variable {}", i);
return Err(error::Expression::UnknownVariable(*i));
}
},
Op::Value(term) => stack.push(term.clone()),
Op::Value(term) => stack.push(Ok(term.clone())),
Op::Unary(unary) => match stack.pop() {
None => {
//println!("expected a value on the stack");
return Err(error::Expression::InvalidStack);
}
Some(term) => stack.push(unary.evaluate(term, symbols)?),
// since && and || are short-circuiting, we don't abort computation as soon as there is an error. Instead we stack it, to allow further calls to && or || to discard it
Some(Err(e)) => stack.push(Err(e)),
Some(Ok(term)) => stack.push(unary.evaluate(term, symbols)),
},
Op::Binary(binary) => match (stack.pop(), stack.pop()) {
// since && and || are short-circuiting, we don't abort computation as soon as there is an error. Instead we stack it, to allow further calls to && or || to discard it.
// && and || are non-strict on the second argument only, so the left-term is always handled
(Some(right_term), Some(left_term)) => {
stack.push(binary.evaluate(left_term, right_term, symbols)?)
stack.push(binary.evaluate(left_term?, right_term, symbols))
}

_ => {
Expand All @@ -282,7 +314,7 @@ impl Expression {
}

if stack.len() == 1 {
Ok(stack.remove(0))
stack.remove(0)
} else {
Err(error::Expression::InvalidStack)
}
Expand Down Expand Up @@ -346,6 +378,31 @@ mod tests {
assert_eq!(res, Ok(Term::Bool(true)));
}

#[test]
fn boolean_short_circuit() {
let symbols = SymbolTable::new();
let mut tmp_symbols = TemporarySymbolTable::new(&symbols);

let ops = vec![
Op::Value(Term::Bool(true)),
Op::Value(Term::Bool(false)),
Op::Value(Term::Bool(false)),
Op::Binary(Binary::GreaterThan),
Op::Unary(Unary::Parens),
Op::Binary(Binary::Or),
];

let values: HashMap<u32, Term> = Default::default();

println!("ops: {:?}", ops);

let e = Expression { ops };
println!("print: {}", e.print(&symbols).unwrap());

let res = e.evaluate(&values, &mut tmp_symbols);
assert_eq!(res, Ok(Term::Bool(true)));
}

#[test]
fn bitwise() {
for (op, v1, v2, expected) in [
Expand Down
Loading