Skip to content

Commit

Permalink
internal/eval: move typename to eval
Browse files Browse the repository at this point in the history
Addresses IDX-142

Signed-off-by: philhassey <[email protected]>
  • Loading branch information
philhassey committed Aug 23, 2024
1 parent 327edac commit bf3d07d
Show file tree
Hide file tree
Showing 21 changed files with 67 additions and 69 deletions.
6 changes: 3 additions & 3 deletions internal/eval/evalers.go
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ func (n *attributeAccessEval) Eval(ctx *Context) (types.Value, error) {
case types.Record:
record = vv
default:
return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, v.TypeName())
return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, TypeName(v))
}
val, ok := record[n.attribute]
if !ok {
Expand Down Expand Up @@ -873,7 +873,7 @@ func (n *hasEval) Eval(ctx *Context) (types.Value, error) {
case types.Record:
record = vv
default:
return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, v.TypeName())
return zeroValue(), fmt.Errorf("%w: expected one of [record, (entity of type `any_entity_type`)], got %v", ErrType, TypeName(v))
}
_, ok := record[n.attribute]
return types.Boolean(ok), nil
Expand Down Expand Up @@ -969,7 +969,7 @@ func (n *inEval) Eval(ctx *Context) (types.Value, error) {
}
default:
return zeroValue(), fmt.Errorf(
"%w: expected one of [set, (entity of type `any_entity_type`)], got %v", ErrType, rhs.TypeName())
"%w: expected one of [set, (entity of type `any_entity_type`)], got %v", ErrType, TypeName(rhs))
}
return types.Boolean(entityIn(lhs, query, ctx.Entities)), nil
}
Expand Down
43 changes: 34 additions & 9 deletions internal/eval/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,76 +6,101 @@ import (
"github.com/cedar-policy/cedar-go/types"
)

func TypeName(v types.Value) string {
switch t := v.(type) {
case types.Boolean:
return "bool"
case types.Decimal:
return "decimal"
case types.EntityType:
return fmt.Sprintf("(EntityType of type `%s`)", t)
case types.EntityUID:
return fmt.Sprintf("(entity of type `%s`)", t.Type)
case types.IPAddr:
return "IP"
case types.Long:
return "long"
case types.Record:
return "record"
case types.Set:
return "set"
case types.String:
return "string"
default:
return "unknown type"
}
}

var ErrType = fmt.Errorf("type error")

func ValueToBool(v types.Value) (types.Boolean, error) {
bv, ok := v.(types.Boolean)
if !ok {
return false, fmt.Errorf("%w: expected bool, got %v", ErrType, v.TypeName())
return false, fmt.Errorf("%w: expected bool, got %v", ErrType, TypeName(v))
}
return bv, nil
}

func ValueToLong(v types.Value) (types.Long, error) {
lv, ok := v.(types.Long)
if !ok {
return 0, fmt.Errorf("%w: expected long, got %v", ErrType, v.TypeName())
return 0, fmt.Errorf("%w: expected long, got %v", ErrType, TypeName(v))
}
return lv, nil
}

func ValueToString(v types.Value) (types.String, error) {
sv, ok := v.(types.String)
if !ok {
return "", fmt.Errorf("%w: expected string, got %v", ErrType, v.TypeName())
return "", fmt.Errorf("%w: expected string, got %v", ErrType, TypeName(v))
}
return sv, nil
}

func ValueToSet(v types.Value) (types.Set, error) {
sv, ok := v.(types.Set)
if !ok {
return nil, fmt.Errorf("%w: expected set, got %v", ErrType, v.TypeName())
return nil, fmt.Errorf("%w: expected set, got %v", ErrType, TypeName(v))
}
return sv, nil
}

func ValueToRecord(v types.Value) (types.Record, error) {
rv, ok := v.(types.Record)
if !ok {
return nil, fmt.Errorf("%w: expected record got %v", ErrType, v.TypeName())
return nil, fmt.Errorf("%w: expected record got %v", ErrType, TypeName(v))
}
return rv, nil
}

func ValueToEntity(v types.Value) (types.EntityUID, error) {
ev, ok := v.(types.EntityUID)
if !ok {
return types.EntityUID{}, fmt.Errorf("%w: expected (entity of type `any_entity_type`), got %v", ErrType, v.TypeName())
return types.EntityUID{}, fmt.Errorf("%w: expected (entity of type `any_entity_type`), got %v", ErrType, TypeName(v))
}
return ev, nil
}

func ValueToEntityType(v types.Value) (types.EntityType, error) {
ev, ok := v.(types.EntityType)
if !ok {
return "", fmt.Errorf("%w: expected (EntityType of type `any_entity_type`), got %v", ErrType, v.TypeName())
return "", fmt.Errorf("%w: expected (EntityType of type `any_entity_type`), got %v", ErrType, TypeName(v))
}
return ev, nil
}

func ValueToDecimal(v types.Value) (types.Decimal, error) {
d, ok := v.(types.Decimal)
if !ok {
return 0, fmt.Errorf("%w: expected decimal, got %v", ErrType, v.TypeName())
return 0, fmt.Errorf("%w: expected decimal, got %v", ErrType, TypeName(v))
}
return d, nil
}

func ValueToIP(v types.Value) (types.IPAddr, error) {
i, ok := v.(types.IPAddr)
if !ok {
return types.IPAddr{}, fmt.Errorf("%w: expected ipaddr, got %v", ErrType, v.TypeName())
return types.IPAddr{}, fmt.Errorf("%w: expected ipaddr, got %v", ErrType, TypeName(v))
}
return i, nil
}
29 changes: 29 additions & 0 deletions internal/eval/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,32 @@ func TestUtil(t *testing.T) {
})

}

func TestTypeName(t *testing.T) {
t.Parallel()
tests := []struct {
name string
in types.Value
out string
}{

{"boolean", types.Boolean(true), "bool"},
{"decimal", types.Decimal(42), "decimal"},
{"entityType", types.EntityType("T"), "(EntityType of type `T`)"},
{"entityUID", types.NewEntityUID("T", "42"), "(entity of type `T`)"},
{"ip", types.IPAddr{}, "IP"},
{"long", types.Long(42), "long"},
{"record", types.Record{}, "record"},
{"set", types.Set{}, "set"},
{"string", types.String("test"), "string"},
{"nil", nil, "unknown type"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
out := TypeName(tt.in)
testutil.Equals(t, out, tt.out)
})
}
}
5 changes: 0 additions & 5 deletions types/boolean_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,4 @@ func TestBool(t *testing.T) {
AssertValueString(t, types.Boolean(true), "true")
})

t.Run("TypeName", func(t *testing.T) {
t.Parallel()
tn := types.Boolean(true).TypeName()
testutil.Equals(t, tn, "bool")
})
}
1 change: 0 additions & 1 deletion types/decimal.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ func (a Decimal) Equal(bi Value) bool {
return ok && a == b
}

func (v Decimal) TypeName() string { return "decimal" }

// Cedar produces a valid Cedar language representation of the Decimal, e.g. `decimal("12.34")`.
func (v Decimal) Cedar() string { return `decimal("` + v.String() + `")` }
Expand Down
5 changes: 0 additions & 5 deletions types/decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,4 @@ func TestDecimal(t *testing.T) {
testutil.FatalIf(t, zero.Equal(f), "%v Equal to %v", zero, f)
})

t.Run("TypeName", func(t *testing.T) {
t.Parallel()
tn := types.Decimal(0).TypeName()
testutil.Equals(t, tn, "decimal")
})
}
2 changes: 0 additions & 2 deletions types/entity_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package types

import (
"encoding/json"
"fmt"
"strings"
)

Expand All @@ -13,7 +12,6 @@ func (a EntityType) Equal(bi Value) bool {
b, ok := bi.(EntityType)
return ok && a == b
}
func (v EntityType) TypeName() string { return fmt.Sprintf("(EntityType of type `%s`)", v) }

func (v EntityType) String() string { return string(v) }
func (v EntityType) Cedar() string { return string(v) }
Expand Down
6 changes: 1 addition & 5 deletions types/entity_type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@ func TestEntityType(t *testing.T) {
testutil.Equals(t, a.Equal(c), false)
testutil.Equals(t, c.Equal(a), false)
})
t.Run("TypeName", func(t *testing.T) {
t.Parallel()
a := types.EntityType("X")
testutil.Equals(t, a.TypeName(), "(EntityType of type `X`)")
})

t.Run("String", func(t *testing.T) {
t.Parallel()
a := types.EntityType("X")
Expand Down
2 changes: 0 additions & 2 deletions types/entity_uid.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package types

import (
"encoding/json"
"fmt"
"strconv"
)

Expand All @@ -28,7 +27,6 @@ func (a EntityUID) Equal(bi Value) bool {
b, ok := bi.(EntityUID)
return ok && a == b
}
func (v EntityUID) TypeName() string { return fmt.Sprintf("(entity of type `%s`)", v.Type) }

// String produces a string representation of the EntityUID, e.g. `Type::"id"`.
func (v EntityUID) String() string { return v.Cedar() }
Expand Down
5 changes: 0 additions & 5 deletions types/entity_uid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,4 @@ func TestEntity(t *testing.T) {
AssertValueString(t, types.EntityUID{Type: "namespace::type", ID: "id"}, `namespace::type::"id"`)
})

t.Run("TypeName", func(t *testing.T) {
t.Parallel()
tn := types.EntityUID{"T", "id"}.TypeName()
testutil.Equals(t, tn, "(entity of type `T`)")
})
}
1 change: 0 additions & 1 deletion types/ipaddr.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ func (a IPAddr) Equal(bi Value) bool {
return ok && a == b
}

func (v IPAddr) TypeName() string { return "IP" }

// Cedar produces a valid Cedar language representation of the IPAddr, e.g. `ip("127.0.0.1")`.
func (v IPAddr) Cedar() string { return `ip("` + v.String() + `")` }
Expand Down
5 changes: 0 additions & 5 deletions types/ipaddr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,4 @@ func TestIP(t *testing.T) {
}
})

t.Run("TypeName", func(t *testing.T) {
t.Parallel()
tn := types.IPAddr{}.TypeName()
testutil.Equals(t, tn, "IP")
})
}
1 change: 0 additions & 1 deletion types/long.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ func (a Long) Equal(bi Value) bool {

// ExplicitMarshalJSON marshals the Long into JSON.
func (v Long) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) }
func (v Long) TypeName() string { return "long" }

// String produces a string representation of the Long, e.g. `42`.
func (v Long) String() string { return v.Cedar() }
Expand Down
5 changes: 0 additions & 5 deletions types/long_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,4 @@ func TestLong(t *testing.T) {
AssertValueString(t, types.Long(1), "1")
})

t.Run("TypeName", func(t *testing.T) {
t.Parallel()
tn := types.Long(1).TypeName()
testutil.Equals(t, tn, "long")
})
}
1 change: 0 additions & 1 deletion types/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ func (v Record) MarshalJSON() ([]byte, error) {
// ExplicitMarshalJSON marshals the Record into JSON, the marshaller uses the
// explicit JSON form for all the values in the Record.
func (v Record) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() }
func (r Record) TypeName() string { return "record" }

// String produces a string representation of the Record, e.g. `{"a":1,"b":2,"c":3}`.
func (r Record) String() string { return r.Cedar() }
Expand Down
5 changes: 0 additions & 5 deletions types/record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,4 @@ func TestRecord(t *testing.T) {
`{"bar":"blah", "foo":true}`)
})

t.Run("TypeName", func(t *testing.T) {
t.Parallel()
tn := types.Record{}.TypeName()
testutil.Equals(t, tn, "record")
})
}
2 changes: 0 additions & 2 deletions types/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ func (v Set) MarshalJSON() ([]byte, error) {
// explicit JSON form for all the values in the Set.
func (v Set) ExplicitMarshalJSON() ([]byte, error) { return v.MarshalJSON() }

func (v Set) TypeName() string { return "set" }

// String produces a string representation of the Set, e.g. `[1,2,3]`.
func (v Set) String() string { return v.Cedar() }

Expand Down
5 changes: 0 additions & 5 deletions types/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,4 @@ func TestSet(t *testing.T) {
"[true, 1]")
})

t.Run("TypeName", func(t *testing.T) {
t.Parallel()
tn := types.Set{}.TypeName()
testutil.Equals(t, tn, "set")
})
}
1 change: 0 additions & 1 deletion types/string.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ func (a String) Equal(bi Value) bool {

// ExplicitMarshalJSON marshals the String into JSON.
func (v String) ExplicitMarshalJSON() ([]byte, error) { return json.Marshal(v) }
func (v String) TypeName() string { return "string" }

// String produces an unquoted string representation of the String, e.g. `hello`.
func (v String) String() string {
Expand Down
5 changes: 0 additions & 5 deletions types/string_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,4 @@ func TestString(t *testing.T) {
AssertValueString(t, types.String("hello\ngoodbye"), "hello\ngoodbye")
})

t.Run("TypeName", func(t *testing.T) {
t.Parallel()
tn := types.String("hello").TypeName()
testutil.Equals(t, tn, "string")
})
}
1 change: 0 additions & 1 deletion types/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,5 @@ type Value interface {
// Sets or Records where the type is not defined.
ExplicitMarshalJSON() ([]byte, error)
Equal(Value) bool
TypeName() string
deepClone() Value
}

0 comments on commit bf3d07d

Please sign in to comment.