Skip to content

Commit

Permalink
Return an error for schema with an unknown extension type (#890)
Browse files Browse the repository at this point in the history
Signed-off-by: John Kastner <[email protected]>
  • Loading branch information
john-h-kastner-aws authored May 21, 2024
1 parent 89fb874 commit 5b9930c
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 43 deletions.
21 changes: 20 additions & 1 deletion cedar-policy-validator/src/err.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

use std::collections::HashSet;
use std::{collections::HashSet, fmt::Display};

use cedar_policy_core::{
ast::{EntityAttrEvaluationError, EntityUID, Name},
Expand Down Expand Up @@ -202,6 +202,10 @@ pub enum SchemaError {
#[error("the `__expr` escape is no longer supported")]
#[diagnostic(help("to create an entity reference, use `__entity`; to create an extension value, use `__extn`; and for all other values, use JSON directly"))]
ExprEscapeUsed,
/// The schema used an extension type that the validator doesn't know about.
#[error(transparent)]
#[diagnostic(transparent)]
UnknownExtensionType(UnknownExtensionType),
}

impl From<transitive_closure::TcError<EntityUID>> for SchemaError {
Expand Down Expand Up @@ -294,3 +298,18 @@ impl JsonDeserializationError {
}
}
}

#[derive(Error, Debug)]
#[error("unknown extension type `{actual}`")]
pub struct UnknownExtensionType {
pub(crate) actual: Name,
pub(crate) suggested_replacement: Option<String>,
}

impl Diagnostic for UnknownExtensionType {
fn help<'a>(&'a self) -> Option<Box<dyn std::fmt::Display + 'a>> {
self.suggested_replacement
.as_ref()
.map(|suggestion| Box::new(format!("did you mean `{suggestion}`?")) as Box<dyn Display>)
}
}
168 changes: 141 additions & 27 deletions cedar-policy-validator/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,18 @@ impl TryFrom<NamespaceDefinition> for ValidatorSchema {
type Error = SchemaError;

fn try_from(nsd: NamespaceDefinition) -> Result<ValidatorSchema> {
ValidatorSchema::from_schema_fragments([ValidatorSchemaFragment::from_namespaces([
nsd.try_into()?
])])
ValidatorSchema::from_schema_fragments(
[ValidatorSchemaFragment::from_namespaces([nsd.try_into()?])],
Extensions::all_available(),
)
}
}

impl TryFrom<SchemaFragment> for ValidatorSchema {
type Error = SchemaError;

fn try_from(frag: SchemaFragment) -> Result<ValidatorSchema> {
ValidatorSchema::from_schema_fragments([frag.try_into()?])
ValidatorSchema::from_schema_fragments([frag.try_into()?], Extensions::all_available())
}
}

Expand Down Expand Up @@ -221,16 +222,20 @@ impl ValidatorSchema {
action_behavior: ActionBehavior,
extensions: Extensions<'_>,
) -> Result<ValidatorSchema> {
Self::from_schema_fragments([ValidatorSchemaFragment::from_schema_fragment(
schema_file,
action_behavior,
Self::from_schema_fragments(
[ValidatorSchemaFragment::from_schema_fragment(
schema_file,
action_behavior,
extensions,
)?],
extensions,
)?])
)
}

/// Construct a new `ValidatorSchema` from some number of schema fragments.
pub fn from_schema_fragments(
fragments: impl IntoIterator<Item = ValidatorSchemaFragment>,
extensions: Extensions<'_>,
) -> Result<ValidatorSchema> {
let mut type_defs = HashMap::new();
let mut entity_type_fragments = HashMap::new();
Expand Down Expand Up @@ -271,7 +276,7 @@ impl ValidatorSchema {
}

let resolver = CommonTypeResolver::new(&type_defs);
let type_defs = resolver.resolve()?;
let type_defs = resolver.resolve(extensions)?;

// Invert the `parents` relation defined by entities and action so far
// to get a `children` relation.
Expand Down Expand Up @@ -656,14 +661,17 @@ impl TryInto<ValidatorSchema> for NamespaceDefinitionWithActionAttributes {
type Error = SchemaError;

fn try_into(self) -> Result<ValidatorSchema> {
ValidatorSchema::from_schema_fragments([ValidatorSchemaFragment::from_namespaces([
ValidatorNamespaceDef::from_namespace_definition(
None,
self.0,
crate::ActionBehavior::PermitAttributes,
Extensions::all_available(),
)?,
])])
ValidatorSchema::from_schema_fragments(
[ValidatorSchemaFragment::from_namespaces([
ValidatorNamespaceDef::from_namespace_definition(
None,
self.0,
crate::ActionBehavior::PermitAttributes,
Extensions::all_available(),
)?,
])],
Extensions::all_available(),
)
}
}

Expand Down Expand Up @@ -819,7 +827,7 @@ impl<'a> CommonTypeResolver<'a> {
}

// Resolve common type references
fn resolve(&self) -> Result<HashMap<Name, Type>> {
fn resolve(&self, extensions: Extensions) -> Result<HashMap<Name, Type>> {
let sorted_names = self
.topo_sort()
.map_err(SchemaError::CycleInCommonTypeReferences)?;
Expand All @@ -845,6 +853,7 @@ impl<'a> CommonTypeResolver<'a> {
ValidatorNamespaceDef::try_schema_type_into_validator_type(
ns.as_ref(),
substituted_ty,
extensions,
)?
.resolve_type_defs(&HashMap::new())?,
);
Expand All @@ -866,6 +875,7 @@ mod test {
use crate::{SchemaType, SchemaTypeVariant};

use cedar_policy_core::ast::RestrictedExpr;
use cedar_policy_core::test_utils::{expect_err, ExpectedErrorMessageBuilder};
use cool_asserts::assert_matches;
use serde_json::json;

Expand Down Expand Up @@ -1389,6 +1399,7 @@ mod test {
let ty: Type = ValidatorNamespaceDef::try_schema_type_into_validator_type(
Some(&Name::parse_unqualified_name("NS").expect("Expected namespace.")),
schema_ty,
Extensions::all_available(),
)
.expect("Error converting schema type to type.")
.resolve_type_defs(&HashMap::new())
Expand All @@ -1409,6 +1420,7 @@ mod test {
let ty: Type = ValidatorNamespaceDef::try_schema_type_into_validator_type(
Some(&Name::parse_unqualified_name("NS").expect("Expected namespace.")),
schema_ty,
Extensions::all_available(),
)
.expect("Error converting schema type to type.")
.resolve_type_defs(&HashMap::new())
Expand All @@ -1434,10 +1446,14 @@ mod test {
additional_attributes: false,
}),
);
let ty: Type = ValidatorNamespaceDef::try_schema_type_into_validator_type(None, schema_ty)
.expect("Error converting schema type to type.")
.resolve_type_defs(&HashMap::new())
.unwrap();
let ty: Type = ValidatorNamespaceDef::try_schema_type_into_validator_type(
None,
schema_ty,
Extensions::all_available(),
)
.expect("Error converting schema type to type.")
.resolve_type_defs(&HashMap::new())
.unwrap();
assert_eq!(ty, Type::closed_record_with_attributes(None));
}

Expand Down Expand Up @@ -1476,7 +1492,8 @@ mod test {

#[test]
fn schema_no_fragments() {
let schema = ValidatorSchema::from_schema_fragments([]).unwrap();
let schema =
ValidatorSchema::from_schema_fragments([], Extensions::all_available()).unwrap();
assert!(schema.entity_types.is_empty());
assert!(schema.action_ids.is_empty());
}
Expand Down Expand Up @@ -1771,7 +1788,11 @@ mod test {
.unwrap()
.try_into()
.unwrap();
let schema = ValidatorSchema::from_schema_fragments([fragment1, fragment2]).unwrap();
let schema = ValidatorSchema::from_schema_fragments(
[fragment1, fragment2],
Extensions::all_available(),
)
.unwrap();

assert_eq!(
schema.entity_types.iter().next().unwrap().1.attributes,
Expand Down Expand Up @@ -1806,7 +1827,10 @@ mod test {
.try_into()
.unwrap();

let schema = ValidatorSchema::from_schema_fragments([fragment1, fragment2]);
let schema = ValidatorSchema::from_schema_fragments(
[fragment1, fragment2],
Extensions::all_available(),
);

match schema {
Err(SchemaError::DuplicateCommonType(s)) if s.contains("A::MyLong") => (),
Expand Down Expand Up @@ -2202,13 +2226,103 @@ mod test {
assert_matches!(schema, Err(SchemaError::UndeclaredCommonTypes(types)) =>
assert_eq!(types, HashSet::from(["Demo::id".to_string()])));
}

#[test]
fn unknown_extension_type() {
let src: serde_json::Value = json!({
"": {
"commonTypes": { },
"entityTypes": {
"User": {
"shape": {
"type": "Record",
"attributes": {
"a": {
"type": "Extension",
"name": "ip",
}
}
}
}
},
"actions": {}
}
});
let schema = ValidatorSchema::from_json_value(src.clone(), Extensions::all_available());
assert_matches!(schema, Err(e) => {
expect_err(
&src,
&miette::Report::new(e),
&ExpectedErrorMessageBuilder::error("unknown extension type `ip`")
.help("did you mean `ipaddr`?")
.build());
});

let src: serde_json::Value = json!({
"": {
"commonTypes": { },
"entityTypes": { },
"actions": {
"A": {
"appliesTo": {
"context": {
"type": "Record",
"attributes": {
"a": {
"type": "Extension",
"name": "deciml",
}
}
}
}
}
}
}
});
let schema = ValidatorSchema::from_json_value(src.clone(), Extensions::all_available());
assert_matches!(schema, Err(e) => {
expect_err(
&src,
&miette::Report::new(e),
&ExpectedErrorMessageBuilder::error("unknown extension type `deciml`")
.help("did you mean `decimal`?")
.build());
});

let src: serde_json::Value = json!({
"": {
"commonTypes": {
"ty": {
"type": "Record",
"attributes": {
"a": {
"type": "Extension",
"name": "i",
}
}
}
},
"entityTypes": { },
"actions": { },
}
});
let schema = ValidatorSchema::from_json_value(src.clone(), Extensions::all_available());
assert_matches!(schema, Err(e) => {
expect_err(
&src,
&miette::Report::new(e),
&ExpectedErrorMessageBuilder::error("unknown extension type `i`")
.help("did you mean `ipaddr`?")
.build());
});
}
}

#[cfg(test)]
mod test_resolver {
use std::collections::HashMap;

use cedar_policy_core::ast::Name;
use cedar_policy_core::{ast::Name, extensions::Extensions};
use cool_asserts::assert_matches;

use super::CommonTypeResolver;
Expand All @@ -2221,7 +2335,7 @@ mod test_resolver {
type_defs.extend(def.type_defs.type_defs.into_iter());
}
let resolver = CommonTypeResolver::new(&type_defs);
resolver.resolve()
resolver.resolve(Extensions::all_available())
}

#[test]
Expand Down
Loading

0 comments on commit 5b9930c

Please sign in to comment.