diff --git a/pgserver/arrowwriter.go b/pgserver/arrowwriter.go index c705d53..c685d34 100644 --- a/pgserver/arrowwriter.go +++ b/pgserver/arrowwriter.go @@ -3,6 +3,7 @@ package pgserver import ( "os" "strings" + "sync/atomic" "github.com/apache/arrow-go/v18/arrow/ipc" "github.com/apecloud/myduckserver/adapter" @@ -62,24 +63,31 @@ func NewArrowWriter( }, nil } -func (dw *ArrowWriter) Start() (string, chan CopyToResult, error) { +func (dw *ArrowWriter) Start(globalErr *atomic.Pointer[error]) (string, chan CopyToResult, error) { // Execute the statement in a separate goroutine. ch := make(chan CopyToResult, 1) go func() { - defer os.Remove(dw.pipePath) defer close(ch) dw.ctx.GetLogger().Tracef("Executing statement via Arrow interface: %s", dw.duckSQL) conn, err := adapter.GetConn(dw.ctx) if err != nil { + globalErr.Store(&err) ch <- CopyToResult{Err: err} return } + // If there is a global error, return immediately. + if e := globalErr.Load(); e != nil { + ch <- CopyToResult{Err: *e} + return + } + // Open the pipe for writing. // This operation will block until the reader opens the pipe for reading. pipe, err := os.OpenFile(dw.pipePath, os.O_WRONLY, os.ModeNamedPipe) if err != nil { + globalErr.Store(&err) ch <- CopyToResult{Err: err} return } @@ -114,6 +122,7 @@ func (dw *ArrowWriter) Start() (string, chan CopyToResult, error) { } return recordReader.Err() }); err != nil { + globalErr.Store(&err) ch <- CopyToResult{Err: err} return } diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index 212c080..9fd6606 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -1414,13 +1414,13 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo } defer writer.Close() - pipePath, ch, err := writer.Start() + var globalErr atomic.Pointer[error] + pipePath, ch, err := writer.Start(&globalErr) if err != nil { return err } done := make(chan struct{}) - var globalErr atomic.Value var blocked atomic.Bool blocked.Store(true) go func() { @@ -1431,7 +1431,8 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo pipe, err := os.OpenFile(pipePath, os.O_RDONLY, os.ModeNamedPipe) blocked.Store(false) if err != nil { - globalErr.Store(fmt.Errorf("failed to open pipe for reading: %w", err)) + err = fmt.Errorf("failed to open pipe for reading: %w", err) + globalErr.Store(&err) cancel() return } @@ -1465,7 +1466,7 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo switch format { case tree.CopyFormatText: - flag := true + responsed := false reader := bufio.NewReader(pipe) for { line, err := reader.ReadSlice('\n') @@ -1473,23 +1474,23 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo if err == io.EOF { break } - globalErr.Store(err) + globalErr.Store(&err) cancel() return } - if flag { - flag = false + if !responsed { + responsed = true count := bytes.Count(line, []byte{'\t'}) err := sendCopyOutResponse(count + 1) if err != nil { - globalErr.Store(err) + globalErr.Store(&err) cancel() return } } err = sendCopyData(line) if err != nil { - globalErr.Store(err) + globalErr.Store(&err) cancel() return } @@ -1497,7 +1498,7 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo default: err := sendCopyOutResponse(1) if err != nil { - globalErr.Store(err) + globalErr.Store(&err) cancel() return } @@ -1509,14 +1510,14 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo if err == io.EOF { break } - globalErr.Store(err) + globalErr.Store(&err) cancel() return } if n > 0 { err := sendCopyData(buf[:n]) if err != nil { - globalErr.Store(err) + globalErr.Store(&err) cancel() return } @@ -1528,30 +1529,29 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo select { case <-ctx.Done(): // Context is canceled <-done - err, _ := globalErr.Load().(error) - return errors.Join(ctx.Err(), err) + if errPtr := globalErr.Load(); errPtr != nil { + return errors.Join(ctx.Err(), err) + } + return ctx.Err() case result := <-ch: if blocked.Load() { // If the pipe is still opened for reading but the writer has exited, // then we need to open the pipe for writing again to unblock the reader. - globalErr.Store(errors.Join( - fmt.Errorf("pipe is opened for reading but the writer has exited"), - result.Err, - )) pipe, _ := os.OpenFile(pipePath, os.O_WRONLY, os.ModeNamedPipe) + <-done if pipe != nil { pipe.Close() } + } else { + <-done } - <-done - if result.Err != nil { return fmt.Errorf("failed to copy data: %w", result.Err) } - if err, ok := globalErr.Load().(error); ok { - return err + if errPtr := globalErr.Load(); errPtr != nil { + return *errPtr } // After data is sent and the producer side is finished without errors, send CopyDone diff --git a/pgserver/datawriter.go b/pgserver/datawriter.go index 1261ea6..d9de498 100644 --- a/pgserver/datawriter.go +++ b/pgserver/datawriter.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "strings" + "sync/atomic" "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/backend" @@ -13,7 +14,7 @@ import ( ) type DataWriter interface { - Start() (string, chan CopyToResult, error) + Start(globalErr *atomic.Pointer[error]) (string, chan CopyToResult, error) Close() } @@ -145,11 +146,10 @@ func NewDuckDataWriter( }, nil } -func (dw *DuckDataWriter) Start() (string, chan CopyToResult, error) { +func (dw *DuckDataWriter) Start(globalErr *atomic.Pointer[error]) (string, chan CopyToResult, error) { // Execute the COPY TO statement in a separate goroutine. ch := make(chan CopyToResult, 1) go func() { - defer os.Remove(dw.pipePath) defer close(ch) dw.ctx.GetLogger().Tracef("Executing COPY TO statement: %s", dw.duckSQL) @@ -157,6 +157,7 @@ func (dw *DuckDataWriter) Start() (string, chan CopyToResult, error) { // This operation will block until the reader opens the pipe for reading. result, err := adapter.ExecCatalog(dw.ctx, dw.duckSQL) if err != nil { + globalErr.Store(&err) ch <- CopyToResult{Err: err} return } diff --git a/test/bats/postgres/copy_tests.bats b/test/bats/postgres/copy_tests.bats index b397f96..5296b1f 100644 --- a/test/bats/postgres/copy_tests.bats +++ b/test/bats/postgres/copy_tests.bats @@ -133,7 +133,6 @@ EOF } @test "copy error handling" { - skip # Test copying from non-existent schema run psql_exec "\copy nonexistent_schema.t TO STDOUT;" [ "$status" -ne 0 ]