Skip to content

Commit

Permalink
Fix a bug in NormalizeVisitor
Browse files Browse the repository at this point in the history
  • Loading branch information
valis committed Jul 6, 2024
1 parent 59998df commit 88101db
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public FunctionDefinition getDefinition() {

@Override
public @Nullable Expression evaluate() {
return NormalizeVisitor.INSTANCE.eval(this);
return NormalizeVisitor.INSTANCE.eval(this, true);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public List<? extends Expression> getArguments() {
}

public Expression eval() {
return NormalizeVisitor.INSTANCE.eval(myExpression);
return NormalizeVisitor.INSTANCE.eval(myExpression, true);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,18 +455,18 @@ public Expression visitBody(Body body, List<? extends Expression> defCallArgs, L
if (body instanceof Expression) {
ExprSubstitution substitution = addArguments(getDataTypeArgumentsSubstitution(expr), defCallArgs, definition);
LevelSubstitution levelSubstitution = expr.getLevelSubstitution();
if (body instanceof CaseExpression caseExpr && !((CaseExpression) body).isSCase()) {
if (body instanceof CaseExpression caseExpr && !caseExpr.isSCase()) {
List<Expression> args = new ArrayList<>(caseExpr.getArguments().size());
for (Expression arg : caseExpr.getArguments()) {
args.add(arg.subst(substitution, levelSubstitution));
}
Expression result = eval(caseExpr.getElimBody(), args, substitution, levelSubstitution, SubstExpression.make(caseExpr, substitution, levelSubstitution), mode);
Expression result = eval(caseExpr.getElimBody(), args, substitution, levelSubstitution, SubstExpression.make(caseExpr, substitution, levelSubstitution), mode, false);
return result == null ? caseExpr.subst(substitution, levelSubstitution) : result;
} else {
return ((Expression) body).subst(substitution, levelSubstitution).accept(this, mode);
}
} else if (body instanceof ElimBody) {
return eval((ElimBody) body, defCallArgs, getDataTypeArgumentsSubstitution(expr), expr.getLevelSubstitution(), expr, mode);
return eval((ElimBody) body, defCallArgs, getDataTypeArgumentsSubstitution(expr), expr.getLevelSubstitution(), expr, mode, false);
} else {
throw new IllegalStateException();
}
Expand All @@ -489,18 +489,18 @@ public Deque<Expression> makeStack(List<? extends Expression> arguments) {
return stack;
}

public Expression eval(Expression expr) {
public Expression eval(Expression expr, boolean shouldProgress) {
if (expr instanceof LeveledDefCallExpression defCall) {
Body body = defCall instanceof FunCallExpression ? ((FunCallExpression) defCall).getDefinition().getActualBody() : defCall instanceof ConCallExpression ? ((ConCallExpression) defCall).getDefinition().getBody() : null;
if (body instanceof Expression) {
return ((Expression) body).subst(new ExprSubstitution().add(defCall.getDefinition().getParameters(), defCall.getDefCallArguments()), defCall.getLevelSubstitution());
} else if (body instanceof ElimBody) {
return eval((ElimBody) body, defCall.getDefCallArguments(), new ExprSubstitution(), defCall.getLevelSubstitution(), expr, null);
return eval((ElimBody) body, defCall.getDefCallArguments(), new ExprSubstitution(), defCall.getLevelSubstitution(), expr, null, shouldProgress);
} else {
return null;
}
} else if (expr instanceof CaseExpression) {
return eval(((CaseExpression) expr).getElimBody(), ((CaseExpression) expr).getArguments(), new ExprSubstitution(), LevelSubstitution.EMPTY, expr, null);
return eval(((CaseExpression) expr).getElimBody(), ((CaseExpression) expr).getArguments(), new ExprSubstitution(), LevelSubstitution.EMPTY, expr, null, shouldProgress);
} else {
return null;
}
Expand All @@ -510,7 +510,7 @@ private static boolean isBlocked(FunctionDefinition def) {
return def.isSFunc() || def == Prelude.PLUS || def == Prelude.MUL || def == Prelude.MINUS || def == Prelude.DIV || def == Prelude.MOD || def == Prelude.DIV_MOD || def == Prelude.COERCE || def == Prelude.COERCE2;
}

public Expression eval(ElimBody elimBody, List<? extends Expression> arguments, ExprSubstitution substitution, LevelSubstitution levelSubstitution, Expression resultExpr, NormalizationMode mode) {
public Expression eval(ElimBody elimBody, List<? extends Expression> arguments, ExprSubstitution substitution, LevelSubstitution levelSubstitution, Expression origExpr, NormalizationMode mode, boolean shouldProgress) {
Deque<Expression> stack = makeStack(arguments);
List<Expression> argList = new ArrayList<>();
Expression result = null;
Expand All @@ -519,6 +519,7 @@ public Expression eval(ElimBody elimBody, List<? extends Expression> arguments,
int recursiveParam = -1;
int sucs = 0;

Expression resultExpr = origExpr;
ElimTree elimTree = elimBody.getElimTree();
while (true) {
for (int i = 0; i < elimTree.getSkip(); i++) {
Expand Down Expand Up @@ -588,7 +589,7 @@ public Expression eval(ElimBody elimBody, List<? extends Expression> arguments,
} else {
break;
}
} else if (resultExpr instanceof FunCallExpression funCall && ((FunCallExpression) resultExpr).getDefinition().getBody() instanceof Expression) {
} else if (resultExpr instanceof FunCallExpression funCall && funCall.getDefinition().getBody() instanceof Expression) {
resultExpr = Objects.requireNonNull((Expression) funCall.getDefinition().getBody()).subst(addArguments(new ExprSubstitution(), funCall.getDefCallArguments(), funCall.getDefinition()), funCall.getLevelSubstitution());
} else if (resultExpr instanceof ReferenceExpression && ((ReferenceExpression) resultExpr).getBinding() instanceof EvaluatingBinding) {
resultExpr = ((EvaluatingBinding) ((ReferenceExpression) resultExpr).getBinding()).getExpression();
Expand Down Expand Up @@ -651,7 +652,10 @@ public Expression eval(ElimBody elimBody, List<? extends Expression> arguments,
if (resultExpr instanceof SubstExpression) {
resultExpr = ((SubstExpression) resultExpr).eval();
}
if (mode == NormalizationMode.WHNF && resultExpr instanceof FunCallExpression funCall && ((FunCallExpression) resultExpr).getDefinition().getBody() instanceof ElimBody) {
if (shouldProgress && resultExpr == origExpr) {
return null;
}
if (mode == NormalizationMode.WHNF && resultExpr instanceof FunCallExpression funCall && funCall.getDefinition().getBody() instanceof ElimBody) {
List<Expression> newArgs = ((ElimBody) Objects.requireNonNull(funCall.getDefinition().getBody())).getElimTree().normalizeArguments(funCall.getDefCallArguments());
resultExpr = FunCallExpression.make(funCall.getDefinition(), funCall.getLevels(), newArgs);
}
Expand Down Expand Up @@ -715,7 +719,7 @@ private ElimTree updateStack(Deque<Expression> stack, List<Expression> argList,
}
if (elimTree == null && branchElimTree.isArray()) {
Expression type = argument.getType().normalize(NormalizationMode.WHNF);
if (type instanceof ClassCallExpression classCall && ((ClassCallExpression) type).getDefinition() == Prelude.DEP_ARRAY) {
if (type instanceof ClassCallExpression classCall && classCall.getDefinition() == Prelude.DEP_ARRAY) {
Expression length = classCall.getImplementationHere(Prelude.ARRAY_LENGTH, argument);
if (length != null) {
length = length.normalize(NormalizationMode.WHNF);
Expand Down Expand Up @@ -797,8 +801,7 @@ public Expression visitDefCall(DefCallExpression expr, NormalizationMode mode) {
}

public Expression evalFieldCall(ClassField field, Expression arg) {
if (arg instanceof FunCallExpression funCall && ((FunCallExpression) arg).getDefinition().getResultType() instanceof ClassCallExpression) {
ClassCallExpression classCall = (ClassCallExpression) funCall.getDefinition().getResultType();
if (arg instanceof FunCallExpression funCall && funCall.getDefinition().getResultType() instanceof ClassCallExpression classCall) {
Expression impl = classCall.getAbsImplementationHere(field);
if (impl != null) {
return impl.subst(new ExprSubstitution(classCall.getThisBinding(), arg).add(funCall.getDefinition().getParameters(), funCall.getDefCallArguments()), funCall.getLevelSubstitution());
Expand Down Expand Up @@ -1011,7 +1014,7 @@ public Expression visitLet(LetExpression let, NormalizationMode mode) {
@Override
public Expression visitCase(CaseExpression expr, NormalizationMode mode) {
if (!expr.isSCase()) {
Expression result = eval(expr.getElimBody(), expr.getArguments(), new ExprSubstitution(), LevelSubstitution.EMPTY, mode == NormalizationMode.WHNF ? expr : null, mode);
Expression result = eval(expr.getElimBody(), expr.getArguments(), new ExprSubstitution(), LevelSubstitution.EMPTY, mode == NormalizationMode.WHNF ? expr : null, mode, false);
if (result != null) {
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,12 @@ public Expression visitPEval(PEvalExpression expr, Expression expectedType) {
List<Expression> args = new ArrayList<>(3);
args.add(type);
args.add(expr.getExpression());
args.add(expr.eval());
Expression evaluated = expr.eval();
if (evaluated == null) {
throw new CoreException(CoreErrorWrapper.make(new TypecheckingError("Expression does not evaluate", mySourceNode), expr.getExpression()));
}

args.add(evaluated);
return check(expectedType, FunCallExpression.make(Prelude.PATH_INFIX, new LevelPair(sort.getPLevel(), sort.getHLevel()), args), expr);
}

Expand Down Expand Up @@ -675,7 +680,7 @@ private boolean checkElimPattern(Expression type, Pattern pattern, List<Binding>
if (constructor != Prelude.EMPTY_ARRAY && constructor != Prelude.ARRAY_CONS) {
throw new CoreException(CoreErrorWrapper.make(new TypecheckingError("Expected either '" + Prelude.EMPTY_ARRAY.getName() + "' or '" + Prelude.ARRAY_CONS.getName() + "'", mySourceNode), errorExpr));
}
if (!(type instanceof ClassCallExpression classCall && ((ClassCallExpression) type).getDefinition() == Prelude.DEP_ARRAY)) {
if (!(type instanceof ClassCallExpression classCall && classCall.getDefinition() == Prelude.DEP_ARRAY)) {
throw new CoreException(CoreErrorWrapper.make(new TypeMismatchError(new ClassCallExpression(Prelude.DEP_ARRAY, LevelPair.STD), type, mySourceNode), errorExpr));
}
Expression length = classCall.getAbsImplementationHere(Prelude.ARRAY_LENGTH);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ private boolean checkClause(ExtElimClause clause, Concrete.SourceNode sourceNode

Expression evaluatedExpr1;
if (elimBody != null && (definition == null || definition instanceof FunctionDefinition && (((FunctionDefinition) definition).getKind() == CoreFunctionDefinition.Kind.SFUNC || ((FunctionDefinition) definition).getKind() == CoreFunctionDefinition.Kind.TYPE) || expr instanceof GoalErrorExpression)) {
evaluatedExpr1 = NormalizeVisitor.INSTANCE.eval(elimBody, pair.proj1, new ExprSubstitution(), LevelSubstitution.EMPTY, null, null);
evaluatedExpr1 = NormalizeVisitor.INSTANCE.eval(elimBody, pair.proj1, new ExprSubstitution(), LevelSubstitution.EMPTY, null, null, false);
if (evaluatedExpr1 == null && definition != null) {
evaluatedExpr1 = definition.getDefCall(definition.makeIdLevels(), pair.proj1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,9 @@ private Expression normalizeRecursiveCalls(Expression expr, FunctionDefinition d
@Override
public Expression visitFunCall(FunCallExpression expr, Void params) {
if (expr.getDefinition() == def) {
Expression result = NormalizeVisitor.INSTANCE.eval(expr);
Expression result = NormalizeVisitor.INSTANCE.eval(expr, false);
if (result != null) {
if (result instanceof FunCallExpression funCall && ((FunCallExpression) result).getDefinition() == def) {
if (result instanceof FunCallExpression funCall && funCall.getDefinition() == def) {
List<Expression> args = new ArrayList<>();
for (Expression arg : funCall.getDefCallArguments()) {
args.add(arg.accept(this, null).normalize(NormalizationMode.WHNF));
Expand All @@ -343,7 +343,7 @@ private Expression evalBody(ExprSubstitution substitution, ElimBody body, List<E
for (Expression arg : args) {
substArgs.add(arg.subst(substitution));
}
return NormalizeVisitor.INSTANCE.eval(body, substArgs, new ExprSubstitution(), LevelSubstitution.EMPTY, null, null);
return NormalizeVisitor.INSTANCE.eval(body, substArgs, new ExprSubstitution(), LevelSubstitution.EMPTY, null, null, false);
}

private int getIntervalBindings(List<? extends ExpressionPattern> patterns, int index, List<Binding> result) {
Expand Down Expand Up @@ -628,7 +628,7 @@ private Result doTypechecking(List<Concrete.Pattern> patterns, DependentLink par
if (pattern instanceof Concrete.ConstructorPattern conPattern) {
Definition def = conPattern.getConstructor() instanceof TCDefReferable ? ((TCDefReferable) conPattern.getConstructor()).getTypechecked() : null;
if (def instanceof DConstructor constructor && def != Prelude.EMPTY_ARRAY && def != Prelude.ARRAY_CONS) {
if (myVisitor == null || ((DConstructor) def).getPattern() == null) {
if (myVisitor == null || constructor.getPattern() == null) {
return null;
}

Expand Down

0 comments on commit 88101db

Please sign in to comment.