diff --git a/protocols/executor.go b/protocols/executor.go index 2c8e0a9..4eac4d3 100644 --- a/protocols/executor.go +++ b/protocols/executor.go @@ -358,13 +358,13 @@ func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session, clearProtocol := func() { s.connectedNodesMu.Lock() + s.runningProtoMu.Lock() for _, part := range pd.Participants { s.connectedNodes[part].Remove(pd.ID()) } s.connectedNodesMu.Unlock() s.connectedNodesCond.Broadcast() - s.runningProtoMu.Lock() delete(s.runningProtos, pid) s.runningProtoMu.Unlock() } @@ -378,8 +378,8 @@ func (s *Executor) runAsAggregator(ctx context.Context, sess *sessions.Session, input, err := s.inputProvider(ctx, pd) if err != nil { - cancelAgg() clearProtocol() + cancelAgg() aggOut.Error = fmt.Errorf("cannot get input for protocol: %w", err) return } @@ -586,11 +586,21 @@ func (s *Executor) Register(peer sessions.NodeID) error { func (s *Executor) Unregister(peer sessions.NodeID) error { s.connectedNodesMu.Lock() - _, has := s.connectedNodes[peer] + pids, has := s.connectedNodes[peer] if !has { panic("unregistering an unregistered node") } - s.DisconnectedNode(peer) + + s.runningProtoMu.RLock() + for pid := range pids { + p, has := s.runningProtos[pid] + if !has { + panic("incoherent state: protocol not running but node is registered for it") + } + p.disconnected <- peer + } + s.runningProtoMu.RUnlock() + delete(s.connectedNodes, peer) s.connectedNodesMu.Unlock() @@ -646,15 +656,6 @@ func (s *Executor) getProtocolDescriptor(sig Signature, sess *sessions.Session) return pd } -func (s *Executor) DisconnectedNode(id sessions.NodeID) { - s.runningProtoMu.RLock() - protoIds := s.connectedNodes[id] - for pid := range protoIds { - s.runningProtos[pid].disconnected <- id - } - s.runningProtoMu.RUnlock() -} - func (s *Executor) Logf(msg string, v ...any) { log.Printf("%s | [executor] %s\n", s.self, fmt.Sprintf(msg, v...)) }