Skip to content

Commit

Permalink
Evolog Modules: aggregate terms into list using aggregate literal
Browse files Browse the repository at this point in the history
  • Loading branch information
madmike200590 committed Aug 6, 2024
1 parent d1f349e commit 3540762
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
public final class Terms {

public static final String LIST_TERM_SYMBOL = "lst";
public static final ConstantTerm<String> EMPTY_LIST = Terms.newSymbolicConstant("emptyList");
public static final ConstantTerm<String> EMPTY_LIST = Terms.newSymbolicConstant("lst_empty");

/**
* Since this is purely a utility class, it may not be instantiated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class AggregateRewriting extends ProgramTransformation<InputProgram, Inpu
private final AbstractAggregateEncoder sumLessOrEqualEncoder;
private final AbstractAggregateEncoder minEncoder;
private final AbstractAggregateEncoder maxEncoder;
private final AbstractAggregateEncoder listEncoder;

/**
* Creates a new {@link AggregateRewriting} transformation.
Expand All @@ -51,6 +52,8 @@ public AggregateRewriting(boolean useSortingCircuit, boolean supportNegativeInte
this.countEqualsEncoder = AggregateEncoders.newCountEqualsEncoder();
this.minEncoder = AggregateEncoders.newMinEncoder();
this.maxEncoder = AggregateEncoders.newMaxEncoder();
this.listEncoder = AggregateEncoders.newListEncoder();

}

/**
Expand Down Expand Up @@ -117,6 +120,8 @@ private AbstractAggregateEncoder getEncoderForAggregateFunction(AggregateFunctio
} else {
throw new UnsupportedOperationException("No fitting encoder for aggregate function " + function + "and operator " + operator + "!");
}
case LIST:
return listEncoder;
default:
throw new UnsupportedOperationException("Unsupported aggregate function/comparison operator: " + function + ", " + operator);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package at.ac.tuwien.kr.alpha.core.programs.transformation.aggregates.encoders;

import at.ac.tuwien.kr.alpha.api.programs.atoms.AggregateAtom.AggregateFunctionSymbol;
import at.ac.tuwien.kr.alpha.core.parser.ProgramParserImpl;

public final class AggregateEncoders {

Expand Down Expand Up @@ -32,4 +33,8 @@ public static MinMaxEncoder newMaxEncoder() {
return new MinMaxEncoder(AggregateFunctionSymbol.MAX);
}

public static AbstractAggregateEncoder newListEncoder() {
return new ListEncoder(new ProgramParserImpl());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package at.ac.tuwien.kr.alpha.core.programs.transformation.aggregates.encoders;

import at.ac.tuwien.kr.alpha.api.programs.InputProgram;
import at.ac.tuwien.kr.alpha.api.programs.Predicate;
import at.ac.tuwien.kr.alpha.api.programs.ProgramParser;
import at.ac.tuwien.kr.alpha.api.programs.atoms.AggregateAtom;
import at.ac.tuwien.kr.alpha.api.programs.atoms.BasicAtom;
import at.ac.tuwien.kr.alpha.api.programs.terms.Term;
import at.ac.tuwien.kr.alpha.commons.Predicates;
import at.ac.tuwien.kr.alpha.commons.comparisons.ComparisonOperators;
import at.ac.tuwien.kr.alpha.commons.programs.atoms.Atoms;
import at.ac.tuwien.kr.alpha.commons.util.Util;
import at.ac.tuwien.kr.alpha.core.programs.transformation.aggregates.AggregateRewritingContext;
import org.stringtemplate.v4.ST;

import java.util.Set;

public class ListEncoder extends AbstractAggregateEncoder {

private static final ST LIST_AGGREGATION = Util.aspStringTemplate(
// First, establish ordering of elements (which we need to establish the order within the list)
"$id$_element_greater(ARGS, N, K) :- $id$_element(ARGS, N), $id$_element(ARGS, K), N > K. " +
"$id$_element_not_successor(ARGS, N, K) :- $id$_element_greater(ARGS, N, I), $id$_element_greater(ARGS, I, K). " +
"$id$_element_successor(ARGS, N, K) :- $id$_element_greater(ARGS, N, K), not $id$_element_not_successor(ARGS, N, K). " +
"$id$_element_has_successor(ARGS, N) :- $id$_element_successor(ARGS, _, N). " +
// Now build the list as a recursively nested function term
"$id$_lst_element(ARGS, IDX, lst(N, lst_empty)) :- $id$_element(ARGS, N), not $id$_element_has_successor(ARGS, N), IDX = 0. " +
"$id$_lst_element(ARGS, IDX, lst(N, lst(K, TAIL))) :- $id$_element(ARGS, N), $id$_element_successor(ARGS, K, N), $id$_lst_element(ARGS, PREV_IDX, lst(K, TAIL)), IDX = PREV_IDX + 1. " +
"has_next_$id$_element(ARGS, IDX) :- $id$_lst_element(ARGS, IDX, _), NEXT_IDX = IDX + 1, $id$_lst_element(ARGS, NEXT_IDX, _). " +
"$aggregate_result$(ARGS, LIST) :- $id$_lst_element(ARGS, IDX, LIST), not has_next_$id$_element(ARGS, IDX).");

private final ProgramParser parser;

protected ListEncoder(ProgramParser parser) {
super(AggregateAtom.AggregateFunctionSymbol.LIST, Set.of(ComparisonOperators.EQ));
this.parser = parser;
}

@Override
protected InputProgram encodeAggregateResult(AggregateRewritingContext.AggregateInfo aggregateToEncode) {
ST encodingTemplate = new ST(LIST_AGGREGATION);
encodingTemplate.add("id", aggregateToEncode.getId());
encodingTemplate.add("aggregate_result", aggregateToEncode.getOutputAtom().getPredicate().getName());
return parser.parse(encodingTemplate.render());
}

@Override
protected BasicAtom buildElementRuleHead(String aggregateId, AggregateAtom.AggregateElement element, Term aggregateArguments) {
Predicate headPredicate = Predicates.getPredicate(this.getElementTuplePredicateSymbol(aggregateId), 2);
if (element.getElementTerms().size() != 1) {
throw new IllegalArgumentException("List elements may only consist of one term.");
}
Term value = element.getElementTerms().get(0);
return Atoms.newBasicAtom(headPredicate, aggregateArguments, value);
}

@Override
protected String getElementTuplePredicateSymbol(String aggregateId) {
return aggregateId + "_element";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import at.ac.tuwien.kr.alpha.api.impl.AlphaFactory;
import at.ac.tuwien.kr.alpha.api.programs.InputProgram;
import at.ac.tuwien.kr.alpha.api.programs.Predicate;
import at.ac.tuwien.kr.alpha.api.programs.atoms.Atom;
import at.ac.tuwien.kr.alpha.api.programs.terms.Term;
import at.ac.tuwien.kr.alpha.commons.Predicates;
import at.ac.tuwien.kr.alpha.commons.programs.atoms.Atoms;
import at.ac.tuwien.kr.alpha.commons.programs.terms.Terms;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.SortedSet;
import java.util.function.Function;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -74,6 +77,10 @@ public class AggregateRewritingTest {
"p(1..10)."
+ "q :- X = #count { Y : p( Y ) }, X = #count { Z : p( Z ) },"
+ " Y = #count { X : p( X ) }, 1 <= #count { X : p( X ) }, Z = #max { W : p( W ) }.";

private static final String LIST_COLLECT =
"p(1). p(2). p(3)."
+ " q(X) :- X = #list{ Y : p(Y) }.";
//@formatter:on

// Use an alpha instance with default config for all test cases
Expand Down Expand Up @@ -233,4 +240,18 @@ public void setComplexEqualityWithGlobals() {
assertTrue(answerSet.getPredicateInstances(q).contains(Atoms.newBasicAtom(q)));
}

@Test
public void listCollect() {
List<AnswerSet> answerSets = solve.apply(LIST_COLLECT);
assertEquals(1, answerSets.size());
AnswerSet answerSet = answerSets.get(0);
Predicate q = Predicates.getPredicate("q", 1);
SortedSet<Atom> instances = answerSet.getPredicateInstances(q);
assertEquals(1, instances.size());
Atom instance = instances.first();
assertEquals(1, instance.getTerms().size());
Term term = instance.getTerms().get(0);
assertEquals(Terms.asListTerm(List.of(Terms.newConstant(1), Terms.newConstant(2), Terms.newConstant(3))), term);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ private static DynamicTest alphaEnd2EndTest(String testName, String... fileset)
Stream<DynamicTest> alphaEnd2EndTests() {
return Stream.of(
alphaEnd2EndTest("3-Coloring", E2E_TESTS_DIR + "3col.asp"),
alphaEnd2EndTest("modules-basic", E2E_TESTS_DIR + "modules-basic.evl")
alphaEnd2EndTest("modules-basic", E2E_TESTS_DIR + "modules-basic.evl"),
alphaEnd2EndTest("neighboring-vertices-list", E2E_TESTS_DIR + "neighboring-vertices-list.evl")
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
% Graph is undirected
edge(Y, X) :- edge(X, Y).

%% Generate list of neighboring vertices for each vertex in the graph
neighbor(V, N) :- vertex(V), vertex(N), edge(V, N).
neighbors(V, LST) :- vertex(V), LST = #list{ N : neighbor(V, N)}.

#test lineGraph(expect: 1) {
given {
vertex(1). vertex(2). edge(1, 2).
}
assertForAll {
:- not neighbors(1, lst(2, lst_empty)).
:- not neighbors(2, lst(1, lst_empty)).
}
}

#test network(expect: 1) {
given {
vertex(1..8).

edge(1, 2).
edge(1, 3).
edge(1, 4).
edge(2, 5).
edge(1, 3).
edge(1, 4).
edge(2, 5).
edge(3, 6).
edge(4, 7).
edge(5, 8).
edge(6, 8).
edge(7, 8).
}
assertForAll {
:- not neighbors(1, lst(2, lst(3, lst(4, lst_empty)))).
:- not neighbors(2, lst(1, lst(5, lst_empty))).
:- not neighbors(3, lst(1, lst(6, lst_empty))).
:- not neighbors(4, lst(1, lst(7, lst_empty))).
:- not neighbors(5, lst(2, lst(8, lst_empty))).
:- not neighbors(6, lst(3, lst(8, lst_empty))).
:- not neighbors(7, lst(4, lst(8, lst_empty))).
:- not neighbors(8, lst(5, lst(6, lst(7, lst_empty)))).
}
}

0 comments on commit 3540762

Please sign in to comment.