Skip to content

Commit

Permalink
Add app name only on connection and listener and add doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
laurenceisla committed Jun 22, 2023
1 parent 4db8c39 commit bb0c835
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 16 deletions.
7 changes: 4 additions & 3 deletions src/PostgREST/AppState.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ import Data.Time (ZonedTime, defaultTimeLocale, formatTime,
import Data.Time.Clock (UTCTime, getCurrentTime)

import PostgREST.Config (AppConfig (..),
readAppConfig)
readAppConfig,
addPgrstVerToDbUri)
import PostgREST.Config.Database (queryDbSettings,
queryPgVersion,
queryRoleSettings)
Expand Down Expand Up @@ -136,7 +137,7 @@ initPool AppConfig{..} =
(fromIntegral configDbPoolAcquisitionTimeout)
(fromIntegral configDbPoolMaxLifetime)
(fromIntegral configDbPoolMaxIdletime)
(toUtf8 configDbUri)
(toUtf8 $ addPgrstVerToDbUri configDbUri)

-- | Run an action with a database connection.
usePool :: AppState -> SQL.Session a -> IO (Either SQL.UsageError a)
Expand Down Expand Up @@ -418,7 +419,7 @@ listener appState = do

-- forkFinally allows to detect if the thread dies
void . flip forkFinally (handleFinally dbChannel) $ do
dbOrError <- acquire $ toUtf8 configDbUri
dbOrError <- acquire $ toUtf8 (addPgrstVerToDbUri configDbUri)
case dbOrError of
Right db -> do
logWithZTime appState $ "Listening for notifications on the " <> dbChannel <> " channel"
Expand Down
34 changes: 23 additions & 11 deletions src/PostgREST/Config.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ module PostgREST.Config
, readPGRSTEnvironment
, toURI
, parseSecret
, addPgrstVerToDbUri
) where

import qualified Crypto.JOSE.Types as JOSE
Expand All @@ -47,7 +48,7 @@ import Data.List (lookup)
import Data.List.NonEmpty (fromList, toList)
import Data.Maybe (fromJust)
import Data.Scientific (floatingOrInteger)
import Network.URI (escapeURIString, isUnescapedInURI, parseURI, uriQuery)
import Network.URI (parseURI, uriQuery)
import Numeric (readOct, showOct)
import System.Environment (getEnvironment)
import System.Posix.Types (FileMode)
Expand Down Expand Up @@ -221,7 +222,7 @@ readAppConfig dbSettings optPath prevDbUri roleSettings roleIsolationLvl = do
decodeLoadFiles :: AppConfig -> IO AppConfig
decodeLoadFiles parsedConfig =
decodeJWKS <$>
(decodeSecret =<< readSecretFile =<< addPgrstVerToDbUri =<< readDbUriFile prevDbUri parsedConfig)
(decodeSecret =<< readSecretFile =<< readDbUriFile prevDbUri parsedConfig)

parser :: Maybe FilePath -> Environment -> [(Text, Text)] -> RoleSettings -> RoleIsolationLvl -> C.Parser C.Config AppConfig
parser optPath env dbSettings roleSettings roleIsolationLvl =
Expand Down Expand Up @@ -464,17 +465,28 @@ readPGRSTEnvironment =
M.map T.pack . M.fromList . filter (isPrefixOf "PGRST_" . fst) <$> getEnvironment

-- | Allows querying the PostgREST version in SQL by adding `fallback_application_name` to the connection string
addPgrstVerToDbUri :: AppConfig -> IO AppConfig
addPgrstVerToDbUri conf = pure $ conf { configDbUri = dbUriWithFallAppName }
--
-- >>> addPgrstVerToDbUri "postgres://user:pass@host:5432/postgres"
-- postgres://user:pass@host:5432/postgres?fallback_application_name=PostgREST%20...
--
-- >>> addPgrstVerToDbUri "postgres://user:pass@host:5432/postgres?"
-- postgres://user:pass@host:5432/postgres?fallback_application_name=PostgREST%20...
--
-- >>> addPgrstVerToDbUri "postgres:///postgres?host=host&port=5432"
-- postgres:///postgres?host=host&port=5432&fallback_application_name=PostgREST%20...
--
-- >>> addPgrstVerToDbUri "host=host port=5432 dbname=postgres"
-- host=host port=5432 dbname=postgres fallback_application_name='PostgREST ...'
addPgrstVerToDbUri :: Text -> Text
addPgrstVerToDbUri dbUri = dbUriWithFallAppName
where
dbUriWithFallAppName = dbUri <>
case uriQuery <$> parseURI (toS dbUri) of
Nothing -> " " <> keyValStr
Just "" -> "?" <> uriStr
Just "?" -> uriStr
_ -> "&" <> uriStr
dbUri = configDbUri conf
uriStr = toS $ escapeURIString isUnescapedInURI $ toS $ pKeyWord <> pgrstVer
keyValStr = pKeyWord <> "'" <> pgrstVer <> "'"
Nothing -> " " <> keyValFmt -- Assume key/value connection string if the uri is not valid
Just "" -> "?" <> uriFmt
Just "?" -> uriFmt
_ -> "&" <> uriFmt
uriFmt = T.replace " " "%20" $ pKeyWord <> pgrstVer
keyValFmt = pKeyWord <> "'" <> pgrstVer <> "'"
pKeyWord = "fallback_application_name="
pgrstVer = "PostgREST " <> T.decodeUtf8 prettyVersion
1 change: 1 addition & 0 deletions test/doc/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ main =
, "src/PostgREST/ApiRequest/QueryParams.hs"
, "src/PostgREST/Error.hs"
, "src/PostgREST/MediaType.hs"
, "src/PostgREST/Config.hs"
]
4 changes: 2 additions & 2 deletions test/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ def test_get_pgrst_version_with_uri_connection_string(dburi, defaultenv):

with run(env=env) as postgrest:
response = postgrest.session.post("/rpc/get_pgrst_version")
assert response.text.startswith('"PostgREST')
assert response.text.startswith('"PostgREST ')


def test_get_pgrst_version_with_keyval_connection_string(dburi, defaultenv):
Expand All @@ -1025,4 +1025,4 @@ def test_get_pgrst_version_with_keyval_connection_string(dburi, defaultenv):

with run(env=env) as postgrest:
response = postgrest.session.post("/rpc/get_pgrst_version")
assert response.text.startswith('"PostgREST')
assert response.text.startswith('"PostgREST ')

0 comments on commit bb0c835

Please sign in to comment.