Skip to content

Commit

Permalink
Fix shardnode raft structure
Browse files Browse the repository at this point in the history
  • Loading branch information
aminst committed Nov 28, 2023
1 parent d4fbc43 commit 332509d
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 140 deletions.
106 changes: 23 additions & 83 deletions pkg/shardnode/raft.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,33 +41,24 @@ func (p positionState) isPathInPaths(paths []int) bool {
}

type shardNodeFSM struct {
requestLog map[string][]string // map of block to requesting requestIDs
requestLogMu sync.Mutex
pathMap map[string]int // map of requestID to new path
pathMapMu sync.Mutex
storageIDMap map[string]int // map of requestID to new storageID
storageIDMapMu sync.Mutex
responseMap map[string]string // map of requestID to response map[string]string
responseMapMu sync.Mutex
requestLog map[string][]string // map of block to requesting requestIDs
pathMap map[string]int // map of requestID to new path
storageIDMap map[string]int // map of requestID to new storageID
stash map[string]stashState // map of block to stashState
stashMu sync.Mutex
responseChannel sync.Map // map of requestId to their channel for receiving response map[string] chan string
acks map[string][]string // map of requestID to array of blocks
acksMu sync.Mutex
nacks map[string][]string // map of requestID to array of blocks
nacksMu sync.Mutex
responseChannel sync.Map // map of requestId to their channel for receiving response map[string] chan string
acks map[string][]string // map of requestID to array of blocks
nacks map[string][]string // map of requestID to array of blocks
positionMap map[string]positionState // map of block to positionState
positionMapMu sync.RWMutex
raftNode RaftNodeWIthState
raftNodeMu sync.Mutex
}

func newShardNodeFSM() *shardNodeFSM {
return &shardNodeFSM{
requestLog: make(map[string][]string),
pathMap: make(map[string]int),
storageIDMap: make(map[string]int),
responseMap: make(map[string]string),
stash: make(map[string]stashState),
responseChannel: sync.Map{},
acks: make(map[string][]string),
Expand All @@ -81,7 +72,6 @@ func (fsm *shardNodeFSM) String() string {
out = out + fmt.Sprintf("requestLog: %v\n", fsm.requestLog)
out = out + fmt.Sprintf("pathMap: %v\n", fsm.pathMap)
out = out + fmt.Sprintf("storageIDMap: %v\n", fsm.storageIDMap)
out = out + fmt.Sprintf("responseMap: %v\n", fsm.responseMap)
out = out + fmt.Sprintf("stash: %v\n", fsm.stash)
out = out + fmt.Sprintf("responseChannel: %v\n", fsm.responseChannel)
out = out + fmt.Sprintf("acks: %v\n", fsm.acks)
Expand All @@ -91,19 +81,6 @@ func (fsm *shardNodeFSM) String() string {
}

func (fsm *shardNodeFSM) handleReplicateRequestAndPathAndStorage(requestID string, r ReplicateRequestAndPathAndStoragePayload) (isFirst bool) {
log.Debug().Msgf("Aquiring lock for shardNodeFSM in handleReplicateRequestAndPathAndStorage")
fsm.requestLogMu.Lock()
fsm.pathMapMu.Lock()
fsm.storageIDMapMu.Lock()
log.Debug().Msgf("Aquired lock for shardNodeFSM in handleReplicateRequestAndPathAndStorage")
defer func() {
log.Debug().Msgf("Releasing lock for shardNodeFSM in handleReplicateRequestAndPathAndStorage")
fsm.requestLogMu.Unlock()
fsm.pathMapMu.Unlock()
fsm.storageIDMapMu.Unlock()
log.Debug().Msgf("Released lock for shardNodeFSM in handleReplicateRequestAndPathAndStorage")
}()

fsm.requestLog[r.RequestedBlock] = append(fsm.requestLog[r.RequestedBlock], requestID)
fsm.pathMap[requestID] = r.Path
fsm.storageIDMap[requestID] = r.StorageID
Expand All @@ -115,32 +92,20 @@ func (fsm *shardNodeFSM) handleReplicateRequestAndPathAndStorage(requestID strin
return isFirst
}

type localReplicaChangeHandlerFunc func(requestID string, r ReplicateResponsePayload)

// It handles the response replication changes locally on each raft replica.
// The leader doesn't wait for this to finish to return success for the response replication command.
func (fsm *shardNodeFSM) handleLocalResponseReplicationChanges(requestID string, r ReplicateResponsePayload) {
log.Debug().Msgf("Aquiring lock for shardNodeFSM in handleLocalResponseReplicationChanges")
func (fsm *shardNodeFSM) handleReplicateResponse(requestID string, r ReplicateResponsePayload) string {
log.Debug().Msgf("Aquiring lock for shardNodeFSM in handleReplicateResponse")
start := time.Now()
fsm.stashMu.Lock()
fsm.responseMapMu.Lock()
fsm.positionMapMu.Lock()
fsm.pathMapMu.Lock()
fsm.storageIDMapMu.Lock()
fsm.requestLogMu.Lock()
fsm.raftNodeMu.Lock()
log.Debug().Msgf("Aquired lock for shardNodeFSM in handleLocalResponseReplicationChanges")
end := time.Now()
log.Debug().Msgf("Aquired lock for shardNodeFSM in handleReplicateResponse in %v", end.Sub(start))
// log.Debug().Msgf("Aquired lock for shardNodeFSM in handleReplicateResponse")
defer func() {
log.Debug().Msgf("Releasing lock for shardNodeFSM in handleReplicateResponse")
fsm.stashMu.Unlock()
fsm.responseMapMu.Unlock()
fsm.positionMapMu.Unlock()
fsm.pathMapMu.Unlock()
fsm.storageIDMapMu.Unlock()
fsm.requestLogMu.Unlock()
fsm.raftNodeMu.Unlock()
log.Debug().Msgf("Released lock for shardNodeFSM in handleReplicateResponse")
}()

stashState, exists := fsm.stash[r.RequestedBlock]
if exists {
if r.OpType == Write {
Expand All @@ -149,7 +114,7 @@ func (fsm *shardNodeFSM) handleLocalResponseReplicationChanges(requestID string,
fsm.stash[r.RequestedBlock] = stashState
}
} else {
response := fsm.responseMap[requestID]
response := r.Response
stashState := fsm.stash[r.RequestedBlock]
if r.OpType == Read {
stashState.value = response
Expand All @@ -161,37 +126,27 @@ func (fsm *shardNodeFSM) handleLocalResponseReplicationChanges(requestID string,
}
if fsm.raftNode.State() == raft.Leader {
fsm.positionMap[r.RequestedBlock] = positionState{path: fsm.pathMap[requestID], storageID: fsm.storageIDMap[requestID]}
for i := len(fsm.requestLog[r.RequestedBlock]) - 1; i >= 0; i-- {
for i := len(fsm.requestLog[r.RequestedBlock]) - 1; i >= 1; i-- { // We don't need to send the response to the first request
log.Debug().Msgf("Sending response to concurrent request number %d in requestLog for block %s", i, r.RequestedBlock)
timeout := time.After(5 * time.Second) // TODO: think about this in the batching scenario
responseChan, _ := fsm.responseChannel.Load(fsm.requestLog[r.RequestedBlock][i])
select {
case <-timeout:
log.Error().Msgf("timeout in sending response to concurrent requests")
log.Error().Msgf("timeout in sending response to concurrent request number %d in requestLog for block %s", i, r.RequestedBlock)
continue
case responseChan.(chan string) <- fsm.stash[r.RequestedBlock].value:
log.Debug().Msgf("sent response to concurrent request number %d in requestLog for block %s", i, r.RequestedBlock)
delete(fsm.pathMap, fsm.requestLog[r.RequestedBlock][i])
delete(fsm.storageIDMap, fsm.requestLog[r.RequestedBlock][i])
fsm.responseChannel.Delete(fsm.requestLog[r.RequestedBlock][i])
fsm.requestLog[r.RequestedBlock] = append(fsm.requestLog[r.RequestedBlock][:i], fsm.requestLog[r.RequestedBlock][i+1:]...)
}
}
}
delete(fsm.responseMap, requestID)
}

func (fsm *shardNodeFSM) handleReplicateResponse(requestID string, r ReplicateResponsePayload, f localReplicaChangeHandlerFunc) {
log.Debug().Msgf("Aquiring lock for shardNodeFSM in handleReplicateResponse")
fsm.responseMapMu.Lock()
log.Debug().Msgf("Aquired lock for shardNodeFSM in handleReplicateResponse")

defer func() {
log.Debug().Msgf("Releasing lock for shardNodeFSM in handleReplicateResponse")
fsm.responseMapMu.Unlock()
log.Debug().Msgf("Released lock for shardNodeFSM in handleReplicateResponse")
}()

fsm.responseMap[requestID] = r.Response
go f(requestID, r)
delete(fsm.pathMap, requestID)
delete(fsm.storageIDMap, requestID)
fsm.responseChannel.Delete(requestID)
delete(fsm.requestLog, r.RequestedBlock)
return fsm.stash[r.RequestedBlock].value
}

func (fsm *shardNodeFSM) handleReplicateSentBlocks(r ReplicateSentBlocksPayload) {
Expand All @@ -216,14 +171,10 @@ func (fsm *shardNodeFSM) handleReplicateSentBlocks(r ReplicateSentBlocksPayload)
// If an acked block was changed during the eviction, it will keep it.
func (fsm *shardNodeFSM) handleLocalAcksNacksReplicationChanges(requestID string) {
log.Debug().Msgf("Aquiring lock for shardNodeFSM in handleLocalAcksNacksReplicationChanges")
fsm.acksMu.Lock()
fsm.nacksMu.Lock()
fsm.stashMu.Lock()
log.Debug().Msgf("Aquired lock for shardNodeFSM in handleLocalAcksNacksReplicationChanges")
defer func() {
log.Debug().Msgf("Releasing lock for shardNodeFSM in handleLocalAcksNacksReplicationChanges")
fsm.acksMu.Unlock()
fsm.nacksMu.Unlock()
fsm.stashMu.Unlock()
log.Debug().Msgf("Released lock for shardNodeFSM in handleLocalAcksNacksReplicationChanges")
}()
Expand All @@ -244,17 +195,6 @@ func (fsm *shardNodeFSM) handleLocalAcksNacksReplicationChanges(requestID string
}

func (fsm *shardNodeFSM) handleReplicateAcksNacks(r ReplicateAcksNacksPayload) {
log.Debug().Msgf("Aquiring lock for shardNodeFSM in handleReplicateAcksNacks")
fsm.acksMu.Lock()
fsm.nacksMu.Lock()
log.Debug().Msgf("Aquired lock for shardNodeFSM in handleReplicateAcksNacks")
defer func() {
log.Debug().Msgf("Releasing lock for shardNodeFSM in handleReplicateAcksNacks")
fsm.acksMu.Unlock()
fsm.nacksMu.Unlock()
log.Debug().Msgf("Released lock for shardNodeFSM in handleReplicateAcksNacks")
}()

requestID := uuid.New().String()
fsm.acks[requestID] = r.AckedBlocks
fsm.nacks[requestID] = r.NackedBlocks
Expand Down Expand Up @@ -286,7 +226,7 @@ func (fsm *shardNodeFSM) Apply(rLog *raft.Log) interface{} {
if err != nil {
return fmt.Errorf("could not unmarshall the response replication command; %s", err)
}
fsm.handleReplicateResponse(requestID, responseReplicationPayload, fsm.handleLocalResponseReplicationChanges)
return fsm.handleReplicateResponse(requestID, responseReplicationPayload)
} else if command.Type == ReplicateSentBlocksCommand {
log.Debug().Msgf("got replication command for replicate sent blocks")
var replicateSentBlocksPayload ReplicateSentBlocksPayload
Expand Down
50 changes: 11 additions & 39 deletions pkg/shardnode/raft_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,28 +60,6 @@ func createTestReplicateResponsePayload(block string, response string, value str
}
}

// This test only checks the functionality of the handleReplicateResponse method;
// It doesn't run the local replica handler go routine so it doesn't affect the test results.
// This is achieved by giving an anonymous empty function to the handleReplicateResponse method.
func TestHandleReplicateResponseWithoutPreviousValue(t *testing.T) {
shardNodeFSM := newShardNodeFSM()
payload := createTestReplicateResponsePayload("block", "response", "value", Read)
shardNodeFSM.handleReplicateResponse("request2", payload, func(requestID string, r ReplicateResponsePayload) {})
if shardNodeFSM.responseMap["request2"] != "response" {
t.Errorf("Expected response to be \"response\" for request1, but the value is %s", shardNodeFSM.responseMap["request2"])
}
}

func TestHandleReplicateResponseOverridingPreviousValue(t *testing.T) {
shardNodeFSM := newShardNodeFSM()
shardNodeFSM.responseMap["request2"] = "prev"
payload := createTestReplicateResponsePayload("block", "response", "value", Read)
shardNodeFSM.handleReplicateResponse("request2", payload, func(requestID string, r ReplicateResponsePayload) {})
if shardNodeFSM.responseMap["request2"] != "response" {
t.Errorf("Expected response to be \"response\" for request1, but the value is %s", shardNodeFSM.responseMap["request2"])
}
}

type responseMessage struct {
requestID string
response string
Expand Down Expand Up @@ -139,32 +117,30 @@ func (m *mockRaftNodeFollower) State() raft.RaftState {

// In this case all the go routines should get the value that resides in stash.
// The stash value has priority over the response value.
func TestHandleLocalReplicaChangesWhenValueInStashReturnsCorrectReadValueToAllWaitingRequests(t *testing.T) {
func TestHandleReplicateResponseWhenValueInStashReturnsCorrectReadValueToAllWaitingRequests(t *testing.T) {
shardNodeFSM := newShardNodeFSM()
shardNodeFSM.raftNode = &mockRaftNodeLeader{}
shardNodeFSM.requestLog["block"] = []string{"request1", "request2", "request3"}
shardNodeFSM.responseChannel.Store("request1", make(chan string))
shardNodeFSM.responseChannel.Store("request2", make(chan string))
shardNodeFSM.responseChannel.Store("request3", make(chan string))
shardNodeFSM.stash["block"] = stashState{value: "test_value"}

payload := createTestReplicateResponsePayload("block", "response", "value", Read)
go shardNodeFSM.handleLocalResponseReplicationChanges("request1", payload)
go shardNodeFSM.handleReplicateResponse("request1", payload)

checkWaitingChannelsHelper(t, shardNodeFSM.responseChannel, "test_value")
}

func TestHandleLocalReplicaChangesWhenValueInStashReturnsCorrectWriteValueToAllWaitingRequests(t *testing.T) {
func TestHandleReplicateResponseWhenValueInStashReturnsCorrectWriteValueToAllWaitingRequests(t *testing.T) {
shardNodeFSM := newShardNodeFSM()
shardNodeFSM.raftNode = &mockRaftNodeLeader{}
shardNodeFSM.requestLog["block"] = []string{"request1", "request2", "request3"}
shardNodeFSM.responseChannel.Store("request1", make(chan string))
shardNodeFSM.responseChannel.Store("request2", make(chan string))
shardNodeFSM.responseChannel.Store("request3", make(chan string))
shardNodeFSM.stash["block"] = stashState{value: "test_value"}

payload := createTestReplicateResponsePayload("block", "response", "value_write", Write)
go shardNodeFSM.handleLocalResponseReplicationChanges("request1", payload)
go shardNodeFSM.handleReplicateResponse("request1", payload)

checkWaitingChannelsHelper(t, shardNodeFSM.responseChannel, "value_write")

Expand All @@ -173,17 +149,15 @@ func TestHandleLocalReplicaChangesWhenValueInStashReturnsCorrectWriteValueToAllW
}
}

func TestHandleLocalReplicaChangesWhenValueNotInStashReturnsResponseToAllWaitingRequests(t *testing.T) {
func TestHandleReplicateResponseWhenValueNotInStashReturnsResponseToAllWaitingRequests(t *testing.T) {
shardNodeFSM := newShardNodeFSM()
shardNodeFSM.raftNode = &mockRaftNodeLeader{}
shardNodeFSM.requestLog["block"] = []string{"request1", "request2", "request3"}
shardNodeFSM.responseChannel.Store("request1", make(chan string))
shardNodeFSM.responseChannel.Store("request2", make(chan string))
shardNodeFSM.responseChannel.Store("request3", make(chan string))
shardNodeFSM.responseMap["request1"] = "response_from_oramnode"

payload := createTestReplicateResponsePayload("block", "response", "", Read)
go shardNodeFSM.handleLocalResponseReplicationChanges("request1", payload)
payload := createTestReplicateResponsePayload("block", "response_from_oramnode", "", Read)
go shardNodeFSM.handleReplicateResponse("request1", payload)

checkWaitingChannelsHelper(t, shardNodeFSM.responseChannel, "response_from_oramnode")

Expand All @@ -192,17 +166,15 @@ func TestHandleLocalReplicaChangesWhenValueNotInStashReturnsResponseToAllWaiting
}
}

func TestHandleLocalReplicaChangesWhenValueNotInStashReturnsWriteResponseToAllWaitingRequests(t *testing.T) {
func TestHandleReplicateResponseWhenValueNotInStashReturnsWriteResponseToAllWaitingRequests(t *testing.T) {
shardNodeFSM := newShardNodeFSM()
shardNodeFSM.raftNode = &mockRaftNodeLeader{}
shardNodeFSM.requestLog["block"] = []string{"request1", "request2", "request3"}
shardNodeFSM.responseChannel.Store("request1", make(chan string))
shardNodeFSM.responseChannel.Store("request2", make(chan string))
shardNodeFSM.responseChannel.Store("request3", make(chan string))
shardNodeFSM.responseMap["request1"] = "response_from_oramnode"

payload := createTestReplicateResponsePayload("block", "response", "write_val", Write)
go shardNodeFSM.handleLocalResponseReplicationChanges("request1", payload)
go shardNodeFSM.handleReplicateResponse("request1", payload)

checkWaitingChannelsHelper(t, shardNodeFSM.responseChannel, "write_val")

Expand All @@ -211,7 +183,7 @@ func TestHandleLocalReplicaChangesWhenValueNotInStashReturnsWriteResponseToAllWa
}
}

func TestHandleLocalReplicaChangesWhenNotLeaderDoesNotWriteOnChannels(t *testing.T) {
func TestHandleReplicateResponseWhenNotLeaderDoesNotWriteOnChannels(t *testing.T) {
shardNodeFSM := newShardNodeFSM()
shardNodeFSM.raftNode = &mockRaftNodeFollower{}
shardNodeFSM.requestLog["block"] = []string{"request1", "request2"}
Expand All @@ -220,7 +192,7 @@ func TestHandleLocalReplicaChangesWhenNotLeaderDoesNotWriteOnChannels(t *testing
shardNodeFSM.stash["block"] = stashState{value: "test_value"}

payload := createTestReplicateResponsePayload("block", "response", "", Read)
go shardNodeFSM.handleLocalResponseReplicationChanges("request1", payload)
go shardNodeFSM.handleReplicateResponse("request1", payload)

for {
ch1Any, _ := shardNodeFSM.responseChannel.Load("request1")
Expand Down
18 changes: 15 additions & 3 deletions pkg/shardnode/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"time"

pb "github.com/dsg-uwaterloo/oblishard/api/shardnode"
"github.com/dsg-uwaterloo/oblishard/pkg/commonerrs"
"github.com/dsg-uwaterloo/oblishard/pkg/config"
"github.com/dsg-uwaterloo/oblishard/pkg/rpc"
"github.com/dsg-uwaterloo/oblishard/pkg/storage"
Expand Down Expand Up @@ -87,6 +88,11 @@ func (s *shardNodeServer) sendBatchesForever() {
func (s *shardNodeServer) sendCurrentBatches() {
storageQueues := make(map[int][]blockRequest)
responseChannels := make(map[string]chan string)
// TODO: I have another idea instead of the high priority lock.
// I can have a seperate go routine that has a for loop that manages the lock
// It has two channels one for low priority and one for high priority
// It will always start by trying to read from the high priority channel,
// if it is empty, it will read from the low priority channel.
s.batchManager.mu.HighPriorityLock()
for storageID, requests := range s.batchManager.storageQueues {
storageQueues[storageID] = append(storageQueues[storageID], requests...)
Expand All @@ -109,6 +115,7 @@ func (s *shardNodeServer) sendCurrentBatches() {
log.Error().Msgf("Could not get value from the oramnode; %s", response.err)
continue
}
log.Debug().Msgf("Got batch response from oram node replica: %v", response)
go func(response batchResponse) {
for _, readPathReply := range response.Responses {
log.Debug().Msgf("Got reply from oram node replica: %v", readPathReply)
Expand All @@ -127,6 +134,9 @@ func (s *shardNodeServer) sendCurrentBatches() {
}

func (s *shardNodeServer) query(ctx context.Context, op OperationType, block string, value string) (string, error) {
if s.raftNode.State() != raft.Leader {
return "", fmt.Errorf(commonerrs.NotTheLeaderError)
}
tracer := otel.Tracer("")
ctx, querySpan := tracer.Start(ctx, "shardnode query")
requestID, err := rpc.GetRequestIDFromContext(ctx)
Expand Down Expand Up @@ -167,11 +177,15 @@ func (s *shardNodeServer) query(ctx context.Context, op OperationType, block str
return "", fmt.Errorf("could not create response replication command; %s", err)
}
_, responseReplicationSpan := tracer.Start(ctx, "apply response replication")
err = s.raftNode.Apply(responseReplicationCommand, 0).Error()
responseApplyFuture := s.raftNode.Apply(responseReplicationCommand, 0)
responseReplicationSpan.End()
err = responseApplyFuture.Error()
if err != nil {
return "", fmt.Errorf("could not apply log to the FSM; %s", err)
}
response := responseApplyFuture.Response().(string)
log.Debug().Msgf("Got is first response from response channel for block %s; value: %s", block, response)
return response, nil
}
responseValue := <-responseChannel
log.Debug().Msgf("Got response from response channel for block %s; value: %s", block, responseValue)
Expand Down Expand Up @@ -301,9 +315,7 @@ func StartServer(shardNodeServerID int, ip string, rpcPort int, replicaID int, r
if err != nil {
log.Fatal().Msgf("The raft node creation did not succeed; %s", err)
}
shardNodeFSM.raftNodeMu.Lock()
shardNodeFSM.raftNode = r
shardNodeFSM.raftNodeMu.Unlock()

if !isFirst {
conn, err := grpc.Dial(joinAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
Expand Down
Loading

0 comments on commit 332509d

Please sign in to comment.