diff --git a/internal/errors.go b/internal/errors.go index 494a3f446..e0290de7c 100644 --- a/internal/errors.go +++ b/internal/errors.go @@ -61,6 +61,6 @@ func (e *Error) Unwrap() error { return e.err } -func internalError(code string, err error) Error { - return Error{Code: code, err: err} +func newInternalError(code string, err error) *Error { + return &Error{Code: code, err: err} } diff --git a/internal/internal.go b/internal/internal.go index a133d08bb..042630f63 100644 --- a/internal/internal.go +++ b/internal/internal.go @@ -370,22 +370,22 @@ func (p *envoyExtAuthzGrpcServer) Check(ctx context.Context, req *ext_authz_v3.C func (p *envoyExtAuthzGrpcServer) check(ctx context.Context, req interface{}) (*ext_authz_v3.CheckResponse, func() *rpc_status.Status, *Error) { var err error var evalErr error - var internalErr Error + var internalErr *Error start := time.Now() logger := p.manager.Logger() result, stopeval, err := envoyauth.NewEvalResult() if err != nil { logger.WithFields(map[string]interface{}{"err": err}).Error("Unable to start new evaluation.") - internalErr = internalError(StartCheckErr, err) - return nil, func() *rpc_status.Status { return nil }, &internalErr + internalErr = newInternalError(StartCheckErr, err) + return nil, func() *rpc_status.Status { return nil }, internalErr } txn, txnClose, err := result.GetTxn(ctx, p.Store()) if err != nil { logger.WithFields(map[string]interface{}{"err": err}).Error("Unable to start new storage transaction.") - internalErr = internalError(StartTxnErr, err) - return nil, func() *rpc_status.Status { return nil }, &internalErr + internalErr = newInternalError(StartTxnErr, err) + return nil, func() *rpc_status.Status { return nil }, internalErr } result.Txn = txn @@ -398,9 +398,9 @@ func (p *envoyExtAuthzGrpcServer) check(ctx context.Context, req interface{}) (* stopeval() if p.cfg.EnablePerformanceMetrics { var topdownError *topdown.Error - if internalErr.Unwrap() != nil && errors.As(internalErr.Unwrap(), &topdownError) { + if internalErr != nil && errors.As(internalErr.Unwrap(), &topdownError) { p.metricErrorCounter.With(prometheus.Labels{"reason": topdownError.Code}).Inc() - } else if internalErr.Code != "" { + } else if internalErr != nil && internalErr.Code != "" { p.metricErrorCounter.With(prometheus.Labels{"reason": internalErr.Code}).Inc() } } @@ -422,27 +422,41 @@ func (p *envoyExtAuthzGrpcServer) check(ctx context.Context, req interface{}) (* input, err = envoyauth.RequestToInput(req, logger, p.cfg.protoSet, p.cfg.SkipRequestBodyParse) if err != nil { - internalErr = internalError(RequestParseErr, err) - return nil, stop, &internalErr + internalErr = newInternalError(RequestParseErr, err) + return &ext_authz_v3.CheckResponse{ + Status: &rpc_status.Status{ + Code: int32(code.Code_PERMISSION_DENIED), + Message: internalErr.Error(), + }, + HttpResponse: &ext_authz_v3.CheckResponse_DeniedResponse{ + DeniedResponse: &ext_authz_v3.DeniedHttpResponse{ + Status: &ext_type_v3.HttpStatus{ + Code: ext_type_v3.StatusCode(ext_type_v3.StatusCode_BadRequest), + }, + Body: internalErr.Error(), + }, + }, + DynamicMetadata: nil, + }, stop, nil } if ctx.Err() != nil { err = errors.Wrap(ctx.Err(), "check request timed out before query execution") - internalErr = internalError(CheckRequestTimeoutErr, err) - return nil, stop, &internalErr + internalErr = newInternalError(CheckRequestTimeoutErr, err) + return nil, stop, internalErr } var inputValue ast.Value inputValue, err = ast.InterfaceToValue(input) if err != nil { - internalErr = internalError(InputParseErr, err) - return nil, stop, &internalErr + internalErr = newInternalError(InputParseErr, err) + return nil, stop, internalErr } if err = envoyauth.Eval(ctx, p, inputValue, result); err != nil { evalErr = err - internalErr = internalError(EnvoyAuthEvalErr, err) - return nil, stop, &internalErr + internalErr = newInternalError(EnvoyAuthEvalErr, err) + return nil, stop, internalErr } resp := &ext_authz_v3.CheckResponse{} @@ -451,8 +465,8 @@ func (p *envoyExtAuthzGrpcServer) check(ctx context.Context, req interface{}) (* allowed, err = result.IsAllowed() if err != nil { err = errors.Wrap(err, "failed to get response status") - internalErr = internalError(EnvoyAuthResultErr, err) - return nil, stop, &internalErr + internalErr = newInternalError(EnvoyAuthResultErr, err) + return nil, stop, internalErr } status := int32(code.Code_PERMISSION_DENIED) @@ -467,16 +481,16 @@ func (p *envoyExtAuthzGrpcServer) check(ctx context.Context, req interface{}) (* responseHeaders, err = result.GetResponseEnvoyHeaderValueOptions() if err != nil { err = errors.Wrap(err, "failed to get response headers") - internalErr = internalError(EnvoyAuthResultErr, err) - return nil, stop, &internalErr + internalErr = newInternalError(EnvoyAuthResultErr, err) + return nil, stop, internalErr } var dynamicMetadata *_structpb.Struct dynamicMetadata, err = result.GetDynamicMetadata() if err != nil { err = errors.Wrap(err, "failed to get dynamic metadata") - internalErr = internalError(EnvoyAuthResultErr, err) - return nil, stop, &internalErr + internalErr = newInternalError(EnvoyAuthResultErr, err) + return nil, stop, internalErr } resp.DynamicMetadata = dynamicMetadata @@ -486,16 +500,16 @@ func (p *envoyExtAuthzGrpcServer) check(ctx context.Context, req interface{}) (* headersToRemove, err = result.GetRequestHTTPHeadersToRemove() if err != nil { err = errors.Wrap(err, "failed to get request headers to remove") - internalErr = internalError(EnvoyAuthResultErr, err) - return nil, stop, &internalErr + internalErr = newInternalError(EnvoyAuthResultErr, err) + return nil, stop, internalErr } var responseHeadersToAdd []*ext_core_v3.HeaderValueOption responseHeadersToAdd, err = result.GetResponseHTTPHeadersToAdd() if err != nil { err = errors.Wrap(err, "failed to get response headers to send to client") - internalErr = internalError(EnvoyAuthResultErr, err) - return nil, stop, &internalErr + internalErr = newInternalError(EnvoyAuthResultErr, err) + return nil, stop, internalErr } resp.HttpResponse = &ext_authz_v3.CheckResponse_OkResponse{ @@ -510,16 +524,16 @@ func (p *envoyExtAuthzGrpcServer) check(ctx context.Context, req interface{}) (* body, err = result.GetResponseBody() if err != nil { err = errors.Wrap(err, "failed to get response body") - internalErr = internalError(EnvoyAuthResultErr, err) - return nil, stop, &internalErr + internalErr = newInternalError(EnvoyAuthResultErr, err) + return nil, stop, internalErr } var httpStatus *ext_type_v3.HttpStatus httpStatus, err = result.GetResponseEnvoyHTTPStatus() if err != nil { err = errors.Wrap(err, "failed to get response http status") - internalErr = internalError(EnvoyAuthResultErr, err) - return nil, stop, &internalErr + internalErr = newInternalError(EnvoyAuthResultErr, err) + return nil, stop, internalErr } deniedResponse := &ext_authz_v3.DeniedHttpResponse{ diff --git a/internal/internal_test.go b/internal/internal_test.go index 7ea6fadd2..f4edb1f28 100644 --- a/internal/internal_test.go +++ b/internal/internal_test.go @@ -8,6 +8,8 @@ import ( "context" "errors" "fmt" + ext_type_v2 "github.com/envoyproxy/go-control-plane/envoy/type" + ext_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" "net/http" "net/http/httptest" "reflect" @@ -862,12 +864,20 @@ func TestCheckBadDecisionWithLogger(t *testing.T) { ctx := context.Background() output, err := server.Check(ctx, &req) - if err == nil { - t.Fatal("Expected error but got nil") + if err != nil { + t.Fatal("Expected nil err") } - if output != nil { - t.Fatalf("Expected no output but got %v", output) + if output == nil { + t.Fatal("Expected output but got nil") + } + if output.Status.Code != int32(code.Code_PERMISSION_DENIED) { + t.Fatalf("Expected status %v status: %v", int32(code.Code_PERMISSION_DENIED), output.Status.Code) + } + if deniedResponse, ok := output.HttpResponse.(*ext_authz.CheckResponse_DeniedResponse); !ok { + t.Fatalf("Expected http response of type ext_authz.CheckResponse_DeniedResponse") + } else if deniedResponse.DeniedResponse.Status.Code != ext_type_v3.StatusCode_BadRequest { + t.Fatalf("Unexpected http status code: %v", deniedResponse.DeniedResponse.Status.Code) } if len(customLogger.events) != 1 { @@ -1056,12 +1066,20 @@ func TestCheckBadDecisionWithLoggerV2(t *testing.T) { ctx := context.Background() output, err := server.Check(ctx, &req) - if err == nil { - t.Fatal("Expected error but got nil") + if err != nil { + t.Fatal("Expected nil err") } - if output != nil { - t.Fatalf("Expected no output but got %v", output) + if output == nil { + t.Fatal("Expected output but got nil") + } + if output.Status.Code != int32(code.Code_PERMISSION_DENIED) { + t.Fatalf("Expected status %v status: %v", int32(code.Code_PERMISSION_DENIED), output.Status.Code) + } + if deniedResponse, ok := output.HttpResponse.(*ext_authz_v2.CheckResponse_DeniedResponse); !ok { + t.Fatalf("Expected http response of type ext_authz.CheckResponse_DeniedResponse") + } else if deniedResponse.DeniedResponse.Status.Code != ext_type_v2.StatusCode_BadRequest { + t.Fatalf("Unexpected http status code: %v", deniedResponse.DeniedResponse.Status.Code) } if len(customLogger.events) != 1 {