Skip to content

Commit

Permalink
Make bucket accesses happen in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
aminst committed Nov 14, 2023
1 parent 0b49a08 commit 937b8a2
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 47 deletions.
146 changes: 108 additions & 38 deletions pkg/oramnode/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ import (

type storage interface {
GetMaxAccessCount() int
LockStorage(storageID int)
UnlockStorage(storageID int)
GetBlockOffset(bucketID int, storageID int, blocks []string) (offset int, isReal bool, blockFound string, err error)
GetAccessCount(bucketID int, storageID int) (count int, err error)
ReadBucket(bucketID int, storageID int) (blocks map[string]string, err error)
WriteBucket(bucketID int, storageID int, ReadBucketBlocks map[string]string, shardNodeBlocks map[string]string, isAtomic bool) (writtenBlocks map[string]string, err error)
WriteBucket(bucketID int, storageID int, ReadBucketBlocks map[string]string, shardNodeBlocks map[string]string) (writtenBlocks map[string]string, err error)
ReadBlock(bucketID int, storageID int, offset int) (value string, err error)
GetBucketsInPaths(paths []int) (bucketIDs []int, err error)
}
Expand Down Expand Up @@ -89,48 +91,78 @@ func (o *oramNodeServer) performFailedOperations() error {
func (o *oramNodeServer) earlyReshuffle(buckets []int, storageID int) error {
log.Debug().Msgf("Performing early reshuffle with buckets %v and storageID %d", buckets, storageID)
// TODO: can we make this a background thread?
errorChan := make(chan error)
for _, bucket := range buckets {
accessCount, err := o.storageHandler.GetAccessCount(bucket, storageID)
if err != nil {
return fmt.Errorf("unable to get access count from the server; %s", err)
}
if accessCount < o.storageHandler.GetMaxAccessCount() {
continue
}
localStash, err := o.storageHandler.ReadBucket(bucket, storageID)
if err != nil {
return fmt.Errorf("unable to read bucket from the server; %s", err)
}
writtenBlocks, err := o.storageHandler.WriteBucket(bucket, storageID, localStash, nil, false)
if err != nil {
return fmt.Errorf("unable to write bucket from the server; %s", err)
}
for block := range localStash {
if _, exists := writtenBlocks[block]; !exists {
return fmt.Errorf("unable to write all blocks to the bucket")
go func(bucket int) {
accessCount, err := o.storageHandler.GetAccessCount(bucket, storageID)
if err != nil {
errorChan <- fmt.Errorf("unable to get access count from the server; %s", err)
return
}
if accessCount < o.storageHandler.GetMaxAccessCount() {
errorChan <- nil
return
}
localStash, err := o.storageHandler.ReadBucket(bucket, storageID)
if err != nil {
errorChan <- fmt.Errorf("unable to read bucket from the server; %s", err)
return
}
writtenBlocks, err := o.storageHandler.WriteBucket(bucket, storageID, localStash, nil)
if err != nil {
errorChan <- fmt.Errorf("unable to write bucket from the server; %s", err)
return
}
for block := range localStash {
if _, exists := writtenBlocks[block]; !exists {
errorChan <- fmt.Errorf("unable to write all blocks to the bucket")
return
}
}
errorChan <- nil
}(bucket)
}
for i := 0; i < len(buckets); i++ {
err := <-errorChan
if err != nil {
return err
}
}
return nil
}

type readBucketResponse struct {
bucket int
blocks map[string]string
err error
}

func (o *oramNodeServer) asyncReadBucket(bucket int, storageID int, responseChan chan readBucketResponse) {
blocks, err := o.storageHandler.ReadBucket(bucket, storageID)
responseChan <- readBucketResponse{bucket: bucket, blocks: blocks, err: err}
}

func (o *oramNodeServer) readAllBuckets(buckets []int, storageID int) (blocksFromReadBucket map[int]map[string]string, err error) {
log.Debug().Msgf("Reading all buckets with buckets %v and storageID %d", buckets, storageID)
blocksFromReadBucket = make(map[int]map[string]string) // map of bucket to map of block to value
if err != nil {
return nil, fmt.Errorf("unable to get bucket ids for early reshuffle path; %v", err)
}
readBucketResponseChan := make(chan readBucketResponse)
for _, bucket := range buckets {
blocks, err := o.storageHandler.ReadBucket(bucket, storageID)
log.Debug().Msgf("Got blocks %v from bucket %d", blocks, bucket)
go o.asyncReadBucket(bucket, storageID, readBucketResponseChan)
}
for i := 0; i < len(buckets); i++ {
response := <-readBucketResponseChan
log.Debug().Msgf("Got blocks %v from bucket %d", response.blocks, response.bucket)
if err != nil {
return nil, fmt.Errorf("unable to read bucket; %s", err)
}
if blocksFromReadBucket[bucket] == nil {
blocksFromReadBucket[bucket] = make(map[string]string)
if blocksFromReadBucket[response.bucket] == nil {
blocksFromReadBucket[response.bucket] = make(map[string]string)
}
for block, value := range blocks {
blocksFromReadBucket[bucket][block] = value
for block, value := range response.blocks {
blocksFromReadBucket[response.bucket][block] = value
}
}
return blocksFromReadBucket, nil
Expand Down Expand Up @@ -161,7 +193,7 @@ func (o *oramNodeServer) writeBackBlocksToAllBuckets(buckets []int, storageID in
receivedBlocksIsWritten[block] = false
}
for i := len(buckets) - 1; i >= 0; i-- {
writtenBlocks, err := o.storageHandler.WriteBucket(buckets[i], storageID, blocksFromReadBucket[buckets[i]], receivedBlocksCopy, true)
writtenBlocks, err := o.storageHandler.WriteBucket(buckets[i], storageID, blocksFromReadBucket[buckets[i]], receivedBlocksCopy)
if err != nil {
return nil, fmt.Errorf("unable to atomic write bucket; %s", err)
}
Expand All @@ -176,6 +208,8 @@ func (o *oramNodeServer) writeBackBlocksToAllBuckets(buckets []int, storageID in
}

func (o *oramNodeServer) evict(paths []int, storageID int) error {
o.storageHandler.LockStorage(storageID)
defer o.storageHandler.UnlockStorage(storageID)
log.Debug().Msgf("Evicting with paths %v and storageID %d", paths, storageID)
beginEvictionCommand, err := newReplicateBeginEvictionCommand(paths, storageID)
if err != nil {
Expand Down Expand Up @@ -238,13 +272,39 @@ func (o *oramNodeServer) getDistinctPathsInBatch(requests []*pb.BlockRequest) []
return pathList
}

type blockOffsetResponse struct {
bucketID int
offset int
isReal bool
blockFound string
err error
}

func (o *oramNodeServer) asyncGetBlockOffset(bucketID int, storageID int, blocks []string, responseChan chan blockOffsetResponse) {
offset, isReal, blockFound, err := o.storageHandler.GetBlockOffset(bucketID, storageID, blocks)
responseChan <- blockOffsetResponse{bucketID: bucketID, offset: offset, isReal: isReal, blockFound: blockFound, err: err}
}

type readBlockResponse struct {
block string
value string
err error
}

func (o *oramNodeServer) asyncReadBlock(block string, bucketID int, storageID int, offset int, responseChan chan readBlockResponse) {
value, err := o.storageHandler.ReadBlock(bucketID, storageID, offset)
responseChan <- readBlockResponse{block: block, value: value, err: err}
}

func (o *oramNodeServer) ReadPath(ctx context.Context, request *pb.ReadPathRequest) (*pb.ReadPathReply, error) {
if o.raftNode.State() != raft.Leader {
return nil, fmt.Errorf(commonerrs.NotTheLeaderError)
}
log.Debug().Msgf("Received read path request %v", request)
tracer := otel.Tracer("")
ctx, span := tracer.Start(ctx, "oramnode read path request")
o.storageHandler.LockStorage(int(request.StorageId))
defer o.storageHandler.UnlockStorage(int(request.StorageId))

var blocks []string
for _, request := range request.Requests {
Expand Down Expand Up @@ -272,34 +332,44 @@ func (o *oramNodeServer) ReadPath(ctx context.Context, request *pb.ReadPathReque
return nil, fmt.Errorf("could not get bucket ids in the paths; %v", err)
}
_, getBlockOffsetsSpan := tracer.Start(ctx, "get block offsets")
offsetListResponseChan := make(chan blockOffsetResponse)
for _, bucketID := range buckets {
offset, isReal, blockFound, err := o.storageHandler.GetBlockOffset(bucketID, int(request.StorageId), blocks)
if err != nil {
go o.asyncGetBlockOffset(bucketID, int(request.StorageId), blocks, offsetListResponseChan)
}
getBlockOffsetsSpan.End()
for i := 0; i < len(buckets); i++ {
response := <-offsetListResponseChan
if response.err != nil {
return nil, fmt.Errorf("could not get offset from storage")
}
if isReal {
realBlockBucketMapping[bucketID] = blockFound
if response.isReal {
realBlockBucketMapping[response.bucketID] = response.blockFound
}
offsetList[bucketID] = offset
offsetList[response.bucketID] = response.offset
}
getBlockOffsetsSpan.End()
log.Debug().Msgf("Got offsets %v", offsetList)

returnValues := make(map[string]string) // map of block to value
for _, block := range blocks {
returnValues[block] = ""
}
_, readBlocksSpan := tracer.Start(ctx, "read blocks")
readBlockResponseChan := make(chan readBlockResponse)
realReadCount := 0
for _, bucketID := range buckets {
if block, exists := realBlockBucketMapping[bucketID]; exists {
value, err := o.storageHandler.ReadBlock(bucketID, int(request.StorageId), offsetList[bucketID])
if err != nil {
return nil, err
}
returnValues[block] = value
go o.asyncReadBlock(block, bucketID, int(request.StorageId), offsetList[bucketID], readBlockResponseChan)
realReadCount++
} else {
o.storageHandler.ReadBlock(bucketID, int(request.StorageId), offsetList[bucketID])
go o.storageHandler.ReadBlock(bucketID, int(request.StorageId), offsetList[bucketID])
}
}
for i := 0; i < realReadCount; i++ {
response := <-readBlockResponseChan
if response.err != nil {
return nil, err
}
returnValues[response.block] = response.value
}
readBlocksSpan.End()
log.Debug().Msgf("Going to return values %v", returnValues)
Expand Down
18 changes: 12 additions & 6 deletions pkg/oramnode/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ type mockStorageHandler struct {
levelCount int
maxAccessCount int
latestReadBlock int
writeBucketFunc func(bucketID int, storageID int, ReadBucketBlocks map[string]string, shardNodeBlocks map[string]string, isAtomic bool) (writtenBlocks map[string]string, err error)
writeBucketFunc func(bucketID int, storageID int, ReadBucketBlocks map[string]string, shardNodeBlocks map[string]string) (writtenBlocks map[string]string, err error)
readBlockAccessedOffsets []int
}

Expand All @@ -112,7 +112,7 @@ func newMockStorageHandler(levelCount int, maxAccessCount int) *mockStorageHandl
levelCount: levelCount,
maxAccessCount: maxAccessCount,
latestReadBlock: 0,
writeBucketFunc: func(_ int, _ int, ReadBucketBlocks map[string]string, shardNodeBlocks map[string]string, _ bool) (writtenBlocks map[string]string, err error) {
writeBucketFunc: func(_ int, _ int, ReadBucketBlocks map[string]string, shardNodeBlocks map[string]string) (writtenBlocks map[string]string, err error) {
writtenBlocks = make(map[string]string)
for block, value := range ReadBucketBlocks {
writtenBlocks[block] = value
Expand All @@ -125,7 +125,7 @@ func newMockStorageHandler(levelCount int, maxAccessCount int) *mockStorageHandl
}
}

func (m *mockStorageHandler) withCustomWriteFunc(customeWriteFunc func(bucketID int, storageID int, ReadBucketBlocks map[string]string, shardNodeBlocks map[string]string, isAtomic bool) (writtenBlocks map[string]string, err error)) *mockStorageHandler {
func (m *mockStorageHandler) withCustomWriteFunc(customeWriteFunc func(bucketID int, storageID int, ReadBucketBlocks map[string]string, shardNodeBlocks map[string]string) (writtenBlocks map[string]string, err error)) *mockStorageHandler {
m.writeBucketFunc = customeWriteFunc
return m
}
Expand All @@ -134,6 +134,12 @@ func (m *mockStorageHandler) GetMaxAccessCount() int {
return m.maxAccessCount
}

func (m *mockStorageHandler) LockStorage(storageID int) {
}

func (m *mockStorageHandler) UnlockStorage(storageID int) {
}

func (m *mockStorageHandler) GetRandomPathAndStorageID(context.Context) (path int, storageID int) {
return 0, 0
}
Expand All @@ -155,8 +161,8 @@ func (m *mockStorageHandler) ReadBucket(bucketID int, storageID int) (blocks map
return blocks, nil
}

func (m *mockStorageHandler) WriteBucket(bucketID int, storageID int, readBucketBlocks map[string]string, shardNodeBlocks map[string]string, isAtomic bool) (writtenBlocks map[string]string, err error) {
return m.writeBucketFunc(bucketID, storageID, readBucketBlocks, shardNodeBlocks, isAtomic)
func (m *mockStorageHandler) WriteBucket(bucketID int, storageID int, readBucketBlocks map[string]string, shardNodeBlocks map[string]string) (writtenBlocks map[string]string, err error) {
return m.writeBucketFunc(bucketID, storageID, readBucketBlocks, shardNodeBlocks)
}

func (m *mockStorageHandler) ReadBlock(bucketID int, storageID int, offset int) (value string, err error) {
Expand Down Expand Up @@ -259,7 +265,7 @@ func TestWriteBackBlocksToAllBucketsPushesReceivedBlocksToTree(t *testing.T) {

func TestWriteBackBlocksToAllBucketsReturnsFalseForNotPushedReceivedBlocks(t *testing.T) {
o := startLeaderRaftNodeServer(t).withMockStorageHandler(newMockStorageHandler(3, 4).withCustomWriteFunc(
func(_ int, _ int, _ map[string]string, shardNodeBlocks map[string]string, _ bool) (writtenBlocks map[string]string, err error) {
func(_ int, _ int, _ map[string]string, shardNodeBlocks map[string]string) (writtenBlocks map[string]string, err error) {
writtenBlocks = make(map[string]string)
for block, val := range shardNodeBlocks {
if block != "a" {
Expand Down
21 changes: 20 additions & 1 deletion pkg/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"math/rand"
"strconv"
"strings"
"sync"

"github.com/dsg-uwaterloo/oblishard/pkg/config"
"github.com/redis/go-redis/v9"
Expand All @@ -24,6 +25,7 @@ type StorageHandler struct {
S int
shift int
storages map[int]*redis.Client // map of storage id to redis client
storageMus map[int]*sync.Mutex // map of storage id to mutex
key []byte
}

Expand All @@ -33,12 +35,17 @@ func NewStorageHandler(treeHeight int, Z int, S int, shift int, redisEndpoints [
for _, endpoint := range redisEndpoints {
storages[endpoint.ID] = getClient(endpoint.IP, endpoint.Port)
}
storageMus := make(map[int]*sync.Mutex)
for storageID := range storages {
storageMus[storageID] = &sync.Mutex{}
}
s := &StorageHandler{
treeHeight: treeHeight,
Z: Z,
S: S,
shift: shift,
storages: storages,
storageMus: storageMus,
key: []byte("passphrasewhichneedstobe32bytes!"),
}
return s
Expand All @@ -48,6 +55,18 @@ func (s *StorageHandler) GetMaxAccessCount() int {
return s.S
}

func (s *StorageHandler) LockStorage(storageID int) {
log.Debug().Msgf("Aquiring lock for storage %d", storageID)
s.storageMus[storageID].Lock()
log.Debug().Msgf("Aquired lock for storage %d", storageID)
}

func (s *StorageHandler) UnlockStorage(storageID int) {
log.Debug().Msgf("Releasing lock for storage %d", storageID)
s.storageMus[storageID].Unlock()
log.Debug().Msgf("Released lock for storage %d", storageID)
}

func (s *StorageHandler) InitDatabase() error {
log.Debug().Msgf("Initializing the redis database")
for _, client := range s.storages {
Expand Down Expand Up @@ -141,7 +160,7 @@ func (s *StorageHandler) ReadBucket(bucketID int, storageID int) (blocks map[str
// WriteBucket writes readBucketBlocks and shardNodeBlocks to the storage shard.
// It priorotizes readBucketBlocks to shardNodeBlocks.
// It returns the blocks that were written into the storage shard in the writtenBlocks variable.
func (s *StorageHandler) WriteBucket(bucketID int, storageID int, readBucketBlocks map[string]string, shardNodeBlocks map[string]string, isAtomic bool) (writtenBlocks map[string]string, err error) {
func (s *StorageHandler) WriteBucket(bucketID int, storageID int, readBucketBlocks map[string]string, shardNodeBlocks map[string]string) (writtenBlocks map[string]string, err error) {
log.Debug().Msgf("Writing bucket %d to storage %d", bucketID, storageID)
// TODO: It should make the counter zero
values := make([]string, s.Z+s.S)
Expand Down
4 changes: 2 additions & 2 deletions pkg/storage/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func TestGetBlockOffset(t *testing.T) {

s := NewStorageHandler(3, 1, 9, 1, []config.RedisEndpoint{{ID: 0, IP: "localhost", Port: 6379}})
s.InitDatabase()
s.WriteBucket(1, 0, map[string]string{"user1": "value1"}, map[string]string{}, true)
s.WriteBucket(1, 0, map[string]string{"user1": "value1"}, map[string]string{})

offset, isReal, blockFound, err := s.GetBlockOffset(bucketId, storageId, []string{"user8", "user10", expectedFound})
if err != nil {
Expand Down Expand Up @@ -85,7 +85,7 @@ func TestWriteBucketBlock(t *testing.T) {
s := NewStorageHandler(3, 1, 9, 1, []config.RedisEndpoint{{ID: 0, IP: "localhost", Port: 6379}})
s.InitDatabase()
expectedWrittenBlocks := map[string]string{"user1": "value1"}
writtenBlocks, _ := s.WriteBucket(bucketId, storageId, map[string]string{"user1": "value1"}, map[string]string{"user10": "value10"}, true)
writtenBlocks, _ := s.WriteBucket(bucketId, storageId, map[string]string{"user1": "value1"}, map[string]string{"user10": "value10"})
for block := range writtenBlocks {
if _, exist := expectedWrittenBlocks[block]; !exist {
t.Errorf("%s was written", block)
Expand Down

0 comments on commit 937b8a2

Please sign in to comment.