Skip to content

Commit

Permalink
executor: Implement batch point get for local temporary table (pingca…
Browse files Browse the repository at this point in the history
  • Loading branch information
lcwangchao authored Jul 19, 2021
1 parent 014005a commit 2f028b3
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 6 deletions.
37 changes: 31 additions & 6 deletions executor/batch_point_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ func (e *BatchPointGetExec) Open(context.Context) error {
setResourceGroupTagForTxn(stmtCtx, snapshot)
// Avoid network requests for the temporary table.
if e.tblInfo.TempTableType == model.TempTableGlobal {
snapshot = globalTemporaryTableSnapshot{snapshot}
snapshot = temporaryTableSnapshot{snapshot, nil}
} else if e.tblInfo.TempTableType == model.TempTableLocal {
snapshot = temporaryTableSnapshot{snapshot, e.ctx.GetSessionVars().TemporaryTableData}
}
var batchGetter kv.BatchGetter = snapshot
if txn.Valid() {
Expand All @@ -166,14 +168,37 @@ func (e *BatchPointGetExec) Open(context.Context) error {
return nil
}

// Global temporary table would always be empty, so get the snapshot data of it is meanless.
// globalTemporaryTableSnapshot inherits kv.Snapshot and override the BatchGet methods to return empty.
type globalTemporaryTableSnapshot struct {
// Temporary table would always use memBuffer in session as snapshot.
// temporaryTableSnapshot inherits kv.Snapshot and override the BatchGet methods to return empty.
type temporaryTableSnapshot struct {
kv.Snapshot
memBuffer kv.MemBuffer
}

func (s globalTemporaryTableSnapshot) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) {
return make(map[string][]byte), nil
func (s temporaryTableSnapshot) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) {
values := make(map[string][]byte)
if s.memBuffer == nil {
return values, nil
}

for _, key := range keys {
val, err := s.memBuffer.Get(ctx, key)
if err == kv.ErrNotExist {
continue
}

if err != nil {
return nil, err
}

if len(val) == 0 {
continue
}

values[string(key)] = val
}

return values, nil
}

// Close implements the Executor interface.
Expand Down
46 changes: 46 additions & 0 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4986,6 +4986,7 @@ func (s *testSessionSuite) TestLocalTemporaryTablePointGet(c *C) {
tk.MustExec("create temporary table tmp1 (id int primary key auto_increment, u int unique, v int)")
tk.MustExec("insert into tmp1 values(1, 11, 101)")
tk.MustExec("insert into tmp1 values(2, 12, 102)")
tk.MustExec("insert into tmp1 values(4, 14, 104)")

// check point get out transaction
tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101"))
Expand All @@ -5004,10 +5005,55 @@ func (s *testSessionSuite) TestLocalTemporaryTablePointGet(c *C) {
tk.MustQuery("select * from tmp1 where u=13").Check(testkit.Rows("3 13 103"))
tk.MustExec("update tmp1 set v=999 where id=2")
tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 999"))
tk.MustExec("delete from tmp1 where id=4")
tk.MustQuery("select * from tmp1 where id=4").Check(testkit.Rows())
tk.MustQuery("select * from tmp1 where u=14").Check(testkit.Rows())
tk.MustExec("commit")

// check point get after transaction
tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 103"))
tk.MustQuery("select * from tmp1 where u=13").Check(testkit.Rows("3 13 103"))
tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 999"))
tk.MustQuery("select * from tmp1 where id=4").Check(testkit.Rows())
tk.MustQuery("select * from tmp1 where u=14").Check(testkit.Rows())
}

func (s *testSessionSuite) TestLocalTemporaryTableBatchPointGet(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("set @@tidb_enable_noop_functions=1")
tk.MustExec("use test")
tk.MustExec("create temporary table tmp1 (id int primary key auto_increment, u int unique, v int)")
tk.MustExec("insert into tmp1 values(1, 11, 101)")
tk.MustExec("insert into tmp1 values(2, 12, 102)")
tk.MustExec("insert into tmp1 values(3, 13, 103)")
tk.MustExec("insert into tmp1 values(4, 14, 104)")

// check point get out transaction
tk.MustQuery("select * from tmp1 where id in (1, 3)").Check(testkit.Rows("1 11 101", "3 13 103"))
tk.MustQuery("select * from tmp1 where u in (11, 13)").Check(testkit.Rows("1 11 101", "3 13 103"))
tk.MustQuery("select * from tmp1 where id in (1, 3, 5)").Check(testkit.Rows("1 11 101", "3 13 103"))
tk.MustQuery("select * from tmp1 where u in (11, 13, 15)").Check(testkit.Rows("1 11 101", "3 13 103"))

// check point get in transaction
tk.MustExec("begin")
tk.MustQuery("select * from tmp1 where id in (1, 3)").Check(testkit.Rows("1 11 101", "3 13 103"))
tk.MustQuery("select * from tmp1 where u in (11, 13)").Check(testkit.Rows("1 11 101", "3 13 103"))
tk.MustQuery("select * from tmp1 where id in (1, 3, 5)").Check(testkit.Rows("1 11 101", "3 13 103"))
tk.MustQuery("select * from tmp1 where u in (11, 13, 15)").Check(testkit.Rows("1 11 101", "3 13 103"))
tk.MustExec("insert into tmp1 values(6, 16, 106)")
tk.MustQuery("select * from tmp1 where id in (1, 6)").Check(testkit.Rows("1 11 101", "6 16 106"))
tk.MustQuery("select * from tmp1 where u in (11, 16)").Check(testkit.Rows("1 11 101", "6 16 106"))
tk.MustExec("update tmp1 set v=999 where id=3")
tk.MustQuery("select * from tmp1 where id in (1, 3)").Check(testkit.Rows("1 11 101", "3 13 999"))
tk.MustQuery("select * from tmp1 where u in (11, 13)").Check(testkit.Rows("1 11 101", "3 13 999"))
tk.MustExec("delete from tmp1 where id=4")
tk.MustQuery("select * from tmp1 where id in (1, 4)").Check(testkit.Rows("1 11 101"))
tk.MustQuery("select * from tmp1 where u in (11, 14)").Check(testkit.Rows("1 11 101"))
tk.MustExec("commit")

// check point get after transaction
tk.MustQuery("select * from tmp1 where id in (1, 3, 6)").Check(testkit.Rows("1 11 101", "3 13 999", "6 16 106"))
tk.MustQuery("select * from tmp1 where u in (11, 13, 16)").Check(testkit.Rows("1 11 101", "3 13 999", "6 16 106"))
tk.MustQuery("select * from tmp1 where id in (1, 4)").Check(testkit.Rows("1 11 101"))
tk.MustQuery("select * from tmp1 where u in (11, 14)").Check(testkit.Rows("1 11 101"))
}

0 comments on commit 2f028b3

Please sign in to comment.