diff --git a/ast/ast_test.go b/ast/ast_test.go index e2b249d..53a0379 100644 --- a/ast/ast_test.go +++ b/ast/ast_test.go @@ -12,6 +12,10 @@ import ( ) func TestASTByTable(t *testing.T) { + type CustomString string + type CustomBool bool + type CustomInt int + type CustomInt64 int64 t.Parallel() tests := []struct { name string @@ -143,6 +147,16 @@ func TestASTByTable(t *testing.T) { ast.Permit().When(ast.Boolean(true)), internalast.Permit().When(internalast.Boolean(true)), }, + { + "customValueBoolFalse", + ast.Permit().When(ast.Boolean(CustomBool(false))), + internalast.Permit().When(internalast.Boolean(false)), + }, + { + "customValueBoolTrue", + ast.Permit().When(ast.Boolean(CustomBool(true))), + internalast.Permit().When(internalast.Boolean(true)), + }, { "valueTrue", ast.Permit().When(ast.True()), @@ -158,6 +172,21 @@ func TestASTByTable(t *testing.T) { ast.Permit().When(ast.String("cedar")), internalast.Permit().When(internalast.String("cedar")), }, + { + "customValueString", + ast.Permit().When(ast.String(CustomString("cedar"))), + internalast.Permit().When(internalast.String("cedar")), + }, + { + "customValueInt", + ast.Permit().When(ast.Long(CustomInt(42))), + internalast.Permit().When(internalast.Long(42)), + }, + { + "customValueInt64", + ast.Permit().When(ast.Long(CustomInt64(42))), + internalast.Permit().When(internalast.Long(42)), + }, { "valueLong", ast.Permit().When(ast.Long(42)), diff --git a/ast/value.go b/ast/value.go index 258683d..3be1731 100644 --- a/ast/value.go +++ b/ast/value.go @@ -9,7 +9,7 @@ import ( ) // Boolean creates a value node containing a Boolean. -func Boolean[T bool | types.Boolean](b T) Node { +func Boolean[T ~bool](b T) Node { return wrapNode(ast.Boolean(types.Boolean(b))) } @@ -24,12 +24,12 @@ func False() Node { } // String creates a value node containing a String. -func String[T string | types.String](s T) Node { +func String[T ~string](s T) Node { return wrapNode(ast.String(types.String(s))) } // Long creates a value node containing a Long. -func Long[T int | int64 | types.Long](l T) Node { +func Long[T ~int | ~int64](l T) Node { return wrapNode(ast.Long(types.Long(l))) }