From 6733f40544eaef83c68e3b3075947dad7f9e1850 Mon Sep 17 00:00:00 2001 From: Daniel Firth Date: Wed, 3 Apr 2024 10:43:21 +0000 Subject: [PATCH] Swap UnversionedProtocol for protocol versioned with HydraVersionData --- hydra-node/src/Hydra/Network/Ouroboros.hs | 47 ++++++++++++++++++++--- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/hydra-node/src/Hydra/Network/Ouroboros.hs b/hydra-node/src/Hydra/Network/Ouroboros.hs index 099de87567b..7fef932eec6 100644 --- a/hydra-node/src/Hydra/Network/Ouroboros.hs +++ b/hydra-node/src/Hydra/Network/Ouroboros.hs @@ -27,6 +27,7 @@ import Data.Aeson (object, withObject, (.:), (.=)) import Data.Aeson qualified as Aeson import Data.Aeson.Types qualified as Aeson import Data.Map.Strict as Map +import Data.Text qualified as T import Hydra.Logging (Tracer, nullTracer) import Hydra.Network ( Host (..), @@ -56,6 +57,7 @@ import Network.TypedProtocol.Codec ( AnyMessageAndAgency (..), ) import Network.TypedProtocol.Pipelined () +import Ouroboros.Network.CodecCBORTerm (CodecCBORTerm (..)) import Ouroboros.Network.Driver.Simple ( TraceSendRecv (..), ) @@ -81,7 +83,7 @@ import Ouroboros.Network.Mux ( RunMiniProtocol (..), mkMiniProtocolCbFromPeer, ) -import Ouroboros.Network.Protocol.Handshake.Codec (noTimeLimitsHandshake) +import Ouroboros.Network.Protocol.Handshake.Codec (cborTermVersionDataCodec, codecHandshake, noTimeLimitsHandshake) import Ouroboros.Network.Protocol.Handshake.Type (Handshake, Message (..), RefuseReason (..)) import Ouroboros.Network.Protocol.Handshake.Unversioned ( UnversionedProtocol, @@ -89,7 +91,7 @@ import Ouroboros.Network.Protocol.Handshake.Unversioned ( unversionedProtocol, unversionedProtocolDataCodec, ) -import Ouroboros.Network.Protocol.Handshake.Version (acceptableVersion, queryVersion) +import Ouroboros.Network.Protocol.Handshake.Version (Accept (..), Acceptable, Queryable, acceptableVersion, queryVersion, simpleSingletonVersions) import Ouroboros.Network.Server.Socket (AcceptedConnectionsLimit (AcceptedConnectionsLimit)) import Ouroboros.Network.Snocket (makeSocketBearer, socketSnocket) import Ouroboros.Network.Socket ( @@ -112,6 +114,41 @@ import Ouroboros.Network.Subscription qualified as Subscription import Ouroboros.Network.Subscription.Ip (SubscriptionParams (..), WithIPList (WithIPList)) import Ouroboros.Network.Subscription.Worker (LocalAddresses (LocalAddresses)) +versionNumberCodec :: CodecCBORTerm (String, Maybe Int) HydraVersionData +versionNumberCodec = CodecCBORTerm{encodeTerm, decodeTerm} + where + encodeTerm x = CBOR.TInt $ hydraVersionNumber x + + decodeTerm (CBOR.TInt x) = Right $ MkHydraVersionData x + decodeTerm _ = Left $ ("unknown tag", Nothing) + +newtype HydraVersionData = MkHydraVersionData {hydraVersionNumber :: Int} + deriving stock (Eq, Show, Generic, Ord) + +instance Acceptable HydraVersionData where + acceptableVersion a b = + if hydraVersionNumber a /= hydraVersionNumber b + then Refuse $ T.pack "Incompatible versions" + else Accept $ MkHydraVersionData (hydraVersionNumber a) + +instance Queryable HydraVersionData where + queryVersion _ = False + +dataCodecCBORTerm :: HydraVersionData -> CodecCBORTerm Text HydraVersionData +dataCodecCBORTerm _ = CodecCBORTerm{encodeTerm, decodeTerm} + where + -- We are using @CBOR.TInt@ instead of @CBOR.TInteger@, since for small + -- integers generated by QuickCheck they will be encoded as @TkInt@ and then + -- are decoded back to @CBOR.TInt@ rather than @COBR.TInteger@. The same for + -- other @CodecCBORTerm@ records. + encodeTerm (MkHydraVersionData x) = + CBOR.TInt x + + decodeTerm (CBOR.TInt x) = + Right $ MkHydraVersionData x + decodeTerm n = + Left $ T.pack $ "decodeTerm VersionData: unrecognised tag: " ++ show n + withOuroborosNetwork :: forall msg. (ToCBOR msg, FromCBOR msg) => @@ -190,12 +227,12 @@ withOuroborosNetwork tracer localHost remoteHosts networkCallback between = do chan <- newBroadcastChannel connectToNodeSocket iomgr - unversionedHandshakeCodec + (codecHandshake versionNumberCodec) noTimeLimitsHandshake - unversionedProtocolDataCodec + (cborTermVersionDataCodec dataCodecCBORTerm) networkConnectTracers (HandshakeCallbacks acceptableVersion queryVersion) - (unversionedProtocol (app chan)) + (simpleSingletonVersions (MkHydraVersionData 0) (MkHydraVersionData 1) (app chan)) sn where networkConnectTracers :: NetworkConnectTracers a v