Skip to content

Commit

Permalink
Merge pull request #1211 from input-output-hk/abailly-iohk/thread-saf…
Browse files Browse the repository at this point in the history
…e-reliability-persistence

Make persistence incremental thread-safe for Network Reliability layer
  • Loading branch information
v0d1ch committed Dec 16, 2023
2 parents 5d1593a + 1b07b81 commit 124156f
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 65 deletions.
28 changes: 17 additions & 11 deletions hydra-node/json-schemas/logs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ definitions:
- missing
- acknowledged
- localCounter
- partyIndex
- theirIndex
properties:
tag:
type: string
Expand All @@ -500,7 +500,7 @@ definitions:
type: array
items:
type: number
partyIndex:
theirIndex:
type: number
- title: BroadcastCounter
description: >-
Expand All @@ -509,13 +509,13 @@ definitions:
additionalProperties: false
required:
- tag
- partyIndex
- ourIndex
- localCounter
properties:
tag:
type: string
enum: ["BroadcastCounter"]
partyIndex:
ourIndex:
type: number
localCounter:
type: array
Expand All @@ -528,13 +528,13 @@ definitions:
additionalProperties: false
required:
- tag
- partyIndex
- ourIndex
- localCounter
properties:
tag:
type: string
enum: ["BroadcastPing"]
partyIndex:
ourIndex:
type: number
localCounter:
type: array
Expand All @@ -549,7 +549,8 @@ definitions:
- tag
- acknowledged
- localCounter
- partyIndex
- theirIndex
- ourIndex
properties:
tag:
type: string
Expand All @@ -562,7 +563,9 @@ definitions:
type: array
items:
type: number
partyIndex:
theirIndex:
type: number
ourIndex:
type: number
- title: ClearedMessageQueue
description: >-
Expand Down Expand Up @@ -590,7 +593,8 @@ definitions:
- tag
- acknowledged
- localCounter
- partyIndex
- theirIndex
- ourIndex
properties:
tag:
type: string
Expand All @@ -603,7 +607,9 @@ definitions:
type: array
items:
type: number
partyIndex:
theirIndex:
type: number
ourIndex:
type: number
- title: ReliabilityFailedToFindMsg
description: >-
Expand Down Expand Up @@ -2063,7 +2069,7 @@ definitions:
enum: ["WaitOnContestationDeadline"]
- title: WaitOnTxs
description: >-
Some transactions from a proposed snapshot have not been seen yet
Some transactions from a proposed snapshot have not been seen yet.
type: object
additionalProperties: false
required:
Expand Down
81 changes: 44 additions & 37 deletions hydra-node/src/Hydra/Network/Reliability.hs
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,16 @@ import Cardano.Binary (serialize')
import Cardano.Crypto.Util (SignableRepresentation (getSignableRepresentation))
import Control.Concurrent.Class.MonadSTM (
MonadSTM (readTQueue, writeTQueue),
modifyTVar',
newTQueueIO,
newTVarIO,
readTVarIO,
writeTVar,
)
import Control.Tracer (Tracer)
import Data.IntMap qualified as IMap
import Data.Sequence.Strict ((|>))
import Data.Sequence.Strict qualified as Seq
import Data.Vector (
Vector,
elemIndex,
Expand All @@ -116,13 +119,13 @@ data ReliableMsg msg = ReliableMsg
-- ^ Vector of highest known counter for each known party. Serves as announcement of
-- which messages the sender of `ReliableMsg` has seen. The individual counters have
-- nothing to do with the `message` also included in this.
, message :: msg
, payload :: msg
}
deriving stock (Eq, Show, Generic)
deriving anyclass (ToJSON, FromJSON)

instance ToCBOR msg => ToCBOR (ReliableMsg msg) where
toCBOR ReliableMsg{knownMessageIds, message} = toCBOR knownMessageIds <> toCBOR message
toCBOR ReliableMsg{knownMessageIds, payload} = toCBOR knownMessageIds <> toCBOR payload

instance FromCBOR msg => FromCBOR (ReliableMsg msg) where
fromCBOR = ReliableMsg <$> fromCBOR <*> fromCBOR
Expand All @@ -135,11 +138,11 @@ instance ToCBOR msg => SignableRepresentation (ReliableMsg msg) where
-- __NOTE__: Log items are documented in a YAML schema file which is not
-- currently public, but should be.
data ReliabilityLog
= Resending {missing :: Vector Int, acknowledged :: Vector Int, localCounter :: Vector Int, partyIndex :: Int}
| BroadcastCounter {partyIndex :: Int, localCounter :: Vector Int}
| BroadcastPing {partyIndex :: Int, localCounter :: Vector Int}
| Received {acknowledged :: Vector Int, localCounter :: Vector Int, partyIndex :: Int}
| Ignored {acknowledged :: Vector Int, localCounter :: Vector Int, partyIndex :: Int}
= Resending {missing :: Vector Int, acknowledged :: Vector Int, localCounter :: Vector Int, theirIndex :: Int}
| BroadcastCounter {ourIndex :: Int, localCounter :: Vector Int}
| BroadcastPing {ourIndex :: Int, localCounter :: Vector Int}
| Received {acknowledged :: Vector Int, localCounter :: Vector Int, theirIndex :: Int, ourIndex :: Int}
| Ignored {acknowledged :: Vector Int, localCounter :: Vector Int, theirIndex :: Int, ourIndex :: Int}
| ReliabilityFailedToFindMsg
{ missingMsgIndex :: Int
, sentMessagesLength :: Int
Expand Down Expand Up @@ -224,56 +227,60 @@ withReliability ::
NetworkComponent m (Authenticated (Heartbeat msg)) (Heartbeat msg) a
withReliability tracer MessagePersistence{saveAcks, loadAcks, appendMessage, loadMessages} me otherParties withRawNetwork callback action = do
acksCache <- loadAcks >>= newTVarIO
sentMessages <- loadMessages >>= newTVarIO . Seq.fromList
resendQ <- newTQueueIO
let ourIndex = fromMaybe (error "This cannot happen because we constructed the list with our party inside.") (findPartyIndex me)
let resend = writeTQueue resendQ
withRawNetwork (reliableCallback acksCache resend ourIndex) $ \network@Network{broadcast} -> do
withRawNetwork (reliableCallback acksCache sentMessages resend ourIndex) $ \network@Network{broadcast} -> do
withAsync (forever $ atomically (readTQueue resendQ) >>= broadcast) $ \_ ->
reliableBroadcast ourIndex acksCache network
reliableBroadcast sentMessages ourIndex acksCache network
where
allParties = fromList $ sort $ me : otherParties
reliableBroadcast ourIndex acksCache Network{broadcast} =
reliableBroadcast sentMessages ourIndex acksCache Network{broadcast} =
action $
Network
{ broadcast = \msg ->
case msg of
Data{} -> do
newAckCounter <- incrementAckCounter
localCounter <- atomically $ cacheMessage msg >> incrementAckCounter
saveAcks localCounter
appendMessage msg
saveAcks newAckCounter
traceWith tracer (BroadcastCounter ourIndex newAckCounter)
broadcast $ ReliableMsg newAckCounter msg
traceWith tracer BroadcastCounter{ourIndex, localCounter}
broadcast $ ReliableMsg localCounter msg
Ping{} -> do
acks <- readTVarIO acksCache
saveAcks acks
traceWith tracer (BroadcastPing ourIndex acks)
broadcast $ ReliableMsg acks msg
localCounter <- readTVarIO acksCache
saveAcks localCounter
traceWith tracer BroadcastPing{ourIndex, localCounter}
broadcast $ ReliableMsg localCounter msg
}
where
incrementAckCounter = atomically $ do
incrementAckCounter = do
acks <- readTVar acksCache
let newAcks = constructAcks acks ourIndex
writeTVar acksCache newAcks
pure newAcks

reliableCallback acksCache resend ourIndex (Authenticated (ReliableMsg acks msg) party) = do
if length acks /= length allParties
cacheMessage msg =
modifyTVar' sentMessages (|> msg)

reliableCallback acksCache sentMessages resend ourIndex (Authenticated (ReliableMsg acknowledged payload) party) = do
if length acknowledged /= length allParties
then
traceWith
tracer
ReceivedMalformedAcks
{ fromParty = party
, partyAcks = acks
, partyAcks = acknowledged
, numberOfParties = length allParties
}
else do
eShouldCallbackWithKnownAcks <- atomically $ runMaybeT $ do
loadedAcks <- lift $ readTVar acksCache
partyIndex <- hoistMaybe $ findPartyIndex party
messageAckForParty <- hoistMaybe (acks !? partyIndex)
messageAckForParty <- hoistMaybe (acknowledged !? partyIndex)
knownAckForParty <- hoistMaybe $ loadedAcks !? partyIndex
if
| isPing msg ->
| isPing payload ->
-- we do not update indices on Pings but we do propagate them
return (True, partyIndex, loadedAcks)
| messageAckForParty == knownAckForParty + 1 -> do
Expand All @@ -286,33 +293,33 @@ withReliability tracer MessagePersistence{saveAcks, loadAcks, appendMessage, loa
return (False, partyIndex, loadedAcks)

case eShouldCallbackWithKnownAcks of
Just (shouldCallback, partyIndex, knownAcks) -> do
Just (shouldCallback, theirIndex, localCounter) -> do
if shouldCallback
then do
callback (Authenticated msg party)
traceWith tracer (Received acks knownAcks partyIndex)
else traceWith tracer (Ignored acks knownAcks partyIndex)
callback Authenticated{payload, party}
traceWith tracer Received{acknowledged, localCounter, theirIndex, ourIndex}
else traceWith tracer Ignored{acknowledged, localCounter, theirIndex, ourIndex}

when (isPing msg) $
resendMessagesIfLagging resend partyIndex knownAcks acks ourIndex
when (isPing payload) $
resendMessagesIfLagging sentMessages resend theirIndex localCounter acknowledged ourIndex
Nothing -> pure ()

constructAcks acks wantedIndex =
zipWith (\ack i -> if i == wantedIndex then ack + 1 else ack) acks partyIndexes

partyIndexes = generate (length allParties) id

resendMessagesIfLagging resend partyIndex knownAcks messageAcks myIndex = do
let mmessageAckForUs = messageAcks !? myIndex
resendMessagesIfLagging sentMessages resend theirIndex knownAcks acknowledged myIndex = do
let mmessageAckForUs = acknowledged !? myIndex
let mknownAckForUs = knownAcks !? myIndex
case (mmessageAckForUs, mknownAckForUs) of
(Just messageAckForUs, Just knownAckForUs) ->
-- We resend messages if our peer notified us that it's lagging behind our
-- latest message sent
when (messageAckForUs < knownAckForUs) $ do
let missing = fromList [messageAckForUs + 1 .. knownAckForUs]
storedMessages <- loadMessages
let messages = IMap.fromList (zip [1 ..] storedMessages)
storedMessages <- readTVarIO sentMessages
let messages = IMap.fromList (zip [1 ..] $ toList storedMessages)
forM_ missing $ \idx -> do
case messages IMap.!? idx of
Nothing ->
Expand All @@ -324,9 +331,9 @@ withReliability tracer MessagePersistence{saveAcks, loadAcks, appendMessage, loa
, messageAckForUs = messageAckForUs
}
Just missingMsg -> do
let newAcks' = zipWith (\ack i -> if i == myIndex then idx else ack) knownAcks partyIndexes
traceWith tracer (Resending missing messageAcks newAcks' partyIndex)
atomically $ resend $ ReliableMsg newAcks' missingMsg
let localCounter = zipWith (\ack i -> if i == myIndex then idx else ack) knownAcks partyIndexes
traceWith tracer Resending{missing, acknowledged, localCounter, theirIndex}
atomically $ resend $ ReliableMsg localCounter missingMsg
_ -> pure ()

-- Find the index of a party in the list of all parties.
Expand Down
2 changes: 1 addition & 1 deletion hydra-node/src/Hydra/Node/Network.hs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ withNetwork tracer connectionMessages configuration callback action = do
-- * Some state already exists and is loaded,
-- * The number of parties is not the same as the number of acknowledgments saved.
configureMessagePersistence ::
(MonadIO m, MonadThrow m, FromJSON msg, ToJSON msg) =>
(MonadIO m, MonadThrow m, FromJSON msg, ToJSON msg, MonadSTM m, MonadThread m, MonadThrow (STM m)) =>
Tracer m (HydraNodeLog tx) ->
FilePath ->
Int ->
Expand Down
24 changes: 21 additions & 3 deletions hydra-node/src/Hydra/Persistence.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@ module Hydra.Persistence where

import Hydra.Prelude

import Control.Concurrent.Class.MonadSTM (newTVarIO, throwSTM, writeTVar)
import Control.Monad.Class.MonadFork (myThreadId)
import Data.Aeson qualified as Aeson
import Data.ByteString qualified as BS
import Data.ByteString.Char8 qualified as C8
import System.Directory (createDirectoryIfMissing, doesFileExist)
import System.FilePath (takeDirectory)
import UnliftIO.IO.File (withBinaryFile, writeBinaryFileDurableAtomic)

newtype PersistenceException
data PersistenceException
= PersistenceException String
| IncorrectAccessException String
deriving stock (Eq, Show)

instance Exception PersistenceException
Expand Down Expand Up @@ -53,18 +56,33 @@ data PersistenceIncremental a m = PersistenceIncremental
}

-- | Initialize persistence handle for given type 'a' at given file path.
--
-- This instance of `PersistenceIncremental` is "thread-safe" in the sense that
-- it prevents loading from a different thread once one starts `append`ing
-- through the handle. If another thread attempts to `loadAll` after this point,
-- an `IncorrectAccessException` will be raised.
createPersistenceIncremental ::
(MonadIO m, MonadThrow m) =>
forall a m.
(MonadIO m, MonadThrow m, MonadSTM m, MonadThread m, MonadThrow (STM m)) =>
FilePath ->
m (PersistenceIncremental a m)
createPersistenceIncremental fp = do
liftIO . createDirectoryIfMissing True $ takeDirectory fp
authorizedThread <- newTVarIO Nothing
pure $
PersistenceIncremental
{ append = \a -> do
tid <- myThreadId
atomically $ writeTVar authorizedThread $ Just tid
let bytes = toStrict $ Aeson.encode a <> "\n"
liftIO $ withBinaryFile fp AppendMode (`BS.hPut` bytes)
, loadAll =
, loadAll = do
tid <- myThreadId
atomically $ do
authTid <- readTVar authorizedThread
when (isJust authTid && authTid /= Just tid) $
throwSTM (IncorrectAccessException $ "Trying to load persisted data in " <> fp <> " from different thread")

liftIO (doesFileExist fp) >>= \case
False -> pure []
True -> do
Expand Down
4 changes: 2 additions & 2 deletions hydra-node/test/Hydra/API/ServerSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,8 @@ spec = describe "ServerSpec" $
guard $ v ^? key "headStatus" == Just (Aeson.String "Initializing")
guard $ v ^? key "snapshotUtxo" == Just expectedUtxos

-- expect the api server to load events from apiPersistence and project headStatus correctly
withTestAPIServer port alice apiPersistence tracer $ \_ -> do
newApiPersistence <- createPersistenceIncremental $ persistenceDir <> "/server-output"
withTestAPIServer port alice newApiPersistence tracer $ \_ -> do
waitForValue port $ \v -> do
guard $ v ^? key "headStatus" == Just (Aeson.String "Initializing")
guard $ v ^? key "snapshotUtxo" == Just expectedUtxos
Expand Down
Loading

0 comments on commit 124156f

Please sign in to comment.