Skip to content

Commit

Permalink
Evolog Modules: Revert to single-predicate module input
Browse files Browse the repository at this point in the history
  • Loading branch information
madmike200590 committed Jul 24, 2024
1 parent 761331c commit 60b819f
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public interface Module {

String getName();

Set<Predicate> getInputSpec();
Predicate getInputSpec();

Set<Predicate> getOutputSpec();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
class ModuleImpl implements Module {

private final String name;
private final Set<Predicate> inputSpec;
private final Predicate inputSpec;
private final Set<Predicate> outputSpec;
private final InputProgram implementation;

ModuleImpl(String name, Set<Predicate> inputSpec, Set<Predicate> outputSpec, InputProgram implementation) {
ModuleImpl(String name, Predicate inputSpec, Set<Predicate> outputSpec, InputProgram implementation) {
this.name = name;
this.inputSpec = inputSpec;
this.outputSpec = outputSpec;
Expand All @@ -26,7 +26,7 @@ public String getName() {
}

@Override
public Set<Predicate> getInputSpec() {
public Predicate getInputSpec() {
return this.inputSpec;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ private Modules() {
throw new AssertionError("Cannot instantiate utility class!");
}

public static Module newModule(final String name, final Set<Predicate> inputSpec, final Set<Predicate> outputSpec, final InputProgram implementation) {
public static Module newModule(final String name, final Predicate inputSpec, final Set<Predicate> outputSpec, final InputProgram implementation) {
return new ModuleImpl(name, inputSpec, outputSpec, implementation);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,5 @@ test_assert_all : TEST_ASSERT_ALL CURLY_OPEN statements? CURLY_CLOSE;

test_assert_some : TEST_ASSERT_SOME CURLY_OPEN statements? CURLY_CLOSE;

module_signature : CURLY_OPEN predicate_specs? CURLY_CLOSE ARROW CURLY_OPEN ('*' | predicate_specs) CURLY_CLOSE;
module_signature : predicate_spec ARROW CURLY_OPEN ('*' | predicate_specs) CURLY_CLOSE;

Original file line number Diff line number Diff line change
Expand Up @@ -335,18 +335,18 @@ public Object visitDirective_module(ASPCore2Parser.Directive_moduleContext ctx)
}
// directive_module: SHARP DIRECTIVE_MODULE id PAREN_OPEN module_signature PAREN_CLOSE CURLY_OPEN statements CURLY_CLOSE;
String name = visitId(ctx.id());
ImmutablePair<Set<Predicate>, Set<Predicate>> moduleSignature = visitModule_signature(ctx.module_signature());
ImmutablePair<Predicate, Set<Predicate>> moduleSignature = visitModule_signature(ctx.module_signature());
startNestedProgram();
visitStatements(ctx.statements());
InputProgram moduleImplementation = endNestedProgram();
currentLevelProgramBuilder.addModule(Modules.newModule(name, moduleSignature.getLeft(), moduleSignature.getRight(), moduleImplementation));
return null;
}

public ImmutablePair<Set<Predicate>, Set<Predicate>> visitModule_signature(ASPCore2Parser.Module_signatureContext ctx) {
Set<Predicate> inputPredicates = ctx.predicate_specs(0) != null ? visitPredicate_specs(ctx.predicate_specs(0)) : Collections.emptySet();
Set<Predicate> outputPredicates = ctx.predicate_specs(1) != null ? visitPredicate_specs(ctx.predicate_specs(1)) : Collections.emptySet();
return ImmutablePair.of(inputPredicates, outputPredicates);
public ImmutablePair<Predicate, Set<Predicate>> visitModule_signature(ASPCore2Parser.Module_signatureContext ctx) {
Predicate inputPredicate = visitPredicate_spec(ctx.predicate_spec());
Set<Predicate> outputPredicates = ctx.predicate_specs() != null ? visitPredicate_specs(ctx.predicate_specs()) : Collections.emptySet();
return ImmutablePair.of(inputPredicate, outputPredicates);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package at.ac.tuwien.kr.alpha.core.programs.transformation;

import at.ac.tuwien.kr.alpha.api.Alpha;
import at.ac.tuwien.kr.alpha.api.programs.NormalProgram;
import at.ac.tuwien.kr.alpha.api.programs.atoms.ExternalAtom;
import at.ac.tuwien.kr.alpha.api.programs.atoms.ModuleAtom;
import at.ac.tuwien.kr.alpha.api.programs.literals.Literal;
import at.ac.tuwien.kr.alpha.api.programs.literals.ModuleLiteral;
import at.ac.tuwien.kr.alpha.api.programs.modules.Module;
import at.ac.tuwien.kr.alpha.api.programs.rules.NormalRule;
import at.ac.tuwien.kr.alpha.api.programs.rules.heads.NormalHead;
import at.ac.tuwien.kr.alpha.commons.programs.rules.Rules;

import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* Program transformation that translates {@link at.ac.tuwien.kr.alpha.api.programs.literals.ModuleLiteral}s into
* {@link at.ac.tuwien.kr.alpha.api.programs.literals.ExternalLiteral}s by constructing {@link at.ac.tuwien.kr.alpha.api.programs.atoms.ExternalAtom}s
* which solve the ASP implementation of the module with the given inputs.
*/
public class ModuleLinker extends ProgramTransformation<NormalProgram, NormalProgram> {

// Note: References to a standard library of modules that are always available for linking should be member variables of a linker.

private final Alpha moduleRunner;

public ModuleLinker(Alpha moduleRunner) {
this.moduleRunner = moduleRunner;
}


@Override
public NormalProgram apply(NormalProgram inputProgram) {
Map<String, Module> moduleTable = inputProgram.getModules().stream().collect(Collectors.toMap(Module::getName, Function.identity()));
List<NormalRule> transformedRules = inputProgram.getRules().stream()
.map(rule -> containsModuleAtom(rule) ? linkModuleAtoms(rule, moduleTable) : rule)
.collect(Collectors.toList());
return null;
}

private NormalRule linkModuleAtoms(NormalRule rule, Map<String, Module> moduleTable) {
NormalHead newHead = rule.getHead();
Set<Literal> newBody = rule.getBody().stream()
.map(literal -> {
if (literal instanceof ModuleLiteral) {
ModuleLiteral moduleLiteral = (ModuleLiteral) literal;
return translateModuleAtom(moduleLiteral.getAtom(), moduleTable).toLiteral(!moduleLiteral.isNegated());
} else {
return literal;
}
})
.collect(Collectors.toSet());
return Rules.newNormalRule(newHead, newBody);
}

private ExternalAtom translateModuleAtom(ModuleAtom moduleAtom, Map<String, Module> moduleTable) {
if (!moduleTable.containsKey(moduleAtom.getModuleName())) {
throw new IllegalArgumentException("Module " + moduleAtom.getModuleName() + " not found in module table.");
}
Module implementationModule = moduleTable.get(moduleAtom.getModuleName());
//implementationModule.
return null;
}

private static boolean containsModuleAtom(NormalRule rule) {
return rule.getBody().stream().anyMatch(literal -> literal instanceof ModuleLiteral);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ public class ParserTest {

private static final String UNIT_TEST_EXPECT_UNSAT =
"p(1). p(2). "
+ ":- p(X), p(Y), X + Y = 3."
+ "#test expected_unsat(expect: unsat) {"
+ "given {}"
+ "}";
+ ":- p(X), p(Y), X + Y = 3."
+ "#test expected_unsat(expect: unsat) {"
+ "given {}"
+ "}";

private static final String UNIT_TEST_BASIC_TEST =
"a :- b. #test ensure_a(expect: 1) { given { b. } assertForAll { :- not a. } }";
Expand All @@ -86,17 +86,13 @@ public class ParserTest {
private static final String UNIT_TEST_KEYWORDS_AS_IDS =
"assert(a) :- given(b). # test test(expect: 1) { given { given(b). } assertForAll { :- not assert(a). :- assertForSome(b).}}";

private static final String MODULE_SIMPLE = "#module aSimpleModule({input/1} => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }";
private static final String MODULE_SIMPLE = "#module aSimpleModule(input/1 => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }";

private static final String MODULE_OUTPUT_ALL = "#module mod({in/1} => {*}) { a(X). b(X) :- a(X).}";
private static final String MODULE_OUTPUT_ALL = "#module mod(in/1 => {*}) { a(X). b(X) :- a(X).}";

private static final String MODULE_WITH_REGULAR_STMTS = "p(a). p(b). q(X) :- p(X). #module aSimpleModule({input/1} => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }";
private static final String MODULE_WITH_REGULAR_STMTS = "p(a). p(b). q(X) :- p(X). #module aSimpleModule(input/1 => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }";

private static final String MODULE_MULTIPLE_DEFINITIONS = "a. b(5). #module aSimpleModule({input/1} => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). } q(Y) :- r(S, Y), t(S). #module anotherModule({input/1} => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }";

private static final String MODULE_EMPTY_INPUT_SPEC = "#module someModule({} => {*}) {p(a).}";

private static final String MODULE_MULTIPLE_INPUTS = "#module someModule({input/1, input2/2} => {out1/2}) {p(a).}";
private static final String MODULE_MULTIPLE_DEFINITIONS = "a. b(5). #module aSimpleModule(input/1 => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). } q(Y) :- r(S, Y), t(S). #module anotherModule(input/1 => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }";

private static final String MODULE_LITERAL = "p(a). q(b). r(X) :- p(X), q(Y), #mod[X, Y](X).";

Expand Down Expand Up @@ -336,11 +332,9 @@ public void simpleModule() {
assertEquals(1, modules.size());
Module module = modules.get(0);
assertEquals("aSimpleModule", module.getName());
Set<Predicate> inputSpec = module.getInputSpec();
assertEquals(1, inputSpec.size());
Predicate inputPredicate = inputSpec.iterator().next();
assertEquals("input", inputPredicate.getName());
assertEquals(1, inputPredicate.getArity());
Predicate inputSpec = module.getInputSpec();
assertEquals("input", inputSpec.getName());
assertEquals(1, inputSpec.getArity());
Set<Predicate> outputSpec = module.getOutputSpec();
assertEquals(2, outputSpec.size());
assertTrue(outputSpec.contains(Predicates.getPredicate("out1", 2)));
Expand All @@ -358,11 +352,9 @@ public void moduleOutputAll() {
assertEquals(1, modules.size());
Module module = modules.get(0);
assertEquals("mod", module.getName());
Set<Predicate> inputSpec = module.getInputSpec();
assertEquals(1, inputSpec.size());
Predicate inputPredicate = inputSpec.iterator().next();
assertEquals("in", inputPredicate.getName());
assertEquals(1, inputPredicate.getArity());
Predicate inputSpec = module.getInputSpec();
assertEquals("in", inputSpec.getName());
assertEquals(1, inputSpec.getArity());
assertTrue(module.getOutputSpec().isEmpty());
InputProgram implementation = module.getImplementation();
assertEquals(1, implementation.getFacts().size());
Expand All @@ -379,11 +371,9 @@ public void moduleAndRegularStmts() {
assertEquals(1, modules.size());
Module module = modules.get(0);
assertEquals("aSimpleModule", module.getName());
Set<Predicate> inputSpec = module.getInputSpec();
assertEquals(1, inputSpec.size());
Predicate inputPredicate = inputSpec.iterator().next();
assertEquals("input", inputPredicate.getName());
assertEquals(1, inputPredicate.getArity());
Predicate inputSpec = module.getInputSpec();
assertEquals("input", inputSpec.getName());
assertEquals(1, inputSpec.getArity());
Set<Predicate> outputSpec = module.getOutputSpec();
assertEquals(2, outputSpec.size());
assertTrue(outputSpec.contains(Predicates.getPredicate("out1", 2)));
Expand All @@ -407,42 +397,13 @@ public void multipleModuleDefinitions() {
@Test
public void invalidNestedModule() {
assertThrows(IllegalStateException.class, () ->
parser.parse("#module aSimpleModule({input/1} => {out1/2, out2/3}) { p(a). p(b). #module anotherModule({input/1} => {out1/2, out2/3}) { p(a). p(b). } }"));
parser.parse("#module aSimpleModule(input/1 => {out1/2, out2/3}) { p(a). p(b). #module anotherModule(input/1 => {out1/2, out2/3}) { p(a). p(b). } }"));
}

@Test
public void invalidNestedTest() {
assertThrows(IllegalStateException.class, () ->
parser.parse("#module mod({foo/1} => {*}) { #test test(expect: 1) { given { b. } assertForAll { :- a. } } }"));
}

@Test
public void emptyInputSpec() {
InputProgram prog = parser.parse(MODULE_EMPTY_INPUT_SPEC);
List<Module> modules = prog.getModules();
assertEquals(1, modules.size());
Module module = modules.get(0);
assertEquals("someModule", module.getName());
Set<Predicate> inputSpec = module.getInputSpec();
assertTrue(inputSpec.isEmpty());
Set<Predicate> outputSpec = module.getOutputSpec();
assertTrue(outputSpec.isEmpty());
}

@Test
public void multipleInputs() {
InputProgram prog = parser.parse(MODULE_MULTIPLE_INPUTS);
List<Module> modules = prog.getModules();
assertEquals(1, modules.size());
Module module = modules.get(0);
assertEquals("someModule", module.getName());
Set<Predicate> inputSpec = module.getInputSpec();
assertEquals(2, inputSpec.size());
assertTrue(inputSpec.contains(Predicates.getPredicate("input", 1)));
assertTrue(inputSpec.contains(Predicates.getPredicate("input2", 2)));
Set<Predicate> outputSpec = module.getOutputSpec();
assertEquals(1, outputSpec.size());
assertTrue(outputSpec.contains(Predicates.getPredicate("out1", 2)));
parser.parse("#module mod(foo/1 => {*}) { #test test(expect: 1) { given { b. } assertForAll { :- a. } } }"));
}

@Test
Expand Down

0 comments on commit 60b819f

Please sign in to comment.