From 332509d84662adf0e1adc98c903aef51222152d8 Mon Sep 17 00:00:00 2001 From: aminst Date: Tue, 28 Nov 2023 17:54:41 -0500 Subject: [PATCH] Fix shardnode raft structure --- pkg/shardnode/raft.go | 106 ++++++++--------------------------- pkg/shardnode/raft_test.go | 50 ++++------------- pkg/shardnode/server.go | 18 +++++- pkg/shardnode/server_test.go | 15 ----- 4 files changed, 49 insertions(+), 140 deletions(-) diff --git a/pkg/shardnode/raft.go b/pkg/shardnode/raft.go index 0ffd0c3..e0aef13 100644 --- a/pkg/shardnode/raft.go +++ b/pkg/shardnode/raft.go @@ -41,25 +41,17 @@ func (p positionState) isPathInPaths(paths []int) bool { } type shardNodeFSM struct { - requestLog map[string][]string // map of block to requesting requestIDs - requestLogMu sync.Mutex - pathMap map[string]int // map of requestID to new path - pathMapMu sync.Mutex - storageIDMap map[string]int // map of requestID to new storageID - storageIDMapMu sync.Mutex - responseMap map[string]string // map of requestID to response map[string]string - responseMapMu sync.Mutex + requestLog map[string][]string // map of block to requesting requestIDs + pathMap map[string]int // map of requestID to new path + storageIDMap map[string]int // map of requestID to new storageID stash map[string]stashState // map of block to stashState stashMu sync.Mutex - responseChannel sync.Map // map of requestId to their channel for receiving response map[string] chan string - acks map[string][]string // map of requestID to array of blocks - acksMu sync.Mutex - nacks map[string][]string // map of requestID to array of blocks - nacksMu sync.Mutex + responseChannel sync.Map // map of requestId to their channel for receiving response map[string] chan string + acks map[string][]string // map of requestID to array of blocks + nacks map[string][]string // map of requestID to array of blocks positionMap map[string]positionState // map of block to positionState positionMapMu sync.RWMutex raftNode RaftNodeWIthState - raftNodeMu sync.Mutex } func newShardNodeFSM() *shardNodeFSM { @@ -67,7 +59,6 @@ func newShardNodeFSM() *shardNodeFSM { requestLog: make(map[string][]string), pathMap: make(map[string]int), storageIDMap: make(map[string]int), - responseMap: make(map[string]string), stash: make(map[string]stashState), responseChannel: sync.Map{}, acks: make(map[string][]string), @@ -81,7 +72,6 @@ func (fsm *shardNodeFSM) String() string { out = out + fmt.Sprintf("requestLog: %v\n", fsm.requestLog) out = out + fmt.Sprintf("pathMap: %v\n", fsm.pathMap) out = out + fmt.Sprintf("storageIDMap: %v\n", fsm.storageIDMap) - out = out + fmt.Sprintf("responseMap: %v\n", fsm.responseMap) out = out + fmt.Sprintf("stash: %v\n", fsm.stash) out = out + fmt.Sprintf("responseChannel: %v\n", fsm.responseChannel) out = out + fmt.Sprintf("acks: %v\n", fsm.acks) @@ -91,19 +81,6 @@ func (fsm *shardNodeFSM) String() string { } func (fsm *shardNodeFSM) handleReplicateRequestAndPathAndStorage(requestID string, r ReplicateRequestAndPathAndStoragePayload) (isFirst bool) { - log.Debug().Msgf("Aquiring lock for shardNodeFSM in handleReplicateRequestAndPathAndStorage") - fsm.requestLogMu.Lock() - fsm.pathMapMu.Lock() - fsm.storageIDMapMu.Lock() - log.Debug().Msgf("Aquired lock for shardNodeFSM in handleReplicateRequestAndPathAndStorage") - defer func() { - log.Debug().Msgf("Releasing lock for shardNodeFSM in handleReplicateRequestAndPathAndStorage") - fsm.requestLogMu.Unlock() - fsm.pathMapMu.Unlock() - fsm.storageIDMapMu.Unlock() - log.Debug().Msgf("Released lock for shardNodeFSM in handleReplicateRequestAndPathAndStorage") - }() - fsm.requestLog[r.RequestedBlock] = append(fsm.requestLog[r.RequestedBlock], requestID) fsm.pathMap[requestID] = r.Path fsm.storageIDMap[requestID] = r.StorageID @@ -115,32 +92,20 @@ func (fsm *shardNodeFSM) handleReplicateRequestAndPathAndStorage(requestID strin return isFirst } -type localReplicaChangeHandlerFunc func(requestID string, r ReplicateResponsePayload) - -// It handles the response replication changes locally on each raft replica. -// The leader doesn't wait for this to finish to return success for the response replication command. -func (fsm *shardNodeFSM) handleLocalResponseReplicationChanges(requestID string, r ReplicateResponsePayload) { - log.Debug().Msgf("Aquiring lock for shardNodeFSM in handleLocalResponseReplicationChanges") +func (fsm *shardNodeFSM) handleReplicateResponse(requestID string, r ReplicateResponsePayload) string { + log.Debug().Msgf("Aquiring lock for shardNodeFSM in handleReplicateResponse") + start := time.Now() fsm.stashMu.Lock() - fsm.responseMapMu.Lock() fsm.positionMapMu.Lock() - fsm.pathMapMu.Lock() - fsm.storageIDMapMu.Lock() - fsm.requestLogMu.Lock() - fsm.raftNodeMu.Lock() - log.Debug().Msgf("Aquired lock for shardNodeFSM in handleLocalResponseReplicationChanges") + end := time.Now() + log.Debug().Msgf("Aquired lock for shardNodeFSM in handleReplicateResponse in %v", end.Sub(start)) + // log.Debug().Msgf("Aquired lock for shardNodeFSM in handleReplicateResponse") defer func() { log.Debug().Msgf("Releasing lock for shardNodeFSM in handleReplicateResponse") fsm.stashMu.Unlock() - fsm.responseMapMu.Unlock() fsm.positionMapMu.Unlock() - fsm.pathMapMu.Unlock() - fsm.storageIDMapMu.Unlock() - fsm.requestLogMu.Unlock() - fsm.raftNodeMu.Unlock() log.Debug().Msgf("Released lock for shardNodeFSM in handleReplicateResponse") }() - stashState, exists := fsm.stash[r.RequestedBlock] if exists { if r.OpType == Write { @@ -149,7 +114,7 @@ func (fsm *shardNodeFSM) handleLocalResponseReplicationChanges(requestID string, fsm.stash[r.RequestedBlock] = stashState } } else { - response := fsm.responseMap[requestID] + response := r.Response stashState := fsm.stash[r.RequestedBlock] if r.OpType == Read { stashState.value = response @@ -161,37 +126,27 @@ func (fsm *shardNodeFSM) handleLocalResponseReplicationChanges(requestID string, } if fsm.raftNode.State() == raft.Leader { fsm.positionMap[r.RequestedBlock] = positionState{path: fsm.pathMap[requestID], storageID: fsm.storageIDMap[requestID]} - for i := len(fsm.requestLog[r.RequestedBlock]) - 1; i >= 0; i-- { + for i := len(fsm.requestLog[r.RequestedBlock]) - 1; i >= 1; i-- { // We don't need to send the response to the first request + log.Debug().Msgf("Sending response to concurrent request number %d in requestLog for block %s", i, r.RequestedBlock) timeout := time.After(5 * time.Second) // TODO: think about this in the batching scenario responseChan, _ := fsm.responseChannel.Load(fsm.requestLog[r.RequestedBlock][i]) select { case <-timeout: - log.Error().Msgf("timeout in sending response to concurrent requests") + log.Error().Msgf("timeout in sending response to concurrent request number %d in requestLog for block %s", i, r.RequestedBlock) continue case responseChan.(chan string) <- fsm.stash[r.RequestedBlock].value: + log.Debug().Msgf("sent response to concurrent request number %d in requestLog for block %s", i, r.RequestedBlock) delete(fsm.pathMap, fsm.requestLog[r.RequestedBlock][i]) delete(fsm.storageIDMap, fsm.requestLog[r.RequestedBlock][i]) fsm.responseChannel.Delete(fsm.requestLog[r.RequestedBlock][i]) - fsm.requestLog[r.RequestedBlock] = append(fsm.requestLog[r.RequestedBlock][:i], fsm.requestLog[r.RequestedBlock][i+1:]...) } } } - delete(fsm.responseMap, requestID) -} - -func (fsm *shardNodeFSM) handleReplicateResponse(requestID string, r ReplicateResponsePayload, f localReplicaChangeHandlerFunc) { - log.Debug().Msgf("Aquiring lock for shardNodeFSM in handleReplicateResponse") - fsm.responseMapMu.Lock() - log.Debug().Msgf("Aquired lock for shardNodeFSM in handleReplicateResponse") - - defer func() { - log.Debug().Msgf("Releasing lock for shardNodeFSM in handleReplicateResponse") - fsm.responseMapMu.Unlock() - log.Debug().Msgf("Released lock for shardNodeFSM in handleReplicateResponse") - }() - - fsm.responseMap[requestID] = r.Response - go f(requestID, r) + delete(fsm.pathMap, requestID) + delete(fsm.storageIDMap, requestID) + fsm.responseChannel.Delete(requestID) + delete(fsm.requestLog, r.RequestedBlock) + return fsm.stash[r.RequestedBlock].value } func (fsm *shardNodeFSM) handleReplicateSentBlocks(r ReplicateSentBlocksPayload) { @@ -216,14 +171,10 @@ func (fsm *shardNodeFSM) handleReplicateSentBlocks(r ReplicateSentBlocksPayload) // If an acked block was changed during the eviction, it will keep it. func (fsm *shardNodeFSM) handleLocalAcksNacksReplicationChanges(requestID string) { log.Debug().Msgf("Aquiring lock for shardNodeFSM in handleLocalAcksNacksReplicationChanges") - fsm.acksMu.Lock() - fsm.nacksMu.Lock() fsm.stashMu.Lock() log.Debug().Msgf("Aquired lock for shardNodeFSM in handleLocalAcksNacksReplicationChanges") defer func() { log.Debug().Msgf("Releasing lock for shardNodeFSM in handleLocalAcksNacksReplicationChanges") - fsm.acksMu.Unlock() - fsm.nacksMu.Unlock() fsm.stashMu.Unlock() log.Debug().Msgf("Released lock for shardNodeFSM in handleLocalAcksNacksReplicationChanges") }() @@ -244,17 +195,6 @@ func (fsm *shardNodeFSM) handleLocalAcksNacksReplicationChanges(requestID string } func (fsm *shardNodeFSM) handleReplicateAcksNacks(r ReplicateAcksNacksPayload) { - log.Debug().Msgf("Aquiring lock for shardNodeFSM in handleReplicateAcksNacks") - fsm.acksMu.Lock() - fsm.nacksMu.Lock() - log.Debug().Msgf("Aquired lock for shardNodeFSM in handleReplicateAcksNacks") - defer func() { - log.Debug().Msgf("Releasing lock for shardNodeFSM in handleReplicateAcksNacks") - fsm.acksMu.Unlock() - fsm.nacksMu.Unlock() - log.Debug().Msgf("Released lock for shardNodeFSM in handleReplicateAcksNacks") - }() - requestID := uuid.New().String() fsm.acks[requestID] = r.AckedBlocks fsm.nacks[requestID] = r.NackedBlocks @@ -286,7 +226,7 @@ func (fsm *shardNodeFSM) Apply(rLog *raft.Log) interface{} { if err != nil { return fmt.Errorf("could not unmarshall the response replication command; %s", err) } - fsm.handleReplicateResponse(requestID, responseReplicationPayload, fsm.handleLocalResponseReplicationChanges) + return fsm.handleReplicateResponse(requestID, responseReplicationPayload) } else if command.Type == ReplicateSentBlocksCommand { log.Debug().Msgf("got replication command for replicate sent blocks") var replicateSentBlocksPayload ReplicateSentBlocksPayload diff --git a/pkg/shardnode/raft_test.go b/pkg/shardnode/raft_test.go index a435469..17cb590 100644 --- a/pkg/shardnode/raft_test.go +++ b/pkg/shardnode/raft_test.go @@ -60,28 +60,6 @@ func createTestReplicateResponsePayload(block string, response string, value str } } -// This test only checks the functionality of the handleReplicateResponse method; -// It doesn't run the local replica handler go routine so it doesn't affect the test results. -// This is achieved by giving an anonymous empty function to the handleReplicateResponse method. -func TestHandleReplicateResponseWithoutPreviousValue(t *testing.T) { - shardNodeFSM := newShardNodeFSM() - payload := createTestReplicateResponsePayload("block", "response", "value", Read) - shardNodeFSM.handleReplicateResponse("request2", payload, func(requestID string, r ReplicateResponsePayload) {}) - if shardNodeFSM.responseMap["request2"] != "response" { - t.Errorf("Expected response to be \"response\" for request1, but the value is %s", shardNodeFSM.responseMap["request2"]) - } -} - -func TestHandleReplicateResponseOverridingPreviousValue(t *testing.T) { - shardNodeFSM := newShardNodeFSM() - shardNodeFSM.responseMap["request2"] = "prev" - payload := createTestReplicateResponsePayload("block", "response", "value", Read) - shardNodeFSM.handleReplicateResponse("request2", payload, func(requestID string, r ReplicateResponsePayload) {}) - if shardNodeFSM.responseMap["request2"] != "response" { - t.Errorf("Expected response to be \"response\" for request1, but the value is %s", shardNodeFSM.responseMap["request2"]) - } -} - type responseMessage struct { requestID string response string @@ -139,32 +117,30 @@ func (m *mockRaftNodeFollower) State() raft.RaftState { // In this case all the go routines should get the value that resides in stash. // The stash value has priority over the response value. -func TestHandleLocalReplicaChangesWhenValueInStashReturnsCorrectReadValueToAllWaitingRequests(t *testing.T) { +func TestHandleReplicateResponseWhenValueInStashReturnsCorrectReadValueToAllWaitingRequests(t *testing.T) { shardNodeFSM := newShardNodeFSM() shardNodeFSM.raftNode = &mockRaftNodeLeader{} shardNodeFSM.requestLog["block"] = []string{"request1", "request2", "request3"} - shardNodeFSM.responseChannel.Store("request1", make(chan string)) shardNodeFSM.responseChannel.Store("request2", make(chan string)) shardNodeFSM.responseChannel.Store("request3", make(chan string)) shardNodeFSM.stash["block"] = stashState{value: "test_value"} payload := createTestReplicateResponsePayload("block", "response", "value", Read) - go shardNodeFSM.handleLocalResponseReplicationChanges("request1", payload) + go shardNodeFSM.handleReplicateResponse("request1", payload) checkWaitingChannelsHelper(t, shardNodeFSM.responseChannel, "test_value") } -func TestHandleLocalReplicaChangesWhenValueInStashReturnsCorrectWriteValueToAllWaitingRequests(t *testing.T) { +func TestHandleReplicateResponseWhenValueInStashReturnsCorrectWriteValueToAllWaitingRequests(t *testing.T) { shardNodeFSM := newShardNodeFSM() shardNodeFSM.raftNode = &mockRaftNodeLeader{} shardNodeFSM.requestLog["block"] = []string{"request1", "request2", "request3"} - shardNodeFSM.responseChannel.Store("request1", make(chan string)) shardNodeFSM.responseChannel.Store("request2", make(chan string)) shardNodeFSM.responseChannel.Store("request3", make(chan string)) shardNodeFSM.stash["block"] = stashState{value: "test_value"} payload := createTestReplicateResponsePayload("block", "response", "value_write", Write) - go shardNodeFSM.handleLocalResponseReplicationChanges("request1", payload) + go shardNodeFSM.handleReplicateResponse("request1", payload) checkWaitingChannelsHelper(t, shardNodeFSM.responseChannel, "value_write") @@ -173,17 +149,15 @@ func TestHandleLocalReplicaChangesWhenValueInStashReturnsCorrectWriteValueToAllW } } -func TestHandleLocalReplicaChangesWhenValueNotInStashReturnsResponseToAllWaitingRequests(t *testing.T) { +func TestHandleReplicateResponseWhenValueNotInStashReturnsResponseToAllWaitingRequests(t *testing.T) { shardNodeFSM := newShardNodeFSM() shardNodeFSM.raftNode = &mockRaftNodeLeader{} shardNodeFSM.requestLog["block"] = []string{"request1", "request2", "request3"} - shardNodeFSM.responseChannel.Store("request1", make(chan string)) shardNodeFSM.responseChannel.Store("request2", make(chan string)) shardNodeFSM.responseChannel.Store("request3", make(chan string)) - shardNodeFSM.responseMap["request1"] = "response_from_oramnode" - payload := createTestReplicateResponsePayload("block", "response", "", Read) - go shardNodeFSM.handleLocalResponseReplicationChanges("request1", payload) + payload := createTestReplicateResponsePayload("block", "response_from_oramnode", "", Read) + go shardNodeFSM.handleReplicateResponse("request1", payload) checkWaitingChannelsHelper(t, shardNodeFSM.responseChannel, "response_from_oramnode") @@ -192,17 +166,15 @@ func TestHandleLocalReplicaChangesWhenValueNotInStashReturnsResponseToAllWaiting } } -func TestHandleLocalReplicaChangesWhenValueNotInStashReturnsWriteResponseToAllWaitingRequests(t *testing.T) { +func TestHandleReplicateResponseWhenValueNotInStashReturnsWriteResponseToAllWaitingRequests(t *testing.T) { shardNodeFSM := newShardNodeFSM() shardNodeFSM.raftNode = &mockRaftNodeLeader{} shardNodeFSM.requestLog["block"] = []string{"request1", "request2", "request3"} - shardNodeFSM.responseChannel.Store("request1", make(chan string)) shardNodeFSM.responseChannel.Store("request2", make(chan string)) shardNodeFSM.responseChannel.Store("request3", make(chan string)) - shardNodeFSM.responseMap["request1"] = "response_from_oramnode" payload := createTestReplicateResponsePayload("block", "response", "write_val", Write) - go shardNodeFSM.handleLocalResponseReplicationChanges("request1", payload) + go shardNodeFSM.handleReplicateResponse("request1", payload) checkWaitingChannelsHelper(t, shardNodeFSM.responseChannel, "write_val") @@ -211,7 +183,7 @@ func TestHandleLocalReplicaChangesWhenValueNotInStashReturnsWriteResponseToAllWa } } -func TestHandleLocalReplicaChangesWhenNotLeaderDoesNotWriteOnChannels(t *testing.T) { +func TestHandleReplicateResponseWhenNotLeaderDoesNotWriteOnChannels(t *testing.T) { shardNodeFSM := newShardNodeFSM() shardNodeFSM.raftNode = &mockRaftNodeFollower{} shardNodeFSM.requestLog["block"] = []string{"request1", "request2"} @@ -220,7 +192,7 @@ func TestHandleLocalReplicaChangesWhenNotLeaderDoesNotWriteOnChannels(t *testing shardNodeFSM.stash["block"] = stashState{value: "test_value"} payload := createTestReplicateResponsePayload("block", "response", "", Read) - go shardNodeFSM.handleLocalResponseReplicationChanges("request1", payload) + go shardNodeFSM.handleReplicateResponse("request1", payload) for { ch1Any, _ := shardNodeFSM.responseChannel.Load("request1") diff --git a/pkg/shardnode/server.go b/pkg/shardnode/server.go index 7ef1349..93b38de 100644 --- a/pkg/shardnode/server.go +++ b/pkg/shardnode/server.go @@ -9,6 +9,7 @@ import ( "time" pb "github.com/dsg-uwaterloo/oblishard/api/shardnode" + "github.com/dsg-uwaterloo/oblishard/pkg/commonerrs" "github.com/dsg-uwaterloo/oblishard/pkg/config" "github.com/dsg-uwaterloo/oblishard/pkg/rpc" "github.com/dsg-uwaterloo/oblishard/pkg/storage" @@ -87,6 +88,11 @@ func (s *shardNodeServer) sendBatchesForever() { func (s *shardNodeServer) sendCurrentBatches() { storageQueues := make(map[int][]blockRequest) responseChannels := make(map[string]chan string) + // TODO: I have another idea instead of the high priority lock. + // I can have a seperate go routine that has a for loop that manages the lock + // It has two channels one for low priority and one for high priority + // It will always start by trying to read from the high priority channel, + // if it is empty, it will read from the low priority channel. s.batchManager.mu.HighPriorityLock() for storageID, requests := range s.batchManager.storageQueues { storageQueues[storageID] = append(storageQueues[storageID], requests...) @@ -109,6 +115,7 @@ func (s *shardNodeServer) sendCurrentBatches() { log.Error().Msgf("Could not get value from the oramnode; %s", response.err) continue } + log.Debug().Msgf("Got batch response from oram node replica: %v", response) go func(response batchResponse) { for _, readPathReply := range response.Responses { log.Debug().Msgf("Got reply from oram node replica: %v", readPathReply) @@ -127,6 +134,9 @@ func (s *shardNodeServer) sendCurrentBatches() { } func (s *shardNodeServer) query(ctx context.Context, op OperationType, block string, value string) (string, error) { + if s.raftNode.State() != raft.Leader { + return "", fmt.Errorf(commonerrs.NotTheLeaderError) + } tracer := otel.Tracer("") ctx, querySpan := tracer.Start(ctx, "shardnode query") requestID, err := rpc.GetRequestIDFromContext(ctx) @@ -167,11 +177,15 @@ func (s *shardNodeServer) query(ctx context.Context, op OperationType, block str return "", fmt.Errorf("could not create response replication command; %s", err) } _, responseReplicationSpan := tracer.Start(ctx, "apply response replication") - err = s.raftNode.Apply(responseReplicationCommand, 0).Error() + responseApplyFuture := s.raftNode.Apply(responseReplicationCommand, 0) responseReplicationSpan.End() + err = responseApplyFuture.Error() if err != nil { return "", fmt.Errorf("could not apply log to the FSM; %s", err) } + response := responseApplyFuture.Response().(string) + log.Debug().Msgf("Got is first response from response channel for block %s; value: %s", block, response) + return response, nil } responseValue := <-responseChannel log.Debug().Msgf("Got response from response channel for block %s; value: %s", block, responseValue) @@ -301,9 +315,7 @@ func StartServer(shardNodeServerID int, ip string, rpcPort int, replicaID int, r if err != nil { log.Fatal().Msgf("The raft node creation did not succeed; %s", err) } - shardNodeFSM.raftNodeMu.Lock() shardNodeFSM.raftNode = r - shardNodeFSM.raftNodeMu.Unlock() if !isFirst { conn, err := grpc.Dial(joinAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) diff --git a/pkg/shardnode/server_test.go b/pkg/shardnode/server_test.go index b54bb16..4f381b2 100644 --- a/pkg/shardnode/server_test.go +++ b/pkg/shardnode/server_test.go @@ -106,9 +106,7 @@ func startLeaderRaftNodeServer(t *testing.T, batchSize int, withBatchReponses bo if err != nil { t.Errorf("unable to start raft server; %v", err) } - fsm.raftNodeMu.Lock() fsm.raftNode = r - fsm.raftNodeMu.Unlock() <-r.LeaderCh() // wait to become the leader oramNodeClients := getMockOramNodeClients() if withBatchReponses { @@ -238,25 +236,12 @@ func TestQueryCleansTempValuesInFSMAfterExecution(t *testing.T) { s := startLeaderRaftNodeServer(t, 1, false) ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("requestid", "request1")) s.query(ctx, Write, "a", "val") - s.shardNodeFSM.pathMapMu.Lock() - s.shardNodeFSM.storageIDMapMu.Lock() - s.shardNodeFSM.responseMapMu.Lock() - s.shardNodeFSM.requestLogMu.Lock() - defer func() { - s.shardNodeFSM.pathMapMu.Unlock() - s.shardNodeFSM.storageIDMapMu.Unlock() - s.shardNodeFSM.responseMapMu.Unlock() - s.shardNodeFSM.requestLogMu.Unlock() - }() if _, exists := s.shardNodeFSM.pathMap["request1"]; exists { t.Errorf("query should remove the request from the pathMap after successful execution.") } if _, exists := s.shardNodeFSM.storageIDMap["request1"]; exists { t.Errorf("query should remove the request from the storageIDMap after successful execution.") } - if _, exists := s.shardNodeFSM.responseMap["request1"]; exists { - t.Errorf("query should remove the request from the responseMap after successful execution.") - } if _, exists := s.shardNodeFSM.requestLog["request1"]; exists { t.Errorf("query should remove the request from the requestLog after successful execution.") }