From 2b192e31492ad1e706448e1733db28c37bb8a9d9 Mon Sep 17 00:00:00 2001 From: crazycs Date: Thu, 15 Jul 2021 16:13:36 +0800 Subject: [PATCH] topsql: refine collect information for DDL and internal SQL (#26047) --- ddl/backfilling.go | 11 ++- ddl/column.go | 14 +-- ddl/ddl.go | 2 +- ddl/ddl_api.go | 127 +++++++++++++------------- ddl/ddl_worker.go | 40 ++++++-- ddl/delete_range.go | 66 ++++++------- ddl/failtest/fail_db_test.go | 2 +- ddl/mock.go | 4 +- ddl/partition.go | 4 +- ddl/reorg.go | 4 +- ddl/restart_test.go | 1 + ddl/table.go | 2 +- ddl/util/util.go | 20 ++-- executor/adapter.go | 5 +- executor/ddl.go | 6 +- executor/ddl_test.go | 16 ++-- executor/executor.go | 2 +- executor/grant.go | 2 +- executor/infoschema_reader.go | 44 ++++----- executor/prepared.go | 2 +- executor/revoke.go | 2 +- executor/show.go | 28 +++--- executor/simple.go | 74 +++++++-------- go.mod | 2 +- go.sum | 4 +- server/conn_stmt.go | 4 +- session/session.go | 15 ++- statistics/handle/handle.go | 5 - util/topsql/reporter/client.go | 4 +- util/topsql/reporter/mock/server.go | 17 +++- util/topsql/reporter/reporter.go | 17 +++- util/topsql/reporter/reporter_test.go | 43 ++++++++- util/topsql/topsql.go | 8 +- util/topsql/topsql_test.go | 12 +-- util/topsql/tracecpu/mock/mock.go | 2 +- 35 files changed, 351 insertions(+), 260 deletions(-) diff --git a/ddl/backfilling.go b/ddl/backfilling.go index 0a9542ff4f39a..a75eb7ce508fb 100644 --- a/ddl/backfilling.go +++ b/ddl/backfilling.go @@ -286,7 +286,7 @@ func (w *backfillWorker) handleBackfillTask(d *ddlCtx, task *reorgBackfillTask, return result } -func (w *backfillWorker) run(d *ddlCtx, bf backfiller) { +func (w *backfillWorker) run(d *ddlCtx, bf backfiller, job *model.Job) { logutil.BgLogger().Info("[ddl] backfill worker start", zap.Int("workerID", w.id)) defer func() { w.resultCh <- &backfillResult{err: errReorgPanic} @@ -297,6 +297,7 @@ func (w *backfillWorker) run(d *ddlCtx, bf backfiller) { if !more { break } + w.ddlWorker.setDDLLabelForTopSQL(job) logutil.BgLogger().Debug("[ddl] backfill worker got task", zap.Int("workerID", w.id), zap.String("task", task.String())) failpoint.Inject("mockBackfillRunErr", func() { @@ -497,7 +498,7 @@ func loadDDLReorgVars(w *worker) error { return errors.Trace(err) } defer w.sessPool.put(ctx) - return ddlutil.LoadDDLReorgVars(ctx) + return ddlutil.LoadDDLReorgVars(w.ddlJobCtx, ctx) } func makeupDecodeColMap(sessCtx sessionctx.Context, t table.Table) (map[int64]decoder.Column, error) { @@ -599,17 +600,17 @@ func (w *worker) writePhysicalTableRecord(t table.PhysicalTable, bfWorkerType ba idxWorker := newAddIndexWorker(sessCtx, w, i, t, indexInfo, decodeColMap, reorgInfo.ReorgMeta.SQLMode) idxWorker.priority = job.Priority backfillWorkers = append(backfillWorkers, idxWorker.backfillWorker) - go idxWorker.backfillWorker.run(reorgInfo.d, idxWorker) + go idxWorker.backfillWorker.run(reorgInfo.d, idxWorker, job) case typeUpdateColumnWorker: updateWorker := newUpdateColumnWorker(sessCtx, w, i, t, oldColInfo, colInfo, decodeColMap, reorgInfo.ReorgMeta.SQLMode) updateWorker.priority = job.Priority backfillWorkers = append(backfillWorkers, updateWorker.backfillWorker) - go updateWorker.backfillWorker.run(reorgInfo.d, updateWorker) + go updateWorker.backfillWorker.run(reorgInfo.d, updateWorker, job) case typeCleanUpIndexWorker: idxWorker := newCleanUpIndexWorker(sessCtx, w, i, t, decodeColMap, reorgInfo.ReorgMeta.SQLMode) idxWorker.priority = job.Priority backfillWorkers = append(backfillWorkers, idxWorker.backfillWorker) - go idxWorker.backfillWorker.run(reorgInfo.d, idxWorker) + go idxWorker.backfillWorker.run(reorgInfo.d, idxWorker, job) default: return errors.New("unknow backfill type") } diff --git a/ddl/column.go b/ddl/column.go index 41d07af7e1b07..d50f2618d6f99 100644 --- a/ddl/column.go +++ b/ddl/column.go @@ -1646,7 +1646,7 @@ func applyNewAutoRandomBits(d *ddlCtx, m *meta.Meta, dbInfo *model.DBInfo, // checkForNullValue ensure there are no null values of the column of this table. // `isDataTruncated` indicates whether the new field and the old field type are the same, in order to be compatible with mysql. -func checkForNullValue(ctx sessionctx.Context, isDataTruncated bool, schema, table, newCol model.CIStr, oldCols ...*model.ColumnInfo) error { +func checkForNullValue(ctx context.Context, sctx sessionctx.Context, isDataTruncated bool, schema, table, newCol model.CIStr, oldCols ...*model.ColumnInfo) error { var buf strings.Builder buf.WriteString("select 1 from %n.%n where ") paramsList := make([]interface{}, 0, 2+len(oldCols)) @@ -1661,11 +1661,11 @@ func checkForNullValue(ctx sessionctx.Context, isDataTruncated bool, schema, tab } } buf.WriteString(" limit 1") - stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), buf.String(), paramsList...) + stmt, err := sctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(ctx, buf.String(), paramsList...) if err != nil { return errors.Trace(err) } - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) + rows, _, err := sctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(ctx, stmt) if err != nil { return errors.Trace(err) } @@ -1785,12 +1785,12 @@ func rollbackModifyColumnJob(t *meta.Meta, tblInfo *model.TableInfo, job *model. func modifyColsFromNull2NotNull(w *worker, dbInfo *model.DBInfo, tblInfo *model.TableInfo, cols []*model.ColumnInfo, newColName model.CIStr, isDataTruncated bool) error { // Get sessionctx from context resource pool. - var ctx sessionctx.Context - ctx, err := w.sessPool.get() + var sctx sessionctx.Context + sctx, err := w.sessPool.get() if err != nil { return errors.Trace(err) } - defer w.sessPool.put(ctx) + defer w.sessPool.put(sctx) skipCheck := false failpoint.Inject("skipMockContextDoExec", func(val failpoint.Value) { @@ -1800,7 +1800,7 @@ func modifyColsFromNull2NotNull(w *worker, dbInfo *model.DBInfo, tblInfo *model. }) if !skipCheck { // If there is a null value inserted, it cannot be modified and needs to be rollback. - err = checkForNullValue(ctx, isDataTruncated, dbInfo.Name, tblInfo.Name, newColName, cols...) + err = checkForNullValue(w.ddlJobCtx, sctx, isDataTruncated, dbInfo.Name, tblInfo.Name, newColName, cols...) if err != nil { return errors.Trace(err) } diff --git a/ddl/ddl.go b/ddl/ddl.go index 9eb05b86741ed..043f58ca99b98 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -102,7 +102,7 @@ type DDL interface { CreateIndex(ctx sessionctx.Context, tableIdent ast.Ident, keyType ast.IndexKeyType, indexName model.CIStr, columnNames []*ast.IndexPartSpecification, indexOption *ast.IndexOption, ifNotExists bool) error DropIndex(ctx sessionctx.Context, tableIdent ast.Ident, indexName model.CIStr, ifExists bool) error - AlterTable(ctx sessionctx.Context, tableIdent ast.Ident, spec []*ast.AlterTableSpec) error + AlterTable(ctx context.Context, sctx sessionctx.Context, tableIdent ast.Ident, spec []*ast.AlterTableSpec) error TruncateTable(ctx sessionctx.Context, tableIdent ast.Ident) error RenameTable(ctx sessionctx.Context, oldTableIdent, newTableIdent ast.Ident, isAlterTable bool) error RenameTables(ctx sessionctx.Context, oldTableIdent, newTableIdent []ast.Ident, isAlterTable bool) error diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 91fb1a215e337..b964aab7bfde6 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -18,6 +18,7 @@ package ddl import ( + "context" "fmt" "math" "strconv" @@ -2388,8 +2389,8 @@ func isSameTypeMultiSpecs(specs []*ast.AlterTableSpec) bool { return true } -func (d *ddl) AlterTable(ctx sessionctx.Context, ident ast.Ident, specs []*ast.AlterTableSpec) (err error) { - validSpecs, err := resolveAlterTableSpec(ctx, specs) +func (d *ddl) AlterTable(ctx context.Context, sctx sessionctx.Context, ident ast.Ident, specs []*ast.AlterTableSpec) (err error) { + validSpecs, err := resolveAlterTableSpec(sctx, specs) if err != nil { return errors.Trace(err) } @@ -2400,15 +2401,15 @@ func (d *ddl) AlterTable(ctx sessionctx.Context, ident ast.Ident, specs []*ast.A } if len(validSpecs) > 1 { - if !ctx.GetSessionVars().EnableChangeMultiSchema { + if !sctx.GetSessionVars().EnableChangeMultiSchema { return errRunMultiSchemaChanges } if isSameTypeMultiSpecs(validSpecs) { switch validSpecs[0].Tp { case ast.AlterTableAddColumns: - err = d.AddColumns(ctx, ident, validSpecs) + err = d.AddColumns(sctx, ident, validSpecs) case ast.AlterTableDropColumn: - err = d.DropColumns(ctx, ident, validSpecs) + err = d.DropColumns(sctx, ident, validSpecs) default: return errRunMultiSchemaChanges } @@ -2425,14 +2426,14 @@ func (d *ddl) AlterTable(ctx sessionctx.Context, ident ast.Ident, specs []*ast.A switch spec.Tp { case ast.AlterTableAddColumns: if len(spec.NewColumns) != 1 { - err = d.AddColumns(ctx, ident, []*ast.AlterTableSpec{spec}) + err = d.AddColumns(sctx, ident, []*ast.AlterTableSpec{spec}) } else { - err = d.AddColumn(ctx, ident, spec) + err = d.AddColumn(sctx, ident, spec) } case ast.AlterTableAddPartitions: - err = d.AddTablePartitions(ctx, ident, spec) + err = d.AddTablePartitions(sctx, ident, spec) case ast.AlterTableCoalescePartitions: - err = d.CoalescePartitions(ctx, ident, spec) + err = d.CoalescePartitions(sctx, ident, spec) case ast.AlterTableReorganizePartition: err = errors.Trace(errUnsupportedReorganizePartition) case ast.AlterTableCheckPartitions: @@ -2446,24 +2447,24 @@ func (d *ddl) AlterTable(ctx sessionctx.Context, ident ast.Ident, specs []*ast.A case ast.AlterTableRepairPartition: err = errors.Trace(errUnsupportedRepairPartition) case ast.AlterTableDropColumn: - err = d.DropColumn(ctx, ident, spec) + err = d.DropColumn(sctx, ident, spec) case ast.AlterTableDropIndex: - err = d.DropIndex(ctx, ident, model.NewCIStr(spec.Name), spec.IfExists) + err = d.DropIndex(sctx, ident, model.NewCIStr(spec.Name), spec.IfExists) case ast.AlterTableDropPrimaryKey: - err = d.DropIndex(ctx, ident, model.NewCIStr(mysql.PrimaryKeyName), spec.IfExists) + err = d.DropIndex(sctx, ident, model.NewCIStr(mysql.PrimaryKeyName), spec.IfExists) case ast.AlterTableRenameIndex: - err = d.RenameIndex(ctx, ident, spec) + err = d.RenameIndex(sctx, ident, spec) case ast.AlterTableDropPartition: - err = d.DropTablePartition(ctx, ident, spec) + err = d.DropTablePartition(sctx, ident, spec) case ast.AlterTableTruncatePartition: - err = d.TruncateTablePartition(ctx, ident, spec) + err = d.TruncateTablePartition(sctx, ident, spec) case ast.AlterTableWriteable: if !config.TableLockEnabled() { return nil } tName := &ast.TableName{Schema: ident.Schema, Name: ident.Name} if spec.Writeable { - err = d.CleanupTableLock(ctx, []*ast.TableName{tName}) + err = d.CleanupTableLock(sctx, []*ast.TableName{tName}) } else { lockStmt := &ast.LockTablesStmt{ TableLocks: []ast.TableLock{ @@ -2473,50 +2474,50 @@ func (d *ddl) AlterTable(ctx sessionctx.Context, ident ast.Ident, specs []*ast.A }, }, } - err = d.LockTables(ctx, lockStmt) + err = d.LockTables(sctx, lockStmt) } case ast.AlterTableExchangePartition: - err = d.ExchangeTablePartition(ctx, ident, spec) + err = d.ExchangeTablePartition(sctx, ident, spec) case ast.AlterTableAddConstraint: constr := spec.Constraint switch spec.Constraint.Tp { case ast.ConstraintKey, ast.ConstraintIndex: - err = d.CreateIndex(ctx, ident, ast.IndexKeyTypeNone, model.NewCIStr(constr.Name), + err = d.CreateIndex(sctx, ident, ast.IndexKeyTypeNone, model.NewCIStr(constr.Name), spec.Constraint.Keys, constr.Option, constr.IfNotExists) case ast.ConstraintUniq, ast.ConstraintUniqIndex, ast.ConstraintUniqKey: - err = d.CreateIndex(ctx, ident, ast.IndexKeyTypeUnique, model.NewCIStr(constr.Name), + err = d.CreateIndex(sctx, ident, ast.IndexKeyTypeUnique, model.NewCIStr(constr.Name), spec.Constraint.Keys, constr.Option, false) // IfNotExists should be not applied case ast.ConstraintForeignKey: // NOTE: we do not handle `symbol` and `index_name` well in the parser and we do not check ForeignKey already exists, // so we just also ignore the `if not exists` check. - err = d.CreateForeignKey(ctx, ident, model.NewCIStr(constr.Name), spec.Constraint.Keys, spec.Constraint.Refer) + err = d.CreateForeignKey(sctx, ident, model.NewCIStr(constr.Name), spec.Constraint.Keys, spec.Constraint.Refer) case ast.ConstraintPrimaryKey: - err = d.CreatePrimaryKey(ctx, ident, model.NewCIStr(constr.Name), spec.Constraint.Keys, constr.Option) + err = d.CreatePrimaryKey(sctx, ident, model.NewCIStr(constr.Name), spec.Constraint.Keys, constr.Option) case ast.ConstraintFulltext: - ctx.GetSessionVars().StmtCtx.AppendWarning(ErrTableCantHandleFt) + sctx.GetSessionVars().StmtCtx.AppendWarning(ErrTableCantHandleFt) case ast.ConstraintCheck: - ctx.GetSessionVars().StmtCtx.AppendWarning(ErrUnsupportedConstraintCheck.GenWithStackByArgs("ADD CONSTRAINT CHECK")) + sctx.GetSessionVars().StmtCtx.AppendWarning(ErrUnsupportedConstraintCheck.GenWithStackByArgs("ADD CONSTRAINT CHECK")) default: // Nothing to do now. } case ast.AlterTableDropForeignKey: // NOTE: we do not check `if not exists` and `if exists` for ForeignKey now. - err = d.DropForeignKey(ctx, ident, model.NewCIStr(spec.Name)) + err = d.DropForeignKey(sctx, ident, model.NewCIStr(spec.Name)) case ast.AlterTableModifyColumn: - err = d.ModifyColumn(ctx, ident, spec) + err = d.ModifyColumn(ctx, sctx, ident, spec) case ast.AlterTableChangeColumn: - err = d.ChangeColumn(ctx, ident, spec) + err = d.ChangeColumn(ctx, sctx, ident, spec) case ast.AlterTableRenameColumn: - err = d.RenameColumn(ctx, ident, spec) + err = d.RenameColumn(sctx, ident, spec) case ast.AlterTableAlterColumn: - err = d.AlterColumn(ctx, ident, spec) + err = d.AlterColumn(sctx, ident, spec) case ast.AlterTableRenameTable: newIdent := ast.Ident{Schema: spec.NewTable.Schema, Name: spec.NewTable.Name} isAlterTable := true - err = d.RenameTable(ctx, ident, newIdent, isAlterTable) + err = d.RenameTable(sctx, ident, newIdent, isAlterTable) case ast.AlterTableAlterPartition: - if ctx.GetSessionVars().EnableAlterPlacement { - err = d.AlterTableAlterPartition(ctx, ident, spec) + if sctx.GetSessionVars().EnableAlterPlacement { + err = d.AlterTableAlterPartition(sctx, ident, spec) } else { err = errors.New("alter partition alter placement is experimental and it is switched off by tidb_enable_alter_placement") } @@ -2530,20 +2531,20 @@ func (d *ddl) AlterTable(ctx sessionctx.Context, ident ast.Ident, specs []*ast.A if opt.UintValue > shardRowIDBitsMax { opt.UintValue = shardRowIDBitsMax } - err = d.ShardRowID(ctx, ident, opt.UintValue) + err = d.ShardRowID(sctx, ident, opt.UintValue) case ast.TableOptionAutoIncrement: - err = d.RebaseAutoID(ctx, ident, int64(opt.UintValue), autoid.RowIDAllocType, opt.BoolValue) + err = d.RebaseAutoID(sctx, ident, int64(opt.UintValue), autoid.RowIDAllocType, opt.BoolValue) case ast.TableOptionAutoIdCache: if opt.UintValue > uint64(math.MaxInt64) { // TODO: Refine this error. return errors.New("table option auto_id_cache overflows int64") } - err = d.AlterTableAutoIDCache(ctx, ident, int64(opt.UintValue)) + err = d.AlterTableAutoIDCache(sctx, ident, int64(opt.UintValue)) case ast.TableOptionAutoRandomBase: - err = d.RebaseAutoID(ctx, ident, int64(opt.UintValue), autoid.AutoRandomType, opt.BoolValue) + err = d.RebaseAutoID(sctx, ident, int64(opt.UintValue), autoid.AutoRandomType, opt.BoolValue) case ast.TableOptionComment: spec.Comment = opt.StrValue - err = d.AlterTableComment(ctx, ident, spec) + err = d.AlterTableComment(sctx, ident, spec) case ast.TableOptionCharset, ast.TableOptionCollate: // getCharsetAndCollateInTableOption will get the last charset and collate in the options, // so it should be handled only once. @@ -2556,7 +2557,7 @@ func (d *ddl) AlterTable(ctx sessionctx.Context, ident ast.Ident, specs []*ast.A return err } needsOverwriteCols := needToOverwriteColCharset(spec.Options) - err = d.AlterTableCharsetAndCollate(ctx, ident, toCharset, toCollate, needsOverwriteCols) + err = d.AlterTableCharsetAndCollate(sctx, ident, toCharset, toCollate, needsOverwriteCols) handledCharsetOrCollate = true default: err = errUnsupportedAlterTableOption @@ -2567,23 +2568,23 @@ func (d *ddl) AlterTable(ctx sessionctx.Context, ident ast.Ident, specs []*ast.A } } case ast.AlterTableSetTiFlashReplica: - err = d.AlterTableSetTiFlashReplica(ctx, ident, spec.TiFlashReplica) + err = d.AlterTableSetTiFlashReplica(sctx, ident, spec.TiFlashReplica) case ast.AlterTableOrderByColumns: - err = d.OrderByColumns(ctx, ident) + err = d.OrderByColumns(sctx, ident) case ast.AlterTableIndexInvisible: - err = d.AlterIndexVisibility(ctx, ident, spec.IndexName, spec.Visibility) + err = d.AlterIndexVisibility(sctx, ident, spec.IndexName, spec.Visibility) case ast.AlterTableAlterCheck: - ctx.GetSessionVars().StmtCtx.AppendWarning(ErrUnsupportedConstraintCheck.GenWithStackByArgs("ALTER CHECK")) + sctx.GetSessionVars().StmtCtx.AppendWarning(ErrUnsupportedConstraintCheck.GenWithStackByArgs("ALTER CHECK")) case ast.AlterTableDropCheck: - ctx.GetSessionVars().StmtCtx.AppendWarning(ErrUnsupportedConstraintCheck.GenWithStackByArgs("DROP CHECK")) + sctx.GetSessionVars().StmtCtx.AppendWarning(ErrUnsupportedConstraintCheck.GenWithStackByArgs("DROP CHECK")) case ast.AlterTableWithValidation: - ctx.GetSessionVars().StmtCtx.AppendWarning(errUnsupportedAlterTableWithValidation) + sctx.GetSessionVars().StmtCtx.AppendWarning(errUnsupportedAlterTableWithValidation) case ast.AlterTableWithoutValidation: - ctx.GetSessionVars().StmtCtx.AppendWarning(errUnsupportedAlterTableWithoutValidation) + sctx.GetSessionVars().StmtCtx.AppendWarning(errUnsupportedAlterTableWithoutValidation) case ast.AlterTableAddStatistics: - err = d.AlterTableAddStatistics(ctx, ident, spec.Statistics, spec.IfNotExists) + err = d.AlterTableAddStatistics(sctx, ident, spec.Statistics, spec.IfNotExists) case ast.AlterTableDropStatistics: - err = d.AlterTableDropStatistics(ctx, ident, spec.Statistics, spec.IfExists) + err = d.AlterTableDropStatistics(sctx, ident, spec.Statistics, spec.IfExists) default: // Nothing to do now. } @@ -3789,7 +3790,7 @@ func processAndCheckDefaultValueAndColumn(ctx sessionctx.Context, col *table.Col return nil } -func (d *ddl) getModifiableColumnJob(ctx sessionctx.Context, ident ast.Ident, originalColName model.CIStr, +func (d *ddl) getModifiableColumnJob(ctx context.Context, sctx sessionctx.Context, ident ast.Ident, originalColName model.CIStr, spec *ast.AlterTableSpec) (*model.Job, error) { specNewColumn := spec.NewColumns[0] is := d.infoCache.GetLatest() @@ -3894,11 +3895,11 @@ func (d *ddl) getModifiableColumnJob(ctx sessionctx.Context, ident ast.Ident, or // TODO: If user explicitly set NULL, we should throw error ErrPrimaryCantHaveNull. } - if err = processColumnOptions(ctx, newCol, specNewColumn.Options); err != nil { + if err = processColumnOptions(sctx, newCol, specNewColumn.Options); err != nil { return nil, errors.Trace(err) } - if err = checkModifyTypes(ctx, &col.FieldType, &newCol.FieldType, isColumnWithIndex(col.Name.L, t.Meta().Indices)); err != nil { + if err = checkModifyTypes(sctx, &col.FieldType, &newCol.FieldType, isColumnWithIndex(col.Name.L, t.Meta().Indices)); err != nil { if strings.Contains(err.Error(), "Unsupported modifying collation") { colErrMsg := "Unsupported modifying collation of column '%s' from '%s' to '%s' when index is defined on it." err = errUnsupportedModifyCollation.GenWithStack(colErrMsg, col.Name.L, col.Collate, newCol.Collate) @@ -3919,14 +3920,14 @@ func (d *ddl) getModifiableColumnJob(ctx sessionctx.Context, ident ast.Ident, or return nil, errUnsupportedModifyColumn.GenWithStackByArgs("can't set auto_increment") } // Disallow modifying column from auto_increment to not auto_increment if the session variable `AllowRemoveAutoInc` is false. - if !ctx.GetSessionVars().AllowRemoveAutoInc && mysql.HasAutoIncrementFlag(col.Flag) && !mysql.HasAutoIncrementFlag(newCol.Flag) { + if !sctx.GetSessionVars().AllowRemoveAutoInc && mysql.HasAutoIncrementFlag(col.Flag) && !mysql.HasAutoIncrementFlag(newCol.Flag) { return nil, errUnsupportedModifyColumn.GenWithStackByArgs("can't remove auto_increment without @@tidb_allow_remove_auto_inc enabled") } // We support modifying the type definitions of 'null' to 'not null' now. var modifyColumnTp byte if !mysql.HasNotNullFlag(col.Flag) && mysql.HasNotNullFlag(newCol.Flag) { - if err = checkForNullValue(ctx, true, ident.Schema, ident.Name, newCol.Name, col.ColumnInfo); err != nil { + if err = checkForNullValue(ctx, sctx, true, ident.Schema, ident.Name, newCol.Name, col.ColumnInfo); err != nil { return nil, errors.Trace(err) } // `modifyColumnTp` indicates that there is a type modification. @@ -3954,7 +3955,7 @@ func (d *ddl) getModifiableColumnJob(ctx sessionctx.Context, ident ast.Ident, or Type: model.ActionModifyColumn, BinlogInfo: &model.HistoryInfo{}, ReorgMeta: &model.DDLReorgMeta{ - SQLMode: ctx.GetSessionVars().SQLMode, + SQLMode: sctx.GetSessionVars().SQLMode, Warnings: make(map[errors.ErrorID]*terror.Error), WarningsCount: make(map[errors.ErrorID]int64), }, @@ -4118,7 +4119,7 @@ func checkAutoRandom(tableInfo *model.TableInfo, originCol *table.Column, specNe // ChangeColumn renames an existing column and modifies the column's definition, // currently we only support limited kind of changes // that do not need to change or check data on the table. -func (d *ddl) ChangeColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { +func (d *ddl) ChangeColumn(ctx context.Context, sctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { specNewColumn := spec.NewColumns[0] if len(specNewColumn.Name.Schema.O) != 0 && ident.Schema.L != specNewColumn.Name.Schema.L { return ErrWrongDBName.GenWithStackByArgs(specNewColumn.Name.Schema.O) @@ -4133,19 +4134,19 @@ func (d *ddl) ChangeColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Al return ErrWrongTableName.GenWithStackByArgs(spec.OldColumnName.Table.O) } - job, err := d.getModifiableColumnJob(ctx, ident, spec.OldColumnName.Name, spec) + job, err := d.getModifiableColumnJob(ctx, sctx, ident, spec.OldColumnName.Name, spec) if err != nil { if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { - ctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.GenWithStackByArgs(spec.OldColumnName.Name, ident.Name)) + sctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.GenWithStackByArgs(spec.OldColumnName.Name, ident.Name)) return nil } return errors.Trace(err) } - err = d.doDDLJob(ctx, job) + err = d.doDDLJob(sctx, job) // column not exists, but if_exists flags is true, so we ignore this error. if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) + sctx.GetSessionVars().StmtCtx.AppendNote(err) return nil } err = d.callHookOnChanged(err) @@ -4218,7 +4219,7 @@ func (d *ddl) RenameColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Al // ModifyColumn does modification on an existing column, currently we only support limited kind of changes // that do not need to change or check data on the table. -func (d *ddl) ModifyColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { +func (d *ddl) ModifyColumn(ctx context.Context, sctx sessionctx.Context, ident ast.Ident, spec *ast.AlterTableSpec) error { specNewColumn := spec.NewColumns[0] if len(specNewColumn.Name.Schema.O) != 0 && ident.Schema.L != specNewColumn.Name.Schema.L { return ErrWrongDBName.GenWithStackByArgs(specNewColumn.Name.Schema.O) @@ -4228,19 +4229,19 @@ func (d *ddl) ModifyColumn(ctx sessionctx.Context, ident ast.Ident, spec *ast.Al } originalColName := specNewColumn.Name.Name - job, err := d.getModifiableColumnJob(ctx, ident, originalColName, spec) + job, err := d.getModifiableColumnJob(ctx, sctx, ident, originalColName, spec) if err != nil { if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { - ctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.GenWithStackByArgs(originalColName, ident.Name)) + sctx.GetSessionVars().StmtCtx.AppendNote(infoschema.ErrColumnNotExists.GenWithStackByArgs(originalColName, ident.Name)) return nil } return errors.Trace(err) } - err = d.doDDLJob(ctx, job) + err = d.doDDLJob(sctx, job) // column not exists, but if_exists flags is true, so we ignore this error. if infoschema.ErrColumnNotExists.Equal(err) && spec.IfExists { - ctx.GetSessionVars().StmtCtx.AppendNote(err) + sctx.GetSessionVars().StmtCtx.AppendNote(err) return nil } err = d.callHookOnChanged(err) diff --git a/ddl/ddl_worker.go b/ddl/ddl_worker.go index ba6da4fdc7ede..5eaad49acc789 100644 --- a/ddl/ddl_worker.go +++ b/ddl/ddl_worker.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/failpoint" + "github.com/pingcap/parser" "github.com/pingcap/parser/model" "github.com/pingcap/parser/terror" pumpcli "github.com/pingcap/tidb-tools/tidb-binlog/pump_client" @@ -37,6 +38,7 @@ import ( "github.com/pingcap/tidb/util/admin" "github.com/pingcap/tidb/util/dbterror" "github.com/pingcap/tidb/util/logutil" + "github.com/pingcap/tidb/util/topsql" "go.etcd.io/etcd/clientv3" "go.uber.org/zap" ) @@ -87,6 +89,17 @@ type worker struct { reorgCtx *reorgCtx // reorgCtx is used for reorganization. delRangeManager delRangeManager logCtx context.Context + + ddlJobCache +} + +// ddlJobCache is a cache for each DDL job. +type ddlJobCache struct { + // below fields are cache for top sql + ddlJobCtx context.Context + cacheSQL string + cacheNormalizedSQL string + cacheDigest *parser.Digest } func newWorker(ctx context.Context, tp workerType, sessPool *sessionPool, delRangeMgr delRangeManager) *worker { @@ -95,6 +108,7 @@ func newWorker(ctx context.Context, tp workerType, sessPool *sessionPool, delRan tp: tp, ddlJobCh: make(chan struct{}, 1), ctx: ctx, + ddlJobCache: ddlJobCache{ddlJobCtx: context.Background()}, reorgCtx: &reorgCtx{notifyCancelReorgJob: 0}, sessPool: sessPool, delRangeManager: delRangeMgr, @@ -354,10 +368,10 @@ func (w *worker) updateDDLJob(t *meta.Meta, job *model.Job, meetErr bool) error return errors.Trace(t.UpdateDDLJob(0, job, updateRawArgs)) } -func (w *worker) deleteRange(job *model.Job) error { +func (w *worker) deleteRange(ctx context.Context, job *model.Job) error { var err error if job.Version <= currentVersion { - err = w.delRangeManager.addDelRangeJob(job) + err = w.delRangeManager.addDelRangeJob(ctx, job) } else { err = errInvalidDDLJobVersion.GenWithStackByArgs(job.Version, currentVersion) } @@ -380,14 +394,14 @@ func (w *worker) finishDDLJob(t *meta.Meta, job *model.Job) (err error) { } // After rolling back an AddIndex operation, we need to use delete-range to delete the half-done index data. - err = w.deleteRange(job) + err = w.deleteRange(w.ddlJobCtx, job) case model.ActionDropSchema, model.ActionDropTable, model.ActionTruncateTable, model.ActionDropIndex, model.ActionDropPrimaryKey, model.ActionDropTablePartition, model.ActionTruncateTablePartition, model.ActionDropColumn, model.ActionDropColumns, model.ActionModifyColumn: - err = w.deleteRange(job) + err = w.deleteRange(w.ddlJobCtx, job) } } if job.Type == model.ActionRecoverTable { - err = finishRecoverTable(w, t, job) + err = finishRecoverTable(w, job) } if err != nil { return errors.Trace(err) @@ -410,7 +424,7 @@ func (w *worker) finishDDLJob(t *meta.Meta, job *model.Job) (err error) { return errors.Trace(err) } -func finishRecoverTable(w *worker, t *meta.Meta, job *model.Job) error { +func finishRecoverTable(w *worker, job *model.Job) error { tbInfo := &model.TableInfo{} var autoIncID, autoRandID, dropJobID, recoverTableCheckFlag int64 var snapshotTS uint64 @@ -451,6 +465,19 @@ func newMetaWithQueueTp(txn kv.Transaction, tp workerType) *meta.Meta { return meta.NewMeta(txn) } +func (w *worker) setDDLLabelForTopSQL(job *model.Job) { + if !variable.TopSQLEnabled() || job == nil { + return + } + + if job.Query != w.cacheSQL { + w.cacheNormalizedSQL, w.cacheDigest = parser.NormalizeDigest(job.Query) + w.cacheSQL = job.Query + } + + w.ddlJobCtx = topsql.AttachSQLInfo(context.Background(), w.cacheNormalizedSQL, w.cacheDigest, "", nil, false) +} + // handleDDLJobQueue handles DDL jobs in DDL Job queue. func (w *worker) handleDDLJobQueue(d *ddlCtx) error { once := true @@ -479,6 +506,7 @@ func (w *worker) handleDDLJobQueue(d *ddlCtx) error { if job == nil || err != nil { return errors.Trace(err) } + w.setDDLLabelForTopSQL(job) if isDone, err1 := isDependencyJobDone(t, job); err1 != nil || !isDone { return errors.Trace(err1) } diff --git a/ddl/delete_range.go b/ddl/delete_range.go index 1aeb5ab0354da..7e350d2cd5203 100644 --- a/ddl/delete_range.go +++ b/ddl/delete_range.go @@ -52,10 +52,10 @@ var ( type delRangeManager interface { // addDelRangeJob add a DDL job into gc_delete_range table. - addDelRangeJob(job *model.Job) error + addDelRangeJob(ctx context.Context, job *model.Job) error // removeFromGCDeleteRange removes the deleting table job from gc_delete_range table by jobID and tableID. // It's use for recover the table that was mistakenly deleted. - removeFromGCDeleteRange(jobID int64, tableID []int64) error + removeFromGCDeleteRange(ctx context.Context, jobID int64, tableID []int64) error start() clear() } @@ -87,14 +87,14 @@ func newDelRangeManager(store kv.Storage, sessPool *sessionPool) delRangeManager } // addDelRangeJob implements delRangeManager interface. -func (dr *delRange) addDelRangeJob(job *model.Job) error { - ctx, err := dr.sessPool.get() +func (dr *delRange) addDelRangeJob(ctx context.Context, job *model.Job) error { + sctx, err := dr.sessPool.get() if err != nil { return errors.Trace(err) } - defer dr.sessPool.put(ctx) + defer dr.sessPool.put(sctx) - err = insertJobIntoDeleteRangeTable(ctx, job) + err = insertJobIntoDeleteRangeTable(ctx, sctx, job) if err != nil { logutil.BgLogger().Error("[ddl] add job into delete-range table failed", zap.Int64("jobID", job.ID), zap.String("jobType", job.Type.String()), zap.Error(err)) return errors.Trace(err) @@ -107,13 +107,13 @@ func (dr *delRange) addDelRangeJob(job *model.Job) error { } // removeFromGCDeleteRange implements delRangeManager interface. -func (dr *delRange) removeFromGCDeleteRange(jobID int64, tableIDs []int64) error { - ctx, err := dr.sessPool.get() +func (dr *delRange) removeFromGCDeleteRange(ctx context.Context, jobID int64, tableIDs []int64) error { + sctx, err := dr.sessPool.get() if err != nil { return errors.Trace(err) } - defer dr.sessPool.put(ctx) - err = util.RemoveMultiFromGCDeleteRange(ctx, jobID, tableIDs) + defer dr.sessPool.put(sctx) + err = util.RemoveMultiFromGCDeleteRange(ctx, sctx, jobID, tableIDs) return errors.Trace(err) } @@ -245,13 +245,13 @@ func (dr *delRange) doTask(ctx sessionctx.Context, r util.DelRangeTask) error { // insertJobIntoDeleteRangeTable parses the job into delete-range arguments, // and inserts a new record into gc_delete_range table. The primary key is // job ID, so we ignore key conflict error. -func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error { - now, err := getNowTSO(ctx) +func insertJobIntoDeleteRangeTable(ctx context.Context, sctx sessionctx.Context, job *model.Job) error { + now, err := getNowTSO(sctx) if err != nil { return errors.Trace(err) } - s := ctx.(sqlexec.SQLExecutor) + s := sctx.(sqlexec.SQLExecutor) switch job.Type { case model.ActionDropSchema: var tableIDs []int64 @@ -263,7 +263,7 @@ func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error if batchEnd > i+batchInsertDeleteRangeSize { batchEnd = i + batchInsertDeleteRangeSize } - if err := doBatchInsert(s, job.ID, tableIDs[i:batchEnd], now); err != nil { + if err := doBatchInsert(ctx, s, job.ID, tableIDs[i:batchEnd], now); err != nil { return errors.Trace(err) } } @@ -279,7 +279,7 @@ func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error for _, pid := range physicalTableIDs { startKey = tablecodec.EncodeTablePrefix(pid) endKey := tablecodec.EncodeTablePrefix(pid + 1) - if err := doInsert(s, job.ID, pid, startKey, endKey, now); err != nil { + if err := doInsert(ctx, s, job.ID, pid, startKey, endKey, now); err != nil { return errors.Trace(err) } } @@ -287,7 +287,7 @@ func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error } startKey = tablecodec.EncodeTablePrefix(tableID) endKey := tablecodec.EncodeTablePrefix(tableID + 1) - return doInsert(s, job.ID, tableID, startKey, endKey, now) + return doInsert(ctx, s, job.ID, tableID, startKey, endKey, now) case model.ActionDropTablePartition, model.ActionTruncateTablePartition: var physicalTableIDs []int64 if err := job.DecodeArgs(&physicalTableIDs); err != nil { @@ -296,7 +296,7 @@ func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error for _, physicalTableID := range physicalTableIDs { startKey := tablecodec.EncodeTablePrefix(physicalTableID) endKey := tablecodec.EncodeTablePrefix(physicalTableID + 1) - if err := doInsert(s, job.ID, physicalTableID, startKey, endKey, now); err != nil { + if err := doInsert(ctx, s, job.ID, physicalTableID, startKey, endKey, now); err != nil { return errors.Trace(err) } } @@ -312,14 +312,14 @@ func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error for _, pid := range partitionIDs { startKey := tablecodec.EncodeTableIndexPrefix(pid, indexID) endKey := tablecodec.EncodeTableIndexPrefix(pid, indexID+1) - if err := doInsert(s, job.ID, indexID, startKey, endKey, now); err != nil { + if err := doInsert(ctx, s, job.ID, indexID, startKey, endKey, now); err != nil { return errors.Trace(err) } } } else { startKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID) endKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID+1) - return doInsert(s, job.ID, indexID, startKey, endKey, now) + return doInsert(ctx, s, job.ID, indexID, startKey, endKey, now) } case model.ActionDropIndex, model.ActionDropPrimaryKey: tableID := job.TableID @@ -333,14 +333,14 @@ func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error for _, pid := range partitionIDs { startKey := tablecodec.EncodeTableIndexPrefix(pid, indexID) endKey := tablecodec.EncodeTableIndexPrefix(pid, indexID+1) - if err := doInsert(s, job.ID, indexID, startKey, endKey, now); err != nil { + if err := doInsert(ctx, s, job.ID, indexID, startKey, endKey, now); err != nil { return errors.Trace(err) } } } else { startKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID) endKey := tablecodec.EncodeTableIndexPrefix(tableID, indexID+1) - return doInsert(s, job.ID, indexID, startKey, endKey, now) + return doInsert(ctx, s, job.ID, indexID, startKey, endKey, now) } case model.ActionDropColumn: var colName model.CIStr @@ -352,12 +352,12 @@ func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error if len(indexIDs) > 0 { if len(partitionIDs) > 0 { for _, pid := range partitionIDs { - if err := doBatchDeleteIndiceRange(s, job.ID, pid, indexIDs, now); err != nil { + if err := doBatchDeleteIndiceRange(ctx, s, job.ID, pid, indexIDs, now); err != nil { return errors.Trace(err) } } } else { - return doBatchDeleteIndiceRange(s, job.ID, job.TableID, indexIDs, now) + return doBatchDeleteIndiceRange(ctx, s, job.ID, job.TableID, indexIDs, now) } } case model.ActionDropColumns: @@ -371,12 +371,12 @@ func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error if len(indexIDs) > 0 { if len(partitionIDs) > 0 { for _, pid := range partitionIDs { - if err := doBatchDeleteIndiceRange(s, job.ID, pid, indexIDs, now); err != nil { + if err := doBatchDeleteIndiceRange(ctx, s, job.ID, pid, indexIDs, now); err != nil { return errors.Trace(err) } } } else { - return doBatchDeleteIndiceRange(s, job.ID, job.TableID, indexIDs, now) + return doBatchDeleteIndiceRange(ctx, s, job.ID, job.TableID, indexIDs, now) } } case model.ActionModifyColumn: @@ -389,10 +389,10 @@ func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error return nil } if len(partitionIDs) == 0 { - return doBatchDeleteIndiceRange(s, job.ID, job.TableID, indexIDs, now) + return doBatchDeleteIndiceRange(ctx, s, job.ID, job.TableID, indexIDs, now) } for _, pid := range partitionIDs { - if err := doBatchDeleteIndiceRange(s, job.ID, pid, indexIDs, now); err != nil { + if err := doBatchDeleteIndiceRange(ctx, s, job.ID, pid, indexIDs, now); err != nil { return errors.Trace(err) } } @@ -400,7 +400,7 @@ func insertJobIntoDeleteRangeTable(ctx sessionctx.Context, job *model.Job) error return nil } -func doBatchDeleteIndiceRange(s sqlexec.SQLExecutor, jobID, tableID int64, indexIDs []int64, ts uint64) error { +func doBatchDeleteIndiceRange(ctx context.Context, s sqlexec.SQLExecutor, jobID, tableID int64, indexIDs []int64, ts uint64) error { logutil.BgLogger().Info("[ddl] batch insert into delete-range indices", zap.Int64("jobID", jobID), zap.Int64s("elementIDs", indexIDs)) paramsList := make([]interface{}, 0, len(indexIDs)*5) var buf strings.Builder @@ -416,19 +416,19 @@ func doBatchDeleteIndiceRange(s sqlexec.SQLExecutor, jobID, tableID int64, index } paramsList = append(paramsList, jobID, indexID, startKeyEncoded, endKeyEncoded, ts) } - _, err := s.ExecuteInternal(context.Background(), buf.String(), paramsList...) + _, err := s.ExecuteInternal(ctx, buf.String(), paramsList...) return errors.Trace(err) } -func doInsert(s sqlexec.SQLExecutor, jobID int64, elementID int64, startKey, endKey kv.Key, ts uint64) error { +func doInsert(ctx context.Context, s sqlexec.SQLExecutor, jobID int64, elementID int64, startKey, endKey kv.Key, ts uint64) error { logutil.BgLogger().Info("[ddl] insert into delete-range table", zap.Int64("jobID", jobID), zap.Int64("elementID", elementID)) startKeyEncoded := hex.EncodeToString(startKey) endKeyEncoded := hex.EncodeToString(endKey) - _, err := s.ExecuteInternal(context.Background(), insertDeleteRangeSQL, jobID, elementID, startKeyEncoded, endKeyEncoded, ts) + _, err := s.ExecuteInternal(ctx, insertDeleteRangeSQL, jobID, elementID, startKeyEncoded, endKeyEncoded, ts) return errors.Trace(err) } -func doBatchInsert(s sqlexec.SQLExecutor, jobID int64, tableIDs []int64, ts uint64) error { +func doBatchInsert(ctx context.Context, s sqlexec.SQLExecutor, jobID int64, tableIDs []int64, ts uint64) error { logutil.BgLogger().Info("[ddl] batch insert into delete-range table", zap.Int64("jobID", jobID), zap.Int64s("elementIDs", tableIDs)) var buf strings.Builder buf.WriteString(insertDeleteRangeSQLPrefix) @@ -444,7 +444,7 @@ func doBatchInsert(s sqlexec.SQLExecutor, jobID int64, tableIDs []int64, ts uint } paramsList = append(paramsList, jobID, tableID, startKeyEncoded, endKeyEncoded, ts) } - _, err := s.ExecuteInternal(context.Background(), buf.String(), paramsList...) + _, err := s.ExecuteInternal(ctx, buf.String(), paramsList...) return errors.Trace(err) } diff --git a/ddl/failtest/fail_db_test.go b/ddl/failtest/fail_db_test.go index 0581e7e87b2cd..b2ae168855365 100644 --- a/ddl/failtest/fail_db_test.go +++ b/ddl/failtest/fail_db_test.go @@ -380,7 +380,7 @@ func (s *testFailDBSuite) TestAddIndexWorkerNum(c *C) { tableStart := tablecodec.GenTableRecordPrefix(tbl.Meta().ID) s.cluster.SplitKeys(tableStart, tableStart.PrefixNext(), splitCount) - err = ddlutil.LoadDDLReorgVars(tk.Se) + err = ddlutil.LoadDDLReorgVars(context.Background(), tk.Se) c.Assert(err, IsNil) originDDLAddIndexWorkerCnt := variable.GetDDLReorgWorkerCounter() lastSetWorkerCnt := originDDLAddIndexWorkerCnt diff --git a/ddl/mock.go b/ddl/mock.go index df6a6ce753586..64345d84d9c1b 100644 --- a/ddl/mock.go +++ b/ddl/mock.go @@ -136,12 +136,12 @@ func newMockDelRangeManager() delRangeManager { } // addDelRangeJob implements delRangeManager interface. -func (dr *mockDelRange) addDelRangeJob(job *model.Job) error { +func (dr *mockDelRange) addDelRangeJob(ctx context.Context, job *model.Job) error { return nil } // removeFromGCDeleteRange implements delRangeManager interface. -func (dr *mockDelRange) removeFromGCDeleteRange(jobID int64, tableIDs []int64) error { +func (dr *mockDelRange) removeFromGCDeleteRange(ctx context.Context, jobID int64, tableIDs []int64) error { return nil } diff --git a/ddl/partition.go b/ddl/partition.go index 5ec161abf3168..a3872bbcf0cc0 100644 --- a/ddl/partition.go +++ b/ddl/partition.go @@ -1371,11 +1371,11 @@ func checkExchangePartitionRecordValidation(w *worker, pt *model.TableInfo, inde } defer w.sessPool.put(ctx) - stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(context.Background(), sql, paramList...) + stmt, err := ctx.(sqlexec.RestrictedSQLExecutor).ParseWithParams(w.ddlJobCtx, sql, paramList...) if err != nil { return errors.Trace(err) } - rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(context.Background(), stmt) + rows, _, err := ctx.(sqlexec.RestrictedSQLExecutor).ExecRestrictedStmt(w.ddlJobCtx, stmt) if err != nil { return errors.Trace(err) } diff --git a/ddl/reorg.go b/ddl/reorg.go index eb20f31ce5fe7..465546c74c842 100644 --- a/ddl/reorg.go +++ b/ddl/reorg.go @@ -339,11 +339,11 @@ func getTableTotalCount(w *worker, tblInfo *model.TableInfo) int64 { return statistics.PseudoRowCount } sql := "select table_rows from information_schema.tables where tidb_table_id=%?;" - stmt, err := executor.ParseWithParams(context.Background(), sql, tblInfo.ID) + stmt, err := executor.ParseWithParams(w.ddlJobCtx, sql, tblInfo.ID) if err != nil { return statistics.PseudoRowCount } - rows, _, err := executor.ExecRestrictedStmt(context.Background(), stmt) + rows, _, err := executor.ExecRestrictedStmt(w.ddlJobCtx, stmt) if err != nil { return statistics.PseudoRowCount } diff --git a/ddl/restart_test.go b/ddl/restart_test.go index b7791ef7679bd..0eb0b26e781be 100644 --- a/ddl/restart_test.go +++ b/ddl/restart_test.go @@ -43,6 +43,7 @@ func (d *ddl) restartWorkers(ctx context.Context) { for _, worker := range d.workers { worker.wg.Add(1) worker.ctx = d.ctx + worker.ddlJobCtx = context.Background() w := worker go w.start(d.ddlCtx) asyncNotify(worker.ddlJobCh) diff --git a/ddl/table.go b/ddl/table.go index 336471c04bd62..9f46f898bb786 100644 --- a/ddl/table.go +++ b/ddl/table.go @@ -315,7 +315,7 @@ func (w *worker) onRecoverTable(d *ddlCtx, t *meta.Meta, job *model.Job) (ver in } else { tids = []int64{tblInfo.ID} } - err = w.delRangeManager.removeFromGCDeleteRange(dropJobID, tids) + err = w.delRangeManager.removeFromGCDeleteRange(w.ddlJobCtx, dropJobID, tids) if err != nil { return ver, errors.Trace(err) } diff --git a/ddl/util/util.go b/ddl/util/util.go index 0e5eb8fe2051d..1ae4e1a7684b1 100644 --- a/ddl/util/util.go +++ b/ddl/util/util.go @@ -118,7 +118,7 @@ func RemoveFromGCDeleteRange(ctx sessionctx.Context, jobID, elementID int64) err } // RemoveMultiFromGCDeleteRange is exported for ddl pkg to use. -func RemoveMultiFromGCDeleteRange(ctx sessionctx.Context, jobID int64, elementIDs []int64) error { +func RemoveMultiFromGCDeleteRange(ctx context.Context, sctx sessionctx.Context, jobID int64, elementIDs []int64) error { var buf strings.Builder buf.WriteString(completeDeleteMultiRangesSQL) paramIDs := make([]interface{}, 0, 1+len(elementIDs)) @@ -131,7 +131,7 @@ func RemoveMultiFromGCDeleteRange(ctx sessionctx.Context, jobID int64, elementID paramIDs = append(paramIDs, elementID) } buf.WriteString(")") - _, err := ctx.(sqlexec.SQLExecutor).ExecuteInternal(context.TODO(), buf.String(), paramIDs...) + _, err := sctx.(sqlexec.SQLExecutor).ExecuteInternal(ctx, buf.String(), paramIDs...) return errors.Trace(err) } @@ -150,20 +150,20 @@ func UpdateDeleteRange(ctx sessionctx.Context, dr DelRangeTask, newStartKey, old } // LoadDDLReorgVars loads ddl reorg variable from mysql.global_variables. -func LoadDDLReorgVars(ctx sessionctx.Context) error { +func LoadDDLReorgVars(ctx context.Context, sctx sessionctx.Context) error { // close issue #21391 // variable.TiDBRowFormatVersion is used to encode the new row for column type change. - return LoadGlobalVars(ctx, []string{variable.TiDBDDLReorgWorkerCount, variable.TiDBDDLReorgBatchSize, variable.TiDBRowFormatVersion}) + return LoadGlobalVars(ctx, sctx, []string{variable.TiDBDDLReorgWorkerCount, variable.TiDBDDLReorgBatchSize, variable.TiDBRowFormatVersion}) } // LoadDDLVars loads ddl variable from mysql.global_variables. func LoadDDLVars(ctx sessionctx.Context) error { - return LoadGlobalVars(ctx, []string{variable.TiDBDDLErrorCountLimit}) + return LoadGlobalVars(context.Background(), ctx, []string{variable.TiDBDDLErrorCountLimit}) } // LoadGlobalVars loads global variable from mysql.global_variables. -func LoadGlobalVars(ctx sessionctx.Context, varNames []string) error { - if sctx, ok := ctx.(sqlexec.RestrictedSQLExecutor); ok { +func LoadGlobalVars(ctx context.Context, sctx sessionctx.Context, varNames []string) error { + if e, ok := sctx.(sqlexec.RestrictedSQLExecutor); ok { var buf strings.Builder buf.WriteString(loadGlobalVars) paramNames := make([]interface{}, 0, len(varNames)) @@ -175,18 +175,18 @@ func LoadGlobalVars(ctx sessionctx.Context, varNames []string) error { paramNames = append(paramNames, name) } buf.WriteString(")") - stmt, err := sctx.ParseWithParams(context.Background(), buf.String(), paramNames...) + stmt, err := e.ParseWithParams(ctx, buf.String(), paramNames...) if err != nil { return errors.Trace(err) } - rows, _, err := sctx.ExecRestrictedStmt(context.Background(), stmt) + rows, _, err := e.ExecRestrictedStmt(ctx, stmt) if err != nil { return errors.Trace(err) } for _, row := range rows { varName := row.GetString(0) varValue := row.GetString(1) - if err = ctx.GetSessionVars().SetSystemVar(varName, varValue); err != nil { + if err = sctx.GetSessionVars().SetSystemVar(varName, varValue); err != nil { return err } } diff --git a/executor/adapter.go b/executor/adapter.go index 16cd9302bc5f4..4527230880c23 100644 --- a/executor/adapter.go +++ b/executor/adapter.go @@ -307,9 +307,10 @@ func (a *ExecStmt) setPlanLabelForTopSQL(ctx context.Context) context.Context { if a.Plan == nil || !variable.TopSQLEnabled() { return ctx } - normalizedSQL, sqlDigest := a.Ctx.GetSessionVars().StmtCtx.SQLDigest() + vars := a.Ctx.GetSessionVars() + normalizedSQL, sqlDigest := vars.StmtCtx.SQLDigest() normalizedPlan, planDigest := getPlanDigest(a.Ctx, a.Plan) - return topsql.AttachSQLInfo(ctx, normalizedSQL, sqlDigest, normalizedPlan, planDigest) + return topsql.AttachSQLInfo(ctx, normalizedSQL, sqlDigest, normalizedPlan, planDigest, vars.InRestrictedSQL) } // Exec builds an Executor from a plan. If the Executor doesn't return result, diff --git a/executor/ddl.go b/executor/ddl.go index 68505134cbace..b794105ec019d 100644 --- a/executor/ddl.go +++ b/executor/ddl.go @@ -93,7 +93,7 @@ func (e *DDLExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { case *ast.AlterDatabaseStmt: err = e.executeAlterDatabase(x) case *ast.AlterTableStmt: - err = e.executeAlterTable(x) + err = e.executeAlterTable(ctx, x) case *ast.CreateIndexStmt: err = e.executeCreateIndex(x) case *ast.CreateDatabaseStmt: @@ -456,9 +456,9 @@ func (e *DDLExec) executeDropIndex(s *ast.DropIndexStmt) error { return err } -func (e *DDLExec) executeAlterTable(s *ast.AlterTableStmt) error { +func (e *DDLExec) executeAlterTable(ctx context.Context, s *ast.AlterTableStmt) error { ti := ast.Ident{Schema: s.Table.Schema, Name: s.Table.Name} - err := domain.GetDomain(e.ctx).DDL().AlterTable(e.ctx, ti, s.Specs) + err := domain.GetDomain(e.ctx).DDL().AlterTable(ctx, e.ctx, ti, s.Specs) return err } diff --git a/executor/ddl_test.go b/executor/ddl_test.go index b763a84673375..47c3a0c82985a 100644 --- a/executor/ddl_test.go +++ b/executor/ddl_test.go @@ -1147,21 +1147,21 @@ func (s *testSuite6) TestMaxHandleAddIndex(c *C) { func (s *testSuite6) TestSetDDLReorgWorkerCnt(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - err := ddlutil.LoadDDLReorgVars(tk.Se) + err := ddlutil.LoadDDLReorgVars(context.Background(), tk.Se) c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(variable.DefTiDBDDLReorgWorkerCount)) tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 1") - err = ddlutil.LoadDDLReorgVars(tk.Se) + err = ddlutil.LoadDDLReorgVars(context.Background(), tk.Se) c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(1)) tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 100") - err = ddlutil.LoadDDLReorgVars(tk.Se) + err = ddlutil.LoadDDLReorgVars(context.Background(), tk.Se) c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(100)) _, err = tk.Exec("set @@global.tidb_ddl_reorg_worker_cnt = invalid_val") c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue, Commentf("err %v", err)) tk.MustExec("set @@global.tidb_ddl_reorg_worker_cnt = 100") - err = ddlutil.LoadDDLReorgVars(tk.Se) + err = ddlutil.LoadDDLReorgVars(context.Background(), tk.Se) c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgWorkerCounter(), Equals, int32(100)) _, err = tk.Exec("set @@global.tidb_ddl_reorg_worker_cnt = -1") @@ -1184,24 +1184,24 @@ func (s *testSuite6) TestSetDDLReorgWorkerCnt(c *C) { func (s *testSuite6) TestSetDDLReorgBatchSize(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - err := ddlutil.LoadDDLReorgVars(tk.Se) + err := ddlutil.LoadDDLReorgVars(context.Background(), tk.Se) c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgBatchSize(), Equals, int32(variable.DefTiDBDDLReorgBatchSize)) tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = 1") tk.MustQuery("show warnings;").Check(testkit.Rows("Warning 1292 Truncated incorrect tidb_ddl_reorg_batch_size value: '1'")) - err = ddlutil.LoadDDLReorgVars(tk.Se) + err = ddlutil.LoadDDLReorgVars(context.Background(), tk.Se) c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgBatchSize(), Equals, variable.MinDDLReorgBatchSize) tk.MustExec(fmt.Sprintf("set @@global.tidb_ddl_reorg_batch_size = %v", variable.MaxDDLReorgBatchSize+1)) tk.MustQuery("show warnings;").Check(testkit.Rows(fmt.Sprintf("Warning 1292 Truncated incorrect tidb_ddl_reorg_batch_size value: '%d'", variable.MaxDDLReorgBatchSize+1))) - err = ddlutil.LoadDDLReorgVars(tk.Se) + err = ddlutil.LoadDDLReorgVars(context.Background(), tk.Se) c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgBatchSize(), Equals, variable.MaxDDLReorgBatchSize) _, err = tk.Exec("set @@global.tidb_ddl_reorg_batch_size = invalid_val") c.Assert(terror.ErrorEqual(err, variable.ErrWrongTypeForVar), IsTrue, Commentf("err %v", err)) tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = 100") - err = ddlutil.LoadDDLReorgVars(tk.Se) + err = ddlutil.LoadDDLReorgVars(context.Background(), tk.Se) c.Assert(err, IsNil) c.Assert(variable.GetDDLReorgBatchSize(), Equals, int32(100)) tk.MustExec("set @@global.tidb_ddl_reorg_batch_size = -1") diff --git a/executor/executor.go b/executor/executor.go index 297b6a8e66dc2..3b01cc6199b47 100644 --- a/executor/executor.go +++ b/executor/executor.go @@ -1696,7 +1696,7 @@ func ResetContextOfStmt(ctx sessionctx.Context, s ast.StmtNode) (err error) { pprof.SetGoroutineLabels(goCtx) } if variable.TopSQLEnabled() && prepareStmt.SQLDigest != nil { - topsql.AttachSQLInfo(goCtx, prepareStmt.NormalizedSQL, prepareStmt.SQLDigest, "", nil) + topsql.AttachSQLInfo(goCtx, prepareStmt.NormalizedSQL, prepareStmt.SQLDigest, "", nil, vars.InRestrictedSQL) } } // execute missed stmtID uses empty sql diff --git a/executor/grant.go b/executor/grant.go index 46c0324fe24f2..d96e10d2729de 100644 --- a/executor/grant.go +++ b/executor/grant.go @@ -121,7 +121,7 @@ func (e *GrantExec) Next(ctx context.Context, req *chunk.Chunk) error { // Check which user is not exist. for _, user := range e.Users { - exists, err := userExists(e.ctx, user.User.Username, user.User.Hostname) + exists, err := userExists(ctx, e.ctx, user.User.Username, user.User.Hostname) if err != nil { return err } diff --git a/executor/infoschema_reader.go b/executor/infoschema_reader.go index e087ccc757ad0..274146a6161a7 100644 --- a/executor/infoschema_reader.go +++ b/executor/infoschema_reader.go @@ -94,11 +94,11 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex case infoschema.TableStatistics: e.setDataForStatistics(sctx, dbs) case infoschema.TableTables: - err = e.setDataFromTables(sctx, dbs) + err = e.setDataFromTables(ctx, sctx, dbs) case infoschema.TableSequences: e.setDataFromSequences(sctx, dbs) case infoschema.TablePartitions: - err = e.setDataFromPartitions(sctx, dbs) + err = e.setDataFromPartitions(ctx, sctx, dbs) case infoschema.TableClusterInfo: err = e.dataForTiDBClusterInfo(sctx) case infoschema.TableAnalyzeStatus: @@ -184,13 +184,13 @@ func (e *memtableRetriever) retrieve(ctx context.Context, sctx sessionctx.Contex return adjustColumns(ret, e.columns, e.table), nil } -func getRowCountAllTable(ctx sessionctx.Context) (map[int64]uint64, error) { - exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), "select table_id, count from mysql.stats_meta") +func getRowCountAllTable(ctx context.Context, sctx sessionctx.Context) (map[int64]uint64, error) { + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, "select table_id, count from mysql.stats_meta") if err != nil { return nil, err } - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return nil, err } @@ -209,13 +209,13 @@ type tableHistID struct { histID int64 } -func getColLengthAllTables(ctx sessionctx.Context) (map[tableHistID]uint64, error) { - exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") +func getColLengthAllTables(ctx context.Context, sctx sessionctx.Context) (map[tableHistID]uint64, error) { + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, "select table_id, hist_id, tot_col_size from mysql.stats_histograms where is_index = 0") if err != nil { return nil, err } - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return nil, err } @@ -278,7 +278,7 @@ var tableStatsCache = &statsCache{} // TableStatsCacheExpiry is the expiry time for table stats cache. var TableStatsCacheExpiry = 3 * time.Second -func (c *statsCache) get(ctx sessionctx.Context) (map[int64]uint64, map[tableHistID]uint64, error) { +func (c *statsCache) get(ctx context.Context, sctx sessionctx.Context) (map[int64]uint64, map[tableHistID]uint64, error) { c.mu.RLock() if time.Since(c.modifyTime) < TableStatsCacheExpiry { tableRows, colLength := c.tableRows, c.colLength @@ -292,11 +292,11 @@ func (c *statsCache) get(ctx sessionctx.Context) (map[int64]uint64, map[tableHis if time.Since(c.modifyTime) < TableStatsCacheExpiry { return c.tableRows, c.colLength, nil } - tableRows, err := getRowCountAllTable(ctx) + tableRows, err := getRowCountAllTable(ctx, sctx) if err != nil { return nil, nil, err } - colLength, err := getColLengthAllTables(ctx) + colLength, err := getColLengthAllTables(ctx, sctx) if err != nil { return nil, nil, err } @@ -445,13 +445,13 @@ func (e *memtableRetriever) setDataForStatisticsInTable(schema *model.DBInfo, ta e.rows = append(e.rows, rows...) } -func (e *memtableRetriever) setDataFromTables(ctx sessionctx.Context, schemas []*model.DBInfo) error { - tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx) +func (e *memtableRetriever) setDataFromTables(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error { + tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx) if err != nil { return err } - checker := privilege.GetPrivilegeManager(ctx) + checker := privilege.GetPrivilegeManager(sctx) var rows [][]types.Datum createTimeTp := mysql.TypeDatetime @@ -469,7 +469,7 @@ func (e *memtableRetriever) setDataFromTables(ctx sessionctx.Context, schemas [] continue } - if checker != nil && !checker.RequestVerification(ctx.GetSessionVars().ActiveRoles, schema.Name.L, table.Name.L, "", mysql.AllPrivMask) { + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.Name.L, table.Name.L, "", mysql.AllPrivMask) { continue } pkType := "NONCLUSTERED" @@ -480,7 +480,7 @@ func (e *memtableRetriever) setDataFromTables(ctx sessionctx.Context, schemas [] var autoIncID interface{} hasAutoIncID, _ := infoschema.HasAutoIncrementColumn(table) if hasAutoIncID { - autoIncID, err = getAutoIncrementID(ctx, schema, table) + autoIncID, err = getAutoIncrementID(sctx, schema, table) if err != nil { return err } @@ -692,17 +692,17 @@ func calcCharOctLength(lenInChar int, cs string) int { return lenInBytes } -func (e *memtableRetriever) setDataFromPartitions(ctx sessionctx.Context, schemas []*model.DBInfo) error { - tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx) +func (e *memtableRetriever) setDataFromPartitions(ctx context.Context, sctx sessionctx.Context, schemas []*model.DBInfo) error { + tableRowsMap, colLengthMap, err := tableStatsCache.get(ctx, sctx) if err != nil { return err } - checker := privilege.GetPrivilegeManager(ctx) + checker := privilege.GetPrivilegeManager(sctx) var rows [][]types.Datum createTimeTp := mysql.TypeDatetime for _, schema := range schemas { for _, table := range schema.Tables { - if checker != nil && !checker.RequestVerification(ctx.GetSessionVars().ActiveRoles, schema.Name.L, table.Name.L, "", mysql.SelectPriv) { + if checker != nil && !checker.RequestVerification(sctx.GetSessionVars().ActiveRoles, schema.Name.L, table.Name.L, "", mysql.SelectPriv) { continue } createTime := types.NewTime(types.FromGoTime(table.GetUpdateTime()), createTimeTp, types.DefaultFsp) diff --git a/executor/prepared.go b/executor/prepared.go index 7eb92ad37ff61..c92bc996e56a6 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -182,7 +182,7 @@ func (e *PrepareExec) Next(ctx context.Context, req *chunk.Chunk) error { } normalizedSQL, digest := parser.NormalizeDigest(prepared.Stmt.Text()) if variable.TopSQLEnabled() { - ctx = topsql.AttachSQLInfo(ctx, normalizedSQL, digest, "", nil) + ctx = topsql.AttachSQLInfo(ctx, normalizedSQL, digest, "", nil, vars.InRestrictedSQL) } if !plannercore.PreparedPlanCacheEnabled() { diff --git a/executor/revoke.go b/executor/revoke.go index 3cc14ec3d1856..8532f4f5fb3c7 100644 --- a/executor/revoke.go +++ b/executor/revoke.go @@ -96,7 +96,7 @@ func (e *RevokeExec) Next(ctx context.Context, req *chunk.Chunk) error { } // Check if user exists. - exists, err := userExists(e.ctx, user.User.Username, user.User.Hostname) + exists, err := userExists(ctx, e.ctx, user.User.Username, user.User.Hostname) if err != nil { return err } diff --git a/executor/show.go b/executor/show.go index 69ea2d041a51b..9810b003cdea9 100644 --- a/executor/show.go +++ b/executor/show.go @@ -139,7 +139,7 @@ func (e *ShowExec) fetchAll(ctx context.Context) error { case ast.ShowCreateSequence: return e.fetchShowCreateSequence() case ast.ShowCreateUser: - return e.fetchShowCreateUser() + return e.fetchShowCreateUser(ctx) case ast.ShowCreateView: return e.fetchShowCreateView() case ast.ShowCreateDatabase: @@ -149,7 +149,7 @@ func (e *ShowExec) fetchAll(ctx context.Context) error { case ast.ShowDrainerStatus: return e.fetchShowPumpOrDrainerStatus(node.DrainerNode) case ast.ShowEngines: - return e.fetchShowEngines() + return e.fetchShowEngines(ctx) case ast.ShowGrants: return e.fetchShowGrants() case ast.ShowIndex: @@ -165,7 +165,7 @@ func (e *ShowExec) fetchAll(ctx context.Context) error { case ast.ShowOpenTables: return e.fetchShowOpenTables() case ast.ShowTableStatus: - return e.fetchShowTableStatus() + return e.fetchShowTableStatus(ctx) case ast.ShowTriggers: return e.fetchShowTriggers() case ast.ShowVariables: @@ -290,14 +290,14 @@ func (e *ShowExec) fetchShowBind() error { return nil } -func (e *ShowExec) fetchShowEngines() error { +func (e *ShowExec) fetchShowEngines(ctx context.Context) error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `SELECT * FROM information_schema.engines`) + stmt, err := exec.ParseWithParams(ctx, `SELECT * FROM information_schema.engines`) if err != nil { return errors.Trace(err) } - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return errors.Trace(err) } @@ -411,7 +411,7 @@ func (e *ShowExec) fetchShowTables() error { return nil } -func (e *ShowExec) fetchShowTableStatus() error { +func (e *ShowExec) fetchShowTableStatus(ctx context.Context) error { checker := privilege.GetPrivilegeManager(e.ctx) if checker != nil && e.ctx.GetSessionVars().User != nil { if !checker.DBIsVisible(e.ctx.GetSessionVars().ActiveRoles, e.DBName.O) { @@ -424,7 +424,7 @@ func (e *ShowExec) fetchShowTableStatus() error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `SELECT + stmt, err := exec.ParseWithParams(ctx, `SELECT table_name, engine, version, row_format, table_rows, avg_row_length, data_length, max_data_length, index_length, data_free, auto_increment, create_time, update_time, check_time, @@ -447,7 +447,7 @@ func (e *ShowExec) fetchShowTableStatus() error { snapshot = e.ctx.GetSessionVars().SnapshotTS } - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt, sqlexec.ExecOptionWithSnapshot(snapshot)) + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt, sqlexec.ExecOptionWithSnapshot(snapshot)) if err != nil { return errors.Trace(err) } @@ -1312,7 +1312,7 @@ func (e *ShowExec) fetchShowCollation() error { } // fetchShowCreateUser composes show create create user result. -func (e *ShowExec) fetchShowCreateUser() error { +func (e *ShowExec) fetchShowCreateUser(ctx context.Context) error { checker := privilege.GetPrivilegeManager(e.ctx) if checker == nil { return errors.New("miss privilege checker") @@ -1334,11 +1334,11 @@ func (e *ShowExec) fetchShowCreateUser() error { exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `SELECT plugin FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.UserTable, userName, hostName) + stmt, err := exec.ParseWithParams(ctx, `SELECT plugin FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.UserTable, userName, hostName) if err != nil { return errors.Trace(err) } - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return errors.Trace(err) } @@ -1354,11 +1354,11 @@ func (e *ShowExec) fetchShowCreateUser() error { authplugin = rows[0].GetString(0) } - stmt, err = exec.ParseWithParams(context.TODO(), `SELECT Priv FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) + stmt, err = exec.ParseWithParams(ctx, `SELECT Priv FROM %n.%n WHERE User=%? AND Host=%?`, mysql.SystemDB, mysql.GlobalPrivTable, userName, hostName) if err != nil { return errors.Trace(err) } - rows, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err = exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return errors.Trace(err) } diff --git a/executor/simple.go b/executor/simple.go index 2ba3483625a8b..56ea1b4383b69 100644 --- a/executor/simple.go +++ b/executor/simple.go @@ -120,7 +120,7 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { switch x := e.Statement.(type) { case *ast.GrantRoleStmt: - err = e.executeGrantRole(x) + err = e.executeGrantRole(ctx, x) case *ast.UseStmt: err = e.executeUse(x) case *ast.FlushStmt: @@ -136,13 +136,13 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { case *ast.CreateUserStmt: err = e.executeCreateUser(ctx, x) case *ast.AlterUserStmt: - err = e.executeAlterUser(x) + err = e.executeAlterUser(ctx, x) case *ast.DropUserStmt: - err = e.executeDropUser(x) + err = e.executeDropUser(ctx, x) case *ast.RenameUserStmt: err = e.executeRenameUser(x) case *ast.SetPwdStmt: - err = e.executeSetPwd(x) + err = e.executeSetPwd(ctx, x) case *ast.KillStmt: err = e.executeKillStmt(ctx, x) case *ast.BinlogStmt: @@ -153,9 +153,9 @@ func (e *SimpleExec) Next(ctx context.Context, req *chunk.Chunk) (err error) { case *ast.SetRoleStmt: err = e.executeSetRole(x) case *ast.RevokeRoleStmt: - err = e.executeRevokeRole(x) + err = e.executeRevokeRole(ctx, x) case *ast.SetDefaultRoleStmt: - err = e.executeSetDefaultRole(x) + err = e.executeSetDefaultRole(ctx, x) case *ast.ShutdownStmt: err = e.executeShutdown(x) case *ast.AdminStmt: @@ -196,9 +196,9 @@ func (e *SimpleExec) setDefaultRoleNone(s *ast.SetDefaultRoleStmt) error { return nil } -func (e *SimpleExec) setDefaultRoleRegular(s *ast.SetDefaultRoleStmt) error { +func (e *SimpleExec) setDefaultRoleRegular(ctx context.Context, s *ast.SetDefaultRoleStmt) error { for _, user := range s.UserList { - exists, err := userExists(e.ctx, user.Username, user.Hostname) + exists, err := userExists(ctx, e.ctx, user.Username, user.Hostname) if err != nil { return err } @@ -207,7 +207,7 @@ func (e *SimpleExec) setDefaultRoleRegular(s *ast.SetDefaultRoleStmt) error { } } for _, role := range s.RoleList { - exists, err := userExists(e.ctx, role.Username, role.Hostname) + exists, err := userExists(ctx, e.ctx, role.Username, role.Hostname) if err != nil { return err } @@ -266,9 +266,9 @@ func (e *SimpleExec) setDefaultRoleRegular(s *ast.SetDefaultRoleStmt) error { return nil } -func (e *SimpleExec) setDefaultRoleAll(s *ast.SetDefaultRoleStmt) error { +func (e *SimpleExec) setDefaultRoleAll(ctx context.Context, s *ast.SetDefaultRoleStmt) error { for _, user := range s.UserList { - exists, err := userExists(e.ctx, user.Username, user.Hostname) + exists, err := userExists(ctx, e.ctx, user.Username, user.Hostname) if err != nil { return err } @@ -376,7 +376,7 @@ func (e *SimpleExec) setDefaultRoleForCurrentUser(s *ast.SetDefaultRoleStmt) (er return nil } -func (e *SimpleExec) executeSetDefaultRole(s *ast.SetDefaultRoleStmt) (err error) { +func (e *SimpleExec) executeSetDefaultRole(ctx context.Context, s *ast.SetDefaultRoleStmt) (err error) { sessionVars := e.ctx.GetSessionVars() checker := privilege.GetPrivilegeManager(e.ctx) if checker == nil { @@ -401,11 +401,11 @@ func (e *SimpleExec) executeSetDefaultRole(s *ast.SetDefaultRoleStmt) (err error switch s.SetRoleOpt { case ast.SetRoleAll: - err = e.setDefaultRoleAll(s) + err = e.setDefaultRoleAll(ctx, s) case ast.SetRoleNone: err = e.setDefaultRoleNone(s) case ast.SetRoleRegular: - err = e.setDefaultRoleRegular(s) + err = e.setDefaultRoleRegular(ctx, s) } if err != nil { return @@ -633,9 +633,9 @@ func (e *SimpleExec) executeBegin(ctx context.Context, s *ast.BeginStmt) error { return nil } -func (e *SimpleExec) executeRevokeRole(s *ast.RevokeRoleStmt) error { +func (e *SimpleExec) executeRevokeRole(ctx context.Context, s *ast.RevokeRoleStmt) error { for _, role := range s.Roles { - exists, err := userExists(e.ctx, role.Username, role.Hostname) + exists, err := userExists(ctx, e.ctx, role.Username, role.Hostname) if err != nil { return errors.Trace(err) } @@ -657,7 +657,7 @@ func (e *SimpleExec) executeRevokeRole(s *ast.RevokeRoleStmt) error { } sql := new(strings.Builder) for _, user := range s.Users { - exists, err := userExists(e.ctx, user.Username, user.Hostname) + exists, err := userExists(ctx, e.ctx, user.Username, user.Hostname) if err != nil { return errors.Trace(err) } @@ -760,7 +760,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm if len(users) > 0 { sqlexec.MustFormatSQL(sql, ",") } - exists, err1 := userExists(e.ctx, spec.User.Username, spec.User.Hostname) + exists, err1 := userExists(ctx, e.ctx, spec.User.Username, spec.User.Hostname) if err1 != nil { return err1 } @@ -837,7 +837,7 @@ func (e *SimpleExec) executeCreateUser(ctx context.Context, s *ast.CreateUserStm return err } -func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { +func (e *SimpleExec) executeAlterUser(ctx context.Context, s *ast.AlterUserStmt) error { if s.CurrentAuth != nil { user := e.ctx.GetSessionVars().User if user == nil { @@ -903,7 +903,7 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { } } - exists, err := userExists(e.ctx, spec.User.Username, spec.User.Hostname) + exists, err := userExists(ctx, e.ctx, spec.User.Username, spec.User.Hostname) if err != nil { return err } @@ -926,22 +926,22 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { if !ok { return errors.Trace(ErrPasswordFormat) } - stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE Host=%? and User=%?;`, mysql.SystemDB, mysql.UserTable, pwd, spec.User.Hostname, spec.User.Username) + stmt, err := exec.ParseWithParams(ctx, `UPDATE %n.%n SET authentication_string=%? WHERE Host=%? and User=%?;`, mysql.SystemDB, mysql.UserTable, pwd, spec.User.Hostname, spec.User.Username) if err != nil { return err } - _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) + _, _, err = exec.ExecRestrictedStmt(ctx, stmt) if err != nil { failedUsers = append(failedUsers, spec.User.String()) } } if len(privData) > 0 { - stmt, err := exec.ParseWithParams(context.TODO(), "INSERT INTO %n.%n (Host, User, Priv) VALUES (%?,%?,%?) ON DUPLICATE KEY UPDATE Priv = values(Priv)", mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, string(hack.String(privData))) + stmt, err := exec.ParseWithParams(ctx, "INSERT INTO %n.%n (Host, User, Priv) VALUES (%?,%?,%?) ON DUPLICATE KEY UPDATE Priv = values(Priv)", mysql.SystemDB, mysql.GlobalPrivTable, spec.User.Hostname, spec.User.Username, string(hack.String(privData))) if err != nil { return err } - _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) + _, _, err = exec.ExecRestrictedStmt(ctx, stmt) if err != nil { failedUsers = append(failedUsers, spec.User.String()) } @@ -969,7 +969,7 @@ func (e *SimpleExec) executeAlterUser(s *ast.AlterUserStmt) error { return nil } -func (e *SimpleExec) executeGrantRole(s *ast.GrantRoleStmt) error { +func (e *SimpleExec) executeGrantRole(ctx context.Context, s *ast.GrantRoleStmt) error { sessionVars := e.ctx.GetSessionVars() for i, user := range s.Users { if user.CurrentUser { @@ -979,7 +979,7 @@ func (e *SimpleExec) executeGrantRole(s *ast.GrantRoleStmt) error { } for _, role := range s.Roles { - exists, err := userExists(e.ctx, role.Username, role.Hostname) + exists, err := userExists(ctx, e.ctx, role.Username, role.Hostname) if err != nil { return err } @@ -988,7 +988,7 @@ func (e *SimpleExec) executeGrantRole(s *ast.GrantRoleStmt) error { } } for _, user := range s.Users { - exists, err := userExists(e.ctx, user.Username, user.Hostname) + exists, err := userExists(ctx, e.ctx, user.Username, user.Hostname) if err != nil { return err } @@ -1147,7 +1147,7 @@ func renameUserHostInSystemTable(sqlExecutor sqlexec.SQLExecutor, tableName, use return err } -func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { +func (e *SimpleExec) executeDropUser(ctx context.Context, s *ast.DropUserStmt) error { // Check privileges. // Check `CREATE USER` privilege. checker := privilege.GetPrivilegeManager(e.ctx) @@ -1182,7 +1182,7 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { sql := new(strings.Builder) for _, user := range s.UserList { - exists, err := userExists(e.ctx, user.Username, user.Hostname) + exists, err := userExists(ctx, e.ctx, user.Username, user.Hostname) if err != nil { return err } @@ -1300,13 +1300,13 @@ func (e *SimpleExec) executeDropUser(s *ast.DropUserStmt) error { return nil } -func userExists(ctx sessionctx.Context, name string, host string) (bool, error) { - exec := ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, name, host) +func userExists(ctx context.Context, sctx sessionctx.Context, name string, host string) (bool, error) { + exec := sctx.(sqlexec.RestrictedSQLExecutor) + stmt, err := exec.ParseWithParams(ctx, `SELECT * FROM %n.%n WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, name, host) if err != nil { return false, err } - rows, _, err := exec.ExecRestrictedStmt(context.TODO(), stmt) + rows, _, err := exec.ExecRestrictedStmt(ctx, stmt) if err != nil { return false, err } @@ -1343,7 +1343,7 @@ func (e *SimpleExec) userAuthPlugin(name string, host string) (string, error) { return authplugin, nil } -func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error { +func (e *SimpleExec) executeSetPwd(ctx context.Context, s *ast.SetPwdStmt) error { var u, h string if s.User == nil { if e.ctx.GetSessionVars().User == nil { @@ -1360,7 +1360,7 @@ func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error { u = s.User.Username h = s.User.Hostname } - exists, err := userExists(e.ctx, u, h) + exists, err := userExists(ctx, e.ctx, u, h) if err != nil { return err } @@ -1381,11 +1381,11 @@ func (e *SimpleExec) executeSetPwd(s *ast.SetPwdStmt) error { // update mysql.user exec := e.ctx.(sqlexec.RestrictedSQLExecutor) - stmt, err := exec.ParseWithParams(context.TODO(), `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, pwd, u, h) + stmt, err := exec.ParseWithParams(ctx, `UPDATE %n.%n SET authentication_string=%? WHERE User=%? AND Host=%?;`, mysql.SystemDB, mysql.UserTable, pwd, u, h) if err != nil { return err } - _, _, err = exec.ExecRestrictedStmt(context.TODO(), stmt) + _, _, err = exec.ExecRestrictedStmt(ctx, stmt) domain.GetDomain(e.ctx).NotifyUpdatePrivilege(e.ctx) return err } diff --git a/go.mod b/go.mod index 5c19112ee5b94..8ea96e89a7fba 100644 --- a/go.mod +++ b/go.mod @@ -47,7 +47,7 @@ require ( github.com/pingcap/parser v0.0.0-20210707071004-31c87e37af5c github.com/pingcap/sysutil v0.0.0-20210315073920-cc0985d983a3 github.com/pingcap/tidb-tools v4.0.9-0.20201127090955-2707c97b3853+incompatible - github.com/pingcap/tipb v0.0.0-20210628060001-1793e022b962 + github.com/pingcap/tipb v0.0.0-20210708040514-0f154bb0dc0f github.com/prometheus/client_golang v1.5.1 github.com/prometheus/client_model v0.2.0 github.com/prometheus/common v0.9.1 diff --git a/go.sum b/go.sum index c4791ba6775f2..f040b0c56d144 100644 --- a/go.sum +++ b/go.sum @@ -454,8 +454,8 @@ github.com/pingcap/sysutil v0.0.0-20210315073920-cc0985d983a3/go.mod h1:tckvA041 github.com/pingcap/tidb-dashboard v0.0.0-20210312062513-eef5d6404638/go.mod h1:OzFN8H0EDMMqeulPhPMw2i2JaiZWOKFQ7zdRPhENNgo= github.com/pingcap/tidb-tools v4.0.9-0.20201127090955-2707c97b3853+incompatible h1:ceznmu/lLseGHP/jKyOa/3u/5H3wtLLLqkH2V3ssSjg= github.com/pingcap/tidb-tools v4.0.9-0.20201127090955-2707c97b3853+incompatible/go.mod h1:XGdcy9+yqlDSEMTpOXnwf3hiTeqrV6MN/u1se9N8yIM= -github.com/pingcap/tipb v0.0.0-20210628060001-1793e022b962 h1:9Y9Eci9LwAEhyXAlAU0bSix7Nemm3G267oyN3GVK+j0= -github.com/pingcap/tipb v0.0.0-20210628060001-1793e022b962/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= +github.com/pingcap/tipb v0.0.0-20210708040514-0f154bb0dc0f h1:q6WgGOeY+hbkvtKLyi6nAew7Ptl5vXyeI61VJuJdXnQ= +github.com/pingcap/tipb v0.0.0-20210708040514-0f154bb0dc0f/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/server/conn_stmt.go b/server/conn_stmt.go index 77217df69f588..237c1f599e6b1 100644 --- a/server/conn_stmt.go +++ b/server/conn_stmt.go @@ -133,7 +133,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e if variable.TopSQLEnabled() { preparedStmt, _ := cc.preparedStmtID2CachePreparedStmt(stmtID) if preparedStmt != nil && preparedStmt.SQLDigest != nil { - ctx = topsql.AttachSQLInfo(ctx, preparedStmt.NormalizedSQL, preparedStmt.SQLDigest, "", nil) + ctx = topsql.AttachSQLInfo(ctx, preparedStmt.NormalizedSQL, preparedStmt.SQLDigest, "", nil, false) } } @@ -277,7 +277,7 @@ func (cc *clientConn) handleStmtFetch(ctx context.Context, data []byte) (err err if variable.TopSQLEnabled() { prepareObj, _ := cc.preparedStmtID2CachePreparedStmt(stmtID) if prepareObj != nil && prepareObj.SQLDigest != nil { - ctx = topsql.AttachSQLInfo(ctx, prepareObj.NormalizedSQL, prepareObj.SQLDigest, "", nil) + ctx = topsql.AttachSQLInfo(ctx, prepareObj.NormalizedSQL, prepareObj.SQLDigest, "", nil, false) } } sql := "" diff --git a/session/session.go b/session/session.go index c0f00eb8815a8..568fd0666fad7 100644 --- a/session/session.go +++ b/session/session.go @@ -24,6 +24,7 @@ import ( "encoding/json" "fmt" "net" + "runtime/pprof" "runtime/trace" "strconv" "strings" @@ -1317,6 +1318,10 @@ func (s *session) ExecuteInternal(ctx context.Context, sql string, args ...inter s.sessionVars.InRestrictedSQL = true defer func() { s.sessionVars.InRestrictedSQL = origin + if variable.TopSQLEnabled() { + // Restore the goroutine label by using the original ctx after execution is finished. + pprof.SetGoroutineLabels(ctx) + } }() if span := opentracing.SpanFromContext(ctx); span != nil && span.Tracer() != nil { @@ -1458,8 +1463,9 @@ func (s *session) ParseWithParams(ctx context.Context, sql string, args ...inter if variable.TopSQLEnabled() { normalized, digest := parser.NormalizeDigest(sql) if digest != nil { - // Fixme: reset/clean the label when internal sql execute finish. - topsql.AttachSQLInfo(ctx, normalized, digest, "", nil) + // Reset the goroutine label when internal sql execute finish. + // Specifically reset in ExecRestrictedStmt function. + topsql.AttachSQLInfo(ctx, normalized, digest, "", nil, s.sessionVars.InRestrictedSQL) } } return stmts[0], nil @@ -1468,6 +1474,9 @@ func (s *session) ParseWithParams(ctx context.Context, sql string, args ...inter // ExecRestrictedStmt implements RestrictedSQLExecutor interface. func (s *session) ExecRestrictedStmt(ctx context.Context, stmtNode ast.StmtNode, opts ...sqlexec.OptionFuncAlias) ( []chunk.Row, []*ast.ResultField, error) { + if variable.TopSQLEnabled() { + defer pprof.SetGoroutineLabels(ctx) + } var execOption sqlexec.ExecOption for _, opt := range opts { opt(&execOption) @@ -1574,7 +1583,7 @@ func (s *session) ExecuteStmt(ctx context.Context, stmtNode ast.StmtNode) (sqlex } normalizedSQL, digest := s.sessionVars.StmtCtx.SQLDigest() if variable.TopSQLEnabled() { - ctx = topsql.AttachSQLInfo(ctx, normalizedSQL, digest, "", nil) + ctx = topsql.AttachSQLInfo(ctx, normalizedSQL, digest, "", nil, s.sessionVars.InRestrictedSQL) } if err := s.validateStatementReadOnlyInStaleness(stmtNode); err != nil { diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index fd344d3ad622e..a518d6430a279 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -17,7 +17,6 @@ import ( "context" "encoding/json" "fmt" - "runtime/pprof" "sort" "strconv" "sync" @@ -122,10 +121,6 @@ func (h *Handle) withRestrictedSQLExecutor(ctx context.Context, fn func(context. func (h *Handle) execRestrictedSQL(ctx context.Context, sql string, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { - if variable.TopSQLEnabled() { - // Restore the goroutine label by using the original ctx after execution is finished. - defer pprof.SetGoroutineLabels(ctx) - } stmt, err := exec.ParseWithParams(ctx, sql, params...) if err != nil { return nil, nil, errors.Trace(err) diff --git a/util/topsql/reporter/client.go b/util/topsql/reporter/client.go index 541fb42db14e2..efc33c71e6088 100644 --- a/util/topsql/reporter/client.go +++ b/util/topsql/reporter/client.go @@ -142,9 +142,11 @@ func (r *GRPCReportClient) sendBatchSQLMeta(ctx context.Context, sqlMap *sync.Ma cnt := 0 sqlMap.Range(func(key, value interface{}) bool { cnt++ + meta := value.(SQLMeta) sqlMeta := &tipb.SQLMeta{ SqlDigest: []byte(key.(string)), - NormalizedSql: value.(string), + NormalizedSql: meta.normalizedSQL, + IsInternalSql: meta.isInternal, } if err = stream.Send(sqlMeta); err != nil { return false diff --git a/util/topsql/reporter/mock/server.go b/util/topsql/reporter/mock/server.go index 0f29e8c3ad445..8051339288224 100644 --- a/util/topsql/reporter/mock/server.go +++ b/util/topsql/reporter/mock/server.go @@ -18,7 +18,7 @@ type mockAgentServer struct { sync.Mutex addr string grpcServer *grpc.Server - sqlMetas map[string]string + sqlMetas map[string]tipb.SQLMeta planMetas map[string]string records [][]*tipb.CPUTimeRecord hang struct { @@ -38,7 +38,7 @@ func StartMockAgentServer() (*mockAgentServer, error) { agentServer := &mockAgentServer{ addr: fmt.Sprintf("127.0.0.1:%d", lis.Addr().(*net.TCPAddr).Port), grpcServer: server, - sqlMetas: make(map[string]string, 5000), + sqlMetas: make(map[string]tipb.SQLMeta, 5000), planMetas: make(map[string]string, 5000), } agentServer.hang.beginTime.Store(time.Now()) @@ -99,7 +99,7 @@ func (svr *mockAgentServer) ReportSQLMeta(stream tipb.TopSQLAgent_ReportSQLMetaS return err } svr.Lock() - svr.sqlMetas[string(req.SqlDigest)] = req.NormalizedSql + svr.sqlMetas[string(req.SqlDigest)] = *req svr.Unlock() } return stream.SendAndClose(&tipb.EmptyResponse{}) @@ -140,14 +140,21 @@ func (svr *mockAgentServer) WaitCollectCnt(cnt int, timeout time.Duration) { } } +func (svr *mockAgentServer) GetSQLMetaByDigest(digest []byte) (tipb.SQLMeta, bool) { + svr.Lock() + sqlMeta, exist := svr.sqlMetas[string(digest)] + svr.Unlock() + return sqlMeta, exist +} + func (svr *mockAgentServer) GetSQLMetaByDigestBlocking(digest []byte, timeout time.Duration) (normalizedSQL string, exist bool) { start := time.Now() for { svr.Lock() - normalizedSQL, exist = svr.sqlMetas[string(digest)] + sqlMeta, exist := svr.sqlMetas[string(digest)] svr.Unlock() if exist || time.Since(start) > timeout { - return normalizedSQL, exist + return sqlMeta.NormalizedSql, exist } time.Sleep(time.Millisecond) } diff --git a/util/topsql/reporter/reporter.go b/util/topsql/reporter/reporter.go index 180ad32bab9ce..03e4fb673f27d 100644 --- a/util/topsql/reporter/reporter.go +++ b/util/topsql/reporter/reporter.go @@ -46,7 +46,7 @@ var _ TopSQLReporter = &RemoteTopSQLReporter{} // TopSQLReporter collects Top SQL metrics. type TopSQLReporter interface { tracecpu.Collector - RegisterSQL(sqlDigest []byte, normalizedSQL string) + RegisterSQL(sqlDigest []byte, normalizedSQL string, isInternal bool) RegisterPlan(planDigest []byte, normalizedPlan string) Close() } @@ -120,7 +120,7 @@ type RemoteTopSQLReporter struct { cancel context.CancelFunc client ReportClient - // normalizedSQLMap is an map, whose keys are SQL digest strings and values are normalized SQL strings + // normalizedSQLMap is an map, whose keys are SQL digest strings and values are SQLMeta. normalizedSQLMap atomic.Value // sync.Map sqlMapLength atomic2.Int64 @@ -133,6 +133,12 @@ type RemoteTopSQLReporter struct { reportCollectedDataChan chan collectedData } +// SQLMeta is the SQL meta which contains the normalized SQL string and a bool field which uses to distinguish internal SQL. +type SQLMeta struct { + normalizedSQL string + isInternal bool +} + // NewRemoteTopSQLReporter creates a new TopSQL reporter // // planBinaryDecoder is a decoding function which will be called asynchronously to decode the plan binary to string @@ -179,14 +185,17 @@ var ( // Note that the normalized SQL string can be of >1M long. // This function should be thread-safe, which means parallelly calling it in several goroutines should be fine. // It should also return immediately, and do any CPU-intensive job asynchronously. -func (tsr *RemoteTopSQLReporter) RegisterSQL(sqlDigest []byte, normalizedSQL string) { +func (tsr *RemoteTopSQLReporter) RegisterSQL(sqlDigest []byte, normalizedSQL string, isInternal bool) { if tsr.sqlMapLength.Load() >= variable.TopSQLVariable.MaxCollect.Load() { ignoreExceedSQLCounter.Inc() return } m := tsr.normalizedSQLMap.Load().(*sync.Map) key := string(sqlDigest) - _, loaded := m.LoadOrStore(key, normalizedSQL) + _, loaded := m.LoadOrStore(key, SQLMeta{ + normalizedSQL: normalizedSQL, + isInternal: isInternal, + }) if !loaded { tsr.sqlMapLength.Add(1) } diff --git a/util/topsql/reporter/reporter_test.go b/util/topsql/reporter/reporter_test.go index 4146ca3e5767c..fbbd9caa8911e 100644 --- a/util/topsql/reporter/reporter_test.go +++ b/util/topsql/reporter/reporter_test.go @@ -48,7 +48,7 @@ func populateCache(tsr *RemoteTopSQLReporter, begin, end int, timestamp uint64) for i := begin; i < end; i++ { key := []byte("sqlDigest" + strconv.Itoa(i+1)) value := "sqlNormalized" + strconv.Itoa(i+1) - tsr.RegisterSQL(key, value) + tsr.RegisterSQL(key, value, false) } // register normalized plan for i := begin; i < end; i++ { @@ -177,7 +177,7 @@ func (s *testTopSQLReporter) TestCollectAndEvicted(c *C) { func (s *testTopSQLReporter) newSQLCPUTimeRecord(tsr *RemoteTopSQLReporter, sqlID int, cpuTimeMs uint32) tracecpu.SQLCPUTimeRecord { key := []byte("sqlDigest" + strconv.Itoa(sqlID)) value := "sqlNormalized" + strconv.Itoa(sqlID) - tsr.RegisterSQL(key, value) + tsr.RegisterSQL(key, value, sqlID%2 == 0) key = []byte("planDigest" + strconv.Itoa(sqlID)) value = "planNormalized" + strconv.Itoa(sqlID) @@ -268,7 +268,7 @@ func (s *testTopSQLReporter) TestCollectCapacity(c *C) { for i := 0; i < n; i++ { key := []byte("sqlDigest" + strconv.Itoa(i)) value := "sqlNormalized" + strconv.Itoa(i) - tsr.RegisterSQL(key, value) + tsr.RegisterSQL(key, value, false) } } registerPlan := func(n int) { @@ -397,6 +397,43 @@ func (s *testTopSQLReporter) TestDataPoints(c *C) { c.Assert(d.TimestampList, IsNil) } +func (s *testTopSQLReporter) TestCollectInternal(c *C) { + agentServer, err := mock.StartMockAgentServer() + c.Assert(err, IsNil) + defer agentServer.Stop() + + tsr := setupRemoteTopSQLReporter(3000, 1, agentServer.Address()) + defer tsr.Close() + + records := []tracecpu.SQLCPUTimeRecord{ + s.newSQLCPUTimeRecord(tsr, 1, 1), + s.newSQLCPUTimeRecord(tsr, 2, 2), + } + s.collectAndWait(tsr, 1, records) + + // Wait agent server collect finish. + agentServer.WaitCollectCnt(1, time.Second*10) + + // check for equality of server received batch and the original data + results := agentServer.GetLatestRecords() + c.Assert(results, HasLen, 2) + for _, req := range results { + id := 0 + prefix := "sqlDigest" + if strings.HasPrefix(string(req.SqlDigest), prefix) { + n, err := strconv.Atoi(string(req.SqlDigest)[len(prefix):]) + c.Assert(err, IsNil) + id = n + } + if id == 0 { + c.Fatalf("the id should not be 0") + } + sqlMeta, exist := agentServer.GetSQLMetaByDigest(req.SqlDigest) + c.Assert(exist, IsTrue) + c.Assert(sqlMeta.IsInternalSql, Equals, id%2 == 0) + } +} + func BenchmarkTopSQL_CollectAndIncrementFrequency(b *testing.B) { tsr := initializeCache(maxSQLNum, 120, ":23333") for i := 0; i < b.N; i++ { diff --git a/util/topsql/topsql.go b/util/topsql/topsql.go index e52bf978c9d31..c7f1ec0705872 100644 --- a/util/topsql/topsql.go +++ b/util/topsql/topsql.go @@ -53,7 +53,7 @@ func Close() { } // AttachSQLInfo attach the sql information info top sql. -func AttachSQLInfo(ctx context.Context, normalizedSQL string, sqlDigest *parser.Digest, normalizedPlan string, planDigest *parser.Digest) context.Context { +func AttachSQLInfo(ctx context.Context, normalizedSQL string, sqlDigest *parser.Digest, normalizedPlan string, planDigest *parser.Digest, isInternal bool) context.Context { if len(normalizedSQL) == 0 || sqlDigest == nil || len(sqlDigest.Bytes()) == 0 { return ctx } @@ -67,7 +67,7 @@ func AttachSQLInfo(ctx context.Context, normalizedSQL string, sqlDigest *parser. if len(normalizedPlan) == 0 || len(planDigestBytes) == 0 { // If plan digest is '', indicate it is the first time to attach the SQL info, since it only know the sql digest. - linkSQLTextWithDigest(sqlDigestBytes, normalizedSQL) + linkSQLTextWithDigest(sqlDigestBytes, normalizedSQL, isInternal) } else { linkPlanTextWithDigest(planDigestBytes, normalizedPlan) } @@ -105,7 +105,7 @@ func AttachSQLInfo(ctx context.Context, normalizedSQL string, sqlDigest *parser. return ctx } -func linkSQLTextWithDigest(sqlDigest []byte, normalizedSQL string) { +func linkSQLTextWithDigest(sqlDigest []byte, normalizedSQL string, isInternal bool) { if len(normalizedSQL) > MaxSQLTextSize { normalizedSQL = normalizedSQL[:MaxSQLTextSize] } @@ -116,7 +116,7 @@ func linkSQLTextWithDigest(sqlDigest []byte, normalizedSQL string) { } topc, ok := c.(reporter.TopSQLReporter) if ok { - topc.RegisterSQL(sqlDigest, normalizedSQL) + topc.RegisterSQL(sqlDigest, normalizedSQL, isInternal) } } diff --git a/util/topsql/topsql_test.go b/util/topsql/topsql_test.go index 88748e110cff9..3dd24cb397eb6 100644 --- a/util/topsql/topsql_test.go +++ b/util/topsql/topsql_test.go @@ -198,10 +198,10 @@ func (s *testSuite) TestMaxSQLAndPlanTest(c *C) { // Test for normal sql and plan sql := "select * from t" sqlDigest := mock.GenSQLDigest(sql) - topsql.AttachSQLInfo(ctx, sql, sqlDigest, "", nil) + topsql.AttachSQLInfo(ctx, sql, sqlDigest, "", nil, false) plan := "TableReader table:t" planDigest := genDigest(plan) - topsql.AttachSQLInfo(ctx, sql, sqlDigest, plan, planDigest) + topsql.AttachSQLInfo(ctx, sql, sqlDigest, plan, planDigest, false) cSQL := collector.GetSQL(sqlDigest.Bytes()) c.Assert(cSQL, Equals, sql) @@ -211,10 +211,10 @@ func (s *testSuite) TestMaxSQLAndPlanTest(c *C) { // Test for huge sql and plan sql = genStr(topsql.MaxSQLTextSize + 10) sqlDigest = mock.GenSQLDigest(sql) - topsql.AttachSQLInfo(ctx, sql, sqlDigest, "", nil) + topsql.AttachSQLInfo(ctx, sql, sqlDigest, "", nil, false) plan = genStr(topsql.MaxPlanTextSize + 10) planDigest = genDigest(plan) - topsql.AttachSQLInfo(ctx, sql, sqlDigest, plan, planDigest) + topsql.AttachSQLInfo(ctx, sql, sqlDigest, plan, planDigest, false) cSQL = collector.GetSQL(sqlDigest.Bytes()) c.Assert(cSQL, Equals, sql[:topsql.MaxSQLTextSize]) @@ -229,10 +229,10 @@ func (s *testSuite) setTopSQLEnable(enabled bool) { func (s *testSuite) mockExecuteSQL(sql, plan string) { ctx := context.Background() sqlDigest := mock.GenSQLDigest(sql) - topsql.AttachSQLInfo(ctx, sql, sqlDigest, "", nil) + topsql.AttachSQLInfo(ctx, sql, sqlDigest, "", nil, false) s.mockExecute(time.Millisecond * 100) planDigest := genDigest(plan) - topsql.AttachSQLInfo(ctx, sql, sqlDigest, plan, planDigest) + topsql.AttachSQLInfo(ctx, sql, sqlDigest, plan, planDigest, false) s.mockExecute(time.Millisecond * 300) } diff --git a/util/topsql/tracecpu/mock/mock.go b/util/topsql/tracecpu/mock/mock.go index b45ce9060be21..c0546a5528b1f 100644 --- a/util/topsql/tracecpu/mock/mock.go +++ b/util/topsql/tracecpu/mock/mock.go @@ -127,7 +127,7 @@ func (c *TopSQLCollector) GetPlan(planDigest []byte) string { } // RegisterSQL uses for testing. -func (c *TopSQLCollector) RegisterSQL(sqlDigest []byte, normalizedSQL string) { +func (c *TopSQLCollector) RegisterSQL(sqlDigest []byte, normalizedSQL string, isInternal bool) { digestStr := string(hack.String(sqlDigest)) c.Lock() _, ok := c.sqlMap[digestStr]