diff --git a/generator/golang/extension/unknown/binary.go b/generator/golang/extension/unknown/binary.go new file mode 100644 index 00000000..34e21d30 --- /dev/null +++ b/generator/golang/extension/unknown/binary.go @@ -0,0 +1,399 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package unknown . +package unknown + +import ( + "encoding/binary" + "errors" + "math" +) + +var InvalidDataLength = errors.New("invalid data length") + +// Binary protocol for bthrift. +var Binary binaryProtocol + +type binaryProtocol struct{} + +func (binaryProtocol) WriteStructBegin(buf []byte, name string) int { + return 0 +} + +func (binaryProtocol) WriteStructEnd(buf []byte) int { + return 0 +} + +func (binaryProtocol) WriteFieldBegin(buf []byte, name string, typeID int, id int16) int { + return Binary.WriteByte(buf, int8(typeID)) + Binary.WriteI16(buf[1:], id) +} + +func (binaryProtocol) WriteFieldEnd(buf []byte) int { + return 0 +} + +func (binaryProtocol) WriteFieldStop(buf []byte) int { + return Binary.WriteByte(buf, TStop) +} + +func (binaryProtocol) WriteMapBegin(buf []byte, keyType, valueType int, size int) int { + return Binary.WriteByte(buf, int8(keyType)) + + Binary.WriteByte(buf[1:], int8(valueType)) + + Binary.WriteI32(buf[2:], int32(size)) +} + +func (binaryProtocol) WriteMapEnd(buf []byte) int { + return 0 +} + +func (binaryProtocol) WriteListBegin(buf []byte, elemType int, size int) int { + return Binary.WriteByte(buf, int8(elemType)) + + Binary.WriteI32(buf[1:], int32(size)) +} + +func (binaryProtocol) WriteListEnd(buf []byte) int { + return 0 +} + +func (binaryProtocol) WriteSetBegin(buf []byte, elemType int, size int) int { + return Binary.WriteByte(buf, int8(elemType)) + + Binary.WriteI32(buf[1:], int32(size)) +} + +func (binaryProtocol) WriteSetEnd(buf []byte) int { + return 0 +} + +func (binaryProtocol) WriteBool(buf []byte, value bool) int { + if value { + return Binary.WriteByte(buf, 1) + } + return Binary.WriteByte(buf, 0) +} + +func (binaryProtocol) WriteByte(buf []byte, value int8) int { + buf[0] = byte(value) + return 1 +} + +func (binaryProtocol) WriteI16(buf []byte, value int16) int { + binary.BigEndian.PutUint16(buf, uint16(value)) + return 2 +} + +func (binaryProtocol) WriteI32(buf []byte, value int32) int { + binary.BigEndian.PutUint32(buf, uint32(value)) + return 4 +} + +func (binaryProtocol) WriteI64(buf []byte, value int64) int { + binary.BigEndian.PutUint64(buf, uint64(value)) + return 8 +} + +func (binaryProtocol) WriteDouble(buf []byte, value float64) int { + return Binary.WriteI64(buf, int64(math.Float64bits(value))) +} + +func (binaryProtocol) WriteString(buf []byte, value string) int { + l := Binary.WriteI32(buf, int32(len(value))) + copy(buf[l:], value) + return l + len(value) +} + +func (binaryProtocol) WriteBinary(buf, value []byte) int { + l := Binary.WriteI32(buf, int32(len(value))) + copy(buf[l:], value) + return l + len(value) +} + +func (binaryProtocol) StructBeginLength(name string) int { + return 0 +} + +func (binaryProtocol) StructEndLength() int { + return 0 +} + +func (binaryProtocol) FieldBeginLength(name string, typeID int, id int16) int { + return Binary.ByteLength(int8(typeID)) + Binary.I16Length(id) +} + +func (binaryProtocol) FieldEndLength() int { + return 0 +} + +func (binaryProtocol) FieldStopLength() int { + return Binary.ByteLength(TStop) +} + +func (binaryProtocol) MapBeginLength(keyType, valueType int, size int) int { + return Binary.ByteLength(int8(keyType)) + + Binary.ByteLength(int8(valueType)) + + Binary.I32Length(int32(size)) +} + +func (binaryProtocol) MapEndLength() int { + return 0 +} + +func (binaryProtocol) ListBeginLength(elemType int, size int) int { + return Binary.ByteLength(int8(elemType)) + + Binary.I32Length(int32(size)) +} + +func (binaryProtocol) ListEndLength() int { + return 0 +} + +func (binaryProtocol) SetBeginLength(elemType int, size int) int { + return Binary.ByteLength(int8(elemType)) + + Binary.I32Length(int32(size)) +} + +func (binaryProtocol) SetEndLength() int { + return 0 +} + +func (binaryProtocol) BoolLength(value bool) int { + if value { + return Binary.ByteLength(1) + } + return Binary.ByteLength(0) +} + +func (binaryProtocol) ByteLength(value int8) int { + return 1 +} + +func (binaryProtocol) I16Length(value int16) int { + return 2 +} + +func (binaryProtocol) I32Length(value int32) int { + return 4 +} + +func (binaryProtocol) I64Length(value int64) int { + return 8 +} + +func (binaryProtocol) DoubleLength(value float64) int { + return Binary.I64Length(int64(math.Float64bits(value))) +} + +func (binaryProtocol) StringLength(value string) int { + return Binary.I32Length(int32(len(value))) + len(value) +} + +func (binaryProtocol) BinaryLength(value []byte) int { + return Binary.I32Length(int32(len(value))) + len(value) +} + +func (binaryProtocol) ReadMessageEnd(buf []byte) (int, error) { + return 0, nil +} + +func (binaryProtocol) ReadStructBegin(buf []byte) (name string, length int, err error) { + return +} + +func (binaryProtocol) ReadStructEnd(buf []byte) (int, error) { + return 0, nil +} + +func (binaryProtocol) ReadFieldBegin(buf []byte) (name string, typeID int, id int16, length int, err error) { + t, l, e := Binary.ReadByte(buf) + length += l + typeID = int(t) + if e != nil { + err = e + return + } + if typeID != TStop { + id, l, err = Binary.ReadI16(buf[length:]) + length += l + } + return +} + +func (binaryProtocol) ReadFieldEnd(buf []byte) (int, error) { + return 0, nil +} + +func (binaryProtocol) ReadMapBegin(buf []byte) (keyType, valueType int, size, length int, err error) { + k, l, e := Binary.ReadByte(buf) + length += l + if e != nil { + err = e + return + } + keyType = int(k) + v, l, e := Binary.ReadByte(buf[length:]) + length += l + if e != nil { + err = e + return + } + valueType = int(v) + size32, l, e := Binary.ReadI32(buf[length:]) + length += l + if e != nil { + err = e + return + } + if size32 < 0 { + err = InvalidDataLength + return + } + size = int(size32) + return +} + +func (binaryProtocol) ReadMapEnd(buf []byte) (int, error) { + return 0, nil +} + +func (binaryProtocol) ReadListBegin(buf []byte) (elemType int, size, length int, err error) { + b, l, e := Binary.ReadByte(buf) + length += l + if e != nil { + err = e + return + } + elemType = int(b) + size32, l, e := Binary.ReadI32(buf[length:]) + length += l + if e != nil { + err = e + return + } + if size32 < 0 { + err = InvalidDataLength + return + } + size = int(size32) + + return +} + +func (binaryProtocol) ReadListEnd(buf []byte) (int, error) { + return 0, nil +} + +func (binaryProtocol) ReadSetBegin(buf []byte) (elemType int, size, length int, err error) { + b, l, e := Binary.ReadByte(buf) + length += l + if e != nil { + err = e + return + } + elemType = int(b) + size32, l, e := Binary.ReadI32(buf[length:]) + length += l + if e != nil { + err = e + return + } + if size32 < 0 { + err = InvalidDataLength + return + } + size = int(size32) + return +} + +func (binaryProtocol) ReadSetEnd(buf []byte) (int, error) { + return 0, nil +} + +func (binaryProtocol) ReadBool(buf []byte) (value bool, length int, err error) { + b, l, e := Binary.ReadByte(buf) + v := true + if b != 1 { + v = false + } + return v, l, e +} + +func (binaryProtocol) ReadByte(buf []byte) (value int8, length int, err error) { + if len(buf) < 1 { + return value, length, InvalidDataLength + } + return int8(buf[0]), 1, err +} + +func (binaryProtocol) ReadI16(buf []byte) (value int16, length int, err error) { + if len(buf) < 2 { + return value, length, InvalidDataLength + } + value = int16(binary.BigEndian.Uint16(buf)) + return value, 2, err +} + +func (binaryProtocol) ReadI32(buf []byte) (value int32, length int, err error) { + if len(buf) < 4 { + return value, length, InvalidDataLength + } + value = int32(binary.BigEndian.Uint32(buf)) + return value, 4, err +} + +func (binaryProtocol) ReadI64(buf []byte) (value int64, length int, err error) { + if len(buf) < 8 { + return value, length, InvalidDataLength + } + value = int64(binary.BigEndian.Uint64(buf)) + return value, 8, err +} + +func (binaryProtocol) ReadDouble(buf []byte) (value float64, length int, err error) { + if len(buf) < 8 { + return value, length, InvalidDataLength + } + value = math.Float64frombits(binary.BigEndian.Uint64(buf)) + return value, 8, err +} + +func (binaryProtocol) ReadString(buf []byte) (value string, length int, err error) { + size, l, e := Binary.ReadI32(buf) + length += l + if e != nil { + err = e + return + } + if size < 0 || int(size) > len(buf) { + return value, length, InvalidDataLength + } + value = string(buf[length : length+int(size)]) + length += int(size) + return +} + +func (binaryProtocol) ReadBinary(buf []byte) (value []byte, length int, err error) { + size, l, e := Binary.ReadI32(buf) + length += l + if e != nil { + err = e + return + } + if size < 0 || int(size) > len(buf) { + return value, length, InvalidDataLength + } + value = make([]byte, size) + copy(value, buf[length:length+int(size)]) + length += int(size) + return +} diff --git a/generator/golang/extension/unknown/unknown.go b/generator/golang/extension/unknown/unknown.go index 320c4115..667e215f 100644 --- a/generator/golang/extension/unknown/unknown.go +++ b/generator/golang/extension/unknown/unknown.go @@ -46,248 +46,414 @@ func SetNestingDepthLimit(d int) { maxNestingDepth = d } -// Field is used to store unrecognized field when deserializing data. -type Field struct { - Name string - ID int16 - Type int - KeyType int - ValType int - Value interface{} -} - -// Fields is a list of Field. -type Fields []*Field +// Fields stores all undeserialized unknown fields. +type Fields []byte -// Append reads an unrecognized field and append it to the current slice. +// Append reads an object of a generalized type from xprot and serializes the object into Fields for compatibility +// with the thrift interface, and the performance is greatly discounted for this reason. +// +// Deprecated: Use the FastCodec api provided by Kitex for serialization/deserialization to improve performance. func (fs *Fields) Append(xprot TProtocol, name string, fieldType TType, id int16) error { iprot, err := convert(xprot) if err != nil { return err } - f, err := read(iprot, name, asInt(fieldType), id, maxNestingDepth) + ft := asInt(fieldType) + buf := ([]byte)(*fs)[:cap(*fs)] + offset := len(*fs) + ensureBytesLen(&buf, offset, Binary.FieldBeginLength(name, ft, id)) + offset += Binary.WriteFieldBegin(buf[offset:], name, ft, id) + offset, err = read(&buf, offset, iprot, name, ft, id, maxNestingDepth) if err != nil { return err } - *fs = append(*fs, f) + *fs = buf[:offset] return nil } -// Write writes out the unknown fields. +// Write reads an object of a generalized type from Fields and srializes the object into xprot for compatibility +// with the thrift interface, and the performance is greatly discounted for this reason. +// +// Deprecated: Use the FastCodec api provided by Kitex for serialization/deserialization to improve performance. func (fs *Fields) Write(xprot TProtocol) (err error) { oprot, err := convert(xprot) if err != nil { return err } - var i int - var f *Field - for i, f = range *fs { - if err = oprot.WriteFieldBegin(ctx, f.Name, f.Type, f.ID); err != nil { - break + rbuf := []byte(*fs) + var offset int + for offset < len(rbuf) { + name, fieldType, fieldID, l, err := Binary.ReadFieldBegin(rbuf[offset:]) + offset += l + if err != nil { + return fmt.Errorf("read field begin error: %w", err) + } + if err = oprot.WriteFieldBegin(ctx, name, fieldType, fieldID); err != nil { + return fmt.Errorf("write field begin error: %w", err) + } + + l, err = write(oprot, name, fieldType, fieldID, rbuf[offset:]) + offset += l + if err != nil { + return fmt.Errorf("write struct field error: %w", err) } - if err = write(oprot, f); err != nil { - break + + l, err = Binary.ReadFieldEnd(rbuf[offset:]) + offset += l + if err != nil { + return fmt.Errorf("read field end error: %w", err) } if err = oprot.WriteFieldEnd(ctx); err != nil { - break + return fmt.Errorf("write field end error: %w", err) } } if err != nil { - err = fmt.Errorf("write field error unknown.%d(name:%s type:%d id:%d): %w", - i, f.Name, f.Type, f.ID, err) + err = fmt.Errorf("write field error unknown: %w", err) } return err } -// write writes out the unknown field. -func write(oprot *protocol, f *Field) (err error) { - switch f.Type { +// write writes fields out the oprot. +func write(oprot *protocol, name string, fieldType int, id int16, fs []byte) (offset int, err error) { + switch fieldType { case TBool: - return oprot.WriteBool(ctx, f.Value.(bool)) + v, l, err := Binary.ReadBool(fs[offset:]) + offset += l + if err != nil { + return offset, err + } + if err = oprot.WriteBool(ctx, v); err != nil { + return offset, err + } case TByte: - return oprot.WriteByte(ctx, f.Value.(int8)) - case TDouble: - return oprot.WriteDouble(ctx, f.Value.(float64)) + v, l, err := Binary.ReadByte(fs[offset:]) + offset += l + if err != nil { + return offset, err + } + if err = oprot.WriteByte(ctx, v); err != nil { + return offset, err + } case TI16: - return oprot.WriteI16(ctx, f.Value.(int16)) + v, l, err := Binary.ReadI16(fs[offset:]) + offset += l + if err != nil { + return offset, err + } + if err = oprot.WriteI16(ctx, v); err != nil { + return offset, err + } case TI32: - return oprot.WriteI32(ctx, f.Value.(int32)) + v, l, err := Binary.ReadI32(fs[offset:]) + offset += l + if err != nil { + return offset, err + } + if err = oprot.WriteI32(ctx, v); err != nil { + return offset, err + } case TI64: - return oprot.WriteI64(ctx, f.Value.(int64)) + v, l, err := Binary.ReadI64(fs[offset:]) + offset += l + if err != nil { + return offset, err + } + if err = oprot.WriteI64(ctx, v); err != nil { + return offset, err + } + case TDouble: + v, l, err := Binary.ReadDouble(fs[offset:]) + offset += l + if err != nil { + return offset, err + } + if err = oprot.WriteDouble(ctx, v); err != nil { + return offset, err + } case TString: - return oprot.WriteString(ctx, f.Value.(string)) + v, l, err := Binary.ReadString(fs[offset:]) + offset += l + if err != nil { + return offset, err + } + if err = oprot.WriteString(ctx, v); err != nil { + return offset, err + } case TSet: - vs := f.Value.([]*Field) - if err = oprot.WriteSetBegin(ctx, f.ValType, len(vs)); err != nil { - return fmt.Errorf("write set begin error: %w", err) + ttype, size, l, err := Binary.ReadSetBegin(fs[offset:]) + offset += l + if err != nil { + return offset, fmt.Errorf("read set begin error: %w", err) } - for _, v := range vs { - if err = write(oprot, v); err != nil { - return fmt.Errorf("write set elem error: %w", err) + if err = oprot.WriteSetBegin(ctx, ttype, size); err != nil { + return offset, err + } + for i := 0; i < size; i++ { + l, err = write(oprot, "", ttype, int16(i), fs[offset:]) + offset += l + if err != nil { + return offset, err } } + l, err = Binary.ReadSetEnd(fs[offset:]) + offset += l + if err != nil { + return offset, fmt.Errorf("read set end error: %w", err) + } if err = oprot.WriteSetEnd(ctx); err != nil { - return fmt.Errorf("write set end error: %w", err) + return offset, err } case TList: - vs := f.Value.([]*Field) - if err = oprot.WriteListBegin(ctx, f.ValType, len(vs)); err != nil { - return fmt.Errorf("write list begin error: %w", err) + ttype, size, l, err := Binary.ReadListBegin(fs[offset:]) + offset += l + if err != nil { + return offset, fmt.Errorf("read list begin error: %w", err) + } + if err = oprot.WriteListBegin(ctx, ttype, size); err != nil { + return offset, err } - for _, v := range vs { - if err = write(oprot, v); err != nil { - return fmt.Errorf("write list elem error: %w", err) + for i := 0; i < size; i++ { + l, err = write(oprot, "", ttype, int16(i), fs[offset:]) + offset += l + if err != nil { + return offset, err } } + l, err = Binary.ReadListEnd(fs[offset:]) + offset += l + if err != nil { + return offset, fmt.Errorf("read list end error: %w", err) + } if err = oprot.WriteListEnd(ctx); err != nil { - return fmt.Errorf("write list end error: %w", err) + return offset, err } case TMap: - kvs := f.Value.([]*Field) - if err = oprot.WriteMapBegin(ctx, f.KeyType, f.ValType, len(kvs)/2); err != nil { - return fmt.Errorf("write map begin error: %w", err) + kttype, vttype, size, l, err := Binary.ReadMapBegin(fs[offset:]) + offset += l + if err != nil { + return offset, fmt.Errorf("read map begin error: %w", err) } - for i := 0; i < len(kvs); i += 2 { - if err = write(oprot, kvs[i]); err != nil { - return fmt.Errorf("write map key error: %w", err) + if err = oprot.WriteMapBegin(ctx, kttype, vttype, size); err != nil { + return offset, fmt.Errorf("write map begin error: %w", err) + } + for i := 0; i < size; i++ { + l, err = write(oprot, "", kttype, int16(i), fs[offset:]) + offset += l + if err != nil { + return offset, fmt.Errorf("read map key error: %w", err) } - if err = write(oprot, kvs[i+1]); err != nil { - return fmt.Errorf("write map value error: %w", err) + l, err = write(oprot, "", vttype, int16(i), fs[offset:]) + offset += l + if err != nil { + return offset, fmt.Errorf("read map value error: %w", err) } } + l, err = Binary.ReadMapEnd(fs[offset:]) + offset += l + if err != nil { + return offset, fmt.Errorf("read map end error: %w", err) + } if err = oprot.WriteMapEnd(ctx); err != nil { - return fmt.Errorf("write map end error: %w", err) + return offset, fmt.Errorf("write map end error: %w", err) } case TStruct: - fs := Fields(f.Value.([]*Field)) - if err = oprot.WriteStructBegin(ctx, f.Name); err != nil { - return fmt.Errorf("write struct begin error: %w", err) + _, l, err := Binary.ReadStructBegin(fs[offset:]) + offset += l + if err != nil { + return offset, fmt.Errorf("read struct begin error: %w", err) } - if err = fs.Write(oprot); err != nil { - return fmt.Errorf("write struct field error: %w", err) + for { + name, fieldTypeID, fieldID, l, err := Binary.ReadFieldBegin(fs[offset:]) + offset += l + if err != nil { + return offset, fmt.Errorf("read field begin error: %w", err) + } + if fieldTypeID == TStop { + if err = oprot.WriteFieldStop(ctx); err != nil { + return offset, fmt.Errorf("write field stop error: %w", err) + } + break + } + if err = oprot.WriteFieldBegin(ctx, name, fieldTypeID, fieldID); err != nil { + return offset, fmt.Errorf("write field begin error: %w", err) + } + l, err = write(oprot, name, fieldTypeID, fieldID, fs[offset:]) + offset += l + if err != nil { + return offset, fmt.Errorf("write struct field error: %w", err) + } + l, err = Binary.ReadFieldEnd(fs[offset:]) + offset += l + if err != nil { + return offset, fmt.Errorf("read field end error: %w", err) + } + if err = oprot.WriteFieldEnd(ctx); err != nil { + return offset, fmt.Errorf("write field end error: %w", err) + } } - if err = oprot.WriteFieldStop(ctx); err != nil { - return fmt.Errorf("write struct stop error: %w", err) + l, err = Binary.ReadStructEnd(fs[offset:]) + offset += l + if err != nil { + return offset, fmt.Errorf("read struct end error: %w", err) } if err = oprot.WriteStructEnd(ctx); err != nil { - return fmt.Errorf("write struct end error: %w", err) + return offset, fmt.Errorf("write struct end error: %w", err) } default: - return ErrUnknownType(f.Type) + return offset, ErrUnknownType(fieldType) } - return + return offset, nil } // read reads an unknown field from the given TProtocol. -func read(iprot *protocol, name string, fieldType int, id int16, maxDepth int) (f *Field, err error) { +func read(buf *[]byte, offset int, iprot *protocol, name string, fieldType int, id int16, maxDepth int) (noffset int, err error) { if maxDepth <= 0 { - return nil, ErrExceedDepthLimit + return 0, ErrExceedDepthLimit } - - var size int - f = &Field{Name: name, ID: id, Type: asInt(fieldType)} switch fieldType { case TBool: - f.Value, err = iprot.ReadBool(ctx) + var v bool + v, err = iprot.ReadBool(ctx) + ensureBytesLen(buf, offset, Binary.BoolLength(v)) + offset += Binary.WriteBool((*buf)[offset:], v) case TByte: - f.Value, err = iprot.ReadByte(ctx) + var v int8 + v, err = iprot.ReadByte(ctx) + ensureBytesLen(buf, offset, Binary.ByteLength(v)) + offset += Binary.WriteByte((*buf)[offset:], v) case TI16: - f.Value, err = iprot.ReadI16(ctx) + var v int16 + v, err = iprot.ReadI16(ctx) + ensureBytesLen(buf, offset, Binary.I16Length(v)) + offset += Binary.WriteI16((*buf)[offset:], v) case TI32: - f.Value, err = iprot.ReadI32(ctx) + var v int32 + v, err = iprot.ReadI32(ctx) + ensureBytesLen(buf, offset, Binary.I32Length(v)) + offset += Binary.WriteI32((*buf)[offset:], v) case TI64: - f.Value, err = iprot.ReadI64(ctx) + var v int64 + v, err = iprot.ReadI64(ctx) + ensureBytesLen(buf, offset, Binary.I64Length(v)) + offset += Binary.WriteI64((*buf)[offset:], v) case TDouble: - f.Value, err = iprot.ReadDouble(ctx) + var v float64 + v, err = iprot.ReadDouble(ctx) + ensureBytesLen(buf, offset, Binary.DoubleLength(v)) + offset += Binary.WriteDouble((*buf)[offset:], v) case TString: - f.Value, err = iprot.ReadString(ctx) + var v string + v, err = iprot.ReadString(ctx) + ensureBytesLen(buf, offset, Binary.StringLength(v)) + offset += Binary.WriteString((*buf)[offset:], v) case TSet: - f.ValType, size, err = iprot.ReadSetBegin(ctx) + var valType int + var size int + valType, size, err = iprot.ReadSetBegin(ctx) if err != nil { - return nil, fmt.Errorf("read set begin error: %w", err) + return 0, fmt.Errorf("read set begin error: %w", err) } - set := make([]*Field, 0, size) + ensureBytesLen(buf, offset, Binary.SetBeginLength(valType, size)) + offset += Binary.WriteSetBegin((*buf)[offset:], valType, size) for i := 0; i < size; i++ { - v, err2 := read(iprot, "", f.ValType, int16(i), maxDepth-1) - if err2 != nil { - return nil, fmt.Errorf("read set elem error: %w", err) + offset, err = read(buf, offset, iprot, "", valType, int16(i), maxDepth-1) + if err != nil { + return 0, fmt.Errorf("read set elem error: %w", err) } - set = append(set, v) } if err = iprot.ReadSetEnd(ctx); err != nil { - return nil, fmt.Errorf("read set end error: %w", err) + return 0, fmt.Errorf("read set end error: %w", err) } - f.Value = set + ensureBytesLen(buf, offset, Binary.SetEndLength()) + offset += Binary.WriteSetEnd((*buf)[offset:]) case TList: - f.ValType, size, err = iprot.ReadListBegin(ctx) + var valType int + var size int + valType, size, err = iprot.ReadListBegin(ctx) if err != nil { - return nil, fmt.Errorf("read list begin error: %w", err) + return 0, fmt.Errorf("read list begin error: %w", err) } - list := make([]*Field, 0, size) + ensureBytesLen(buf, offset, Binary.ListBeginLength(valType, size)) + offset += Binary.WriteListBegin((*buf)[offset:], valType, size) for i := 0; i < size; i++ { - v, err2 := read(iprot, "", f.ValType, int16(i), maxDepth-1) - if err2 != nil { - return nil, fmt.Errorf("read list elem error: %w", err) + offset, err = read(buf, offset, iprot, "", valType, int16(i), maxDepth-1) + if err != nil { + return 0, fmt.Errorf("read list elem error: %w", err) } - list = append(list, v) } if err = iprot.ReadListEnd(ctx); err != nil { - return nil, fmt.Errorf("read list end error: %w", err) + return 0, fmt.Errorf("read list end error: %w", err) } - f.Value = list + ensureBytesLen(buf, offset, Binary.ListEndLength()) + offset += Binary.WriteListEnd((*buf)[offset:]) case TMap: - f.KeyType, f.ValType, size, err = iprot.ReadMapBegin(ctx) + var keyType, valType int + var size int + keyType, valType, size, err = iprot.ReadMapBegin(ctx) if err != nil { - return nil, fmt.Errorf("read map begin error: %w", err) + return 0, fmt.Errorf("read map begin error: %w", err) } - flatMap := make([]*Field, 0, size*2) + ensureBytesLen(buf, offset, Binary.MapBeginLength(keyType, valType, size)) + offset += Binary.WriteMapBegin((*buf)[offset:], keyType, valType, size) for i := 0; i < size; i++ { - k, err2 := read(iprot, "", f.KeyType, int16(i), maxDepth-1) - if err2 != nil { - return nil, fmt.Errorf("read map key error: %w", err) + offset, err = read(buf, offset, iprot, "", keyType, int16(i), maxDepth-1) + if err != nil { + return 0, fmt.Errorf("read map key error: %w", err) } - v, err2 := read(iprot, "", f.ValType, int16(i), maxDepth-1) - if err2 != nil { - return nil, fmt.Errorf("read map value error: %w", err) + offset, err = read(buf, offset, iprot, "", valType, int16(i), maxDepth-1) + if err != nil { + return 0, fmt.Errorf("read map value error: %w", err) } - flatMap = append(flatMap, k, v) } if err = iprot.ReadMapEnd(ctx); err != nil { - return nil, fmt.Errorf("read map end error: %w", err) + return 0, fmt.Errorf("read map end error: %w", err) } - f.Value = flatMap + ensureBytesLen(buf, offset, Binary.MapEndLength()) + offset += Binary.WriteMapEnd((*buf)[offset:]) case TStruct: - _, err = iprot.ReadStructBegin(ctx) + name, err := iprot.ReadStructBegin(ctx) if err != nil { - return nil, fmt.Errorf("read struct begin error: %w", err) + return 0, fmt.Errorf("read struct begin error: %w", err) } - var fields []*Field + ensureBytesLen(buf, offset, Binary.StructBeginLength(name)) + offset += Binary.WriteStructBegin((*buf)[offset:], name) for { name, fieldTypeID, fieldID, err := iprot.ReadFieldBegin(ctx) if err != nil { - return nil, fmt.Errorf("read field begin error: %w", err) + return 0, fmt.Errorf("read field begin error: %w", err) } if fieldTypeID == TStop { + ensureBytesLen(buf, offset, Binary.FieldStopLength()) + offset += Binary.WriteFieldStop((*buf)[offset:]) break } - v, err := read(iprot, name, fieldTypeID, fieldID, maxDepth-1) + ensureBytesLen(buf, offset, Binary.FieldBeginLength(name, fieldTypeID, fieldID)) + offset += Binary.WriteFieldBegin((*buf)[offset:], name, fieldTypeID, fieldID) + offset, err = read(buf, offset, iprot, name, fieldTypeID, fieldID, maxDepth-1) if err != nil { - return nil, fmt.Errorf("read struct field error: %w", err) + return 0, fmt.Errorf("read struct field error: %w", err) } - if err := iprot.ReadFieldEnd(ctx); err != nil { - return nil, fmt.Errorf("read field end error: %w", err) + if err = iprot.ReadFieldEnd(ctx); err != nil { + return 0, fmt.Errorf("read field end error: %w", err) } - fields = append(fields, v) + ensureBytesLen(buf, offset, Binary.FieldEndLength()) + offset += Binary.WriteFieldEnd((*buf)[offset:]) } if err = iprot.ReadStructEnd(ctx); err != nil { - return nil, fmt.Errorf("read struct end error: %w", err) + return 0, fmt.Errorf("read struct end error: %w", err) } - f.Value = fields + ensureBytesLen(buf, offset, Binary.StructEndLength()) + offset += Binary.WriteStructEnd((*buf)[offset:]) default: - return nil, ErrUnknownType(fieldType) + return 0, ErrUnknownType(fieldType) } - if err != nil { - return nil, err + return offset, nil +} + +func ensureBytesLen(buf *[]byte, offset int, l int) { + if len(*buf)-offset < l { + nb := make([]byte, (offset+l)*2) + copy(nb, (*buf)[:offset]) + *buf = nb } - return }