Skip to content

Commit

Permalink
First pass at adding aggregate functions
Browse files Browse the repository at this point in the history
  • Loading branch information
timabdulla committed Aug 30, 2023
1 parent 7dc6e2b commit 4962bef
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 29 deletions.
20 changes: 14 additions & 6 deletions src/PostgREST/ApiRequest/QueryParams.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import Data.Tree (Tree (..))
import Text.Parsec.Error (errorMessages,
showErrorMessages)
import Text.ParserCombinators.Parsec (GenParser, ParseError, Parser,
anyChar, between, char, digit,
eof, errorPos, letter,
anyChar, between, char, choice,
digit, eof, errorPos, letter,
lookAhead, many1, noneOf,
notFollowedBy, oneOf,
optionMaybe, sepBy, sepBy1,
Expand All @@ -43,7 +43,8 @@ import PostgREST.RangeQuery (NonnegRange, allRange,
rangeOffset, restrictRange)
import PostgREST.SchemaCache.Identifiers (FieldName)

import PostgREST.ApiRequest.Types (EmbedParam (..), EmbedPath, Field,
import PostgREST.ApiRequest.Types (AggregateFunction(..),
EmbedParam (..), EmbedPath, Field,
Filter (..), FtsOperator (..),
Hint, JoinType (..),
JsonOperand (..),
Expand All @@ -58,7 +59,7 @@ import PostgREST.ApiRequest.Types (EmbedParam (..), EmbedPath, Field,
SimpleOperator (..), SingleVal,
TrileanVal (..))

import Protolude hiding (try)
import Protolude hiding (try, Sum)

data QueryParams =
QueryParams
Expand Down Expand Up @@ -495,18 +496,25 @@ pFieldSelect :: Parser SelectItem
pFieldSelect = lexeme $ try (do
s <- pStar
pEnd
return $ SelectField (s, []) Nothing Nothing)
return $ SelectField (s, []) Nothing Nothing Nothing)
<|> do
alias <- optionMaybe ( try(pFieldName <* aliasSeparator) )
fld <- pField
agg <- optionMaybe (try (char '.' *> pAggregation <* string "()"))
cast' <- optionMaybe (string "::" *> pIdentifier)
pEnd
return $ SelectField fld (toS <$> cast') alias
return $ SelectField fld agg (toS <$> cast') alias
where
pEnd = try (void $ lookAhead (string ")")) <|>
try (void $ lookAhead (string ",")) <|>
try eof
pStar = string "*" $> "*"
pAggregation = choice
[ string "sum" $> Sum
, string "avg" $> Avg
, string "max" $> Max
, string "min" $> Min

Check warning on line 516 in src/PostgREST/ApiRequest/QueryParams.hs

View check run for this annotation

Codecov / codecov/patch

src/PostgREST/ApiRequest/QueryParams.hs#L512-L516

Added lines #L512 - L516 were not covered by tests
]


-- |
Expand Down
15 changes: 10 additions & 5 deletions src/PostgREST/ApiRequest/Types.hs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{-# LANGUAGE DuplicateRecordFields #-}
module PostgREST.ApiRequest.Types
( Alias
( AggregateFunction(..)
, Alias
, Cast
, Depth
, EmbedParam(..)
Expand Down Expand Up @@ -42,12 +43,13 @@ import PostgREST.SchemaCache.Routine (Routine (..))

import Protolude

-- | The value in `/tbl?select=alias:field::cast`
-- | The value in `/tbl?select=alias:field.aggregateFunction()::cast`
data SelectItem
= SelectField
{ selField :: Field
, selCast :: Maybe Cast
, selAlias :: Maybe Alias
{ selField :: Field
, selAggregateFunction :: Maybe AggregateFunction
, selCast :: Maybe Cast
, selAlias :: Maybe Alias

Check warning on line 52 in src/PostgREST/ApiRequest/Types.hs

View check run for this annotation

Codecov / codecov/patch

src/PostgREST/ApiRequest/Types.hs#L49-L52

Added lines #L49 - L52 were not covered by tests
}
-- | The value in `/tbl?select=alias:another_tbl(*)`
| SelectRelation
Expand Down Expand Up @@ -128,6 +130,9 @@ type Cast = Text
type Alias = Text
type Hint = Text

data AggregateFunction = Sum | Avg | Max | Min
deriving (Show, Eq)

Check warning on line 134 in src/PostgREST/ApiRequest/Types.hs

View check run for this annotation

Codecov / codecov/patch

src/PostgREST/ApiRequest/Types.hs#L134

Added line #L134 was not covered by tests

data EmbedParam
-- | Disambiguates an embedding operation when there's multiple relationships
-- between two tables. Can be the name of a foreign key constraint, column
Expand Down
18 changes: 9 additions & 9 deletions src/PostgREST/Plan.hs
Original file line number Diff line number Diff line change
Expand Up @@ -317,15 +317,15 @@ initReadRequest ctx@ResolverContext{qi=QualifiedIdentifier{..}} =
(Node defReadPlan{from=QualifiedIdentifier qiSchema selRelation, relName=selRelation, relHint=selHint, relJoinType=selJoinType, depth=nxtDepth, relIsSpread=True} [])
fldForest:rForest
SelectField{..} ->
Node q{select=(resolveOutputField ctx{qi=from q} selField, selCast, selAlias):select q} rForest
Node q{select=(resolveOutputField ctx{qi=from q} selField, selAggregateFunction, selCast, selAlias):select q} rForest

-- | Preserve the original field name if data representation is used to coerce the value.
addDataRepresentationAliases :: ReadPlanTree -> Either ApiRequestError ReadPlanTree
addDataRepresentationAliases rPlanTree = Right $ fmap (\rPlan@ReadPlan{select=sel} -> rPlan{select=map aliasSelectItem sel}) rPlanTree
where
aliasSelectItem :: (CoercibleField, Maybe Cast, Maybe Alias) -> (CoercibleField, Maybe Cast, Maybe Alias)
aliasSelectItem :: (CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias) -> (CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias)
-- If there already is an alias, don't overwrite it.
aliasSelectItem (fld@(CoercibleField{cfName=fieldName, cfTransform=(Just _)}), Nothing, Nothing) = (fld, Nothing, Just fieldName)
aliasSelectItem (fld@(CoercibleField{cfName=fieldName, cfTransform=(Just _)}), Nothing, Nothing, Nothing) = (fld, Nothing, Nothing, Just fieldName)
aliasSelectItem fld = fld

knownColumnsInContext :: ResolverContext -> [Column]
Expand All @@ -348,7 +348,7 @@ expandStarsForDataRepresentations ctx@ResolverContext{qi} rPlanTree = Right $ fm
expandStarsForTable :: ResolverContext -> ReadPlan -> ReadPlan
expandStarsForTable ctx@ResolverContext{representations, outputType} rplan@ReadPlan{select=selectItems} =
-- If we have a '*' select AND the target table has at least one data representation, expand.
if ("*" `elem` map (\(field, _, _) -> cfName field) selectItems) && any hasOutputRep knownColumns
if ("*" `elem` map (\(field, _, _, _) -> cfName field) selectItems) && any hasOutputRep knownColumns
then rplan{select=concatMap (expandStarSelectItem knownColumns) selectItems}
else rplan
where
Expand All @@ -357,8 +357,8 @@ expandStarsForTable ctx@ResolverContext{representations, outputType} rplan@ReadP
hasOutputRep :: Column -> Bool
hasOutputRep col = HM.member (colNominalType col, outputType) representations

expandStarSelectItem :: [Column] -> (CoercibleField, Maybe Cast, Maybe Alias) -> [(CoercibleField, Maybe Cast, Maybe Alias)]
expandStarSelectItem columns (CoercibleField{cfName="*", cfJsonPath=[]}, b, c) = map (\col -> (withOutputFormat ctx $ resolveColumnField col, b, c)) columns
expandStarSelectItem :: [Column] -> (CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias) -> [(CoercibleField, Maybe AggregateFunction,Maybe Cast, Maybe Alias)]
expandStarSelectItem columns (CoercibleField{cfName="*", cfJsonPath=[]}, b, c, d) = map (\col -> (withOutputFormat ctx $ resolveColumnField col, b, c, d)) columns
expandStarSelectItem _ selectItem = [selectItem]

-- | Enforces the `max-rows` config on the result
Expand Down Expand Up @@ -770,7 +770,7 @@ inferColsEmbedNeeds (Node ReadPlan{select} forest) pkCols
| "*" `elem` fldNames = ["*"]
| otherwise = returnings
where
fldNames = cfName . (\(f, _, _) -> f) <$> select
fldNames = cfName . (\(f, _, _, _) -> f) <$> select
-- Without fkCols, when a mutatePlan to
-- /projects?select=name,clients(name) occurs, the RETURNING SQL part would
-- be `RETURNING name`(see QueryBuilder). This would make the embedding
Expand Down Expand Up @@ -839,8 +839,8 @@ binaryField AppConfig{configRawMediaTypes} acceptMediaType proc rpTree
_ -> False

fstFieldName :: ReadPlanTree -> Maybe FieldName
fstFieldName (Node ReadPlan{select=(CoercibleField{cfName="*", cfJsonPath=[]}, _, _):_} []) = Nothing
fstFieldName (Node ReadPlan{select=[(CoercibleField{cfName=fld, cfJsonPath=[]}, _, _)]} []) = Just fld
fstFieldName (Node ReadPlan{select=(CoercibleField{cfName="*", cfJsonPath=[]}, _, _, _):_} []) = Nothing
fstFieldName (Node ReadPlan{select=[(CoercibleField{cfName=fld, cfJsonPath=[]}, _, _, _)]} []) = Just fld
fstFieldName _ = Nothing


Expand Down
6 changes: 3 additions & 3 deletions src/PostgREST/Plan/ReadPlan.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ module PostgREST.Plan.ReadPlan

import Data.Tree (Tree (..))

import PostgREST.ApiRequest.Types (Alias, Cast, Depth, Hint,
JoinType, NodeName)
import PostgREST.ApiRequest.Types (AggregateFunction, Alias, Cast, Depth,
Hint, JoinType, NodeName)
import PostgREST.Plan.Types (CoercibleField (..),
CoercibleLogicTree,
CoercibleOrderTerm)
Expand All @@ -28,7 +28,7 @@ data JoinCondition =
deriving (Eq, Show)

data ReadPlan = ReadPlan
{ select :: [(CoercibleField, Maybe Cast, Maybe Alias)]
{ select :: [(CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias)]
, from :: QualifiedIdentifier
, fromAlias :: Maybe Alias
, where_ :: [CoercibleLogicTree]
Expand Down
3 changes: 2 additions & 1 deletion src/PostgREST/Query/QueryBuilder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ readPlanToQuery (Node ReadPlan{select,from=mainQi,fromAlias,where_=logicForest,o
(if null logicForest && null relJoinConds
then mempty
else "WHERE " <> intercalateSnippet " AND " (map (pgFmtLogicTree qi) logicForest ++ map pgFmtJoinCondition relJoinConds)) <> " " <>
groupF qi select <> " " <>
orderF qi order <> " " <>
limitOffsetF readRange
where
fromFrag = fromF relToParent mainQi fromAlias
qi = getQualifiedIdentifier relToParent mainQi fromAlias
defSelect = [(unknownField "*" [], Nothing, Nothing)] -- gets all the columns in case of an empty select, ignoring/obtaining these columns is done at the aggregation stage
defSelect = [(unknownField "*" [], Nothing, Nothing, Nothing)] -- gets all the columns in case of an empty select, ignoring/obtaining these columns is done at the aggregation stage
(selects, joins) = foldr getSelectsJoins ([],[]) forest

getSelectsJoins :: ReadPlanTree -> ([SQL.Snippet], [SQL.Snippet]) -> ([SQL.Snippet], [SQL.Snippet])
Expand Down
34 changes: 29 additions & 5 deletions src/PostgREST/Query/SqlFragment.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ module PostgREST.Query.SqlFragment
( noLocationF
, aggF
, countF
, groupF
, fromQi
, limitOffsetF
, locationF
Expand Down Expand Up @@ -50,7 +51,8 @@ import Control.Arrow ((***))
import Data.Foldable (foldr1)
import Text.InterpolatedString.Perl6 (qc)

import PostgREST.ApiRequest.Types (Alias, Cast,
import PostgREST.ApiRequest.Types (AggregateFunction (..),
Alias, Cast,
FtsOperator (..),
JsonOperand (..),
JsonOperation (..),
Expand Down Expand Up @@ -82,7 +84,7 @@ import PostgREST.SchemaCache.Routine (ResultAggregate (..),
funcReturnsSetOfScalar,
funcReturnsSingleComposite)

import Protolude hiding (cast)
import Protolude hiding (cast, Sum)

sourceCTEName :: Text
sourceCTEName = "pgrst_source"
Expand Down Expand Up @@ -260,12 +262,20 @@ pgFmtCoerceNamed :: CoercibleField -> SQL.Snippet
pgFmtCoerceNamed CoercibleField{cfName=fn, cfTransform=(Just formatterProc)} = pgFmtCallUnary formatterProc (pgFmtIdent fn) <> " AS " <> pgFmtIdent fn
pgFmtCoerceNamed CoercibleField{cfName=fn} = pgFmtIdent fn

pgFmtSelectItem :: QualifiedIdentifier -> (CoercibleField, Maybe Cast, Maybe Alias) -> SQL.Snippet
pgFmtSelectItem table (fld, Nothing, alias) = pgFmtTableCoerce table fld <> pgFmtAs (cfName fld) (cfJsonPath fld) alias
pgFmtSelectItem :: QualifiedIdentifier -> (CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias) -> SQL.Snippet
pgFmtSelectItem table (fld, agg, Nothing, alias) = pgFmtApplyAggregate agg (pgFmtTableCoerce table fld <> pgFmtAs (cfName fld) (cfJsonPath fld) alias)
-- Ideally we'd quote the cast with "pgFmtIdent cast". However, that would invalidate common casts such as "int", "bigint", etc.
-- Try doing: `select 1::"bigint"` - it'll err, using "int8" will work though. There's some parser magic that pg does that's invalidated when quoting.
-- Not quoting should be fine, we validate the input on Parsers.
pgFmtSelectItem table (fld, Just cast, alias) = "CAST (" <> pgFmtTableCoerce table fld <> " AS " <> SQL.sql (encodeUtf8 cast) <> " )" <> pgFmtAs (cfName fld) (cfJsonPath fld) alias
pgFmtSelectItem table (fld, agg, Just cast, alias) = "CAST (" <> pgFmtTableCoerce table fld <> " AS " <> SQL.sql (encodeUtf8 cast) <> " )" <> pgFmtApplyAggregate agg (pgFmtAs (cfName fld) (cfJsonPath fld) alias)

pgFmtApplyAggregate :: Maybe AggregateFunction -> SQL.Snippet -> SQL.Snippet
pgFmtApplyAggregate Nothing snippet = snippet
pgFmtApplyAggregate (Just agg) snippet = case agg of
Sum -> "SUM(" <> snippet <> ")"
Max -> "MAX(" <> snippet <> ")"
Min -> "MIN(" <> snippet <> ")"
Avg -> "AVG(" <> snippet <> ")"

Check warning on line 278 in src/PostgREST/Query/SqlFragment.hs

View check run for this annotation

Codecov / codecov/patch

src/PostgREST/Query/SqlFragment.hs#L274-L278

Added lines #L274 - L278 were not covered by tests

-- TODO: At this stage there shouldn't be a Maybe since ApiRequest should ensure that an INSERT/UPDATE has a body
fromJsonBodyF :: Maybe LBS.ByteString -> [CoercibleField] -> Bool -> Bool -> Bool -> SQL.Snippet
Expand Down Expand Up @@ -409,6 +419,19 @@ pgFmtAs fName jp Nothing = case jOp <$> lastMay jp of
Nothing -> mempty
pgFmtAs _ _ (Just alias) = " AS " <> pgFmtIdent alias

groupF :: QualifiedIdentifier -> [(CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias)] -> SQL.Snippet
groupF _ [] = mempty
groupF qi fields =
if all (\(_, agg, _, _) -> isNothing agg) fields || all (\(_, agg, _, _) -> isJust agg) fields
then
mempty
else
" GROUP BY " <> intercalateSnippet ", " (pgFmtGroup qi <$> (filter (\(_, agg, _, _) -> isNothing agg) fields))

Check warning on line 429 in src/PostgREST/Query/SqlFragment.hs

View check run for this annotation

Codecov / codecov/patch

src/PostgREST/Query/SqlFragment.hs#L429

Added line #L429 was not covered by tests

pgFmtGroup :: QualifiedIdentifier -> (CoercibleField, Maybe AggregateFunction, Maybe Cast, Maybe Alias) -> SQL.Snippet
pgFmtGroup _ (_, Just _, _, _) = mempty
pgFmtGroup qi (fld, _, _, _) = pgFmtField qi fld

Check warning on line 433 in src/PostgREST/Query/SqlFragment.hs

View check run for this annotation

Codecov / codecov/patch

src/PostgREST/Query/SqlFragment.hs#L432-L433

Added lines #L432 - L433 were not covered by tests

countF :: SQL.Snippet -> Bool -> (SQL.Snippet, SQL.Snippet)
countF countQuery shouldCount =
if shouldCount
Expand Down Expand Up @@ -496,6 +519,7 @@ setConfigLocalJson prefix keyVals = [setConfigLocal mempty (prefix, gucJsonVal k
arrayByteStringToText :: [(ByteString, ByteString)] -> [(Text,Text)]
arrayByteStringToText keyVal = (T.decodeUtf8 *** T.decodeUtf8) <$> keyVal

-- Investigate this
aggF :: Maybe Routine -> ResultAggregate -> SQL.Snippet
aggF rout = \case
BuiltinAggJson -> asJsonF rout False
Expand Down

0 comments on commit 4962bef

Please sign in to comment.