Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pg): COPY TO STDOUT #184

Merged
merged 11 commits into from
Nov 21, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ pipes/
*.sock
__debug_*
.DS_Store
*.csv
13 changes: 3 additions & 10 deletions backend/loaddata.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"runtime"
"strconv"
"strings"
"syscall"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/catalog"
Expand Down Expand Up @@ -80,21 +79,15 @@ func (db *DuckBuilder) buildClientSideLoadData(ctx *sql.Context, insert *plan.In
}
defer reader.Close()

// Create the FIFO pipe
pipeDir := filepath.Join(db.provider.DataDir(), "pipes", "load-data")
if err := os.MkdirAll(pipeDir, 0755); err != nil {
return nil, err
}
pipeName := strconv.Itoa(int(ctx.ID())) + ".pipe"
pipePath := filepath.Join(pipeDir, pipeName)
if err := syscall.Mkfifo(pipePath, 0600); err != nil {
pipePath, err := db.CreatePipe(ctx, "load-data")
if err != nil {
return nil, err
}
defer os.Remove(pipePath)

// Write the data to the FIFO pipe.
go func() {
pipe, err := os.OpenFile(pipePath, os.O_WRONLY, 0600)
pipe, err := os.OpenFile(pipePath, os.O_WRONLY, os.ModeNamedPipe)
if err != nil {
return
}
Expand Down
25 changes: 25 additions & 0 deletions backend/pipe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package backend

import (
"os"
"path/filepath"
"strconv"
"syscall"

"github.com/dolthub/go-mysql-server/sql"
)

func (h *DuckBuilder) CreatePipe(ctx *sql.Context, subdir string) (string, error) {
// Create the FIFO pipe
pipeDir := filepath.Join(h.provider.DataDir(), "pipes", subdir)
if err := os.MkdirAll(pipeDir, 0755); err != nil {
return "", err
}
pipeName := strconv.Itoa(int(ctx.ID())) + ".pipe"
pipePath := filepath.Join(pipeDir, pipeName)
ctx.GetLogger().Debugln("Creating FIFO pipe for COPY FROM operation:", pipePath)
fanyang01 marked this conversation as resolved.
Show resolved Hide resolved
if err := syscall.Mkfifo(pipePath, 0600); err != nil {
return "", err
}
return pipePath, nil
}
3 changes: 3 additions & 0 deletions pgserver/connection_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"fmt"

"github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/vt/proto/query"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/lib/pq/oid"
Expand Down Expand Up @@ -66,6 +67,8 @@ type copyFromStdinState struct {
// node is used to look at what parameters were specified, such as which table to load data into, file format,
// delimiters, etc.
copyFromStdinNode *tree.CopyFrom
// targetTable stores the targetTable that the data is being loaded into.
targetTable sql.InsertableTable
// dataLoader is the implementation of DataLoader that is used to load each individual CopyData chunk into the
// target table.
dataLoader DataLoader
Expand Down
149 changes: 129 additions & 20 deletions pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,25 @@ import (
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"runtime/debug"
"slices"
"strings"
"sync/atomic"

"github.com/cockroachdb/cockroachdb-parser/pkg/sql/parser"
"github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree"
gms "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/mysql"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/jackc/pgx/v5/pgtype"
"github.com/sirupsen/logrus"
// Import the new datawriter package
)

// ConnectionHandler is responsible for the entire lifecycle of a user connection: receiving messages they send,
Expand Down Expand Up @@ -348,7 +350,7 @@ func (h *ConnectionHandler) receiveMessage() (bool, error) {
// handleMessages processes the message provided and returns status flags indicating what the connection should do next.
// If the |stop| response parameter is true, it indicates that the connection should be closed by the caller. If the
// |endOfMessages| response parameter is true, it indicates that no more messages are expected for the current operation
// and a READY FOR QUERY message should be sent back to the client, so it can send the next query.
// and a READY FOR QUERY message should be sent back to the client so it can send the next query.
func (h *ConnectionHandler) handleMessage(msg pgproto3.Message) (stop, endOfMessages bool, err error) {
logrus.Tracef("Handling message: %T", msg)
switch message := msg.(type) {
Expand Down Expand Up @@ -448,11 +450,12 @@ func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (hand
if stmt.Stdin {
return true, false, h.handleCopyFromStdinQuery(query, stmt, h.Conn())
}
case *tree.CopyTo:
return true, true, h.handleCopyToStdout(query, stmt)
}
return false, true, nil
}

// handleParse handles a parse message, returning any error that occurs
func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error {
h.waitForSync = true

Expand Down Expand Up @@ -662,20 +665,9 @@ func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (st
if copyFrom == nil {
return false, false, fmt.Errorf("no COPY FROM STDIN node found")
}

// TODO: It would be better to get the table from the copyFromStdinNode – not by calling core.GetSqlTableFromContext
schemaName := copyFrom.Table.Schema()
tableName := copyFrom.Table.Table()
table, err := GetSqlTableFromContext(sqlCtx, schemaName, tableName)
if err != nil {
return false, true, err
}
table := h.copyFromStdinState.targetTable
if table == nil {
return false, true, fmt.Errorf(`relation "%s" does not exist`, tableName)
}
insertableTable, ok := table.(sql.InsertableTable)
if !ok {
return false, true, fmt.Errorf(`table "%s" is read-only`, tableName)
return false, true, fmt.Errorf("no target table found")
}

switch copyFrom.Options.CopyFormat {
Expand All @@ -693,12 +685,15 @@ func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (st
}
fallthrough
case tree.CopyFormatCSV:
dataLoader, err = NewCsvDataLoader(sqlCtx, h.duckHandler, schemaName, insertableTable, copyFrom.Columns, &copyFrom.Options)
dataLoader, err = NewCsvDataLoader(
sqlCtx, h.duckHandler,
copyFrom.Table.Schema(), table, copyFrom.Columns,
&copyFrom.Options,
)
case tree.CopyFormatBinary:
err = fmt.Errorf("BINARY format is not supported for COPY FROM")
default:
err = fmt.Errorf("unknown format specified for COPY FROM: %v",
copyFrom.Options.CopyFormat)
err = fmt.Errorf("unknown format specified for COPY FROM: %v", copyFrom.Options.CopyFormat)
}

if err != nil {
Expand Down Expand Up @@ -1101,12 +1096,14 @@ func (h *ConnectionHandler) handleCopyFromStdinQuery(query ConvertedQuery, copyF
}
sqlCtx.SetLogger(sqlCtx.GetLogger().WithField("query", query.String))

if err := ValidateCopyFrom(copyFrom, sqlCtx); err != nil {
table, err := ValidateCopyFrom(copyFrom, sqlCtx)
if err != nil {
return err
}

h.copyFromStdinState = &copyFromStdinState{
copyFromStdinNode: copyFrom,
targetTable: table,
}

return h.send(&pgproto3.CopyInResponse{
Expand Down Expand Up @@ -1144,3 +1141,115 @@ func returnsRow(tag string) bool {
return false
}
}

func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tree.CopyTo) error {
ctx, err := h.duckHandler.NewContext(context.Background(), h.mysqlConn, query.String)
if err != nil {
return err
}
ctx.SetLogger(ctx.GetLogger().WithField("query", query.String))

// Create cancelable context
childCtx, cancel := context.WithCancel(ctx)
defer cancel()
ctx = ctx.WithContext(childCtx)

table, err := ValidateCopyTo(copyTo, ctx)
if err != nil {
return err
}

var stmt string
if copyTo.Statement != nil {
stmt = copyTo.Statement.String()
}
writer, err := NewDataWriter(
ctx, h.duckHandler,
copyTo.Table.Schema(), table, copyTo.Columns,
stmt,
&copyTo.Options,
)
if err != nil {
return err
}
defer writer.Close()

// Send CopyOutResponse to the client
ctx.GetLogger().Debug("sending CopyOutResponse to the client")
copyOutResponse := &pgproto3.CopyOutResponse{
OverallFormat: 0, // 0 for text format
ColumnFormatCodes: []uint16{0}, // 0 for text format
}
if err := h.send(copyOutResponse); err != nil {
return err
}

pipePath, ch, err := writer.Start()
if err != nil {
return err
}

var sendErr atomic.Value
go func() {
// Open the pipe for reading.
ctx.GetLogger().Tracef("Opening FIFO pipe for reading: %s", pipePath)
pipe, err := os.OpenFile(pipePath, os.O_RDONLY, os.ModeNamedPipe)
if err != nil {
sendErr.Store(fmt.Errorf("failed to open pipe for reading: %w", err))
cancel()
return
}
defer pipe.Close()

ctx.GetLogger().Debug("Copying data from the pipe to the client")
buf := make([]byte, 1<<20) // 1MB buffer
for {
n, err := pipe.Read(buf)
if n > 0 {
copyData := &pgproto3.CopyData{
Data: buf[:n],
}
ctx.GetLogger().Debugf("sending CopyData (%d bytes) to the client", n)
if err := h.send(copyData); err != nil {
sendErr.Store(err)
cancel()
return
}
}
if err != nil {
if err == io.EOF {
break
}
sendErr.Store(err)
cancel()
return
}
}
}()

select {
case <-ctx.Done(): // Context is canceled
err, _ := sendErr.Load().(error)
return errors.Join(ctx.Err(), err)
case result := <-ch:
if result.Err != nil {
return fmt.Errorf("failed to copy data: %w", result.Err)
}

if err, ok := sendErr.Load().(error); ok {
return err
}

// After data is sent and the producer side is finished without errors, send CopyDone
ctx.GetLogger().Debug("sending CopyDone to the client")
if err := h.send(&pgproto3.CopyDone{}); err != nil {
return err
}

// Send CommandComplete with the number of rows copied
ctx.GetLogger().Debugf("sending CommandComplete to the client")
return h.send(&pgproto3.CommandComplete{
CommandTag: []byte(fmt.Sprintf("COPY %d", result.RowCount)),
})
}
}
26 changes: 8 additions & 18 deletions pgserver/dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ import (
"fmt"
"io"
"os"
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"syscall"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/backend"
Expand Down Expand Up @@ -63,28 +61,20 @@ type CsvDataLoader struct {

var _ DataLoader = (*CsvDataLoader)(nil)

func NewCsvDataLoader(sqlCtx *sql.Context, handler *DuckHandler, schema string, table sql.InsertableTable, columns tree.NameList, options *tree.CopyOptions) (DataLoader, error) {
duckBuilder := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder)
dataDir := duckBuilder.Provider().DataDir()

func NewCsvDataLoader(ctx *sql.Context, handler *DuckHandler, schema string, table sql.InsertableTable, columns tree.NameList, options *tree.CopyOptions) (DataLoader, error) {
// Create the FIFO pipe
pipeDir := filepath.Join(dataDir, "pipes", "load-data")
if err := os.MkdirAll(pipeDir, 0755); err != nil {
return nil, err
}
pipeName := strconv.Itoa(int(sqlCtx.ID())) + ".pipe"
pipePath := filepath.Join(pipeDir, pipeName)
sqlCtx.GetLogger().Traceln("Creating FIFO pipe for COPY operation:", pipePath)
if err := syscall.Mkfifo(pipePath, 0600); err != nil {
duckBuilder := handler.e.Analyzer.ExecBuilder.(*backend.DuckBuilder)
pipePath, err := duckBuilder.CreatePipe(ctx, "pg-copy-from")
if err != nil {
return nil, err
}

// Create cancelable context
childCtx, cancel := context.WithCancel(sqlCtx)
sqlCtx.Context = childCtx
childCtx, cancel := context.WithCancel(ctx)
ctx.Context = childCtx

loader := &CsvDataLoader{
ctx: sqlCtx,
ctx: ctx,
cancel: cancel,
schema: schema,
table: table,
Expand All @@ -103,7 +93,7 @@ func NewCsvDataLoader(sqlCtx *sql.Context, handler *DuckHandler, schema string,

// Open the pipe for writing.
// This operation will block until the reader opens the pipe for reading.
pipe, err := os.OpenFile(pipePath, os.O_WRONLY, 0600)
pipe, err := os.OpenFile(pipePath, os.O_WRONLY, os.ModeNamedPipe)
if err != nil {
return nil, err
}
Expand Down
Loading