Skip to content

Commit 632c290

Browse files
committed
Make it work
1 parent 28cc8eb commit 632c290

File tree

6 files changed

+170
-66
lines changed

6 files changed

+170
-66
lines changed

adapter/adapter.go

+23
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ type ConnectionHolder interface {
1111
GetConn(ctx context.Context) (*stdsql.Conn, error)
1212
GetTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error)
1313
GetCatalogConn(ctx context.Context) (*stdsql.Conn, error)
14+
GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error)
15+
TryGetTxn() *stdsql.Tx
16+
CloseTxn()
1417
}
1518

1619
func GetConn(ctx *sql.Context) (*stdsql.Conn, error) {
@@ -21,6 +24,18 @@ func GetTxn(ctx *sql.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
2124
return ctx.Session.(ConnectionHolder).GetTxn(ctx, options)
2225
}
2326

27+
func GetCatalogTxn(ctx *sql.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
28+
return ctx.Session.(ConnectionHolder).GetCatalogTxn(ctx, options)
29+
}
30+
31+
func TryGetTxn(ctx *sql.Context) *stdsql.Tx {
32+
return ctx.Session.(ConnectionHolder).TryGetTxn()
33+
}
34+
35+
func CloseTxn(ctx *sql.Context) {
36+
ctx.Session.(ConnectionHolder).CloseTxn()
37+
}
38+
2439
func Query(ctx *sql.Context, query string, args ...any) (*stdsql.Rows, error) {
2540
conn, err := GetConn(ctx)
2641
if err != nil {
@@ -75,6 +90,14 @@ func ExecCatalog(ctx *sql.Context, query string, args ...any) (stdsql.Result, er
7590
return conn.ExecContext(ctx, query, args...)
7691
}
7792

93+
func ExecCatalogInTxn(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) {
94+
tx, err := ctx.Session.(ConnectionHolder).GetCatalogTxn(ctx, nil)
95+
if err != nil {
96+
return nil, err
97+
}
98+
return tx.ExecContext(ctx, query, args...)
99+
}
100+
78101
func ExecInTxn(ctx *sql.Context, query string, args ...any) (stdsql.Result, error) {
79102
tx, err := GetTxn(ctx, nil)
80103
if err != nil {

backend/connpool.go

+8
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,14 @@ func (p *ConnectionPool) GetTxn(ctx context.Context, id uint32, schemaName strin
119119
return tx, nil
120120
}
121121

122+
func (p *ConnectionPool) TryGetTxn(id uint32) *stdsql.Tx {
123+
entry, ok := p.txns.Load(id)
124+
if !ok {
125+
return nil
126+
}
127+
return entry.(*stdsql.Tx)
128+
}
129+
122130
func (p *ConnectionPool) CloseTxn(id uint32) {
123131
p.txns.Delete(id)
124132
}

backend/session.go

+17-5
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ func (sess Session) StartTransaction(ctx *sql.Context, tCharacteristic sql.Trans
9090

9191
var tx *stdsql.Tx
9292
if startUnderlyingTx {
93-
sess.GetLogger().Infoln("StartDuckTransaction")
93+
sess.GetLogger().Trace("StartDuckTransaction")
9494
tx, err = sess.GetTxn(ctx, &stdsql.TxOptions{ReadOnly: tCharacteristic == sql.ReadOnly})
9595
if err != nil {
9696
return nil, err
@@ -101,10 +101,10 @@ func (sess Session) StartTransaction(ctx *sql.Context, tCharacteristic sql.Trans
101101

102102
// CommitTransaction implements sql.TransactionSession.
103103
func (sess Session) CommitTransaction(ctx *sql.Context, tx sql.Transaction) error {
104-
sess.GetLogger().Infoln("CommitTransaction")
104+
sess.GetLogger().Trace("CommitTransaction")
105105
transaction := tx.(*Transaction)
106106
if transaction.tx != nil {
107-
sess.GetLogger().Infoln("CommitDuckTransaction")
107+
sess.GetLogger().Trace("CommitDuckTransaction")
108108
defer sess.CloseTxn()
109109
if err := transaction.tx.Commit(); err != nil {
110110
return err
@@ -115,10 +115,10 @@ func (sess Session) CommitTransaction(ctx *sql.Context, tx sql.Transaction) erro
115115

116116
// Rollback implements sql.TransactionSession.
117117
func (sess Session) Rollback(ctx *sql.Context, tx sql.Transaction) error {
118-
sess.GetLogger().Infoln("Rollback")
118+
sess.GetLogger().Trace("Rollback")
119119
transaction := tx.(*Transaction)
120120
if transaction.tx != nil {
121-
sess.GetLogger().Infoln("RollbackDuckTransaction")
121+
sess.GetLogger().Trace("RollbackDuckTransaction")
122122
defer sess.CloseTxn()
123123
if err := transaction.tx.Rollback(); err != nil {
124124
return err
@@ -193,10 +193,22 @@ func (sess Session) GetCatalogConn(ctx context.Context) (*stdsql.Conn, error) {
193193
return sess.pool.GetConn(ctx, sess.ID())
194194
}
195195

196+
// GetTxn implements adapter.ConnectionHolder.
196197
func (sess Session) GetTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
197198
return sess.pool.GetTxn(ctx, sess.ID(), sess.GetCurrentDatabase(), options)
198199
}
199200

201+
// GetCatalogTxn implements adapter.ConnectionHolder.
202+
func (sess Session) GetCatalogTxn(ctx context.Context, options *stdsql.TxOptions) (*stdsql.Tx, error) {
203+
return sess.pool.GetTxn(ctx, sess.ID(), "", options)
204+
}
205+
206+
// TryGetTxn implements adapter.ConnectionHolder.
207+
func (sess Session) TryGetTxn() *stdsql.Tx {
208+
return sess.pool.TryGetTxn(sess.ID())
209+
}
210+
211+
// CloseTxn implements adapter.ConnectionHolder.
200212
func (sess Session) CloseTxn() {
201213
sess.pool.CloseTxn(sess.ID())
202214
}

binlogreplication/binlog_position_store.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import (
3232
)
3333

3434
const binlogPositionDirectory = ".replica"
35-
const binlogPositionFilename = "binlog-position"
3635
const mysqlFlavor = "MySQL56"
3736
const defaultChannelName = ""
3837

@@ -61,10 +60,7 @@ func (store *binlogPositionStore) Load(ctx *sql.Context, engine *gms.Engine) (po
6160
}
6261

6362
// Strip off the "MySQL56/" prefix
64-
prefix := "MySQL56/"
65-
if strings.HasPrefix(positionString, prefix) {
66-
positionString = positionString[len(prefix):]
67-
}
63+
positionString = strings.TrimPrefix(positionString, "MySQL56/")
6864

6965
return replication.ParsePosition(mysqlFlavor, positionString)
7066
}
@@ -81,7 +77,11 @@ func (store *binlogPositionStore) Save(ctx *sql.Context, engine *gms.Engine, pos
8177
store.mu.Lock()
8278
defer store.mu.Unlock()
8379

84-
if _, err := adapter.ExecInTxn(ctx, catalog.InternalTables.BinlogPosition.UpsertStmt(), defaultChannelName, position.String()); err != nil {
80+
if _, err := adapter.ExecCatalogInTxn(
81+
ctx,
82+
catalog.InternalTables.BinlogPosition.UpsertStmt(),
83+
defaultChannelName, position.String(),
84+
); err != nil {
8585
return fmt.Errorf("unable to save binlog position: %w", err)
8686
}
8787
return nil
@@ -94,7 +94,7 @@ func (store *binlogPositionStore) Delete(ctx *sql.Context, engine *gms.Engine) e
9494
store.mu.Lock()
9595
defer store.mu.Unlock()
9696

97-
_, err := adapter.ExecInTxn(ctx, catalog.InternalTables.BinlogPosition.DeleteStmt(), defaultChannelName)
97+
_, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.BinlogPosition.DeleteStmt(), defaultChannelName)
9898
return err
9999
}
100100

0 commit comments

Comments
 (0)