Skip to content

Commit

Permalink
better helper persistent state loading
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristianMct committed Jul 19, 2024
1 parent 3b6635e commit 6fcba5c
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 47 deletions.
6 changes: 5 additions & 1 deletion circuits/circuits.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,11 @@ func (tr *TestRuntime) Input(opl OperandLabel) *FutureOperand {
tr.l.Lock()
defer tr.l.Unlock()
fop := NewFutureOperand(opl)
ct, err := tr.encryptor.EncryptNew(tr.inputProvider(opl))
pt := tr.inputProvider(opl)
if pt == nil {
panic(fmt.Errorf("input provider returned nil input for %s", opl))
}
ct, err := tr.encryptor.EncryptNew(pt)
if err != nil {
panic(err)
}
Expand Down
8 changes: 8 additions & 0 deletions circuits/operand.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ func (opl OperandLabel) CircuitID() sessions.CircuitID {
return sessions.CircuitID(strings.Trim(path.Dir(nopl.Path), "/"))
}

func (opl OperandLabel) CiphertextID() sessions.CiphertextID {
nopl, err := url.Parse(string(opl))
if err != nil {
panic(fmt.Errorf("invalid operand label: %s", opl))
}
return sessions.CiphertextID(path.Base(nopl.Path))
}

// HasNode returns true if the operand label has the given host id.
func (opl OperandLabel) HasNode(id sessions.NodeID) bool {
nopl, err := url.Parse(string(opl))
Expand Down
46 changes: 16 additions & 30 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,29 +150,19 @@ func (node *Node) Run(ctx context.Context, app App, ip compute.InputProvider, up

// runs the setup phase
if node.IsHelperNode() {
go func() {

// loads the setup state from persistent storage and rebuilds a ("fake") event log
// TODO: proper log storing and loading ?
setupSigs := setup.DescriptionToSignatureList(*app.SetupDescription)
sigList := make([]protocols.Signature, 0, len(setupSigs))
for _, sig := range setupSigs {
protoCompleted, err := node.setup.GetCompletedDescriptor(ctx, sig)
// TODO: error checking against a keynotfoud type of error to distinguish real failure cases from the absence of a completed descriptor
if protoCompleted == nil || err != nil {
sigList = append(sigList, sig)
} else {
pd := *protoCompleted
if sig.Type == protocols.RKG {
rkg1Desc := *protoCompleted
rkg1Desc.Type = protocols.RKG1
sc.setupCoordinator.outgoing <- setup.Event{Event: protocols.Event{EventType: protocols.Started, Descriptor: rkg1Desc}}
sc.setupCoordinator.outgoing <- setup.Event{Event: protocols.Event{EventType: protocols.Completed, Descriptor: rkg1Desc}}
}
sc.setupCoordinator.outgoing <- setup.Event{Event: protocols.Event{EventType: protocols.Started, Descriptor: pd}}
sc.setupCoordinator.outgoing <- setup.Event{Event: protocols.Event{EventType: protocols.Completed, Descriptor: pd}}
}
// loads the setup state from persistent storage and rebuilds a ("fake") event log
// TODO: proper log storing and loading ?
setupSigs := setup.DescriptionToSignatureList(*app.SetupDescription)
sigList := make([]protocols.Signature, 0, len(setupSigs))
for _, sig := range setupSigs {
protoCompleted, err := node.setup.GetCompletedDescriptor(ctx, sig)
// TODO: error checking against a keynotfoud type of error to distinguish real failure cases from the absence of a completed descriptor
if protoCompleted == nil || err != nil {
sigList = append(sigList, sig)
}
}

go func() {

node.Logf("running setup phase: %d signatures to run", len(sigList))
for _, sig := range sigList {
Expand Down Expand Up @@ -217,6 +207,10 @@ func (node *Node) Run(ctx context.Context, app App, ip compute.InputProvider, up
return cds, or, nil
}

func (node *Node) GetCompletedSetupDescriptor(ctx context.Context, sig protocols.Signature) (*protocols.Descriptor, error) {
return node.setup.GetCompletedDescriptor(ctx, sig)
}

// Transport interface implementation

// GetAggregationOutput returns the aggregation output for a given protocol descriptor.
Expand Down Expand Up @@ -301,14 +295,6 @@ func (node *Node) Logf(msg string, v ...any) {
log.Printf("%s | [node] %s\n", node.id, fmt.Sprintf(msg, v...))
}

// func (node *Node) RegisterPostsetupHandler(h func(*pkg.SessionStore, compute.PublicKeyBackend) error) {
// node.postsetupHandler = h
// }

// func (node *Node) RegisterPrecomputeHandler(h func(*pkg.SessionStore, compute.PublicKeyBackend) error) {
// node.precomputeHandler = h
// }

// pkg.PublicKeyBackend interface implementation

// GetCollectivePublicKey returns the collective public key.
Expand Down
4 changes: 2 additions & 2 deletions protocols/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ func (s *Executor) Run(ctx context.Context, trans Transport) error { // TODO: ca
select {
case qpd, more := <-s.queuedPart:
if !more {
s.Logf("closed participation queue")
//s.Logf("closed participation queue")
return nil
}
if err := s.runAsParticipant(qpd.ctx, qpd.pd); err != nil {
Expand Down Expand Up @@ -586,7 +586,7 @@ func (s *Executor) Register(peer sessions.NodeID) error {
s.connectedNodesMu.Unlock()
s.connectedNodesCond.Broadcast()

s.Logf("registered peer %v, %d online nodes", peer, len(s.connectedNodes))
//s.Logf("registered peer %v, %d online nodes", peer, len(s.connectedNodes))
return nil // TODO: Implement
}

Expand Down
52 changes: 42 additions & 10 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/ChristianMct/helium/node"
"github.com/ChristianMct/helium/protocols"
"github.com/ChristianMct/helium/services/compute"
"github.com/ChristianMct/helium/services/setup"
"github.com/ChristianMct/helium/sessions"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -101,8 +102,13 @@ type nodeCoordinator struct {
}

func (nc *nodeCoordinator) Register(ctx context.Context) (evChan *coordinator.Channel[node.Event], present int, err error) {
outgoing := make(chan node.Event)

incoming := make(chan node.Event, len(nc.events))
for _, ev := range nc.events {
incoming <- ev
}

outgoing := make(chan node.Event)
go func() {
for ev := range outgoing {
ev := ev
Expand All @@ -111,7 +117,7 @@ func (nc *nodeCoordinator) Register(ctx context.Context) (evChan *coordinator.Ch
nc.CloseEventLog()
}()

return &coordinator.Channel[node.Event]{Outgoing: outgoing}, 0, nil
return &coordinator.Channel[node.Event]{Outgoing: outgoing, Incoming: incoming}, len(nc.events), nil
}

type nodeTransport struct {
Expand Down Expand Up @@ -139,19 +145,45 @@ func (nt *nodeTransport) PutCiphertext(ctx context.Context, ct sessions.Cipherte
}

func (hsv *HeliumServer) Run(ctx context.Context, app node.App, ip compute.InputProvider) (cdescs chan<- circuits.Descriptor, outs <-chan circuits.Output, err error) {

// populates a pseudo log with completed protocols
// TODO: proper log storing and loading
setupSigs := setup.DescriptionToSignatureList(*app.SetupDescription)
for _, sig := range setupSigs {
protoCompleted, err := hsv.helperNode.GetCompletedSetupDescriptor(ctx, sig)
// TODO: error checking against a keynotfoud type of error to distinguish real failure cases from the absence of a completed descriptor
if protoCompleted != nil && err == nil {
pd := *protoCompleted
if sig.Type == protocols.RKG {
rkg1Desc := *protoCompleted
rkg1Desc.Type = protocols.RKG1
hsv.AppendEventToLog(
node.Event{SetupEvent: &setup.Event{Event: protocols.Event{EventType: protocols.Started, Descriptor: rkg1Desc}}},
node.Event{SetupEvent: &setup.Event{Event: protocols.Event{EventType: protocols.Completed, Descriptor: rkg1Desc}}},
)
}
hsv.AppendEventToLog(
node.Event{SetupEvent: &setup.Event{Event: protocols.Event{EventType: protocols.Started, Descriptor: pd}}},
node.Event{SetupEvent: &setup.Event{Event: protocols.Event{EventType: protocols.Completed, Descriptor: pd}}},
)
}
}

return hsv.helperNode.Run(ctx, app, ip, &nodeCoordinator{hsv}, &nodeTransport{s: hsv})
}

// AppendEventToLog is called by the server side to append a new event to the log and send it to all connected peers.
func (hsv *HeliumServer) AppendEventToLog(event node.Event) {
func (hsv *HeliumServer) AppendEventToLog(events ...node.Event) {
hsv.mu.Lock()
hsv.events = append(hsv.events, event)
for nodeID, node := range hsv.nodes {
if node.sendQueue != nil {
select {
case node.sendQueue <- event:
default:
panic(fmt.Errorf("node %s has full send queue", nodeID)) // TODO: handle this by closing stream instead
hsv.events = append(hsv.events, events...)
for _, event := range events {
for nodeID, node := range hsv.nodes {
if node.sendQueue != nil {
select {
case node.sendQueue <- event:
default:
panic(fmt.Errorf("node %s has full send queue", nodeID)) // TODO: handle this by closing stream instead
}
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions services/compute/participant.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ func (p *participantRuntime) Input(opl circuits.OperandLabel) *circuits.FutureOp
switch enc := p.Encoder.(type) {
case *bgv.Encoder:
inpt = bgv.NewPlaintext(p.sess.Params.(bgv.Parameters), p.sess.Params.GetRLWEParameters().MaxLevel())
err = enc.Encode(in, inpt)
err = enc.ShallowCopy().Encode(in, inpt)
case *ckks.Encoder:
inpt = ckks.NewPlaintext(p.sess.Params.(ckks.Parameters), p.sess.Params.GetRLWEParameters().MaxLevel())
err = p.Encoder.(*ckks.Encoder).Encode(in, inpt)
err = enc.ShallowCopy().Encode(in, inpt)
}
if err != nil {
panic(fmt.Errorf("cannot encode input: %w", err))
Expand All @@ -184,7 +184,7 @@ func (p *participantRuntime) Input(opl circuits.OperandLabel) *circuits.FutureOp
fallthrough
case isRLWEPLaintext(in):
inpt := in.(*rlwe.Plaintext)
inct, err := p.EncryptNew(inpt)
inct, err := p.Encryptor.ShallowCopy().EncryptNew(inpt)
if err != nil {
panic(err)
}
Expand Down
2 changes: 1 addition & 1 deletion services/compute/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ type InputProvider func(context.Context, sessions.CircuitID, circuits.OperandLab

// NoInput is an input provider that returns nil for all inputs.
var NoInput InputProvider = func(_ context.Context, _ sessions.CircuitID, _ circuits.OperandLabel, _ sessions.Session) (any, error) {
return nil, nil
return nil, fmt.Errorf("node has no input")
}

// OutputReceiver is a type for receiving outputs from a circuit.
Expand Down

0 comments on commit 6fcba5c

Please sign in to comment.