2626import org .neo4j .gds .core .utils .partition .PartitionUtils ;
2727import org .neo4j .gds .core .utils .progress .tasks .ProgressTracker ;
2828import org .neo4j .gds .mem .BitUtil ;
29- import org .neo4j .gds .ml .core .functions .Sigmoid ;
3029import org .neo4j .gds .ml .core .tensor .FloatVector ;
3130
3231import java .util .ArrayList ;
3736import java .util .concurrent .atomic .AtomicInteger ;
3837import java .util .function .LongUnaryOperator ;
3938
40- import static org .neo4j .gds .ml .core .tensor .operations .FloatVectorOperations .addInPlace ;
41- import static org .neo4j .gds .ml .core .tensor .operations .FloatVectorOperations .scale ;
4239import static org .neo4j .gds .utils .StringFormatting .formatWithLocale ;
4340
4441public class Node2VecModel {
@@ -58,7 +55,7 @@ public class Node2VecModel {
5855 private final ProgressTracker progressTracker ;
5956 private final long randomSeed ;
6057
61- private static final double EPSILON = 1e-10 ;
58+ static final double EPSILON = 1e-10 ;
6259
6360 Node2VecModel (
6461 LongUnaryOperator toOriginalId ,
@@ -192,89 +189,6 @@ private HugeObjectArray<FloatVector> initializeEmbeddings(
192189 return embeddings ;
193190 }
194191
195- private static final class TrainingTask implements Runnable {
196- private final HugeObjectArray <FloatVector > centerEmbeddings ;
197- private final HugeObjectArray <FloatVector > contextEmbeddings ;
198-
199- private final PositiveSampleProducer positiveSampleProducer ;
200- private final NegativeSampleProducer negativeSampleProducer ;
201- private final FloatVector centerGradientBuffer ;
202- private final FloatVector contextGradientBuffer ;
203- private final int negativeSamplingRate ;
204- private final float learningRate ;
205-
206- private final ProgressTracker progressTracker ;
207-
208- private double lossSum ;
209-
210- private TrainingTask (
211- HugeObjectArray <FloatVector > centerEmbeddings ,
212- HugeObjectArray <FloatVector > contextEmbeddings ,
213- PositiveSampleProducer positiveSampleProducer ,
214- NegativeSampleProducer negativeSampleProducer ,
215- float learningRate ,
216- int negativeSamplingRate ,
217- int embeddingDimensions ,
218- ProgressTracker progressTracker
219- ) {
220- this .centerEmbeddings = centerEmbeddings ;
221- this .contextEmbeddings = contextEmbeddings ;
222- this .positiveSampleProducer = positiveSampleProducer ;
223- this .negativeSampleProducer = negativeSampleProducer ;
224- this .learningRate = learningRate ;
225- this .negativeSamplingRate = negativeSamplingRate ;
226-
227- this .centerGradientBuffer = new FloatVector (embeddingDimensions );
228- this .contextGradientBuffer = new FloatVector (embeddingDimensions );
229- this .progressTracker = progressTracker ;
230- }
231-
232- @ Override
233- public void run () {
234- var buffer = new long [2 ];
235-
236- // this corresponds to a stochastic optimizer as the embeddings are updated after each sample
237- while (positiveSampleProducer .next (buffer )) {
238- trainSample (buffer [0 ], buffer [1 ], true );
239-
240- for (var i = 0 ; i < negativeSamplingRate ; i ++) {
241- trainSample (buffer [0 ], negativeSampleProducer .next (), false );
242- }
243- progressTracker .logProgress ();
244- }
245- }
246-
247- private void trainSample (long center , long context , boolean positive ) {
248- var centerEmbedding = centerEmbeddings .get (center );
249- var contextEmbedding = contextEmbeddings .get (context );
250-
251- // L_pos = -log sigmoid(center * context) ; gradient: -sigmoid (-center * context)
252- // L_neg = -log sigmoid(-center * context) ; gradient: sigmoid (center * context)
253- float affinity = centerEmbedding .innerProduct (contextEmbedding );
254-
255- //When |affinity| > 40, positiveSigmoid = 1. Double precision is not enough.
256- //Make sure negativeSigmoid can never be 0 to avoid infinity loss.
257- double positiveSigmoid = Sigmoid .sigmoid (affinity );
258- double negativeSigmoid = 1 - positiveSigmoid ;
259-
260- lossSum -= positive ? Math .log (positiveSigmoid + EPSILON ) : Math .log (negativeSigmoid + EPSILON );
261-
262- float gradient = positive ? (float ) -negativeSigmoid : (float ) positiveSigmoid ;
263- // we are doing gradient descent, so we go in the negative direction of the gradient here
264- float scaledGradient = -gradient * learningRate ;
265-
266- scale (contextEmbedding .data (), scaledGradient , centerGradientBuffer .data ());
267- scale (centerEmbedding .data (), scaledGradient , contextGradientBuffer .data ());
268-
269- addInPlace (centerEmbedding .data (), centerGradientBuffer .data ());
270- addInPlace (contextEmbedding .data (), contextGradientBuffer .data ());
271- }
272-
273- double lossSum () {
274- return lossSum ;
275- }
276- }
277-
278192 static class FloatConsumer {
279193 float [] values ;
280194 int index ;
0 commit comments