Skip to content

Commit

Permalink
[*] adopt test to transformation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitryikh committed Mar 16, 2019
1 parent f50d815 commit 9a9c6da
Show file tree
Hide file tree
Showing 14 changed files with 4,735 additions and 2,506 deletions.
32 changes: 23 additions & 9 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ build_breast_cancer_model.py:
'objective': 'binary',
}
clf = lgb.train(params, d_train, n_estimators)
# note raw_score=True used here because `leaves` output only raw scores
y_pred = clf.predict(X_test, raw_score=True)
y_pred = clf.predict(X_test)
y_pred_raw = clf.predict(X_test, raw_score=True)
clf.save_model('lg_breast_cancer.model') # save the model in txt format
np.savetxt('lg_breast_cancer_true_predictions.txt', y_pred)
np.savetxt('lg_breast_cancer_true_predictions_raw.txt', y_pred_raw)
np.savetxt('breast_cancer_test.tsv', X_test, delimiter='\t')
predict_breast_cancer_model.go:
Expand All @@ -55,39 +56,52 @@ predict_breast_cancer_model.go:
}
// loading model
model, err := leaves.LGEnsembleFromFile("lg_breast_cancer.model", false)
model, err := leaves.LGEnsembleFromFile("lg_breast_cancer.model", true)
if err != nil {
panic(err)
}
fmt.Printf("Name: %s\n", model.Name())
fmt.Printf("NFeatures: %d\n", model.NFeatures())
fmt.Printf("NRawOutputGroups: %d\n", model.NRawOutputGroups())
fmt.Printf("NOutputGroups: %d\n", model.NOutputGroups())
fmt.Printf("NEstimators: %d\n", model.NEstimators())
fmt.Printf("Transformation: %s\n", model.Transformation().Name())
// loading true predictions as DenseMat
truePredictions, err := mat.DenseMatFromCsvFile("lg_breast_cancer_true_predictions.txt", 0, false, "\t", 0.0)
if err != nil {
panic(err)
}
truePredictionsRaw, err := mat.DenseMatFromCsvFile("lg_breast_cancer_true_predictions_raw.txt", 0, false, "\t", 0.0)
if err != nil {
panic(err)
}
// preallocate slice to store model predictions
predictions := make([]float64, test.Rows*model.NRawOutputGroups())
predictions := make([]float64, test.Rows*model.NOutputGroups())
// do predictions
model.PredictDense(test.Values, test.Rows, test.Cols, predictions, 0, 1)
// compare results
const tolerance = 1e-6
if err := util.AlmostEqualFloat64Slices(truePredictions.Values, predictions, tolerance); err != nil {
panic(fmt.Errorf("different predictions: %s", err.Error()))
}
// compare raw predictions (before transformation function)
rawModel := model.EnsembleWithRawPredictions()
rawModel.PredictDense(test.Values, test.Rows, test.Cols, predictions, 0, 1)
if err := util.AlmostEqualFloat64Slices(truePredictionsRaw.Values, predictions, tolerance); err != nil {
panic(fmt.Errorf("different raw predictions: %s", err.Error()))
}
fmt.Println("Predictions the same!")
}
Output:
Name: lightgbm.gbdt
NFeatures: 30
NRawOutputGroups: 1
NOutputGroups: 1
NEstimators: 30
Transformation: logistic
Predictions the same!
XGBoost Model
Expand Down Expand Up @@ -146,7 +160,7 @@ predict_iris_model.go:
}
fmt.Printf("Name: %s\n", model.Name())
fmt.Printf("NFeatures: %d\n", model.NFeatures())
fmt.Printf("NRawOutputGroups: %d\n", model.NRawOutputGroups())
fmt.Printf("NOutputGroups: %d\n", model.NOutputGroups())
fmt.Printf("NEstimators: %d\n", model.NEstimators())
// loading true predictions as DenseMat
Expand All @@ -156,7 +170,7 @@ predict_iris_model.go:
}
// preallocate slice to store model predictions
predictions := make([]float64, csr.Rows()*model.NRawOutputGroups())
predictions := make([]float64, csr.Rows()*model.NOutputGroups())
// do predictions
model.PredictCSR(csr.RowHeaders, csr.ColIndexes, csr.Values, predictions, 0, 1)
// compare results
Expand All @@ -177,7 +191,7 @@ Output:
Name: xgboost.gbtree
NFeatures: 4
NRawOutputGroups: 3
NOutputGroups: 3
NEstimators: 5
Predictions the same! (mismatch = 0)
Expand Down
Loading

0 comments on commit 9a9c6da

Please sign in to comment.