diff --git a/src/Elara/AST/Generic/Instances/StripLocation.hs b/src/Elara/AST/Generic/Instances/StripLocation.hs index dac67b9..3cecd78 100644 --- a/src/Elara/AST/Generic/Instances/StripLocation.hs +++ b/src/Elara/AST/Generic/Instances/StripLocation.hs @@ -200,8 +200,7 @@ instance forall (ast1 :: LocatedAST) (ast2 :: UnlocatedAST). ( ASTLocate' ast1 ~ Located , ASTLocate' ast2 ~ Unlocated - , ( StripLocation (Select "Infixed" ast1) (Select "Infixed" ast2) - ) + , (StripLocation (Select "Infixed" ast1) (Select "Infixed" ast2)) , ( StripLocation (CleanupLocated (Located (Select "SymOp" ast1))) (Select "SymOp" ast2) diff --git a/src/Elara/AST/Pretty.hs b/src/Elara/AST/Pretty.hs index e4556d3..b517d98 100644 --- a/src/Elara/AST/Pretty.hs +++ b/src/Elara/AST/Pretty.hs @@ -45,8 +45,7 @@ prettyLambdaExpr args body = parens (if ?contextFree then prettyCTFLambdaExpr el long = align - ( "\\" <+> hsep (pretty <$> args) <+> "->" <> hardline <> nest indentDepth (pretty body) - ) + ("\\" <+> hsep (pretty <$> args) <+> "->" <> hardline <> nest indentDepth (pretty body)) prettyFunctionCall :: (?contextFree :: Bool, Pretty a, Pretty b) => a -> b -> Doc AnsiStyle prettyFunctionCall e1' e2' = parens (if ?contextFree then short else group (flatAlt long short)) diff --git a/src/Elara/Parse/Declaration.hs b/src/Elara/Parse/Declaration.hs index b4e5cb1..c981773 100644 --- a/src/Elara/Parse/Declaration.hs +++ b/src/Elara/Parse/Declaration.hs @@ -42,8 +42,7 @@ defDec modName = fmapLocated Declaration $ do ( Declaration' modName name - ( DeclarationBody declBody - ) + (DeclarationBody declBody) ) letDec :: Located ModuleName -> Parser FrontendDeclaration diff --git a/src/Elara/TypeInfer/ConstraintGeneration.hs b/src/Elara/TypeInfer/ConstraintGeneration.hs index a61ad15..f6aa874 100644 --- a/src/Elara/TypeInfer/ConstraintGeneration.hs +++ b/src/Elara/TypeInfer/ConstraintGeneration.hs @@ -11,7 +11,8 @@ import Elara.AST.Typed (TypedExpr, TypedExpr') import Elara.AST.VarRef import Elara.Data.Unique (UniqueGen) import Elara.TypeInfer.Environment (InferError, LocalTypeEnvironment, TypeEnvKey (TermVarKey), TypeEnvironment, addType, lookupLocalVar, lookupLocalVarType, lookupType, withLocalType) -import Elara.TypeInfer.Type (AxiomScheme, Constraint (..), Monotype (..), Scalar (..), Substitutable (..), Substitution, Type (Forall)) +import Elara.TypeInfer.Ftv (occurs) +import Elara.TypeInfer.Type (AxiomScheme, Constraint (..), Monotype (..), Scalar (..), Substitutable (..), Substitution (..), Type (Forall)) import Elara.TypeInfer.Unique (makeUniqueTyVar) import Polysemy import Polysemy.Error @@ -95,9 +96,61 @@ tidyConstraint (Conjunction x (Conjunction y z)) | x == y = Conjunction (tidyCon tidyConstraint (Conjunction c1 c2) = Conjunction (tidyConstraint c1) (tidyConstraint c2) tidyConstraint (Equality x y) = Equality x y --- unifyConstraints :: [Constraint loc] -> - --- simplifyConstraints :: AxiomScheme loc -> Constraint loc -> Constraint loc -> (Constraint loc, Substitution loc) --- simplifyConstraints axioms given wanted = (residual, subst) --- where --- (residual, subst) = todo +solveConstraints :: AxiomScheme loc -> Constraint loc -> Constraint loc -> Sem '[Error UnifyError] (Constraint loc, Substitution loc) +solveConstraints axioms given wanted = do + (substitution, simplifiedWanted) <- unifyEquality wanted + + -- let simplifiedGiven = substitute substitution given + + -- let simplifiedWanted' = substitute substitution simplifiedWanted + + -- let (entailable, residualWanted) = entail axioms simplifiedGiven simplifiedWanted' + + pure (todo, substitution) + +unifyEquality :: Constraint loc -> Sem '[Error UnifyError] (Substitution loc, Constraint loc) +unifyEquality (Equality a b) = unify a b +unifyEquality c = pure (mempty, c) + +unify :: HasCallStack => Monotype loc -> Monotype loc -> Sem '[Error UnifyError] (Substitution loc, Constraint loc) +unify (TypeVar a) (TypeVar b) | a == b = pure (mempty, EmptyConstraint) +unify (TypeVar a) b = + if occurs a b + then throw OccursCheckFailed + else pure (Substitution [(a, b)], EmptyConstraint) +unify a (TypeVar b) = unify (TypeVar b) a -- swap to avoid duplication +unify (Scalar a) (Scalar b) = + if a == b + then pure (mempty, EmptyConstraint) + else throw ScalarMismatch +unify (TypeConstructor a as) (TypeConstructor b bs) + | a /= b = throw TypeConstructorMismatch + | length as /= length bs = throw ArityMismatch + | otherwise = unifyMany as bs +unify (Function a b) (Function c d) = do + (s1, c1) <- unify a c + (s2, c2) <- unify (substituteAll s1 b) (substituteAll s1 d) + pure (s1 <> s2, c1 <> c2) +unify a b = throw $ UnificationFailed $ "Unification failed: " <> show a <> " and " <> show b + +unifyMany :: + HasCallStack => + [Monotype loc] -> + [Monotype loc] -> + Sem + '[Error UnifyError] + (Substitution loc, Constraint loc) +unifyMany [] _ = pure (mempty, EmptyConstraint) +unifyMany _ [] = pure (mempty, EmptyConstraint) +unifyMany (a : as) (b : bs) = do + (s1, c1) <- unify a b + (s2, c2) <- unifyMany (fmap (substituteAll s1) as) (fmap (substituteAll s1) bs) + pure (s1 <> s2, c1 <> c2) + +data UnifyError + = OccursCheckFailed + | ScalarMismatch + | TypeConstructorMismatch + | ArityMismatch + | UnificationFailed String + deriving (Eq, Show) diff --git a/src/Elara/TypeInfer/Ftv.hs b/src/Elara/TypeInfer/Ftv.hs new file mode 100644 index 0000000..3879475 --- /dev/null +++ b/src/Elara/TypeInfer/Ftv.hs @@ -0,0 +1,20 @@ +module Elara.TypeInfer.Ftv where + +import Data.Set (difference, member) +import Elara.TypeInfer.Type (Monotype (..), Type (..)) +import Elara.TypeInfer.Unique + +class Ftv a where + ftv :: a -> Set UniqueTyVar + +instance Ftv (Monotype loc) where + ftv (TypeVar tv) = one tv + ftv (Scalar _) = mempty + ftv (TypeConstructor _ ts) = foldMap ftv ts + ftv (Function t1 t2) = ftv t1 <> ftv t2 + +instance Ftv (Type loc) where + ftv (Forall tv _ t) = ftv t `difference` one tv + +occurs :: Ftv a => UniqueTyVar -> a -> Bool +occurs tv a = tv `member` ftv a diff --git a/src/Elara/TypeInfer/Type.hs b/src/Elara/TypeInfer/Type.hs index 53efe26..f0a6a15 100644 --- a/src/Elara/TypeInfer/Type.hs +++ b/src/Elara/TypeInfer/Type.hs @@ -3,9 +3,7 @@ module Elara.TypeInfer.Type where import Data.Kind qualified as Kind import Elara.AST.Name -import Elara.AST.VarRef (UnlocatedVarRef, VarRef) import Elara.TypeInfer.Unique -import Prelude hiding (Constraint) -- | A type scheme σ data Type loc @@ -67,11 +65,13 @@ data Scalar | ScalarString | ScalarChar | ScalarUnit - deriving (Generic, Show, Eq, Ord) + deriving (Generic, Show, Eq, Ord, Enum, Bounded) type DataCon = Qualified TypeName -data Substitution loc = Substitution [(UniqueTyVar, Monotype loc)] +newtype Substitution loc = Substitution [(UniqueTyVar, Monotype loc)] + deriving newtype (Semigroup, Monoid) + deriving stock (Eq, Show) class Substitutable (a :: k -> Kind.Type) where substitute :: UniqueTyVar -> Monotype loc -> a loc -> a loc diff --git a/test/Arbitrary/Type.hs b/test/Arbitrary/Type.hs new file mode 100644 index 0000000..faf06b9 --- /dev/null +++ b/test/Arbitrary/Type.hs @@ -0,0 +1,28 @@ +module Arbitrary.Type where + +import Elara.AST.Name +import Elara.Data.Unique (unsafeMkUnique) +import Elara.TypeInfer.Type +import Elara.TypeInfer.Unique +import Hedgehog (Gen) +import Hedgehog.Gen qualified as Gen +import Hedgehog.Range qualified as Range +import Region (qualifiedTest) + +-- | contrary to what the name suggests, this will NOT be unique :) +genUniqueTypeVar :: Gen UniqueTyVar +genUniqueTypeVar = unsafeMkUnique Nothing <$> Gen.integral (Range.linear 0 100) + +typeConstructorNames :: [TypeName] +typeConstructorNames = ["List", "Maybe", "Pair", "Box", "IO"] + +genMonotype :: Gen (Monotype loc) +genMonotype = + Gen.recursive + Gen.choice + [ TypeVar <$> genUniqueTypeVar + , Scalar <$> Gen.enumBounded + ] + [ TypeConstructor <$> Gen.element (qualifiedTest <$> typeConstructorNames) <*> Gen.list (Range.linear 0 2) genMonotype + , Function <$> genMonotype <*> genMonotype + ] diff --git a/test/Infer.hs b/test/Infer.hs index 3d6f2ab..67ef51f 100644 --- a/test/Infer.hs +++ b/test/Infer.hs @@ -16,6 +16,7 @@ import Elara.TypeInfer.Type import Hedgehog (Property, assert, evalEither, evalEitherM, evalIO, failure, forAll, property) import Hedgehog.Gen qualified as Gen import Hedgehog.Range qualified as Range +import Infer.Unify qualified as Unify import Optics.Operators.Unsafe ((^?!)) import Polysemy (Sem, run, runM, subsume, subsume_) import Polysemy.Error (runError) @@ -30,11 +31,12 @@ import Test.Syd.Hedgehog () import Prelude hiding (fail) spec :: Spec -spec = describe "Infers types correctly" $ parallel $ do +spec = describe "Infers types correctly" $ do literalTests lambdaTests it "infers literals" prop_literalTypesInvariants + Unify.spec -- Literal Type Inference Tests literalTests :: Spec diff --git a/test/Infer/Unify.hs b/test/Infer/Unify.hs new file mode 100644 index 0000000..ae6c0bc --- /dev/null +++ b/test/Infer/Unify.hs @@ -0,0 +1,62 @@ +module Infer.Unify where + +import Arbitrary.Type (genMonotype, genUniqueTypeVar) +import Elara.TypeInfer.ConstraintGeneration +import Elara.TypeInfer.Type +import Hedgehog (Gen, Property, evalEither, forAll, property, (===)) +import Hedgehog.Gen qualified as Gen +import Hedgehog.Range qualified as Range +import Polysemy +import Polysemy.Error +import Test.Syd +import Test.Syd.Hedgehog () + +spec :: Spec +spec = describe "Type unification" $ do + it "unifies type variables" prop_unify_type_vars + it "unifies scalars" prop_unify_scalars + it "unifies functions" prop_unify_functions + it "unifies self" prop_unify_self + it "fails to unify mismatched types" prop_unify_failure + +runUnify :: + Sem '[Error UnifyError] (Substitution loc, Constraint loc) -> + Either UnifyError (Substitution loc, Constraint loc) +runUnify = run . runError + +prop_unify_type_vars :: Property +prop_unify_type_vars = property $ do + a <- forAll $ genUniqueTypeVar + let typeVar = TypeVar a + (sub, _) <- evalEither $ runUnify $ unify typeVar typeVar + sub === Substitution [] + +prop_unify_scalars :: Property +prop_unify_scalars = property $ do + a <- forAll $ Gen.enumBounded + let scalarType = Scalar a + (sub, _) <- evalEither $ runUnify $ unify scalarType scalarType + sub === Substitution [] + +prop_unify_self :: Property +prop_unify_self = property $ do + a <- forAll genMonotype + (sub, _) <- evalEither $ runUnify $ unify a a + sub === Substitution [] + +prop_unify_functions :: Property +prop_unify_functions = property $ do + a <- forAll genMonotype + b <- forAll genMonotype + (sub, _) <- evalEither $ runUnify $ unify (Function a b) (Function a b) + sub === Substitution [] + +-- Hedgehog property: Check unification failure for mismatched types +prop_unify_failure :: Property +prop_unify_failure = property $ do + a <- forAll genMonotype + b <- forAll genMonotype + -- let's come back to this later + -- let result = runUnify $ unify a b + -- result === Left (UnificationFailed $ "Unification failed: " <> show a <> " and " <> show b) + guard $ a /= b