diff --git a/cmd/diag/command/collect.go b/cmd/diag/command/collect.go index a5cbffd6..0360fe42 100644 --- a/cmd/diag/command/collect.go +++ b/cmd/diag/command/collect.go @@ -157,6 +157,9 @@ func newCollectCmd() *cobra.Command { cmd.Flags().StringSliceVar(&ext, "exclude", nil, "types of data not to collect") cmd.Flags().StringSliceVar(&cOpt.MetricsFilter, "metricsfilter", nil, "prefix of metrics to collect") cmd.Flags().StringSliceVar(&cOpt.MetricsExclude, "metricsexclude", []string{"node_interrupts_total"}, "prefix of metrics to exclude") + cmd.Flags().StringSliceVar(&cOpt.MetricsLowPriority, "metrics-low-priority", []string{"tidb_tikvclient_request_seconds_bucket"}, + "prefix of metrics to collect with low priority") + cmd.Flags().IntVar(&cOpt.MetricsMinInterval, "metrics-min-interval", 60, "the minimum interval of a single request in seconds") cmd.Flags().IntVar(&cOpt.MetricsLimit, "metricslimit", 10000, "metric size limit of single request, specified in series*hour per request") cmd.Flags().StringVar(&metricsConf, "metricsconfig", "", "config file of metricsfilter") cmd.Flags().StringSliceVar(&labels, "metricslabel", nil, "only collect metrics that match labels") diff --git a/collector/collect.go b/collector/collect.go index d4c07cd9..d13c88f6 100644 --- a/collector/collect.go +++ b/collector/collect.go @@ -108,28 +108,30 @@ type BaseOptions struct { // CollectOptions contains the options defining which type of data to collect type CollectOptions struct { - RawRequest interface{} // raw collect command or request - Mode string // the cluster is deployed with what type of tool - DiagMode string // run diag collect at command line mode or server mode - ProfileName string // the name of a pre-defined collecting profile - Collectors CollectTree // struct to show which collector is enabled - MetricsFilter []string // prefix of metrics to collect - MetricsExclude []string //prefix of metrics to exclude - MetricsLabel map[string]string // label to filte metrics - Dir string // target directory to store collected data - Limit int // rate limit of SCP - MetricsLimit int // query limit of one request - PerfDuration int //seconds: profile time(s), default is 30s. - CompressScp bool // compress of files during collecting - CompressMetrics bool // compress of files during collecting - RawMonitor bool // collect raw data for metrics - ExitOnError bool // break the process and exit when an error occur - ExtendedAttrs map[string]string // extended attributes used for manual collecting mode - ExplainSQLPath string // File path for explain sql - ExplainSqls []string // explain sqls - CurrDB string - Header []string - UsePortForward bool // use portforward when call api inside k8s cluster + RawRequest interface{} // raw collect command or request + Mode string // the cluster is deployed with what type of tool + DiagMode string // run diag collect at command line mode or server mode + ProfileName string // the name of a pre-defined collecting profile + Collectors CollectTree // struct to show which collector is enabled + MetricsFilter []string // prefix of metrics to collect + MetricsExclude []string // prefix of metrics to exclude + MetricsLowPriority []string // prefix of metrics to collect with low priority + MetricsLabel map[string]string // label to filte metrics + Dir string // target directory to store collected data + Limit int // rate limit of SCP + MetricsLimit int // query limit of one request + MetricsMinInterval int // query minimum interval of one request, default is 1min. + PerfDuration int // seconds: profile time(s), default is 30s. + CompressScp bool // compress of files during collecting + CompressMetrics bool // compress of files during collecting + RawMonitor bool // collect raw data for metrics + ExitOnError bool // break the process and exit when an error occur + ExtendedAttrs map[string]string // extended attributes used for manual collecting mode + ExplainSQLPath string // File path for explain sql + ExplainSqls []string // explain sqls + CurrDB string + Header []string + UsePortForward bool // use portforward when call api inside k8s cluster } // CollectStat is estimated size stats of data to be collected @@ -301,7 +303,9 @@ func (m *Manager) CollectClusterInfo( label: cOpt.MetricsLabel, filter: cOpt.MetricsFilter, exclude: cOpt.MetricsExclude, + lowPriority: cOpt.MetricsLowPriority, limit: cOpt.MetricsLimit, + minInterval: cOpt.MetricsMinInterval, compress: cOpt.CompressMetrics, customHeader: cOpt.Header, portForward: cOpt.UsePortForward, @@ -537,8 +541,9 @@ func (m *Manager) CollectClusterInfo( // run collectors collectErrs := make(map[string]error) for _, c := range collectors { - fmt.Printf("Collecting %s...\n", c.Desc()) - m.logger.Infof("Collecting %s...\n", c.Desc()) + timeNow := time.Now() + fmt.Printf("Collecting %s..., time:%v\n", c.Desc(), timeNow) + m.logger.Infof("Collecting %s..., time:%v\n", c.Desc(), timeNow) if err := c.Collect(m, cls); err != nil { if cOpt.ExitOnError { return "", err @@ -569,7 +574,7 @@ func (m *Manager) CollectClusterInfo( } logStr := fmt.Sprintf("The collected data has been stored in %s. For more details, please refer to the log at %s/diag.log.", dir, dir) fmt.Println(logStr) - m.logger.Infof(logStr) + m.logger.Infof("%s", logStr) return resultDir, nil } diff --git a/collector/prom2influx.go b/collector/prom2influx.go index c07cf9ba..f12b82df 100644 --- a/collector/prom2influx.go +++ b/collector/prom2influx.go @@ -269,7 +269,7 @@ func buildPoints( func writeBatchPoints(client influx.Client, data promDump, opts *RebuildOptions) error { // build and write points var errr error - tl := utils.NewTokenLimiter(uint(opts.Concurrency)) + tl := utils.NewTokenLimiter(opts.Concurrency) wg := sync.WaitGroup{} for _, series := range data.Data.Result { ptChan := buildPoints(series, opts) diff --git a/collector/prometheus.go b/collector/prometheus.go index 08cf20fd..d971af13 100644 --- a/collector/prometheus.go +++ b/collector/prometheus.go @@ -18,6 +18,7 @@ import ( "fmt" "io" "maps" + "net" "net/http" "net/url" "os" @@ -49,12 +50,13 @@ import ( ) const ( - subdirMonitor = "monitor" - subdirAlerts = "alerts" - subdirMetrics = "metrics" - subdirRaw = "raw" - maxQueryRange = 120 * 60 // 120min - minQueryRange = 5 * 60 // 5min + subdirMonitor = "monitor" + subdirAlerts = "alerts" + subdirMetrics = "metrics" + subdirRaw = "raw" + maxQueryRange = 120 * 60 // 120min + smallQueryRange = 15 // 15s + logQuerySeries = 120000 // The value is equal to the result of 3600*speedLimit/300(s), where the default value of speedLimit is 10000. ) type collectMonitor struct { @@ -176,7 +178,9 @@ type MetricCollectOptions struct { metrics []string // metric list filter []string exclude []string + lowPriority []string limit int // series*min per query + minInterval int // the minimum interval of a single request in seconds compress bool customHeader []string endpoint string @@ -281,10 +285,9 @@ func (c *MetricCollectOptions) Collect(m *Manager, topo *models.TiDBCluster) err if c.endpoint == "" { return nil } + startTime := time.Now() mb := progress.NewMultiBar("+ Dumping metrics") bars := make(map[string]*progress.MultiBarItem) - total := len(c.metrics) - mu := sync.Mutex{} key := c.endpoint if _, ok := bars[key]; !ok { @@ -301,9 +304,13 @@ func (c *MetricCollectOptions) Collect(m *Manager, topo *models.TiDBCluster) err if cpuCnt < qLimit { qLimit = cpuCnt } - tl := utils.NewTokenLimiter(uint(qLimit)) + // Prometheus default query.max-concurrency is 20, so here set the max qLimit to 20. + defaultQueryMaxConcurrency := 20 + if qLimit > defaultQueryMaxConcurrency { + qLimit = defaultQueryMaxConcurrency + } + tl := utils.NewTokenLimiter(qLimit) - done := 1 if err := ensureMonitorDir(c.resultDir, subdirMetrics, strings.ReplaceAll(c.endpoint, ":", "-")); err != nil { bars[key].UpdateDisplay(&progress.DisplayProps{ Prefix: fmt.Sprintf(" - Query server %s: %s", key, err), @@ -312,17 +319,85 @@ func (c *MetricCollectOptions) Collect(m *Manager, topo *models.TiDBCluster) err return err } - client := &http.Client{Timeout: time.Second * time.Duration(c.opt.APITimeout)} - for _, mtc := range c.metrics { + client := &http.Client{ + Transport: &http.Transport{ + MaxIdleConns: defaultQueryMaxConcurrency, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 30 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + DialContext: (&net.Dialer{ + Timeout: 5 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + }, + Timeout: time.Second * time.Duration(c.opt.APITimeout), + } + + if len(c.lowPriority) == 0 { + c.collectMetrics(m.logger, client, c.metrics, midPriority, tl, bars) + m.logger.Infof("Dumping metrics finish .......................................... token limit:%d, take time:%v", + qLimit, time.Since(startTime)) + } else { + c.collectMetrics(m.logger, client, c.metrics, highPriority, tl, bars) + m.logger.Infof("Dumping high priority metrics finish .......................................... token limit:%d, concurrency:%d, take time:%v", + qLimit, c.opt.Concurrency, time.Since(startTime)) + startTime = time.Now() + c.collectMetrics(m.logger, client, c.lowPriority, lowPriority, tl, bars) + m.logger.Infof("Dumping low priority metrics finish .......................................... token limit:%d, take time:%v", + qLimit, time.Since(startTime)) + } + + return nil +} + +const ( + lowPriority = -1 + midPriority = 0 + highPriority = 1 +) + +func (c *MetricCollectOptions) collectMetrics( + l *logprinter.Logger, + client *http.Client, + metrics []string, + priority int, + tl *utils.TokenLimiter, + bars map[string]*progress.MultiBarItem, +) { + done := 1 + key := c.endpoint + mu := sync.Mutex{} + minInterval := c.minInterval + if minInterval < smallQueryRange { + minInterval = smallQueryRange + } + concurrency := 1 + total := len(c.metrics) + if priority == highPriority { + total = len(c.metrics) - len(c.lowPriority) + } else if priority == lowPriority { + total = len(c.lowPriority) + concurrency = tl.Cap() + } + tsEnd, _ := utils.ParseTime(c.GetBaseOptions().ScrapeEnd) + tsStart, _ := utils.ParseTime(c.GetBaseOptions().ScrapeBegin) + originInfo := queryRangeInfo{ + queryBegin: tsStart, + queryEnd: tsEnd, + intervalSec: minInterval, + } + for _, mtc := range metrics { + if priority == highPriority && utils.MatchPrefixs(mtc, c.lowPriority) { + continue + } + go func(tok *utils.Token, mtc string) { bars[key].UpdateDisplay(&progress.DisplayProps{ Prefix: fmt.Sprintf(" - Querying server %s", key), Suffix: fmt.Sprintf("%d/%d querying %s ...", done, total, mtc), }) - tsEnd, _ := utils.ParseTime(c.GetBaseOptions().ScrapeEnd) - tsStart, _ := utils.ParseTime(c.GetBaseOptions().ScrapeBegin) - collectMetric(m.logger, client, key, tsStart, tsEnd, mtc, c.label, c.resultDir, c.limit, c.compress, c.customHeader, "") + collectSingleMetric(l, client, key, originInfo, concurrency, mtc, c.label, c.resultDir, c.limit, c.compress, c.customHeader, "", tok.ID, tl) mu.Lock() done++ @@ -337,10 +412,7 @@ func (c *MetricCollectOptions) Collect(m *Manager, topo *models.TiDBCluster) err tl.Put(tok) }(tl.Get(), mtc) } - tl.Wait() - - return nil } func getMetricList(c *http.Client, addr string, customHeader []string) ([]string, error) { @@ -398,24 +470,28 @@ func makeURL(addr string, path string, queries map[string]string) string { return link + "?" + vals.Encode() } -func collectMetric( +func collectSingleMetric( l *logprinter.Logger, c *http.Client, promAddr string, - beginTime, endTime time.Time, + originInfo queryRangeInfo, + concurrency int, mtc string, label map[string]string, resultDir string, - speedlimit int, + speedLimit int, compress bool, customHeader []string, instance string, + curTokenID int, + tl *utils.TokenLimiter, ) { nameSuffix := "" if len(instance) > 0 { nameSuffix = "." + strings.ReplaceAll(instance, ":", "-") } query := generateQueryWitLabel(mtc, label) + beginTime, endTime := originInfo.queryBegin, originInfo.queryEnd queries := map[string]string{ "match[]": query, "start": beginTime.Format(time.RFC3339), @@ -466,7 +542,7 @@ func collectMetric( newLabel := make(map[string]string) maps.Copy(newLabel, label) newLabel["instance"] = instance - collectMetric(l, c, promAddr, beginTime, endTime, mtc, newLabel, resultDir, speedlimit, compress, customHeader, instance) + collectSingleMetric(l, c, promAddr, originInfo, concurrency, mtc, newLabel, resultDir, speedLimit, compress, customHeader, instance, curTokenID, tl) } } return @@ -478,18 +554,38 @@ func collectMetric( } // split time into smaller ranges to avoid querying too many data in one request - if speedlimit == 0 { - speedlimit = 10000 + if speedLimit == 0 { + speedLimit = 10000 } - block := 3600 * speedlimit / series + block := 3600 * speedLimit / series if block > maxQueryRange { block = maxQueryRange } - if block < minQueryRange { - block = minQueryRange + if block < originInfo.intervalSec { + block = originInfo.intervalSec } - l.Debugf("Dumping metric %s-%s-%s%s...", mtc, beginTime.Format(time.RFC3339), endTime.Format(time.RFC3339), nameSuffix) + if block == originInfo.intervalSec || series >= logQuerySeries { + l.Infof("Collecting single metric %s series %d too large and the interval is %ds, concurrency: %d, speedLimit:%d, req timeout:%v ...", + mtc+nameSuffix, series, block, tl.Cap(), speedLimit, c.Timeout) + } + retryOption := tiuputils.RetryOption{ + Attempts: 3, + Delay: time.Microsecond * 300, + Timeout: c.Timeout*3 + 5*time.Second, //make sure the retry timeout is longer than the api timeout + } + goCnt := 0 + taskCnt := 0 + qInfo := queryInfo{ + query: query, + promAddr: promAddr, + customHeader: customHeader, + compress: compress, + retryOption: retryOption, + } + queryInfoCh := make(chan queryInfo, concurrency) + wg := WaitGroupWrapper{} + startTime := time.Now() for queryEnd := endTime; queryEnd.After(beginTime); queryEnd = queryEnd.Add(time.Duration(-block) * time.Second) { querySec := block queryBegin := queryEnd.Add(time.Duration(-block) * time.Second) @@ -497,71 +593,165 @@ func collectMetric( querySec = int(queryEnd.Sub(beginTime).Seconds()) queryBegin = beginTime } - if err := tiuputils.Retry( - func() error { - req, err := http.NewRequest( - http.MethodGet, - fmt.Sprintf("http://%s/api/v1/query?%s", promAddr, url.Values{ - "query": {fmt.Sprintf("%s[%ds]", query, querySec)}, - "time": {queryEnd.Format(time.RFC3339)}, - }.Encode()), - nil) - if err != nil { - return err - } - utils.AddHeaders(req.Header, customHeader) - resp, err := c.Do(req) - if err != nil { - l.Errorf("failed query metric %s: %s, retry...", mtc+nameSuffix, err) - return err - } - // Prometheus API response format is JSON. Every successful API request returns a 2xx status code. - if resp.StatusCode/100 != 2 { - l.Errorf("failed query metric %s: Status Code %d, retry...", mtc+nameSuffix, resp.StatusCode) - } - defer resp.Body.Close() + startTime0 := time.Now() - dst, err := os.Create( - filepath.Join( - resultDir, subdirMonitor, subdirMetrics, strings.ReplaceAll(promAddr, ":", "-"), - fmt.Sprintf("%s-%s-%s%s.json", mtc, queryBegin.Format(time.RFC3339), queryEnd.Format(time.RFC3339), nameSuffix), - ), - ) - if err != nil { - l.Errorf("collect metric %s: %s, retry...", mtc+nameSuffix, err) + qInfo.queryRangeInfo = queryRangeInfo{ + queryBegin: queryBegin, + queryEnd: queryEnd, + intervalSec: querySec, + } + logInfo := "" + if concurrency == 1 { + if err := collectSingleQuery(l, c, curTokenID, resultDir, mtc, nameSuffix, qInfo); err != nil { + l.Errorf("Error quering metrics %s: %s... timeout:%v, take time:%v", + mtc+nameSuffix, err, c.Timeout*3+5*time.Second, time.Since(startTime0)) + } + } else { + queryInfoCh <- qInfo + if goCnt == 0 { + logInfo = fmt.Sprintf(" with a new goroutine ID:%v", curTokenID) + wg.RunWithRecover(func() { collectQueries(l, c, curTokenID, resultDir, mtc, nameSuffix, queryInfoCh) }, nil) + goCnt++ + } else if goCnt < concurrency { + token := tl.TryGet() + if token != nil { + logInfo = fmt.Sprintf(" with a new goroutine ID:%v", token.ID) + wg.RunWithRecover(func() { + collectQueries(l, c, token.ID, resultDir, mtc, nameSuffix, queryInfoCh) + tl.Put(token) + }, nil) + goCnt++ } - defer dst.Close() + } + l.Infof("Collecting single metric %s%s, go:%d, interval:%d s, put task no.%d range[%v:%v] to chan ...", + mtc+nameSuffix, logInfo, goCnt, qInfo.intervalSec, taskCnt, queryBegin.Format(time.RFC3339), queryEnd.Format(time.RFC3339)) + } + taskCnt++ + } + if concurrency == 1 { + return + } - var enc io.WriteCloser - var n int64 - if compress { - // compress the metric - enc, err = zstd.NewWriter(dst) - if err != nil { - l.Errorf("failed compressing metric %s: %s, retry...\n", mtc+nameSuffix, err) - return err - } - defer enc.Close() - } else { - enc = dst - } - n, err = io.Copy(enc, resp.Body) + startTime1 := time.Now() + for { + if len(queryInfoCh) == 0 { + close(queryInfoCh) + break + } + if wg.PanicCnt == concurrency { + break + } + time.Sleep(5 * time.Millisecond) + } + wg.Wait() + l.Infof("Collected single metric %s from %s to %s take time:%v, total task:%v, concurrency:%d, wait take time:%v", + mtc+nameSuffix, endTime.Format(time.RFC3339), beginTime.Format(time.RFC3339), time.Since(startTime), goCnt, concurrency, time.Since(startTime1)) +} + +type queryInfo struct { + query string + promAddr string + queryRangeInfo + customHeader []string + compress bool + retryOption tiuputils.RetryOption +} + +type queryRangeInfo struct { + queryBegin time.Time + queryEnd time.Time + intervalSec int +} + +func collectQueries(l *logprinter.Logger, c *http.Client, tokenID int, resultDir, mtc, nameSuffix string, + queryInfoCh chan queryInfo) { + for { + qInfo, ok := <-queryInfoCh + if !ok { + l.Infof("[ID:%d] collect metric %s goroutine finished", tokenID, mtc+nameSuffix) + return + } + + startTime0 := time.Now() + err := collectSingleQuery(l, c, tokenID, resultDir, mtc, nameSuffix, qInfo) + if err != nil { + l.Errorf("[ID:%d] failed retry collecting a query metric %s: %s... client timeout:%v, take time:%v", + tokenID, mtc+nameSuffix, err, c.Timeout*3+5*time.Second, time.Since(startTime0)) + } + } +} + +func collectSingleQuery(l *logprinter.Logger, c *http.Client, tokenID int, resultDir, mtc, nameSuffix string, qInfo queryInfo) error { + i := 0 + return tiuputils.Retry( + func() error { + startTime := time.Now() + req, err := http.NewRequest( + http.MethodGet, + fmt.Sprintf("http://%s/api/v1/query?%s", qInfo.promAddr, url.Values{ + "query": {fmt.Sprintf("%s[%ds]", qInfo.query, qInfo.intervalSec)}, + "time": {qInfo.queryEnd.Format(time.RFC3339)}, + }.Encode()), + nil) + if err != nil { + return err + } + getTime := time.Since(startTime) + utils.AddHeaders(req.Header, qInfo.customHeader) + resp, err := c.Do(req) + i++ + if err != nil { + l.Errorf("[ID:%d-try:%d] failed query metric %s: %s retry... interval:%v s, take time:%v. If prometheus OOM is the cause, consider reducing concurrency and metrics-min-interval", + tokenID, i, mtc+nameSuffix, err, qInfo.intervalSec, getTime) + time.Sleep(200 * time.Millisecond) + return err + } + // Prometheus API response format is JSON. Every successful API request returns a 2xx status code. + if resp.StatusCode/100 != 2 { + l.Errorf("[ID:%d-try:%d] failed query metric %s Status Code %d, retry... interval:%d s, take time:%v", + tokenID, i, mtc+nameSuffix, resp.StatusCode, qInfo.intervalSec, getTime) + time.Sleep(200 * time.Millisecond) + } + defer resp.Body.Close() + + dst, err := os.Create( + filepath.Join( + resultDir, subdirMonitor, subdirMetrics, strings.ReplaceAll(qInfo.promAddr, ":", "-"), + fmt.Sprintf("%s-%s-%s%s.json", mtc, qInfo.queryBegin.Format(time.RFC3339), qInfo.queryEnd.Format(time.RFC3339), nameSuffix), + ), + ) + if err != nil { + l.Errorf("[ID:%d-try:%d] failed query metric %s: %s, retry...", tokenID, i, mtc+nameSuffix, err) + } + defer dst.Close() + + var enc io.WriteCloser + var n int64 + if qInfo.compress { + // compress the metric + enc, err = zstd.NewWriter(dst) if err != nil { - l.Errorf("failed writing metric %s to file: %s, retry...\n", mtc+nameSuffix, err) + l.Errorf("[ID:%d-try:%d] failed compressing metric %s: %s, retry...\n", tokenID, i, mtc+nameSuffix, err) return err } - l.Debugf(" Dumped metric %s from %s to %s (%d bytes)", mtc+nameSuffix, queryBegin.Format(time.RFC3339), queryEnd.Format(time.RFC3339), n) - return nil - }, - tiuputils.RetryOption{ - Attempts: 3, - Delay: time.Microsecond * 300, - Timeout: c.Timeout*3 + 5*time.Second, //make sure the retry timeout is longer than the api timeout - }, - ); err != nil { - l.Errorf("Error quering metrics %s: %s", mtc+nameSuffix, err) - } - } + defer enc.Close() + } else { + enc = dst + } + n, err = io.Copy(enc, resp.Body) + if err != nil { + l.Errorf("[ID:%d-try:%d] failed writing metric err %s to file: %s, retry...take time:%v \n", + tokenID, i, mtc+nameSuffix, err, time.Since(startTime)) + return err + } + if time.Since(startTime) > time.Second { + l.Infof("[ID:%d-try:%d] Collected a query metric %s from %s to %s (%d bytes) take a long time:%v", + tokenID, i, mtc+nameSuffix, qInfo.queryBegin.Format(time.RFC3339), qInfo.queryEnd.Format(time.RFC3339), n, time.Since(startTime)) + } + return nil + }, + qInfo.retryOption, + ) } func ensureMonitorDir(base string, sub ...string) error { diff --git a/collector/util.go b/collector/util.go new file mode 100644 index 00000000..3188cf2f --- /dev/null +++ b/collector/util.go @@ -0,0 +1,45 @@ +// Copyright 2025 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package collector + +import ( + "sync" +) + +// WaitGroupWrapper is a wrapper for sync.WaitGroup +type WaitGroupWrapper struct { + sync.WaitGroup + PanicCnt int +} + +// RunWithRecover wraps goroutine startup call with force recovery, add 1 to WaitGroup +// and call done when function return. it will dump current goroutine stack into log if catch any recover result. +// exec is that execute logic function. recoverFn is that handler will be called after recover and before dump stack, +// passing `nil` means noop. +func (w *WaitGroupWrapper) RunWithRecover(exec func(), recoverFn func(r any)) { + w.Add(1) + go func() { + defer func() { + r := recover() + if recoverFn != nil { + recoverFn(r) + } + if r != nil { + w.PanicCnt++ + } + w.Done() + }() + exec() + }() +} diff --git a/pkg/utils/tokenlimiter.go b/pkg/utils/tokenlimiter.go index 03cc7ee4..3725ef2b 100644 --- a/pkg/utils/tokenlimiter.go +++ b/pkg/utils/tokenlimiter.go @@ -19,14 +19,20 @@ import ( // Token is used as a permission to keep on running. type Token struct { + ID int } // TokenLimiter is used to limit the number of concurrent tasks. type TokenLimiter struct { - count uint + count int ch chan *Token } +// Cap obtains the cap. +func (tl *TokenLimiter) Cap() int { + return tl.count +} + // Put releases the token. func (tl *TokenLimiter) Put(tk *Token) { tl.ch <- tk @@ -37,18 +43,28 @@ func (tl *TokenLimiter) Get() *Token { return <-tl.ch } +// TryGet trys to obtain a token. +func (tl *TokenLimiter) TryGet() *Token { + select { + case token := <-tl.ch: + return token + default: + return nil + } +} + // Wait all token put back func (tl *TokenLimiter) Wait() { - for len(tl.ch) < int(tl.count) { + for len(tl.ch) < tl.count { runtime.Gosched() } } // NewTokenLimiter creates a TokenLimiter with count tokens. -func NewTokenLimiter(count uint) *TokenLimiter { +func NewTokenLimiter(count int) *TokenLimiter { tl := &TokenLimiter{count: count, ch: make(chan *Token, count)} - for i := uint(0); i < count; i++ { - tl.ch <- &Token{} + for i := 0; i < count; i++ { + tl.ch <- &Token{ID: i} } return tl