diff --git a/.gitignore b/.gitignore index 9b4ebe9c..8ff8ac38 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ pipes/ *.sock __debug_* .DS_Store +*.csv diff --git a/backend/loaddata.go b/backend/loaddata.go index 891db0ad..50127793 100644 --- a/backend/loaddata.go +++ b/backend/loaddata.go @@ -8,7 +8,6 @@ import ( "runtime" "strconv" "strings" - "syscall" "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/catalog" @@ -80,21 +79,15 @@ func (db *DuckBuilder) buildClientSideLoadData(ctx *sql.Context, insert *plan.In } defer reader.Close() - // Create the FIFO pipe - pipeDir := filepath.Join(db.provider.DataDir(), "pipes", "load-data") - if err := os.MkdirAll(pipeDir, 0755); err != nil { - return nil, err - } - pipeName := strconv.Itoa(int(ctx.ID())) + ".pipe" - pipePath := filepath.Join(pipeDir, pipeName) - if err := syscall.Mkfifo(pipePath, 0600); err != nil { + pipePath, err := db.CreatePipe(ctx, "load-data") + if err != nil { return nil, err } defer os.Remove(pipePath) // Write the data to the FIFO pipe. go func() { - pipe, err := os.OpenFile(pipePath, os.O_WRONLY, 0600) + pipe, err := os.OpenFile(pipePath, os.O_WRONLY, os.ModeNamedPipe) if err != nil { return } diff --git a/backend/pipe.go b/backend/pipe.go index e69de29b..f8ea1d7f 100644 --- a/backend/pipe.go +++ b/backend/pipe.go @@ -0,0 +1,25 @@ +package backend + +import ( + "os" + "path/filepath" + "strconv" + "syscall" + + "github.com/dolthub/go-mysql-server/sql" +) + +func (h *DuckBuilder) CreatePipe(ctx *sql.Context, subdir string) (string, error) { + // Create the FIFO pipe + pipeDir := filepath.Join(h.provider.DataDir(), "pipes", subdir) + if err := os.MkdirAll(pipeDir, 0755); err != nil { + return "", err + } + pipeName := strconv.Itoa(int(ctx.ID())) + ".pipe" + pipePath := filepath.Join(pipeDir, pipeName) + ctx.GetLogger().Debugln("Creating FIFO pipe for COPY FROM operation:", pipePath) + if err := syscall.Mkfifo(pipePath, 0600); err != nil { + return "", err + } + return pipePath, nil +} diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index ad8e337d..8d87e7f6 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -20,6 +20,7 @@ import ( "context" "crypto/tls" "encoding/json" + "errors" "fmt" "io" "net" @@ -1144,6 +1145,11 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tre } ctx.SetLogger(ctx.GetLogger().WithField("query", query.String)) + // Create cancelable context + childCtx, cancel := context.WithCancel(ctx) + defer cancel() + ctx = ctx.WithContext(childCtx) + table, err := ValidateCopyTo(copyTo, ctx) if err != nil { return err @@ -1153,35 +1159,45 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tre if copyTo.Statement != nil { stmt = copyTo.Statement.String() } - dataWriter, err := NewDataWriter(ctx, h.duckHandler, table, copyTo.Table, copyTo.Columns, stmt, ©To.Options) + writer, err := NewDataWriter( + ctx, h.duckHandler, + copyTo.Table.Schema(), table, copyTo.Columns, + stmt, + ©To.Options, + ) if err != nil { return err } - defer dataWriter.Cancel() + defer writer.Close() // Send CopyOutResponse to the client + ctx.GetLogger().Debug("sending CopyOutResponse to the client") copyOutResponse := &pgproto3.CopyOutResponse{ - OverallFormat: 0, // 0 for text format + OverallFormat: 0, // 0 for text format + ColumnFormatCodes: []uint16{0}, // 0 for text format } if err := h.send(copyOutResponse); err != nil { return err } - // Create a channel to receive the result from the goroutine - type copyResult struct { - rowCount int - err error - } - done := make(chan copyResult, 1) - - pipe, ch, err := dataWriter.Start() + pipePath, ch, err := writer.Start() if err != nil { return err } + var sendErr atomic.Value go func() { + // Open the pipe for reading. + ctx.GetLogger().Tracef("Opening FIFO pipe for reading: %s", pipePath) + pipe, err := os.OpenFile(pipePath, os.O_RDONLY, os.ModeNamedPipe) + if err != nil { + sendErr.Store(fmt.Errorf("failed to open pipe for reading: %w", err)) + cancel() + return + } defer pipe.Close() - defer close(done) + + ctx.GetLogger().Debug("Copying data from the pipe to the client") buf := make([]byte, 1<<20) // 1MB buffer for { n, err := pipe.Read(buf) @@ -1189,8 +1205,10 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tre copyData := &pgproto3.CopyData{ Data: buf[:n], } + ctx.GetLogger().Debugf("sending CopyData (%d bytes) to the client", n) if err := h.send(copyData); err != nil { sendErr.Store(err) + cancel() return } } @@ -1199,27 +1217,35 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tre break } sendErr.Store(err) + cancel() return } } }() - result := <-ch - if result.Err != nil { - return fmt.Errorf("failed to copy data: %w", result.Err) - } + select { + case <-ctx.Done(): // Context is canceled + err, _ := sendErr.Load().(error) + return errors.Join(ctx.Err(), err) + case result := <-ch: + if result.Err != nil { + return fmt.Errorf("failed to copy data: %w", result.Err) + } - if err := sendErr.Load(); err != nil { - return err.(error) - } + if err, ok := sendErr.Load().(error); ok { + return err + } - // After data is sent and CopyToPipe is finished without errors, send CopyDone - if err := h.send(&pgproto3.CopyDone{}); err != nil { - return err - } + // After data is sent and the producer side is finished without errors, send CopyDone + ctx.GetLogger().Debug("sending CopyDone to the client") + if err := h.send(&pgproto3.CopyDone{}); err != nil { + return err + } - // Send CommandComplete with the number of rows copied - return h.send(&pgproto3.CommandComplete{ - CommandTag: []byte(fmt.Sprintf("COPY %d", result.RowCount)), - }) + // Send CommandComplete with the number of rows copied + ctx.GetLogger().Debugf("sending CommandComplete to the client") + return h.send(&pgproto3.CommandComplete{ + CommandTag: []byte(fmt.Sprintf("COPY %d", result.RowCount)), + }) + } } diff --git a/pgserver/dataloader.go b/pgserver/dataloader.go index 5202c375..2208975d 100644 --- a/pgserver/dataloader.go +++ b/pgserver/dataloader.go @@ -7,11 +7,9 @@ import ( "fmt" "io" "os" - "path/filepath" "strconv" "strings" "sync/atomic" - "syscall" "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/backend" @@ -64,18 +62,10 @@ type CsvDataLoader struct { var _ DataLoader = (*CsvDataLoader)(nil) func NewCsvDataLoader(ctx *sql.Context, handler *DuckHandler, schema string, table sql.InsertableTable, columns tree.NameList, options *tree.CopyOptions) (DataLoader, error) { - duckBuilder := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder) - dataDir := duckBuilder.Provider().DataDir() - // Create the FIFO pipe - pipeDir := filepath.Join(dataDir, "pipes", "pg-copy-from") - if err := os.MkdirAll(pipeDir, 0755); err != nil { - return nil, err - } - pipeName := strconv.Itoa(int(ctx.ID())) + ".pipe" - pipePath := filepath.Join(pipeDir, pipeName) - ctx.GetLogger().Traceln("Creating FIFO pipe for COPY FROM operation:", pipePath) - if err := syscall.Mkfifo(pipePath, 0600); err != nil { + duckBuilder := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder) + pipePath, err := duckBuilder.CreatePipe(ctx, "pg-copy-from") + if err != nil { return nil, err } diff --git a/pgserver/datawriter.go b/pgserver/datawriter.go index bee656f9..bbc3ed2d 100644 --- a/pgserver/datawriter.go +++ b/pgserver/datawriter.go @@ -1,22 +1,18 @@ package pgserver import ( - "context" "fmt" "os" - "path/filepath" - "strconv" - "syscall" "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/backend" + "github.com/apecloud/myduckserver/catalog" "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" "github.com/dolthub/go-mysql-server/sql" ) type DataWriter struct { ctx *sql.Context - cancel context.CancelFunc duckSQL string options *tree.CopyOptions pipePath string @@ -25,7 +21,7 @@ type DataWriter struct { func NewDataWriter( ctx *sql.Context, handler *DuckHandler, - table sql.Table, tableName tree.TableName, columns tree.NameList, + schema string, table sql.Table, columns tree.NameList, query string, options *tree.CopyOptions, ) (*DataWriter, error) { @@ -43,7 +39,10 @@ func NewDataWriter( var source string if table != nil { - source = tableName.FQString() + if schema != "" { + source += catalog.QuoteIdentifierANSI(schema) + "." + } + source += catalog.QuoteIdentifierANSI(table.Name()) if columns != nil { source += "(" + columns.String() + ")" } @@ -51,33 +50,19 @@ func NewDataWriter( source = "(" + query + ")" } - duckBuilder := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder) - dataDir := duckBuilder.Provider().DataDir() - // Create the FIFO pipe - pipeDir := filepath.Join(dataDir, "pipes", "pg-copy-to") - if err := os.MkdirAll(pipeDir, 0755); err != nil { - return nil, err - } - pipeName := strconv.Itoa(int(ctx.ID())) + ".pipe" - pipePath := filepath.Join(pipeDir, pipeName) - ctx.GetLogger().Traceln("Creating FIFO pipe for COPY TO operation:", pipePath) - if err := syscall.Mkfifo(pipePath, 0600); err != nil { + db := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder) + pipePath, err := db.CreatePipe(ctx, "pg-copy-to") + if err != nil { return nil, err } - // Create cancelable context - childCtx, cancel := context.WithCancel(ctx) - ctx.Context = childCtx - // Initialize DataWriter writer := &DataWriter{ ctx: ctx, - cancel: cancel, duckSQL: fmt.Sprintf("COPY %s TO '%s' (%s)", source, pipePath, format), options: options, pipePath: pipePath, - rowCount: make(chan int64, 1), } return writer, nil @@ -88,31 +73,28 @@ type copyToResult struct { Err error } -func (dw *DataWriter) Start() (*os.File, chan int64, error) { - // Open the pipe for reading. - pipe, err := os.OpenFile(dw.pipePath, os.O_RDONLY, os.ModeNamedPipe) - if err != nil { - return nil, nil, fmt.Errorf("failed to open pipe for reading: %w", err) - } - +func (dw *DataWriter) Start() (string, chan copyToResult, error) { + // Execute the COPY TO statement in a separate goroutine. + ch := make(chan copyToResult, 1) go func() { - defer dw.cancel() - defer pipe.Close() defer os.Remove(dw.pipePath) - defer close(dw.rowCount) + defer close(ch) + + dw.ctx.GetLogger().Tracef("Executing COPY TO statement: %s", dw.duckSQL) + // This operation will block until the reader opens the pipe for reading. result, err := adapter.ExecCatalog(dw.ctx, dw.duckSQL) if err != nil { - dw.err.Store(&err) + ch <- copyToResult{Err: err} return } affected, _ := result.RowsAffected() - dw.rowCount <- affected + ch <- copyToResult{RowCount: affected} }() - return pipe, dw.rowCount, nil + return dw.pipePath, ch, nil } -func (dw *DataWriter) Cancel() { - dw.cancel() +func (dw *DataWriter) Close() { + os.Remove(dw.pipePath) }