diff --git a/.gitignore b/.gitignore index 8ff8ac38..2416eb99 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ pipes/ __debug_* .DS_Store *.csv +*.parquet diff --git a/logo/myduck-logo.png b/logo/myduck-logo.png new file mode 100644 index 00000000..4202b333 Binary files /dev/null and b/logo/myduck-logo.png differ diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index a109c9d9..60a32f43 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -39,6 +39,7 @@ import ( "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" gms "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/server" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/mysql" "github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgtype" @@ -602,8 +603,15 @@ func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (hand return true, false, h.handleCopyFromStdinQuery(query, stmt, h.Conn()) } case *tree.CopyTo: - return true, true, h.handleCopyToStdout(query, stmt) + return true, true, h.handleCopyToStdout(query, stmt, "" /* unused */, tree.CopyFormatBinary, "") } + + if query.StatementTag == "COPY" { + if subquery, format, options, ok := ParseCopy(query.String); ok { + return true, true, h.handleCopyToStdout(query, nil, subquery, format, options) + } + } + return false, true, nil } @@ -1307,7 +1315,7 @@ func returnsRow(tag string) bool { } } -func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tree.CopyTo) error { +func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tree.CopyTo, subquery string, format tree.CopyFormat, rawOptions string) error { ctx, err := h.duckHandler.NewContext(context.Background(), h.mysqlConn, query.String) if err != nil { return err @@ -1319,20 +1327,40 @@ func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tre defer cancel() ctx = ctx.WithContext(childCtx) - table, err := ValidateCopyTo(copyTo, ctx) - if err != nil { - return err - } + var ( + schema string + table sql.Table + columns tree.NameList + stmt string + options *tree.CopyOptions + ) - var stmt string - if copyTo.Statement != nil { - stmt = copyTo.Statement.String() + if copyTo != nil { + // PG-parsable COPY TO + table, err = ValidateCopyTo(copyTo, ctx) + if err != nil { + return err + } + if copyTo.Statement != nil { + stmt = `(` + copyTo.Statement.String() + `)` + } + schema = copyTo.Table.Schema() + columns = copyTo.Columns + options = ©To.Options + } else { + // Non-PG-parsable COPY TO, which is parsed via regex. + stmt = subquery + options = &tree.CopyOptions{ + CopyFormat: format, + HasFormat: true, + } } + writer, err := NewDataWriter( ctx, h.duckHandler, - copyTo.Table.Schema(), table, copyTo.Columns, + schema, table, columns, stmt, - ©To.Options, + options, rawOptions, ) if err != nil { return err diff --git a/pgserver/copy.go b/pgserver/copy.go new file mode 100644 index 00000000..9911df76 --- /dev/null +++ b/pgserver/copy.go @@ -0,0 +1,56 @@ +package pgserver + +import ( + "regexp" + "strings" + + "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" + "github.com/dolthub/go-mysql-server/sql" +) + +const ( + CopyFormatParquet = tree.CopyFormatCSV + 1 + CopyFormatJSON = tree.CopyFormatCSV + 2 +) + +var ( + // We are supporting the parquet/... formats for COPY TO, but + // COPY ... TO STDOUT [WITH] (FORMAT PARQUET, OPT1 v1, OPT2, OPT3 v3, ...) + // Let's match them with regex and extract the ... part. + // Update regex to capture FORMAT and other options + reCopyToFormat = regexp.MustCompile(`(?i)^COPY\s+(.*?)\s+TO\s+STDOUT(?:\s+(?:WITH\s*)?\(\s*(?:FORMAT\s+(\w+)\s*,?\s*)?(.*?)\s*\))?$`) +) + +func ParseCopy(stmt string) (query string, format tree.CopyFormat, options string, ok bool) { + stmt = RemoveComments(stmt) + stmt = sql.RemoveSpaceAndDelimiter(stmt, ';') + m := reCopyToFormat.FindStringSubmatch(stmt) + if m == nil { + return "", 0, "", false + } + query = strings.TrimSpace(m[1]) + + var formatStr string + if m[2] != "" { + formatStr = strings.ToUpper(m[2]) + } else { + formatStr = "TEXT" + } + + options = strings.TrimSpace(m[3]) + + switch formatStr { + case "PARQUET": + format = CopyFormatParquet + case "JSON": + format = CopyFormatJSON + case "CSV": + format = tree.CopyFormatCSV + case "BINARY": + format = tree.CopyFormatBinary + case "", "TEXT": + format = tree.CopyFormatText + } + + return query, format, options, true +} diff --git a/pgserver/datawriter.go b/pgserver/datawriter.go index fc15e893..b1441f9d 100644 --- a/pgserver/datawriter.go +++ b/pgserver/datawriter.go @@ -24,7 +24,7 @@ func NewDataWriter( handler *DuckHandler, schema string, table sql.Table, columns tree.NameList, query string, - options *tree.CopyOptions, + options *tree.CopyOptions, rawOptions string, ) (*DataWriter, error) { // Create the FIFO pipe db := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder) @@ -51,18 +51,41 @@ func NewDataWriter( builder.WriteString(")") } } else { - builder.WriteString("(") + // the parentheses have already been added builder.WriteString(query) - builder.WriteString(")") } builder.WriteString(" TO '") builder.WriteString(pipePath) switch options.CopyFormat { + case CopyFormatParquet: + builder.WriteString("' (FORMAT PARQUET") + if rawOptions != "" { + builder.WriteString(", ") + builder.WriteString(rawOptions) + } + builder.WriteString(")") + + case CopyFormatJSON: + builder.WriteString("' (FORMAT JSON") + if rawOptions != "" { + builder.WriteString(", ") + builder.WriteString(rawOptions) + } + builder.WriteString(")") + case tree.CopyFormatText, tree.CopyFormatCSV: builder.WriteString("' (FORMAT CSV") + if rawOptions != "" { + // TODO(fan): For TEXT format, we should add some default options if not specified. + builder.WriteString(", ") + builder.WriteString(rawOptions) + builder.WriteString(")") + break + } + builder.WriteString(", HEADER ") if options.HasHeader && options.Header { builder.WriteString("true") @@ -98,13 +121,12 @@ func NewDataWriter( } else if options.CopyFormat == tree.CopyFormatText { builder.WriteString(`, NULLSTR '\N'`) } + builder.WriteString(")") case tree.CopyFormatBinary: return nil, fmt.Errorf("BINARY format is not supported for COPY TO") } - builder.WriteString(")") - return &DataWriter{ ctx: ctx, duckSQL: builder.String(), diff --git a/pgserver/stmt.go b/pgserver/stmt.go index 0ff25803..1fb0d0ac 100644 --- a/pgserver/stmt.go +++ b/pgserver/stmt.go @@ -1,6 +1,7 @@ package pgserver import ( + "bytes" "strings" "unicode" @@ -82,18 +83,18 @@ func GetStatementTag(stmt *duckdb.Stmt) string { } func GuessStatementTag(query string) string { - // Remove leading line and block comments + // Remove leading comments query = RemoveLeadingComments(query) // Remove trailing semicolon query = sql.RemoveSpaceAndDelimiter(query, ';') - // Guess the statement tag by looking for the first space in the query. + // Guess the statement tag by looking for the first non-identifier character for i, c := range query { - if unicode.IsSpace(c) { + if !unicode.IsLetter(c) && c != '_' { return strings.ToUpper(query[:i]) } } - return "" + return strings.ToUpper(query) } func RemoveLeadingComments(query string) string { @@ -109,12 +110,28 @@ func RemoveLeadingComments(query string) string { } i += end + 1 } else if strings.HasPrefix(query[i:], "/*") { - // Skip block comment - end := strings.Index(query[i+2:], "*/") - if end == -1 { + // Skip block comment with nesting support + nestLevel := 1 + pos := i + 2 + for pos < n && nestLevel > 0 { + if pos+1 < n { + if query[pos] == '/' && query[pos+1] == '*' { + nestLevel++ + pos += 2 + continue + } + if query[pos] == '*' && query[pos+1] == '/' { + nestLevel-- + pos += 2 + continue + } + } + pos++ + } + if nestLevel > 0 { return "" } - i += end + 4 + i = pos } else if unicode.IsSpace(rune(query[i])) { // Skip whitespace i++ @@ -124,3 +141,121 @@ func RemoveLeadingComments(query string) string { } return query[i:] } + +// RemoveComments removes comments from a query string. +// It supports line comments (--), block comments (/* ... */), and quoted strings. +// Author: Claude Sonnet 3.5 +func RemoveComments(query string) string { + var buf bytes.Buffer + runes := []rune(query) + length := len(runes) + pos := 0 + + for pos < length { + // Handle line comments + if pos+1 < length && runes[pos] == '-' && runes[pos+1] == '-' { + pos += 2 + for pos < length && runes[pos] != '\n' { + pos++ + } + if pos < length { + buf.WriteRune('\n') + pos++ + } + continue + } + + // Handle block comments + if pos+1 < length && runes[pos] == '/' && runes[pos+1] == '*' { + nestLevel := 1 + pos += 2 + for pos < length && nestLevel > 0 { + if pos+1 < length { + if runes[pos] == '/' && runes[pos+1] == '*' { + nestLevel++ + pos += 2 + continue + } + if runes[pos] == '*' && runes[pos+1] == '/' { + nestLevel-- + pos += 2 + continue + } + } + pos++ + } + continue + } + + // Handle string literals + if runes[pos] == '\'' || (pos+1 < length && runes[pos] == 'E' && runes[pos+1] == '\'') { + if runes[pos] == 'E' { + buf.WriteRune('E') + pos++ + } + buf.WriteRune('\'') + pos++ + for pos < length { + if runes[pos] == '\'' { + buf.WriteRune('\'') + pos++ + break + } + if pos+1 < length && runes[pos] == '\\' { + buf.WriteRune('\\') + buf.WriteRune(runes[pos+1]) + pos += 2 + continue + } + buf.WriteRune(runes[pos]) + pos++ + } + continue + } + + // Handle dollar-quoted strings + if runes[pos] == '$' { + start := pos + tagEnd := pos + 1 + for tagEnd < length && (unicode.IsLetter(runes[tagEnd]) || unicode.IsDigit(runes[tagEnd]) || runes[tagEnd] == '_') { + tagEnd++ + } + if tagEnd < length && runes[tagEnd] == '$' { + tag := string(runes[start : tagEnd+1]) + buf.WriteString(tag) + pos = tagEnd + 1 + for pos < length { + if pos+len(tag) <= length && string(runes[pos:pos+len(tag)]) == tag { + buf.WriteString(tag) + pos += len(tag) + break + } + buf.WriteRune(runes[pos]) + pos++ + } + continue + } + } + + // Handle quoted identifiers + if runes[pos] == '"' { + buf.WriteRune('"') + pos++ + for pos < length { + if runes[pos] == '"' { + buf.WriteRune('"') + pos++ + break + } + buf.WriteRune(runes[pos]) + pos++ + } + continue + } + + buf.WriteRune(runes[pos]) + pos++ + } + + return buf.String() +} diff --git a/pgserver/stmt_test.go b/pgserver/stmt_test.go index 8b61e58f..cd565bd8 100644 --- a/pgserver/stmt_test.go +++ b/pgserver/stmt_test.go @@ -17,6 +17,19 @@ func TestGuessStatementTag(t *testing.T) { {"/* block comment */ INSERT INTO table VALUES (1);", "INSERT"}, {"\n\n", ""}, {"INVALID QUERY", "INVALID"}, + {"SELECT/* comment */FROM table;", "SELECT"}, + {"UPDATE-- comment\n table SET col = 1;", "UPDATE"}, + {"DELETE/* multi\nline\ncomment */FROM table;", "DELETE"}, + {"INSERT/* c1 */-- c2\n/* c3 */INTO table;", "INSERT"}, + {"CREATE/* comment */TABLE t1;", "CREATE"}, + {"select from t", "SELECT"}, + {"", ""}, + {"UPDATE(", "UPDATE"}, + {"DELETE.", "DELETE"}, + {"INSERT\n", "INSERT"}, + {"CREATE[", "CREATE"}, + {"drop_table", "DROP_TABLE"}, + {"select", "SELECT"}, } for _, tt := range tests { @@ -29,21 +42,141 @@ func TestGuessStatementTag(t *testing.T) { func TestRemoveLeadingComments(t *testing.T) { tests := []struct { + name string query string want string }{ - {"-- comment\nSELECT * FROM table;", "SELECT * FROM table;"}, - {"/* block comment */ SELECT * FROM table;", "SELECT * FROM table;"}, - {" \t\nSELECT * FROM table;", "SELECT * FROM table;"}, - {"/* comment */ -- another comment\nSELECT * FROM table;", "SELECT * FROM table;"}, - {"SELECT * FROM table;", "SELECT * FROM table;"}, - {"", ""}, + { + name: "basic line comment", + query: "-- comment\nSELECT * FROM table;", + want: "SELECT * FROM table;", + }, + { + name: "basic block comment", + query: "/* block comment */ SELECT * FROM table;", + want: "SELECT * FROM table;", + }, + { + name: "nested block comments", + query: "/* outer /* inner */ comment */ SELECT * FROM table;", + want: "SELECT * FROM table;", + }, + { + name: "multiple leading comments", + query: "/* c1 */-- c2\n/* c3 */ SELECT * FROM table;", + want: "SELECT * FROM table;", + }, + { + name: "only whitespace", + query: " \t\n ", + want: "", + }, + { + name: "unclosed block comment", + query: "/* unclosed comment SELECT 1;", + want: "", + }, + { + name: "not a leading comment", + query: "SELECT /* not leading */ 1;", + want: "SELECT /* not leading */ 1;", + }, + { + name: "empty input", + query: "", + want: "", + }, } for _, tt := range tests { - got := RemoveLeadingComments(tt.query) - if got != tt.want { - t.Errorf("RemoveLeadingComments(%q) = %q; want %q", tt.query, got, tt.want) - } + t.Run(tt.name, func(t *testing.T) { + got := RemoveLeadingComments(tt.query) + if got != tt.want { + t.Errorf("RemoveLeadingComments(%q) = %q; want %q", tt.query, got, tt.want) + } + }) + } +} + +func TestRemoveComments(t *testing.T) { + tests := []struct { + name string + query string + want string + }{ + { + name: "line comments", + query: "SELECT 1; -- comment\n-- another comment\nSELECT 2;", + want: "SELECT 1; \n\nSELECT 2;", + }, + { + name: "block comments", + query: "SELECT /* in-line */ 1; /* multi\nline\ncomment */ SELECT 2;", + want: "SELECT 1; SELECT 2;", + }, + { + name: "nested block comments", + query: "SELECT /* outer /* inner */ rest */ 1;", + want: "SELECT 1;", + }, + { + name: "comments in string literals", + query: "SELECT '-- not a comment' AS c1, '/* also not */ a comment' AS c2;", + want: "SELECT '-- not a comment' AS c1, '/* also not */ a comment' AS c2;", + }, + { + name: "comments in quoted identifiers", + query: `SELECT "-- not /* a */ comment" FROM t1;`, + want: `SELECT "-- not /* a */ comment" FROM t1;`, + }, + { + name: "dollar quoted strings", + query: "SELECT $tag$-- not /* a */ comment$tag$, $$ /* not */ comment $$;", + want: "SELECT $tag$-- not /* a */ comment$tag$, $$ /* not */ comment $$;", + }, + { + name: "complex dollar quotes", + query: "SELECT $a$b$/* not $a$b$ comment */$$a$b$, $tag$$tag$;", + want: "SELECT $a$b$/* not $a$b$ comment */$$a$b$, $tag$$tag$;", + }, + { + name: "escaped quotes in strings", + query: "SELECT 'string with \\'-- not a comment\\' continues';", + want: "SELECT 'string with \\'-- not a comment\\' continues';", + }, + { + name: "mixed comments and quotes", + query: "/* c1 */ SELECT -- c2\n'/* c3 */' /* c4 */ FROM /* c5 */ t1;", + want: " SELECT \n'/* c3 */' FROM t1;", + }, + { + name: "postgres escape string", + query: "SELECT E'\\t' /* comment */ AS tab;", + want: "SELECT E'\\t' AS tab;", + }, + { + name: "postgres escape string with embedded comments", + query: "SELECT E'-- not/*a*/comment\\n' FROM t1;", + want: "SELECT E'-- not/*a*/comment\\n' FROM t1;", + }, + { + name: "mixed postgres strings", + query: "SELECT E'\\t', '\\t', /* comment */ E'\\n';", + want: "SELECT E'\\t', '\\t', E'\\n';", + }, + { + name: "empty input", + query: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := RemoveComments(tt.query) + if got != tt.want { + t.Errorf("RemoveComments(%q) = %q; want %q", tt.query, got, tt.want) + } + }) } } diff --git a/pgtest/psql/copy/parquet.sql b/pgtest/psql/copy/parquet.sql new file mode 100644 index 00000000..7b17f6ce --- /dev/null +++ b/pgtest/psql/copy/parquet.sql @@ -0,0 +1,31 @@ +CREATE SCHEMA IF NOT EXISTS test_psql_copy_to_parquet; + +USE test_psql_copy_to_parquet; + +CREATE TABLE t (a int, b text, c float); + +INSERT INTO t VALUES (1, 'one', 1.1), (2, 'two', 2.2), (3, 'three', 3.3), (4, 'four', 4.4), (5, 'five', 5.5); + +\o 'stdout-1.parquet' + +COPY t TO STDOUT (FORMAT PARQUET); + +\o 'stdout-2.parquet' + +\copy t (a, b) TO STDOUT (FORMAT PARQUET); + +\o 'stdout-3.parquet' + +COPY t TO STDOUT (FORMAT PARQUET); + +\o 'stdout-4.parquet' + +\copy (SELECT a * a, b, c + a FROM t) TO STDOUT (FORMAT PARQUET); + +\echo `duckdb -c "SELECT * FROM 'stdout-1.parquet'"` + +\echo `duckdb -c "SELECT * FROM 'stdout-2.parquet'"` + +\echo `duckdb -c "SELECT * FROM 'stdout-3.parquet'"` + +\echo `duckdb -c "SELECT * FROM 'stdout-4.parquet'"` \ No newline at end of file