diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8c4fb08..e3e82e0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,11 +11,11 @@ jobs: name: Build and Publish runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 - - uses: actions/setup-go@v1 + - uses: actions/setup-go@v5 with: go-version: 1.21 diff --git a/tools/gen_registry/main.go b/tools/gen_registry/main.go new file mode 100644 index 0000000..d9afd2a --- /dev/null +++ b/tools/gen_registry/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "fmt" + "go/format" + "os" + "path/filepath" + "regexp" + "strings" +) + +func main() { + root, err := os.Getwd() + if err != nil { + panic(err) + } + + registryPath := filepath.Join(root, "registry_types.go") + content, err := os.ReadFile(registryPath) + if err != nil { + panic(err) + } + + re := regexp.MustCompile(`"([A-Za-z0-9_]+)"`) + matches := re.FindAllStringSubmatch(string(content), -1) + if len(matches) == 0 { + panic("no type names found in registry_types.go") + } + + seen := map[string]bool{} + names := make([]string, 0, len(matches)) + for _, match := range matches { + name := match[1] + if seen[name] { + continue + } + seen[name] = true + names = append(names, name) + } + + var out strings.Builder + out.WriteString("// Code generated by tools/gen_registry; DO NOT EDIT.\n\n") + out.WriteString("package types\n\n") + out.WriteString("var baseCodecFactories = map[string]CodecFactory{\n") + for _, name := range names { + out.WriteString(fmt.Sprintf("\t%q: func() Decoder { return &%s{} },\n", strings.ToLower(name), name)) + } + out.WriteString("}\n") + + formatted, err := format.Source([]byte(out.String())) + if err != nil { + panic(err) + } + + outPath := filepath.Join(root, "registry_gen.go") + if err := os.WriteFile(outPath, formatted, 0o644); err != nil { + panic(err) + } +} diff --git a/types/Bool.go b/types/Bool.go index 0a9aa45..7120d1b 100644 --- a/types/Bool.go +++ b/types/Bool.go @@ -8,8 +8,12 @@ func (b *Bool) Process() { b.Value = b.getNextBool() } -func (b *Bool) Encode(value bool) string { - if value { +func (b *Bool) Encode(value interface{}) string { + v, ok := value.(bool) + if !ok { + panic("invalid bool input") + } + if v { return "01" } return "00" diff --git a/types/Bytes.go b/types/Bytes.go index 4c8eb91..3fe7dfd 100644 --- a/types/Bytes.go +++ b/types/Bytes.go @@ -17,18 +17,22 @@ func (b *Bytes) Process() { } } -func (b *Bytes) Encode(value string) string { +func (b *Bytes) Encode(value interface{}) string { + valueStr, ok := value.(string) + if !ok { + panic("invalid bytes input") + } var bytes []byte - if strings.HasPrefix(value, "0x") { - value = utiles.TrimHex(value) - if len(value)%2 == 1 { - value += "0" + if strings.HasPrefix(valueStr, "0x") { + valueStr = utiles.TrimHex(valueStr) + if len(valueStr)%2 == 1 { + valueStr += "0" } } else { - value = utiles.BytesToHex([]byte(value)) + valueStr = utiles.BytesToHex([]byte(valueStr)) } - bytes = utiles.HexToBytes(value) - return Encode("Compact", len(bytes)) + value + bytes = utiles.HexToBytes(valueStr) + return Encode("Compact", len(bytes)) + valueStr } func (b *Bytes) TypeStructString() string { @@ -41,13 +45,17 @@ func (h *HexBytes) Process() { h.Value = utiles.AddHex(utiles.BytesToHex(h.NextBytes(h.ProcessAndUpdateData("Compact").(int)))) } -func (h *HexBytes) Encode(value string) string { - value = utiles.TrimHex(value) - if len(value)%2 == 1 { - value += "0" +func (h *HexBytes) Encode(value interface{}) string { + valueStr, ok := value.(string) + if !ok { + panic("invalid hexbytes input") + } + valueStr = utiles.TrimHex(valueStr) + if len(valueStr)%2 == 1 { + valueStr += "0" } - bytes := utiles.HexToBytes(value) - return Encode("Compact", len(bytes)) + value + bytes := utiles.HexToBytes(valueStr) + return Encode("Compact", len(bytes)) + valueStr } func (h *HexBytes) TypeStructString() string { @@ -56,8 +64,12 @@ func (h *HexBytes) TypeStructString() string { type String struct{ Bytes } -func (s *String) Encode(value string) string { - bytes := []byte(value) +func (s *String) Encode(value interface{}) string { + valueStr, ok := value.(string) + if !ok { + panic("invalid string input") + } + bytes := []byte(valueStr) return Encode("Compact", len(bytes)) + utiles.BytesToHex(bytes) } diff --git a/types/FixedArray.go b/types/FixedArray.go index 370d78c..afcf799 100644 --- a/types/FixedArray.go +++ b/types/FixedArray.go @@ -2,7 +2,6 @@ package types import ( "fmt" - "reflect" "strings" "github.com/itering/scale.go/types/scaleBytes" @@ -49,31 +48,29 @@ func (f *FixedArray) TypeStructString() string { func (f *FixedArray) Encode(value interface{}) string { var raw string - if reflect.TypeOf(value).Kind() == reflect.String && value.(string) == "" { - return "" - } - switch reflect.TypeOf(value).Kind() { - case reflect.Slice: - s := reflect.ValueOf(value) - if s.Len() != f.FixedLength { - panic("fixed length not match") - } - subType := f.SubType - for i := 0; i < s.Len(); i++ { - raw += EncodeWithOpt(subType, s.Index(i).Interface(), &ScaleDecoderOption{Spec: f.Spec, Metadata: f.Metadata}) + if valueStr, ok := value.(string); ok { + if valueStr == "" { + return "" } - return raw - case reflect.String: - valueStr := value.(string) if strings.HasPrefix(valueStr, "0x") { return utiles.TrimHex(valueStr) } else { return utiles.BytesToHex([]byte(valueStr)) } - default: - if f.FixedLength == 1 { - return EncodeWithOpt(f.SubType, value, &ScaleDecoderOption{Spec: f.Spec, Metadata: f.Metadata}) + } + values, ok := asInterfaceSlice(value) + if ok { + if len(values) != f.FixedLength { + panic("fixed length not match") + } + subType := f.SubType + for _, item := range values { + raw += EncodeWithOpt(subType, item, &ScaleDecoderOption{Spec: f.Spec, Metadata: f.Metadata}) } - panic(fmt.Errorf("invalid vec input")) + return raw + } + if f.FixedLength == 1 { + return EncodeWithOpt(f.SubType, value, &ScaleDecoderOption{Spec: f.Spec, Metadata: f.Metadata}) } + panic(fmt.Errorf("invalid fixed array input: expected fixed length %d with subtype %q, got value of type %T", f.FixedLength, f.SubType, value)) } diff --git a/types/Option.go b/types/Option.go index 028d690..837dcb3 100644 --- a/types/Option.go +++ b/types/Option.go @@ -25,9 +25,27 @@ func (o *Option) Encode(value interface{}) string { if v, ok := value.(string); ok && v == "" { return "00" } - if utiles.IsNil(value) { + if value == nil { return "00" } + switch v := value.(type) { + case []byte: + if v == nil { + return "00" + } + case []interface{}: + if v == nil { + return "00" + } + case map[string]interface{}: + if v == nil { + return "00" + } + case error: + if v == nil { + return "00" + } + } if o.SubType == "bool" { if value.(bool) { return "01" diff --git a/types/Results.go b/types/Results.go index 7eb9977..4c05828 100644 --- a/types/Results.go +++ b/types/Results.go @@ -26,15 +26,19 @@ func (b *Result) Process() { } } -func (b *Result) Encode(value map[string]interface{}) string { +func (b *Result) Encode(value interface{}) string { + typed, ok := value.(map[string]interface{}) + if !ok { + panic("invalid Result input") + } subType := strings.Split(b.SubType, ",") if len(subType) != 2 { panic("Result subType not illegal") } - if data, ok := value["Ok"]; ok { + if data, ok := typed["Ok"]; ok { return "00" + EncodeWithOpt(subType[0], data, &ScaleDecoderOption{Spec: b.Spec, Metadata: b.Metadata}) } - if data, ok := value["Error"]; ok { + if data, ok := typed["Error"]; ok { return "01" + EncodeWithOpt(subType[1], data, &ScaleDecoderOption{Spec: b.Spec, Metadata: b.Metadata}) } panic("illegal Result data") diff --git a/types/Struct.go b/types/Struct.go index 118b932..8e2bc54 100644 --- a/types/Struct.go +++ b/types/Struct.go @@ -24,11 +24,15 @@ func (s *Struct) Process() { s.Value = result } -func (s *Struct) Encode(value map[string]interface{}) string { +func (s *Struct) Encode(value interface{}) string { var raw string + typed, ok := value.(map[string]interface{}) + if !ok { + panic("invalid struct input") + } if s.TypeMapping != nil { for k, v := range s.TypeMapping.Names { - raw += EncodeWithOpt(s.TypeMapping.Types[k], value[v], &ScaleDecoderOption{Spec: s.Spec, Metadata: s.Metadata}) + raw += EncodeWithOpt(s.TypeMapping.Types[k], typed[v], &ScaleDecoderOption{Spec: s.Spec, Metadata: s.Metadata}) } } return raw diff --git a/types/Uint.go b/types/Uint.go index 9f834e9..83fcbbb 100644 --- a/types/Uint.go +++ b/types/Uint.go @@ -1,12 +1,10 @@ package types import ( - "bytes" "encoding/binary" "fmt" "io" "math/big" - "reflect" "strings" "github.com/huandu/xstrings" @@ -53,13 +51,10 @@ type U16 struct { } func (u *U16) Process() { - buf := &bytes.Buffer{} - var reader io.Reader - reader = buf - _, _ = buf.Write(u.NextBytes(2)) - c := make([]byte, 2) - _, _ = reader.Read(c) - u.Value = binary.LittleEndian.Uint16(c) + data := u.NextBytes(2) + var c [2]byte + copy(c[:], data) + u.Value = binary.LittleEndian.Uint16(c[:]) } func (u *U16) Encode(value interface{}) string { @@ -78,9 +73,9 @@ func (u *U16) Encode(value interface{}) string { case float64: u16 = uint16(v) } - bs := make([]byte, 2) - binary.LittleEndian.PutUint16(bs, u16) - return utiles.BytesToHex(bs) + var bs [2]byte + binary.LittleEndian.PutUint16(bs[:], u16) + return utiles.BytesToHex(bs[:]) } func (u *U16) TypeStructString() string { @@ -92,13 +87,10 @@ type U32 struct { } func (u *U32) Process() { - buf := &bytes.Buffer{} - var reader io.Reader - reader = buf - _, _ = buf.Write(u.NextBytes(4)) - c := make([]byte, 4) - _, _ = reader.Read(c) - u.Value = binary.LittleEndian.Uint32(c) + data := u.NextBytes(4) + var c [4]byte + copy(c[:], data) + u.Value = binary.LittleEndian.Uint32(c[:]) } func (u *U32) Encode(value interface{}) string { @@ -113,9 +105,9 @@ func (u *U32) Encode(value interface{}) string { case float64: u32 = uint32(v) } - bs := make([]byte, 4) - binary.LittleEndian.PutUint32(bs, u32) - return utiles.BytesToHex(bs) + var bs [4]byte + binary.LittleEndian.PutUint32(bs[:], u32) + return utiles.BytesToHex(bs[:]) } func (u *U32) TypeStructString() string { @@ -128,12 +120,10 @@ type U64 struct { } func (u *U64) Process() { - buf := &bytes.Buffer{} - u.Reader = buf - _, _ = buf.Write(u.NextBytes(8)) - c := make([]byte, 8) - _, _ = u.Reader.Read(c) - u.Value = binary.LittleEndian.Uint64(c) + data := u.NextBytes(8) + var c [8]byte + copy(c[:], data) + u.Value = binary.LittleEndian.Uint64(c[:]) } func (u *U64) Encode(value interface{}) string { @@ -150,9 +140,9 @@ func (u *U64) Encode(value interface{}) string { case float64: u64 = uint64(v) } - bs := make([]byte, 8) - binary.LittleEndian.PutUint64(bs, u64) - return utiles.BytesToHex(bs) + var bs [8]byte + binary.LittleEndian.PutUint64(bs[:], u64) + return utiles.BytesToHex(bs[:]) } func (u *U64) TypeStructString() string { @@ -208,33 +198,38 @@ func (u *U256) Process() { func (u *U256) Encode(value interface{}) string { var raw string - if reflect.TypeOf(value).Kind() == reflect.String && value.(string) == "" { + switch v := value.(type) { + case nil: return "" - } - switch reflect.TypeOf(value).String() { - case reflect.Slice.String(): - s := reflect.ValueOf(value) - if s.Len() != 32 { - panic("fixed length not match") - } - for i := 0; i < s.Len(); i++ { - raw += EncodeWithOpt("U8", s.Index(i).Interface(), nil) + case string: + if v == "" { + return "" } - return raw - case reflect.String.String(): - valueStr := value.(string) - if strings.HasPrefix(valueStr, "0x") { - return utiles.TrimHex(valueStr) + if strings.HasPrefix(v, "0x") { + return utiles.TrimHex(v) } else { - return utiles.BytesToHex([]byte(valueStr)) + return utiles.BytesToHex([]byte(v)) } - case "decimal.Decimal": - value = value.(decimal.Decimal).BigInt() - fallthrough - case "*big.Int": - bigVal := fmt.Sprintf("%064s", value.(*big.Int).Text(16)) + case decimal.Decimal: + bigVal := fmt.Sprintf("%064s", v.BigInt().Text(16)) + return utiles.BytesToHex(utiles.ReverseBytes(utiles.HexToBytes(bigVal))) + case *big.Int: + if v == nil { + return "" + } + bigVal := fmt.Sprintf("%064s", v.Text(16)) return utiles.BytesToHex(utiles.ReverseBytes(utiles.HexToBytes(bigVal))) default: - panic(fmt.Errorf("invalid vec input")) + values, ok := asInterfaceSlice(value) + if !ok { + panic(fmt.Errorf("invalid U256 input: expected slice-like value, got %T (%v)", value, value)) + } + if len(values) != 32 { + panic("fixed length not match") + } + for _, item := range values { + raw += EncodeWithOpt("U8", item, nil) + } + return raw } } diff --git a/types/Vectors.go b/types/Vectors.go index 7055d58..5243397 100644 --- a/types/Vectors.go +++ b/types/Vectors.go @@ -2,7 +2,6 @@ package types import ( "fmt" - "reflect" "strings" "github.com/itering/scale.go/types/scaleBytes" @@ -37,20 +36,24 @@ func (v *Vec) Process() { func (v *Vec) Encode(value interface{}) string { var raw string - if reflect.TypeOf(value).Kind() == reflect.String && value.(string) == "" { + if v, ok := value.(string); ok { + if v == "" { + return Encode("Compact", 0) + } + panic(fmt.Errorf("invalid vec input")) + } + if value == nil { return Encode("Compact", 0) } - switch reflect.TypeOf(value).Kind() { - case reflect.Slice: - s := reflect.ValueOf(value) - raw += Encode("Compact", s.Len()) - for i := 0; i < s.Len(); i++ { - raw += utiles.TrimHex(EncodeWithOpt(v.SubType, s.Index(i).Interface(), &ScaleDecoderOption{Spec: v.Spec, Metadata: v.Metadata})) - } - return raw - default: + values, ok := asInterfaceSlice(value) + if !ok { panic(fmt.Errorf("invalid vec input")) } + raw += Encode("Compact", len(values)) + for _, item := range values { + raw += utiles.TrimHex(EncodeWithOpt(v.SubType, item, &ScaleDecoderOption{Spec: v.Spec, Metadata: v.Metadata})) + } + return raw } func (v *Vec) TypeStructString() string { diff --git a/types/base.go b/types/base.go index dca2f7a..feaac0f 100644 --- a/types/base.go +++ b/types/base.go @@ -1,8 +1,8 @@ package types import ( + "encoding/binary" "fmt" - "reflect" "regexp" "strings" @@ -41,12 +41,21 @@ type AdditionalSigned struct { Type string `json:"type"` } -type IScaleDecoder interface { +type Decoder interface { Init(data scaleBytes.ScaleBytes, option *ScaleDecoderOption) Process() - Encode(interface{}) string - TypeStructString() string + GetData() scaleBytes.ScaleBytes + GetInternalCall() []string + GetValue() interface{} +} + +type TypeMappingGetter interface { + GetTypeMapping() *TypeMapping +} + +type Encoder interface { + Encode(interface{}) string } type ScaleDecoder struct { @@ -102,6 +111,22 @@ func (s *ScaleDecoder) Encode(interface{}) string { panic(fmt.Sprintf("not found base type %s", s.TypeName)) } +func (s *ScaleDecoder) GetData() scaleBytes.ScaleBytes { + return s.Data +} + +func (s *ScaleDecoder) GetInternalCall() []string { + return s.InternalCall +} + +func (s *ScaleDecoder) GetValue() interface{} { + return s.Value +} + +func (s *ScaleDecoder) GetTypeMapping() *TypeMapping { + return s.TypeMapping +} + // TypeStructString Type Struct string func (s *ScaleDecoder) TypeStructString() string { return s.TypeName @@ -157,33 +182,75 @@ func (s *ScaleDecoder) buildStruct() { func (s *ScaleDecoder) ProcessAndUpdateData(typeString string) interface{} { r := RuntimeType{Module: s.Module} + if value, ok := s.fastProcess(typeString); ok { + return value + } - class, value, subType := r.GetCodecClass(typeString, s.Spec) - if class == nil { + decoder, subType, err := r.GetCodec(typeString, s.Spec) + if err != nil { panic(fmt.Sprintf("Not found decoder class %s", typeString)) } offsetStart := s.Data.Offset // init - method, exist := class.MethodByName("Init") - if !exist { - panic(fmt.Sprintf("%s not implement init function", typeString)) - } option := ScaleDecoderOption{SubType: subType, Spec: s.Spec, Metadata: s.Metadata, Module: s.Module, TypeName: typeString} - method.Func.Call([]reflect.Value{value, reflect.ValueOf(s.Data), reflect.ValueOf(&option)}) + decoder.Init(s.Data, &option) // process do decode - value.MethodByName("Process").Call(nil) - elementData := value.Elem().FieldByName("Data").Interface().(scaleBytes.ScaleBytes) - if internalCall := value.Elem().FieldByName("InternalCall").Interface().([]string); len(internalCall) > 0 { + decoder.Process() + elementData := decoder.GetData() + if internalCall := decoder.GetInternalCall(); len(internalCall) > 0 { s.InternalCall = append(s.InternalCall, internalCall...) } s.Data.Offset = elementData.Offset s.Data.Data = elementData.Data s.RawValue = utiles.BytesToHex(s.Data.Data[offsetStart:s.Data.Offset]) - return value.Elem().FieldByName("Value").Interface() + return decoder.GetValue() +} + +func (s *ScaleDecoder) fastProcess(typeString string) (interface{}, bool) { + switch strings.ToLower(typeString) { + case "u8": + offsetStart := s.Data.Offset + data := s.Data.GetNextBytes(1) + var value int + if len(data) > 0 { + value = int(data[0]) + } + s.RawValue = utiles.BytesToHex(s.Data.Data[offsetStart:s.Data.Offset]) + return value, true + case "u16": + offsetStart := s.Data.Offset + data := s.Data.GetNextBytes(2) + var c [2]byte + copy(c[:], data) + s.RawValue = utiles.BytesToHex(s.Data.Data[offsetStart:s.Data.Offset]) + return binary.LittleEndian.Uint16(c[:]), true + case "u32": + offsetStart := s.Data.Offset + data := s.Data.GetNextBytes(4) + var c [4]byte + copy(c[:], data) + s.RawValue = utiles.BytesToHex(s.Data.Data[offsetStart:s.Data.Offset]) + return binary.LittleEndian.Uint32(c[:]), true + case "u64": + offsetStart := s.Data.Offset + data := s.Data.GetNextBytes(8) + var c [8]byte + copy(c[:], data) + s.RawValue = utiles.BytesToHex(s.Data.Data[offsetStart:s.Data.Offset]) + return binary.LittleEndian.Uint64(c[:]), true + case "bool": + offsetStart := s.Data.Offset + data := s.Data.GetNextBytes(1) + value := len(data) > 0 && data[0] == 1 + s.RawValue = utiles.BytesToHex(s.Data.Data[offsetStart:s.Data.Offset]) + return value, true + default: + return nil, false + } } func Encode(typeString string, data interface{}) string { @@ -199,24 +266,21 @@ func EncodeWithOpt(typeString string, data interface{}, opt *ScaleDecoderOption) opt = &ScaleDecoderOption{Spec: -1} } opt.TypeName = typeString - class, value, subType := r.GetCodecClass(typeString, opt.Spec) - if class == nil { + decoder, subType, err := r.GetCodec(typeString, opt.Spec) + if err != nil { panic(fmt.Sprintf("Not found decoder class %s", typeString)) } - method, _ := class.MethodByName("Init") opt.SubType = subType - method.Func.Call([]reflect.Value{value, reflect.ValueOf(scaleBytes.EmptyScaleBytes()), reflect.ValueOf(opt)}) - var val reflect.Value - if data == nil { - val = reflect.New(reflect.TypeOf("")).Elem() - } else { - val = reflect.ValueOf(data) + decoder.Init(scaleBytes.EmptyScaleBytes(), opt) + dataVal := data + if dataVal == nil { + dataVal = "" } - out := value.MethodByName("Encode").Call([]reflect.Value{val}) - if len(out) > 0 { - return utiles.TrimHex(strings.ToLower(out[0].String())) + encoder, ok := decoder.(Encoder) + if !ok { + panic(fmt.Sprintf("%s not implement Encode function", typeString)) } - return "" + return utiles.TrimHex(strings.ToLower(encoder.Encode(dataVal))) } func EqTypeStringWithTypeStruct(typeString string, dest *source.TypeStruct) bool { @@ -249,18 +313,13 @@ func EqTypeStringWithTypeStruct(typeString string, dest *source.TypeStruct) bool // getTypeStructString get type struct string func getTypeStructString(typeString string, recursiveTime int) string { r := RuntimeType{} - class, value, subType := r.GetCodecClass(typeString, 0) - if class == nil { + decoder, subType, err := r.GetCodec(typeString, 0) + if err != nil { return "" } - method, _ := class.MethodByName("Init") opt := &ScaleDecoderOption{SubType: subType, TypeName: typeString, recursiveTime: recursiveTime} - method.Func.Call([]reflect.Value{value, reflect.ValueOf(scaleBytes.EmptyScaleBytes()), reflect.ValueOf(opt)}) - typeNameValue := value.MethodByName("TypeStructString").Call(nil) - if len(typeNameValue) == 0 { - return "" - } - return typeNameValue[0].String() + decoder.Init(scaleBytes.EmptyScaleBytes(), opt) + return decoder.TypeStructString() } // Eq check type string is equal diff --git a/types/basic_types_test.go b/types/basic_types_test.go new file mode 100644 index 0000000..ae49349 --- /dev/null +++ b/types/basic_types_test.go @@ -0,0 +1,52 @@ +package types + +import ( + "testing" + + "github.com/itering/scale.go/types/scaleBytes" + "github.com/itering/scale.go/utiles" + "github.com/stretchr/testify/assert" +) + +func TestBytesAndHexBytes(t *testing.T) { + raw := "1054657374" + m := ScaleDecoder{} + m.Init(scaleBytes.ScaleBytes{Data: utiles.HexToBytes(raw)}, nil) + assert.Equal(t, "Test", m.ProcessAndUpdateData("Bytes").(string)) + assert.Equal(t, raw, Encode("Bytes", "Test")) + + hexRaw := "080102" + m.Init(scaleBytes.ScaleBytes{Data: utiles.HexToBytes(hexRaw)}, nil) + assert.Equal(t, "0x0102", m.ProcessAndUpdateData("HexBytes").(string)) + assert.Equal(t, hexRaw, Encode("HexBytes", "0x0102")) +} + +func TestOptionBool(t *testing.T) { + m := ScaleDecoder{} + m.Init(scaleBytes.ScaleBytes{Data: utiles.HexToBytes("01")}, &ScaleDecoderOption{SubType: "bool"}) + assert.Equal(t, true, m.ProcessAndUpdateData("Option").(bool)) + assert.Equal(t, "01", Encode("Option", true)) + assert.Equal(t, "02", Encode("Option", false)) + assert.Equal(t, "00", Encode("Option", nil)) +} + +func TestNull(t *testing.T) { + m := ScaleDecoder{} + m.Init(scaleBytes.ScaleBytes{Data: []byte{}}, nil) + assert.Nil(t, m.ProcessAndUpdateData("Null")) + assert.Equal(t, "", Encode("Null", nil)) +} + +func TestHashTypes(t *testing.T) { + h256Raw := "1111111111111111111111111111111111111111111111111111111111111111" + m := ScaleDecoder{} + m.Init(scaleBytes.ScaleBytes{Data: utiles.HexToBytes(h256Raw)}, nil) + assert.Equal(t, h256Raw, utiles.TrimHex(m.ProcessAndUpdateData("H256").(string))) + assert.Equal(t, h256Raw, Encode("H256", "0x"+h256Raw)) + + h512Raw := "2222222222222222222222222222222222222222222222222222222222222222" + + "2222222222222222222222222222222222222222222222222222222222222222" + m.Init(scaleBytes.ScaleBytes{Data: utiles.HexToBytes(h512Raw)}, nil) + assert.Equal(t, h512Raw, utiles.TrimHex(m.ProcessAndUpdateData("H512").(string))) + assert.Equal(t, h512Raw, Encode("H512", "0x"+h512Raw)) +} diff --git a/types/benchmark_test.go b/types/benchmark_test.go new file mode 100644 index 0000000..a5efd37 --- /dev/null +++ b/types/benchmark_test.go @@ -0,0 +1,40 @@ +package types + +import ( + "testing" + + "github.com/itering/scale.go/types/scaleBytes" + "github.com/itering/scale.go/utiles" + "github.com/shopspring/decimal" +) + +func BenchmarkDecodeU32(b *testing.B) { + data := utiles.HexToBytes("64000000") + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decoder := ScaleDecoder{} + decoder.Init(scaleBytes.ScaleBytes{Data: data}, nil) + _ = decoder.ProcessAndUpdateData("U32") + } +} + +func BenchmarkDecodeRegistrationBalanceOf(b *testing.B) { + data := utiles.HexToBytes("04010000000200a0724e180900000000000000000000000d505552455354414b452d30310e507572655374616b65204c74641b68747470733a2f2f7777772e707572657374616b652e636f6d2f000000000d40707572657374616b65636f") + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decoder := ScaleDecoder{} + decoder.Init(scaleBytes.ScaleBytes{Data: data}, nil) + _ = decoder.ProcessAndUpdateData("Registration") + } +} + +func BenchmarkEncodeCompactBalance(b *testing.B) { + value := decimal.NewFromInt32(750000000) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = Encode("Compact", value) + } +} diff --git a/types/customType.go b/types/customType.go index 5b03674..b95658b 100644 --- a/types/customType.go +++ b/types/customType.go @@ -21,6 +21,12 @@ func newStruct(names, typeString []string) *TypeMapping { } func RegCustomTypes(registry map[string]source.TypeStruct) { + resetCodecCacheOnExit := false + registerCustomKey := func(key string, factory CodecFactory) { + regCustomKey(key, factory, false) + resetCodecCacheOnExit = true + } + for key := range registry { typeStruct := registry[key] if typeStruct.V14 { @@ -47,7 +53,7 @@ func RegCustomTypes(registry map[string]source.TypeStruct) { instant := TypeRegistry[strings.ToLower(typeString)] TypeRegistryLock.RUnlock() if instant != nil { - regCustomKey(key, instant) + registerCustomKey(key, instant) continue } @@ -58,7 +64,7 @@ func RegCustomTypes(registry map[string]source.TypeStruct) { instant = TypeRegistry[strings.ToLower(explainedType.TypeString)] TypeRegistryLock.RUnlock() if instant != nil { - regCustomKey(key, instant) + registerCustomKey(key, instant) continue } } else { @@ -72,29 +78,44 @@ func RegCustomTypes(registry map[string]source.TypeStruct) { typeParts := reg.FindStringSubmatch(typeString) if len(typeParts) > 2 { if strings.EqualFold(typeParts[1], "vec") { - v := Vec{} - v.SubType = typeParts[2] - regCustomKey(key, &v) + subType := typeParts[2] + registerCustomKey(key, func() Decoder { + v := Vec{} + v.SubType = subType + return &v + }) continue } else if strings.EqualFold(typeParts[1], "option") { - v := Option{} - v.SubType = typeParts[2] - regCustomKey(key, &v) + subType := typeParts[2] + registerCustomKey(key, func() Decoder { + v := Option{} + v.SubType = subType + return &v + }) continue } else if strings.EqualFold(typeParts[1], "compact") { - v := Compact{} - v.SubType = typeParts[2] - regCustomKey(key, &v) + subType := typeParts[2] + registerCustomKey(key, func() Decoder { + v := Compact{} + v.SubType = subType + return &v + }) continue } else if strings.EqualFold(typeParts[1], "BTreeMap") { - v := BTreeMap{} - v.SubType = typeParts[2] - regCustomKey(key, &v) + subType := typeParts[2] + registerCustomKey(key, func() Decoder { + v := BTreeMap{} + v.SubType = subType + return &v + }) continue } else if strings.EqualFold(typeParts[1], "BTreeSet") { - v := BTreeSet{} - v.SubType = typeParts[2] - regCustomKey(key, &v) + subType := typeParts[2] + registerCustomKey(key, func() Decoder { + v := BTreeSet{} + v.SubType = subType + return &v + }) continue } } @@ -103,21 +124,25 @@ func RegCustomTypes(registry map[string]source.TypeStruct) { // Tuple if typeString != "()" && string(typeString[0]) == "(" && typeString[len(typeString)-1:] == ")" { - s := Struct{} - s.TypeString = typeString - s.buildStruct() - regCustomKey(key, &s) + subType := typeString + registerCustomKey(key, func() Decoder { + s := Struct{} + s.TypeString = subType + s.buildStruct() + return &s + }) continue } // Array if typeString != "[]" && string(typeString[0]) == "[" && typeString[len(typeString)-1:] == "]" { if typePart := strings.Split(typeString[1:len(typeString)-1], ";"); len(typePart) == 2 { - fixed := FixedArray{ - FixedLength: utiles.StringToInt(strings.TrimSpace(typePart[1])), - SubType: strings.TrimSpace(typePart[0]), - } - regCustomKey(key, &fixed) + length := utiles.StringToInt(strings.TrimSpace(typePart[1])) + subType := strings.TrimSpace(typePart[0]) + registerCustomKey(key, func() Decoder { + fixed := FixedArray{FixedLength: length, SubType: subType} + return &fixed + }) continue } } @@ -126,9 +151,12 @@ func RegCustomTypes(registry map[string]source.TypeStruct) { reg := regexp.MustCompile("^Result<(.+)>$") typeParts := reg.FindStringSubmatch(typeString) if len(typeParts) > 1 { - r := Result{} - r.SubType = typeParts[1] - regCustomKey(key, &r) + subType := typeParts[1] + registerCustomKey(key, func() Decoder { + r := Result{} + r.SubType = subType + return &r + }) continue } } @@ -138,10 +166,12 @@ func RegCustomTypes(registry map[string]source.TypeStruct) { names = append(names, v[0]) typeStrings = append(typeStrings, v[1]) } - s := Struct{} - s.TypeMapping = newStruct(names, typeStrings) - - regCustomKey(key, &s) + subTypeMapping := newStruct(names, typeStrings) + registerCustomKey(key, func() Decoder { + s := Struct{} + s.TypeMapping = subTypeMapping + return &s + }) continue case "enum": var names, typeStrings []string @@ -149,21 +179,33 @@ func RegCustomTypes(registry map[string]source.TypeStruct) { names = append(names, v[0]) typeStrings = append(typeStrings, v[1]) } - e := Enum{ValueList: typeStruct.ValueList} - e.TypeMapping = newStruct(names, typeStrings) - regCustomKey(key, &e) + subTypeMapping := newStruct(names, typeStrings) + valueList := typeStruct.ValueList + registerCustomKey(key, func() Decoder { + e := Enum{ValueList: valueList} + e.TypeMapping = subTypeMapping + return &e + }) continue case "set": - regCustomKey(key, &Set{ValueList: typeStruct.ValueList, BitLength: typeStruct.BitLength}) + valueList := typeStruct.ValueList + bitLength := typeStruct.BitLength + registerCustomKey(key, func() Decoder { + return &Set{ValueList: valueList, BitLength: bitLength} + }) continue } } + + if resetCodecCacheOnExit { + resetCodecCache() + } } -func regCustomKey(key string, rt interface{}) { +func regCustomKey(key string, factory CodecFactory, reset ...bool) { slice := strings.Split(key, "#") if len(slice) == 2 { // for Special - special := Special{Registry: rt, Version: []int{0, 99999999}} + special := Special{Registry: factory, Version: []int{0, 99999999}} if version := strings.Split(slice[1], "-"); len(version) == 2 { special.Version[0] = utiles.StringToInt(version[0]) if version[1] != "?" { @@ -183,8 +225,15 @@ func regCustomKey(key string, rt interface{}) { specialRegistryLock.Unlock() } else { TypeRegistryLock.Lock() - TypeRegistry[key] = rt + TypeRegistry[key] = factory TypeRegistryLock.Unlock() } + shouldReset := true + if len(reset) > 0 { + shouldReset = reset[0] + } + if shouldReset { + resetCodecCache() + } } diff --git a/types/customType_test.go b/types/customType_test.go index bf6d58b..2a2a9f0 100644 --- a/types/customType_test.go +++ b/types/customType_test.go @@ -6,6 +6,8 @@ import ( "testing" "github.com/itering/scale.go/source" + "github.com/itering/scale.go/types/scaleBytes" + "github.com/itering/scale.go/utiles" ) func TestRegCustomTypesConcurrency(t *testing.T) { @@ -25,3 +27,59 @@ func TestRegCustomTypesConcurrency(t *testing.T) { wg.Wait() } + +func TestCodecCacheConcurrency(t *testing.T) { + RegCustomTypes(map[string]source.TypeStruct{ + "MyVec": {Type: "string", TypeString: "Vec"}, + }) + raw := "080100000002000000" + errCh := make(chan error, 1000) + wg := sync.WaitGroup{} + for i := 0; i < 1000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + m := ScaleDecoder{} + m.Init(scaleBytes.ScaleBytes{Data: utiles.HexToBytes(raw)}, nil) + result := m.ProcessAndUpdateData("MyVec") + values, ok := result.([]interface{}) + if !ok || len(values) != 2 { + errCh <- fmt.Errorf("decode mismatch: %v", result) + return + } + encoded := Encode("MyVec", []interface{}{1, 2}) + if encoded != raw { + errCh <- fmt.Errorf("encode mismatch: %s", encoded) + } + }() + } + wg.Wait() + close(errCh) + for err := range errCh { + t.Error(err) + } +} + +func TestRegCustomTypesResetsCodecCacheAfterBulkRegister(t *testing.T) { + r := RuntimeType{} + _, _, err := r.GetCodec("U32", 0) + if err != nil { + t.Fatal(err) + } + codecCacheLock.RLock() + before := len(codecCache) + codecCacheLock.RUnlock() + if before == 0 { + t.Fatal("expected codec cache to contain entries before custom registration") + } + RegCustomTypes(map[string]source.TypeStruct{ + "CacheResetTypeA": {Type: "string", TypeString: "u32"}, + "CacheResetTypeB": {Type: "string", TypeString: "u64"}, + }) + codecCacheLock.RLock() + after := len(codecCache) + codecCacheLock.RUnlock() + if after != 0 { + t.Fatalf("expected codec cache to be reset after bulk registration, got %d entries", after) + } +} diff --git a/types/encode_helpers.go b/types/encode_helpers.go new file mode 100644 index 0000000..6d0d9d5 --- /dev/null +++ b/types/encode_helpers.go @@ -0,0 +1,42 @@ +package types + +func toInterfaceSlice[T any](values []T) []interface{} { + out := make([]interface{}, len(values)) + for i, v := range values { + out[i] = v + } + return out +} + +func asInterfaceSlice(value interface{}) ([]interface{}, bool) { + switch v := value.(type) { + case []interface{}: + return v, true + case []string: + return toInterfaceSlice(v), true + case []int: + return toInterfaceSlice(v), true + case []int8: + return toInterfaceSlice(v), true + case []int16: + return toInterfaceSlice(v), true + case []int32: + return toInterfaceSlice(v), true + case []int64: + return toInterfaceSlice(v), true + case []uint: + return toInterfaceSlice(v), true + case []uint16: + return toInterfaceSlice(v), true + case []uint32: + return toInterfaceSlice(v), true + case []uint64: + return toInterfaceSlice(v), true + case []byte: + return toInterfaceSlice(v), true + case []bool: + return toInterfaceSlice(v), true + default: + return nil, false + } +} diff --git a/types/registry.go b/types/registry.go index 9c4c388..3a0be69 100644 --- a/types/registry.go +++ b/types/registry.go @@ -1,9 +1,10 @@ +//go:generate go run ../tools/gen_registry + package types import ( "errors" "fmt" - "reflect" "regexp" "strings" "sync" @@ -18,14 +19,23 @@ type RuntimeType struct { Module string } +type CodecFactory func() Decoder + +type codecCacheEntry struct { + factory CodecFactory + subType string +} + type Special struct { Version []int - Registry interface{} + Registry CodecFactory } var ( - TypeRegistry map[string]interface{} + TypeRegistry map[string]CodecFactory TypeRegistryLock = &sync.RWMutex{} + codecCache = make(map[string]codecCacheEntry) + codecCacheLock = &sync.RWMutex{} specialRegistry = make(map[string][]Special) specialRegistryLock = &sync.RWMutex{} V14Types = make(map[string]source.TypeStruct) @@ -52,233 +62,157 @@ func init() { } func regBaseType() { - registry := make(map[string]interface{}) - scales := []interface{}{ - &Null{}, - &U8{}, - &U16{}, - &U32{}, - &U64{}, - &U256{}, - &Float64{}, - &Float32{}, - &U128{}, - &H160{}, - &H256{}, - &H512{}, - &Address{}, - &Option{}, - &Struct{}, - &Enum{}, - &Bytes{}, - &Vec{}, - &BoundedVec{}, - &WeakBoundedVec{}, - &Set{}, - &Compact{}, - &CompactU32{}, - &Bool{}, - &HexBytes{}, - &Moment{}, - &BlockNumber{}, - &AccountId{}, - &BoxProposal{}, - &Signature{}, - &Era{}, - &EraExtrinsic{}, - &Balance{}, - &LogDigest{}, - &Other{}, - &ChangesTrieRoot{}, - &AuthoritiesChange{}, - &SealV0{}, - &Consensus{}, - &Seal{}, - &PreRuntime{}, - &Exposure{}, - &RawAuraPreDigest{}, - &RawBabePreDigestPrimary{}, - &RawBabePreDigestSecondary{}, - &RawBabePreDigestSecondaryVRF{}, - &SlotNumber{}, - &LockIdentifier{}, - &Call{}, - &EcdsaSignature{}, - &EthereumAddress{}, - &Data{}, - &VoteOutcome{}, - &String{}, - &GenericAddress{}, - &OpaqueCall{}, - &BitVec{}, - &MetadataModuleEvent{}, - &MetadataModuleCallArgument{}, - &MetadataModuleCall{}, - &MetadataV13ModuleStorage{}, - &MetadataV7ModuleConstants{}, - &MetadataV7ModuleStorageEntry{}, - &MetadataV13ModuleStorageEntry{}, - &MetadataV8Module{}, - &MetadataV7ModuleStorage{}, - &MetadataV9Decoder{}, - &MetadataV10Decoder{}, - &MetadataV11Decoder{}, - &MetadataV12Decoder{}, - &MetadataV13Decoder{}, - &MetadataV14Decoder{}, - &MetadataV15Decoder{}, - &MetadataV12Module{}, - &MetadataV13Module{}, - &MetadataV14Module{}, - &MetadataV15Module{}, - &RuntimeApiMetadataV15{}, - &RuntimeApiMethodMetadataV15{}, - &RuntimeApiMethodParamMetadataV15{}, - &MetadataV14ModuleStorage{}, - &MetadataV14ModuleStorageEntry{}, - &PalletConstantMetadataV14{}, - &MetadataModuleError{}, - &GenericLookupSource{}, - &BTreeMap{}, - &BTreeSet{}, - &Box{}, - &Result{}, - &RuntimeEnvironmentUpdated{}, - &WrapperOpaque{}, - &Range{}, - &RangeInclusive{}, - &SubstrateFixedU64{}, - &SubstrateFixedI128{}, - &Empty{}, - &OuterEnumsMetadataV15{}, - &CustomMetadataV15{}, + registry := make(map[string]CodecFactory, len(baseCodecFactories)+24) + for key, factory := range baseCodecFactories { + registry[key] = factory } - for _, class := range scales { - valueOf := reflect.ValueOf(class) - if valueOf.Type().Kind() == reflect.Ptr { - registry[strings.ToLower(reflect.Indirect(valueOf).Type().Name())] = class - } else { - registry[strings.ToLower(valueOf.Type().Name())] = class - } - } - registry["compact"] = &CompactU32{} - registry["compact"] = &CompactMoment{} - registry["str"] = &String{} - registry["hash"] = &H256{} - registry["blockhash"] = &H256{} - registry["i8"] = &IntFixed{FixedLength: 1} - registry["i16"] = &IntFixed{FixedLength: 2} - registry["i32"] = &IntFixed{FixedLength: 4} - registry["i64"] = &IntFixed{FixedLength: 8} - registry["i128"] = &IntFixed{FixedLength: 16} - registry["i256"] = &IntFixed{FixedLength: 32} - registry["h128"] = &FixedU8{FixedLength: 16} - registry["[u8; 32]"] = &FixedU8{FixedLength: 32} - registry["[u8; 64]"] = &FixedU8{FixedLength: 64} - registry["[u8; 65]"] = &FixedU8{FixedLength: 65} - registry["[u8; 16]"] = &FixedU8{FixedLength: 16} - registry["[u8; 20]"] = &FixedU8{FixedLength: 20} - registry["[u8; 8]"] = &FixedU8{FixedLength: 8} - registry["[u8; 4]"] = &FixedU8{FixedLength: 4} - registry["[u8; 2]"] = &FixedU8{FixedLength: 2} - registry["[u8; 256]"] = &FixedU8{FixedLength: 256} - registry["[u128; 3]"] = &FixedArray{FixedLength: 3, SubType: "u128"} + registry["compact"] = registry["compactu32"] + registry["compact"] = func() Decoder { return &CompactMoment{} } + registry["str"] = registry["string"] + registry["hash"] = registry["h256"] + registry["blockhash"] = registry["h256"] + registry["i8"] = func() Decoder { return &IntFixed{FixedLength: 1} } + registry["i16"] = func() Decoder { return &IntFixed{FixedLength: 2} } + registry["i32"] = func() Decoder { return &IntFixed{FixedLength: 4} } + registry["i64"] = func() Decoder { return &IntFixed{FixedLength: 8} } + registry["i128"] = func() Decoder { return &IntFixed{FixedLength: 16} } + registry["i256"] = func() Decoder { return &IntFixed{FixedLength: 32} } + registry["h128"] = func() Decoder { return &FixedU8{FixedLength: 16} } + registry["[u8; 32]"] = func() Decoder { return &FixedU8{FixedLength: 32} } + registry["[u8; 64]"] = func() Decoder { return &FixedU8{FixedLength: 64} } + registry["[u8; 65]"] = func() Decoder { return &FixedU8{FixedLength: 65} } + registry["[u8; 16]"] = func() Decoder { return &FixedU8{FixedLength: 16} } + registry["[u8; 20]"] = func() Decoder { return &FixedU8{FixedLength: 20} } + registry["[u8; 8]"] = func() Decoder { return &FixedU8{FixedLength: 8} } + registry["[u8; 4]"] = func() Decoder { return &FixedU8{FixedLength: 4} } + registry["[u8; 2]"] = func() Decoder { return &FixedU8{FixedLength: 2} } + registry["[u8; 256]"] = func() Decoder { return &FixedU8{FixedLength: 256} } + registry["[u128; 3]"] = func() Decoder { return &FixedArray{FixedLength: 3, SubType: "u128"} } TypeRegistryLock.Lock() TypeRegistry = registry TypeRegistryLock.Unlock() + resetCodecCache() // todo change load source pallet type to lazy load RegCustomTypes(source.LoadTypeRegistry([]byte(source.BaseType))) } -func (r *RuntimeType) getCodecInstant(t string, spec int) (reflect.Type, reflect.Value, error) { +func resetCodecCache() { + codecCacheLock.Lock() + codecCache = make(map[string]codecCacheEntry) + codecCacheLock.Unlock() +} + +func (r *RuntimeType) getCodecInstant(t string, spec int) (Decoder, CodecFactory, error) { t = override.ModuleType(strings.ToLower(t), r.Module) - rt, err := r.specialVersionCodec(t, spec) + factory, err := r.specialVersionCodec(t, spec) if err != nil { TypeRegistryLock.RLock() - rt = TypeRegistry[strings.ToLower(t)] + factory = TypeRegistry[strings.ToLower(t)] TypeRegistryLock.RUnlock() // fixed array - if rt == nil && t != "[]" && string(t[0]) == "[" && t[len(t)-1:] == "]" { + if factory == nil && t != "[]" && string(t[0]) == "[" && t[len(t)-1:] == "]" { if typePart := strings.Split(t[1:len(t)-1], ";"); len(typePart) >= 2 { remainPart := typePart[0 : len(typePart)-1] - fixed := FixedArray{ - FixedLength: utiles.StringToInt(strings.TrimSpace(typePart[len(typePart)-1])), - SubType: strings.TrimSpace(strings.Join(remainPart, ";")), - } - rt = &fixed + fixedLength := utiles.StringToInt(strings.TrimSpace(typePart[len(typePart)-1])) + subType := strings.TrimSpace(strings.Join(remainPart, ";")) + factory = func() Decoder { return &FixedArray{FixedLength: fixedLength, SubType: subType} } } } - if rt == nil { - return nil, reflect.ValueOf((*error)(nil)).Elem(), errors.New("Scale codec type nil" + t) + if factory == nil { + return nil, nil, errors.New("Scale codec type nil" + t) } } - value := reflect.ValueOf(rt) - if value.Kind() == reflect.Ptr { - value = reflect.Indirect(value) - } - p := reflect.New(value.Type()) - p.Elem().Set(value) - return p.Type(), p, nil + return factory(), factory, nil } -func (r *RuntimeType) GetCodecClass(typeString string, spec int) (reflect.Type, reflect.Value, string) { +func (r *RuntimeType) GetCodec(typeString string, spec int) (Decoder, string, error) { var typeParts []string typeString = convert.ConvertType(typeString) + cacheKey := fmt.Sprintf("%s|%d|%s", r.Module, spec, typeString) + codecCacheLock.RLock() + entry, ok := codecCache[cacheKey] + codecCacheLock.RUnlock() + if ok { + return entry.factory(), entry.subType, nil + } // complex if typeString[len(typeString)-1:] == ">" { - decoderClass, rc, err := r.getCodecInstant(typeString, spec) + decoder, factory, err := r.getCodecInstant(typeString, spec) if err == nil { - return decoderClass, rc, "" + codecCacheLock.Lock() + codecCache[cacheKey] = codecCacheEntry{factory: factory, subType: ""} + codecCacheLock.Unlock() + return decoder, "", nil } reg := regexp.MustCompile("^([^<]*)<(.+)>$") typeParts = reg.FindStringSubmatch(typeString) } if len(typeParts) > 0 { - class, rc, err := r.getCodecInstant(typeParts[1], spec) + decoder, factory, err := r.getCodecInstant(typeParts[1], spec) if err == nil { - return class, rc, typeParts[2] + codecCacheLock.Lock() + codecCache[cacheKey] = codecCacheEntry{factory: factory, subType: typeParts[2]} + codecCacheLock.Unlock() + return decoder, typeParts[2], nil } } else { - class, rc, err := r.getCodecInstant(typeString, spec) + decoder, factory, err := r.getCodecInstant(typeString, spec) if err == nil { - return class, rc, "" + codecCacheLock.Lock() + codecCache[cacheKey] = codecCacheEntry{factory: factory, subType: ""} + codecCacheLock.Unlock() + return decoder, "", nil } } // Tuple if typeString != "()" && string(typeString[0]) == "(" && typeString[len(typeString)-1:] == ")" { - decoderClass, rc, _ := r.getCodecInstant("Struct", spec) - s := rc.Interface().(*Struct) + decoder, _, err := r.getCodecInstant("Struct", spec) + if err != nil { + return nil, "", err + } + s, ok := decoder.(*Struct) + if !ok { + return nil, "", fmt.Errorf("invalid struct decoder for %s", typeString) + } s.TypeString = typeString s.buildStruct() - return decoderClass, rc, "" + codecCacheLock.Lock() + codecCache[cacheKey] = codecCacheEntry{factory: func() Decoder { + clone := Struct{} + clone.TypeString = typeString + clone.buildStruct() + return &clone + }, subType: ""} + codecCacheLock.Unlock() + return s, "", nil } // namespace if strings.Contains(typeString, "::") && typeString != "::" { namespaceSlice := strings.Split(typeString, "::") - return r.GetCodecClass(namespaceSlice[len(namespaceSlice)-1], spec) + return r.GetCodec(namespaceSlice[len(namespaceSlice)-1], spec) } - return nil, reflect.ValueOf((*error)(nil)).Elem(), "" + return nil, "", fmt.Errorf("scale codec type nil %s", typeString) } -func (r *RuntimeType) specialVersionCodec(t string, spec int) (interface{}, error) { - var rt interface{} +func (r *RuntimeType) specialVersionCodec(t string, spec int) (CodecFactory, error) { + var factory CodecFactory specialRegistryLock.RLock() specials, ok := specialRegistry[t] specialRegistryLock.RUnlock() if ok { for _, special := range specials { if spec >= special.Version[0] && spec <= special.Version[1] { - rt = special.Registry - return rt, nil + factory = special.Registry + return factory, nil } } } - return rt, fmt.Errorf("not found") + return factory, fmt.Errorf("not found") } diff --git a/types/registry_gen.go b/types/registry_gen.go new file mode 100644 index 0000000..623f2fe --- /dev/null +++ b/types/registry_gen.go @@ -0,0 +1,104 @@ +// Code generated by tools/gen_registry; DO NOT EDIT. + +package types + +var baseCodecFactories = map[string]CodecFactory{ + "null": func() Decoder { return &Null{} }, + "u8": func() Decoder { return &U8{} }, + "u16": func() Decoder { return &U16{} }, + "u32": func() Decoder { return &U32{} }, + "u64": func() Decoder { return &U64{} }, + "u256": func() Decoder { return &U256{} }, + "float64": func() Decoder { return &Float64{} }, + "float32": func() Decoder { return &Float32{} }, + "u128": func() Decoder { return &U128{} }, + "h160": func() Decoder { return &H160{} }, + "h256": func() Decoder { return &H256{} }, + "h512": func() Decoder { return &H512{} }, + "address": func() Decoder { return &Address{} }, + "option": func() Decoder { return &Option{} }, + "struct": func() Decoder { return &Struct{} }, + "enum": func() Decoder { return &Enum{} }, + "bytes": func() Decoder { return &Bytes{} }, + "vec": func() Decoder { return &Vec{} }, + "boundedvec": func() Decoder { return &BoundedVec{} }, + "weakboundedvec": func() Decoder { return &WeakBoundedVec{} }, + "set": func() Decoder { return &Set{} }, + "compact": func() Decoder { return &Compact{} }, + "compactu32": func() Decoder { return &CompactU32{} }, + "bool": func() Decoder { return &Bool{} }, + "hexbytes": func() Decoder { return &HexBytes{} }, + "moment": func() Decoder { return &Moment{} }, + "blocknumber": func() Decoder { return &BlockNumber{} }, + "accountid": func() Decoder { return &AccountId{} }, + "boxproposal": func() Decoder { return &BoxProposal{} }, + "signature": func() Decoder { return &Signature{} }, + "era": func() Decoder { return &Era{} }, + "eraextrinsic": func() Decoder { return &EraExtrinsic{} }, + "balance": func() Decoder { return &Balance{} }, + "logdigest": func() Decoder { return &LogDigest{} }, + "other": func() Decoder { return &Other{} }, + "changestrieroot": func() Decoder { return &ChangesTrieRoot{} }, + "authoritieschange": func() Decoder { return &AuthoritiesChange{} }, + "sealv0": func() Decoder { return &SealV0{} }, + "consensus": func() Decoder { return &Consensus{} }, + "seal": func() Decoder { return &Seal{} }, + "preruntime": func() Decoder { return &PreRuntime{} }, + "exposure": func() Decoder { return &Exposure{} }, + "rawaurapredigest": func() Decoder { return &RawAuraPreDigest{} }, + "rawbabepredigestprimary": func() Decoder { return &RawBabePreDigestPrimary{} }, + "rawbabepredigestsecondary": func() Decoder { return &RawBabePreDigestSecondary{} }, + "rawbabepredigestsecondaryvrf": func() Decoder { return &RawBabePreDigestSecondaryVRF{} }, + "slotnumber": func() Decoder { return &SlotNumber{} }, + "lockidentifier": func() Decoder { return &LockIdentifier{} }, + "call": func() Decoder { return &Call{} }, + "ecdsasignature": func() Decoder { return &EcdsaSignature{} }, + "ethereumaddress": func() Decoder { return &EthereumAddress{} }, + "data": func() Decoder { return &Data{} }, + "voteoutcome": func() Decoder { return &VoteOutcome{} }, + "string": func() Decoder { return &String{} }, + "genericaddress": func() Decoder { return &GenericAddress{} }, + "opaquecall": func() Decoder { return &OpaqueCall{} }, + "bitvec": func() Decoder { return &BitVec{} }, + "metadatamoduleevent": func() Decoder { return &MetadataModuleEvent{} }, + "metadatamodulecallargument": func() Decoder { return &MetadataModuleCallArgument{} }, + "metadatamodulecall": func() Decoder { return &MetadataModuleCall{} }, + "metadatav13modulestorage": func() Decoder { return &MetadataV13ModuleStorage{} }, + "metadatav7moduleconstants": func() Decoder { return &MetadataV7ModuleConstants{} }, + "metadatav7modulestorageentry": func() Decoder { return &MetadataV7ModuleStorageEntry{} }, + "metadatav13modulestorageentry": func() Decoder { return &MetadataV13ModuleStorageEntry{} }, + "metadatav8module": func() Decoder { return &MetadataV8Module{} }, + "metadatav7modulestorage": func() Decoder { return &MetadataV7ModuleStorage{} }, + "metadatav9decoder": func() Decoder { return &MetadataV9Decoder{} }, + "metadatav10decoder": func() Decoder { return &MetadataV10Decoder{} }, + "metadatav11decoder": func() Decoder { return &MetadataV11Decoder{} }, + "metadatav12decoder": func() Decoder { return &MetadataV12Decoder{} }, + "metadatav13decoder": func() Decoder { return &MetadataV13Decoder{} }, + "metadatav14decoder": func() Decoder { return &MetadataV14Decoder{} }, + "metadatav15decoder": func() Decoder { return &MetadataV15Decoder{} }, + "metadatav12module": func() Decoder { return &MetadataV12Module{} }, + "metadatav13module": func() Decoder { return &MetadataV13Module{} }, + "metadatav14module": func() Decoder { return &MetadataV14Module{} }, + "metadatav15module": func() Decoder { return &MetadataV15Module{} }, + "runtimeapimetadatav15": func() Decoder { return &RuntimeApiMetadataV15{} }, + "runtimeapimethodmetadatav15": func() Decoder { return &RuntimeApiMethodMetadataV15{} }, + "runtimeapimethodparammetadatav15": func() Decoder { return &RuntimeApiMethodParamMetadataV15{} }, + "metadatav14modulestorage": func() Decoder { return &MetadataV14ModuleStorage{} }, + "metadatav14modulestorageentry": func() Decoder { return &MetadataV14ModuleStorageEntry{} }, + "palletconstantmetadatav14": func() Decoder { return &PalletConstantMetadataV14{} }, + "metadatamoduleerror": func() Decoder { return &MetadataModuleError{} }, + "genericlookupsource": func() Decoder { return &GenericLookupSource{} }, + "btreemap": func() Decoder { return &BTreeMap{} }, + "btreeset": func() Decoder { return &BTreeSet{} }, + "box": func() Decoder { return &Box{} }, + "result": func() Decoder { return &Result{} }, + "runtimeenvironmentupdated": func() Decoder { return &RuntimeEnvironmentUpdated{} }, + "wrapperopaque": func() Decoder { return &WrapperOpaque{} }, + "range": func() Decoder { return &Range{} }, + "rangeinclusive": func() Decoder { return &RangeInclusive{} }, + "substratefixedu64": func() Decoder { return &SubstrateFixedU64{} }, + "substratefixedi128": func() Decoder { return &SubstrateFixedI128{} }, + "empty": func() Decoder { return &Empty{} }, + "outerenumsmetadatav15": func() Decoder { return &OuterEnumsMetadataV15{} }, + "custommetadatav15": func() Decoder { return &CustomMetadataV15{} }, +} diff --git a/types/registry_types.go b/types/registry_types.go new file mode 100644 index 0000000..cf785cd --- /dev/null +++ b/types/registry_types.go @@ -0,0 +1,104 @@ +// Code generated by tools/gen_registry; DO NOT EDIT. + +package types + +var baseCodecTypeNames = []string{ + "Null", + "U8", + "U16", + "U32", + "U64", + "U256", + "Float64", + "Float32", + "U128", + "H160", + "H256", + "H512", + "Address", + "Option", + "Struct", + "Enum", + "Bytes", + "Vec", + "BoundedVec", + "WeakBoundedVec", + "Set", + "Compact", + "CompactU32", + "Bool", + "HexBytes", + "Moment", + "BlockNumber", + "AccountId", + "BoxProposal", + "Signature", + "Era", + "EraExtrinsic", + "Balance", + "LogDigest", + "Other", + "ChangesTrieRoot", + "AuthoritiesChange", + "SealV0", + "Consensus", + "Seal", + "PreRuntime", + "Exposure", + "RawAuraPreDigest", + "RawBabePreDigestPrimary", + "RawBabePreDigestSecondary", + "RawBabePreDigestSecondaryVRF", + "SlotNumber", + "LockIdentifier", + "Call", + "EcdsaSignature", + "EthereumAddress", + "Data", + "VoteOutcome", + "String", + "GenericAddress", + "OpaqueCall", + "BitVec", + "MetadataModuleEvent", + "MetadataModuleCallArgument", + "MetadataModuleCall", + "MetadataV13ModuleStorage", + "MetadataV7ModuleConstants", + "MetadataV7ModuleStorageEntry", + "MetadataV13ModuleStorageEntry", + "MetadataV8Module", + "MetadataV7ModuleStorage", + "MetadataV9Decoder", + "MetadataV10Decoder", + "MetadataV11Decoder", + "MetadataV12Decoder", + "MetadataV13Decoder", + "MetadataV14Decoder", + "MetadataV15Decoder", + "MetadataV12Module", + "MetadataV13Module", + "MetadataV14Module", + "MetadataV15Module", + "RuntimeApiMetadataV15", + "RuntimeApiMethodMetadataV15", + "RuntimeApiMethodParamMetadataV15", + "MetadataV14ModuleStorage", + "MetadataV14ModuleStorageEntry", + "PalletConstantMetadataV14", + "MetadataModuleError", + "GenericLookupSource", + "BTreeMap", + "BTreeSet", + "Box", + "Result", + "RuntimeEnvironmentUpdated", + "WrapperOpaque", + "Range", + "RangeInclusive", + "SubstrateFixedU64", + "SubstrateFixedI128", + "Empty", + "OuterEnumsMetadataV15", + "CustomMetadataV15", +} diff --git a/types/type_mapping_test.go b/types/type_mapping_test.go new file mode 100644 index 0000000..cd302cd --- /dev/null +++ b/types/type_mapping_test.go @@ -0,0 +1,35 @@ +package types + +import ( + "testing" + + "github.com/itering/scale.go/types/scaleBytes" + "github.com/stretchr/testify/assert" +) + +func TestGetTypeMappingConsensus(t *testing.T) { + r := RuntimeType{} + dec, _, err := r.GetCodec("Consensus", 0) + assert.NoError(t, err) + dec.Init(scaleBytes.EmptyScaleBytes(), &ScaleDecoderOption{}) + + getter, ok := dec.(TypeMappingGetter) + assert.True(t, ok) + tm := getter.GetTypeMapping() + if assert.NotNil(t, tm) { + assert.Equal(t, []string{"engine", "data"}, tm.Names) + assert.Equal(t, []string{"u32", "Vec"}, tm.Types) + } +} + +func TestGetCodecFixedArrayReturnsFreshInstance(t *testing.T) { + resetCodecCache() + r := RuntimeType{} + first, _, err := r.GetCodec("[u16; 2]", 0) + assert.NoError(t, err) + second, _, err := r.GetCodec("[u16; 2]", 0) + assert.NoError(t, err) + assert.IsType(t, &FixedArray{}, first) + assert.IsType(t, &FixedArray{}, second) + assert.NotSame(t, first, second) +} diff --git a/types/types.go b/types/types.go index 6258c2e..ee39bfe 100644 --- a/types/types.go +++ b/types/types.go @@ -1,9 +1,7 @@ package types import ( - "bytes" "fmt" - "io" "math" "strconv" "strings" @@ -47,8 +45,12 @@ func (h *H160) Process() { h.Value = utiles.AddHex(utiles.BytesToHex(h.NextBytes(20))) } -func (h *H160) Encode(value string) string { - return utiles.AddHex(strings.ToLower(value)) +func (h *H160) Encode(value interface{}) string { + valueStr, ok := value.(string) + if !ok { + panic("invalid H160 input") + } + return utiles.AddHex(strings.ToLower(valueStr)) } func (h *H160) TypeStructString() string { @@ -63,8 +65,12 @@ func (h *H256) Process() { h.Value = utiles.AddHex(utiles.BytesToHex(h.NextBytes(32))) } -func (h *H256) Encode(value string) string { - return utiles.AddHex(strings.ToLower(value)) +func (h *H256) Encode(value interface{}) string { + valueStr, ok := value.(string) + if !ok { + panic("invalid H256 input") + } + return utiles.AddHex(strings.ToLower(valueStr)) } func (h *H256) TypeStructString() string { @@ -79,8 +85,12 @@ func (h *H512) Process() { h.Value = utiles.AddHex(utiles.BytesToHex(h.NextBytes(64))) } -func (h *H512) Encode(value string) string { - return utiles.AddHex(strings.ToLower(value)) +func (h *H512) Encode(value interface{}) string { + valueStr, ok := value.(string) + if !ok { + panic("invalid H512 input") + } + return utiles.AddHex(strings.ToLower(valueStr)) } func (h *H512) TypeStructString() string { @@ -100,7 +110,11 @@ func (e *Era) Process() { } } -func (e *Era) Encode(era string) string { +func (e *Era) Encode(value interface{}) string { + era, ok := value.(string) + if !ok { + panic("invalid Era input") + } return era } @@ -183,8 +197,12 @@ func (s *AccountId) Process() { s.Value = utiles.AddHex(xstrings.RightJustify(utiles.BytesToHex(s.NextBytes(32)), 64, "0")) } -func (s *AccountId) Encode(value string) string { - return value +func (s *AccountId) Encode(value interface{}) string { + valueStr, ok := value.(string) + if !ok { + panic("invalid AccountId input") + } + return valueStr } func (s *AccountId) TypeStructString() string { @@ -196,17 +214,14 @@ type Balance struct { } func (b *Balance) Process() { - buf := &bytes.Buffer{} - var reader io.Reader - reader = buf - _, _ = buf.Write(b.NextBytes(16)) - c := make([]byte, 16) - _, _ = reader.Read(c) - if utiles.BytesToHex(c) == "ffffffffffffffffffffffffffffffff" { + data := b.NextBytes(16) + var c [16]byte + copy(c[:], data) + if utiles.BytesToHex(c[:]) == "ffffffffffffffffffffffffffffffff" { b.Value = decimal.NewFromInt32(-1) return } - b.Value = decimal.NewFromBigInt(uint128.FromBytes(c).Big(), 0) + b.Value = decimal.NewFromBigInt(uint128.FromBytes(c[:]).Big(), 0) } type LogDigest struct{ Enum } @@ -371,8 +386,12 @@ func (d *Data) Process() { } } -func (d *Data) Encode(v map[string]interface{}) string { - key, val, err := utiles.GetEnumValue(v) +func (d *Data) Encode(value interface{}) string { + typed, ok := value.(map[string]interface{}) + if !ok { + panic("invalid Data input") + } + key, val, err := utiles.GetEnumValue(typed) if err != nil { panic(err) } @@ -380,7 +399,11 @@ func (d *Data) Encode(v map[string]interface{}) string { if key == "None" { return Encode("U8", 0) } - return Encode("U8", index+32) + Encode(d.TypeMapping.Types[index], val.(string)) + valStr, ok := val.(string) + if !ok { + panic("invalid Data value") + } + return Encode("U8", index+32) + Encode(d.TypeMapping.Types[index], valStr) } // raw data if strings.HasPrefix(key, "Raw") { @@ -390,10 +413,14 @@ func (d *Data) Encode(v map[string]interface{}) string { } else { l++ indexRaw := Encode("U8", l) - if l == len(val.(string)) { - return indexRaw + val.(string) + valStr, ok := val.(string) + if !ok { + panic("invalid Data raw value") } - return indexRaw + utiles.BytesToHex([]byte(val.(string))) + if l == len(valStr) { + return indexRaw + valStr + } + return indexRaw + utiles.BytesToHex([]byte(valStr)) } } panic("invalid enum key") @@ -494,9 +521,13 @@ func (b *BitVec) TypeStructString() string { return "BitVec" } -func (b *BitVec) Encode(value string) string { - value = strings.TrimPrefix(value, "0b") - values := strings.Split(value, "_") +func (b *BitVec) Encode(value interface{}) string { + valueStr, ok := value.(string) + if !ok { + panic("invalid BitVec input") + } + valueStr = strings.TrimPrefix(valueStr, "0b") + values := strings.Split(valueStr, "_") var u8a []byte for _, v := range values { b, _ := strconv.ParseUint(v, 2, 8) diff --git a/types/types_test.go b/types/types_test.go index 51fcc03..42ab71b 100644 --- a/types/types_test.go +++ b/types/types_test.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "encoding/json" "math/big" - "reflect" "strings" "sync" "testing" @@ -146,9 +145,7 @@ func TestReferendumInfo(t *testing.T) { "proposalHash": "0x295ce46278975a53b855188482af699f7726fbbeac89cf16a1741c4698dcdbc9", "tally": map[string]interface{}{"ayes": "0", "nays": "0", "turnout": "0"}, "threshold": "SuperMajorityApprove", }} - if !reflect.DeepEqual(utiles.ToString(c), utiles.ToString(r)) { - t.Errorf("Test TestReferendumInfo Process fail, decode return %v", r.(map[string]interface{})) - } + assert.Equal(t, utiles.ToString(c), utiles.ToString(r)) } func TestEthereumAccountId(t *testing.T) { @@ -365,16 +362,34 @@ func TestFixedArray(t *testing.T) { } } +func TestFixedArrayEncodeInvalidInputMessage(t *testing.T) { + assert.PanicsWithError(t, `invalid fixed array input: expected fixed length 2 with subtype "u16", got value of type int`, func() { + Encode("[u16; 2]", 1) + }) +} + +func TestVecEncodeSliceTypes(t *testing.T) { + assert.Equal(t, "080100000002000000", Encode("Vec", []uint32{1, 2})) + assert.Equal(t, "080100000002000000", Encode("Vec", []int{1, 2})) +} + func TestU256(t *testing.T) { raw := "0x1001000000000000000000000000000000000000000000000000000000000000" m := ScaleDecoder{} m.Init(scaleBytes.ScaleBytes{Data: utiles.HexToBytes(raw)}, nil) assert.Equal(t, raw, utiles.AddHex(Encode("U256", m.ProcessAndUpdateData("U256").(decimal.Decimal)))) assert.Equal(t, raw, utiles.AddHex(Encode("U256", "0x1001000000000000000000000000000000000000000000000000000000000000"))) + assert.Equal(t, raw, utiles.AddHex(Encode("U256", utiles.HexToBytes(raw)))) m.Init(scaleBytes.ScaleBytes{Data: utiles.HexToBytes("0x00b5070000000000000000000000000000000000000000000000000000000000")}, nil) assert.Equal(t, int64(505088), m.ProcessAndUpdateData("U256").(decimal.Decimal).IntPart()) } +func TestU256EncodeInvalidInputMessage(t *testing.T) { + assert.PanicsWithError(t, "invalid U256 input: expected slice-like value, got int (1)", func() { + Encode("U256", 1) + }) +} + func TestXcmV2ResultType(t *testing.T) { raw := "0204031501020042f2f9dc" m := ScaleDecoder{} diff --git a/types/v13.go b/types/v13.go index 536c03c..c427a40 100644 --- a/types/v13.go +++ b/types/v13.go @@ -97,7 +97,12 @@ func registerOriginCaller(originCallers []OriginCaller) { e.TypeMapping.Types = append(e.TypeMapping.Types, "NULL") } } - regCustomKey(strings.ToLower("OriginCaller"), &e) + typeMapping := e.TypeMapping + regCustomKey(strings.ToLower("OriginCaller"), func() Decoder { + origin := Enum{} + origin.TypeMapping = typeMapping + return &origin + }) } type MetadataV13Module struct { diff --git a/utiles/tools.go b/utiles/tools.go index 752ccbe..928b0c0 100644 --- a/utiles/tools.go +++ b/utiles/tools.go @@ -5,11 +5,11 @@ import ( "encoding/json" "errors" "fmt" - "github.com/shopspring/decimal" "math/big" - "reflect" "strconv" "strings" + + "github.com/shopspring/decimal" ) func StringToInt(s string) int { @@ -154,17 +154,6 @@ func U8Encode(i int) string { return BytesToHex(bs) } -func IsNil(a interface{}) bool { - if a == nil { - return true - } - switch reflect.TypeOf(a).Kind() { - case reflect.Ptr, reflect.Map, reflect.Array, reflect.Chan, reflect.Slice, reflect.Interface, reflect.Func: - return reflect.ValueOf(a).IsNil() - } - return false -} - // GetEnumValue get enum single key && value func GetEnumValue(e map[string]interface{}) (string, interface{}, error) { for key, v := range e {