From 6fcba5c6f8270b21959487afc230cf9be32fd60c Mon Sep 17 00:00:00 2001 From: Christian Mouchet Date: Fri, 19 Jul 2024 15:39:05 +0200 Subject: [PATCH] better helper persistent state loading --- circuits/circuits.go | 6 +++- circuits/operand.go | 8 +++++ node/node.go | 46 ++++++++++------------------- protocols/executor.go | 4 +-- server.go | 52 ++++++++++++++++++++++++++------- services/compute/participant.go | 6 ++-- services/compute/service.go | 2 +- 7 files changed, 77 insertions(+), 47 deletions(-) diff --git a/circuits/circuits.go b/circuits/circuits.go index 5eea04c..c72d15a 100644 --- a/circuits/circuits.go +++ b/circuits/circuits.go @@ -202,7 +202,11 @@ func (tr *TestRuntime) Input(opl OperandLabel) *FutureOperand { tr.l.Lock() defer tr.l.Unlock() fop := NewFutureOperand(opl) - ct, err := tr.encryptor.EncryptNew(tr.inputProvider(opl)) + pt := tr.inputProvider(opl) + if pt == nil { + panic(fmt.Errorf("input provider returned nil input for %s", opl)) + } + ct, err := tr.encryptor.EncryptNew(pt) if err != nil { panic(err) } diff --git a/circuits/operand.go b/circuits/operand.go index acabba6..96d1723 100644 --- a/circuits/operand.go +++ b/circuits/operand.go @@ -85,6 +85,14 @@ func (opl OperandLabel) CircuitID() sessions.CircuitID { return sessions.CircuitID(strings.Trim(path.Dir(nopl.Path), "/")) } +func (opl OperandLabel) CiphertextID() sessions.CiphertextID { + nopl, err := url.Parse(string(opl)) + if err != nil { + panic(fmt.Errorf("invalid operand label: %s", opl)) + } + return sessions.CiphertextID(path.Base(nopl.Path)) +} + // HasNode returns true if the operand label has the given host id. func (opl OperandLabel) HasNode(id sessions.NodeID) bool { nopl, err := url.Parse(string(opl)) diff --git a/node/node.go b/node/node.go index 69ee9d0..0ba3855 100644 --- a/node/node.go +++ b/node/node.go @@ -150,29 +150,19 @@ func (node *Node) Run(ctx context.Context, app App, ip compute.InputProvider, up // runs the setup phase if node.IsHelperNode() { - go func() { - - // loads the setup state from persistent storage and rebuilds a ("fake") event log - // TODO: proper log storing and loading ? - setupSigs := setup.DescriptionToSignatureList(*app.SetupDescription) - sigList := make([]protocols.Signature, 0, len(setupSigs)) - for _, sig := range setupSigs { - protoCompleted, err := node.setup.GetCompletedDescriptor(ctx, sig) - // TODO: error checking against a keynotfoud type of error to distinguish real failure cases from the absence of a completed descriptor - if protoCompleted == nil || err != nil { - sigList = append(sigList, sig) - } else { - pd := *protoCompleted - if sig.Type == protocols.RKG { - rkg1Desc := *protoCompleted - rkg1Desc.Type = protocols.RKG1 - sc.setupCoordinator.outgoing <- setup.Event{Event: protocols.Event{EventType: protocols.Started, Descriptor: rkg1Desc}} - sc.setupCoordinator.outgoing <- setup.Event{Event: protocols.Event{EventType: protocols.Completed, Descriptor: rkg1Desc}} - } - sc.setupCoordinator.outgoing <- setup.Event{Event: protocols.Event{EventType: protocols.Started, Descriptor: pd}} - sc.setupCoordinator.outgoing <- setup.Event{Event: protocols.Event{EventType: protocols.Completed, Descriptor: pd}} - } + // loads the setup state from persistent storage and rebuilds a ("fake") event log + // TODO: proper log storing and loading ? + setupSigs := setup.DescriptionToSignatureList(*app.SetupDescription) + sigList := make([]protocols.Signature, 0, len(setupSigs)) + for _, sig := range setupSigs { + protoCompleted, err := node.setup.GetCompletedDescriptor(ctx, sig) + // TODO: error checking against a keynotfoud type of error to distinguish real failure cases from the absence of a completed descriptor + if protoCompleted == nil || err != nil { + sigList = append(sigList, sig) } + } + + go func() { node.Logf("running setup phase: %d signatures to run", len(sigList)) for _, sig := range sigList { @@ -217,6 +207,10 @@ func (node *Node) Run(ctx context.Context, app App, ip compute.InputProvider, up return cds, or, nil } +func (node *Node) GetCompletedSetupDescriptor(ctx context.Context, sig protocols.Signature) (*protocols.Descriptor, error) { + return node.setup.GetCompletedDescriptor(ctx, sig) +} + // Transport interface implementation // GetAggregationOutput returns the aggregation output for a given protocol descriptor. @@ -301,14 +295,6 @@ func (node *Node) Logf(msg string, v ...any) { log.Printf("%s | [node] %s\n", node.id, fmt.Sprintf(msg, v...)) } -// func (node *Node) RegisterPostsetupHandler(h func(*pkg.SessionStore, compute.PublicKeyBackend) error) { -// node.postsetupHandler = h -// } - -// func (node *Node) RegisterPrecomputeHandler(h func(*pkg.SessionStore, compute.PublicKeyBackend) error) { -// node.precomputeHandler = h -// } - // pkg.PublicKeyBackend interface implementation // GetCollectivePublicKey returns the collective public key. diff --git a/protocols/executor.go b/protocols/executor.go index 5712222..17fe20c 100644 --- a/protocols/executor.go +++ b/protocols/executor.go @@ -259,7 +259,7 @@ func (s *Executor) Run(ctx context.Context, trans Transport) error { // TODO: ca select { case qpd, more := <-s.queuedPart: if !more { - s.Logf("closed participation queue") + //s.Logf("closed participation queue") return nil } if err := s.runAsParticipant(qpd.ctx, qpd.pd); err != nil { @@ -586,7 +586,7 @@ func (s *Executor) Register(peer sessions.NodeID) error { s.connectedNodesMu.Unlock() s.connectedNodesCond.Broadcast() - s.Logf("registered peer %v, %d online nodes", peer, len(s.connectedNodes)) + //s.Logf("registered peer %v, %d online nodes", peer, len(s.connectedNodes)) return nil // TODO: Implement } diff --git a/server.go b/server.go index c321cc6..d0bb270 100644 --- a/server.go +++ b/server.go @@ -15,6 +15,7 @@ import ( "github.com/ChristianMct/helium/node" "github.com/ChristianMct/helium/protocols" "github.com/ChristianMct/helium/services/compute" + "github.com/ChristianMct/helium/services/setup" "github.com/ChristianMct/helium/sessions" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -101,8 +102,13 @@ type nodeCoordinator struct { } func (nc *nodeCoordinator) Register(ctx context.Context) (evChan *coordinator.Channel[node.Event], present int, err error) { - outgoing := make(chan node.Event) + incoming := make(chan node.Event, len(nc.events)) + for _, ev := range nc.events { + incoming <- ev + } + + outgoing := make(chan node.Event) go func() { for ev := range outgoing { ev := ev @@ -111,7 +117,7 @@ func (nc *nodeCoordinator) Register(ctx context.Context) (evChan *coordinator.Ch nc.CloseEventLog() }() - return &coordinator.Channel[node.Event]{Outgoing: outgoing}, 0, nil + return &coordinator.Channel[node.Event]{Outgoing: outgoing, Incoming: incoming}, len(nc.events), nil } type nodeTransport struct { @@ -139,19 +145,45 @@ func (nt *nodeTransport) PutCiphertext(ctx context.Context, ct sessions.Cipherte } func (hsv *HeliumServer) Run(ctx context.Context, app node.App, ip compute.InputProvider) (cdescs chan<- circuits.Descriptor, outs <-chan circuits.Output, err error) { + + // populates a pseudo log with completed protocols + // TODO: proper log storing and loading + setupSigs := setup.DescriptionToSignatureList(*app.SetupDescription) + for _, sig := range setupSigs { + protoCompleted, err := hsv.helperNode.GetCompletedSetupDescriptor(ctx, sig) + // TODO: error checking against a keynotfoud type of error to distinguish real failure cases from the absence of a completed descriptor + if protoCompleted != nil && err == nil { + pd := *protoCompleted + if sig.Type == protocols.RKG { + rkg1Desc := *protoCompleted + rkg1Desc.Type = protocols.RKG1 + hsv.AppendEventToLog( + node.Event{SetupEvent: &setup.Event{Event: protocols.Event{EventType: protocols.Started, Descriptor: rkg1Desc}}}, + node.Event{SetupEvent: &setup.Event{Event: protocols.Event{EventType: protocols.Completed, Descriptor: rkg1Desc}}}, + ) + } + hsv.AppendEventToLog( + node.Event{SetupEvent: &setup.Event{Event: protocols.Event{EventType: protocols.Started, Descriptor: pd}}}, + node.Event{SetupEvent: &setup.Event{Event: protocols.Event{EventType: protocols.Completed, Descriptor: pd}}}, + ) + } + } + return hsv.helperNode.Run(ctx, app, ip, &nodeCoordinator{hsv}, &nodeTransport{s: hsv}) } // AppendEventToLog is called by the server side to append a new event to the log and send it to all connected peers. -func (hsv *HeliumServer) AppendEventToLog(event node.Event) { +func (hsv *HeliumServer) AppendEventToLog(events ...node.Event) { hsv.mu.Lock() - hsv.events = append(hsv.events, event) - for nodeID, node := range hsv.nodes { - if node.sendQueue != nil { - select { - case node.sendQueue <- event: - default: - panic(fmt.Errorf("node %s has full send queue", nodeID)) // TODO: handle this by closing stream instead + hsv.events = append(hsv.events, events...) + for _, event := range events { + for nodeID, node := range hsv.nodes { + if node.sendQueue != nil { + select { + case node.sendQueue <- event: + default: + panic(fmt.Errorf("node %s has full send queue", nodeID)) // TODO: handle this by closing stream instead + } } } } diff --git a/services/compute/participant.go b/services/compute/participant.go index 3ec3733..d583695 100644 --- a/services/compute/participant.go +++ b/services/compute/participant.go @@ -172,10 +172,10 @@ func (p *participantRuntime) Input(opl circuits.OperandLabel) *circuits.FutureOp switch enc := p.Encoder.(type) { case *bgv.Encoder: inpt = bgv.NewPlaintext(p.sess.Params.(bgv.Parameters), p.sess.Params.GetRLWEParameters().MaxLevel()) - err = enc.Encode(in, inpt) + err = enc.ShallowCopy().Encode(in, inpt) case *ckks.Encoder: inpt = ckks.NewPlaintext(p.sess.Params.(ckks.Parameters), p.sess.Params.GetRLWEParameters().MaxLevel()) - err = p.Encoder.(*ckks.Encoder).Encode(in, inpt) + err = enc.ShallowCopy().Encode(in, inpt) } if err != nil { panic(fmt.Errorf("cannot encode input: %w", err)) @@ -184,7 +184,7 @@ func (p *participantRuntime) Input(opl circuits.OperandLabel) *circuits.FutureOp fallthrough case isRLWEPLaintext(in): inpt := in.(*rlwe.Plaintext) - inct, err := p.EncryptNew(inpt) + inct, err := p.Encryptor.ShallowCopy().EncryptNew(inpt) if err != nil { panic(err) } diff --git a/services/compute/service.go b/services/compute/service.go index 8b65129..84c1fbd 100644 --- a/services/compute/service.go +++ b/services/compute/service.go @@ -84,7 +84,7 @@ type InputProvider func(context.Context, sessions.CircuitID, circuits.OperandLab // NoInput is an input provider that returns nil for all inputs. var NoInput InputProvider = func(_ context.Context, _ sessions.CircuitID, _ circuits.OperandLabel, _ sessions.Session) (any, error) { - return nil, nil + return nil, fmt.Errorf("node has no input") } // OutputReceiver is a type for receiving outputs from a circuit.