diff --git a/CHANGELOG.md b/CHANGELOG.md index ac956f5329..627259d7bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ This project adheres to [Semantic Versioning](http://semver.org/). - #3001, Add `statement_timeout` set on functions - @taimoorzaeem - #3045, Apply superuser settings on impersonated roles if they have PostgreSQL 15 `GRANT SET ON PARAMETER` privilege - @steve-chavez - #3062, Add config for enabling the `Server-Timing` header - @develop7 + - #915, Add support for aggregate functions - @timabdulla + + The aggregate functions SUM(), MAX(), MIN(), AVG(), and COUNT() are now supported. + + It's disabled by default, you can enable it with `db-aggregates-enabled`. ### Fixed diff --git a/postgrest.cabal b/postgrest.cabal index 80601ebf0b..7fedc1b628 100644 --- a/postgrest.cabal +++ b/postgrest.cabal @@ -201,6 +201,7 @@ test-suite spec Feature.OpenApi.RootSpec Feature.OpenApi.SecurityOpenApiSpec Feature.OptionsSpec + Feature.Query.AggregateFunctionsSpec Feature.Query.AndOrParamsSpec Feature.Query.ComputedRelsSpec Feature.Query.CustomMediaSpec diff --git a/src/PostgREST/ApiRequest/QueryParams.hs b/src/PostgREST/ApiRequest/QueryParams.hs index d6515068f8..ff7de07d21 100644 --- a/src/PostgREST/ApiRequest/QueryParams.hs +++ b/src/PostgREST/ApiRequest/QueryParams.hs @@ -31,8 +31,8 @@ import Data.Tree (Tree (..)) import Text.Parsec.Error (errorMessages, showErrorMessages) import Text.ParserCombinators.Parsec (GenParser, ParseError, Parser, - anyChar, between, char, digit, - eof, errorPos, letter, + anyChar, between, char, choice, + digit, eof, errorPos, letter, lookAhead, many1, noneOf, notFollowedBy, oneOf, optionMaybe, sepBy, sepBy1, @@ -43,7 +43,8 @@ import PostgREST.RangeQuery (NonnegRange, allRange, rangeOffset, restrictRange) import PostgREST.SchemaCache.Identifiers (FieldName) -import PostgREST.ApiRequest.Types (EmbedParam (..), EmbedPath, Field, +import PostgREST.ApiRequest.Types (AggregateFunction (..), + EmbedParam (..), EmbedPath, Field, Filter (..), FtsOperator (..), Hint, JoinType (..), JsonOperand (..), @@ -58,7 +59,7 @@ import PostgREST.ApiRequest.Types (EmbedParam (..), EmbedPath, Field, SimpleOperator (..), SingleVal, TrileanVal (..)) -import Protolude hiding (try) +import Protolude hiding (Sum, try) data QueryParams = QueryParams @@ -99,7 +100,7 @@ data QueryParams = -- 'select' is a reserved parameter that selects the fields to be returned: -- -- >>> qsSelect <$> parse False "select=name,location" --- Right [Node {rootLabel = SelectField {selField = ("name",[]), selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SelectField {selField = ("location",[]), selCast = Nothing, selAlias = Nothing}, subForest = []}] +-- Right [Node {rootLabel = SelectField {selField = ("name",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SelectField {selField = ("location",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []}] -- -- Filters are parameters whose value contains an operator, separated by a '.' from its value: -- @@ -282,16 +283,16 @@ pTreePath = do -- Parse select= into a Forest of SelectItems -- -- >>> P.parse pFieldForest "" "id" --- Right [Node {rootLabel = SelectField {selField = ("id",[]), selCast = Nothing, selAlias = Nothing}, subForest = []}] +-- Right [Node {rootLabel = SelectField {selField = ("id",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []}] -- -- >>> P.parse pFieldForest "" "client(id)" --- Right [Node {rootLabel = SelectRelation {selRelation = "client", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("id",[]), selCast = Nothing, selAlias = Nothing}, subForest = []}]}] +-- Right [Node {rootLabel = SelectRelation {selRelation = "client", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("id",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []}]}] -- -- >>> P.parse pFieldForest "" "*,client(*,nested(*))" --- Right [Node {rootLabel = SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SelectRelation {selRelation = "client", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SelectRelation {selRelation = "nested", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing}, subForest = []}]}]}] +-- Right [Node {rootLabel = SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SelectRelation {selRelation = "client", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SelectRelation {selRelation = "nested", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []}]}]}] -- -- >>> P.parse pFieldForest "" "*,...client(*),other(*)" --- Right [Node {rootLabel = SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SpreadRelation {selRelation = "client", selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing}, subForest = []}]},Node {rootLabel = SelectRelation {selRelation = "other", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing}, subForest = []}]}] +-- Right [Node {rootLabel = SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SpreadRelation {selRelation = "client", selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []}]},Node {rootLabel = SelectRelation {selRelation = "other", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []}]}] -- -- >>> P.parse pFieldForest "" "" -- Right [] @@ -299,7 +300,7 @@ pTreePath = do -- >>> P.parse pFieldForest "" "id,clients(name[])" -- Left (line 1, column 16): -- unexpected '[' --- expecting letter, digit, "-", "->>", "->", "::", ")", "," or end of input +-- expecting letter, digit, "-", "->>", "->", "::", ".", ")", "," or end of input -- -- >>> P.parse pFieldForest "" "data->>-78xy" -- Left (line 1, column 11): @@ -452,35 +453,37 @@ pRelationSelect :: Parser SelectItem pRelationSelect = lexeme $ do alias <- optionMaybe ( try(pFieldName <* aliasSeparator) ) name <- pFieldName + guard (name /= "count") (hint, jType) <- pEmbedParams try (void $ lookAhead (string "(")) return $ SelectRelation name alias hint jType + -- | -- Parse regular fields in select -- -- >>> P.parse pFieldSelect "" "name" --- Right (SelectField {selField = ("name",[]), selCast = Nothing, selAlias = Nothing}) +-- Right (SelectField {selField = ("name",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}) -- -- >>> P.parse pFieldSelect "" "name->jsonpath" --- Right (SelectField {selField = ("name",[JArrow {jOp = JKey {jVal = "jsonpath"}}]), selCast = Nothing, selAlias = Nothing}) +-- Right (SelectField {selField = ("name",[JArrow {jOp = JKey {jVal = "jsonpath"}}]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}) -- -- >>> P.parse pFieldSelect "" "name::cast" --- Right (SelectField {selField = ("name",[]), selCast = Just "cast", selAlias = Nothing}) +-- Right (SelectField {selField = ("name",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Just "cast", selAlias = Nothing}) -- -- >>> P.parse pFieldSelect "" "alias:name" --- Right (SelectField {selField = ("name",[]), selCast = Nothing, selAlias = Just "alias"}) +-- Right (SelectField {selField = ("name",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Just "alias"}) -- -- >>> P.parse pFieldSelect "" "alias:name->jsonpath::cast" --- Right (SelectField {selField = ("name",[JArrow {jOp = JKey {jVal = "jsonpath"}}]), selCast = Just "cast", selAlias = Just "alias"}) +-- Right (SelectField {selField = ("name",[JArrow {jOp = JKey {jVal = "jsonpath"}}]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Just "cast", selAlias = Just "alias"}) -- -- >>> P.parse pFieldSelect "" "*" --- Right (SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing}) +-- Right (SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}) -- -- >>> P.parse pFieldSelect "" "name!hint" -- Left (line 1, column 5): -- unexpected '!' --- expecting letter, digit, "-", "->>", "->", "::", ")", "," or end of input +-- expecting letter, digit, "-", "->>", "->", "::", ".", ")", "," or end of input -- -- >>> P.parse pFieldSelect "" "*!hint" -- Left (line 1, column 2): @@ -495,18 +498,36 @@ pFieldSelect :: Parser SelectItem pFieldSelect = lexeme $ try (do s <- pStar pEnd - return $ SelectField (s, []) Nothing Nothing) + return $ SelectField (s, []) Nothing Nothing Nothing Nothing) + <|> try (do + alias <- optionMaybe ( try(pFieldName <* aliasSeparator) ) + _ <- string "count()" + aggCast' <- optionMaybe (string "::" *> pIdentifier) + pEnd + return $ SelectField ("*", []) (Just Count) (toS <$> aggCast') Nothing alias) <|> do - alias <- optionMaybe ( try(pFieldName <* aliasSeparator) ) - fld <- pField - cast' <- optionMaybe (string "::" *> pIdentifier) + alias <- optionMaybe ( try(pFieldName <* aliasSeparator) ) + fld <- pField + cast' <- optionMaybe (string "::" *> pIdentifier) + agg <- optionMaybe (try (char '.' *> pAggregation <* string "()")) + aggCast' <- optionMaybe (string "::" *> pIdentifier) pEnd - return $ SelectField fld (toS <$> cast') alias + return $ SelectField fld agg (toS <$> aggCast') (toS <$> cast') alias where pEnd = try (void $ lookAhead (string ")")) <|> try (void $ lookAhead (string ",")) <|> try eof pStar = string "*" $> "*" + pAggregation = choice + [ string "sum" $> Sum + , string "avg" $> Avg + , string "count" $> Count + -- Using 'try' for "min" and "max" to allow backtracking. + -- This is necessary because both start with the same character 'm', + -- and without 'try', a partial match on "max" would prevent "min" from being tried. + , try (string "max") $> Max + , try (string "min") $> Min + ] -- | diff --git a/src/PostgREST/ApiRequest/Types.hs b/src/PostgREST/ApiRequest/Types.hs index e09d57db8d..2fc73b0458 100644 --- a/src/PostgREST/ApiRequest/Types.hs +++ b/src/PostgREST/ApiRequest/Types.hs @@ -1,6 +1,7 @@ {-# LANGUAGE DuplicateRecordFields #-} module PostgREST.ApiRequest.Types - ( Alias + ( AggregateFunction(..) + , Alias , Cast , Depth , EmbedParam(..) @@ -42,12 +43,14 @@ import PostgREST.SchemaCache.Routine (Routine (..)) import Protolude --- | The value in `/tbl?select=alias:field::cast` +-- | The value in `/tbl?select=alias:field.aggregateFunction()::cast` data SelectItem = SelectField - { selField :: Field - , selCast :: Maybe Cast - , selAlias :: Maybe Alias + { selField :: Field + , selAggregateFunction :: Maybe AggregateFunction + , selAggregateCast :: Maybe Cast + , selCast :: Maybe Cast + , selAlias :: Maybe Alias } -- | The value in `/tbl?select=alias:another_tbl(*)` | SelectRelation @@ -65,7 +68,8 @@ data SelectItem deriving (Eq, Show) data ApiRequestError - = AmbiguousRelBetween Text Text [Relationship] + = AggregatesNotAllowed + | AmbiguousRelBetween Text Text [Relationship] | AmbiguousRpc [Routine] | BinaryFieldError MediaType | MediaTypeError [ByteString] @@ -135,6 +139,9 @@ type Cast = Text type Alias = Text type Hint = Text +data AggregateFunction = Sum | Avg | Max | Min | Count + deriving (Show, Eq) + data EmbedParam -- | Disambiguates an embedding operation when there's multiple relationships -- between two tables. Can be the name of a foreign key constraint, column diff --git a/src/PostgREST/Config.hs b/src/PostgREST/Config.hs index af1238d7f6..8b4411611f 100644 --- a/src/PostgREST/Config.hs +++ b/src/PostgREST/Config.hs @@ -69,6 +69,7 @@ import Protolude hiding (Proxy, toList) data AppConfig = AppConfig { configAppSettings :: [(Text, Text)] + , configDbAggregates :: Bool , configDbAnonRole :: Maybe BS.ByteString , configDbChannel :: Text , configDbChannelEnabled :: Bool @@ -138,7 +139,8 @@ toText conf = where -- apply conf to all pgrst settings pgrstSettings = (\(k, v) -> (k, v conf)) <$> - [("db-anon-role", q . T.decodeUtf8 . fromMaybe "" . configDbAnonRole) + [("db-aggregates-enabled", T.toLower . show . configDbAggregates) + ,("db-anon-role", q . T.decodeUtf8 . fromMaybe "" . configDbAnonRole) ,("db-channel", q . configDbChannel) ,("db-channel-enabled", T.toLower . show . configDbChannelEnabled) ,("db-extra-search-path", q . T.intercalate "," . configDbExtraSearchPath) @@ -232,6 +234,7 @@ parser :: Maybe FilePath -> Environment -> [(Text, Text)] -> RoleSettings -> Rol parser optPath env dbSettings roleSettings roleIsolationLvl = AppConfig <$> parseAppSettings "app.settings" + <*> (fromMaybe False <$> optBool "db-aggregates-enabled") <*> (fmap encodeUtf8 <$> optString "db-anon-role") <*> (fromMaybe "pgrst" <$> optString "db-channel") <*> (fromMaybe True <$> optBool "db-channel-enabled") diff --git a/src/PostgREST/Config/Database.hs b/src/PostgREST/Config/Database.hs index 4672289f92..697d06cee4 100644 --- a/src/PostgREST/Config/Database.hs +++ b/src/PostgREST/Config/Database.hs @@ -45,7 +45,8 @@ prefix = "pgrst." dbSettingsNames :: [Text] dbSettingsNames = (prefix <>) <$> - ["db_anon_role" + ["db_aggregates_enabled" + ,"db_anon_role" ,"db_pre_config" ,"db_extra_search_path" ,"db_max_rows" diff --git a/src/PostgREST/Error.hs b/src/PostgREST/Error.hs index b05e7744a6..90c3ea0a8e 100644 --- a/src/PostgREST/Error.hs +++ b/src/PostgREST/Error.hs @@ -61,6 +61,7 @@ class (JSON.ToJSON a) => PgrstError a where responseLBS (status err) (baseHeader : headers err) $ errorPayload err instance PgrstError ApiRequestError where + status AggregatesNotAllowed{} = HTTP.status400 status AmbiguousRelBetween{} = HTTP.status300 status AmbiguousRpc{} = HTTP.status300 status BinaryFieldError{} = HTTP.status406 @@ -198,6 +199,9 @@ instance JSON.ToJSON ApiRequestError where (Just $ JSON.String $ T.decodeUtf8 ("Invalid preferences: " <> BS.intercalate ", " prefs)) Nothing + toJSON AggregatesNotAllowed = toJsonPgrstError + ApiRequestErrorCode23 "Use of aggregate functions is not allowed" Nothing Nothing + toJSON (NoRelBetween parent child embedHint schema allRels) = toJsonPgrstError SchemaCacheErrorCode00 ("Could not find a relationship between '" <> parent <> "' and '" <> child <> "' in the schema cache") @@ -604,6 +608,7 @@ data ErrorCode | ApiRequestErrorCode20 | ApiRequestErrorCode21 | ApiRequestErrorCode22 + | ApiRequestErrorCode23 -- Schema Cache errors | SchemaCacheErrorCode00 | SchemaCacheErrorCode01 @@ -652,6 +657,7 @@ buildErrorCode code = "PGRST" <> case code of ApiRequestErrorCode20 -> "120" ApiRequestErrorCode21 -> "121" ApiRequestErrorCode22 -> "122" + ApiRequestErrorCode23 -> "123" SchemaCacheErrorCode00 -> "200" SchemaCacheErrorCode01 -> "201" diff --git a/src/PostgREST/Plan.hs b/src/PostgREST/Plan.hs index 4948b69525..68cadeda63 100644 --- a/src/PostgREST/Plan.hs +++ b/src/PostgREST/Plan.hs @@ -34,7 +34,7 @@ import qualified Data.Set as S import qualified PostgREST.SchemaCache.Routine as Routine import Data.Either.Combinators (mapLeft, mapRight) -import Data.List (delete) +import Data.List (delete, lookup) import Data.Tree (Tree (..)) import PostgREST.ApiRequest (Action (..), @@ -296,18 +296,21 @@ resolveQueryInputField ctx field = withTextParse ctx $ resolveTypeOrUnknown ctx -- | Adds filters, order, limits on its respective nodes. -- | Adds joins conditions obtained from resource embedding. readPlan :: QualifiedIdentifier -> AppConfig -> SchemaCache -> ApiRequest -> Either Error ReadPlanTree -readPlan qi@QualifiedIdentifier{..} AppConfig{configDbMaxRows} SchemaCache{dbTables, dbRelationships, dbRepresentations} apiRequest = +readPlan qi@QualifiedIdentifier{..} AppConfig{configDbMaxRows, configDbAggregates} SchemaCache{dbTables, dbRelationships, dbRepresentations} apiRequest = let -- JSON output format hardcoded for now. In the future we might want to support other output mappings such as CSV. ctx = ResolverContext dbTables dbRepresentations qi "json" in mapLeft ApiRequestError $ treeRestrictRange configDbMaxRows (iAction apiRequest) =<< + validateAggFunctions configDbAggregates =<< + hoistSpreadAggFunctions =<< + addRelSelects =<< addNullEmbedFilters =<< validateSpreadEmbeds =<< addRelatedOrders =<< - addDataRepresentationAliases =<< - expandStarsForDataRepresentations ctx =<< + addAliases =<< + expandStars ctx =<< addRels qiSchema (iAction apiRequest) dbRelationships Nothing =<< addLogicTrees ctx apiRequest =<< addRanges apiRequest =<< @@ -320,7 +323,7 @@ initReadRequest ctx@ResolverContext{qi=QualifiedIdentifier{..}} = foldr (treeEntry rootDepth) $ Node defReadPlan{from=qi ctx, relName=qiName, depth=rootDepth} [] where rootDepth = 0 - defReadPlan = ReadPlan [] (QualifiedIdentifier mempty mempty) Nothing [] [] allRange mempty Nothing [] Nothing mempty Nothing Nothing False rootDepth + defReadPlan = ReadPlan [] (QualifiedIdentifier mempty mempty) Nothing [] [] allRange mempty Nothing [] Nothing mempty Nothing Nothing False [] rootDepth treeEntry :: Depth -> Tree SelectItem -> ReadPlanTree -> ReadPlanTree treeEntry depth (Node si fldForest) (Node q rForest) = let nxtDepth = succ depth in @@ -336,49 +339,86 @@ initReadRequest ctx@ResolverContext{qi=QualifiedIdentifier{..}} = (Node defReadPlan{from=QualifiedIdentifier qiSchema selRelation, relName=selRelation, relHint=selHint, relJoinType=selJoinType, depth=nxtDepth, relIsSpread=True} []) fldForest:rForest SelectField{..} -> - Node q{select=(resolveOutputField ctx{qi=from q} selField, selCast, selAlias):select q} rForest + Node q{select=CoercibleSelectField (resolveOutputField ctx{qi=from q} selField) selAggregateFunction selAggregateCast selCast selAlias:select q} rForest --- | Preserve the original field name if data representation is used to coerce the value. -addDataRepresentationAliases :: ReadPlanTree -> Either ApiRequestError ReadPlanTree -addDataRepresentationAliases rPlanTree = Right $ fmap (\rPlan@ReadPlan{select=sel} -> rPlan{select=map aliasSelectItem sel}) rPlanTree +-- If an alias is explicitly specified, it is always respected. However, an alias may be +-- determined automatically in the case of a select term with a JSON path, or in the case +-- of domain representations. +addAliases :: ReadPlanTree -> Either ApiRequestError ReadPlanTree +addAliases = Right . fmap addAliasToPlan where - aliasSelectItem :: (CoercibleField, Maybe Cast, Maybe Alias) -> (CoercibleField, Maybe Cast, Maybe Alias) - -- If there already is an alias, don't overwrite it. - aliasSelectItem (fld@(CoercibleField{cfName=fieldName, cfTransform=(Just _)}), Nothing, Nothing) = (fld, Nothing, Just fieldName) - aliasSelectItem fld = fld + addAliasToPlan rp@ReadPlan{select=sel} = rp{select=map aliasSelectField sel} + + aliasSelectField :: CoercibleSelectField -> CoercibleSelectField + aliasSelectField field@CoercibleSelectField{csField=fieldDetails, csAggFunction=aggFun, csAlias=alias} + | isJust alias || isJust aggFun = field + | isJsonKeyPath fieldDetails, Just key <- lastJsonKey fieldDetails = field { csAlias = Just key } + | isTransformPath fieldDetails = field { csAlias = Just (cfName fieldDetails) } + | otherwise = field + + isJsonKeyPath CoercibleField{cfJsonPath=(_: _)} = True + isJsonKeyPath _ = False + + isTransformPath CoercibleField{cfTransform=(Just _), cfName=_} = True + isTransformPath _ = False + + lastJsonKey CoercibleField{cfName=fieldName, cfJsonPath=jsonPath} = + case jOp <$> lastMay jsonPath of + Just (JKey key) -> Just key + Just (JIdx _) -> Just $ fromMaybe fieldName lastKey + -- We get the lastKey because on: + -- `select=data->1->mycol->>2`, we need to show the result as [ {"mycol": ..}, {"mycol": ..} ] + -- `select=data->3`, we need to show the result as [ {"data": ..}, {"data": ..} ] + where lastKey = jVal <$> find (\case JKey{} -> True; _ -> False) (jOp <$> reverse jsonPath) + Nothing -> Nothing knownColumnsInContext :: ResolverContext -> [Column] knownColumnsInContext ResolverContext{..} = fromMaybe [] $ HM.lookup qi tables >>= Just . tableColumnsList --- | Expand "select *" into explicit field names of the table, if necessary to apply data representations. -expandStarsForDataRepresentations :: ResolverContext -> ReadPlanTree -> Either ApiRequestError ReadPlanTree -expandStarsForDataRepresentations ctx@ResolverContext{qi} rPlanTree = Right $ fmap expandStars rPlanTree +-- | Expand "select *" into explicit field names of the table in the following situations: +-- * When there are data representations present. +-- * When there is an aggregate function in a given ReadPlan or its parent. +expandStars :: ResolverContext -> ReadPlanTree -> Either ApiRequestError ReadPlanTree +expandStars ctx rPlanTree = Right $ expandStarsForReadPlan False rPlanTree where - expandStars :: ReadPlan -> ReadPlan + expandStarsForReadPlan :: Bool -> ReadPlanTree -> ReadPlanTree + expandStarsForReadPlan hasAgg (Node rp@ReadPlan{select, from=fromQI, fromAlias=alias} children) = + let + newHasAgg = hasAgg || any (isJust . csAggFunction) select + newCtx = adjustContext ctx fromQI alias + newRPlan = expandStarsForTable newCtx newHasAgg rp + in Node newRPlan (map (expandStarsForReadPlan newHasAgg) children) + + -- Choose the appropriate context based on whether we're dealing with "pgrst_source" + adjustContext :: ResolverContext -> QualifiedIdentifier -> Maybe Text -> ResolverContext -- When the schema is "" and the table is the source CTE, we assume the true source table is given in the from -- alias and belongs to the request schema. See the bit in `addRels` with `newFrom = ...`. - expandStars rPlan@ReadPlan{from=(QualifiedIdentifier "" "pgrst_source"), fromAlias=(Just tblAlias)} = - expandStarsForTable ctx{qi=qi{qiName=tblAlias}} rPlan - expandStars rPlan@ReadPlan{from=fromTable} = - expandStarsForTable ctx{qi=fromTable} rPlan - -expandStarsForTable :: ResolverContext -> ReadPlan -> ReadPlan -expandStarsForTable ctx@ResolverContext{representations, outputType} rplan@ReadPlan{select=selectItems} = - -- If we have a '*' select AND the target table has at least one data representation, expand. - if ("*" `elem` map (\(field, _, _) -> cfName field) selectItems) && any hasOutputRep knownColumns - then rplan{select=concatMap (expandStarSelectItem knownColumns) selectItems} - else rplan + adjustContext context@ResolverContext{qi=ctxQI} (QualifiedIdentifier "" "pgrst_source") (Just a) = context{qi=ctxQI{qiName=a}} + adjustContext context fromQI _ = context{qi=fromQI} + +expandStarsForTable :: ResolverContext -> Bool -> ReadPlan -> ReadPlan +expandStarsForTable ctx@ResolverContext{representations, outputType} hasAgg rp@ReadPlan{select=selectFields} + -- We expand if either of the below are true: + -- * We have a '*' select AND there is an aggregate function in this ReadPlan's sub-tree. + -- * We have a '*' select AND the target table has at least one data representation. + -- We ignore any '*' selects that have an aggregate function attached (i.e for COUNT(*)). + | hasStarSelect && (hasAgg || hasDataRepresentation) = rp{select = concatMap (expandStarSelectField knownColumns) selectFields} + | otherwise = rp where + hasStarSelect = "*" `elem` map (cfName . csField) filteredSelectFields + filteredSelectFields = filter (isNothing . csAggFunction) selectFields + hasDataRepresentation = any hasOutputRep knownColumns knownColumns = knownColumnsInContext ctx hasOutputRep :: Column -> Bool hasOutputRep col = HM.member (colNominalType col, outputType) representations - expandStarSelectItem :: [Column] -> (CoercibleField, Maybe Cast, Maybe Alias) -> [(CoercibleField, Maybe Cast, Maybe Alias)] - expandStarSelectItem columns (CoercibleField{cfName="*", cfJsonPath=[]}, b, c) = map (\col -> (withOutputFormat ctx $ resolveColumnField col, b, c)) columns - expandStarSelectItem _ selectItem = [selectItem] + expandStarSelectField :: [Column] -> CoercibleSelectField -> [CoercibleSelectField] + expandStarSelectField columns sel@CoercibleSelectField{csField=CoercibleField{cfName="*", cfJsonPath=[]}, csAggFunction=Nothing} = + map (\col -> sel { csField = withOutputFormat ctx $ resolveColumnField col }) columns + expandStarSelectField _ selectField = [selectField] -- | Enforces the `max-rows` config on the result treeRestrictRange :: Maybe Integer -> Action -> ReadPlanTree -> Either ApiRequestError ReadPlanTree @@ -535,6 +575,123 @@ findRel schema allRels origin target hint = ) ) $ fromMaybe mempty $ HM.lookup (QualifiedIdentifier schema origin, schema) allRels + +addRelSelects :: ReadPlanTree -> Either ApiRequestError ReadPlanTree +addRelSelects node@(Node rp forest) + | null forest = Right node + | otherwise = + let newForest = rights $ addRelSelects <$> forest + newRelSelects = mapMaybe generateRelSelectField newForest + in Right $ Node rp { relSelect = newRelSelects } newForest + +generateRelSelectField :: ReadPlanTree -> Maybe RelSelectField +generateRelSelectField (Node rp@ReadPlan{relToParent=Just _, relAggAlias, relIsSpread = True} _) = + Just $ Spread { rsSpreadSel = generateSpreadSelectFields rp, rsAggAlias = relAggAlias } +generateRelSelectField (Node ReadPlan{relToParent=Just rel, select, relName, relAlias, relAggAlias, relIsSpread = False} forest) = + Just $ JsonEmbed { rsEmbedMode, rsSelName, rsAggAlias = relAggAlias, rsEmptyEmbed } + where + rsSelName = fromMaybe relName relAlias + rsEmbedMode = if relIsToOne rel then JsonObject else JsonArray + rsEmptyEmbed = null select && null forest +generateRelSelectField _ = Nothing + +generateSpreadSelectFields :: ReadPlan -> [SpreadSelectField] +generateSpreadSelectFields ReadPlan{select, relSelect} = + -- We combine the select and relSelect fields into a single list of SpreadSelectField. + selectSpread ++ relSelectSpread + where + selectSpread = map selectToSpread select + selectToSpread :: CoercibleSelectField -> SpreadSelectField + selectToSpread CoercibleSelectField{csField = CoercibleField{cfName}, csAlias} = + SpreadSelectField { ssSelName = fromMaybe cfName csAlias, ssSelAggFunction = Nothing, ssSelAggCast = Nothing, ssSelAlias = Nothing } + + relSelectSpread = concatMap relSelectToSpread relSelect + relSelectToSpread :: RelSelectField -> [SpreadSelectField] + relSelectToSpread (JsonEmbed{rsSelName}) = + [SpreadSelectField { ssSelName = rsSelName, ssSelAggFunction = Nothing, ssSelAggCast = Nothing, ssSelAlias = Nothing }] + relSelectToSpread (Spread{rsSpreadSel}) = + rsSpreadSel + +-- When aggregates are present in a ReadPlan that will be spread, we "hoist" +-- to the highest level possible so that their semantics make sense. For instance, +-- imagine the user performs the following request: +-- `GET /projects?select=client_id,...project_invoices(invoice_total.sum())` +-- +-- In this case, it is sensible that we would expect to receive the sum of the +-- `invoice_total`, grouped by the `client_id`. Without hoisting, the sum would +-- be performed in the sub-query for the joined table `project_invoices`, thus +-- making it essentially a no-op. With hoisting, we hoist the aggregate function +-- so that the aggregate function is performed in a more sensible context. +-- +-- We will try to hoist the aggregate function to the highest possible level, +-- which means that we hoist until we reach the root node, or until we reach a +-- ReadPlan that will be embedded a JSON object or JSON array. + +-- This type alias represents an aggregate that is to be hoisted to the next +-- level up. The first tuple of `Alias` and `FieldName` contain the alias for +-- the joined table and the original field name for the hoisted field. +-- +-- The second tuple contains the aggregate function to be applied, the cast, and +-- the alias, if it was supplied by the user or otherwise determined. +type HoistedAgg = ((Alias, FieldName), (AggregateFunction, Maybe Cast, Maybe Alias)) + +hoistSpreadAggFunctions :: ReadPlanTree -> Either ApiRequestError ReadPlanTree +hoistSpreadAggFunctions tree = Right $ fst $ applySpreadAggHoistingToNode tree + +applySpreadAggHoistingToNode :: ReadPlanTree -> (ReadPlanTree, [HoistedAgg]) +applySpreadAggHoistingToNode (Node rp@ReadPlan{relAggAlias, relToParent, relIsSpread} children) = + let (newChildren, childAggLists) = unzip $ map applySpreadAggHoistingToNode children + allChildAggLists = concat childAggLists + (newSelects, aggList) = if depth rp == 0 || (isJust relToParent && not relIsSpread) + then (select rp, []) + else hoistFromSelectFields relAggAlias (select rp) + + newRelSelects = if null children + then relSelect rp + else map (hoistIntoRelSelectFields allChildAggLists) $ relSelect rp + in (Node rp { select = newSelects, relSelect = newRelSelects } newChildren, aggList) + +-- Hoist aggregate functions from the select list of a ReadPlan, and return the +-- updated select list and the list of hoisted aggregates. +hoistFromSelectFields :: Alias -> [CoercibleSelectField] -> ([CoercibleSelectField], [HoistedAgg]) +hoistFromSelectFields aggAlias fields = + let (newFields, maybeAggs) = foldr processField ([], []) fields + in (newFields, catMaybes maybeAggs) + where + processField field (newFields, aggList) = + let (modifiedField, maybeAgg) = modifyField field + in (modifiedField : newFields, maybeAgg : aggList) + + modifyField field = + case csAggFunction field of + Just aggFunc -> + ( field { csAggFunction = Nothing, csAggCast = Nothing }, + Just ((aggAlias, determineFieldName field), (aggFunc, csAggCast field, csAlias field))) + Nothing -> (field, Nothing) + + determineFieldName field = fromMaybe (cfName $ csField field) (csAlias field) + +-- Taking the hoisted aggregates, modify the rel selects to apply the aggregates, +-- and any applicable casts or aliases. +hoistIntoRelSelectFields :: [HoistedAgg] -> RelSelectField -> RelSelectField +hoistIntoRelSelectFields aggList r@(Spread {rsSpreadSel = spreadSelects, rsAggAlias = aggAlias}) = + r { rsSpreadSel = map updateSelect spreadSelects } + where + updateSelect s = + case lookup (aggAlias, ssSelName s) aggList of + Just (aggFunc, aggCast, fldAlias) -> + s { ssSelAggFunction = Just aggFunc, + ssSelAggCast = aggCast, + ssSelAlias = fldAlias } + Nothing -> s +hoistIntoRelSelectFields _ r = r + +validateAggFunctions :: Bool -> ReadPlanTree -> Either ApiRequestError ReadPlanTree +validateAggFunctions aggFunctionsAllowed (Node rp@ReadPlan {select} forest) + | aggFunctionsAllowed = Node rp <$> traverse (validateAggFunctions aggFunctionsAllowed) forest + | any (isJust . csAggFunction) select = Left AggregatesNotAllowed + | otherwise = Node rp <$> traverse (validateAggFunctions aggFunctionsAllowed) forest + addFilters :: ResolverContext -> ApiRequest -> ReadPlanTree -> Either ApiRequestError ReadPlanTree addFilters ctx ApiRequest{..} rReq = foldr addFilterToNode (Right rReq) flts @@ -608,7 +765,8 @@ addRelatedOrders (Node rp@ReadPlan{order,from} forest) = do -- relName = "projects", -- relToParent = Nothing, -- relJoinConds = [], --- relAlias = Nothing, relAggAlias = "clients_projects_1", relHint = Nothing, relJoinType = Nothing, relIsSpread = False, depth = 1 +-- relAlias = Nothing, relAggAlias = "clients_projects_1", relHint = Nothing, relJoinType = Nothing, relIsSpread = False, depth = 1, +-- relSelect = [] -- }, -- subForest = [] -- } @@ -633,7 +791,8 @@ addRelatedOrders (Node rp@ReadPlan{order,from} forest) = do -- ) -- ], -- order = [], range_ = fullRange, relName = "clients", relToParent = Nothing, relJoinConds = [], relAlias = Nothing, relAggAlias = "", relHint = Nothing, --- relJoinType = Nothing, relIsSpread = False, depth = 0 +-- relJoinType = Nothing, relIsSpread = False, depth = 0, +-- relSelect = [] -- }, -- subForest = subForst -- } @@ -789,7 +948,7 @@ inferColsEmbedNeeds (Node ReadPlan{select} forest) pkCols | "*" `elem` fldNames = ["*"] | otherwise = returnings where - fldNames = cfName . (\(f, _, _) -> f) <$> select + fldNames = cfName . csField <$> select -- Without fkCols, when a mutatePlan to -- /projects?select=name,clients(name) occurs, the RETURNING SQL part would -- be `RETURNING name`(see QueryBuilder). This would make the embedding diff --git a/src/PostgREST/Plan/ReadPlan.hs b/src/PostgREST/Plan/ReadPlan.hs index f0de4430a4..854cf1ffa7 100644 --- a/src/PostgREST/Plan/ReadPlan.hs +++ b/src/PostgREST/Plan/ReadPlan.hs @@ -6,11 +6,12 @@ module PostgREST.Plan.ReadPlan import Data.Tree (Tree (..)) -import PostgREST.ApiRequest.Types (Alias, Cast, Depth, Hint, +import PostgREST.ApiRequest.Types (Alias, Depth, Hint, JoinType, NodeName) -import PostgREST.Plan.Types (CoercibleField (..), - CoercibleLogicTree, - CoercibleOrderTerm) +import PostgREST.Plan.Types (CoercibleLogicTree, + CoercibleOrderTerm, + CoercibleSelectField (..), + RelSelectField (..)) import PostgREST.RangeQuery (NonnegRange) import PostgREST.SchemaCache.Identifiers (FieldName, QualifiedIdentifier) @@ -28,7 +29,7 @@ data JoinCondition = deriving (Eq, Show) data ReadPlan = ReadPlan - { select :: [(CoercibleField, Maybe Cast, Maybe Alias)] + { select :: [CoercibleSelectField] , from :: QualifiedIdentifier , fromAlias :: Maybe Alias , where_ :: [CoercibleLogicTree] @@ -42,6 +43,7 @@ data ReadPlan = ReadPlan , relHint :: Maybe Hint , relJoinType :: Maybe JoinType , relIsSpread :: Bool + , relSelect :: [RelSelectField] , depth :: Depth -- ^ used for aliasing } diff --git a/src/PostgREST/Plan/Types.hs b/src/PostgREST/Plan/Types.hs index c9267e3d90..97de469952 100644 --- a/src/PostgREST/Plan/Types.hs +++ b/src/PostgREST/Plan/Types.hs @@ -1,13 +1,18 @@ module PostgREST.Plan.Types ( CoercibleField(..) + , CoercibleSelectField(..) , unknownField , CoercibleLogicTree(..) , CoercibleFilter(..) , TransformerProc , CoercibleOrderTerm(..) + , RelSelectField(..) + , RelJsonEmbedMode(..) + , SpreadSelectField(..) ) where -import PostgREST.ApiRequest.Types (Field, JsonPath, LogicOperator, +import PostgREST.ApiRequest.Types (AggregateFunction, Alias, Cast, + Field, JsonPath, LogicOperator, OpExpr, OrderDirection, OrderNulls) import PostgREST.SchemaCache.Identifiers (FieldName) @@ -65,3 +70,37 @@ data CoercibleOrderTerm , coNullOrder :: Maybe OrderNulls } deriving (Eq, Show) + +data CoercibleSelectField = CoercibleSelectField + { csField :: CoercibleField + , csAggFunction :: Maybe AggregateFunction + , csAggCast :: Maybe Cast + , csCast :: Maybe Cast + , csAlias :: Maybe Alias + } + deriving (Eq, Show) + +data RelJsonEmbedMode = JsonObject | JsonArray + deriving (Show, Eq) + +data RelSelectField + = JsonEmbed + { rsSelName :: FieldName + , rsAggAlias :: Alias + , rsEmbedMode :: RelJsonEmbedMode + , rsEmptyEmbed :: Bool + } + | Spread + { rsSpreadSel :: [SpreadSelectField] + , rsAggAlias :: Alias + } + deriving (Eq, Show) + +data SpreadSelectField = + SpreadSelectField + { ssSelName :: FieldName + , ssSelAggFunction :: Maybe AggregateFunction + , ssSelAggCast :: Maybe Cast + , ssSelAlias :: Maybe Alias + } + deriving (Eq, Show) diff --git a/src/PostgREST/Query/QueryBuilder.hs b/src/PostgREST/Query/QueryBuilder.hs index c81772b810..0d51c55484 100644 --- a/src/PostgREST/Query/QueryBuilder.hs +++ b/src/PostgREST/Query/QueryBuilder.hs @@ -19,7 +19,8 @@ module PostgREST.Query.QueryBuilder import qualified Data.ByteString.Char8 as BS import qualified Hasql.DynamicStatements.Snippet as SQL -import Data.Tree (Tree (..)) +import Data.Maybe (fromJust) +import Data.Tree (Tree (..)) import PostgREST.ApiRequest.Preferences (PreferResolution (..)) import PostgREST.Config.PgVersion (PgVersion, pgVersion110, @@ -27,8 +28,7 @@ import PostgREST.Config.PgVersion (PgVersion, pgVersion110, import PostgREST.SchemaCache.Identifiers (QualifiedIdentifier (..)) import PostgREST.SchemaCache.Relationship (Cardinality (..), Junction (..), - Relationship (..), - relIsToOne) + Relationship (..)) import PostgREST.SchemaCache.Routine (RoutineParam (..)) import PostgREST.ApiRequest.Types @@ -42,45 +42,70 @@ import PostgREST.RangeQuery (allRange) import Protolude readPlanToQuery :: ReadPlanTree -> SQL.Snippet -readPlanToQuery (Node ReadPlan{select,from=mainQi,fromAlias,where_=logicForest,order, range_=readRange, relToParent, relJoinConds} forest) = +readPlanToQuery node@(Node ReadPlan{select,from=mainQi,fromAlias,where_=logicForest,order, range_=readRange, relToParent, relJoinConds, relSelect} forest) = "SELECT " <> - intercalateSnippet ", " ((pgFmtSelectItem qi <$> (if null select && null forest then defSelect else select)) ++ selects) <> " " <> + intercalateSnippet ", " ((pgFmtSelectItem qi <$> (if null select && null forest then defSelect else select)) ++ joinsSelects) <> " " <> fromFrag <> " " <> intercalateSnippet " " joins <> " " <> (if null logicForest && null relJoinConds then mempty else "WHERE " <> intercalateSnippet " AND " (map (pgFmtLogicTree qi) logicForest ++ map pgFmtJoinCondition relJoinConds)) <> " " <> + groupF qi select relSelect <> " " <> orderF qi order <> " " <> limitOffsetF readRange where fromFrag = fromF relToParent mainQi fromAlias qi = getQualifiedIdentifier relToParent mainQi fromAlias - defSelect = [(unknownField "*" [], Nothing, Nothing)] -- gets all the columns in case of an empty select, ignoring/obtaining these columns is done at the aggregation stage - (selects, joins) = foldr getSelectsJoins ([],[]) forest + -- gets all the columns in case of an empty select, ignoring/obtaining these columns is done at the aggregation stage + defSelect = [CoercibleSelectField (unknownField "*" []) Nothing Nothing Nothing Nothing] + joins = getJoins node + joinsSelects = getJoinSelects node -getSelectsJoins :: ReadPlanTree -> ([SQL.Snippet], [SQL.Snippet]) -> ([SQL.Snippet], [SQL.Snippet]) -getSelectsJoins (Node ReadPlan{relToParent=Nothing} _) _ = ([], []) -getSelectsJoins rr@(Node ReadPlan{select, relName, relToParent=Just rel, relAggAlias, relAlias, relJoinType, relIsSpread} forest) (selects,joins) = +getJoinSelects :: ReadPlanTree -> [SQL.Snippet] +getJoinSelects (Node ReadPlan{relSelect} _) = + mapMaybe relSelectToSnippet relSelect + where + relSelectToSnippet :: RelSelectField -> Maybe SQL.Snippet + relSelectToSnippet fld = + let aggAlias = pgFmtIdent $ rsAggAlias fld + in + case fld of + JsonEmbed{rsEmptyEmbed = True} -> + Nothing + JsonEmbed{rsSelName, rsEmbedMode = JsonObject} -> + Just $ "row_to_json(" <> aggAlias <> ".*)::jsonb AS " <> pgFmtIdent rsSelName + JsonEmbed{rsSelName, rsEmbedMode = JsonArray} -> + Just $ "COALESCE( " <> aggAlias <> "." <> aggAlias <> ", '[]') AS " <> pgFmtIdent rsSelName + Spread{rsSpreadSel, rsAggAlias} -> + Just $ intercalateSnippet ", " (pgFmtSpreadSelectItem rsAggAlias <$> rsSpreadSel) + +getJoins :: ReadPlanTree -> [SQL.Snippet] +getJoins (Node _ []) = [] +getJoins (Node ReadPlan{relSelect} forest) = + map (\fld -> + let alias = rsAggAlias fld + matchingNode = fromJust $ find (\(Node ReadPlan{relAggAlias} _) -> alias == relAggAlias) forest + in getJoin fld matchingNode + ) relSelect + +getJoin :: RelSelectField -> ReadPlanTree -> SQL.Snippet +getJoin fld node@(Node ReadPlan{relJoinType} _) = let - subquery = readPlanToQuery rr - aliasOrName = pgFmtIdent $ fromMaybe relName relAlias - aggAlias = pgFmtIdent relAggAlias correlatedSubquery sub al cond = (if relJoinType == Just JTInner then "INNER" else "LEFT") <> " JOIN LATERAL ( " <> sub <> " ) AS " <> al <> " ON " <> cond - (sel, joi) = if relIsToOne rel - then - ( if relIsSpread - then aggAlias <> ".*" - else "row_to_json(" <> aggAlias <> ".*) AS " <> aliasOrName - , correlatedSubquery subquery aggAlias "TRUE") - else - ( "COALESCE( " <> aggAlias <> "." <> aggAlias <> ", '[]') AS " <> aliasOrName - , correlatedSubquery ( - "SELECT json_agg(" <> aggAlias <> ") AS " <> aggAlias <> - "FROM (" <> subquery <> " ) AS " <> aggAlias - ) aggAlias $ if relJoinType == Just JTInner then aggAlias <> " IS NOT NULL" else "TRUE") + subquery = readPlanToQuery node + aggAlias = pgFmtIdent $ rsAggAlias fld in - (if null select && null forest then selects else sel:selects, joi:joins) + case fld of + JsonEmbed{rsEmbedMode = JsonObject} -> + correlatedSubquery subquery aggAlias "TRUE" + Spread{} -> + correlatedSubquery subquery aggAlias "TRUE" + JsonEmbed{rsEmbedMode = JsonArray} -> + let + subq = "SELECT json_agg(" <> aggAlias <> ")::jsonb AS " <> aggAlias <> " FROM (" <> subquery <> " ) AS " <> aggAlias + condition = if relJoinType == Just JTInner then aggAlias <> " IS NOT NULL" else "TRUE" + in correlatedSubquery subq aggAlias condition mutatePlanToQuery :: MutatePlan -> SQL.Snippet mutatePlanToQuery (Insert mainQi iCols body onConflct putConditions returnings _ applyDefaults) = diff --git a/src/PostgREST/Query/SqlFragment.hs b/src/PostgREST/Query/SqlFragment.hs index 94cafa956a..03f6177bd2 100644 --- a/src/PostgREST/Query/SqlFragment.hs +++ b/src/PostgREST/Query/SqlFragment.hs @@ -9,6 +9,7 @@ module PostgREST.Query.SqlFragment ( noLocationF , handlerF , countF + , groupF , fromQi , limitOffsetF , locationF @@ -21,6 +22,7 @@ module PostgREST.Query.SqlFragment , pgFmtLogicTree , pgFmtOrderTerm , pgFmtSelectItem + , pgFmtSpreadSelectItem , fromJsonBodyF , responseHeadersF , responseStatusF @@ -54,7 +56,8 @@ import Control.Arrow ((***)) import Data.Foldable (foldr1) import Text.InterpolatedString.Perl6 (qc) -import PostgREST.ApiRequest.Types (Alias, Cast, +import PostgREST.ApiRequest.Types (AggregateFunction (..), + Alias, Cast, FtsOperator (..), JsonOperand (..), JsonOperation (..), @@ -75,6 +78,9 @@ import PostgREST.Plan.Types (CoercibleField (..), CoercibleFilter (..), CoercibleLogicTree (..), CoercibleOrderTerm (..), + CoercibleSelectField (..), + RelSelectField (..), + SpreadSelectField (..), unknownField) import PostgREST.RangeQuery (NonnegRange, allRange, rangeLimit, rangeOffset) @@ -86,7 +92,7 @@ import PostgREST.SchemaCache.Routine (MediaHandler (..), funcReturnsSetOfScalar, funcReturnsSingleComposite) -import Protolude hiding (cast) +import Protolude hiding (Sum, cast) sourceCTEName :: Text sourceCTEName = "pgrst_source" @@ -258,12 +264,34 @@ pgFmtCoerceNamed :: CoercibleField -> SQL.Snippet pgFmtCoerceNamed CoercibleField{cfName=fn, cfTransform=(Just formatterProc)} = pgFmtCallUnary formatterProc (pgFmtIdent fn) <> " AS " <> pgFmtIdent fn pgFmtCoerceNamed CoercibleField{cfName=fn} = pgFmtIdent fn -pgFmtSelectItem :: QualifiedIdentifier -> (CoercibleField, Maybe Cast, Maybe Alias) -> SQL.Snippet -pgFmtSelectItem table (fld, Nothing, alias) = pgFmtTableCoerce table fld <> pgFmtAs (cfName fld) (cfJsonPath fld) alias +pgFmtSelectItem :: QualifiedIdentifier -> CoercibleSelectField -> SQL.Snippet +pgFmtSelectItem table CoercibleSelectField{csField=fld, csAggFunction=agg, csAggCast=aggCast, csCast=cast, csAlias=alias} = + pgFmtApplyAggregate agg aggCast (pgFmtApplyCast cast (pgFmtTableCoerce table fld)) <> pgFmtAs alias + +pgFmtSpreadSelectItem :: Alias -> SpreadSelectField -> SQL.Snippet +pgFmtSpreadSelectItem aggAlias SpreadSelectField{ssSelName, ssSelAggFunction, ssSelAggCast, ssSelAlias} = + pgFmtApplyAggregate ssSelAggFunction ssSelAggCast fullSelName <> pgFmtAs ssSelAlias + where + fullSelName = case ssSelName of + "*" -> pgFmtIdent aggAlias <> ".*" + _ -> pgFmtIdent aggAlias <> "." <> pgFmtIdent ssSelName + +pgFmtApplyAggregate :: Maybe AggregateFunction -> Maybe Cast -> SQL.Snippet -> SQL.Snippet +pgFmtApplyAggregate Nothing _ snippet = snippet +pgFmtApplyAggregate (Just agg) aggCast snippet = + pgFmtApplyCast aggCast aggregatedSnippet + where + convertAggFunction :: AggregateFunction -> SQL.Snippet + -- Convert from e.g. Sum (the data type) to SUM + convertAggFunction = SQL.sql . BS.map toUpper . BS.pack . show + aggregatedSnippet = convertAggFunction agg <> "(" <> snippet <> ")" + +pgFmtApplyCast :: Maybe Cast -> SQL.Snippet -> SQL.Snippet +pgFmtApplyCast Nothing snippet = snippet -- Ideally we'd quote the cast with "pgFmtIdent cast". However, that would invalidate common casts such as "int", "bigint", etc. -- Try doing: `select 1::"bigint"` - it'll err, using "int8" will work though. There's some parser magic that pg does that's invalidated when quoting. -- Not quoting should be fine, we validate the input on Parsers. -pgFmtSelectItem table (fld, Just cast, alias) = "CAST (" <> pgFmtTableCoerce table fld <> " AS " <> SQL.sql (encodeUtf8 cast) <> " )" <> pgFmtAs (cfName fld) (cfJsonPath fld) alias +pgFmtApplyCast (Just cast) snippet = "CAST( " <> snippet <> " AS " <> SQL.sql (encodeUtf8 cast) <> " )" -- TODO: At this stage there shouldn't be a Maybe since ApiRequest should ensure that an INSERT/UPDATE has a body fromJsonBodyF :: Maybe LBS.ByteString -> [CoercibleField] -> Bool -> Bool -> Bool -> SQL.Snippet @@ -395,17 +423,40 @@ pgFmtJsonPath = \case pgFmtJsonOperand (JKey k) = unknownLiteral k pgFmtJsonOperand (JIdx i) = unknownLiteral i <> "::int" -pgFmtAs :: FieldName -> JsonPath -> Maybe Alias -> SQL.Snippet -pgFmtAs _ [] Nothing = mempty -pgFmtAs fName jp Nothing = case jOp <$> lastMay jp of - Just (JKey key) -> " AS " <> pgFmtIdent key - Just (JIdx _) -> " AS " <> pgFmtIdent (fromMaybe fName lastKey) - -- We get the lastKey because on: - -- `select=data->1->mycol->>2`, we need to show the result as [ {"mycol": ..}, {"mycol": ..} ] - -- `select=data->3`, we need to show the result as [ {"data": ..}, {"data": ..} ] - where lastKey = jVal <$> find (\case JKey{} -> True; _ -> False) (jOp <$> reverse jp) - Nothing -> mempty -pgFmtAs _ _ (Just alias) = " AS " <> pgFmtIdent alias +pgFmtAs :: Maybe Alias -> SQL.Snippet +pgFmtAs Nothing = mempty +pgFmtAs (Just alias) = " AS " <> pgFmtIdent alias + +groupF :: QualifiedIdentifier -> [CoercibleSelectField] -> [RelSelectField] -> SQL.Snippet +groupF qi select relSelect + | (noSelectsAreAggregated && noRelSelectsAreAggregated) || null groupTerms = mempty + | otherwise = " GROUP BY " <> intercalateSnippet ", " groupTerms + where + noSelectsAreAggregated = null $ [s | s@(CoercibleSelectField { csAggFunction = Just _ }) <- select] + noRelSelectsAreAggregated = all (\case Spread sels _ -> all (isNothing . ssSelAggFunction) sels; _ -> True) relSelect + groupTermsFromSelect = mapMaybe (pgFmtGroup qi) select + groupTermsFromRelSelect = mapMaybe groupTermFromRelSelectField relSelect + groupTerms = groupTermsFromSelect ++ groupTermsFromRelSelect + +groupTermFromRelSelectField :: RelSelectField -> Maybe SQL.Snippet +groupTermFromRelSelectField (JsonEmbed { rsSelName }) = + Just $ pgFmtIdent rsSelName +groupTermFromRelSelectField (Spread { rsSpreadSel, rsAggAlias }) = + if null groupTerms + then Nothing + else + Just $ intercalateSnippet ", " groupTerms + where + processField :: SpreadSelectField -> Maybe SQL.Snippet + processField SpreadSelectField{ssSelAggFunction = Just _} = Nothing + processField SpreadSelectField{ssSelName, ssSelAlias} = + Just $ pgFmtIdent rsAggAlias <> "." <> pgFmtIdent (fromMaybe ssSelName ssSelAlias) + groupTerms = mapMaybe processField rsSpreadSel + +pgFmtGroup :: QualifiedIdentifier -> CoercibleSelectField -> Maybe SQL.Snippet +pgFmtGroup _ CoercibleSelectField{csAggFunction=Just _} = Nothing +pgFmtGroup _ CoercibleSelectField{csAlias=Just alias, csAggFunction=Nothing} = Just $ pgFmtIdent alias +pgFmtGroup qi CoercibleSelectField{csField=fld, csAlias=Nothing, csAggFunction=Nothing} = Just $ pgFmtField qi fld countF :: SQL.Snippet -> Bool -> (SQL.Snippet, SQL.Snippet) countF countQuery shouldCount = diff --git a/test/io/configs/expected/aliases.config b/test/io/configs/expected/aliases.config index bb67b29648..bf2df05a11 100644 --- a/test/io/configs/expected/aliases.config +++ b/test/io/configs/expected/aliases.config @@ -1,3 +1,4 @@ +db-aggregates-enabled = false db-anon-role = "" db-channel = "pgrst" db-channel-enabled = true diff --git a/test/io/configs/expected/boolean-numeric.config b/test/io/configs/expected/boolean-numeric.config index 9bd66476dc..1359f09fef 100644 --- a/test/io/configs/expected/boolean-numeric.config +++ b/test/io/configs/expected/boolean-numeric.config @@ -1,3 +1,4 @@ +db-aggregates-enabled = false db-anon-role = "" db-channel = "pgrst" db-channel-enabled = true diff --git a/test/io/configs/expected/boolean-string.config b/test/io/configs/expected/boolean-string.config index 9bd66476dc..1359f09fef 100644 --- a/test/io/configs/expected/boolean-string.config +++ b/test/io/configs/expected/boolean-string.config @@ -1,3 +1,4 @@ +db-aggregates-enabled = false db-anon-role = "" db-channel = "pgrst" db-channel-enabled = true diff --git a/test/io/configs/expected/defaults.config b/test/io/configs/expected/defaults.config index 30168ee149..aefd98aa0d 100644 --- a/test/io/configs/expected/defaults.config +++ b/test/io/configs/expected/defaults.config @@ -1,3 +1,4 @@ +db-aggregates-enabled = false db-anon-role = "" db-channel = "pgrst" db-channel-enabled = true diff --git a/test/io/configs/expected/no-defaults-with-db-other-authenticator.config b/test/io/configs/expected/no-defaults-with-db-other-authenticator.config index e7837011de..b8d0b018b0 100644 --- a/test/io/configs/expected/no-defaults-with-db-other-authenticator.config +++ b/test/io/configs/expected/no-defaults-with-db-other-authenticator.config @@ -1,3 +1,4 @@ +db-aggregates-enabled = false db-anon-role = "pre_config_role" db-channel = "postgrest" db-channel-enabled = false diff --git a/test/io/configs/expected/no-defaults-with-db.config b/test/io/configs/expected/no-defaults-with-db.config index ac4d87ef54..2cb69cb722 100644 --- a/test/io/configs/expected/no-defaults-with-db.config +++ b/test/io/configs/expected/no-defaults-with-db.config @@ -1,3 +1,4 @@ +db-aggregates-enabled = false db-anon-role = "anonymous" db-channel = "postgrest" db-channel-enabled = false diff --git a/test/io/configs/expected/no-defaults.config b/test/io/configs/expected/no-defaults.config index 09dda558f3..2a45b21df0 100644 --- a/test/io/configs/expected/no-defaults.config +++ b/test/io/configs/expected/no-defaults.config @@ -1,3 +1,4 @@ +db-aggregates-enabled = true db-anon-role = "root" db-channel = "postgrest" db-channel-enabled = false diff --git a/test/io/configs/expected/types.config b/test/io/configs/expected/types.config index d7d6429312..e6d0328b65 100644 --- a/test/io/configs/expected/types.config +++ b/test/io/configs/expected/types.config @@ -1,3 +1,4 @@ +db-aggregates-enabled = false db-anon-role = "" db-channel = "pgrst" db-channel-enabled = true diff --git a/test/io/configs/no-defaults-env.yaml b/test/io/configs/no-defaults-env.yaml index 709a149de1..4f03111a28 100644 --- a/test/io/configs/no-defaults-env.yaml +++ b/test/io/configs/no-defaults-env.yaml @@ -1,5 +1,6 @@ PGRST_APP_SETTINGS_test2: test PGRST_APP_SETTINGS_test: test +PGRST_DB_AGGREGATES_ENABLED: true PGRST_DB_ANON_ROLE: root PGRST_DB_CHANNEL: postgrest PGRST_DB_CHANNEL_ENABLED: false diff --git a/test/io/configs/no-defaults.config b/test/io/configs/no-defaults.config index dbd18aee27..d9944c352a 100644 --- a/test/io/configs/no-defaults.config +++ b/test/io/configs/no-defaults.config @@ -1,3 +1,4 @@ +db-aggregates-enabled = true db-anon-role = "root" db-channel = "postgrest" db-channel-enabled = false diff --git a/test/io/db_config.sql b/test/io/db_config.sql index 78bcd0ae03..265b403178 100644 --- a/test/io/db_config.sql +++ b/test/io/db_config.sql @@ -6,6 +6,7 @@ ALTER ROLE db_config_authenticator SET pgrst.openapi_server_proxy_uri = 'https:/ ALTER ROLE db_config_authenticator SET pgrst.jwt_secret = 'REALLY=REALLY=REALLY=REALLY=VERY=SAFE'; ALTER ROLE db_config_authenticator SET pgrst.jwt_secret_is_base64 = 'false'; ALTER ROLE db_config_authenticator SET pgrst.jwt_role_claim_key = '."a"."role"'; +ALTER ROLE db_config_authenticator SET pgrst.db_aggregates_enabled = 'false'; ALTER ROLE db_config_authenticator SET pgrst.db_anon_role = 'anonymous'; ALTER ROLE db_config_authenticator SET pgrst.db_tx_end = 'commit-allow-override'; ALTER ROLE db_config_authenticator SET pgrst.db_pre_config = 'postgrest.preconf'; @@ -53,6 +54,7 @@ ALTER ROLE other_authenticator SET pgrst.jwt_aud = 'https://otherexample.org'; ALTER ROLE other_authenticator SET pgrst.openapi_server_proxy_uri = 'https://otherexample.org/api'; ALTER ROLE other_authenticator SET pgrst.jwt_secret = 'ODERREALLYREALLYREALLYREALLYVERYSAFE'; ALTER ROLE other_authenticator SET pgrst.jwt_secret_is_base64 = 'true'; +ALTER ROLE other_authenticator SET pgrst.db_aggregates_enabled = 'false'; ALTER ROLE other_authenticator SET pgrst.db_schemas = 'test, other_tenant1, other_tenant2'; ALTER ROLE other_authenticator SET pgrst.db_root_spec = 'other_root'; ALTER ROLE other_authenticator SET pgrst.db_plan_enabled = 'true'; diff --git a/test/spec/Feature/Query/AggregateFunctionsSpec.hs b/test/spec/Feature/Query/AggregateFunctionsSpec.hs new file mode 100644 index 0000000000..5030fa3f2b --- /dev/null +++ b/test/spec/Feature/Query/AggregateFunctionsSpec.hs @@ -0,0 +1,168 @@ +module Feature.Query.AggregateFunctionsSpec where + +import Network.Wai (Application) + +import Test.Hspec hiding (pendingWith) +import Test.Hspec.Wai +import Test.Hspec.Wai.JSON + +import Protolude hiding (get) +import SpecHelper + +allowed :: SpecWith ((), Application) +allowed = + describe "aggregate functions" $ do + context "performing a count without specifying a field" $ do + it "returns the count of all rows when no other fields are selected" $ + get "/entities?select=count()" `shouldRespondWith` + [json|[{ "count": 4 }]|] { matchHeaders = [matchContentTypeJson] } + it "allows you to specify an alias for the count" $ + get "/entities?select=cnt:count()" `shouldRespondWith` + [json|[{ "cnt": 4 }]|] { matchHeaders = [matchContentTypeJson] } + it "allows you to cast the result of the count" $ + get "/entities?select=count()::text" `shouldRespondWith` + [json|[{ "count": "4" }]|] { matchHeaders = [matchContentTypeJson] } + it "returns the count grouped by all provided fields when other fields are selected" $ + get "/projects?select=c:count(),client_id&order=client_id.desc" `shouldRespondWith` + [json|[{ "c": 1, "client_id": null }, { "c": 2, "client_id": 2 }, { "c": 2, "client_id": 1}]|] { matchHeaders = [matchContentTypeJson] } + + context "performing a count by using it as a column (backwards compat)" $ do + it "returns the count of all rows when no other fields are selected" $ + get "/entities?select=count" `shouldRespondWith` + [json|[{ "count": 4 }]|] { matchHeaders = [matchContentTypeJson] } + it "returns the embedded count of another resource" $ + get "/clients?select=name,projects(count)'" `shouldRespondWith` + [json|[{"name":"Microsoft","projects":[{"count": 2}]}, {"name":"Apple","projects":[{"count": 2}]}]|] { matchHeaders = [matchContentTypeJson] } + + context "performing an aggregation on one or more fields" $ do + it "supports sum()" $ + get "/project_invoices?select=invoice_total.sum()" `shouldRespondWith` + [json|[{"sum":8800}]|] { matchHeaders = [matchContentTypeJson] } + it "supports avg()" $ + get "/project_invoices?select=invoice_total.avg()" `shouldRespondWith` + [json|[{"avg":1100.0000000000000000}]|] { matchHeaders = [matchContentTypeJson] } + it "supports min()" $ + get "/project_invoices?select=invoice_total.min()" `shouldRespondWith` + [json|[{ "min": 100 }]|] { matchHeaders = [matchContentTypeJson] } + it "supports max()" $ + get "/project_invoices?select=invoice_total.max()" `shouldRespondWith` + [json|[{ "max": 4000 }]|] { matchHeaders = [matchContentTypeJson] } + it "supports count()" $ + get "/project_invoices?select=invoice_total.count()" `shouldRespondWith` + [json|[{ "count": 8 }]|] { matchHeaders = [matchContentTypeJson] } + it "groups by any fields selected that do not have an aggregate applied" $ + get "/project_invoices?select=invoice_total.sum(),invoice_total.max(),invoice_total.min(),project_id&order=project_id.desc" `shouldRespondWith` + [json|[ + {"sum":4100,"max":4000,"min":100,"project_id":4}, + {"sum":3200,"max":2000,"min":1200,"project_id":3}, + {"sum":1200,"max":700,"min":500,"project_id":2}, + {"sum":300,"max":200,"min":100,"project_id":1} ]|] + { matchHeaders = [matchContentTypeJson] } + it "supports the use of aliases on fields that will be used in the group by" $ + get "/project_invoices?select=invoice_total.sum(),invoice_total.max(),invoice_total.min(),pid:project_id&order=project_id.desc" `shouldRespondWith` + [json|[ + {"sum":4100,"max":4000,"min":100,"pid":4}, + {"sum":3200,"max":2000,"min":1200,"pid":3}, + {"sum":1200,"max":700,"min":500,"pid":2}, + {"sum":300,"max":200,"min":100,"pid":1}]|] + { matchHeaders = [matchContentTypeJson] } + it "allows you to specify an alias for the aggregate" $ + get "/project_invoices?select=total_charged:invoice_total.sum(),project_id&order=project_id.desc" `shouldRespondWith` + [json|[ + {"total_charged":4100,"project_id":4}, + {"total_charged":3200,"project_id":3}, + {"total_charged":1200,"project_id":2}, + {"total_charged":300,"project_id":1}]|] { matchHeaders = [matchContentTypeJson] } + it "allows you to cast the result of the aggregate" $ + get "/project_invoices?select=total_charged:invoice_total.sum()::text,project_id&order=project_id.desc" `shouldRespondWith` + [json|[ + {"total_charged":"4100","project_id":4}, + {"total_charged":"3200","project_id":3}, + {"total_charged":"1200","project_id":2}, + {"total_charged":"300","project_id":1}]|] { matchHeaders = [matchContentTypeJson] } + it "allows you to cast the input argument of the aggregate" $ + get "/trash_details?select=jsonb_col->>key::integer.sum()" `shouldRespondWith` + [json|[{"sum": 24}]|] { matchHeaders = [matchContentTypeJson] } + it "allows the combination of an alias, a before cast, and an after cast" $ + get "/trash_details?select=s:jsonb_col->>key::integer.sum()::text" `shouldRespondWith` + [json|[{"s": "24"}]|] { matchHeaders = [matchContentTypeJson] } + it "supports use of aggregates on RPC functions that return table values" $ + get "/rpc/getallprojects?select=id.max()" `shouldRespondWith` + [json|[{"max": 5}]|] { matchHeaders = [matchContentTypeJson] } + it "allows the use of an JSON-embedded relationship column as part of the group by" $ + get "/project_invoices?select=project_id,total:invoice_total.sum(),projects(name)&order=project_id" `shouldRespondWith` + [json|[ + {"project_id": 1, "total": 300, "projects": {"name": "Windows 7"}}, + {"project_id": 2, "total": 1200, "projects": {"name": "Windows 10"}}, + {"project_id": 3, "total": 3200, "projects": {"name": "IOS"}}, + {"project_id": 4, "total": 4100, "projects": {"name": "OSX"}}]|] { matchHeaders = [matchContentTypeJson] } + context "performing aggregations that involve JSON-embedded relationships" $ do + it "supports sum()" $ + get "/projects?select=name,project_invoices(invoice_total.sum())" `shouldRespondWith` + [json|[ + {"name":"Windows 7","project_invoices":[{"sum": 300}]}, + {"name":"Windows 10","project_invoices":[{"sum": 1200}]}, + {"name":"IOS","project_invoices":[{"sum": 3200}]}, + {"name":"OSX","project_invoices":[{"sum": 4100}]}, + {"name":"Orphan","project_invoices":[{"sum": null}]}]|] + { matchHeaders = [matchContentTypeJson] } + it "supports max()" $ + get "/projects?select=name,project_invoices(invoice_total.max())" `shouldRespondWith` + [json|[{"name":"Windows 7","project_invoices":[{"max": 200}]}, + {"name":"Windows 10","project_invoices":[{"max": 700}]}, + {"name":"IOS","project_invoices":[{"max": 2000}]}, + {"name":"OSX","project_invoices":[{"max": 4000}]}, + {"name":"Orphan","project_invoices":[{"max": null}]}]|] + { matchHeaders = [matchContentTypeJson] } + it "supports avg()" $ + get "/projects?select=name,project_invoices(invoice_total.avg())" `shouldRespondWith` + [json|[{"name":"Windows 7","project_invoices":[{"avg": 150.0000000000000000}]}, + {"name":"Windows 10","project_invoices":[{"avg": 600.0000000000000000}]}, + {"name":"IOS","project_invoices":[{"avg": 1600.0000000000000000}]}, + {"name":"OSX","project_invoices":[{"avg": 2050.0000000000000000}]}, + {"name":"Orphan","project_invoices":[{"avg": null}]}]|] + { matchHeaders = [matchContentTypeJson] } + it "supports min()" $ + get "/projects?select=name,project_invoices(invoice_total.min())" `shouldRespondWith` + [json|[{"name":"Windows 7","project_invoices":[{"min": 100}]}, + {"name":"Windows 10","project_invoices":[{"min": 500}]}, + {"name":"IOS","project_invoices":[{"min": 1200}]}, + {"name":"OSX","project_invoices":[{"min": 100}]}, + {"name":"Orphan","project_invoices":[{"min": null}]}]|] + { matchHeaders = [matchContentTypeJson] } + it "supports all at once" $ + get "/projects?select=name,project_invoices(invoice_total.max(),invoice_total.min(),invoice_total.avg(),invoice_total.sum(),invoice_total.count())" `shouldRespondWith` + [json|[ + {"name":"Windows 7","project_invoices":[{"avg": 150.0000000000000000, "max": 200, "min": 100, "sum": 300, "count": 2}]}, + {"name":"Windows 10","project_invoices":[{"avg": 600.0000000000000000, "max": 700, "min": 500, "sum": 1200, "count": 2}]}, + {"name":"IOS","project_invoices":[{"avg": 1600.0000000000000000, "max": 2000, "min": 1200, "sum": 3200, "count": 2}]}, + {"name":"OSX","project_invoices":[{"avg": 2050.0000000000000000, "max": 4000, "min": 100, "sum": 4100, "count": 2}]}, + {"name":"Orphan","project_invoices":[{"avg": null, "max": null, "min": null, "sum": null, "count": 0}]}]|] + { matchHeaders = [matchContentTypeJson] } + + context "performing aggregations on spreaded fields from an embedded resource" $ do + it "supports the use of aggregates on spreaded fields" $ do + get "/budget_expenses?select=total_expenses:expense_amount.sum(),...budget_categories(budget_owner,total_budget:budget_amount.sum())&order=budget_categories(budget_owner)" `shouldRespondWith` + [json|[ + {"total_expenses": 600.52,"budget_owner": "Brian Smith", "total_budget": 2000.42}, + {"total_expenses": 100.22, "budget_owner": "Jane Clarkson","total_budget": 7000.41}, + {"total_expenses": 900.27, "budget_owner": "Sally Hughes", "total_budget": 500.23}]|] + { matchHeaders = [matchContentTypeJson] } + it "supports the use of aggregates on spreaded fields when only aggregates are supplied" $ do + get "/budget_expenses?select=...budget_categories(total_budget:budget_amount.sum())" `shouldRespondWith` + [json|[{"total_budget": 9501.06}]|] + { matchHeaders = [matchContentTypeJson] } + +disallowed :: SpecWith ((), Application) +disallowed = + describe "attempting to use an aggregate when aggregate functions are disallowed" $ do + it "prevents the use of aggregates" $ + get "/project_invoices?select=invoice_total.sum()" `shouldRespondWith` + [json|{ + "hint":null, + "details":null, + "code":"PGRST123", + "message":"Use of aggregate functions is not allowed" + }|] + { matchStatus = 400 + , matchHeaders = [matchContentTypeJson] } diff --git a/test/spec/Feature/Query/PlanSpec.hs b/test/spec/Feature/Query/PlanSpec.hs index dcd2373bef..704fe877e0 100644 --- a/test/spec/Feature/Query/PlanSpec.hs +++ b/test/spec/Feature/Query/PlanSpec.hs @@ -348,12 +348,12 @@ spec actualPgVersion = do r1 <- request methodGet "/users?select=*,tasks!inner(*)&tasks.id=eq.1" [planHdr] "" - liftIO $ planCost r1 `shouldSatisfy` (< 20876.14) + liftIO $ planCost r1 `shouldSatisfy` (< 20888.83) r2 <- request methodGet "/users?select=*,tasks(*)&tasks.id=eq.1&tasks=not.is.null" [planHdr] "" - liftIO $ planCost r2 `shouldSatisfy` (< 20876.14) + liftIO $ planCost r2 `shouldSatisfy` (< 20888.83) describe "function call costs" $ do it "should not exceed cost when calling setof composite proc" $ do diff --git a/test/spec/Main.hs b/test/spec/Main.hs index 86ced44310..369db9a96e 100644 --- a/test/spec/Main.hs +++ b/test/spec/Main.hs @@ -34,6 +34,7 @@ import qualified Feature.OpenApi.ProxySpec import qualified Feature.OpenApi.RootSpec import qualified Feature.OpenApi.SecurityOpenApiSpec import qualified Feature.OptionsSpec +import qualified Feature.Query.AggregateFunctionsSpec import qualified Feature.Query.AndOrParamsSpec import qualified Feature.Query.ComputedRelsSpec import qualified Feature.Query.CustomMediaSpec @@ -109,6 +110,7 @@ main = do pgSafeUpdateApp = app testPgSafeUpdateEnabledCfg obsApp = app testObservabilityCfg serverTiming = app testCfgServerTiming + aggregatesEnabled = app testCfgAggregatesEnabled extraSearchPathApp = appDbs testCfgExtraSearchPath unicodeApp = appDbs testUnicodeCfg @@ -242,6 +244,12 @@ main = do parallel $ before serverTiming $ describe "Feature.Query.ServerTimingSpec.spec" Feature.Query.ServerTimingSpec.spec + parallel $ before aggregatesEnabled $ + describe "Feature.Query.AggregateFunctionsSpec" Feature.Query.AggregateFunctionsSpec.allowed + + parallel $ before withApp $ + describe "Feature.Query.AggregateFunctionsDisallowedSpec." Feature.Query.AggregateFunctionsSpec.disallowed + -- Note: the rollback tests can not run in parallel, because they test persistance and -- this results in race conditions diff --git a/test/spec/SpecHelper.hs b/test/spec/SpecHelper.hs index 5deba3f043..0717ef12b4 100644 --- a/test/spec/SpecHelper.hs +++ b/test/spec/SpecHelper.hs @@ -98,6 +98,7 @@ baseCfg :: AppConfig baseCfg = let secret = Just $ encodeUtf8 "reallyreallyreallyreallyverysafe" in AppConfig { configAppSettings = [ ("app.settings.app_host", "localhost") , ("app.settings.external_api_secret", "0123456789abcdef") ] + , configDbAggregates = False , configDbAnonRole = Just "postgrest_test_anonymous" , configDbChannel = mempty , configDbChannelEnabled = True @@ -235,6 +236,9 @@ testObservabilityCfg = baseCfg { configServerTraceHeader = Just $ mk "X-Request- testCfgServerTiming :: AppConfig testCfgServerTiming = baseCfg { configDbPlanEnabled = True } +testCfgAggregatesEnabled :: AppConfig +testCfgAggregatesEnabled = baseCfg { configDbAggregates = True } + analyzeTable :: Text -> IO () analyzeTable tableName = void $ readProcess "psql" ["-U", "postgres", "--set", "ON_ERROR_STOP=1", "-a", "-c", toS $ "ANALYZE test.\"" <> tableName <> "\""] [] diff --git a/test/spec/fixtures/data.sql b/test/spec/fixtures/data.sql index f5c4c71dd7..6ccf5e8962 100644 --- a/test/spec/fixtures/data.sql +++ b/test/spec/fixtures/data.sql @@ -866,3 +866,27 @@ TRUNCATE TABLE timestamps CASCADE; INSERT INTO timestamps VALUES ('2023-10-18 12:37:59.611000+0000'); INSERT INTO timestamps VALUES ('2023-10-18 14:37:59.611000+0000'); INSERT INTO timestamps VALUES ('2023-10-18 16:37:59.611000+0000'); + +TRUNCATE TABLE project_invoices CASCADE; +INSERT INTO project_invoices VALUES (1, 100, 1); +INSERT INTO project_invoices VALUES (2, 200, 1); +INSERT INTO project_invoices VALUES (3, 500, 2); +INSERT INTO project_invoices VALUES (4, 700, 2); +INSERT INTO project_invoices VALUES (5, 1200, 3); +INSERT INTO project_invoices VALUES (6, 2000, 3); +INSERT INTO project_invoices VALUES (7, 100, 4); +INSERT INTO project_invoices VALUES (8, 4000, 4); + +TRUNCATE TABLE budget_categories CASCADE; +INSERT INTO budget_categories VALUES (1, 'Beanie Babies', 'Brian Smith', 1000.31); +INSERT INTO budget_categories VALUES (2, 'DVDs', 'Jane Clarkson', 2000.12); +INSERT INTO budget_categories VALUES (3, 'Pizza', 'Brian Smith', 1000.11); +INSERT INTO budget_categories VALUES (4, 'Opera Tickets', 'Jane Clarkson', 7000.41); +INSERT INTO budget_categories VALUES (5, 'Nuclear Fusion Research', 'Sally Hughes', 500.23); +INSERT INTO budget_categories VALUES (6, 'T-5hirts', 'Dana de Groot', 500.33); + +TRUNCATE TABLE budget_expenses CASCADE; +INSERT INTO budget_expenses VALUES (1, 200.26, 1); +INSERT INTO budget_expenses VALUES (2, 400.26, 3); +INSERT INTO budget_expenses VALUES (3, 100.22, 4); +INSERT INTO budget_expenses VALUES (5, 900.27, 5); diff --git a/test/spec/fixtures/schema.sql b/test/spec/fixtures/schema.sql index 30036916f1..29999ecc55 100644 --- a/test/spec/fixtures/schema.sql +++ b/test/spec/fixtures/schema.sql @@ -3631,3 +3631,22 @@ create table empty_string as select 1 as id, ''::text as string; create table timestamps ( t timestamp with time zone ); + +create table project_invoices ( + id int primary key +, invoice_total numeric +, project_id integer references projects(id) +); + +create table budget_categories ( + id int primary key +, category_name text +, budget_owner text +, budget_amount numeric +); + +create table budget_expenses ( + id int primary key +, expense_amount numeric +, budget_category_id integer references budget_categories(id) +);