Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation of side conditions in pattern matches #2984

Merged
merged 9 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/Juvix/Compiler/Core/Evaluator.hs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ geval opts herr tab env0 = eval' env0
match n env vs = \case
br : brs ->
case matchPatterns [] vs (toList (br ^. matchBranchPatterns)) of
Just args -> eval' (args ++ env) (br ^. matchBranchBody)
Just args -> matchRhs (args ++ env) (br ^. matchBranchRhs)
Nothing -> match n env vs brs
where
matchPatterns :: [Node] -> [Node] -> [Pattern] -> Maybe [Node]
Expand All @@ -169,6 +169,18 @@ geval opts herr tab env0 = eval' env0
| tag == _patternConstrTag =
matchPatterns (v : acc) args _patternConstrArgs
patmatch _ _ _ = Nothing

matchRhs :: [Node] -> MatchBranchRhs -> Node
matchRhs env' = \case
MatchBranchRhsExpression e -> eval' env' e
MatchBranchRhsIfs ifs -> matchIfs env' (toList ifs)

matchIfs :: [Node] -> [SideIfBranch] -> Node
matchIfs env' = \case
SideIfBranch {..} : ifs -> case eval' env' _sideIfBranchCondition of
NCtr (Constr _ (BuiltinTag TagTrue) []) -> eval' env' _sideIfBranchBody
_ -> matchIfs env' ifs
[] -> match n env vs brs
[] ->
evalError "no matching pattern" (substEnv env n)

Expand Down
54 changes: 41 additions & 13 deletions src/Juvix/Compiler/Core/Extra/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,6 @@ mkMatch i vtys rty vs bs = NMatch (Match i vtys rty vs bs)
mkMatch' :: NonEmpty Type -> Type -> NonEmpty Node -> [MatchBranch] -> Node
mkMatch' = mkMatch Info.empty

mkMatchBranch :: Info -> NonEmpty Pattern -> Node -> MatchBranch
mkMatchBranch = MatchBranch

mkMatchBranch' :: NonEmpty Pattern -> Node -> MatchBranch
mkMatchBranch' = MatchBranch mempty

mkIf :: Info -> Symbol -> Node -> Node -> Node -> Node
mkIf i sym v b1 b2 = mkCase i sym v [br] (Just b2)
where
Expand All @@ -122,6 +116,14 @@ mkIf i sym v b1 b2 = mkCase i sym v [br] (Just b2)
mkIf' :: Symbol -> Node -> Node -> Node -> Node
mkIf' = mkIf Info.empty

mkIfs :: Symbol -> [(Info, Node, Node)] -> Node -> Node
mkIfs sym = \case
[] -> id
((i, v, b) : rest) -> mkIf i sym v b . mkIfs sym rest

mkIfs' :: Symbol -> [(Node, Node)] -> Node -> Node
mkIfs' sym = mkIfs sym . map (\(v, b) -> (Info.empty, v, b))

mkBinder :: Text -> Type -> Binder
mkBinder name ty = Binder name Nothing ty

Expand Down Expand Up @@ -641,18 +643,27 @@ destruct = \case
: map noBinders (toList vtys)
++ map noBinders (toList vs)
++ concat
[ br
: reverse (foldl' (\acc b -> manyBinders (take (length acc) bis) (b ^. binderType) : acc) [] bis)
| (bis, br) <- branchChildren
[ brs
++ reverse (foldl' (\acc b -> manyBinders (take (length acc) bis) (b ^. binderType) : acc) [] bis)
| (bis, brs) <- branchChildren
]
where
branchChildren :: [([Binder], NodeChild)]
branchChildren :: [([Binder], [NodeChild])]
branchChildren =
[ (binders, manyBinders binders (br ^. matchBranchBody))
[ (binders, map (manyBinders binders) (branchRhsChildren (br ^. matchBranchRhs)))
| br <- branches,
let binders = concatMap getPatternBinders (toList (br ^. matchBranchPatterns))
]

branchRhsChildren :: MatchBranchRhs -> [Node]
branchRhsChildren = \case
MatchBranchRhsExpression e -> [e]
MatchBranchRhsIfs ifs -> concatMap sideIfBranchChildren ifs

sideIfBranchChildren :: SideIfBranch -> [Node]
sideIfBranchChildren SideIfBranch {..} =
[_sideIfBranchCondition, _sideIfBranchBody]

branchInfos :: [Info]
branchInfos =
concat
Expand Down Expand Up @@ -684,14 +695,31 @@ destruct = \case
let mkBranch :: MatchBranch -> Sem '[Input Node, Input Info] MatchBranch
mkBranch br = do
bi' <- inputJust
b' <- inputJust
b' <- mkBranchRhs (br ^. matchBranchRhs)
pats' <- setPatternsInfos (br ^. matchBranchPatterns)
return
br
{ _matchBranchInfo = bi',
_matchBranchPatterns = pats',
_matchBranchBody = b'
_matchBranchRhs = b'
}
mkBranchRhs :: MatchBranchRhs -> Sem '[Input Node, Input Info] MatchBranchRhs
mkBranchRhs = \case
MatchBranchRhsExpression _ -> do
e' <- inputJust
return (MatchBranchRhsExpression e')
MatchBranchRhsIfs ifs -> do
ifs' <- mkSideIfs ifs
return (MatchBranchRhsIfs ifs')
mkSideIfs :: NonEmpty SideIfBranch -> Sem '[Input Node, Input Info] (NonEmpty SideIfBranch)
mkSideIfs brs =
mapM mkSideIfBranch brs
mkSideIfBranch :: SideIfBranch -> Sem '[Input Node, Input Info] SideIfBranch
mkSideIfBranch _ = do
_sideIfBranchInfo <- inputJust
_sideIfBranchCondition <- inputJust
_sideIfBranchBody <- inputJust
return SideIfBranch {..}
numVals = length vs
values' :: NonEmpty Node
valueTypes' :: NonEmpty Node
Expand Down
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Core/Extra/Recursors/RMap/Named.hs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ where
_ ->
recur [] node
where
cont :: Level -> [BinderChange] -> Node -> Node
cont :: [BinderChange] -> Node -> Node
cont bcs = go (recur . (bcs ++)) (k + bindersNumFromBinderChange bcs)
```
produces
Expand Down
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Core/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ type PatternWildcard = PatternWildcard' Info Node

type PatternConstr = PatternConstr' Info Node

type MatchBranchRhs = MatchBranchRhs' Info Node

type SideIfBranch = SideIfBranch' Info Node

type Pattern = Pattern' Info Node

type PiLhs = PiLhs' Info Node
Expand Down
22 changes: 21 additions & 1 deletion src/Juvix/Compiler/Core/Language/Nodes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ data Match' i a = Match
data MatchBranch' i a = MatchBranch
{ _matchBranchInfo :: i,
_matchBranchPatterns :: !(NonEmpty (Pattern' i a)),
_matchBranchBody :: !a
_matchBranchRhs :: !(MatchBranchRhs' i a)
}

data Pattern' i a
Expand All @@ -202,6 +202,16 @@ data PatternConstr' i a = PatternConstr
_patternConstrArgs :: ![Pattern' i a]
}

data MatchBranchRhs' i a
= MatchBranchRhsExpression !a
| MatchBranchRhsIfs !(NonEmpty (SideIfBranch' i a))

data SideIfBranch' i a = SideIfBranch
{ _sideIfBranchInfo :: i,
_sideIfBranchCondition :: !a,
_sideIfBranchBody :: !a
}

-- | Useful for unfolding Pi
data PiLhs' i a = PiLhs
{ _piLhsInfo :: i,
Expand Down Expand Up @@ -437,8 +447,10 @@ makeLenses ''Case'
makeLenses ''CaseBranch'
makeLenses ''Match'
makeLenses ''MatchBranch'
makeLenses ''MatchBranchRhs'
makeLenses ''PatternWildcard'
makeLenses ''PatternConstr'
makeLenses ''SideIfBranch'
makeLenses ''Pi'
makeLenses ''Lambda'
makeLenses ''Univ'
Expand Down Expand Up @@ -528,12 +540,20 @@ instance (Eq a) => Eq (Pi' i a) where
eqOn (^. piBinder . binderType)
..&&.. eqOn (^. piBody)

instance (Eq a) => Eq (MatchBranchRhs' i a) where
(MatchBranchRhsExpression e1) == (MatchBranchRhsExpression e2) = e1 == e2
(MatchBranchRhsIfs ifs1) == (MatchBranchRhsIfs ifs2) = ifs1 == ifs2
_ == _ = False

instance (Eq a) => Eq (MatchBranch' i a) where
(MatchBranch _ pats1 b1) == (MatchBranch _ pats2 b2) = pats1 == pats2 && b1 == b2

instance (Eq a) => Eq (PatternConstr' i a) where
(PatternConstr _ _ tag1 ps1) == (PatternConstr _ _ tag2 ps2) = tag1 == tag2 && ps1 == ps2

instance (Eq a) => Eq (SideIfBranch' i a) where
(SideIfBranch _ c1 b1) == (SideIfBranch _ c2 b2) = c1 == c2 && b1 == b2

instance Hashable (Ident' i) where
hashWithSalt s = hashWithSalt s . (^. identSymbol)

Expand Down
21 changes: 19 additions & 2 deletions src/Juvix/Compiler/Core/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,23 @@ instance PrettyCode Bottom where
ty' <- ppCode _bottomType
return (parens (kwBottom <+> kwColon <+> ty'))

instance PrettyCode SideIfBranch where
ppCode :: (Member (Reader Options) r) => SideIfBranch -> Sem r (Doc Ann)
ppCode SideIfBranch {..} = do
cond <- ppCode _sideIfBranchCondition
body <- ppCode _sideIfBranchBody
return $ kwIf <+> cond <+> kwAssign <+> body

instance PrettyCode MatchBranchRhs where
ppCode :: (Member (Reader Options) r) => MatchBranchRhs -> Sem r (Doc Ann)
ppCode = \case
MatchBranchRhsExpression x -> do
e <- ppCode x
return $ kwAssign <+> e
MatchBranchRhsIfs x -> do
brs <- mapM ppCode x
return $ vsep brs

instance PrettyCode Node where
ppCode :: forall r. (Member (Reader Options) r) => Node -> Sem r (Doc Ann)
ppCode node = case node of
Expand All @@ -394,11 +411,11 @@ instance PrettyCode Node where
ppCodeCase' branchBinderNames branchBinderTypes branchTagNames x
NMatch Match {..} -> do
let branchPatterns = map (^. matchBranchPatterns) _matchBranches
branchBodies = map (^. matchBranchBody) _matchBranches
branchRhs = map (^. matchBranchRhs) _matchBranches
pats <- mapM ppPatterns branchPatterns
vs <- mapM ppCode _matchValues
vs' <- zipWithM ppWithType (toList vs) (toList _matchValueTypes)
bs <- sequence $ zipWithExact (\ps br -> ppCode br >>= \br' -> return $ ps <+> kwAssign <+> br') pats branchBodies
bs <- sequence $ zipWithExact (\ps br -> ppCode br >>= \br' -> return $ ps <+> br') pats branchRhs
let bss = bracesIndent $ align $ concatWith (\a b -> a <> kwSemicolon <> line <> b) bs
rty <- ppTypeAnnot _matchReturnType
return $ kwMatch <+> hsep (punctuate comma vs') <+> kwWith <> rty <+> bss
Expand Down
24 changes: 17 additions & 7 deletions src/Juvix/Compiler/Core/Transformation/MatchToCase.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import Juvix.Compiler.Core.Transformation.Base

data PatternRow = PatternRow
{ _patternRowPatterns :: [Pattern],
_patternRowBody :: Node,
_patternRowRhs :: MatchBranchRhs,
-- | The number of initial wildcard binders in `_patternRowPatterns` which
-- don't originate from the input
_patternRowIgnoredPatternsNum :: Int,
Expand Down Expand Up @@ -58,7 +58,7 @@ goMatchToCase recur node = case node of
matchBranchToPatternRow MatchBranch {..} =
PatternRow
{ _patternRowPatterns = toList _matchBranchPatterns,
_patternRowBody = _matchBranchBody,
_patternRowRhs = _matchBranchRhs,
_patternRowIgnoredPatternsNum = 0,
_patternRowBinderChangesRev = [BCAdd n]
}
Expand Down Expand Up @@ -104,10 +104,10 @@ goMatchToCase recur node = case node of
pat' = if length pat == 1 then doc defaultOptions (head' pat) else docValueSequence pat
mockFile = $(mkAbsFile "/match-to-case")
defaultLoc = singletonInterval (mkInitialLoc mockFile)
r@PatternRow {..} : _
r@PatternRow {..} : matrix'
| all isPatWildcard _patternRowPatterns ->
-- The first row matches all values (Section 4, case 2)
compileMatchingRow bindersNum vs r
compileMatchingRow err bindersNum vs matrix' r
_ -> do
-- Section 4, case 3
-- Select the first column
Expand Down Expand Up @@ -181,9 +181,19 @@ goMatchToCase recur node = case node of
where
ii = lookupInductiveInfo md ind

compileMatchingRow :: Level -> [Level] -> PatternRow -> Sem r Node
compileMatchingRow bindersNum vs PatternRow {..} =
goMatchToCase (recur . (bcs ++)) _patternRowBody
compileMatchingRow :: ([Value] -> [Value]) -> Level -> [Level] -> PatternMatrix -> PatternRow -> Sem r Node
compileMatchingRow err bindersNum vs matrix PatternRow {..} =
case _patternRowRhs of
MatchBranchRhsExpression body ->
goMatchToCase (recur . (bcs ++)) body
MatchBranchRhsIfs ifs -> do
-- If the branch has side-conditions, then we need to continue pattern
-- matching when none of the conditions is satisfied.
body <- compile err bindersNum vs matrix
md <- ask
let boolSym = lookupConstructorInfo md (BuiltinTag TagTrue) ^. constructorInductive
ifs' = map (\(SideIfBranch i c b) -> (i, c, b)) (toList ifs)
return $ mkIfs boolSym ifs' body
where
bcs =
reverse $
Expand Down
50 changes: 36 additions & 14 deletions src/Juvix/Compiler/Core/Translation/FromInternal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -492,41 +492,63 @@ goCase c = do
rty <- goType (fromJust $ c ^. Internal.caseExpressionWholeType)
return (mkMatch i (pure ty) rty (pure expr) branches)
_ ->
-- If the type of the value matched on is not an inductive type, then the
-- case expression has one branch with a variable pattern.
case c ^. Internal.caseBranches of
Internal.CaseBranch {..} :| _ ->
case _caseBranchPattern ^. Internal.patternArgPattern of
Internal.PatternVariable name -> do
vars <- asks (^. indexTableVars)
varsNum <- asks (^. indexTableVarsNum)
let vars' = addPatternVariableNames _caseBranchPattern varsNum vars
body <-
rhs <-
local
(set indexTableVars vars')
(underBinders 1 (goCaseBranchRhs _caseBranchRhs))
return $ mkLet i (Binder (name ^. nameText) (Just $ name ^. nameLoc) ty) expr body
case rhs of
MatchBranchRhsExpression body ->
return $ mkLet i (Binder (name ^. nameText) (Just $ name ^. nameLoc) ty) expr body
_ ->
impossible
_ ->
impossible
where
goCaseBranch :: Type -> Internal.CaseBranch -> Sem r MatchBranch
goCaseBranch ty b = goPatternArgs 0 (b ^. Internal.caseBranchRhs) [b ^. Internal.caseBranchPattern] [ty]

-- | FIXME Fix this as soon as side if conditions are implemented in Core. This
-- is needed so that we can test typechecking without a crash.
todoSideIfs ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen, Error BadScope] r) =>
Internal.SideIfs ->
Sem r Node
todoSideIfs s = goExpression (s ^. Internal.sideIfBranches . _head1 . Internal.sideIfBranchBody)

goCaseBranchRhs ::
forall r.
(Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen, Error BadScope] r) =>
Internal.CaseBranchRhs ->
Sem r Node
Sem r MatchBranchRhs
goCaseBranchRhs = \case
Internal.CaseBranchRhsExpression e -> goExpression e
Internal.CaseBranchRhsIf i -> todoSideIfs i
Internal.CaseBranchRhsExpression e -> MatchBranchRhsExpression <$> goExpression e
Internal.CaseBranchRhsIf Internal.SideIfs {..} -> case _sideIfElse of
Just elseBranch -> do
branches <- toList <$> mapM goSideIfBranch _sideIfBranches
elseBranch' <- goExpression elseBranch
boolSym <- getBoolSymbol
return $ MatchBranchRhsExpression $ mkIfs' boolSym branches elseBranch'
where
goSideIfBranch :: Internal.SideIfBranch -> Sem r (Node, Node)
goSideIfBranch Internal.SideIfBranch {..} = do
cond <- goExpression _sideIfBranchCondition
body <- goExpression _sideIfBranchBody
return (cond, body)
Nothing -> do
branches <- mapM goSideIfBranch _sideIfBranches
return $ MatchBranchRhsIfs branches
where
goSideIfBranch :: Internal.SideIfBranch -> Sem r SideIfBranch
goSideIfBranch Internal.SideIfBranch {..} = do
cond <- goExpression _sideIfBranchCondition
body <- goExpression _sideIfBranchBody
return $
SideIfBranch
{ _sideIfBranchInfo = setInfoLocation (getLoc _sideIfBranchCondition) mempty,
_sideIfBranchCondition = cond,
_sideIfBranchBody = body
}

goLambda ::
forall r.
Expand Down
Loading
Loading