Skip to content

Commit 714f26f

Browse files
committed
migrate node regression train
1 parent aab3a1c commit 714f26f

File tree

12 files changed

+479
-95
lines changed

12 files changed

+479
-95
lines changed

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/regression/NodeRegressionTrainAlgorithm.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.neo4j.gds.api.GraphStore;
2323
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
2424
import org.neo4j.gds.ml.pipeline.PipelineTrainAlgorithm;
25+
import org.neo4j.gds.ml.pipeline.PipelineTrainer;
2526
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureStep;
2627
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainResult.NodeRegressionTrainPipelineResult;
2728

@@ -31,8 +32,8 @@ public class NodeRegressionTrainAlgorithm extends PipelineTrainAlgorithm<
3132
NodeRegressionPipelineTrainConfig,
3233
NodeFeatureStep> {
3334

34-
NodeRegressionTrainAlgorithm(
35-
NodeRegressionTrain pipelineTrainer,
35+
public NodeRegressionTrainAlgorithm(
36+
PipelineTrainer<NodeRegressionTrainResult> pipelineTrainer,
3637
NodeRegressionTrainingPipeline pipeline,
3738
GraphStore graphStore,
3839
NodeRegressionPipelineTrainConfig config,

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/NodeRegressionPipelineTrainProc.java

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919
*/
2020
package org.neo4j.gds.ml.pipeline.node.regression;
2121

22-
import org.neo4j.gds.BaseProc;
23-
import org.neo4j.gds.core.model.ModelCatalog;
24-
import org.neo4j.gds.executor.ExecutionContext;
25-
import org.neo4j.gds.executor.ProcedureExecutor;
22+
import org.neo4j.gds.procedures.GraphDataScienceProcedures;
23+
import org.neo4j.gds.procedures.pipelines.NodeRegressionPipelineTrainResult;
2624
import org.neo4j.procedure.Context;
2725
import org.neo4j.procedure.Description;
2826
import org.neo4j.procedure.Mode;
@@ -32,28 +30,16 @@
3230
import java.util.Map;
3331
import java.util.stream.Stream;
3432

35-
import static org.neo4j.gds.ml.pipeline.PipelineCompanion.preparePipelineConfig;
36-
37-
public class NodeRegressionPipelineTrainProc extends BaseProc {
38-
33+
public class NodeRegressionPipelineTrainProc {
3934
@Context
40-
public ModelCatalog modelCatalog;
35+
public GraphDataScienceProcedures facade;
4136

4237
@Procedure(name = "gds.alpha.pipeline.nodeRegression.train", mode = Mode.READ)
4338
@Description("Trains a node classification model based on a pipeline")
44-
public Stream<TrainResult> train(
39+
public Stream<NodeRegressionPipelineTrainResult> train(
4540
@Name(value = "graphName") String graphName,
4641
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration
4742
) {
48-
preparePipelineConfig(graphName, configuration);
49-
return new ProcedureExecutor<>(
50-
new NodeRegressionPipelineTrainSpec(),
51-
executionContext()
52-
).compute(graphName, configuration);
53-
}
54-
55-
@Override
56-
public ExecutionContext executionContext() {
57-
return super.executionContext().withModelCatalog(modelCatalog);
43+
return facade.pipelines().nodeRegression().train(graphName, configuration);
5844
}
5945
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/NodeRegressionPipelineTrainSpec.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineTrainConfig;
3232
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainAlgorithm;
3333
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainPipelineAlgorithmFactory;
34-
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainResult;
34+
import org.neo4j.gds.procedures.pipelines.NodeRegressionPipelineTrainResult;
3535
import org.neo4j.graphdb.GraphDatabaseService;
3636

3737
import java.util.List;
@@ -42,9 +42,9 @@
4242
@GdsCallable(name = "gds.alpha.pipeline.nodeRegression.train", description = "Trains a node regression model based on a pipeline", executionMode = TRAIN)
4343
public class NodeRegressionPipelineTrainSpec implements AlgorithmSpec<
4444
NodeRegressionTrainAlgorithm,
45-
NodeRegressionTrainResult.NodeRegressionTrainPipelineResult,
45+
org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainResult.NodeRegressionTrainPipelineResult,
4646
NodeRegressionPipelineTrainConfig,
47-
Stream<TrainResult>,
47+
Stream<NodeRegressionPipelineTrainResult>,
4848
NodeRegressionTrainPipelineAlgorithmFactory> {
4949
@Override
5050
public String name() {
@@ -62,7 +62,7 @@ public NewConfigFunction<NodeRegressionPipelineTrainConfig> newConfigFunction()
6262
}
6363

6464
@Override
65-
public ComputationResultConsumer<NodeRegressionTrainAlgorithm, NodeRegressionTrainResult.NodeRegressionTrainPipelineResult, NodeRegressionPipelineTrainConfig, Stream<TrainResult>> computationResultConsumer() {
65+
public ComputationResultConsumer<NodeRegressionTrainAlgorithm, org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainResult.NodeRegressionTrainPipelineResult, NodeRegressionPipelineTrainConfig, Stream<NodeRegressionPipelineTrainResult>> computationResultConsumer() {
6666
return (computationResult, executionContext) -> {
6767
return computationResult.result().map(result -> {
6868
var model = result.model();
@@ -83,7 +83,7 @@ public ComputationResultConsumer<NodeRegressionTrainAlgorithm, NodeRegressionTra
8383
throw e;
8484
}
8585
}
86-
return Stream.of(new TrainResult(model, result.trainingStatistics(), computationResult.computeMillis()
86+
return Stream.of(new NodeRegressionPipelineTrainResult(model, result.trainingStatistics(), computationResult.computeMillis()
8787
));
8888
}).orElseGet(Stream::empty);
8989
};

procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/NodeRegressionFacade.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,6 @@ public interface NodeRegressionFacade {
2626
Stream<PredictMutateResult> mutate(String graphName, Map<String, Object> configuration);
2727

2828
Stream<NodeRegressionStreamResult> stream(String graphName, Map<String, Object> configuration);
29+
30+
Stream<NodeRegressionPipelineTrainResult> train(String graphName, Map<String, Object> configuration);
2931
}

proc/machine-learning/src/main/java/org/neo4j/gds/ml/pipeline/node/regression/TrainResult.java renamed to procedures/pipelines-facade-api/src/main/java/org/neo4j/gds/procedures/pipelines/NodeRegressionPipelineTrainResult.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@
1717
* You should have received a copy of the GNU General Public License
1818
* along with this program. If not, see <http://www.gnu.org/licenses/>.
1919
*/
20-
package org.neo4j.gds.ml.pipeline.node.regression;
20+
package org.neo4j.gds.procedures.pipelines;
2121

2222
import org.neo4j.gds.core.model.Model;
23-
import org.neo4j.gds.procedures.pipelines.MLTrainResult;
2423
import org.neo4j.gds.ml.models.Regressor;
2524
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineModelInfo;
2625
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineTrainConfig;
@@ -29,11 +28,10 @@
2928
import java.util.Map;
3029
import java.util.Optional;
3130

32-
public class TrainResult extends MLTrainResult {
33-
31+
public class NodeRegressionPipelineTrainResult extends MLTrainResult {
3432
public final Map<String, Object> modelSelectionStats;
3533

36-
TrainResult(
34+
public NodeRegressionPipelineTrainResult(
3735
Model<Regressor.RegressorData, NodeRegressionPipelineTrainConfig, NodeRegressionPipelineModelInfo> model,
3836
TrainingStatistics trainingStatistics,
3937
long trainMillis

procedures/pipelines-facade/src/main/java/org/neo4j/gds/procedures/pipelines/LocalNodeRegressionFacade.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,13 @@ public Stream<NodeRegressionStreamResult> stream(String graphNameAsString, Map<S
7474

7575
return pipelineApplications.nodeRegressionPredictStream(graphName, configuration);
7676
}
77+
78+
@Override
79+
public Stream<NodeRegressionPipelineTrainResult> train(String graphNameAsString, Map<String, Object> configuration) {
80+
PipelineCompanion.preparePipelineConfig(graphNameAsString, configuration);
81+
82+
var graphName = GraphName.parse(graphNameAsString);
83+
84+
return pipelineApplications.nodeRegressionTrain(graphName, configuration);
85+
}
7786
}
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.procedures.pipelines;
21+
22+
import org.neo4j.common.DependencyResolver;
23+
import org.neo4j.gds.api.CloseableResourceRegistry;
24+
import org.neo4j.gds.api.DatabaseId;
25+
import org.neo4j.gds.api.Graph;
26+
import org.neo4j.gds.api.GraphStore;
27+
import org.neo4j.gds.api.NodeLookup;
28+
import org.neo4j.gds.api.ProcedureReturnColumns;
29+
import org.neo4j.gds.api.User;
30+
import org.neo4j.gds.applications.algorithms.machinery.AlgorithmMachinery;
31+
import org.neo4j.gds.applications.algorithms.machinery.Computation;
32+
import org.neo4j.gds.applications.algorithms.machinery.ProgressTrackerCreator;
33+
import org.neo4j.gds.core.model.ModelCatalog;
34+
import org.neo4j.gds.core.utils.progress.TaskRegistryFactory;
35+
import org.neo4j.gds.core.utils.warnings.UserLogRegistryFactory;
36+
import org.neo4j.gds.core.write.NodePropertyExporterBuilder;
37+
import org.neo4j.gds.core.write.RelationshipExporterBuilder;
38+
import org.neo4j.gds.executor.ImmutableExecutionContext;
39+
import org.neo4j.gds.logging.Log;
40+
import org.neo4j.gds.metrics.Metrics;
41+
import org.neo4j.gds.ml.pipeline.PipelineCompanion;
42+
import org.neo4j.gds.ml.pipeline.nodePipeline.NodeFeatureProducer;
43+
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionPipelineTrainConfig;
44+
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrain;
45+
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainAlgorithm;
46+
import org.neo4j.gds.ml.pipeline.nodePipeline.regression.NodeRegressionTrainResult;
47+
import org.neo4j.gds.procedures.algorithms.AlgorithmsProcedureFacade;
48+
import org.neo4j.gds.termination.TerminationMonitor;
49+
50+
final class NodeRegressionTrainComputation implements Computation<NodeRegressionTrainResult.NodeRegressionTrainPipelineResult> {
51+
private final AlgorithmMachinery algorithmMachinery = new AlgorithmMachinery();
52+
53+
private final Log log;
54+
private final ModelCatalog modelCatalog;
55+
private final PipelineRepository pipelineRepository;
56+
private final CloseableResourceRegistry closeableResourceRegistry;
57+
private final DatabaseId databaseId;
58+
private final DependencyResolver dependencyResolver;
59+
private final Metrics metrics;
60+
private final NodeLookup nodeLookup;
61+
private final NodePropertyExporterBuilder nodePropertyExporterBuilder;
62+
private final ProcedureReturnColumns procedureReturnColumns;
63+
private final RelationshipExporterBuilder relationshipExporterBuilder;
64+
private final TaskRegistryFactory taskRegistryFactory;
65+
private final TerminationMonitor terminationMonitor;
66+
private final UserLogRegistryFactory userLogRegistryFactory;
67+
private final ProgressTrackerCreator progressTrackerCreator;
68+
private final AlgorithmsProcedureFacade algorithmsProcedureFacade;
69+
private final NodeRegressionPipelineTrainConfig configuration;
70+
71+
private NodeRegressionTrainComputation(
72+
Log log,
73+
ModelCatalog modelCatalog,
74+
PipelineRepository pipelineRepository,
75+
CloseableResourceRegistry closeableResourceRegistry,
76+
DatabaseId databaseId,
77+
DependencyResolver dependencyResolver,
78+
Metrics metrics,
79+
NodeLookup nodeLookup,
80+
NodePropertyExporterBuilder nodePropertyExporterBuilder,
81+
ProcedureReturnColumns procedureReturnColumns,
82+
RelationshipExporterBuilder relationshipExporterBuilder,
83+
TaskRegistryFactory taskRegistryFactory,
84+
TerminationMonitor terminationMonitor,
85+
UserLogRegistryFactory userLogRegistryFactory,
86+
ProgressTrackerCreator progressTrackerCreator,
87+
AlgorithmsProcedureFacade algorithmsProcedureFacade,
88+
NodeRegressionPipelineTrainConfig configuration
89+
) {
90+
this.log = log;
91+
this.modelCatalog = modelCatalog;
92+
this.pipelineRepository = pipelineRepository;
93+
this.closeableResourceRegistry = closeableResourceRegistry;
94+
this.databaseId = databaseId;
95+
this.dependencyResolver = dependencyResolver;
96+
this.metrics = metrics;
97+
this.nodeLookup = nodeLookup;
98+
this.nodePropertyExporterBuilder = nodePropertyExporterBuilder;
99+
this.procedureReturnColumns = procedureReturnColumns;
100+
this.relationshipExporterBuilder = relationshipExporterBuilder;
101+
this.taskRegistryFactory = taskRegistryFactory;
102+
this.terminationMonitor = terminationMonitor;
103+
this.userLogRegistryFactory = userLogRegistryFactory;
104+
this.progressTrackerCreator = progressTrackerCreator;
105+
this.algorithmsProcedureFacade = algorithmsProcedureFacade;
106+
this.configuration = configuration;
107+
}
108+
109+
static Computation<NodeRegressionTrainResult.NodeRegressionTrainPipelineResult> create(
110+
Log log,
111+
ModelCatalog modelCatalog,
112+
PipelineRepository pipelineRepository,
113+
CloseableResourceRegistry closeableResourceRegistry,
114+
DatabaseId databaseId,
115+
DependencyResolver dependencyResolver,
116+
Metrics metrics,
117+
NodeLookup nodeLookup,
118+
NodePropertyExporterBuilder nodePropertyExporterBuilder,
119+
ProcedureReturnColumns procedureReturnColumns,
120+
RelationshipExporterBuilder relationshipExporterBuilder,
121+
TaskRegistryFactory taskRegistryFactory,
122+
TerminationMonitor terminationMonitor,
123+
UserLogRegistryFactory userLogRegistryFactory,
124+
ProgressTrackerCreator progressTrackerCreator,
125+
AlgorithmsProcedureFacade algorithmsProcedureFacade,
126+
NodeRegressionPipelineTrainConfig configuration
127+
) {
128+
return new NodeRegressionTrainComputation(
129+
log,
130+
modelCatalog,
131+
pipelineRepository,
132+
closeableResourceRegistry,
133+
databaseId,
134+
dependencyResolver,
135+
metrics,
136+
nodeLookup,
137+
nodePropertyExporterBuilder,
138+
procedureReturnColumns,
139+
relationshipExporterBuilder,
140+
taskRegistryFactory,
141+
terminationMonitor,
142+
userLogRegistryFactory,
143+
progressTrackerCreator,
144+
algorithmsProcedureFacade,
145+
configuration
146+
);
147+
}
148+
149+
@Override
150+
public NodeRegressionTrainResult.NodeRegressionTrainPipelineResult compute(Graph graph, GraphStore graphStore) {
151+
var user = new User(configuration.username(), false);
152+
var pipelineName = PipelineName.parse(configuration.pipeline());
153+
var pipeline = pipelineRepository.getNodeRegressionTrainingPipeline(
154+
user,
155+
pipelineName
156+
);
157+
158+
PipelineCompanion.validateMainMetric(pipeline, configuration.metrics().get(0).toString());
159+
160+
var executionContext = ImmutableExecutionContext.builder()
161+
.algorithmsProcedureFacade(algorithmsProcedureFacade)
162+
.closeableResourceRegistry(closeableResourceRegistry)
163+
.databaseId(databaseId)
164+
.dependencyResolver(dependencyResolver)
165+
.isGdsAdmin(user.isAdmin())
166+
.log(log)
167+
.metrics(metrics)
168+
.modelCatalog(modelCatalog)
169+
.nodeLookup(nodeLookup)
170+
.nodePropertyExporterBuilder(nodePropertyExporterBuilder)
171+
.relationshipExporterBuilder(relationshipExporterBuilder)
172+
.returnColumns(procedureReturnColumns)
173+
.taskRegistryFactory(taskRegistryFactory)
174+
.terminationMonitor(terminationMonitor)
175+
.userLogRegistryFactory(userLogRegistryFactory)
176+
.username(user.getUsername())
177+
.build();
178+
179+
var task = NodeRegressionTrain.progressTask(pipeline, graphStore.nodeCount());
180+
var progressTracker = progressTrackerCreator.createProgressTracker(configuration, task);
181+
182+
var nodeFeatureProducer = NodeFeatureProducer.create(
183+
graphStore,
184+
configuration,
185+
executionContext,
186+
progressTracker
187+
);
188+
189+
nodeFeatureProducer.validateNodePropertyStepsContextConfigs(pipeline.nodePropertySteps());
190+
191+
var pipelineTrainer = NodeRegressionTrain.create(
192+
graphStore,
193+
pipeline,
194+
configuration,
195+
nodeFeatureProducer,
196+
progressTracker
197+
);
198+
199+
var algorithm = new NodeRegressionTrainAlgorithm(
200+
pipelineTrainer,
201+
pipeline,
202+
graphStore,
203+
configuration,
204+
progressTracker
205+
);
206+
207+
return algorithmMachinery.runAlgorithmsAndManageProgressTracker(algorithm, progressTracker, true);
208+
}
209+
}

0 commit comments

Comments
 (0)