From f2993e3451f9cf27ab37af0c39d8a7ec0c4f0258 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Fri, 30 Aug 2024 21:09:34 +0200 Subject: [PATCH 1/9] add side conditions in Core Match (wip) --- src/Juvix/Compiler/Core/Extra/Base.hs | 29 +++++++++++++---------- src/Juvix/Compiler/Core/Language.hs | 2 ++ src/Juvix/Compiler/Core/Language/Nodes.hs | 12 +++++++++- 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/src/Juvix/Compiler/Core/Extra/Base.hs b/src/Juvix/Compiler/Core/Extra/Base.hs index 951876c840..6b5f062198 100644 --- a/src/Juvix/Compiler/Core/Extra/Base.hs +++ b/src/Juvix/Compiler/Core/Extra/Base.hs @@ -101,12 +101,6 @@ mkMatch i vtys rty vs bs = NMatch (Match i vtys rty vs bs) mkMatch' :: NonEmpty Type -> Type -> NonEmpty Node -> [MatchBranch] -> Node mkMatch' = mkMatch Info.empty -mkMatchBranch :: Info -> NonEmpty Pattern -> Node -> MatchBranch -mkMatchBranch = MatchBranch - -mkMatchBranch' :: NonEmpty Pattern -> Node -> MatchBranch -mkMatchBranch' = MatchBranch mempty - mkIf :: Info -> Symbol -> Node -> Node -> Node -> Node mkIf i sym v b1 b2 = mkCase i sym v [br] (Just b2) where @@ -641,18 +635,21 @@ destruct = \case : map noBinders (toList vtys) ++ map noBinders (toList vs) ++ concat - [ br - : reverse (foldl' (\acc b -> manyBinders (take (length acc) bis) (b ^. binderType) : acc) [] bis) - | (bis, br) <- branchChildren + [ brs + ++ reverse (foldl' (\acc b -> manyBinders (take (length acc) bis) (b ^. binderType) : acc) [] bis) + | (bis, brs) <- branchChildren ] where - branchChildren :: [([Binder], NodeChild)] + branchChildren :: [([Binder], [NodeChild])] branchChildren = - [ (binders, manyBinders binders (br ^. matchBranchBody)) + [ (binders, map (manyBinders binders) (concatMap branchRhsChildren (br ^. matchBranchRhs))) | br <- branches, let binders = concatMap getPatternBinders (toList (br ^. matchBranchPatterns)) ] + branchRhsChildren :: SideIfBranch -> [Node] + branchRhsChildren = undefined + branchInfos :: [Info] branchInfos = concat @@ -684,14 +681,20 @@ destruct = \case let mkBranch :: MatchBranch -> Sem '[Input Node, Input Info] MatchBranch mkBranch br = do bi' <- inputJust - b' <- inputJust + b' <- mkBranchRhs (br ^. matchBranchRhs) pats' <- setPatternsInfos (br ^. matchBranchPatterns) return br { _matchBranchInfo = bi', _matchBranchPatterns = pats', - _matchBranchBody = b' + _matchBranchRhs = b' } + mkBranchRhs :: NonEmpty SideIfBranch -> Sem '[Input Node, Input Info] (NonEmpty SideIfBranch) + mkBranchRhs brs = + mapM mkSideIfBranch brs + mkSideIfBranch :: SideIfBranch -> Sem '[Input Node, Input Info] SideIfBranch + mkSideIfBranch = + undefined numVals = length vs values' :: NonEmpty Node valueTypes' :: NonEmpty Node diff --git a/src/Juvix/Compiler/Core/Language.hs b/src/Juvix/Compiler/Core/Language.hs index 36adbc2b7f..564abc6390 100644 --- a/src/Juvix/Compiler/Core/Language.hs +++ b/src/Juvix/Compiler/Core/Language.hs @@ -47,6 +47,8 @@ type PatternWildcard = PatternWildcard' Info Node type PatternConstr = PatternConstr' Info Node +type SideIfBranch = SideIfBranch' Info Node + type Pattern = Pattern' Info Node type PiLhs = PiLhs' Info Node diff --git a/src/Juvix/Compiler/Core/Language/Nodes.hs b/src/Juvix/Compiler/Core/Language/Nodes.hs index 833fdcd285..5f759ac7cf 100644 --- a/src/Juvix/Compiler/Core/Language/Nodes.hs +++ b/src/Juvix/Compiler/Core/Language/Nodes.hs @@ -183,7 +183,7 @@ data Match' i a = Match data MatchBranch' i a = MatchBranch { _matchBranchInfo :: i, _matchBranchPatterns :: !(NonEmpty (Pattern' i a)), - _matchBranchBody :: !a + _matchBranchRhs :: !(NonEmpty (SideIfBranch' i a)) } data Pattern' i a @@ -202,6 +202,12 @@ data PatternConstr' i a = PatternConstr _patternConstrArgs :: ![Pattern' i a] } +data SideIfBranch' i a = SideIfBranch + { _sizeIfBranchInfo :: i, + _sideIfBranchCondition :: !a, + _sideIfBranchBody :: !a + } + -- | Useful for unfolding Pi data PiLhs' i a = PiLhs { _piLhsInfo :: i, @@ -439,6 +445,7 @@ makeLenses ''Match' makeLenses ''MatchBranch' makeLenses ''PatternWildcard' makeLenses ''PatternConstr' +makeLenses ''SideIfBranch' makeLenses ''Pi' makeLenses ''Lambda' makeLenses ''Univ' @@ -534,6 +541,9 @@ instance (Eq a) => Eq (MatchBranch' i a) where instance (Eq a) => Eq (PatternConstr' i a) where (PatternConstr _ _ tag1 ps1) == (PatternConstr _ _ tag2 ps2) = tag1 == tag2 && ps1 == ps2 +instance (Eq a) => Eq (SideIfBranch' i a) where + (SideIfBranch _ c1 b1) == (SideIfBranch _ c2 b2) = c1 == c2 && b1 == b2 + instance Hashable (Ident' i) where hashWithSalt s = hashWithSalt s . (^. identSymbol) From 52c508ae170228e277c98ef9a4305e710efd26fa Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Mon, 2 Sep 2024 12:32:55 +0200 Subject: [PATCH 2/9] update destruct for Match nodes --- src/Juvix/Compiler/Core/Extra/Base.hs | 31 ++++++++++++++++++----- src/Juvix/Compiler/Core/Language.hs | 2 ++ src/Juvix/Compiler/Core/Language/Nodes.hs | 14 ++++++++-- 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/src/Juvix/Compiler/Core/Extra/Base.hs b/src/Juvix/Compiler/Core/Extra/Base.hs index 6b5f062198..62fba58f05 100644 --- a/src/Juvix/Compiler/Core/Extra/Base.hs +++ b/src/Juvix/Compiler/Core/Extra/Base.hs @@ -642,13 +642,19 @@ destruct = \case where branchChildren :: [([Binder], [NodeChild])] branchChildren = - [ (binders, map (manyBinders binders) (concatMap branchRhsChildren (br ^. matchBranchRhs))) + [ (binders, map (manyBinders binders) (branchRhsChildren (br ^. matchBranchRhs))) | br <- branches, let binders = concatMap getPatternBinders (toList (br ^. matchBranchPatterns)) ] - branchRhsChildren :: SideIfBranch -> [Node] - branchRhsChildren = undefined + branchRhsChildren :: MatchBranchRhs -> [Node] + branchRhsChildren = \case + MatchBranchRhsExpression e -> [e] + MatchBranchRhsIfs ifs -> concatMap sideIfBranchChildren ifs + + sideIfBranchChildren :: SideIfBranch -> [Node] + sideIfBranchChildren SideIfBranch {..} = + [_sideIfBranchCondition, _sideIfBranchBody] branchInfos :: [Info] branchInfos = @@ -689,12 +695,23 @@ destruct = \case _matchBranchPatterns = pats', _matchBranchRhs = b' } - mkBranchRhs :: NonEmpty SideIfBranch -> Sem '[Input Node, Input Info] (NonEmpty SideIfBranch) - mkBranchRhs brs = + mkBranchRhs :: MatchBranchRhs -> Sem '[Input Node, Input Info] MatchBranchRhs + mkBranchRhs = \case + MatchBranchRhsExpression _ -> do + e' <- inputJust + return (MatchBranchRhsExpression e') + MatchBranchRhsIfs ifs -> do + ifs' <- mkSideIfs ifs + return (MatchBranchRhsIfs ifs') + mkSideIfs :: NonEmpty SideIfBranch -> Sem '[Input Node, Input Info] (NonEmpty SideIfBranch) + mkSideIfs brs = mapM mkSideIfBranch brs mkSideIfBranch :: SideIfBranch -> Sem '[Input Node, Input Info] SideIfBranch - mkSideIfBranch = - undefined + mkSideIfBranch _ = do + _sideIfBranchInfo <- inputJust + _sideIfBranchCondition <- inputJust + _sideIfBranchBody <- inputJust + return SideIfBranch {..} numVals = length vs values' :: NonEmpty Node valueTypes' :: NonEmpty Node diff --git a/src/Juvix/Compiler/Core/Language.hs b/src/Juvix/Compiler/Core/Language.hs index 564abc6390..7542c90734 100644 --- a/src/Juvix/Compiler/Core/Language.hs +++ b/src/Juvix/Compiler/Core/Language.hs @@ -47,6 +47,8 @@ type PatternWildcard = PatternWildcard' Info Node type PatternConstr = PatternConstr' Info Node +type MatchBranchRhs = MatchBranchRhs' Info Node + type SideIfBranch = SideIfBranch' Info Node type Pattern = Pattern' Info Node diff --git a/src/Juvix/Compiler/Core/Language/Nodes.hs b/src/Juvix/Compiler/Core/Language/Nodes.hs index 5f759ac7cf..aa164b0381 100644 --- a/src/Juvix/Compiler/Core/Language/Nodes.hs +++ b/src/Juvix/Compiler/Core/Language/Nodes.hs @@ -183,7 +183,7 @@ data Match' i a = Match data MatchBranch' i a = MatchBranch { _matchBranchInfo :: i, _matchBranchPatterns :: !(NonEmpty (Pattern' i a)), - _matchBranchRhs :: !(NonEmpty (SideIfBranch' i a)) + _matchBranchRhs :: !(MatchBranchRhs' i a) } data Pattern' i a @@ -202,8 +202,12 @@ data PatternConstr' i a = PatternConstr _patternConstrArgs :: ![Pattern' i a] } +data MatchBranchRhs' i a + = MatchBranchRhsExpression !a + | MatchBranchRhsIfs !(NonEmpty (SideIfBranch' i a)) + data SideIfBranch' i a = SideIfBranch - { _sizeIfBranchInfo :: i, + { _sideIfBranchInfo :: i, _sideIfBranchCondition :: !a, _sideIfBranchBody :: !a } @@ -443,6 +447,7 @@ makeLenses ''Case' makeLenses ''CaseBranch' makeLenses ''Match' makeLenses ''MatchBranch' +makeLenses ''MatchBranchRhs' makeLenses ''PatternWildcard' makeLenses ''PatternConstr' makeLenses ''SideIfBranch' @@ -535,6 +540,11 @@ instance (Eq a) => Eq (Pi' i a) where eqOn (^. piBinder . binderType) ..&&.. eqOn (^. piBody) +instance (Eq a) => Eq (MatchBranchRhs' i a) where + (MatchBranchRhsExpression e1) == (MatchBranchRhsExpression e2) = e1 == e2 + (MatchBranchRhsIfs ifs1) == (MatchBranchRhsIfs ifs2) = ifs1 == ifs2 + _ == _ = False + instance (Eq a) => Eq (MatchBranch' i a) where (MatchBranch _ pats1 b1) == (MatchBranch _ pats2 b2) = pats1 == pats2 && b1 == b2 From 1c099e0af23095364030e95f456b24ede88efab8 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Mon, 2 Sep 2024 14:28:47 +0200 Subject: [PATCH 3/9] fromInternal --- src/Juvix/Compiler/Core/Extra/Base.hs | 8 +++ .../Compiler/Core/Translation/FromInternal.hs | 50 +++++++++++++------ 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/src/Juvix/Compiler/Core/Extra/Base.hs b/src/Juvix/Compiler/Core/Extra/Base.hs index 62fba58f05..f4ccedd004 100644 --- a/src/Juvix/Compiler/Core/Extra/Base.hs +++ b/src/Juvix/Compiler/Core/Extra/Base.hs @@ -116,6 +116,14 @@ mkIf i sym v b1 b2 = mkCase i sym v [br] (Just b2) mkIf' :: Symbol -> Node -> Node -> Node -> Node mkIf' = mkIf Info.empty +mkIfs :: Symbol -> [(Info, Node, Node)] -> Node -> Node +mkIfs sym = \case + [] -> id + ((i, v, b) : rest) -> mkIf i sym v b . mkIfs sym rest + +mkIfs' :: Symbol -> [(Node, Node)] -> Node -> Node +mkIfs' sym = mkIfs sym . map (\(v, b) -> (Info.empty, v, b)) + mkBinder :: Text -> Type -> Binder mkBinder name ty = Binder name Nothing ty diff --git a/src/Juvix/Compiler/Core/Translation/FromInternal.hs b/src/Juvix/Compiler/Core/Translation/FromInternal.hs index ffaa4be376..b3578f015c 100644 --- a/src/Juvix/Compiler/Core/Translation/FromInternal.hs +++ b/src/Juvix/Compiler/Core/Translation/FromInternal.hs @@ -492,6 +492,8 @@ goCase c = do rty <- goType (fromJust $ c ^. Internal.caseExpressionWholeType) return (mkMatch i (pure ty) rty (pure expr) branches) _ -> + -- If the type of the value matched on is not an inductive type, then the + -- case expression has one branch with a variable pattern. case c ^. Internal.caseBranches of Internal.CaseBranch {..} :| _ -> case _caseBranchPattern ^. Internal.patternArgPattern of @@ -499,34 +501,54 @@ goCase c = do vars <- asks (^. indexTableVars) varsNum <- asks (^. indexTableVarsNum) let vars' = addPatternVariableNames _caseBranchPattern varsNum vars - body <- + rhs <- local (set indexTableVars vars') (underBinders 1 (goCaseBranchRhs _caseBranchRhs)) - return $ mkLet i (Binder (name ^. nameText) (Just $ name ^. nameLoc) ty) expr body + case rhs of + MatchBranchRhsExpression body -> + return $ mkLet i (Binder (name ^. nameText) (Just $ name ^. nameLoc) ty) expr body + _ -> + impossible _ -> impossible where goCaseBranch :: Type -> Internal.CaseBranch -> Sem r MatchBranch goCaseBranch ty b = goPatternArgs 0 (b ^. Internal.caseBranchRhs) [b ^. Internal.caseBranchPattern] [ty] --- | FIXME Fix this as soon as side if conditions are implemented in Core. This --- is needed so that we can test typechecking without a crash. -todoSideIfs :: - forall r. - (Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen, Error BadScope] r) => - Internal.SideIfs -> - Sem r Node -todoSideIfs s = goExpression (s ^. Internal.sideIfBranches . _head1 . Internal.sideIfBranchBody) - goCaseBranchRhs :: forall r. (Members '[InfoTableBuilder, Reader InternalTyped.TypesTable, Reader InternalTyped.FunctionsTable, Reader Internal.InfoTable, Reader IndexTable, NameIdGen, Error BadScope] r) => Internal.CaseBranchRhs -> - Sem r Node + Sem r MatchBranchRhs goCaseBranchRhs = \case - Internal.CaseBranchRhsExpression e -> goExpression e - Internal.CaseBranchRhsIf i -> todoSideIfs i + Internal.CaseBranchRhsExpression e -> MatchBranchRhsExpression <$> goExpression e + Internal.CaseBranchRhsIf Internal.SideIfs {..} -> case _sideIfElse of + Just elseBranch -> do + branches <- toList <$> mapM goSideIfBranch _sideIfBranches + elseBranch' <- goExpression elseBranch + boolSym <- getBoolSymbol + return $ MatchBranchRhsExpression $ mkIfs' boolSym branches elseBranch' + where + goSideIfBranch :: Internal.SideIfBranch -> Sem r (Node, Node) + goSideIfBranch Internal.SideIfBranch {..} = do + cond <- goExpression _sideIfBranchCondition + body <- goExpression _sideIfBranchBody + return (cond, body) + Nothing -> do + branches <- mapM goSideIfBranch _sideIfBranches + return $ MatchBranchRhsIfs branches + where + goSideIfBranch :: Internal.SideIfBranch -> Sem r SideIfBranch + goSideIfBranch Internal.SideIfBranch {..} = do + cond <- goExpression _sideIfBranchCondition + body <- goExpression _sideIfBranchBody + return $ + SideIfBranch + { _sideIfBranchInfo = setInfoLocation (getLoc _sideIfBranchCondition) mempty, + _sideIfBranchCondition = cond, + _sideIfBranchBody = body + } goLambda :: forall r. From b396b293164ccc6955e2ca1d0b6063b2fe75f276 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Mon, 2 Sep 2024 15:18:28 +0200 Subject: [PATCH 4/9] Core parsing --- .../Compiler/Core/Translation/FromSource.hs | 73 ++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/src/Juvix/Compiler/Core/Translation/FromSource.hs b/src/Juvix/Compiler/Core/Translation/FromSource.hs index acc1f98a21..3be6c2e31a 100644 --- a/src/Juvix/Compiler/Core/Translation/FromSource.hs +++ b/src/Juvix/Compiler/Core/Translation/FromSource.hs @@ -1002,6 +1002,67 @@ matchBranch :: matchBranch patsNum varsNum vars = do off <- P.getOffset pats <- branchPatterns varsNum vars + rhs <- branchRhs off pats patsNum varsNum vars + return $ MatchBranch Info.empty (fromList pats) rhs + +branchRhs :: + (Member InfoTableBuilder r) => + Int -> + [Pattern] -> + Int -> + Index -> + HashMap Text Level -> + ParsecS r MatchBranchRhs +branchRhs off pats patsNum varsNum vars = + branchRhsExpr off pats patsNum varsNum vars + <|> branchRhsIf off pats patsNum varsNum vars + +branchRhsExpr :: + (Member InfoTableBuilder r) => + Int -> + [Pattern] -> + Int -> + Index -> + HashMap Text Level -> + ParsecS r MatchBranchRhs +branchRhsExpr off pats patsNum varsNum vars = do + kw kwAssign + unless (length pats == patsNum) $ + parseFailure off "wrong number of patterns" + let pis :: [Binder] + pis = concatMap getPatternBinders pats + (vars', varsNum') = + foldl' + ( \(vs, k) name -> + (HashMap.insert name k vs, k + 1) + ) + (vars, varsNum) + (map (^. binderName) pis) + br <- bracedExpr varsNum' vars' + return $ MatchBranchRhsExpression br + +branchRhsIf :: + (Member InfoTableBuilder r) => + Int -> + [Pattern] -> + Int -> + Index -> + HashMap Text Level -> + ParsecS r MatchBranchRhs +branchRhsIf off pats patsNum varsNum vars = do + ifs <- sideIfs off pats patsNum varsNum vars + return $ MatchBranchRhsIfs ifs + +sideIfs :: + (Member InfoTableBuilder r) => + Int -> + [Pattern] -> + Int -> + Index -> + HashMap Text Level -> + ParsecS r (NonEmpty SideIfBranch) +sideIfs off pats patsNum varsNum vars = do + cond <- branchCond varsNum vars kw kwAssign unless (length pats == patsNum) $ parseFailure off "wrong number of patterns" @@ -1015,7 +1076,17 @@ matchBranch patsNum varsNum vars = do (vars, varsNum) (map (^. binderName) pis) br <- bracedExpr varsNum' vars' - return $ MatchBranch Info.empty (fromList pats) br + conds <- optional (sideIfs off pats patsNum varsNum vars) + return $ SideIfBranch Info.empty cond br :| maybe [] toList conds + +branchCond :: + (Member InfoTableBuilder r) => + Index -> + HashMap Text Level -> + ParsecS r Node +branchCond varsNum vars = do + kw kwIf + expr varsNum vars branchPatterns :: (Member InfoTableBuilder r) => From 10fd8f50f28f3f46948aaadd5e28d7a989420d18 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Mon, 2 Sep 2024 15:40:19 +0200 Subject: [PATCH 5/9] Core printing --- src/Juvix/Compiler/Core/Pretty/Base.hs | 21 +++++++++++++++++-- .../Core/Transformation/MatchToCase.hs | 2 +- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/Juvix/Compiler/Core/Pretty/Base.hs b/src/Juvix/Compiler/Core/Pretty/Base.hs index 433567ce9d..8cb74a3aa4 100644 --- a/src/Juvix/Compiler/Core/Pretty/Base.hs +++ b/src/Juvix/Compiler/Core/Pretty/Base.hs @@ -369,6 +369,23 @@ instance PrettyCode Bottom where ty' <- ppCode _bottomType return (parens (kwBottom <+> kwColon <+> ty')) +instance PrettyCode SideIfBranch where + ppCode :: (Member (Reader Options) r) => SideIfBranch -> Sem r (Doc Ann) + ppCode SideIfBranch {..} = do + cond <- ppCode _sideIfBranchCondition + body <- ppCode _sideIfBranchBody + return $ kwIf <+> cond <+> kwAssign <+> body + +instance PrettyCode MatchBranchRhs where + ppCode :: (Member (Reader Options) r) => MatchBranchRhs -> Sem r (Doc Ann) + ppCode = \case + MatchBranchRhsExpression x -> do + e <- ppCode x + return $ kwAssign <+> e + MatchBranchRhsIfs x -> do + brs <- mapM ppCode x + return $ vsep brs + instance PrettyCode Node where ppCode :: forall r. (Member (Reader Options) r) => Node -> Sem r (Doc Ann) ppCode node = case node of @@ -394,11 +411,11 @@ instance PrettyCode Node where ppCodeCase' branchBinderNames branchBinderTypes branchTagNames x NMatch Match {..} -> do let branchPatterns = map (^. matchBranchPatterns) _matchBranches - branchBodies = map (^. matchBranchBody) _matchBranches + branchRhs = map (^. matchBranchRhs) _matchBranches pats <- mapM ppPatterns branchPatterns vs <- mapM ppCode _matchValues vs' <- zipWithM ppWithType (toList vs) (toList _matchValueTypes) - bs <- sequence $ zipWithExact (\ps br -> ppCode br >>= \br' -> return $ ps <+> kwAssign <+> br') pats branchBodies + bs <- sequence $ zipWithExact (\ps br -> ppCode br >>= \br' -> return $ ps <+> br') pats branchRhs let bss = bracesIndent $ align $ concatWith (\a b -> a <> kwSemicolon <> line <> b) bs rty <- ppTypeAnnot _matchReturnType return $ kwMatch <+> hsep (punctuate comma vs') <+> kwWith <> rty <+> bss diff --git a/src/Juvix/Compiler/Core/Transformation/MatchToCase.hs b/src/Juvix/Compiler/Core/Transformation/MatchToCase.hs index 9ac354f6b4..b498b74c90 100644 --- a/src/Juvix/Compiler/Core/Transformation/MatchToCase.hs +++ b/src/Juvix/Compiler/Core/Transformation/MatchToCase.hs @@ -58,7 +58,7 @@ goMatchToCase recur node = case node of matchBranchToPatternRow MatchBranch {..} = PatternRow { _patternRowPatterns = toList _matchBranchPatterns, - _patternRowBody = _matchBranchBody, + _patternRowBody = undefined, _patternRowIgnoredPatternsNum = 0, _patternRowBinderChangesRev = [BCAdd n] } From becd3bc6954681266d8ef81f2907332708a70363 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Mon, 2 Sep 2024 16:05:36 +0200 Subject: [PATCH 6/9] Core evaluator --- src/Juvix/Compiler/Core/Evaluator.hs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/Juvix/Compiler/Core/Evaluator.hs b/src/Juvix/Compiler/Core/Evaluator.hs index 21bc538477..cb79eb9b15 100644 --- a/src/Juvix/Compiler/Core/Evaluator.hs +++ b/src/Juvix/Compiler/Core/Evaluator.hs @@ -149,7 +149,7 @@ geval opts herr tab env0 = eval' env0 match n env vs = \case br : brs -> case matchPatterns [] vs (toList (br ^. matchBranchPatterns)) of - Just args -> eval' (args ++ env) (br ^. matchBranchBody) + Just args -> matchRhs (args ++ env) (br ^. matchBranchRhs) Nothing -> match n env vs brs where matchPatterns :: [Node] -> [Node] -> [Pattern] -> Maybe [Node] @@ -169,6 +169,18 @@ geval opts herr tab env0 = eval' env0 | tag == _patternConstrTag = matchPatterns (v : acc) args _patternConstrArgs patmatch _ _ _ = Nothing + + matchRhs :: [Node] -> MatchBranchRhs -> Node + matchRhs env' = \case + MatchBranchRhsExpression e -> eval' env' e + MatchBranchRhsIfs ifs -> matchIfs env' (toList ifs) + + matchIfs :: [Node] -> [SideIfBranch] -> Node + matchIfs env' = \case + SideIfBranch {..} : ifs -> case eval' env' _sideIfBranchCondition of + NCtr (Constr _ (BuiltinTag TagTrue) []) -> eval' env' _sideIfBranchBody + _ -> matchIfs env' ifs + [] -> match n env vs brs [] -> evalError "no matching pattern" (substEnv env n) From 2fd362b0608affc96e12093b22eb979de0e883a7 Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Wed, 4 Sep 2024 15:15:46 +0200 Subject: [PATCH 7/9] pattern matching compilation --- .../Core/Extra/Recursors/RMap/Named.hs | 2 +- .../Core/Transformation/MatchToCase.hs | 24 +++++++++++++------ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/Juvix/Compiler/Core/Extra/Recursors/RMap/Named.hs b/src/Juvix/Compiler/Core/Extra/Recursors/RMap/Named.hs index e7863152cc..b4012a3673 100644 --- a/src/Juvix/Compiler/Core/Extra/Recursors/RMap/Named.hs +++ b/src/Juvix/Compiler/Core/Extra/Recursors/RMap/Named.hs @@ -53,7 +53,7 @@ where _ -> recur [] node where - cont :: Level -> [BinderChange] -> Node -> Node + cont :: [BinderChange] -> Node -> Node cont bcs = go (recur . (bcs ++)) (k + bindersNumFromBinderChange bcs) ``` produces diff --git a/src/Juvix/Compiler/Core/Transformation/MatchToCase.hs b/src/Juvix/Compiler/Core/Transformation/MatchToCase.hs index b498b74c90..e9c62be4f5 100644 --- a/src/Juvix/Compiler/Core/Transformation/MatchToCase.hs +++ b/src/Juvix/Compiler/Core/Transformation/MatchToCase.hs @@ -12,7 +12,7 @@ import Juvix.Compiler.Core.Transformation.Base data PatternRow = PatternRow { _patternRowPatterns :: [Pattern], - _patternRowBody :: Node, + _patternRowRhs :: MatchBranchRhs, -- | The number of initial wildcard binders in `_patternRowPatterns` which -- don't originate from the input _patternRowIgnoredPatternsNum :: Int, @@ -58,7 +58,7 @@ goMatchToCase recur node = case node of matchBranchToPatternRow MatchBranch {..} = PatternRow { _patternRowPatterns = toList _matchBranchPatterns, - _patternRowBody = undefined, + _patternRowRhs = _matchBranchRhs, _patternRowIgnoredPatternsNum = 0, _patternRowBinderChangesRev = [BCAdd n] } @@ -104,10 +104,10 @@ goMatchToCase recur node = case node of pat' = if length pat == 1 then doc defaultOptions (head' pat) else docValueSequence pat mockFile = $(mkAbsFile "/match-to-case") defaultLoc = singletonInterval (mkInitialLoc mockFile) - r@PatternRow {..} : _ + r@PatternRow {..} : matrix' | all isPatWildcard _patternRowPatterns -> -- The first row matches all values (Section 4, case 2) - compileMatchingRow bindersNum vs r + compileMatchingRow err bindersNum vs matrix' r _ -> do -- Section 4, case 3 -- Select the first column @@ -181,9 +181,19 @@ goMatchToCase recur node = case node of where ii = lookupInductiveInfo md ind - compileMatchingRow :: Level -> [Level] -> PatternRow -> Sem r Node - compileMatchingRow bindersNum vs PatternRow {..} = - goMatchToCase (recur . (bcs ++)) _patternRowBody + compileMatchingRow :: ([Value] -> [Value]) -> Level -> [Level] -> PatternMatrix -> PatternRow -> Sem r Node + compileMatchingRow err bindersNum vs matrix PatternRow {..} = + case _patternRowRhs of + MatchBranchRhsExpression body -> + goMatchToCase (recur . (bcs ++)) body + MatchBranchRhsIfs ifs -> do + -- If the branch has side-conditions, then we need to continue pattern + -- matching when none of the conditions is satisfied. + body <- compile err bindersNum vs matrix + md <- ask + let boolSym = lookupConstructorInfo md (BuiltinTag TagTrue) ^. constructorInductive + ifs' = map (\(SideIfBranch i c b) -> (i, c, b)) (toList ifs) + return $ mkIfs boolSym ifs' body where bcs = reverse $ From 6a81b8a43aad2431b1c3ed167cad1148a8e4570e Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Wed, 4 Sep 2024 16:30:09 +0200 Subject: [PATCH 8/9] fix Format.juvix test --- src/Juvix/Compiler/Internal/Language.hs | 2 +- tests/positive/Format.juvix | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Juvix/Compiler/Internal/Language.hs b/src/Juvix/Compiler/Internal/Language.hs index 662c691513..ce43ab97c7 100644 --- a/src/Juvix/Compiler/Internal/Language.hs +++ b/src/Juvix/Compiler/Internal/Language.hs @@ -669,7 +669,7 @@ instance HasLoc CaseBranch where getLoc c = getLoc (c ^. caseBranchPattern) <> getLoc (c ^. caseBranchRhs) instance HasLoc Case where - getLoc c = getLocSpan (c ^. caseBranches) + getLoc c = getLoc (c ^. caseExpression) <> getLocSpan (c ^. caseBranches) instance HasLoc Expression where getLoc = \case diff --git a/tests/positive/Format.juvix b/tests/positive/Format.juvix index 6e08aa7b9f..65021767c3 100644 --- a/tests/positive/Format.juvix +++ b/tests/positive/Format.juvix @@ -465,7 +465,8 @@ module SideIfConditions; | suc (suc n) | if 0 < 0 := 3 | else := n - | suc n if 0 < 0 := 3; + | suc n if 0 < 0 := 3 + | suc zero := 5; end; module MultiIf; From 9b6f05a30e70f5fb73f7d509213cab4b923b6b5e Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Thu, 5 Sep 2024 14:21:36 +0200 Subject: [PATCH 9/9] add tests --- src/Juvix/Compiler/Internal/Language.hs | 2 +- test/Compilation/Negative.hs | 6 +++++- tests/Compilation/negative/test007.juvix | 10 ++++++++++ tests/Compilation/positive/out/test035.out | 5 +++++ tests/Compilation/positive/test035.juvix | 15 ++++++++++++++- 5 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 tests/Compilation/negative/test007.juvix diff --git a/src/Juvix/Compiler/Internal/Language.hs b/src/Juvix/Compiler/Internal/Language.hs index ce43ab97c7..662c691513 100644 --- a/src/Juvix/Compiler/Internal/Language.hs +++ b/src/Juvix/Compiler/Internal/Language.hs @@ -669,7 +669,7 @@ instance HasLoc CaseBranch where getLoc c = getLoc (c ^. caseBranchPattern) <> getLoc (c ^. caseBranchRhs) instance HasLoc Case where - getLoc c = getLoc (c ^. caseExpression) <> getLocSpan (c ^. caseBranches) + getLoc c = getLocSpan (c ^. caseBranches) instance HasLoc Expression where getLoc = \case diff --git a/test/Compilation/Negative.hs b/test/Compilation/Negative.hs index c6fbf1f57a..afbc4b2c63 100644 --- a/test/Compilation/Negative.hs +++ b/test/Compilation/Negative.hs @@ -53,5 +53,9 @@ tests = NegTest "Test006: Ill scoped term (This is a bug. It should be positive)" $(mkRelDir ".") - $(mkRelFile "test006.juvix") + $(mkRelFile "test006.juvix"), + NegTest + "Test007: Pattern matching coverage with side conditions" + $(mkRelDir ".") + $(mkRelFile "test007.juvix") ] diff --git a/tests/Compilation/negative/test007.juvix b/tests/Compilation/negative/test007.juvix new file mode 100644 index 0000000000..7235242512 --- /dev/null +++ b/tests/Compilation/negative/test007.juvix @@ -0,0 +1,10 @@ +module test007; + +import Stdlib.Prelude open; + +f (x : List Nat) : Nat := + case x of + | nil := 0 + | x :: _ if true := x; + +main : Nat := f (1 :: 2 :: nil); diff --git a/tests/Compilation/positive/out/test035.out b/tests/Compilation/positive/out/test035.out index 80257c93d2..a420b9e14c 100644 --- a/tests/Compilation/positive/out/test035.out +++ b/tests/Compilation/positive/out/test035.out @@ -7,3 +7,8 @@ 13536 1 0 +1 +0 +4 +9 +0 diff --git a/tests/Compilation/positive/test035.juvix b/tests/Compilation/positive/test035.juvix index 463330ce92..30bd483f0d 100644 --- a/tests/Compilation/positive/test035.juvix +++ b/tests/Compilation/positive/test035.juvix @@ -40,6 +40,14 @@ h : Nat -> Nat | (suc (suc (suc (suc n)))) := n | _ := 0; +hh (x : Nat) : Nat := + case x of + | zero := 1 + | (suc n) if h n == 0 := n + | (suc zero) := 17 + | (suc (suc (suc (suc (suc (suc zero)))))) := 9 + | _ := 0; + printListNatLn : List Nat → IO | nil := printStringLn "nil" | (x :: xs) := @@ -54,4 +62,9 @@ main : IO := >>> printNatLn (f (gen 18)) >>> printNatLn (f (gen 20)) >>> printNatLn (h 5) - >>> printNatLn (h 3); + >>> printNatLn (h 3) + >>> printNatLn (hh 0) + >>> printNatLn (hh 1) + >>> printNatLn (hh 5) + >>> printNatLn (hh 6) + >>> printNatLn (hh 7);