Skip to content

Commit 3ee5b7f

Browse files
Louvain Rewrite write test
1 parent f50fb52 commit 3ee5b7f

File tree

1 file changed

+51
-97
lines changed

1 file changed

+51
-97
lines changed

proc/community/src/integrationTest/java/org/neo4j/gds/louvain/LouvainWriteProcTest.java

Lines changed: 51 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -36,61 +36,34 @@
3636
import org.neo4j.gds.functions.AsNodeFunc;
3737

3838
import java.util.HashMap;
39+
import java.util.HashSet;
3940
import java.util.List;
4041
import java.util.Map;
42+
import java.util.Optional;
4143
import java.util.stream.Stream;
4244

4345
import static org.assertj.core.api.Assertions.assertThat;
4446
import static org.assertj.core.api.InstanceOfAssertFactories.DOUBLE;
47+
import static org.assertj.core.api.InstanceOfAssertFactories.LIST;
4548
import static org.assertj.core.api.InstanceOfAssertFactories.LONG;
4649
import static org.assertj.core.api.InstanceOfAssertFactories.LONG_ARRAY;
50+
import static org.assertj.core.api.InstanceOfAssertFactories.SET;
4751

4852
class LouvainWriteProcTest extends BaseProcTest {
4953

5054
@Neo4jGraph
5155
private static final String DB_CYPHER =
5256
"CREATE" +
53-
" (a:Node {seed: 1})" + // 0
54-
", (b:Node {seed: 1})" + // 1
55-
", (c:Node {seed: 1})" + // 2
56-
", (d:Node {seed: 1})" + // 3
57-
", (e:Node {seed: 1})" + // 4
58-
", (f:Node {seed: 1})" + // 5
59-
", (g:Node {seed: 2})" + // 6
60-
", (h:Node {seed: 2})" + // 7
61-
", (i:Node {seed: 2})" + // 8
62-
", (j:Node {seed: 42})" + // 9
63-
", (k:Node {seed: 42})" + // 10
64-
", (l:Node {seed: 42})" + // 11
65-
", (m:Node {seed: 42})" + // 12
66-
", (n:Node {seed: 42})" + // 13
67-
", (x:Node {seed: 1})" + // 14
68-
69-
", (a)-[:TYPE {weight: 1.0}]->(b)" +
70-
", (a)-[:TYPE {weight: 1.0}]->(d)" +
71-
", (a)-[:TYPE {weight: 1.0}]->(f)" +
72-
", (b)-[:TYPE {weight: 1.0}]->(d)" +
73-
", (b)-[:TYPE {weight: 1.0}]->(x)" +
74-
", (b)-[:TYPE {weight: 1.0}]->(g)" +
75-
", (b)-[:TYPE {weight: 1.0}]->(e)" +
76-
", (c)-[:TYPE {weight: 1.0}]->(x)" +
77-
", (c)-[:TYPE {weight: 1.0}]->(f)" +
78-
", (d)-[:TYPE {weight: 1.0}]->(k)" +
79-
", (e)-[:TYPE {weight: 1.0}]->(x)" +
80-
", (e)-[:TYPE {weight: 0.01}]->(f)" +
81-
", (e)-[:TYPE {weight: 1.0}]->(h)" +
82-
", (f)-[:TYPE {weight: 1.0}]->(g)" +
83-
", (g)-[:TYPE {weight: 1.0}]->(h)" +
84-
", (h)-[:TYPE {weight: 1.0}]->(i)" +
85-
", (h)-[:TYPE {weight: 1.0}]->(j)" +
86-
", (i)-[:TYPE {weight: 1.0}]->(k)" +
87-
", (j)-[:TYPE {weight: 1.0}]->(k)" +
88-
", (j)-[:TYPE {weight: 1.0}]->(m)" +
89-
", (j)-[:TYPE {weight: 1.0}]->(n)" +
90-
", (k)-[:TYPE {weight: 1.0}]->(m)" +
91-
", (k)-[:TYPE {weight: 1.0}]->(l)" +
92-
", (l)-[:TYPE {weight: 1.0}]->(n)" +
93-
", (m)-[:TYPE {weight: 1.0}]->(n)";
57+
" (b0:Node {seed: 111})" +
58+
", (b1:Node {seed: 111})" +
59+
", (a0:Node {seed: 222})" +
60+
", (a1:Node {seed: 222})" +
61+
", (a2:Node {seed: 222})" +
62+
", (b0)-[:TYPE]->(b1)"+
63+
", (a0)-[:TYPE]->(a1)" +
64+
", (a0)-[:TYPE]->(a2)" +
65+
", (a1)-[:TYPE]->(a2)";
66+
9467

9568
@Inject
9669
private IdFunction idFunction;
@@ -121,24 +94,30 @@ void tearDown() {
12194
void testWrite() {
12295

12396
var query = "CALL gds.louvain.write('myGraph', { writeProperty: 'myFancyCommunity'})" +
124-
" YIELD communityCount, modularity, modularities, ranLevels, preProcessingMillis, " +
97+
" YIELD communityCount, modularity, modularities, ranLevels, preProcessingMillis, nodePropertiesWritten," +
12598
" computeMillis, writeMillis, postProcessingMillis, communityDistribution, configuration";
12699

127100
runQueryWithRowConsumer(query, row -> {
101+
102+
assertThat(row.getNumber("nodePropertiesWritten"))
103+
.asInstanceOf(LONG)
104+
.as("wrong node props written ")
105+
.isEqualTo(5L);
106+
128107
assertThat(row.getNumber("communityCount"))
129108
.asInstanceOf(LONG)
130109
.as("wrong community count")
131-
.isEqualTo(3L);
110+
.isEqualTo(2L);
132111

133112
assertThat(row.get("modularities"))
134-
.asList()
135-
.as("invalud modularities")
136-
.hasSize(2);
113+
.asInstanceOf(LIST)
114+
.as("invalid modularities")
115+
.hasSize(1);
137116

138117
assertThat(row.getNumber("ranLevels"))
139118
.asInstanceOf(LONG)
140119
.as("invalid level count")
141-
.isEqualTo(2L);
120+
.isEqualTo(1L);
142121

143122
assertUserInput(row, "includeIntermediateCommunities", false);
144123

@@ -168,12 +147,10 @@ void testWrite() {
168147
long nodeId = row.getNumber("id").longValue();
169148
actualCommunities.put(nodeId, row.getNumber("community").longValue());
170149
});
171-
172150
CommunityHelper.assertCommunities(
173151
actualCommunities,
174-
idFunction.of("a", "b", "c", "d", "e", "f", "x"),
175-
idFunction.of("g", "h", "i"),
176-
idFunction.of("j", "k", "l", "m", "n")
152+
idFunction.of("a0","a1","a2"),
153+
idFunction.of("b0","b1")
177154
);
178155
}
179156

@@ -190,40 +167,23 @@ void testWriteIntermediateCommunities() {
190167
runQueryWithRowConsumer("MATCH (n) RETURN n.myFancyCommunity as myFancyCommunity", row -> {
191168
assertThat(row.get("myFancyCommunity"))
192169
.asInstanceOf(LONG_ARRAY)
193-
.hasSize(2);
170+
.hasSize(1);
194171
});
195172
}
196173

197174
@Test
198175
void testWriteWithSeeding() {
199176
var query = "CALL gds.louvain.write('myGraph', { writeProperty: 'myFancyWriteProperty', seedProperty: 'seed'})" +
200177
" YIELD communityCount, ranLevels";
201-
runQueryWithRowConsumer(
202-
query,
203-
row -> {
204-
assertThat(row.getNumber("communityCount"))
205-
.asInstanceOf(LONG)
206-
.as("wrong community count")
207-
.isEqualTo(3L);
208-
assertThat(row.getNumber("ranLevels"))
209-
.asInstanceOf(LONG)
210-
.as("wrong number of levels")
211-
.isEqualTo(1L);
212-
}
213-
);
178+
runQuery(query);
214179

215-
Map<Long, Long> actualCommunities = new HashMap<>();
180+
var actualCommunities = new HashSet<>();
216181
runQueryWithRowConsumer("MATCH (n) RETURN id(n) as id, n.myFancyWriteProperty as community", row -> {
217-
long nodeId = row.getNumber("id").longValue();
218-
actualCommunities.put(nodeId, row.getNumber("community").longValue());
182+
actualCommunities.add(row.getNumber("community").longValue());
219183
});
220184

221-
CommunityHelper.assertCommunities(
222-
actualCommunities,
223-
idFunction.of("a", "b", "c", "d", "e", "f", "x"),
224-
idFunction.of("g", "h", "i"),
225-
idFunction.of("j", "k", "l", "m", "n")
226-
);
185+
186+
assertThat(actualCommunities).containsExactlyInAnyOrderElementsOf(List.of(111L,222L));
227187
}
228188

229189

@@ -252,46 +212,40 @@ void zeroCommunitiesInEmptyGraph() {
252212
static Stream<Arguments> communitySizeInputs() {
253213
return Stream.of(
254214
// configuration | expectedCommunityCount | expectedCommunityIds
255-
Arguments.of(Map.of("minCommunitySize", 1), 3L, List.of(11L, 13L, 14L)),
256-
Arguments.of(Map.of("minCommunitySize", 1, "consecutiveIds", true), 3L, List.of(0L, 1L, 2L)),
257-
Arguments.of(Map.of("minCommunitySize", 1, "seedProperty", "seed"), 3L, List.of(1L, 2L, 42L)),
258-
Arguments.of(Map.of("minCommunitySize", 3, "seedProperty", "seed"), 3L, List.of(2L, 42L))
215+
Arguments.of(Map.of("minCommunitySize", 3), 1, Optional.empty()),
216+
Arguments.of(Map.of("minCommunitySize", 3, "consecutiveIds", true), 1, Optional.of(List.of(0L))),
217+
Arguments.of(Map.of("minCommunitySize", 1, "seedProperty", "seed"), 2,Optional.of(List.of(111L,222L))),
218+
Arguments.of(Map.of("minCommunitySize", 3, "seedProperty", "seed"), 1,Optional.of(List.of(222L)))
259219
);
260220
}
261221

262222
@ParameterizedTest
263223
@MethodSource("communitySizeInputs")
264-
void testWriteWithMinCommunitySize(Map<String, Object> parameters, long expectedCommunityCount, List<Long> expectedCommunityIds) {
265-
var createQuery = GdsCypher.call(DEFAULT_GRAPH_NAME)
266-
.graphProject()
267-
.withAnyLabel()
268-
.withAnyRelationshipType()
269-
.withNodeProperty("seed")
270-
.yields();
271-
runQuery(createQuery);
224+
void testWriteWithMinCommunitySize(Map<String, Object> parameters, int expectedSize,Optional<List<Long>> expectedCommunityIds) {
272225

273226
var query = GdsCypher
274-
.call(DEFAULT_GRAPH_NAME)
227+
.call("myGraph")
275228
.algo("louvain")
276229
.writeMode()
277230
.addParameter("writeProperty", "writeProperty")
278231
.addParameter("concurrency", 1)
279232
.addAllParameters(parameters)
280-
.yields("communityCount");
233+
.yields();
281234

282-
runQueryWithRowConsumer(query, row -> {
283-
assertThat(row.getNumber("communityCount"))
284-
.asInstanceOf(LONG)
285-
.isEqualTo(expectedCommunityCount);
286-
});
235+
runQuery(query);
287236

237+
var hashSet=new HashSet<Long>();
288238
runQueryWithRowConsumer(
289-
"MATCH (n) RETURN collect(DISTINCT n.writeProperty) AS communities ",
239+
"MATCH (n) WHERE n.writeProperty IS NOT NULL RETURN n.writeProperty AS community ",
290240
row -> {
291-
assertThat(row.get("communities"))
292-
.asList()
293-
.containsExactlyInAnyOrderElementsOf(expectedCommunityIds);
241+
hashSet.add(row.getNumber("community").longValue());
294242
}
295243
);
244+
assertThat(hashSet).satisfies(
245+
set-> {
246+
assertThat(set).hasSize(expectedSize);
247+
expectedCommunityIds
248+
.ifPresent(exactIds -> assertThat(set).asInstanceOf(SET).containsAll(exactIds));
249+
});
296250
}
297251
}

0 commit comments

Comments
 (0)