Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow calling variadic functions with repeated parameters #1603

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
- #1512, Allow schema cache reloading with NOTIFY - @steve-chavez
- #1119, Allow config file reloading with SIGUSR2 - @steve-chavez
- #1558, Allow 'Bearer' with and without capitalization as authentication schema - @wolfgangwalther
- #1470, Allow calling RPC with variadic argument by passing repeated params - @wolfgangwalther
- #1559, No downtime when reloading the schema cache with SIGUSR1 - @steve-chavez
- #504, Add `log-level` config option. The admitted levels are: crit, error, warn and info - @steve-chavez

Expand Down
55 changes: 39 additions & 16 deletions src/PostgREST/ApiRequest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
Module : PostgREST.ApiRequest
Description : PostgREST functions to translate HTTP request to a domain type called ApiRequest.
-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE NamedFieldPuns #-}

module PostgREST.ApiRequest (
ApiRequest(..)
Expand Down Expand Up @@ -45,6 +46,7 @@ import Web.Cookie (parseCookiesText)
import Data.Ranged.Boundaries

import PostgREST.Error (ApiRequestError (..))
import PostgREST.Parsers (pRequestColumns)
import PostgREST.RangeQuery (NonnegRange, allRange, rangeGeq,
rangeLimit, rangeOffset, rangeRequested,
restrictRange)
Expand All @@ -63,7 +65,7 @@ data Action = ActionCreate | ActionRead{isHead :: Bool}
deriving Eq
-- | The target db object of a user action
data Target = TargetIdent QualifiedIdentifier
| TargetProc{tpQi :: QualifiedIdentifier, tpIsRootSpec :: Bool}
| TargetProc{tProc :: ProcDescription, tpIsRootSpec :: Bool}
| TargetDefaultSpec{tdsSchema :: Schema} -- The default spec offered at root "/"
| TargetUnknown [Text]
deriving Eq
Expand All @@ -90,7 +92,7 @@ data ApiRequest = ApiRequest {
, iLogic :: [(Text, Text)] -- ^ &and and &or parameters used for complex boolean logic
, iSelect :: Maybe Text -- ^ &select parameter used to shape the response
, iOnConflict :: Maybe Text -- ^ &on_conflict parameter used to upsert on specific unique keys
, iColumns :: Maybe Text -- ^ &columns parameter used to shape the payload
, iColumns :: S.Set FieldName -- ^ parsed colums from &columns parameter and payload
, iOrder :: [(Text, Text)] -- ^ &order parameters for each level
, iCanonicalQS :: ByteString -- ^ Alphabetized (canonical) request query string for response URLs
, iJWT :: Text -- ^ JSON Web Token
Expand All @@ -103,12 +105,13 @@ data ApiRequest = ApiRequest {
}

-- | Examines HTTP request and translates it into user intent.
userApiRequest :: NonEmpty Schema -> Maybe Text -> Request -> RequestBody -> Either ApiRequestError ApiRequest
userApiRequest confSchemas rootSpec req reqBody
userApiRequest :: NonEmpty Schema -> Maybe Text -> DbStructure -> Request -> RequestBody -> Either ApiRequestError ApiRequest
userApiRequest confSchemas rootSpec dbStructure req reqBody
| isJust profile && fromJust profile `notElem` confSchemas = Left $ UnacceptableSchema $ toList confSchemas
| isTargetingProc && method `notElem` ["HEAD", "GET", "POST"] = Left ActionInappropriate
| topLevelRange == emptyRange = Left InvalidRange
| shouldParsePayload && isLeft payload = either (Left . InvalidBody . toS) witness payload
| isLeft parsedColumns = either Left witness parsedColumns
| otherwise = Right ApiRequest {
iAction = action
, iTarget = target
Expand All @@ -131,7 +134,7 @@ userApiRequest confSchemas rootSpec req reqBody
, iLogic = [(toS k, toS $ fromJust v) | (k,v) <- qParams, isJust v, endingIn ["and", "or"] k ]
, iSelect = toS <$> join (lookup "select" qParams)
, iOnConflict = toS <$> join (lookup "on_conflict" qParams)
, iColumns = columns
, iColumns = payloadColumns
, iOrder = [(toS k, toS $ fromJust v) | (k,v) <- qParams, isJust v, endingIn ["order"] k ]
, iCanonicalQS = toS $ urlEncodeVars
. L.sortOn fst
Expand Down Expand Up @@ -174,6 +177,16 @@ userApiRequest confSchemas rootSpec req reqBody
columns
| action `elem` [ActionCreate, ActionUpdate, ActionInvoke InvPost] = toS <$> join (lookup "columns" qParams)
| otherwise = Nothing
parsedColumns = pRequestColumns columns
payloadColumns =
case (contentType, action) of
(_, ActionInvoke InvGet) -> S.fromList $ fst <$> rpcQParams
(_, ActionInvoke InvHead) -> S.fromList $ fst <$> rpcQParams
(CTOther "application/x-www-form-urlencoded", _) -> S.fromList $ map (toS . fst) $ parseSimpleQuery $ toS reqBody
_ -> case (relevantPayload, fromRight Nothing parsedColumns) of
(Just ProcessedJSON{pjKeys}, _) -> pjKeys
(Just RawJSON{}, Just cls) -> cls
_ -> S.empty
payload =
case (contentType, action) of
(_, ActionInvoke InvGet) -> Right rpcPrmsToJson
Expand All @@ -189,12 +202,18 @@ userApiRequest confSchemas rootSpec req reqBody
json <- csvToJson <$> CSV.decodeByName reqBody
note "All lines must have same number of fields" $ payloadAttributes (JSON.encode json) json
(CTOther "application/x-www-form-urlencoded", _) ->
let json = M.fromList . map (toS *** JSON.String . toS) . parseSimpleQuery $ toS reqBody
let json = paramsFromList . map (toS *** toS) . parseSimpleQuery $ toS reqBody
wolfgangwalther marked this conversation as resolved.
Show resolved Hide resolved
keys = S.fromList $ M.keys json in
Right $ ProcessedJSON (JSON.encode json) keys
(ct, _) ->
Left $ toS $ "Content-Type not acceptable: " <> toMime ct
rpcPrmsToJson = ProcessedJSON (JSON.encode $ M.fromList $ second JSON.toJSON <$> rpcQParams) (S.fromList $ fst <$> rpcQParams)
rpcPrmsToJson = ProcessedJSON (JSON.encode $ paramsFromList rpcQParams) (S.fromList $ fst <$> rpcQParams)
paramsFromList ls = M.fromListWith mergeParams $ toRpcParamsWith isVariadic ls
where
isVariadic k =
case target of
TargetProc{tProc} -> argIsVariadic tProc k
_ -> False
topLevelRange = fromMaybe allRange $ M.lookup "limit" ranges -- if no limit is specified, get all the request rows
action =
case method of
Expand Down Expand Up @@ -226,13 +245,17 @@ userApiRequest confSchemas rootSpec req reqBody
= Just $ maybe defaultSchema toS $ lookupHeader "Accept-Profile"
| otherwise = Nothing
schema = fromMaybe defaultSchema profile
target = case path of
[] -> case rootSpec of
Just pName -> TargetProc (QualifiedIdentifier schema pName) True
Nothing -> TargetDefaultSpec schema
[table] -> TargetIdent $ QualifiedIdentifier schema table
["rpc", proc] -> TargetProc (QualifiedIdentifier schema proc) False
other -> TargetUnknown other
target =
let
callFindProc proc = findProc (QualifiedIdentifier schema proc) payloadColumns (hasPrefer (show SingleObject)) $ dbProcs dbStructure
in
case path of
[] -> case rootSpec of
Just pName -> TargetProc (callFindProc pName) True
Nothing -> TargetDefaultSpec schema
[table] -> TargetIdent $ QualifiedIdentifier schema table
["rpc", pName] -> TargetProc (callFindProc pName) False
other -> TargetUnknown other

shouldParsePayload =
action `elem`
Expand Down
65 changes: 27 additions & 38 deletions src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ import PostgREST.Error (PgError (..), SimpleError (..),
errorResponseFor, singularityError)
import PostgREST.Middleware
import PostgREST.OpenAPI
import PostgREST.Parsers (pRequestColumns)
import PostgREST.QueryBuilder (limitedQuery, mutateRequestToQuery,
readRequestToCountQuery,
readRequestToQuery,
Expand All @@ -76,51 +75,38 @@ postgrest logLev refConf refDbStructure pool getTime connWorker =
Nothing -> respond . errorResponseFor $ ConnectionLostError
Just dbStructure -> do
response <- do
let apiReq = userApiRequest (configSchemas conf) (configRootSpec conf) req body
-- Need to parse ?columns early because findProc needs it to solve overloaded functions.
apiReqCols = (,) <$> apiReq <*> (pRequestColumns . iColumns =<< apiReq)
case apiReqCols of
let apiReq = userApiRequest (configSchemas conf) (configRootSpec conf) dbStructure req body
case apiReq of
Left err -> return . errorResponseFor $ err
Right (apiRequest, maybeCols) -> do
Right apiRequest -> do
-- The jwt must be checked before touching the db.
attempt <- attemptJwtClaims (configJWKS conf) (configJwtAudience conf) (toS $ iJWT apiRequest) time (rightToMaybe $ configRoleClaimKey conf)
case jwtClaims attempt of
Left errJwt -> return . errorResponseFor $ errJwt
Right claims -> do
let
authed = containsRole claims
cols = case (iPayload apiRequest, maybeCols) of
(Just ProcessedJSON{pjKeys}, _) -> pjKeys
(Just RawJSON{}, Just cls) -> cls
_ -> S.empty
proc = case iTarget apiRequest of
TargetProc qi _ -> findProc qi cols (iPreferParameters apiRequest == Just SingleObject) $ dbProcs dbStructure
_ -> Nothing
handleReq = runPgLocals conf claims (app dbStructure proc cols conf) apiRequest
txMode = transactionMode proc (iAction apiRequest)
dbResp <- P.use pool $ HT.transaction HT.ReadCommitted txMode handleReq
handleReq = runPgLocals conf claims (app dbStructure conf) apiRequest
dbResp <- P.use pool $ HT.transaction HT.ReadCommitted (txMode apiRequest) handleReq
return $ either (errorResponseFor . PgError authed) identity dbResp
-- Launch the connWorker when the connection is down. The postgrest function can respond successfully(with a stale schema cache) before the connWorker is done.
when (responseStatus response == status503) connWorker
respond response

transactionMode :: Maybe ProcDescription -> Action -> HT.Mode
transactionMode proc action =
case action of
ActionRead _ -> HT.Read
ActionInfo -> HT.Read
ActionInspect _ -> HT.Read
ActionInvoke InvGet -> HT.Read
ActionInvoke InvHead -> HT.Read
ActionInvoke InvPost ->
let v = maybe Volatile pdVolatility proc in
if v == Stable || v == Immutable
then HT.Read
else HT.Write
txMode :: ApiRequest -> HT.Mode
txMode apiRequest =
case (iAction apiRequest, iTarget apiRequest) of
(ActionRead _ , _) -> HT.Read
(ActionInfo , _) -> HT.Read
(ActionInspect _ , _) -> HT.Read
(ActionInvoke InvGet , _) -> HT.Read
(ActionInvoke InvHead, _) -> HT.Read
(ActionInvoke InvPost, TargetProc ProcDescription{pdVolatility=Stable} _) -> HT.Read
(ActionInvoke InvPost, TargetProc ProcDescription{pdVolatility=Immutable} _) -> HT.Read
_ -> HT.Write

app :: DbStructure -> Maybe ProcDescription -> S.Set FieldName -> AppConfig -> ApiRequest -> H.Transaction Response
app dbStructure proc cols conf apiRequest =
app :: DbStructure -> AppConfig -> ApiRequest -> H.Transaction Response
app dbStructure conf apiRequest =
let rawContentTypes = (decodeContentType <$> configRawMediaTypes conf) `L.union` [ CTOctetStream, CTTextPlain ] in
case responseContentTypeOrError (iAccepts apiRequest) rawContentTypes (iAction apiRequest) (iTarget apiRequest) of
Left errorResponse -> return errorResponse
Expand Down Expand Up @@ -210,7 +196,7 @@ app dbStructure proc cols conf apiRequest =
Left err -> return $ errorResponseFor err
Right (ghdrs, gstatus) -> do
let
updateIsNoOp = S.null cols
updateIsNoOp = S.null (iColumns apiRequest)
defStatus | queryTotal == 0 && not updateIsNoOp = status404
| iPreferRepresentation apiRequest == Full = status200
| otherwise = status204
Expand Down Expand Up @@ -294,14 +280,14 @@ app dbStructure proc cols conf apiRequest =
allOrigins = ("Access-Control-Allow-Origin", "*") :: Header in
return $ responseLBS status200 [allOrigins, allowH] mempty

(ActionInvoke invMethod, TargetProc qi@(QualifiedIdentifier tSchema pName) _, Just pJson) ->
let tName = fromMaybe pName $ procTableName =<< proc in
case readSqlParts tSchema tName of
(ActionInvoke invMethod, TargetProc proc@ProcDescription{pdSchema, pdName} _, Just pJson) ->
let tName = fromMaybe pdName $ procTableName proc in
case readSqlParts pdSchema tName of
Left errorResponse -> return errorResponse
Right (q, cq, bField, returning) -> do
let
preferParams = iPreferParameters apiRequest
pq = requestToCallProcQuery qi (specifiedProcArgs cols proc) returnsScalar preferParams returning
pq = requestToCallProcQuery (QualifiedIdentifier pdSchema pdName) (specifiedProcArgs (iColumns apiRequest) proc) returnsScalar preferParams returning
stm = callProcStatement returnsScalar pq q cq shouldCount (contentType == CTSingularJSON)
(contentType == CTTextCSV) (contentType `elem` rawContentTypes) (preferParams == Just MultipleObjects)
bField pgVer
Expand Down Expand Up @@ -351,7 +337,10 @@ app dbStructure proc cols conf apiRequest =
plannedCount = iPreferCount apiRequest == Just PlannedCount
shouldCount = exactCount || estimatedCount
topLevelRange = iTopLevelRange apiRequest
returnsScalar = maybe False procReturnsScalar proc
returnsScalar =
case iTarget apiRequest of
TargetProc proc _ -> procReturnsScalar proc
_ -> False
pgVer = pgVersion dbStructure
profileH = contentProfileH <$> iProfile apiRequest

Expand All @@ -370,7 +359,7 @@ app dbStructure proc cols conf apiRequest =
mutateSqlParts s t =
let
readReq = readRequest s t maxRows (dbRelations dbStructure) apiRequest
mutReq = mutateRequest s t apiRequest cols (tablePKCols dbStructure s t) =<< readReq
mutReq = mutateRequest s t apiRequest (tablePKCols dbStructure s t) =<< readReq
in
(,) <$>
(readRequestToQuery <$> readReq) <*>
Expand Down
10 changes: 5 additions & 5 deletions src/PostgREST/DbRequestBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -287,15 +287,15 @@ addProperty f (targetNodeName:remainingPath, a) (Node rn forest) =
where
pathNode = find (\(Node (_,(nodeName,_,alias,_,_)) _) -> nodeName == targetNodeName || alias == Just targetNodeName) forest

mutateRequest :: Schema -> TableName -> ApiRequest -> S.Set FieldName -> [FieldName] -> ReadRequest -> Either Response MutateRequest
mutateRequest schema tName apiRequest cols pkCols readReq = mapLeft errorResponseFor $
mutateRequest :: Schema -> TableName -> ApiRequest -> [FieldName] -> ReadRequest -> Either Response MutateRequest
mutateRequest schema tName apiRequest pkCols readReq = mapLeft errorResponseFor $
case action of
ActionCreate -> do
confCols <- case iOnConflict apiRequest of
Nothing -> pure pkCols
Just param -> pRequestOnConflict param
pure $ Insert qi cols ((,) <$> iPreferResolution apiRequest <*> Just confCols) [] returnings
ActionUpdate -> Update qi cols <$> combinedLogic <*> pure returnings
pure $ Insert qi (iColumns apiRequest) ((,) <$> iPreferResolution apiRequest <*> Just confCols) [] returnings
ActionUpdate -> Update qi (iColumns apiRequest) <$> combinedLogic <*> pure returnings
ActionSingleUpsert ->
(\flts ->
if null (iLogic apiRequest) &&
Expand All @@ -304,7 +304,7 @@ mutateRequest schema tName apiRequest cols pkCols readReq = mapLeft errorRespons
all (\case
Filter _ (OpExpr False (Op "eq" _)) -> True
_ -> False) flts
then Insert qi cols (Just (MergeDuplicates, pkCols)) <$> combinedLogic <*> pure returnings
then Insert qi (iColumns apiRequest) (Just (MergeDuplicates, pkCols)) <$> combinedLogic <*> pure returnings
else
Left InvalidFilters) =<< filters
ActionDelete -> Delete qi <$> combinedLogic <*> pure returnings
Expand Down
10 changes: 6 additions & 4 deletions src/PostgREST/DbStructure.hs
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,16 @@ decodeProcs =
parseArgs = mapMaybe parseArg . filter (not . isPrefixOf "OUT" . toS) . map strip . split (==',')

parseArg :: Text -> Maybe PgArg
parseArg a =
let arg = lastDef "" $ splitOn "INOUT " a
(body, def) = breakOn " DEFAULT " arg
parseArg arg =
let isVariadic = isPrefixOf "VARIADIC " $ toS arg
-- argmode can be IN, OUT, INOUT, or VARIADIC
argNoMode = lastDef "" $ splitOn (if isVariadic then "VARIADIC " else "INOUT ") arg
(body, def) = breakOn " DEFAULT " argNoMode
(name, typ) = breakOn " " body in
if T.null typ
then Nothing
else Just $
PgArg (dropAround (== '"') name) (strip typ) (T.null def)
PgArg (dropAround (== '"') name) (strip typ) (T.null def) isVariadic

parseRetType :: Text -> Text -> Bool -> Char -> RetType
parseRetType schema name isSetOf typ
Expand Down
2 changes: 1 addition & 1 deletion src/PostgREST/OpenAPI.hs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ makeProcSchema pd =
& required .~ map pgaName (filter pgaReq (pdArgs pd))

makeProcProperty :: PgArg -> (Text, Referenced Schema)
makeProcProperty (PgArg n t _) = (n, Inline s)
makeProcProperty (PgArg n t _ _) = (n, Inline s)
where
s = (mempty :: Schema)
& type_ ?~ toSwaggerType t
Expand Down
13 changes: 8 additions & 5 deletions src/PostgREST/QueryBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,18 @@ requestToCallProcQuery qi pgArgs returnsScalar preferParams returnings =
BS.unwords [
normalizedBody <> ",",
"pgrst_args AS (",
"SELECT * FROM json_to_recordset(" <> selectBody <> ") AS _(" <> fmtArgs (\a -> " " <> encodeUtf8 (pgaType a)) <> ")",
"SELECT * FROM json_to_recordset(" <> selectBody <> ") AS _(" <> fmtArgs (const mempty) (\a -> " " <> encodeUtf8 (pgaType a)) <> ")",
")"]
, if paramsAsMultipleObjects
then fmtArgs (\a -> " := pgrst_args." <> pgFmtIdent (pgaName a))
else fmtArgs (\a -> " := (SELECT " <> pgFmtIdent (pgaName a) <> " FROM pgrst_args LIMIT 1)")
then fmtArgs varadicPrefix (\a -> " := pgrst_args." <> pgFmtIdent (pgaName a))
else fmtArgs varadicPrefix (\a -> " := (SELECT " <> pgFmtIdent (pgaName a) <> " FROM pgrst_args LIMIT 1)")
)

fmtArgs :: (PgArg -> SqlFragment) -> SqlFragment
fmtArgs argFrag = BS.intercalate ", " ((\a -> pgFmtIdent (pgaName a) <> argFrag a) <$> pgArgs)
fmtArgs :: (PgArg -> SqlFragment) -> (PgArg -> SqlFragment) -> SqlFragment
fmtArgs argFragPre argFragSuf = BS.intercalate ", " ((\a -> argFragPre a <> pgFmtIdent (pgaName a) <> argFragSuf a) <$> pgArgs)

varadicPrefix :: PgArg -> SqlFragment
varadicPrefix a = if pgaVar a then "VARIADIC " else mempty

sourceBody :: SqlFragment
sourceBody
Expand Down
Loading