Skip to content

Commit

Permalink
Merge pull request #77 from strongdm/tag-support
Browse files Browse the repository at this point in the history
RFC-82: Implement support for Entity Tags
  • Loading branch information
patjakdev authored Dec 13, 2024
2 parents 96529dc + 21b073c commit d9c6d61
Show file tree
Hide file tree
Showing 32 changed files with 679 additions and 11 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ The Go implementation does not yet include:
- schema support and the [validator](https://docs.cedarpolicy.com/policies/validation.html)
- the formatter
- partial evaluation
- support for [RFC 82](https://github.com/cedar-policy/rfcs/blob/main/text/0082-entity-tags.md) (entity tags)
- support for [policy templates](https://docs.cedarpolicy.com/policies/templates.html)

## Quick Start
Expand Down
10 changes: 10 additions & 0 deletions ast/ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,16 @@ func TestASTByTable(t *testing.T) {
ast.Permit().When(ast.Long(42).Has("key")),
internalast.Permit().When(internalast.Long(42).Has("key")),
},
{
"opGetTag",
ast.Permit().When(ast.EntityUID("T", "1").GetTag(ast.String("key"))),
internalast.Permit().When(internalast.EntityUID("T", "1").GetTag(internalast.String("key"))),
},
{
"opsHasTag",
ast.Permit().When(ast.EntityUID("T", "1").HasTag(ast.String("key"))),
internalast.Permit().When(internalast.EntityUID("T", "1").HasTag(internalast.String("key"))),
},
{
"opIsIpv4",
ast.Permit().When(ast.Long(42).IsIpv4()),
Expand Down
8 changes: 8 additions & 0 deletions ast/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ func (lhs Node) Has(attr types.String) Node {
return wrapNode(lhs.Node.Has(attr))
}

func (lhs Node) GetTag(rhs Node) Node {
return wrapNode(lhs.Node.GetTag(rhs.Node))
}

func (lhs Node) HasTag(rhs Node) Node {
return wrapNode(lhs.Node.HasTag(rhs.Node))
}

// ___ ____ _ _ _
// |_ _| _ \ / \ __| | __| |_ __ ___ ___ ___
// | || |_) / _ \ / _` |/ _` | '__/ _ \/ __/ __|
Expand Down
4 changes: 4 additions & 0 deletions authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ const (
// IsAuthorized uses the combination of the PolicySet and Entities to determine
// if the given Request to determine Decision and Diagnostic.
func (p PolicySet) IsAuthorized(entities types.EntityGetter, req Request) (Decision, Diagnostic) {
if entities == nil {
var zero types.EntityMap
entities = zero
}
env := eval.Env{
Entities: entities,
Principal: req.Principal,
Expand Down
31 changes: 30 additions & 1 deletion authorize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (

"github.com/cedar-policy/cedar-go"
"github.com/cedar-policy/cedar-go/internal/testutil"
"github.com/cedar-policy/cedar-go/types"
)

//nolint:revive // due to table test function-length
Expand All @@ -15,7 +16,7 @@ func TestIsAuthorized(t *testing.T) {
tests := []struct {
Name string
Policy string
Entities cedar.EntityMap
Entities types.EntityGetter
Principal, Action, Resource cedar.EntityUID
Context cedar.Record
Want cedar.Decision
Expand All @@ -33,6 +34,34 @@ func TestIsAuthorized(t *testing.T) {
Want: true,
DiagErr: 0,
},
{
Name: "permit-when-tags",
Policy: `permit(principal,action,resource) when { principal.hasTag("foo") };`,
Entities: types.EntityMap{
cuzco: types.Entity{
Tags: types.NewRecord(cedar.RecordMap{
"foo": types.String("bar"),
}),
},
},
Principal: cuzco,
Action: dropTable,
Resource: cedar.NewEntityUID("table", "whatever"),
Context: cedar.Record{},
Want: true,
DiagErr: 0,
},
{
Name: "nil-entity-getter",
Policy: `permit(principal,action,resource);`,
Entities: nil,
Principal: cuzco,
Action: dropTable,
Resource: cedar.NewEntityUID("table", "whatever"),
Context: cedar.Record{},
Want: true,
DiagErr: 0,
},
{
Name: "simple-forbid",
Policy: `forbid(principal,action,resource);`,
Expand Down
Binary file modified corpus-tests.tar.gz
Binary file not shown.
12 changes: 11 additions & 1 deletion corpus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/cedar-policy/cedar-go"
"github.com/cedar-policy/cedar-go/internal/testutil"
"github.com/cedar-policy/cedar-go/types"
"github.com/cedar-policy/cedar-go/x/exp/batch"
)

Expand Down Expand Up @@ -244,6 +245,7 @@ func TestCorpusRelated(t *testing.T) {
tests := []struct {
name string
policy string
entities types.EntityGetter
request cedar.Request
decision cedar.Decision
reasons []cedar.PolicyID
Expand All @@ -258,6 +260,7 @@ func TestCorpusRelated(t *testing.T) {
) when {
(true && (((!870985681610) == principal) == principal)) && principal
};`,
nil,
cedar.Request{Principal: cedar.NewEntityUID("a", "\u0000\u0000"), Action: cedar.NewEntityUID("Action", "action"), Resource: cedar.NewEntityUID("a", "\u0000\u0000")},
cedar.Deny,
nil,
Expand All @@ -273,6 +276,7 @@ func TestCorpusRelated(t *testing.T) {
) when {
(((!870985681610) == principal) == principal)
};`,
nil,
cedar.Request{Principal: cedar.NewEntityUID("a", "\u0000\u0000"), Action: cedar.NewEntityUID("Action", "action"), Resource: cedar.NewEntityUID("a", "\u0000\u0000")},
cedar.Deny,
nil,
Expand All @@ -287,6 +291,7 @@ func TestCorpusRelated(t *testing.T) {
) when {
((!870985681610) == principal)
};`,
nil,
cedar.Request{Principal: cedar.NewEntityUID("a", "\u0000\u0000"), Action: cedar.NewEntityUID("Action", "action"), Resource: cedar.NewEntityUID("a", "\u0000\u0000")},
cedar.Deny,
nil,
Expand All @@ -302,6 +307,7 @@ func TestCorpusRelated(t *testing.T) {
) when {
(!870985681610)
};`,
nil,
cedar.Request{Principal: cedar.NewEntityUID("a", "\u0000\u0000"), Action: cedar.NewEntityUID("Action", "action"), Resource: cedar.NewEntityUID("a", "\u0000\u0000")},
cedar.Deny,
nil,
Expand All @@ -317,6 +323,7 @@ func TestCorpusRelated(t *testing.T) {
) when {
((!42) == principal)
};`,
nil,
cedar.Request{},
cedar.Deny,
nil,
Expand All @@ -332,6 +339,7 @@ func TestCorpusRelated(t *testing.T) {
) when {
(!42 == principal)
};`,
nil,
cedar.Request{},
cedar.Deny,
nil,
Expand All @@ -346,6 +354,7 @@ func TestCorpusRelated(t *testing.T) {
) when {
true && ((if (principal in action) then (ip("")) else (if true then (ip("6b6b:f00::32ff:ffff:6368/00")) else (ip("7265:6c69:706d:6f43:5f74:6f70:7374:6f68")))).isMulticast())
};`,
nil,
cedar.Request{Principal: cedar.NewEntityUID("a", "\u0000\b\u0011\u0000R"), Action: cedar.NewEntityUID("Action", "action"), Resource: cedar.NewEntityUID("a", "\u0000\b\u0011\u0000R")},
cedar.Deny,
nil,
Expand All @@ -360,6 +369,7 @@ func TestCorpusRelated(t *testing.T) {
) when {
true && ip("6b6b:f00::32ff:ffff:6368/00").isMulticast()
};`,
nil,
cedar.Request{},
cedar.Deny,
nil,
Expand All @@ -386,7 +396,7 @@ func TestCorpusRelated(t *testing.T) {
t.Parallel()
policy, err := cedar.NewPolicySetFromBytes("", []byte(tt.policy))
testutil.OK(t, err)
ok, diag := policy.IsAuthorized(cedar.EntityMap{}, tt.request)
ok, diag := policy.IsAuthorized(tt.entities, tt.request)
testutil.Equals(t, ok, tt.decision)
var reasons []cedar.PolicyID
for _, n := range diag.Reasons {
Expand Down
4 changes: 4 additions & 0 deletions internal/eval/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ func toEval(n ast.IsNode) Evaler {
return newAttributeAccessEval(toEval(v.Arg), v.Value)
case ast.NodeTypeHas:
return newHasEval(toEval(v.Arg), v.Value)
case ast.NodeTypeGetTag:
return newGetTagEval(toEval(v.Left), toEval(v.Right))
case ast.NodeTypeHasTag:
return newHasTagEval(toEval(v.Left), toEval(v.Right))
case ast.NodeTypeLike:
return newLikeEval(toEval(v.Arg), v.Value)
case ast.NodeTypeIfThenElse:
Expand Down
18 changes: 17 additions & 1 deletion internal/eval/convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ func TestToEval(t *testing.T) {
types.True,
testutil.OK,
},
{
"getTag",
ast.EntityUID("T", "ID").GetTag(ast.String("key")),
types.Long(42),
testutil.OK,
},
{
"hasTag",
ast.EntityUID("T", "ID").HasTag(ast.String("key")),
types.True,
testutil.OK,
},
{
"like",
ast.String("test").Like(types.Pattern{}),
Expand Down Expand Up @@ -355,7 +367,11 @@ func TestToEval(t *testing.T) {
Action: types.NewEntityUID("Action", "test"),
Resource: types.NewEntityUID("Resource", "database"),
Context: types.Record{},
Entities: types.EntityMap{},
Entities: types.EntityMap{
types.NewEntityUID("T", "ID"): types.Entity{
Tags: types.NewRecord(types.RecordMap{"key": types.Long(42)}),
},
},
})
tt.err(t, err)
testutil.Equals(t, out, tt.out)
Expand Down
70 changes: 69 additions & 1 deletion internal/eval/evalers.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ var errOverflow = fmt.Errorf("integer overflow")
var errUnknownExtensionFunction = fmt.Errorf("function does not exist")
var errArity = fmt.Errorf("wrong number of arguments provided to extension function")
var errAttributeAccess = fmt.Errorf("does not have the attribute")
var errTagAccess = fmt.Errorf("does not have the tag")
var errEntityNotExist = fmt.Errorf("does not exist")
var errUnspecifiedEntity = fmt.Errorf("unspecified entity")

Expand Down Expand Up @@ -804,6 +805,73 @@ func (n *hasEval) Eval(env Env) (types.Value, error) {
return types.Boolean(ok), nil
}

// getTagEval
type getTagEval struct {
lhs, rhs Evaler
}

func newGetTagEval(object, tag Evaler) *getTagEval {
return &getTagEval{lhs: object, rhs: tag}
}

func (n *getTagEval) Eval(env Env) (types.Value, error) {
eid, err := evalEntity(n.lhs, env)
if err != nil {
return zeroValue(), err
}

var zero types.EntityUID
if eid == zero {
return zeroValue(), fmt.Errorf("cannot access tag `%s` of %w", n.rhs, errUnspecifiedEntity)
}

t, err := evalString(n.rhs, env)
if err != nil {
return zeroValue(), err
}

e, ok := env.Entities.Get(eid)
if !ok {
return zeroValue(), fmt.Errorf("entity `%v` %w", eid.String(), errEntityNotExist)
}

val, ok := e.Tags.Get(t)
if !ok {
return zeroValue(), fmt.Errorf("`%s` %w `%s`", eid.String(), errTagAccess, t)
}

return val, nil
}

// hasTagEval
type hasTagEval struct {
lhs, rhs Evaler
}

func newHasTagEval(object, tag Evaler) *hasTagEval {
return &hasTagEval{lhs: object, rhs: tag}
}

func (n *hasTagEval) Eval(env Env) (types.Value, error) {
eid, err := evalEntity(n.lhs, env)
if err != nil {
return zeroValue(), err
}

t, err := evalString(n.rhs, env)
if err != nil {
return zeroValue(), err
}

e, ok := env.Entities.Get(eid)
if !ok {
return types.False, nil
}

_, ok = e.Tags.Get(t)
return types.Boolean(ok), nil
}

// likeEval
type likeEval struct {
lhs Evaler
Expand Down Expand Up @@ -1139,7 +1207,7 @@ func newExtensionEval(name types.Path, args []Evaler) Evaler {

if i, ok := extensions.ExtMap[name]; ok {
if i.Args != len(args) {
return newErrorEval(fmt.Errorf("%w: %s takes %d parameter(s)", errArity, name, i.Args))
return newErrorEval(fmt.Errorf("%w: %s takes %d parameter(s), but %d provided", errArity, name, i.Args, len(args)))
}
switch {
case name == "datetime":
Expand Down
Loading

0 comments on commit d9c6d61

Please sign in to comment.