Skip to content

Commit

Permalink
Add aggregate functions (#2925)
Browse files Browse the repository at this point in the history
The aggregate functions SUM(), MAX(), MIN(), AVG(), and COUNT() are now supported.
  • Loading branch information
timabdulla authored Nov 23, 2023
1 parent c3301a1 commit 1c60b50
Show file tree
Hide file tree
Showing 29 changed files with 667 additions and 114 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ This project adheres to [Semantic Versioning](http://semver.org/).
- #3001, Add `statement_timeout` set on functions - @taimoorzaeem
- #3045, Apply superuser settings on impersonated roles if they have PostgreSQL 15 `GRANT SET ON PARAMETER` privilege - @steve-chavez
- #3062, Add config for enabling the `Server-Timing` header - @develop7
- #915, Add support for aggregate functions - @timabdulla
+ The aggregate functions SUM(), MAX(), MIN(), AVG(), and COUNT() are now supported.
+ It's disabled by default, you can enable it with `db-aggregates-enabled`.

### Fixed

Expand Down
1 change: 1 addition & 0 deletions postgrest.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ test-suite spec
Feature.OpenApi.RootSpec
Feature.OpenApi.SecurityOpenApiSpec
Feature.OptionsSpec
Feature.Query.AggregateFunctionsSpec
Feature.Query.AndOrParamsSpec
Feature.Query.ComputedRelsSpec
Feature.Query.CustomMediaSpec
Expand Down
65 changes: 43 additions & 22 deletions src/PostgREST/ApiRequest/QueryParams.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import Data.Tree (Tree (..))
import Text.Parsec.Error (errorMessages,
showErrorMessages)
import Text.ParserCombinators.Parsec (GenParser, ParseError, Parser,
anyChar, between, char, digit,
eof, errorPos, letter,
anyChar, between, char, choice,
digit, eof, errorPos, letter,
lookAhead, many1, noneOf,
notFollowedBy, oneOf,
optionMaybe, sepBy, sepBy1,
Expand All @@ -43,7 +43,8 @@ import PostgREST.RangeQuery (NonnegRange, allRange,
rangeOffset, restrictRange)
import PostgREST.SchemaCache.Identifiers (FieldName)

import PostgREST.ApiRequest.Types (EmbedParam (..), EmbedPath, Field,
import PostgREST.ApiRequest.Types (AggregateFunction (..),
EmbedParam (..), EmbedPath, Field,
Filter (..), FtsOperator (..),
Hint, JoinType (..),
JsonOperand (..),
Expand All @@ -58,7 +59,7 @@ import PostgREST.ApiRequest.Types (EmbedParam (..), EmbedPath, Field,
SimpleOperator (..), SingleVal,
TrileanVal (..))

import Protolude hiding (try)
import Protolude hiding (Sum, try)

data QueryParams =
QueryParams
Expand Down Expand Up @@ -99,7 +100,7 @@ data QueryParams =
-- 'select' is a reserved parameter that selects the fields to be returned:
--
-- >>> qsSelect <$> parse False "select=name,location"
-- Right [Node {rootLabel = SelectField {selField = ("name",[]), selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SelectField {selField = ("location",[]), selCast = Nothing, selAlias = Nothing}, subForest = []}]
-- Right [Node {rootLabel = SelectField {selField = ("name",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SelectField {selField = ("location",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []}]
--
-- Filters are parameters whose value contains an operator, separated by a '.' from its value:
--
Expand Down Expand Up @@ -282,24 +283,24 @@ pTreePath = do
-- Parse select= into a Forest of SelectItems
--
-- >>> P.parse pFieldForest "" "id"
-- Right [Node {rootLabel = SelectField {selField = ("id",[]), selCast = Nothing, selAlias = Nothing}, subForest = []}]
-- Right [Node {rootLabel = SelectField {selField = ("id",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []}]
--
-- >>> P.parse pFieldForest "" "client(id)"
-- Right [Node {rootLabel = SelectRelation {selRelation = "client", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("id",[]), selCast = Nothing, selAlias = Nothing}, subForest = []}]}]
-- Right [Node {rootLabel = SelectRelation {selRelation = "client", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("id",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []}]}]
--
-- >>> P.parse pFieldForest "" "*,client(*,nested(*))"
-- Right [Node {rootLabel = SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SelectRelation {selRelation = "client", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SelectRelation {selRelation = "nested", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing}, subForest = []}]}]}]
-- Right [Node {rootLabel = SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SelectRelation {selRelation = "client", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SelectRelation {selRelation = "nested", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []}]}]}]
--
-- >>> P.parse pFieldForest "" "*,...client(*),other(*)"
-- Right [Node {rootLabel = SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SpreadRelation {selRelation = "client", selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing}, subForest = []}]},Node {rootLabel = SelectRelation {selRelation = "other", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing}, subForest = []}]}]
-- Right [Node {rootLabel = SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []},Node {rootLabel = SpreadRelation {selRelation = "client", selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []}]},Node {rootLabel = SelectRelation {selRelation = "other", selAlias = Nothing, selHint = Nothing, selJoinType = Nothing}, subForest = [Node {rootLabel = SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing}, subForest = []}]}]
--
-- >>> P.parse pFieldForest "" ""
-- Right []
--
-- >>> P.parse pFieldForest "" "id,clients(name[])"
-- Left (line 1, column 16):
-- unexpected '['
-- expecting letter, digit, "-", "->>", "->", "::", ")", "," or end of input
-- expecting letter, digit, "-", "->>", "->", "::", ".", ")", "," or end of input
--
-- >>> P.parse pFieldForest "" "data->>-78xy"
-- Left (line 1, column 11):
Expand Down Expand Up @@ -452,35 +453,37 @@ pRelationSelect :: Parser SelectItem
pRelationSelect = lexeme $ do
alias <- optionMaybe ( try(pFieldName <* aliasSeparator) )
name <- pFieldName
guard (name /= "count")
(hint, jType) <- pEmbedParams
try (void $ lookAhead (string "("))
return $ SelectRelation name alias hint jType


-- |
-- Parse regular fields in select
--
-- >>> P.parse pFieldSelect "" "name"
-- Right (SelectField {selField = ("name",[]), selCast = Nothing, selAlias = Nothing})
-- Right (SelectField {selField = ("name",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing})
--
-- >>> P.parse pFieldSelect "" "name->jsonpath"
-- Right (SelectField {selField = ("name",[JArrow {jOp = JKey {jVal = "jsonpath"}}]), selCast = Nothing, selAlias = Nothing})
-- Right (SelectField {selField = ("name",[JArrow {jOp = JKey {jVal = "jsonpath"}}]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing})
--
-- >>> P.parse pFieldSelect "" "name::cast"
-- Right (SelectField {selField = ("name",[]), selCast = Just "cast", selAlias = Nothing})
-- Right (SelectField {selField = ("name",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Just "cast", selAlias = Nothing})
--
-- >>> P.parse pFieldSelect "" "alias:name"
-- Right (SelectField {selField = ("name",[]), selCast = Nothing, selAlias = Just "alias"})
-- Right (SelectField {selField = ("name",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Just "alias"})
--
-- >>> P.parse pFieldSelect "" "alias:name->jsonpath::cast"
-- Right (SelectField {selField = ("name",[JArrow {jOp = JKey {jVal = "jsonpath"}}]), selCast = Just "cast", selAlias = Just "alias"})
-- Right (SelectField {selField = ("name",[JArrow {jOp = JKey {jVal = "jsonpath"}}]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Just "cast", selAlias = Just "alias"})
--
-- >>> P.parse pFieldSelect "" "*"
-- Right (SelectField {selField = ("*",[]), selCast = Nothing, selAlias = Nothing})
-- Right (SelectField {selField = ("*",[]), selAggregateFunction = Nothing, selAggregateCast = Nothing, selCast = Nothing, selAlias = Nothing})
--
-- >>> P.parse pFieldSelect "" "name!hint"
-- Left (line 1, column 5):
-- unexpected '!'
-- expecting letter, digit, "-", "->>", "->", "::", ")", "," or end of input
-- expecting letter, digit, "-", "->>", "->", "::", ".", ")", "," or end of input
--
-- >>> P.parse pFieldSelect "" "*!hint"
-- Left (line 1, column 2):
Expand All @@ -495,18 +498,36 @@ pFieldSelect :: Parser SelectItem
pFieldSelect = lexeme $ try (do
s <- pStar
pEnd
return $ SelectField (s, []) Nothing Nothing)
return $ SelectField (s, []) Nothing Nothing Nothing Nothing)
<|> try (do
alias <- optionMaybe ( try(pFieldName <* aliasSeparator) )
_ <- string "count()"
aggCast' <- optionMaybe (string "::" *> pIdentifier)
pEnd
return $ SelectField ("*", []) (Just Count) (toS <$> aggCast') Nothing alias)
<|> do
alias <- optionMaybe ( try(pFieldName <* aliasSeparator) )
fld <- pField
cast' <- optionMaybe (string "::" *> pIdentifier)
alias <- optionMaybe ( try(pFieldName <* aliasSeparator) )
fld <- pField
cast' <- optionMaybe (string "::" *> pIdentifier)
agg <- optionMaybe (try (char '.' *> pAggregation <* string "()"))
aggCast' <- optionMaybe (string "::" *> pIdentifier)
pEnd
return $ SelectField fld (toS <$> cast') alias
return $ SelectField fld agg (toS <$> aggCast') (toS <$> cast') alias
where
pEnd = try (void $ lookAhead (string ")")) <|>
try (void $ lookAhead (string ",")) <|>
try eof
pStar = string "*" $> "*"
pAggregation = choice
[ string "sum" $> Sum
, string "avg" $> Avg
, string "count" $> Count
-- Using 'try' for "min" and "max" to allow backtracking.
-- This is necessary because both start with the same character 'm',
-- and without 'try', a partial match on "max" would prevent "min" from being tried.
, try (string "max") $> Max
, try (string "min") $> Min
]


-- |
Expand Down
19 changes: 13 additions & 6 deletions src/PostgREST/ApiRequest/Types.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
module PostgREST.ApiRequest.Types
( Alias
( AggregateFunction(..)
, Alias
, Cast
, Depth
, EmbedParam(..)
Expand Down Expand Up @@ -42,12 +43,14 @@ import PostgREST.SchemaCache.Routine (Routine (..))

import Protolude

-- | The value in `/tbl?select=alias:field::cast`
-- | The value in `/tbl?select=alias:field.aggregateFunction()::cast`
data SelectItem
= SelectField
{ selField :: Field
, selCast :: Maybe Cast
, selAlias :: Maybe Alias
{ selField :: Field
, selAggregateFunction :: Maybe AggregateFunction
, selAggregateCast :: Maybe Cast
, selCast :: Maybe Cast
, selAlias :: Maybe Alias
}
-- | The value in `/tbl?select=alias:another_tbl(*)`
| SelectRelation
Expand All @@ -65,7 +68,8 @@ data SelectItem
deriving (Eq, Show)

data ApiRequestError
= AmbiguousRelBetween Text Text [Relationship]
= AggregatesNotAllowed
| AmbiguousRelBetween Text Text [Relationship]
| AmbiguousRpc [Routine]
| BinaryFieldError MediaType
| MediaTypeError [ByteString]
Expand Down Expand Up @@ -135,6 +139,9 @@ type Cast = Text
type Alias = Text
type Hint = Text

data AggregateFunction = Sum | Avg | Max | Min | Count
deriving (Show, Eq)

data EmbedParam
-- | Disambiguates an embedding operation when there's multiple relationships
-- between two tables. Can be the name of a foreign key constraint, column
Expand Down
5 changes: 4 additions & 1 deletion src/PostgREST/Config.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ import Protolude hiding (Proxy, toList)

data AppConfig = AppConfig
{ configAppSettings :: [(Text, Text)]
, configDbAggregates :: Bool
, configDbAnonRole :: Maybe BS.ByteString
, configDbChannel :: Text
, configDbChannelEnabled :: Bool
Expand Down Expand Up @@ -138,7 +139,8 @@ toText conf =
where
-- apply conf to all pgrst settings
pgrstSettings = (\(k, v) -> (k, v conf)) <$>
[("db-anon-role", q . T.decodeUtf8 . fromMaybe "" . configDbAnonRole)
[("db-aggregates-enabled", T.toLower . show . configDbAggregates)
,("db-anon-role", q . T.decodeUtf8 . fromMaybe "" . configDbAnonRole)
,("db-channel", q . configDbChannel)
,("db-channel-enabled", T.toLower . show . configDbChannelEnabled)
,("db-extra-search-path", q . T.intercalate "," . configDbExtraSearchPath)
Expand Down Expand Up @@ -232,6 +234,7 @@ parser :: Maybe FilePath -> Environment -> [(Text, Text)] -> RoleSettings -> Rol
parser optPath env dbSettings roleSettings roleIsolationLvl =
AppConfig
<$> parseAppSettings "app.settings"
<*> (fromMaybe False <$> optBool "db-aggregates-enabled")
<*> (fmap encodeUtf8 <$> optString "db-anon-role")
<*> (fromMaybe "pgrst" <$> optString "db-channel")
<*> (fromMaybe True <$> optBool "db-channel-enabled")
Expand Down
3 changes: 2 additions & 1 deletion src/PostgREST/Config/Database.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ prefix = "pgrst."
dbSettingsNames :: [Text]
dbSettingsNames =
(prefix <>) <$>
["db_anon_role"
["db_aggregates_enabled"
,"db_anon_role"
,"db_pre_config"
,"db_extra_search_path"
,"db_max_rows"
Expand Down
6 changes: 6 additions & 0 deletions src/PostgREST/Error.hs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class (JSON.ToJSON a) => PgrstError a where
responseLBS (status err) (baseHeader : headers err) $ errorPayload err

instance PgrstError ApiRequestError where
status AggregatesNotAllowed{} = HTTP.status400
status AmbiguousRelBetween{} = HTTP.status300
status AmbiguousRpc{} = HTTP.status300
status BinaryFieldError{} = HTTP.status406
Expand Down Expand Up @@ -198,6 +199,9 @@ instance JSON.ToJSON ApiRequestError where
(Just $ JSON.String $ T.decodeUtf8 ("Invalid preferences: " <> BS.intercalate ", " prefs))
Nothing

toJSON AggregatesNotAllowed = toJsonPgrstError
ApiRequestErrorCode23 "Use of aggregate functions is not allowed" Nothing Nothing

toJSON (NoRelBetween parent child embedHint schema allRels) = toJsonPgrstError
SchemaCacheErrorCode00
("Could not find a relationship between '" <> parent <> "' and '" <> child <> "' in the schema cache")
Expand Down Expand Up @@ -604,6 +608,7 @@ data ErrorCode
| ApiRequestErrorCode20
| ApiRequestErrorCode21
| ApiRequestErrorCode22
| ApiRequestErrorCode23
-- Schema Cache errors
| SchemaCacheErrorCode00
| SchemaCacheErrorCode01
Expand Down Expand Up @@ -652,6 +657,7 @@ buildErrorCode code = "PGRST" <> case code of
ApiRequestErrorCode20 -> "120"
ApiRequestErrorCode21 -> "121"
ApiRequestErrorCode22 -> "122"
ApiRequestErrorCode23 -> "123"

SchemaCacheErrorCode00 -> "200"
SchemaCacheErrorCode01 -> "201"
Expand Down
Loading

0 comments on commit 1c60b50

Please sign in to comment.