Skip to content

Commit

Permalink
fix: make COPY TO STDOUT robust to errors, again (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
fanyang01 authored Dec 26, 2024
1 parent 481a550 commit 0c978f0
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 28 deletions.
13 changes: 11 additions & 2 deletions pgserver/arrowwriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pgserver
import (
"os"
"strings"
"sync/atomic"

"github.com/apache/arrow-go/v18/arrow/ipc"
"github.com/apecloud/myduckserver/adapter"
Expand Down Expand Up @@ -62,24 +63,31 @@ func NewArrowWriter(
}, nil
}

func (dw *ArrowWriter) Start() (string, chan CopyToResult, error) {
func (dw *ArrowWriter) Start(globalErr *atomic.Pointer[error]) (string, chan CopyToResult, error) {
// Execute the statement in a separate goroutine.
ch := make(chan CopyToResult, 1)
go func() {
defer os.Remove(dw.pipePath)
defer close(ch)

dw.ctx.GetLogger().Tracef("Executing statement via Arrow interface: %s", dw.duckSQL)
conn, err := adapter.GetConn(dw.ctx)
if err != nil {
globalErr.Store(&err)
ch <- CopyToResult{Err: err}
return
}

// If there is a global error, return immediately.
if e := globalErr.Load(); e != nil {
ch <- CopyToResult{Err: *e}
return
}

// Open the pipe for writing.
// This operation will block until the reader opens the pipe for reading.
pipe, err := os.OpenFile(dw.pipePath, os.O_WRONLY, os.ModeNamedPipe)
if err != nil {
globalErr.Store(&err)
ch <- CopyToResult{Err: err}
return
}
Expand Down Expand Up @@ -114,6 +122,7 @@ func (dw *ArrowWriter) Start() (string, chan CopyToResult, error) {
}
return recordReader.Err()
}); err != nil {
globalErr.Store(&err)
ch <- CopyToResult{Err: err}
return
}
Expand Down
44 changes: 22 additions & 22 deletions pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -1414,13 +1414,13 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo
}
defer writer.Close()

pipePath, ch, err := writer.Start()
var globalErr atomic.Pointer[error]
pipePath, ch, err := writer.Start(&globalErr)
if err != nil {
return err
}

done := make(chan struct{})
var globalErr atomic.Value
var blocked atomic.Bool
blocked.Store(true)
go func() {
Expand All @@ -1431,7 +1431,8 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo
pipe, err := os.OpenFile(pipePath, os.O_RDONLY, os.ModeNamedPipe)
blocked.Store(false)
if err != nil {
globalErr.Store(fmt.Errorf("failed to open pipe for reading: %w", err))
err = fmt.Errorf("failed to open pipe for reading: %w", err)
globalErr.Store(&err)
cancel()
return
}
Expand Down Expand Up @@ -1465,39 +1466,39 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo

switch format {
case tree.CopyFormatText:
flag := true
responsed := false
reader := bufio.NewReader(pipe)
for {
line, err := reader.ReadSlice('\n')
if err != nil {
if err == io.EOF {
break
}
globalErr.Store(err)
globalErr.Store(&err)
cancel()
return
}
if flag {
flag = false
if !responsed {
responsed = true
count := bytes.Count(line, []byte{'\t'})
err := sendCopyOutResponse(count + 1)
if err != nil {
globalErr.Store(err)
globalErr.Store(&err)
cancel()
return
}
}
err = sendCopyData(line)
if err != nil {
globalErr.Store(err)
globalErr.Store(&err)
cancel()
return
}
}
default:
err := sendCopyOutResponse(1)
if err != nil {
globalErr.Store(err)
globalErr.Store(&err)
cancel()
return
}
Expand All @@ -1509,14 +1510,14 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo
if err == io.EOF {
break
}
globalErr.Store(err)
globalErr.Store(&err)
cancel()
return
}
if n > 0 {
err := sendCopyData(buf[:n])
if err != nil {
globalErr.Store(err)
globalErr.Store(&err)
cancel()
return
}
Expand All @@ -1528,30 +1529,29 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedStatement, copyTo
select {
case <-ctx.Done(): // Context is canceled
<-done
err, _ := globalErr.Load().(error)
return errors.Join(ctx.Err(), err)
if errPtr := globalErr.Load(); errPtr != nil {
return errors.Join(ctx.Err(), err)
}
return ctx.Err()
case result := <-ch:
if blocked.Load() {
// If the pipe is still opened for reading but the writer has exited,
// then we need to open the pipe for writing again to unblock the reader.
globalErr.Store(errors.Join(
fmt.Errorf("pipe is opened for reading but the writer has exited"),
result.Err,
))
pipe, _ := os.OpenFile(pipePath, os.O_WRONLY, os.ModeNamedPipe)
<-done
if pipe != nil {
pipe.Close()
}
} else {
<-done
}

<-done

if result.Err != nil {
return fmt.Errorf("failed to copy data: %w", result.Err)
}

if err, ok := globalErr.Load().(error); ok {
return err
if errPtr := globalErr.Load(); errPtr != nil {
return *errPtr
}

// After data is sent and the producer side is finished without errors, send CopyDone
Expand Down
7 changes: 4 additions & 3 deletions pgserver/datawriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"os"
"strings"
"sync/atomic"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/backend"
Expand All @@ -13,7 +14,7 @@ import (
)

type DataWriter interface {
Start() (string, chan CopyToResult, error)
Start(globalErr *atomic.Pointer[error]) (string, chan CopyToResult, error)
Close()
}

Expand Down Expand Up @@ -145,18 +146,18 @@ func NewDuckDataWriter(
}, nil
}

func (dw *DuckDataWriter) Start() (string, chan CopyToResult, error) {
func (dw *DuckDataWriter) Start(globalErr *atomic.Pointer[error]) (string, chan CopyToResult, error) {
// Execute the COPY TO statement in a separate goroutine.
ch := make(chan CopyToResult, 1)
go func() {
defer os.Remove(dw.pipePath)
defer close(ch)

dw.ctx.GetLogger().Tracef("Executing COPY TO statement: %s", dw.duckSQL)

// This operation will block until the reader opens the pipe for reading.
result, err := adapter.ExecCatalog(dw.ctx, dw.duckSQL)
if err != nil {
globalErr.Store(&err)
ch <- CopyToResult{Err: err}
return
}
Expand Down
1 change: 0 additions & 1 deletion test/bats/postgres/copy_tests.bats
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ EOF
}

@test "copy error handling" {
skip
# Test copying from non-existent schema
run psql_exec "\copy nonexistent_schema.t TO STDOUT;"
[ "$status" -ne 0 ]
Expand Down

0 comments on commit 0c978f0

Please sign in to comment.