From 60e8972fc88b2491ad30b0cba63164eab4f93851 Mon Sep 17 00:00:00 2001 From: Fan Yang Date: Fri, 22 Nov 2024 21:17:27 +0800 Subject: [PATCH] feat: arrow in, arrow out --- .github/workflows/psql.yml | 2 +- .gitignore | 1 + pgserver/arrowloader.go | 147 ++++++++++++++ pgserver/arrowwriter.go | 129 ++++++++++++ pgserver/connection_data.go | 4 + pgserver/connection_handler.go | 88 +++++++-- pgserver/copy.go | 197 +++++++++++++++++-- pgserver/copy_test.go | 349 +++++++++++++++++++++++++++++++++ pgserver/dataloader.go | 199 +++++++++++-------- pgserver/datawriter.go | 33 ++-- pgtest/psql/copy/arrow.sql | 19 ++ 11 files changed, 1033 insertions(+), 135 deletions(-) create mode 100644 pgserver/arrowloader.go create mode 100644 pgserver/arrowwriter.go create mode 100644 pgserver/copy_test.go create mode 100644 pgtest/psql/copy/arrow.sql diff --git a/.github/workflows/psql.yml b/.github/workflows/psql.yml index 0b567241..6b822101 100644 --- a/.github/workflows/psql.yml +++ b/.github/workflows/psql.yml @@ -27,7 +27,7 @@ jobs: run: | go get . - pip3 install "sqlglot[rs]" + pip3 install "sqlglot[rs]" pyarrow pandas curl -LJO https://github.com/duckdb/duckdb/releases/download/v1.1.3/duckdb_cli-linux-amd64.zip unzip duckdb_cli-linux-amd64.zip diff --git a/.gitignore b/.gitignore index 2416eb99..09f18562 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,4 @@ __debug_* .DS_Store *.csv *.parquet +*.arrow diff --git a/pgserver/arrowloader.go b/pgserver/arrowloader.go new file mode 100644 index 00000000..ef083215 --- /dev/null +++ b/pgserver/arrowloader.go @@ -0,0 +1,147 @@ +package pgserver + +import ( + "context" + "os" + "strconv" + "strings" + + "github.com/apache/arrow-go/v18/arrow/ipc" + "github.com/apecloud/myduckserver/adapter" + "github.com/apecloud/myduckserver/backend" + "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" + "github.com/dolthub/go-mysql-server/sql" + "github.com/marcboeker/go-duckdb" +) + +type ArrowDataLoader struct { + PipeDataLoader + arrowName string + options string +} + +var _ DataLoader = (*ArrowDataLoader)(nil) + +func NewArrowDataLoader(ctx *sql.Context, handler *DuckHandler, schema string, table sql.InsertableTable, columns tree.NameList, options string) (DataLoader, error) { + // Create the FIFO pipe + duckBuilder := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder) + pipePath, err := duckBuilder.CreatePipe(ctx, "pg-from-arrow") + if err != nil { + return nil, err + } + arrowName := "__sys_copy_from_arrow_" + strconv.Itoa(int(ctx.ID())) + "__" + + // Create cancelable context + childCtx, cancel := context.WithCancel(ctx) + ctx.Context = childCtx + + loader := &ArrowDataLoader{ + PipeDataLoader: PipeDataLoader{ + ctx: ctx, + cancel: cancel, + schema: schema, + table: table, + columns: columns, + pipePath: pipePath, + rowCount: make(chan int64, 1), + logger: ctx.GetLogger(), + }, + arrowName: arrowName, + options: options, + } + loader.read = func() { + loader.executeInsert(loader.buildSQL(), pipePath) + } + + return loader, nil +} + +// buildSQL builds the DuckDB INSERT statement. +func (loader *ArrowDataLoader) buildSQL() string { + var b strings.Builder + b.Grow(256) + + b.WriteString("INSERT INTO ") + if loader.schema != "" { + b.WriteString(loader.schema) + b.WriteString(".") + } + b.WriteString(loader.table.Name()) + + if len(loader.columns) > 0 { + b.WriteString(" (") + b.WriteString(loader.columns.String()) + b.WriteString(")") + } + + b.WriteString(" FROM ") + b.WriteString(loader.arrowName) + + return b.String() +} + +func (loader *ArrowDataLoader) executeInsert(sql string, pipePath string) { + defer close(loader.rowCount) + + // Open the pipe for reading. + loader.logger.Debugf("Opening pipe for reading: %s", pipePath) + pipe, err := os.OpenFile(pipePath, os.O_RDONLY, os.ModeNamedPipe) + if err != nil { + loader.err.Store(&err) + // Open the pipe once to unblock the writer + pipe, _ = os.OpenFile(pipePath, os.O_RDONLY, os.ModeNamedPipe) + loader.errPipe.Store(pipe) + return + } + + // Create an Arrow IPC reader from the pipe. + loader.logger.Debugf("Creating Arrow IPC reader from pipe: %s", pipePath) + arrowReader, err := ipc.NewReader(pipe) + if err != nil { + loader.err.Store(&err) + return + } + defer arrowReader.Release() + + conn, err := adapter.GetConn(loader.ctx) + if err != nil { + loader.err.Store(&err) + return + } + + // Register the Arrow IPC reader to DuckDB. + loader.logger.Debugf("Registering Arrow IPC reader into DuckDB: %s", loader.arrowName) + var release func() + if err := conn.Raw(func(driverConn any) error { + conn := driverConn.(*duckdb.Conn) + arrow, err := duckdb.NewArrowFromConn(conn) + if err != nil { + return err + } + + release, err = arrow.RegisterView(arrowReader, loader.arrowName) + return err + }); err != nil { + loader.err.Store(&err) + return + } + defer release() + + // Execute the INSERT statement. + // This will block until the reader has finished reading the data. + loader.logger.Debugln("Executing SQL:", sql) + result, err := conn.ExecContext(loader.ctx, sql) + if err != nil { + loader.err.Store(&err) + return + } + + rows, err := result.RowsAffected() + if err != nil { + loader.err.Store(&err) + return + } + + loader.logger.Debugf("Inserted %d rows", rows) + loader.rowCount <- rows +} diff --git a/pgserver/arrowwriter.go b/pgserver/arrowwriter.go new file mode 100644 index 00000000..c705d531 --- /dev/null +++ b/pgserver/arrowwriter.go @@ -0,0 +1,129 @@ +package pgserver + +import ( + "os" + "strings" + + "github.com/apache/arrow-go/v18/arrow/ipc" + "github.com/apecloud/myduckserver/adapter" + "github.com/apecloud/myduckserver/backend" + "github.com/apecloud/myduckserver/catalog" + "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" + "github.com/dolthub/go-mysql-server/sql" + "github.com/marcboeker/go-duckdb" +) + +type ArrowWriter struct { + ctx *sql.Context + duckSQL string + pipePath string + rawOptions string +} + +func NewArrowWriter( + ctx *sql.Context, + handler *DuckHandler, + schema string, table sql.Table, columns tree.NameList, + query string, + rawOptions string, +) (*ArrowWriter, error) { + // Create the FIFO pipe + db := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder) + pipePath, err := db.CreatePipe(ctx, "pg-to-arrow") + if err != nil { + return nil, err + } + + var builder strings.Builder + builder.Grow(128) + + if table != nil { + // https://duckdb.org/docs/sql/query_syntax/from.html#from-first-syntax + // FROM table_name [ SELECT column_list ] + builder.WriteString("FROM ") + if schema != "" { + builder.WriteString(catalog.QuoteIdentifierANSI(schema)) + builder.WriteString(".") + } + builder.WriteString(catalog.QuoteIdentifierANSI(table.Name())) + if columns != nil { + builder.WriteString(" SELECT ") + builder.WriteString(columns.String()) + } + } else { + builder.WriteString(query) + } + + return &ArrowWriter{ + ctx: ctx, + duckSQL: builder.String(), + pipePath: pipePath, + rawOptions: rawOptions, // TODO(fan): parse rawOptions + }, nil +} + +func (dw *ArrowWriter) Start() (string, chan CopyToResult, error) { + // Execute the statement in a separate goroutine. + ch := make(chan CopyToResult, 1) + go func() { + defer os.Remove(dw.pipePath) + defer close(ch) + + dw.ctx.GetLogger().Tracef("Executing statement via Arrow interface: %s", dw.duckSQL) + conn, err := adapter.GetConn(dw.ctx) + if err != nil { + ch <- CopyToResult{Err: err} + return + } + + // Open the pipe for writing. + // This operation will block until the reader opens the pipe for reading. + pipe, err := os.OpenFile(dw.pipePath, os.O_WRONLY, os.ModeNamedPipe) + if err != nil { + ch <- CopyToResult{Err: err} + return + } + defer pipe.Close() + + rowCount := int64(0) + + if err := conn.Raw(func(driverConn any) error { + conn := driverConn.(*duckdb.Conn) + arrow, err := duckdb.NewArrowFromConn(conn) + if err != nil { + return err + } + + // TODO(fan): Currently, this API materializes the entire result set in memory. + // We should consider modifying the API to allow streaming the result set. + recordReader, err := arrow.QueryContext(dw.ctx, dw.duckSQL) + if err != nil { + return err + } + defer recordReader.Release() + + writer := ipc.NewWriter(pipe, ipc.WithSchema(recordReader.Schema())) + defer writer.Close() + + for recordReader.Next() { + record := recordReader.Record() + rowCount += record.NumRows() + if err := writer.Write(record); err != nil { + return err + } + } + return recordReader.Err() + }); err != nil { + ch <- CopyToResult{Err: err} + return + } + + ch <- CopyToResult{RowCount: rowCount} + }() + + return dw.pipePath, ch, nil +} + +func (dw *ArrowWriter) Close() { + os.Remove(dw.pipePath) +} diff --git a/pgserver/connection_data.go b/pgserver/connection_data.go index 4d001579..c293b417 100644 --- a/pgserver/connection_data.go +++ b/pgserver/connection_data.go @@ -69,6 +69,10 @@ type copyFromStdinState struct { copyFromStdinNode *tree.CopyFrom // targetTable stores the targetTable that the data is being loaded into. targetTable sql.InsertableTable + + // For non-PG-parsable COPY FROM + rawOptions string + // dataLoader is the implementation of DataLoader that is used to load each individual CopyData chunk into the // target table. dataLoader DataLoader diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index 53f53da5..840829e9 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -15,7 +15,6 @@ package pgserver import ( - "bufio" "bytes" "context" "crypto/tls" @@ -478,6 +477,8 @@ func (h *ConnectionHandler) handleQuery(message *pgproto3.Query) (endOfMessages handled, endOfMessages, err = h.handleQueryOutsideEngine(query) if handled { return endOfMessages, err + } else if err != nil { + h.logger.Warnf("Failed to handle query %v outside engine: %v", query, err) } return true, h.query(query) @@ -600,15 +601,35 @@ func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (hand // We send endOfMessages=false since the server will be in COPY DATA mode and won't // be ready for more queries util COPY DATA mode is completed. if stmt.Stdin { - return true, false, h.handleCopyFromStdinQuery(query, stmt, h.Conn()) + return true, false, h.handleCopyFromStdinQuery(query, stmt, "") } case *tree.CopyTo: - return true, true, h.handleCopyToStdout(query, stmt, "" /* unused */, tree.CopyFormatBinary, "") + return true, true, h.handleCopyToStdout(query, stmt, "" /* unused */, stmt.Options.CopyFormat, "") } if query.StatementTag == "COPY" { - if subquery, format, options, ok := ParseCopy(query.String); ok { - return true, true, h.handleCopyToStdout(query, nil, subquery, format, options) + if target, format, options, ok := ParseCopyFrom(query.String); ok { + stmt, err := parser.ParseOne("COPY " + target + " FROM STDIN") + if err != nil { + return false, true, err + } + copyFrom := stmt.AST.(*tree.CopyFrom) + copyFrom.Options.CopyFormat = format + return true, false, h.handleCopyFromStdinQuery(query, copyFrom, options) + } + if subquery, format, options, ok := ParseCopyTo(query.String); ok { + if strings.HasPrefix(subquery, "(") && strings.HasSuffix(subquery, ")") { + // subquery may be richer than Postgres supports, so we just pass it as a string + return true, true, h.handleCopyToStdout(query, nil, subquery, format, options) + } + // subquery is "table [(column_list)]", so we can parse it and pass the AST + stmt, err := parser.ParseOne("COPY " + subquery + " TO STDOUT") + if err != nil { + return false, true, err + } + copyTo := stmt.AST.(*tree.CopyTo) + copyTo.Options.CopyFormat = format + return true, true, h.handleCopyToStdout(query, copyTo, "", format, options) } } @@ -829,8 +850,15 @@ func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (st if table == nil { return false, true, fmt.Errorf("no target table found") } + rawOptions := h.copyFromStdinState.rawOptions switch copyFrom.Options.CopyFormat { + case CopyFormatArrow: + dataLoader, err = NewArrowDataLoader( + sqlCtx, h.duckHandler, + copyFrom.Table.Schema(), table, copyFrom.Columns, + rawOptions, + ) case tree.CopyFormatText: // Remove trailing backslash, comma and newline characters from the data if bytes.HasSuffix(message.Data, []byte{'\n'}) { @@ -848,6 +876,7 @@ func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (st sqlCtx, h.duckHandler, copyFrom.Table.Schema(), table, copyFrom.Columns, ©From.Options, + rawOptions, ) case tree.CopyFormatBinary: err = fmt.Errorf("BINARY format is not supported for COPY FROM") @@ -859,12 +888,15 @@ func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (st return false, false, err } + ready := dataLoader.Start() + if err, hasErr := <-ready; hasErr { + return false, false, err + } + h.copyFromStdinState.dataLoader = dataLoader } - byteReader := bytes.NewReader(message.Data) - reader := bufio.NewReader(byteReader) - if err = dataLoader.LoadChunk(sqlCtx, reader); err != nil { + if err = dataLoader.LoadChunk(sqlCtx, message.Data); err != nil { return false, false, err } @@ -1248,7 +1280,10 @@ func (h *ConnectionHandler) discardAll(query ConvertedQuery) error { // handleCopyFromStdinQuery handles the COPY FROM STDIN query at the Doltgres layer, without passing it to the engine. // COPY FROM STDIN can't be handled directly by the GMS engine, since COPY FROM STDIN relies on multiple messages sent // over the wire. -func (h *ConnectionHandler) handleCopyFromStdinQuery(query ConvertedQuery, copyFrom *tree.CopyFrom, conn net.Conn) error { +func (h *ConnectionHandler) handleCopyFromStdinQuery( + query ConvertedQuery, copyFrom *tree.CopyFrom, + rawOptions string, // For non-PG-parseable COPY FROM +) error { sqlCtx, err := h.duckHandler.NewContext(context.Background(), h.mysqlConn, query.String) if err != nil { return err @@ -1263,10 +1298,19 @@ func (h *ConnectionHandler) handleCopyFromStdinQuery(query ConvertedQuery, copyF h.copyFromStdinState = ©FromStdinState{ copyFromStdinNode: copyFrom, targetTable: table, + rawOptions: rawOptions, + } + + var format byte + switch copyFrom.Options.CopyFormat { + case tree.CopyFormatText, tree.CopyFormatCSV, CopyFormatJSON: + format = 0 // text format + default: + format = 1 // binary format } return h.send(&pgproto3.CopyInResponse{ - OverallFormat: 0, + OverallFormat: format, }) } @@ -1342,12 +1386,24 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tre } } - writer, err := NewDataWriter( - ctx, h.duckHandler, - schema, table, columns, - stmt, - options, rawOptions, - ) + var writer DataWriter + + switch format { + case CopyFormatArrow: + writer, err = NewArrowWriter( + ctx, h.duckHandler, + schema, table, columns, + stmt, + rawOptions, + ) + default: + writer, err = NewDuckDataWriter( + ctx, h.duckHandler, + schema, table, columns, + stmt, + options, rawOptions, + ) + } if err != nil { return err } diff --git a/pgserver/copy.go b/pgserver/copy.go index 9911df76..129b8704 100644 --- a/pgserver/copy.go +++ b/pgserver/copy.go @@ -1,8 +1,11 @@ package pgserver import ( + "fmt" "regexp" + "strconv" "strings" + "unicode" "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" "github.com/dolthub/go-mysql-server/sql" @@ -11,17 +14,41 @@ import ( const ( CopyFormatParquet = tree.CopyFormatCSV + 1 CopyFormatJSON = tree.CopyFormatCSV + 2 + CopyFormatArrow = tree.CopyFormatCSV + 3 ) var ( - // We are supporting the parquet/... formats for COPY TO, but + // We are supporting the parquet|json|arrow formats for COPY TO, but // COPY ... TO STDOUT [WITH] (FORMAT PARQUET, OPT1 v1, OPT2, OPT3 v3, ...) + // cannot be parsed by the PG parser. // Let's match them with regex and extract the ... part. - // Update regex to capture FORMAT and other options reCopyToFormat = regexp.MustCompile(`(?i)^COPY\s+(.*?)\s+TO\s+STDOUT(?:\s+(?:WITH\s*)?\(\s*(?:FORMAT\s+(\w+)\s*,?\s*)?(.*?)\s*\))?$`) + // Also for COPY ... FROM STDIN [WITH] (FORMAT PARQUET, OPT1 v1, OPT2, OPT3 v3, ...) + reCopyFromFormat = regexp.MustCompile(`(?i)^COPY\s+(.*?)\s+FROM\s+STDIN(?:\s+(?:WITH\s*)?\(\s*(?:FORMAT\s+(\w+)\s*,?\s*)?(.*?)\s*\))?$`) ) -func ParseCopy(stmt string) (query string, format tree.CopyFormat, options string, ok bool) { +func ParseFormat(s string) (format tree.CopyFormat, ok bool) { + switch strings.ToUpper(s) { + case "PARQUET": + format = CopyFormatParquet + case "JSON": + format = CopyFormatJSON + case "ARROW": + format = CopyFormatArrow + case "CSV": + format = tree.CopyFormatCSV + case "BINARY": + format = tree.CopyFormatBinary + case "", "TEXT": + format = tree.CopyFormatText + default: + return 0, false + } + return format, true +} + +// ParseCopyTo parses a COPY TO statement and returns the query, format, and options. +func ParseCopyTo(stmt string) (query string, format tree.CopyFormat, options string, ok bool) { stmt = RemoveComments(stmt) stmt = sql.RemoveSpaceAndDelimiter(stmt, ';') m := reCopyToFormat.FindStringSubmatch(stmt) @@ -29,28 +56,156 @@ func ParseCopy(stmt string) (query string, format tree.CopyFormat, options strin return "", 0, "", false } query = strings.TrimSpace(m[1]) + format, ok = ParseFormat(strings.TrimSpace(m[2])) + options = strings.TrimSpace(m[3]) + return +} - var formatStr string - if m[2] != "" { - formatStr = strings.ToUpper(m[2]) - } else { - formatStr = "TEXT" +// ParseCopyFrom parses a COPY FROM statement and returns the query, format, and options. +func ParseCopyFrom(stmt string) (target string, format tree.CopyFormat, options string, ok bool) { + stmt = RemoveComments(stmt) + stmt = sql.RemoveSpaceAndDelimiter(stmt, ';') + m := reCopyFromFormat.FindStringSubmatch(stmt) + if m == nil { + return "", 0, "", false } - + target = strings.TrimSpace(m[1]) + format, ok = ParseFormat(strings.TrimSpace(m[2])) options = strings.TrimSpace(m[3]) + return +} - switch formatStr { - case "PARQUET": - format = CopyFormatParquet - case "JSON": - format = CopyFormatJSON - case "CSV": - format = tree.CopyFormatCSV - case "BINARY": - format = tree.CopyFormatBinary - case "", "TEXT": - format = tree.CopyFormatText +type OptionValueType uint8 + +const ( + OptionValueTypeBool OptionValueType = iota // bool + OptionValueTypeInt // int + OptionValueTypeFloat // float64 + OptionValueTypeString // string +) + +// ParseCopyOptions parses the options string and returns the CopyOptions. +// The options string is a comma-separated list of key-value pairs: `OPT1 1, OPT2, OPT3 'v3', OPT4 E'v4', ...`. +// The allowed map specifies the allowed options and their types. Its keys are the option names in uppercase. +func ParseCopyOptions(options string, allowed map[string]OptionValueType) (result map[string]any, err error) { + result = make(map[string]any) + var key, value string + inQuotes := false + expectComma := false + readingKey := true + var sb strings.Builder + + parseOption := func() error { + k := strings.TrimSpace(key) + if k == "" { + return nil + } + k = strings.ToUpper(k) + if _, ok := allowed[k]; !ok { + return fmt.Errorf("unsupported option: %s", k) + } + v := strings.TrimSpace(value) + + switch allowed[k] { + case OptionValueTypeBool: + if v == "" { + result[k] = true + } else { + val, err := strconv.ParseBool(v) + if err != nil { + return fmt.Errorf("invalid bool value for %s: %v", k, err) + } + result[k] = val + } + case OptionValueTypeInt: + val, err := strconv.Atoi(v) + if err != nil { + return fmt.Errorf("invalid int value for %s: %v", k, err) + } + result[k] = val + case OptionValueTypeFloat: + val, err := strconv.ParseFloat(v, 64) + if err != nil { + return fmt.Errorf("invalid float value for %s: %v", k, err) + } + result[k] = val + case OptionValueTypeString: + if strings.HasPrefix(v, `E'`) && strings.HasSuffix(v, `'`) { + // Remove the 'E' prefix and unescape the value + unquoted, err := strconv.Unquote(`"` + v[2:len(v)-1] + `"`) + if err != nil { + return fmt.Errorf("invalid escaped string value for %s: %v", k, err) + } + v = unquoted + } else if strings.HasPrefix(v, "'") && strings.HasSuffix(v, "'") { + // Trim the single quotes + v = v[1 : len(v)-1] + // Replace double single quotes with a single quote + v = strings.ReplaceAll(v, "''", "'") + } else { + return fmt.Errorf("invalid string value for %s: %q", k, v) + } + result[k] = v + } + key, value = "", "" + readingKey = true + return nil + } + + for i, c := range options { + if expectComma && c != ',' && !unicode.IsSpace(c) { + return nil, fmt.Errorf("expected comma before %q", options[i:]) + } + switch c { + case '\'': + inQuotes = !inQuotes + sb.WriteRune(c) + case ',': + if !inQuotes { + if readingKey { + key = sb.String() + } else { + value = sb.String() + } + if err := parseOption(); err != nil { + return nil, err + } + expectComma = false + sb.Reset() + } else { + sb.WriteRune(c) + } + default: + if unicode.IsSpace(c) { + if !inQuotes { + if sb.Len() > 0 { + if readingKey { + key = sb.String() + sb.Reset() + readingKey = false + } else { + expectComma = true + } + } + } else { + sb.WriteRune(c) + } + } else { + sb.WriteRune(c) + } + } + } + + if sb.Len() > 0 { + if readingKey { + key = sb.String() + } else { + value = sb.String() + } + if err := parseOption(); err != nil { + return nil, err + } } - return query, format, options, true + return result, nil } diff --git a/pgserver/copy_test.go b/pgserver/copy_test.go new file mode 100644 index 00000000..ae4801ad --- /dev/null +++ b/pgserver/copy_test.go @@ -0,0 +1,349 @@ +package pgserver + +import ( + "testing" +) + +func TestParseCopyOptions(t *testing.T) { + tests := []struct { + name string + options string + allowed map[string]OptionValueType + expected map[string]any + wantErr bool + }{ + { + name: "Valid options with different types", + options: "OPT1 1, OPT2, OPT3 'v3', OPT4 E'v4', OPT5 3.14", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeInt, + "OPT2": OptionValueTypeBool, + "OPT3": OptionValueTypeString, + "OPT4": OptionValueTypeString, + "OPT5": OptionValueTypeFloat, + }, + expected: map[string]any{ + "OPT1": 1, + "OPT2": true, + "OPT3": "v3", + "OPT4": "v4", + "OPT5": 3.14, + }, + wantErr: false, + }, + { + name: "Unsupported option", + options: "OPT1 1, OPT6 'unsupported'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeInt, + }, + expected: nil, + wantErr: true, + }, + { + name: "Invalid int value", + options: "OPT1 'invalid'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeInt, + }, + expected: nil, + wantErr: true, + }, + { + name: "Invalid float value", + options: "OPT1 'invalid'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeFloat, + }, + expected: nil, + wantErr: true, + }, + { + name: "Boolean option with explicit value", + options: "OPT1 false", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeBool, + }, + expected: map[string]any{ + "OPT1": false, + }, + wantErr: false, + }, + { + name: "Boolean option without value", + options: "OPT1", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeBool, + }, + expected: map[string]any{ + "OPT1": true, + }, + wantErr: false, + }, + { + name: "String option with escaped quotes", + options: "OPT1 'value with ''escaped'' quotes'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": "value with 'escaped' quotes", + }, + wantErr: false, + }, + { + name: "String option with E-escaped newlines", + options: "OPT1 E'line1\\nline2'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": "line1\nline2", + }, + wantErr: false, + }, + { + name: "String option with escaped backslash", + options: "OPT1 E'\\\\path\\\\to\\\\file'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": "\\path\\to\\file", + }, + wantErr: false, + }, + { + name: "String option with unbalanced quotes", + options: "OPT1 'unbalanced", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: nil, + wantErr: true, + }, + { + name: "String option with spaces", + options: "OPT1 'value with spaces'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": "value with spaces", + }, + wantErr: false, + }, + { + name: "String option with special characters", + options: "OPT1 'special !@#$%^&*() characters'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": "special !@#$%^&*() characters", + }, + wantErr: false, + }, + { + name: "String option with double backslashes", + options: "OPT1 E'\\\\\\\\'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": "\\\\", + }, + wantErr: false, + }, + { + name: "Mixed options with quoted and unquoted values", + options: "OPT1 'string value', OPT2 42, OPT3 E'escaped\\tvalue', OPT4 true", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + "OPT2": OptionValueTypeInt, + "OPT3": OptionValueTypeString, + "OPT4": OptionValueTypeBool, + }, + expected: map[string]any{ + "OPT1": "string value", + "OPT2": 42, + "OPT3": "escaped\tvalue", + "OPT4": true, + }, + wantErr: false, + }, + { + name: "Option with empty string value", + options: "OPT1 ''", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": "", + }, + wantErr: false, + }, + { + name: "Option with only spaces in quotes", + options: "OPT1 ' '", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": " ", + }, + wantErr: false, + }, + { + name: "String option with escaped single quote", + options: "OPT1 'it''s a test'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": "it's a test", + }, + wantErr: false, + }, + { + name: "Option with missing value after E prefix", + options: "OPT1 E", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: nil, + wantErr: true, + }, + { + name: "Option with invalid escape sequence", + options: "OPT1 E'\\x'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: nil, + wantErr: true, + }, + { + name: "Option with numeric value in quotes", + options: "OPT1 '123'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeInt, + }, + expected: nil, + wantErr: true, + }, + { + name: "Option with float value in quotes", + options: "OPT1 '123.456'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeFloat, + }, + expected: nil, + wantErr: true, + }, + { + name: "Option with multiple consecutive commas", + options: "OPT1 'value1',, OPT2 'value2'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + "OPT2": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": "value1", + "OPT2": "value2", + }, + wantErr: false, + }, + { + name: "Option with missing comma between options", + options: "OPT1 'value1' OPT2 'value2'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + "OPT2": OptionValueTypeString, + }, + expected: nil, + wantErr: true, + }, + { + name: "Option with excess whitespace", + options: " OPT1 'value1' , OPT2 'value2' ", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + "OPT2": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": "value1", + "OPT2": "value2", + }, + wantErr: false, + }, + { + name: "Option with tab and newline characters", + options: "OPT1 E'line1\\nline2\\tend'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": "line1\nline2\tend", + }, + wantErr: false, + }, + { + name: "Option with hexadecimal escape sequence", + options: "OPT1 E'\\x41\\x42\\x43'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": "ABC", + }, + wantErr: false, + }, + { + name: "Option with Unicode escape sequence", + options: "OPT1 E'\\u263A'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: map[string]any{ + "OPT1": "☺", + }, + wantErr: false, + }, + { + name: "Option with null character", + options: "OPT1 E'null\\0char'", + allowed: map[string]OptionValueType{ + "OPT1": OptionValueTypeString, + }, + expected: nil, + wantErr: true, + }, + { + name: "Empty options string", + options: "", + allowed: map[string]OptionValueType{}, + expected: map[string]any{}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ParseCopyOptions(tt.options, tt.allowed) + if (err != nil) != tt.wantErr { + t.Errorf("ParseCopyOptions() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if len(result) != len(tt.expected) { + t.Errorf("ParseCopyOptions() got = %v, want %v", result, tt.expected) + } + for k, v := range tt.expected { + if result[k] != v { + t.Errorf("ParseCopyOptions() got = %v, want %v", result, tt.expected) + } + } + } + }) + } +} diff --git a/pgserver/dataloader.go b/pgserver/dataloader.go index 8bc6056f..876d51f6 100644 --- a/pgserver/dataloader.go +++ b/pgserver/dataloader.go @@ -1,11 +1,9 @@ package pgserver import ( - "bufio" "context" "errors" "fmt" - "io" "os" "strconv" "strings" @@ -15,6 +13,7 @@ import ( "github.com/apecloud/myduckserver/backend" "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" "github.com/dolthub/go-mysql-server/sql" + "github.com/sirupsen/logrus" ) // DataLoader allows callers to insert rows from multiple chunks into a table. Rows encoded in each chunk will not @@ -22,10 +21,14 @@ import ( // incomplete records, and saving that partial record until the next call to LoadChunk, so that it may be prefixed // with the incomplete record. type DataLoader interface { + // Start prepares the DataLoader for loading data. This may involve creating goroutines, opening files, etc. + // Start must be called before any calls to LoadChunk. + Start() <-chan error + // LoadChunk reads the records from |data| and inserts them into the previously configured table. Data records // are not guaranteed to stard and end cleanly on chunk boundaries, so implementations must recognize incomplete // records and save them to prepend on the next processed chunk. - LoadChunk(ctx *sql.Context, data *bufio.Reader) error + LoadChunk(ctx *sql.Context, data []byte) error // Abort aborts the current load operation and releases all used resources. Abort(ctx *sql.Context) error @@ -45,23 +48,112 @@ type LoadDataResults struct { var ErrCopyAborted = fmt.Errorf("COPY operation aborted") -type CsvDataLoader struct { +type PipeDataLoader struct { ctx *sql.Context cancel context.CancelFunc schema string table sql.InsertableTable columns tree.NameList - options *tree.CopyOptions pipePath string - pipe *os.File - errPipe *os.File // for error handling + read func() + pipe atomic.Pointer[os.File] // for writing + errPipe atomic.Pointer[os.File] // for error handling rowCount chan int64 err atomic.Pointer[error] + logger *logrus.Entry +} + +func (loader *PipeDataLoader) Start() <-chan error { + // Open the reader. + go loader.read() + + ready := make(chan error, 1) + go func() { + defer close(ready) + + // TODO(fan): If the reader fails to open the pipe, the writer will block forever. + // Open the pipe for writing. + // This operation will block until the reader opens the pipe for reading. + loader.logger.Debugf("Opening pipe for writing: %s", loader.pipePath) + pipe, err := os.OpenFile(loader.pipePath, os.O_WRONLY, os.ModeNamedPipe) + if err != nil { + ready <- err + return + } + + // If the COPY operation failed to start, close the pipe and return the error. + if loader.errPipe.Load() != nil { + ready <- errors.Join(*loader.err.Load(), pipe.Close(), loader.errPipe.Load().Close()) + return + } + + loader.pipe.Store(pipe) + }() + return ready +} + +func (loader *PipeDataLoader) LoadChunk(ctx *sql.Context, data []byte) error { + if errp := loader.err.Load(); errp != nil { + return fmt.Errorf("COPY operation has been aborted: %w", *errp) + } + loader.logger.Tracef("Copying %d bytes to pipe %s", len(data), loader.pipePath) + // Write the data to the FIFO pipe. + _, err := loader.pipe.Load().Write(data) + if err != nil { + loader.logger.Error("Copying data to pipe failed:", err) + loader.Abort(ctx) + return err + } + loader.logger.Tracef("Copied %d bytes to pipe %s", len(data), loader.pipePath) + return nil +} + +func (loader *PipeDataLoader) Abort(ctx *sql.Context) error { + defer os.Remove(loader.pipePath) + loader.err.Store(&ErrCopyAborted) + loader.cancel() + <-loader.rowCount // Ensure the reader has exited + return loader.pipe.Load().Close() +} + +func (loader *PipeDataLoader) Finish(ctx *sql.Context) (*LoadDataResults, error) { + defer os.Remove(loader.pipePath) + + if errp := loader.err.Load(); errp != nil { + loader.logger.Errorln("COPY operation failed:", *errp) + return nil, *errp + } + + // Close the pipe to signal the reader to exit + if err := loader.pipe.Load().Close(); err != nil { + return nil, err + } + + rows := <-loader.rowCount + + // Now the reader has exited, check the error again + if errp := loader.err.Load(); errp != nil { + loader.logger.Errorln("COPY operation failed:", *errp) + return nil, *errp + } + + return &LoadDataResults{ + RowsLoaded: int32(rows), + }, nil +} + +type CsvDataLoader struct { + PipeDataLoader + options *tree.CopyOptions } var _ DataLoader = (*CsvDataLoader)(nil) -func NewCsvDataLoader(ctx *sql.Context, handler *DuckHandler, schema string, table sql.InsertableTable, columns tree.NameList, options *tree.CopyOptions) (DataLoader, error) { +func NewCsvDataLoader( + ctx *sql.Context, handler *DuckHandler, + schema string, table sql.InsertableTable, columns tree.NameList, options *tree.CopyOptions, + rawOptions string, // For non-PG-parsable COPY FROM, unused for now +) (DataLoader, error) { // Create the FIFO pipe duckBuilder := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder) pipePath, err := duckBuilder.CreatePipe(ctx, "pg-copy-from") @@ -74,37 +166,22 @@ func NewCsvDataLoader(ctx *sql.Context, handler *DuckHandler, schema string, tab ctx.Context = childCtx loader := &CsvDataLoader{ - ctx: ctx, - cancel: cancel, - schema: schema, - table: table, - columns: columns, - options: options, - pipePath: pipePath, - rowCount: make(chan int64, 1), - } - - // Execute the DuckDB COPY statement in a goroutine. - sql := loader.buildSQL() - loader.ctx.GetLogger().Trace(sql) - go loader.executeCopy(sql, pipePath) - - // TODO(fan): If the reader fails to open the pipe, the writer will block forever. - - // Open the pipe for writing. - // This operation will block until the reader opens the pipe for reading. - pipe, err := os.OpenFile(pipePath, os.O_WRONLY, os.ModeNamedPipe) - if err != nil { - return nil, err + PipeDataLoader: PipeDataLoader{ + ctx: ctx, + cancel: cancel, + schema: schema, + table: table, + columns: columns, + pipePath: pipePath, + rowCount: make(chan int64, 1), + logger: ctx.GetLogger(), + }, + options: options, } - - // If the COPY operation failed to start, close the pipe and return the error. - if loader.errPipe != nil { - return nil, errors.Join(*loader.err.Load(), pipe.Close(), loader.errPipe.Close()) + loader.read = func() { + loader.executeCopy(loader.buildSQL(), pipePath) } - loader.pipe = pipe - return loader, nil } @@ -174,12 +251,14 @@ func (loader *CsvDataLoader) buildSQL() string { func (loader *CsvDataLoader) executeCopy(sql string, pipePath string) { defer close(loader.rowCount) + loader.logger.Debugf("Executing COPY statement: %s", sql) result, err := adapter.Exec(loader.ctx, sql) if err != nil { loader.ctx.GetLogger().Error(err) loader.err.Store(&err) // Open the pipe once to unblock the writer - loader.errPipe, _ = os.OpenFile(pipePath, os.O_RDONLY, os.ModeNamedPipe) + pipe, _ := os.OpenFile(pipePath, os.O_RDONLY, os.ModeNamedPipe) + loader.errPipe.Store(pipe) return } @@ -192,52 +271,6 @@ func (loader *CsvDataLoader) executeCopy(sql string, pipePath string) { loader.rowCount <- rows } -func (loader *CsvDataLoader) LoadChunk(ctx *sql.Context, data *bufio.Reader) error { - if errp := loader.err.Load(); errp != nil { - return fmt.Errorf("COPY operation has been aborted: %w", *errp) - } - // Write the data to the FIFO pipe. - _, err := io.Copy(loader.pipe, data) - if err != nil { - ctx.GetLogger().Error("Copying data to pipe failed:", err) - loader.Abort(ctx) - return err - } - return nil -} - -func (loader *CsvDataLoader) Abort(ctx *sql.Context) error { - defer os.Remove(loader.pipePath) - loader.err.Store(&ErrCopyAborted) - loader.cancel() - <-loader.rowCount // Ensure the reader has exited - return loader.pipe.Close() -} - -func (loader *CsvDataLoader) Finish(ctx *sql.Context) (*LoadDataResults, error) { - defer os.Remove(loader.pipePath) - - if errp := loader.err.Load(); errp != nil { - return nil, *errp - } - - // Close the pipe to signal the reader to exit - if err := loader.pipe.Close(); err != nil { - return nil, err - } - - rows := <-loader.rowCount - - // Now the reader has exited, check the error again - if errp := loader.err.Load(); errp != nil { - return nil, *errp - } - - return &LoadDataResults{ - RowsLoaded: int32(rows), - }, nil -} - func singleQuotedDuckChar(s string) string { if len(s) == 0 { return `''` diff --git a/pgserver/datawriter.go b/pgserver/datawriter.go index b1441f9d..1261ea69 100644 --- a/pgserver/datawriter.go +++ b/pgserver/datawriter.go @@ -12,20 +12,30 @@ import ( "github.com/dolthub/go-mysql-server/sql" ) -type DataWriter struct { +type DataWriter interface { + Start() (string, chan CopyToResult, error) + Close() +} + +type CopyToResult struct { + RowCount int64 + Err error +} + +type DuckDataWriter struct { ctx *sql.Context duckSQL string options *tree.CopyOptions pipePath string } -func NewDataWriter( +func NewDuckDataWriter( ctx *sql.Context, handler *DuckHandler, schema string, table sql.Table, columns tree.NameList, query string, options *tree.CopyOptions, rawOptions string, -) (*DataWriter, error) { +) (*DuckDataWriter, error) { // Create the FIFO pipe db := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder) pipePath, err := db.CreatePipe(ctx, "pg-copy-to") @@ -127,7 +137,7 @@ func NewDataWriter( return nil, fmt.Errorf("BINARY format is not supported for COPY TO") } - return &DataWriter{ + return &DuckDataWriter{ ctx: ctx, duckSQL: builder.String(), options: options, @@ -135,14 +145,9 @@ func NewDataWriter( }, nil } -type copyToResult struct { - RowCount int64 - Err error -} - -func (dw *DataWriter) Start() (string, chan copyToResult, error) { +func (dw *DuckDataWriter) Start() (string, chan CopyToResult, error) { // Execute the COPY TO statement in a separate goroutine. - ch := make(chan copyToResult, 1) + ch := make(chan CopyToResult, 1) go func() { defer os.Remove(dw.pipePath) defer close(ch) @@ -152,16 +157,16 @@ func (dw *DataWriter) Start() (string, chan copyToResult, error) { // This operation will block until the reader opens the pipe for reading. result, err := adapter.ExecCatalog(dw.ctx, dw.duckSQL) if err != nil { - ch <- copyToResult{Err: err} + ch <- CopyToResult{Err: err} return } affected, _ := result.RowsAffected() - ch <- copyToResult{RowCount: affected} + ch <- CopyToResult{RowCount: affected} }() return dw.pipePath, ch, nil } -func (dw *DataWriter) Close() { +func (dw *DuckDataWriter) Close() { os.Remove(dw.pipePath) } diff --git a/pgtest/psql/copy/arrow.sql b/pgtest/psql/copy/arrow.sql new file mode 100644 index 00000000..33b3cfb5 --- /dev/null +++ b/pgtest/psql/copy/arrow.sql @@ -0,0 +1,19 @@ +CREATE SCHEMA IF NOT EXISTS test_psql_copy_to_arrow; + +USE test_psql_copy_to_arrow; + +CREATE TABLE t (a int, b text, c float); + +INSERT INTO t VALUES (1, 'one', 1.1), (2, 'two', 2.2), (3, 'three', 3.3), (4, 'four', 4.4), (5, 'five', 5.5); + +\o 'stdout.arrow' + +COPY t TO STDOUT (FORMAT ARROW); + +\o + +\echo `python -c "import pyarrow as pa; reader = pa.ipc.open_stream('stdout.arrow'); print(reader.read_all().to_pandas())"` + +\copy t FROM 'stdout.arrow' (FORMAT ARROW); + +SELECT * FROM t; \ No newline at end of file