Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change SET LOCAL gucs to set_config #1600

Merged
merged 2 commits into from
Dec 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
- #504, Add `log-level` config option. The admitted levels are: crit, error, warn and info - @steve-chavez
- #1607, Enable embedding through multiple views recursively - @wolfgangwalther
- #1598, Allow rollback of the transaction with Prefer tx=rollback - @wolfgangwalther
- #1633, Enable prepared statements for filters. When behind a connection pooler, you can disable preparing with `db-prepared-statements=false` - @steve-chavez
- #1633, #1600, Enable prepared statements for filters. When behind a connection pooler, you can disable preparing with `db-prepared-statements=false` - @steve-chavez

### Fixed

Expand Down
39 changes: 23 additions & 16 deletions src/PostgREST/Middleware.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Description : Sets CORS policy. Also the PostgreSQL GUCs, role, search_path and

module PostgREST.Middleware where

import qualified Hasql.Decoders as HD
import qualified Hasql.DynamicStatements.Statement as H
import PostgREST.Private.Common

import qualified Data.Aeson as JSON
import qualified Data.ByteString.Char8 as BS
import qualified Data.CaseInsensitive as CI
Expand All @@ -33,7 +37,7 @@ import Network.Wai.Middleware.Static (only, staticPolicy)

import PostgREST.ApiRequest (ApiRequest (..))
import PostgREST.Config (AppConfig (..))
import PostgREST.QueryBuilder (setLocalQuery, setLocalSearchPathQuery)
import PostgREST.QueryBuilder (setConfigLocal)
import PostgREST.Types (LogLevel (..))
import Protolude hiding (head, toS)
import Protolude.Conv (toS)
Expand All @@ -44,23 +48,26 @@ runPgLocals :: AppConfig -> M.HashMap Text JSON.Value ->
(ApiRequest -> H.Transaction Response) ->
ApiRequest -> H.Transaction Response
runPgLocals conf claims app req = do
H.sql $ toS . mconcat $ setSearchPathSql : setRoleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ appSettingsSql
traverse_ H.sql preReq
H.statement mempty $ H.dynamicallyParameterized
("select " <> intercalateSnippet ", " (searchPathSql : roleSql ++ claimsSql ++ [methodSql, pathSql] ++ headersSql ++ cookiesSql ++ appSettingsSql))
HD.noResult (configDbPreparedStatements conf)
traverse_ H.sql preReqSql
app req
where
methodSql = setLocalQuery mempty ("request.method", toS $ iMethod req)
pathSql = setLocalQuery mempty ("request.path", toS $ iPath req)
headersSql = setLocalQuery "request.header." <$> iHeaders req
cookiesSql = setLocalQuery "request.cookie." <$> iCookies req
claimsSql = setLocalQuery "request.jwt.claim." <$> [(c,unquoted v) | (c,v) <- M.toList claimsWithRole]
appSettingsSql = setLocalQuery mempty <$> configAppSettings conf
setRoleSql = maybeToList $ (\x ->
setLocalQuery mempty ("role", unquoted x)) <$> M.lookup "role" claimsWithRole
setSearchPathSql = setLocalSearchPathQuery (iSchema req : configDbExtraSearchPath conf)
-- role claim defaults to anon if not specified in jwt
claimsWithRole = M.union claims (M.singleton "role" anon)
anon = JSON.String . toS $ configDbAnonRole conf
preReq = (\f -> "select " <> toS f <> "();") <$> configDbPreRequest conf
methodSql = setConfigLocal mempty ("request.method", toS $ iMethod req)
pathSql = setConfigLocal mempty ("request.path", toS $ iPath req)
headersSql = setConfigLocal "request.header." <$> iHeaders req
cookiesSql = setConfigLocal "request.cookie." <$> iCookies req
claimsWithRole =
let anon = JSON.String . toS $ configDbAnonRole conf in -- role claim defaults to anon if not specified in jwt
M.union claims (M.singleton "role" anon)
claimsSql = setConfigLocal "request.jwt.claim." <$> [(c,unquoted v) | (c,v) <- M.toList claimsWithRole]
roleSql = maybeToList $ (\x -> setConfigLocal mempty ("role", unquoted x)) <$> M.lookup "role" claimsWithRole
appSettingsSql = setConfigLocal mempty <$> configAppSettings conf
searchPathSql =
let schemas = T.intercalate ", " (iSchema req : configDbExtraSearchPath conf) in
setConfigLocal mempty ("search_path", schemas)
preReqSql = (\f -> "select " <> toS f <> "();") <$> configDbPreRequest conf

-- | Log in apache format. Only requests that have a status greater than minStatus are logged.
-- | There's no way to filter logs in the apache format on wai-extra: https://hackage.haskell.org/package/wai-extra-3.0.29.2/docs/Network-Wai-Middleware-RequestLogger.html#t:OutputFormat.
Expand Down
14 changes: 12 additions & 2 deletions src/PostgREST/Private/Common.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ Description : Common helper functions.
module PostgREST.Private.Common where

import Data.Maybe
import qualified Hasql.Decoders as HD
import qualified Hasql.Encoders as HE
import qualified Hasql.Decoders as HD
import qualified Hasql.DynamicStatements.Snippet as H
import qualified Hasql.Encoders as HE
import Protolude

import Data.Foldable (foldr1)

column :: HD.Value a -> HD.Row a
column = HD.column . HD.nonNullable

Expand All @@ -23,3 +26,10 @@ param = HE.param . HE.nonNullable

arrayParam :: HE.Value a -> HE.Params [a]
arrayParam = param . HE.array . HE.dimension foldl' . HE.element . HE.nonNullable

emptySnippetOnFalse :: H.Snippet -> Bool -> H.Snippet
emptySnippetOnFalse val cond = if cond then mempty else val

intercalateSnippet :: ByteString -> [H.Snippet] -> H.Snippet
intercalateSnippet _ [] = mempty
intercalateSnippet frag snippets = foldr1 (\a b -> a <> H.sql frag <> b) snippets
12 changes: 2 additions & 10 deletions src/PostgREST/Private/QueryFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,8 @@ import Protolude hiding (cast,
import Protolude.Conv (toS)
import Text.InterpolatedString.Perl6 (qc)

import qualified Hasql.Encoders as HE

import Data.Foldable (foldr1)
import qualified Hasql.Encoders as HE
import PostgREST.Private.Common

noLocationF :: SqlFragment
noLocationF = "array[]::text[]"
Expand Down Expand Up @@ -250,10 +249,3 @@ unknownEncoder = H.encoderAndParam (HE.nonNullable HE.unknown)

unknownLiteral :: Text -> H.Snippet
unknownLiteral = unknownEncoder . encodeUtf8

emptySnippetOnFalse :: H.Snippet -> Bool -> H.Snippet
emptySnippetOnFalse val cond = if cond then mempty else val

intercalateSnippet :: SqlFragment -> [H.Snippet] -> H.Snippet
intercalateSnippet _ [] = mempty
intercalateSnippet frag snippets = foldr1 (\a b -> a <> H.sql frag <> b) snippets
15 changes: 6 additions & 9 deletions src/PostgREST/QueryBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ module PostgREST.QueryBuilder (
, readRequestToCountQuery
, requestToCallProcQuery
, limitedQuery
, setLocalQuery
, setLocalSearchPathQuery
, setConfigLocal
) where

import qualified Data.ByteString.Char8 as BS
Expand All @@ -27,6 +26,7 @@ import qualified Hasql.DynamicStatements.Snippet as H
import Data.Tree (Tree (..))

import Data.Maybe
import PostgREST.Private.Common
import PostgREST.Private.QueryFragment
import PostgREST.Types
import Protolude hiding (cast, intercalate,
Expand Down Expand Up @@ -173,10 +173,7 @@ readRequestToCountQuery (Node (Select{from=qi, where_=logicForest}, _) _) =
limitedQuery :: H.Snippet -> Maybe Integer -> H.Snippet
limitedQuery query maxRows = query <> H.sql (maybe mempty (\x -> " LIMIT " <> BS.pack (show x)) maxRows)

setLocalQuery :: Text -> (Text, Text) -> SqlQuery
setLocalQuery prefix (k, v) =
"SET LOCAL " <> pgFmtIdent (prefix <> k) <> " = " <> pgFmtLit v <> ";"

setLocalSearchPathQuery :: [Text] -> SqlQuery
setLocalSearchPathQuery vals =
"SET LOCAL search_path = " <> BS.intercalate ", " (pgFmtLit <$> vals) <> ";"
-- | Do a pg set_config(setting, value, true) call. This is equivalent to a SET LOCAL.
setConfigLocal :: Text -> (Text, Text) -> H.Snippet
setConfigLocal prefix (k, v) =
"set_config(" <> unknownLiteral (prefix <> k) <> ", " <> unknownLiteral v <> ", true)"