Skip to content

Commit 8b8f3ce

Browse files
committed
migrate link prediction mutate
1 parent eff6707 commit 8b8f3ce

File tree

46 files changed

+1008
-128
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1008
-128
lines changed

applications/algorithms/machinery/src/main/java/org/neo4j/gds/applications/algorithms/machinery/AlgorithmEstimationTemplate.java

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,15 @@ public <CONFIGURATION extends AlgoBaseConfig> MemoryEstimateResult estimate(
6565
CONFIGURATION configuration,
6666
Object graphNameOrConfiguration,
6767
MemoryEstimation memoryEstimation
68+
) {
69+
return estimate(configuration, graphNameOrConfiguration, memoryEstimation, DimensionTransformer.DISABLED);
70+
}
71+
72+
public <CONFIGURATION extends AlgoBaseConfig> MemoryEstimateResult estimate(
73+
CONFIGURATION configuration,
74+
Object graphNameOrConfiguration,
75+
MemoryEstimation memoryEstimation,
76+
DimensionTransformer dimensionTransformer
6877
) {
6978
var estimationBuilder = MemoryEstimations.builder("Memory Estimation");
7079

@@ -76,10 +85,14 @@ public <CONFIGURATION extends AlgoBaseConfig> MemoryEstimateResult estimate(
7685

7786
estimationBuilder.add("graph", graphMemoryEstimation.estimateMemoryUsageAfterLoading());
7887

88+
var graphDimensions = graphMemoryEstimation.dimensions();
89+
90+
var transformedDimensions = dimensionTransformer.transform(graphDimensions);
91+
7992
return estimate(
8093
estimationBuilder,
8194
memoryEstimation,
82-
graphMemoryEstimation.dimensions(),
95+
transformedDimensions,
8396
configuration.concurrency()
8497
);
8598
}
@@ -89,7 +102,9 @@ public <CONFIGURATION extends AlgoBaseConfig> MemoryEstimateResult estimate(
89102

90103
var graphDimensions = dimensionsFromActualGraph(graphName, configuration);
91104

92-
return estimate(estimationBuilder, memoryEstimation, graphDimensions, configuration.concurrency());
105+
var transformedDimensions = dimensionTransformer.transform(graphDimensions);
106+
107+
return estimate(estimationBuilder, memoryEstimation, transformedDimensions, configuration.concurrency());
93108
}
94109

95110
throw new IllegalArgumentException(formatWithLocale(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.applications.algorithms.machinery;
21+
22+
import org.neo4j.gds.core.GraphDimensions;
23+
24+
/**
25+
* For some algorithms we can transform dimensions intelligently,
26+
* to give better estimates
27+
*/
28+
public interface DimensionTransformer {
29+
DimensionTransformer DISABLED = graphDimensions -> graphDimensions;
30+
31+
GraphDimensions transform(GraphDimensions graphDimensions);
32+
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/LinkPredictionPipelineCompanion.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.neo4j.gds.ml.models.Classifier;
2626
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionModelInfo;
2727
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfig;
28+
import org.neo4j.gds.procedures.pipelines.TrainedLPPipelineModel;
2829

2930
import java.util.List;
3031
import java.util.Map;
@@ -47,6 +48,6 @@ public static Model<Classifier.ClassifierData, LinkPredictionTrainConfig, LinkPr
4748
String pipelineName,
4849
String username
4950
) {
50-
return modelCatalog.get(username, pipelineName, Classifier.ClassifierData.class, LinkPredictionTrainConfig.class, LinkPredictionModelInfo.class);
51+
return new TrainedLPPipelineModel(modelCatalog).get(pipelineName, username);
5152
}
5253
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateProc.java

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,9 @@
1919
*/
2020
package org.neo4j.gds.ml.linkmodels.pipeline.predict;
2121

22-
import org.neo4j.gds.BaseProc;
23-
import org.neo4j.gds.core.model.ModelCatalog;
24-
import org.neo4j.gds.executor.ExecutionContext;
25-
import org.neo4j.gds.executor.MemoryEstimationExecutor;
26-
import org.neo4j.gds.executor.ProcedureExecutor;
2722
import org.neo4j.gds.applications.algorithms.machinery.MemoryEstimateResult;
23+
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
24+
import org.neo4j.gds.procedures.pipelines.MutateResult;
2825
import org.neo4j.procedure.Context;
2926
import org.neo4j.procedure.Description;
3027
import org.neo4j.procedure.Mode;
@@ -36,25 +33,18 @@
3633

3734
import static org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionPipelineCompanion.ESTIMATE_PREDICT_DESCRIPTION;
3835
import static org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionPipelineCompanion.PREDICT_DESCRIPTION;
39-
import static org.neo4j.gds.ml.pipeline.PipelineCompanion.preparePipelineConfig;
40-
41-
public class LinkPredictionPipelineMutateProc extends BaseProc {
4236

37+
public class LinkPredictionPipelineMutateProc {
4338
@Context
44-
public ModelCatalog modelCatalog;
39+
public GraphDataScienceProcedures facade;
4540

4641
@Procedure(name = "gds.beta.pipeline.linkPrediction.predict.mutate", mode = Mode.READ)
4742
@Description(PREDICT_DESCRIPTION)
4843
public Stream<MutateResult> mutate(
4944
@Name(value = "graphName") String graphName,
5045
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
5146
) {
52-
preparePipelineConfig(graphName, configuration);
53-
54-
return new ProcedureExecutor<>(
55-
new LinkPredictionPipelineMutateSpec(),
56-
executionContext()
57-
).compute(graphName, configuration);
47+
return facade.pipelines().linkPrediction().mutate(graphName, configuration);
5848
}
5949

6050
@Procedure(name = "gds.beta.pipeline.linkPrediction.predict.mutate.estimate", mode = Mode.READ)
@@ -63,17 +53,6 @@ public Stream<MemoryEstimateResult> estimate(
6353
@Name(value = "graphNameOrConfiguration") Object graphNameOrConfiguration,
6454
@Name(value = "algoConfiguration") Map<String, Object> algoConfiguration
6555
) {
66-
preparePipelineConfig(graphNameOrConfiguration, algoConfiguration);
67-
return new MemoryEstimationExecutor<>(
68-
new LinkPredictionPipelineMutateSpec(),
69-
executionContext(),
70-
transactionContext()
71-
).computeEstimate(graphNameOrConfiguration, algoConfiguration);
56+
return facade.pipelines().linkPrediction().mutateEstimate(graphNameOrConfiguration, algoConfiguration);
7257
}
73-
74-
@Override
75-
public ExecutionContext executionContext() {
76-
return super.executionContext().withModelCatalog(modelCatalog);
77-
}
78-
7958
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateResultConsumer.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
import org.neo4j.gds.executor.ComputationResult;
3232
import org.neo4j.gds.executor.ExecutionContext;
3333
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
34+
import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineExecutor;
35+
import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineMutateConfig;
36+
import org.neo4j.gds.procedures.pipelines.MutateResult;
3437
import org.neo4j.gds.result.AbstractResultBuilder;
3538
import org.neo4j.gds.termination.TerminationFlag;
3639

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineMutateSpec.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
import org.neo4j.gds.executor.GdsCallable;
2727
import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction;
2828
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
29+
import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineExecutor;
30+
import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineMutateConfig;
31+
import org.neo4j.gds.procedures.pipelines.MutateResult;
2932

3033
import java.util.Collections;
3134
import java.util.stream.Stream;

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPipelineStreamSpec.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import org.neo4j.gds.executor.GdsCallable;
2727
import org.neo4j.gds.procedures.algorithms.configuration.NewConfigFunction;
2828
import org.neo4j.gds.ml.linkmodels.LinkPredictionResult;
29+
import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineExecutor;
30+
import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineStreamConfig;
2931

3032
import java.util.Collection;
3133
import java.util.stream.Stream;

proc/machine-learning/src/main/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineAlgorithmFactory.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
import org.neo4j.gds.core.utils.progress.tasks.Task;
3131
import org.neo4j.gds.executor.ExecutionContext;
3232
import org.neo4j.gds.ml.models.ClassifierFactory;
33+
import org.neo4j.gds.procedures.pipelines.LPGraphStoreFilterFactory;
34+
import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineBaseConfig;
35+
import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineExecutor;
3336

3437
import static org.neo4j.gds.ml.linkmodels.pipeline.LinkPredictionPipelineCompanion.getTrainedLPPipelineModel;
3538
import static org.neo4j.gds.ml.pipeline.PipelineCompanion.ANONYMOUS_GRAPH;
@@ -71,7 +74,7 @@ public LinkPredictionPredictPipelineExecutor build(
7174
);
7275

7376
var trainConfig = model.trainConfig();
74-
var lpGraphStoreFilter = LPGraphStoreFilterFactory.generate(trainConfig, configuration, graphStore, progressTracker);
77+
var lpGraphStoreFilter = LPGraphStoreFilterFactory.generate(executionContext.log(), trainConfig, configuration, graphStore);
7578

7679
return new LinkPredictionPredictPipelineExecutor(
7780
model.customInfo().pipeline(),
@@ -117,7 +120,7 @@ public GraphDimensions estimatedGraphDimensionTransformer(GraphDimensions graphD
117120
.get(CatalogRequest.of(config.username(), executionContext.databaseId()), config.graphName())
118121
.graphStore();
119122

120-
var lpNodeLabelFilter = LPGraphStoreFilterFactory.generate(model.trainConfig(), config, graphStore, ProgressTracker.NULL_TRACKER);
123+
var lpNodeLabelFilter = LPGraphStoreFilterFactory.generate(executionContext.log(), model.trainConfig(), config, graphStore);
121124

122125
//Taking nodePropertyStepsLabels since they are superset of source&target nodeLabels, to give the upper bound estimation
123126
//In the future we can add nodeCount per label info to GraphDimensions to make more exact estimations

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/classification/predict/NodeClassificationPredictPipelineAlgorithmFactory.java

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.neo4j.gds.ml.pipeline.nodePipeline.classification.train.NodeClassificationPipelineTrainConfig;
3535
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineBaseConfig;
3636
import org.neo4j.gds.procedures.pipelines.NodeClassificationPredictPipelineExecutor;
37+
import org.neo4j.gds.procedures.pipelines.TrainedNCPipelineModel;
3738

3839
public class NodeClassificationPredictPipelineAlgorithmFactory
3940
<CONFIG extends NodeClassificationPredictPipelineBaseConfig>
@@ -92,11 +93,9 @@ public NodeClassificationPredictPipelineExecutor build(
9293

9394
@Override
9495
public MemoryEstimation memoryEstimation(CONFIG configuration) {
95-
var model = getTrainedNCPipelineModel(
96-
this.modelCatalog,
97-
configuration.modelName(),
98-
configuration.username()
99-
);
96+
var trainedNCPipelineModel = new TrainedNCPipelineModel(modelCatalog);
97+
98+
var model = trainedNCPipelineModel.get(configuration.modelName(), configuration.username());
10099

101100
return MemoryEstimations.builder(NodeClassificationPredictPipelineExecutor.class.getSimpleName())
102101
.add("Pipeline executor", NodeClassificationPredictPipelineExecutor.estimate(model, configuration, modelCatalog, executionContext.algorithmsProcedureFacade()))
@@ -108,12 +107,8 @@ private static Model<Classifier.ClassifierData, NodeClassificationPipelineTrainC
108107
String modelName,
109108
String username
110109
) {
111-
return modelCatalog.get(
112-
username,
113-
modelName,
114-
Classifier.ClassifierData.class,
115-
NodeClassificationPipelineTrainConfig.class,
116-
NodeClassificationPipelineModelInfo.class
117-
);
110+
var trainedNCPipelineModel = new TrainedNCPipelineModel(modelCatalog);
111+
112+
return trainedNCPipelineModel.get(modelName, username);
118113
}
119114
}

proc/machine-learning/src/test/java/org/neo4j/gds/ml/linkmodels/pipeline/predict/LinkPredictionPredictPipelineAlgorithmFactoryTest.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.neo4j.gds.ml.pipeline.linkPipeline.LinkPredictionPredictPipeline;
3939
import org.neo4j.gds.ml.pipeline.linkPipeline.linkfunctions.L2FeatureStep;
4040
import org.neo4j.gds.ml.pipeline.linkPipeline.train.LinkPredictionTrainConfigImpl;
41+
import org.neo4j.gds.procedures.pipelines.LinkPredictionPredictPipelineStreamConfig;
4142

4243
import java.util.List;
4344
import java.util.Map;

0 commit comments

Comments
 (0)