@@ -2,13 +2,17 @@ package sgx
22
33import (
44 "context"
5+ "crypto/ed25519"
6+ "encoding/base64"
7+ "encoding/hex"
58 "encoding/json"
69 "errors"
710 "fmt"
811 "iter"
912 "net"
1013 "os"
1114 "path/filepath"
15+ "strings"
1216 "sync"
1317
1418 tz "github.com/ecadlabs/gotez/v2"
@@ -78,13 +82,15 @@ func (e *encryptedKey) UnmarshalJSON(data []byte) error {
7882///////////////////////////////////////////////////////////////////////////////////////////
7983
8084type Config struct {
81- SGXHost string `yaml:"host"`
82- SGXPort string `yaml:"port"`
83- EncryptionKeyID string `yaml:"encryption_key_id"`
84- ProxyPort * uint32 `yaml:"proxy_local_port"`
85- ProxyRemoteAddress string `yaml:"proxy_remote_address"`
86- Storage * StorageConfig `yaml:"storage"`
87- Credentials * Credentials `yaml:"credentials"`
85+ SGXHost string `yaml:"host"`
86+ SGXPort string `yaml:"port"`
87+ ServerPublicKey string `yaml:"server_public_key"` // Ed25519 public key (hex or base64)
88+ ClientPrivateKeyPath string `yaml:"client_private_key_path"` // Path to identity file containing Ed25519 private key
89+ EncryptionKeyID string `yaml:"encryption_key_id"`
90+ ProxyPort * uint32 `yaml:"proxy_local_port"`
91+ ProxyRemoteAddress string `yaml:"proxy_remote_address"`
92+ Storage * StorageConfig `yaml:"storage"`
93+ Credentials * Credentials `yaml:"credentials"`
8894}
8995
9096func resolve [T comparable ](value T , ev string ) T {
@@ -118,18 +124,64 @@ func populateConfig(c *Config) *Config {
118124 c = & zero
119125 }
120126 return & Config {
121- SGXHost : resolve (c .SGXHost , "SGX_HOST" ),
122- SGXPort : resolve (c .SGXPort , "SGX_PORT" ),
123- EncryptionKeyID : resolve (c .EncryptionKeyID , "ENCRYPTION_KEY_ID" ),
124- ProxyPort : resolvePtr (c .ProxyPort , "PROXY_LOCAL_PORT" ),
125- ProxyRemoteAddress : resolve (c .ProxyRemoteAddress , "PROXY_REMOTE_ADDRESS" ),
126- Storage : c .Storage ,
127- Credentials : c .Credentials ,
127+ SGXHost : resolve (c .SGXHost , "SGX_HOST" ),
128+ SGXPort : resolve (c .SGXPort , "SGX_PORT" ),
129+ ServerPublicKey : resolve (c .ServerPublicKey , "SGX_SERVER_PUBLIC_KEY" ),
130+ ClientPrivateKeyPath : resolve (c .ClientPrivateKeyPath , "SGX_CLIENT_PRIVATE_KEY_PATH" ),
131+ EncryptionKeyID : resolve (c .EncryptionKeyID , "ENCRYPTION_KEY_ID" ),
132+ ProxyPort : resolvePtr (c .ProxyPort , "PROXY_LOCAL_PORT" ),
133+ ProxyRemoteAddress : resolve (c .ProxyRemoteAddress , "PROXY_REMOTE_ADDRESS" ),
134+ Storage : c .Storage ,
135+ Credentials : c .Credentials ,
128136 }
129137}
130138
131139type Credentials = awsutils.Config
132140
141+ func parseEd25519PublicKey (keyStr string ) (ed25519.PublicKey , error ) {
142+ if keyStr == "" {
143+ return nil , errors .New ("empty public key" )
144+ }
145+ data , err := hex .DecodeString (keyStr )
146+ if err != nil {
147+ data , err = base64 .StdEncoding .DecodeString (keyStr )
148+ if err != nil {
149+ return nil , fmt .Errorf ("failed to decode public key as hex or base64: %w" , err )
150+ }
151+ }
152+ if len (data ) != ed25519 .PublicKeySize {
153+ return nil , fmt .Errorf ("invalid public key size: expected %d, got %d" , ed25519 .PublicKeySize , len (data ))
154+ }
155+ return ed25519 .PublicKey (data ), nil
156+ }
157+
158+ func loadClientPrivateKeyFromFile (path string ) (ed25519.PrivateKey , error ) {
159+ if path == "" {
160+ return nil , errors .New ("empty identity file path" )
161+ }
162+ path = os .ExpandEnv (path )
163+ data , err := os .ReadFile (path )
164+ if err != nil {
165+ return nil , fmt .Errorf ("failed to read identity file: %w" , err )
166+ }
167+ hexStr := string (data )
168+ hexStr = strings .TrimSpace (hexStr )
169+ if hexStr == "" {
170+ return nil , errors .New ("identity file is empty" )
171+ }
172+ keyData , err := hex .DecodeString (hexStr )
173+ if err != nil {
174+ keyData , err = base64 .StdEncoding .DecodeString (hexStr )
175+ if err != nil {
176+ return nil , fmt .Errorf ("failed to decode private key as hex or base64: %w" , err )
177+ }
178+ }
179+ if len (keyData ) != ed25519 .PrivateKeySize {
180+ return nil , fmt .Errorf ("invalid private key size: expected %d, got %d" , ed25519 .PrivateKeySize , len (keyData ))
181+ }
182+ return ed25519 .PrivateKey (keyData ), nil
183+ }
184+
133185type SgxVault [C any ] struct {
134186 client * rpc.Client [C ]
135187 storage keyBlobStorage
@@ -222,22 +274,52 @@ func newWithStorage(ctx context.Context, config *Config, storage keyBlobStorage)
222274 return nil , errors .New ("(SGX): missing SGX port" )
223275 }
224276
277+ var serverPublicKey ed25519.PublicKey
278+ var clientPrivateKey ed25519.PrivateKey
279+ if conf .ServerPublicKey != "" {
280+ serverPublicKey , err = parseEd25519PublicKey (conf .ServerPublicKey )
281+ if err != nil {
282+ return nil , fmt .Errorf ("(SGX): invalid server public key: %w" , err )
283+ }
284+
285+ if conf .ClientPrivateKeyPath == "" {
286+ return nil , errors .New ("(SGX): client private key file path required when server public key is provided" )
287+ }
288+
289+ clientPrivateKey , err = loadClientPrivateKeyFromFile (conf .ClientPrivateKeyPath )
290+ if err != nil {
291+ return nil , fmt .Errorf ("(SGX): failed to load client private key from file: %w" , err )
292+ }
293+ }
294+
225295 addr := net .JoinHostPort (conf .SGXHost , conf .SGXPort )
226296 log .Infof ("(SGX): connecting to the enclave signer on %v..." , addr )
227297 conn , err := net .Dial ("tcp" , addr )
228298 if err != nil {
229299 return nil , fmt .Errorf ("(SGX): %w" , err )
230300 }
231301
232- v , err := newWithConn (ctx , conn , rpcCred , storage )
302+ v , err := newWithConn (ctx , conn , rpcCred , storage , serverPublicKey , clientPrivateKey )
233303 if err != nil {
234304 return nil , err
235305 }
236306 return v , nil
237307}
238308
239- func newWithConn [C any ](ctx context.Context , conn net.Conn , credentials * C , storage keyBlobStorage ) (* SgxVault [C ], error ) {
240- client := rpc.NewClient [C ](conn )
309+ func newWithConn [C any ](ctx context.Context , conn net.Conn , credentials * C , storage keyBlobStorage , serverPublicKey ed25519.PublicKey , clientPrivateKey ed25519.PrivateKey ) (* SgxVault [C ], error ) {
310+ var client * rpc.Client [C ]
311+ var err error
312+
313+ if serverPublicKey != nil {
314+ client , err = rpc .NewSecureClient [C ](conn , serverPublicKey , clientPrivateKey )
315+ if err != nil {
316+ conn .Close ()
317+ return nil , fmt .Errorf ("(SGX): failed to establish secure connection: %w" , err )
318+ }
319+ } else {
320+ client = rpc.NewClient [C ](conn )
321+ }
322+
241323 client .Logger = log .StandardLogger ()
242324
243325 if err := client .Initialize (ctx , credentials ); err != nil {
0 commit comments