Skip to content

Commit

Permalink
Detect redundant patterns (#3101)
Browse files Browse the repository at this point in the history
* Closes #3008
* Implements the algorithm from [Luc Maranget, Warnings for Pattern
Matching](https://www.cambridge.org/core/services/aop-cambridge-core/content/view/3165B75113781E2431E3856972940347/S0956796807006223a.pdf/warnings-for-pattern-matching.pdf)
to detect redundant patterns.
* Adds an option to the Core pretty printer to print match patterns in a
user-friendly format consistent with pattern syntax in Juvix frontend
language.
  • Loading branch information
lukaszcz authored Oct 30, 2024
1 parent 23837ed commit 68a79bc
Show file tree
Hide file tree
Showing 17 changed files with 301 additions and 38 deletions.
6 changes: 4 additions & 2 deletions src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ data TransformationId
| UnrollRecursion
| ComputeTypeInfo
| ComputeCaseANF
| DetectRedundantPatterns
| MatchToCase
| EtaExpandApps
| DisambiguateNames
Expand Down Expand Up @@ -58,10 +59,10 @@ data PipelineId
type TransformationLikeId = TransformationLikeId' TransformationId PipelineId

toTypecheckTransformations :: [TransformationId]
toTypecheckTransformations = [MatchToCase]
toTypecheckTransformations = [DetectRedundantPatterns, MatchToCase]

toStoredTransformations :: [TransformationId]
toStoredTransformations = [EtaExpandApps, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, OptPhaseEval, DisambiguateNames]
toStoredTransformations = [EtaExpandApps, DetectRedundantPatterns, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, OptPhaseEval, DisambiguateNames]

combineInfoTablesTransformations :: [TransformationId]
combineInfoTablesTransformations = [CombineInfoTables, FilterUnreachable]
Expand All @@ -83,6 +84,7 @@ instance TransformationId' TransformationId where
LambdaLetRecLifting -> strLifting
LetRecLifting -> strLetRecLifting
TopEtaExpand -> strTopEtaExpand
DetectRedundantPatterns -> strDetectRedundantPatterns
MatchToCase -> strMatchToCase
EtaExpandApps -> strEtaExpandApps
IdentityTrans -> strIdentity
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ strLetRecLifting = "letrec-lifting"
strTopEtaExpand :: Text
strTopEtaExpand = "top-eta-expand"

strDetectRedundantPatterns :: Text
strDetectRedundantPatterns = "detect-redundant-patterns"

strMatchToCase :: Text
strMatchToCase = "match-to-case"

Expand Down
17 changes: 16 additions & 1 deletion src/Juvix/Compiler/Core/Extra/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ isTypeBool = \case
NPrim (TypePrim _ (PrimBool _)) -> True
_ -> False

isUniverse :: Type -> Bool
isUniverse = \case
NUniv {} -> True
_ -> False

-- | `expandType argtys ty` expands the dynamic target of `ty` to match the
-- number of arguments with types specified by `argstys`. For example,
-- `expandType [int, string] (int -> any) = int -> string -> any`.
Expand Down Expand Up @@ -675,9 +680,19 @@ destruct = \case
concat
[ br
^. matchBranchInfo
: concatMap getPatternInfos (br ^. matchBranchPatterns)
: getSideIfBranchInfos (br ^. matchBranchRhs)
++ concatMap getPatternInfos (br ^. matchBranchPatterns)
| br <- branches
]

getSideIfBranchInfos :: MatchBranchRhs -> [Info]
getSideIfBranchInfos = \case
MatchBranchRhsExpression _ -> []
MatchBranchRhsIfs ifs -> map getSideIfBranchInfos' (toList ifs)
where
getSideIfBranchInfos' :: SideIfBranch -> Info
getSideIfBranchInfos' SideIfBranch {..} = _sideIfBranchInfo

-- sets the infos and the binder types in the patterns
setPatternsInfos :: forall r. (Members '[Input Info, Input Node] r) => NonEmpty Pattern -> Sem r (NonEmpty Pattern)
setPatternsInfos = mapM goPattern
Expand Down
3 changes: 2 additions & 1 deletion src/Juvix/Compiler/Core/Language/Nodes.hs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ data PatternWildcard' i a = PatternWildcard

data PatternConstr' i a = PatternConstr
{ _patternConstrInfo :: i,
_patternConstrFixity :: Maybe Fixity,
_patternConstrBinder :: Binder' a,
_patternConstrTag :: !Tag,
_patternConstrArgs :: ![Pattern' i a]
Expand Down Expand Up @@ -549,7 +550,7 @@ instance (Eq a) => Eq (MatchBranch' i a) where
(MatchBranch _ pats1 b1) == (MatchBranch _ pats2 b2) = pats1 == pats2 && b1 == b2

instance (Eq a) => Eq (PatternConstr' i a) where
(PatternConstr _ _ tag1 ps1) == (PatternConstr _ _ tag2 ps2) = tag1 == tag2 && ps1 == ps2
(PatternConstr _ _ _ tag1 ps1) == (PatternConstr _ _ _ tag2 ps2) = tag1 == tag2 && ps1 == ps2

instance (Eq a) => Eq (SideIfBranch' i a) where
(SideIfBranch _ c1 b1) == (SideIfBranch _ c2 b2) = c1 == c2 && b1 == b2
Expand Down
74 changes: 56 additions & 18 deletions src/Juvix/Compiler/Core/Pretty/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -305,18 +305,53 @@ instance (PrettyCode a) => PrettyCode (If' i a) where

instance PrettyCode PatternWildcard where
ppCode PatternWildcard {..} = do
n <- ppName KNameLocal (_patternWildcardBinder ^. binderName)
ppWithType n (_patternWildcardBinder ^. binderType)
bPretty <- asks (^. optPrettyPatterns)
let name = _patternWildcardBinder ^. binderName
if
| not bPretty -> do
n <- ppName KNameLocal name
ppWithType n (_patternWildcardBinder ^. binderType)
| isPrefixOf "_" (fromText name) || name == "?" || name == "" ->
return kwWildcard
| otherwise ->
ppName KNameLocal name

instance PrettyCode PatternConstr where
ppCode PatternConstr {..} = do
n <- ppName KNameConstructor (getInfoName _patternConstrInfo)
bPretty <- asks (^. optPrettyPatterns)
let cname = getInfoName _patternConstrInfo
n <- ppName KNameConstructor cname
bn <- ppName KNameLocal (_patternConstrBinder ^. binderName)
let mkpat :: Doc Ann -> Doc Ann
mkpat pat = if _patternConstrBinder ^. binderName == "?" || _patternConstrBinder ^. binderName == "" then pat else bn <> kwAt <> parens pat
args <- mapM (ppRightExpression appFixity) _patternConstrArgs
let name = fromText (_patternConstrBinder ^. binderName)
mkpat :: Doc Ann -> Doc Ann
mkpat pat = if name == "?" || name == "" || (bPretty && isPrefixOf "_" name) then pat else bn <> kwAt <> parens pat
args0 =
if
| bPretty ->
filter (not . isWildcardTypeBinder) _patternConstrArgs
| otherwise ->
_patternConstrArgs
args <- mapM (ppRightExpression appFixity) args0
let pat = mkpat (hsep (n : args))
ppWithType pat (_patternConstrBinder ^. binderType)
if
| bPretty ->
case _patternConstrFixity of
Nothing -> do
return pat
Just fixity
| isBinary fixity ->
goBinary (cname == ",") fixity n args0
| isUnary fixity ->
goUnary fixity n args0
_ -> impossible
| otherwise ->
ppWithType pat (_patternConstrBinder ^. binderType)
where
isWildcardTypeBinder :: Pattern -> Bool
isWildcardTypeBinder = \case
PatWildcard PatternWildcard {..} ->
isUniverse (typeTarget (_patternWildcardBinder ^. binderType))
_ -> False

instance PrettyCode Pattern where
ppCode = \case
Expand Down Expand Up @@ -683,7 +718,7 @@ instance (PrettyCode a) => PrettyCode [a] where
-- printing values
--------------------------------------------------------------------------------

goBinary :: (Member (Reader Options) r) => Bool -> Fixity -> Doc Ann -> [Value] -> Sem r (Doc Ann)
goBinary :: (HasAtomicity a, PrettyCode a, Member (Reader Options) r) => Bool -> Fixity -> Doc Ann -> [a] -> Sem r (Doc Ann)
goBinary isComma fixity name = \case
[] -> return (parens name)
[arg] -> do
Expand All @@ -700,7 +735,7 @@ goBinary isComma fixity name = \case
_ ->
impossible

goUnary :: (Member (Reader Options) r) => Fixity -> Doc Ann -> [Value] -> Sem r (Doc Ann)
goUnary :: (HasAtomicity a, PrettyCode a, Member (Reader Options) r) => Fixity -> Doc Ann -> [a] -> Sem r (Doc Ann)
goUnary fixity name = \case
[] -> return (parens name)
[arg] -> do
Expand Down Expand Up @@ -731,19 +766,22 @@ instance PrettyCode Value where
ValueFun -> return "<function>"
ValueType -> return "<type>"

ppValueSequence :: (Member (Reader Options) r) => [Value] -> Sem r (Doc Ann)
ppValueSequence vs = hsep <$> mapM (ppRightExpression appFixity) vs

docValueSequence :: [Value] -> Doc Ann
docValueSequence =
run
. runReader defaultOptions
. ppValueSequence

--------------------------------------------------------------------------------
-- helper functions
--------------------------------------------------------------------------------

ppSequence ::
(PrettyCode a, HasAtomicity a, Member (Reader Options) r) =>
[a] ->
Sem r (Doc Ann)
ppSequence vs = hsep <$> mapM (ppRightExpression appFixity) vs

docSequence :: (PrettyCode a, HasAtomicity a) => Options -> [a] -> Doc Ann
docSequence opts =
run
. runReader opts
. ppSequence

ppPostExpression ::
(PrettyCode a, HasAtomicity a, Member (Reader Options) r) =>
Fixity ->
Expand Down
9 changes: 6 additions & 3 deletions src/Juvix/Compiler/Core/Pretty/Options.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import Juvix.Prelude
data Options = Options
{ _optShowIdentIds :: Bool,
_optShowDeBruijnIndices :: Bool,
_optShowArgsNum :: Bool
_optShowArgsNum :: Bool,
_optPrettyPatterns :: Bool
}

makeLenses ''Options
Expand All @@ -15,15 +16,17 @@ defaultOptions =
Options
{ _optShowIdentIds = False,
_optShowDeBruijnIndices = False,
_optShowArgsNum = False
_optShowArgsNum = False,
_optPrettyPatterns = False
}

traceOptions :: Options
traceOptions =
Options
{ _optShowIdentIds = True,
_optShowDeBruijnIndices = True,
_optShowArgsNum = True
_optShowArgsNum = True,
_optPrettyPatterns = False
}

fromGenericOptions :: GenericOptions -> Options
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import Juvix.Compiler.Core.Transformation.CombineInfoTables (combineInfoTables)
import Juvix.Compiler.Core.Transformation.ComputeCaseANF
import Juvix.Compiler.Core.Transformation.ComputeTypeInfo
import Juvix.Compiler.Core.Transformation.ConvertBuiltinTypes
import Juvix.Compiler.Core.Transformation.DetectRedundantPatterns
import Juvix.Compiler.Core.Transformation.DisambiguateNames
import Juvix.Compiler.Core.Transformation.Eta
import Juvix.Compiler.Core.Transformation.FoldTypeSynonyms
Expand Down Expand Up @@ -75,6 +76,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
ComputeTypeInfo -> return . computeTypeInfo
ComputeCaseANF -> return . computeCaseANF
UnrollRecursion -> unrollRecursion
DetectRedundantPatterns -> mapError (JuvixError @CoreError) . detectRedundantPatterns
MatchToCase -> mapError (JuvixError @CoreError) . matchToCase
EtaExpandApps -> return . etaExpansionApps
DisambiguateNames -> return . disambiguateNames
Expand Down
129 changes: 129 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/DetectRedundantPatterns.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
module Juvix.Compiler.Core.Transformation.DetectRedundantPatterns where

import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Error
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.LocationInfo
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Pretty hiding (Options)
import Juvix.Compiler.Core.Transformation.Base

type PatternRow = [Pattern]

type PatternMatrix = [PatternRow]

-- | Checks for redundant patterns in `Match` nodes. The algorithm is based on
-- the paper: Luc Maranget, "Warnings for pattern matching", JFP 17 (3):
-- 387–421, 2007.
detectRedundantPatterns :: (Members '[Error CoreError, Reader CoreOptions] r) => Module -> Sem r Module
detectRedundantPatterns md = do
fCoverage <- asks (^. optCheckCoverage)
if
| fCoverage ->
mapAllNodesM (umapM (goDetectRedundantPatterns md)) md
| otherwise ->
return md

goDetectRedundantPatterns ::
forall r.
(Members '[Error CoreError, Reader CoreOptions] r) =>
Module ->
Node ->
Sem r Node
goDetectRedundantPatterns md node = case node of
NMatch m -> do
checkMatch m
return node
_ -> return node
where
mockFile = $(mkAbsFile "/check-redundant-patterns")
defaultLoc = singletonInterval (mkInitialLoc mockFile)

checkMatch :: Match -> Sem r ()
checkMatch Match {..} = case _matchBranches of
[] -> return ()
MatchBranch {..} : brs -> go [toList _matchBranchPatterns] brs
where
go :: PatternMatrix -> [MatchBranch] -> Sem r ()
go matrix = \case
[] -> return ()
MatchBranch {..} : branches -> do
let row = toList _matchBranchPatterns
unless (check matrix row) $
throw
CoreError
{ _coreErrorMsg = ppOutput ("Redundant pattern" <> seq <> ": " <> pat),
_coreErrorNode = Nothing,
_coreErrorLoc = fromMaybe defaultLoc (getInfoLocation _matchBranchInfo)
}
case _matchBranchRhs of
MatchBranchRhsExpression {} -> go (row : matrix) branches
MatchBranchRhsIfs {} -> go matrix branches
where
opts = defaultOptions {_optPrettyPatterns = True}
seq = if length _matchBranchPatterns == 1 then "" else " sequence"
pat = if length _matchBranchPatterns == 1 then doc opts (head _matchBranchPatterns) else docSequence opts (toList _matchBranchPatterns)

-- Returns True if vector is useful (not redundant) for matrix, i.e. it is
-- not covered by any row in the matrix. See Definition 6 and Section 3.1 in
-- the paper.
check :: PatternMatrix -> PatternRow -> Bool
check matrix vector = case vector of
[]
| null matrix -> True
| otherwise -> False
(p : ps) -> case p of
PatConstr PatternConstr {..} ->
check
(specialize _patternConstrTag (length _patternConstrArgs) matrix)
(_patternConstrArgs ++ ps)
PatWildcard {} ->
let col = map head' matrix
tagsSet = getPatTags col
tags = toList tagsSet
ind = lookupConstructorInfo md (head' tags) ^. constructorInductive
ctrsNum = length (lookupInductiveInfo md ind ^. inductiveConstructors)
in if
| not (null tags) && length tags == ctrsNum ->
go tags
| otherwise ->
check (computeDefault matrix) ps
where
go :: [Tag] -> Bool
go = \case
[] -> False
(tag : tags') ->
check matrix' (replicate argsNum p ++ ps) || go tags'
where
argsNum = lookupConstructorInfo md tag ^. constructorArgsNum
matrix' = specialize tag argsNum matrix

getPatTags :: [Pattern] -> HashSet Tag
getPatTags = \case
[] ->
mempty
PatConstr PatternConstr {..} : pats ->
HashSet.insert _patternConstrTag (getPatTags pats)
_ : pats ->
getPatTags pats

specialize :: Tag -> Int -> PatternMatrix -> PatternMatrix
specialize tag argsNum = mapMaybe go
where
go :: PatternRow -> Maybe PatternRow
go row = case row of
PatConstr PatternConstr {..} : row'
| _patternConstrTag == tag -> Just $ _patternConstrArgs ++ row'
| otherwise -> Nothing
w@PatWildcard {} : row' ->
Just $ replicate argsNum w ++ row'
[] -> impossible

computeDefault :: PatternMatrix -> PatternMatrix
computeDefault matrix = mapMaybe go matrix
where
go :: PatternRow -> Maybe PatternRow
go row = case row of
PatConstr {} : _ -> Nothing
PatWildcard {} : row' -> Just row'
[] -> impossible
Loading

0 comments on commit 68a79bc

Please sign in to comment.