diff --git a/chainweb.cabal b/chainweb.cabal index b8c91b26ed..6943052fd7 100644 --- a/chainweb.cabal +++ b/chainweb.cabal @@ -249,6 +249,8 @@ library , Chainweb.Utils.RequestLog , Chainweb.Utils.Rule , Chainweb.Utils.Serialization + , Chainweb.Utils.Throttling + , Chainweb.Utils.TokenLimiting , Chainweb.VerifierPlugin , Chainweb.VerifierPlugin.Allow , Chainweb.VerifierPlugin.Hyperlane.Announcement @@ -361,6 +363,7 @@ library , base64-bytestring-kadena == 0.1 , binary >= 0.8 , bytestring >= 0.10.12 + , cache >= 0.1.1.2 , case-insensitive >= 1.2 , cassava >= 0.5.1 , chainweb-storage >= 0.1 @@ -434,7 +437,7 @@ library , time >= 1.12.2 , tls >=1.9 , tls-session-manager >= 0.0 - , token-bucket >= 0.1 + , token-limiter >= 0.1 , transformers >= 0.5 , trifecta >= 2.1 , unliftio >= 0.2 diff --git a/src/Chainweb/Chainweb.hs b/src/Chainweb/Chainweb.hs index f596ffc8b9..bb78a2acaf 100644 --- a/src/Chainweb/Chainweb.hs +++ b/src/Chainweb/Chainweb.hs @@ -185,6 +185,7 @@ import P2P.Peer import qualified Pact.Types.ChainMeta as P import qualified Pact.Types.Command as P +import qualified Chainweb.Utils.Throttling as Throttling -- -------------------------------------------------------------------------- -- -- Chainweb Resources @@ -718,28 +719,29 @@ runChainweb cw nowServing = do logg Warn $ "OpenAPI spec validation enabled on service API, make sure this is what you want" mkValidationMiddleware else return id - - concurrentlies_ - - -- 1. Start serving Rest API - [ (if tls then serve else servePlain) - $ httpLog - . throttle (_chainwebPutPeerThrottler cw) - . throttle (_chainwebMempoolThrottler cw) - . throttle (_chainwebThrottler cw) - . p2pRequestSizeLimit - . p2pValidationMiddleware - - -- 2. Start Clients (with a delay of 500ms) - , threadDelay 500000 >> clients - - -- 3. Start serving local API - , threadDelay 500000 >> do - serveServiceApi - $ serviceHttpLog - . serviceRequestSizeLimit - . serviceApiValidationMiddleware - ] + Throttling.throttleMiddleware (logFunction $ _chainwebLogger cw) "p2p" p2pThrottleEconomy $ \p2pThrottler -> + Throttling.throttleMiddleware (logFunction $ _chainwebLogger cw) "service" serviceThrottleEconomy $ \serviceThrottler -> + + concurrentlies_ + + -- 1. Start serving Rest API + [ (if tls then serve else servePlain) + $ httpLog + . p2pRequestSizeLimit + . p2pThrottler + . p2pValidationMiddleware + + -- 2. Start Clients (with a delay of 500ms) + , threadDelay 500000 >> clients + + -- 3. Start serving local API + , threadDelay 500000 >> do + serveServiceApi + $ serviceHttpLog + . serviceRequestSizeLimit + . serviceThrottler + . serviceApiValidationMiddleware + ] where @@ -864,6 +866,22 @@ runChainweb cw nowServing = do setMaxLengthForRequest (\_req -> pure $ Just $ 2 * 1024 * 1024) -- 2MB defaultRequestSizeLimitSettings + p2pThrottleEconomy = Throttling.ThrottleEconomy + { Throttling.requestCost = 10 + , Throttling.requestBody100ByteCost = 1 + , Throttling.responseBody100ByteCost = 2 + , Throttling.maxBudget = 35_000 + , Throttling.freeRate = 35_000 + } + + serviceThrottleEconomy = Throttling.ThrottleEconomy + { Throttling.requestCost = 10 + , Throttling.requestBody100ByteCost = 1 + , Throttling.responseBody100ByteCost = 2 + , Throttling.maxBudget = 50_000 + , Throttling.freeRate = 50_000 + } + -- Request size limit for the P2P API -- -- NOTE: this may need to have to be adjusted if the p2p limits for batch diff --git a/src/Chainweb/Utils.hs b/src/Chainweb/Utils.hs index f5f1e4a3a9..1913fbde26 100644 --- a/src/Chainweb/Utils.hs +++ b/src/Chainweb/Utils.hs @@ -233,7 +233,7 @@ import Configuration.Utils hiding (Error, Lens) import Control.Concurrent (threadDelay) import Control.Concurrent.Async import Control.Concurrent.MVar -import Control.Concurrent.TokenBucket +import Control.Concurrent.TokenLimiter import Control.DeepSeq import Control.Exception (SomeAsyncException(..), evaluate) import Control.Lens hiding ((.=)) @@ -970,9 +970,13 @@ runForeverThrottled -> IO () -> IO () runForeverThrottled logfun name burst rate a = mask $ \umask -> do - tokenBucket <- newTokenBucket + let config = defaultLimitConfig + { maxBucketTokens = fromIntegral burst + , bucketRefillTokensPerSecond = fromIntegral rate + } + tokenBucket <- newRateLimiter config logfun Debug $ "start " <> name - let runThrottled = tokenBucketWait tokenBucket burst rate >> a + let runThrottled = waitDebit config tokenBucket 1 >> a go = do forever (umask runThrottled) `catchAllSynchronous` \e -> logfun Error $ name <> " failed: " <> sshow e <> ". Restarting ..." @@ -1494,4 +1498,4 @@ unsafeHead msg = \case unsafeTail :: HasCallStack => String -> [a] -> [a] unsafeTail msg = \case _ : xs -> xs - [] -> error $ "unsafeTail: empty list: " <> msg \ No newline at end of file + [] -> error $ "unsafeTail: empty list: " <> msg diff --git a/src/Chainweb/Utils/Throttling.hs b/src/Chainweb/Utils/Throttling.hs new file mode 100644 index 0000000000..3a19fed541 --- /dev/null +++ b/src/Chainweb/Utils/Throttling.hs @@ -0,0 +1,145 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} + +module Chainweb.Utils.Throttling + ( ThrottleEconomy(..) + , ThrottledException(..) + , throttleMiddleware + ) where + +import Data.LogMessage +import Data.Text (Text) +import qualified Network.Wai as Wai +import qualified Network.Wai.Internal as Wai.Internal +import Chainweb.Utils.TokenLimiting +import Control.Exception.Safe +import Network.HTTP.Types.Status +import qualified Data.ByteString as BS +import qualified Data.Text as T +import Data.Hashable +import Network.Socket (SockAddr(..)) +import qualified Data.ByteString.Builder as BSB +import System.IO.Unsafe (unsafeInterleaveIO) +import qualified Data.ByteString.Lazy as LBS + +data ThrottleEconomy = ThrottleEconomy + { requestCost :: Int + , requestBody100ByteCost :: Int + , responseBody100ByteCost :: Int + , maxBudget :: Int + , freeRate :: Int + } + +data ThrottledException = ThrottledException Text + deriving (Show, Exception) + +hashWithSalt' :: Hashable a => a -> Int -> Int +hashWithSalt' = flip hashWithSalt + +newtype HashableSockAddr = HashableSockAddr SockAddr + deriving newtype Eq +instance Hashable HashableSockAddr where + hashWithSalt salt (HashableSockAddr sockAddr) = case sockAddr of + SockAddrInet port hostAddr -> + -- constructor tag + hashWithSalt' (1 :: Word) + . hashWithSalt' (fromIntegral port :: Word) + . hashWithSalt' hostAddr + $ salt + SockAddrInet6 port flowInfo hostAddr scopeId -> + hashWithSalt' (2 :: Word) + . hashWithSalt' (fromIntegral port :: Word) + . hashWithSalt' flowInfo + . hashWithSalt' hostAddr + . hashWithSalt' scopeId + $ salt + SockAddrUnix str -> + hashWithSalt' (3 :: Word) + . hashWithSalt' str + $ salt + +debitOrDie :: Hashable k => TokenLimitMap k -> (Text, k) -> Int -> IO () +debitOrDie tokenLimitMap (name, k) cost = do + tryDebit cost k tokenLimitMap >>= \case + True -> return () + False -> throwIO (ThrottledException name) + +throttleMiddleware :: LogFunction -> Text -> ThrottleEconomy -> (Wai.Middleware -> IO r) -> IO r +throttleMiddleware logfun name ThrottleEconomy{..} k = + withTokenLimitMap logfun ("request-throttler-" <> name) limitCachePolicy limitConfig $ \tokenLimitMap -> do + k $ middleware tokenLimitMap + where + middleware tokenLimitMap app request respond = do + debitOrDie' requestCost + meteredRequest <- meterRequest debitOrDie' request + app meteredRequest (meterResponse debitOrDie' respond) + where + host = HashableSockAddr $ Wai.remoteHost request + hostText = T.pack $ show (Wai.remoteHost request) + debitOrDie' = debitOrDie tokenLimitMap (hostText, host) + + limitCachePolicy = TokenLimitCachePolicy 30 + limitConfig = defaultLimitConfig + { maxBucketTokens = maxBudget + , initialBucketTokens = maxBudget + , bucketRefillTokensPerSecond = freeRate + } + + meterRequest debit request + | requestBody100ByteCost == 0 = return request + | otherwise = case Wai.requestBodyLength request of + Wai.KnownLength requestBodyLen -> do + () <- debit $ (requestBody100ByteCost * fromIntegral requestBodyLen) `div` 100 + return request + Wai.ChunkedBody -> + return (Wai.setRequestBodyChunks (getMeteredRequestBodyChunk debit request) request) + + getMeteredRequestBodyChunk debit request = do + chunk <- Wai.getRequestBodyChunk request + -- charge *after* receiving a request body chunk + () <- debit $ (requestBody100ByteCost * BS.length chunk) `div` 100 + return chunk + + -- the only way to match on responses without using internal API is via + -- responseToStream, which converts any response into a streaming response. + -- unfortunately: + -- * all of the responses produced by servant are builder responses, + -- not streaming responses + -- * streaming responses are not supported by http2; we try to use http2 + -- (see https://hackage.haskell.org/package/http2-5.3.5/docs/src/Network.HTTP2.Server.Run.html#runIO) + -- * a streaming response body may be less efficient than a builder + -- response body, in particular because it needs to use a chunked + -- encoding + -- + meterResponse + :: (Int -> IO ()) + -> (Wai.Response -> IO a) -> Wai.Response -> IO a + meterResponse _ respond response + | responseBody100ByteCost == 0 = respond response + meterResponse debit respond (Wai.Internal.ResponseStream status headers responseBody) = do + respond + $ Wai.responseStream status headers + $ meterStreamingResponseBody debit responseBody + meterResponse debit respond (Wai.Internal.ResponseBuilder status headers responseBody) = do + respond + <$> Wai.responseLBS status headers . LBS.fromChunks + =<< meterBuilderResponseBody debit (LBS.toChunks $ BSB.toLazyByteString responseBody) + meterResponse _ _ _ = error "unrecognized response type" + + meterStreamingResponseBody debit responseBody send flush = responseBody + (\chunkBSBuilder -> do + let chunkBS = BS.toStrict (BSB.toLazyByteString chunkBSBuilder) + () <- debit $ (responseBody100ByteCost * BS.length chunkBS) `div` 100 + -- charger *before* sending a response body chunk + send (BSB.byteString chunkBS) + ) + flush + meterBuilderResponseBody debit (chunk:chunks) = unsafeInterleaveIO $ do + () <- debit $ (responseBody100ByteCost * BS.length chunk) `div` 100 + (chunk:) <$> meterBuilderResponseBody debit chunks + meterBuilderResponseBody _ [] = return [] diff --git a/src/Chainweb/Utils/TokenLimiting.hs b/src/Chainweb/Utils/TokenLimiting.hs new file mode 100644 index 0000000000..6cec5ea153 --- /dev/null +++ b/src/Chainweb/Utils/TokenLimiting.hs @@ -0,0 +1,151 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} + +-- | A concurrent, expiring map from @k@ to RateLimiter. +module Chainweb.Utils.TokenLimiting +( TokenLimitMap +, TokenLimitCachePolicy(..) +, LimitConfig(..) +, withTokenLimitMap +, startTokenLimitMap +, stopTokenLimitMap +, defaultLimitConfig +, makeLimitConfig +, getLimiter +, withLimiter +, getLimitPolicy +, tryDebit +, waitDebit +, penalize +) where + +import Control.Concurrent.Async (Async) +import qualified Control.Concurrent.Async as Async +import Control.Concurrent.STM +import Control.Concurrent.TokenLimiter (LimitConfig(..), RateLimiter) +import qualified Control.Concurrent.TokenLimiter as TL +import Control.Exception +import Control.Monad +import Data.Cache (Cache) +import qualified Data.Cache as Cache +import Data.Hashable (Hashable) +import Data.Text (Text) +import GHC.Generics +import System.Clock (TimeSpec) +import qualified System.Clock as Clock + +import Chainweb.Utils +import Data.LogMessage + +data TokenLimitMap k = TokenLimitMap + { _tlmMap :: !(Cache k RateLimiter) + , _tlmLimitPolicy :: !LimitConfig + -- ^ token bucket rate limiting policy (max num tokens, refill rate, etc) + , _tlmReaper :: !(Async ()) + , _tlmCachePolicy :: !TokenLimitCachePolicy + -- ^ inactivity period before your key's rate limiter is expired from the + -- cache + } deriving (Generic) + +newtype TokenLimitCachePolicy = TokenLimitCachePolicy + { policyExpirationTime :: TimeSpec + } deriving (Generic, Eq, Ord, Num, Show) + +makeLimitConfig :: Int -> Int -> Int -> LimitConfig +makeLimitConfig mx it ref = + defaultLimitConfig + { maxBucketTokens = mx + , initialBucketTokens = it + , bucketRefillTokensPerSecond = ref + } + +withTokenLimitMap + :: (Eq k, Hashable k) + => LogFunctionText + -> Text + -> TokenLimitCachePolicy + -> LimitConfig + -> (TokenLimitMap k -> IO a) + -> IO a +withTokenLimitMap logfun mapName expPolicy@(TokenLimitCachePolicy expTSpec) lcfg act = + mask $ \restore -> do + cache <- restore (Cache.newCache (Just expTSpec)) + Async.withAsync (reap cache) $ \rtid -> do + let m = TokenLimitMap cache lcfg rtid expPolicy + restore $ act m + where + reap = reaper logfun mapName + +startTokenLimitMap + :: (Eq k, Hashable k) + => LogFunctionText + -> Text + -> TokenLimitCachePolicy + -> LimitConfig + -> IO (TokenLimitMap k) +startTokenLimitMap logfun mapName expPolicy@(TokenLimitCachePolicy expTSpec) + lcfg = do + cache <- Cache.newCache (Just expTSpec) + rtid <- Async.async (reap cache) + return $! TokenLimitMap cache lcfg rtid expPolicy + where + reap = reaper logfun mapName + +stopTokenLimitMap :: TokenLimitMap k -> IO () +stopTokenLimitMap tlm = Async.cancel t `finally` + void (Async.waitCatch t) + where + t = _tlmReaper tlm + +reaper + :: (Eq k, Hashable k) + => LogFunctionText + -> Text + -> Cache k v -> IO () +reaper logfun mapName cache = runForever logfun mapName $ do + approximateThreadDelay (2 * 60 * 1000000) -- two minute cycle time + Cache.purgeExpired cache + +getLimiter :: (Eq k, Hashable k) => k -> TokenLimitMap k -> IO RateLimiter +getLimiter k (TokenLimitMap cache limitPolicy _ cachePolicy) = + Cache.lookup' cache k >>= maybe noKey return + where + expTSpec = policyExpirationTime cachePolicy + addExpireTime = (+ expTSpec) + noKey = do + now <- Clock.getTime Clock.Monotonic + rl <- TL.newRateLimiter limitPolicy + atomically $ do + mbV <- Cache.lookupSTM False k cache now + maybe (do Cache.insertSTM k rl cache (Just $! addExpireTime now) + return rl) + (\rl' -> return rl') + mbV + + +withLimiter + :: (Eq k, Hashable k) + => k + -> TokenLimitMap k + -> (RateLimiter -> IO b) + -> IO b +withLimiter k tlm act = getLimiter k tlm >>= act + +penalize :: (Eq k, Hashable k) => Int -> k -> TokenLimitMap k -> IO Int +penalize ndebits k tlm = withLimiter k tlm $ \rl -> + TL.penalize rl ndebits + +tryDebit :: (Eq k, Hashable k) => Int -> k -> TokenLimitMap k -> IO Bool +tryDebit ndebits k tlm = withLimiter k tlm $ \rl -> + TL.tryDebit (_tlmLimitPolicy tlm) rl ndebits + +waitDebit :: (Eq k, Hashable k) => Int -> k -> TokenLimitMap k -> IO () +waitDebit ndebits k tlm = withLimiter k tlm $ \rl -> + TL.waitDebit (_tlmLimitPolicy tlm) rl ndebits + +defaultLimitConfig :: LimitConfig +defaultLimitConfig = TL.defaultLimitConfig + +getLimitPolicy :: TokenLimitMap k -> LimitConfig +getLimitPolicy = _tlmLimitPolicy