diff --git a/src/Database/Esqueleto/Record.hs b/src/Database/Esqueleto/Record.hs index bc9fe8217..22a47dc0b 100644 --- a/src/Database/Esqueleto/Record.hs +++ b/src/Database/Esqueleto/Record.hs @@ -1,5 +1,7 @@ {-# LANGUAGE CPP #-} +{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -13,6 +15,8 @@ module Database.Esqueleto.Record , DeriveEsqueletoRecordSettings(..) , defaultDeriveEsqueletoRecordSettings + , takeColumns + , takeMaybeColumns ) where import Control.Monad.Trans.State.Strict (StateT(..), evalStateT) @@ -20,6 +24,7 @@ import Data.Proxy (Proxy(..)) import Database.Esqueleto.Experimental (Entity, PersistValue, SqlExpr, Value(..), (:&)(..)) import Database.Esqueleto.Experimental.ToAlias (ToAlias(..)) +import Database.Esqueleto.Experimental.ToMaybe (ToMaybe(..)) import Database.Esqueleto.Experimental.ToAliasReference (ToAliasReference(..)) import Database.Esqueleto.Internal.Internal (SqlSelect(..)) import Language.Haskell.TH @@ -182,6 +187,7 @@ deriveEsqueletoRecordWith settings originalName = do sqlSelectInstanceDec <- makeSqlSelectInstance info sqlMaybeRecordDec <- makeSqlMaybeRecord info toMaybeInstanceDec <- makeToMaybeInstance info + sqlMaybeRecordSelectInstanceDec <- makeSqlMaybeRecordSelectInstance info toAliasInstanceDec <- makeToAliasInstance info toAliasReferenceInstanceDec <- makeToAliasReferenceInstance info pure @@ -189,6 +195,7 @@ deriveEsqueletoRecordWith settings originalName = do , sqlSelectInstanceDec , sqlMaybeRecordDec , toMaybeInstanceDec + , sqlMaybeRecordSelectInstanceDec , toAliasInstanceDec , toAliasReferenceInstanceDec ] @@ -322,24 +329,27 @@ sqlMaybeFieldType :: Type -> Q Type sqlMaybeFieldType fieldType = do maybeSqlType <- reifySqlSelectType fieldType - pure $ - flip fromMaybe maybeSqlType $ - case fieldType of - -- Field type -> Sql type -> Sql Maybe type + pure $ maybe convertFieldType convertSqlType maybeSqlType + where + convertSqlType = ((ConT ''ToMaybeT) `AppT`) + convertFieldType = case fieldType of -- Entity x -> SqlExpr (Entity x) -> SqlExpr (Maybe (Entity x)) AppT (ConT ((==) ''Entity -> True)) _innerType -> (ConT ''SqlExpr) `AppT` ((ConT ''Maybe) `AppT` fieldType) - -- Maybe (Entity x) -> SqlExpr (Maybe (Entity x)) -> SqlExpr (Maybe (Maybe (Entity x))) + -- Maybe (Entity x) -> SqlExpr (Maybe (Entity x)) -> SqlExpr (Maybe (Entity x)) (ConT ((==) ''Maybe -> True)) `AppT` ((ConT ((==) ''Entity -> True)) `AppT` _innerType) -> - (ConT ''SqlExpr) `AppT` ((ConT ''Maybe) `AppT` ((ConT ''Maybe) `AppT` fieldType)) + (ConT ''SqlExpr) `AppT` fieldType + + -- Maybe x -> SqlExpr (Value (Maybe x)) -> SqlExpr (Value (Maybe x)) + inner@((ConT ((==) ''Maybe -> True)) `AppT` _inner) -> (ConT ''SqlExpr) `AppT` ((ConT ''Value) `AppT` inner) -- x -> SqlExpr (Value x) -> SqlExpr (Value (Maybe x)) _ -> (ConT ''SqlExpr) `AppT` ((ConT ''Value) - `AppT` ((ConT ''Maybe) `AppT` fieldType) + `AppT` ((ConT ''Maybe) `AppT` fieldType)) -- | Generates the declaration for an @Sql@-prefixed record, given the original -- record's information. @@ -724,7 +734,7 @@ makeSqlMaybeRecord :: RecordInfo -> Q Dec makeSqlMaybeRecord RecordInfo {..} = do let newConstructor = RecC sqlMaybeConstructorName (makeField `map` sqlMaybeFields) derivingClauses = [] - pure $ DataD constraints sqlName typeVarBinders kind [newConstructor] derivingClauses + pure $ DataD constraints sqlMaybeName typeVarBinders kind [newConstructor] derivingClauses where makeField (fieldName', fieldType) = (fieldName', Bang NoSourceUnpackedness NoSourceStrictness, fieldType) @@ -739,7 +749,7 @@ makeToMaybeInstance info@RecordInfo {..} = do instanceConstraints = [] instanceType = (ConT ''ToMaybe) `AppT` (ConT sqlName) - pure $ InstanceD overlap instanceConstraints instanceType [toMaybeTDec', toMaybeDec] + pure $ InstanceD overlap instanceConstraints instanceType [toMaybeTDec', toMaybeDec'] -- | Generates a `type ToMaybeT ... = ...` declaration for the given record. toMaybeTDec :: RecordInfo -> Q Dec @@ -752,9 +762,189 @@ toMaybeTDec RecordInfo {..} = do -- | Generates a `toMaybe value = ...` declaration for the given record. toMaybeDec :: RecordInfo -> Q Dec toMaybeDec RecordInfo {..} = do - valueName <- newName "value" - let patterns = [VarP valueName] - body = NormalB $ RecConE sqlMaybeName fields - fields = [] - decs = [] - pure $ FunD 'toMaybe [Clause patterns body decs] + (fieldPatterns, fieldExps) <- + unzip <$> forM (zip sqlFields sqlMaybeFields) (\((fieldName', _), (maybeFieldName', _)) -> do + fieldPatternName <- newName (nameBase fieldName') + pure + ( (fieldName', VarP fieldPatternName) + , (maybeFieldName', VarE 'toMaybe `AppE` VarE fieldPatternName) + )) + + pure $ + FunD + 'toMaybe + [ Clause + [ RecP sqlName fieldPatterns + ] + (NormalB $ RecConE sqlMaybeName fieldExps) + [] + ] + +-- | Generates an `SqlSelect` instance for the given record and its +-- @Sql@-prefixed variant. +makeSqlMaybeRecordSelectInstance :: RecordInfo -> Q Dec +makeSqlMaybeRecordSelectInstance info@RecordInfo {..} = do + sqlSelectColsDec' <- sqlMaybeSelectColsDec info + sqlSelectColCountDec' <- sqlMaybeSelectColCountDec info + sqlSelectProcessRowDec' <- sqlMaybeSelectProcessRowDec info + let overlap = Nothing + instanceConstraints = [] + instanceType = + (ConT ''SqlSelect) + `AppT` (ConT sqlMaybeName) + `AppT` (AppT (ConT ''Maybe) (ConT name)) + + pure $ InstanceD overlap instanceConstraints instanceType [sqlSelectColsDec', sqlSelectColCountDec', sqlSelectProcessRowDec'] + +-- | Generates the `sqlSelectCols` declaration for an `SqlSelect` instance. +sqlMaybeSelectColsDec :: RecordInfo -> Q Dec +sqlMaybeSelectColsDec RecordInfo {..} = do + -- Pairs of record field names and local variable names. + fieldNames <- forM sqlMaybeFields (\(name', _type) -> do + var <- newName $ nameBase name' + pure (name', var)) + + -- Patterns binding record fields to local variables. + let fieldPatterns :: [FieldPat] + fieldPatterns = [(name', VarP var) | (name', var) <- fieldNames] + + -- Local variables for fields joined with `:&` in a single expression. + joinedFields :: Exp + joinedFields = + case snd `map` fieldNames of + [] -> TupE [] + [f1] -> VarE f1 + f1 : rest -> + let helper lhs field = + InfixE + (Just lhs) + (ConE '(:&)) + (Just $ VarE field) + in foldl' helper (VarE f1) rest + + identInfo <- newName "identInfo" + -- Roughly: + -- sqlSelectCols $identInfo SqlFoo{..} = sqlSelectCols $identInfo $joinedFields + pure $ + FunD + 'sqlSelectCols + [ Clause + [ VarP identInfo + , RecP sqlMaybeName fieldPatterns + ] + ( NormalB $ + (VarE 'sqlSelectCols) + `AppE` (VarE identInfo) + `AppE` (ParensE joinedFields) + ) + -- `where` clause. + [] + ] + +-- | Generates the `sqlSelectProcessRow` declaration for an `SqlSelect` +-- instance. +sqlMaybeSelectProcessRowDec :: RecordInfo -> Q Dec +sqlMaybeSelectProcessRowDec RecordInfo {..} = do + let + sqlOp x = case x of + -- AppT (ConT ((==) ''Entity -> True)) _innerType -> id + -- (ConT ((==) ''Maybe -> True)) `AppT` ((ConT ((==) ''Entity -> True)) `AppT` _innerType) -> (AppE (VarE 'pure)) + -- inner@((ConT ((==) ''Maybe -> True)) `AppT` _inner) -> (AppE (VarE 'unValue)) + (AppT (ConT ((==) ''SqlExpr -> True)) (AppT (ConT ((==) ''Value -> True)) _)) -> (AppE (VarE 'unValue)) + (AppT (ConT ((==) ''SqlExpr -> True)) (AppT (ConT ((==) ''Entity -> True)) _)) -> id + (AppT (ConT ((==) ''SqlExpr -> True)) (AppT (ConT ((==) ''Maybe -> True)) _)) -> (AppE (VarE 'pure)) + (ConT _) -> id + _ -> error $ show x + + fieldNames <- forM sqlFields (\(name', typ) -> do + var <- newName $ nameBase name' + pure (name', var, sqlOp typ (VarE var))) + + let + joinedFields = + case (\(_,x,_) -> x) `map` fieldNames of + [] -> TupP [] + [f1] -> VarP f1 + f1 : rest -> + let helper lhs field = + InfixP + lhs + '(:&) + (VarP field) + in foldl' helper (VarP f1) rest + + + colsName <- newName "columns" + + let + bodyExp = DoE Nothing + [ BindS joinedFields (AppE (VarE 'sqlSelectProcessRow) (VarE colsName)) + , NoBindS + $ AppE (VarE 'pure) ( + case fieldNames of + [] -> ConE constructorName + (_,_,e):xs -> foldl' + (\acc (_,_,e2) -> AppE (AppE (VarE '(<*>)) acc) e2) + (AppE (AppE (VarE 'fmap) (ConE constructorName)) e) + xs + ) + ] + + pure $ + FunD + 'sqlSelectProcessRow + [ Clause + [VarP colsName] + (NormalB bodyExp) + [] + ] + +-- | Generates the `sqlSelectColCount` declaration for an `SqlSelect` instance. +sqlMaybeSelectColCountDec :: RecordInfo -> Q Dec +sqlMaybeSelectColCountDec RecordInfo {..} = do + let joinedTypes = + case snd `map` sqlMaybeFields of + [] -> TupleT 0 + t1 : rest -> + let helper lhs ty = + InfixT lhs ''(:&) ty + in foldl' helper t1 rest + + -- Roughly: + -- sqlSelectColCount _ = sqlSelectColCount (Proxy @($joinedTypes)) + pure $ + FunD + 'sqlSelectColCount + [ Clause + [WildP] + ( NormalB $ + AppE (VarE 'sqlSelectColCount) $ + ParensE $ + AppTypeE + (ConE 'Proxy) + joinedTypes + ) + -- `where` clause. + [] + ] + +-- | Statefully parse some number of columns from a list of `PersistValue`s, +-- where the number of columns to parse is determined by `sqlSelectColCount` +-- for @a@. +-- +-- This is used to implement `sqlSelectProcessRow` for records created with +-- `deriveEsqueletoRecord`. +takeMaybeColumns :: + forall a b. + (SqlSelect a (ToMaybeT b)) => + StateT [PersistValue] (Either Text) (ToMaybeT b) +takeMaybeColumns = StateT (\pvs -> + let targetColCount = + sqlSelectColCount (Proxy @a) + (target, other) = + splitAt targetColCount pvs + in if length target == targetColCount + then do + value <- sqlSelectProcessRow target + Right (value, other) + else Left "Insufficient columns when trying to parse a column") diff --git a/test/Common/Record.hs b/test/Common/Record.hs index 5cb1599ed..b4172942d 100644 --- a/test/Common/Record.hs +++ b/test/Common/Record.hs @@ -8,25 +8,34 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE OverloadedRecordDot #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} -- Tests for `Database.Esqueleto.Record`. module Common.Record (testDeriveEsqueletoRecord) where import Common.Test.Import hiding (from, on) +import Control.Monad.Trans.State.Strict (StateT(..), evalStateT) +import Data.Bifunctor (first) import Data.List (sortOn) +import Data.Maybe (catMaybes) +import Data.Proxy (Proxy(..)) import Database.Esqueleto.Experimental +import Database.Esqueleto.Internal.Internal (SqlSelect(..)) import Database.Esqueleto.Record ( DeriveEsqueletoRecordSettings(..) , defaultDeriveEsqueletoRecordSettings , deriveEsqueletoRecord , deriveEsqueletoRecordWith + , takeColumns + , takeMaybeColumns ) data MyRecord = @@ -112,6 +121,15 @@ myModifiedRecordQuery = do , myModifiedAddressSql = address } +mySubselectRecordQuery :: SqlQuery (SqlExpr (Maybe (Entity Address))) +mySubselectRecordQuery = do + _ :& record <- from $ + table @User + `leftJoin` + myRecordQuery + `on` (do \(user :& record) -> just (user ^. #id) ==. record.myUser ?. #id) + pure $ record.myAddress + testDeriveEsqueletoRecord :: SpecDb testDeriveEsqueletoRecord = describe "deriveEsqueletoRecord" $ do let setup :: MonadIO m => SqlPersistT m () @@ -208,7 +226,6 @@ testDeriveEsqueletoRecord = describe "deriveEsqueletoRecord" $ do } -> addr1 == addr2 -- The keys should match. _ -> False) - itDb "can select user-modified records" $ do setup records <- select myModifiedRecordQuery @@ -235,3 +252,64 @@ testDeriveEsqueletoRecord = describe "deriveEsqueletoRecord" $ do , myModifiedAddress = Just (Entity addr2 Address {addressAddress = "30-50 Feral Hogs Rd"}) } -> addr1 == addr2 -- The keys should match. _ -> False) + + itDb "can left join on records" $ do + setup + records <- select $ do + from + ( table @User + `leftJoin` myRecordQuery `on` (do \(user :& record) -> just (user ^. #id) ==. record.myUser ?. #id) + ) + let sortedRecords = sortOn (\(Entity _ user :& _) -> user.userName) records + liftIO $ sortedRecords !! 0 + `shouldSatisfy` + (\case (_ :& Just (MyRecord {myName = "Rebecca", myAddress = Nothing})) -> True + _ -> False) + liftIO $ sortedRecords !! 1 + `shouldSatisfy` + (\case ( _ :& Just ( MyRecord { myName = "Some Guy" + , myAddress = (Just (Entity addr2 Address {addressAddress = "30-50 Feral Hogs Rd"})) + } + )) -> True + _ -> True) + + itDb "can can handle joins on records with Nothing" $ do + setup + records <- select $ do + from + ( table @User + `leftJoin` myRecordQuery `on` (do \(user :& record) -> user ^. #address ==. record.myAddress ?. #id) + ) + let sortedRecords = sortOn (\(Entity _ user :& _) -> user.userName) records + liftIO $ sortedRecords !! 0 + `shouldSatisfy` + (\case (_ :& Nothing) -> True + _ -> False) + liftIO $ sortedRecords !! 1 + `shouldSatisfy` + (\case ( _ :& Just ( MyRecord { myName = "Some Guy" + , myAddress = (Just (Entity addr2 Address {addressAddress = "30-50 Feral Hogs Rd"})) + } + )) -> True + _ -> True) + + itDb "can left join on nested records" $ do + setup + records <- select $ do + from + ( table @User + `leftJoin` myNestedRecordQuery + `on` (do \(user :& record) -> just (user ^. #id) ==. record.myRecord.myUser ?. #id) + ) + let sortedRecords = sortOn (\(Entity _ user :& _) -> user.userName) records + liftIO $ sortedRecords !! 0 + `shouldSatisfy` + (\case (_ :& Just (MyNestedRecord {myRecord = MyRecord {myName = "Rebecca", myAddress = Nothing}})) -> True + _ -> False) + liftIO $ sortedRecords !! 1 + `shouldSatisfy` + (\case ( _ :& Just ( MyNestedRecord { myRecord = MyRecord { myName = "Some Guy" + , myAddress = (Just (Entity addr2 Address {addressAddress = "30-50 Feral Hogs Rd"})) + } + })) -> True + _ -> True)