Skip to content

Commit a9cbbd4

Browse files
rstzcopybara-github
authored andcommitted
Refactor: Add set_node_format to DecisionForestInterface
PiperOrigin-RevId: 580872047
1 parent ff849d9 commit a9cbbd4

File tree

3 files changed

+17
-30
lines changed

3 files changed

+17
-30
lines changed

tensorflow_decision_forests/tensorflow/ops/training/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ cc_library(
111111
"@ydf//yggdrasil_decision_forests/learner/distributed_decision_tree/dataset_cache:dataset_cache_common",
112112
"@ydf//yggdrasil_decision_forests/model:abstract_model",
113113
"@ydf//yggdrasil_decision_forests/model:model_library",
114+
"@ydf//yggdrasil_decision_forests/model/decision_tree:decision_forest_interface",
114115
"@ydf//yggdrasil_decision_forests/model/gradient_boosted_trees",
115116
"@ydf//yggdrasil_decision_forests/model/random_forest",
116117
"@ydf//yggdrasil_decision_forests/utils:concurrency",

tensorflow_decision_forests/tensorflow/ops/training/kernel.cc

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,8 @@
3737
#include "yggdrasil_decision_forests/learner/abstract_learner.pb.h"
3838
#include "yggdrasil_decision_forests/learner/learner_library.h"
3939
#include "yggdrasil_decision_forests/model/abstract_model.h"
40-
#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h"
40+
#include "yggdrasil_decision_forests/model/decision_tree/decision_forest_interface.h"
4141
#include "yggdrasil_decision_forests/model/model_library.h"
42-
#include "yggdrasil_decision_forests/model/random_forest/random_forest.h"
43-
#include "yggdrasil_decision_forests/utils/distribution.pb.h"
4442
#include "yggdrasil_decision_forests/utils/logging.h"
4543
#include "yggdrasil_decision_forests/utils/tensorflow.h"
4644

@@ -903,20 +901,15 @@ class SimpleMLModelTrainer : public tensorflow::OpKernel {
903901

904902
RETURN_IF_ERROR(model.status());
905903

906-
// If the model is GBT or RF, set the node format.
904+
// If the model is a decision forest, set the node format.
907905
if (!training_state->node_format.empty()) {
908906
// Set the model format.
909-
auto* gbt_model = dynamic_cast<
910-
model::gradient_boosted_trees::GradientBoostedTreesModel*>(
911-
model.value().get());
912-
if (gbt_model) {
913-
gbt_model->set_node_format(training_state->node_format);
914-
}
915-
916-
auto* rf_model = dynamic_cast<model::random_forest::RandomForestModel*>(
917-
model.value().get());
918-
if (rf_model) {
919-
rf_model->set_node_format(training_state->node_format);
907+
auto* df_model =
908+
dynamic_cast<model::DecisionForestInterface*>(model.value().get());
909+
if (df_model) {
910+
df_model->set_node_format(training_state->node_format);
911+
} else {
912+
YDF_LOG(INFO) << "The node format cannot be set for this model type";
920913
}
921914
}
922915

tensorflow_decision_forests/tensorflow/ops/training/kernel_on_file.cc

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,8 @@
3030
#include "yggdrasil_decision_forests/learner/abstract_learner.h"
3131
#include "yggdrasil_decision_forests/learner/abstract_learner.pb.h"
3232
#include "yggdrasil_decision_forests/learner/learner_library.h"
33-
#include "yggdrasil_decision_forests/model/abstract_model.h"
34-
#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h"
33+
#include "yggdrasil_decision_forests/model/decision_tree/decision_forest_interface.h"
3534
#include "yggdrasil_decision_forests/model/model_library.h"
36-
#include "yggdrasil_decision_forests/model/random_forest/random_forest.h"
3735
#include "yggdrasil_decision_forests/utils/logging.h"
3836
#include "yggdrasil_decision_forests/utils/tensorflow.h"
3937

@@ -211,20 +209,15 @@ class SimpleMLModelTrainerOnFile : public tensorflow::OpKernel {
211209

212210
RETURN_IF_ERROR(model.status());
213211

214-
// If the model is GBT or RF, set the node format.
212+
// If the model is a decision forest, set the node format.
215213
if (!training_state->node_format.empty()) {
216214
// Set the model format.
217-
auto* gbt_model = dynamic_cast<
218-
model::gradient_boosted_trees::GradientBoostedTreesModel*>(
219-
model.value().get());
220-
if (gbt_model) {
221-
gbt_model->set_node_format(training_state->node_format);
222-
}
223-
224-
auto* rf_model = dynamic_cast<model::random_forest::RandomForestModel*>(
225-
model.value().get());
226-
if (rf_model) {
227-
rf_model->set_node_format(training_state->node_format);
215+
auto* df_model =
216+
dynamic_cast<model::DecisionForestInterface*>(model.value().get());
217+
if (df_model) {
218+
df_model->set_node_format(training_state->node_format);
219+
} else {
220+
YDF_LOG(INFO) << "The node format cannot be set for this model type";
228221
}
229222
}
230223

0 commit comments

Comments
 (0)