diff --git a/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/programs/modules/Module.java b/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/programs/modules/Module.java index 10b268d10..d43bdce6c 100644 --- a/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/programs/modules/Module.java +++ b/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/programs/modules/Module.java @@ -9,7 +9,7 @@ public interface Module { String getName(); - Set getInputSpec(); + Predicate getInputSpec(); Set getOutputSpec(); diff --git a/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/modules/ModuleImpl.java b/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/modules/ModuleImpl.java index 06dbe385f..be4fd2d31 100644 --- a/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/modules/ModuleImpl.java +++ b/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/modules/ModuleImpl.java @@ -9,11 +9,11 @@ class ModuleImpl implements Module { private final String name; - private final Set inputSpec; + private final Predicate inputSpec; private final Set outputSpec; private final InputProgram implementation; - ModuleImpl(String name, Set inputSpec, Set outputSpec, InputProgram implementation) { + ModuleImpl(String name, Predicate inputSpec, Set outputSpec, InputProgram implementation) { this.name = name; this.inputSpec = inputSpec; this.outputSpec = outputSpec; @@ -26,7 +26,7 @@ public String getName() { } @Override - public Set getInputSpec() { + public Predicate getInputSpec() { return this.inputSpec; } diff --git a/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/modules/Modules.java b/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/modules/Modules.java index 57dc7d2ef..0d8d04e0b 100644 --- a/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/modules/Modules.java +++ b/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/modules/Modules.java @@ -12,7 +12,7 @@ private Modules() { throw new AssertionError("Cannot instantiate utility class!"); } - public static Module newModule(final String name, final Set inputSpec, final Set outputSpec, final InputProgram implementation) { + public static Module newModule(final String name, final Predicate inputSpec, final Set outputSpec, final InputProgram implementation) { return new ModuleImpl(name, inputSpec, outputSpec, implementation); } diff --git a/alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPCore2.g4 b/alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPCore2.g4 index 38af994f0..876c4e190 100644 --- a/alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPCore2.g4 +++ b/alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPCore2.g4 @@ -121,5 +121,5 @@ test_assert_all : TEST_ASSERT_ALL CURLY_OPEN statements? CURLY_CLOSE; test_assert_some : TEST_ASSERT_SOME CURLY_OPEN statements? CURLY_CLOSE; -module_signature : CURLY_OPEN predicate_specs? CURLY_CLOSE ARROW CURLY_OPEN ('*' | predicate_specs) CURLY_CLOSE; +module_signature : predicate_spec ARROW CURLY_OPEN ('*' | predicate_specs) CURLY_CLOSE; diff --git a/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/parser/ParseTreeVisitor.java b/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/parser/ParseTreeVisitor.java index 4f6e05561..16226a168 100644 --- a/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/parser/ParseTreeVisitor.java +++ b/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/parser/ParseTreeVisitor.java @@ -335,7 +335,7 @@ public Object visitDirective_module(ASPCore2Parser.Directive_moduleContext ctx) } // directive_module: SHARP DIRECTIVE_MODULE id PAREN_OPEN module_signature PAREN_CLOSE CURLY_OPEN statements CURLY_CLOSE; String name = visitId(ctx.id()); - ImmutablePair, Set> moduleSignature = visitModule_signature(ctx.module_signature()); + ImmutablePair> moduleSignature = visitModule_signature(ctx.module_signature()); startNestedProgram(); visitStatements(ctx.statements()); InputProgram moduleImplementation = endNestedProgram(); @@ -343,10 +343,10 @@ public Object visitDirective_module(ASPCore2Parser.Directive_moduleContext ctx) return null; } - public ImmutablePair, Set> visitModule_signature(ASPCore2Parser.Module_signatureContext ctx) { - Set inputPredicates = ctx.predicate_specs(0) != null ? visitPredicate_specs(ctx.predicate_specs(0)) : Collections.emptySet(); - Set outputPredicates = ctx.predicate_specs(1) != null ? visitPredicate_specs(ctx.predicate_specs(1)) : Collections.emptySet(); - return ImmutablePair.of(inputPredicates, outputPredicates); + public ImmutablePair> visitModule_signature(ASPCore2Parser.Module_signatureContext ctx) { + Predicate inputPredicate = visitPredicate_spec(ctx.predicate_spec()); + Set outputPredicates = ctx.predicate_specs() != null ? visitPredicate_specs(ctx.predicate_specs()) : Collections.emptySet(); + return ImmutablePair.of(inputPredicate, outputPredicates); } @Override diff --git a/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/programs/transformation/ModuleLinker.java b/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/programs/transformation/ModuleLinker.java new file mode 100644 index 000000000..8a5e6c867 --- /dev/null +++ b/alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/programs/transformation/ModuleLinker.java @@ -0,0 +1,73 @@ +package at.ac.tuwien.kr.alpha.core.programs.transformation; + +import at.ac.tuwien.kr.alpha.api.Alpha; +import at.ac.tuwien.kr.alpha.api.programs.NormalProgram; +import at.ac.tuwien.kr.alpha.api.programs.atoms.ExternalAtom; +import at.ac.tuwien.kr.alpha.api.programs.atoms.ModuleAtom; +import at.ac.tuwien.kr.alpha.api.programs.literals.Literal; +import at.ac.tuwien.kr.alpha.api.programs.literals.ModuleLiteral; +import at.ac.tuwien.kr.alpha.api.programs.modules.Module; +import at.ac.tuwien.kr.alpha.api.programs.rules.NormalRule; +import at.ac.tuwien.kr.alpha.api.programs.rules.heads.NormalHead; +import at.ac.tuwien.kr.alpha.commons.programs.rules.Rules; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Program transformation that translates {@link at.ac.tuwien.kr.alpha.api.programs.literals.ModuleLiteral}s into + * {@link at.ac.tuwien.kr.alpha.api.programs.literals.ExternalLiteral}s by constructing {@link at.ac.tuwien.kr.alpha.api.programs.atoms.ExternalAtom}s + * which solve the ASP implementation of the module with the given inputs. + */ +public class ModuleLinker extends ProgramTransformation { + + // Note: References to a standard library of modules that are always available for linking should be member variables of a linker. + + private final Alpha moduleRunner; + + public ModuleLinker(Alpha moduleRunner) { + this.moduleRunner = moduleRunner; + } + + + @Override + public NormalProgram apply(NormalProgram inputProgram) { + Map moduleTable = inputProgram.getModules().stream().collect(Collectors.toMap(Module::getName, Function.identity())); + List transformedRules = inputProgram.getRules().stream() + .map(rule -> containsModuleAtom(rule) ? linkModuleAtoms(rule, moduleTable) : rule) + .collect(Collectors.toList()); + return null; + } + + private NormalRule linkModuleAtoms(NormalRule rule, Map moduleTable) { + NormalHead newHead = rule.getHead(); + Set newBody = rule.getBody().stream() + .map(literal -> { + if (literal instanceof ModuleLiteral) { + ModuleLiteral moduleLiteral = (ModuleLiteral) literal; + return translateModuleAtom(moduleLiteral.getAtom(), moduleTable).toLiteral(!moduleLiteral.isNegated()); + } else { + return literal; + } + }) + .collect(Collectors.toSet()); + return Rules.newNormalRule(newHead, newBody); + } + + private ExternalAtom translateModuleAtom(ModuleAtom moduleAtom, Map moduleTable) { + if (!moduleTable.containsKey(moduleAtom.getModuleName())) { + throw new IllegalArgumentException("Module " + moduleAtom.getModuleName() + " not found in module table."); + } + Module implementationModule = moduleTable.get(moduleAtom.getModuleName()); + //implementationModule. + return null; + } + + private static boolean containsModuleAtom(NormalRule rule) { + return rule.getBody().stream().anyMatch(literal -> literal instanceof ModuleLiteral); + } + +} diff --git a/alpha-core/src/test/java/at/ac/tuwien/kr/alpha/core/parser/ParserTest.java b/alpha-core/src/test/java/at/ac/tuwien/kr/alpha/core/parser/ParserTest.java index 13e63878c..624658a64 100644 --- a/alpha-core/src/test/java/at/ac/tuwien/kr/alpha/core/parser/ParserTest.java +++ b/alpha-core/src/test/java/at/ac/tuwien/kr/alpha/core/parser/ParserTest.java @@ -69,10 +69,10 @@ public class ParserTest { private static final String UNIT_TEST_EXPECT_UNSAT = "p(1). p(2). " - + ":- p(X), p(Y), X + Y = 3." - + "#test expected_unsat(expect: unsat) {" - + "given {}" - + "}"; + + ":- p(X), p(Y), X + Y = 3." + + "#test expected_unsat(expect: unsat) {" + + "given {}" + + "}"; private static final String UNIT_TEST_BASIC_TEST = "a :- b. #test ensure_a(expect: 1) { given { b. } assertForAll { :- not a. } }"; @@ -86,17 +86,13 @@ public class ParserTest { private static final String UNIT_TEST_KEYWORDS_AS_IDS = "assert(a) :- given(b). # test test(expect: 1) { given { given(b). } assertForAll { :- not assert(a). :- assertForSome(b).}}"; - private static final String MODULE_SIMPLE = "#module aSimpleModule({input/1} => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }"; + private static final String MODULE_SIMPLE = "#module aSimpleModule(input/1 => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }"; - private static final String MODULE_OUTPUT_ALL = "#module mod({in/1} => {*}) { a(X). b(X) :- a(X).}"; + private static final String MODULE_OUTPUT_ALL = "#module mod(in/1 => {*}) { a(X). b(X) :- a(X).}"; - private static final String MODULE_WITH_REGULAR_STMTS = "p(a). p(b). q(X) :- p(X). #module aSimpleModule({input/1} => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }"; + private static final String MODULE_WITH_REGULAR_STMTS = "p(a). p(b). q(X) :- p(X). #module aSimpleModule(input/1 => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }"; - private static final String MODULE_MULTIPLE_DEFINITIONS = "a. b(5). #module aSimpleModule({input/1} => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). } q(Y) :- r(S, Y), t(S). #module anotherModule({input/1} => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }"; - - private static final String MODULE_EMPTY_INPUT_SPEC = "#module someModule({} => {*}) {p(a).}"; - - private static final String MODULE_MULTIPLE_INPUTS = "#module someModule({input/1, input2/2} => {out1/2}) {p(a).}"; + private static final String MODULE_MULTIPLE_DEFINITIONS = "a. b(5). #module aSimpleModule(input/1 => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). } q(Y) :- r(S, Y), t(S). #module anotherModule(input/1 => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }"; private static final String MODULE_LITERAL = "p(a). q(b). r(X) :- p(X), q(Y), #mod[X, Y](X)."; @@ -336,11 +332,9 @@ public void simpleModule() { assertEquals(1, modules.size()); Module module = modules.get(0); assertEquals("aSimpleModule", module.getName()); - Set inputSpec = module.getInputSpec(); - assertEquals(1, inputSpec.size()); - Predicate inputPredicate = inputSpec.iterator().next(); - assertEquals("input", inputPredicate.getName()); - assertEquals(1, inputPredicate.getArity()); + Predicate inputSpec = module.getInputSpec(); + assertEquals("input", inputSpec.getName()); + assertEquals(1, inputSpec.getArity()); Set outputSpec = module.getOutputSpec(); assertEquals(2, outputSpec.size()); assertTrue(outputSpec.contains(Predicates.getPredicate("out1", 2))); @@ -358,11 +352,9 @@ public void moduleOutputAll() { assertEquals(1, modules.size()); Module module = modules.get(0); assertEquals("mod", module.getName()); - Set inputSpec = module.getInputSpec(); - assertEquals(1, inputSpec.size()); - Predicate inputPredicate = inputSpec.iterator().next(); - assertEquals("in", inputPredicate.getName()); - assertEquals(1, inputPredicate.getArity()); + Predicate inputSpec = module.getInputSpec(); + assertEquals("in", inputSpec.getName()); + assertEquals(1, inputSpec.getArity()); assertTrue(module.getOutputSpec().isEmpty()); InputProgram implementation = module.getImplementation(); assertEquals(1, implementation.getFacts().size()); @@ -379,11 +371,9 @@ public void moduleAndRegularStmts() { assertEquals(1, modules.size()); Module module = modules.get(0); assertEquals("aSimpleModule", module.getName()); - Set inputSpec = module.getInputSpec(); - assertEquals(1, inputSpec.size()); - Predicate inputPredicate = inputSpec.iterator().next(); - assertEquals("input", inputPredicate.getName()); - assertEquals(1, inputPredicate.getArity()); + Predicate inputSpec = module.getInputSpec(); + assertEquals("input", inputSpec.getName()); + assertEquals(1, inputSpec.getArity()); Set outputSpec = module.getOutputSpec(); assertEquals(2, outputSpec.size()); assertTrue(outputSpec.contains(Predicates.getPredicate("out1", 2))); @@ -407,42 +397,13 @@ public void multipleModuleDefinitions() { @Test public void invalidNestedModule() { assertThrows(IllegalStateException.class, () -> - parser.parse("#module aSimpleModule({input/1} => {out1/2, out2/3}) { p(a). p(b). #module anotherModule({input/1} => {out1/2, out2/3}) { p(a). p(b). } }")); + parser.parse("#module aSimpleModule(input/1 => {out1/2, out2/3}) { p(a). p(b). #module anotherModule(input/1 => {out1/2, out2/3}) { p(a). p(b). } }")); } @Test public void invalidNestedTest() { assertThrows(IllegalStateException.class, () -> - parser.parse("#module mod({foo/1} => {*}) { #test test(expect: 1) { given { b. } assertForAll { :- a. } } }")); - } - - @Test - public void emptyInputSpec() { - InputProgram prog = parser.parse(MODULE_EMPTY_INPUT_SPEC); - List modules = prog.getModules(); - assertEquals(1, modules.size()); - Module module = modules.get(0); - assertEquals("someModule", module.getName()); - Set inputSpec = module.getInputSpec(); - assertTrue(inputSpec.isEmpty()); - Set outputSpec = module.getOutputSpec(); - assertTrue(outputSpec.isEmpty()); - } - - @Test - public void multipleInputs() { - InputProgram prog = parser.parse(MODULE_MULTIPLE_INPUTS); - List modules = prog.getModules(); - assertEquals(1, modules.size()); - Module module = modules.get(0); - assertEquals("someModule", module.getName()); - Set inputSpec = module.getInputSpec(); - assertEquals(2, inputSpec.size()); - assertTrue(inputSpec.contains(Predicates.getPredicate("input", 1))); - assertTrue(inputSpec.contains(Predicates.getPredicate("input2", 2))); - Set outputSpec = module.getOutputSpec(); - assertEquals(1, outputSpec.size()); - assertTrue(outputSpec.contains(Predicates.getPredicate("out1", 2))); + parser.parse("#module mod(foo/1 => {*}) { #test test(expect: 1) { given { b. } assertForAll { :- a. } } }")); } @Test