diff --git a/backend/pipe.go b/backend/pipe.go new file mode 100644 index 00000000..e69de29b diff --git a/pgserver/connection_data.go b/pgserver/connection_data.go index 18b6d98c..4d001579 100644 --- a/pgserver/connection_data.go +++ b/pgserver/connection_data.go @@ -18,6 +18,7 @@ import ( "fmt" "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/jackc/pgx/v5/pgproto3" "github.com/lib/pq/oid" @@ -66,6 +67,8 @@ type copyFromStdinState struct { // node is used to look at what parameters were specified, such as which table to load data into, file format, // delimiters, etc. copyFromStdinNode *tree.CopyFrom + // targetTable stores the targetTable that the data is being loaded into. + targetTable sql.InsertableTable // dataLoader is the implementation of DataLoader that is used to load each individual CopyData chunk into the // target table. dataLoader DataLoader diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index 0f8b425d..ad8e337d 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -27,16 +27,17 @@ import ( "runtime/debug" "slices" "strings" + "sync/atomic" "github.com/cockroachdb/cockroachdb-parser/pkg/sql/parser" "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" gms "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/server" - "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/mysql" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" "github.com/sirupsen/logrus" + // Import the new datawriter package ) // ConnectionHandler is responsible for the entire lifecycle of a user connection: receiving messages they send, @@ -348,7 +349,7 @@ func (h *ConnectionHandler) receiveMessage() (bool, error) { // handleMessages processes the message provided and returns status flags indicating what the connection should do next. // If the |stop| response parameter is true, it indicates that the connection should be closed by the caller. If the // |endOfMessages| response parameter is true, it indicates that no more messages are expected for the current operation -// and a READY FOR QUERY message should be sent back to the client, so it can send the next query. +// and a READY FOR QUERY message should be sent back to the client so it can send the next query. func (h *ConnectionHandler) handleMessage(msg pgproto3.Message) (stop, endOfMessages bool, err error) { logrus.Tracef("Handling message: %T", msg) switch message := msg.(type) { @@ -448,11 +449,12 @@ func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (hand if stmt.Stdin { return true, false, h.handleCopyFromStdinQuery(query, stmt, h.Conn()) } + case *tree.CopyTo: + return true, true, h.handleCopyToStdout(query, stmt) } return false, true, nil } -// handleParse handles a parse message, returning any error that occurs func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error { h.waitForSync = true @@ -662,20 +664,9 @@ func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (st if copyFrom == nil { return false, false, fmt.Errorf("no COPY FROM STDIN node found") } - - // TODO: It would be better to get the table from the copyFromStdinNode – not by calling core.GetSqlTableFromContext - schemaName := copyFrom.Table.Schema() - tableName := copyFrom.Table.Table() - table, err := GetSqlTableFromContext(sqlCtx, schemaName, tableName) - if err != nil { - return false, true, err - } + table := h.copyFromStdinState.targetTable if table == nil { - return false, true, fmt.Errorf(`relation "%s" does not exist`, tableName) - } - insertableTable, ok := table.(sql.InsertableTable) - if !ok { - return false, true, fmt.Errorf(`table "%s" is read-only`, tableName) + return false, true, fmt.Errorf("no target table found") } switch copyFrom.Options.CopyFormat { @@ -693,12 +684,15 @@ func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (st } fallthrough case tree.CopyFormatCSV: - dataLoader, err = NewCsvDataLoader(sqlCtx, h.duckHandler, schemaName, insertableTable, copyFrom.Columns, ©From.Options) + dataLoader, err = NewCsvDataLoader( + sqlCtx, h.duckHandler, + copyFrom.Table.Schema(), table, copyFrom.Columns, + ©From.Options, + ) case tree.CopyFormatBinary: err = fmt.Errorf("BINARY format is not supported for COPY FROM") default: - err = fmt.Errorf("unknown format specified for COPY FROM: %v", - copyFrom.Options.CopyFormat) + err = fmt.Errorf("unknown format specified for COPY FROM: %v", copyFrom.Options.CopyFormat) } if err != nil { @@ -1097,12 +1091,14 @@ func (h *ConnectionHandler) handleCopyFromStdinQuery(query ConvertedQuery, copyF } sqlCtx.SetLogger(sqlCtx.GetLogger().WithField("query", query.String)) - if err := ValidateCopyFrom(copyFrom, sqlCtx); err != nil { + table, err := ValidateCopyFrom(copyFrom, sqlCtx) + if err != nil { return err } h.copyFromStdinState = ©FromStdinState{ copyFromStdinNode: copyFrom, + targetTable: table, } return h.send(&pgproto3.CopyInResponse{ @@ -1140,3 +1136,90 @@ func returnsRow(tag string) bool { return false } } + +func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tree.CopyTo) error { + ctx, err := h.duckHandler.NewContext(context.Background(), h.mysqlConn, query.String) + if err != nil { + return err + } + ctx.SetLogger(ctx.GetLogger().WithField("query", query.String)) + + table, err := ValidateCopyTo(copyTo, ctx) + if err != nil { + return err + } + + var stmt string + if copyTo.Statement != nil { + stmt = copyTo.Statement.String() + } + dataWriter, err := NewDataWriter(ctx, h.duckHandler, table, copyTo.Table, copyTo.Columns, stmt, ©To.Options) + if err != nil { + return err + } + defer dataWriter.Cancel() + + // Send CopyOutResponse to the client + copyOutResponse := &pgproto3.CopyOutResponse{ + OverallFormat: 0, // 0 for text format + } + if err := h.send(copyOutResponse); err != nil { + return err + } + + // Create a channel to receive the result from the goroutine + type copyResult struct { + rowCount int + err error + } + done := make(chan copyResult, 1) + + pipe, ch, err := dataWriter.Start() + if err != nil { + return err + } + var sendErr atomic.Value + go func() { + defer pipe.Close() + defer close(done) + buf := make([]byte, 1<<20) // 1MB buffer + for { + n, err := pipe.Read(buf) + if n > 0 { + copyData := &pgproto3.CopyData{ + Data: buf[:n], + } + if err := h.send(copyData); err != nil { + sendErr.Store(err) + return + } + } + if err != nil { + if err == io.EOF { + break + } + sendErr.Store(err) + return + } + } + }() + + result := <-ch + if result.Err != nil { + return fmt.Errorf("failed to copy data: %w", result.Err) + } + + if err := sendErr.Load(); err != nil { + return err.(error) + } + + // After data is sent and CopyToPipe is finished without errors, send CopyDone + if err := h.send(&pgproto3.CopyDone{}); err != nil { + return err + } + + // Send CommandComplete with the number of rows copied + return h.send(&pgproto3.CommandComplete{ + CommandTag: []byte(fmt.Sprintf("COPY %d", result.RowCount)), + }) +} diff --git a/pgserver/dataloader.go b/pgserver/dataloader.go index fef7750b..5202c375 100644 --- a/pgserver/dataloader.go +++ b/pgserver/dataloader.go @@ -63,28 +63,28 @@ type CsvDataLoader struct { var _ DataLoader = (*CsvDataLoader)(nil) -func NewCsvDataLoader(sqlCtx *sql.Context, handler *DuckHandler, schema string, table sql.InsertableTable, columns tree.NameList, options *tree.CopyOptions) (DataLoader, error) { +func NewCsvDataLoader(ctx *sql.Context, handler *DuckHandler, schema string, table sql.InsertableTable, columns tree.NameList, options *tree.CopyOptions) (DataLoader, error) { duckBuilder := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder) dataDir := duckBuilder.Provider().DataDir() // Create the FIFO pipe - pipeDir := filepath.Join(dataDir, "pipes", "load-data") + pipeDir := filepath.Join(dataDir, "pipes", "pg-copy-from") if err := os.MkdirAll(pipeDir, 0755); err != nil { return nil, err } - pipeName := strconv.Itoa(int(sqlCtx.ID())) + ".pipe" + pipeName := strconv.Itoa(int(ctx.ID())) + ".pipe" pipePath := filepath.Join(pipeDir, pipeName) - sqlCtx.GetLogger().Traceln("Creating FIFO pipe for COPY operation:", pipePath) + ctx.GetLogger().Traceln("Creating FIFO pipe for COPY FROM operation:", pipePath) if err := syscall.Mkfifo(pipePath, 0600); err != nil { return nil, err } // Create cancelable context - childCtx, cancel := context.WithCancel(sqlCtx) - sqlCtx.Context = childCtx + childCtx, cancel := context.WithCancel(ctx) + ctx.Context = childCtx loader := &CsvDataLoader{ - ctx: sqlCtx, + ctx: ctx, cancel: cancel, schema: schema, table: table, @@ -103,7 +103,7 @@ func NewCsvDataLoader(sqlCtx *sql.Context, handler *DuckHandler, schema string, // Open the pipe for writing. // This operation will block until the reader opens the pipe for reading. - pipe, err := os.OpenFile(pipePath, os.O_WRONLY, 0600) + pipe, err := os.OpenFile(pipePath, os.O_WRONLY, os.ModeNamedPipe) if err != nil { return nil, err } diff --git a/pgserver/datawriter.go b/pgserver/datawriter.go new file mode 100644 index 00000000..bee656f9 --- /dev/null +++ b/pgserver/datawriter.go @@ -0,0 +1,118 @@ +package pgserver + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strconv" + "syscall" + + "github.com/apecloud/myduckserver/adapter" + "github.com/apecloud/myduckserver/backend" + "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" + "github.com/dolthub/go-mysql-server/sql" +) + +type DataWriter struct { + ctx *sql.Context + cancel context.CancelFunc + duckSQL string + options *tree.CopyOptions + pipePath string +} + +func NewDataWriter( + ctx *sql.Context, + handler *DuckHandler, + table sql.Table, tableName tree.TableName, columns tree.NameList, + query string, + options *tree.CopyOptions, +) (*DataWriter, error) { + // https://www.postgresql.org/docs/current/sql-copy.html + // https://duckdb.org/docs/sql/statements/copy.html#csv-options + var format string + switch options.CopyFormat { + case tree.CopyFormatText: + format = `FORMAT CSV, DELIMITER '\t', QUOTE '', ESCAPE '', NULLSTR '\N'` + case tree.CopyFormatCSV: + format = `FORMAT CSV` + case tree.CopyFormatBinary: + return nil, fmt.Errorf("BINARY format is not supported for COPY TO") + } + + var source string + if table != nil { + source = tableName.FQString() + if columns != nil { + source += "(" + columns.String() + ")" + } + } else { + source = "(" + query + ")" + } + + duckBuilder := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder) + dataDir := duckBuilder.Provider().DataDir() + + // Create the FIFO pipe + pipeDir := filepath.Join(dataDir, "pipes", "pg-copy-to") + if err := os.MkdirAll(pipeDir, 0755); err != nil { + return nil, err + } + pipeName := strconv.Itoa(int(ctx.ID())) + ".pipe" + pipePath := filepath.Join(pipeDir, pipeName) + ctx.GetLogger().Traceln("Creating FIFO pipe for COPY TO operation:", pipePath) + if err := syscall.Mkfifo(pipePath, 0600); err != nil { + return nil, err + } + + // Create cancelable context + childCtx, cancel := context.WithCancel(ctx) + ctx.Context = childCtx + + // Initialize DataWriter + writer := &DataWriter{ + ctx: ctx, + cancel: cancel, + duckSQL: fmt.Sprintf("COPY %s TO '%s' (%s)", source, pipePath, format), + options: options, + pipePath: pipePath, + rowCount: make(chan int64, 1), + } + + return writer, nil +} + +type copyToResult struct { + RowCount int64 + Err error +} + +func (dw *DataWriter) Start() (*os.File, chan int64, error) { + // Open the pipe for reading. + pipe, err := os.OpenFile(dw.pipePath, os.O_RDONLY, os.ModeNamedPipe) + if err != nil { + return nil, nil, fmt.Errorf("failed to open pipe for reading: %w", err) + } + + go func() { + defer dw.cancel() + defer pipe.Close() + defer os.Remove(dw.pipePath) + defer close(dw.rowCount) + // This operation will block until the reader opens the pipe for reading. + result, err := adapter.ExecCatalog(dw.ctx, dw.duckSQL) + if err != nil { + dw.err.Store(&err) + return + } + affected, _ := result.RowsAffected() + dw.rowCount <- affected + }() + + return pipe, dw.rowCount, nil +} + +func (dw *DataWriter) Cancel() { + dw.cancel() +} diff --git a/pgserver/validate.go b/pgserver/validate.go index 5e8202f4..e61c4e4d 100644 --- a/pgserver/validate.go +++ b/pgserver/validate.go @@ -16,40 +16,42 @@ package pgserver import ( "fmt" - "strings" "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" "github.com/dolthub/go-mysql-server/sql" ) -// Validate returns an error if the CopyFrom node is invalid, for example if it contains columns that -// are not in the table schema. -func ValidateCopyFrom(cf *tree.CopyFrom, ctx *sql.Context) error { +// ValidateCopyFrom returns an error if the CopyFrom node is invalid. +func ValidateCopyFrom(cf *tree.CopyFrom, ctx *sql.Context) (sql.InsertableTable, error) { table, err := GetSqlTableFromContext(ctx, cf.Table.Schema(), cf.Table.Table()) if err != nil { - return err + return nil, err } if table == nil { - return fmt.Errorf(`relation "%s" does not exist`, cf.Table.Table()) + return nil, fmt.Errorf(`relation "%s" does not exist`, cf.Table.Table()) } - if _, ok := table.(sql.InsertableTable); !ok { - return fmt.Errorf(`table "%s" is read-only`, cf.Table.Table()) + if it, ok := table.(sql.InsertableTable); !ok { + return nil, fmt.Errorf(`table "%s" is read-only`, cf.Table.Table()) + } else { + return it, nil } +} - // If a set of columns was explicitly specified, validate them - if len(cf.Columns) > 0 { - if len(table.Schema()) != len(cf.Columns) { - return fmt.Errorf("invalid column name list for table %s: %v", table.Name(), cf.Columns) - } - - for i, col := range table.Schema() { - name := cf.Columns[i] - nameString := strings.Trim(name.String(), `"`) - if nameString != col.Name { - return fmt.Errorf("invalid column name list for table %s: %v", table.Name(), cf.Columns) - } +// ValidateCopyTo returns an error if the CopyTo node is invalid, for example if it contains columns that +// are not in the table schema. +func ValidateCopyTo(ct *tree.CopyTo, ctx *sql.Context) (sql.Table, error) { + if ct.Table.Table() == "" { + if ct.Statement == nil { + return nil, fmt.Errorf("no table specified") } + return nil, nil } - - return nil + table, err := GetSqlTableFromContext(ctx, ct.Table.Schema(), ct.Table.Table()) + if err != nil { + return nil, err + } + if table == nil { + return nil, fmt.Errorf(`relation "%s" does not exist`, ct.Table.Table()) + } + return table, nil }