Skip to content

Commit

Permalink
Add fallback_application_name to db-uri
Browse files Browse the repository at this point in the history
  • Loading branch information
laurenceisla committed Jun 14, 2023
1 parent 11a9849 commit b62fc36
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 1 deletion.
20 changes: 19 additions & 1 deletion src/PostgREST/Config.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import Data.List (lookup)
import Data.List.NonEmpty (fromList, toList)
import Data.Maybe (fromJust)
import Data.Scientific (floatingOrInteger)
import Network.URI (escapeURIString, isUnescapedInURI, parseURI, uriQuery)
import Numeric (readOct, showOct)
import System.Environment (getEnvironment)
import System.Posix.Types (FileMode)
Expand All @@ -60,6 +61,7 @@ import PostgREST.Config.Proxy (Proxy (..),
import PostgREST.MediaType (MediaType (..), toMime)
import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier, dumpQi,
toQi)
import PostgREST.Version (prettyVersion)

import Protolude hiding (Proxy, toList)

Expand Down Expand Up @@ -219,7 +221,7 @@ readAppConfig dbSettings optPath prevDbUri roleSettings roleIsolationLvl = do
decodeLoadFiles :: AppConfig -> IO AppConfig
decodeLoadFiles parsedConfig =
decodeJWKS <$>
(decodeSecret =<< readSecretFile =<< readDbUriFile prevDbUri parsedConfig)
(decodeSecret =<< readSecretFile =<< addPgrstVerToDbUri =<< readDbUriFile prevDbUri parsedConfig)

parser :: Maybe FilePath -> Environment -> [(Text, Text)] -> RoleSettings -> RoleIsolationLvl -> C.Parser C.Config AppConfig
parser optPath env dbSettings roleSettings roleIsolationLvl =
Expand Down Expand Up @@ -460,3 +462,19 @@ type Environment = M.Map [Char] Text
readPGRSTEnvironment :: IO Environment
readPGRSTEnvironment =
M.map T.pack . M.fromList . filter (isPrefixOf "PGRST_" . fst) <$> getEnvironment

-- | Allows querying the PostgREST version in SQL by adding `fallback_application_name` to the connection string
addPgrstVerToDbUri :: AppConfig -> IO AppConfig
addPgrstVerToDbUri conf = pure $ conf { configDbUri = dbUriWithFallAppName }
where
dbUriWithFallAppName = dbUri <>
case uriQuery <$> parseURI (toS dbUri) of
Nothing -> " " <> keyValStr
Just "" -> "?" <> uriStr
Just "?" -> uriStr
_ -> "&" <> uriStr
dbUri = configDbUri conf
uriStr = toS $ escapeURIString isUnescapedInURI $ toS $ pKeyWord <> pgrstVer
keyValStr = pKeyWord <> "'" <> pgrstVer <> "'"
pKeyWord = "fallback_application_name="
pgrstVer = "PostgREST " <> T.decodeUtf8 prettyVersion
9 changes: 9 additions & 0 deletions test/io/fixtures.sql
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,12 @@ create or replace function migrate_function() returns void as $_$
$$ language sql;
notify pgrst, 'reload schema';
$_$ language sql security definer;

create or replace function get_pgrst_version() returns text
language sql
as $$
select application_name
from pg_stat_activity
where application_name ilike 'postgrest%'
limit 1;
$$
28 changes: 28 additions & 0 deletions test/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,3 +998,31 @@ def test_openapi_in_big_schema(defaultenv):
with run(env=env) as postgrest:
response = postgrest.session.get("/")
assert response.status_code == 200

def test_get_pgrst_version_with_uri_connection_string(dburi, defaultenv):
"The fallback_application_name should be added to the db-uri if it has a URI format"
defaultenv_without_libpq = {
key: value
for key, value in defaultenv.items()
if key not in ["PGDATABASE", "PGHOST", "PGUSER"]
}
env = {**defaultenv_without_libpq, "PGRST_DB_URI": dburi.decode()}

with run(env=env) as postgrest:
response = postgrest.session.post("/rpc/get_pgrst_version")
assert response.text.startswith('"PostgREST')


def test_get_pgrst_version_with_keyval_connection_string(dburi, defaultenv):
"The fallback_application_name should be added to the db-uri if it has a keyword/value format"
uri = f'dbname={defaultenv["PGDATABASE"]} host={defaultenv["PGHOST"]} user={defaultenv["PGUSER"]}'
defaultenv_without_libpq = {
key: value
for key, value in defaultenv.items()
if key not in ["PGDATABASE", "PGHOST", "PGUSER"]
}
env = {**defaultenv_without_libpq, "PGRST_DB_URI": uri}

with run(env=env) as postgrest:
response = postgrest.session.post("/rpc/get_pgrst_version")
assert response.text.startswith('"PostgREST')

0 comments on commit b62fc36

Please sign in to comment.