diff --git a/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs b/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs index 2a665ab4e5..b9be178967 100644 --- a/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs +++ b/src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs @@ -19,7 +19,10 @@ makeLenses ''FreeVarsInfo -- | Computes free variable info for each subnode. Assumption: no subnode is a -- closure. computeFreeVarsInfo :: Node -> Node -computeFreeVarsInfo = umap go +computeFreeVarsInfo = computeFreeVarsInfo' 1 + +computeFreeVarsInfo' :: Int -> Node -> Node +computeFreeVarsInfo' lambdaMultiplier = umap go where go :: Node -> Node go node = case node of @@ -27,6 +30,15 @@ computeFreeVarsInfo = umap go mkVar (Info.insert fvi _varInfo) _varIndex where fvi = FreeVarsInfo (Map.singleton _varIndex 1) + NLam Lambda {..} -> + modifyInfo (Info.insert fvi) node + where + fvi = + FreeVarsInfo + . fmap (* lambdaMultiplier) + . Map.mapKeysMonotonic (\idx -> idx - 1) + . Map.filterWithKey (\idx _ -> idx >= 1) + $ getFreeVarsInfo _lambdaBody ^. infoFreeVars _ -> modifyInfo (Info.insert fvi) node where @@ -34,11 +46,11 @@ computeFreeVarsInfo = umap go FreeVarsInfo $ foldr ( \NodeChild {..} acc -> - Map.unionWith (+) acc $ - Map.mapKeysMonotonic (\idx -> idx - _childBindersNum) $ - Map.filterWithKey - (\idx _ -> idx >= _childBindersNum) - (getFreeVarsInfo _childNode ^. infoFreeVars) + Map.unionWith (+) acc + . Map.mapKeysMonotonic (\idx -> idx - _childBindersNum) + . Map.filterWithKey + (\idx _ -> idx >= _childBindersNum) + $ getFreeVarsInfo _childNode ^. infoFreeVars ) mempty (children node) diff --git a/src/Juvix/Compiler/Core/Info/ShallowFreeVarsInfo.hs b/src/Juvix/Compiler/Core/Info/ShallowFreeVarsInfo.hs deleted file mode 100644 index 31cc9d4f62..0000000000 --- a/src/Juvix/Compiler/Core/Info/ShallowFreeVarsInfo.hs +++ /dev/null @@ -1,50 +0,0 @@ -module Juvix.Compiler.Core.Info.ShallowFreeVarsInfo where - -import Data.Map qualified as Map -import Juvix.Compiler.Core.Extra -import Juvix.Compiler.Core.Info qualified as Info - -newtype ShallowFreeVarsInfo = ShallowFreeVarsInfo - { -- map free variables to the number of their shallow occurrences (not under binders) - _infoShallowFreeVars :: Map Index Int - } - -instance IsInfo ShallowFreeVarsInfo - -kShallowFreeVarsInfo :: Key ShallowFreeVarsInfo -kShallowFreeVarsInfo = Proxy - -makeLenses ''ShallowFreeVarsInfo - --- | Computes shallow free variable info for each subnode. Assumption: no --- subnode is a closure. -computeShallowFreeVarsInfo :: Node -> Node -computeShallowFreeVarsInfo = umap go - where - go :: Node -> Node - go node = case node of - NVar Var {..} -> - mkVar (Info.insert fvi _varInfo) _varIndex - where - fvi = ShallowFreeVarsInfo (Map.singleton _varIndex 1) - _ -> - modifyInfo (Info.insert fvi) node - where - fvi = - ShallowFreeVarsInfo $ - foldr - ( \NodeChild {..} acc -> - if - | _childBindersNum == 0 -> - Map.unionWith (+) acc (getShallowFreeVarsInfo _childNode ^. infoShallowFreeVars) - | otherwise -> - acc - ) - mempty - (children node) - -getShallowFreeVarsInfo :: Node -> ShallowFreeVarsInfo -getShallowFreeVarsInfo = fromJust . Info.lookup kShallowFreeVarsInfo . getInfo - -shallowFreeVarOccurrences :: Index -> Node -> Int -shallowFreeVarOccurrences idx n = fromMaybe 0 (Map.lookup idx (getShallowFreeVarsInfo n ^. infoShallowFreeVars)) diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs index 87cf884eb8..d95891177c 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs @@ -40,7 +40,7 @@ letFolding' isFoldable tab = mapAllNodes ( removeInfo kFreeVarsInfo . convertNode isFoldable tab - . computeFreeVarsInfo + . computeFreeVarsInfo' 2 ) tab diff --git a/tests/Compilation/positive/test059.juvix b/tests/Compilation/positive/test059.juvix index 03a5be832f..91ea8dcbce 100644 --- a/tests/Compilation/positive/test059.juvix +++ b/tests/Compilation/positive/test059.juvix @@ -1,15 +1,15 @@ -- builtin list module test059; -import Stdlib.Prelude open hiding {head}; +import Stdlib.Prelude open; mylist : List Nat := [1; 2; 3 + 1]; mylist2 : List (List Nat) := [[10]; [2]; 3 + 1 :: nil]; -head : {a : Type} -> a -> List a -> a +head' : {a : Type} -> a -> List a -> a | a [] := a | a [x; _] := x | _ (h :: _) := h; -main : Nat := head 50 mylist + head 50 (head [] mylist2); +main : Nat := head' 50 mylist + head' 50 (head' [] mylist2);