diff --git a/pkg/core/region.go b/pkg/core/region.go index b4ebede7722d..94fc525f11b4 100644 --- a/pkg/core/region.go +++ b/pkg/core/region.go @@ -47,6 +47,7 @@ import ( const ( randomRegionMaxRetry = 10 scanRegionLimit = 1000 + batchSearchSize = 16 // CollectFactor is the factor to collect the count of region. CollectFactor = 0.9 ) @@ -1512,15 +1513,40 @@ func (r *RegionsInfo) QueryRegions( // getRegionsByKeys searches RegionInfo from regionTree by keys. func (r *RegionsInfo) getRegionsByKeys(keys [][]byte) []*RegionInfo { - r.t.RLock() - defer r.t.RUnlock() - return r.tree.searchByKeys(keys) + regions := make([]*RegionInfo, 0, len(keys)) + // Split the keys into multiple batches, and search each batch separately. + // This is to avoid the lock contention on the `regionTree`. + for _, batch := range splitKeysIntoBatches(keys) { + r.t.RLock() + results := r.tree.searchByKeys(batch) + r.t.RUnlock() + regions = append(regions, results...) + } + return regions +} + +func splitKeysIntoBatches(keys [][]byte) [][][]byte { + keysLen := len(keys) + batches := make([][][]byte, 0, (keysLen+batchSearchSize-1)/batchSearchSize) + for i := 0; i < keysLen; i += batchSearchSize { + end := i + batchSearchSize + if end > keysLen { + end = keysLen + } + batches = append(batches, keys[i:end]) + } + return batches } func (r *RegionsInfo) getRegionsByPrevKeys(prevKeys [][]byte) []*RegionInfo { - r.t.RLock() - defer r.t.RUnlock() - return r.tree.searchByPrevKeys(prevKeys) + regions := make([]*RegionInfo, 0, len(prevKeys)) + for _, batch := range splitKeysIntoBatches(prevKeys) { + r.t.RLock() + results := r.tree.searchByPrevKeys(batch) + r.t.RUnlock() + regions = append(regions, results...) + } + return regions } // sortOutKeyIDMap will iterate the regions, convert it to a slice of regionID that corresponds to the input regions. diff --git a/server/grpc_service.go b/server/grpc_service.go index 77ac9ddddc6c..1301abe2ae9a 100644 --- a/server/grpc_service.go +++ b/server/grpc_service.go @@ -1553,9 +1553,6 @@ func (s *GrpcServer) QueryRegion(stream pdpb.PD_QueryRegionServer) error { // TODO: add forwarding logic. - if s.IsClosed() { - return errs.ErrNotStarted - } if clusterID := keypath.ClusterID(); request.GetHeader().GetClusterId() != clusterID { return errs.ErrMismatchClusterID(clusterID, request.GetHeader().GetClusterId()) }