Skip to content

Commit

Permalink
merge: Simplify confluence check (#1285)
Browse files Browse the repository at this point in the history
This PR simplifies confluence check by using a prefix count array data
structure for the clause index in MCTs.

It also stores MCTs in the core, which will be helpful later.
  • Loading branch information
ice1000 authored Jan 25, 2025
2 parents 32ac332 + cb5abcb commit 0830357
Show file tree
Hide file tree
Showing 13 changed files with 164 additions and 123 deletions.
17 changes: 8 additions & 9 deletions base/src/main/java/org/aya/normalize/Normalizer.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
// Copyright (c) 2020-2024 Tesla (Yinsen) Zhang.
// Copyright (c) 2020-2025 Tesla (Yinsen) Zhang.
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.normalize;

import java.util.function.BiFunction;
import java.util.function.UnaryOperator;

import static org.aya.generic.State.Stuck;

import kala.collection.SeqView;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.immutable.ImmutableSet;
Expand All @@ -25,15 +30,9 @@
import org.aya.syntax.ref.LocalVar;
import org.aya.tyck.TyckState;
import org.aya.tyck.tycker.Stateful;
import org.aya.util.error.WithPos;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.function.BiFunction;
import java.util.function.UnaryOperator;

import static org.aya.generic.State.Stuck;

/**
* Unlike in pre-v0.30 Aya, we use only one normalizer, only doing head reduction,
* and we merge conservative normalizer and the whnf normalizer.
Expand Down Expand Up @@ -92,8 +91,8 @@ case FnCall(FnDef.Delegate delegate, int ulift, var args) -> {
term = body.instTele(args.view());
continue;
}
case Either.Right(var clauses): {
var result = tryUnfoldClauses(clauses.view().map(WithPos::data),
case Either.Right(var body): {
var result = tryUnfoldClauses(body.matchingsView(),
args, core.is(Modifier.Overlap), ulift);
// we may get stuck
if (result == null) return defaultValue;
Expand Down
17 changes: 8 additions & 9 deletions base/src/main/java/org/aya/primitive/ShapeMatcher.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
// Copyright (c) 2020-2024 Tesla (Yinsen) Zhang.
// Copyright (c) 2020-2025 Tesla (Yinsen) Zhang.
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.primitive;

import java.util.Objects;
import java.util.function.BiPredicate;
import java.util.function.BooleanSupplier;
import java.util.function.Function;

import kala.collection.SeqLike;
import kala.collection.immutable.ImmutableMap;
import kala.collection.immutable.ImmutableSeq;
Expand All @@ -24,15 +29,9 @@
import org.aya.util.Pair;
import org.aya.util.RepoLike;
import org.aya.util.error.Panic;
import org.aya.util.error.WithPos;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.Objects;
import java.util.function.BiPredicate;
import java.util.function.BooleanSupplier;
import java.util.function.Function;

/**
* @author kiva
*/
Expand Down Expand Up @@ -123,10 +122,10 @@ private boolean matchFn(@NotNull FnShape shape, @NotNull FnDef def) {
return switch (new Pair<>(shape.body(), def.body())) {
case Pair(Either.Left(var termShape), Either.Left(var term)) ->
matchInside(() -> captures.put(shape.name(), def.ref()), () -> matchTerm(termShape, term));
case Pair(Either.Right(var clauseShapes), Either.Right(var clauses)) -> {
case Pair(Either.Right(var clauseShapes), Either.Right(var body)) -> {
var mode = def.is(Modifier.Overlap) ? MatchMode.Sub : MatchMode.Eq;
yield matchInside(() -> captures.put(shape.name(), def.ref()), () ->
matchMany(mode, clauseShapes, clauses.view().map(WithPos::data), this::matchClause));
matchMany(mode, clauseShapes, body.matchingsView(), this::matchClause));
}
default -> false;
};
Expand Down
11 changes: 5 additions & 6 deletions base/src/main/java/org/aya/terck/CallResolver.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) 2020-2024 Tesla (Yinsen) Zhang.
// Copyright (c) 2020-2025 Tesla (Yinsen) Zhang.
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.terck;

import java.util.function.Consumer;

import kala.collection.Set;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.immutable.ImmutableSet;
Expand All @@ -22,14 +24,11 @@
import org.aya.syntax.core.term.xtt.PAppTerm;
import org.aya.tyck.TyckState;
import org.aya.tyck.tycker.Stateful;
import org.aya.util.error.WithPos;
import org.aya.util.terck.CallGraph;
import org.aya.util.terck.CallMatrix;
import org.aya.util.terck.Relation;
import org.jetbrains.annotations.NotNull;

import java.util.function.Consumer;

/**
* Resolve calls and build call graph of recursive functions,
* after {@link org.aya.tyck.StmtTycker}.
Expand Down Expand Up @@ -139,8 +138,8 @@ private Relation compareConArgs(@NotNull ImmutableSeq<Term> conArgs, @NotNull Pa
}

public void check() {
var clauses = caller.body().getRightValue();
clauses.view().map(WithPos::data).forEach(this);
var clauses = caller.body().getRightValue().matchingsView();
clauses.forEach(this);
}

@Override public void accept(@NotNull Term.Matching matching) {
Expand Down
16 changes: 7 additions & 9 deletions base/src/main/java/org/aya/tyck/ExprTycker.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.tyck;

import kala.collection.AbstractSeqView;
import java.util.Comparator;
import java.util.function.BiFunction;
import java.util.function.Function;

import kala.collection.immutable.ImmutableSeq;
import kala.collection.immutable.ImmutableTreeSeq;
import kala.collection.mutable.MutableList;
Expand Down Expand Up @@ -51,10 +54,6 @@
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import java.util.Comparator;
import java.util.function.BiFunction;
import java.util.function.Function;

public final class ExprTycker extends AbstractTycker implements Unifiable {
public final @NotNull MutableTreeSet<WithPos<Expr.WithTerm>> withTerms =
MutableTreeSet.create(Comparator.comparing(SourceNode::sourcePos));
Expand Down Expand Up @@ -192,9 +191,7 @@ && whnf(type) instanceof DataCall dataCall
ImmutableSeq.fill(discriminant.size(), i ->
new LocalVar("match" + i, discriminant.get(i).sourcePos(), GenerateKind.Basic.Tyck)),
ImmutableSeq.empty(), clauses);
var wellClauses = clauseTycker.check(exprPos)
.wellTyped()
.map(WithPos::data);
var wellClauses = clauseTycker.check(exprPos).wellTyped().matchingsView();

// Find free occurrences
var usages = new FreeCollector();
Expand All @@ -205,7 +202,8 @@ && whnf(type) instanceof DataCall dataCall
var captures = usages.collected();
var lifted = new Matchy(type.bindTele(wellArgs.size(), captures.view()),
new QName(QPath.fileLevel(fileModule), "match-" + exprPos.lineColumnString()),
wellClauses.map(clause -> clause.update(clause.body().bindTele(clause.bindCount(), captures.view()))));
wellClauses.map(clause -> clause.update(clause.body().bindTele(clause.bindCount(), captures.view())))
.toImmutableSeq());

var wellTerms = wellArgs.map(Jdg::wellTyped);
return new MatchCall(lifted, wellTerms, captures.map(FreeTerm::new));
Expand Down
37 changes: 23 additions & 14 deletions base/src/main/java/org/aya/tyck/StmtTycker.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.tyck;

import static org.aya.tyck.tycker.TeleTycker.loadTele;

import kala.collection.immutable.ImmutableSeq;
import kala.control.Either;
import kala.control.Option;
Expand Down Expand Up @@ -44,8 +46,6 @@
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

import static org.aya.tyck.tycker.TeleTycker.loadTele;

public record StmtTycker(
@NotNull SuppressingReporter reporter, @NotNull ModulePath fileModule,
@NotNull ShapeFactory shapeFactory, @NotNull PrimFactory primFactory
Expand Down Expand Up @@ -85,8 +85,6 @@ public void suppress(@NotNull Decl decl) {
case FnDecl fnDecl -> {
var fnRef = fnDecl.ref;
assert fnRef.signature != null;

var factory = FnDef.factory(body -> new FnDef(fnRef, fnDecl.modifiers, body));
var teleVars = fnDecl.telescope.map(Expr.Param::ref);

yield switch (fnDecl.body) {
Expand All @@ -100,7 +98,7 @@ yield switch (fnDecl.body) {
var zonker = new Finalizer.Zonk<>(tycker);
var resultTerm = zonker.zonk(result).bindTele(teleVars.view());
fnRef.signature = fnRef.signature.descent(zonker::zonk);
yield factory.apply(Either.left(resultTerm));
yield new FnDef(fnRef, fnDecl.modifiers, Either.left(resultTerm));
}
case FnBody.BlockBody body -> {
var clauses = body.clauses();
Expand All @@ -119,22 +117,33 @@ yield switch (fnDecl.body) {

var orderIndependent = fnDecl.modifiers.contains(Modifier.Overlap);
FnDef def;
ClauseTycker.TyckResult patResult;
boolean hasLhsError;
FnClauseBody coreBody;
if (orderIndependent) {
// Order-independent.
patResult = clauseTycker.checkNoClassify();
def = factory.apply(Either.right(patResult.wellTyped()));
if (!patResult.hasLhsError()) {
var patResult = clauseTycker.checkNoClassify();
coreBody = new FnClauseBody(patResult.wellTyped());
def = new FnDef(fnRef, fnDecl.modifiers, Either.right(coreBody));
hasLhsError = patResult.hasLhsError();
if (!hasLhsError) {
var rawParams = signature.params();
var confluence = new YouTrack(rawParams, tycker, fnDecl.sourcePos());
confluence.check(patResult, signature.result(),
PatClassifier.classify(patResult.clauses().view(), rawParams.view(), tycker, fnDecl.sourcePos()));
var classes = PatClassifier.classify(patResult.clauses().view(),
rawParams.view(), tycker, fnDecl.sourcePos());
var absurds = patResult.absurdPrefixCount();
coreBody.classes = classes.map(cls -> cls.ignoreAbsurd(absurds));
confluence.check(coreBody, signature.result());
}
} else {
patResult = clauseTycker.check(fnDecl.entireSourcePos());
def = factory.apply(Either.right(patResult.wellTyped()));
var patResult = clauseTycker.check(fnDecl.entireSourcePos());
coreBody = patResult.wellTyped();
hasLhsError = patResult.hasLhsError();
def = new FnDef(fnRef, fnDecl.modifiers, Either.right(coreBody));
}
if (!hasLhsError) {
var hitConflChecker = new IApplyConfl(def, tycker, fnDecl.sourcePos());
hitConflChecker.check();
}
if (!patResult.hasLhsError()) new IApplyConfl(def, tycker, fnDecl.sourcePos()).check();
yield def;
}
};
Expand Down
52 changes: 41 additions & 11 deletions base/src/main/java/org/aya/tyck/pat/ClauseTycker.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@
// Use of this source code is governed by the MIT license that can be found in the LICENSE.md file.
package org.aya.tyck.pat;

import java.util.function.Supplier;
import java.util.function.UnaryOperator;

import kala.collection.SeqView;
import kala.collection.immutable.ImmutableSeq;
import kala.collection.immutable.primitive.ImmutableIntArray;
import kala.collection.immutable.primitive.ImmutableIntSeq;
import kala.value.primitive.MutableBooleanValue;
import org.aya.generic.Renamer;
import org.aya.normalize.Finalizer;
import org.aya.prettier.AyaPrettierOptions;
import org.aya.syntax.concrete.Expr;
import org.aya.syntax.concrete.Pattern;
import org.aya.syntax.core.Jdg;
import org.aya.syntax.core.def.FnClauseBody;
import org.aya.syntax.core.pat.Pat;
import org.aya.syntax.core.pat.PatToTerm;
import org.aya.syntax.core.pat.TypeEraser;
Expand All @@ -31,11 +37,10 @@
import org.aya.util.error.SourcePos;
import org.aya.util.error.WithPos;
import org.aya.util.reporter.Reporter;
import org.aya.util.tyck.pat.PatClass;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;

import java.util.function.Supplier;
import java.util.function.UnaryOperator;
import org.jetbrains.annotations.Nullable;

public final class ClauseTycker implements Problematic, Stateful {
private final @NotNull ExprTycker exprTycker;
Expand All @@ -46,6 +51,18 @@ public record TyckResult(@NotNull ImmutableSeq<Pat.Preclause<Term>> clauses, boo
public @NotNull ImmutableSeq<WithPos<Term.Matching>> wellTyped() {
return clauses.flatMap(Pat.Preclause::lift);
}
/// @return null if there is no absurd pattern
public @Nullable ImmutableIntSeq absurdPrefixCount() {
var ints = new int[clauses.size()];
var count = 0;
for (int i = 0; i < clauses.size(); i++) {
var clause = clauses.get(i);
if (clause.expr() == null) count++;
ints[i] = count;
}
if (count == 0) return null;
return ImmutableIntArray.Unsafe.wrap(ints);
}
}

/**
Expand Down Expand Up @@ -89,6 +106,7 @@ public void addLocalLet(@NotNull ImmutableSeq<LocalVar> teleBinds, @NotNull Expr
}
}

public record WorkerResult(FnClauseBody wellTyped, boolean hasLhsError) { }
public record Worker(
@NotNull ClauseTycker parent,
@NotNull ImmutableSeq<Param> telescope,
Expand All @@ -97,20 +115,31 @@ public record Worker(
@NotNull ImmutableSeq<LocalVar> elims,
@NotNull ImmutableSeq<Pattern.Clause> clauses
) {
public @NotNull TyckResult check(@NotNull SourcePos overallPos) {
public @NotNull WorkerResult check(@NotNull SourcePos overallPos) {
var lhs = checkAllLhs();

if (lhs.noneMatch(r -> r.hasError)) {
var classes = PatClassifier.classify(
ImmutableSeq<PatClass<ImmutableSeq<Term>>> classes;
var hasError = lhs.anyMatch(LhsResult::hasError);
if (!hasError) {
classes = PatClassifier.classify(
lhs.view().map(LhsResult::clause),
telescope.view().concat(unpi.params()), parent.exprTycker, overallPos);
if (clauses.isNotEmpty()) {
var usages = PatClassifier.firstMatchDomination(clauses, parent, classes);
// refinePatterns(lhs, usages, classes);
}
} else {
classes = null;
}

return parent.checkAllRhs(teleVars, lhs.map(cl -> cl.mapPats(new TypeEraser())));
var map = lhs.map(cl -> cl.mapPats(new TypeEraser()));
var rhs = parent.checkAllRhs(teleVars, map, hasError);
var wellTyped = new FnClauseBody(rhs.wellTyped());
if (classes != null) {
var absurds = rhs.absurdPrefixCount();
wellTyped.classes = classes.map(cl -> cl.ignoreAbsurd(absurds));
}
return new WorkerResult(wellTyped, hasError);
}

public @NotNull ImmutableSeq<LhsResult> checkAllLhs() {
Expand All @@ -120,7 +149,8 @@ public record Worker(
}

public @NotNull TyckResult checkNoClassify() {
return parent.checkAllRhs(teleVars, checkAllLhs());
var lhsResults = checkAllLhs();
return parent.checkAllRhs(teleVars, lhsResults, lhsResults.anyMatch(LhsResult::hasError));
}
}

Expand All @@ -138,9 +168,9 @@ public record Worker(

public @NotNull TyckResult checkAllRhs(
@NotNull ImmutableSeq<LocalVar> vars,
@NotNull ImmutableSeq<LhsResult> lhsResults
@NotNull ImmutableSeq<LhsResult> lhsResults,
boolean lhsError
) {
var lhsError = lhsResults.anyMatch(LhsResult::hasError);
var rhsResult = lhsResults.map(x -> checkRhs(vars, x));

// inline terms in rhsResult
Expand Down Expand Up @@ -188,7 +218,7 @@ public record Worker(
// It would be nice if we have a SourcePos here
new Pat.Bind(new LocalVar("unpi" + idx, SourcePos.NONE, GenerateKind.Basic.Tyck),
x.type()));

var wellTypedPats = patResult.wellTyped().appendedAll(missingPats);
var patWithTypeBound = Pat.collectVariables(wellTypedPats.view());

Expand Down
Loading

0 comments on commit 0830357

Please sign in to comment.