3636import org .neo4j .gds .functions .AsNodeFunc ;
3737
3838import java .util .HashMap ;
39+ import java .util .HashSet ;
3940import java .util .List ;
4041import java .util .Map ;
42+ import java .util .Optional ;
4143import java .util .stream .Stream ;
4244
4345import static org .assertj .core .api .Assertions .assertThat ;
4446import static org .assertj .core .api .InstanceOfAssertFactories .DOUBLE ;
47+ import static org .assertj .core .api .InstanceOfAssertFactories .LIST ;
4548import static org .assertj .core .api .InstanceOfAssertFactories .LONG ;
4649import static org .assertj .core .api .InstanceOfAssertFactories .LONG_ARRAY ;
50+ import static org .assertj .core .api .InstanceOfAssertFactories .SET ;
4751
4852class LouvainWriteProcTest extends BaseProcTest {
4953
5054 @ Neo4jGraph
5155 private static final String DB_CYPHER =
5256 "CREATE" +
53- " (a:Node {seed: 1})" + // 0
54- ", (b:Node {seed: 1})" + // 1
55- ", (c:Node {seed: 1})" + // 2
56- ", (d:Node {seed: 1})" + // 3
57- ", (e:Node {seed: 1})" + // 4
58- ", (f:Node {seed: 1})" + // 5
59- ", (g:Node {seed: 2})" + // 6
60- ", (h:Node {seed: 2})" + // 7
61- ", (i:Node {seed: 2})" + // 8
62- ", (j:Node {seed: 42})" + // 9
63- ", (k:Node {seed: 42})" + // 10
64- ", (l:Node {seed: 42})" + // 11
65- ", (m:Node {seed: 42})" + // 12
66- ", (n:Node {seed: 42})" + // 13
67- ", (x:Node {seed: 1})" + // 14
68-
69- ", (a)-[:TYPE {weight: 1.0}]->(b)" +
70- ", (a)-[:TYPE {weight: 1.0}]->(d)" +
71- ", (a)-[:TYPE {weight: 1.0}]->(f)" +
72- ", (b)-[:TYPE {weight: 1.0}]->(d)" +
73- ", (b)-[:TYPE {weight: 1.0}]->(x)" +
74- ", (b)-[:TYPE {weight: 1.0}]->(g)" +
75- ", (b)-[:TYPE {weight: 1.0}]->(e)" +
76- ", (c)-[:TYPE {weight: 1.0}]->(x)" +
77- ", (c)-[:TYPE {weight: 1.0}]->(f)" +
78- ", (d)-[:TYPE {weight: 1.0}]->(k)" +
79- ", (e)-[:TYPE {weight: 1.0}]->(x)" +
80- ", (e)-[:TYPE {weight: 0.01}]->(f)" +
81- ", (e)-[:TYPE {weight: 1.0}]->(h)" +
82- ", (f)-[:TYPE {weight: 1.0}]->(g)" +
83- ", (g)-[:TYPE {weight: 1.0}]->(h)" +
84- ", (h)-[:TYPE {weight: 1.0}]->(i)" +
85- ", (h)-[:TYPE {weight: 1.0}]->(j)" +
86- ", (i)-[:TYPE {weight: 1.0}]->(k)" +
87- ", (j)-[:TYPE {weight: 1.0}]->(k)" +
88- ", (j)-[:TYPE {weight: 1.0}]->(m)" +
89- ", (j)-[:TYPE {weight: 1.0}]->(n)" +
90- ", (k)-[:TYPE {weight: 1.0}]->(m)" +
91- ", (k)-[:TYPE {weight: 1.0}]->(l)" +
92- ", (l)-[:TYPE {weight: 1.0}]->(n)" +
93- ", (m)-[:TYPE {weight: 1.0}]->(n)" ;
57+ " (b0:Node {seed: 111})" +
58+ ", (b1:Node {seed: 111})" +
59+ ", (a0:Node {seed: 222})" +
60+ ", (a1:Node {seed: 222})" +
61+ ", (a2:Node {seed: 222})" +
62+ ", (b0)-[:TYPE]->(b1)" +
63+ ", (a0)-[:TYPE]->(a1)" +
64+ ", (a0)-[:TYPE]->(a2)" +
65+ ", (a1)-[:TYPE]->(a2)" ;
66+
9467
9568 @ Inject
9669 private IdFunction idFunction ;
@@ -121,24 +94,30 @@ void tearDown() {
12194 void testWrite () {
12295
12396 var query = "CALL gds.louvain.write('myGraph', { writeProperty: 'myFancyCommunity'})" +
124- " YIELD communityCount, modularity, modularities, ranLevels, preProcessingMillis, " +
97+ " YIELD communityCount, modularity, modularities, ranLevels, preProcessingMillis, nodePropertiesWritten, " +
12598 " computeMillis, writeMillis, postProcessingMillis, communityDistribution, configuration" ;
12699
127100 runQueryWithRowConsumer (query , row -> {
101+
102+ assertThat (row .getNumber ("nodePropertiesWritten" ))
103+ .asInstanceOf (LONG )
104+ .as ("wrong node props written " )
105+ .isEqualTo (5L );
106+
128107 assertThat (row .getNumber ("communityCount" ))
129108 .asInstanceOf (LONG )
130109 .as ("wrong community count" )
131- .isEqualTo (3L );
110+ .isEqualTo (2L );
132111
133112 assertThat (row .get ("modularities" ))
134- .asList ( )
135- .as ("invalud modularities" )
136- .hasSize (2 );
113+ .asInstanceOf ( LIST )
114+ .as ("invalid modularities" )
115+ .hasSize (1 );
137116
138117 assertThat (row .getNumber ("ranLevels" ))
139118 .asInstanceOf (LONG )
140119 .as ("invalid level count" )
141- .isEqualTo (2L );
120+ .isEqualTo (1L );
142121
143122 assertUserInput (row , "includeIntermediateCommunities" , false );
144123
@@ -168,12 +147,10 @@ void testWrite() {
168147 long nodeId = row .getNumber ("id" ).longValue ();
169148 actualCommunities .put (nodeId , row .getNumber ("community" ).longValue ());
170149 });
171-
172150 CommunityHelper .assertCommunities (
173151 actualCommunities ,
174- idFunction .of ("a" , "b" , "c" , "d" , "e" , "f" , "x" ),
175- idFunction .of ("g" , "h" , "i" ),
176- idFunction .of ("j" , "k" , "l" , "m" , "n" )
152+ idFunction .of ("a0" ,"a1" ,"a2" ),
153+ idFunction .of ("b0" ,"b1" )
177154 );
178155 }
179156
@@ -190,40 +167,23 @@ void testWriteIntermediateCommunities() {
190167 runQueryWithRowConsumer ("MATCH (n) RETURN n.myFancyCommunity as myFancyCommunity" , row -> {
191168 assertThat (row .get ("myFancyCommunity" ))
192169 .asInstanceOf (LONG_ARRAY )
193- .hasSize (2 );
170+ .hasSize (1 );
194171 });
195172 }
196173
197174 @ Test
198175 void testWriteWithSeeding () {
199176 var query = "CALL gds.louvain.write('myGraph', { writeProperty: 'myFancyWriteProperty', seedProperty: 'seed'})" +
200177 " YIELD communityCount, ranLevels" ;
201- runQueryWithRowConsumer (
202- query ,
203- row -> {
204- assertThat (row .getNumber ("communityCount" ))
205- .asInstanceOf (LONG )
206- .as ("wrong community count" )
207- .isEqualTo (3L );
208- assertThat (row .getNumber ("ranLevels" ))
209- .asInstanceOf (LONG )
210- .as ("wrong number of levels" )
211- .isEqualTo (1L );
212- }
213- );
178+ runQuery (query );
214179
215- Map < Long , Long > actualCommunities = new HashMap <>();
180+ var actualCommunities = new HashSet <>();
216181 runQueryWithRowConsumer ("MATCH (n) RETURN id(n) as id, n.myFancyWriteProperty as community" , row -> {
217- long nodeId = row .getNumber ("id" ).longValue ();
218- actualCommunities .put (nodeId , row .getNumber ("community" ).longValue ());
182+ actualCommunities .add (row .getNumber ("community" ).longValue ());
219183 });
220184
221- CommunityHelper .assertCommunities (
222- actualCommunities ,
223- idFunction .of ("a" , "b" , "c" , "d" , "e" , "f" , "x" ),
224- idFunction .of ("g" , "h" , "i" ),
225- idFunction .of ("j" , "k" , "l" , "m" , "n" )
226- );
185+
186+ assertThat (actualCommunities ).containsExactlyInAnyOrderElementsOf (List .of (111L ,222L ));
227187 }
228188
229189
@@ -252,46 +212,40 @@ void zeroCommunitiesInEmptyGraph() {
252212 static Stream <Arguments > communitySizeInputs () {
253213 return Stream .of (
254214 // configuration | expectedCommunityCount | expectedCommunityIds
255- Arguments .of (Map .of ("minCommunitySize" , 1 ), 3L , List . of ( 11L , 13L , 14L )),
256- Arguments .of (Map .of ("minCommunitySize" , 1 , "consecutiveIds" , true ), 3L , List .of (0L , 1L , 2L )),
257- Arguments .of (Map .of ("minCommunitySize" , 1 , "seedProperty" , "seed" ), 3L , List .of (1L , 2L , 42L )),
258- Arguments .of (Map .of ("minCommunitySize" , 3 , "seedProperty" , "seed" ), 3L , List .of (2L , 42L ))
215+ Arguments .of (Map .of ("minCommunitySize" , 3 ), 1 , Optional . empty ( )),
216+ Arguments .of (Map .of ("minCommunitySize" , 3 , "consecutiveIds" , true ), 1 , Optional . of ( List .of (0L ) )),
217+ Arguments .of (Map .of ("minCommunitySize" , 1 , "seedProperty" , "seed" ), 2 , Optional . of ( List .of (111L , 222L ) )),
218+ Arguments .of (Map .of ("minCommunitySize" , 3 , "seedProperty" , "seed" ), 1 , Optional . of ( List .of (222L ) ))
259219 );
260220 }
261221
262222 @ ParameterizedTest
263223 @ MethodSource ("communitySizeInputs" )
264- void testWriteWithMinCommunitySize (Map <String , Object > parameters , long expectedCommunityCount , List <Long > expectedCommunityIds ) {
265- var createQuery = GdsCypher .call (DEFAULT_GRAPH_NAME )
266- .graphProject ()
267- .withAnyLabel ()
268- .withAnyRelationshipType ()
269- .withNodeProperty ("seed" )
270- .yields ();
271- runQuery (createQuery );
224+ void testWriteWithMinCommunitySize (Map <String , Object > parameters , int expectedSize ,Optional <List <Long >> expectedCommunityIds ) {
272225
273226 var query = GdsCypher
274- .call (DEFAULT_GRAPH_NAME )
227+ .call ("myGraph" )
275228 .algo ("louvain" )
276229 .writeMode ()
277230 .addParameter ("writeProperty" , "writeProperty" )
278231 .addParameter ("concurrency" , 1 )
279232 .addAllParameters (parameters )
280- .yields ("communityCount" );
233+ .yields ();
281234
282- runQueryWithRowConsumer (query , row -> {
283- assertThat (row .getNumber ("communityCount" ))
284- .asInstanceOf (LONG )
285- .isEqualTo (expectedCommunityCount );
286- });
235+ runQuery (query );
287236
237+ var hashSet =new HashSet <Long >();
288238 runQueryWithRowConsumer (
289- "MATCH (n) RETURN collect(DISTINCT n.writeProperty) AS communities " ,
239+ "MATCH (n) WHERE n.writeProperty IS NOT NULL RETURN n.writeProperty AS community " ,
290240 row -> {
291- assertThat (row .get ("communities" ))
292- .asList ()
293- .containsExactlyInAnyOrderElementsOf (expectedCommunityIds );
241+ hashSet .add (row .getNumber ("community" ).longValue ());
294242 }
295243 );
244+ assertThat (hashSet ).satisfies (
245+ set -> {
246+ assertThat (set ).hasSize (expectedSize );
247+ expectedCommunityIds
248+ .ifPresent (exactIds -> assertThat (set ).asInstanceOf (SET ).containsAll (exactIds ));
249+ });
296250 }
297251}
0 commit comments