From 63bd59bc0cfc98f032af0ace41a56a947e54b937 Mon Sep 17 00:00:00 2001 From: Axay Sagathiya Date: Fri, 20 Dec 2024 12:55:56 +0530 Subject: [PATCH] Encode decode method of compact statement --- dot/parachain/types/statement.go | 71 ++++++++++++++++++++++----- dot/parachain/types/statement_test.go | 62 +++++++++++++++++++++++ 2 files changed, 120 insertions(+), 13 deletions(-) diff --git a/dot/parachain/types/statement.go b/dot/parachain/types/statement.go index dcfb0d8e65..cfb2608bab 100644 --- a/dot/parachain/types/statement.go +++ b/dot/parachain/types/statement.go @@ -4,7 +4,9 @@ package parachaintypes import ( + "bytes" "fmt" + "io" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/crypto/sr25519" @@ -12,6 +14,8 @@ import ( "github.com/ChainSafe/gossamer/pkg/scale" ) +var BACKING_STATEMENT_MAGIC = [4]byte{'B', 'K', 'N', 'G'} + // Statement is a result of candidate validation. It could be either `Valid` or `Seconded`. type StatementVDTValues interface { Valid | Seconded @@ -184,17 +188,16 @@ type CompactStatementValues interface { Valid | SecondedCandidateHash } -// Statements that can be made about parachain candidates. -// These are the actual values that are signed. -type CompactStatement struct { +// compactStatementInner is a helper struct that is used to encode/decode CompactStatement. +type compactStatementInner struct { inner any } -func setCompactStatement[Value CompactStatementValues](mvdt *CompactStatement, value Value) { +func setCompactStatement[Value CompactStatementValues](mvdt *compactStatementInner, value Value) { mvdt.inner = value } -func (mvdt *CompactStatement) SetValue(value any) (err error) { +func (mvdt *compactStatementInner) SetValue(value any) (err error) { switch value := value.(type) { case Valid: setCompactStatement(mvdt, value) @@ -207,7 +210,7 @@ func (mvdt *CompactStatement) SetValue(value any) (err error) { } } -func (mvdt CompactStatement) IndexValue() (index uint, value any, err error) { +func (mvdt compactStatementInner) IndexValue() (index uint, value any, err error) { switch mvdt.inner.(type) { case Valid: return 2, mvdt.inner, nil @@ -217,12 +220,12 @@ func (mvdt CompactStatement) IndexValue() (index uint, value any, err error) { return 0, nil, scale.ErrUnsupportedVaryingDataTypeValue } -func (mvdt CompactStatement) Value() (value any, err error) { +func (mvdt compactStatementInner) Value() (value any, err error) { _, value, err = mvdt.IndexValue() return } -func (mvdt CompactStatement) ValueAt(index uint) (value any, err error) { +func (mvdt compactStatementInner) ValueAt(index uint) (value any, err error) { switch index { case 2: return Valid{}, nil @@ -232,12 +235,54 @@ func (mvdt CompactStatement) ValueAt(index uint) (value any, err error) { return nil, scale.ErrUnknownVaryingDataTypeValue } -func (c *CompactStatement) Encode() ([]byte, error) { - // TODO: implement this - return nil, nil +// Statements that can be made about parachain candidates. +// These are the actual values that are signed. +type CompactStatement[T CompactStatementValues] struct { + Value T +} + +func (c CompactStatement[CompactStatementValues]) MarshalSCALE() ([]byte, error) { + inner := compactStatementInner{} + err := inner.SetValue(c.Value) + if err != nil { + return nil, fmt.Errorf("setting value: %w", err) + } + + buffer := bytes.NewBuffer(BACKING_STATEMENT_MAGIC[:]) + encoder := scale.NewEncoder(buffer) + + err = encoder.Encode(inner) + if err != nil { + return nil, err + } + + return buffer.Bytes(), nil } -func (c *CompactStatement) Decode(in []byte) error { - // TODO: implement this +func (c *CompactStatement[CompactStatementValues]) UnmarshalSCALE(reader io.Reader) error { + decoder := scale.NewDecoder(reader) + + var magicBytes [4]byte + err := decoder.Decode(&magicBytes) + if err != nil { + return err + } + + if !bytes.Equal(magicBytes[:], BACKING_STATEMENT_MAGIC[:]) { + return fmt.Errorf("invalid magic bytes") + } + + var inner compactStatementInner + err = decoder.Decode(&inner) + if err != nil { + return fmt.Errorf("decoding compactStatementInner: %w", err) + } + + value, err := inner.Value() + if err != nil { + return fmt.Errorf("getting value: %w", err) + } + + c.Value = value.(CompactStatementValues) return nil } diff --git a/dot/parachain/types/statement_test.go b/dot/parachain/types/statement_test.go index 4e8a80024f..28d3fd132d 100644 --- a/dot/parachain/types/statement_test.go +++ b/dot/parachain/types/statement_test.go @@ -169,3 +169,65 @@ func TestStatementVDT_Sign(t *testing.T) { require.NoError(t, err) require.True(t, ok) } + +func TestCompactStatement(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + compactStatement any + encodingValue []byte + expectedErr error + }{ + { + name: "SecondedCandidateHash", + compactStatement: CompactStatement[SecondedCandidateHash]{ + Value: SecondedCandidateHash{Value: getDummyHash(6)}, + }, + encodingValue: []byte{66, 75, 78, 71, 1, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6}, + }, + { + name: "Valid", + compactStatement: CompactStatement[Valid]{ + Value: Valid{Value: getDummyHash(7)}, + }, + encodingValue: []byte{ + 66, 75, 78, 71, 2, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7}, + }, + } + + for _, c := range testCases { + c := c + t.Run(c.name, func(t *testing.T) { + t.Parallel() + + t.Run("marshal", func(t *testing.T) { + t.Parallel() + + bytes, err := scale.Marshal(c.compactStatement) + require.NoError(t, err) + require.Equal(t, c.encodingValue, bytes) + }) + + t.Run("unmarshal", func(t *testing.T) { + t.Parallel() + + switch value := c.compactStatement.(type) { + case CompactStatement[Valid]: + var statement CompactStatement[Valid] + err := scale.Unmarshal(c.encodingValue, &statement) + require.NoError(t, err) + require.EqualValues(t, value, statement) + case CompactStatement[SecondedCandidateHash]: + var statement CompactStatement[SecondedCandidateHash] + err := scale.Unmarshal(c.encodingValue, &statement) + require.NoError(t, err) + require.EqualValues(t, value, statement) + } + }) + + }) + } +}