From 8fd21f7e49ab6a56ef9aac0036cbc78b26a69306 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 28 Nov 2023 16:10:10 +0800 Subject: [PATCH 1/6] Implement the marshaler interfaces for Rule/RuleOP Signed-off-by: JmPotato --- client/http/codec.go | 99 +++++++++++++++++++++++++++++++++++++++ client/http/codec_test.go | 64 +++++++++++++++++++++++++ client/http/types.go | 76 ++++++++++++++++++++++++++++++ 3 files changed, 239 insertions(+) create mode 100644 client/http/codec.go create mode 100644 client/http/codec_test.go diff --git a/client/http/codec.go b/client/http/codec.go new file mode 100644 index 00000000000..0512c5adce9 --- /dev/null +++ b/client/http/codec.go @@ -0,0 +1,99 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package http + +import ( + "github.com/pingcap/errors" +) + +const ( + encGroupSize = 8 + encMarker = byte(0xFF) + encPad = byte(0x0) +) + +var pads = make([]byte, encGroupSize) + +// encodeBytes guarantees the encoded value is in ascending order for comparison, +// encoding with the following rule: +// +// [group1][marker1]...[groupN][markerN] +// group is 8 bytes slice which is padding with 0. +// marker is `0xFF - padding 0 count` +// +// For example: +// +// [] -> [0, 0, 0, 0, 0, 0, 0, 0, 247] +// [1, 2, 3] -> [1, 2, 3, 0, 0, 0, 0, 0, 250] +// [1, 2, 3, 0] -> [1, 2, 3, 0, 0, 0, 0, 0, 251] +// [1, 2, 3, 4, 5, 6, 7, 8] -> [1, 2, 3, 4, 5, 6, 7, 8, 255, 0, 0, 0, 0, 0, 0, 0, 0, 247] +// +// Refer: https://github.com/facebook/mysql-5.6/wiki/MyRocks-record-format#memcomparable-format +func encodeBytes(data []byte) []byte { + // Allocate more space to avoid unnecessary slice growing. + // Assume that the byte slice size is about `(len(data) / encGroupSize + 1) * (encGroupSize + 1)` bytes, + // that is `(len(data) / 8 + 1) * 9` in our implement. + dLen := len(data) + result := make([]byte, 0, (dLen/encGroupSize+1)*(encGroupSize+1)) + for idx := 0; idx <= dLen; idx += encGroupSize { + remain := dLen - idx + padCount := 0 + if remain >= encGroupSize { + result = append(result, data[idx:idx+encGroupSize]...) + } else { + padCount = encGroupSize - remain + result = append(result, data[idx:]...) + result = append(result, pads[:padCount]...) + } + + marker := encMarker - byte(padCount) + result = append(result, marker) + } + return result +} + +func decodeBytes(b []byte) ([]byte, []byte, error) { + buf := make([]byte, 0, len(b)) + for { + if len(b) < encGroupSize+1 { + return nil, nil, errors.New("insufficient bytes to decode value") + } + + groupBytes := b[:encGroupSize+1] + + group := groupBytes[:encGroupSize] + marker := groupBytes[encGroupSize] + + padCount := encMarker - marker + if padCount > encGroupSize { + return nil, nil, errors.Errorf("invalid marker byte, group bytes %q", groupBytes) + } + + realGroupSize := encGroupSize - padCount + buf = append(buf, group[:realGroupSize]...) + b = b[encGroupSize+1:] + + if padCount != 0 { + // Check validity of padding bytes. + for _, v := range group[realGroupSize:] { + if v != encPad { + return nil, nil, errors.Errorf("invalid padding byte, group bytes %q", groupBytes) + } + } + break + } + } + return b, buf, nil +} diff --git a/client/http/codec_test.go b/client/http/codec_test.go new file mode 100644 index 00000000000..cf3f42430fe --- /dev/null +++ b/client/http/codec_test.go @@ -0,0 +1,64 @@ +// Copyright 2023 TiKV Project Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package http + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBytesCodec(t *testing.T) { + inputs := []struct { + enc []byte + dec []byte + }{ + {[]byte{}, []byte{0, 0, 0, 0, 0, 0, 0, 0, 247}}, + {[]byte{0}, []byte{0, 0, 0, 0, 0, 0, 0, 0, 248}}, + {[]byte{1, 2, 3}, []byte{1, 2, 3, 0, 0, 0, 0, 0, 250}}, + {[]byte{1, 2, 3, 0}, []byte{1, 2, 3, 0, 0, 0, 0, 0, 251}}, + {[]byte{1, 2, 3, 4, 5, 6, 7}, []byte{1, 2, 3, 4, 5, 6, 7, 0, 254}}, + {[]byte{0, 0, 0, 0, 0, 0, 0, 0}, []byte{0, 0, 0, 0, 0, 0, 0, 0, 255, 0, 0, 0, 0, 0, 0, 0, 0, 247}}, + {[]byte{1, 2, 3, 4, 5, 6, 7, 8}, []byte{1, 2, 3, 4, 5, 6, 7, 8, 255, 0, 0, 0, 0, 0, 0, 0, 0, 247}}, + {[]byte{1, 2, 3, 4, 5, 6, 7, 8, 9}, []byte{1, 2, 3, 4, 5, 6, 7, 8, 255, 9, 0, 0, 0, 0, 0, 0, 0, 248}}, + } + + for _, input := range inputs { + b := encodeBytes(input.enc) + require.Equal(t, input.dec, b) + + _, d, err := decodeBytes(b) + require.NoError(t, err) + require.Equal(t, input.enc, d) + } + + // Test error decode. + errInputs := [][]byte{ + {1, 2, 3, 4}, + {0, 0, 0, 0, 0, 0, 0, 247}, + {0, 0, 0, 0, 0, 0, 0, 0, 246}, + {0, 0, 0, 0, 0, 0, 0, 1, 247}, + {1, 2, 3, 4, 5, 6, 7, 8, 0}, + {1, 2, 3, 4, 5, 6, 7, 8, 255, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 255, 1, 2, 3, 4, 5, 6, 7, 8}, + {1, 2, 3, 4, 5, 6, 7, 8, 255, 1, 2, 3, 4, 5, 6, 7, 8, 255}, + {1, 2, 3, 4, 5, 6, 7, 8, 255, 1, 2, 3, 4, 5, 6, 7, 8, 0}, + } + + for _, input := range errInputs { + _, _, err := decodeBytes(input) + require.Error(t, err) + } +} diff --git a/client/http/types.go b/client/http/types.go index 4e99d911e0b..db5ada047af 100644 --- a/client/http/types.go +++ b/client/http/types.go @@ -341,6 +341,44 @@ func (r *Rule) Clone() *Rule { return &clone } +var ( + _ json.Marshaler = (*Rule)(nil) + _ json.Unmarshaler = (*Rule)(nil) +) + +// MarshalJSON implements `json.Marshaler` interface to make sure we could set the correct start/end key. +func (r *Rule) MarshalJSON() ([]byte, error) { + r.StartKeyHex = hex.EncodeToString(encodeBytes(r.StartKey)) + r.EndKeyHex = hex.EncodeToString(encodeBytes(r.EndKey)) + return json.Marshal(r) +} + +// UnmarshalJSON implements `json.Unmarshaler` interface to make sure we could get the correct start/end key. +func (r *Rule) UnmarshalJSON(bytes []byte) error { + if err := json.Unmarshal(bytes, r); err != nil { + return err + } + + startKey, err := hex.DecodeString(r.StartKeyHex) + if err != nil { + return err + } + + endKey, err := hex.DecodeString(r.EndKeyHex) + if err != nil { + return err + } + + _, r.StartKey, err = decodeBytes(startKey) + if err != nil { + return err + } + + _, r.EndKey, err = decodeBytes(endKey) + + return err +} + // RuleOpType indicates the operation type type RuleOpType string @@ -364,6 +402,44 @@ func (r RuleOp) String() string { return string(b) } +var ( + _ json.Marshaler = (*RuleOp)(nil) + _ json.Unmarshaler = (*RuleOp)(nil) +) + +// MarshalJSON implements `json.Marshaler` interface to make sure we could set the correct start/end key. +func (r *RuleOp) MarshalJSON() ([]byte, error) { + r.StartKeyHex = hex.EncodeToString(encodeBytes(r.StartKey)) + r.EndKeyHex = hex.EncodeToString(encodeBytes(r.EndKey)) + return json.Marshal(r) +} + +// UnmarshalJSON implements `json.Unmarshaler` interface to make sure we could get the correct start/end key. +func (r *RuleOp) UnmarshalJSON(bytes []byte) error { + if err := json.Unmarshal(bytes, r); err != nil { + return err + } + + startKey, err := hex.DecodeString(r.StartKeyHex) + if err != nil { + return err + } + + endKey, err := hex.DecodeString(r.EndKeyHex) + if err != nil { + return err + } + + _, r.StartKey, err = decodeBytes(startKey) + if err != nil { + return err + } + + _, r.EndKey, err = decodeBytes(endKey) + + return err +} + // RuleGroup defines properties of a rule group. type RuleGroup struct { ID string `json:"id,omitempty"` From 01dbb8f980b50dbf28a6a7b23cb72e42474cc561 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 28 Nov 2023 16:22:43 +0800 Subject: [PATCH 2/6] Refine the codec function Signed-off-by: JmPotato --- client/http/codec.go | 32 +++- client/http/codec_test.go | 4 +- client/http/types.go | 156 ++++++++++++++---- client/http/types_test.go | 63 +++++++ tests/integrations/client/http_client_test.go | 34 +++- 5 files changed, 242 insertions(+), 47 deletions(-) diff --git a/client/http/codec.go b/client/http/codec.go index 0512c5adce9..cac1a89fd87 100644 --- a/client/http/codec.go +++ b/client/http/codec.go @@ -15,6 +15,8 @@ package http import ( + "encoding/hex" + "github.com/pingcap/errors" ) @@ -64,11 +66,11 @@ func encodeBytes(data []byte) []byte { return result } -func decodeBytes(b []byte) ([]byte, []byte, error) { +func decodeBytes(b []byte) ([]byte, error) { buf := make([]byte, 0, len(b)) for { if len(b) < encGroupSize+1 { - return nil, nil, errors.New("insufficient bytes to decode value") + return nil, errors.New("insufficient bytes to decode value") } groupBytes := b[:encGroupSize+1] @@ -78,7 +80,7 @@ func decodeBytes(b []byte) ([]byte, []byte, error) { padCount := encMarker - marker if padCount > encGroupSize { - return nil, nil, errors.Errorf("invalid marker byte, group bytes %q", groupBytes) + return nil, errors.Errorf("invalid marker byte, group bytes %q", groupBytes) } realGroupSize := encGroupSize - padCount @@ -89,11 +91,31 @@ func decodeBytes(b []byte) ([]byte, []byte, error) { // Check validity of padding bytes. for _, v := range group[realGroupSize:] { if v != encPad { - return nil, nil, errors.Errorf("invalid padding byte, group bytes %q", groupBytes) + return nil, errors.Errorf("invalid padding byte, group bytes %q", groupBytes) } } break } } - return b, buf, nil + return buf, nil +} + +// keyToKeyHexStr converts a raw key to a hex string after encoding. +func rawKeyToKeyHexStr(key []byte) string { + if len(key) == 0 { + return "" + } + return hex.EncodeToString(encodeBytes(key)) +} + +// keyHexStrToRawKey converts a hex string to a raw key after decoding. +func keyHexStrToRawKey(hexKey string) ([]byte, error) { + if len(hexKey) == 0 { + return make([]byte, 0), nil + } + key, err := hex.DecodeString(hexKey) + if err != nil { + return nil, err + } + return decodeBytes(key) } diff --git a/client/http/codec_test.go b/client/http/codec_test.go index cf3f42430fe..fa8d413a0d1 100644 --- a/client/http/codec_test.go +++ b/client/http/codec_test.go @@ -39,7 +39,7 @@ func TestBytesCodec(t *testing.T) { b := encodeBytes(input.enc) require.Equal(t, input.dec, b) - _, d, err := decodeBytes(b) + d, err := decodeBytes(b) require.NoError(t, err) require.Equal(t, input.enc, d) } @@ -58,7 +58,7 @@ func TestBytesCodec(t *testing.T) { } for _, input := range errInputs { - _, _, err := decodeBytes(input) + _, err := decodeBytes(input) require.Error(t, err) } } diff --git a/client/http/types.go b/client/http/types.go index db5ada047af..1d8db36d100 100644 --- a/client/http/types.go +++ b/client/http/types.go @@ -346,37 +346,79 @@ var ( _ json.Unmarshaler = (*Rule)(nil) ) +// This is a helper struct used to customizing the JSON marshal/unmarshal methods of `Rule`. +type rule struct { + GroupID string `json:"group_id"` + ID string `json:"id"` + Index int `json:"index,omitempty"` + Override bool `json:"override,omitempty"` + StartKeyHex string `json:"start_key"` + EndKeyHex string `json:"end_key"` + Role PeerRoleType `json:"role"` + IsWitness bool `json:"is_witness"` + Count int `json:"count"` + LabelConstraints []LabelConstraint `json:"label_constraints,omitempty"` + LocationLabels []string `json:"location_labels,omitempty"` + IsolationLevel string `json:"isolation_level,omitempty"` +} + // MarshalJSON implements `json.Marshaler` interface to make sure we could set the correct start/end key. func (r *Rule) MarshalJSON() ([]byte, error) { - r.StartKeyHex = hex.EncodeToString(encodeBytes(r.StartKey)) - r.EndKeyHex = hex.EncodeToString(encodeBytes(r.EndKey)) - return json.Marshal(r) + tempRule := &rule{ + GroupID: r.GroupID, + ID: r.ID, + Index: r.Index, + Override: r.Override, + StartKeyHex: r.StartKeyHex, + EndKeyHex: r.EndKeyHex, + Role: r.Role, + IsWitness: r.IsWitness, + Count: r.Count, + LabelConstraints: r.LabelConstraints, + LocationLabels: r.LocationLabels, + IsolationLevel: r.IsolationLevel, + } + // Converts the start/end key to hex format if the corresponding hex field is empty. + if len(r.StartKey) > 0 && len(r.StartKeyHex) == 0 { + tempRule.StartKeyHex = rawKeyToKeyHexStr(r.StartKey) + } + if len(r.EndKey) > 0 && len(r.EndKeyHex) == 0 { + tempRule.EndKeyHex = rawKeyToKeyHexStr(r.EndKey) + } + return json.Marshal(tempRule) } // UnmarshalJSON implements `json.Unmarshaler` interface to make sure we could get the correct start/end key. func (r *Rule) UnmarshalJSON(bytes []byte) error { - if err := json.Unmarshal(bytes, r); err != nil { - return err - } - - startKey, err := hex.DecodeString(r.StartKeyHex) + var tempRule rule + err := json.Unmarshal(bytes, &tempRule) if err != nil { return err } - - endKey, err := hex.DecodeString(r.EndKeyHex) + newRule := Rule{ + GroupID: tempRule.GroupID, + ID: tempRule.ID, + Index: tempRule.Index, + Override: tempRule.Override, + StartKeyHex: tempRule.StartKeyHex, + EndKeyHex: tempRule.EndKeyHex, + Role: tempRule.Role, + IsWitness: tempRule.IsWitness, + Count: tempRule.Count, + LabelConstraints: tempRule.LabelConstraints, + LocationLabels: tempRule.LocationLabels, + IsolationLevel: tempRule.IsolationLevel, + } + newRule.StartKey, err = keyHexStrToRawKey(newRule.StartKeyHex) if err != nil { return err } - - _, r.StartKey, err = decodeBytes(startKey) + newRule.EndKey, err = keyHexStrToRawKey(newRule.EndKeyHex) if err != nil { return err } - - _, r.EndKey, err = decodeBytes(endKey) - - return err + *r = newRule + return nil } // RuleOpType indicates the operation type @@ -407,37 +449,87 @@ var ( _ json.Unmarshaler = (*RuleOp)(nil) ) +// This is a helper struct used to customizing the JSON marshal/unmarshal methods of `RuleOp`. +type ruleOp struct { + GroupID string `json:"group_id"` + ID string `json:"id"` + Index int `json:"index,omitempty"` + Override bool `json:"override,omitempty"` + StartKeyHex string `json:"start_key"` + EndKeyHex string `json:"end_key"` + Role PeerRoleType `json:"role"` + IsWitness bool `json:"is_witness"` + Count int `json:"count"` + LabelConstraints []LabelConstraint `json:"label_constraints,omitempty"` + LocationLabels []string `json:"location_labels,omitempty"` + IsolationLevel string `json:"isolation_level,omitempty"` + Action RuleOpType `json:"action"` + DeleteByIDPrefix bool `json:"delete_by_id_prefix"` +} + // MarshalJSON implements `json.Marshaler` interface to make sure we could set the correct start/end key. func (r *RuleOp) MarshalJSON() ([]byte, error) { - r.StartKeyHex = hex.EncodeToString(encodeBytes(r.StartKey)) - r.EndKeyHex = hex.EncodeToString(encodeBytes(r.EndKey)) - return json.Marshal(r) + tempRuleOp := &ruleOp{ + GroupID: r.GroupID, + ID: r.ID, + Index: r.Index, + Override: r.Override, + StartKeyHex: r.StartKeyHex, + EndKeyHex: r.EndKeyHex, + Role: r.Role, + IsWitness: r.IsWitness, + Count: r.Count, + LabelConstraints: r.LabelConstraints, + LocationLabels: r.LocationLabels, + IsolationLevel: r.IsolationLevel, + Action: r.Action, + DeleteByIDPrefix: r.DeleteByIDPrefix, + } + // Converts the start/end key to hex format if the corresponding hex field is empty. + if len(r.StartKey) > 0 && len(r.StartKeyHex) == 0 { + tempRuleOp.StartKeyHex = rawKeyToKeyHexStr(r.StartKey) + } + if len(r.EndKey) > 0 && len(r.EndKeyHex) == 0 { + tempRuleOp.EndKeyHex = rawKeyToKeyHexStr(r.EndKey) + } + return json.Marshal(tempRuleOp) } // UnmarshalJSON implements `json.Unmarshaler` interface to make sure we could get the correct start/end key. func (r *RuleOp) UnmarshalJSON(bytes []byte) error { - if err := json.Unmarshal(bytes, r); err != nil { - return err - } - - startKey, err := hex.DecodeString(r.StartKeyHex) + var tempRuleOp ruleOp + err := json.Unmarshal(bytes, &tempRuleOp) if err != nil { return err } - - endKey, err := hex.DecodeString(r.EndKeyHex) + newRuleOp := RuleOp{ + Rule: &Rule{ + GroupID: tempRuleOp.GroupID, + ID: tempRuleOp.ID, + Index: tempRuleOp.Index, + Override: tempRuleOp.Override, + StartKeyHex: tempRuleOp.StartKeyHex, + EndKeyHex: tempRuleOp.EndKeyHex, + Role: tempRuleOp.Role, + IsWitness: tempRuleOp.IsWitness, + Count: tempRuleOp.Count, + LabelConstraints: tempRuleOp.LabelConstraints, + LocationLabels: tempRuleOp.LocationLabels, + IsolationLevel: tempRuleOp.IsolationLevel, + }, + Action: tempRuleOp.Action, + DeleteByIDPrefix: tempRuleOp.DeleteByIDPrefix, + } + newRuleOp.StartKey, err = keyHexStrToRawKey(newRuleOp.StartKeyHex) if err != nil { return err } - - _, r.StartKey, err = decodeBytes(startKey) + newRuleOp.EndKey, err = keyHexStrToRawKey(newRuleOp.EndKeyHex) if err != nil { return err } - - _, r.EndKey, err = decodeBytes(endKey) - - return err + *r = newRuleOp + return nil } // RuleGroup defines properties of a rule group. diff --git a/client/http/types_test.go b/client/http/types_test.go index 0dfebacbdcf..c43a5453646 100644 --- a/client/http/types_test.go +++ b/client/http/types_test.go @@ -15,6 +15,7 @@ package http import ( + "encoding/json" "testing" "github.com/stretchr/testify/require" @@ -47,3 +48,65 @@ func TestMergeRegionsInfo(t *testing.T) { re.Equal(2, len(regionsInfo.Regions)) re.Equal(append(regionsInfo1.Regions, regionsInfo2.Regions...), regionsInfo.Regions) } + +func TestRuleStartEndKey(t *testing.T) { + re := require.New(t) + // Empty start/end key and key hex. + ruleToMarshal := &Rule{} + rule := mustMarshalAndUnmarshal(re, ruleToMarshal) + re.Equal("", rule.StartKeyHex) + re.Equal("", rule.EndKeyHex) + re.Equal([]byte(""), rule.StartKey) + re.Equal([]byte(""), rule.EndKey) + // Empty start/end key and non-empty key hex. + ruleToMarshal = &Rule{ + StartKeyHex: rawKeyToKeyHexStr([]byte("a")), + EndKeyHex: rawKeyToKeyHexStr([]byte("b")), + } + rule = mustMarshalAndUnmarshal(re, ruleToMarshal) + re.Equal([]byte("a"), rule.StartKey) + re.Equal([]byte("b"), rule.EndKey) + re.Equal(ruleToMarshal.StartKeyHex, rule.StartKeyHex) + re.Equal(ruleToMarshal.EndKeyHex, rule.EndKeyHex) + // Non-empty start/end key and empty key hex. + ruleToMarshal = &Rule{ + StartKey: []byte("a"), + EndKey: []byte("b"), + } + rule = mustMarshalAndUnmarshal(re, ruleToMarshal) + re.Equal(ruleToMarshal.StartKey, rule.StartKey) + re.Equal(ruleToMarshal.EndKey, rule.EndKey) + re.Equal(rawKeyToKeyHexStr(ruleToMarshal.StartKey), rule.StartKeyHex) + re.Equal(rawKeyToKeyHexStr(ruleToMarshal.EndKey), rule.EndKeyHex) + // Non-empty start/end key and non-empty key hex. + ruleToMarshal = &Rule{ + StartKey: []byte("a"), + EndKey: []byte("b"), + StartKeyHex: rawKeyToKeyHexStr([]byte("c")), + EndKeyHex: rawKeyToKeyHexStr([]byte("d")), + } + rule = mustMarshalAndUnmarshal(re, ruleToMarshal) + re.Equal([]byte("c"), rule.StartKey) + re.Equal([]byte("d"), rule.EndKey) + re.Equal(ruleToMarshal.StartKeyHex, rule.StartKeyHex) + re.Equal(ruleToMarshal.EndKeyHex, rule.EndKeyHex) + // Half of each pair of keys is empty. + ruleToMarshal = &Rule{ + StartKey: []byte("a"), + EndKeyHex: rawKeyToKeyHexStr([]byte("d")), + } + rule = mustMarshalAndUnmarshal(re, ruleToMarshal) + re.Equal(ruleToMarshal.StartKey, rule.StartKey) + re.Equal([]byte("d"), rule.EndKey) + re.Equal(rawKeyToKeyHexStr(ruleToMarshal.StartKey), rule.StartKeyHex) + re.Equal(ruleToMarshal.EndKeyHex, rule.EndKeyHex) +} + +func mustMarshalAndUnmarshal(re *require.Assertions, rule *Rule) *Rule { + ruleJSON, err := json.Marshal(rule) + re.NoError(err) + var newRule *Rule + err = json.Unmarshal(ruleJSON, &newRule) + re.NoError(err) + return newRule +} diff --git a/tests/integrations/client/http_client_test.go b/tests/integrations/client/http_client_test.go index a007b893187..0cc3ceaa06a 100644 --- a/tests/integrations/client/http_client_test.go +++ b/tests/integrations/client/http_client_test.go @@ -170,18 +170,22 @@ func (suite *httpClientTestSuite) TestRule() { re.Equal(bundles[0], bundle) // Check if we have the default rule. suite.checkRule(re, &pd.Rule{ - GroupID: placement.DefaultGroupID, - ID: placement.DefaultRuleID, - Role: pd.Voter, - Count: 3, + GroupID: placement.DefaultGroupID, + ID: placement.DefaultRuleID, + Role: pd.Voter, + Count: 3, + StartKey: []byte{}, + EndKey: []byte{}, }, 1, true) // Should be the same as the rules in the bundle. suite.checkRule(re, bundle.Rules[0], 1, true) testRule := &pd.Rule{ - GroupID: placement.DefaultGroupID, - ID: "test", - Role: pd.Voter, - Count: 3, + GroupID: placement.DefaultGroupID, + ID: "test", + Role: pd.Voter, + Count: 3, + StartKey: []byte{}, + EndKey: []byte{}, } err = suite.client.SetPlacementRule(suite.ctx, testRule) re.NoError(err) @@ -233,6 +237,18 @@ func (suite *httpClientTestSuite) TestRule() { ruleGroup, err = suite.client.GetPlacementRuleGroupByID(suite.ctx, testRuleGroup.ID) re.ErrorContains(err, http.StatusText(http.StatusNotFound)) re.Empty(ruleGroup) + // Test the start key and end key. + testRule = &pd.Rule{ + GroupID: placement.DefaultGroupID, + ID: "test", + Role: pd.Voter, + Count: 5, + StartKey: []byte("a1"), + EndKey: []byte(""), + } + err = suite.client.SetPlacementRule(suite.ctx, testRule) + re.NoError(err) + suite.checkRule(re, testRule, 1, true) } func (suite *httpClientTestSuite) checkRule( @@ -262,6 +278,8 @@ func checkRuleFunc( re.Equal(rule.ID, r.ID) re.Equal(rule.Role, r.Role) re.Equal(rule.Count, r.Count) + re.Equal(rule.StartKey, r.StartKey) + re.Equal(rule.EndKey, r.EndKey) return } if exist { From 75be3f794520daa52899d5acbe067799b8831268 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 29 Nov 2023 15:33:35 +0800 Subject: [PATCH 3/6] Add onlyCount to GetRegionStatusByKeyRange Signed-off-by: JmPotato --- client/http/api.go | 6 +++++- client/http/client.go | 6 +++--- tests/integrations/client/http_client_test.go | 9 +++++++-- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/client/http/api.go b/client/http/api.go index 6b317330b61..7da60934fb2 100644 --- a/client/http/api.go +++ b/client/http/api.go @@ -96,8 +96,12 @@ func RegionsByStoreID(storeID uint64) string { } // RegionStatsByKeyRange returns the path of PD HTTP API to get region stats by start key and end key. -func RegionStatsByKeyRange(keyRange *KeyRange) string { +func RegionStatsByKeyRange(keyRange *KeyRange, onlyCount bool) string { startKeyStr, endKeyStr := keyRange.EscapeAsUTF8Str() + if onlyCount { + return fmt.Sprintf("%s?start_key=%s&end_key=%s&count", + StatsRegion, startKeyStr, endKeyStr) + } return fmt.Sprintf("%s?start_key=%s&end_key=%s", StatsRegion, startKeyStr, endKeyStr) } diff --git a/client/http/client.go b/client/http/client.go index d15693e11d4..ffcdcb02d10 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -50,7 +50,7 @@ type Client interface { GetHotReadRegions(context.Context) (*StoreHotPeersInfos, error) GetHotWriteRegions(context.Context) (*StoreHotPeersInfos, error) GetHistoryHotRegions(context.Context, *HistoryHotRegionsRequest) (*HistoryHotRegions, error) - GetRegionStatusByKeyRange(context.Context, *KeyRange) (*RegionStats, error) + GetRegionStatusByKeyRange(context.Context, *KeyRange, bool) (*RegionStats, error) GetStores(context.Context) (*StoresInfo, error) /* Config-related interfaces */ GetScheduleConfig(context.Context) (map[string]interface{}, error) @@ -399,10 +399,10 @@ func (c *client) GetHistoryHotRegions(ctx context.Context, req *HistoryHotRegion // GetRegionStatusByKeyRange gets the region status by key range. // The keys in the key range should be encoded in the UTF-8 bytes format. -func (c *client) GetRegionStatusByKeyRange(ctx context.Context, keyRange *KeyRange) (*RegionStats, error) { +func (c *client) GetRegionStatusByKeyRange(ctx context.Context, keyRange *KeyRange, onlyCount bool) (*RegionStats, error) { var regionStats RegionStats err := c.requestWithRetry(ctx, - "GetRegionStatusByKeyRange", RegionStatsByKeyRange(keyRange), + "GetRegionStatusByKeyRange", RegionStatsByKeyRange(keyRange, onlyCount), http.MethodGet, http.NoBody, ®ionStats, ) if err != nil { diff --git a/tests/integrations/client/http_client_test.go b/tests/integrations/client/http_client_test.go index 0cc3ceaa06a..500da8c6ace 100644 --- a/tests/integrations/client/http_client_test.go +++ b/tests/integrations/client/http_client_test.go @@ -105,9 +105,14 @@ func (suite *httpClientTestSuite) TestMeta() { re.NoError(err) re.Equal(int64(2), regions.Count) re.Len(regions.Regions, 2) - regionStats, err := suite.client.GetRegionStatusByKeyRange(suite.ctx, pd.NewKeyRange([]byte("a1"), []byte("a3"))) + regionStats, err := suite.client.GetRegionStatusByKeyRange(suite.ctx, pd.NewKeyRange([]byte("a1"), []byte("a3")), false) re.NoError(err) - re.Equal(2, regionStats.Count) + re.Greater(regionStats.Count, 0) + re.NotEmpty(regionStats.StoreLeaderCount) + regionStats, err = suite.client.GetRegionStatusByKeyRange(suite.ctx, pd.NewKeyRange([]byte("a1"), []byte("a3")), true) + re.NoError(err) + re.Greater(regionStats.Count, 0) + re.Empty(regionStats.StoreLeaderCount) hotReadRegions, err := suite.client.GetHotReadRegions(suite.ctx) re.NoError(err) re.Len(hotReadRegions.AsPeer, 1) From 2384f38ca62398816f283a68a96cbd6688e82d52 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 29 Nov 2023 16:01:21 +0800 Subject: [PATCH 4/6] Add GetRegionsReplicatedStateByKeyRange Signed-off-by: JmPotato --- client/http/api.go | 8 ++++++++ client/http/client.go | 13 +++++++++++++ tests/integrations/client/http_client_test.go | 3 +++ 3 files changed, 24 insertions(+) diff --git a/client/http/api.go b/client/http/api.go index 7da60934fb2..f744fd0c395 100644 --- a/client/http/api.go +++ b/client/http/api.go @@ -31,6 +31,7 @@ const ( Regions = "/pd/api/v1/regions" regionsByKey = "/pd/api/v1/regions/key" RegionsByStoreIDPrefix = "/pd/api/v1/regions/store" + regionsReplicated = "/pd/api/v1/regions/replicated" EmptyRegions = "/pd/api/v1/regions/check/empty-region" AccelerateSchedule = "/pd/api/v1/regions/accelerate-schedule" AccelerateScheduleInBatch = "/pd/api/v1/regions/accelerate-schedule/batch" @@ -95,6 +96,13 @@ func RegionsByStoreID(storeID uint64) string { return fmt.Sprintf("%s/%d", RegionsByStoreIDPrefix, storeID) } +// RegionsReplicatedByKeyRange returns the path of PD HTTP API to get replicated regions with given start key and end key. +func RegionsReplicatedByKeyRange(keyRange *KeyRange) string { + startKeyStr, endKeyStr := keyRange.EscapeAsHexStr() + return fmt.Sprintf("%s?startKey=%s&endKey=%s", + regionsReplicated, startKeyStr, endKeyStr) +} + // RegionStatsByKeyRange returns the path of PD HTTP API to get region stats by start key and end key. func RegionStatsByKeyRange(keyRange *KeyRange, onlyCount bool) string { startKeyStr, endKeyStr := keyRange.EscapeAsUTF8Str() diff --git a/client/http/client.go b/client/http/client.go index ffcdcb02d10..21b7bea20db 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -47,6 +47,7 @@ type Client interface { GetRegions(context.Context) (*RegionsInfo, error) GetRegionsByKeyRange(context.Context, *KeyRange, int) (*RegionsInfo, error) GetRegionsByStoreID(context.Context, uint64) (*RegionsInfo, error) + GetRegionsReplicatedStateByKeyRange(context.Context, *KeyRange) (string, error) GetHotReadRegions(context.Context) (*StoreHotPeersInfos, error) GetHotWriteRegions(context.Context) (*StoreHotPeersInfos, error) GetHistoryHotRegions(context.Context, *HistoryHotRegionsRequest) (*HistoryHotRegions, error) @@ -356,6 +357,18 @@ func (c *client) GetRegionsByStoreID(ctx context.Context, storeID uint64) (*Regi return ®ions, nil } +// GetRegionsReplicatedStateByKeyRange gets the regions replicated state info by key range. +func (c *client) GetRegionsReplicatedStateByKeyRange(ctx context.Context, keyRange *KeyRange) (string, error) { + var state string + err := c.requestWithRetry(ctx, + "GetRegionsReplicatedStateByKeyRange", RegionsReplicatedByKeyRange(keyRange), + http.MethodGet, http.NoBody, &state) + if err != nil { + return "", err + } + return state, nil +} + // GetHotReadRegions gets the hot read region statistics info. func (c *client) GetHotReadRegions(ctx context.Context) (*StoreHotPeersInfos, error) { var hotReadRegions StoreHotPeersInfos diff --git a/tests/integrations/client/http_client_test.go b/tests/integrations/client/http_client_test.go index 500da8c6ace..6c636d2a2a1 100644 --- a/tests/integrations/client/http_client_test.go +++ b/tests/integrations/client/http_client_test.go @@ -105,6 +105,9 @@ func (suite *httpClientTestSuite) TestMeta() { re.NoError(err) re.Equal(int64(2), regions.Count) re.Len(regions.Regions, 2) + state, err := suite.client.GetRegionsReplicatedStateByKeyRange(suite.ctx, pd.NewKeyRange([]byte("a1"), []byte("a3"))) + re.NoError(err) + re.Equal("INPROGRESS", state) regionStats, err := suite.client.GetRegionStatusByKeyRange(suite.ctx, pd.NewKeyRange([]byte("a1"), []byte("a3")), false) re.NoError(err) re.Greater(regionStats.Count, 0) From a7f5fb63923cc326bdac6288883b6de115af89e6 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Wed, 29 Nov 2023 17:01:39 +0800 Subject: [PATCH 5/6] Add TestRuleOpStartEndKey Signed-off-by: JmPotato --- client/http/types_test.go | 88 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/client/http/types_test.go b/client/http/types_test.go index c43a5453646..74482e29c3c 100644 --- a/client/http/types_test.go +++ b/client/http/types_test.go @@ -110,3 +110,91 @@ func mustMarshalAndUnmarshal(re *require.Assertions, rule *Rule) *Rule { re.NoError(err) return newRule } + +func TestRuleOpStartEndKey(t *testing.T) { + re := require.New(t) + // Empty start/end key and key hex. + ruleOpToMarshal := &RuleOp{ + Rule: &Rule{}, + } + ruleOp := mustMarshalAndUnmarshalRuleOp(re, ruleOpToMarshal) + re.Equal("", ruleOp.StartKeyHex) + re.Equal("", ruleOp.EndKeyHex) + re.Equal([]byte(""), ruleOp.StartKey) + re.Equal([]byte(""), ruleOp.EndKey) + // Empty start/end key and non-empty key hex. + ruleOpToMarshal = &RuleOp{ + Rule: &Rule{ + StartKeyHex: rawKeyToKeyHexStr([]byte("a")), + EndKeyHex: rawKeyToKeyHexStr([]byte("b")), + }, + Action: RuleOpAdd, + DeleteByIDPrefix: true, + } + ruleOp = mustMarshalAndUnmarshalRuleOp(re, ruleOpToMarshal) + re.Equal([]byte("a"), ruleOp.StartKey) + re.Equal([]byte("b"), ruleOp.EndKey) + re.Equal(ruleOpToMarshal.StartKeyHex, ruleOp.StartKeyHex) + re.Equal(ruleOpToMarshal.EndKeyHex, ruleOp.EndKeyHex) + re.Equal(ruleOpToMarshal.Action, ruleOp.Action) + re.Equal(ruleOpToMarshal.DeleteByIDPrefix, ruleOp.DeleteByIDPrefix) + // Non-empty start/end key and empty key hex. + ruleOpToMarshal = &RuleOp{ + Rule: &Rule{ + StartKey: []byte("a"), + EndKey: []byte("b"), + }, + Action: RuleOpAdd, + DeleteByIDPrefix: true, + } + ruleOp = mustMarshalAndUnmarshalRuleOp(re, ruleOpToMarshal) + re.Equal(ruleOpToMarshal.StartKey, ruleOp.StartKey) + re.Equal(ruleOpToMarshal.EndKey, ruleOp.EndKey) + re.Equal(rawKeyToKeyHexStr(ruleOpToMarshal.StartKey), ruleOp.StartKeyHex) + re.Equal(rawKeyToKeyHexStr(ruleOpToMarshal.EndKey), ruleOp.EndKeyHex) + re.Equal(ruleOpToMarshal.Action, ruleOp.Action) + re.Equal(ruleOpToMarshal.DeleteByIDPrefix, ruleOp.DeleteByIDPrefix) + // Non-empty start/end key and non-empty key hex. + ruleOpToMarshal = &RuleOp{ + Rule: &Rule{ + StartKey: []byte("a"), + EndKey: []byte("b"), + StartKeyHex: rawKeyToKeyHexStr([]byte("c")), + EndKeyHex: rawKeyToKeyHexStr([]byte("d")), + }, + Action: RuleOpAdd, + DeleteByIDPrefix: true, + } + ruleOp = mustMarshalAndUnmarshalRuleOp(re, ruleOpToMarshal) + re.Equal([]byte("c"), ruleOp.StartKey) + re.Equal([]byte("d"), ruleOp.EndKey) + re.Equal(ruleOpToMarshal.StartKeyHex, ruleOp.StartKeyHex) + re.Equal(ruleOpToMarshal.EndKeyHex, ruleOp.EndKeyHex) + re.Equal(ruleOpToMarshal.Action, ruleOp.Action) + re.Equal(ruleOpToMarshal.DeleteByIDPrefix, ruleOp.DeleteByIDPrefix) + // Half of each pair of keys is empty. + ruleOpToMarshal = &RuleOp{ + Rule: &Rule{ + StartKey: []byte("a"), + EndKeyHex: rawKeyToKeyHexStr([]byte("d")), + }, + Action: RuleOpDel, + DeleteByIDPrefix: false, + } + ruleOp = mustMarshalAndUnmarshalRuleOp(re, ruleOpToMarshal) + re.Equal(ruleOpToMarshal.StartKey, ruleOp.StartKey) + re.Equal([]byte("d"), ruleOp.EndKey) + re.Equal(rawKeyToKeyHexStr(ruleOpToMarshal.StartKey), ruleOp.StartKeyHex) + re.Equal(ruleOpToMarshal.EndKeyHex, ruleOp.EndKeyHex) + re.Equal(ruleOpToMarshal.Action, ruleOp.Action) + re.Equal(ruleOpToMarshal.DeleteByIDPrefix, ruleOp.DeleteByIDPrefix) +} + +func mustMarshalAndUnmarshalRuleOp(re *require.Assertions, ruleOp *RuleOp) *RuleOp { + ruleOpJSON, err := json.Marshal(ruleOp) + re.NoError(err) + var newRuleOp *RuleOp + err = json.Unmarshal(ruleOpJSON, &newRuleOp) + re.NoError(err) + return newRuleOp +} From a98d738338a6e9906c98232d202713af8aaa5041 Mon Sep 17 00:00:00 2001 From: JmPotato Date: Thu, 30 Nov 2023 15:56:46 +0800 Subject: [PATCH 6/6] Address the comments Signed-off-by: JmPotato --- client/http/client.go | 2 ++ client/http/codec.go | 2 +- server/api/stats.go | 5 +++-- server/api/store.go | 16 ++++++++-------- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/client/http/client.go b/client/http/client.go index 21b7bea20db..36355a90d19 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -358,6 +358,7 @@ func (c *client) GetRegionsByStoreID(ctx context.Context, storeID uint64) (*Regi } // GetRegionsReplicatedStateByKeyRange gets the regions replicated state info by key range. +// The keys in the key range should be encoded in the hex bytes format (without encoding to the UTF-8 bytes). func (c *client) GetRegionsReplicatedStateByKeyRange(ctx context.Context, keyRange *KeyRange) (string, error) { var state string err := c.requestWithRetry(ctx, @@ -411,6 +412,7 @@ func (c *client) GetHistoryHotRegions(ctx context.Context, req *HistoryHotRegion } // GetRegionStatusByKeyRange gets the region status by key range. +// If the `onlyCount` flag is true, the result will only include the count of regions. // The keys in the key range should be encoded in the UTF-8 bytes format. func (c *client) GetRegionStatusByKeyRange(ctx context.Context, keyRange *KeyRange, onlyCount bool) (*RegionStats, error) { var regionStats RegionStats diff --git a/client/http/codec.go b/client/http/codec.go index cac1a89fd87..26be64b4f28 100644 --- a/client/http/codec.go +++ b/client/http/codec.go @@ -100,7 +100,7 @@ func decodeBytes(b []byte) ([]byte, error) { return buf, nil } -// keyToKeyHexStr converts a raw key to a hex string after encoding. +// rawKeyToKeyHexStr converts a raw key to a hex string after encoding. func rawKeyToKeyHexStr(key []byte) string { if len(key) == 0 { return "" diff --git a/server/api/stats.go b/server/api/stats.go index 1798597b6cc..915d33ddfdf 100644 --- a/server/api/stats.go +++ b/server/api/stats.go @@ -36,8 +36,9 @@ func newStatsHandler(svr *server.Server, rd *render.Render) *statsHandler { // @Tags stats // @Summary Get region statistics of a specified range. -// @Param start_key query string true "Start key" -// @Param end_key query string true "End key" +// @Param start_key query string true "Start key" +// @Param end_key query string true "End key" +// @Param count query bool false "Whether only count the number of regions" // @Produce json // @Success 200 {object} statistics.RegionStats // @Router /stats/region [get] diff --git a/server/api/store.go b/server/api/store.go index a44850d35cc..8537cd45c5b 100644 --- a/server/api/store.go +++ b/server/api/store.go @@ -172,14 +172,14 @@ func newStoreHandler(handler *server.Handler, rd *render.Render) *storeHandler { } } -// @Tags store +// @Tags store // @Summary Get a store's information. // @Param id path integer true "Store Id" -// @Produce json +// @Produce json // @Success 200 {object} StoreInfo // @Failure 400 {string} string "The input is invalid." // @Failure 404 {string} string "The store does not exist." -// @Failure 500 {string} string "PD server failed to proceed the request." +// @Failure 500 {string} string "PD server failed to proceed the request." // @Router /store/{id} [get] func (h *storeHandler) GetStore(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) @@ -735,13 +735,13 @@ func (h *storesHandler) GetStoresProgress(w http.ResponseWriter, r *http.Request } // @Tags store -// @Summary Get all stores in the cluster. -// @Param state query array true "Specify accepted store states." +// @Summary Get all stores in the cluster. +// @Param state query array true "Specify accepted store states." // @Produce json -// @Success 200 {object} StoresInfo +// @Success 200 {object} StoresInfo // @Failure 500 {string} string "PD server failed to proceed the request." -// @Router /stores [get] -// @Deprecated Better to use /stores/check instead. +// @Router /stores [get] +// @Deprecated Better to use /stores/check instead. func (h *storesHandler) GetAllStores(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) stores := rc.GetMetaStores()