Skip to content

Commit

Permalink
generic types for termination
Browse files Browse the repository at this point in the history
  • Loading branch information
janmasrovira committed Sep 4, 2024
1 parent ff095a2 commit f47442b
Show file tree
Hide file tree
Showing 17 changed files with 292 additions and 321 deletions.
2 changes: 1 addition & 1 deletion app/Commands/Dev/Termination/CallGraph.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ runCommand CallGraphOptions {..} = do
<> _pipelineResult
^. Internal.resultInternalModule
. Internal.internalModuleInfoTable
callMap = Termination.buildCallMap mainModule
callMap = fst (Termination.buildCallMap mainModule)
completeGraph = Termination.completeCallGraph callMap
filteredGraph =
maybe
Expand Down
2 changes: 1 addition & 1 deletion app/Commands/Dev/Termination/Calls.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ runCommand :: (Members AppEffects r) => CallsOptions -> Sem r ()
runCommand localOpts@CallsOptions {..} = do
globalOpts <- askGlobalOptions
PipelineResult {..} <- runPipelineTermination _callsInputFile upToInternal
let callMap0 = Termination.buildCallMap (_pipelineResult ^. Internal.resultModule)
let callMap0 = fst (Termination.buildCallMap (_pipelineResult ^. Internal.resultModule))
callMap = case _callsFunctionNameFilter of
Nothing -> callMap0
Just f -> Termination.filterCallMap f callMap0
Expand Down
4 changes: 3 additions & 1 deletion src/Juvix/Compiler/Internal/Data/Cast.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ module Juvix.Compiler.Internal.Data.Cast where
import Juvix.Compiler.Internal.Language
import Juvix.Prelude

data CastType = CastInt | CastNat
data CastType
= CastInt
| CastNat

data CastHole = CastHole
{ _castHoleHole :: Hole,
Expand Down
13 changes: 6 additions & 7 deletions src/Juvix/Compiler/Internal/Extra/InstanceInfo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -99,19 +99,18 @@ traitFromExpression metaVars e = case paramFromExpression metaVars e of
Just (InstanceParamApp app) -> Just app
_ -> Nothing

instanceFromTypedExpression :: TypedExpression -> Maybe InstanceInfo
instanceFromTypedExpression TypedExpression {..} = do
InstanceApp {..} <- traitFromExpression metaVars e
mkInstanceInfo :: Iden -> Expression -> Maybe InstanceInfo
mkInstanceInfo funName funTy = do
let (args, ret) = unfoldFunType funTy
metaVars = hashSet (mapMaybe (^. paramName) args)
InstanceApp {..} <- traitFromExpression metaVars ret
return $
InstanceInfo
{ _instanceInfoInductive = _instanceAppHead,
_instanceInfoParams = _instanceAppArgs,
_instanceInfoResult = _typedExpression,
_instanceInfoIden = funName,
_instanceInfoArgs = args
}
where
(args, e) = unfoldFunType _typedType
metaVars = HashSet.fromList $ mapMaybe (^. paramName) args

checkNoMeta :: InstanceParam -> Bool
checkNoMeta = \case
Expand Down
4 changes: 3 additions & 1 deletion src/Juvix/Compiler/Internal/Language.hs
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ data NormalizedExpression = NormalizedExpression
}

makePrisms ''Expression
makePrisms ''MutualStatement

makeLenses ''SideIfBranch
makeLenses ''SideIfs
makeLenses ''CaseBranchRhs
Expand Down Expand Up @@ -582,7 +584,7 @@ instance HasAtomicity Pattern where
PatternWildcardConstructor {} -> Atom

instance HasLoc Module where
getLoc m = getLoc (m ^. moduleName) <>? maybe Nothing (Just . getLocSpan) (nonEmpty (m ^. moduleBody . moduleStatements))
getLoc m = getLoc (m ^. moduleName) <>? fmap getLocSpan (nonEmpty (m ^. moduleBody . moduleStatements))

instance HasLoc MutualBlock where
getLoc = getLocSpan . (^. mutualStatements)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,96 +3,53 @@ module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.FunctionCall
)
where

import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Internal.Extra
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data
import Juvix.Prelude

-- type FunCall = FunCall' Expression
-- type CallMap = CallMap' Expression
-- type FunCallArg = FunCallArg' Expression
viewCallNew :: Expression -> Maybe (FunctionName, [Expression])
viewCallNew = fmap swap . run . runFail . runOutputList . go
where
go :: (Members '[Fail, Output Expression] r) => Expression -> Sem r FunctionName
go = \case
ExpressionIden (IdenFunction fun) -> return fun
ExpressionApplication (Application f arg impl)
| isImplicitOrInstance impl -> go f -- implicit arguments are ignored
| otherwise -> do
fun <- go f
output arg
return fun
_ -> fail

viewCall ::
forall r.
(Members '[Reader SizeInfo] r) =>
Expression ->
Sem r (Maybe (FunCall' Expression))
viewCall = \case
ExpressionIden (IdenFunction x) ->
return (Just (singletonCall x))
ExpressionApplication (Application f x impl)
| isImplicitOrInstance impl -> viewCall f -- implicit arguments are ignored
| otherwise -> do
c <- viewCall f
x' <- callArg
return $ over callArgs (`snoc` x') <$> c
where
callArg :: Sem r (FunCallArg' Expression)
callArg = do
lt <- (^. callRow) <$> lessThan
eq <- (^. callRow) <$> equalTo
let cr = CallRow (lt `mplus` eq)
return
FunCallArg
{ _argRow = cr,
_argExpression = x
}
where
lessThan :: Sem r CallRow
lessThan = case viewExpressionAsPattern x of
Nothing -> return (CallRow Nothing)
Just x' -> do
s <- asks (findIndex (elem x') . (^. sizeSmaller))
return $ case s of
Nothing -> CallRow Nothing
Just s' -> CallRow (Just (s', RLe))
equalTo :: Sem r CallRow
equalTo =
case viewExpressionAsPattern x of
Just x' -> do
s <- asks (elemIndex x' . (^. sizeEqual))
return $ case s of
Nothing -> CallRow Nothing
Just s' -> CallRow (Just (s', REq))
Nothing -> return (CallRow Nothing)
_ -> return Nothing
where
singletonCall :: FunctionName -> FunCall' expr
singletonCall r = FunCall r []

addCall :: forall expr. FunctionName -> FunCall' expr -> CallMap' expr -> CallMap' expr
addCall fun c = over callMap (HashMap.alter (Just . insertCall c) fun)
where
insertCall ::
FunCall' expr ->
Maybe (HashMap FunctionName [FunCall' expr]) ->
HashMap FunctionName [FunCall' expr]
insertCall f = \case
Nothing -> singl f
Just m' -> addFunCall f m'

singl :: FunCall' expr -> HashMap FunctionName [FunCall' expr]
singl f = HashMap.singleton (f ^. callRef) [f]

addFunCall ::
FunCall' expr ->
HashMap FunctionName [FunCall' expr] ->
HashMap FunctionName [FunCall' expr]
addFunCall fc = HashMap.insertWith (flip (<>)) (fc ^. callRef) [fc]
viewCall e = do
si :: SizeInfo <- ask
let rel :: Pattern -> Expression -> Maybe SizeRel'
rel pat expr = do
pexpr <- viewExpressionAsPattern expr
guard (pexpr `elem` pat ^.. patternSubCosmos) $> RLe
<|> guard (pexpr == pat) $> REq
return $ do
(fun, args) <- viewCallNew e
return (mkFunCall rel fun (si ^. sizeEqual) args)

registerFunctionDef ::
forall expr r.
(Members '[State (CallMap' expr)] r) =>
(Members '[State (HashMap FunctionName FunctionDef)] r) =>
Proxy expr ->
FunctionDef ->
Sem r ()
registerFunctionDef Proxy f = modify' @(CallMap' expr) (set (callMapScanned . at (f ^. funDefName)) (Just f))
registerFunctionDef Proxy f = modify' @(HashMap FunctionName FunctionDef) (set (at (f ^. funDefName)) (Just f))

registerCall ::
forall expr r.
(Members '[State (CallMap' expr), Reader FunctionName] r) =>
(Members '[CallMapBuilder' expr, Reader FunctionName] r) =>
FunCall' expr ->
Sem r ()
registerCall c = do
fun <- ask
modify (addCall fun c)
addCall fun c
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Che
)
where

import Data.HashMap.Internal.Strict qualified as HashMap
import Juvix.Compiler.Internal.Language as Internal
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data.TerminationState
Expand Down Expand Up @@ -65,12 +64,12 @@ checkTerminationShallow' ::
m ->
Sem r ()
checkTerminationShallow' topModule = do
let callmap = buildCallMap topModule
let (callmap, scannedFuns) = buildCallMap topModule
forM_ (callMapRecursiveBehaviour callmap) $ \rb -> do
let funName = rb ^. recursiveBehaviourFun
markedTerminating :: Bool = funInfo ^. Internal.funDefTerminating
funInfo :: FunctionDef
funInfo = HashMap.lookupDefault err funName (callmap ^. callMapScanned)
funInfo = fromMaybe err (scannedFuns ^. at funName)
where
err = error ("Impossible: function not found: " <> funName ^. nameText)
order = findOrder rb
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data
( module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data.FunctionCall,
( module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data.Base,
module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data.Graph,
module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data.SizeInfo,
module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data.SizeRelation,
)
where

import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data.FunctionCall
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data.Base
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data.Graph
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data.SizeInfo
import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data.SizeRelation
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{-# OPTIONS_GHC -Wno-unused-type-patterns #-}

module Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Data.Base where

import Data.HashMap.Strict qualified as HashMap
Expand All @@ -7,9 +9,8 @@ import Juvix.Compiler.Internal.Translation.FromInternal.Analysis.Termination.Dat
import Juvix.Extra.Strings qualified as Str
import Juvix.Prelude

data CallMap' expr = CallMap
{ _callMap :: HashMap FunctionName (HashMap FunctionName [FunCall' expr]),
_callMapScanned :: HashMap FunctionName FunctionDef
newtype CallMap' expr = CallMap
{ _callMap :: HashMap FunctionName (HashMap FunctionName [FunCall' expr])
}

data FunCall' expr = FunCall
Expand Down Expand Up @@ -42,12 +43,65 @@ data Call = Call

newtype LexOrder = LexOrder (NonEmpty Int)

data CallMapBuilder' expr :: Effect where
AddCall :: FunctionName -> FunCall' expr -> CallMapBuilder' expr m ()

makeEffect ''CallMapBuilder'

makeLenses ''CallMatrix
makeLenses ''CallRow
makeLenses ''FunCall'
makeLenses ''CallMap'
makeLenses ''FunCallArg'

mkFunCall :: forall pattrn expr. (pattrn -> expr -> Maybe SizeRel') -> FunctionName -> [pattrn] -> [expr] -> FunCall' expr
mkFunCall rel fun pats args =
FunCall
{ _callRef = fun,
_callArgs = map (mkFunCallArg rel pats) args
}

mkFunCallArg :: forall pattrn expr. (pattrn -> expr -> Maybe SizeRel') -> [pattrn] -> expr -> FunCallArg' expr
mkFunCallArg rel pats arg =
let rels = map (`rel` arg) pats
helper srel = (,srel) <$> elemIndex (Just srel) rels
smaller = helper RLe
equal = helper REq
in FunCallArg
{ _argExpression = arg,
_argRow =
CallRow $
smaller
<|> equal
}

execCallMapBuilder :: Sem (CallMapBuilder' expr ': r) a -> Sem r (CallMap' expr)
execCallMapBuilder = fmap fst . runCallMapBuilder

runCallMapBuilder :: Sem (CallMapBuilder' expr ': r) a -> Sem r (CallMap' expr, a)
runCallMapBuilder = reinterpret (runState emptyCallMap) $ \case
AddCall fun c -> modify (addCall' fun c)

addCall' :: forall expr. FunctionName -> FunCall' expr -> CallMap' expr -> CallMap' expr
addCall' fun c = over callMap (HashMap.alter (Just . insertCall c) fun)
where
insertCall ::
FunCall' expr ->
Maybe (HashMap FunctionName [FunCall' expr]) ->
HashMap FunctionName [FunCall' expr]
insertCall f = \case
Nothing -> singl f
Just m' -> addFunCall f m'

singl :: FunCall' expr -> HashMap FunctionName [FunCall' expr]
singl f = HashMap.singleton (f ^. callRef) [f]

addFunCall ::
FunCall' expr ->
HashMap FunctionName [FunCall' expr] ->
HashMap FunctionName [FunCall' expr]
addFunCall fc = HashMap.insertWith (flip (<>)) (fc ^. callRef) [fc]

filterCallMap :: (Foldable f) => f Text -> CallMap' expr -> CallMap' expr
filterCallMap funNames =
over
Expand Down Expand Up @@ -97,7 +151,7 @@ instance (HasAtomicity expr, PrettyCode expr) => PrettyCode (CallMap' expr) wher
(Members '[Reader Options] r) =>
CallMap' expr ->
Sem r (Doc Ann)
ppCode (CallMap m _) = vsep <$> mapM ppEntry (HashMap.toList m)
ppCode (CallMap m) = vsep <$> mapM ppEntry (HashMap.toList m)
where
ppEntry :: (FunctionName, HashMap FunctionName [FunCall' expr]) -> Sem r (Doc Ann)
ppEntry (fun, mcalls) = do
Expand Down Expand Up @@ -125,6 +179,5 @@ kwWaveArrow = keyword Str.waveArrow
emptyCallMap :: CallMap' expr
emptyCallMap =
CallMap
{ _callMap = mempty,
_callMapScanned = mempty
{ _callMap = mempty
}
Loading

0 comments on commit f47442b

Please sign in to comment.