Skip to content

Commit

Permalink
Implement RFC 57 (#702)
Browse files Browse the repository at this point in the history
Signed-off-by: Craig Disselkoen <[email protected]>
  • Loading branch information
cdisselkoen authored Mar 20, 2024
1 parent d6d5e98 commit 0f9df95
Show file tree
Hide file tree
Showing 17 changed files with 607 additions and 457 deletions.
52 changes: 14 additions & 38 deletions cedar-policy-core/src/ast/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,6 @@ pub enum ExprKind<T = ()> {
/// Second arg
arg2: Arc<Expr<T>>,
},
/// Multiplication by constant
///
/// This isn't just a BinaryOp because its arguments aren't both expressions.
/// (Similar to how `like` isn't a BinaryOp and has its own AST node as well.)
MulByConst {
/// first argument, which may be an arbitrary expression, but must
/// evaluate to Long type
arg: Arc<Expr<T>>,
/// second argument, which must be an integer constant
constant: Integer,
},
/// Application of an extension function to n arguments
/// INVARIANT (MethodStyleArgs):
/// if op.style is MethodStyle then args _cannot_ be empty.
Expand Down Expand Up @@ -384,9 +373,9 @@ impl Expr {
ExprBuilder::new().sub(e1, e2)
}

/// Create a 'mul' expression. First argument must evaluate to Long type.
pub fn mul(e: Expr, c: Integer) -> Self {
ExprBuilder::new().mul(e, c)
/// Create a 'mul' expression. Arguments must evaluate to Long type
pub fn mul(e1: Expr, e2: Expr) -> Self {
ExprBuilder::new().mul(e1, e2)
}

/// Create a 'neg' expression. `e` must evaluate to Long type.
Expand Down Expand Up @@ -589,9 +578,6 @@ impl Expr {
Ok(Expr::record(map)
.expect("cannot have a duplicate key because the input was already a BTreeMap"))
}
ExprKind::MulByConst { arg, constant } => {
Ok(Expr::mul(arg.substitute(definitions)?, *constant))
}
ExprKind::Is { expr, entity_type } => Ok(Expr::is_entity_type(
expr.substitute(definitions)?,
entity_type.clone(),
Expand Down Expand Up @@ -862,11 +848,12 @@ impl<T> ExprBuilder<T> {
})
}

/// Create a 'mul' expression. First argument must evaluate to Long type.
pub fn mul(self, e: Expr<T>, c: Integer) -> Expr<T> {
self.with_expr_kind(ExprKind::MulByConst {
arg: Arc::new(e),
constant: c,
/// Create a 'mul' expression. Arguments must evaluate to Long type
pub fn mul(self, e1: Expr<T>, e2: Expr<T>) -> Expr<T> {
self.with_expr_kind(ExprKind::BinaryApp {
op: BinaryOp::Mul,
arg1: Arc::new(e1),
arg2: Arc::new(e2),
})
}

Expand Down Expand Up @@ -1077,12 +1064,12 @@ impl<T: Clone> ExprBuilder<T> {
}

/// Errors when constructing an `Expr`
#[derive(Debug, PartialEq, Diagnostic, Error)]
#[derive(Debug, PartialEq, Eq, Clone, Diagnostic, Error)]
pub enum ExprConstructionError {
/// A key occurred twice (or more) in a record literal
/// The same key occurred two or more times in a single record literal
#[error("duplicate key `{key}` in record literal")]
DuplicateKeyInRecordLiteral {
/// The key which occurred twice (or more) in the record literal
/// The key which occurred two or more times in the record literal
key: SmolStr,
},
}
Expand Down Expand Up @@ -1178,13 +1165,6 @@ impl<T> Expr<T> {
arg2: arg21,
},
) => op == op1 && arg1.eq_shape(arg11) && arg2.eq_shape(arg21),
(
MulByConst { arg, constant },
MulByConst {
arg: arg1,
constant: constant1,
},
) => constant == constant1 && arg.eq_shape(arg1),
(
ExtensionFunctionApp { fn_name, args },
ExtensionFunctionApp {
Expand Down Expand Up @@ -1274,10 +1254,6 @@ impl<T> Expr<T> {
arg1.hash_shape(state);
arg2.hash_shape(state);
}
ExprKind::MulByConst { arg, constant } => {
arg.hash_shape(state);
constant.hash(state);
}
ExprKind::ExtensionFunctionApp { fn_name, args } => {
fn_name.hash(state);
state.write_usize(args.len());
Expand Down Expand Up @@ -1654,8 +1630,8 @@ mod test {
Expr::sub(Expr::val(1), Expr::val(1)),
),
(
ExprBuilder::with_data(1).mul(temp.clone(), 1),
Expr::mul(Expr::val(1), 1),
ExprBuilder::with_data(1).mul(temp.clone(), temp.clone()),
Expr::mul(Expr::val(1), Expr::val(1)),
),
(
ExprBuilder::with_data(1).neg(temp.clone()),
Expand Down
3 changes: 0 additions & 3 deletions cedar-policy-core/src/ast/expr_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ impl<'a, T> Iterator for ExprIterator<'a, T> {
self.expression_stack.push(arg1);
self.expression_stack.push(arg2);
}
ExprKind::MulByConst { arg, .. } => {
self.expression_stack.push(arg);
}
ExprKind::ExtensionFunctionApp { args, .. } => {
for arg in args.as_ref() {
self.expression_stack.push(arg);
Expand Down
6 changes: 6 additions & 0 deletions cedar-policy-core/src/ast/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ pub enum BinaryOp {
/// Arguments must have Long type
Sub,

/// Integer multiplication
///
/// Arguments must have Long type
Mul,

/// Hierarchy membership. Specifically, is the first arg a member of the
/// second.
///
Expand Down Expand Up @@ -103,6 +108,7 @@ impl std::fmt::Display for BinaryOp {
BinaryOp::LessEq => write!(f, "_<=_"),
BinaryOp::Add => write!(f, "_+_"),
BinaryOp::Sub => write!(f, "_-_"),
BinaryOp::Mul => write!(f, "_*_"),
BinaryOp::In => write!(f, "_in_"),
BinaryOp::Contains => write!(f, "contains"),
BinaryOp::ContainsAll => write!(f, "containsAll"),
Expand Down
4 changes: 4 additions & 0 deletions cedar-policy-core/src/ast/policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@ extern crate tsify;

/// Top level structure for a policy template.
/// Contains both the AST for template, and the list of open slots in the template.
///
/// Note that this "template" may have no slots, in which case this `Template` represents a static policy
#[derive(Clone, Hash, Eq, PartialEq, Debug, Serialize, Deserialize)]
#[serde(from = "TemplateBody")]
#[serde(into = "TemplateBody")]
pub struct Template {
body: TemplateBody,
/// INVARIANT (slot cache correctness): This Vec must contain _all_ of the open slots in `body`
/// This is maintained by the only two public constructors, `new` and `instantiate_inline_policy`
///
/// Note that `slots` may be empty, in which case this `Template` represents a static policy
slots: Vec<SlotId>,
}

Expand Down
15 changes: 7 additions & 8 deletions cedar-policy-core/src/ast/restricted_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::parser::err::ParseErrors;
use crate::parser::{self, Loc};
use miette::Diagnostic;
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
use smol_str::{SmolStr, ToSmolStr};
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::sync::Arc;
Expand Down Expand Up @@ -477,15 +477,11 @@ fn is_restricted(expr: &Expr) -> Result<(), RestrictedExprError> {
expr: expr.clone(),
}),
ExprKind::UnaryApp { op, .. } => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: op.to_string().into(),
feature: op.to_smolstr(),
expr: expr.clone(),
}),
ExprKind::BinaryApp { op, .. } => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: op.to_string().into(),
expr: expr.clone(),
}),
ExprKind::MulByConst { .. } => Err(RestrictedExprError::InvalidRestrictedExpression {
feature: "multiplication".into(),
feature: op.to_smolstr(),
expr: expr.clone(),
}),
ExprKind::GetAttr { .. } => Err(RestrictedExprError::InvalidRestrictedExpression {
Expand Down Expand Up @@ -654,6 +650,7 @@ pub enum RestrictedExprParseError {
#[cfg(test)]
mod test {
use super::*;
use crate::ast::ExprConstructionError;
use crate::parser::err::{ParseError, ToASTError, ToASTErrorKind};
use crate::parser::Loc;
use std::str::FromStr;
Expand Down Expand Up @@ -706,7 +703,9 @@ mod test {
RestrictedExpr::from_str(str),
Err(RestrictedExprParseError::Parse(ParseErrors(vec![
ParseError::ToAST(ToASTError::new(
ToASTErrorKind::DuplicateKeyInRecordLiteral { key: "foo".into() },
ToASTErrorKind::ExprConstructionError(
ExprConstructionError::DuplicateKeyInRecordLiteral { key: "foo".into() }
),
Loc::new(0..32, Arc::from(str))
))
]))),
Expand Down
1 change: 0 additions & 1 deletion cedar-policy-core/src/ast/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ impl TryFrom<Expr> for ValueKind {
ExprKind::Or { .. } => Err(NotValue::NotValue { loc }),
ExprKind::UnaryApp { .. } => Err(NotValue::NotValue { loc }),
ExprKind::BinaryApp { .. } => Err(NotValue::NotValue { loc }),
ExprKind::MulByConst { .. } => Err(NotValue::NotValue { loc }),
ExprKind::ExtensionFunctionApp { .. } => Err(NotValue::NotValue { loc }),
ExprKind::GetAttr { .. } => Err(NotValue::NotValue { loc }),
ExprKind::HasAttr { .. } => Err(NotValue::NotValue { loc }),
Expand Down
4 changes: 2 additions & 2 deletions cedar-policy-core/src/est/err.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

use crate::ast::{self, SlotId};
use crate::ast;
use crate::entities::JsonDeserializationError;
use crate::parser::err::ParseErrors;
use crate::parser::unescape;
Expand Down Expand Up @@ -49,7 +49,7 @@ pub enum FromJsonError {
#[diagnostic(help("slots are currently unsupported in `{clausetype}` clauses"))]
SlotsInConditionClause {
/// Slot that was found in a when/unless clause
slot: SlotId,
slot: ast::SlotId,
/// Clause type, e.g. "when" or "unless"
clausetype: &'static str,
},
Expand Down
29 changes: 5 additions & 24 deletions cedar-policy-core/src/est/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,26 +620,10 @@ impl Expr {
(*left).clone().try_into_ast(id.clone())?,
(*right).clone().try_into_ast(id)?,
)),
Expr::ExprNoExt(ExprNoExt::Mul { left, right }) => {
let left: ast::Expr = (*left).clone().try_into_ast(id.clone())?;
let right: ast::Expr = (*right).clone().try_into_ast(id)?;
let left_c = match left.expr_kind() {
ast::ExprKind::Lit(ast::Literal::Long(c)) => Some(c),
_ => None,
};
let right_c = match right.expr_kind() {
ast::ExprKind::Lit(ast::Literal::Long(c)) => Some(c),
_ => None,
};
match (left_c, right_c) {
(_, Some(c)) => Ok(ast::Expr::mul(left, *c)),
(Some(c), _) => Ok(ast::Expr::mul(right, *c)),
(None, None) => Err(FromJsonError::MultiplicationByNonConstant {
arg1: left,
arg2: right,
})?,
}
}
Expr::ExprNoExt(ExprNoExt::Mul { left, right }) => Ok(ast::Expr::mul(
(*left).clone().try_into_ast(id.clone())?,
(*right).clone().try_into_ast(id)?,
)),
Expr::ExprNoExt(ExprNoExt::Contains { left, right }) => Ok(ast::Expr::contains(
(*left).clone().try_into_ast(id.clone())?,
(*right).clone().try_into_ast(id)?,
Expand Down Expand Up @@ -780,15 +764,12 @@ impl From<ast::Expr> for Expr {
ast::BinaryOp::LessEq => Expr::lesseq(arg1, arg2),
ast::BinaryOp::Add => Expr::add(arg1, arg2),
ast::BinaryOp::Sub => Expr::sub(arg1, arg2),
ast::BinaryOp::Mul => Expr::mul(arg1, arg2),
ast::BinaryOp::Contains => Expr::contains(Arc::new(arg1), arg2),
ast::BinaryOp::ContainsAll => Expr::contains_all(Arc::new(arg1), arg2),
ast::BinaryOp::ContainsAny => Expr::contains_any(Arc::new(arg1), arg2),
}
}
ast::ExprKind::MulByConst { arg, constant } => Expr::mul(
Arc::unwrap_or_clone(arg).into(),
Expr::lit(CedarValueJson::Long(constant as InputInteger)),
),
ast::ExprKind::ExtensionFunctionApp { fn_name, args } => {
let args = Arc::unwrap_or_clone(args)
.into_iter()
Expand Down
50 changes: 25 additions & 25 deletions cedar-policy-core/src/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,11 @@ impl<'e> Evaluator<'e> {
match op {
BinaryOp::Eq => Ok((arg1 == arg2).into()),
// comparison and arithmetic operators, which only work on Longs
BinaryOp::Less | BinaryOp::LessEq | BinaryOp::Add | BinaryOp::Sub => {
BinaryOp::Less
| BinaryOp::LessEq
| BinaryOp::Add
| BinaryOp::Sub
| BinaryOp::Mul => {
let i1 = arg1.get_as_long()?;
let i2 = arg2.get_as_long()?;
match op {
Expand Down Expand Up @@ -436,6 +440,17 @@ impl<'e> Evaluator<'e> {
loc.cloned(),
)),
},
BinaryOp::Mul => match i1.checked_mul(i2) {
Some(prod) => Ok(prod.into()),
None => Err(EvaluationError::integer_overflow(
IntegerOverflowError::BinaryOp {
op: *op,
arg1,
arg2,
},
loc.cloned(),
)),
},
// PANIC SAFETY `op` is checked to be one of the above
#[allow(clippy::unreachable)]
_ => {
Expand Down Expand Up @@ -530,22 +545,6 @@ impl<'e> Evaluator<'e> {
}
}
}
ExprKind::MulByConst { arg, constant } => match self.partial_interpret(arg, slots)? {
PartialValue::Value(arg) => {
let i1 = arg.get_as_long()?;
match i1.checked_mul(*constant) {
Some(prod) => Ok(prod.into()),
None => Err(EvaluationError::integer_overflow(
IntegerOverflowError::Multiplication {
arg,
constant: *constant,
},
loc.cloned(),
)),
}
}
PartialValue::Residual(r) => Ok(PartialValue::Residual(Expr::mul(r, *constant))),
},
ExprKind::ExtensionFunctionApp { fn_name, args } => {
let args = args
.iter()
Expand Down Expand Up @@ -2986,17 +2985,17 @@ pub mod test {
);
// 5 * (-3)
assert_eq!(
eval.interpret_inline_policy(&Expr::mul(Expr::val(5), -3)),
eval.interpret_inline_policy(&Expr::mul(Expr::val(5), Expr::val(-3))),
Ok(Value::from(-15))
);
// 5 * 0
assert_eq!(
eval.interpret_inline_policy(&Expr::mul(Expr::val(5), 0)),
eval.interpret_inline_policy(&Expr::mul(Expr::val(5), Expr::val(0))),
Ok(Value::from(0))
);
// "5" * 0
assert_matches!(
eval.interpret_inline_policy(&Expr::mul(Expr::val("5"), 0)),
eval.interpret_inline_policy(&Expr::mul(Expr::val("5"), Expr::val(0))),
Err(e) => assert_eq!(e.error_kind(),
&EvaluationErrorKind::TypeError {
expected: nonempty![Type::Long],
Expand All @@ -3006,11 +3005,12 @@ pub mod test {
);
// overflow
assert_eq!(
eval.interpret_inline_policy(&Expr::mul(Expr::val(Integer::MAX - 1), 3)),
eval.interpret_inline_policy(&Expr::mul(Expr::val(Integer::MAX - 1), Expr::val(3))),
Err(EvaluationError::integer_overflow(
IntegerOverflowError::Multiplication {
arg: Value::from(Integer::MAX - 1),
constant: 3,
IntegerOverflowError::BinaryOp {
op: BinaryOp::Mul,
arg1: Value::from(Integer::MAX - 1),
arg2: Value::from(3),
},
None
))
Expand Down Expand Up @@ -5550,7 +5550,7 @@ pub mod test {
let exts = Extensions::none();
let eval = Evaluator::new(empty_request(), &es, &exts);

let e = Expr::mul(Expr::unknown(Unknown::new_untyped("a")), 32);
let e = Expr::mul(Expr::unknown(Unknown::new_untyped("a")), Expr::val(32));
let r = eval.partial_interpret(&e, &HashMap::new()).unwrap();
assert_eq!(r, PartialValue::Residual(e));
}
Expand Down
Loading

0 comments on commit 0f9df95

Please sign in to comment.