diff --git a/src/PostgREST/ApiRequest.hs b/src/PostgREST/ApiRequest.hs index f154f05cd51..f82606f2dd2 100644 --- a/src/PostgREST/ApiRequest.hs +++ b/src/PostgREST/ApiRequest.hs @@ -70,6 +70,32 @@ data Target = TargetIdent QualifiedIdentifier | TargetUnknown [Text] deriving Eq +-- | RPC query param value `/rpc/func?v=`, used for VARIADIC functions on form-urlencoded POST and GETs +-- | It can be fixed `?v=1` or repeated `?v=1&v=2&v=3. +data RpcParamValue = Fixed Text | Variadic [Text] +instance JSON.ToJSON RpcParamValue where + toJSON (Fixed v) = JSON.toJSON v + toJSON (Variadic v) = JSON.toJSON v + +toRpcParamValue :: ProcDescription -> (Text, Text) -> (Text, RpcParamValue) +toRpcParamValue proc (k, v) | argIsVariadic k = (k, Variadic [v]) + | otherwise = (k, Fixed v) + where + argIsVariadic arg = isJust $ find (\PgArg{pgaName, pgaVar} -> pgaName == arg && pgaVar) $ pdArgs proc + +-- | Convert rpc params `/rpc/func?a=val1&b=val2` to json `{"a": "val1", "b": "val2"} +jsonRpcParams :: ProcDescription -> [(Text, Text)] -> PayloadJSON +jsonRpcParams proc prms = + if not $ pdHasVariadic proc then -- if proc has no variadic arg, save steps and directly convert to json + ProcessedJSON (JSON.encode $ M.fromList $ second JSON.toJSON <$> prms) (S.fromList $ fst <$> prms) + else + let paramsMap = M.fromListWith mergeParams $ toRpcParamValue proc <$> prms in + ProcessedJSON (JSON.encode paramsMap) (S.fromList $ M.keys paramsMap) + where + mergeParams :: RpcParamValue -> RpcParamValue -> RpcParamValue + mergeParams (Variadic a) (Variadic b) = Variadic $ b ++ a + mergeParams _ v = v -- repeated params for non-variadic arguments are not merged + {-| Describes what the user wants to do. This data type is a translation of the raw elements of an HTTP request into domain @@ -187,33 +213,27 @@ userApiRequest confSchemas rootSpec dbStructure req reqBody (Just ProcessedJSON{pjKeys}, _) -> pjKeys (Just RawJSON{}, Just cls) -> cls _ -> S.empty - payload = - case (contentType, action) of - (_, ActionInvoke InvGet) -> Right rpcPrmsToJson - (_, ActionInvoke InvHead) -> Right rpcPrmsToJson - (CTApplicationJSON, _) -> - if isJust columns - then Right $ RawJSON reqBody - else note "All object keys must match" . payloadAttributes reqBody - =<< if BL.null reqBody && isTargetingProc - then Right emptyObject - else JSON.eitherDecode reqBody - (CTTextCSV, _) -> do - json <- csvToJson <$> CSV.decodeByName reqBody - note "All lines must have same number of fields" $ payloadAttributes (JSON.encode json) json - (CTUrlEncoded, _) -> - let json = paramsFromList . map (toS *** toS) . parseSimpleQuery $ toS reqBody - keys = S.fromList $ M.keys json in - Right $ ProcessedJSON (JSON.encode json) keys - (ct, _) -> - Left $ toS $ "Content-Type not acceptable: " <> toMime ct - rpcPrmsToJson = ProcessedJSON (JSON.encode $ paramsFromList rpcQParams) (S.fromList $ fst <$> rpcQParams) - paramsFromList ls = M.fromListWith mergeParams $ toRpcParamsWith isVariadic ls - where - isVariadic k = - case target of - TargetProc{tProc} -> argIsVariadic tProc k - _ -> False + payload = case contentType of + CTApplicationJSON -> + if isJust columns + then Right $ RawJSON reqBody + else note "All object keys must match" . payloadAttributes reqBody + =<< if BL.null reqBody && isTargetingProc + then Right emptyObject + else JSON.eitherDecode reqBody + CTTextCSV -> do + json <- csvToJson <$> CSV.decodeByName reqBody + note "All lines must have same number of fields" $ payloadAttributes (JSON.encode json) json + CTUrlEncoded -> + let urlEncodedBody = parseSimpleQuery $ toS reqBody in + case target of + TargetProc{tProc} -> + Right $ jsonRpcParams tProc $ (toS *** toS) <$> urlEncodedBody + _ -> + let paramsMap = M.fromList $ (toS *** JSON.String . toS) <$> urlEncodedBody in + Right $ ProcessedJSON (JSON.encode paramsMap) $ S.fromList (M.keys paramsMap) + ct -> + Left $ toS $ "Content-Type not acceptable: " <> toMime ct topLevelRange = fromMaybe allRange $ M.lookup "limit" ranges -- if no limit is specified, get all the request rows action = case method of @@ -257,16 +277,14 @@ userApiRequest confSchemas rootSpec dbStructure req reqBody ["rpc", pName] -> TargetProc (callFindProc pName) False other -> TargetUnknown other - shouldParsePayload = - action `elem` - [ActionCreate, ActionUpdate, ActionSingleUpsert, - ActionInvoke InvPost, - -- Though ActionInvoke{isGet=True}(a GET /rpc/..) doesn't really have a payload, we use the payload variable as a way + shouldParsePayload = action `elem` [ActionCreate, ActionUpdate, ActionSingleUpsert, ActionInvoke InvPost] + relevantPayload = case (target, action) of + -- Though ActionInvoke GET/HEAD doesn't really have a payload, we use the payload variable as a way -- to store the query string arguments to the function. - ActionInvoke InvGet, - ActionInvoke InvHead] - relevantPayload | shouldParsePayload = rightToMaybe payload - | otherwise = Nothing + (TargetProc{tProc}, ActionInvoke InvGet) -> Just $ jsonRpcParams tProc rpcQParams + (TargetProc{tProc}, ActionInvoke InvHead) -> Just $ jsonRpcParams tProc rpcQParams + _ | shouldParsePayload -> rightToMaybe payload + | otherwise -> Nothing path = pathInfo req method = requestMethod req hdrs = requestHeaders req diff --git a/src/PostgREST/DbStructure.hs b/src/PostgREST/DbStructure.hs index e3d563d1809..3cb456159e2 100644 --- a/src/PostgREST/DbStructure.hs +++ b/src/PostgREST/DbStructure.hs @@ -128,7 +128,7 @@ sourceColumnFromRow allCols (s1,t1,c1,s2,t2,c2) = (,) <$> col1 <*> col2 decodeProcs :: HD.Result ProcsMap decodeProcs = -- Duplicate rows for a function means they're overloaded, order these by least args according to ProcDescription Ord instance - map sort . M.fromListWith (++) . map ((\(x,y) -> (x, [y])) . addKey) <$> HD.rowList procRow + map sort . M.fromListWith (++) . map ((\(x,y) -> (x, [y])) . addKey . addHasVariadic) <$> HD.rowList procRow where procRow = ProcDescription <$> column HD.text @@ -141,6 +141,10 @@ decodeProcs = <*> column HD.bool <*> column HD.char) <*> (parseVolatility <$> column HD.char) + <*> pure False + + addHasVariadic :: ProcDescription -> ProcDescription + addHasVariadic pd@ProcDescription{pdArgs} = pd{pdHasVariadic=isJust $ find pgaVar pdArgs} addKey :: ProcDescription -> (QualifiedIdentifier, ProcDescription) addKey pd = (QualifiedIdentifier (pdSchema pd) (pdName pd), pd) diff --git a/src/PostgREST/Types.hs b/src/PostgREST/Types.hs index 377a8f4316c..56964e71b5e 100644 --- a/src/PostgREST/Types.hs +++ b/src/PostgREST/Types.hs @@ -4,7 +4,6 @@ Description : PostgREST common types and functions used by the rest of the modul -} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DuplicateRecordFields #-} -{-# LANGUAGE NamedFieldPuns #-} module PostgREST.Types where @@ -148,14 +147,15 @@ data ProcDescription = ProcDescription { , pdArgs :: [PgArg] , pdReturnType :: RetType , pdVolatility :: ProcVolatility +, pdHasVariadic :: Bool } deriving (Show, Eq) -- Order by least number of args in the case of overloaded functions instance Ord ProcDescription where - ProcDescription schema1 name1 des1 args1 rt1 vol1 `compare` ProcDescription schema2 name2 des2 args2 rt2 vol2 + ProcDescription schema1 name1 des1 args1 rt1 vol1 hasVar1 `compare` ProcDescription schema2 name2 des2 args2 rt2 vol2 hasVar2 | schema1 == schema2 && name1 == name2 && length args1 < length args2 = LT | schema2 == schema2 && name1 == name2 && length args1 > length args2 = GT - | otherwise = (schema1, name1, des1, args1, rt1, vol1) `compare` (schema2, name2, des2, args2, rt2, vol2) + | otherwise = (schema1, name1, des1, args1, rt1, vol1, hasVar1) `compare` (schema2, name2, des2, args2, rt2, vol2, hasVar2) -- | A map of all procs, all of which can be overloaded(one entry will have more than one ProcDescription). -- | It uses a HashMap for a faster lookup. @@ -171,7 +171,7 @@ findProc qi payloadKeys paramsAsSingleObject allProcs = fromMaybe fallback bestM where -- instead of passing Maybe ProcDescription around, we create a fallback description here when we can't find a matching function -- args is empty, but because "specifiedProcArgs" will fill the missing arguments with default type text, this is not a problem - fallback = ProcDescription (qiSchema qi) (qiName qi) Nothing mempty (SetOf $ Composite $ QualifiedIdentifier "" "record") Volatile + fallback = ProcDescription (qiSchema qi) (qiName qi) Nothing mempty (SetOf $ Composite $ QualifiedIdentifier mempty "record") Volatile False bestMatch = case M.lookup qi allProcs of Nothing -> Nothing @@ -202,12 +202,6 @@ procTableName proc = case pdReturnType proc of Single (Composite qi) -> Just $ qiName qi _ -> Nothing -argIsVariadic :: ProcDescription -> Text -> Bool -argIsVariadic proc arg = - case find (\PgArg{pgaName} -> pgaName == arg) $ pdArgs proc of - Just PgArg{pgaVar} -> pgaVar - _ -> False - type Schema = Text type TableName = Text @@ -414,26 +408,6 @@ type Alias = Text type Cast = Text type NodeName = Text --- RPC query param, used for POST of form-data and GET requests -data RpcParamValue = Fixed Text | Variadic [Text] - -mergeParams :: RpcParamValue -> RpcParamValue -> RpcParamValue -mergeParams (Variadic a) (Variadic b) = Variadic $ b ++ a --- repeated params for non-variadic arguments are not merged -mergeParams _ v = v - -instance JSON.ToJSON RpcParamValue where - toJSON (Fixed v) = JSON.toJSON v - toJSON (Variadic v) = JSON.toJSON v - -type RpcParams = [(Text, RpcParamValue)] - -toRpcParamsWith :: (Text -> Bool) -> [(Text, Text)] -> RpcParams -toRpcParamsWith isVariadic ls = toRpcParamValue <$> ls - where - toRpcParamValue (k, v) - | isVariadic k = (k, Variadic [v]) - | otherwise = (k, Fixed v) {-| Custom guc header, it's obtained by parsing the json in a: diff --git a/test/Feature/JsonOperatorSpec.hs b/test/Feature/JsonOperatorSpec.hs index 9f0d47f3e7d..094f5b8f8dc 100644 --- a/test/Feature/JsonOperatorSpec.hs +++ b/test/Feature/JsonOperatorSpec.hs @@ -7,7 +7,8 @@ import Test.Hspec import Test.Hspec.Wai import Test.Hspec.Wai.JSON -import PostgREST.Types (PgVersion, pgVersion112, pgVersion121, pgVersion95) +import PostgREST.Types (PgVersion, pgVersion112, pgVersion121, + pgVersion95) import Protolude hiding (get) import SpecHelper diff --git a/test/Feature/QuerySpec.hs b/test/Feature/QuerySpec.hs index b27c21385f2..823c824716f 100644 --- a/test/Feature/QuerySpec.hs +++ b/test/Feature/QuerySpec.hs @@ -10,7 +10,8 @@ import Test.Hspec.Wai.JSON import Text.Heredoc -import PostgREST.Types (PgVersion, pgVersion96, pgVersion112, pgVersion121) +import PostgREST.Types (PgVersion, pgVersion112, pgVersion121, + pgVersion96) import Protolude hiding (get) import SpecHelper diff --git a/test/Feature/RpcPreRequestGucsSpec.hs b/test/Feature/RpcPreRequestGucsSpec.hs index 77ff5120b0a..9afc7a9a683 100644 --- a/test/Feature/RpcPreRequestGucsSpec.hs +++ b/test/Feature/RpcPreRequestGucsSpec.hs @@ -11,7 +11,7 @@ import Test.Hspec.Wai import Test.Hspec.Wai.JSON import Text.Heredoc -import Protolude hiding (get) +import Protolude hiding (get) import SpecHelper spec :: SpecWith ((), Application) diff --git a/test/Feature/RpcSpec.hs b/test/Feature/RpcSpec.hs index ca2716e6a7f..b894fab1ccd 100644 --- a/test/Feature/RpcSpec.hs +++ b/test/Feature/RpcSpec.hs @@ -3,7 +3,7 @@ module Feature.RpcSpec where import qualified Data.ByteString.Lazy as BL (empty) import Network.Wai (Application) -import Network.Wai.Test (SResponse (simpleBody, simpleStatus, simpleHeaders)) +import Network.Wai.Test (SResponse (simpleBody, simpleHeaders, simpleStatus)) import Network.HTTP.Types import Test.Hspec hiding (pendingWith) @@ -523,8 +523,8 @@ spec actualPgVersion = `shouldRespondWith` [json|["hi", "there"]|] - it "returns first value for repeated params without VARIADIC" $ - get "/rpc/sayhello?name=world&name=ignored" + it "returns last value for repeated params without VARIADIC" $ + get "/rpc/sayhello?name=ignored&name=world" `shouldRespondWith` [json|"Hello, world"|]