Skip to content

Commit

Permalink
unification of types
Browse files Browse the repository at this point in the history
  • Loading branch information
bristermitten committed Nov 23, 2024
1 parent 1601b02 commit 0003f18
Show file tree
Hide file tree
Showing 9 changed files with 180 additions and 18 deletions.
3 changes: 1 addition & 2 deletions src/Elara/AST/Generic/Instances/StripLocation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ instance
forall (ast1 :: LocatedAST) (ast2 :: UnlocatedAST).
( ASTLocate' ast1 ~ Located
, ASTLocate' ast2 ~ Unlocated
, ( StripLocation (Select "Infixed" ast1) (Select "Infixed" ast2)
)
, (StripLocation (Select "Infixed" ast1) (Select "Infixed" ast2))
, ( StripLocation
(CleanupLocated (Located (Select "SymOp" ast1)))
(Select "SymOp" ast2)
Expand Down
3 changes: 1 addition & 2 deletions src/Elara/AST/Pretty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ prettyLambdaExpr args body = parens (if ?contextFree then prettyCTFLambdaExpr el

long =
align
( "\\" <+> hsep (pretty <$> args) <+> "->" <> hardline <> nest indentDepth (pretty body)
)
("\\" <+> hsep (pretty <$> args) <+> "->" <> hardline <> nest indentDepth (pretty body))

prettyFunctionCall :: (?contextFree :: Bool, Pretty a, Pretty b) => a -> b -> Doc AnsiStyle
prettyFunctionCall e1' e2' = parens (if ?contextFree then short else group (flatAlt long short))
Expand Down
3 changes: 1 addition & 2 deletions src/Elara/Parse/Declaration.hs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ defDec modName = fmapLocated Declaration $ do
( Declaration'
modName
name
( DeclarationBody declBody
)
(DeclarationBody declBody)
)

letDec :: Located ModuleName -> Parser FrontendDeclaration
Expand Down
67 changes: 60 additions & 7 deletions src/Elara/TypeInfer/ConstraintGeneration.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import Elara.AST.Typed (TypedExpr, TypedExpr')
import Elara.AST.VarRef
import Elara.Data.Unique (UniqueGen)
import Elara.TypeInfer.Environment (InferError, LocalTypeEnvironment, TypeEnvKey (TermVarKey), TypeEnvironment, addType, lookupLocalVar, lookupLocalVarType, lookupType, withLocalType)
import Elara.TypeInfer.Type (AxiomScheme, Constraint (..), Monotype (..), Scalar (..), Substitutable (..), Substitution, Type (Forall))
import Elara.TypeInfer.Ftv (occurs)
import Elara.TypeInfer.Type (AxiomScheme, Constraint (..), Monotype (..), Scalar (..), Substitutable (..), Substitution (..), Type (Forall))
import Elara.TypeInfer.Unique (makeUniqueTyVar)
import Polysemy
import Polysemy.Error
Expand Down Expand Up @@ -95,9 +96,61 @@ tidyConstraint (Conjunction x (Conjunction y z)) | x == y = Conjunction (tidyCon
tidyConstraint (Conjunction c1 c2) = Conjunction (tidyConstraint c1) (tidyConstraint c2)
tidyConstraint (Equality x y) = Equality x y

-- unifyConstraints :: [Constraint loc] ->

-- simplifyConstraints :: AxiomScheme loc -> Constraint loc -> Constraint loc -> (Constraint loc, Substitution loc)
-- simplifyConstraints axioms given wanted = (residual, subst)
-- where
-- (residual, subst) = todo
solveConstraints :: AxiomScheme loc -> Constraint loc -> Constraint loc -> Sem '[Error UnifyError] (Constraint loc, Substitution loc)
solveConstraints axioms given wanted = do
(substitution, simplifiedWanted) <- unifyEquality wanted

-- let simplifiedGiven = substitute substitution given

-- let simplifiedWanted' = substitute substitution simplifiedWanted

-- let (entailable, residualWanted) = entail axioms simplifiedGiven simplifiedWanted'

pure (todo, substitution)

unifyEquality :: Constraint loc -> Sem '[Error UnifyError] (Substitution loc, Constraint loc)
unifyEquality (Equality a b) = unify a b
unifyEquality c = pure (mempty, c)

unify :: HasCallStack => Monotype loc -> Monotype loc -> Sem '[Error UnifyError] (Substitution loc, Constraint loc)
unify (TypeVar a) (TypeVar b) | a == b = pure (mempty, EmptyConstraint)
unify (TypeVar a) b =
if occurs a b
then throw OccursCheckFailed
else pure (Substitution [(a, b)], EmptyConstraint)
unify a (TypeVar b) = unify (TypeVar b) a -- swap to avoid duplication
unify (Scalar a) (Scalar b) =
if a == b
then pure (mempty, EmptyConstraint)
else throw ScalarMismatch
unify (TypeConstructor a as) (TypeConstructor b bs)
| a /= b = throw TypeConstructorMismatch
| length as /= length bs = throw ArityMismatch
| otherwise = unifyMany as bs
unify (Function a b) (Function c d) = do
(s1, c1) <- unify a c
(s2, c2) <- unify (substituteAll s1 b) (substituteAll s1 d)
pure (s1 <> s2, c1 <> c2)
unify a b = throw $ UnificationFailed $ "Unification failed: " <> show a <> " and " <> show b

unifyMany ::
HasCallStack =>
[Monotype loc] ->
[Monotype loc] ->
Sem
'[Error UnifyError]
(Substitution loc, Constraint loc)
unifyMany [] _ = pure (mempty, EmptyConstraint)
unifyMany _ [] = pure (mempty, EmptyConstraint)
unifyMany (a : as) (b : bs) = do
(s1, c1) <- unify a b
(s2, c2) <- unifyMany (fmap (substituteAll s1) as) (fmap (substituteAll s1) bs)
pure (s1 <> s2, c1 <> c2)

data UnifyError
= OccursCheckFailed
| ScalarMismatch
| TypeConstructorMismatch
| ArityMismatch
| UnificationFailed String
deriving (Eq, Show)
20 changes: 20 additions & 0 deletions src/Elara/TypeInfer/Ftv.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module Elara.TypeInfer.Ftv where

import Data.Set (difference, member)
import Elara.TypeInfer.Type (Monotype (..), Type (..))
import Elara.TypeInfer.Unique

class Ftv a where
ftv :: a -> Set UniqueTyVar

instance Ftv (Monotype loc) where
ftv (TypeVar tv) = one tv
ftv (Scalar _) = mempty
ftv (TypeConstructor _ ts) = foldMap ftv ts
ftv (Function t1 t2) = ftv t1 <> ftv t2

instance Ftv (Type loc) where
ftv (Forall tv _ t) = ftv t `difference` one tv

occurs :: Ftv a => UniqueTyVar -> a -> Bool
occurs tv a = tv `member` ftv a
8 changes: 4 additions & 4 deletions src/Elara/TypeInfer/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ module Elara.TypeInfer.Type where

import Data.Kind qualified as Kind
import Elara.AST.Name
import Elara.AST.VarRef (UnlocatedVarRef, VarRef)
import Elara.TypeInfer.Unique
import Prelude hiding (Constraint)

-- | A type scheme σ
data Type loc
Expand Down Expand Up @@ -67,11 +65,13 @@ data Scalar
| ScalarString
| ScalarChar
| ScalarUnit
deriving (Generic, Show, Eq, Ord)
deriving (Generic, Show, Eq, Ord, Enum, Bounded)

type DataCon = Qualified TypeName

data Substitution loc = Substitution [(UniqueTyVar, Monotype loc)]
newtype Substitution loc = Substitution [(UniqueTyVar, Monotype loc)]
deriving newtype (Semigroup, Monoid)
deriving stock (Eq, Show)

class Substitutable (a :: k -> Kind.Type) where
substitute :: UniqueTyVar -> Monotype loc -> a loc -> a loc
Expand Down
28 changes: 28 additions & 0 deletions test/Arbitrary/Type.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module Arbitrary.Type where

import Elara.AST.Name
import Elara.Data.Unique (unsafeMkUnique)
import Elara.TypeInfer.Type
import Elara.TypeInfer.Unique
import Hedgehog (Gen)
import Hedgehog.Gen qualified as Gen
import Hedgehog.Range qualified as Range
import Region (qualifiedTest)

-- | contrary to what the name suggests, this will NOT be unique :)
genUniqueTypeVar :: Gen UniqueTyVar
genUniqueTypeVar = unsafeMkUnique Nothing <$> Gen.integral (Range.linear 0 100)

typeConstructorNames :: [TypeName]
typeConstructorNames = ["List", "Maybe", "Pair", "Box", "IO"]

genMonotype :: Gen (Monotype loc)
genMonotype =
Gen.recursive
Gen.choice
[ TypeVar <$> genUniqueTypeVar
, Scalar <$> Gen.enumBounded
]
[ TypeConstructor <$> Gen.element (qualifiedTest <$> typeConstructorNames) <*> Gen.list (Range.linear 0 2) genMonotype
, Function <$> genMonotype <*> genMonotype
]
4 changes: 3 additions & 1 deletion test/Infer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import Elara.TypeInfer.Type
import Hedgehog (Property, assert, evalEither, evalEitherM, evalIO, failure, forAll, property)
import Hedgehog.Gen qualified as Gen
import Hedgehog.Range qualified as Range
import Infer.Unify qualified as Unify
import Optics.Operators.Unsafe ((^?!))
import Polysemy (Sem, run, runM, subsume, subsume_)
import Polysemy.Error (runError)
Expand All @@ -30,11 +31,12 @@ import Test.Syd.Hedgehog ()
import Prelude hiding (fail)

spec :: Spec
spec = describe "Infers types correctly" $ parallel $ do
spec = describe "Infers types correctly" $ do
literalTests
lambdaTests

it "infers literals" prop_literalTypesInvariants
Unify.spec

-- Literal Type Inference Tests
literalTests :: Spec
Expand Down
62 changes: 62 additions & 0 deletions test/Infer/Unify.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
module Infer.Unify where

import Arbitrary.Type (genMonotype, genUniqueTypeVar)
import Elara.TypeInfer.ConstraintGeneration
import Elara.TypeInfer.Type
import Hedgehog (Gen, Property, evalEither, forAll, property, (===))
import Hedgehog.Gen qualified as Gen
import Hedgehog.Range qualified as Range
import Polysemy
import Polysemy.Error
import Test.Syd
import Test.Syd.Hedgehog ()

spec :: Spec
spec = describe "Type unification" $ do
it "unifies type variables" prop_unify_type_vars
it "unifies scalars" prop_unify_scalars
it "unifies functions" prop_unify_functions
it "unifies self" prop_unify_self
it "fails to unify mismatched types" prop_unify_failure

runUnify ::
Sem '[Error UnifyError] (Substitution loc, Constraint loc) ->
Either UnifyError (Substitution loc, Constraint loc)
runUnify = run . runError

prop_unify_type_vars :: Property
prop_unify_type_vars = property $ do
a <- forAll $ genUniqueTypeVar
let typeVar = TypeVar a
(sub, _) <- evalEither $ runUnify $ unify typeVar typeVar
sub === Substitution []

prop_unify_scalars :: Property
prop_unify_scalars = property $ do
a <- forAll $ Gen.enumBounded
let scalarType = Scalar a
(sub, _) <- evalEither $ runUnify $ unify scalarType scalarType
sub === Substitution []

prop_unify_self :: Property
prop_unify_self = property $ do
a <- forAll genMonotype
(sub, _) <- evalEither $ runUnify $ unify a a
sub === Substitution []

prop_unify_functions :: Property
prop_unify_functions = property $ do
a <- forAll genMonotype
b <- forAll genMonotype
(sub, _) <- evalEither $ runUnify $ unify (Function a b) (Function a b)
sub === Substitution []

-- Hedgehog property: Check unification failure for mismatched types
prop_unify_failure :: Property
prop_unify_failure = property $ do
a <- forAll genMonotype
b <- forAll genMonotype
-- let's come back to this later
-- let result = runUnify $ unify a b
-- result === Left (UnificationFailed $ "Unification failed: " <> show a <> " and " <> show b)
guard $ a /= b

0 comments on commit 0003f18

Please sign in to comment.