Skip to content

Commit 71ec5d3

Browse files
committed
add Intel SGX vault support with file storage implementation
1 parent 639ac33 commit 71ec5d3

File tree

5 files changed

+992
-0
lines changed

5 files changed

+992
-0
lines changed

pkg/vault/preamble/preamble.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ import (
1212
_ "github.com/ecadlabs/signatory/pkg/vault/mem"
1313
_ "github.com/ecadlabs/signatory/pkg/vault/nitro"
1414
_ "github.com/ecadlabs/signatory/pkg/vault/pkcs11"
15+
_ "github.com/ecadlabs/signatory/pkg/vault/sgx"
1516
_ "github.com/ecadlabs/signatory/pkg/vault/yubi"
1617
)

pkg/vault/sgx/file_storage.go

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package sgx
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"iter"
8+
"os"
9+
"slices"
10+
"sync"
11+
12+
"github.com/ecadlabs/signatory/pkg/utils"
13+
)
14+
15+
type fileStorage struct {
16+
path string
17+
mtx sync.RWMutex
18+
keys []*encryptedKey
19+
}
20+
21+
func newFileStorage(path string) (*fileStorage, error) {
22+
buf, err := os.ReadFile(path)
23+
if err != nil || len(buf) == 0 {
24+
if err != nil && !errors.Is(err, os.ErrNotExist) {
25+
return nil, err
26+
}
27+
return &fileStorage{
28+
path: path,
29+
keys: make([]*encryptedKey, 0),
30+
}, nil
31+
}
32+
33+
var keys []*encryptedKey
34+
if err = json.Unmarshal(buf, &keys); err != nil {
35+
return nil, err
36+
}
37+
return &fileStorage{
38+
path: path,
39+
keys: keys,
40+
}, nil
41+
}
42+
43+
type fileResult struct {
44+
keys []*encryptedKey
45+
}
46+
47+
func (f *fileResult) Err() error { return nil }
48+
func (f *fileResult) Result() iter.Seq[*encryptedKey] {
49+
return slices.Values(f.keys)
50+
}
51+
52+
func (f *fileStorage) GetKeys(ctx context.Context) (result[*encryptedKey], error) {
53+
f.mtx.RLock()
54+
defer f.mtx.RUnlock()
55+
return &fileResult{keys: f.keys}, nil
56+
}
57+
58+
func (f *fileStorage) ImportKey(ctx context.Context, encryptedKey *encryptedKey) (err error) {
59+
f.mtx.Lock()
60+
defer f.mtx.Unlock()
61+
62+
f.keys = append(f.keys, encryptedKey)
63+
64+
data, err := json.MarshalIndent(f.keys, "", " ")
65+
if err != nil {
66+
return err
67+
}
68+
return utils.WriteRename(f.path, "keys", data)
69+
}

pkg/vault/sgx/rpc/client.go

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
package rpc
2+
3+
import (
4+
"context"
5+
"encoding/binary"
6+
"io"
7+
"net"
8+
"time"
9+
10+
"github.com/ecadlabs/signatory/pkg/vault"
11+
"github.com/fxamacker/cbor/v2"
12+
"github.com/kr/pretty"
13+
)
14+
15+
type Client[C any] struct {
16+
Logger Logger
17+
conn net.Conn
18+
}
19+
20+
func NewClient[C any](conn net.Conn) *Client[C] {
21+
return &Client[C]{conn: conn}
22+
}
23+
24+
func (c *Client[C]) Close() error {
25+
return c.conn.Close()
26+
}
27+
28+
func (c *Client[C]) Conn() net.Conn {
29+
return c.conn
30+
}
31+
32+
type Logger interface {
33+
Debugf(format string, args ...interface{})
34+
}
35+
36+
var aLongTimeAgo = time.Unix(1, 0)
37+
38+
func RoundTripRaw[T, C any](ctx context.Context, conn net.Conn, req *Request[C], log Logger) (r T, err error) {
39+
var debugLog func(format string, args ...interface{})
40+
if log != nil {
41+
debugLog = log.Debugf
42+
} else {
43+
debugLog = func(string, ...interface{}) {}
44+
}
45+
46+
var res T
47+
reqBuf, err := cbor.Marshal(req)
48+
if err != nil {
49+
return res, err
50+
}
51+
debugLog("<<< %# v\n", pretty.Formatter(req))
52+
53+
intErr := make(chan error)
54+
done := make(chan struct{})
55+
56+
go func() {
57+
select {
58+
case <-ctx.Done():
59+
conn.SetDeadline(aLongTimeAgo)
60+
intErr <- ctx.Err()
61+
case <-done:
62+
intErr <- nil
63+
}
64+
}()
65+
66+
defer func() {
67+
close(done)
68+
if e := <-intErr; e != nil {
69+
err = e
70+
}
71+
conn.SetDeadline(time.Time{})
72+
}()
73+
74+
wrBuf := make([]byte, len(reqBuf)+4)
75+
binary.BigEndian.PutUint32(wrBuf, uint32(len(reqBuf)))
76+
copy(wrBuf[4:], reqBuf)
77+
if _, err := conn.Write(wrBuf); err != nil {
78+
return res, err
79+
}
80+
81+
var lenBuf [4]byte
82+
if _, err := io.ReadFull(conn, lenBuf[:]); err != nil {
83+
return res, err
84+
}
85+
rBuf := make([]byte, int(binary.BigEndian.Uint32(lenBuf[:])))
86+
if _, err := io.ReadFull(conn, rBuf); err != nil {
87+
return res, err
88+
}
89+
err = cbor.Unmarshal(rBuf, &res)
90+
if err == nil {
91+
debugLog(">>> %# v\n", pretty.Formatter(&res))
92+
}
93+
return res, err
94+
}
95+
96+
func RoundTrip[T, C any](ctx context.Context, conn net.Conn, req *Request[C], log Logger) (r Result[*T], err error) {
97+
return RoundTripRaw[Result[*T]](ctx, conn, req, log)
98+
}
99+
100+
func (c *Client[C]) Initialize(ctx context.Context, cred *C) error {
101+
res, err := RoundTrip[struct{}](ctx, c.conn, &Request[C]{Initialize: cred}, c.Logger)
102+
if err != nil {
103+
return err
104+
}
105+
return res.Error()
106+
}
107+
108+
func (c *Client[C]) Import(ctx context.Context, keyData []byte) (*ImportResult, error) {
109+
res, err := RoundTrip[ImportResult](ctx, c.conn, &Request[C]{
110+
Import: keyData,
111+
}, c.Logger)
112+
if err == nil && res.Error() != nil {
113+
err = res.Error()
114+
}
115+
if err != nil {
116+
return nil, err
117+
}
118+
return res.Ok, nil
119+
}
120+
121+
func (c *Client[C]) ImportUnencrypted(ctx context.Context, priv *PrivateKey) (*GenerateAndImportResult, error) {
122+
res, err := RoundTrip[GenerateAndImportResult](ctx, c.conn, &Request[C]{
123+
ImportUnencrypted: priv,
124+
}, c.Logger)
125+
if err == nil && res.Error() != nil {
126+
err = res.Error()
127+
}
128+
if err != nil {
129+
return nil, err
130+
}
131+
return res.Ok, nil
132+
}
133+
134+
func (c *Client[C]) Generate(ctx context.Context, keyType KeyType) (*GenerateResult, error) {
135+
res, err := RoundTrip[GenerateResult](ctx, c.conn, &Request[C]{
136+
Generate: &keyType,
137+
}, c.Logger)
138+
if err == nil && res.Error() != nil {
139+
err = res.Error()
140+
}
141+
if err != nil {
142+
return nil, err
143+
}
144+
return res.Ok, nil
145+
}
146+
147+
func (c *Client[C]) GenerateAndImport(ctx context.Context, keyType KeyType) (*GenerateAndImportResult, error) {
148+
res, err := RoundTrip[GenerateAndImportResult](ctx, c.conn, &Request[C]{
149+
GenerateAndImport: &keyType,
150+
}, c.Logger)
151+
if err == nil && res.Error() != nil {
152+
err = res.Error()
153+
}
154+
if err != nil {
155+
return nil, err
156+
}
157+
return res.Ok, nil
158+
}
159+
160+
func (c *Client[C]) Sign(ctx context.Context, handle uint64, message []byte, opt *vault.SignOptions) (sig *Signature, err error) {
161+
res, err := RoundTrip[Signature](ctx, c.conn, &Request[C]{
162+
Sign: &SignRequest{Handle: handle, Message: message, Version: opt.Version.ToUint8()},
163+
}, c.Logger)
164+
if err == nil && res.Error() != nil {
165+
err = res.Error()
166+
}
167+
if err != nil {
168+
return nil, err
169+
}
170+
return res.Ok, nil
171+
}
172+
173+
func (c *Client[C]) SignWith(ctx context.Context, keyData []byte, message []byte) (sig *Signature, err error) {
174+
res, err := RoundTrip[Signature](ctx, c.conn, &Request[C]{
175+
SignWith: &SignWithRequest{EncryptedPrivateKey: keyData, Message: message},
176+
}, c.Logger)
177+
if err == nil && res.Error() != nil {
178+
err = res.Error()
179+
}
180+
if err != nil {
181+
return nil, err
182+
}
183+
return res.Ok, nil
184+
}
185+
186+
func (c *Client[C]) PublicKey(ctx context.Context, handle uint64) (publicKey *PublicKey, err error) {
187+
res, err := RoundTrip[PublicKey](ctx, c.conn, &Request[C]{
188+
PublicKey: &handle,
189+
}, c.Logger)
190+
if err == nil && res.Error() != nil {
191+
err = res.Error()
192+
}
193+
if err != nil {
194+
return nil, err
195+
}
196+
return res.Ok, nil
197+
}
198+
199+
func (c *Client[C]) PublicKeyFrom(ctx context.Context, data []byte) (publicKey *PublicKey, err error) {
200+
res, err := RoundTrip[PublicKey](ctx, c.conn, &Request[C]{
201+
PublicKeyFrom: data,
202+
}, c.Logger)
203+
if err == nil && res.Error() != nil {
204+
err = res.Error()
205+
}
206+
if err != nil {
207+
return nil, err
208+
}
209+
return res.Ok, nil
210+
}
211+
212+
func (c *Client[C]) ProvePossession(ctx context.Context, handle uint64) (sig *Signature, err error) {
213+
res, err := RoundTrip[Signature](ctx, c.conn, &Request[C]{
214+
ProvePossession: &handle,
215+
}, c.Logger)
216+
if err == nil && res.Error() != nil {
217+
err = res.Error()
218+
}
219+
if err != nil {
220+
return nil, err
221+
}
222+
return res.Ok, nil
223+
}

0 commit comments

Comments
 (0)