Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Detect constant side conditions in matches #3133

Merged
merged 8 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ data TransformationId
| UnrollRecursion
| ComputeTypeInfo
| ComputeCaseANF
| DetectConstantSideConditions
| DetectRedundantPatterns
| MatchToCase
| EtaExpandApps
Expand Down Expand Up @@ -59,10 +60,10 @@ data PipelineId
type TransformationLikeId = TransformationLikeId' TransformationId PipelineId

toTypecheckTransformations :: [TransformationId]
toTypecheckTransformations = [DetectRedundantPatterns, MatchToCase]
toTypecheckTransformations = [DetectConstantSideConditions, DetectRedundantPatterns, MatchToCase]

toStoredTransformations :: [TransformationId]
toStoredTransformations = [EtaExpandApps, DetectRedundantPatterns, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, OptPhaseEval, DisambiguateNames]
toStoredTransformations = [EtaExpandApps, DetectConstantSideConditions, DetectRedundantPatterns, MatchToCase, NatToPrimInt, IntToPrimInt, ConvertBuiltinTypes, OptPhaseEval, DisambiguateNames]

combineInfoTablesTransformations :: [TransformationId]
combineInfoTablesTransformations = [CombineInfoTables, FilterUnreachable]
Expand All @@ -84,6 +85,7 @@ instance TransformationId' TransformationId where
LambdaLetRecLifting -> strLifting
LetRecLifting -> strLetRecLifting
TopEtaExpand -> strTopEtaExpand
DetectConstantSideConditions -> strDetectConstantSideConditions
DetectRedundantPatterns -> strDetectRedundantPatterns
MatchToCase -> strMatchToCase
EtaExpandApps -> strEtaExpandApps
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ strLetRecLifting = "letrec-lifting"
strTopEtaExpand :: Text
strTopEtaExpand = "top-eta-expand"

strDetectConstantSideConditions :: Text
strDetectConstantSideConditions = "detect-constant-side-conditions"

strDetectRedundantPatterns :: Text
strDetectRedundantPatterns = "detect-redundant-patterns"

Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import Juvix.Compiler.Core.Transformation.CombineInfoTables (combineInfoTables)
import Juvix.Compiler.Core.Transformation.ComputeCaseANF
import Juvix.Compiler.Core.Transformation.ComputeTypeInfo
import Juvix.Compiler.Core.Transformation.ConvertBuiltinTypes
import Juvix.Compiler.Core.Transformation.DetectConstantSideConditions
import Juvix.Compiler.Core.Transformation.DetectRedundantPatterns
import Juvix.Compiler.Core.Transformation.DisambiguateNames
import Juvix.Compiler.Core.Transformation.Eta
Expand Down Expand Up @@ -76,6 +77,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
ComputeTypeInfo -> return . computeTypeInfo
ComputeCaseANF -> return . computeCaseANF
UnrollRecursion -> unrollRecursion
DetectConstantSideConditions -> mapError (JuvixError @CoreError) . detectConstantSideConditions
DetectRedundantPatterns -> mapError (JuvixError @CoreError) . detectRedundantPatterns
MatchToCase -> mapError (JuvixError @CoreError) . matchToCase
EtaExpandApps -> return . etaExpansionApps
Expand Down
1 change: 0 additions & 1 deletion src/Juvix/Compiler/Core/Transformation/ComputeCaseANF.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ module Juvix.Compiler.Core.Transformation.ComputeCaseANF (computeCaseANF) where
-- ```
-- let z := f x in case z of { c y := y + x; d y := y }
-- ```
-- This transformation is needed for the Nockma backend.

import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Extra
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
module Juvix.Compiler.Core.Transformation.DetectConstantSideConditions
( detectConstantSideConditions,
)
where

import Juvix.Compiler.Core.Error
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.LocationInfo
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base

detectConstantSideConditions :: forall r. (Members '[Error CoreError, Reader CoreOptions] r) => Module -> Sem r Module
detectConstantSideConditions md = mapAllNodesM (umapM go) md
where
mockFile = $(mkAbsFile "/detect-constant-side-conditions")
defaultLoc = singletonInterval (mkInitialLoc mockFile)

boolSym = lookupConstructorInfo md (BuiltinTag TagTrue) ^. constructorInductive

go :: Node -> Sem r Node
go node = case node of
NMatch m -> NMatch <$> (overM matchBranches (mapMaybeM convertMatchBranch) m)
_ -> return node

convertMatchBranch :: MatchBranch -> Sem r (Maybe MatchBranch)
convertMatchBranch br@MatchBranch {..} =
case _matchBranchRhs of
MatchBranchRhsExpression {} ->
return $ Just br
MatchBranchRhsIfs ifs ->
case ifs1 of
[] ->
case nonEmpty ifs0 of
Nothing -> return Nothing
Just ifs0' -> return $ Just $ set matchBranchRhs (MatchBranchRhsIfs ifs0') br
SideIfBranch {..} : ifs1' -> do
fCoverage <- asks (^. optCheckCoverage)
unless (not fCoverage || null ifs1') $
throw
CoreError
{ _coreErrorMsg = "Redundant side condition",
_coreErrorNode = Nothing,
_coreErrorLoc = fromMaybe defaultLoc (getInfoLocation (head' ifs1' ^. sideIfBranchInfo))
}
let ifsBody = mkIfs boolSym (map (\(SideIfBranch i c b) -> (i, c, b)) ifs0) _sideIfBranchBody
return $ Just $ set matchBranchRhs (MatchBranchRhsExpression ifsBody) br
where
ifs' = filter (not . isFalseConstr . (^. sideIfBranchCondition)) (toList ifs)
(ifs0, ifs1) = span (not . isTrueConstr . (^. sideIfBranchCondition)) ifs'
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ goDetectRedundantPatterns md node = case node of
return node
_ -> return node
where
mockFile = $(mkAbsFile "/check-redundant-patterns")
mockFile = $(mkAbsFile "/detect-redundant-patterns")
defaultLoc = singletonInterval (mkInitialLoc mockFile)

checkMatch :: Match -> Sem r ()
Expand All @@ -52,7 +52,7 @@ goDetectRedundantPatterns md node = case node of
unless (check matrix row) $
throw
CoreError
{ _coreErrorMsg = ppOutput ("Redundant pattern" <> seq <> ": " <> pat),
{ _coreErrorMsg = ppOutput ("Redundant pattern" <> seq <> ": " <> pat <> "\nPerhaps you mistyped a constructor name in an earlier pattern?"),
_coreErrorNode = Nothing,
_coreErrorLoc = fromMaybe defaultLoc (getInfoLocation _matchBranchInfo)
}
Expand All @@ -61,8 +61,8 @@ goDetectRedundantPatterns md node = case node of
MatchBranchRhsIfs {} -> go matrix branches
where
opts = defaultOptions {_optPrettyPatterns = True}
seq = if length _matchBranchPatterns == 1 then "" else " sequence"
pat = if length _matchBranchPatterns == 1 then doc opts (head _matchBranchPatterns) else docSequence opts (toList _matchBranchPatterns)
seq = if isSingleton (toList _matchBranchPatterns) then "" else " sequence"
pat = if isSingleton (toList _matchBranchPatterns) then doc opts (head _matchBranchPatterns) else docSequence opts (toList _matchBranchPatterns)

-- Returns True if vector is useful (not redundant) for matrix, i.e. it is
-- not covered by any row in the matrix. See Definition 6 and Section 3.1 in
Expand Down
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Core/Transformation/MatchToCase.hs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ goMatchToCase recur node = case node of
mkBuiltinApp' OpFail [mkConstant' (ConstString ("Pattern sequence not matched: " <> ppTrace pat))]
where
pat = err (replicate (length vs) ValueWildcard)
seq = if length pat == 1 then "" else "sequence "
pat' = if length pat == 1 then doc defaultOptions (head' pat) else docSequence defaultOptions pat
seq = if isSingleton pat then "" else "sequence "
pat' = if isSingleton pat then doc defaultOptions (head' pat) else docSequence defaultOptions pat
r@PatternRow {..} : matrix'
| all isPatWildcard _patternRowPatterns ->
-- The first row matches all values (Section 4, case 2)
Expand Down
2 changes: 2 additions & 0 deletions src/Juvix/Prelude/Base/Foundation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ module Juvix.Prelude.Base.Foundation
module GHC.Generics,
module GHC.Num,
module GHC.Real,
module GHC.Utils.Misc,
module Control.Lens,
module Language.Haskell.TH.Syntax,
module Prettyprinter,
Expand Down Expand Up @@ -197,6 +198,7 @@ import GHC.Generics (Generic)
import GHC.Num
import GHC.Real
import GHC.Stack.Types
import GHC.Utils.Misc (isSingleton)
import Language.Haskell.TH.Syntax (Exp, Lift, Q)
import Numeric hiding (exp, log, pi)
import Path (Abs, Dir, File, Path, Rel, SomeBase (..))
Expand Down
14 changes: 13 additions & 1 deletion test/Compilation/Negative.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,17 @@ tests =
NegTest
"Test010: Redundant pattern detection with complex patterns"
$(mkRelDir ".")
$(mkRelFile "test010.juvix")
$(mkRelFile "test010.juvix"),
NegTest
"Test011: Redundant pattern detection with side conditions"
$(mkRelDir ".")
$(mkRelFile "test011.juvix"),
NegTest
"Test012: Pattern matching coverage with side conditions"
$(mkRelDir ".")
$(mkRelFile "test012.juvix"),
NegTest
"Test013: Redundant side condition detection"
$(mkRelDir ".")
$(mkRelFile "test013.juvix")
]
7 changes: 6 additions & 1 deletion test/Compilation/Positive.hs
Original file line number Diff line number Diff line change
Expand Up @@ -480,5 +480,10 @@ tests =
"Test081: Non-duplication in let-folding"
$(mkRelDir ".")
$(mkRelFile "test081.juvix")
$(mkRelFile "out/test081.out")
$(mkRelFile "out/test081.out"),
posTest
"Test082: Pattern matching with side conditions"
$(mkRelDir ".")
$(mkRelFile "test082.juvix")
$(mkRelFile "out/test082.out")
]
1 change: 0 additions & 1 deletion tests/Compilation/negative/test001.juvix
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ f : List Nat -> List Nat -> Nat
| _ nil := 0;

main : Nat := f (1 :: nil) (2 :: nil);

15 changes: 15 additions & 0 deletions tests/Compilation/negative/test011.juvix
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
-- Redundant pattern after a true side condition
module test011;

import Stdlib.Prelude open;

f (x : List Nat) : Nat :=
case x of
| nil := 0
| x :: _ :: nil := x
| _ :: _ :: _ :: _ if true := 0
| _ :: _ :: x :: nil := x
| _ :: nil := 1
| _ := 2;

main : Nat := f (1 :: 2 :: nil);
11 changes: 11 additions & 0 deletions tests/Compilation/negative/test012.juvix
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- Non-exhaustive pattern matching with false side conditions
module test012;

import Stdlib.Prelude open;

f (x : List Nat) : Nat :=
case x of
| nil := 0
| x :: _ if false := x;

main : Nat := f (1 :: 2 :: nil);
14 changes: 14 additions & 0 deletions tests/Compilation/negative/test013.juvix
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
-- Redundant side condition
module test013;

import Stdlib.Prelude open;

f (x : List Nat) : Nat :=
case x of
| nil := 0
| x :: _ if x > 0 := x
| if true := 0
| if false := 1
| if x == 0 := 2;

main : Nat := f (1 :: 2 :: nil);
1 change: 1 addition & 0 deletions tests/Compilation/positive/out/test082.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
4
22 changes: 22 additions & 0 deletions tests/Compilation/positive/test082.juvix
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
-- Pattern matching with side conditions
module test082;

import Stdlib.Prelude open;

f (lst : List Nat) : Nat :=
case lst of
| [] := 0
| x :: xs
| if x == 0 := 1
| if true := 2;

g (lst : List Nat) : Nat :=
case lst of
| [] := 0
| _ :: _ if false := 0
| x :: xs
| if x == 0 := 1
| if false := 2
| if true := 3;

main : Nat := f [0; 1; 2] + g [1; 2];
Loading