Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pg): COPY TO STDOUT #184

Merged
merged 11 commits into from
Nov 21, 2024
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ pipes/
*.sock
__debug_*
.DS_Store
*.csv
13 changes: 3 additions & 10 deletions backend/loaddata.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"runtime"
"strconv"
"strings"
"syscall"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/catalog"
Expand Down Expand Up @@ -80,21 +79,15 @@ func (db *DuckBuilder) buildClientSideLoadData(ctx *sql.Context, insert *plan.In
}
defer reader.Close()

// Create the FIFO pipe
pipeDir := filepath.Join(db.provider.DataDir(), "pipes", "load-data")
if err := os.MkdirAll(pipeDir, 0755); err != nil {
return nil, err
}
pipeName := strconv.Itoa(int(ctx.ID())) + ".pipe"
pipePath := filepath.Join(pipeDir, pipeName)
if err := syscall.Mkfifo(pipePath, 0600); err != nil {
pipePath, err := db.CreatePipe(ctx, "load-data")
if err != nil {
return nil, err
}
defer os.Remove(pipePath)

// Write the data to the FIFO pipe.
go func() {
pipe, err := os.OpenFile(pipePath, os.O_WRONLY, 0600)
pipe, err := os.OpenFile(pipePath, os.O_WRONLY, os.ModeNamedPipe)
if err != nil {
return
}
Expand Down
25 changes: 25 additions & 0 deletions backend/pipe.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package backend

import (
"os"
"path/filepath"
"strconv"
"syscall"

"github.com/dolthub/go-mysql-server/sql"
)

func (h *DuckBuilder) CreatePipe(ctx *sql.Context, subdir string) (string, error) {
// Create the FIFO pipe
pipeDir := filepath.Join(h.provider.DataDir(), "pipes", subdir)
if err := os.MkdirAll(pipeDir, 0755); err != nil {
return "", err
}
pipeName := strconv.Itoa(int(ctx.ID())) + ".pipe"
pipePath := filepath.Join(pipeDir, pipeName)
ctx.GetLogger().Debugln("Creating FIFO pipe for LOAD/COPY operation:", pipePath)
if err := syscall.Mkfifo(pipePath, 0600); err != nil {
return "", err
}
return pipePath, nil
}
3 changes: 3 additions & 0 deletions pgserver/connection_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"fmt"

"github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/vt/proto/query"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/lib/pq/oid"
Expand Down Expand Up @@ -66,6 +67,8 @@ type copyFromStdinState struct {
// node is used to look at what parameters were specified, such as which table to load data into, file format,
// delimiters, etc.
copyFromStdinNode *tree.CopyFrom
// targetTable stores the targetTable that the data is being loaded into.
targetTable sql.InsertableTable
// dataLoader is the implementation of DataLoader that is used to load each individual CopyData chunk into the
// target table.
dataLoader DataLoader
Expand Down
159 changes: 138 additions & 21 deletions pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"runtime/debug"
"slices"
"strings"
"sync/atomic"

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/catalog"
Expand All @@ -38,7 +39,6 @@ import (
"github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree"
gms "github.com/dolthub/go-mysql-server"
"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/vitess/go/mysql"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/jackc/pgx/v5/pgtype"
Expand Down Expand Up @@ -341,7 +341,7 @@ func (h *ConnectionHandler) receiveMessage() (bool, error) {
if HandlePanics {
defer func() {
if r := recover(); r != nil {
fmt.Printf("Listener recovered panic: %v\n%s\n", r, string(debug.Stack()))
h.logger.Debugf("Listener recovered panic: %v\n%s\n", r, string(debug.Stack()))

var eomErr error
if rErr, ok := r.(error); ok {
Expand All @@ -352,7 +352,7 @@ func (h *ConnectionHandler) receiveMessage() (bool, error) {

if !endOfMessages && h.waitForSync {
if syncErr := h.discardToSync(); syncErr != nil {
fmt.Println(syncErr.Error())
h.logger.Error(syncErr.Error())
}
}
h.endOfMessages(eomErr)
Expand Down Expand Up @@ -601,6 +601,8 @@ func (h *ConnectionHandler) handleQueryOutsideEngine(query ConvertedQuery) (hand
if stmt.Stdin {
return true, false, h.handleCopyFromStdinQuery(query, stmt, h.Conn())
}
case *tree.CopyTo:
return true, true, h.handleCopyToStdout(query, stmt)
}
return false, true, nil
}
Expand Down Expand Up @@ -815,25 +817,13 @@ func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (st
if copyFrom == nil {
return false, false, fmt.Errorf("no COPY FROM STDIN node found")
}

// TODO: It would be better to get the table from the copyFromStdinNode – not by calling core.GetSqlTableFromContext
schemaName := copyFrom.Table.Schema()
tableName := copyFrom.Table.Table()
table, err := GetSqlTableFromContext(sqlCtx, schemaName, tableName)
if err != nil {
return false, true, err
}
table := h.copyFromStdinState.targetTable
if table == nil {
return false, true, fmt.Errorf(`relation "%s" does not exist`, tableName)
}
insertableTable, ok := table.(sql.InsertableTable)
if !ok {
return false, true, fmt.Errorf(`table "%s" is read-only`, tableName)
return false, true, fmt.Errorf("no target table found")
}

switch copyFrom.Options.CopyFormat {
case tree.CopyFormatText:
copyFrom.Options.Delimiter = tree.NewStrVal("\t")
// Remove trailing backslash, comma and newline characters from the data
if bytes.HasSuffix(message.Data, []byte{'\n'}) {
message.Data = message.Data[:len(message.Data)-1]
Expand All @@ -846,12 +836,15 @@ func (h *ConnectionHandler) handleCopyDataHelper(message *pgproto3.CopyData) (st
}
fallthrough
case tree.CopyFormatCSV:
dataLoader, err = NewCsvDataLoader(sqlCtx, h.duckHandler, schemaName, insertableTable, copyFrom.Columns, &copyFrom.Options)
dataLoader, err = NewCsvDataLoader(
sqlCtx, h.duckHandler,
copyFrom.Table.Schema(), table, copyFrom.Columns,
&copyFrom.Options,
)
case tree.CopyFormatBinary:
err = fmt.Errorf("BINARY format is not supported for COPY FROM")
default:
err = fmt.Errorf("unknown format specified for COPY FROM: %v",
copyFrom.Options.CopyFormat)
err = fmt.Errorf("unknown format specified for COPY FROM: %v", copyFrom.Options.CopyFormat)
}

if err != nil {
Expand Down Expand Up @@ -1254,12 +1247,14 @@ func (h *ConnectionHandler) handleCopyFromStdinQuery(query ConvertedQuery, copyF
}
sqlCtx.SetLogger(sqlCtx.GetLogger().WithField("query", query.String))

if err := ValidateCopyFrom(copyFrom, sqlCtx); err != nil {
table, err := ValidateCopyFrom(copyFrom, sqlCtx)
if err != nil {
return err
}

h.copyFromStdinState = &copyFromStdinState{
copyFromStdinNode: copyFrom,
targetTable: table,
}

return h.send(&pgproto3.CopyInResponse{
Expand Down Expand Up @@ -1297,3 +1292,125 @@ func returnsRow(tag string) bool {
return false
}
}

func (h *ConnectionHandler) handleCopyToStdout(query ConvertedQuery, copyTo *tree.CopyTo) error {
ctx, err := h.duckHandler.NewContext(context.Background(), h.mysqlConn, query.String)
if err != nil {
return err
}
ctx.SetLogger(ctx.GetLogger().WithField("query", query.String))

// Create cancelable context
childCtx, cancel := context.WithCancel(ctx)
defer cancel()
ctx = ctx.WithContext(childCtx)

table, err := ValidateCopyTo(copyTo, ctx)
if err != nil {
return err
}

var stmt string
if copyTo.Statement != nil {
stmt = copyTo.Statement.String()
}
writer, err := NewDataWriter(
ctx, h.duckHandler,
copyTo.Table.Schema(), table, copyTo.Columns,
stmt,
&copyTo.Options,
)
if err != nil {
return err
}
defer writer.Close()

// Send CopyOutResponse to the client
ctx.GetLogger().Debug("sending CopyOutResponse to the client")
copyOutResponse := &pgproto3.CopyOutResponse{
OverallFormat: 0, // 0 for text format
ColumnFormatCodes: []uint16{0}, // 0 for text format
}
if err := h.send(copyOutResponse); err != nil {
return err
}

pipePath, ch, err := writer.Start()
if err != nil {
return err
}

done := make(chan struct{})
var sendErr atomic.Value
go func() {
defer close(done)

// Open the pipe for reading.
ctx.GetLogger().Tracef("Opening FIFO pipe for reading: %s", pipePath)
pipe, err := os.OpenFile(pipePath, os.O_RDONLY, os.ModeNamedPipe)
if err != nil {
sendErr.Store(fmt.Errorf("failed to open pipe for reading: %w", err))
cancel()
return
}
defer pipe.Close()

ctx.GetLogger().Debug("Copying data from the pipe to the client")
defer func() {
ctx.GetLogger().Debug("Finished copying data from the pipe to the client")
}()

buf := make([]byte, 1<<20) // 1MB buffer
for {
n, err := pipe.Read(buf)
if n > 0 {
copyData := &pgproto3.CopyData{
Data: buf[:n],
}
ctx.GetLogger().Debugf("sending CopyData (%d bytes) to the client", n)
if err := h.send(copyData); err != nil {
sendErr.Store(err)
cancel()
return
}
}
if err != nil {
if err == io.EOF {
break
}
sendErr.Store(err)
cancel()
return
}
}
}()

select {
case <-ctx.Done(): // Context is canceled
<-done
err, _ := sendErr.Load().(error)
return errors.Join(ctx.Err(), err)
case result := <-ch:
<-done

if result.Err != nil {
return fmt.Errorf("failed to copy data: %w", result.Err)
}

if err, ok := sendErr.Load().(error); ok {
return err
}

// After data is sent and the producer side is finished without errors, send CopyDone
ctx.GetLogger().Debug("sending CopyDone to the client")
if err := h.send(&pgproto3.CopyDone{}); err != nil {
return err
}

// Send CommandComplete with the number of rows copied
ctx.GetLogger().Debugf("sending CommandComplete to the client")
return h.send(&pgproto3.CommandComplete{
CommandTag: []byte(fmt.Sprintf("COPY %d", result.RowCount)),
})
}
}
Loading