Skip to content

Commit f9ea100

Browse files
committed
implement secure connection for sgx
1 parent 71ec5d3 commit f9ea100

File tree

2 files changed

+139
-31
lines changed

2 files changed

+139
-31
lines changed

pkg/vault/sgx/rpc/client.go

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@ package rpc
22

33
import (
44
"context"
5+
"crypto/ed25519"
56
"encoding/binary"
67
"io"
78
"net"
89
"time"
910

11+
"github.com/ecadlabs/signatory/pkg/utils/secureconn"
1012
"github.com/ecadlabs/signatory/pkg/vault"
1113
"github.com/fxamacker/cbor/v2"
1214
"github.com/kr/pretty"
@@ -21,6 +23,14 @@ func NewClient[C any](conn net.Conn) *Client[C] {
2123
return &Client[C]{conn: conn}
2224
}
2325

26+
func NewSecureClient[C any](conn net.Conn, serverPublicKey ed25519.PublicKey, clientPrivateKey ed25519.PrivateKey) (*Client[C], error) {
27+
secureConn, err := secureconn.WrapConnection(conn, serverPublicKey, clientPrivateKey)
28+
if err != nil {
29+
return nil, err
30+
}
31+
return &Client[C]{conn: secureConn}, nil
32+
}
33+
2434
func (c *Client[C]) Close() error {
2535
return c.conn.Close()
2636
}
@@ -71,22 +81,38 @@ func RoundTripRaw[T, C any](ctx context.Context, conn net.Conn, req *Request[C],
7181
conn.SetDeadline(time.Time{})
7282
}()
7383

74-
wrBuf := make([]byte, len(reqBuf)+4)
75-
binary.BigEndian.PutUint32(wrBuf, uint32(len(reqBuf)))
76-
copy(wrBuf[4:], reqBuf)
77-
if _, err := conn.Write(wrBuf); err != nil {
78-
return res, err
79-
}
84+
_, isSecureConn := conn.(*secureconn.SecureConn)
8085

81-
var lenBuf [4]byte
82-
if _, err := io.ReadFull(conn, lenBuf[:]); err != nil {
83-
return res, err
84-
}
85-
rBuf := make([]byte, int(binary.BigEndian.Uint32(lenBuf[:])))
86-
if _, err := io.ReadFull(conn, rBuf); err != nil {
87-
return res, err
86+
if isSecureConn {
87+
if _, err := conn.Write(reqBuf); err != nil {
88+
return res, err
89+
}
90+
91+
rBuf := make([]byte, 65536)
92+
n, readErr := conn.Read(rBuf)
93+
if readErr != nil {
94+
return res, readErr
95+
}
96+
err = cbor.Unmarshal(rBuf[:n], &res)
97+
} else {
98+
wrBuf := make([]byte, len(reqBuf)+4)
99+
binary.BigEndian.PutUint32(wrBuf, uint32(len(reqBuf)))
100+
copy(wrBuf[4:], reqBuf)
101+
if _, err := conn.Write(wrBuf); err != nil {
102+
return res, err
103+
}
104+
105+
var lenBuf [4]byte
106+
if _, err := io.ReadFull(conn, lenBuf[:]); err != nil {
107+
return res, err
108+
}
109+
rBuf := make([]byte, int(binary.BigEndian.Uint32(lenBuf[:])))
110+
if _, err := io.ReadFull(conn, rBuf); err != nil {
111+
return res, err
112+
}
113+
err = cbor.Unmarshal(rBuf, &res)
88114
}
89-
err = cbor.Unmarshal(rBuf, &res)
115+
90116
if err == nil {
91117
debugLog(">>> %# v\n", pretty.Formatter(&res))
92118
}

pkg/vault/sgx/sgx.go

Lines changed: 99 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@ package sgx
22

33
import (
44
"context"
5+
"crypto/ed25519"
6+
"encoding/base64"
7+
"encoding/hex"
58
"encoding/json"
69
"errors"
710
"fmt"
811
"iter"
912
"net"
1013
"os"
1114
"path/filepath"
15+
"strings"
1216
"sync"
1317

1418
tz "github.com/ecadlabs/gotez/v2"
@@ -78,13 +82,15 @@ func (e *encryptedKey) UnmarshalJSON(data []byte) error {
7882
///////////////////////////////////////////////////////////////////////////////////////////
7983

8084
type Config struct {
81-
SGXHost string `yaml:"host"`
82-
SGXPort string `yaml:"port"`
83-
EncryptionKeyID string `yaml:"encryption_key_id"`
84-
ProxyPort *uint32 `yaml:"proxy_local_port"`
85-
ProxyRemoteAddress string `yaml:"proxy_remote_address"`
86-
Storage *StorageConfig `yaml:"storage"`
87-
Credentials *Credentials `yaml:"credentials"`
85+
SGXHost string `yaml:"host"`
86+
SGXPort string `yaml:"port"`
87+
ServerPublicKey string `yaml:"server_public_key"` // Ed25519 public key (hex or base64)
88+
ClientPrivateKeyPath string `yaml:"client_private_key_path"` // Path to identity file containing Ed25519 private key
89+
EncryptionKeyID string `yaml:"encryption_key_id"`
90+
ProxyPort *uint32 `yaml:"proxy_local_port"`
91+
ProxyRemoteAddress string `yaml:"proxy_remote_address"`
92+
Storage *StorageConfig `yaml:"storage"`
93+
Credentials *Credentials `yaml:"credentials"`
8894
}
8995

9096
func resolve[T comparable](value T, ev string) T {
@@ -118,18 +124,64 @@ func populateConfig(c *Config) *Config {
118124
c = &zero
119125
}
120126
return &Config{
121-
SGXHost: resolve(c.SGXHost, "SGX_HOST"),
122-
SGXPort: resolve(c.SGXPort, "SGX_PORT"),
123-
EncryptionKeyID: resolve(c.EncryptionKeyID, "ENCRYPTION_KEY_ID"),
124-
ProxyPort: resolvePtr(c.ProxyPort, "PROXY_LOCAL_PORT"),
125-
ProxyRemoteAddress: resolve(c.ProxyRemoteAddress, "PROXY_REMOTE_ADDRESS"),
126-
Storage: c.Storage,
127-
Credentials: c.Credentials,
127+
SGXHost: resolve(c.SGXHost, "SGX_HOST"),
128+
SGXPort: resolve(c.SGXPort, "SGX_PORT"),
129+
ServerPublicKey: resolve(c.ServerPublicKey, "SGX_SERVER_PUBLIC_KEY"),
130+
ClientPrivateKeyPath: resolve(c.ClientPrivateKeyPath, "SGX_CLIENT_PRIVATE_KEY_PATH"),
131+
EncryptionKeyID: resolve(c.EncryptionKeyID, "ENCRYPTION_KEY_ID"),
132+
ProxyPort: resolvePtr(c.ProxyPort, "PROXY_LOCAL_PORT"),
133+
ProxyRemoteAddress: resolve(c.ProxyRemoteAddress, "PROXY_REMOTE_ADDRESS"),
134+
Storage: c.Storage,
135+
Credentials: c.Credentials,
128136
}
129137
}
130138

131139
type Credentials = awsutils.Config
132140

141+
func parseEd25519PublicKey(keyStr string) (ed25519.PublicKey, error) {
142+
if keyStr == "" {
143+
return nil, errors.New("empty public key")
144+
}
145+
data, err := hex.DecodeString(keyStr)
146+
if err != nil {
147+
data, err = base64.StdEncoding.DecodeString(keyStr)
148+
if err != nil {
149+
return nil, fmt.Errorf("failed to decode public key as hex or base64: %w", err)
150+
}
151+
}
152+
if len(data) != ed25519.PublicKeySize {
153+
return nil, fmt.Errorf("invalid public key size: expected %d, got %d", ed25519.PublicKeySize, len(data))
154+
}
155+
return ed25519.PublicKey(data), nil
156+
}
157+
158+
func loadClientPrivateKeyFromFile(path string) (ed25519.PrivateKey, error) {
159+
if path == "" {
160+
return nil, errors.New("empty identity file path")
161+
}
162+
path = os.ExpandEnv(path)
163+
data, err := os.ReadFile(path)
164+
if err != nil {
165+
return nil, fmt.Errorf("failed to read identity file: %w", err)
166+
}
167+
hexStr := string(data)
168+
hexStr = strings.TrimSpace(hexStr)
169+
if hexStr == "" {
170+
return nil, errors.New("identity file is empty")
171+
}
172+
keyData, err := hex.DecodeString(hexStr)
173+
if err != nil {
174+
keyData, err = base64.StdEncoding.DecodeString(hexStr)
175+
if err != nil {
176+
return nil, fmt.Errorf("failed to decode private key as hex or base64: %w", err)
177+
}
178+
}
179+
if len(keyData) != ed25519.PrivateKeySize {
180+
return nil, fmt.Errorf("invalid private key size: expected %d, got %d", ed25519.PrivateKeySize, len(keyData))
181+
}
182+
return ed25519.PrivateKey(keyData), nil
183+
}
184+
133185
type SgxVault[C any] struct {
134186
client *rpc.Client[C]
135187
storage keyBlobStorage
@@ -222,22 +274,52 @@ func newWithStorage(ctx context.Context, config *Config, storage keyBlobStorage)
222274
return nil, errors.New("(SGX): missing SGX port")
223275
}
224276

277+
var serverPublicKey ed25519.PublicKey
278+
var clientPrivateKey ed25519.PrivateKey
279+
if conf.ServerPublicKey != "" {
280+
serverPublicKey, err = parseEd25519PublicKey(conf.ServerPublicKey)
281+
if err != nil {
282+
return nil, fmt.Errorf("(SGX): invalid server public key: %w", err)
283+
}
284+
285+
if conf.ClientPrivateKeyPath == "" {
286+
return nil, errors.New("(SGX): client private key file path required when server public key is provided")
287+
}
288+
289+
clientPrivateKey, err = loadClientPrivateKeyFromFile(conf.ClientPrivateKeyPath)
290+
if err != nil {
291+
return nil, fmt.Errorf("(SGX): failed to load client private key from file: %w", err)
292+
}
293+
}
294+
225295
addr := net.JoinHostPort(conf.SGXHost, conf.SGXPort)
226296
log.Infof("(SGX): connecting to the enclave signer on %v...", addr)
227297
conn, err := net.Dial("tcp", addr)
228298
if err != nil {
229299
return nil, fmt.Errorf("(SGX): %w", err)
230300
}
231301

232-
v, err := newWithConn(ctx, conn, rpcCred, storage)
302+
v, err := newWithConn(ctx, conn, rpcCred, storage, serverPublicKey, clientPrivateKey)
233303
if err != nil {
234304
return nil, err
235305
}
236306
return v, nil
237307
}
238308

239-
func newWithConn[C any](ctx context.Context, conn net.Conn, credentials *C, storage keyBlobStorage) (*SgxVault[C], error) {
240-
client := rpc.NewClient[C](conn)
309+
func newWithConn[C any](ctx context.Context, conn net.Conn, credentials *C, storage keyBlobStorage, serverPublicKey ed25519.PublicKey, clientPrivateKey ed25519.PrivateKey) (*SgxVault[C], error) {
310+
var client *rpc.Client[C]
311+
var err error
312+
313+
if serverPublicKey != nil {
314+
client, err = rpc.NewSecureClient[C](conn, serverPublicKey, clientPrivateKey)
315+
if err != nil {
316+
conn.Close()
317+
return nil, fmt.Errorf("(SGX): failed to establish secure connection: %w", err)
318+
}
319+
} else {
320+
client = rpc.NewClient[C](conn)
321+
}
322+
241323
client.Logger = log.StandardLogger()
242324

243325
if err := client.Initialize(ctx, credentials); err != nil {

0 commit comments

Comments
 (0)