From 1e8663d114102f81c4713670fad46c1fd64b9cb9 Mon Sep 17 00:00:00 2001 From: Christian Mouchet Date: Fri, 12 Jul 2024 18:22:48 +0200 Subject: [PATCH] first version of secret provisioning interface --- examples/vec-mul/main.go | 42 +++++++++++++++++++--------------------- helium.go | 9 +++++---- helium_test.go | 4 ++-- node/config.go | 5 ++++- node/localtest.go | 22 ++++++++++----------- node/node.go | 29 +++++++++++++-------------- sessions/sessions.go | 16 +++++++-------- sessions/sessionstore.go | 4 ++-- sessions/testsession.go | 5 ++--- 9 files changed, 67 insertions(+), 69 deletions(-) diff --git a/examples/vec-mul/main.go b/examples/vec-mul/main.go index 0176ea0..aff8745 100644 --- a/examples/vec-mul/main.go +++ b/examples/vec-mul/main.go @@ -35,13 +35,11 @@ var ( Threshold: 3, // the number of honest nodes assumed by the system. ShamirPks: map[sessions.NodeID]mhe.ShamirPublicPoint{"node-1": 1, "node-2": 2, "node-3": 3, "node-4": 4}, // the shamir public-key of the nodes for the t-out-of-n-threshold scheme. PublicSeed: []byte{'e', 'x', 'a', 'm', 'p', 'l', 'e', 's', 'e', 'e', 'd'}, // the CRS - Secrets: nil, // normally read from a file, simulated here for simplicity (see loadSecrets) } // the configuration of peer nodes peerNodeConfig = node.Config{ ID: "", // read from command line args - Address: "", // read from command line args HelperID: "helper", // the node id of the helper node SessionParameters: []sessions.Parameters{sessionParams}, @@ -56,7 +54,6 @@ var ( // the configuration of the helper node. Similar as for peer node, but enables multiple protocol and circuit evaluations at once. helperConfig = node.Config{ ID: "", // read from command line args - Address: "", // read from command line args HelperID: "helper", SessionParameters: []sessions.Parameters{sessionParams}, @@ -144,17 +141,8 @@ func main() { var config node.Config if nodeID == helperID { config = helperConfig - if len(nodeAddr) == 0 { - log.Fatal("address of helper node not set, must provide with -address flag") - } - config.Address = nodeAddr } else { config = peerNodeConfig - secrets, err := loadSecrets(config.SessionParameters[0], nodeID) // session node must load their secrets. - if err != nil { - log.Fatalf("could not load node's secrets: %s", err) - } - config.SessionParameters[0].Secrets = secrets } config.ID = nodeID @@ -185,7 +173,8 @@ func main() { if nodeID == helperID { statsProvider, cdescs, outs, err = helium.RunHeliumServer(ctx, config, nodelist, app, ip) } else { - statsProvider, outs, err = helium.RunHeliumClient(ctx, config, nodelist, app, ip) + secrets := loadSecrets(config.SessionParameters[0], nodeID) + statsProvider, outs, err = helium.RunHeliumClient(ctx, config, nodelist, secrets, app, ip) } if err != nil { log.Fatalf("could not run node: %s", err) @@ -234,17 +223,26 @@ func main() { } // simulates loading the secrets. In a real application, the secrets would be loaded from a secure storage. -func loadSecrets(sp sessions.Parameters, nid sessions.NodeID) (secrets *sessions.Secrets, err error) { +func loadSecrets(params sessions.Parameters, nid sessions.NodeID) node.SecretProvider { - ss, err := sessions.GenTestSecretKeys(sp) - if err != nil { - return nil, err - } + var sp node.SecretProvider = func(sid sessions.ID) (*sessions.Secrets, error) { + + if sid != params.ID { + return nil, fmt.Errorf("no secret for session %s", sid) + } + + ss, err := sessions.GenTestSecretKeys(params) + if err != nil { + return nil, err + } + + secrets, ok := ss[nid] + if !ok { + return nil, fmt.Errorf("node %s not in session", nid) + } - secrets, ok := ss[nid] - if !ok { - return nil, fmt.Errorf("node %s not in session", nid) + return secrets, nil } - return + return sp } diff --git a/helium.go b/helium.go index 8f86a30..b11a402 100644 --- a/helium.go +++ b/helium.go @@ -14,14 +14,15 @@ import ( func RunHeliumServer(ctx context.Context, config node.Config, nl node.List, app node.App, ip compute.InputProvider) (hsv *HeliumServer, cdescs chan<- circuits.Descriptor, outs <-chan circuits.Output, err error) { - helperNode, err := node.New(config, nl) + helperNode, err := node.New(config, nl, nil) // TODO: assumes that the helper node never has any secrets if err != nil { return nil, nil, nil, err } hsv = NewHeliumServer(helperNode) - lis, err := net.Listen("tcp", string(config.Address)) + bindAddress := string(nl.AddressOf(config.ID)) + lis, err := net.Listen("tcp", bindAddress) if err != nil { return nil, nil, nil, err } @@ -37,9 +38,9 @@ func RunHeliumServer(ctx context.Context, config node.Config, nl node.List, app return } -func RunHeliumClient(ctx context.Context, config node.Config, nl node.List, app node.App, ip compute.InputProvider) (hc *HeliumClient, outs <-chan circuits.Output, err error) { +func RunHeliumClient(ctx context.Context, config node.Config, nl node.List, secrets node.SecretProvider, app node.App, ip compute.InputProvider) (hc *HeliumClient, outs <-chan circuits.Output, err error) { - n, err := node.New(config, nl) + n, err := node.New(config, nl, secrets) if err != nil { return nil, nil, err } diff --git a/helium_test.go b/helium_test.go index c016375..7dba388 100644 --- a/helium_test.go +++ b/helium_test.go @@ -88,7 +88,7 @@ func TestSetup(t *testing.T) { helper := NewHeliumServer(lt.HelperNode) clients := make([]*HeliumClient, ts.N) for i := 0; i < ts.N; i++ { - clients[i] = NewHeliumClient(lt.PeerNodes[i], lt.HelperNode.ID(), lt.HelperNode.Address()) + clients[i] = NewHeliumClient(lt.PeerNodes[i], lt.HelperNode.ID(), "local") } lis := bufconn.Listen(buffConBufferSize) @@ -188,7 +188,7 @@ func TestCompute(t *testing.T) { helper := NewHeliumServer(lt.HelperNode) clients := make([]*HeliumClient, ts.N) for i := 0; i < ts.N; i++ { - clients[i] = NewHeliumClient(lt.PeerNodes[i], lt.HelperNode.ID(), lt.HelperNode.Address()) + clients[i] = NewHeliumClient(lt.PeerNodes[i], lt.HelperNode.ID(), "local") } lis := bufconn.Listen(buffConBufferSize) diff --git a/node/config.go b/node/config.go index ec337bb..b962b6d 100644 --- a/node/config.go +++ b/node/config.go @@ -18,7 +18,6 @@ import ( // In the current implementation, only a single session per node is supported. type Config struct { ID sessions.NodeID - Address Address HelperID sessions.NodeID SessionParameters []sessions.Parameters SetupConfig setup.ServiceConfig @@ -118,3 +117,7 @@ type TLSConfig struct { OwnPk string // Own public key as a PEM encoded string OwnSk string // Own secret key as a PEM encoded string } + +// SecretProvider is a function that returns the secrets for a session, +// given the session ID. +type SecretProvider func(sessions.ID) (*sessions.Secrets, error) diff --git a/node/localtest.go b/node/localtest.go index 7c61af4..6d398cb 100644 --- a/node/localtest.go +++ b/node/localtest.go @@ -71,12 +71,7 @@ func NewLocalTest(config LocalTestConfig) (test *LocalTest, err error) { config.SessionParams.ID = "test-session" config.SessionParams.PublicSeed = []byte{'l', 'a', 't', 't', 'i', 'g', '0'} - nodeSecrets, err := sessions.GenTestSecretKeys(*config.SessionParams) - if err != nil { - return nil, err - } - - test.SessNodeConfigs, test.HelperConfig = genNodeConfigs(config, test.List, nodeSecrets) + test.SessNodeConfigs, test.HelperConfig = genNodeConfigs(config, test.List) if config.SessionParams != nil { test.TestSession, err = sessions.NewTestSessionFromParams(*config.SessionParams, test.HelperConfig.ID) @@ -85,15 +80,22 @@ func NewLocalTest(config LocalTestConfig) (test *LocalTest, err error) { } } + secrets, err := sessions.GenTestSecretKeys(test.TestSession.SessParams) + if err != nil { + return nil, err + } + test.Nodes = make([]*Node, 1+config.PeerNodes) - test.HelperNode, err = New(test.HelperConfig, test.List) + test.HelperNode, err = New(test.HelperConfig, test.List, nil) if err != nil { return nil, err } test.Nodes[0] = test.HelperNode for i, nc := range test.SessNodeConfigs { var err error - test.Nodes[i+1], err = New(nc, test.List) + test.Nodes[i+1], err = New(nc, test.List, func(_ sessions.ID) (*sessions.Secrets, error) { + return secrets[nc.ID], nil + }) if err != nil { return nil, err } @@ -114,7 +116,7 @@ func NewLocalTest(config LocalTestConfig) (test *LocalTest, err error) { } // genNodeConfigs generates the necessary NodeConfig for each party specified in the LocalTestConfig. -func genNodeConfigs(config LocalTestConfig, nl List, secrets map[sessions.NodeID]*sessions.Secrets) (sessNodesConfig []Config, helperNodeConfig Config) { +func genNodeConfigs(config LocalTestConfig, nl List) (sessNodesConfig []Config, helperNodeConfig Config) { tlsConfigs, err := createTLSConfigs(config, nl) if err != nil { @@ -152,7 +154,6 @@ func genNodeConfigs(config LocalTestConfig, nl List, secrets map[sessions.NodeID MaxCircuitEvaluation: 1, }, } - sessNodesConfig[i].SessionParameters[0].Secrets = secrets[nid] } helperExecConfig := protocols.ExecutorConfig{ @@ -162,7 +163,6 @@ func genNodeConfigs(config LocalTestConfig, nl List, secrets map[sessions.NodeID } helperNodeConfig = Config{ ID: hid, - Address: "local", HelperID: hid, SessionParameters: []sessions.Parameters{*sp}, TLSConfig: tlsConfigs[hid], diff --git a/node/node.go b/node/node.go index a7ff6d4..0dfdbf7 100644 --- a/node/node.go +++ b/node/node.go @@ -33,7 +33,6 @@ import ( // - the peer nodes connect to the helper node and provide their protocol shares and // encrypted inputs to the compuation. Peer nodes do not need to have an address. type Node struct { - addr Address id, helperID sessions.NodeID nodeList List @@ -58,7 +57,7 @@ type Node struct { // New creates a new Helium node from the provided config and node list. // The method returns an error if the config is invalid or if the node list is empty. -func New(config Config, nodeList List) (node *Node, err error) { +func New(config Config, nodeList List, secretsProvider SecretProvider) (node *Node, err error) { node = new(Node) if err := ValidateConfig(config, nodeList); err != nil { @@ -66,7 +65,6 @@ func New(config Config, nodeList List) (node *Node, err error) { } node.id = config.ID - node.addr = config.Address node.helperID = config.HelperID node.nodeList = nodeList @@ -79,7 +77,7 @@ func New(config Config, nodeList List) (node *Node, err error) { // session node.sessStore = sessions.NewStore() for _, sp := range config.SessionParameters { - _, err = node.createNewSession(sp) + _, err = node.createNewSession(sp, secretsProvider) if err != nil { return nil, err } @@ -256,16 +254,6 @@ func (node *Node) ID() sessions.NodeID { return node.id } -// Address returns the node's address. -func (node *Node) Address() Address { - return node.addr -} - -// HasAddress returns true if the node has an address. -func (node *Node) HasAddress() bool { - return node.addr != "" -} - // IsHelperNode returns true if the node is the helper node. func (node *Node) IsHelperNode() bool { return node.id == node.helperID @@ -366,8 +354,17 @@ func (node *Node) GetDecryptor(ctx context.Context) (*rlwe.Decryptor, error) { return node.compute.GetDecryptor(ctx) } -func (node *Node) createNewSession(sessParams sessions.Parameters) (sess *sessions.Session, err error) { - sess, err = node.sessStore.NewRLWESession(sessParams, node.id) +func (node *Node) createNewSession(sessParams sessions.Parameters, secrets SecretProvider) (sess *sessions.Session, err error) { + + var sec *sessions.Secrets + if slices.Contains(sessParams.Nodes, node.id) { + sec, err = secrets(sessParams.ID) + if err != nil { + return nil, err + } + } + + sess, err = node.sessStore.NewRLWESession(node.id, sessParams, sec) if err != nil { return sess, err } diff --git a/sessions/sessions.go b/sessions/sessions.go index 8c377b2..5a62c17 100644 --- a/sessions/sessions.go +++ b/sessions/sessions.go @@ -34,6 +34,7 @@ type CiphertextID string // Session holds the session's critical state. type Session struct { Parameters + Secrets NodeID NodeID Params FHEParameters @@ -59,11 +60,11 @@ type Parameters struct { Threshold int ShamirPks map[NodeID]drlwe.ShamirPublicPoint PublicSeed []byte - *Secrets + //*Secrets } // NewSession creates a new session. -func NewSession(sessParams Parameters, nodeID NodeID) (sess *Session, err error) { +func NewSession(nodeID NodeID, sessParams Parameters, secrets *Secrets) (sess *Session, err error) { sess = new(Session) sess.NodeID = nodeID //sess.ObjectStore = objStore @@ -109,14 +110,13 @@ func NewSession(sessParams Parameters, nodeID NodeID) (sess *Session, err error) // node re-generates its secret-key material for the session if utils.NewSet(sessParams.Nodes).Contains(nodeID) { - if sessParams.Secrets == nil || len(sessParams.Secrets.PrivateSeed) == 0 { + if secrets == nil || len(secrets.PrivateSeed) == 0 { return nil, fmt.Errorf("session nodes must specify session secrets") } - sess.Secrets = new(Secrets) - sess.PrivateSeed = slices.Clone(sessParams.Secrets.PrivateSeed) + sess.PrivateSeed = slices.Clone(secrets.PrivateSeed) - sessPrng, err := sampling.NewKeyedPRNG(sessParams.Secrets.PrivateSeed) + sessPrng, err := sampling.NewKeyedPRNG(secrets.PrivateSeed) if err != nil { return nil, fmt.Errorf("could not create session PRNG: %s", err) } @@ -132,10 +132,10 @@ func NewSession(sessParams Parameters, nodeID NodeID) (sess *Session, err error) } if sessParams.Threshold < len(sessParams.Nodes) { - if sessParams.ThresholdSecretKey == nil { + if secrets.ThresholdSecretKey == nil { return nil, fmt.Errorf("session nodes must specify threshold secret key when session threshold is less than the number of nodes") } - sess.ThresholdSecretKey = &drlwe.ShamirSecretShare{Poly: *sessParams.ThresholdSecretKey.CopyNew()} // TODO: add copy method to Lattigo + sess.ThresholdSecretKey = &drlwe.ShamirSecretShare{Poly: *secrets.ThresholdSecretKey.CopyNew()} // TODO: add copy method to Lattigo } } diff --git a/sessions/sessionstore.go b/sessions/sessionstore.go index ece544b..c67a8c0 100644 --- a/sessions/sessionstore.go +++ b/sessions/sessionstore.go @@ -22,13 +22,13 @@ func NewStore() *Store { return ss } -func (s *Store) NewRLWESession(sessParams Parameters, nodeID NodeID) (sess *Session, err error) { +func (s *Store) NewRLWESession(nodeID NodeID, sessParams Parameters, secrets *Secrets) (sess *Session, err error) { if _, exists := s.sess[sessParams.ID]; exists { return nil, fmt.Errorf("session id already exists: %s", sessParams.ID) } - sess, err = NewSession(sessParams, nodeID) + sess, err = NewSession(nodeID, sessParams, secrets) if err != nil { return nil, err } diff --git a/sessions/testsession.go b/sessions/testsession.go index 514e9e3..1da969f 100644 --- a/sessions/testsession.go +++ b/sessions/testsession.go @@ -68,10 +68,9 @@ func NewTestSessionFromParams(sp Parameters, helperID NodeID) (*TestSession, err for _, nid := range sp.Nodes { spi := sp - spi.Secrets = nodeSecrets[nid] // computes the ideal secret-key for the test - ts.NodeSessions[nid], err = NewSession(spi, nid) + ts.NodeSessions[nid], err = NewSession(nid, spi, nodeSecrets[nid]) if err != nil { return nil, err } @@ -82,7 +81,7 @@ func NewTestSessionFromParams(sp Parameters, helperID NodeID) (*TestSession, err ts.RlweParams.RingQP().AtLevel(ts.SkIdeal.Value.Q.Level(), ts.SkIdeal.Value.P.Level()).Add(sk.Value, ts.SkIdeal.Value, ts.SkIdeal.Value) } - ts.HelperSession, err = NewSession(sp, helperID) + ts.HelperSession, err = NewSession(helperID, sp, nil) if err != nil { return nil, err }