Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: omelkonian/setup-agda@v2.3
- uses: omelkonian/setup-agda@main
with:
agda-version: 2.8.0
stdlib-version: 2.3
Expand Down
191 changes: 144 additions & 47 deletions Tactic/Derive.agda
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
-- checker. Writing an actual derivation strategy then does not
-- require dealing with any mutual recursion, it is all handled here.
--
-- TODO: This breaks with:
-- - mutual recursion that nests too deep (i.e. deeper than 1, e.g. Term)
-- - indexed datatypes that require absurd clauses (e.g. Vec)
-- TODO: This breaks with indexed datatypes that require absurd clauses (e.g. Vec)
--
-- TODO: This is very slow for mutual recursion that nests too deep (e.g. Term)
--
-- TODO: support type classes with more than one field
--
Expand All @@ -27,7 +27,6 @@ import Data.Bool.ListAction as L
import Data.List as L hiding (any)
import Data.List.NonEmpty as NE
import Data.String as S
open import Data.Maybe using (fromMaybe)
open import Reflection.Tactic
open import Reflection.Utils
open import Reflection.Utils.TCI
Expand All @@ -44,13 +43,32 @@ open import Class.Traversable

instance
_ = ContextMonad-MonadTC
_ = Functor-M
_ = Functor-M {M = TC}
_ = Show-List
_ = DecEq-×

open ClauseExprM

-- A wrapper chain `[n₁ , n₂ , … , nₖ]` represents the applied type
-- `n₁ (n₂ (… (nₖ dName)))`.
WrapperChain : Set
WrapperChain = List Name

applyChain : WrapperChain → Term → Term
applyChain [] t = t
applyChain (n ∷ ns) t = n ∙⟦ applyChain ns t ⟧

-- Does the a path of the syntax tree of the term match the given
-- wrapper chain applied to `dName`?
matchesChain : WrapperChain → Name → Term → Bool
matchesChain [] dName (def n _) = ⌊ dName ≟ n ⌋
matchesChain [] _ _ = false
matchesChain (n ∷ ns) dName (def n' args) =
⌊ n ≟ n' ⌋ ∧ L.any (λ where (arg _ t) → matchesChain ns dName t) args
matchesChain _ _ _ = false

-- generate the type of the `className dName` instance
genClassType : ℕ → Name → Maybe Name → TC Type
genClassType : ℕ → Name → Maybe WrapperChain → TC Type
genClassType arity dName wName = do
params ← getParamsAndIndices dName
let params' = L.map (λ where (abs x y) → abs x (hide y)) $ take (length params ∸ arity) params
Expand Down Expand Up @@ -86,54 +104,133 @@ genClassType arity dName wName = do
x' ← (abs "_" ∘ iArg) <$> (genSortInstance k k i)
(x' ∷_) <$> (extendContext ("", hArg unknown) $ genSortInstanceWithCtx xs)

modifyClassType : Maybe Name → TypeView → Type
modifyClassType nothing (tel , ty) = tyView (tel , className ∙⟦ ty ⟧)
modifyClassType (just n) (tel , ty) = tyView (tel , className ∙⟦ n ∙⟦ ty ⟧ ⟧)

lookupName : List (Name × Name) → Name → Maybe Name
lookupName = lookupᵇ (λ n n' → ⌊ n ≟ n' ⌋)

-- Look at the constructors of the argument and return all types that
-- recursively contain it. This isn't very clever right now.
genMutualHelpers : Name → TC (List Name)
genMutualHelpers n = do
tys ← L.map (unArg ∘ unAbs) <$> (L.concatMap (proj₁ ∘ viewTy ∘ proj₂) <$> getConstrs n)
return $ deduplicate _≟_ $ L.mapMaybe helper tys
where
helper : Type → Maybe Name
helper (def n' args) =
if L.any (λ where (arg _ (def n'' _)) → ⌊ n ≟ n'' ⌋ ; _ → false) args
then just n' else nothing
helper _ = nothing

module _ (arity : ℕ) (genCe : (Name → Maybe Name) → List SinglePattern → List (NE.List⁺ SinglePattern × TC (ClauseExpr ⊎ Maybe Term))) where
modifyClassType : Maybe WrapperChain → TypeView → Type
modifyClassType nothing (tel , ty) = tyView (tel , className ∙⟦ ty ⟧)
modifyClassType (just ns) (tel , ty) = tyView (tel , className ∙⟦ applyChain ns ty ⟧)

-- Entry in the translation table: ((chain , target data name) , instance name).
-- The chain applied to the target datatype gives the concrete type whose
-- instance we resolved to `instance name`. An empty chain represents a
-- user-declared base instance for the target itself.
TransEntry : Set
TransEntry = (WrapperChain × Name) × Name

lookupByTerm : List TransEntry → Term → Maybe Name
lookupByTerm l ty = proj₂ <$> L.findᵇ (λ where ((c , t) , _) → matchesChain c t ty) l

-- allChainsTo:
-- Look at the constructors of the argument and return all non-empty
-- wrapper chains `[n₁ , … , nₖ]` together with the mutual-peer name `t`
-- they terminate at, such that some constructor field has a subterm
-- of the form `n₁ (n₂ (… (nₖ t)))`.
private module AllChainsTo (ns : List Name) where
nameInSeeds : Name → Bool
nameInSeeds n' = L.any (_≡ᵇ n') ns

-- All chains from the head of `t` down to some position equal to
-- some seed in `ns`.
mutual
chainsTo : Term → List (WrapperChain × Name)
chainsTo (def n' args) with nameInSeeds n'
... | true = [ ([] , n') ]
... | false = L.map (λ where (c , t) → (n' ∷ c , t)) (chainsToArgs args)
chainsTo _ = []

chainsToArgs : List (Arg Term) → List (WrapperChain × Name)
chainsToArgs [] = []
chainsToArgs (arg _ t ∷ rest) = chainsTo t ++ chainsToArgs rest

-- `chainsTo` applied to the term and, recursively, to every
-- argument position, so nested sub-chains are also collected.
-- Bare references to seeds themselves are skipped (no helper needed).
mutual
allChainsTo : Term → List (WrapperChain × Name)
allChainsTo t@(def n' args) with nameInSeeds n'
... | true = []
... | false = chainsTo t ++ allChainsToArgs args
allChainsTo _ = []

allChainsToArgs : List (Arg Term) → List (WrapperChain × Name)
allChainsToArgs [] = []
allChainsToArgs (arg _ t ∷ rest) = allChainsTo t ++ allChainsToArgs rest

open AllChainsTo using (allChainsTo)

-- Collect all non-empty wrapper chains (tagged with their terminating
-- seed) discovered in the constructors arguments of the constructors
-- of any seed.
genMutualHelpers : List Name → TC (List (WrapperChain × Name))
genMutualHelpers ns = do
tysPerSeed ← traverse
(λ n → L.map (unArg ∘ unAbs) <$> (L.concatMap (proj₁ ∘ viewTy ∘ proj₂) <$> getConstrs n)) ns
return $ deduplicate _≟_ $ L.concatMap (allChainsTo ns) $ concat tysPerSeed

module _ (arity : ℕ) (genCe : (Term → Maybe Name) → List SinglePattern → List (NE.List⁺ SinglePattern × TC (ClauseExpr ⊎ Maybe Term))) where
-- Generate the declaration & definition of a particular derivation.
--
-- Takes a dictionary (for mutual recursion), a wrapper (also for
-- mutual recursion), the name of the original type we want to derive
-- Show for and the name we want to define Show originally at.
deriveSingle : List (Name × Name) → Name → Name → Maybe Name → TC (Arg Name × Type × List Clause)
deriveSingle transName dName iName wName = inDebugPath "DeriveSingle" do
debugLog ("For: " ∷ᵈ dName ∷ᵈ [])
goalTy ← genClassType arity dName wName
ps ← constructorPatterns' (fromMaybe dName wName ∙)
-- `tgtData` is the target datatype and `wName` is an optional wrapper
-- chain; the resulting instance has type
-- `className (applyChain wName tgtData)` and is declared at `iName`.
-- Base instances (wName ≡ nothing) are declared with `iArg` so they
-- participate in instance search; helpers (wName ≡ just _) are
-- declared with `vArg` and referenced explicitly via `transName`.

deriveSingle : List TransEntry → Name → Name → Maybe WrapperChain → TC (Arg Name × Type × List Clause)
deriveSingle transName tgtData iName wChain = inDebugPath "DeriveSingle" do
debugLog ("For: " ∷ᵈ tgtData ∷ᵈ [])
-- e.g. `⦃ DecEq A ⦄ → DecEq (List (Arg A))` for chain [List,Arg]
goalTy ← genClassType arity tgtData wChain
-- we only ever have to pattern-match on the outermost patterns
-- since we call other instances directly (i.e. we wouldn't match
-- on `Arg` in the above example, but rather call another helper
-- `⦃ DecEq A ⦄ → DecEq (Arg A)`)
let outerName = maybe (λ where [] → tgtData ; (x ∷ _) → x) tgtData wChain
ps ← constructorPatterns' (outerName ∙)
-- TODO: find a good way of printing this
--debugLogᵐ ("Constrs: " ∷ᵈᵐ ps ᵛⁿ ∷ᵈᵐ []ᵐ)
cs ← local (λ c → record c { goal = inj₂ goalTy }) $
singleMatchExpr ([] , iArg (Pattern.proj projName)) $ contMatch $ multiMatchExprM $
genCe (lookupName transName) ps
let defName = maybe (maybe vArg (iArg iName) ∘ lookupName transName) (iArg iName) wName
genCe (lookupByTerm transName) ps
-- only names declared by the user should participate in instance search
let defName = maybe (λ _ → vArg iName) (iArg iName) wChain
return (defName , goalTy , clauseExprToClauses cs)

deriveMulti : Name × Name × List Name → TC (List (Arg Name × Type × List Clause))
deriveMulti (dName , iName , hClasses) = do
hClassNames ← traverse ⦃ Functor-List ⦄
(λ cn → freshName (showName className S.++ "-" S.++ showName cn S.++ showName dName)) hClasses
traverse ⦃ Functor-List ⦄ (deriveSingle (L.zip hClasses hClassNames) dName iName) (nothing ∷ L.map just hClasses)
-- Derive all instances for a group of mutually-defined seeds at once,
-- sharing a single translation table that covers both the user-named
-- base instances and the fresh wrapper-chain helpers.
deriveGroup : List (Name × Name) → TC (List (Arg Name × Type × List Clause))
deriveGroup seeds = do
-- discover all wrapper-chain helpers needed by any seed's constructors,
-- e.g. ([List,Arg],Term) when Term has a `List (Arg Term)` field
helpers ← genMutualHelpers $ L.map proj₁ seeds
-- generate fresh, human-readable names for the helpers
helperNames ← traverse ⦃ Functor-List ⦄ mkHelperName helpers
let helperTable : List TransEntry
helperTable = L.zip helpers helperNames

-- seeds get empty-chain entries so cross-seed references are
-- resolved explicitly rather than relying on instance search
seedTable : List TransEntry
seedTable = L.map (λ (s , i) → (([] , s) , i)) seeds

-- all derivations share a single table so every mutual peer and helper is in scope
open D (seedTable ++ helperTable)
baseResults ← traverse deriveBase seeds
helperResults ← traverse deriveHelper helperTable
return (baseResults ++ helperResults)
where
mkHelperName : WrapperChain × Name → TC Name
mkHelperName (chain , tgt) = freshName
(showName className
S.++ L.foldr (λ n s → "-" S.++ showName n S.++ s) "" chain
S.++ "-" S.++ showName tgt)

module D (transName : List (TransEntry)) where
deriveBase : Name × Name → TC (Arg Name × Type × List Clause)
deriveBase (s , i) = deriveSingle transName s i nothing

deriveHelper : TransEntry → TC (Arg Name × Type × List Clause)
deriveHelper ((chain , tgt) , n) = deriveSingle transName tgt n (just chain)

derive-Class : ⦃ TCOptions ⦄ → List (Name × Name) → UnquoteDecl
derive-Class l = initUnquoteWithGoal (className ∙) $
declareAndDefineFuns =<< runAndReset (concat <$> traverse ⦃ Functor-List ⦄ helper l)
where
helper : Name × Name → TC (List (Arg Name × Type × List Clause))
helper (a , b) = do hs ← genMutualHelpers a ; deriveMulti (a , b , hs)
declareAndDefineFuns =<< runAndReset (deriveGroup l)
8 changes: 6 additions & 2 deletions Tactic/Derive/DecEq.agda
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ private
pattern ``yes' x = quote _because_ ◇⟦ quote true ◇ ∣ x ⟧
pattern ``no' x = quote _because_ ◇⟦ quote false ◇ ∣ x ⟧

module _ (transName : Name → Maybe Name) where
module _ (transName : Term → Maybe Name) where

eqFromTerm : Term → Term → Term → Term
eqFromTerm (def n _) t t' with transName n
eqFromTerm ty@(def _ _) t t' with transName ty
... | just n' = def (quote _≟_) (iArg (n' ∙) ∷ vArg t ∷ vArg t' ∷ [])
... | nothing = quote _≟_ ∙⟦ t ∣ t' ⟧
eqFromTerm _ t t' = quote _≟_ ∙⟦ t ∣ t' ⟧
Expand Down Expand Up @@ -149,4 +149,8 @@ private

unquoteDecl DecEq-M₁ DecEq-M₂ = derive-DecEq $ (quote M₁ , DecEq-M₁) ∷ (quote M₂ , DecEq-M₂) ∷ []

unquoteDecl DecEq-E5 = derive-DecEq [ (quote E5 , DecEq-E5) ]

unquoteDecl DecEq-N₁ DecEq-N₂ = derive-DecEq $ (quote N₁ , DecEq-N₁) ∷ (quote N₂ , DecEq-N₂) ∷ []

-- Expected: DecEq-Term DecEq-Product
8 changes: 6 additions & 2 deletions Tactic/Derive/Show.agda
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ wrapWithPars s = "(" ++S s ++S ")"
genPars : Term → Term
genPars t = quote wrapWithPars ∙⟦ t ⟧

module _ (transName : Name → Maybe Name) where
module _ (transName : Term → Maybe Name) where
showFromTerm : Term → Term → Term
showFromTerm (def n _) t with transName n
showFromTerm ty@(def _ _) t with transName ty
... | just n' = def (quote show) (iArg (n' ∙) ∷ vArg t ∷ [])
... | nothing = quote show ∙⟦ t ⟧
showFromTerm _ t = quote show ∙⟦ t ⟧
Expand Down Expand Up @@ -92,4 +92,8 @@ private

unquoteDecl Show-M₁ Show-M₂ = derive-Show $ (quote M₁ , Show-M₁) ∷ (quote M₂ , Show-M₂) ∷ []

unquoteDecl Show-E5 = derive-Show [ (quote E5 , Show-E5) ]

unquoteDecl Show-N₁ Show-N₂ = derive-Show $ (quote N₁ , Show-N₁) ∷ (quote N₂ , Show-N₂) ∷ []

-- Expected: Show-Product Show-Term
19 changes: 18 additions & 1 deletion Tactic/Derive/TestTypes.agda
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ data E4 : {n : ℕ} → Fin n → Set where
c1E4 : ∀ {k} → E4 {suc k} zero
c2E4 : ∀ {k} {l} → E4 {suc k} (suc l)

data E5 : Set where
c1E5 : List (Maybe E5) → E5
c2E5 : E5

record R1 : Set where
field f1R1 : E1
f2R1 : E2 ℕ
Expand All @@ -54,8 +58,21 @@ data M₂ where
m₂ : M₂
m₁→₂ : M₁ → M₂

-- Like M₁/M₂ but N₁ wraps N₂ in a List.
-- This exercises cross-seed wrapper chains: when deriving for both seeds
-- together the helper `DecEq-List-N₂` (or `Show-List-N₂`) lands in the
-- same mutual group, making termination visible to Agda.
data N₁ : Set
data N₂ : Set
data N₁ where
n₁ : N₁
n₂→₁ : List N₂ → N₁
data N₂ where
n₂ : N₂
n₁→₂ : N₁ → N₂

AllTestTypes : List Name
AllTestTypes = quote E0 ∷ quote E1 ∷ quote E2 ∷ quote E3 ∷ quote R1 ∷ quote R2 ∷ quote M₁ ∷ quote M₂ ∷ []
AllTestTypes = quote E0 ∷ quote E1 ∷ quote E2 ∷ quote E3 ∷ quote R1 ∷ quote R2 ∷ quote M₁ ∷ quote M₂ ∷ quote E5 ∷ []

open import Data.Bool using (Bool) public
open import Data.Char using (Char) public
Expand Down
Loading