diff --git a/subprojects/interpreter-rete/src/main/java/tools/refinery/interpreter/rete/aggregation/ColumnAggregatorNode.java b/subprojects/interpreter-rete/src/main/java/tools/refinery/interpreter/rete/aggregation/ColumnAggregatorNode.java index d5f4a0e4a..a055cc809 100644 --- a/subprojects/interpreter-rete/src/main/java/tools/refinery/interpreter/rete/aggregation/ColumnAggregatorNode.java +++ b/subprojects/interpreter-rete/src/main/java/tools/refinery/interpreter/rete/aggregation/ColumnAggregatorNode.java @@ -30,6 +30,7 @@ import java.util.Collection; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; /** * Timeless implementation of the column aggregator node. @@ -300,7 +301,7 @@ private void propagateBatchUpdate(Collection> updates, Tim if (updates.isEmpty()) { return; } - var oldValues = CollectionsFactory.createMap(); + var oldValues = CollectionsFactory.>createMap(); for (var entry : updates) { var update = entry.getKey(); var key = groupMask.transform(update); @@ -316,7 +317,8 @@ private void propagateBatchUpdate(Collection> updates, Tim var oldMainAccumulator = memory.get(key); oldValues.computeIfAbsent(key, ignoredKey -> - oldMainAccumulator == null ? NEUTRAL : operator.getAggregate(oldMainAccumulator)); + Optional.ofNullable(oldMainAccumulator == null ? NEUTRAL : + operator.getAggregate(oldMainAccumulator))); Accumulator newMainAccumulator = oldMainAccumulator == null ? operator.createNeutral() : oldMainAccumulator; for (int i = 0; i < count; i++) { @@ -329,7 +331,7 @@ private void propagateBatchUpdate(Collection> updates, Tim var oldValue = entry.getValue(); var newMainAccumulator = getMainAccumulator(key); var newValue = operator.getAggregate(newMainAccumulator); - propagateAggregateResultUpdate(key, oldValue, newValue, timestamp); + propagateAggregateResultUpdate(key, oldValue.orElse(null), newValue, timestamp); } } diff --git a/subprojects/store-query-interpreter/src/test/java/tools/refinery/store/query/interpreter/AggregatorBatchingTest.java b/subprojects/store-query-interpreter/src/test/java/tools/refinery/store/query/interpreter/AggregatorBatchingTest.java index 0bdd21489..2153d9e72 100644 --- a/subprojects/store-query-interpreter/src/test/java/tools/refinery/store/query/interpreter/AggregatorBatchingTest.java +++ b/subprojects/store-query-interpreter/src/test/java/tools/refinery/store/query/interpreter/AggregatorBatchingTest.java @@ -20,8 +20,7 @@ import tools.refinery.store.representation.Symbol; import tools.refinery.store.tuple.Tuple; -import java.util.Map; -import java.util.Optional; +import java.util.*; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; @@ -38,6 +37,11 @@ class AggregatorBatchingTest { personView.call(p1), output.assign(valuesView.aggregate(new InstrumentedAggregator(), p1, Variable.of())) )); + private final Query queryMax = Query.of(Integer.class, (builder, p1, output) -> builder + .clause( + personView.call(p1), + output.assign(valuesView.aggregate(new InstrumentedAggregatorMax(), p1, Variable.of())) + )); private int extractCount = 0; @@ -48,6 +52,7 @@ void batchTest() { var valuesInterpretation = model.getInterpretation(values); var queryEngine = model.getAdapter(ModelQueryAdapter.class); var resultSet = queryEngine.getResultSet(query); + var resultSetMax = queryEngine.getResultSet(queryMax); assertThat(extractCount, is(1)); @@ -69,6 +74,11 @@ void batchTest() { Tuple.of(1), Optional.of(0), Tuple.of(2), Optional.empty() ), resultSet); + assertNullableResults(Map.of( + Tuple.of(0), Optional.of(3), + Tuple.of(1), Optional.of(1), + Tuple.of(2), Optional.empty() + ), resultSetMax); } @Test @@ -124,7 +134,7 @@ private Model createModel() { var store = ModelStore.builder() .symbols(person, values) .with(QueryInterpreterAdapter.builder() - .query(query)) + .queries(query, queryMax)) .build(); return store.createEmptyModel(); } @@ -185,4 +195,61 @@ public StatefulAggregate deepCopy() { return new InstrumentedAggregate(sum); } } + + class InstrumentedAggregatorMax implements StatefulAggregator { + @Override + public Class getResultType() { + return Integer.class; + } + + @Override + public Class getInputType() { + return Integer.class; + } + + @Override + public StatefulAggregate createEmptyAggregate() { + return new InstrumentedAggregateMax(); + } + } + class InstrumentedAggregateMax implements StatefulAggregate { + private final List numbers; + + public InstrumentedAggregateMax() { + this.numbers = new ArrayList<>(); + } + public InstrumentedAggregateMax(List numbers) { + this.numbers = new ArrayList<>(); + this.numbers.addAll(numbers); + } + + @Override + public void add(Integer value) { + numbers.add(value); + } + + @Override + public void remove(Integer value) { + numbers.remove(value); + } + + @Override + public Integer getResult() { + if(numbers.isEmpty()){ + return null; + } else { + return Collections.max(numbers); + } + } + + @Override + public boolean isEmpty() { + return numbers.isEmpty(); + } + + @Override + public StatefulAggregate deepCopy() { + return new InstrumentedAggregateMax(numbers); + } + } }