diff --git a/lightning/pkg/importer/precheck_impl.go b/lightning/pkg/importer/precheck_impl.go index 3327813e701fa..d36afc0254978 100644 --- a/lightning/pkg/importer/precheck_impl.go +++ b/lightning/pkg/importer/precheck_impl.go @@ -1296,7 +1296,7 @@ func checkFieldCompatibility( values []types.Datum, logger log.Logger, ) bool { - se := kv.NewSessionCtx(&encode.SessionOptions{ + se := kv.NewSession(&encode.SessionOptions{ SQLMode: mysql.ModeStrictTransTables, }, logger) for i, col := range tbl.Columns { @@ -1307,7 +1307,7 @@ func checkFieldCompatibility( if i >= len(values) { break } - _, err := table.CastValue(se, values[i], col, true, false) + _, err := table.CastColumnValue(se.GetExprCtx(), values[i], col, true, false) if err != nil { logger.Error("field value is not consistent with column type", zap.String("value", values[i].GetString()), zap.Any("column_info", col), zap.Error(err)) diff --git a/pkg/executor/batch_checker.go b/pkg/executor/batch_checker.go index 5bec8e3837558..a3aed84315672 100644 --- a/pkg/executor/batch_checker.go +++ b/pkg/executor/batch_checker.go @@ -280,7 +280,7 @@ func getOldRow(ctx context.Context, sctx sessionctx.Context, txn kv.Transaction, } cols := t.WritableCols() - oldRow, oldRowMap, err := tables.DecodeRawRowData(sctx, t.Meta(), handle, cols, oldValue) + oldRow, oldRowMap, err := tables.DecodeRawRowData(sctx.GetExprCtx(), t.Meta(), handle, cols, oldValue) if err != nil { return nil, err } diff --git a/pkg/executor/importer/BUILD.bazel b/pkg/executor/importer/BUILD.bazel index 4ae10013e1088..bfec2f2411196 100644 --- a/pkg/executor/importer/BUILD.bazel +++ b/pkg/executor/importer/BUILD.bazel @@ -44,6 +44,7 @@ go_library( "//pkg/parser/model", "//pkg/parser/mysql", "//pkg/parser/terror", + "//pkg/planner/context", "//pkg/planner/core", "//pkg/planner/util", "//pkg/sessionctx", diff --git a/pkg/executor/importer/import.go b/pkg/executor/importer/import.go index e95eb79faf97d..be3dfb5f434e6 100644 --- a/pkg/executor/importer/import.go +++ b/pkg/executor/importer/import.go @@ -45,6 +45,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/parser/terror" + planctx "github.com/pingcap/tidb/pkg/planner/context" plannercore "github.com/pingcap/tidb/pkg/planner/core" plannerutil "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/sessionctx" @@ -1291,17 +1292,17 @@ func (p *Plan) IsGlobalSort() bool { // CreateColAssignExprs creates the column assignment expressions using session context. // RewriteAstExpr will write ast node in place(due to xxNode.Accept), but it doesn't change node content, // so we sync it. -func (e *LoadDataController) CreateColAssignExprs(sctx sessionctx.Context) ([]expression.Expression, []contextutil.SQLWarn, error) { +func (e *LoadDataController) CreateColAssignExprs(planCtx planctx.PlanContext) ([]expression.Expression, []contextutil.SQLWarn, error) { e.colAssignMu.Lock() defer e.colAssignMu.Unlock() res := make([]expression.Expression, 0, len(e.ColumnAssignments)) allWarnings := []contextutil.SQLWarn{} for _, assign := range e.ColumnAssignments { - newExpr, err := plannerutil.RewriteAstExprWithPlanCtx(sctx.GetPlanCtx(), assign.Expr, nil, nil, false) + newExpr, err := plannerutil.RewriteAstExprWithPlanCtx(planCtx, assign.Expr, nil, nil, false) // col assign expr warnings is static, we should generate it for each row processed. // so we save it and clear it here. - allWarnings = append(allWarnings, sctx.GetSessionVars().StmtCtx.GetWarnings()...) - sctx.GetSessionVars().StmtCtx.SetWarnings(nil) + allWarnings = append(allWarnings, planCtx.GetSessionVars().StmtCtx.GetWarnings()...) + planCtx.GetSessionVars().StmtCtx.SetWarnings(nil) if err != nil { return nil, nil, err } diff --git a/pkg/executor/importer/kv_encode.go b/pkg/executor/importer/kv_encode.go index a7ea8d89b59e0..6a6d66af42671 100644 --- a/pkg/executor/importer/kv_encode.go +++ b/pkg/executor/importer/kv_encode.go @@ -27,7 +27,6 @@ import ( "github.com/pingcap/tidb/pkg/meta/autoid" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" //nolint: goimports - "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -64,8 +63,8 @@ func NewTableKVEncoder( return nil, err } // we need a non-nil TxnCtx to avoid panic when evaluating set clause - baseKVEncoder.SessionCtx.Vars.TxnCtx = new(variable.TransactionContext) - colAssignExprs, _, err := ti.CreateColAssignExprs(baseKVEncoder.SessionCtx) + baseKVEncoder.SessionCtx.SetTxnCtxNotNil() + colAssignExprs, _, err := ti.CreateColAssignExprs(baseKVEncoder.SessionCtx.GetPlanCtx()) if err != nil { return nil, err } @@ -95,24 +94,20 @@ func (en *tableKVEncoder) Encode(row []types.Datum, rowID int64) (*kv.Pairs, err } func (en *tableKVEncoder) GetColumnSize() map[int64]int64 { - sessionVars := en.SessionCtx.GetSessionVars() - sessionVars.TxnCtxMu.Lock() - defer sessionVars.TxnCtxMu.Unlock() - return sessionVars.TxnCtx.TableDeltaMap[en.TableMeta().ID].ColSize + return en.SessionCtx.GetColumnSize(en.TableMeta().ID) } // todo merge with code in load_data.go func (en *tableKVEncoder) parserData2TableData(parserData []types.Datum, rowID int64) ([]types.Datum, error) { row := make([]types.Datum, 0, len(en.insertColumns)) - sessionVars := en.SessionCtx.GetSessionVars() setVar := func(name string, col *types.Datum) { // User variable names are not case-sensitive // https://dev.mysql.com/doc/refman/8.0/en/user-variables.html name = strings.ToLower(name) if col == nil || col.IsNull() { - sessionVars.UnsetUserVar(name) + en.SessionCtx.UnsetUserVar(name) } else { - sessionVars.SetUserVarVal(name, *col) + en.SessionCtx.SetUserVarVal(name, *col) } } @@ -166,7 +161,7 @@ func (en *tableKVEncoder) getRow(vals []types.Datum, rowID int64) ([]types.Datum row := make([]types.Datum, len(en.Columns)) hasValue := make([]bool, len(en.Columns)) for i := 0; i < len(en.insertColumns); i++ { - casted, err := table.CastValue(en.SessionCtx, vals[i], en.insertColumns[i].ToInfo(), false, false) + casted, err := table.CastColumnValue(en.SessionCtx.GetExprCtx(), vals[i], en.insertColumns[i].ToInfo(), false, false) if err != nil { return nil, err } diff --git a/pkg/executor/load_data.go b/pkg/executor/load_data.go index fde54ed25360e..298cef32299d1 100644 --- a/pkg/executor/load_data.go +++ b/pkg/executor/load_data.go @@ -291,7 +291,7 @@ func initEncodeCommitWorkers(e *LoadDataWorker) (*encodeWorker, *commitWorker, e if err2 != nil { return nil, nil, err2 } - colAssignExprs, exprWarnings, err2 := e.controller.CreateColAssignExprs(insertValues.Ctx()) + colAssignExprs, exprWarnings, err2 := e.controller.CreateColAssignExprs(insertValues.Ctx().GetPlanCtx()) if err2 != nil { return nil, nil, err2 } diff --git a/pkg/lightning/backend/kv/BUILD.bazel b/pkg/lightning/backend/kv/BUILD.bazel index fa510dd6d62a3..6d201fe17adb1 100644 --- a/pkg/lightning/backend/kv/BUILD.bazel +++ b/pkg/lightning/backend/kv/BUILD.bazel @@ -58,13 +58,12 @@ go_test( "base_test.go", "kv2sql_test.go", "session_internal_test.go", - "session_test.go", "sql2kv_test.go", ], embed = [":kv"], flaky = True, race = "on", - shard_count = 19, + shard_count = 18, deps = [ "//pkg/ddl", "//pkg/kv", diff --git a/pkg/lightning/backend/kv/base.go b/pkg/lightning/backend/kv/base.go index 3b4d753a42cc7..0c67492c323d0 100644 --- a/pkg/lightning/backend/kv/base.go +++ b/pkg/lightning/backend/kv/base.go @@ -28,7 +28,6 @@ import ( "github.com/pingcap/tidb/pkg/meta/autoid" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/chunk" @@ -208,11 +207,7 @@ func (e *BaseKVEncoder) Record2KV(record, originalRow []types.Datum, rowID int64 // AddRecord adds a record into encoder func (e *BaseKVEncoder) AddRecord(record []types.Datum) (kv.Handle, error) { - txn, err := e.SessionCtx.Txn(true) - if err != nil { - return nil, err - } - return e.table.AddRecord(e.SessionCtx.GetTableCtx(), txn, record, table.DupKeyCheckSkip) + return e.table.AddRecord(e.SessionCtx.GetTableCtx(), e.SessionCtx.Txn(), record, table.DupKeyCheckSkip) } // TableAllocators returns the allocators of the table @@ -258,8 +253,10 @@ func (e *BaseKVEncoder) getActualDatum(col *table.Column, rowID int64, inputDatu ) isBadNullValue := false + exprCtx := e.SessionCtx.GetExprCtx() + errCtx := exprCtx.GetEvalCtx().ErrCtx() if inputDatum != nil { - value, err = table.CastValue(e.SessionCtx, *inputDatum, col.ToInfo(), false, false) + value, err = table.CastColumnValue(exprCtx, *inputDatum, col.ToInfo(), false, false) if err != nil { return value, err } @@ -272,7 +269,7 @@ func (e *BaseKVEncoder) getActualDatum(col *table.Column, rowID int64, inputDatu switch { case IsAutoIncCol(col.ToInfo()): // we still need a conversion, e.g. to catch overflow with a TINYINT column. - value, err = table.CastValue(e.SessionCtx, + value, err = table.CastColumnValue(exprCtx, types.NewIntDatum(rowID), col.ToInfo(), false, false) case e.IsAutoRandomCol(col.ToInfo()): var val types.Datum @@ -282,21 +279,19 @@ func (e *BaseKVEncoder) getActualDatum(col *table.Column, rowID int64, inputDatu } else { val = types.NewIntDatum(realRowID) } - value, err = table.CastValue(e.SessionCtx, val, col.ToInfo(), false, false) + value, err = table.CastColumnValue(exprCtx, val, col.ToInfo(), false, false) case col.IsGenerated(): // inject some dummy value for gen col so that MutRowFromDatums below sees a real value instead of nil. // if MutRowFromDatums sees a nil it won't initialize the underlying storage and cause SetDatum to panic. value = types.GetMinValue(&col.FieldType) case isBadNullValue: - err = col.HandleBadNull(e.SessionCtx.Vars.StmtCtx.ErrCtx(), &value, 0) + err = col.HandleBadNull(errCtx, &value, 0) default: // copy from the following GetColDefaultValue function, when this is true it will use getColDefaultExprValue if col.DefaultIsExpr { // the expression rewriter requires a non-nil TxnCtx. - e.SessionCtx.Vars.TxnCtx = new(variable.TransactionContext) - defer func() { - e.SessionCtx.Vars.TxnCtx = nil - }() + deferFn := e.SessionCtx.SetTxnCtxNotNil() + defer deferFn() } value, err = table.GetColDefaultValue(e.SessionCtx.GetExprCtx(), col.ToInfo()) } @@ -363,7 +358,7 @@ func (e *BaseKVEncoder) LogEvalGenExprFailed(row []types.Datum, colInfo *model.C // TruncateWarns resets the warnings in session context. func (e *BaseKVEncoder) TruncateWarns() { - e.SessionCtx.Vars.StmtCtx.TruncateWarnings(0) + e.SessionCtx.GetExprCtx().GetEvalCtx().TruncateWarnings(0) } func evalGeneratedColumns(se *Session, record []types.Datum, cols []*table.Column, @@ -375,7 +370,7 @@ func evalGeneratedColumns(se *Session, record []types.Datum, cols []*table.Colum if err != nil { return col, err } - value, err := table.CastValue(se, evaluated, col, false, false) + value, err := table.CastColumnValue(se.GetExprCtx(), evaluated, col, false, false) if err != nil { return col, err } diff --git a/pkg/lightning/backend/kv/kv2sql.go b/pkg/lightning/backend/kv/kv2sql.go index 9a6afa9724f3c..14b8538944478 100644 --- a/pkg/lightning/backend/kv/kv2sql.go +++ b/pkg/lightning/backend/kv/kv2sql.go @@ -54,7 +54,7 @@ func (t *TableKVDecoder) DecodeHandleFromIndex(indexInfo *model.IndexInfo, key, // DecodeRawRowData decodes raw row data into a datum slice and a (columnID:columnValue) map. func (t *TableKVDecoder) DecodeRawRowData(h kv.Handle, value []byte) ([]types.Datum, map[int64]types.Datum, error) { - return tables.DecodeRawRowData(t.se, t.tbl.Meta(), h, t.tbl.Cols(), value) + return tables.DecodeRawRowData(t.se.GetExprCtx(), t.tbl.Meta(), h, t.tbl.Cols(), value) } // DecodeRawRowDataAsStr decodes raw row data into a string. @@ -92,6 +92,8 @@ func (t *TableKVDecoder) IterRawIndexKeys(h kv.Handle, rawRow []byte, fn func([] var buffer []types.Datum var indexBuffer []byte + evalCtx := t.se.GetExprCtx().GetEvalCtx() + ec, loc := evalCtx.ErrCtx(), evalCtx.Location() for _, index := range indices { // skip clustered PK if index.Meta().Primary && isCommonHandle { @@ -102,8 +104,7 @@ func (t *TableKVDecoder) IterRawIndexKeys(h kv.Handle, rawRow []byte, fn func([] if err != nil { return err } - sc := t.se.Vars.StmtCtx - iter := index.GenIndexKVIter(sc.ErrCtx(), sc.TimeZone(), indexValues, h, nil) + iter := index.GenIndexKVIter(ec, loc, indexValues, h, nil) for iter.Valid() { indexKey, _, _, err := iter.Next(indexBuffer, nil) if err != nil { diff --git a/pkg/lightning/backend/kv/kv2sql_test.go b/pkg/lightning/backend/kv/kv2sql_test.go index 411767d0d443f..27939dc65f739 100644 --- a/pkg/lightning/backend/kv/kv2sql_test.go +++ b/pkg/lightning/backend/kv/kv2sql_test.go @@ -53,8 +53,7 @@ func TestIterRawIndexKeysClusteredPK(t *testing.T) { require.NoError(t, err) sctx := kv.NewSession(sessionOpts, log.L()) - txn, err := sctx.Txn(true) - require.NoError(t, err) + txn := sctx.Txn() handle, err := tbl.AddRecord(sctx.GetTableCtx(), txn, []types.Datum{types.NewIntDatum(1), types.NewIntDatum(2)}) require.NoError(t, err) paris := sctx.TakeKvPairs() @@ -94,7 +93,7 @@ func TestIterRawIndexKeysIntPK(t *testing.T) { require.NoError(t, err) sctx := kv.NewSession(sessionOpts, log.L()) - txn, err := sctx.Txn(true) + txn := sctx.Txn() require.NoError(t, err) handle, err := tbl.AddRecord(sctx.GetTableCtx(), txn, []types.Datum{types.NewIntDatum(1), types.NewIntDatum(2)}) require.NoError(t, err) diff --git a/pkg/lightning/backend/kv/session.go b/pkg/lightning/backend/kv/session.go index 9618a567c9279..906d5085196bb 100644 --- a/pkg/lightning/backend/kv/session.go +++ b/pkg/lightning/backend/kv/session.go @@ -20,11 +20,13 @@ import ( "context" "errors" "fmt" + "maps" "strconv" "sync" "github.com/docker/go-units" "github.com/pingcap/tidb/pkg/errctx" + "github.com/pingcap/tidb/pkg/expression" exprctx "github.com/pingcap/tidb/pkg/expression/context" exprctximpl "github.com/pingcap/tidb/pkg/expression/contextsession" infoschema "github.com/pingcap/tidb/pkg/infoschema/context" @@ -40,6 +42,7 @@ import ( "github.com/pingcap/tidb/pkg/sessionctx/variable" tbctx "github.com/pingcap/tidb/pkg/table/context" tbctximpl "github.com/pingcap/tidb/pkg/table/contextimpl" + "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/mathutil" "github.com/pingcap/tidb/pkg/util/topsql/stmtstats" "go.uber.org/zap" @@ -283,14 +286,16 @@ func (*transaction) MayFlush() error { } type planCtxImpl struct { - *Session + *session *planctximpl.PlanCtxExtendedImpl } -// Session is a trimmed down Session type which only wraps our own trimmed-down +// session is a trimmed down Session type which only wraps our own trimmed-down // transaction type and provides the session variables to the TiDB library // optimized for Lightning. -type Session struct { +// The `session` object is private to make sure it is only used by public `Session` struct to provide limited access. +// TODO: remove `session` and build related context without a mocked `sessionctx.Context` instead. +type session struct { sessionctx.Context planctx.EmptyPlanContextExtended txn transaction @@ -302,14 +307,9 @@ type Session struct { values map[fmt.Stringer]any } -// NewSessionCtx creates a new trimmed down Session matching the options. -func NewSessionCtx(options *encode.SessionOptions, logger log.Logger) sessionctx.Context { - return NewSession(options, logger) -} - -// NewSession creates a new trimmed down Session matching the options. -func NewSession(options *encode.SessionOptions, logger log.Logger) *Session { - s := &Session{ +// newSession creates a new trimmed down Session matching the options. +func newSession(options *encode.SessionOptions, logger log.Logger) *session { + s := &session{ values: make(map[fmt.Stringer]any, 1), } sqlMode := options.SQLMode @@ -359,7 +359,7 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session { s.Vars = vars s.exprCtx = exprctximpl.NewSessionExprContext(s) s.planctx = &planCtxImpl{ - Session: s, + session: s, PlanCtxExtendedImpl: planctximpl.NewPlanCtxExtendedImpl(s), } s.tblctx = tbctximpl.NewTableContextImpl(s) @@ -368,69 +368,131 @@ func NewSession(options *encode.SessionOptions, logger log.Logger) *Session { return s } -// TakeKvPairs returns the current Pairs and resets the buffer. -func (se *Session) TakeKvPairs() *Pairs { - memBuf := &se.txn.MemBuf - pairs := memBuf.kvPairs - if pairs.BytesBuf != nil { - pairs.MemBuf = memBuf - } - memBuf.kvPairs = &Pairs{Pairs: make([]common.KvPair, 0, len(pairs.Pairs))} - memBuf.size = 0 - return pairs -} - // Txn implements the sessionctx.Context interface -func (se *Session) Txn(_ bool) (kv.Transaction, error) { +func (se *session) Txn(_ bool) (kv.Transaction, error) { return &se.txn, nil } // GetSessionVars implements the sessionctx.Context interface -func (se *Session) GetSessionVars() *variable.SessionVars { +func (se *session) GetSessionVars() *variable.SessionVars { return se.Vars } // GetPlanCtx returns the PlanContext. -func (se *Session) GetPlanCtx() planctx.PlanContext { +func (se *session) GetPlanCtx() planctx.PlanContext { return se.planctx } // GetExprCtx returns the expression context of the session. -func (se *Session) GetExprCtx() exprctx.ExprContext { +func (se *session) GetExprCtx() exprctx.ExprContext { return se.exprCtx } // GetTableCtx returns the table.MutateContext -func (se *Session) GetTableCtx() tbctx.MutateContext { +func (se *session) GetTableCtx() tbctx.MutateContext { return se.tblctx } // SetValue saves a value associated with this context for key. -func (se *Session) SetValue(key fmt.Stringer, value any) { +func (se *session) SetValue(key fmt.Stringer, value any) { se.values[key] = value } // Value returns the value associated with this context for key. -func (se *Session) Value(key fmt.Stringer) any { +func (se *session) Value(key fmt.Stringer) any { return se.values[key] } // StmtAddDirtyTableOP implements the sessionctx.Context interface -func (*Session) StmtAddDirtyTableOP(_ int, _ int64, _ kv.Handle) {} +func (*session) StmtAddDirtyTableOP(_ int, _ int64, _ kv.Handle) {} // GetInfoSchema implements the sessionctx.Context interface. -func (*Session) GetInfoSchema() infoschema.MetaOnlyInfoSchema { +func (*session) GetInfoSchema() infoschema.MetaOnlyInfoSchema { return nil } // GetStmtStats implements the sessionctx.Context interface. -func (*Session) GetStmtStats() *stmtstats.StatementStats { +func (*session) GetStmtStats() *stmtstats.StatementStats { return nil } +// Session is used to provide context for lightning. +type Session struct { + sctx *session +} + +// NewSession creates a new Session. +func NewSession(options *encode.SessionOptions, logger log.Logger) *Session { + return &Session{ + sctx: newSession(options, logger), + } +} + +// GetExprCtx returns the expression context +func (s *Session) GetExprCtx() expression.BuildContext { + return s.sctx.GetExprCtx() +} + +// Txn returns the internal txn. +func (s *Session) Txn() kv.Transaction { + return &s.sctx.txn +} + +// GetTableCtx returns the table MutateContext. +func (s *Session) GetTableCtx() tbctx.MutateContext { + return s.sctx.tblctx +} + +// GetPlanCtx returns the context for planner. +func (s *Session) GetPlanCtx() planctx.PlanContext { + return s.sctx.planctx +} + +// TakeKvPairs returns the current Pairs and resets the buffer. +func (s *Session) TakeKvPairs() *Pairs { + memBuf := &s.sctx.txn.MemBuf + pairs := memBuf.kvPairs + if pairs.BytesBuf != nil { + pairs.MemBuf = memBuf + } + memBuf.kvPairs = &Pairs{Pairs: make([]common.KvPair, 0, len(pairs.Pairs))} + memBuf.size = 0 + return pairs +} + +// SetTxnCtxNotNil sets the internal SessionVars.TxnCtx to a non-nil value to avoid some panics. +// TODO: remove it after code refactoring. +func (s *Session) SetTxnCtxNotNil() func() { + s.sctx.Vars.TxnCtx = new(variable.TransactionContext) + return func() { + s.sctx.Vars.TxnCtx = nil + } +} + +// SetUserVarVal sets the value of a user variable. +func (s *Session) SetUserVarVal(name string, dt types.Datum) { + s.sctx.Vars.SetUserVarVal(name, dt) +} + +// UnsetUserVar unsets a user variable. +func (s *Session) UnsetUserVar(varName string) { + s.sctx.Vars.UnsetUserVar(varName) +} + +// GetColumnSize returns the size of each column. +func (s *Session) GetColumnSize(tblID int64) (ret map[int64]int64) { + vars := s.sctx.Vars + vars.TxnCtxMu.Lock() + defer vars.TxnCtxMu.Unlock() + if txnCtx := s.sctx.Vars.TxnCtx; txnCtx != nil { + return maps.Clone(txnCtx.TableDeltaMap[tblID].ColSize) + } + return ret +} + // Close implements the sessionctx.Context interface -func (se *Session) Close() { - memBuf := &se.txn.MemBuf +func (s *Session) Close() { + memBuf := &s.sctx.txn.MemBuf if memBuf.buf != nil { memBuf.buf.destroy() memBuf.buf = nil diff --git a/pkg/lightning/backend/kv/session_test.go b/pkg/lightning/backend/kv/session_test.go deleted file mode 100644 index f432a5c98bed3..0000000000000 --- a/pkg/lightning/backend/kv/session_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2019 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package kv_test - -import ( - "testing" - - "github.com/pingcap/tidb/pkg/lightning/backend/encode" - "github.com/pingcap/tidb/pkg/lightning/backend/kv" - "github.com/pingcap/tidb/pkg/lightning/log" - "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/stretchr/testify/require" -) - -func TestSession(t *testing.T) { - session := kv.NewSessionCtx(&encode.SessionOptions{SQLMode: mysql.ModeNone, Timestamp: 1234567890}, log.L()) - _, err := session.Txn(true) - require.NoError(t, err) -} diff --git a/pkg/lightning/backend/kv/sql2kv.go b/pkg/lightning/backend/kv/sql2kv.go index 1413bdc827949..1321266e1174e 100644 --- a/pkg/lightning/backend/kv/sql2kv.go +++ b/pkg/lightning/backend/kv/sql2kv.go @@ -32,8 +32,6 @@ import ( "github.com/pingcap/tidb/pkg/meta/autoid" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" //nolint: goimports - "github.com/pingcap/tidb/pkg/sessionctx" - "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/tablecodec" "github.com/pingcap/tidb/pkg/types" @@ -45,7 +43,7 @@ type tableKVEncoder struct { } // GetSession4test is only used for test. -func GetSession4test(encoder encode.Encoder) sessionctx.Context { +func GetSession4test(encoder encode.Encoder) *Session { return encoder.(*tableKVEncoder).SessionCtx } @@ -84,10 +82,9 @@ func CollectGeneratedColumns(se *Session, meta *model.TableInfo, cols []*table.C } // the expression rewriter requires a non-nil TxnCtx. - se.Vars.TxnCtx = new(variable.TransactionContext) - defer func() { - se.Vars.TxnCtx = nil - }() + // TODO: remove it after code refactoring. + deferFn := se.SetTxnCtxNotNil() + defer deferFn() // not using TableInfo2SchemaAndNames to avoid parsing all virtual generated columns again. exprColumns := make([]*expression.Column, 0, len(cols)) @@ -250,7 +247,7 @@ func (kvcodec *tableKVEncoder) Encode(row []types.Datum, rowValue := rowID j := columnPermutation[len(kvcodec.Columns)] if j >= 0 && j < len(row) { - value, err = table.CastValue(kvcodec.SessionCtx, row[j], + value, err = table.CastColumnValue(kvcodec.SessionCtx.GetExprCtx(), row[j], ExtraHandleColumnInfo, false, false) rowValue = value.GetInt64() } else { diff --git a/pkg/lightning/backend/local/duplicate_test.go b/pkg/lightning/backend/local/duplicate_test.go index 8a42e7f6f69c4..fc17ec9293308 100644 --- a/pkg/lightning/backend/local/duplicate_test.go +++ b/pkg/lightning/backend/local/duplicate_test.go @@ -94,7 +94,7 @@ func buildTableForTestConvertToErrFoundConflictRecords(t *testing.T, node []ast. Logger: log.L(), }) require.NoError(t, err) - encoder.SessionCtx.GetSessionVars().RowEncoder.Enable = true + encoder.SessionCtx.GetTableCtx().GetRowEncodingConfig().RowEncoder.Enable = true data1 := []types.Datum{ types.NewIntDatum(1), diff --git a/pkg/lightning/backend/tidb/BUILD.bazel b/pkg/lightning/backend/tidb/BUILD.bazel index 7fc382b7b46ee..d65342fd2a725 100644 --- a/pkg/lightning/backend/tidb/BUILD.bazel +++ b/pkg/lightning/backend/tidb/BUILD.bazel @@ -18,7 +18,6 @@ go_library( "//pkg/lightning/verification", "//pkg/parser/model", "//pkg/parser/mysql", - "//pkg/sessionctx", "//pkg/table", "//pkg/types", "//pkg/util/dbutil", diff --git a/pkg/lightning/backend/tidb/tidb.go b/pkg/lightning/backend/tidb/tidb.go index 186e17d9d6d06..2c70cc1df82b9 100644 --- a/pkg/lightning/backend/tidb/tidb.go +++ b/pkg/lightning/backend/tidb/tidb.go @@ -39,7 +39,6 @@ import ( "github.com/pingcap/tidb/pkg/lightning/verification" "github.com/pingcap/tidb/pkg/parser/model" "github.com/pingcap/tidb/pkg/parser/mysql" - "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/table" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util/dbutil" @@ -83,7 +82,6 @@ func (rows tidbRows) MarshalLogArray(encoder zapcore.ArrayEncoder) error { type tidbEncoder struct { mode mysql.SQLMode tbl table.Table - se sessionctx.Context // the index of table columns for each data field. // index == len(table.columns) means this field is `_tidb_rowid` columnIdx []int @@ -105,17 +103,10 @@ func NewEncodingBuilder() encode.EncodingBuilder { // NewEncoder creates a KV encoder. // It implements the `backend.EncodingBuilder` interface. -func (*encodingBuilder) NewEncoder(ctx context.Context, config *encode.EncodingConfig) (encode.Encoder, error) { - se := kv.NewSessionCtx(&config.SessionOptions, log.FromContext(ctx)) - if config.SQLMode.HasStrictMode() { - se.GetSessionVars().SkipUTF8Check = false - se.GetSessionVars().SkipASCIICheck = false - } - +func (*encodingBuilder) NewEncoder(_ context.Context, config *encode.EncodingConfig) (encode.Encoder, error) { return &tidbEncoder{ mode: config.SQLMode, tbl: config.Table, - se: se, path: config.Path, logger: config.Logger, }, nil diff --git a/pkg/lightning/errormanager/errormanager.go b/pkg/lightning/errormanager/errormanager.go index 21a77131d8f85..7c83d7e486a6a 100644 --- a/pkg/lightning/errormanager/errormanager.go +++ b/pkg/lightning/errormanager/errormanager.go @@ -611,7 +611,7 @@ func (em *ErrorManager) ReplaceConflictKeys( if err != nil { return errors.Trace(err) } - decodedData, _, err := tables.DecodeRawRowData(encoder.SessionCtx, + decodedData, _, err := tables.DecodeRawRowData(encoder.SessionCtx.GetExprCtx(), tbl.Meta(), overwrittenHandle, tbl.Cols(), overwritten) if err != nil { return errors.Trace(err) @@ -791,7 +791,7 @@ func (em *ErrorManager) ReplaceConflictKeys( if err != nil { return errors.Trace(err) } - decodedData, _, err := tables.DecodeRawRowData(encoder.SessionCtx, + decodedData, _, err := tables.DecodeRawRowData(encoder.SessionCtx.GetExprCtx(), tbl.Meta(), handle, tbl.Cols(), latestValue) if err != nil { return errors.Trace(err) @@ -820,7 +820,7 @@ func (em *ErrorManager) ReplaceConflictKeys( if err != nil { return errors.Trace(err) } - decodedData, _, err := tables.DecodeRawRowData(encoder.SessionCtx, + decodedData, _, err := tables.DecodeRawRowData(encoder.SessionCtx.GetExprCtx(), tbl.Meta(), handle, tbl.Cols(), rawValue) if err != nil { return errors.Trace(err) diff --git a/pkg/lightning/errormanager/errormanager_test.go b/pkg/lightning/errormanager/errormanager_test.go index e611f24d3b1da..ed353b3629c55 100644 --- a/pkg/lightning/errormanager/errormanager_test.go +++ b/pkg/lightning/errormanager/errormanager_test.go @@ -231,7 +231,7 @@ func TestReplaceConflictOneKey(t *testing.T) { Logger: log.L(), }) require.NoError(t, err) - encoder.SessionCtx.GetSessionVars().RowEncoder.Enable = true + encoder.SessionCtx.GetTableCtx().GetRowEncodingConfig().RowEncoder.Enable = true data1 := []types.Datum{ types.NewIntDatum(1), @@ -420,7 +420,7 @@ func TestReplaceConflictOneUniqueKey(t *testing.T) { Logger: log.L(), }) require.NoError(t, err) - encoder.SessionCtx.GetSessionVars().RowEncoder.Enable = true + encoder.SessionCtx.GetTableCtx().GetRowEncodingConfig().RowEncoder.Enable = true data1 := []types.Datum{ types.NewIntDatum(1), diff --git a/pkg/lightning/errormanager/resolveconflict_test.go b/pkg/lightning/errormanager/resolveconflict_test.go index 117420e84a90c..a76c7418aa54c 100644 --- a/pkg/lightning/errormanager/resolveconflict_test.go +++ b/pkg/lightning/errormanager/resolveconflict_test.go @@ -67,7 +67,7 @@ func TestReplaceConflictMultipleKeysNonclusteredPk(t *testing.T) { Logger: log.L(), }) require.NoError(t, err) - encoder.SessionCtx.GetSessionVars().RowEncoder.Enable = true + encoder.SessionCtx.GetTableCtx().GetRowEncodingConfig().RowEncoder.Enable = true data1 := []types.Datum{ types.NewIntDatum(1), @@ -288,7 +288,7 @@ func TestReplaceConflictOneKeyNonclusteredPk(t *testing.T) { Logger: log.L(), }) require.NoError(t, err) - encoder.SessionCtx.GetSessionVars().RowEncoder.Enable = true + encoder.SessionCtx.GetTableCtx().GetRowEncodingConfig().RowEncoder.Enable = true data1 := []types.Datum{ types.NewIntDatum(1), @@ -456,7 +456,7 @@ func TestReplaceConflictOneUniqueKeyNonclusteredPk(t *testing.T) { Logger: log.L(), }) require.NoError(t, err) - encoder.SessionCtx.GetSessionVars().RowEncoder.Enable = true + encoder.SessionCtx.GetTableCtx().GetRowEncodingConfig().RowEncoder.Enable = true data1 := []types.Datum{ types.NewIntDatum(1), @@ -662,7 +662,7 @@ func TestReplaceConflictOneUniqueKeyNonclusteredVarcharPk(t *testing.T) { Logger: log.L(), }) require.NoError(t, err) - encoder.SessionCtx.GetSessionVars().RowEncoder.Enable = true + encoder.SessionCtx.GetTableCtx().GetRowEncodingConfig().RowEncoder.Enable = true data1 := []types.Datum{ types.NewStringDatum("x"), diff --git a/pkg/table/tables/index_test.go b/pkg/table/tables/index_test.go index 32ff3ed2e7411..7819f6805dfa6 100644 --- a/pkg/table/tables/index_test.go +++ b/pkg/table/tables/index_test.go @@ -190,7 +190,7 @@ func TestGenIndexValueFromIndex(t *testing.T) { SessionOptions: sessionOpts, }) require.NoError(t, err) - encoder.SessionCtx.GetSessionVars().RowEncoder.Enable = true + encoder.SessionCtx.GetTableCtx().GetRowEncodingConfig().RowEncoder.Enable = true data1 := []types.Datum{ types.NewIntDatum(1), diff --git a/pkg/table/tables/tables.go b/pkg/table/tables/tables.go index d6eaa5daf6db8..623e7ef729b9f 100644 --- a/pkg/table/tables/tables.go +++ b/pkg/table/tables/tables.go @@ -1052,7 +1052,7 @@ func RowWithCols(t table.Table, ctx sessionctx.Context, h kv.Handle, cols []*tab if err != nil { return nil, err } - v, _, err := DecodeRawRowData(ctx, t.Meta(), h, cols, value) + v, _, err := DecodeRawRowData(ctx.GetExprCtx(), t.Meta(), h, cols, value) if err != nil { return nil, err } @@ -1072,7 +1072,7 @@ func containFullColInHandle(meta *model.TableInfo, col *table.Column) (containFu } // DecodeRawRowData decodes raw row data into a datum slice and a (columnID:columnValue) map. -func DecodeRawRowData(ctx sessionctx.Context, meta *model.TableInfo, h kv.Handle, cols []*table.Column, +func DecodeRawRowData(ctx expression.BuildContext, meta *model.TableInfo, h kv.Handle, cols []*table.Column, value []byte) ([]types.Datum, map[int64]types.Datum, error) { v := make([]types.Datum, len(cols)) colTps := make(map[int64]*types.FieldType, len(cols)) @@ -1096,7 +1096,7 @@ func DecodeRawRowData(ctx sessionctx.Context, meta *model.TableInfo, h kv.Handle if err != nil { return nil, nil, err } - dt, err = tablecodec.Unflatten(dt, &col.FieldType, ctx.GetSessionVars().Location()) + dt, err = tablecodec.Unflatten(dt, &col.FieldType, ctx.GetEvalCtx().Location()) if err != nil { return nil, nil, err } @@ -1107,7 +1107,7 @@ func DecodeRawRowData(ctx sessionctx.Context, meta *model.TableInfo, h kv.Handle } colTps[col.ID] = &col.FieldType } - rowMap, err := tablecodec.DecodeRowToDatumMap(value, colTps, ctx.GetSessionVars().Location()) + rowMap, err := tablecodec.DecodeRowToDatumMap(value, colTps, ctx.GetEvalCtx().Location()) if err != nil { return nil, rowMap, err } @@ -1130,9 +1130,9 @@ func DecodeRawRowData(ctx sessionctx.Context, meta *model.TableInfo, h kv.Handle continue } if col.ChangeStateInfo != nil { - v[i], _, err = GetChangingColVal(ctx.GetExprCtx(), cols, col, rowMap, defaultVals) + v[i], _, err = GetChangingColVal(ctx, cols, col, rowMap, defaultVals) } else { - v[i], err = GetColDefaultValue(ctx.GetExprCtx(), col, defaultVals) + v[i], err = GetColDefaultValue(ctx, col, defaultVals) } if err != nil { return nil, rowMap, err