Skip to content

Commit

Permalink
Encode decode method of compact statement
Browse files Browse the repository at this point in the history
  • Loading branch information
axaysagathiya committed Dec 20, 2024
1 parent 3156a15 commit 63bd59b
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 13 deletions.
71 changes: 58 additions & 13 deletions dot/parachain/types/statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
package parachaintypes

import (
"bytes"
"fmt"
"io"

"github.com/ChainSafe/gossamer/lib/common"
"github.com/ChainSafe/gossamer/lib/crypto/sr25519"
"github.com/ChainSafe/gossamer/lib/keystore"
"github.com/ChainSafe/gossamer/pkg/scale"
)

var BACKING_STATEMENT_MAGIC = [4]byte{'B', 'K', 'N', 'G'}

// Statement is a result of candidate validation. It could be either `Valid` or `Seconded`.
type StatementVDTValues interface {
Valid | Seconded
Expand Down Expand Up @@ -184,17 +188,16 @@ type CompactStatementValues interface {
Valid | SecondedCandidateHash
}

// Statements that can be made about parachain candidates.
// These are the actual values that are signed.
type CompactStatement struct {
// compactStatementInner is a helper struct that is used to encode/decode CompactStatement.
type compactStatementInner struct {
inner any
}

func setCompactStatement[Value CompactStatementValues](mvdt *CompactStatement, value Value) {
func setCompactStatement[Value CompactStatementValues](mvdt *compactStatementInner, value Value) {
mvdt.inner = value
}

func (mvdt *CompactStatement) SetValue(value any) (err error) {
func (mvdt *compactStatementInner) SetValue(value any) (err error) {
switch value := value.(type) {
case Valid:
setCompactStatement(mvdt, value)
Expand All @@ -207,7 +210,7 @@ func (mvdt *CompactStatement) SetValue(value any) (err error) {
}
}

func (mvdt CompactStatement) IndexValue() (index uint, value any, err error) {
func (mvdt compactStatementInner) IndexValue() (index uint, value any, err error) {
switch mvdt.inner.(type) {
case Valid:
return 2, mvdt.inner, nil
Expand All @@ -217,12 +220,12 @@ func (mvdt CompactStatement) IndexValue() (index uint, value any, err error) {
return 0, nil, scale.ErrUnsupportedVaryingDataTypeValue
}

func (mvdt CompactStatement) Value() (value any, err error) {
func (mvdt compactStatementInner) Value() (value any, err error) {
_, value, err = mvdt.IndexValue()
return
}

func (mvdt CompactStatement) ValueAt(index uint) (value any, err error) {
func (mvdt compactStatementInner) ValueAt(index uint) (value any, err error) {
switch index {
case 2:
return Valid{}, nil
Expand All @@ -232,12 +235,54 @@ func (mvdt CompactStatement) ValueAt(index uint) (value any, err error) {
return nil, scale.ErrUnknownVaryingDataTypeValue
}

func (c *CompactStatement) Encode() ([]byte, error) {
// TODO: implement this
return nil, nil
// Statements that can be made about parachain candidates.
// These are the actual values that are signed.
type CompactStatement[T CompactStatementValues] struct {
Value T
}

func (c CompactStatement[CompactStatementValues]) MarshalSCALE() ([]byte, error) {
inner := compactStatementInner{}
err := inner.SetValue(c.Value)
if err != nil {
return nil, fmt.Errorf("setting value: %w", err)
}

buffer := bytes.NewBuffer(BACKING_STATEMENT_MAGIC[:])
encoder := scale.NewEncoder(buffer)

err = encoder.Encode(inner)
if err != nil {
return nil, err
}

return buffer.Bytes(), nil
}

func (c *CompactStatement) Decode(in []byte) error {
// TODO: implement this
func (c *CompactStatement[CompactStatementValues]) UnmarshalSCALE(reader io.Reader) error {
decoder := scale.NewDecoder(reader)

var magicBytes [4]byte
err := decoder.Decode(&magicBytes)
if err != nil {
return err
}

if !bytes.Equal(magicBytes[:], BACKING_STATEMENT_MAGIC[:]) {
return fmt.Errorf("invalid magic bytes")
}

var inner compactStatementInner
err = decoder.Decode(&inner)
if err != nil {
return fmt.Errorf("decoding compactStatementInner: %w", err)
}

value, err := inner.Value()
if err != nil {
return fmt.Errorf("getting value: %w", err)
}

c.Value = value.(CompactStatementValues)
return nil
}
62 changes: 62 additions & 0 deletions dot/parachain/types/statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,65 @@ func TestStatementVDT_Sign(t *testing.T) {
require.NoError(t, err)
require.True(t, ok)
}

func TestCompactStatement(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
compactStatement any
encodingValue []byte
expectedErr error
}{
{
name: "SecondedCandidateHash",
compactStatement: CompactStatement[SecondedCandidateHash]{
Value: SecondedCandidateHash{Value: getDummyHash(6)},
},
encodingValue: []byte{66, 75, 78, 71, 1,
6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6},
},
{
name: "Valid",
compactStatement: CompactStatement[Valid]{
Value: Valid{Value: getDummyHash(7)},
},
encodingValue: []byte{
66, 75, 78, 71, 2,
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7},
},
}

for _, c := range testCases {
c := c
t.Run(c.name, func(t *testing.T) {
t.Parallel()

t.Run("marshal", func(t *testing.T) {
t.Parallel()

bytes, err := scale.Marshal(c.compactStatement)
require.NoError(t, err)
require.Equal(t, c.encodingValue, bytes)
})

t.Run("unmarshal", func(t *testing.T) {
t.Parallel()

switch value := c.compactStatement.(type) {
case CompactStatement[Valid]:
var statement CompactStatement[Valid]
err := scale.Unmarshal(c.encodingValue, &statement)
require.NoError(t, err)
require.EqualValues(t, value, statement)
case CompactStatement[SecondedCandidateHash]:
var statement CompactStatement[SecondedCandidateHash]
err := scale.Unmarshal(c.encodingValue, &statement)
require.NoError(t, err)
require.EqualValues(t, value, statement)
}
})

})
}
}

0 comments on commit 63bd59b

Please sign in to comment.