Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify confluence check #1285

Merged
merged 6 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions base/src/main/java/org/aya/normalize/Normalizer.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
// Copyright (c) 2020-2024 Tesla (Yinsen) Zhang.
// Copyright (c) 2020-2025 Tesla (Yinsen) Zhang.
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.normalize;

import java.util.function.BiFunction;
import java.util.function.UnaryOperator;

import static org.aya.generic.State.Stuck;

import kala.collection.SeqView;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.immutable.ImmutableSet;
Expand All @@ -25,15 +30,9 @@
import org.aya.syntax.ref.LocalVar;
import org.aya.tyck.TyckState;
import org.aya.tyck.tycker.Stateful;
import org.aya.util.error.WithPos;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.function.BiFunction;
import java.util.function.UnaryOperator;

import static org.aya.generic.State.Stuck;

/**
* Unlike in pre-v0.30 Aya, we use only one normalizer, only doing head reduction,
* and we merge conservative normalizer and the whnf normalizer.
Expand Down Expand Up @@ -92,8 +91,8 @@ case FnCall(FnDef.Delegate delegate, int ulift, var args) -> {
term = body.instTele(args.view());
continue;
}
case Either.Right(var clauses): {
var result = tryUnfoldClauses(clauses.view().map(WithPos::data),
case Either.Right(var body): {
var result = tryUnfoldClauses(body.matchingsView(),
args, core.is(Modifier.Overlap), ulift);
// we may get stuck
if (result == null) return defaultValue;
Expand Down
17 changes: 8 additions & 9 deletions base/src/main/java/org/aya/primitive/ShapeMatcher.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
// Copyright (c) 2020-2024 Tesla (Yinsen) Zhang.
// Copyright (c) 2020-2025 Tesla (Yinsen) Zhang.
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.primitive;

import java.util.Objects;
import java.util.function.BiPredicate;
import java.util.function.BooleanSupplier;
import java.util.function.Function;

import kala.collection.SeqLike;
import kala.collection.immutable.ImmutableMap;
import kala.collection.immutable.ImmutableSeq;
Expand All @@ -24,15 +29,9 @@
import org.aya.util.Pair;
import org.aya.util.RepoLike;
import org.aya.util.error.Panic;
import org.aya.util.error.WithPos;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.Objects;
import java.util.function.BiPredicate;
import java.util.function.BooleanSupplier;
import java.util.function.Function;

/**
* @author kiva
*/
Expand Down Expand Up @@ -123,10 +122,10 @@ private boolean matchFn(@NotNull FnShape shape, @NotNull FnDef def) {
return switch (new Pair<>(shape.body(), def.body())) {
case Pair(Either.Left(var termShape), Either.Left(var term)) ->
matchInside(() -> captures.put(shape.name(), def.ref()), () -> matchTerm(termShape, term));
case Pair(Either.Right(var clauseShapes), Either.Right(var clauses)) -> {
case Pair(Either.Right(var clauseShapes), Either.Right(var body)) -> {
var mode = def.is(Modifier.Overlap) ? MatchMode.Sub : MatchMode.Eq;
yield matchInside(() -> captures.put(shape.name(), def.ref()), () ->
matchMany(mode, clauseShapes, clauses.view().map(WithPos::data), this::matchClause));
matchMany(mode, clauseShapes, body.matchingsView(), this::matchClause));
}
default -> false;
};
Expand Down
11 changes: 5 additions & 6 deletions base/src/main/java/org/aya/terck/CallResolver.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) 2020-2024 Tesla (Yinsen) Zhang.
// Copyright (c) 2020-2025 Tesla (Yinsen) Zhang.
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.terck;

import java.util.function.Consumer;

import kala.collection.Set;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.immutable.ImmutableSet;
Expand All @@ -22,14 +24,11 @@
import org.aya.syntax.core.term.xtt.PAppTerm;
import org.aya.tyck.TyckState;
import org.aya.tyck.tycker.Stateful;
import org.aya.util.error.WithPos;
import org.aya.util.terck.CallGraph;
import org.aya.util.terck.CallMatrix;
import org.aya.util.terck.Relation;
import org.jetbrains.annotations.NotNull;

import java.util.function.Consumer;

/**
* Resolve calls and build call graph of recursive functions,
* after {@link org.aya.tyck.StmtTycker}.
Expand Down Expand Up @@ -139,8 +138,8 @@ private Relation compareConArgs(@NotNull ImmutableSeq<Term> conArgs, @NotNull Pa
}

public void check() {
var clauses = caller.body().getRightValue();
clauses.view().map(WithPos::data).forEach(this);
var clauses = caller.body().getRightValue().matchingsView();
clauses.forEach(this);
}

@Override public void accept(@NotNull Term.Matching matching) {
Expand Down
16 changes: 7 additions & 9 deletions base/src/main/java/org/aya/tyck/ExprTycker.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.tyck;

import kala.collection.AbstractSeqView;
import java.util.Comparator;
import java.util.function.BiFunction;
import java.util.function.Function;

import kala.collection.immutable.ImmutableSeq;
import kala.collection.immutable.ImmutableTreeSeq;
import kala.collection.mutable.MutableList;
Expand Down Expand Up @@ -51,10 +54,6 @@
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.Comparator;
import java.util.function.BiFunction;
import java.util.function.Function;

public final class ExprTycker extends AbstractTycker implements Unifiable {
public final @NotNull MutableTreeSet<WithPos<Expr.WithTerm>> withTerms =
MutableTreeSet.create(Comparator.comparing(SourceNode::sourcePos));
Expand Down Expand Up @@ -192,9 +191,7 @@ && whnf(type) instanceof DataCall dataCall
ImmutableSeq.fill(discriminant.size(), i ->
new LocalVar("match" + i, discriminant.get(i).sourcePos(), GenerateKind.Basic.Tyck)),
ImmutableSeq.empty(), clauses);
var wellClauses = clauseTycker.check(exprPos)
.wellTyped()
.map(WithPos::data);
var wellClauses = clauseTycker.check(exprPos).wellTyped().matchingsView();

// Find free occurrences
var usages = new FreeCollector();
Expand All @@ -205,7 +202,8 @@ && whnf(type) instanceof DataCall dataCall
var captures = usages.collected();
var lifted = new Matchy(type.bindTele(wellArgs.size(), captures.view()),
new QName(QPath.fileLevel(fileModule), "match-" + exprPos.lineColumnString()),
wellClauses.map(clause -> clause.update(clause.body().bindTele(clause.bindCount(), captures.view()))));
wellClauses.map(clause -> clause.update(clause.body().bindTele(clause.bindCount(), captures.view())))
.toImmutableSeq());

var wellTerms = wellArgs.map(Jdg::wellTyped);
return new MatchCall(lifted, wellTerms, captures.map(FreeTerm::new));
Expand Down
37 changes: 23 additions & 14 deletions base/src/main/java/org/aya/tyck/StmtTycker.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.tyck;

import static org.aya.tyck.tycker.TeleTycker.loadTele;

import kala.collection.immutable.ImmutableSeq;
import kala.control.Either;
import kala.control.Option;
Expand Down Expand Up @@ -44,8 +46,6 @@
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import static org.aya.tyck.tycker.TeleTycker.loadTele;

public record StmtTycker(
@NotNull SuppressingReporter reporter, @NotNull ModulePath fileModule,
@NotNull ShapeFactory shapeFactory, @NotNull PrimFactory primFactory
Expand Down Expand Up @@ -85,8 +85,6 @@ public void suppress(@NotNull Decl decl) {
case FnDecl fnDecl -> {
var fnRef = fnDecl.ref;
assert fnRef.signature != null;

var factory = FnDef.factory(body -> new FnDef(fnRef, fnDecl.modifiers, body));
var teleVars = fnDecl.telescope.map(Expr.Param::ref);

yield switch (fnDecl.body) {
Expand All @@ -100,7 +98,7 @@ yield switch (fnDecl.body) {
var zonker = new Finalizer.Zonk<>(tycker);
var resultTerm = zonker.zonk(result).bindTele(teleVars.view());
fnRef.signature = fnRef.signature.descent(zonker::zonk);
yield factory.apply(Either.left(resultTerm));
yield new FnDef(fnRef, fnDecl.modifiers, Either.left(resultTerm));
}
case FnBody.BlockBody body -> {
var clauses = body.clauses();
Expand All @@ -119,22 +117,33 @@ yield switch (fnDecl.body) {

var orderIndependent = fnDecl.modifiers.contains(Modifier.Overlap);
FnDef def;
ClauseTycker.TyckResult patResult;
boolean hasLhsError;
FnClauseBody coreBody;
if (orderIndependent) {
// Order-independent.
patResult = clauseTycker.checkNoClassify();
def = factory.apply(Either.right(patResult.wellTyped()));
if (!patResult.hasLhsError()) {
var patResult = clauseTycker.checkNoClassify();
coreBody = new FnClauseBody(patResult.wellTyped());
def = new FnDef(fnRef, fnDecl.modifiers, Either.right(coreBody));
hasLhsError = patResult.hasLhsError();
if (!hasLhsError) {
var rawParams = signature.params();
var confluence = new YouTrack(rawParams, tycker, fnDecl.sourcePos());
confluence.check(patResult, signature.result(),
PatClassifier.classify(patResult.clauses().view(), rawParams.view(), tycker, fnDecl.sourcePos()));
var classes = PatClassifier.classify(patResult.clauses().view(),
rawParams.view(), tycker, fnDecl.sourcePos());
var absurds = patResult.absurdPrefixCount();
coreBody.classes = classes.map(cls -> cls.ignoreAbsurd(absurds));
confluence.check(coreBody, signature.result());
}
} else {
patResult = clauseTycker.check(fnDecl.entireSourcePos());
def = factory.apply(Either.right(patResult.wellTyped()));
var patResult = clauseTycker.check(fnDecl.entireSourcePos());
coreBody = patResult.wellTyped();
hasLhsError = patResult.hasLhsError();
def = new FnDef(fnRef, fnDecl.modifiers, Either.right(coreBody));
}
if (!hasLhsError) {
var hitConflChecker = new IApplyConfl(def, tycker, fnDecl.sourcePos());
hitConflChecker.check();
}
if (!patResult.hasLhsError()) new IApplyConfl(def, tycker, fnDecl.sourcePos()).check();
yield def;
}
};
Expand Down
52 changes: 41 additions & 11 deletions base/src/main/java/org/aya/tyck/pat/ClauseTycker.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.tyck.pat;

import java.util.function.Supplier;
import java.util.function.UnaryOperator;

import kala.collection.SeqView;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.immutable.primitive.ImmutableIntArray;
import kala.collection.immutable.primitive.ImmutableIntSeq;
import kala.value.primitive.MutableBooleanValue;
import org.aya.generic.Renamer;
import org.aya.normalize.Finalizer;
import org.aya.prettier.AyaPrettierOptions;
import org.aya.syntax.concrete.Expr;
import org.aya.syntax.concrete.Pattern;
import org.aya.syntax.core.Jdg;
import org.aya.syntax.core.def.FnClauseBody;
import org.aya.syntax.core.pat.Pat;
import org.aya.syntax.core.pat.PatToTerm;
import org.aya.syntax.core.pat.TypeEraser;
Expand All @@ -31,11 +37,10 @@
import org.aya.util.error.SourcePos;
import org.aya.util.error.WithPos;
import org.aya.util.reporter.Reporter;
import org.aya.util.tyck.pat.PatClass;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;

import java.util.function.Supplier;
import java.util.function.UnaryOperator;
import org.jetbrains.annotations.Nullable;

public final class ClauseTycker implements Problematic, Stateful {
private final @NotNull ExprTycker exprTycker;
Expand All @@ -46,6 +51,18 @@ public record TyckResult(@NotNull ImmutableSeq<Pat.Preclause<Term>> clauses, boo
public @NotNull ImmutableSeq<WithPos<Term.Matching>> wellTyped() {
return clauses.flatMap(Pat.Preclause::lift);
}
/// @return null if there is no absurd pattern
public @Nullable ImmutableIntSeq absurdPrefixCount() {
var ints = new int[clauses.size()];
var count = 0;
for (int i = 0; i < clauses.size(); i++) {
var clause = clauses.get(i);
if (clause.expr() == null) count++;
ints[i] = count;
}
if (count == 0) return null;
return ImmutableIntArray.Unsafe.wrap(ints);
}
}

/**
Expand Down Expand Up @@ -89,6 +106,7 @@ public void addLocalLet(@NotNull ImmutableSeq<LocalVar> teleBinds, @NotNull Expr
}
}

public record WorkerResult(FnClauseBody wellTyped, boolean hasLhsError) { }
public record Worker(
@NotNull ClauseTycker parent,
@NotNull ImmutableSeq<Param> telescope,
Expand All @@ -97,20 +115,31 @@ public record Worker(
@NotNull ImmutableSeq<LocalVar> elims,
@NotNull ImmutableSeq<Pattern.Clause> clauses
) {
public @NotNull TyckResult check(@NotNull SourcePos overallPos) {
public @NotNull WorkerResult check(@NotNull SourcePos overallPos) {
var lhs = checkAllLhs();

if (lhs.noneMatch(r -> r.hasError)) {
var classes = PatClassifier.classify(
ImmutableSeq<PatClass<ImmutableSeq<Term>>> classes;
var hasError = lhs.anyMatch(LhsResult::hasError);
if (!hasError) {
classes = PatClassifier.classify(
lhs.view().map(LhsResult::clause),
telescope.view().concat(unpi.params()), parent.exprTycker, overallPos);
if (clauses.isNotEmpty()) {
var usages = PatClassifier.firstMatchDomination(clauses, parent, classes);
// refinePatterns(lhs, usages, classes);
}
} else {
classes = null;
}

return parent.checkAllRhs(teleVars, lhs.map(cl -> cl.mapPats(new TypeEraser())));
var map = lhs.map(cl -> cl.mapPats(new TypeEraser()));
var rhs = parent.checkAllRhs(teleVars, map, hasError);
var wellTyped = new FnClauseBody(rhs.wellTyped());
if (classes != null) {
var absurds = rhs.absurdPrefixCount();
wellTyped.classes = classes.map(cl -> cl.ignoreAbsurd(absurds));
}
return new WorkerResult(wellTyped, hasError);
}

public @NotNull ImmutableSeq<LhsResult> checkAllLhs() {
Expand All @@ -120,7 +149,8 @@ public record Worker(
}

public @NotNull TyckResult checkNoClassify() {
return parent.checkAllRhs(teleVars, checkAllLhs());
var lhsResults = checkAllLhs();
return parent.checkAllRhs(teleVars, lhsResults, lhsResults.anyMatch(LhsResult::hasError));
}
}

Expand All @@ -138,9 +168,9 @@ public record Worker(

public @NotNull TyckResult checkAllRhs(
@NotNull ImmutableSeq<LocalVar> vars,
@NotNull ImmutableSeq<LhsResult> lhsResults
@NotNull ImmutableSeq<LhsResult> lhsResults,
boolean lhsError
) {
var lhsError = lhsResults.anyMatch(LhsResult::hasError);
var rhsResult = lhsResults.map(x -> checkRhs(vars, x));

// inline terms in rhsResult
Expand Down Expand Up @@ -188,7 +218,7 @@ public record Worker(
// It would be nice if we have a SourcePos here
new Pat.Bind(new LocalVar("unpi" + idx, SourcePos.NONE, GenerateKind.Basic.Tyck),
x.type()));

var wellTypedPats = patResult.wellTyped().appendedAll(missingPats);
var patWithTypeBound = Pat.collectVariables(wellTypedPats.view());

Expand Down
Loading
Loading