diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc index 8ed914009db..089d15ffe30 100644 --- a/compiler/cpp/src/thrift/generate/t_go_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc @@ -2988,7 +2988,7 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* string write_err; if (!tfunction->is_oneway()) { write_err = tmp("_write_err"); - f_types_ << indent() << "var " << write_err << " error" << '\n'; + f_types_ << indent() << "var " << write_err << " thrift.TException" << '\n'; } f_types_ << indent() << "args := " << argsname << "{}" << '\n'; f_types_ << indent() << "if err2 := args." << read_method_name_ << "(ctx, iprot); err2 != nil {" << '\n'; @@ -3120,14 +3120,24 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* // Avoid writing the error to the wire if it's ErrAbandonRequest f_types_ << indent() << "if errors.Is(err2, thrift.ErrAbandonRequest) {" << '\n'; indent_up(); - f_types_ << indent() << "return false, thrift.WrapTException(err2)" << '\n'; + f_types_ << indent() << "return false, &thrift.TProcessorError{" << '\n'; + indent_up(); + f_types_ << indent() << "WriteError: thrift.WrapTException(err2)," << '\n'; + f_types_ << indent() << "EndpointError: err," << '\n'; + indent_down(); + f_types_ << indent() << "}" << '\n'; indent_down(); f_types_ << indent() << "}" << '\n'; f_types_ << indent() << "if errors.Is(err2, context.Canceled) {" << '\n'; indent_up(); - f_types_ << indent() << "if err := context.Cause(ctx); errors.Is(err, thrift.ErrAbandonRequest) {" << '\n'; + f_types_ << indent() << "if err3 := context.Cause(ctx); errors.Is(err3, thrift.ErrAbandonRequest) {" << '\n'; indent_up(); - f_types_ << indent() << "return false, thrift.WrapTException(err)" << '\n'; + f_types_ << indent() << "return false, &thrift.TProcessorError{" << '\n'; + indent_up(); + f_types_ << indent() << "WriteError: thrift.WrapTException(err3)," << '\n'; + f_types_ << indent() << "EndpointError: err," << '\n'; + indent_down(); + f_types_ << indent() << "}" << '\n'; indent_down(); f_types_ << indent() << "}" << '\n'; indent_down(); @@ -3168,7 +3178,12 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* f_types_ << indent() << "if " << write_err << " != nil {" << '\n'; indent_up(); - f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << '\n'; + f_types_ << indent() << "return false, &thrift.TProcessorError{" << '\n'; + indent_up(); + f_types_ << indent() << "WriteError: " << write_err << "," << '\n'; + f_types_ << indent() << "EndpointError: err," << '\n'; + indent_down(); + f_types_ << indent() << "}" << '\n'; indent_down(); f_types_ << indent() << "}" << '\n'; @@ -3230,7 +3245,12 @@ void t_go_generator::generate_process_function(t_service* tservice, t_function* f_types_ << indent() << "if " << write_err << " != nil {" << '\n'; indent_up(); - f_types_ << indent() << "return false, thrift.WrapTException(" << write_err << ")" << '\n'; + f_types_ << indent() << "return false, &thrift.TProcessorError{" << '\n'; + indent_up(); + f_types_ << indent() << "WriteError: " << write_err << "," << '\n'; + f_types_ << indent() << "EndpointError: err," << '\n'; + indent_down(); + f_types_ << indent() << "}" << '\n'; indent_down(); f_types_ << indent() << "}" << '\n'; diff --git a/lib/go/test/tests/processor_middleware_test.go b/lib/go/test/tests/processor_middleware_test.go index 1bd911cfe60..aedd93f2279 100644 --- a/lib/go/test/tests/processor_middleware_test.go +++ b/lib/go/test/tests/processor_middleware_test.go @@ -32,9 +32,12 @@ import ( const errorMessage = "foo error" -type serviceImpl struct{} +type serviceImpl struct { + sleepTime time.Duration +} -func (serviceImpl) Ping(_ context.Context) (err error) { +func (s serviceImpl) Ping(_ context.Context) (err error) { + time.Sleep(s.sleepTime) return &processormiddlewaretest.Error{ Foo: thrift.StringPtr(errorMessage), } @@ -67,9 +70,14 @@ func checkError(tb testing.TB, err error) { } func TestProcessorMiddleware(t *testing.T) { - const timeout = time.Second + const ( + sleepTime = 10 * time.Millisecond + timeout = sleepTime / 5 + ) - processor := processormiddlewaretest.NewServiceProcessor(&serviceImpl{}) + processor := processormiddlewaretest.NewServiceProcessor(&serviceImpl{ + sleepTime: sleepTime, + }) serverTransport, err := thrift.NewTServerSocket("127.0.0.1:0") if err != nil { t.Fatalf("Could not find available server port: %v", err) @@ -80,7 +88,9 @@ func TestProcessorMiddleware(t *testing.T) { thrift.NewTHeaderTransportFactoryConf(nil, nil), thrift.NewTHeaderProtocolFactoryConf(nil), ) - defer server.Stop() + t.Cleanup(func() { + server.Stop() + }) var wg sync.WaitGroup wg.Add(1) go func() { @@ -103,6 +113,14 @@ func TestProcessorMiddleware(t *testing.T) { client := processormiddlewaretest.NewServiceClient(thrift.NewTStandardClient(protocol, protocol)) - err = client.Ping(context.Background()) - checkError(t, err) + for label, timeout := range map[string]time.Duration{ + "enough-time": sleepTime * 10, + "not-enough-time": sleepTime / 2, + } { + t.Run(label, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(cancel) + client.Ping(ctx) + }) + } } diff --git a/lib/go/thrift/processor_factory.go b/lib/go/thrift/processor_factory.go index 245a3ccfc98..aebf50a12f4 100644 --- a/lib/go/thrift/processor_factory.go +++ b/lib/go/thrift/processor_factory.go @@ -19,7 +19,11 @@ package thrift -import "context" +import ( + "context" + "fmt" + "strings" +) // A processor is a generic object which operates upon an input stream and // writes to some output stream. @@ -78,3 +82,49 @@ func NewTProcessorFunctionFactory(p TProcessorFunction) TProcessorFunctionFactor func (p *tProcessorFunctionFactory) GetProcessorFunction(trans TTransport) TProcessorFunction { return p.processor } + +// TProcessorError is the combined original error returned by the endpoint +// implementation, and I/O error when writing the response back to the client. +// +// This type will be returned by Process function if there's an error happened +// during writing the response back to the client. ProcessorMiddlewares can +// check for this type (use errors.As) to get the underlying write and endpoint +// errors. +type TProcessorError struct { + // WriteError is the error happened during writing the response to the + // client, always set. + WriteError TException + + // EndpointError is the original error returned by the endpoint + // implementation, might be nil. + EndpointError TException +} + +func (tpe *TProcessorError) Unwrap() []error { + if tpe.EndpointError != nil { + return []error{ + tpe.WriteError, + tpe.EndpointError, + } + } + return []error{tpe.WriteError} +} + +func (tpe *TProcessorError) Error() string { + var sb strings.Builder + sb.WriteString("thrift.TProcessorError: ") + sb.WriteString(fmt.Sprintf("write response to client: %v", tpe.WriteError)) + if tpe.EndpointError != nil { + sb.WriteString(fmt.Sprintf("; original error from endpoint: %v", tpe.EndpointError)) + } + return sb.String() +} + +func (tpe *TProcessorError) TExceptionType() TExceptionType { + return tpe.WriteError.TExceptionType() +} + +var ( + _ error = (*TProcessorError)(nil) + _ TException = (*TProcessorError)(nil) +) diff --git a/lib/go/thrift/simple_server_test.go b/lib/go/thrift/simple_server_test.go index f3a59ee18ea..0fcb01bd8b8 100644 --- a/lib/go/thrift/simple_server_test.go +++ b/lib/go/thrift/simple_server_test.go @@ -294,6 +294,7 @@ func TestErrAbandonRequest(t *testing.T) { if !errors.Is(ErrAbandonRequest, context.Canceled) { t.Error("errors.Is(ErrAbandonRequest, context.Canceled) returned false") } + //lint:ignore SA1032 Intentional order for this test. if errors.Is(context.Canceled, ErrAbandonRequest) { t.Error("errors.Is(context.Canceled, ErrAbandonRequest) returned true") }