From 59098ed796ffb725ba534fa188db642bc39f3315 Mon Sep 17 00:00:00 2001 From: Jan Brabec Date: Mon, 6 May 2019 07:50:24 +0200 Subject: [PATCH] Initial commit. v1.0.0 --- .gitignore | 48 + LICENSE | 60 + LICENSE-binary | 523 +++++++ NOTICE | 31 + NOTICE-binary | 1168 +++++++++++++++ README.md | 98 +- pom.xml | 163 +++ .../OptimizedDecisionTreeClassifier.scala | 341 +++++ .../OptimizedRandomForestClassifier.scala | 367 +++++ .../OptimizedDecisionTreeRegressor.scala | 323 +++++ .../OptimizedRandomForestRegressor.scala | 326 +++++ .../ml/tree/LocalTrainingAlgorithm.scala | 30 + .../apache/spark/ml/tree/OptimizedNode.scala | 430 ++++++ .../spark/ml/tree/impl/AggUpdateUtils.scala | 85 ++ .../spark/ml/tree/impl/FeatureColumn.scala | 97 ++ .../spark/ml/tree/impl/ImpurityUtils.scala | 126 ++ .../ml/tree/impl/LocalDecisionTree.scala | 268 ++++ .../ml/tree/impl/LocalDecisionTreeUtils.scala | 142 ++ .../tree/impl/LocalTrainingScheduling.scala | 117 ++ .../ml/tree/impl/OptimizedRandomForest.scala | 1264 +++++++++++++++++ .../spark/ml/tree/impl/SplitUtils.scala | 206 +++ .../spark/ml/tree/impl/TrainingInfo.scala | 152 ++ .../ml/tree/impl/TrainingStatistics.scala | 22 + .../org/apache/spark/ml/tree/treeModels.scala | 355 +++++ .../org/apache/spark/ml/tree/treeParams.scala | 136 ++ .../mllib/tree/OptimizedRandomForest.scala | 476 +++++++ .../OptimizedForestStrategy.scala | 164 +++ .../TimePredictionStrategy.scala | 25 + ...OptimizedDecisionTreeClassifierSuite.scala | 432 ++++++ ...OptimizedRandomForestClassifierSuite.scala | 250 ++++ .../OptimizedDecisionTreeRegressorSuite.scala | 170 +++ .../OptimizedRandomForestRegressorSuite.scala | 154 ++ .../impl/LocalDecisionTreeRegressor.scala | 76 + .../ml/tree/impl/LocalTrainingPlanSuite.scala | 74 + .../ml/tree/impl/LocalTreeDataSuite.scala | 202 +++ .../tree/impl/LocalTreeIntegrationSuite.scala | 106 ++ .../spark/ml/tree/impl/LocalTreeTests.scala | 108 ++ .../ml/tree/impl/LocalTreeUnitSuite.scala | 109 ++ .../ml/tree/impl/LocalTreeUtilsSuite.scala | 96 ++ ...ptimizedDecisionTreeIntegrationSuite.scala | 128 ++ .../impl/OptimizedRandomForestSuite.scala | 663 +++++++++ .../ml/tree/impl/OptimizedTreeTests.scala | 319 +++++ .../ml/tree/impl/TreeSplitUtilsSuite.scala | 163 +++ 43 files changed, 10591 insertions(+), 2 deletions(-) create mode 100755 .gitignore mode change 100644 => 100755 LICENSE create mode 100755 LICENSE-binary create mode 100755 NOTICE create mode 100755 NOTICE-binary mode change 100644 => 100755 README.md create mode 100755 pom.xml create mode 100755 src/main/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifier.scala create mode 100755 src/main/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifier.scala create mode 100755 src/main/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressor.scala create mode 100755 src/main/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressor.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/LocalTrainingAlgorithm.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/OptimizedNode.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/impl/FeatureColumn.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTree.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeUtils.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/impl/LocalTrainingScheduling.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForest.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/impl/TrainingInfo.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/impl/TrainingStatistics.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/treeModels.scala create mode 100755 src/main/scala/org/apache/spark/ml/tree/treeParams.scala create mode 100755 src/main/scala/org/apache/spark/mllib/tree/OptimizedRandomForest.scala create mode 100755 src/main/scala/org/apache/spark/mllib/tree/configuration/OptimizedForestStrategy.scala create mode 100755 src/main/scala/org/apache/spark/mllib/tree/configuration/TimePredictionStrategy.scala create mode 100755 src/test/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifierSuite.scala create mode 100755 src/test/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifierSuite.scala create mode 100755 src/test/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressorSuite.scala create mode 100755 src/test/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressorSuite.scala create mode 100755 src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala create mode 100755 src/test/scala/org/apache/spark/ml/tree/impl/LocalTrainingPlanSuite.scala create mode 100755 src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeDataSuite.scala create mode 100755 src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeIntegrationSuite.scala create mode 100755 src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeTests.scala create mode 100755 src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUnitSuite.scala create mode 100755 src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUtilsSuite.scala create mode 100755 src/test/scala/org/apache/spark/ml/tree/impl/OptimizedDecisionTreeIntegrationSuite.scala create mode 100755 src/test/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForestSuite.scala create mode 100755 src/test/scala/org/apache/spark/ml/tree/impl/OptimizedTreeTests.scala create mode 100755 src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000..7e12bb5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,48 @@ +# use glob syntax. +syntax: glob +*.ser +*.class +*~ +*.bak +*.off +*.old + +# eclipse conf file +.settings +.classpath +.project +.manager +.idea + +# building +target +build +null +tmp +temp +dist +test-output +build.log +pom.xml.releaseBackup +release.properties +TestJavaClass*.java +interpolated-*.xml + +# other scm +.svn +.CVS +.hg* + +# switch to regexp syntax. +# syntax: regexp +# ^\.pc/ + +#SHITTY output not in target directory +build.log + +# IDEA conf files +*.iml +*.ipr +*.iws + +spark-warehouse diff --git a/LICENSE b/LICENSE old mode 100644 new mode 100755 index 261eeb9..b771bd5 --- a/LICENSE +++ b/LICENSE @@ -199,3 +199,63 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses/ +for text of these licenses. + + +Apache Software Foundation License 2.0 +-------------------------------------- + +common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +BSD 3-Clause +------------ + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg \ No newline at end of file diff --git a/LICENSE-binary b/LICENSE-binary new file mode 100755 index 0000000..66c5599 --- /dev/null +++ b/LICENSE-binary @@ -0,0 +1,523 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ +This project bundles some components that are also licensed under the Apache +License Version 2.0: + +commons-beanutils:commons-beanutils +org.apache.zookeeper:zookeeper +oro:oro +commons-configuration:commons-configuration +commons-digester:commons-digester +com.chuusai:shapeless_2.12 +com.googlecode.javaewah:JavaEWAH +com.twitter:chill-java +com.twitter:chill_2.12 +com.univocity:univocity-parsers +javax.jdo:jdo-api +joda-time:joda-time +net.sf.opencsv:opencsv +org.apache.derby:derby +org.objenesis:objenesis +org.roaringbitmap:RoaringBitmap +org.scalanlp:breeze-macros_2.12 +org.scalanlp:breeze_2.12 +org.typelevel:macro-compat_2.12 +org.yaml:snakeyaml +org.apache.xbean:xbean-asm5-shaded +com.squareup.okhttp3:logging-interceptor +com.squareup.okhttp3:okhttp +com.squareup.okio:okio +org.apache.spark:spark-catalyst_2.12 +org.apache.spark:spark-kvstore_2.12 +org.apache.spark:spark-launcher_2.12 +org.apache.spark:spark-mllib-local_2.12 +org.apache.spark:spark-network-common_2.12 +org.apache.spark:spark-network-shuffle_2.12 +org.apache.spark:spark-sketch_2.12 +org.apache.spark:spark-tags_2.12 +org.apache.spark:spark-unsafe_2.12 +commons-httpclient:commons-httpclient +com.vlkan:flatbuffers +com.ning:compress-lzf +io.airlift:aircompressor +io.dropwizard.metrics:metrics-core +io.dropwizard.metrics:metrics-ganglia +io.dropwizard.metrics:metrics-graphite +io.dropwizard.metrics:metrics-json +io.dropwizard.metrics:metrics-jvm +org.iq80.snappy:snappy +com.clearspring.analytics:stream +com.jamesmurty.utils:java-xmlbuilder +commons-codec:commons-codec +commons-collections:commons-collections +io.fabric8:kubernetes-client +io.fabric8:kubernetes-model +io.netty:netty +io.netty:netty-all +net.hydromatic:eigenbase-properties +net.sf.supercsv:super-csv +org.apache.arrow:arrow-format +org.apache.arrow:arrow-memory +org.apache.arrow:arrow-vector +org.apache.commons:commons-crypto +org.apache.commons:commons-lang3 +org.apache.hadoop:hadoop-annotations +org.apache.hadoop:hadoop-auth +org.apache.hadoop:hadoop-client +org.apache.hadoop:hadoop-common +org.apache.hadoop:hadoop-hdfs +org.apache.hadoop:hadoop-mapreduce-client-app +org.apache.hadoop:hadoop-mapreduce-client-common +org.apache.hadoop:hadoop-mapreduce-client-core +org.apache.hadoop:hadoop-mapreduce-client-jobclient +org.apache.hadoop:hadoop-mapreduce-client-shuffle +org.apache.hadoop:hadoop-yarn-api +org.apache.hadoop:hadoop-yarn-client +org.apache.hadoop:hadoop-yarn-common +org.apache.hadoop:hadoop-yarn-server-common +org.apache.hadoop:hadoop-yarn-server-web-proxy +org.apache.httpcomponents:httpclient +org.apache.httpcomponents:httpcore +org.apache.orc:orc-core +org.apache.orc:orc-mapreduce +org.mortbay.jetty:jetty +org.mortbay.jetty:jetty-util +com.jolbox:bonecp +org.json4s:json4s-ast_2.12 +org.json4s:json4s-core_2.12 +org.json4s:json4s-jackson_2.12 +org.json4s:json4s-scalap_2.12 +com.carrotsearch:hppc +com.fasterxml.jackson.core:jackson-annotations +com.fasterxml.jackson.core:jackson-core +com.fasterxml.jackson.core:jackson-databind +com.fasterxml.jackson.dataformat:jackson-dataformat-yaml +com.fasterxml.jackson.module:jackson-module-jaxb-annotations +com.fasterxml.jackson.module:jackson-module-paranamer +com.fasterxml.jackson.module:jackson-module-scala_2.12 +com.github.mifmif:generex +com.google.code.findbugs:jsr305 +com.google.code.gson:gson +com.google.inject:guice +com.google.inject.extensions:guice-servlet +com.twitter:parquet-hadoop-bundle +commons-cli:commons-cli +commons-dbcp:commons-dbcp +commons-io:commons-io +commons-lang:commons-lang +commons-logging:commons-logging +commons-net:commons-net +commons-pool:commons-pool +io.fabric8:zjsonpatch +javax.inject:javax.inject +javax.validation:validation-api +log4j:apache-log4j-extras +log4j:log4j +net.sf.jpam:jpam +org.apache.avro:avro +org.apache.avro:avro-ipc +org.apache.avro:avro-mapred +org.apache.commons:commons-compress +org.apache.commons:commons-math3 +org.apache.curator:curator-client +org.apache.curator:curator-framework +org.apache.curator:curator-recipes +org.apache.directory.api:api-asn1-api +org.apache.directory.api:api-util +org.apache.directory.server:apacheds-i18n +org.apache.directory.server:apacheds-kerberos-codec +org.apache.htrace:htrace-core +org.apache.ivy:ivy +org.apache.mesos:mesos +org.apache.parquet:parquet-column +org.apache.parquet:parquet-common +org.apache.parquet:parquet-encoding +org.apache.parquet:parquet-format +org.apache.parquet:parquet-hadoop +org.apache.parquet:parquet-jackson +org.apache.thrift:libfb303 +org.apache.thrift:libthrift +org.codehaus.jackson:jackson-core-asl +org.codehaus.jackson:jackson-mapper-asl +org.datanucleus:datanucleus-api-jdo +org.datanucleus:datanucleus-core +org.datanucleus:datanucleus-rdbms +org.lz4:lz4-java +org.spark-project.hive:hive-beeline +org.spark-project.hive:hive-cli +org.spark-project.hive:hive-exec +org.spark-project.hive:hive-jdbc +org.spark-project.hive:hive-metastore +org.xerial.snappy:snappy-java +stax:stax-api +xerces:xercesImpl +org.codehaus.jackson:jackson-jaxrs +org.codehaus.jackson:jackson-xc +org.eclipse.jetty:jetty-client +org.eclipse.jetty:jetty-continuation +org.eclipse.jetty:jetty-http +org.eclipse.jetty:jetty-io +org.eclipse.jetty:jetty-jndi +org.eclipse.jetty:jetty-plus +org.eclipse.jetty:jetty-proxy +org.eclipse.jetty:jetty-security +org.eclipse.jetty:jetty-server +org.eclipse.jetty:jetty-servlet +org.eclipse.jetty:jetty-servlets +org.eclipse.jetty:jetty-util +org.eclipse.jetty:jetty-webapp +org.eclipse.jetty:jetty-xml + +core/src/main/java/org/apache/spark/util/collection/TimSort.java +core/src/main/resources/org/apache/spark/ui/static/bootstrap* +core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* +core/src/main/resources/org/apache/spark/ui/static/vis* +docs/js/vendor/bootstrap.js + + +------------------------------------------------------------------------------------ +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses-binary/ +for text of these licenses. + + +BSD 2-Clause +------------ + +com.github.luben:zstd-jni +javolution:javolution +com.esotericsoftware:kryo-shaded +com.esotericsoftware:minlog +com.esotericsoftware:reflectasm +com.google.protobuf:protobuf-java +org.codehaus.janino:commons-compiler +org.codehaus.janino:janino +jline:jline +org.jodd:jodd-core + + +BSD 3-Clause +------------ + +dk.brics.automaton:automaton +org.antlr:antlr-runtime +org.antlr:ST4 +org.antlr:stringtemplate +org.antlr:antlr4-runtime +antlr:antlr +com.github.fommil.netlib:core +com.thoughtworks.paranamer:paranamer +org.scala-lang:scala-compiler +org.scala-lang:scala-library +org.scala-lang:scala-reflect +org.scala-lang.modules:scala-parser-combinators_2.12 +org.scala-lang.modules:scala-xml_2.12 +org.fusesource.leveldbjni:leveldbjni-all +net.sourceforge.f2j:arpack_combined_all +xmlenc:xmlenc +net.sf.py4j:py4j +org.jpmml:pmml-model +org.jpmml:pmml-schema + +python/lib/py4j-*-src.zip +python/pyspark/cloudpickle.py +python/pyspark/join.py +core/src/main/resources/org/apache/spark/ui/static/d3.min.js + +The CSS style for the navigation sidebar of the documentation was originally +submitted by Óscar Nájera for the scikit-learn project. The scikit-learn project +is distributed under the 3-Clause BSD license. + + +MIT License +----------- + +org.spire-math:spire-macros_2.12 +org.spire-math:spire_2.12 +org.typelevel:machinist_2.12 +net.razorvine:pyrolite +org.slf4j:jcl-over-slf4j +org.slf4j:jul-to-slf4j +org.slf4j:slf4j-api +org.slf4j:slf4j-log4j12 +com.github.scopt:scopt_2.12 + +core/src/main/resources/org/apache/spark/ui/static/dagre-d3.min.js +core/src/main/resources/org/apache/spark/ui/static/*dataTables* +core/src/main/resources/org/apache/spark/ui/static/graphlib-dot.min.js +ore/src/main/resources/org/apache/spark/ui/static/jquery* +core/src/main/resources/org/apache/spark/ui/static/sorttable.js +docs/js/vendor/anchor.min.js +docs/js/vendor/jquery* +docs/js/vendor/modernizer* + + +Common Development and Distribution License (CDDL) 1.0 +------------------------------------------------------ + +javax.activation:activation http://www.oracle.com/technetwork/java/javase/tech/index-jsp-138795.html +javax.xml.stream:stax-api https://jcp.org/en/jsr/detail?id=173 + + +Common Development and Distribution License (CDDL) 1.1 +------------------------------------------------------ + +javax.annotation:javax.annotation-api https://jcp.org/en/jsr/detail?id=250 +javax.servlet:javax.servlet-api https://javaee.github.io/servlet-spec/ +javax.transaction:jta http://www.oracle.com/technetwork/java/index.html +javax.ws.rs:javax.ws.rs-api https://github.com/jax-rs +javax.xml.bind:jaxb-api https://github.com/javaee/jaxb-v2 +org.glassfish.hk2:hk2-api https://github.com/javaee/glassfish +org.glassfish.hk2:hk2-locator (same) +org.glassfish.hk2:hk2-utils +org.glassfish.hk2:osgi-resource-locator +org.glassfish.hk2.external:aopalliance-repackaged +org.glassfish.hk2.external:javax.inject +org.glassfish.jersey.bundles.repackaged:jersey-guava +org.glassfish.jersey.containers:jersey-container-servlet +org.glassfish.jersey.containers:jersey-container-servlet-core +org.glassfish.jersey.core:jersey-client +org.glassfish.jersey.core:jersey-common +org.glassfish.jersey.core:jersey-server +org.glassfish.jersey.media:jersey-media-jaxb + + +Eclipse Distribution License (EDL) 1.0 +-------------------------------------- + +org.glassfish.jaxb:jaxb-runtime +jakarta.xml.bind:jakarta.xml.bind-api +com.sun.istack:istack-commons-runtime +jakarta.activation:jakarta.activation-api + + +Mozilla Public License (MPL) 1.1 +-------------------------------- + +com.github.rwl:jtransforms https://sourceforge.net/projects/jtransforms/ + + +Python Software Foundation License +---------------------------------- + +pyspark/heapq3.py + + +Public Domain +------------- + +aopalliance:aopalliance +net.iharder:base64 +org.tukaani:xz + + +Creative Commons CC0 1.0 Universal Public Domain Dedication +----------------------------------------------------------- +(see LICENSE-CC0.txt) + +data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg +data/mllib/images/kittens/54893.jpg +data/mllib/images/kittens/DP153539.jpg +data/mllib/images/kittens/DP802813.jpg +data/mllib/images/multi-channel/chr30.4.184.jpg diff --git a/NOTICE b/NOTICE new file mode 100755 index 0000000..ce361fb --- /dev/null +++ b/NOTICE @@ -0,0 +1,31 @@ +Optimized Random Forest +Copyright 2019 and onwards Cisco Systems. + +Apache Spark +Copyright 2014 and onwards The Apache Software Foundation. + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Export Control Notice +--------------------- + +This distribution includes cryptographic software. The country in which you currently reside may have +restrictions on the import, possession, use, and/or re-export to another country, of encryption software. +BEFORE using any encryption software, please check your country's laws, regulations and policies concerning +the import, possession, or use, and re-export of encryption software, to see if this is permitted. See + for more information. + +The U.S. Government Department of Commerce, Bureau of Industry and Security (BIS), has classified this +software as Export Commodity Control Number (ECCN) 5D002.C.1, which includes information security software +using or performing cryptographic functions with asymmetric algorithms. The form and manner of this Apache +Software Foundation distribution makes it eligible for export under the License Exception ENC Technology +Software Unrestricted (TSU) exception (see the BIS Export Administration Regulations, Section 740.13) for +both object code and source code. + +The following provides more details on the included cryptographic software: + +This software uses Apache Commons Crypto (https://commons.apache.org/proper/commons-crypto/) to +support authentication, and encryption and decryption of data sent across the network between +services. diff --git a/NOTICE-binary b/NOTICE-binary new file mode 100755 index 0000000..26cafd7 --- /dev/null +++ b/NOTICE-binary @@ -0,0 +1,1168 @@ +Optimized Random Forest +Copyright 2019 and onwards Cisco Systems. + +Apache Spark +Copyright 2014 and onwards The Apache Software Foundation. + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Export Control Notice +--------------------- + +This distribution includes cryptographic software. The country in which you currently reside may have +restrictions on the import, possession, use, and/or re-export to another country, of encryption software. +BEFORE using any encryption software, please check your country's laws, regulations and policies concerning +the import, possession, or use, and re-export of encryption software, to see if this is permitted. See + for more information. + +The U.S. Government Department of Commerce, Bureau of Industry and Security (BIS), has classified this +software as Export Commodity Control Number (ECCN) 5D002.C.1, which includes information security software +using or performing cryptographic functions with asymmetric algorithms. The form and manner of this Apache +Software Foundation distribution makes it eligible for export under the License Exception ENC Technology +Software Unrestricted (TSU) exception (see the BIS Export Administration Regulations, Section 740.13) for +both object code and source code. + +The following provides more details on the included cryptographic software: + +This software uses Apache Commons Crypto (https://commons.apache.org/proper/commons-crypto/) to +support authentication, and encryption and decryption of data sent across the network between +services. + + +// ------------------------------------------------------------------ +// NOTICE file corresponding to the section 4d of The Apache License, +// Version 2.0, in this case for +// ------------------------------------------------------------------ + +Hive Beeline +Copyright 2016 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro +Copyright 2009-2014 The Apache Software Foundation + +This product currently only contains code developed by authors +of specific components, as identified by the source code files; +if such notes are missing files have been created by +Tatu Saloranta. + +For additional credits (generally to people who reported problems) +see CREDITS file. + +Apache Commons Compress +Copyright 2002-2012 The Apache Software Foundation + +This product includes software developed by +The Apache Software Foundation (http://www.apache.org/). + +Apache Avro Mapred API +Copyright 2009-2014 The Apache Software Foundation + +Apache Avro IPC +Copyright 2009-2014 The Apache Software Foundation + +Objenesis +Copyright 2006-2013 Joe Walnes, Henri Tremblay, Leonardo Mesquita + +Apache XBean :: ASM 5 shaded (repackaged) +Copyright 2005-2015 The Apache Software Foundation + +-------------------------------------- + +This product includes software developed at +OW2 Consortium (http://asm.ow2.org/) + +This product includes software developed by The Apache Software +Foundation (http://www.apache.org/). + +The binary distribution of this product bundles binaries of +org.iq80.leveldb:leveldb-api (https://github.com/dain/leveldb), which has the +following notices: +* Copyright 2011 Dain Sundstrom +* Copyright 2011 FuseSource Corp. http://fusesource.com + +The binary distribution of this product bundles binaries of +org.fusesource.hawtjni:hawtjni-runtime (https://github.com/fusesource/hawtjni), +which has the following notices: +* This product includes software developed by FuseSource Corp. + http://fusesource.com +* This product includes software developed at + Progress Software Corporation and/or its subsidiaries or affiliates. +* This product includes software developed by IBM Corporation and others. + +The binary distribution of this product bundles binaries of +Gson 2.2.4, +which has the following notices: + + The Netty Project + ================= + +Please visit the Netty web site for more information: + + * http://netty.io/ + +Copyright 2014 The Netty Project + +The Netty Project licenses this file to you under the Apache License, +version 2.0 (the "License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at: + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +License for the specific language governing permissions and limitations +under the License. + +Also, please refer to each LICENSE..txt file, which is located in +the 'license' directory of the distribution file, for the license terms of the +components that this product depends on. + +------------------------------------------------------------------------------- +This product contains the extensions to Java Collections Framework which has +been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: + + * LICENSE: + * license/LICENSE.jsr166y.txt (Public Domain) + * HOMEPAGE: + * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ + * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ + +This product contains a modified version of Robert Harder's Public Domain +Base64 Encoder and Decoder, which can be obtained at: + + * LICENSE: + * license/LICENSE.base64.txt (Public Domain) + * HOMEPAGE: + * http://iharder.sourceforge.net/current/java/base64/ + +This product contains a modified portion of 'Webbit', an event based +WebSocket and HTTP server, which can be obtained at: + + * LICENSE: + * license/LICENSE.webbit.txt (BSD License) + * HOMEPAGE: + * https://github.com/joewalnes/webbit + +This product contains a modified portion of 'SLF4J', a simple logging +facade for Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.slf4j.txt (MIT License) + * HOMEPAGE: + * http://www.slf4j.org/ + +This product contains a modified portion of 'ArrayDeque', written by Josh +Bloch of Google, Inc: + + * LICENSE: + * license/LICENSE.deque.txt (Public Domain) + +This product contains a modified portion of 'Apache Harmony', an open source +Java SE, which can be obtained at: + + * LICENSE: + * license/LICENSE.harmony.txt (Apache License 2.0) + * HOMEPAGE: + * http://archive.apache.org/dist/harmony/ + +This product contains a modified version of Roland Kuhn's ASL2 +AbstractNodeQueue, which is based on Dmitriy Vyukov's non-intrusive MPSC queue. +It can be obtained at: + + * LICENSE: + * license/LICENSE.abstractnodequeue.txt (Public Domain) + * HOMEPAGE: + * https://github.com/akka/akka/blob/wip-2.2.3-for-scala-2.11/akka-actor/src/main/java/akka/dispatch/AbstractNodeQueue.java + +This product contains a modified portion of 'jbzip2', a Java bzip2 compression +and decompression library written by Matthew J. Francis. It can be obtained at: + + * LICENSE: + * license/LICENSE.jbzip2.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jbzip2/ + +This product contains a modified portion of 'libdivsufsort', a C API library to construct +the suffix array and the Burrows-Wheeler transformed string for any input string of +a constant-size alphabet written by Yuta Mori. It can be obtained at: + + * LICENSE: + * license/LICENSE.libdivsufsort.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/libdivsufsort/ + +This product contains a modified portion of Nitsan Wakart's 'JCTools', Java Concurrency Tools for the JVM, + which can be obtained at: + + * LICENSE: + * license/LICENSE.jctools.txt (ASL2 License) + * HOMEPAGE: + * https://github.com/JCTools/JCTools + +This product optionally depends on 'JZlib', a re-implementation of zlib in +pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product optionally depends on 'Compress-LZF', a Java library for encoding and +decoding data in LZF format, written by Tatu Saloranta. It can be obtained at: + + * LICENSE: + * license/LICENSE.compress-lzf.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/ning/compress + +This product optionally depends on 'lz4', a LZ4 Java compression +and decompression library written by Adrien Grand. It can be obtained at: + + * LICENSE: + * license/LICENSE.lz4.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jpountz/lz4-java + +This product optionally depends on 'lzma-java', a LZMA Java compression +and decompression library, which can be obtained at: + + * LICENSE: + * license/LICENSE.lzma-java.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jponge/lzma-java + +This product contains a modified portion of 'jfastlz', a Java port of FastLZ compression +and decompression library written by William Kinney. It can be obtained at: + + * LICENSE: + * license/LICENSE.jfastlz.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jfastlz/ + +This product contains a modified portion of and optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + + * LICENSE: + * license/LICENSE.protobuf.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/protobuf/ + +This product optionally depends on 'Bouncy Castle Crypto APIs' to generate +a temporary self-signed X.509 certificate when the JVM does not provide the +equivalent functionality. It can be obtained at: + + * LICENSE: + * license/LICENSE.bouncycastle.txt (MIT License) + * HOMEPAGE: + * http://www.bouncycastle.org/ + +This product optionally depends on 'Snappy', a compression library produced +by Google Inc, which can be obtained at: + + * LICENSE: + * license/LICENSE.snappy.txt (New BSD License) + * HOMEPAGE: + * http://code.google.com/p/snappy/ + +This product optionally depends on 'JBoss Marshalling', an alternative Java +serialization API, which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-marshalling.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://www.jboss.org/jbossmarshalling + +This product optionally depends on 'Caliper', Google's micro- +benchmarking framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.caliper.txt (Apache License 2.0) + * HOMEPAGE: + * http://code.google.com/p/caliper/ + +This product optionally depends on 'Apache Commons Logging', a logging +framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-logging.txt (Apache License 2.0) + * HOMEPAGE: + * http://commons.apache.org/logging/ + +This product optionally depends on 'Apache Log4J', a logging framework, which +can be obtained at: + + * LICENSE: + * license/LICENSE.log4j.txt (Apache License 2.0) + * HOMEPAGE: + * http://logging.apache.org/log4j/ + +This product optionally depends on 'Aalto XML', an ultra-high performance +non-blocking XML processor, which can be obtained at: + + * LICENSE: + * license/LICENSE.aalto-xml.txt (Apache License 2.0) + * HOMEPAGE: + * http://wiki.fasterxml.com/AaltoHome + +This product contains a modified version of 'HPACK', a Java implementation of +the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at: + + * LICENSE: + * license/LICENSE.hpack.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/twitter/hpack + +This product contains a modified portion of 'Apache Commons Lang', a Java library +provides utilities for the java.lang API, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-lang.txt (Apache License 2.0) + * HOMEPAGE: + * https://commons.apache.org/proper/commons-lang/ + +The binary distribution of this product bundles binaries of +Commons Codec 1.4, +which has the following notices: + * src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.javacontains test data from http://aspell.net/test/orig/batch0.tab.Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + =============================================================================== + The content of package org.apache.commons.codec.language.bm has been translated + from the original php source code available at http://stevemorse.org/phoneticinfo.htm + with permission from the original authors. + Original source copyright:Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +The binary distribution of this product bundles binaries of +Commons Lang 2.6, +which has the following notices: + * This product includes software from the Spring Framework,under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +The binary distribution of this product bundles binaries of +Apache Log4j 1.2.17, +which has the following notices: + * ResolverUtil.java + Copyright 2005-2006 Tim Fennell + Dumbster SMTP test server + Copyright 2004 Jason Paul Kitchen + TypeUtil.java + Copyright 2002-2012 Ramnivas Laddad, Juergen Hoeller, Chris Beams + +The binary distribution of this product bundles binaries of +Jetty 6.1.26, +which has the following notices: + * ============================================================== + Jetty Web Container + Copyright 1995-2016 Mort Bay Consulting Pty Ltd. + ============================================================== + + The Jetty Web Container is Copyright Mort Bay Consulting Pty Ltd + unless otherwise noted. + + Jetty is dual licensed under both + + * The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0.html + + and + + * The Eclipse Public 1.0 License + http://www.eclipse.org/legal/epl-v10.html + + Jetty may be distributed under either license. + + ------ + Eclipse + + The following artifacts are EPL. + * org.eclipse.jetty.orbit:org.eclipse.jdt.core + + The following artifacts are EPL and ASL2. + * org.eclipse.jetty.orbit:javax.security.auth.message + + The following artifacts are EPL and CDDL 1.0. + * org.eclipse.jetty.orbit:javax.mail.glassfish + + ------ + Oracle + + The following artifacts are CDDL + GPLv2 with classpath exception. + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + * javax.servlet:javax.servlet-api + * javax.annotation:javax.annotation-api + * javax.transaction:javax.transaction-api + * javax.websocket:javax.websocket-api + + ------ + Oracle OpenJDK + + If ALPN is used to negotiate HTTP/2 connections, then the following + artifacts may be included in the distribution or downloaded when ALPN + module is selected. + + * java.sun.security.ssl + + These artifacts replace/modify OpenJDK classes. The modififications + are hosted at github and both modified and original are under GPL v2 with + classpath exceptions. + http://openjdk.java.net/legal/gplv2+ce.html + + ------ + OW2 + + The following artifacts are licensed by the OW2 Foundation according to the + terms of http://asm.ow2.org/license.html + + org.ow2.asm:asm-commons + org.ow2.asm:asm + + ------ + Apache + + The following artifacts are ASL2 licensed. + + org.apache.taglibs:taglibs-standard-spec + org.apache.taglibs:taglibs-standard-impl + + ------ + MortBay + + The following artifacts are ASL2 licensed. Based on selected classes from + following Apache Tomcat jars, all ASL2 licensed. + + org.mortbay.jasper:apache-jsp + org.apache.tomcat:tomcat-jasper + org.apache.tomcat:tomcat-juli + org.apache.tomcat:tomcat-jsp-api + org.apache.tomcat:tomcat-el-api + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-api + org.apache.tomcat:tomcat-util-scan + org.apache.tomcat:tomcat-util + + org.mortbay.jasper:apache-el + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-el-api + + ------ + Mortbay + + The following artifacts are CDDL + GPLv2 with classpath exception. + + https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + org.eclipse.jetty.toolchain:jetty-schemas + + ------ + Assorted + + The UnixCrypt.java code implements the one way cryptography used by + Unix systems for simple password protection. Copyright 1996 Aki Yoshida, + modified April 2001 by Iris Van den Broeke, Daniel Deville. + Permission to use, copy, modify and distribute UnixCrypt + for non-commercial or commercial purposes and without fee is + granted provided that the copyright notice appears in all copies./ + +The binary distribution of this product bundles binaries of +Snappy for Java 1.0.4.1, +which has the following notices: + * This product includes software developed by Google + Snappy: http://code.google.com/p/snappy/ (New BSD License) + + This product includes software developed by Apache + PureJavaCrc32C from apache-hadoop-common http://hadoop.apache.org/ + (Apache 2.0 license) + + This library contains statically linked libstdc++. This inclusion is allowed by + "GCC RUntime Library Exception" + http://gcc.gnu.org/onlinedocs/libstdc++/manual/license.html + + == Contributors == + * Tatu Saloranta + * Providing benchmark suite + * Alec Wysoker + * Performance and memory usage improvement + +The binary distribution of this product bundles binaries of +Xerces2 Java Parser 2.9.1, +which has the following notices: + * ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Apache Xerces Java + Copyright 1999-2007 The Apache Software Foundation + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +Apache Commons Collections +Copyright 2001-2015 The Apache Software Foundation + +Apache Commons Configuration +Copyright 2001-2008 The Apache Software Foundation + +Apache Jakarta Commons Digester +Copyright 2001-2006 The Apache Software Foundation + +Apache Commons BeanUtils +Copyright 2000-2008 The Apache Software Foundation + +ApacheDS Protocol Kerberos Codec +Copyright 2003-2013 The Apache Software Foundation + +ApacheDS I18n +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory API ASN.1 API +Copyright 2003-2013 The Apache Software Foundation + +Apache Directory LDAP API Utilities +Copyright 2003-2013 The Apache Software Foundation + +Curator Client +Copyright 2011-2015 The Apache Software Foundation + +htrace-core +Copyright 2015 The Apache Software Foundation + + ========================================================================= + == NOTICE file corresponding to section 4(d) of the Apache License, == + == Version 2.0, in this case for the Apache Xerces Java distribution. == + ========================================================================= + + Portions of this software were originally based on the following: + - software copyright (c) 1999, IBM Corporation., http://www.ibm.com. + - software copyright (c) 1999, Sun Microsystems., http://www.sun.com. + - voluntary contributions made by Paul Eng on behalf of the + Apache Software Foundation that were originally developed at iClick, Inc., + software copyright (c) 1999. + +# Jackson JSON processor + +Jackson is a high-performance, Free/Open Source JSON processing library. +It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has +been in development since 2007. +It is currently developed by a community of developers, as well as supported +commercially by FasterXML.com. + +## Licensing + +Jackson core and extension components may licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +## Credits + +A list of contributors may be found from CREDITS file, which is included +in some artifacts (usually source distributions); but is always available +from the source code management (SCM) system project uses. + +Apache HttpCore +Copyright 2005-2017 The Apache Software Foundation + +Curator Recipes +Copyright 2011-2015 The Apache Software Foundation + +Curator Framework +Copyright 2011-2015 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2016 The Apache Software Foundation + +This product includes software from the Spring Framework, +under the Apache License 2.0 (see: StringUtils.containsWhitespace()) + +Apache Commons Math +Copyright 2001-2015 The Apache Software Foundation + +This product includes software developed for Orekit by +CS Systèmes d'Information (http://www.c-s.fr/) +Copyright 2010-2012 CS Systèmes d'Information + +Apache log4j +Copyright 2007 The Apache Software Foundation + +# Compress LZF + +This library contains efficient implementation of LZF compression format, +as well as additional helper classes that build on JDK-provided gzip (deflat) +codec. + +Library is licensed under Apache License 2.0, as per accompanying LICENSE file. + +## Credit + +Library has been written by Tatu Saloranta (tatu.saloranta@iki.fi). +It was started at Ning, inc., as an official Open Source process used by +platform backend, but after initial versions has been developed outside of +Ning by supporting community. + +Other contributors include: + +* Jon Hartlaub (first versions of streaming reader/writer; unit tests) +* Cedrik Lime: parallel LZF implementation + +Various community members have contributed bug reports, and suggested minor +fixes; these can be found from file "VERSION.txt" in SCM. + +Apache Commons Net +Copyright 2001-2012 The Apache Software Foundation + +Copyright 2011 The Netty Project + +http://www.apache.org/licenses/LICENSE-2.0 + +This product contains a modified version of 'JZlib', a re-implementation of +zlib in pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD Style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product contains a modified version of 'Webbit', a Java event based +WebSocket and HTTP server: + +This product optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + +This product optionally depends on 'SLF4J', a simple logging facade for Java, +which can be obtained at: + +This product optionally depends on 'Apache Log4J', a logging framework, +which can be obtained at: + +This product optionally depends on 'JBoss Logging', a logging framework, +which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) + * HOMEPAGE: + * http://anonsvn.jboss.org/repos/common/common-logging-spi/ + +This product optionally depends on 'Apache Felix', an open source OSGi +framework implementation, which can be obtained at: + + * LICENSE: + * license/LICENSE.felix.txt (Apache License 2.0) + * HOMEPAGE: + * http://felix.apache.org/ + +Jackson core and extension components may be licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +Apache Ivy (TM) +Copyright 2007-2014 The Apache Software Foundation + +Portions of Ivy were originally developed at +Jayasoft SARL (http://www.jayasoft.fr/) +and are licensed to the Apache Software Foundation under the +"Software Grant License Agreement" + +SSH and SFTP support is provided by the JCraft JSch package, +which is open source software, available under +the terms of a BSD style license. +The original software and related information is available +at http://www.jcraft.com/jsch/. + + +ORC Core +Copyright 2013-2018 The Apache Software Foundation + +Apache Commons Lang +Copyright 2001-2011 The Apache Software Foundation + +ORC MapReduce +Copyright 2013-2018 The Apache Software Foundation + +Apache Parquet Format +Copyright 2017 The Apache Software Foundation + +Arrow Vectors +Copyright 2017 The Apache Software Foundation + +Arrow Format +Copyright 2017 The Apache Software Foundation + +Arrow Memory +Copyright 2017 The Apache Software Foundation + +Apache Commons CLI +Copyright 2001-2009 The Apache Software Foundation + +Google Guice - Extensions - Servlet +Copyright 2006-2011 Google, Inc. + +Apache Commons IO +Copyright 2002-2012 The Apache Software Foundation + +Google Guice - Core Library +Copyright 2006-2011 Google, Inc. + +mesos +Copyright 2017 The Apache Software Foundation + +Apache Parquet Hadoop Bundle (Incubating) +Copyright 2015 The Apache Software Foundation + +Hive Query Language +Copyright 2016 The Apache Software Foundation + +Apache Extras Companion for log4j 1.2. +Copyright 2007 The Apache Software Foundation + +Hive Metastore +Copyright 2016 The Apache Software Foundation + +Apache Commons Logging +Copyright 2003-2013 The Apache Software Foundation + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, == +== Version 2.0, in this case for the DataNucleus distribution. == +========================================================================= + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Erik Bengtson +Andy Jefferson + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== + +=================================================================== +This product includes software developed by many individuals, +including the following: +=================================================================== +Andy Jefferson +Erik Bengtson +Joerg von Frantzius +Marco Schulze + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Barry Haddow +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Anton Troshin (Timesten) + +=================================================================== +This product also includes software developed by the TJDO project +(http://tjdo.sourceforge.net/). +=================================================================== + +=================================================================== +This product also includes software developed by the Apache Commons project +(http://commons.apache.org/). +=================================================================== + +Apache Commons Pool +Copyright 1999-2009 The Apache Software Foundation + +Apache Commons DBCP +Copyright 2001-2010 The Apache Software Foundation + +Apache Java Data Objects (JDO) +Copyright 2005-2006 The Apache Software Foundation + +Apache Jakarta HttpClient +Copyright 1999-2007 The Apache Software Foundation + +Apache HttpClient +Copyright 1999-2017 The Apache Software Foundation + +Apache Commons Codec +Copyright 2002-2014 The Apache Software Foundation + +src/test/org/apache/commons/codec/language/DoubleMetaphoneTest.java +contains test data from http://aspell.net/test/orig/batch0.tab. +Copyright (C) 2002 Kevin Atkinson (kevina@gnu.org) + +=============================================================================== + +The content of package org.apache.commons.codec.language.bm has been translated +from the original php source code available at http://stevemorse.org/phoneticinfo.htm +with permission from the original authors. +Original source copyright: +Copyright (c) 2008 Alexander Beider & Stephen P. Morse. + +============================================================================= += NOTICE file corresponding to section 4d of the Apache License Version 2.0 = +============================================================================= +This product includes software developed by +Joda.org (http://www.joda.org/). + +=================================================================== +This product has included contributions from some individuals, +including the following: +=================================================================== +Joerg von Frantzius +Thomas Marti +Barry Haddow +Marco Schulze +Ralph Ullrich +David Ezzio +Brendan de Beer +David Eaves +Martin Taal +Tony Lai +Roland Szabo +Marcus Mennemeier +Xuan Baldauf +Eric Sultan + +Apache Thrift +Copyright 2006-2010 The Apache Software Foundation. + +========================================================================= +== NOTICE file corresponding to section 4(d) of the Apache License, +== Version 2.0, in this case for the Apache Derby distribution. +== +== DO NOT EDIT THIS FILE DIRECTLY. IT IS GENERATED +== BY THE buildnotice TARGET IN THE TOP LEVEL build.xml FILE. +== +========================================================================= + +Apache Derby +Copyright 2004-2015 The Apache Software Foundation + +========================================================================= + +Portions of Derby were originally developed by +International Business Machines Corporation and are +licensed to the Apache Software Foundation under the +"Software Grant and Corporate Contribution License Agreement", +informally known as the "Derby CLA". +The following copyright notice(s) were affixed to portions of the code +with which this file is now or was at one time distributed +and are placed here unaltered. + +(C) Copyright 1997,2004 International Business Machines Corporation. All rights reserved. + +(C) Copyright IBM Corp. 2003. + +The portion of the functionTests under 'nist' was originally +developed by the National Institute of Standards and Technology (NIST), +an agency of the United States Department of Commerce, and adapted by +International Business Machines Corporation in accordance with the NIST +Software Acknowledgment and Redistribution document at +http://www.itl.nist.gov/div897/ctg/sql_form.htm + +The JDBC apis for small devices and JDBC3 (under java/stubs/jsr169 and +java/stubs/jdbc3) were produced by trimming sources supplied by the +Apache Harmony project. In addition, the Harmony SerialBlob and +SerialClob implementations are used. The following notice covers the Harmony sources: + +Portions of Harmony were originally developed by +Intel Corporation and are licensed to the Apache Software +Foundation under the "Software Grant and Corporate Contribution +License Agreement", informally known as the "Intel Harmony CLA". + +The Derby build relies on source files supplied by the Apache Felix +project. The following notice covers the Felix files: + + Apache Felix Main + Copyright 2008 The Apache Software Foundation + + I. Included Software + + This product includes software developed at + The Apache Software Foundation (http://www.apache.org/). + Licensed under the Apache License 2.0. + + This product includes software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + This product includes software from http://kxml.sourceforge.net. + Copyright (c) 2002,2003, Stefan Haustein, Oberhausen, Rhld., Germany. + Licensed under BSD License. + + II. Used Software + + This product uses software developed at + The OSGi Alliance (http://www.osgi.org/). + Copyright (c) OSGi Alliance (2000, 2007). + Licensed under the Apache License 2.0. + + III. License Summary + - Apache License 2.0 + - BSD License + +The Derby build relies on jar files supplied by the Apache Lucene +project. The following notice covers the Lucene files: + +Apache Lucene +Copyright 2013 The Apache Software Foundation + +Includes software from other Apache Software Foundation projects, +including, but not limited to: + - Apache Ant + - Apache Jakarta Regexp + - Apache Commons + - Apache Xerces + +ICU4J, (under analysis/icu) is licensed under an MIT styles license +and Copyright (c) 1995-2008 International Business Machines Corporation and others + +Some data files (under analysis/icu/src/data) are derived from Unicode data such +as the Unicode Character Database. See http://unicode.org/copyright.html for more +details. + +Brics Automaton (under core/src/java/org/apache/lucene/util/automaton) is +BSD-licensed, created by Anders Møller. See http://www.brics.dk/automaton/ + +The levenshtein automata tables (under core/src/java/org/apache/lucene/util/automaton) were +automatically generated with the moman/finenight FSA library, created by +Jean-Philippe Barrette-LaPierre. This library is available under an MIT license, +see http://sites.google.com/site/rrettesite/moman and +http://bitbucket.org/jpbarrette/moman/overview/ + +The class org.apache.lucene.util.WeakIdentityMap was derived from +the Apache CXF project and is Apache License 2.0. + +The Google Code Prettify is Apache License 2.0. +See http://code.google.com/p/google-code-prettify/ + +JUnit (junit-4.10) is licensed under the Common Public License v. 1.0 +See http://junit.sourceforge.net/cpl-v10.html + +This product includes code (JaspellTernarySearchTrie) from Java Spelling Checkin +g Package (jaspell): http://jaspell.sourceforge.net/ +License: The BSD License (http://www.opensource.org/licenses/bsd-license.php) + +The snowball stemmers in + analysis/common/src/java/net/sf/snowball +were developed by Martin Porter and Richard Boulton. +The snowball stopword lists in + analysis/common/src/resources/org/apache/lucene/analysis/snowball +were developed by Martin Porter and Richard Boulton. +The full snowball package is available from + http://snowball.tartarus.org/ + +The KStem stemmer in + analysis/common/src/org/apache/lucene/analysis/en +was developed by Bob Krovetz and Sergio Guzman-Lara (CIIR-UMass Amherst) +under the BSD-license. + +The Arabic,Persian,Romanian,Bulgarian, and Hindi analyzers (common) come with a default +stopword list that is BSD-licensed created by Jacques Savoy. These files reside in: +analysis/common/src/resources/org/apache/lucene/analysis/ar/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/fa/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/ro/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/bg/stopwords.txt, +analysis/common/src/resources/org/apache/lucene/analysis/hi/stopwords.txt +See http://members.unine.ch/jacques.savoy/clef/index.html. + +The German,Spanish,Finnish,French,Hungarian,Italian,Portuguese,Russian and Swedish light stemmers +(common) are based on BSD-licensed reference implementations created by Jacques Savoy and +Ljiljana Dolamic. These files reside in: +analysis/common/src/java/org/apache/lucene/analysis/de/GermanLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/de/GermanMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/es/SpanishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fi/FinnishLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/fr/FrenchMinimalStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/hu/HungarianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/it/ItalianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/pt/PortugueseLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/ru/RussianLightStemmer.java +analysis/common/src/java/org/apache/lucene/analysis/sv/SwedishLightStemmer.java + +The Stempel analyzer (stempel) includes BSD-licensed software developed +by the Egothor project http://egothor.sf.net/, created by Leo Galambos, Martin Kvapil, +and Edmond Nolan. + +The Polish analyzer (stempel) comes with a default +stopword list that is BSD-licensed created by the Carrot2 project. The file resides +in stempel/src/resources/org/apache/lucene/analysis/pl/stopwords.txt. +See http://project.carrot2.org/license.html. + +The SmartChineseAnalyzer source code (smartcn) was +provided by Xiaoping Gao and copyright 2009 by www.imdict.net. + +WordBreakTestUnicode_*.java (under modules/analysis/common/src/test/) +is derived from Unicode data such as the Unicode Character Database. +See http://unicode.org/copyright.html for more details. + +The Morfologik analyzer (morfologik) includes BSD-licensed software +developed by Dawid Weiss and Marcin Miłkowski (http://morfologik.blogspot.com/). + +Morfologik uses data from Polish ispell/myspell dictionary +(http://www.sjp.pl/slownik/en/) licenced on the terms of (inter alia) +LGPL and Creative Commons ShareAlike. + +Morfologic includes data from BSD-licensed dictionary of Polish (SGJP) +(http://sgjp.pl/morfeusz/) + +Servlet-api.jar and javax.servlet-*.jar are under the CDDL license, the original +source code for this can be found at http://www.eclipse.org/jetty/downloads.php + +=========================================================================== +Kuromoji Japanese Morphological Analyzer - Apache Lucene Integration +=========================================================================== + +This software includes a binary and/or source version of data from + + mecab-ipadic-2.7.0-20070801 + +which can be obtained from + + http://atilika.com/releases/mecab-ipadic/mecab-ipadic-2.7.0-20070801.tar.gz + +or + + http://jaist.dl.sourceforge.net/project/mecab/mecab-ipadic/2.7.0-20070801/mecab-ipadic-2.7.0-20070801.tar.gz + +=========================================================================== +mecab-ipadic-2.7.0-20070801 Notice +=========================================================================== + +Nara Institute of Science and Technology (NAIST), +the copyright holders, disclaims all warranties with regard to this +software, including all implied warranties of merchantability and +fitness, in no event shall NAIST be liable for +any special, indirect or consequential damages or any damages +whatsoever resulting from loss of use, data or profits, whether in an +action of contract, negligence or other tortuous action, arising out +of or in connection with the use or performance of this software. + +A large portion of the dictionary entries +originate from ICOT Free Software. The following conditions for ICOT +Free Software applies to the current dictionary as well. + +Each User may also freely distribute the Program, whether in its +original form or modified, to any third party or parties, PROVIDED +that the provisions of Section 3 ("NO WARRANTY") will ALWAYS appear +on, or be attached to, the Program, which is distributed substantially +in the same form as set out herein and that such intended +distribution, if actually made, will neither violate or otherwise +contravene any of the laws and regulations of the countries having +jurisdiction over the User or the intended distribution itself. + +NO WARRANTY + +The program was produced on an experimental basis in the course of the +research and development conducted during the project and is provided +to users as so produced on an experimental basis. Accordingly, the +program is provided without any warranty whatsoever, whether express, +implied, statutory or otherwise. The term "warranty" used herein +includes, but is not limited to, any warranty of the quality, +performance, merchantability and fitness for a particular purpose of +the program and the nonexistence of any infringement or violation of +any right of any third party. + +Each user of the program will agree and understand, and be deemed to +have agreed and understood, that there is no warranty whatsoever for +the program and, accordingly, the entire risk arising from or +otherwise connected with the program is assumed by the user. + +Therefore, neither ICOT, the copyright holder, or any other +organization that participated in or was otherwise related to the +development of the program and their respective officials, directors, +officers and other employees shall be held liable for any and all +damages, including, without limitation, general, special, incidental +and consequential damages, arising out of or otherwise in connection +with the use or inability to use the program or any product, material +or result produced or otherwise obtained by using the program, +regardless of whether they have been advised of, or otherwise had +knowledge of, the possibility of such damages at any time during the +project or thereafter. Each user will be deemed to have agreed to the +foregoing by his or her commencement of use of the program. The term +"use" as used herein includes, but is not limited to, the use, +modification, copying and distribution of the program and the +production of secondary products from the program. + +In the case where the program, whether in its original form or +modified, was distributed or delivered to or received by a user from +any person, organization or entity other than ICOT, unless it makes or +grants independently of ICOT any specific warranty to the user in +writing, such person, organization or entity, will also be exempted +from and not be held liable to the user for any such damages as noted +above as far as the program is concerned. + +The Derby build relies on a jar file supplied by the JSON Simple +project, hosted at https://code.google.com/p/json-simple/. +The JSON simple jar file is licensed under the Apache 2.0 License. + +Hive CLI +Copyright 2016 The Apache Software Foundation + +Hive JDBC +Copyright 2016 The Apache Software Foundation + + +Chill is a set of Scala extensions for Kryo. +Copyright 2012 Twitter, Inc. + +Third Party Dependencies: + +Kryo 2.17 +BSD 3-Clause License +http://code.google.com/p/kryo + +Commons-Codec 1.7 +Apache Public License 2.0 +http://hadoop.apache.org + + + +Breeze is distributed under an Apache License V2.0 (See LICENSE) + +=============================================================================== + +Proximal algorithms outlined in Proximal.scala (package breeze.optimize.proximal) +are based on https://github.com/cvxgrp/proximal (see LICENSE for details) and distributed with +Copyright (c) 2014 by Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +QuadraticMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2014, Debasish Das (Verizon), all rights reserved. + +=============================================================================== + +NonlinearMinimizer class in package breeze.optimize.proximal is distributed with Copyright (c) +2015, Debasish Das (Verizon), all rights reserved. + + +stream-lib +Copyright 2016 AddThis + +This product includes software developed by AddThis. + +This product also includes code adapted from: + +Apache Solr (http://lucene.apache.org/solr/) +Copyright 2014 The Apache Software Foundation + +Apache Mahout (http://mahout.apache.org/) +Copyright 2014 The Apache Software Foundation diff --git a/README.md b/README.md old mode 100644 new mode 100755 index 37ead91..32a7127 --- a/README.md +++ b/README.md @@ -1,2 +1,96 @@ -# oraf -Optimized RAndom Forests +# ORaF (Optimized Random Forest on Apache Spark) + +ORaF is a library which aims to improve the performance of distributed random forest training on large datasets in Spark MLlib. To optimize the training process, we introduce a local training phase in which we complete the tree induction of sufficiently small nodes in-memory on a single executor. Additionally, we group these nodes into larger and more balanced local training tasks using bin packing and effectively schedule the processing of these tasks into batches by computing their expected duration. Our algorithm speeds up the training process significantly (**~100x on our data**), enables the training of deeper decision trees and mitigates runtime memory issues. + +A thorough explanation of the used methods and experiments can be found in [Distributed Algorithms for Decision Forest Training in the Network Traffic Classification Task](https://dspace.cvut.cz/bitstream/handle/10467/76092/F3-BP-2018-Starosta-Radek-thesis.pdf?sequence=-1&isAllowed=y) thesis. + +## Installation + +Use `mvn package` to build the project to jar file in Maven. You can also download prebuilt jar file in the releases tab. + +We plan to add ORaF to https://spark-packages.org/ soon. + +Currently ORaF depends and was tested on Apache Spark 2.4.0. We will try to update the dependency regularly to more recent Spark versions. If you would like to try ORaF on a version of Spark that we do not officialy support yet, feel free to try it. In our experience, the jar file usually works even on slightly different minor or patch versions of Spark. + +## Example + +The interface is almost identical to the original RandomForestClassifier / RandomForestRegressor classes (see [RandomForestClassifier](https://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier)). It includes all of the fundamental methods for training, saving / loading models and inference, but we don't support computing classification probabilities and feature importances (see #impurity_stats) + + import org.apache.spark.ml.classification.OptimizedRandomForestClassifier + + val orf = new OptimizedRandomForestClassifier() + .setImpurity("entropy") + .setMaxDepth(30) + .setNumTrees(5) + .setMaxMemoryMultiplier(2.0) + + // trainingData is a Dataset containing "label" (Double) and "features" (ml.Vector) columns + val model = orf.fit(trainingData) + + // testData is a Dataset with a "features" (ml.Vector) column, predictions are filled into a new "prediction" column + val dataWithPredictions = model.transform(testData) + + +## Old mllib interface example + +The training interface is again identical to the mllib RandomForest class (see [MLlib ensembles](https://spark.apache.org/docs/latest/mllib-ensembles.html)). This interface returns the same models as the new ml interface (OptimizedRandomForestClassificationModel / OptimizedRandomForestRegressionModel), as the old model is unable to store trees deeper than 30 levels because of node indexing. + + import org.apache.spark.mllib.tree.configuration.OptimizedForestStrategy + import org.apache.spark.mllib.tree.OptimizedRandomForest + + val strategy = new OptimizedForestStrategy(algo = Classification, + impurity = Entropy, + maxDepth = 30, + numClasses = 3, + numTrees = 5 + maxMemoryMultiplier = 2.0) + + // trainingData is an RDD of LabeledPoints + val (model, statistics) = OptimizedRandomForest.trainClassifier( + input = trainingData, + featureSubsetStrategy = "sqrt", + strategy = strategy, + numTrees = 5) + + // testData is an RDD of mllib.Vectors + val dataWithPredictions = testData.map { point => + (point, model.predict(point.features)) + } + +## Additional parameters + +These parameters can be set in the OptimizedForestStrategy object (RDD mllib interface), or in the OptimizedRandomForestClassifier / Regressor class (DataFrame ml interface). + +- maxMemoryMultiplier (Double) + - This parameter affects the threshold deciding whether a task is small enough to be trained locally. It is used to multiply the estimate of the tasks memory consumption (the larger the value, the smaller the task has to be in order for it to be selected for local training). Default value is 4.0, which is very conservative. Increasing this parameter can also help balancing the tasks if your dataset isn't very large and the training doesn't utilize the cluster fully. +- timePredictionStrategy (TimePredictionStrategy) + - Logic behind the task scheduling. By default, the tasks are sorted by the number of data points, which works well in most cases. During our experiments, we found that the entropy in the given node also plays a large role in the final training time of the nodes, so in our inhouse implementation we use a linear regressor combining both task size and entropy (see #thesis). +- localTrainingAlgorithm (LocalTrainingAlgorithm) + - Implementation of the local decision tree training. Default is an implementation by Siddharth Murching ([smurching](https://github.com/smurching), [SPARK-3162](https://github.com/apache/spark/pull/19433)) which is based on the Yggdrasil algorithm. In the current state, this implementation is probably not the most efficient solution, because it doesn't fully utilize the advantages of the columnar format, but still requires the data to be transformed into it. +- maxTasksPerBin (Int) + - This parameter can be used to limit the total number of tasks packed into one bin (the batch of training tasks sent to a single executor). By default, the amount of tasks is not limited and the algorithm tries to make the bins as large as possible. +- customSplits (Array[Array[Double]]) + - The default discretization logic that is hardcoded into the current random forest implementation can work poorly on some datasets (i.e. when classes are highly imbalanced), so this allows the users to pass in their own precomputed threshold values for individual features. + +## Notable differences + +### Removal of ImpurityStats from final models + +We have decided to remove the ImpurityStats objects in the finalized version of the tree model. In classification, the final predicted value is the majority class in the appropriate leaf node and we don't compute the individual class probabilities. In most cases, this does not have any significant impact on the classification performance [1], but helped us mitigate some of the memory management issues we've encountered with larger datasets. + +[1] L. Breiman. Bagging predictors. Technical Report 421, University of California Berkeley, 1994. + +### Removal of tree depth limit + +As the trees are now eventually trained locally on one executor core, we no longer need to have a globally unique index for every node. This means we can theoretically train the complete subtree for every node, although this would probably be too time intensive for large datasets. + +Because the improved algorithm allows training trees deeper than 30 levels which cannot be represented in the 1.x version of the MLlib decision tree models, the old mllib interface also returns the new ml models, which include a convenience predict method for the old mllib Vectors. (see #mlexample) + +### NodeIdCache enabled by default + +Additionally, our method relies heavily on the presence of NodeIdCache, which is used to quickly pair data points with their respective tree nodes. We have decided to enable it by default, as it provides a significant speed up by sacrificing some memory. + +## Authors + +* Radek Starosta (rstarost@cisco.com, github: @rstarosta) +* Jan Brabec (janbrabe@cisco.com, github: @BrabecJan, twitter: @BrabecJan91) diff --git a/pom.xml b/pom.xml new file mode 100755 index 0000000..38527a4 --- /dev/null +++ b/pom.xml @@ -0,0 +1,163 @@ + + 4.0.0 + com.cisco.cognitive + oraf + 1.1.0 + ${project.artifactId} + Optimized Random Forest + 2018 + + + 1.8 + 1.8 + UTF-8 + 2.11.8 + 2.11 + 2.11 + 4.2.0 + 2.4.0 + + + + + org.scala-lang + scala-library + ${scala.version} + + + + + junit + junit + 4.12 + test + + + org.scalatest + scalatest_${scala.compat.version} + 3.0.3 + test + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.depVersion} + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.depVersion} + test-jar + test + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.depVersion} + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.depVersion} + + + org.apache.spark + spark-graphx_${scala.binary.version} + ${project.depVersion} + + + org.apache.spark + spark-mllib-local_${scala.binary.version} + ${project.depVersion} + + + org.apache.spark + spark-mllib-local_${scala.binary.version} + ${project.depVersion} + test-jar + test + + + org.apache.spark + spark-mllib_${scala.binary.version} + ${project.depVersion} + + + org.apache.spark + spark-mllib_${scala.binary.version} + ${project.depVersion} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.depVersion} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.depVersion} + test-jar + test + + + + + src/main/scala + src/test/scala + + + + net.alchim31.maven + scala-maven-plugin + 3.3.2 + + + + compile + testCompile + + + + -dependencyfile + ${project.build.directory}/.scala_dependencies + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + 2.21.0 + + + true + + + + org.scalatest + scalatest-maven-plugin + 2.0.0 + + ${project.build.directory}/surefire-reports + . + TestSuiteReport.txt + + + + + test + + test + + + + + + + diff --git a/src/main/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifier.scala b/src/main/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifier.scala new file mode 100755 index 0000000..a84440f --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifier.scala @@ -0,0 +1,341 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.hadoop.fs.Path +import org.apache.spark.annotation.Since +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree.OptimizedDecisionTreeModelReadWrite.NodeData +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.impl.OptimizedRandomForest +import org.apache.spark.ml.util.Instrumentation.instrumented +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.tree.configuration.{TimePredictionStrategy, Algo => OldAlgo, OptimizedForestStrategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.mllib.linalg.{Vector => OldVector} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Dataset +import org.json4s.JsonDSL._ +import org.json4s.{DefaultFormats, JObject} + + +/** + * Decision tree learning algorithm (http://en.wikipedia.org/wiki/Decision_tree_learning) + * for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@Since("1.4.0") +class OptimizedDecisionTreeClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) + extends ProbabilisticClassifier[Vector, OptimizedDecisionTreeClassifier, + OptimizedDecisionTreeClassificationModel] + with OptimizedDecisionTreeClassifierParams with DefaultParamsWritable { + + @Since("1.4.0") + def this() = this(Identifiable.randomUID("odtc")) + + // Override parameter setters from parent trait for Java API compatibility. + + /** @group setParam */ + @Since("1.4.0") + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMaxBins(value: Int): this.type = set(maxBins, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + + /** @group expertSetParam */ + @Since("1.4.0") + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + + /** @group expertSetParam */ + @Since("1.4.0") + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be at least 1. + * (default = 10) + * @group setParam + */ + @Since("1.4.0") + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** @group setParam */ + @Since("1.4.0") + override def setImpurity(value: String): this.type = set(impurity, value) + + /** @group setParam */ + @Since("1.6.0") + override def setSeed(value: Long): this.type = set(seed, value) + + /** @group setParam */ + @Since("2.0.0") + override def setMaxMemoryMultiplier(value: Double): this.type = set(maxMemoryMultiplier, value) + + /** @group setParam */ + @Since("2.0.0") + override def setTimePredictionStrategy(value: TimePredictionStrategy) = timePredictionStrategy = value + + /** @group setParam */ + @Since("2.0.0") + override def setMaxTasksPerBin(value: Int): this.type + = set(maxTasksPerBin, value) + + /** @group setParam */ + @Since("2.0.0") + override def setCustomSplits(value: Option[Array[Array[Double]]]) = customSplits = value + + /** @group setParam */ + @Since("2.0.0") + override def setLocalTrainingAlgorithm(value: LocalTrainingAlgorithm) = localTrainingAlgorithm = value + + override protected def train(dataset: Dataset[_]): OptimizedDecisionTreeClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val numClasses: Int = getNumClasses(dataset) + + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) + val strategy = + getOldStrategy(categoricalFeatures, numClasses) + + instr.logParams(this, params: _*) + + val trees = OptimizedRandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = $(seed), instr = Some(instr), parentUID = Some(uid))._1 + + trees.head.asInstanceOf[OptimizedDecisionTreeClassificationModel] + } + + private[ml] def train(data: RDD[LabeledPoint], + oldStrategy: OldStrategy): OptimizedDecisionTreeClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(data) + instr.logParams(this, params: _*) + + val trees = OptimizedRandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 0L, instr = Some(instr), parentUID = Some(uid))._1 + + trees.head.asInstanceOf[OptimizedDecisionTreeClassificationModel] + } + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + private[ml] def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, + subsamplingRate = 1.0) + } + + @Since("1.4.1") + override def copy(extra: ParamMap): OptimizedDecisionTreeClassifier = defaultCopy(extra) +} + +@Since("1.4.0") +object OptimizedDecisionTreeClassifier + extends DefaultParamsReadable[OptimizedDecisionTreeClassifier] { + /** Accessor for supported impurities: entropy, gini */ + @Since("1.4.0") + final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities + + @Since("2.0.0") + override def load(path: String): OptimizedDecisionTreeClassifier = super.load(path) +} + +/** + * Decision tree model (http://en.wikipedia.org/wiki/Decision_tree_learning) for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@Since("1.4.0") +class OptimizedDecisionTreeClassificationModel private[ml] ( + @Since("1.4.0")override val uid: String, + @Since("1.4.0")override val rootNode: OptimizedNode, + @Since("1.6.0")override val numFeatures: Int, + @Since("1.5.0")override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, OptimizedDecisionTreeClassificationModel] + with OptimizedDecisionTreeModel with OptimizedDecisionTreeClassifierParams with MLWritable with Serializable { + + require(rootNode != null, + "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") + + /** + * Construct a decision tree classification model. + * @param rootNode Root node of tree, with other nodes attached. + */ + private[ml] def this(rootNode: OptimizedNode, numFeatures: Int, numClasses: Int) = + this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses) + + override def predict(features: Vector): Double = { + rootNode.predictImpl(features).prediction + } + + def predict(features: OldVector): Double = { + rootNode.predictImpl(Vectors.dense(features.toArray)).prediction + } + + def oldPredict(vector: OldVector): Double = { + makePredictionForOldVector(rootNode, vector) + } + + private def makePredictionForOldVector(topNode: OptimizedNode, features: OldVector): Double = { + topNode match { + case node: OptimizedLeafNode => + node.prediction + case node: OptimizedInternalNode => + val shouldGoLeft = node.split match { + case split: ContinuousSplit => + features(split.featureIndex) <= split.threshold + + case split: CategoricalSplit => + // leftCategories will sort every time, rather use + split.leftCategories.contains(features(split.featureIndex)) + } + + if (shouldGoLeft) { + makePredictionForOldVector(node.leftChild, features) + } else { + makePredictionForOldVector(node.rightChild, features) + } + + case _ => throw new RuntimeException("Unexpected error in OptimizedDecisionTreeClassificationModel, unknown Node type.") + } + } + + // TODO: Make sure this is correct + override protected def predictRaw(features: Vector): Vector = { + val predictions = Array.fill[Double](numClasses)(0.0) + predictions(rootNode.predictImpl(features).prediction.toInt) = 1.0 + + Vectors.dense(predictions) + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv) + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + + @Since("1.4.0") + override def copy(extra: ParamMap): OptimizedDecisionTreeClassificationModel = { + copyValues(new OptimizedDecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses), extra) + .setParent(parent) + } + + @Since("1.4.0") + override def toString: String = { + s"DecisionTreeClassificationModel (uid=$uid) of depth $depth with $numNodes nodes" + } + + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ + override private[spark] def toOld: OldDecisionTreeModel = { + new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification) + } + + @Since("2.0.0") + override def write: MLWriter = + new OptimizedDecisionTreeClassificationModel.DecisionTreeClassificationModelWriter(this) +} + +@Since("2.0.0") +object OptimizedDecisionTreeClassificationModel extends MLReadable[OptimizedDecisionTreeClassificationModel] { + + @Since("2.0.0") + override def read: MLReader[OptimizedDecisionTreeClassificationModel] = + new DecisionTreeClassificationModelReader + + @Since("2.0.0") + override def load(path: String): OptimizedDecisionTreeClassificationModel = super.load(path) + + private[OptimizedDecisionTreeClassificationModel] + class DecisionTreeClassificationModelWriter(instance: OptimizedDecisionTreeClassificationModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numClasses" -> instance.numClasses) + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val (nodeData, _) = NodeData.build(instance.rootNode, 0) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(nodeData).write.parquet(dataPath) + } + } + + private class DecisionTreeClassificationModelReader + extends MLReader[OptimizedDecisionTreeClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[OptimizedDecisionTreeClassificationModel].getName + + override def load(path: String): OptimizedDecisionTreeClassificationModel = { + implicit val format = DefaultFormats + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numClasses = (metadata.metadata \ "numClasses").extract[Int] + val root = OptimizedDecisionTreeModelReadWrite.loadTreeNodes(path, metadata, sparkSession) + val model = new OptimizedDecisionTreeClassificationModel(metadata.uid, root, numFeatures, numClasses) + metadata.getAndSetParams(model) + model + } + } + + /** Convert a model from the old API */ + private[ml] def fromOld( + oldModel: OldDecisionTreeModel, + parent: OptimizedDecisionTreeClassifier, + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): OptimizedDecisionTreeClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, + s"Cannot convert non-classification DecisionTreeModel (old API) to" + + s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") + val rootNode = OptimizedNode.fromOld(oldModel.topNode, categoricalFeatures) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") + // Can't infer number of features from old model, so default to -1 + new OptimizedDecisionTreeClassificationModel(uid, rootNode, numFeatures, -1) + } +} diff --git a/src/main/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifier.scala b/src/main/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifier.scala new file mode 100755 index 0000000..4079905 --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifier.scala @@ -0,0 +1,367 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.impl.OptimizedRandomForest +import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.tree.configuration.{TimePredictionStrategy, Algo => OldAlgo} +import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} +import org.apache.spark.mllib.linalg.{Vector => OldVector} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset} +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ +import org.apache.spark.sql.functions._ + + +/** + * Random Forest learning algorithm for + * classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + */ +@Since("1.4.0") +class OptimizedRandomForestClassifier @Since("1.4.0") ( + @Since("1.4.0") override val uid: String) + extends ProbabilisticClassifier[Vector, OptimizedRandomForestClassifier, + OptimizedRandomForestClassificationModel] + with OptimizedRandomForestClassifierParams with DefaultParamsWritable { + + @Since("1.4.0") + def this() = this(Identifiable.randomUID("orfc")) + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeClassifierParams: + + /** @group setParam */ + @Since("1.4.0") + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMaxBins(value: Int): this.type = set(maxBins, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + + /** @group expertSetParam */ + @Since("1.4.0") + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + + /** @group expertSetParam */ + @Since("1.4.0") + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be at least 1. + * (default = 10) + * @group setParam + */ + @Since("1.4.0") + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** @group setParam */ + @Since("1.4.0") + override def setImpurity(value: String): this.type = set(impurity, value) + + // Parameters from TreeEnsembleParams: + + /** @group setParam */ + @Since("1.4.0") + override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + + /** @group setParam */ + @Since("1.4.0") + override def setSeed(value: Long): this.type = set(seed, value) + + // Parameters from RandomForestParams: + + /** @group setParam */ + @Since("1.4.0") + override def setNumTrees(value: Int): this.type = set(numTrees, value) + + /** @group setParam */ + @Since("1.4.0") + override def setFeatureSubsetStrategy(value: String): this.type = + set(featureSubsetStrategy, value) + + /** @group setParam */ + @Since("2.0.0") + override def setMaxMemoryMultiplier(value: Double): this.type = set(maxMemoryMultiplier, value) + + /** @group setParam */ + @Since("2.0.0") + override def setTimePredictionStrategy(value: TimePredictionStrategy) = timePredictionStrategy = value + + /** @group setParam */ + @Since("2.0.0") + override def setMaxTasksPerBin(value: Int): this.type + = set(maxTasksPerBin, value) + + /** @group setParam */ + @Since("2.0.0") + override def setCustomSplits(value: Option[Array[Array[Double]]]) = customSplits = value + + /** @group setParam */ + @Since("2.0.0") + override def setLocalTrainingAlgorithm(value: LocalTrainingAlgorithm) = localTrainingAlgorithm = value + + override protected def train(dataset: Dataset[_]): OptimizedRandomForestClassificationModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val numClasses: Int = getNumClasses(dataset) + + if (isDefined(thresholds)) { + require($(thresholds).length == numClasses, this.getClass.getSimpleName + + ".train() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}") + } + + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses) + val strategy = + super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity) + + instr.logParams(this, labelCol, featuresCol, predictionCol, probabilityCol, rawPredictionCol, + impurity, numTrees, featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, + minInstancesPerNode, seed, subsamplingRate, thresholds, cacheNodeIds, checkpointInterval) + + val trees = OptimizedRandomForest + .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))._1 + .map(_.asInstanceOf[OptimizedDecisionTreeClassificationModel]) + + val numFeatures = oldDataset.first().features.size + instr.logNumClasses(numClasses) + instr.logNumFeatures(numFeatures) + new OptimizedRandomForestClassificationModel(uid, trees, numFeatures, numClasses) + } + + @Since("1.4.1") + override def copy(extra: ParamMap): OptimizedRandomForestClassifier = defaultCopy(extra) +} + +@Since("1.4.0") +object OptimizedRandomForestClassifier + extends DefaultParamsReadable[OptimizedRandomForestClassifier] { + /** Accessor for supported impurity settings: entropy, gini */ + @Since("1.4.0") + final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities + + /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + @Since("1.4.0") + final val supportedFeatureSubsetStrategies: Array[String] = + TreeEnsembleParams.supportedFeatureSubsetStrategies + + @Since("2.0.0") + override def load(path: String): OptimizedRandomForestClassifier = super.load(path) +} + +/** + * Random Forest model for classification. + * It supports both binary and multiclass labels, as well as both continuous and categorical + * features. + * + * @param _trees Decision trees in the ensemble. + * Warning: These have null parents. + */ +@Since("1.4.0") +class OptimizedRandomForestClassificationModel private[spark] ( + @Since("1.5.0") override val uid: String, + private val _trees: Array[OptimizedDecisionTreeClassificationModel], + @Since("1.6.0") override val numFeatures: Int, + @Since("1.5.0") override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, OptimizedRandomForestClassificationModel] + with OptimizedRandomForestClassifierParams with OptimizedTreeEnsembleModel[OptimizedDecisionTreeClassificationModel] + with MLWritable with Serializable { + + require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.") + + /** + * Construct a random forest classification model, with all trees weighted equally. + * + * @param trees Component trees + */ + private[spark] def this( + trees: Array[OptimizedDecisionTreeClassificationModel], + numFeatures: Int, + numClasses: Int) = + this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) + + @Since("1.4.0") + override def trees: Array[OptimizedDecisionTreeClassificationModel] = _trees + + // Note: We may add support for weights (based on tree performance) later on. + private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0) + + @Since("1.4.0") + override def treeWeights: Array[Double] = _treeWeights + + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { + val bcastModel = dataset.sparkSession.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + override protected def predictRaw(features: Vector): Vector = { + // TODO: When we add a generic Bagging class, handle transform there: SPARK-7128 + // Classifies using majority votes. + // Ignore the tree weights since all are 1.0 for now. + val votes = Array.fill[Double](numClasses)(0.0) + _trees.view.foreach { tree => + votes(tree.rootNode.predictImpl(features).prediction.toInt) += 1 + } + Vectors.dense(votes) + } + + def predict(vector: OldVector): Double = { + predict(Vectors.dense(vector.toArray)) + } + + def oldPredict(vector: OldVector): Double = { + val predictions = _trees.map(_.oldPredict(vector)) + // Find most prevalent value + predictions.groupBy(identity).maxBy(_._2.length)._1 + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(dv) + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in RandomForestClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } + } + + @Since("1.4.0") + override def copy(extra: ParamMap): OptimizedRandomForestClassificationModel = { + copyValues(new OptimizedRandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) + .setParent(parent) + } + + @Since("1.4.0") + override def toString: String = { + s"RandomForestClassificationModel (uid=$uid) with $getNumTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldRandomForestModel = { + new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld)) + } + + @Since("2.0.0") + override def write: MLWriter = + new OptimizedRandomForestClassificationModelWriter(this) +} + +private +class OptimizedRandomForestClassificationModelWriter(instance: OptimizedRandomForestClassificationModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Note: numTrees is not currently used, but could be nice to store for fast querying. + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numClasses" -> instance.numClasses, + "numTrees" -> instance.getNumTrees) + OptimizedEnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) + } +} + +@Since("2.0.0") +object OptimizedRandomForestClassificationModel extends MLReadable[OptimizedRandomForestClassificationModel] { + + @Since("2.0.0") + override def read: MLReader[OptimizedRandomForestClassificationModel] = + new RandomForestClassificationModelReader + + @Since("2.0.0") + override def load(path: String): OptimizedRandomForestClassificationModel = super.load(path) + + + private class RandomForestClassificationModelReader + extends MLReader[OptimizedRandomForestClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[OptimizedRandomForestClassificationModel].getName + private val treeClassName = classOf[OptimizedDecisionTreeClassificationModel].getName + + override def load(path: String): OptimizedRandomForestClassificationModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, OptimizedNode)], _) = + OptimizedEnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numClasses = (metadata.metadata \ "numClasses").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[OptimizedDecisionTreeClassificationModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new OptimizedDecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses) + treeMetadata.getAndSetParams(tree) + tree + } + require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new OptimizedRandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses) + metadata.getAndSetParams(model) + model + } + } + + /** Convert a model from the old API */ + private[ml] def fromOld( + oldModel: OldRandomForestModel, + parent: OptimizedRandomForestClassifier, + categoricalFeatures: Map[Int, Int], + numClasses: Int, + numFeatures: Int = -1): OptimizedRandomForestClassificationModel = { + require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" + + s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent for each tree is null since there is no good way to set this. + OptimizedDecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures) + } + val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc") + new OptimizedRandomForestClassificationModel(uid, newTrees, numFeatures, numClasses) + } +} diff --git a/src/main/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressor.scala b/src/main/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressor.scala new file mode 100755 index 0000000..7d57375 --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressor.scala @@ -0,0 +1,323 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.hadoop.fs.Path +import org.apache.spark.annotation.Since +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree.OptimizedDecisionTreeModelReadWrite.NodeData +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.impl.OptimizedRandomForest +import org.apache.spark.ml.util.Instrumentation.instrumented +import org.apache.spark.ml.util._ +import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.mllib.tree.configuration.{TimePredictionStrategy, Algo => OldAlgo, OptimizedForestStrategy => OldStrategy} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.mllib.linalg.{Vector => OldVector} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{DataFrame, Dataset} +import org.json4s.JsonDSL._ +import org.json4s.{DefaultFormats, JObject} + + +/** + * Decision tree + * learning algorithm for regression. + * It supports both continuous and categorical features. + * + * TODO: Add maxPartitions setter + */ +@Since("1.4.0") +class OptimizedDecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) + extends Predictor[Vector, OptimizedDecisionTreeRegressor, OptimizedDecisionTreeRegressionModel] + with OptimizedDecisionTreeRegressorParams with DefaultParamsWritable { + + @Since("1.4.0") + def this() = this(Identifiable.randomUID("odtr")) + + // Override parameter setters from parent trait for Java API compatibility. + /** @group setParam */ + @Since("1.4.0") + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMaxBins(value: Int): this.type = set(maxBins, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + + /** @group expertSetParam */ + @Since("1.4.0") + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + + /** @group expertSetParam */ + @Since("1.4.0") + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be at least 1. + * (default = 10) + * @group setParam + */ + @Since("1.4.0") + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** @group setParam */ + @Since("1.4.0") + override def setImpurity(value: String): this.type = set(impurity, value) + + /** @group setParam */ + @Since("1.6.0") + override def setSeed(value: Long): this.type = set(seed, value) + + /** @group setParam */ + @Since("2.0.0") + override def setMaxMemoryMultiplier(value: Double): this.type = set(maxMemoryMultiplier, value) + + /** @group setParam */ + @Since("2.0.0") + override def setTimePredictionStrategy(value: TimePredictionStrategy) = timePredictionStrategy = value + + /** @group setParam */ + @Since("2.0.0") + override def setMaxTasksPerBin(value: Int): this.type + = set(maxTasksPerBin, value) + + /** @group setParam */ + @Since("2.0.0") + override def setCustomSplits(value: Option[Array[Array[Double]]]) = customSplits = value + + /** @group setParam */ + @Since("2.0.0") + override def setLocalTrainingAlgorithm(value: LocalTrainingAlgorithm) = localTrainingAlgorithm = value + + override protected def train(dataset: Dataset[_]): OptimizedDecisionTreeRegressionModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(dataset) + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val strategy = getOldStrategy(categoricalFeatures) + + instr.logParams(this, params: _*) + + val trees = OptimizedRandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = $(seed), instr = Some(instr), parentUID = Some(uid))._1 + + trees.head.asInstanceOf[OptimizedDecisionTreeRegressionModel] + } + + private[ml] def train(data: RDD[LabeledPoint], + oldStrategy: OldStrategy): OptimizedDecisionTreeRegressionModel = instrumented { instr => + instr.logPipelineStage(this) + instr.logDataset(data) + instr.logParams(this, params: _*) + + val trees = OptimizedRandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 0L, instr = Some(instr), parentUID = Some(uid))._1 + + trees.head.asInstanceOf[OptimizedDecisionTreeRegressionModel] + } + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, + subsamplingRate = 1.0) + } + + @Since("1.4.0") + override def copy(extra: ParamMap): OptimizedDecisionTreeRegressor = defaultCopy(extra) +} + +@Since("1.4.0") +object OptimizedDecisionTreeRegressor + extends DefaultParamsReadable[OptimizedDecisionTreeRegressor] { + /** Accessor for supported impurities: variance */ + final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + + @Since("2.0.0") + override def load(path: String): OptimizedDecisionTreeRegressor = super.load(path) +} + +/** + * + * Decision tree (Wikipedia) model for regression. + * It supports both continuous and categorical features. + * @param rootNode Root of the decision tree + */ +@Since("1.4.0") +class OptimizedDecisionTreeRegressionModel private[ml] ( + override val uid: String, + override val rootNode: OptimizedNode, + override val numFeatures: Int) + extends PredictionModel[Vector, OptimizedDecisionTreeRegressionModel] + with OptimizedDecisionTreeModel with OptimizedDecisionTreeRegressorParams with MLWritable with Serializable { + + require(rootNode != null, + "DecisionTreeRegressionModel given null rootNode, but it requires a non-null rootNode.") + + /** + * Construct a decision tree regression model. + * @param rootNode Root node of tree, with other nodes attached. + */ + private[ml] def this(rootNode: OptimizedNode, numFeatures: Int) = + this(Identifiable.randomUID("dtr"), rootNode, numFeatures) + + override def predict(features: Vector): Double = { + rootNode.predictImpl(features).prediction + } + + def predict(features: OldVector): Double = { + predict(Vectors.dense(features.toArray)) + } + + def oldPredict(vector: OldVector): Double = { + makePredictionForOldVector(rootNode, vector) + } + + private def makePredictionForOldVector(topNode: OptimizedNode, features: OldVector): Double = { + topNode match { + case node: OptimizedLeafNode => + node.prediction + case node: OptimizedInternalNode => + val shouldGoLeft = node.split match { + case split: ContinuousSplit => + features(split.featureIndex) <= split.threshold + + case split: CategoricalSplit => + // leftCategories will sort every time, rather use copied ml.Vector? + split.leftCategories.contains(features(split.featureIndex)) + } + + if (shouldGoLeft) { + makePredictionForOldVector(node.leftChild, features) + } else { + makePredictionForOldVector(node.rightChild, features) + } + + case _ => throw new RuntimeException("Unexpected error in OptimizedDecisionTreeRegressionModel, unknown Node type.") + } + } + + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + transformImpl(dataset) + } + + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { + val predictUDF = udf { (features: Vector) => predict(features) } + var output = dataset.toDF() + if ($(predictionCol).nonEmpty) { + output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + output + } + + @Since("1.4.0") + override def copy(extra: ParamMap): OptimizedDecisionTreeRegressionModel = { + copyValues(new OptimizedDecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent) + } + + @Since("1.4.0") + override def toString: String = { + s"DecisionTreeRegressionModel (uid=$uid) of depth $depth with $numNodes nodes" + } + + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ + override private[spark] def toOld: OldDecisionTreeModel = { + new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) + } + + @Since("2.0.0") + override def write: MLWriter = + new OptimizedDecisionTreeRegressionModel.DecisionTreeRegressionModelWriter(this) +} + +@Since("2.0.0") +object OptimizedDecisionTreeRegressionModel extends MLReadable[OptimizedDecisionTreeRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[OptimizedDecisionTreeRegressionModel] = + new DecisionTreeRegressionModelReader + + @Since("2.0.0") + override def load(path: String): OptimizedDecisionTreeRegressionModel = super.load(path) + + private[OptimizedDecisionTreeRegressionModel] + class DecisionTreeRegressionModelWriter(instance: OptimizedDecisionTreeRegressionModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures) + DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata)) + val (nodeData, _) = NodeData.build(instance.rootNode, 0) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(nodeData).write.parquet(dataPath) + } + } + + private class DecisionTreeRegressionModelReader + extends MLReader[OptimizedDecisionTreeRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[OptimizedDecisionTreeRegressionModel].getName + + override def load(path: String): OptimizedDecisionTreeRegressionModel = { + implicit val format = DefaultFormats + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val root = OptimizedDecisionTreeModelReadWrite.loadTreeNodes(path, metadata, sparkSession) + val model = new OptimizedDecisionTreeRegressionModel(metadata.uid, root, numFeatures) + metadata.getAndSetParams(model) + model + } + } + + /** Convert a model from the old API */ + private[ml] def fromOld( + oldModel: OldDecisionTreeModel, + parent: OptimizedDecisionTreeRegressor, + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): OptimizedDecisionTreeRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, + s"Cannot convert non-regression DecisionTreeModel (old API) to" + + s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}") + val rootNode = OptimizedNode.fromOld(oldModel.topNode, categoricalFeatures) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr") + new OptimizedDecisionTreeRegressionModel(uid, rootNode, numFeatures) + } +} + diff --git a/src/main/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressor.scala b/src/main/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressor.scala new file mode 100755 index 0000000..b640e42 --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressor.scala @@ -0,0 +1,326 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.tree.impl.OptimizedRandomForest +import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.Instrumentation.instrumented +import org.apache.spark.ml.util._ +import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.mllib.tree.configuration.{TimePredictionStrategy, Algo => OldAlgo} +import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} +import org.apache.spark.mllib.linalg.{Vector => OldVector} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.functions._ +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ + + +/** + * Random Forest + * learning algorithm for regression. + * It supports both continuous and categorical features. + * + */ +@Since("1.4.0") +class OptimizedRandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) + extends Predictor[Vector, OptimizedRandomForestRegressor, OptimizedRandomForestRegressionModel] + with OptimizedRandomForestRegressorParams with DefaultParamsWritable { + + @Since("1.4.0") + def this() = this(Identifiable.randomUID("orfr")) + + // Override parameter setters from parent trait for Java API compatibility. + + // Parameters from TreeRegressorParams: + + /** @group setParam */ + @Since("1.4.0") + override def setMaxDepth(value: Int): this.type = set(maxDepth, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMaxBins(value: Int): this.type = set(maxBins, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) + + /** @group setParam */ + @Since("1.4.0") + override def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) + + /** @group expertSetParam */ + @Since("1.4.0") + override def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) + + /** @group expertSetParam */ + @Since("1.4.0") + override def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) + + /** + * Specifies how often to checkpoint the cached node IDs. + * E.g. 10 means that the cache will get checkpointed every 10 iterations. + * This is only used if cacheNodeIds is true and if the checkpoint directory is set in + * [[org.apache.spark.SparkContext]]. + * Must be at least 1. + * (default = 10) + * @group setParam + */ + @Since("1.4.0") + override def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + + /** @group setParam */ + @Since("1.4.0") + override def setImpurity(value: String): this.type = set(impurity, value) + + // Parameters from TreeEnsembleParams: + + /** @group setParam */ + @Since("1.4.0") + override def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) + + /** @group setParam */ + @Since("1.4.0") + override def setSeed(value: Long): this.type = set(seed, value) + + // Parameters from RandomForestParams: + + /** @group setParam */ + @Since("1.4.0") + override def setNumTrees(value: Int): this.type = set(numTrees, value) + + /** @group setParam */ + @Since("1.4.0") + override def setFeatureSubsetStrategy(value: String): this.type = + set(featureSubsetStrategy, value) + + /** @group setParam */ + @Since("2.0.0") + override def setMaxMemoryMultiplier(value: Double): this.type = set(maxMemoryMultiplier, value) + + /** @group setParam */ + @Since("2.0.0") + override def setTimePredictionStrategy(value: TimePredictionStrategy) = timePredictionStrategy = value + + /** @group setParam */ + @Since("2.0.0") + override def setMaxTasksPerBin(value: Int): this.type + = set(maxTasksPerBin, value) + + /** @group setParam */ + @Since("2.0.0") + override def setCustomSplits(value: Option[Array[Array[Double]]]) = customSplits = value + + /** @group setParam */ + @Since("2.0.0") + override def setLocalTrainingAlgorithm(value: LocalTrainingAlgorithm) = localTrainingAlgorithm = value + + override protected def train(dataset: Dataset[_]): OptimizedRandomForestRegressionModel = instrumented { instr => + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val strategy = + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity) + + instr.logPipelineStage(this) + instr.logDataset(dataset) + instr.logParams(this, labelCol, featuresCol, predictionCol, impurity, numTrees, + featureSubsetStrategy, maxDepth, maxBins, maxMemoryInMB, minInfoGain, + minInstancesPerNode, seed, subsamplingRate, cacheNodeIds, checkpointInterval) + + val trees = OptimizedRandomForest + .run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))._1 + .map(_.asInstanceOf[OptimizedDecisionTreeRegressionModel]) + + val numFeatures = oldDataset.first().features.size + instr.logNamedValue(Instrumentation.loggerTags.numFeatures, numFeatures) + new OptimizedRandomForestRegressionModel(uid, trees, numFeatures) + } + + @Since("1.4.0") + override def copy(extra: ParamMap): OptimizedRandomForestRegressor = defaultCopy(extra) +} + +@Since("1.4.0") +object OptimizedRandomForestRegressor extends DefaultParamsReadable[OptimizedRandomForestRegressor]{ + /** Accessor for supported impurity settings: variance */ + @Since("1.4.0") + final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities + + /** Accessor for supported featureSubsetStrategy settings: auto, all, onethird, sqrt, log2 */ + @Since("1.4.0") + final val supportedFeatureSubsetStrategies: Array[String] = + TreeEnsembleParams.supportedFeatureSubsetStrategies + + @Since("2.0.0") + override def load(path: String): OptimizedRandomForestRegressor = super.load(path) +} + +/** + * Random Forest model for regression. + * It supports both continuous and categorical features. + * + * @param _trees Decision trees in the ensemble. + * @param numFeatures Number of features used by this model + */ +@Since("1.4.0") +class OptimizedRandomForestRegressionModel private[spark] ( + override val uid: String, + private val _trees: Array[OptimizedDecisionTreeRegressionModel], + override val numFeatures: Int) + extends PredictionModel[Vector, OptimizedRandomForestRegressionModel] + with OptimizedRandomForestRegressorParams with OptimizedTreeEnsembleModel[OptimizedDecisionTreeRegressionModel] + with MLWritable with Serializable { + + require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.") + + /** + * Construct a random forest regression model, with all trees weighted equally. + * + * @param trees Component trees + */ + private[ml] def this(trees: Array[OptimizedDecisionTreeRegressionModel], numFeatures: Int) = + this(Identifiable.randomUID("rfr"), trees, numFeatures) + + @Since("1.4.0") + override def trees: Array[OptimizedDecisionTreeRegressionModel] = _trees + + // Note: We may add support for weights (based on tree performance) later on. + private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0) + + @Since("1.4.0") + override def treeWeights: Array[Double] = _treeWeights + + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { + val bcastModel = dataset.sparkSession.sparkContext.broadcast(this) + val predictUDF = udf { (features: Any) => + bcastModel.value.predict(features.asInstanceOf[Vector]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + + override def predict(features: Vector): Double = { + // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 + // Predict average of tree predictions. + // Ignore the weights since all are 1.0 for now. + _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees + } + + def predict(features: OldVector): Double = { + predict(Vectors.dense(features.toArray)) + } + + def oldPredict(vector: OldVector): Double = { + _trees.map(_.oldPredict(vector)).sum / getNumTrees + } + + @Since("1.4.0") + override def copy(extra: ParamMap): OptimizedRandomForestRegressionModel = { + copyValues(new OptimizedRandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) + } + + @Since("1.4.0") + override def toString: String = { + s"RandomForestRegressionModel (uid=$uid) with $getNumTrees trees" + } + + /** (private[ml]) Convert to a model in the old API */ + private[ml] def toOld: OldRandomForestModel = { + new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld)) + } + + @Since("2.0.0") + override def write: MLWriter = + new OptimizedRandomForestRegressionModelWriter(this) +} + +private +class OptimizedRandomForestRegressionModelWriter(instance: OptimizedRandomForestRegressionModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + OptimizedEnsembleModelReadWrite.saveImpl(instance, path, sparkSession, extraMetadata) + } +} + +@Since("2.0.0") +object OptimizedRandomForestRegressionModel extends MLReadable[OptimizedRandomForestRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[OptimizedRandomForestRegressionModel] = new RandomForestRegressionModelReader + + @Since("2.0.0") + override def load(path: String): OptimizedRandomForestRegressionModel = super.load(path) + + + private class RandomForestRegressionModelReader extends MLReader[OptimizedRandomForestRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[OptimizedRandomForestRegressionModel].getName + private val treeClassName = classOf[OptimizedDecisionTreeRegressionModel].getName + + override def load(path: String): OptimizedRandomForestRegressionModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, OptimizedNode)], treeWeights: Array[Double]) = + OptimizedEnsembleModelReadWrite.loadImpl(path, sparkSession, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[OptimizedDecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => + val tree = + new OptimizedDecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + treeMetadata.getAndSetParams(tree) + tree + } + require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new OptimizedRandomForestRegressionModel(metadata.uid, trees, numFeatures) + metadata.getAndSetParams(model) + model + } + } + + /** Convert a model from the old API */ + def fromOld( + oldModel: OldRandomForestModel, + parent: OptimizedRandomForestRegressor, + categoricalFeatures: Map[Int, Int], + numFeatures: Int = -1): OptimizedRandomForestRegressionModel = { + require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" + + s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).") + val newTrees = oldModel.trees.map { tree => + // parent for each tree is null since there is no good way to set this. + OptimizedDecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) + } + val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfr") + new OptimizedRandomForestRegressionModel(uid, newTrees, numFeatures) + } +} diff --git a/src/main/scala/org/apache/spark/ml/tree/LocalTrainingAlgorithm.scala b/src/main/scala/org/apache/spark/ml/tree/LocalTrainingAlgorithm.scala new file mode 100755 index 0000000..dd8437a --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/LocalTrainingAlgorithm.scala @@ -0,0 +1,30 @@ +/* + * Copyright (C) 2019 Cisco Systems + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.ml.tree.impl.{DecisionTreeMetadata, TreePoint} + +trait LocalTrainingAlgorithm extends Serializable { + + def fitNode(input: Array[TreePoint], + instanceWeights: Array[Double], + node: OptimizedLearningNode, + metadata: DecisionTreeMetadata, + splits: Array[Array[Split]], + maxDepthOverride: Option[Int] = None, + prune: Boolean = true): OptimizedNode +} diff --git a/src/main/scala/org/apache/spark/ml/tree/OptimizedNode.scala b/src/main/scala/org/apache/spark/ml/tree/OptimizedNode.scala new file mode 100755 index 0000000..c6a0263 --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/OptimizedNode.scala @@ -0,0 +1,430 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} + +/** + * Decision tree node interface. + */ +sealed abstract class OptimizedNode extends Serializable { + + // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree + // code into the new API and deprecate the old API. SPARK-3727 + + /** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */ + def prediction: Double + + /** Impurity measure at this node (for training data) */ + def impurity: Double + + /** Recursive prediction helper method */ + private[ml] def predictImpl(features: Vector): OptimizedLeafNode + + /** + * Get the number of nodes in tree below this node, including leaf nodes. + * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. + */ + private[tree] def numDescendants: Int + + /** + * Recursive print function. + * @param indentFactor The number of spaces to add to each level of indentation. + */ + private[tree] def subtreeToString(indentFactor: Int = 0): String + + /** + * Get depth of tree from this node. + * E.g.: Depth 0 means this is a leaf node. Depth 1 means 1 internal and 2 leaf nodes. + */ + private[tree] def subtreeDepth: Int + + /** + * Create a copy of this node in the old Node format, recursively creating child nodes as needed. + * @param id Node ID using old format IDs + */ + private[ml] def toOld(id: Int): OldNode + + /** + * Trace down the tree, and return the largest feature index used in any split. + * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). + */ + private[ml] def maxSplitFeatureIndex(): Int + + /** Returns a deep copy of the subtree rooted at this node. */ + private[tree] def deepCopy(): OptimizedNode +} + +private[ml] object OptimizedNode { + + /** + * Create a new Node from the old Node format, recursively creating child nodes as needed. + */ + def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): OptimizedNode = { + if (oldNode.isLeaf) { + // TODO: Once the implementation has been moved to this API, then include sufficient + // statistics here. + new OptimizedLeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity) + } else { + val gain = if (oldNode.stats.nonEmpty) { + oldNode.stats.get.gain + } else { + 0.0 + } + new OptimizedInternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, + gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), + rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), + split = Split.fromOld(oldNode.split.get, categoricalFeatures)) + } + } +} + +/** + * Decision tree leaf node. + * @param prediction Prediction this node makes + * @param impurity Impurity measure at this node (for training data) + */ +class OptimizedLeafNode private[ml]( + override val prediction: Double, + override val impurity: Double) extends OptimizedNode { + + override def toString: String = + s"LeafNode(prediction = $prediction, impurity = $impurity)" + + override private[ml] def predictImpl(features: Vector): OptimizedLeafNode = this + + override private[tree] def numDescendants: Int = 0 + + override private[tree] def subtreeToString(indentFactor: Int = 0): String = { + val prefix: String = " " * indentFactor + prefix + s"Predict: $prediction\n" + } + + override private[tree] def subtreeDepth: Int = 0 + + override private[ml] def toOld(id: Int): OldNode = { + // TODO: Probability can't be computed without impurityStats + new OldNode(id, new OldPredict(prediction, prob = 0.0), + impurity, isLeaf = true, None, None, None, None) + } + + override private[ml] def maxSplitFeatureIndex(): Int = -1 + + override private[tree] def deepCopy(): OptimizedNode = { + new OptimizedLeafNode(prediction, impurity) + } +} + +/** + * Internal Decision Tree node. + * @param prediction Prediction this node would make if it were a leaf node + * @param impurity Impurity measure at this node (for training data) + * @param gain Information gain value. Values less than 0 indicate missing values; + * this quirk will be removed with future updates. + * @param leftChild Left-hand child node + * @param rightChild Right-hand child node + * @param split Information about the test used to split to the left or right child. + */ +class OptimizedInternalNode private[ml]( + override val prediction: Double, + override val impurity: Double, + val gain: Double, + val leftChild: OptimizedNode, + val rightChild: OptimizedNode, + val split: Split) extends OptimizedNode { + + // Note to developers: The constructor argument impurityStats should be reconsidered before we + // make the constructor public. We may be able to improve the representation. + + override def toString: String = { + s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" + } + + override private[ml] def predictImpl(features: Vector): OptimizedLeafNode = { + if (split.shouldGoLeft(features)) { + leftChild.predictImpl(features) + } else { + rightChild.predictImpl(features) + } + } + + override private[tree] def numDescendants: Int = { + 2 + leftChild.numDescendants + rightChild.numDescendants + } + + override private[tree] def subtreeToString(indentFactor: Int = 0): String = { + val prefix: String = " " * indentFactor + prefix + s"If (${OptimizedInternalNode.splitToString(split, left = true)})\n" + + leftChild.subtreeToString(indentFactor + 1) + + prefix + s"Else (${OptimizedInternalNode.splitToString(split, left = false)})\n" + + rightChild.subtreeToString(indentFactor + 1) + } + + override private[tree] def subtreeDepth: Int = { + 1 + math.max(leftChild.subtreeDepth, rightChild.subtreeDepth) + } + + override private[ml] def toOld(id: Int): OldNode = { + assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API" + + " since the old API does not support deep trees.") + new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, + isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), + Some(rightChild.toOld(OldNode.rightChildIndex(id))), + Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity, + new OldPredict(leftChild.prediction, prob = 0.0), + new OldPredict(rightChild.prediction, prob = 0.0)))) + } + + override private[ml] def maxSplitFeatureIndex(): Int = { + math.max(split.featureIndex, + math.max(leftChild.maxSplitFeatureIndex(), rightChild.maxSplitFeatureIndex())) + } + + override private[tree] def deepCopy(): OptimizedNode = { + new OptimizedInternalNode(prediction, impurity, gain, leftChild.deepCopy(), + rightChild.deepCopy(), split) + } +} + +private object OptimizedInternalNode { + + /** + * Helper method for [[Node.subtreeToString()]]. + * @param split Split to print + * @param left Indicates whether this is the part of the split going to the left, + * or that going to the right. + */ + private def splitToString(split: Split, left: Boolean): String = { + val featureStr = s"feature ${split.featureIndex}" + split match { + case contSplit: ContinuousSplit => + if (left) { + s"$featureStr <= ${contSplit.threshold}" + } else { + s"$featureStr > ${contSplit.threshold}" + } + case catSplit: CategoricalSplit => + val categoriesStr = catSplit.leftCategories.mkString("{", ",", "}") + if (left) { + s"$featureStr in $categoriesStr" + } else { + s"$featureStr not in $categoriesStr" + } + } + } +} + +/** + * Version of a node used in learning. This uses vars so that we can modify nodes as we split the + * tree by adding children, etc. + * + * For now, we use node IDs. These will be kept internal since we hope to remove node IDs + * in the future, or at least change the indexing (so that we can support much deeper trees). + * + * This node can either be: + * - a leaf node, with leftChild, rightChild, split set to null, or + * - an internal node, with all values set + * + * @param id We currently use the same indexing as the old implementation in + * [[org.apache.spark.mllib.tree.model.Node]], but this will change later. + * @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree, + * so that we do not need to consider splitting it further. + * @param stats Impurity statistics for this node. + */ +private[tree] class OptimizedLearningNode( + var id: Int, + var leftChild: Option[OptimizedLearningNode], + var rightChild: Option[OptimizedLearningNode], + var split: Option[Split], + var isLeaf: Boolean, + var stats: ImpurityStats) extends Serializable { + + /** + * Convert this [[OptimizedLearningNode]] to a regular [[Node]], and recurse on any children. + */ + def toNode: OptimizedNode = toNode(prune = true) + + def toNode(prune: Boolean = true): OptimizedNode = { + + if (!leftChild.isEmpty || !rightChild.isEmpty) { + assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, + "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") + (leftChild.get.toNode(prune), rightChild.get.toNode(prune)) match { + case (l: OptimizedLeafNode, r: OptimizedLeafNode) if prune && l.prediction == r.prediction => + new OptimizedLeafNode(l.prediction, stats.impurity) + case (l, r) => + new OptimizedInternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l, r, split.get) + } + } else { + if (stats.valid) { + new OptimizedLeafNode(stats.impurityCalculator.predict, stats.impurity) + } else { + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + new OptimizedLeafNode(stats.impurityCalculator.predict, -1.0) + } + } + } + + def toNodeWithLocalNodesMap(localNodesMap: Map[(Int, Int), OptimizedNode], treeIndex: Int, prune: Boolean): OptimizedNode = { + localNodesMap.getOrElse((treeIndex, id), { + if (!leftChild.isEmpty || !rightChild.isEmpty) { + assert(leftChild.nonEmpty && rightChild.nonEmpty && split.nonEmpty && stats != null, + "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") + (leftChild.get.toNodeWithLocalNodesMap(localNodesMap, treeIndex, prune), + rightChild.get.toNodeWithLocalNodesMap(localNodesMap, treeIndex, prune) + ) match { + case (l: OptimizedLeafNode, r: OptimizedLeafNode) if prune && l.prediction == r.prediction => + new OptimizedLeafNode(l.prediction, stats.impurity) + case (l, r) => + new OptimizedInternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + l, r, split.get) + } + } else { + if (stats.valid) { + new OptimizedLeafNode(stats.impurityCalculator.predict, stats.impurity) + } else { + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + new OptimizedLeafNode(stats.impurityCalculator.predict, -1.0) + } + } + }) + } + + /** + * Get the node index corresponding to this data point. + * This function mimics prediction, passing an example from the root node down to a leaf + * or unsplit node; that node's index is returned. + * + * @param binnedFeatures Binned feature vector for data point. + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. + * Note: This is the global node index, i.e., the index used in the tree. + * This index is different from the index used during training a particular + * group of nodes on one call to + * [[org.apache.spark.ml.tree.impl.RandomForest.findBestSplits()]]. + */ + def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = { + if (this.isLeaf || this.split.isEmpty) { + this.id + } else { + val split = this.split.get + val featureIndex = split.featureIndex + val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex)) + if (this.leftChild.isEmpty) { + // Not yet split. Return next layer of nodes to train + if (splitLeft) { + OptimizedLearningNode.leftChildIndex(this.id) + } else { + OptimizedLearningNode.rightChildIndex(this.id) + } + } else { + if (splitLeft) { + this.leftChild.get.predictImpl(binnedFeatures, splits) + } else { + this.rightChild.get.predictImpl(binnedFeatures, splits) + } + } + } + } + +} + +private[tree] object OptimizedLearningNode { + + /** Create a node with some of its fields set. */ + def apply( + id: Int, + isLeaf: Boolean, + stats: ImpurityStats): OptimizedLearningNode = { + new OptimizedLearningNode(id, None, None, None, isLeaf, stats) + } + + /** Create an empty node with the given node index. Values must be set later on. */ + def emptyNode(nodeIndex: Int): OptimizedLearningNode = { + new OptimizedLearningNode(nodeIndex, None, None, None, false, null) + } + + // The below indexing methods were copied from spark.mllib.tree.model.Node + + /** + * Return the index of the left child of this node. + */ + def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1 + + /** + * Return the index of the right child of this node. + */ + def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1 + + /** + * Get the parent index of the given node, or 0 if it is the root. + */ + def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1 + + /** + * Return the level of a tree which the given node is in. + */ + def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) { + throw new IllegalArgumentException(s"0 is not a valid node index.") + } else { + java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex)) + } + + /** + * Returns true if this is a left child. + * Note: Returns false for the root. + */ + def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0 + + /** + * Return the maximum number of nodes which can be in the given level of the tree. + * @param level Level of tree (0 = root). + */ + def maxNodesInLevel(level: Int): Int = 1 << level + + /** + * Return the index of the first node in the given level. + * @param level Level of tree (0 = root). + */ + def startIndexInLevel(level: Int): Int = 1 << level + + /** + * Traces down from a root node to get the node with the given node index. + * This assumes the node exists. + */ + def getNode(nodeIndex: Int, rootNode: OptimizedLearningNode): OptimizedLearningNode = { + var tmpNode: OptimizedLearningNode = rootNode + var levelsToGo = indexToLevel(nodeIndex) + while (levelsToGo > 0) { + if ((nodeIndex & (1 << levelsToGo - 1)) == 0) { + tmpNode = tmpNode.leftChild.get + } else { + tmpNode = tmpNode.rightChild.get + } + levelsToGo -= 1 + } + tmpNode + } + +} diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala b/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala new file mode 100755 index 0000000..07e4a16 --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/impl/AggUpdateUtils.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.ml.tree.Split + +/** + * Helpers for updating DTStatsAggregators during collection of sufficient stats for tree training. + */ +private[impl] object AggUpdateUtils { + + /** + * Updates the parent node stats of the passed-in impurity aggregator with the labels + * corresponding to the feature values at indices [from, to). + * @param indices Array of row indices for feature values; indices(i) = row index of the ith + * feature value + */ + private[impl] def updateParentImpurity( + statsAggregator: DTStatsAggregator, + indices: Array[Int], + from: Int, + to: Int, + instanceWeights: Array[Double], + labels: Array[Double]): Unit = { + from.until(to).foreach { idx => + val rowIndex = indices(idx) + val label = labels(rowIndex) + statsAggregator.updateParent(label, instanceWeights(rowIndex)) + } + } + + /** + * Update aggregator for an (unordered feature, label) pair + * @param featureSplits Array of splits for the current feature + */ + private[impl] def updateUnorderedFeature( + agg: DTStatsAggregator, + featureValue: Int, + label: Double, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + instanceWeight: Double): Unit = { + val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) + // Each unordered split has a corresponding bin for impurity stats of data points that fall + // onto the left side of the split. For each unordered split, update left-side bin if applicable + // for the current data point. + val numSplits = agg.metadata.numSplits(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplits) { + if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { + agg.featureUpdate(leftNodeFeatureOffset, splitIndex, label, instanceWeight) + } + splitIndex += 1 + } + } + + /** Update aggregator for an (ordered feature, label) pair */ + private[impl] def updateOrderedFeature( + agg: DTStatsAggregator, + featureValue: Int, + label: Double, + featureIndexIdx: Int, + instanceWeight: Double): Unit = { + // The bin index of an ordered feature is just the feature value itself + val binIndex = featureValue + agg.update(featureIndexIdx, binIndex, label, instanceWeight) + } + +} diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/FeatureColumn.scala b/src/main/scala/org/apache/spark/ml/tree/impl/FeatureColumn.scala new file mode 100755 index 0000000..a403fd7 --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/impl/FeatureColumn.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.util.collection.BitSet + +/** + * Stores values for a single training data column (a single continuous or categorical feature). + * + * Values are currently stored in a dense representation only. + * TODO: Support sparse storage (to optimize deeper levels of the tree), and maybe compressed + * storage (to optimize upper levels of the tree). + * + * TODO: Sort feature values to support more complicated splitting logic (e.g. considering every + * possible continuous split instead of discretizing continuous features). + * + * TODO: Consider sorting feature values; the only changed required would be to + * sort values at construction-time. Sorting might improve locality during stats + * aggregation (we'd frequently update the same O(statsSize) array for a (feature, bin), + * instead of frequently updating for the same feature). + * + */ +private[impl] class FeatureColumn( + val featureIndex: Int, + val values: Array[Int]) + extends Serializable { + + /** For debugging */ + override def toString: String = { + " FeatureVector(" + + s" featureIndex: $featureIndex,\n" + + s" values: ${values.mkString(", ")},\n" + + " )" + } + + def deepCopy(): FeatureColumn = new FeatureColumn(featureIndex, values.clone()) + + override def equals(other: Any): Boolean = { + other match { + case o: FeatureColumn => + featureIndex == o.featureIndex && values.sameElements(o.values) + case _ => false + } + } + + override def hashCode: Int = { + com.google.common.base.Objects.hashCode( + featureIndex: java.lang.Integer, + values) + } + + /** + * Reorders the subset of feature values at indices [from, to) in the passed-in column + * according to the split information encoded in instanceBitVector (feature values for rows + * that split left appear before feature values for rows that split right). + * + * @param numLeftRows Number of rows on the left side of the split + * @param tempVals Destination buffer for reordered feature values + * @param instanceBitVector instanceBitVector(i) = true if the row for the (from + i)th feature + * value splits right, false otherwise + */ + private[ml] def updateForSplit( + from: Int, + to: Int, + numLeftRows: Int, + tempVals: Array[Int], + instanceBitVector: BitSet): Unit = { + LocalDecisionTreeUtils.updateArrayForSplit(values, from, to, numLeftRows, tempVals, + instanceBitVector) + } +} + +private[impl] object FeatureColumn { + /** + * Store column values sorted by decision tree node (i.e. all column values for a node occur + * in a contiguous subarray). + */ + private[impl] def apply(featureIndex: Int, values: Array[Int]) = { + new FeatureColumn(featureIndex, values) + } + +} diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala b/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala new file mode 100755 index 0000000..b8b26a1 --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/impl/ImpurityUtils.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.mllib.tree.impurity._ +import org.apache.spark.mllib.tree.model.ImpurityStats + +/** Helper methods for impurity-related calculations during node split decisions. */ +private[impl] object ImpurityUtils { + + /** + * Get impurity calculator containing statistics for all labels for rows corresponding to + * feature values in [from, to). + * @param indices indices(i) = row index corresponding to ith feature value + */ + private[impl] def getParentImpurityCalculator( + metadata: DecisionTreeMetadata, + indices: Array[Int], + from: Int, + to: Int, + instanceWeights: Array[Double], + labels: Array[Double]): ImpurityCalculator = { + // Compute sufficient stats (e.g. label counts) for all data at the current node, + // store result in currNodeStatsAgg.parentStats so that we can share it across + // all features for the current node + val currNodeStatsAgg = new DTStatsAggregator(metadata, featureSubset = None) + AggUpdateUtils.updateParentImpurity(currNodeStatsAgg, indices, from, to, + instanceWeights, labels) + currNodeStatsAgg.getParentImpurityCalculator() + } + + /** + * Calculate the impurity statistics for a given (feature, split) based upon left/right + * aggregates. + * + * @param parentImpurityCalculator An ImpurityCalculator containing the impurity stats + * of the node currently being split. + * @param leftImpurityCalculator left node aggregates for this (feature, split) + * @param rightImpurityCalculator right node aggregate for this (feature, split) + * @param metadata learning and dataset metadata for DecisionTree + * @return Impurity statistics for this (feature, split) + */ + private[impl] def calculateImpurityStats( + parentImpurityCalculator: ImpurityCalculator, + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator, + metadata: DecisionTreeMetadata): ImpurityStats = { + + val impurity: Double = parentImpurityCalculator.calculate() + + val leftCount = leftImpurityCalculator.count + val rightCount = rightImpurityCalculator.count + + val totalCount = leftCount + rightCount + + // If left child or right child doesn't satisfy minimum instances per node, + // then this split is invalid, return invalid information gain stats. + if ((leftCount < metadata.minInstancesPerNode) || + (rightCount < metadata.minInstancesPerNode)) { + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + + val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 + val rightImpurity = rightImpurityCalculator.calculate() + + val leftWeight = leftCount / totalCount.toDouble + val rightWeight = rightCount / totalCount.toDouble + + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + // If information gain doesn't satisfy minimum information gain, + // then this split is invalid, return invalid information gain stats. + if (gain < metadata.minInfoGain) { + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) + } + + new ImpurityStats(gain, impurity, parentImpurityCalculator, + leftImpurityCalculator, rightImpurityCalculator) + } + + /** + * Given an impurity aggregator containing label statistics for a given (node, feature, bin), + * returns the corresponding "centroid", used to order bins while computing best splits. + * + * @param metadata learning and dataset metadata for DecisionTree + */ + private[impl] def getCentroid( + metadata: DecisionTreeMetadata, + binStats: ImpurityCalculator): Double = { + + if (binStats.count != 0) { + if (metadata.isMulticlass) { + // multiclass classification + // For categorical features in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + binStats.calculate() + } else if (metadata.isClassification) { + // binary classification + // For categorical features in binary classification, + // the bins are ordered by the count of class 1. + binStats.stats(1) + } else { + // regression + // For categorical features in regression and binary classification, + // the bins are ordered by the prediction. + binStats.predict + } + } else { + Double.MaxValue + } + } +} diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTree.scala b/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTree.scala new file mode 100755 index 0000000..6b40db7 --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTree.scala @@ -0,0 +1,268 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.util.Random + +import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.tree.model.ImpurityStats +import org.apache.spark.util.random.SamplingUtils + + +/** Object exposing methods for local training of decision trees */ +class LocalDecisionTree extends LocalTrainingAlgorithm { + + /** + * Fully splits the passed-in node on the provided local dataset, returning the finalized Internal / Leaf Node + * with its fully trained descendants. + * + * @param node LearningNode to use as the root of the subtree fit on the passed-in dataset + * @param metadata learning and dataset metadata for DecisionTree + * @param splits splits(i) = array of splits for feature i + */ + def fitNode( + input: Array[TreePoint], + instanceWeights: Array[Double], + node: OptimizedLearningNode, + metadata: DecisionTreeMetadata, + splits: Array[Array[Split]], + maxDepthOverride: Option[Int] = None, + prune: Boolean = true): OptimizedNode = { + + // The case with 1 node (depth = 0) is handled separately. + // This allows all iterations in the depth > 0 case to use the same code. + // TODO: Check that learning works when maxDepth > 0 but learning stops at 1 node (because of + // other parameters). + val maxDepth = maxDepthOverride.getOrElse(metadata.maxDepth) + + if (maxDepth == 0) { + return node.toNode + } + + val labels = input.map(_.label) + + // Prepare column store. + // Note: rowToColumnStoreDense checks to make sure numRows < Int.MaxValue. + val colStoreInit: Array[Array[Int]] = LocalDecisionTreeUtils + .rowToColumnStoreDense(input.map(_.binnedFeatures)) + + // Fit a decision tree model on the dataset + val learningNode = trainDecisionTree(node, colStoreInit, instanceWeights, labels, + metadata, splits, maxDepth) + + // Create the finalized InternalNode and prune the tree + learningNode.toNode(prune) + } + + /** + * Locally fits a decision tree model. + * + * @param rootNode Node to use as root of the tree fit on the passed-in dataset + * @param colStoreInit Array of columns of training data + * @param instanceWeights Array of weights for each training example + * @param metadata learning and dataset metadata for DecisionTree + * @param splits splits(i) = Array of possible splits for feature i + * @return rootNode with its completely trained subtree + */ + private[ml] def trainDecisionTree( + rootNode: OptimizedLearningNode, + colStoreInit: Array[Array[Int]], + instanceWeights: Array[Double], + labels: Array[Double], + metadata: DecisionTreeMetadata, + splits: Array[Array[Split]], + maxDepth: Int): OptimizedLearningNode = { + + // Sort each column by decision tree node. + val colStore: Array[FeatureColumn] = colStoreInit.zipWithIndex.map { case (col, featureIndex) => + val featureArity: Int = metadata.featureArity.getOrElse(featureIndex, 0) + FeatureColumn(featureIndex, col) + } + + val numRows = colStore.headOption match { + case None => 0 + case Some(column) => column.values.length + } + + // Create a new TrainingInfo describing the status of our partially-trained subtree + // at each iteration of training + var trainingInfo: TrainingInfo = TrainingInfo(colStore, + nodeOffsets = Array[(Int, Int)]((0, numRows)), currentLevelActiveNodes = Array(rootNode)) + + // Iteratively learn, one level of the tree at a time. + // Note: We do not use node IDs. + var currentLevel = 0 + var doneLearning = false + val rng = new Random() + + while (currentLevel < maxDepth && !doneLearning) { + // Splits each active node if possible, returning an array of new active nodes + val nextLevelNodes: Array[OptimizedLearningNode] = + computeBestSplits(trainingInfo, instanceWeights, labels, metadata, splits, rng) + // Count number of non-leaf nodes in the next level + val estimatedRemainingActive = nextLevelNodes.count(!_.isLeaf) + // TODO: Check to make sure we split something, and stop otherwise. + doneLearning = currentLevel + 1 >= maxDepth || estimatedRemainingActive == 0 + if (!doneLearning) { + // Obtain a new trainingInfo instance describing our current training status + trainingInfo = trainingInfo.update(splits, nextLevelNodes) + } + currentLevel += 1 + } + + // Done with learning + rootNode + } + + /** + * Iterate over feature values and labels for a specific (node, feature), updating stats + * aggregator for the current node. + */ private[impl] def updateAggregator( statsAggregator: DTStatsAggregator, col: FeatureColumn, indices: Array[Int], instanceWeights: Array[Double], + labels: Array[Double], + from: Int, + to: Int, + featureIndexIdx: Int, + featureSplits: Array[Split]): Unit = { + val metadata = statsAggregator.metadata + if (metadata.isUnordered(col.featureIndex)) { + from.until(to).foreach { idx => + val rowIndex = indices(idx) + AggUpdateUtils.updateUnorderedFeature(statsAggregator, col.values(idx), labels(rowIndex), + featureIndex = col.featureIndex, featureIndexIdx, featureSplits, + instanceWeight = instanceWeights(rowIndex)) + } + } else { + from.until(to).foreach { idx => + val rowIndex = indices(idx) + AggUpdateUtils.updateOrderedFeature(statsAggregator, col.values(idx), labels(rowIndex), + featureIndexIdx, instanceWeight = instanceWeights(rowIndex)) + } + } + } + + /** + * Find the best splits for all active nodes + * + * @param trainingInfo Contains node offset info for current set of active nodes + * @return Array of new active nodes formed by splitting the current set of active nodes. + */ + private def computeBestSplits( + trainingInfo: TrainingInfo, + instanceWeights: Array[Double], + labels: Array[Double], + metadata: DecisionTreeMetadata, + splits: Array[Array[Split]], + rng: Random): Array[OptimizedLearningNode] = { + // For each node, select the best split across all features + trainingInfo match { + case TrainingInfo(columns: Array[FeatureColumn], nodeOffsets: Array[(Int, Int)], + currentLevelActiveNodes: Array[OptimizedLearningNode], _) => { + // Filter out leaf nodes from the previous iteration + val activeNonLeafs = currentLevelActiveNodes.zipWithIndex.filterNot(_._1.isLeaf) + // Iterate over the active nodes in the current level. + activeNonLeafs.flatMap { case (node: OptimizedLearningNode, nodeIndex: Int) => + // Features for the current node start at fromOffset and end at toOffset + val (from, to) = nodeOffsets(nodeIndex) + // Get impurityCalculator containing label stats for all data points at the current node + val parentImpurityCalc = ImpurityUtils.getParentImpurityCalculator(metadata, + trainingInfo.indices, from, to, instanceWeights, labels) + + // Randomly select a subset of features + val featureSubset = if (metadata.subsamplingFeatures) { + Some(SamplingUtils.reservoirSampleAndCount(Range(0, metadata.numFeatures).iterator, + metadata.numFeaturesPerNode, rng.nextLong())._1) + } else { + None + } + + val validFeatureSplits = OptimizedRandomForest.getFeaturesWithSplits(metadata, + featuresForNode = featureSubset) + // Find the best split for each feature for the current node + val splitsAndImpurityInfo = validFeatureSplits.map { case (_, featureIndex) => + val col = columns(featureIndex) + // Create a DTStatsAggregator to hold label statistics for each bin of the current + // feature & compute said label statistics + val statsAggregator = new DTStatsAggregator(metadata, Some(Array(featureIndex))) + updateAggregator(statsAggregator, col, trainingInfo.indices, instanceWeights, + labels, from, to, featureIndexIdx = 0, splits(col.featureIndex)) + // Choose best split for current feature based on label statistics + SplitUtils.chooseSplit(statsAggregator, featureIndex, featureIndexIdx = 0, + splits(featureIndex), Some(parentImpurityCalc)) + } + // Find the best split overall (across all features) for the current node + val (bestSplit, bestStats) = OptimizedRandomForest.getBestSplitByGain(parentImpurityCalc, + metadata, featuresForNode = None, splitsAndImpurityInfo) + // Split current node, get an iterator over its children + splitIfPossible(node, metadata, bestStats, bestSplit) + } + } + } + } + + /** + * Splits the passed-in node if permitted by the parameters of the learning algorithm, + * returning an iterator over its children. Returns an empty array if node could not be split. + * + * @param metadata learning and dataset metadata for DecisionTree + * @param stats Label impurity stats associated with the current node + */ + private[impl] def splitIfPossible( + node: OptimizedLearningNode, + metadata: DecisionTreeMetadata, + stats: ImpurityStats, + split: Split): Iterator[OptimizedLearningNode] = { + if (stats.valid) { + // Split node and return an iterator over its children; we filter out leaf nodes later + doSplit(node, split, stats) + Iterator(node.leftChild.get, node.rightChild.get) + } else { + node.stats = stats + node.isLeaf = true + Iterator() + } + } + + /** + * Splits the passed-in node. This method returns nothing, but modifies the passed-in node + * by updating its split and stats members. + * + * @param split Split to associate with the passed-in node + * @param stats Label impurity statistics to associate with the passed-in node + */ + private[impl] def doSplit( + node: OptimizedLearningNode, + split: Split, + stats: ImpurityStats): Unit = { + val leftChildIsLeaf = stats.leftImpurity == 0 + node.leftChild = Some(OptimizedLearningNode(id = OptimizedLearningNode.leftChildIndex(node.id), + isLeaf = leftChildIsLeaf, + ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) + val rightChildIsLeaf = stats.rightImpurity == 0 + node.rightChild = Some(OptimizedLearningNode( + id = OptimizedLearningNode.rightChildIndex(node.id), + isLeaf = rightChildIsLeaf, + ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator) + )) + node.split = Some(split) + node.isLeaf = false + node.stats = stats + } + +} diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeUtils.scala b/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeUtils.scala new file mode 100755 index 0000000..99f632a --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeUtils.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.collection.mutable.Builder + +import org.apache.spark.internal.Logging +import org.apache.spark.util.collection.BitSet + +/** + * Utility methods specific to local decision tree training. + */ +private[ml] object LocalDecisionTreeUtils extends Logging { + + /** + * Convert a dataset of binned feature values from row storage to column storage. + * Stores data as [[org.apache.spark.ml.linalg.DenseVector]]. + * + * + * @param rowStore An array of input data rows, each represented as an + * int array of binned feature values + * @return Transpose of rowStore as an array of columns consisting of binned feature values. + * + * TODO: Add implementation for sparse data. + * For sparse data, distribute more evenly based on number of non-zeros. + * (First collect stats to decide how to partition.) + */ + private[impl] def rowToColumnStoreDense(rowStore: Array[Array[Int]]): Array[Array[Int]] = { + // Compute the number of rows in the data + val numRows = { + val longNumRows: Long = rowStore.length + require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD with $longNumRows rows," + + s" but can handle at most ${Int.MaxValue} rows") + longNumRows.toInt + } + + // Check that the input dataset isn't empty (0 rows) or featureless (rows with 0 features) + require(numRows > 0, "Local decision tree training requires numRows > 0.") + val numFeatures = rowStore(0).length + require(numFeatures > 0, "Local decision tree training requires numFeatures > 0.") + // Return the transpose of the rowStore matrix + rowStore.transpose + } + + private[impl] def rowToColumnStoreDenseWithSubsampling + (rowStore: Array[Array[Int]], selectedFeatures: Array[Int]): Array[Array[Int]] = { + // Compute the number of rows in the data + val numRows = { + val longNumRows: Long = rowStore.length + require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD with $longNumRows rows," + + s" but can handle at most ${Int.MaxValue} rows") + longNumRows.toInt + } + + // Check that the input dataset isn't empty (0 rows) or featureless (rows with 0 features) + require(numRows > 0, "Local decision tree training requires numRows > 0.") + val numFeatures = rowStore(0).length + require(numFeatures > 0, "Local decision tree training requires numFeatures > 0.") + // Return the transpose of the rowStore matrix + + + transposeSelectedFeatures(rowStore, selectedFeatures) + } + + + def transposeSelectedFeatures(rowStore: Array[Array[Int]], + selectedFeatures: Array[Int]): Array[Array[Int]] = { + val bb: Builder[Array[Int], Array[Array[Int]]] = Array.newBuilder + + val bs = selectedFeatures map (_ => Array.newBuilder[Int]) + for (xs <- rowStore) { + var i = 0 + for (x <- selectedFeatures) { + bs(i) += xs(x) + i += 1 + } + } + for (b <- bs) bb += b.result() + bb.result() + } + + /** + * Reorders the subset of array values at indices [from, to) + * according to the split information encoded in instanceBitVector (values for rows + * that split left appear before feature values for rows that split right). + * + * @param numLeftRows Number of rows on the left side of the split + * @param tempVals Destination buffer for reordered feature values + * @param instanceBitVector instanceBitVector(i) = true if the row corresponding to the + * (from + i)th array value splits right, false otherwise + */ + private[ml] def updateArrayForSplit( + values: Array[Int], + from: Int, + to: Int, + numLeftRows: Int, + tempVals: Array[Int], + instanceBitVector: BitSet): Unit = { + + // BEGIN SORTING + // We sort the [from, to) slice of col based on instance bit. + // All instances going "left" in the split (which are false) + // should be ordered before the instances going "right". The instanceBitVector + // gives us the split bit value for each instance based on the instance's index. + // We copy our feature values into @tempVals and @tempIndices either: + // 1) in the [from, numLeftRows) range if the bit is false, or + // 2) in the [numLeftRows, to) range if the bit is true. + var (leftInstanceIdx, rightInstanceIdx) = (0, numLeftRows) + var idx = from + while (idx < to) { + val bit = instanceBitVector.get(idx - from) + if (bit) { + tempVals(rightInstanceIdx) = values(idx) + rightInstanceIdx += 1 + } else { + tempVals(leftInstanceIdx) = values(idx) + leftInstanceIdx += 1 + } + idx += 1 + } + // END SORTING + // update the column values and indices + // with the corresponding indices + System.arraycopy(tempVals, 0, values, from, to - from) + } + +} diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/LocalTrainingScheduling.scala b/src/main/scala/org/apache/spark/ml/tree/impl/LocalTrainingScheduling.scala new file mode 100755 index 0000000..e7df8ee --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/impl/LocalTrainingScheduling.scala @@ -0,0 +1,117 @@ +/* + * Copyright (C) 2019 Cisco Systems + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.ml.tree.OptimizedLearningNode +import org.apache.spark.mllib.tree.configuration.TimePredictionStrategy + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +/** + * Represents a single decision tree Node selected to be split locally. + * + * @param node OptimizedLearningNode to split + * @param treeIndex index of its corresponding tree in the ensemble + * @param rows number of data points in the data subset corresponding to the node + * @param impurity the impurity value in the current data subset corresponding to the node + */ +case class LocalTrainingTask(node: OptimizedLearningNode, + treeIndex: Int, + rows: Long, + impurity: Double) { + + /** + * Computes the expected time requirement of the task based on a given time prediction strategy. + * The computed time is a value relative to other tasks, not the actual time in seconds. + * + * @param timePredictionStrategy strategy to calculate expected time requirements + * @return Double time prediction (relative value) + */ + private[impl] def computeTimePrediction(timePredictionStrategy: TimePredictionStrategy) + : Double = { + timePredictionStrategy.predict(rows, impurity) + } +} +object LocalTrainingTask { + // Implicit ordering of the local training tasks in decreasing order based on its data size -- + // we to pack the largest tasks first (greedy first-fit descending bin-packing). + implicit val orderingByRows: Ordering[LocalTrainingTask] = + Ordering.by((task: LocalTrainingTask) => task.rows).reverse +} + +/** + * Represents a set of LocalTrainingTasks to be processed together on one executor. + * (i.e. the total memory requirements of all tasks in the bin is below the local training threshold) + * + * @param maxRows the maximum number of data points which fit in this bin + * @param timePredictionStrategy strategy to calculate expected time requirements + */ +class LocalTrainingBin(val maxRows: Long, + timePredictionStrategy: TimePredictionStrategy) { + var currentRows: Long = 0 + var tasks: ListBuffer[LocalTrainingTask] = mutable.ListBuffer[LocalTrainingTask]() + var totalTimePrediction: Double = 0 + + /** + * Attempts to add the task into the LocalTrainingBin and returns whether the action succeeded. + * @param task LocalTrainingTask + * @return true if task was succesfully added / false if the task couldn't fit anymore + */ + def fitTask(task: LocalTrainingTask): Boolean = { + if (currentRows + task.rows <= maxRows) { + tasks += task + currentRows += task.rows + totalTimePrediction += task.computeTimePrediction(timePredictionStrategy) + return true + } + false + } +} + +object LocalTrainingBin { + // Implicit ordering of the bins -- we want to process the bins that are expected to take + // the longest time during the earliest batches of the local training process. + implicit val orderingByTimePrediction: Ordering[LocalTrainingBin] = + Ordering.by((bin: LocalTrainingBin) => bin.totalTimePrediction).reverse +} + +/** + * + * @param maxBinRows + * @param timePredictionStrategy + * @param maxTasksPerBin + */ +class LocalTrainingPlan(val maxBinRows: Long, + val timePredictionStrategy: TimePredictionStrategy, + val maxTasksPerBin: Int) { + var bins: mutable.ListBuffer[LocalTrainingBin] = mutable.ListBuffer[LocalTrainingBin]() + + /** + * Schedules the LocalTrainingTask into the first available LocalTrainingBin, or creates a new + * one if it doesn't fit into any of them. + * + * @param task LocalTrainingTask + */ + def scheduleTask(task: LocalTrainingTask): Unit = { + bins.find(bin => bin.tasks.size < maxTasksPerBin && bin.fitTask(task)).getOrElse { + val newBin = new LocalTrainingBin(maxBinRows, timePredictionStrategy) + newBin.fitTask(task) + bins += newBin + } + } +} diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForest.scala b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForest.scala new file mode 100755 index 0000000..3c83bc5 --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForest.scala @@ -0,0 +1,1264 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import java.io.IOException + +import org.apache.spark.Partitioner +import org.apache.spark.internal.Logging +import org.apache.spark.ml.classification.OptimizedDecisionTreeClassificationModel +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.regression.OptimizedDecisionTreeRegressionModel +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.util.Instrumentation +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, OptimizedForestStrategy => OldStrategy} +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator +import org.apache.spark.mllib.tree.model.ImpurityStats +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} + +import scala.collection.{SeqView, mutable} +import scala.util.{Random, Try} + + +/** + * ALGORITHM + * + * This is a sketch of the algorithm to help new developers. + * + * The algorithm partitions data by instances (rows). + * On each iteration, the algorithm splits a set of nodes. In order to choose the best split + * for a given node, sufficient statistics are collected from the distributed data. + * For each node, the statistics are collected to some worker node, and that worker selects + * the best split. + * + * This setup requires discretization of continuous features. This binning is done in the + * findSplits() method during initialization, after which each continuous feature becomes + * an ordered discretized feature with at most maxBins possible values. + * + * The main loop in the algorithm operates on a queue of nodes (nodeStack). These nodes + * lie at the periphery of the tree being trained. If multiple trees are being trained at once, + * then this queue contains nodes from all of them. Each iteration works roughly as follows: + * On the master node: + * - Some number of nodes are pulled off of the queue (based on the amount of memory + * required for their sufficient statistics). + * - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate + * features are chosen for each node. See method selectNodesToSplit(). + * On worker nodes, via method findBestSplits(): + * - The worker makes one pass over its subset of instances. + * - For each (tree, node, feature, split) tuple, the worker collects statistics about + * splitting. Note that the set of (tree, node) pairs is limited to the nodes selected + * from the queue for this iteration. The set of features considered can also be limited + * based on featureSubsetStrategy. + * - For each node, the statistics for that node are aggregated to a particular worker + * via reduceByKey(). The designated worker chooses the best (feature, split) pair, + * or chooses to stop splitting if the stopping criteria are met. + * On the master node: + * - The master collects all decisions about splitting nodes and updates the model. + * - The updated model is passed to the workers on the next iteration. + * This process continues until the node queue is empty. + * + * Most of the methods in this implementation support the statistics aggregation, which is + * the heaviest part of the computation. In general, this implementation is bound by either + * the cost of statistics computation on workers or by communicating the sufficient statistics. + */ +private[spark] object OptimizedRandomForest extends Logging { + + /** + * Train a random forest. + * + * @param input Training data: RDD of `LabeledPoint` + * @return an unweighted set of trees + */ + def run( + input: RDD[LabeledPoint], + strategy: OldStrategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Long, + instr: Option[Instrumentation], + prune: Boolean = true, + parentUID: Option[String] = None, + computeStatistics: Boolean = false) + : (Array[OptimizedDecisionTreeModel], Option[TrainingStatistics]) = { + + val timer = new TimeTracker() + + timer.start("total") + + timer.start("init") + + val retaggedInput = input.retag(classOf[LabeledPoint]) + val metadata = + DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) + instr match { + case Some(instrumentation) => + instrumentation.logNumFeatures(metadata.numFeatures) + instrumentation.logNumClasses(metadata.numClasses) + case None => + logInfo("numFeatures: " + metadata.numFeatures) + logInfo("numClasses: " + metadata.numClasses) + } + + val timePredictionStrategy = strategy.getTimePredictionStrategy + val localTrainingAlgorithm: LocalTrainingAlgorithm = strategy.getLocalTrainingAlgorithm + + // Find the splits and the corresponding bins (interval between the splits) using a sample + // of the input data. + timer.start("findSplits") + + val splits = strategy.customSplits.map(splits => { + if(metadata.numFeatures != splits.length) { + throw new IllegalArgumentException("strategy.customSplits have wrong size: metadata.numFeatures= " + + s"${metadata.numFeatures} while customSplits.length= ${splits.length}") + } + splits.zipWithIndex.map { case (feature, idx) => + // Set metadata: + metadata.setNumSplits(idx, feature.length) + // Convert Array[Array[Double]] into Array[Array[Split]] + feature.map(threshold => new ContinuousSplit(idx, threshold).asInstanceOf[Split]) + } + }).getOrElse(findSplits(retaggedInput, metadata, seed)) + + timer.stop("findSplits") + logDebug("numBins: feature: number of bins") + logDebug(Range(0, metadata.numFeatures).map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + }.mkString("\n")) + + // Bin feature values (TreePoint representation). + // Cache input RDD for speedup during multiple passes. + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, splits, metadata) + + val withReplacement = numTrees > 1 + + val baggedInput = BaggedPoint + .convertToBaggedRDD(treeInput, strategy.subsamplingRate, numTrees, withReplacement, seed) + .persist(StorageLevel.MEMORY_AND_DISK) + + val distributedMaxDepth = Math.min(strategy.maxDepth, 30) + + // Max memory usage for aggregates + // TODO: Calculate memory usage more precisely. + val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L + logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") + + /* + * The main idea here is to perform group-wise training of the decision tree nodes thus + * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). + * Each data sample is handled by a particular node (or it reaches a leaf and is not used + * in lower levels). + */ + + // Create an RDD of node Id cache. + // At first, all the rows belong to the root nodes (node Id == 1). + val nodeIdCache = if (strategy.useNodeIdCache) { + Some(NodeIdCache.init( + data = baggedInput, + numTrees = numTrees, + checkpointInterval = strategy.checkpointInterval, + initVal = 1)) + } else { + None + } + + /* + * Stack of nodes to train: (treeIndex, node) + * The reason this is a stack is that we train many trees at once, but we want to focus on + * completing trees, rather than training all simultaneously. If we are splitting nodes from + * 1 tree, then the new nodes to split will be put at the top of this stack, so we will continue + * training the same tree in the next iteration. This focus allows us to send fewer trees to + * workers on each iteration; see topNodesForGroup below. + */ + val nodeStack = new mutable.ArrayStack[(Int, OptimizedLearningNode)] + val localTrainingStack = new mutable.ListBuffer[LocalTrainingTask] + + val rng = new Random() + rng.setSeed(seed) + + // Allocate and queue root nodes. + val topNodes = + Array.fill[OptimizedLearningNode](numTrees)(OptimizedLearningNode.emptyNode(nodeIndex = 1)) + Range(0, numTrees).foreach(treeIndex => nodeStack.push((treeIndex, topNodes(treeIndex)))) + + timer.stop("init") + + timer.start("distributedTraining") + + // Calculate the local training threshold (the number of data points which fit onto a single executor core). + // Attempts to determine this value dynamically from the cluster setup. + + val numCores = baggedInput.context.getConf.getInt("spark.executor.cores", 1) + + val maxExecutorMemory = Try( + baggedInput.sparkContext.getExecutorMemoryStatus.head match { + case (executorId, (usedMemory, maxMemory)) => maxMemory / numCores + } + ).getOrElse(maxMemoryUsage) + + val nodeMemUsage = OptimizedRandomForest.aggregateSizeForNode(metadata, None) * 8L + val featuresMem = (metadata.numFeatures + metadata.numTrees + 1) * 8L + + val localTrainingThreshold = + ((maxExecutorMemory - nodeMemUsage) / (strategy.maxMemoryMultiplier * featuresMem)).toInt + + val trainingLimits = TrainingLimits(localTrainingThreshold, distributedMaxDepth) + + while (nodeStack.nonEmpty) { + // Collect some nodes to split, and choose features for each node (if subsampling). + // Each group of nodes may come from one or multiple trees, and at multiple levels. + val (nodesForGroup, treeToNodeToIndexInfo) = + OptimizedRandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) + // Sanity check (should never occur): + assert(nodesForGroup.nonEmpty, + s"OptimizedRandomForest selected empty nodesForGroup. Error for unknown reason.") + + // Only send trees to worker if they contain nodes being split this iteration. + val topNodesForGroup: Map[Int, OptimizedLearningNode] = + nodesForGroup.keys.map(treeIdx => treeIdx -> topNodes(treeIdx)).toMap + + // Choose node splits, and enqueue new nodes as needed. + timer.start("findBestSplits") + OptimizedRandomForest.findBestSplits(baggedInput, metadata, topNodesForGroup, nodesForGroup, + treeToNodeToIndexInfo, splits, + (nodeStack, localTrainingStack), + trainingLimits, + timer, nodeIdCache) + timer.stop("findBestSplits") + } + timer.stop("distributedTraining") + + timer.start("localTraining") + + val nodeStats = mutable.ListBuffer.empty[NodeStatistics] + val numExecutors = Math.max(baggedInput.context.getExecutorMemoryStatus.size - 1, 1) + + val numPartitions = numExecutors * numCores + + /** + * Pack smaller nodes together using first-fit decreasing bin-packing and then sort + * the resulting bins by their predicted duration (implicitly in decreasing order). + * + * @return List[LocalTrainingBin] sorted in decreasing order + */ + def scheduleLocalTrainingTasks: Seq[LocalTrainingBin] = { + val trainingPlan = new LocalTrainingPlan(localTrainingThreshold, + timePredictionStrategy, + strategy.maxTasksPerBin) + + localTrainingStack.sorted.foreach(task => trainingPlan.scheduleTask(task)) + trainingPlan.bins.sorted.toList + } + + /** + * Group all nodes in the current batch of local training tasks by tree. + * + * @param batch + * @return (treeId, nodeList) tuples for all trees in the batch + */ + def getNodesForTrees(batch: Seq[LocalTrainingBin]): Map[Int, Seq[Int]] = { + batch.flatMap(bin => { + bin.tasks.map(task => (task.treeIndex, task.node.id)) + }).groupBy { + case (treeId, _) => treeId + }.map { + case (treeId, nodes) => (treeId, nodes.map { case (_, nodeId) => nodeId }) + } + } + + /** + * Determine which node subset the input BaggedPoint belongs to for every tree, either + * using NodeIdCache or by evaluating it in the current model. + * + * @return (baggedPoint, nodeIdArray) + */ + def getNodeIdsForPoints: RDD[(BaggedPoint[TreePoint], Array[Int])] = { + if (nodeIdCache.nonEmpty) { + baggedInput.zip(nodeIdCache.get.nodeIdsForInstances) + } else { + baggedInput.map(point => + (point, + Range(0, numTrees) + .map(treeId => topNodes(treeId).predictImpl(point.datum.binnedFeatures, splits)) + .toArray) + ) + } + } + + /** + * Filter the points used in nodes in the current batch and duplicate them if they are + * used in multiple trees. + * + * @return RDD((treeId, nodeId), (treePoint, sampleWeight)) + */ + def filterDataInBatch(batch: Seq[LocalTrainingBin], + pointsWithNodeIds: RDD[(BaggedPoint[TreePoint], Array[Int])]) = { + val nodeSets: Map[Int, Seq[Int]] = getNodesForTrees(batch) + val nodeSetsBc = baggedInput.sparkContext.broadcast(nodeSets) + + pointsWithNodeIds.flatMap { + case (baggedPoint, nodeIdsForTree) => + nodeSetsBc.value.keys + .filter(treeId => baggedPoint.subsampleWeights(treeId) > 0) + .map(treeId => (treeId, nodeIdsForTree(treeId))) + .filter { case (treeId, nodeId) => nodeSetsBc.value(treeId).contains(nodeId) } + .map { case (treeId, nodeId) => + ((treeId, nodeId), (baggedPoint.datum, baggedPoint.subsampleWeights(treeId))) + } + } + } + + /** + * Partition the data so that each bin is processed on one executor. + * + * @return partitioned data + */ + def partitionByBin(batch: Seq[LocalTrainingBin], + filtered: RDD[((Int, Int), (TreePoint, Double))]) = { + val treeNodeMapping = batch.zipWithIndex.flatMap { + case (bin, partitionIndex) => + bin.tasks.map(task => ((task.treeIndex, task.node.id), partitionIndex)) + }.toMap + + filtered.partitionBy(new NodeIdPartitioner(batch.length, treeNodeMapping)) + } + + /** + * In each partition, group points that belong to the same node and train the nodes + * using a local training algorithm. + * + * @return + */ + def runLocalTraining(partitioned: RDD[((Int, Int), (TreePoint, Double))]) = { + partitioned + .mapPartitions(partition => { + partition.toSeq + .groupBy { case (nodeIds, _) => nodeIds } + .values + .map(pointsWithIndices => + (pointsWithIndices.head._1, pointsWithIndices.map { case (_, point) => point })) + .map { case ((treeIndex, nodeIndex), points) => + trainNodeLocally(treeIndex, nodeIndex, points) + }.toIterator + }).collect() + } + + /** + * Run local training and collect statistics about the training duration and data size. + * @return + */ + def trainNodeLocally(treeIndex: Int, nodeIndex: Int, points: Seq[(TreePoint, Double)]) = { + val startTime = System.nanoTime() + val pointArray = points.map(_._1).toArray + val instanceWeights = points.map(_._2).toArray + val node = OptimizedLearningNode.emptyNode(nodeIndex) + + val currentLevel = LearningNode.indexToLevel(nodeIndex) + val localMaxDepth = metadata.maxDepth - currentLevel + + val tree = localTrainingAlgorithm.fitNode(pointArray, instanceWeights, node, + metadata, splits, Some(localMaxDepth), prune) + + val time = (System.nanoTime() - startTime) / 1e9 + (time, tree, points.length, treeIndex, nodeIndex) + } + + /** + * Update the main model on driver with a locally trained subtree. + */ + def updateModelWithSubtree(learningNode: OptimizedLearningNode, treeIndex: Int): Unit = { + val parent = OptimizedLearningNode.getNode( + OptimizedLearningNode.parentIndex(learningNode.id), topNodes(treeIndex)) + if (OptimizedLearningNode.isLeftChild(learningNode.id)) { + parent.leftChild = Some(learningNode) + } else { + parent.rightChild = Some(learningNode) + } + } + + + timer.start("localTrainingScheduling") + val trainingPlan = scheduleLocalTrainingTasks + timer.stop("localTrainingScheduling") + + val pointsWithNodeIds = getNodeIdsForPoints.cache() + + val finishedNodeMap = trainingPlan.grouped(numPartitions).flatMap(batch => { + timer.start("localTrainingFitting") + + val filtered = filterDataInBatch(batch, pointsWithNodeIds) + val partitioned = partitionByBin(batch, filtered) + val finished = runLocalTraining(partitioned) + + val nodesMap = finished.map { case (time, node, rows, treeIndex, nodeIndex) => + if (computeStatistics) { + nodeStats += NodeStatistics(nodeIndex, rows, node.impurity, time) + } + + ((treeIndex, nodeIndex), node) + } + + timer.stop("localTrainingFitting") + nodesMap + }).toMap + + timer.stop("localTraining") + + baggedInput.unpersist() + + timer.stop("total") + + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + // Delete any remaining checkpoints used for node Id cache. + if (nodeIdCache.nonEmpty) { + try { + nodeIdCache.get.deleteAllCheckpoints() + } catch { + case e: IOException => + logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}") + } + } + + val numFeatures = metadata.numFeatures + + val model: Array[OptimizedDecisionTreeModel] = parentUID match { + case Some(uid) => + if (strategy.algo == OldAlgo.Classification) { + topNodes.zipWithIndex.map { case (rootNode, treeIndex) => + new OptimizedDecisionTreeClassificationModel(uid, + rootNode.toNodeWithLocalNodesMap(finishedNodeMap, treeIndex, prune), + numFeatures, strategy.getNumClasses) + } + } else { + topNodes.zipWithIndex.map { case (rootNode, treeIndex) => + new OptimizedDecisionTreeRegressionModel(uid, + rootNode.toNodeWithLocalNodesMap(finishedNodeMap, treeIndex, prune), + numFeatures) + } + } + case None => + if (strategy.algo == OldAlgo.Classification) { + topNodes.zipWithIndex.map { case (rootNode, treeIndex) => + new OptimizedDecisionTreeClassificationModel(rootNode.toNodeWithLocalNodesMap(finishedNodeMap, treeIndex, prune), + numFeatures, strategy.getNumClasses) + } + } else { + topNodes.zipWithIndex.map { case (rootNode, treeIndex) => + new OptimizedDecisionTreeRegressionModel(rootNode.toNodeWithLocalNodesMap(finishedNodeMap, treeIndex, prune), + numFeatures) + } + } + } + + if (computeStatistics) { + return (model, Some(TrainingStatistics(timer, nodeStats.toList))) + } + + (model, None) + } + + /** + * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. + * + * For ordered features, a single bin is updated. + * For unordered features, bins correspond to subsets of categories; either the left or right bin + * for each subset is updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param splits possible splits indexed (numFeatures)(numSplits) + * @param unorderedFeatures Set of indices of unordered features. + * @param instanceWeight Weight (importance) of instance in dataset. + */ + private def mixedBinSeqOp( + agg: DTStatsAggregator, + treePoint: TreePoint, + splits: Array[Array[Split]], + unorderedFeatures: Set[Int], + instanceWeight: Double, + featuresForNode: Option[Array[Int]]): Unit = { + val numFeaturesPerNode = if (featuresForNode.nonEmpty) { + // Use subsampled features + featuresForNode.get.length + } else { + // Use all features + agg.metadata.numFeatures + } + // Iterate over features. + var featureIndexIdx = 0 + while (featureIndexIdx < numFeaturesPerNode) { + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx + } + if (unorderedFeatures.contains(featureIndex)) { + AggUpdateUtils.updateUnorderedFeature(agg, + featureValue = treePoint.binnedFeatures(featureIndex), label = treePoint.label, + featureIndex = featureIndex, featureIndexIdx = featureIndexIdx, + featureSplits = splits(featureIndex), instanceWeight = instanceWeight) + } else { + AggUpdateUtils.updateOrderedFeature(agg, + featureValue = treePoint.binnedFeatures(featureIndex), label = treePoint.label, + featureIndexIdx = featureIndexIdx, instanceWeight = instanceWeight) + } + featureIndexIdx += 1 + } + } + + /** + * Helper for binSeqOp, for regression and for classification with only ordered features. + * + * For each feature, the sufficient statistics of one bin are updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (feature, bin). + * @param treePoint Data point being aggregated. + * @param instanceWeight Weight (importance) of instance in dataset. + */ + private def orderedBinSeqOp( + agg: DTStatsAggregator, + treePoint: TreePoint, + instanceWeight: Double, + featuresForNode: Option[Array[Int]]): Unit = { + val label = treePoint.label + + // Iterate over features. + if (featuresForNode.nonEmpty) { + // Use subsampled features + var featureIndexIdx = 0 + while (featureIndexIdx < featuresForNode.get.length) { + val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) + agg.update(featureIndexIdx, binIndex, label, instanceWeight) + featureIndexIdx += 1 + } + } else { + // Use all features + val numFeatures = agg.metadata.numFeatures + var featureIndex = 0 + while (featureIndex < numFeatures) { + val binIndex = treePoint.binnedFeatures(featureIndex) + agg.update(featureIndex, binIndex, label, instanceWeight) + featureIndex += 1 + } + } + } + + /** + * Given a group of nodes, this finds the best split for each node. + * + * @param input Training data: RDD of [[TreePoint]] + * @param metadata Learning and dataset metadata + * @param topNodesForGroup For each tree in group, tree index -> root node. + * Used for matching instances with nodes. + * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree + * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, + * where nodeIndexInfo stores the index in the group and the + * feature subsets (if using feature subsets). + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @param stacks Queue of nodes to split, with values (treeIndex, node). + * Updated with new non-leaf nodes which are created. + * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where + * each value in the array is the data point's node Id + * for a corresponding tree. This is used to prevent the need + * to pass the entire tree to the executors during + * the node stat aggregation phase. + */ + private[tree] def findBestSplits( + input: RDD[BaggedPoint[TreePoint]], + metadata: DecisionTreeMetadata, + topNodesForGroup: Map[Int, OptimizedLearningNode], + nodesForGroup: Map[Int, Array[OptimizedLearningNode]], + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], + splits: Array[Array[Split]], + stacks: (mutable.ArrayStack[(Int, OptimizedLearningNode)], + mutable.ListBuffer[LocalTrainingTask]), + limits: TrainingLimits, + timer: TimeTracker = new TimeTracker, + nodeIdCache: Option[NodeIdCache] = None): Unit = { + + /* + * The high-level descriptions of the best split optimizations are noted here. + * + * *Group-wise training* + * We perform bin calculations for groups of nodes to reduce the number of + * passes over the data. Each iteration requires more computation and storage, + * but saves several iterations over the data. + * + * *Bin-wise computation* + * We use a bin-wise best split computation strategy instead of a straightforward best split + * computation strategy. Instead of analyzing each sample for contribution to the left/right + * child node impurity of every split, we first categorize each feature of a sample into a + * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates + * to calculate information gain for each split. + * + * *Aggregation over partitions* + * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know + * the number of splits in advance. Thus, we store the aggregates (at the appropriate + * indices) in a single array for all bins and rely upon the RDD aggregate method to + * drastically reduce the communication overhead. + */ + + // numNodes: Number of nodes in this group + val numNodes = nodesForGroup.values.map(_.length).sum + logDebug("numNodes = " + numNodes) + logDebug("numFeatures = " + metadata.numFeatures) + logDebug("numClasses = " + metadata.numClasses) + logDebug("isMulticlass = " + metadata.isMulticlass) + logDebug("isMulticlassWithCategoricalFeatures = " + + metadata.isMulticlassWithCategoricalFeatures) + logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) + + val (nodeStack, localTrainingStack) = stacks + + /** + * Performs a sequential aggregation over a partition for a particular tree and node. + * + * For each feature, the aggregate sufficient statistics are updated for the relevant + * bins. + * + * @param treeIndex Index of the tree that we want to perform aggregation for. + * @param nodeInfo The node info for the tree node. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics + * for each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + */ + def nodeBinSeqOp( + treeIndex: Int, + nodeInfo: NodeIndexInfo, + agg: Array[DTStatsAggregator], + baggedPoint: BaggedPoint[TreePoint]): Unit = { + if (nodeInfo != null) { + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val instanceWeight = baggedPoint.subsampleWeights(treeIndex) + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) + } else { + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, + metadata.unorderedFeatures, instanceWeight, featuresForNode) + } + agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) + } + } + + /** + * Performs a sequential aggregation over a partition. + * + * Each data point contributes to one node. For each feature, + * the aggregate sufficient statistics are updated for the relevant bins. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + * @return agg + */ + def binSeqOp( + agg: Array[DTStatsAggregator], + baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val nodeIndex = + topNodesForGroup(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) + } + agg + } + + /** + * Do the same thing as binSeqOp, but with nodeIdCache. + */ + def binSeqOpWithNodeIdCache( + agg: Array[DTStatsAggregator], + dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val baggedPoint = dataPoint._1 + val nodeIdCache = dataPoint._2 + val nodeIndex = nodeIdCache(treeIndex) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) + } + + agg + } + + /** + * Get node index in group --> features indices map, + * which is a short cut to find feature indices for a node given node index in group. + */ + def getNodeToFeatures( + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]): Option[Map[Int, Array[Int]]] = { + if (!metadata.subsamplingFeatures) { + None + } else { + val mutableNodeToFeatures = new mutable.HashMap[Int, Array[Int]]() + treeToNodeToIndexInfo.values.foreach { nodeIdToNodeInfo => + nodeIdToNodeInfo.values.foreach { nodeIndexInfo => + assert(nodeIndexInfo.featureSubset.isDefined) + mutableNodeToFeatures(nodeIndexInfo.nodeIndexInGroup) = nodeIndexInfo.featureSubset.get + } + } + Some(mutableNodeToFeatures.toMap) + } + } + + def addTrainingTask(node: OptimizedLearningNode, + treeIndex: Int, + rows: Long, + nodeLevel: Int, + impurity: Double) = { + if (rows < limits.localTrainingThreshold) { + val task = new LocalTrainingTask(node, treeIndex, rows, impurity) + localTrainingStack += task + } else { + nodeStack.push((treeIndex, node)) + } + } + + // array of nodes to train indexed by node index in group + val nodes = new Array[OptimizedLearningNode](numNodes) + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node + } + } + + // Calculate best splits for all nodes in the group + timer.start("chooseSplits") + + // In each partition, iterate all instances and compute aggregate stats for each node, + // yield a (nodeIndex, nodeAggregateStats) pair for each node. + // After a `reduceByKey` operation, + // stats of a node will be shuffled to a particular partition and be combined together, + // then best splits for nodes are found there. + // Finally, only best Splits for nodes are collected to driver to construct decision tree. + val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) + val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) + + val partitionAggregates: RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { + input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures => + nodeToFeatures(nodeIndex) + } + new DTStatsAggregator(metadata, featuresForNode) + } + + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + } + } else { + input.mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + new DTStatsAggregator(metadata, featuresForNode) + } + + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOp(nodeStatsAggregators, _)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + } + } + + // Aggregate sufficient stats by node, then find best splits + val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)).map { + case (nodeIndex, aggStats) => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + + // find best split for each node + val (split: Split, stats: ImpurityStats) = + OptimizedRandomForest.binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) + (nodeIndex, (split, stats)) + }.collectAsMap() + + timer.stop("chooseSplits") + + // Perform splits + val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { + Array.fill[mutable.Map[Int, NodeIndexUpdater]]( + metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]()) + } else { + null + } + // Iterate over all nodes in this group. + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + val nodeIndex = node.id + val nodeLevel = LearningNode.indexToLevel(nodeIndex) + val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val (split: Split, stats: ImpurityStats) = + nodeToBestSplits(aggNodeIndex) + logDebug("best split = " + split) + + // Extract info for this node. Create children if not leaf. + val isLeaf = + (stats.gain <= 0) || (nodeLevel == limits.distributedMaxDepth) + node.isLeaf = isLeaf + node.stats = stats + logDebug("Node = " + node) + + if (!isLeaf) { + node.split = Some(split) + val childIsLeaf = (nodeLevel + 1) == limits.distributedMaxDepth + val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) + val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) + node.leftChild = Some(OptimizedLearningNode( + OptimizedLearningNode.leftChildIndex(nodeIndex), + leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) + node.rightChild = Some(OptimizedLearningNode( + LearningNode.rightChildIndex(nodeIndex), + rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) + + if (nodeIdCache.nonEmpty) { + val nodeIndexUpdater = NodeIndexUpdater( + split = split, + nodeIndex = nodeIndex) + nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater) + } + + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + addTrainingTask(node.leftChild.get, treeIndex, stats.leftImpurityCalculator.count, + nodeLevel, stats.leftImpurity) + } + if (!rightChildIsLeaf) { + addTrainingTask(node.rightChild.get, treeIndex, stats.rightImpurityCalculator.count, + nodeLevel, stats.rightImpurity) + } + + logDebug("leftChildIndex = " + node.leftChild.get.id + + ", impurity = " + stats.leftImpurity) + logDebug("rightChildIndex = " + node.rightChild.get.id + + ", impurity = " + stats.rightImpurity) + } + } + } + + if (nodeIdCache.nonEmpty) { + // Update the cache if needed. + nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, splits) + } + } + + + /** + * Return a list of pairs (featureIndexIdx, featureIndex) where featureIndex is the global + * (across all trees) index of a feature and featureIndexIdx is the index of a feature within the + * list of features for a given node. Filters out features known to be constant + * (features with 0 splits) + */ + private[impl] def getFeaturesWithSplits( + metadata: DecisionTreeMetadata, + featuresForNode: Option[Array[Int]]): SeqView[(Int, Int), Seq[_]] = { + Range(0, metadata.numFeaturesPerNode).view.map { featureIndexIdx => + featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) + .getOrElse((featureIndexIdx, featureIndexIdx)) + }.withFilter { case (_, featureIndex) => + metadata.numSplits(featureIndex) != 0 + } + } + + private[impl] def getBestSplitByGain( + parentImpurityCalculator: ImpurityCalculator, + metadata: DecisionTreeMetadata, + featuresForNode: Option[Array[Int]], + splitsAndImpurityInfo: Seq[(Split, ImpurityStats)]): (Split, ImpurityStats) = { + val (bestSplit, bestSplitStats) = + if (splitsAndImpurityInfo.isEmpty) { + // If no valid splits for features, then this split is invalid, + // return invalid information gain stats. Take any split and continue. + // Splits is empty, so arbitrarily choose to split on any threshold + val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0) + if (metadata.isContinuous(dummyFeatureIndex)) { + (new ContinuousSplit(dummyFeatureIndex, 0), + ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) + } else { + val numCategories = metadata.featureArity(dummyFeatureIndex) + (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories), + ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)) + } + } else { + splitsAndImpurityInfo.maxBy(_._2.gain) + } + (bestSplit, bestSplitStats) + } + + /** + * Find the best split for a node. + * + * @param binAggregates Bin statistics. + * @return tuple for best split: (Split, information gain, prediction at node) + */ + private[tree] def binsToBestSplit( + binAggregates: DTStatsAggregator, + splits: Array[Array[Split]], + featuresForNode: Option[Array[Int]], + node: OptimizedLearningNode): (Split, ImpurityStats) = { + val validFeatureSplits = getFeaturesWithSplits(binAggregates.metadata, featuresForNode) + // For each (feature, split), calculate the gain, and select the best (feature, split). + val parentImpurityCalc = if (node.stats == null) None else Some(node.stats.impurityCalculator) + val splitsAndImpurityInfo = + validFeatureSplits.map { case (featureIndexIdx, featureIndex) => + SplitUtils.chooseSplit(binAggregates, featureIndex, featureIndexIdx, splits(featureIndex), + parentImpurityCalc) + } + getBestSplitByGain(binAggregates.getParentImpurityCalculator(), binAggregates.metadata, + featuresForNode, splitsAndImpurityInfo) + } + + private[impl] def findUnorderedSplits( + metadata: DecisionTreeMetadata, + featureIndex: Int): Array[Split] = { + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations + val featureArity = metadata.featureArity(featureIndex) + Array.tabulate[Split](metadata.numSplits(featureIndex)) { splitIndex => + val categories = extractMultiClassCategories(splitIndex + 1, featureArity) + new CategoricalSplit(featureIndex, categories.toArray, featureArity) + } + } + + /** + * Returns splits for decision tree calculation. + * Continuous and categorical features are handled differently. + * + * Continuous features: + * For each feature, there are numBins - 1 possible splits representing the possible binary + * decisions at each node in the tree. + * This finds locations (feature values) for splits using a subsample of the data. + * + * Categorical features: + * For each feature, there is 1 bin per split. + * Splits and bins are handled in 2 ways: + * (a) "unordered features" + * For multiclass classification with a low-arity feature + * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), + * the feature is split based on subsets of categories. + * (b) "ordered features" + * For regression and binary classification, + * and for multiclass classification with a high-arity feature, + * there is one bin per category. + * + * @param input Training data: RDD of [[LabeledPoint]] + * @param metadata Learning and dataset metadata + * @param seed random seed + * @return Splits, an Array of [[Split]] + * of size (numFeatures, numSplits) + */ + protected[tree] def findSplits( + input: RDD[LabeledPoint], + metadata: DecisionTreeMetadata, + seed: Long): Array[Array[Split]] = { + + logDebug("isMulticlass = " + metadata.isMulticlass) + + val numFeatures = metadata.numFeatures + + // Sample the input only if there are continuous features. + val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous) + val sampledInput = if (continuousFeatures.nonEmpty) { + // Calculate the number of samples for approximate quantile calculation. + val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) + val fraction = if (requiredSamples < metadata.numExamples) { + requiredSamples.toDouble / metadata.numExamples + } else { + 1.0 + } + logDebug("fraction of data used for calculating quantiles = " + fraction) + input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) + } else { + input.sparkContext.emptyRDD[LabeledPoint] + } + + findSplitsBySorting(sampledInput, metadata, continuousFeatures) + } + + private def findSplitsBySorting( + input: RDD[LabeledPoint], + metadata: DecisionTreeMetadata, + continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = { + + val continuousSplits: scala.collection.Map[Int, Array[Split]] = { + // reduce the parallelism for split computations when there are less + // continuous features than input partitions. this prevents tasks from + // being spun up that will definitely do no work. + val numPartitions = math.min(continuousFeatures.length, input.partitions.length) + + input + .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx)))) + .groupByKey(numPartitions) + .map { case (idx, samples) => + val thresholds = findSplitsForContinuousFeature(samples, metadata, idx) + val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh)) + logDebug(s"featureIndex = $idx, numSplits = ${splits.length}") + (idx, splits) + }.collectAsMap() + } + + val numFeatures = metadata.numFeatures + val splits: Array[Array[Split]] = Array.tabulate(numFeatures) { + case i if metadata.isContinuous(i) => + val split = continuousSplits(i) + metadata.setNumSplits(i, split.length) + split + + case i if metadata.isCategorical(i) && metadata.isUnordered(i) => + findUnorderedSplits(metadata, i) + + case i if metadata.isCategorical(i) => + // Ordered features + // Splits are constructed as needed during training. + Array.empty[Split] + } + splits + } + + /** + * Nested method to extract list of eligible categories given an index. It extracts the + * position of ones in a binary representation of the input. If binary + * representation of an number is 01101 (13), the output list should (3.0, 2.0, + * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones. + */ + private[tree] def extractMultiClassCategories( + input: Int, + maxFeatureValue: Int): List[Double] = { + var categories = List[Double]() + var j = 0 + var bitShiftedInput = input + while (j < maxFeatureValue) { + if (bitShiftedInput % 2 != 0) { + // updating the list of categories. + categories = j.toDouble :: categories + } + // Right shift by one + bitShiftedInput = bitShiftedInput >> 1 + j += 1 + } + categories + } + + /** + * Find splits for a continuous feature + * NOTE: Returned number of splits is set based on `featureSamples` and + * could be different from the specified `numSplits`. + * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. + * + * @param featureSamples feature values of each sample + * @param metadata decision tree metadata + * NOTE: `metadata.numbins` will be changed accordingly + * if there are not enough splits to be found + * @param featureIndex feature index to find splits + * @return array of split thresholds + */ + private[tree] def findSplitsForContinuousFeature( + featureSamples: Iterable[Double], + metadata: DecisionTreeMetadata, + featureIndex: Int): Array[Double] = { + require(metadata.isContinuous(featureIndex), + "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") + + val splits: Array[Double] = if (featureSamples.isEmpty) { + Array.empty[Double] + } else { + val numSplits = metadata.numSplits(featureIndex) + + // get count for each distinct value + val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) { + case ((m, cnt), x) => + (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1) + } + // sort distinct values + val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray + + val possibleSplits = valueCounts.length - 1 + if (possibleSplits == 0) { + // constant feature + Array.empty[Double] + } else if (possibleSplits <= numSplits) { + // if possible splits is not enough or just enough, just return all possible splits + (1 to possibleSplits) + .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0) + .toArray + } else { + // stride between splits + val stride: Double = numSamples.toDouble / (numSplits + 1) + logDebug("stride = " + stride) + + // iterate `valueCount` to find splits + val splitsBuilder = mutable.ArrayBuilder.make[Double] + var index = 1 + // currentCount: sum of counts of values that have been visited + var currentCount = valueCounts(0)._2 + // targetCount: target value for `currentCount`. + // If `currentCount` is closest value to `targetCount`, + // then current value is a split threshold. + // After finding a split threshold, `targetCount` is added by stride. + var targetCount = stride + while (index < valueCounts.length) { + val previousCount = currentCount + currentCount += valueCounts(index)._2 + val previousGap = math.abs(previousCount - targetCount) + val currentGap = math.abs(currentCount - targetCount) + // If adding count of current value to currentCount + // makes the gap between currentCount and targetCount smaller, + // previous value is a split threshold. + if (previousGap < currentGap) { + splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0 + targetCount += stride + } + index += 1 + } + + splitsBuilder.result() + } + } + splits + } + + private[tree] class NodeIndexInfo( + val nodeIndexInGroup: Int, + val featureSubset: Option[Array[Int]]) extends Serializable + + /** + * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. + * This tracks the memory usage for aggregates and stops adding nodes when too much memory + * will be needed; this allows an adaptive number of nodes since different nodes may require + * different amounts of memory (if featureSubsetStrategy is not "all"). + * + * @param nodeStack Queue of nodes to split. + * @param maxMemoryUsage Bound on size of aggregate statistics. + * @return (nodesForGroup, treeToNodeToIndexInfo). + * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. + * + * treeToNodeToIndexInfo holds indices selected features for each node: + * treeIndex --> (global) node index --> (node index in group, feature indices). + * The (global) node index is the index in the tree; the node index in group is the + * index in [0, numNodesInGroup) of the node in this group. + * The feature indices are None if not subsampling features. + */ + private[tree] def selectNodesToSplit( + nodeStack: mutable.ArrayStack[(Int, OptimizedLearningNode)], + maxMemoryUsage: Long, + metadata: DecisionTreeMetadata, + rng: Random) + : (Map[Int, Array[OptimizedLearningNode]], Map[Int, Map[Int, NodeIndexInfo]]) = { + // Collect some nodes to split: + // nodesForGroup(treeIndex) = nodes to split + val mutableNodesForGroup = + new mutable.HashMap[Int, mutable.ArrayBuffer[OptimizedLearningNode]]() + val mutableTreeToNodeToIndexInfo = + new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]() + var memUsage: Long = 0L + var numNodesInGroup = 0 + // If maxMemoryInMB is set very small, we want to still try to split 1 node, + // so we allow one iteration if memUsage == 0. + var groupDone = false + while (nodeStack.nonEmpty && !groupDone) { + val (treeIndex, node) = nodeStack.top + // Choose subset of features for node (if subsampling). + val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { + Some(SamplingUtils.reservoirSampleAndCount(Range(0, + metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong())._1) + } else { + None + } + // Check if enough memory remains to add this node to the group. + val nodeMemUsage = OptimizedRandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L + if (memUsage + nodeMemUsage <= maxMemoryUsage || memUsage == 0) { + nodeStack.pop() + mutableNodesForGroup.getOrElseUpdate(treeIndex, + new mutable.ArrayBuffer[OptimizedLearningNode]()) += node + mutableTreeToNodeToIndexInfo + .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) + = new NodeIndexInfo(numNodesInGroup, featureSubset) + numNodesInGroup += 1 + memUsage += nodeMemUsage + } else { + groupDone = true + } + } + if (memUsage > maxMemoryUsage) { + // If maxMemoryUsage is 0, we should still allow splitting 1 node. + logWarning(s"Tree learning is using approximately $memUsage bytes per iteration, which" + + s" exceeds requested limit maxMemoryUsage=$maxMemoryUsage. This allows splitting" + + s" $numNodesInGroup nodes in this iteration.") + } + // Convert mutable maps to immutable ones. + val nodesForGroup: Map[Int, Array[OptimizedLearningNode]] = + mutableNodesForGroup.mapValues(_.toArray).toMap + val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap + (nodesForGroup, treeToNodeToIndexInfo) + } + + /** + * Get the number of values to be stored for this node in the bin aggregates. + * + * @param featureSubset Indices of features which may be split at this node. + * If None, then use all features. + */ + private def aggregateSizeForNode( + metadata: DecisionTreeMetadata, + featureSubset: Option[Array[Int]]): Long = { + val totalBins = if (featureSubset.nonEmpty) { + featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum + } else { + metadata.numBins.map(_.toLong).sum + } + if (metadata.isClassification) { + metadata.numClasses * totalBins + } else { + 3 * totalBins + } + } +} + +private class NodeIdPartitioner(override val numPartitions: Int, + val nodeIdPartitionMapping: Map[(Int, Int), Int]) + extends Partitioner { + + def getPartition(key: Any): Int = { + val k = key.asInstanceOf[(Int, Int)] + + // orElse part should never happen + nodeIdPartitionMapping.getOrElse(k, 1) + } +} + +private case class TrainingLimits(localTrainingThreshold: Int, + distributedMaxDepth: Int) diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala b/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala new file mode 100755 index 0000000..206405a --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/impl/SplitUtils.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.{CategoricalSplit, Split} +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator +import org.apache.spark.mllib.tree.model.ImpurityStats + +/** Utility methods for choosing splits during local & distributed tree training. */ +private[impl] object SplitUtils extends Logging { + + /** Sorts ordered feature categories by label centroid, returning an ordered list of categories */ + private def sortByCentroid( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int): List[Int] = { + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val numCategories = binAggregates.metadata.numBins(featureIndex) + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + + val centroidForCategories = Range(0, numCategories).map { featureValue => + val categoryStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = ImpurityUtils.getCentroid(binAggregates.metadata, categoryStats) + (featureValue, centroid) + } + logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2).map(_._1) + logDebug("Sorted centroids for categorical variable = " + + categoriesSortedByCentroid.mkString(",")) + categoriesSortedByCentroid + } + + /** + * Find the best split for an unordered categorical feature at a single node. + * + * Algorithm: + * - Considers all possible subsets (exponentially many) + * + * @param featureIndex Global index of feature being split. + * @param featureIndexIdx Index of feature being split within subset of features for current node. + * @param featureSplits Array of splits for the current feature + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + * @return (best split, statistics for split) If no valid split was found, the returned + * ImpurityStats instance will be invalid (have member valid = false). + */ + private[impl] def chooseUnorderedCategoricalSplit( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + // Unordered categorical feature + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + val numSplits = binAggregates.metadata.numSplits(featureIndex) + val parentCalc = parentCalculator.getOrElse(binAggregates.getParentImpurityCalculator()) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) + val gainAndImpurityStats = ImpurityUtils.calculateImpurityStats(parentCalc, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + (featureSplits(bestFeatureSplitIndex), bestFeatureGainStats) + + } + + /** + * Choose splitting rule: feature value <= threshold + * + * @return (best split, statistics for split) If the best split actually puts all instances + * in one leaf node, then it will be set to None. If no valid split was found, the + * returned ImpurityStats instance will be invalid (have member valid = false) + */ + private[impl] def chooseContinuousSplit( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + // For a continuous feature, bins are already sorted for splitting + // Number of "categories" = number of bins + val sortedCategories = Range(0, binAggregates.metadata.numBins(featureIndex)).toList + // Get & return best split info + val (bestFeatureSplitIndex, bestFeatureGainStats) = orderedSplitHelper(binAggregates, + featureIndex, featureIndexIdx, sortedCategories, parentCalculator) + (featureSplits(bestFeatureSplitIndex), bestFeatureGainStats) + } + + /** + * Computes the index of the best split for an ordered feature. + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + */ + private def orderedSplitHelper( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + categoriesSortedByCentroid: List[Int], + parentCalculator: Option[ImpurityCalculator]): (Int, ImpurityStats) = { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val numSplits = binAggregates.metadata.numSplits(featureIndex) + val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx) + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex) + val nextCategory = categoriesSortedByCentroid(splitIndex + 1) + binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 + } + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last + + // Find best split. + val parentCalc = parentCalculator.getOrElse(binAggregates.getParentImpurityCalculator()) + Range(0, numSplits).map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex) + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + val gainAndImpurityStats = ImpurityUtils.calculateImpurityStats(parentCalc, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) + }.maxBy(_._2.gain) + } + + /** + * Choose the best split for an ordered categorical feature. + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + */ + private[impl] def chooseOrderedCategoricalSplit( + binAggregates: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + // Sort feature categories by label centroid + val categoriesSortedByCentroid = sortByCentroid(binAggregates, featureIndex, featureIndexIdx) + // Get index, stats of best split + val (bestFeatureSplitIndex, bestFeatureGainStats) = orderedSplitHelper(binAggregates, + featureIndex, featureIndexIdx, categoriesSortedByCentroid, parentCalculator) + // Create result (CategoricalSplit instance) + val categoriesForSplit = + categoriesSortedByCentroid.map(_.toDouble).slice(0, bestFeatureSplitIndex + 1) + val numCategories = binAggregates.metadata.featureArity(featureIndex) + val bestFeatureSplit = + new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories) + (bestFeatureSplit, bestFeatureGainStats) + } + + /** + * Choose the best split for a feature at a node. + * + * @param parentCalculator Optional: ImpurityCalculator containing impurity stats for current node + * @return (best split, statistics for split) If no valid split was found, the returned + * ImpurityStats will have member stats.valid = false. + */ + private[impl] def chooseSplit( + statsAggregator: DTStatsAggregator, + featureIndex: Int, + featureIndexIdx: Int, + featureSplits: Array[Split], + parentCalculator: Option[ImpurityCalculator] = None): (Split, ImpurityStats) = { + val metadata = statsAggregator.metadata + if (metadata.isCategorical(featureIndex)) { + if (metadata.isUnordered(featureIndex)) { + SplitUtils.chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, + featureIndexIdx, featureSplits, parentCalculator) + } else { + SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, featureIndex, + featureIndexIdx, parentCalculator) + } + } else { + SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, featureIndexIdx, + featureSplits, parentCalculator) + } + + } + +} diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/TrainingInfo.scala b/src/main/scala/org/apache/spark/ml/tree/impl/TrainingInfo.scala new file mode 100755 index 0000000..3f22dde --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/impl/TrainingInfo.scala @@ -0,0 +1,152 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.ml.tree.{OptimizedLearningNode, Split} +import org.apache.spark.util.collection.BitSet + +import scala.collection.mutable.ArrayBuffer + +/** + * Maintains intermediate state of data (columns) and tree during local tree training. + * Primary local tree training data structure; contains all information required to describe + * the state of the algorithm at any point during learning.?? + * + * Nodes are indexed left-to-right along the periphery of the tree, with 0-based indices. + * The "periphery" is the set of leaf nodes (active and inactive). + * + * @param columns Array of columns. + * Each column is sorted first by nodes (left-to-right along the tree periphery); + * all columns share this first level of sorting. + * @param nodeOffsets Offsets into the columns indicating the first level of sorting (by node). + * The rows corresponding to the node activeNodes(i) are in the range + * [nodeOffsets(i)(0), nodeOffsets(i)(1)) . + * @param currentLevelActiveNodes Nodes which are active (could still be split). + * Inactive nodes are known to be leaves in the final tree. + */ +private[impl] case class TrainingInfo( + columns: Array[FeatureColumn], + nodeOffsets: Array[(Int, Int)], + currentLevelActiveNodes: Array[OptimizedLearningNode], + rowIndices: Option[Array[Int]] = None) extends Serializable { + + // pre-allocated temporary buffers that we use to sort + // instances in left and right children during update + val tempVals: Array[Int] = new Array[Int](columns.head.values.length) + + // Array of row indices for feature values, shared across all columns. + // For each column (col) in [[columns]], col(j) is the feature value corresponding to the row + // with index indices(j). + val indices: Array[Int] = rowIndices.getOrElse(columns.head.values.indices.toArray) + + /** For debugging */ + override def toString: String = { + "TrainingInfo(" + + " columns: {\n" + + columns.mkString(",\n") + + " },\n" + + s" nodeOffsets: ${nodeOffsets.mkString(", ")},\n" + + s" activeNodes: ${currentLevelActiveNodes.iterator.mkString(", ")},\n" + + ")\n" + } + + /** + * Update columns and nodeOffsets for the next level of the tree. + * + * Update columns: + * For each (previously) active node, + * Compute bitset indicating whether each training instance under the node splits left/right + * For each column, + * Sort corresponding range of instances based on bitset. + * Update nodeOffsets, activeNodes: + * Split offsets for nodes which split (which can be identified using the bitset). + * + * @return Updated partition info + */ + def update(splits: Array[Array[Split]], + newActiveNodes: Array[OptimizedLearningNode]): TrainingInfo = { + // Create buffers for storing our new arrays of node offsets & impurities + val newNodeOffsets = new ArrayBuffer[(Int, Int)]() + // Update (per-node) sorting of each column to account for creation of new nodes + var nodeIdx = 0 + while (nodeIdx < currentLevelActiveNodes.length) { + val node = currentLevelActiveNodes(nodeIdx) + // Get new active node offsets from active nodes that were split + if (!node.isLeaf) { + // Get split and FeatureVector corresponding to feature for split + val split = node.split.get + val col = columns(split.featureIndex) + val (from, to) = nodeOffsets(nodeIdx) + // Compute bitset indicating whether each training example splits left/right + val bitset = TrainingInfo.bitSetFromSplit(col, from, to, split, splits(split.featureIndex)) + // Update each column according to the bitset + val numRows = to - from + // Allocate shared temp buffers (shared across all columns) for reordering + // feature values/indices for current node. + val tempVals = new Array[Int](numRows) + val numLeftRows = numRows - bitset.cardinality() + // Reorder values for each column + columns.foreach { col => + LocalDecisionTreeUtils.updateArrayForSplit(col.values, from, to, numLeftRows, tempVals, + bitset) + } + // Reorder indices (shared across all columns) + LocalDecisionTreeUtils.updateArrayForSplit(indices, from, to, numLeftRows, tempVals, bitset) + // Add new node offsets to array + val leftIndices = (from, from + numLeftRows) + val rightIndices = (from + numLeftRows, to) + newNodeOffsets ++= Array(leftIndices, rightIndices) + } + nodeIdx += 1 + } + TrainingInfo(columns, newNodeOffsets.toArray, newActiveNodes, Some(indices)) + } + +} + +/** Training-info specific utility methods. */ +private[impl] object TrainingInfo { + /** + * For a given feature, for a given node, apply a split and return a bitset indicating the + * outcome of the split for each instance at that node. + * + * @param col Column for feature + * @param from Start offset in col for the node + * @param to End offset in col for the node + * @param split Split to apply to instances at this node. + * @return Bitset indicating splits for instances at this node. + * These bits are sorted by the row indices. + * bitset(i) = true if ith example for current node splits right, false otherwise. + */ + private[impl] def bitSetFromSplit( + col: FeatureColumn, + from: Int, + to: Int, + split: Split, + featureSplits: Array[Split]): BitSet = { + val bitset = new BitSet(to - from) + from.until(to).foreach { i => + if (!split.shouldGoLeft(col.values(i), featureSplits)) { + bitset.set(i - from) + } + } + bitset + } +} diff --git a/src/main/scala/org/apache/spark/ml/tree/impl/TrainingStatistics.scala b/src/main/scala/org/apache/spark/ml/tree/impl/TrainingStatistics.scala new file mode 100755 index 0000000..9e383ef --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/impl/TrainingStatistics.scala @@ -0,0 +1,22 @@ +/* + * Copyright (C) 2019 Cisco Systems + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +case class TrainingStatistics(timers: TimeTracker, + localTrainingStatistics: Seq[NodeStatistics]) + +case class NodeStatistics(id: Int, rows: Int, impurity: Double, time: Double) diff --git a/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/src/main/scala/org/apache/spark/ml/tree/treeModels.scala new file mode 100755 index 0000000..cd5705b --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -0,0 +1,355 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import org.apache.hadoop.fs.Path +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.{Param, Params} +import org.apache.spark.ml.util.DefaultParamsReader.Metadata +import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter} +import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Dataset, SparkSession} +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +/** + * Abstraction for Decision Tree models. + * + * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 + */ +private[spark] trait OptimizedDecisionTreeModel { + + /** Root of the decision tree */ + def rootNode: OptimizedNode + + /** Number of nodes in tree, including leaf nodes. */ + def numNodes: Int = { + 1 + rootNode.numDescendants + } + + /** + * Depth of the tree. + * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. + */ + lazy val depth: Int = { + rootNode.subtreeDepth + } + + /** Summary of the model */ + override def toString: String = { + // Implementing classes should generally override this method to be more descriptive. + s"DecisionTreeModel of depth $depth with $numNodes nodes" + } + + /** Full description of model */ + def toDebugString: String = { + val header = toString + "\n" + header + rootNode.subtreeToString(2) + } + + /** + * Trace down the tree, and return the largest feature index used in any split. + * + * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). + */ + private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex() + + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ + private[spark] def toOld: OldDecisionTreeModel +} + +/** + * Abstraction for models which are ensembles of decision trees + * + * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 + * + * @tparam M Type of tree model in this ensemble + */ +private[ml] trait OptimizedTreeEnsembleModel[M <: OptimizedDecisionTreeModel] { + + // Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of + // DecisionTreeModel. + + /** Trees in this ensemble. Warning: These have null parent Estimators. */ + def trees: Array[M] + + /** Weights for each tree, zippable with [[trees]] */ + def treeWeights: Array[Double] + + /** Weights used by the python wrappers. */ + // Note: An array cannot be returned directly due to serialization problems. + private[spark] def javaTreeWeights: Vector = Vectors.dense(treeWeights) + + /** Summary of the model */ + override def toString: String = { + // Implementing classes should generally override this method to be more descriptive. + s"TreeEnsembleModel with ${trees.length} trees" + } + + /** Full description of model */ + def toDebugString: String = { + val header = toString + "\n" + header + trees.zip(treeWeights).zipWithIndex.map { case ((tree, weight), treeIndex) => + s" Tree $treeIndex (weight $weight):\n" + tree.rootNode.subtreeToString(4) + }.fold("")(_ + _) + } + + /** Total number of nodes, summed over all trees in the ensemble. */ + lazy val totalNumNodes: Int = trees.map(_.numNodes).sum +} + +/** Helper classes for tree model persistence */ +private[ml] object OptimizedDecisionTreeModelReadWrite { + + /** + * Info for a [[org.apache.spark.ml.tree.Split]] + * + * @param featureIndex Index of feature split on + * @param leftCategoriesOrThreshold For categorical feature, set of leftCategories. + * For continuous feature, threshold. + * @param numCategories For categorical feature, number of categories. + * For continuous feature, -1. + */ + case class SplitData( + featureIndex: Int, + leftCategoriesOrThreshold: Array[Double], + numCategories: Int) { + + def getSplit: Split = { + if (numCategories != -1) { + new CategoricalSplit(featureIndex, leftCategoriesOrThreshold, numCategories) + } else { + assert(leftCategoriesOrThreshold.length == 1, s"DecisionTree split data expected" + + s" 1 threshold for ContinuousSplit, but found thresholds: " + + leftCategoriesOrThreshold.mkString(", ")) + new ContinuousSplit(featureIndex, leftCategoriesOrThreshold(0)) + } + } + } + + object SplitData { + def apply(split: Split): SplitData = split match { + case s: CategoricalSplit => + SplitData(s.featureIndex, s.leftCategories, s.numCategories) + case s: ContinuousSplit => + SplitData(s.featureIndex, Array(s.threshold), -1) + } + } + + /** + * Info for a [[OptimizedNode]] + * + * @param id Index used for tree reconstruction. Indices follow a pre-order traversal. + * @param gain Gain, or arbitrary value if leaf node. + * @param leftChild Left child index, or arbitrary value if leaf node. + * @param rightChild Right child index, or arbitrary value if leaf node. + * @param split Split info, or arbitrary value if leaf node. + */ + case class NodeData( + id: Int, + prediction: Double, + impurity: Double, + gain: Double, + leftChild: Int, + rightChild: Int, + split: SplitData) + + object NodeData { + /** + * Create [[NodeData]] instances for this node and all children. + * + * @param id Current ID. IDs are assigned via a pre-order traversal. + * @return (sequence of nodes in pre-order traversal order, largest ID in subtree) + * The nodes are returned in pre-order traversal (root first) so that it is easy to + * get the ID of the subtree's root node. + */ + def build(node: OptimizedNode, id: Int): (Seq[NodeData], Int) = node match { + case n: OptimizedInternalNode => + val (leftNodeData, leftIdx) = build(n.leftChild, id + 1) + val (rightNodeData, rightIdx) = build(n.rightChild, leftIdx + 1) + val thisNodeData = NodeData(id, n.prediction, n.impurity, + n.gain, leftNodeData.head.id, rightNodeData.head.id, SplitData(n.split)) + (thisNodeData +: (leftNodeData ++ rightNodeData), rightIdx) + case _: OptimizedLeafNode => + (Seq(NodeData(id, node.prediction, node.impurity, + -1.0, -1, -1, SplitData(-1, Array.empty[Double], -1))), + id) + } + } + + /** + * Load a decision tree from a file. + * @return Root node of reconstructed tree + */ + def loadTreeNodes( + path: String, + metadata: DefaultParamsReader.Metadata, + sparkSession: SparkSession): OptimizedNode = { + import sparkSession.implicits._ + implicit val format = DefaultFormats + + // Get impurity to construct ImpurityCalculator for each node + val impurityType: String = { + val impurityJson: JValue = metadata.getParamValue("impurity") + Param.jsonDecode[String](compact(render(impurityJson))) + } + + val dataPath = new Path(path, "data").toString + val data = sparkSession.read.parquet(dataPath).as[NodeData] + buildTreeFromNodes(data.collect(), impurityType) + } + + /** + * Given all data for all nodes in a tree, rebuild the tree. + * @param data Unsorted node data + * @param impurityType Impurity type for this tree + * @return Root node of reconstructed tree + */ + def buildTreeFromNodes(data: Array[NodeData], impurityType: String): OptimizedNode = { + // Load all nodes, sorted by ID. + val nodes = data.sortBy(_.id) + // Sanity checks; could remove + assert(nodes.head.id == 0, s"Decision Tree load failed. Expected smallest node ID to be 0," + + s" but found ${nodes.head.id}") + assert(nodes.last.id == nodes.length - 1, s"Decision Tree load failed. Expected largest" + + s" node ID to be ${nodes.length - 1}, but found ${nodes.last.id}") + // We fill `finalNodes` in reverse order. Since node IDs are assigned via a pre-order + // traversal, this guarantees that child nodes will be built before parent nodes. + val finalNodes = new Array[OptimizedNode](nodes.length) + nodes.reverseIterator.foreach { case n: NodeData => + val node = if (n.leftChild != -1) { + val leftChild = finalNodes(n.leftChild) + val rightChild = finalNodes(n.rightChild) + new OptimizedInternalNode(n.prediction, n.impurity, n.gain, leftChild, rightChild, + n.split.getSplit) + } else { + new OptimizedLeafNode(n.prediction, n.impurity) + } + finalNodes(n.id) = node + } + // Return the root node + finalNodes.head + } +} + +private[ml] object OptimizedEnsembleModelReadWrite { + + import OptimizedDecisionTreeModelReadWrite.NodeData + + /** + * Helper method for saving a tree ensemble to disk. + * + * @param instance Tree ensemble model + * @param path Path to which to save the ensemble model. + * @param extraMetadata Metadata such as numFeatures, numClasses, numTrees. + */ + def saveImpl[M <: Params with OptimizedTreeEnsembleModel[_ <: OptimizedDecisionTreeModel]]( + instance: M, + path: String, + sql: SparkSession, + extraMetadata: JObject): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata)) + val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map { + case (tree, treeID) => + (treeID, + DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext), + instance.treeWeights(treeID)) + } + val treesMetadataPath = new Path(path, "treesMetadata").toString + sql.createDataFrame(treesMetadataWeights).toDF("treeID", "metadata", "weights") + .write.parquet(treesMetadataPath) + val dataPath = new Path(path, "data").toString + val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap { + case (tree, treeID) => EnsembleNodeData.build(tree, treeID) + } + sql.createDataFrame(nodeDataRDD).write.parquet(dataPath) + } + + /** + * Helper method for loading a tree ensemble from disk. + * This reconstructs all trees, returning the root nodes. + * @param path Path given to `saveImpl` + * @param className Class name for ensemble model type + * @param treeClassName Class name for tree model type in the ensemble + * @return (ensemble metadata, array over trees of (tree metadata, root node)), + * where the root node is linked with all descendents + * @see `saveImpl` for how the model was saved + */ + def loadImpl( + path: String, + sql: SparkSession, + className: String, + treeClassName: String): (Metadata, Array[(Metadata, OptimizedNode)], Array[Double]) = { + import sql.implicits._ + implicit val format = DefaultFormats + val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) + + // Get impurity to construct ImpurityCalculator for each node + val impurityType: String = { + val impurityJson: JValue = metadata.getParamValue("impurity") + Param.jsonDecode[String](compact(render(impurityJson))) + } + + val treesMetadataPath = new Path(path, "treesMetadata").toString + val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath) + .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map { + case (treeID: Int, json: String, weights: Double) => + treeID -> ((DefaultParamsReader.parseMetadata(json, treeClassName), weights)) + } + + val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect() + val treesMetadata = treesMetadataWeights.map(_._1) + val treesWeights = treesMetadataWeights.map(_._2) + + val dataPath = new Path(path, "data").toString + val nodeData: Dataset[EnsembleNodeData] = + sql.read.parquet(dataPath).as[EnsembleNodeData] + val rootNodesRDD: RDD[(Int, OptimizedNode)] = + nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { + case (treeID: Int, nodeData: Iterable[NodeData]) => + treeID -> OptimizedDecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) + } + val rootNodes: Array[OptimizedNode] = rootNodesRDD.sortByKey().values.collect() + (metadata, treesMetadata.zip(rootNodes), treesWeights) + } + + /** + * Info for one [[Node]] in a tree ensemble + * + * @param treeID Tree index + * @param nodeData Data for this node + */ + case class EnsembleNodeData( + treeID: Int, + nodeData: NodeData) + + object EnsembleNodeData { + /** + * Create [[EnsembleNodeData]] instances for the given tree. + * + * @return Sequence of nodes for this tree + */ + def build(tree: OptimizedDecisionTreeModel, treeID: Int): Seq[EnsembleNodeData] = { + val (nodeData: Seq[NodeData], _) = NodeData.build(tree.rootNode, 0) + nodeData.map(nd => EnsembleNodeData(treeID, nd)) + } + } +} diff --git a/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/src/main/scala/org/apache/spark/ml/tree/treeParams.scala new file mode 100755 index 0000000..2b7a560 --- /dev/null +++ b/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -0,0 +1,136 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree + +import java.util.Locale + +import org.apache.spark.ml.param._ +import org.apache.spark.ml.tree.impl.LocalDecisionTree +import org.apache.spark.mllib.tree.configuration.{DefaultTimePredictionStrategy, OptimizedForestStrategy, TimePredictionStrategy, Algo => OldAlgo} +import org.apache.spark.mllib.tree.impurity.{Impurity => OldImpurity} + + +private[ml] trait OptimizedDecisionTreeParams extends DecisionTreeParams { + + final val maxMemoryMultiplier: DoubleParam = new DoubleParam(this, "maxMemoryMultiplier", "", + ParamValidators.gt(0.0)) + + var timePredictionStrategy: TimePredictionStrategy = new DefaultTimePredictionStrategy + + final val maxTasksPerBin: IntParam + = new IntParam (this, "maxTasksPerBin", "", ParamValidators.gt(0)) + + var customSplits: Option[Array[Array[Double]]] = None + + var localTrainingAlgorithm: LocalTrainingAlgorithm = new LocalDecisionTree + + setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0, + maxMemoryInMB -> 256, cacheNodeIds -> true, checkpointInterval -> 10, + maxMemoryMultiplier -> 4, maxTasksPerBin -> Int.MaxValue) + + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setMaxMemoryMultiplier(value: Double): this.type = set(maxMemoryMultiplier, value) + + /** @group expertGetParam */ + final def getMaxMemoryMultiplier: Double = $(maxMemoryMultiplier) + + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setTimePredictionStrategy(value: TimePredictionStrategy) = timePredictionStrategy = value + + final def getTimePredictionStrategy: TimePredictionStrategy = timePredictionStrategy + + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setMaxTasksPerBin(value: Int): this.type = set(maxTasksPerBin, value) + + /** @group expertGetParam */ + final def getMaxTasksPerBin: Int = $(maxTasksPerBin) + + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setCustomSplits(value: Option[Array[Array[Double]]]) = customSplits = value + + /** @group expertGetParam */ + final def getCustomSplits: Option[Array[Array[Double]]] = customSplits + + @deprecated("This method is deprecated and will be removed in 3.0.0.", "2.1.0") + def setLocalTrainingAlgorithm(value: LocalTrainingAlgorithm) = this.localTrainingAlgorithm = value + + /** @group expertGetParam */ + final def getLocalTrainingAlgorithm: LocalTrainingAlgorithm = localTrainingAlgorithm + + /** (private[ml]) Create a Strategy instance to use with the old API. */ + private[ml] override def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity, + subsamplingRate: Double): OptimizedForestStrategy = { + val strategy = OptimizedForestStrategy.defaultStrategy(oldAlgo) + strategy.impurity = oldImpurity + strategy.checkpointInterval = getCheckpointInterval + strategy.maxBins = getMaxBins + strategy.maxDepth = getMaxDepth + strategy.maxMemoryInMB = getMaxMemoryInMB + strategy.minInfoGain = getMinInfoGain + strategy.minInstancesPerNode = getMinInstancesPerNode + strategy.useNodeIdCache = getCacheNodeIds + strategy.numClasses = numClasses + strategy.categoricalFeaturesInfo = categoricalFeatures + strategy.subsamplingRate = subsamplingRate + strategy.maxMemoryMultiplier = getMaxMemoryMultiplier + strategy.timePredictionStrategy = getTimePredictionStrategy + strategy.localTrainingAlgorithm = getLocalTrainingAlgorithm + strategy + } +} + +private[spark] object OptimizedTreeEnsembleParams { + // These options should be lowercase. + final val supportedTimePredictionStrategies: Array[String] = + Array("size").map(_.toLowerCase(Locale.ROOT)) + + final val supportedLocalTrainingAlgorithms: Array[String] = + Array("yggdrasil").map(_.toLowerCase(Locale.ROOT)) +} + +private[ml] trait OptimizedDecisionTreeClassifierParams + extends OptimizedDecisionTreeParams with TreeClassifierParams + +private[ml] trait OptimizedDecisionTreeRegressorParams + extends OptimizedDecisionTreeParams with TreeRegressorParams + +private[ml] trait OptimizedTreeEnsembleParams extends TreeEnsembleParams + with OptimizedDecisionTreeParams { + private[ml] override def getOldStrategy( + categoricalFeatures: Map[Int, Int], + numClasses: Int, + oldAlgo: OldAlgo.Algo, + oldImpurity: OldImpurity): OptimizedForestStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate) + } +} + +private[ml] trait OptimizedRandomForestParams extends RandomForestParams + with OptimizedTreeEnsembleParams + +private[ml] trait OptimizedRandomForestClassifierParams + extends OptimizedRandomForestParams with TreeClassifierParams + +private[ml] trait OptimizedRandomForestRegressorParams + extends OptimizedRandomForestParams with TreeRegressorParams diff --git a/src/main/scala/org/apache/spark/mllib/tree/OptimizedRandomForest.scala b/src/main/scala/org/apache/spark/mllib/tree/OptimizedRandomForest.scala new file mode 100755 index 0000000..c9a7a9d --- /dev/null +++ b/src/main/scala/org/apache/spark/mllib/tree/OptimizedRandomForest.scala @@ -0,0 +1,476 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.annotation.Since +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.internal.Logging +import org.apache.spark.ml.classification.{OptimizedDecisionTreeClassificationModel, OptimizedRandomForestClassificationModel} +import org.apache.spark.ml.regression.{OptimizedDecisionTreeRegressionModel, OptimizedRandomForestRegressionModel} +import org.apache.spark.ml.tree.impl.{TrainingStatistics, OptimizedRandomForest => NewRandomForest} +import org.apache.spark.ml.tree.{OptimizedDecisionTreeModel => NewDTModel, TreeEnsembleParams => NewRFParams} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.OptimizedForestStrategy +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impurity.Impurities +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +import scala.collection.JavaConverters._ +import scala.util.Try + + +/** + * A class that implements a Random Forest + * learning algorithm for classification and regression. + * It supports both continuous and categorical features. + * + * The settings for featureSubsetStrategy are based on the following references: + * - log2: tested in Breiman (2001) + * - sqrt: recommended by Breiman manual for random forests + * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest + * package. + * + * @see Breiman (2001) + * @see + * Breiman manual for random forests + * @param strategy The configuration parameters for the random forest algorithm which specify + * the type of random forest (classification or regression), feature type + * (continuous, categorical), depth of the tree, quantile calculation strategy, + * etc. + * @param numTrees If 1, then no bootstrapping is used. If greater than 1, then bootstrapping is + * done. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported values: "auto", "all", "sqrt", "log2", "onethird". + * Supported numerical values: "(0.0-1.0]", "[1-n]". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees is greater than 1 (forest) set to "sqrt" for + * classification and to "onethird" for regression. + * If a real value "n" in the range (0, 1.0] is set, + * use n * number of features. + * If an integer value "n" in the range (1, num features) is set, + * use n features. + * @param seed Random seed for bootstrapping and choosing feature subsets. + */ +private class OptimizedRandomForest ( + private val strategy: OptimizedForestStrategy, + private val numTrees: Int, + featureSubsetStrategy: String, + private val seed: Int) + extends Serializable with Logging { + + strategy.assertValid() + require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") + require(OptimizedRandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy) + || Try(featureSubsetStrategy.toInt).filter(_ > 0).isSuccess + || Try(featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess, + s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." + + s" Supported values: ${NewRFParams.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" (0.0-1.0], [1-n].") + + /** + * Method to train a decision tree model over an RDD + * + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return RandomForestModel that can be used for prediction. + */ + def run(input: RDD[LabeledPoint]): (Array[NewDTModel], Option[TrainingStatistics]) = { + NewRandomForest.run(input.map(_.asML), strategy, numTrees, + featureSubsetStrategy, seed.toLong, None, prune = true, None, computeStatistics = true) + } +} + +@Since("1.2.0") +object OptimizedRandomForest extends Serializable with Logging { + + /** + * Method to train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param strategy Parameters for training each tree in the forest. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported values: "auto", "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees is greater than 1 (forest) set to "sqrt". + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. + */ + @Since("1.2.0") + def trainClassifier( + input: RDD[LabeledPoint], + strategy: OptimizedForestStrategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Int): OptimizedRandomForestClassificationModel = { + require(strategy.algo == Classification, + s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}") + val rf = new OptimizedRandomForest(strategy, numTrees, featureSubsetStrategy, seed) + val (trees, _) = rf.run(input) + + val classificationTrees = trees.map(_.asInstanceOf[OptimizedDecisionTreeClassificationModel]) + val numFeatures = input.first().features.size + val numClasses = strategy.getNumClasses + + new OptimizedRandomForestClassificationModel(classificationTrees, numFeatures, numClasses) + } + + /** + * Method to train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClasses Number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n to k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported values: "auto", "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees is greater than 1 (forest) set to "sqrt". + * @param impurity Criterion used for information gain calculation. + * Supported values: "gini" (recommended) or "entropy". + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 4) + * @param maxBins Maximum number of bins used for splitting features + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. + */ + @Since("1.2.0") + def trainClassifier( + input: RDD[LabeledPoint], + numClasses: Int, + categoricalFeaturesInfo: Map[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int = Utils.random.nextInt()): OptimizedRandomForestClassificationModel = { + val impurityType = Impurities.fromString(impurity) + val strategy = new OptimizedForestStrategy(Classification, impurityType, maxDepth, + numClasses, maxBins, Sort, categoricalFeaturesInfo) + trainClassifier(input, strategy, numTrees, featureSubsetStrategy, seed) + } + + /** + * Java-friendly API for `org.apache.spark.mllib.tree.RandomForest.trainClassifier` + */ + @Since("1.2.0") + def trainClassifier( + input: JavaRDD[LabeledPoint], + numClasses: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int): OptimizedRandomForestClassificationModel = { + trainClassifier(input.rdd, numClasses, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param strategy Parameters for training each tree in the forest. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported values: "auto", "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees is greater than 1 (forest) set to "onethird". + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. + */ + @Since("1.2.0") + def trainRegressor( + input: RDD[LabeledPoint], + strategy: OptimizedForestStrategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Int): OptimizedRandomForestRegressionModel = { + require(strategy.algo == Regression, + s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}") + val rf = new OptimizedRandomForest(strategy, numTrees, featureSubsetStrategy, seed) + val (trees, _) = rf.run(input) + + val regressionTrees = trees.map(_.asInstanceOf[OptimizedDecisionTreeRegressionModel]) + val numFeatures = input.first().features.size + + new OptimizedRandomForestRegressionModel(Identifiable.randomUID("rfc"), regressionTrees, numFeatures) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n to k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported values: "auto", "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees is greater than 1 (forest) set to "onethird". + * @param impurity Criterion used for information gain calculation. + * The only supported value for regression is "variance". + * @param maxDepth Maximum depth of the tree. (e.g., depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 4) + * @param maxBins Maximum number of bins used for splitting features. + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. + */ + @Since("1.2.0") + def trainRegressor( + input: RDD[LabeledPoint], + categoricalFeaturesInfo: Map[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int = Utils.random.nextInt()) + : OptimizedRandomForestRegressionModel = { + val impurityType = Impurities.fromString(impurity) + val strategy = new OptimizedForestStrategy(Regression, impurityType, maxDepth, + 0, maxBins, Sort, categoricalFeaturesInfo) + trainRegressor(input, strategy, numTrees, featureSubsetStrategy, seed) + } + + /** + * Java-friendly API for `org.apache.spark.mllib.tree.RandomForest.trainRegressor` + */ + @Since("1.2.0") + def trainRegressor( + input: JavaRDD[LabeledPoint], + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int): OptimizedRandomForestRegressionModel = { + trainRegressor(input.rdd, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) + } + + @Since("1.2.0") + def trainClassifierWithStatistics( + input: RDD[LabeledPoint], + strategy: OptimizedForestStrategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Int) + : (OptimizedRandomForestClassificationModel, TrainingStatistics) = { + require(strategy.algo == Classification, + s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}") + val rf = new OptimizedRandomForest(strategy, numTrees, featureSubsetStrategy, seed) + val (trees, statistics) = rf.run(input) + + val classificationTrees = trees.map(_.asInstanceOf[OptimizedDecisionTreeClassificationModel]) + val numFeatures = input.first().features.size + val numClasses = strategy.getNumClasses + val m = new OptimizedRandomForestClassificationModel(classificationTrees, numFeatures, numClasses) + + (m, statistics.get) + } + + + /** + * Method to train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClasses Number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n to k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported values: "auto", "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees is greater than 1 (forest) set to "sqrt". + * @param impurity Criterion used for information gain calculation. + * Supported values: "gini" (recommended) or "entropy". + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 4) + * @param maxBins Maximum number of bins used for splitting features + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. + */ + @Since("1.2.0") + def trainClassifierWithStatistics( + input: RDD[LabeledPoint], + numClasses: Int, + categoricalFeaturesInfo: Map[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int = Utils.random.nextInt()) + : (OptimizedRandomForestClassificationModel, TrainingStatistics) = { + val impurityType = Impurities.fromString(impurity) + val strategy = new OptimizedForestStrategy(Classification, impurityType, maxDepth, + numClasses, maxBins, Sort, categoricalFeaturesInfo) + trainClassifierWithStatistics(input, strategy, numTrees, featureSubsetStrategy, seed) + } + + /** + * Java-friendly API for `org.apache.spark.mllib.tree.RandomForest.trainClassifier` + */ + @Since("1.2.0") + def trainClassifierWithStatistics( + input: JavaRDD[LabeledPoint], + numClasses: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int): (OptimizedRandomForestClassificationModel, TrainingStatistics) = { + trainClassifierWithStatistics(input.rdd, numClasses, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param strategy Parameters for training each tree in the forest. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported values: "auto", "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees is greater than 1 (forest) set to "onethird". + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. + */ + @Since("1.2.0") + def trainRegressorWithStatistics( + input: RDD[LabeledPoint], + strategy: OptimizedForestStrategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Int) + : (OptimizedRandomForestRegressionModel, TrainingStatistics) = { + require(strategy.algo == Regression, + s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}") + val rf = new OptimizedRandomForest(strategy, numTrees, featureSubsetStrategy, seed) + val (trees, statistics) = rf.run(input) + + val regressionTrees = trees.map(_.asInstanceOf[OptimizedDecisionTreeRegressionModel]) + val numFeatures = input.first().features.size + val m = new OptimizedRandomForestRegressionModel(Identifiable.randomUID("rfc"), regressionTrees, numFeatures) + + (m, statistics.get) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param categoricalFeaturesInfo Map storing arity of categorical features. An entry (n to k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported values: "auto", "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees is greater than 1 (forest) set to "onethird". + * @param impurity Criterion used for information gain calculation. + * The only supported value for regression is "variance". + * @param maxDepth Maximum depth of the tree. (e.g., depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * (suggested value: 4) + * @param maxBins Maximum number of bins used for splitting features. + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction. + */ + @Since("1.2.0") + def trainRegressorWithStatistics( + input: RDD[LabeledPoint], + categoricalFeaturesInfo: Map[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int = Utils.random.nextInt()) + : (OptimizedRandomForestRegressionModel, TrainingStatistics) = { + val impurityType = Impurities.fromString(impurity) + val strategy = new OptimizedForestStrategy(Regression, impurityType, maxDepth, + 0, maxBins, Sort, categoricalFeaturesInfo) + trainRegressorWithStatistics(input, strategy, numTrees, featureSubsetStrategy, seed) + } + + /** + * Java-friendly API for `org.apache.spark.mllib.tree.RandomForest.trainRegressor` + */ + @Since("1.2.0") + def trainRegressorWithStatistics( + input: JavaRDD[LabeledPoint], + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int): (OptimizedRandomForestRegressionModel, TrainingStatistics) = { + trainRegressorWithStatistics(input.rdd, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) + } + + /** + * List of supported feature subset sampling strategies. + */ + @Since("1.2.0") + val supportedFeatureSubsetStrategies: Array[String] = NewRFParams.supportedFeatureSubsetStrategies +} diff --git a/src/main/scala/org/apache/spark/mllib/tree/configuration/OptimizedForestStrategy.scala b/src/main/scala/org/apache/spark/mllib/tree/configuration/OptimizedForestStrategy.scala new file mode 100755 index 0000000..556d758 --- /dev/null +++ b/src/main/scala/org/apache/spark/mllib/tree/configuration/OptimizedForestStrategy.scala @@ -0,0 +1,164 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.tree.LocalTrainingAlgorithm +import org.apache.spark.ml.tree.impl.LocalDecisionTree +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impurity.{Gini, Impurity, Variance} + +import scala.beans.BeanProperty +import scala.collection.JavaConverters._ + +/** + * Stores all the configuration options for tree construction + * @param algo Learning goal. Supported: + * `org.apache.spark.mllib.tree.configuration.Algo.Classification`, + * `org.apache.spark.mllib.tree.configuration.Algo.Regression` + * @param impurity Criterion used for information gain calculation. + * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], + * [[org.apache.spark.mllib.tree.impurity.Entropy]]. + * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. + * @param maxDepth Maximum depth of the tree (e.g. depth 0 means 1 leaf node, depth 1 means + * 1 internal node + 2 leaf nodes). + * @param numClasses Number of classes for classification. + * (Ignored for regression.) + * Default value is 2 (binary classification). + * @param maxBins Maximum number of bins used for discretizing continuous features and + * for choosing how to split on features at each node. + * More bins give higher granularity. + * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: + * `org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort` + * @param categoricalFeaturesInfo A map storing information about the categorical variables and the + * number of discrete values they take. An entry (n to k) + * indicates that feature n is categorical with k categories + * indexed from 0: {0, 1, ..., k-1}. + * @param minInstancesPerNode Minimum number of instances each child must have after split. + * Default value is 1. If a split cause left or right child + * to have less than minInstancesPerNode, + * this split will not be considered as a valid split. + * @param minInfoGain Minimum information gain a split must get. Default value is 0.0. + * If a split has less information gain than minInfoGain, + * this split will not be considered as a valid split. + * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is + * 256 MB. If too small, then 1 node will be split per iteration, and + * its aggregates may exceed this size. + * @param subsamplingRate Fraction of the training data used for learning decision tree. + * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will + * maintain a separate RDD of node Id cache for each row. + * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. + * E.g. 10 means that the cache will get checkpointed every 10 updates. If + * the checkpoint directory is not set in + * [[org.apache.spark.SparkContext]], this setting is ignored. + */ +@Since("1.0.0") +class OptimizedForestStrategy @Since("1.3.0")( + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClasses: Int = 2, + maxBins: Int = 32, + quantileCalculationStrategy: QuantileStrategy = Sort, + categoricalFeaturesInfo: Map[Int, Int] + = Map[Int, Int](), + minInstancesPerNode: Int = 1, + minInfoGain: Double = 0.0, + maxMemoryInMB: Int = 256, + subsamplingRate: Double = 1, + useNodeIdCache: Boolean = true, + checkpointInterval: Int = 10, + @Since("2.1.2") @BeanProperty var maxMemoryMultiplier: Double = 4.0, + @Since("2.1.2") @BeanProperty var timePredictionStrategy: TimePredictionStrategy = new DefaultTimePredictionStrategy, + @Since("2.1.2") @BeanProperty var maxTasksPerBin: Int = Int.MaxValue, + @Since("2.3.1") @BeanProperty var localTrainingAlgorithm: LocalTrainingAlgorithm = new LocalDecisionTree, + @Since("2.3.1") @BeanProperty var customSplits: Option[Array[Array[Double]]] = None) + extends Strategy(algo, impurity, maxDepth, numClasses, maxBins, quantileCalculationStrategy, + categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, maxMemoryInMB, subsamplingRate, + useNodeIdCache, checkpointInterval) with Serializable { + + /** + * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]] + */ + @Since("1.1.0") + def this( + algo: Algo, + impurity: Impurity, + maxDepth: Int, + numClasses: Int, + maxBins: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) { + this(algo, impurity, maxDepth, numClasses, maxBins, Sort, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) + } + /** + * Check validity of parameters. + * Throws exception if invalid. + */ + private[spark] override def assertValid(): Unit = { + super.assertValid() + require(maxMemoryMultiplier > 0, + s"DecisionTree Strategy requires maxMemoryMultiplier > 0, but was given " + + s"$maxMemoryMultiplier") + require(maxTasksPerBin > 0, + s"DecisionTree Strategy requires maxTasksPerBin > 0, but was given " + + s"$maxMemoryMultiplier") + } + + /** + * Returns a shallow copy of this instance. + */ + @Since("1.2.0") + override def copy: OptimizedForestStrategy = { + new OptimizedForestStrategy(algo, impurity, maxDepth, numClasses, maxBins, + quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, + maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval, + maxMemoryMultiplier, timePredictionStrategy, maxTasksPerBin, localTrainingAlgorithm, customSplits) + } +} + +@Since("1.2.0") +object OptimizedForestStrategy { + + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo "Classification" or "Regression" + */ + @Since("1.2.0") + def defaultStrategy(algo: String): OptimizedForestStrategy = { + defaultStrategy(Algo.fromString(algo)) + } + + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo Algo.Classification or Algo.Regression + */ + @Since("1.3.0") + def defaultStrategy(algo: Algo): OptimizedForestStrategy = algo match { + case Algo.Classification => + new OptimizedForestStrategy(algo = Classification, impurity = Gini, maxDepth = 10, + numClasses = 2) + case Algo.Regression => + new OptimizedForestStrategy(algo = Regression, impurity = Variance, maxDepth = 10, + numClasses = 0) + } + +} diff --git a/src/main/scala/org/apache/spark/mllib/tree/configuration/TimePredictionStrategy.scala b/src/main/scala/org/apache/spark/mllib/tree/configuration/TimePredictionStrategy.scala new file mode 100755 index 0000000..60b5703 --- /dev/null +++ b/src/main/scala/org/apache/spark/mllib/tree/configuration/TimePredictionStrategy.scala @@ -0,0 +1,25 @@ +/* + * Copyright (C) 2019 Cisco Systems + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +trait TimePredictionStrategy extends Serializable { + def predict(rows: Long, impurity: Double): Double +} + +class DefaultTimePredictionStrategy extends TimePredictionStrategy with Serializable { + override def predict(rows: Long, impurity: Double): Double = rows +} diff --git a/src/test/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifierSuite.scala b/src/test/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifierSuite.scala new file mode 100755 index 0000000..8549022 --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/classification/OptimizedDecisionTreeClassifierSuite.scala @@ -0,0 +1,432 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.tree.OptimizedLeafNode +import org.apache.spark.ml.tree.impl.OptimizedTreeTests +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} + +class OptimizedDecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest { + + import OptimizedDecisionTreeClassifierSuite.compareAPIs + import testImplicits._ + + private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ + private var orderedLabeledPointsWithLabel0RDD: RDD[LabeledPoint] = _ + private var orderedLabeledPointsWithLabel1RDD: RDD[LabeledPoint] = _ + private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ + private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _ + private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + categoricalDataPointsRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints()).map(_.asML) + orderedLabeledPointsWithLabel0RDD = + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()).map(_.asML) + orderedLabeledPointsWithLabel1RDD = + sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()).map(_.asML) + categoricalDataPointsForMulticlassRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass()).map(_.asML) + continuousDataPointsForMulticlassRDD = + sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass()).map(_.asML) + categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize( + OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()) + .map(_.asML) + } + + test("params") { + ParamsSuite.checkParams(new OptimizedDecisionTreeClassifier) + val model = new OptimizedDecisionTreeClassificationModel("dtc", new OptimizedLeafNode(0.0, 0.0), 1, 2) + ParamsSuite.checkParams(model) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + test("Binary classification stump with ordered categorical features") { + val odt = new OptimizedDecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(2) + .setMaxBins(100) + .setSeed(1) + val dt = new DecisionTreeClassifier() + .setImpurity("gini") + .setMaxDepth(2) + .setMaxBins(100) + .setSeed(1) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + val numClasses = 2 + compareAPIs(categoricalDataPointsRDD, dt, odt, categoricalFeatures, numClasses) + } + + test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") { + val dt = new DecisionTreeClassifier() + .setMaxDepth(3) + .setMaxBins(100) + val odt = new OptimizedDecisionTreeClassifier() + .setMaxDepth(3) + .setMaxBins(100) + val numClasses = 2 + Array(orderedLabeledPointsWithLabel0RDD, orderedLabeledPointsWithLabel1RDD).foreach { rdd => + OptimizedDecisionTreeClassifier.supportedImpurities.foreach { impurity => + dt.setImpurity(impurity) + odt.setImpurity(impurity) + compareAPIs(rdd, dt, odt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + } + } + + test("Multiclass classification stump with 3-ary (unordered) categorical features") { + val rdd = categoricalDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val odt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 3 + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + compareAPIs(rdd, dt, odt, categoricalFeatures, numClasses) + } + + test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(1.0)), + LabeledPoint(1.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(3.0))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val odt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 2 + compareAPIs(rdd, dt, odt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Binary classification stump with 2 continuous features") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val odt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + val numClasses = 2 + compareAPIs(rdd, dt, odt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Multiclass classification stump with unordered categorical features," + + " with just enough bins") { + val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features + val rdd = categoricalDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(maxBins) + val odt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(maxBins) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + val numClasses = 3 + compareAPIs(rdd, dt, odt, categoricalFeatures, numClasses) + } + + test("Multiclass classification stump with continuous features") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val odt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val numClasses = 3 + compareAPIs(rdd, dt, odt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("Multiclass classification stump with continuous + unordered categorical features") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val odt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + compareAPIs(rdd, dt, odt, categoricalFeatures, numClasses) + } + + test("Multiclass classification stump with 10-ary (ordered) categorical features") { + val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val odt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 10, 1 -> 10) + val numClasses = 3 + compareAPIs(rdd, dt, odt, categoricalFeatures, numClasses) + } + + test("Multiclass classification tree with 10-ary (ordered) categorical features," + + " with just enough bins") { + val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(10) + val odt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(10) + val categoricalFeatures = Map(0 -> 10, 1 -> 10) + val numClasses = 3 + compareAPIs(rdd, dt, odt, categoricalFeatures, numClasses) + } + + test("split must satisfy min instances per node requirements") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val odt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val numClasses = 2 + compareAPIs(rdd, dt, odt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("do not choose split that does not satisfy min instance per node requirements") { + // if a split does not satisfy min instances per node requirements, + // this split is invalid, even though the information gain of split is large. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0))) + val rdd = sc.parallelize(arr) + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxBins(2) + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val odt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxBins(2) + .setMaxDepth(2) + .setMinInstancesPerNode(2) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + val numClasses = 2 + compareAPIs(rdd, dt, odt, categoricalFeatures, numClasses) + } + + test("split must satisfy min info gain requirements") { + val arr = Array( + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))), + LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))) + val rdd = sc.parallelize(arr) + + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInfoGain(1.0) + val odt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setMinInfoGain(1.0) + val numClasses = 2 + compareAPIs(rdd, dt, odt, categoricalFeatures = Map.empty[Int, Int], numClasses) + } + + test("predictRaw and predictProbability") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + + val newData: DataFrame = OptimizedTreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + + MLTestingUtils.checkCopyAndUids(dt, newTree) + + testTransformer[(Vector, Double)](newData, newTree, + "prediction", "rawPrediction", "probability") { + case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") + } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, OptimizedDecisionTreeClassificationModel](this, newTree, newData) + } + + test("prediction on single instance") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + + val newData: DataFrame = OptimizedTreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + + testPredictionModelSinglePrediction(newTree, newData) + } + + test("training with 1-category categorical feature") { + val data = sc.parallelize(Seq( + LabeledPoint(0, Vectors.dense(0, 2, 3)), + LabeledPoint(1, Vectors.dense(0, 3, 1)), + LabeledPoint(0, Vectors.dense(0, 2, 2)), + LabeledPoint(1, Vectors.dense(0, 3, 9)), + LabeledPoint(0, Vectors.dense(0, 2, 6)) + )) + val df = OptimizedTreeTests.setMetadata(data, Map(0 -> 1), 2) + val dt = new OptimizedDecisionTreeClassifier().setMaxDepth(3) + dt.fit(df) + } + + test("should support all NumericType labels and not support other types") { + val dt = new OptimizedDecisionTreeClassifier().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[OptimizedDecisionTreeClassificationModel, OptimizedDecisionTreeClassifier]( + dt, spark) { (expected, actual) => + OptimizedTreeTests.checkEqual(expected, actual) + } + } + + test("Fitting without numClasses in metadata") { + val df: DataFrame = OptimizedTreeTests.featureImportanceData(sc).toDF() + val dt = new OptimizedDecisionTreeClassifier().setMaxDepth(1) + dt.fit(df) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + test("read/write") { + def checkModelData( + model: OptimizedDecisionTreeClassificationModel, + model2: OptimizedDecisionTreeClassificationModel): Unit = { + OptimizedTreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) + assert(model.numClasses === model2.numClasses) + } + + val dt = new OptimizedDecisionTreeClassifier() + val rdd = OptimizedTreeTests.getTreeReadWriteData(sc) + + val allParamSettings = OptimizedTreeTests.allParamSettings ++ Map("impurity" -> "entropy") + + // Categorical splits with tree depth 2 + val categoricalData: DataFrame = + OptimizedTreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2) + testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, + allParamSettings, checkModelData) + + // Continuous splits with tree depth 2 + val continuousData: DataFrame = + OptimizedTreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, + allParamSettings, checkModelData) + + // Continuous splits with tree depth 0 + testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0), + allParamSettings ++ Map("maxDepth" -> 0), checkModelData) + } + + test("SPARK-20043: " + + "ImpurityCalculator builder fails for uppercase impurity type Gini in model read/write") { + val rdd = OptimizedTreeTests.getTreeReadWriteData(sc) + val data: DataFrame = + OptimizedTreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + + val dt = new OptimizedDecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + val model = dt.fit(data) + + testDefaultReadWrite(model) + } +} + +private[ml] object OptimizedDecisionTreeClassifierSuite extends SparkFunSuite { + + /** + * Train 2 decision trees on the given dataset, one using the old API and one using the new API. + * Convert the old tree to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + dt: DecisionTreeClassifier, + odt: OptimizedDecisionTreeClassifier, + categoricalFeatures: Map[Int, Int], + numClasses: Int): Unit = { + val numFeatures = data.first().features.size + val newData: DataFrame = OptimizedTreeTests.setMetadata(data, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + val optimizedTree = odt.fit(newData) + // Use parent from newTree since this is not checked anyways. + OptimizedTreeTests.checkEqual(newTree, optimizedTree) + assert(newTree.numFeatures === numFeatures) + assert(optimizedTree.numFeatures === numFeatures) + } +} diff --git a/src/test/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifierSuite.scala b/src/test/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifierSuite.scala new file mode 100755 index 0000000..4e1fe14 --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/classification/OptimizedRandomForestClassifierSuite.scala @@ -0,0 +1,250 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.tree.impl.{OptimizedTreeTests, TreeTests} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} + +/** + * Test suite for [[OptimizedRandomForestClassifier]]. + */ +class OptimizedRandomForestClassifierSuite extends MLTest with DefaultReadWriteTest { + + import OptimizedRandomForestClassifierSuite.compareAPIs + import testImplicits._ + + private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ + private var orderedLabeledPoints5_20: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + orderedLabeledPoints50_1000 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)) + .map(_.asML) + orderedLabeledPoints5_20 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 5, 20)) + .map(_.asML) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + def binaryClassificationTestWithContinuousFeatures(rf: RandomForestClassifier, orf: OptimizedRandomForestClassifier) { + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + val newRF = rf + .setImpurity("Gini") + .setMaxDepth(2) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + val optimizedRF = orf + .setImpurity("Gini") + .setMaxDepth(2) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(orderedLabeledPoints50_1000, newRF, optimizedRF, categoricalFeatures, numClasses) + } + + test("Binary classification with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestClassifier() + val orf = new OptimizedRandomForestClassifier() + binaryClassificationTestWithContinuousFeatures(rf, orf) + } + + test("Binary classification with continuous features and node Id cache:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestClassifier() + .setCacheNodeIds(true) + val orf = new OptimizedRandomForestClassifier() + .setCacheNodeIds(true) + binaryClassificationTestWithContinuousFeatures(rf, orf) + } + + test("alternating categorical and continuous features with multiclass labels to test indexing") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)), + LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) + ) + val rdd = sc.parallelize(arr) + val categoricalFeatures = Map(0 -> 3, 2 -> 2, 4 -> 4) + val numClasses = 3 + + val rf = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(5) + .setNumTrees(2) + .setFeatureSubsetStrategy("all") + .setSeed(12345) + val orf = new OptimizedRandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(5) + .setNumTrees(2) + .setFeatureSubsetStrategy("all") + .setSeed(12345) + compareAPIs(rdd, rf, orf, categoricalFeatures, numClasses) + } + + // Skip test: Different random generators are created during local training + ignore("subsampling rate in RandomForest") { + val rdd = orderedLabeledPoints5_20 + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val rf1 = new RandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setCacheNodeIds(true) + .setNumTrees(3) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + val orf1 = new OptimizedRandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(2) + .setCacheNodeIds(true) + .setNumTrees(3) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(rdd, rf1, orf1, categoricalFeatures, numClasses) + + val rf2 = rf1.setSubsamplingRate(0.5) + val orf2 = orf1.setSubsamplingRate(0.5) + compareAPIs(rdd, rf2, orf2, categoricalFeatures, numClasses) + } + + test("predictRaw and predictProbability") { + val rdd = orderedLabeledPoints5_20 + val rf = new OptimizedRandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setSeed(123) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val df: DataFrame = OptimizedTreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(df) + + MLTestingUtils.checkCopyAndUids(rf, model) + + testTransformer[(Vector, Double)](df, model, "prediction", "rawPrediction") { + case Row(pred: Double, rawPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + } + + ProbabilisticClassifierSuite.testPredictMethods[ + Vector, OptimizedRandomForestClassificationModel](this, model, df) + } + + test("prediction on single instance") { + val rdd = orderedLabeledPoints5_20 + val rf = new OptimizedRandomForestClassifier() + .setImpurity("Gini") + .setMaxDepth(3) + .setNumTrees(3) + .setSeed(123) + val categoricalFeatures = Map.empty[Int, Int] + val numClasses = 2 + + val df: DataFrame = OptimizedTreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val model = rf.fit(df) + + testPredictionModelSinglePrediction(model, df) + } + + test("Fitting without numClasses in metadata") { + val df: DataFrame = OptimizedTreeTests.featureImportanceData(sc).toDF() + val rf = new OptimizedRandomForestClassifier().setMaxDepth(1).setNumTrees(1) + rf.fit(df) + } + + test("should support all NumericType labels and not support other types") { + val rf = new OptimizedRandomForestClassifier().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[OptimizedRandomForestClassificationModel, OptimizedRandomForestClassifier]( + rf, spark) { (expected, actual) => + OptimizedTreeTests.checkEqual(expected, actual) + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + test("read/write") { + def checkModelData( + model: OptimizedRandomForestClassificationModel, + model2: OptimizedRandomForestClassificationModel): Unit = { + OptimizedTreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) + assert(model.numClasses === model2.numClasses) + } + + val rf = new OptimizedRandomForestClassifier().setNumTrees(2) + val rdd = OptimizedTreeTests.getTreeReadWriteData(sc) + + val allParamSettings = OptimizedTreeTests.allParamSettings ++ Map("impurity" -> "entropy") + + val continuousData: DataFrame = + OptimizedTreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, + allParamSettings, checkModelData) + } +} + +private object OptimizedRandomForestClassifierSuite extends SparkFunSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + rf: RandomForestClassifier, + orf: OptimizedRandomForestClassifier, + categoricalFeatures: Map[Int, Int], + numClasses: Int): Unit = { + val numFeatures = data.first().features.size + val newData: DataFrame = OptimizedTreeTests.setMetadata(data, categoricalFeatures, numClasses) + val newModel = rf.fit(newData) + val optimizedModel = orf.fit(newData) + + // Use parent from newTree since this is not checked anyways. + OptimizedTreeTests.checkEqualOldClassification(newModel, optimizedModel) + assert(optimizedModel.hasParent) + assert(!optimizedModel.trees.head.asInstanceOf[OptimizedDecisionTreeClassificationModel].hasParent) + assert(optimizedModel.numClasses === numClasses) + assert(optimizedModel.numFeatures === numFeatures) + } +} \ No newline at end of file diff --git a/src/test/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressorSuite.scala b/src/test/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressorSuite.scala new file mode 100755 index 0000000..ec9c567 --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/regression/OptimizedDecisionTreeRegressorSuite.scala @@ -0,0 +1,170 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.tree.impl.{OptimizedTreeTests, TreeTests} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} +import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} +import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + +class OptimizedDecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest { + + import OptimizedDecisionTreeRegressorSuite.compareAPIs + import testImplicits._ + + private var categoricalDataPointsRDD: RDD[LabeledPoint] = _ + private var linearRegressionData: DataFrame = _ + + private val seed = 42 + + override def beforeAll() { + super.beforeAll() + categoricalDataPointsRDD = + sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints().map(_.asML)) + linearRegressionData = sc.parallelize(LinearDataGenerator.generateLinearInput( + intercept = 6.3, weights = Array(4.7, 7.2), xMean = Array(0.9, -1.3), + xVariance = Array(0.7, 1.2), nPoints = 1000, seed, eps = 0.5), 2).map(_.asML).toDF() + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + test("Regression stump with 3-ary (ordered) categorical features") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + .setSeed(1) + val odt = new OptimizedDecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + .setSeed(1) + val categoricalFeatures = Map(0 -> 3, 1 -> 3) + compareAPIs(categoricalDataPointsRDD, dt, odt, categoricalFeatures) + } + + test("Regression stump with binary (ordered) categorical features") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + val odt = new OptimizedDecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + compareAPIs(categoricalDataPointsRDD, dt, odt, categoricalFeatures) + } + + test("copied model must have the same parent") { + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val dtr = new OptimizedDecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(8) + val model = dtr.fit(df) + MLTestingUtils.checkCopyAndUids(dtr, model) + } + + test("prediction on single instance") { + val dt = new OptimizedDecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(3) + .setSeed(123) + + // In this data, feature 1 is very important. + val data: RDD[LabeledPoint] = OptimizedTreeTests.featureImportanceData(sc) + val categoricalFeatures = Map.empty[Int, Int] + val df: DataFrame = OptimizedTreeTests.setMetadata(data, categoricalFeatures, 0) + + val model = dt.fit(df) + testPredictionModelSinglePrediction(model, df) + } + + test("should support all NumericType labels and not support other types") { + val dt = new OptimizedDecisionTreeRegressor().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[OptimizedDecisionTreeRegressionModel, OptimizedDecisionTreeRegressor]( + dt, spark, isClassification = false) { (expected, actual) => + OptimizedTreeTests.checkEqual(expected, actual) + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + test("read/write") { + def checkModelData( + model: OptimizedDecisionTreeRegressionModel, + model2: OptimizedDecisionTreeRegressionModel): Unit = { + OptimizedTreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) + } + + val dt = new OptimizedDecisionTreeRegressor() + val rdd = OptimizedTreeTests.getTreeReadWriteData(sc) + + // Categorical splits with tree depth 2 + val categoricalData: DataFrame = + OptimizedTreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0) + testEstimatorAndModelReadWrite(dt, categoricalData, + OptimizedTreeTests.allParamSettings, OptimizedTreeTests.allParamSettings, checkModelData) + + // Continuous splits with tree depth 2 + val continuousData: DataFrame = + OptimizedTreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(dt, continuousData, + OptimizedTreeTests.allParamSettings, OptimizedTreeTests.allParamSettings, checkModelData) + + // Continuous splits with tree depth 0 + testEstimatorAndModelReadWrite(dt, continuousData, + OptimizedTreeTests.allParamSettings ++ Map("maxDepth" -> 0), + OptimizedTreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) + } +} + +private[ml] object OptimizedDecisionTreeRegressorSuite extends SparkFunSuite { + + /** + * Train 2 decision trees on the given dataset, one using the old API and one using the new API. + * Convert the old tree to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + dt: DecisionTreeRegressor, + odt: OptimizedDecisionTreeRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val numFeatures = data.first().features.size + val newData: DataFrame = OptimizedTreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newTree = dt.fit(newData) + val optimizedTree = odt.fit(newData) + // Use parent from newTree since this is not checked anyways. + OptimizedTreeTests.checkEqual(newTree, optimizedTree) + assert(optimizedTree.numFeatures === numFeatures) + } +} \ No newline at end of file diff --git a/src/test/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressorSuite.scala b/src/test/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressorSuite.scala new file mode 100755 index 0000000..f9e533a --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/regression/OptimizedRandomForestRegressorSuite.scala @@ -0,0 +1,154 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.tree.impl.OptimizedTreeTests +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} +import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.DataFrame + +/** + * Test suite for [[OptimizedRandomForestRegressor]]. + */ +class OptimizedRandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{ + + import OptimizedRandomForestRegressorSuite.compareAPIs + import testImplicits._ + + private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _ + + override def beforeAll() { + super.beforeAll() + orderedLabeledPoints50_1000 = + sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) + .map(_.asML)) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests calling train() + ///////////////////////////////////////////////////////////////////////////// + + def regressionTestWithContinuousFeatures(rf: RandomForestRegressor, orf: OptimizedRandomForestRegressor) { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val newRF = rf + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + val optimizedRF = orf + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + compareAPIs(orderedLabeledPoints50_1000, newRF, optimizedRF, categoricalFeaturesInfo) + } + + test("Regression with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestRegressor() + val orf = new OptimizedRandomForestRegressor() + regressionTestWithContinuousFeatures(rf, orf) + } + + test("Regression with continuous features and node Id cache :" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val rf = new RandomForestRegressor() + .setCacheNodeIds(true) + val orf = new OptimizedRandomForestRegressor() + .setCacheNodeIds(true) + regressionTestWithContinuousFeatures(rf, orf) + } + + test("prediction on single instance") { + val rf = new OptimizedRandomForestRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(10) + .setNumTrees(1) + .setFeatureSubsetStrategy("auto") + .setSeed(123) + + val df = orderedLabeledPoints50_1000.toDF() + val model = rf.fit(df) + testPredictionModelSinglePrediction(model, df) + } + + + test("should support all NumericType labels and not support other types") { + val rf = new OptimizedRandomForestRegressor().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[OptimizedRandomForestRegressionModel, OptimizedRandomForestRegressor]( + rf, spark, isClassification = false) { (expected, actual) => + OptimizedTreeTests.checkEqual(expected, actual) + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of model save/load + ///////////////////////////////////////////////////////////////////////////// + + test("read/write") { + def checkModelData( + model: OptimizedRandomForestRegressionModel, + model2: OptimizedRandomForestRegressionModel): Unit = { + OptimizedTreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) + } + + val rf = new OptimizedRandomForestRegressor().setNumTrees(2) + val rdd = OptimizedTreeTests.getTreeReadWriteData(sc) + + val allParamSettings = OptimizedTreeTests.allParamSettings ++ Map("impurity" -> "variance") + + val continuousData: DataFrame = + OptimizedTreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, + allParamSettings, checkModelData) + } +} + +private object OptimizedRandomForestRegressorSuite extends SparkFunSuite { + + /** + * Train 2 models on the given dataset, one using the old API and one using the new API. + * Convert the old model to the new format, compare them, and fail if they are not exactly equal. + */ + def compareAPIs( + data: RDD[LabeledPoint], + rf: RandomForestRegressor, + orf: OptimizedRandomForestRegressor, + categoricalFeatures: Map[Int, Int]): Unit = { + val numFeatures = data.first().features.size + val newData: DataFrame = OptimizedTreeTests.setMetadata(data, categoricalFeatures, numClasses = 0) + val newModel = rf.fit(newData) + val optimizedModel = orf.fit(newData) + // Use parent from newTree since this is not checked anyways. + OptimizedTreeTests.checkEqualOldRegression(newModel, optimizedModel) + assert(newModel.numFeatures === numFeatures) + } +} \ No newline at end of file diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala b/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala new file mode 100755 index 0000000..fe4e9fe --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/tree/impl/LocalDecisionTreeRegressor.scala @@ -0,0 +1,76 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.ml.Predictor +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, OptimizedDecisionTreeRegressionModel} +import org.apache.spark.ml.tree.{DecisionTreeParams, OptimizedDecisionTreeParams, TreeRegressorParams} +import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Dataset + +/** + * Test-only class for fitting a decision tree regressor on a dataset small enough to fit on a + * single machine. + */ +private[impl] final class LocalDecisionTreeRegressor(override val uid: String) + extends Predictor[Vector, LocalDecisionTreeRegressor, OptimizedDecisionTreeRegressionModel] + with OptimizedDecisionTreeParams with TreeRegressorParams { + + def this() = this(Identifiable.randomUID("local_dtr")) + + // Override parameter setters from parent trait for Java API compatibility. + override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) + + override def setMaxBins(value: Int): this.type = super.setMaxBins(value) + + override def setMinInstancesPerNode(value: Int): this.type = + super.setMinInstancesPerNode(value) + + override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) + + override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) + + override def setImpurity(value: String): this.type = super.setImpurity(value) + + override def setSeed(value: Long): this.type = super.setSeed(value) + + override def copy(extra: ParamMap): LocalDecisionTreeRegressor = defaultCopy(extra) + + override protected def train(dataset: Dataset[_]): OptimizedDecisionTreeRegressionModel = { + val categoricalFeatures: Map[Int, Int] = + MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) + val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) + val strategy = getOldStrategy(categoricalFeatures) + val model = LocalTreeTests.train(oldDataset, strategy, parentUID = Some(uid), + seed = getSeed) + model.asInstanceOf[OptimizedDecisionTreeRegressionModel] + } + + /** Create a Strategy instance to use with the old API. */ + private[impl] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { + super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, + subsamplingRate = 1.0) + } +} diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/LocalTrainingPlanSuite.scala b/src/test/scala/org/apache/spark/ml/tree/impl/LocalTrainingPlanSuite.scala new file mode 100755 index 0000000..6d7965c --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/tree/impl/LocalTrainingPlanSuite.scala @@ -0,0 +1,74 @@ +/* + * Copyright (C) 2019 Cisco Systems + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.tree.OptimizedLearningNode +import org.apache.spark.mllib.tree.configuration.{DefaultTimePredictionStrategy, TimePredictionStrategy} +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class LocalTrainingPlanSuite extends SparkFunSuite with MLlibTestSparkContext { + + val timePredictonStrategy: TimePredictionStrategy = new DefaultTimePredictionStrategy + + test("memory restriction") { + val plan = new LocalTrainingPlan(10, timePredictonStrategy, Int.MaxValue) + + plan.scheduleTask(new LocalTrainingTask(node = OptimizedLearningNode.emptyNode(1), + treeIndex = 1, rows = 2, impurity = 1.0)) + plan.scheduleTask(new LocalTrainingTask(node = OptimizedLearningNode.emptyNode(1), + treeIndex = 1, rows = 2, impurity = 1.0)) + plan.scheduleTask(new LocalTrainingTask(node = OptimizedLearningNode.emptyNode(1), + treeIndex = 1, rows = 2, impurity = 1.0)) + + plan.scheduleTask(new LocalTrainingTask(node = OptimizedLearningNode.emptyNode(1), + treeIndex = 1, rows = 9, impurity = 1.0)) + + assert(plan.bins.length == 2) + assert(plan.bins.head.tasks.length == 3) + assert(plan.bins(1).tasks.length == 1) + } + + test("count restriction") { + val plan = new LocalTrainingPlan(10, timePredictonStrategy, 2) + + plan.scheduleTask(new LocalTrainingTask(node = OptimizedLearningNode.emptyNode(1), + treeIndex = 1, rows = 2, impurity = 1.0)) + plan.scheduleTask(new LocalTrainingTask(node = OptimizedLearningNode.emptyNode(1), + treeIndex = 1, rows = 2, impurity = 1.0)) + plan.scheduleTask(new LocalTrainingTask(node = OptimizedLearningNode.emptyNode(1), + treeIndex = 1, rows = 2, impurity = 1.0)) + + assert(plan.bins.length == 2) + assert(plan.bins.head.tasks.length == 2) + assert(plan.bins(1).tasks.length == 1) + } + + test("task implicit ordering by memory usage descending") { + val l = List(new LocalTrainingTask(node = OptimizedLearningNode.emptyNode(1), + treeIndex = 1, rows = 1, impurity = 1.0), + new LocalTrainingTask(node = OptimizedLearningNode.emptyNode(1), + treeIndex = 2, rows = 5, impurity = 1.0), + new LocalTrainingTask(node = OptimizedLearningNode.emptyNode(1), + treeIndex = 3, rows = 3, impurity = 1.0) + ) + + val sorted = l.sorted + + assert(sorted.head.treeIndex == 2) + } +} diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeDataSuite.scala b/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeDataSuite.scala new file mode 100755 index 0000000..91184ca --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeDataSuite.scala @@ -0,0 +1,202 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.collection.BitSet + +import scala.util.Random + +/** Suite exercising data structures (FeatureVector, TrainingInfo) for local tree training. */ +class LocalTreeDataSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("FeatureVector: updating columns for split") { + val vecLength = 100 + // Create a column of vecLength values + val values = 0.until(vecLength).toArray + val col = FeatureColumn(-1, values) + // Pick a random subset of indices to split left + val rng = new Random(seed = 42) + val leftProb = 0.5 + val (leftIdxs, rightIdxs) = values.indices.partition(_ => rng.nextDouble() < leftProb) + // Determine our expected result after updating for split + val expected = leftIdxs.map(values(_)) ++ rightIdxs.map(values(_)) + // Create a bitset indicating whether each of our values splits left or right + val instanceBitVector = new BitSet(values.length) + rightIdxs.foreach(instanceBitVector.set) + // Update column, compare new values to expected result + val tempVals = new Array[Int](vecLength) + val tempIndices = new Array[Int](vecLength) + LocalDecisionTreeUtils.updateArrayForSplit(col.values, from = 0, to = vecLength, + leftIdxs.length, tempVals, instanceBitVector) + assert(col.values.sameElements(expected)) + } + + /* Check that FeatureVector methods produce expected results */ + test("FeatureVector: constructor and deepCopy") { + // Create a feature vector v, modify a deep copy of v, and check that + // v itself was not modified + val v = new FeatureColumn(1, Array(1, 2, 3)) + val vCopy = v.deepCopy() + vCopy.values(0) = 1000 + assert(v.values(0) !== vCopy.values(0)) + } + + // Get common TrainingInfo for tests + // Data: + // Feature 0 (continuous): [3, 2, 0, 1] + // Feature 1 (categorical):[0, 0, 2, 1] + private def getTrainingInfo(): TrainingInfo = { + val numRows = 4 + // col1 is continuous features + val col1 = FeatureColumn(featureIndex = 0, Array(3, 2, 0, 1)) + // col2 is categorical features + val catFeatureIdx = 1 + val col2 = FeatureColumn(featureIndex = catFeatureIdx, values = Array(0, 0, 2, 1)) + + val nodeOffsets = Array((0, numRows)) + val activeNodes = Array(OptimizedLearningNode.emptyNode(nodeIndex = -1)) + TrainingInfo(Array(col1, col2), nodeOffsets, activeNodes) + } + + // Check that TrainingInfo correctly updates node offsets, sorts column values during update() + test("TrainingInfo.update(): correctness when splitting on continuous features") { + // Get TrainingInfo + // Feature 0 (continuous): [3, 2, 0, 1] + // Feature 1 (categorical):[0, 0, 2, 1] + val info = getTrainingInfo() + val activeNodes = info.currentLevelActiveNodes + val contFeatureIdx = 0 + + // For continuous feature, active node has a split with threshold 1 + val contNode = activeNodes(contFeatureIdx) + contNode.split = Some(new ContinuousSplit(contFeatureIdx, threshold = 1)) + + // Update TrainingInfo for continuous split + val contValues = info.columns(contFeatureIdx).values + val splits = Array(LocalTreeTests.getContinuousSplits(contValues, contFeatureIdx)) + val newInfo = info.update(splits, newActiveNodes = Array(contNode)) + + assert(newInfo.columns.length === 2) + // Continuous split should send feature values [0, 1] to the left, [3, 2] to the right + // ==> row indices (2, 3) should split left, row indices (0, 1) should split right + val expectedContCol = new FeatureColumn(0, values = Array(0, 1, 3, 2)) + val expectedCatCol = new FeatureColumn(1, values = Array(2, 1, 0, 0)) + val expectedIndices = Array(2, 3, 0, 1) + assert(newInfo.columns(0) === expectedContCol) + assert(newInfo.columns(1) === expectedCatCol) + assert(newInfo.indices === expectedIndices) + // Check that node offsets were updated properly + assert(newInfo.nodeOffsets === Array((0, 2), (2, 4))) + } + + test("TrainingInfo.update(): correctness when splitting on categorical features") { + // Get TrainingInfo + // Feature 0 (continuous): [3, 2, 0, 1] + // Feature 1 (categorical):[0, 0, 2, 1] + val info = getTrainingInfo() + val activeNodes = info.currentLevelActiveNodes + val catFeatureIdx = 1 + + // For categorical feature, active node puts category 2 on left side of split + val catNode = activeNodes(0) + val catSplit = new CategoricalSplit(catFeatureIdx, _leftCategories = Array(2), + numCategories = 3) + catNode.split = Some(catSplit) + + // Update TrainingInfo for categorical split + val splits: Array[Array[Split]] = Array(Array.empty, Array(catSplit)) + val newInfo = info.update(splits, newActiveNodes = Array(catNode)) + + assert(newInfo.columns.length === 2) + // Categorical split should send feature values [2] to the left, [0, 1] to the right + // ==> row 2 should split left, rows [0, 1, 3] should split right + val expectedContCol = new FeatureColumn(0, values = Array(0, 3, 2, 1)) + val expectedCatCol = new FeatureColumn(1, values = Array(2, 0, 0, 1)) + val expectedIndices = Array(2, 0, 1, 3) + assert(newInfo.columns(0) === expectedContCol) + assert(newInfo.columns(1) === expectedCatCol) + assert(newInfo.indices === expectedIndices) + // Check that node offsets were updated properly + assert(newInfo.nodeOffsets === Array((0, 1), (1, 4))) + } + + private def getSetBits(bitset: BitSet): Set[Int] = { + Range(0, bitset.capacity).filter(bitset.get).toSet + } + + test("TrainingInfo.bitSetFromSplit correctness: splitting a single node") { + val featureIndex = 0 + val thresholds = Array(1, 2, 4, 6, 7) + val values = thresholds.indices.toArray + val splits = LocalTreeTests.getContinuousSplits(thresholds, featureIndex) + val col = FeatureColumn(0, values) + val fromOffset = 0 + val toOffset = col.values.length + val numRows = toOffset + // Create split; first three rows (with feature values [1, 2, 4]) should split left, as they + // have feature values <= 5. Last two rows (feature values [6, 7]) should split right. + val split = new ContinuousSplit(0, threshold = 5) + val bitset = TrainingInfo.bitSetFromSplit(col, fromOffset, toOffset, split, splits) + // Check that the last two rows (row indices [3, 4] within the set of rows being split) + // fall on the right side of the split. + assert(getSetBits(bitset) === Set(3, 4)) + } + + test("TrainingInfo.bitSetFromSplit correctness: splitting 2 nodes") { + // Assume there was already 1 split, which split rows (represented by row index) as: + // (0, 2, 4) | (1, 3) + val thresholds = Array(1, 2, 4, 6, 7) + val values = thresholds.indices.toArray + val splits = LocalTreeTests.getContinuousSplits(thresholds, featureIndex = 0) + val col = new FeatureColumn(0, values) + + /** + * Computes a bitset for splitting rows in with indices in [fromOffset, toOffset) using a + * continuous split with the specified threshold. Then, checks that right side of the split + * contains the row indices in expectedRight. + */ + def checkSplit( + fromOffset: Int, + toOffset: Int, + threshold: Double, + expectedRight: Set[Int]): Unit = { + val split = new ContinuousSplit(0, threshold) + val numRows = col.values.length + val bitset = TrainingInfo.bitSetFromSplit(col, fromOffset, toOffset, split, splits) + assert(getSetBits(bitset) === expectedRight) + } + + // Split rows corresponding to left child node (rows [0, 2, 4]) + checkSplit(fromOffset = 0, toOffset = 3, threshold = 0.5, expectedRight = Set(0, 1, 2)) + checkSplit(fromOffset = 0, toOffset = 3, threshold = 1.5, expectedRight = Set(1, 2)) + checkSplit(fromOffset = 0, toOffset = 3, threshold = 2, expectedRight = Set(2)) + checkSplit(fromOffset = 0, toOffset = 3, threshold = 5, expectedRight = Set()) + // Split rows corresponding to right child node (rows [1, 3]) + checkSplit(fromOffset = 3, toOffset = 5, threshold = 1, expectedRight = Set(0, 1)) + checkSplit(fromOffset = 3, toOffset = 5, threshold = 6.5, expectedRight = Set(1)) + checkSplit(fromOffset = 3, toOffset = 5, threshold = 8, expectedRight = Set()) + } + +} diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeIntegrationSuite.scala b/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeIntegrationSuite.scala new file mode 100755 index 0000000..7d4393c --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeIntegrationSuite.scala @@ -0,0 +1,106 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.Estimator +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.regression.DecisionTreeRegressor +import org.apache.spark.mllib.tree.DecisionTreeSuite +import org.apache.spark.mllib.util.{LogisticRegressionDataGenerator, MLlibTestSparkContext} +import org.apache.spark.sql.DataFrame + +/** Tests checking equivalence of trees produced by local and distributed tree training. */ +class LocalTreeIntegrationSuite extends SparkFunSuite with MLlibTestSparkContext { + + val medDepthTreeSettings = OptimizedTreeTests.allParamSettings ++ Map[String, Any]("maxDepth" -> 4) + + /** + * For each (paramName, paramVal) pair in the passed-in map, set the corresponding + * parameter of the passed-in estimator & return the estimator. + */ + private def setParams[E <: Estimator[_]](estimator: E, params: Map[String, Any]): E = { + params.foreach { case (p, v) => + estimator.set(estimator.getParam(p), v) + } + estimator + } + + /** + * Verifies that local tree training & distributed training produce the same tree + * when fit on the same dataset with the same set of params. + */ + private def testEquivalence(train: DataFrame, testParams: Map[String, Any]): Unit = { + val distribTree = setParams(new DecisionTreeRegressor(), testParams) + val localTree = setParams(new LocalDecisionTreeRegressor(), testParams) + val localModel = localTree.fit(train) + val model = distribTree.fit(train) + OptimizedTreeTests.checkEqual(model, localModel) + } + + + test("Local & distributed training produce the same tree on a toy dataset") { + val data = sc.parallelize(Range(0, 8).map(x => LabeledPoint(x, Vectors.dense(x)))) + val df = spark.createDataFrame(data) + testEquivalence(df, OptimizedTreeTests.allParamSettings) + } + + test("Local & distributed training produce the same tree on a slightly larger toy dataset") { + val data = sc.parallelize(Range(0, 16).map(x => LabeledPoint(x, Vectors.dense(x)))) + val df = spark.createDataFrame(data) + testEquivalence(df, medDepthTreeSettings) + } + + test("Local & distributed training produce the same tree on a larger toy dataset") { + val data = sc.parallelize(Range(0, 64).map(x => LabeledPoint(x, Vectors.dense(x)))) + val df = spark.createDataFrame(data) + testEquivalence(df, medDepthTreeSettings) + } + + test("Local & distributed training produce same tree on a dataset of categorical features") { + val data = sc.parallelize(DecisionTreeSuite.generateCategoricalDataPoints().map(_.asML)) + // Create a map of categorical feature index to arity; each feature has arity nclasses + val featuresMap: Map[Int, Int] = Map(0 -> 3, 1 -> 3) + // Convert the data RDD to a DataFrame with metadata indicating the arity of each of its + // categorical features + val df = OptimizedTreeTests.setMetadata(data, featuresMap, numClasses = 2) + testEquivalence(df, OptimizedTreeTests.allParamSettings) + } + + test("Local & distributed training produce the same tree on a dataset of continuous features") { + val sqlContext = spark.sqlContext + import sqlContext.implicits._ + // Use maxDepth = 5 and default params + val params = medDepthTreeSettings + val data = LogisticRegressionDataGenerator.generateLogisticRDD(spark.sparkContext, + nexamples = 1000, nfeatures = 5, eps = 2.0, nparts = 1, probOne = 0.2) + .map(_.asML).toDF().cache() + testEquivalence(data, params) + } + + test("Local & distributed training produce the same tree on a dataset of constant features") { + // Generate constant, continuous data + val data = sc.parallelize(Range(0, 8).map(_ => LabeledPoint(1, Vectors.dense(1)))) + val df = spark.createDataFrame(data) + testEquivalence(df, OptimizedTreeTests.allParamSettings) + } + +} diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeTests.scala b/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeTests.scala new file mode 100755 index 0000000..9e8e939 --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeTests.scala @@ -0,0 +1,108 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.internal.Logging +import org.apache.spark.ml.classification.OptimizedDecisionTreeClassificationModel +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.regression.OptimizedDecisionTreeRegressionModel +import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} +import org.apache.spark.rdd.RDD + + +/** Object providing test-only methods for local decision tree training. */ +private[impl] object LocalTreeTests extends Logging { + + /** + * Given the root node of a decision tree, returns a corresponding DecisionTreeModel + * @param algo Enum describing the algorithm used to fit the tree + * @param numClasses Number of label classes (for classification trees) + * @param parentUID UID of parent estimator + */ + private[impl] def finalizeTree( + rootNode: OptimizedNode, + algo: OldAlgo.Algo, + numClasses: Int, + numFeatures: Int, + parentUID: Option[String]): OptimizedDecisionTreeModel = { + parentUID match { + case Some(uid) => + if (algo == OldAlgo.Classification) { + new OptimizedDecisionTreeClassificationModel(uid, rootNode, numFeatures = numFeatures, + numClasses = numClasses) + } else { + new OptimizedDecisionTreeRegressionModel(uid, rootNode, numFeatures = numFeatures) + } + case None => + if (algo == OldAlgo.Classification) { + new OptimizedDecisionTreeClassificationModel(rootNode, numFeatures = numFeatures, + numClasses = numClasses) + } else { + new OptimizedDecisionTreeRegressionModel(rootNode, numFeatures = numFeatures) + } + } + } + + /** + * Method to locally train a decision tree model over an RDD. Assumes the RDD is small enough + * to be collected at a single worker and used to fit a decision tree locally. + * Only used for testing. + */ + private[impl] def train( + input: RDD[LabeledPoint], + strategy: OldStrategy, + seed: Long, + parentUID: Option[String] = None): OptimizedDecisionTreeModel = { + + // Validate input data + require(input.count() > 0, "Local decision tree training requires > 0 training examples.") + val numFeatures = input.first().features.size + require(numFeatures > 0, "Local decision tree training requires > 0 features.") + + // Construct metadata, find splits + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val splits = RandomForest.findSplits(input, metadata, seed) + + // Bin feature values (convert to TreePoint representation). + val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata).collect() + val instanceWeights = Array.fill[Double](treeInput.length)(1.0) + + // Create tree root node + val initialRoot = OptimizedLearningNode.emptyNode(nodeIndex = 1) + // Fit tree + val rootNode = (new LocalDecisionTree).fitNode(treeInput, instanceWeights, + initialRoot, metadata, splits) + finalizeTree(rootNode, strategy.algo, strategy.numClasses, numFeatures, parentUID) + } + + /** + * Returns an array of continuous splits for the feature with index featureIndex and the passed-in + * set of values. Creates one continuous split per value in values. + */ + private[impl] def getContinuousSplits( + values: Array[Int], + featureIndex: Int): Array[Split] = { + val splits = values.sorted.map { + new ContinuousSplit(featureIndex, _).asInstanceOf[Split] + } + splits + } +} diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUnitSuite.scala b/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUnitSuite.scala new file mode 100755 index 0000000..c7a0a3b --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUnitSuite.scala @@ -0,0 +1,109 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** Unit tests for helper classes/methods specific to local tree training */ +class LocalTreeUnitSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + test("Fit a single decision tree regressor on constant features") { + // Generate constant, continuous data + val data = sc.parallelize(Range(0, 8).map(_ => LabeledPoint(1, Vectors.dense(1)))) + val df = spark.sqlContext.createDataFrame(data) + // Initialize estimator + val dt = new LocalDecisionTreeRegressor() + .setFeaturesCol("features") // indexedFeatures + .setLabelCol("label") + .setMaxDepth(3) + // Fit model + val model = dt.fit(df) + assert(model.rootNode.isInstanceOf[OptimizedLeafNode]) + val root = model.rootNode.asInstanceOf[OptimizedLeafNode] + assert(root.prediction == 1) + } + + test("Fit a single decision tree regressor on some continuous features") { + // Generate continuous data + val data = sc.parallelize(Range(0, 8).map(x => LabeledPoint(x, Vectors.dense(x)))) + val df = spark.createDataFrame(data) + // Initialize estimator + val dt = new LocalDecisionTreeRegressor() + .setFeaturesCol("features") // indexedFeatures + .setLabelCol("label") + .setMaxDepth(3) + // Fit model + val model = dt.fit(df) + + // Check that model is of depth 3 (the specified max depth) and that leaf/internal nodes have + // the correct class. + // Validate root + assert(model.rootNode.isInstanceOf[OptimizedInternalNode]) + // Validate first level of tree (nodes with depth = 1) + val root = model.rootNode.asInstanceOf[OptimizedInternalNode] + assert(root.leftChild.isInstanceOf[OptimizedInternalNode] && root.rightChild.isInstanceOf[OptimizedInternalNode]) + // Validate second and third levels of tree (nodes with depth = 2 or 3) + val left = root.leftChild.asInstanceOf[OptimizedInternalNode] + val right = root.rightChild.asInstanceOf[OptimizedInternalNode] + val grandkids = Array(left.leftChild, left.rightChild, right.leftChild, right.rightChild) + grandkids.foreach { grandkid => + assert(grandkid.isInstanceOf[OptimizedInternalNode]) + val grandkidNode = grandkid.asInstanceOf[OptimizedInternalNode] + assert(grandkidNode.leftChild.isInstanceOf[OptimizedLeafNode]) + assert(grandkidNode.rightChild.isInstanceOf[OptimizedLeafNode]) + } + } + + test("Fit deep local trees") { + + /** + * Deep tree test. Tries to fit tree on synthetic data designed to force tree + * to split to specified depth. + */ + def deepTreeTest(depth: Int): Unit = { + val deepTreeData = OptimizedTreeTests.deepTreeData(sc, depth) + val df = spark.createDataFrame(deepTreeData) + // Construct estimators; single-tree random forest & decision tree regressor. + val localTree = new LocalDecisionTreeRegressor() + .setFeaturesCol("features") // indexedFeatures + .setLabelCol("label") + .setMaxDepth(depth) + .setMinInfoGain(0.0) + + // Fit model, check depth... + val localModel = localTree.fit(df) + assert(localModel.rootNode.subtreeDepth == depth) + } + + // Test small depth tree + deepTreeTest(10) + // Test medium depth tree + deepTreeTest(40) + // Test high depth tree + deepTreeTest(200) + } + +} diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUtilsSuite.scala b/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUtilsSuite.scala new file mode 100755 index 0000000..bc79432 --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/tree/impl/LocalTreeUtilsSuite.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite + +/** Unit tests for helper classes/methods specific to local tree training */ +class LocalTreeUtilsSuite extends SparkFunSuite { + + test("rowToColumnStoreDense: transforms row-major data into a column-major representation") { + // Attempt to transform an empty training dataset + intercept[IllegalArgumentException] { + LocalDecisionTreeUtils.rowToColumnStoreDense(Array.empty) + } + + // Transform a training dataset consisting of a single row + { + val rowLength = 10 + val data = Array(0.until(rowLength).toArray) + val transposed = LocalDecisionTreeUtils.rowToColumnStoreDense(data) + assert(transposed.length == rowLength, + s"Column-major representation of $rowLength-element row " + + s"contained ${transposed.length} elements") + transposed.foreach { col => + assert(col.length == 1, s"Column-major representation of a single row " + + s"contained column of length ${col.length}, expected length: 1") + } + } + + // Transform a dataset consisting of a single column + { + val colSize = 10 + val data = Array.tabulate[Array[Int]](colSize)(Array(_)) + val transposed = LocalDecisionTreeUtils.rowToColumnStoreDense(data) + assert(transposed.length > 0, s"Column-major representation of $colSize-element column " + + s"was empty.") + assert(transposed.length == 1, s"Column-major representation of $colSize-element column " + + s"should be a single array but was ${transposed.length} arrays.") + assert(transposed(0).length == colSize, + s"Column-major representation of $colSize-element column contained " + + s"${transposed(0).length} elements") + } + + // Transform a 2x3 (non-square) dataset + { + val data = Array(Array(0, 1, 2), Array(3, 4, 5)) + val expected = Array(Array(0, 3), Array(1, 4), Array(2, 5)) + val transposed = LocalDecisionTreeUtils.rowToColumnStoreDense(data) + transposed.zip(expected).foreach { case (resultCol, expectedCol) => + assert(resultCol.sameElements(expectedCol), s"Result column" + + s"${resultCol.mkString(", ")} differed from expected col ${expectedCol.mkString(", ")}") + } + } + } + + test("transposeSelectedFeatures") { + { + val data = Array(Array(1, 2, 3), Array(4, 5, 6), Array(7, 8, 9)) + val selected = Array(1) + + val transposed = LocalDecisionTreeUtils.transposeSelectedFeatures(data, selected) + val expected = Array(Array(2, 5, 8)) + transposed.zip(expected).foreach { case (resultCol, expectedCol) => + assert(resultCol.sameElements(expectedCol), s"Result column" + + s"${resultCol.mkString(", ")} differed from expected col ${expectedCol.mkString(", ")}") + } + } + + { + val data = Array(Array(1, 2, 3), Array(4, 5, 6), Array(7, 8, 9)) + val selected = Array(0, 2) + + val transposed = LocalDecisionTreeUtils.transposeSelectedFeatures(data, selected) + val expected = Array(Array(1, 4, 7), Array(3, 6, 9)) + transposed.zip(expected).foreach { case (resultCol, expectedCol) => + assert(resultCol.sameElements(expectedCol), s"Result column" + + s"${resultCol.mkString(", ")} differed from expected col ${expectedCol.mkString(", ")}") + } + } + } +} diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedDecisionTreeIntegrationSuite.scala b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedDecisionTreeIntegrationSuite.scala new file mode 100755 index 0000000..df92a8e --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedDecisionTreeIntegrationSuite.scala @@ -0,0 +1,128 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.Estimator +import org.apache.spark.ml.classification.{DecisionTreeClassifier, OptimizedDecisionTreeClassifier} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.regression.{DecisionTreeRegressor, OptimizedDecisionTreeRegressor} +import org.apache.spark.mllib.tree.DecisionTreeSuite +import org.apache.spark.mllib.util.{LogisticRegressionDataGenerator, MLlibTestSparkContext} +import org.apache.spark.sql.DataFrame + + +/** Tests checking equivalence of trees produced by local and distributed tree training. */ +class OptimizedDecisionTreeIntegrationSuite extends SparkFunSuite with MLlibTestSparkContext { + + val medDepthTreeSettings = OptimizedTreeTests.allParamSettings ++ Map[String, Any]("maxDepth" -> 4) + + /** + * For each (paramName, paramVal) pair in the passed-in map, set the corresponding + * parameter of the passed-in estimator & return the estimator. + */ + private def setParams[E <: Estimator[_]](estimator: E, params: Map[String, Any]): E = { + params.foreach { case (p, v) => + estimator.set(estimator.getParam(p), v) + } + estimator + } + + /** + * Verifies that local tree training & distributed training produce the same tree + * when fit on the same dataset with the same set of params. + */ + private def testEquivalence(train: DataFrame, testParams: Map[String, Any]): Unit = { + val oldTree = setParams(new DecisionTreeRegressor(), testParams) + val newTree = setParams(new OptimizedDecisionTreeRegressor(), testParams) + val newModel = newTree.fit(train) + val oldModel = oldTree.fit(train) + OptimizedTreeTests.checkEqual(oldModel, newModel) + } + + private def testClassifierEquivalence(train: DataFrame, testParams: Map[String, Any]): Unit = { + val oldTree = setParams(new DecisionTreeClassifier(), testParams) + val newTree = setParams(new OptimizedDecisionTreeClassifier(), testParams) + val newModel = newTree.fit(train) + val model = oldTree.fit(train) + OptimizedTreeTests.checkEqual(model, newModel) + } + + test("Local & distributed training produce the same tree on a toy dataset") { + val data = sc.parallelize(Range(0, 8).map(x => LabeledPoint(x, Vectors.dense(x)))) + val df = spark.createDataFrame(data) + testEquivalence(df, OptimizedTreeTests.allParamSettings) + testClassifierEquivalence(df, OptimizedTreeTests.allParamSettings) + } + + test("Local & distributed training produce the same tree with two feature values") { + val data = sc.parallelize(Range(0, 8).map(x => { + if (x > 3) { + LabeledPoint(x, Vectors.dense(0.0)) + } else { + LabeledPoint(x, Vectors.dense(1.0)) + }})) + val df = spark.createDataFrame(data) + testEquivalence(df, OptimizedTreeTests.allParamSettings) + testClassifierEquivalence(df, OptimizedTreeTests.allParamSettings) + } + + test("Local & distributed training produce the same tree on a slightly larger toy dataset") { + val data = sc.parallelize(Range(0, 10).map(x => LabeledPoint(x, Vectors.dense(x)))) + val df = spark.createDataFrame(data) + testEquivalence(df, medDepthTreeSettings) + } + + test("Local & distributed training produce the same tree on a larger toy dataset") { + val data = sc.parallelize(Range(0, 64).map(x => LabeledPoint(x, Vectors.dense(x)))) + val df = spark.createDataFrame(data) + testEquivalence(df, medDepthTreeSettings) + } + + test("Local & distributed training produce same tree on a dataset of categorical features") { + val data = sc.parallelize(DecisionTreeSuite.generateCategoricalDataPoints().map(_.asML)) + // Create a map of categorical feature index to arity; each feature has arity nclasses + val featuresMap: Map[Int, Int] = Map(0 -> 3, 1 -> 3) + // Convert the data RDD to a DataFrame with metadata indicating the arity of each of its + // categorical features + val df = OptimizedTreeTests.setMetadata(data, featuresMap, numClasses = 2) + testEquivalence(df, OptimizedTreeTests.allParamSettings) + } + + test("Local & distributed training produce the same tree on a dataset of continuous features") { + val sqlContext = spark.sqlContext + import sqlContext.implicits._ + // Use maxDepth = 5 and default params + val params = medDepthTreeSettings + val data = LogisticRegressionDataGenerator.generateLogisticRDD(spark.sparkContext, + nexamples = 1000, nfeatures = 5, eps = 2.0, nparts = 1, probOne = 0.2) + .map(_.asML).toDF().cache() + testEquivalence(data, params) + } + + test("Local & distributed training produce the same tree on a dataset of constant features") { + // Generate constant, continuous data + val data = sc.parallelize(Range(0, 8).map(_ => LabeledPoint(1, Vectors.dense(1)))) + val df = spark.createDataFrame(data) + testEquivalence(df, OptimizedTreeTests.allParamSettings) + } + +} diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForestSuite.scala b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForestSuite.scala new file mode 100755 index 0000000..ad1b717 --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedRandomForestSuite.scala @@ -0,0 +1,663 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.annotation.tailrec +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.DecisionTreeClassificationModel +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.tree._ +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper} +import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, OptimizedForestStrategy => OldStrategy} +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator, Variance} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.collection.OpenHashMap + +/** + * Test suite for [[RandomForest]]. + */ +class OptimizedRandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { + + import OptimizedRandomForestSuite.mapToVec + + private val seed = 42 + + ///////////////////////////////////////////////////////////////////////////// + // Tests for split calculation + ///////////////////////////////////////////////////////////////////////////// + + test("Binary classification with continuous features: split calculation") { + val arr = OldDTSuite.generateOrderedLabeledPointsWithLabel1().map(_.asML) + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, 3, 2, 100) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + val splits = OptimizedRandomForest.findSplits(rdd, metadata, seed = 42) + assert(splits.length === 2) + assert(splits(0).length === 99) + } + + test("Binary classification with binary (ordered) categorical features: split calculation") { + val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML) + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, + maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1 -> 2)) + + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + val splits = OptimizedRandomForest.findSplits(rdd, metadata, seed = 42) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + assert(splits.length === 2) + // no splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + } + + test("Binary classification with 3-ary (ordered) categorical features," + + " with no samples for one category: split calculation") { + val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML) + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 2, + maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val splits = OptimizedRandomForest.findSplits(rdd, metadata, seed = 42) + assert(splits.length === 2) + // no splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + } + + test("find splits for a continuous feature") { + // find splits for normal case + { + val fakeMetadata = new DecisionTreeMetadata(1, 200000, 0, 0, + Map(), Set(), + Array(6), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array.fill(10000)(math.random).filter(_ != 0.0) + val splits = OptimizedRandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 5) + assert(fakeMetadata.numSplits(0) === 5) + assert(fakeMetadata.numBins(0) === 6) + // check returned splits are distinct + assert(splits.distinct.length === splits.length) + } + + // SPARK-16957: Use midpoints for split values. + { + val fakeMetadata = new DecisionTreeMetadata(1, 8, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + + // TODO: Why doesn't this work after filtering 0.0? + // possibleSplits <= numSplits + { + val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble) + val splits = OptimizedRandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expectedSplits = Array((0.0 + 1.0) / 2) + assert(splits === expectedSplits) + } + + // possibleSplits > numSplits + { + val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble) + val splits = OptimizedRandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2) + assert(splits === expectedSplits) + } + } + + // find splits should not return identical splits + // when there are not enough split candidates, reduce the number of splits in metadata + { + val fakeMetadata = new DecisionTreeMetadata(1, 12, 0, 0, + Map(), Set(), + Array(5), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble) + val splits = OptimizedRandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2) + assert(splits === expectedSplits) + // check returned splits are distinct + assert(splits.distinct.length === splits.length) + } + + // find splits when most samples close to the minimum + { + val fakeMetadata = new DecisionTreeMetadata(1, 18, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5) + .map(_.toDouble) + val splits = OptimizedRandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2) + assert(splits === expectedSplits) + } + + // find splits when most samples close to the maximum + { + val fakeMetadata = new DecisionTreeMetadata(1, 17, 0, 0, + Map(), Set(), + Array(2), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2) + .map(_.toDouble).filter(_ != 0.0) + val splits = OptimizedRandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + val expectedSplits = Array((1.0 + 2.0) / 2) + assert(splits === expectedSplits) + } + + // find splits for constant feature + { + val fakeMetadata = new DecisionTreeMetadata(1, 3, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 0, 0).map(_.toDouble).filter(_ != 0.0) + val featureSamplesEmpty = Array.empty[Double] + val splits = OptimizedRandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits === Array.empty[Double]) + val splitsEmpty = + OptimizedRandomForest.findSplitsForContinuousFeature(featureSamplesEmpty, fakeMetadata, 0) + assert(splitsEmpty === Array.empty[Double]) + } + } + + test("train with empty arrays") { + val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double])) + val data = Array.fill(5)(lp) + val rdd = sc.parallelize(data) + + val strategy = new OldStrategy(OldAlgo.Regression, Gini, maxDepth = 2, + maxBins = 5) + withClue("DecisionTree requires number of features > 0," + + " but was given an empty features vector") { + intercept[IllegalArgumentException] { + OptimizedRandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)._1 + } + } + } + + test("train with constant features") { + val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0)) + val data = Array.fill(5)(lp) + val rdd = sc.parallelize(data) + val strategy = new OldStrategy( + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 2, + maxBins = 5, + categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5)) + val Array(tree) = OptimizedRandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)._1 + assert(tree.rootNode.impurity === -1.0) + assert(tree.depth === 0) + assert(tree.rootNode.prediction === lp.label) + + // Test with no categorical features + val strategy2 = new OldStrategy( + OldAlgo.Regression, + Variance, + maxDepth = 2, + maxBins = 5) + val Array(tree2) = OptimizedRandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None)._1 + assert(tree2.rootNode.impurity === -1.0) + assert(tree2.depth === 0) + assert(tree2.rootNode.prediction === lp.label) + } + + test("Multiclass classification with unordered categorical features: split calculations") { + val arr = OldDTSuite.generateCategoricalDataPoints().map(_.asML) + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy( + OldAlgo.Classification, + Gini, + maxDepth = 2, + numClasses = 100, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(metadata.isUnordered(featureIndex = 0)) + assert(metadata.isUnordered(featureIndex = 1)) + val splits = OptimizedRandomForest.findSplits(rdd, metadata, seed = 42) + assert(splits.length === 2) + assert(splits(0).length === 3) + assert(metadata.numSplits(0) === 3) + assert(metadata.numBins(0) === 3) + assert(metadata.numSplits(1) === 3) + assert(metadata.numBins(1) === 3) + + // Expecting 2^2 - 1 = 3 splits per feature + def checkCategoricalSplit(s: Split, featureIndex: Int, leftCategories: Array[Double]): Unit = { + assert(s.featureIndex === featureIndex) + assert(s.isInstanceOf[CategoricalSplit]) + val s0 = s.asInstanceOf[CategoricalSplit] + assert(s0.leftCategories === leftCategories) + assert(s0.numCategories === 3) // for this unit test + } + // Feature 0 + checkCategoricalSplit(splits(0)(0), 0, Array(0.0)) + checkCategoricalSplit(splits(0)(1), 0, Array(1.0)) + checkCategoricalSplit(splits(0)(2), 0, Array(0.0, 1.0)) + // Feature 1 + checkCategoricalSplit(splits(1)(0), 1, Array(0.0)) + checkCategoricalSplit(splits(1)(1), 1, Array(1.0)) + checkCategoricalSplit(splits(1)(2), 1, Array(0.0, 1.0)) + } + + test("Multiclass classification with ordered categorical features: split calculations") { + val arr = OldDTSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures().map(_.asML) + assert(arr.length === 3000) + val rdd = sc.parallelize(arr) + val strategy = new OldStrategy(OldAlgo.Classification, Gini, maxDepth = 2, numClasses = 100, + maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + // 2^(10-1) - 1 > 100, so categorical features will be ordered + + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val splits = OptimizedRandomForest.findSplits(rdd, metadata, seed = 42) + assert(splits.length === 2) + // no splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests of other algorithm internals + ///////////////////////////////////////////////////////////////////////////// + + test("extract categories from a number for multiclass classification") { + val l = OptimizedRandomForest.extractMultiClassCategories(13, 10) + assert(l.length === 3) + assert(Seq(3.0, 2.0, 0.0) === l) + } + + test("Avoid aggregation on the last level") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val splits = OptimizedRandomForest.findSplits(input, metadata, seed = 42) + + val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, withReplacement = false) + + val topNode = OptimizedLearningNode.emptyNode(nodeIndex = 1) + assert(topNode.isLeaf === false) + assert(topNode.stats === null) + + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new OptimizedRandomForest.NodeIndexInfo(0, None) + )) + val nodeStack = new mutable.ArrayStack[(Int, OptimizedLearningNode)] + val localTrainingStack = new mutable.ListBuffer[LocalTrainingTask] + val maxMemoryUsage = 100 * 1024L * 1024L + val maxMemoryMultiplier = 4.0 + OptimizedRandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, (nodeStack, localTrainingStack), TrainingLimits(100000, metadata.maxDepth)) + + // don't enqueue leaf nodes into node queue + assert(nodeStack.isEmpty) + + // set impurity and predict for topNode + assert(topNode.stats !== null) + assert(topNode.stats.impurity > 0.0) + + // set impurity and predict for child nodes + assert(topNode.leftChild.get.toNode.prediction === 0.0) + assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.stats.impurity === 0.0) + assert(topNode.rightChild.get.stats.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))) + val input = sc.parallelize(arr) + + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 5, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val splits = OptimizedRandomForest.findSplits(input, metadata, seed = 42) + + val treeInput = TreePoint.convertToTreeRDD(input, splits, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, withReplacement = false) + + val topNode = OptimizedLearningNode.emptyNode(nodeIndex = 1) + assert(topNode.isLeaf === false) + assert(topNode.stats === null) + + val nodesForGroup = Map(0 -> Array(topNode)) + val treeToNodeToIndexInfo = Map(0 -> Map( + topNode.id -> new OptimizedRandomForest.NodeIndexInfo(0, None) + )) + val nodeStack = new mutable.ArrayStack[(Int, OptimizedLearningNode)] + val localTrainingStack = new mutable.ListBuffer[LocalTrainingTask] + val maxMemoryUsage = 100 * 1024L * 1024L + val maxMemoryMultiplier = 4.0 + OptimizedRandomForest.findBestSplits(baggedInput, metadata, Map(0 -> topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, (nodeStack, localTrainingStack), TrainingLimits(100000, metadata.maxDepth)) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeStack.isEmpty) + + // set impurity and predict for topNode + assert(topNode.stats !== null) + assert(topNode.stats.impurity > 0.0) + + // set impurity and predict for child nodes + assert(topNode.leftChild.get.toNode.prediction === 0.0) + assert(topNode.rightChild.get.toNode.prediction === 1.0) + assert(topNode.leftChild.get.stats.impurity === 0.0) + assert(topNode.rightChild.get.stats.impurity === 0.0) + } + + test("Use soft prediction for binary classification with ordered categorical features") { + // The following dataset is set up such that the best split is {1} vs. {0, 2}. + // If the hard prediction is used to order the categories, then {0} vs. {1, 2} is chosen. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(0.0)), + LabeledPoint(1.0, Vectors.dense(0.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(1.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(0.0, Vectors.dense(2.0)), + LabeledPoint(1.0, Vectors.dense(2.0))) + val input = sc.parallelize(arr) + + // Must set maxBins s.t. the feature will be treated as an ordered categorical feature. + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 1, + numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3) + + val model = OptimizedRandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all", + seed = 42, instr = None, prune = false)._1.head + + model.rootNode match { + case n: OptimizedInternalNode => n.split match { + case s: CategoricalSplit => + assert(s.leftCategories === Array(1.0)) + case _ => fail("model.rootNode.split was not a CategoricalSplit") + } + case _ => fail("model.rootNode was not an InternalNode") + } + } + + test("Second level node building with vs. without groups") { + val arr = OldDTSuite.generateOrderedLabeledPoints().map(_.asML) + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + // For tree with 1 group + val strategy1 = + new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 1000) + // For tree with multiple groups + val strategy2 = + new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 0) + + val tree1 = OptimizedRandomForest.run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all", + seed = 42, instr = None)._1.head + val tree2 = OptimizedRandomForest.run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all", + seed = 42, instr = None)._1.head + + def getChildren(rootNode: OptimizedNode): Array[OptimizedInternalNode] = rootNode match { + case n: OptimizedInternalNode => + assert(n.leftChild.isInstanceOf[OptimizedInternalNode]) + assert(n.rightChild.isInstanceOf[OptimizedInternalNode]) + Array(n.leftChild.asInstanceOf[OptimizedInternalNode], n.rightChild.asInstanceOf[OptimizedInternalNode]) + case _ => fail("rootNode was not an InternalNode") + } + + // Single group second level tree construction. + val children1 = getChildren(tree1.rootNode) + val children2 = getChildren(tree2.rootNode) + + // Verify whether the splits obtained using single group and multiple group level + // construction strategies are the same. + for (i <- 0 until 2) { + assert(children1(i).gain > 0) + assert(children2(i).gain > 0) + assert(children1(i).split === children2(i).split) + assert(children1(i).impurity === children2(i).impurity) + assert(children1(i).leftChild.impurity === children2(i).leftChild.impurity) + assert(children1(i).rightChild.impurity === children2(i).rightChild.impurity) + assert(children1(i).prediction === children2(i).prediction) + } + } + + def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: OldStrategy) { + val numFeatures = 50 + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) + val rdd = sc.parallelize(arr).map(_.asML) + + // Select feature subset for top nodes. Return true if OK. + def checkFeatureSubsetStrategy( + numTrees: Int, + featureSubsetStrategy: String, + numFeaturesPerNode: Int): Unit = { + val seeds = Array(123, 5354, 230, 349867, 23987) + val maxMemoryUsage: Long = 128 * 1024L * 1024L + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy) + seeds.foreach { seed => + val failString = s"Failed on test with:" + + s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," + + s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed" + val nodeStack = new mutable.ArrayStack[(Int, OptimizedLearningNode)] + val topNodes: Array[OptimizedLearningNode] = new Array[OptimizedLearningNode](numTrees) + Range(0, numTrees).foreach { treeIndex => + topNodes(treeIndex) = OptimizedLearningNode.emptyNode(nodeIndex = 1) + nodeStack.push((treeIndex, topNodes(treeIndex))) + } + val rng = new scala.util.Random(seed = seed) + val (nodesForGroup: Map[Int, Array[OptimizedLearningNode]], + treeToNodeToIndexInfo: Map[Int, Map[Int, OptimizedRandomForest.NodeIndexInfo]]) = + OptimizedRandomForest.selectNodesToSplit(nodeStack, maxMemoryUsage, metadata, rng) + + assert(nodesForGroup.size === numTrees, failString) + assert(nodesForGroup.values.forall(_.length == 1), failString) // 1 node per tree + + if (numFeaturesPerNode == numFeatures) { + // featureSubset values should all be None + assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), + failString) + } else { + // Check number of features. + assert(treeToNodeToIndexInfo.values.forall(_.values.forall( + _.featureSubset.get.length === numFeaturesPerNode)), failString) + } + } + } + + checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures) + checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures) + checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 1, "log2", + (math.log(numFeatures) / math.log(2)).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) + + val realStrategies = Array(".1", ".10", "0.10", "0.1", "0.9", "1.0") + for (strategy <- realStrategies) { + val expected = (strategy.toDouble * numFeatures).ceil.toInt + checkFeatureSubsetStrategy(numTrees = 1, strategy, expected) + } + + val integerStrategies = Array("1", "10", "100", "1000", "10000") + for (strategy <- integerStrategies) { + val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures + checkFeatureSubsetStrategy(numTrees = 1, strategy, expected) + } + + val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0") + for (invalidStrategy <- invalidStrategies) { + intercept[IllegalArgumentException] { + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy) + } + } + + checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) + checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "log2", + (math.log(numFeatures) / math.log(2)).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) + + for (strategy <- realStrategies) { + val expected = (strategy.toDouble * numFeatures).ceil.toInt + checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) + } + + for (strategy <- integerStrategies) { + val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures + checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) + } + for (invalidStrategy <- invalidStrategies) { + intercept[IllegalArgumentException] { + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy) + } + } + } + + test("Binary classification with continuous features: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) + } + + test("Binary classification with continuous features and node Id cache: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 2, + numClasses = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + useNodeIdCache = true) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) + } + + test("normalizeMapValues") { + val map = new OpenHashMap[Int, Double]() + map(0) = 1.0 + map(2) = 2.0 + TreeEnsembleModel.normalizeMapValues(map) + val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0) + assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01) + } + + /////////////////////////////////////////////////////////////////////////////// + // Tests for pruning of redundant subtrees (generated by a split improving the + // impurity measure, but always leading to the same prediction). + /////////////////////////////////////////////////////////////////////////////// + + test("SPARK-3159 tree model redundancy - classification") { + // The following dataset is set up such that splitting over feature_1 for points having + // feature_0 = 0 improves the impurity measure, despite the prediction will always be 0 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val numClasses = 2 + val strategy = new OldStrategy(algo = OldAlgo.Classification, impurity = Gini, maxDepth = 4, + numClasses = numClasses, maxBins = 32) + + val prunedTree = OptimizedRandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None)._1.head + + val unprunedTree = OptimizedRandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false)._1.head + + assert(prunedTree.numNodes === 5) + assert(unprunedTree.numNodes === 7) + } + + test("SPARK-3159 tree model redundancy - regression") { + // The following dataset is set up such that splitting over feature_0 for points having + // feature_1 = 1 improves the impurity measure, despite the prediction will always be 0.5 + // in both branches. + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(0.5, Vectors.dense(1.0, 1.0)) + ) + val rdd = sc.parallelize(arr) + + val strategy = new OldStrategy(algo = OldAlgo.Regression, impurity = Variance, maxDepth = 4, + numClasses = 0, maxBins = 32) + + val prunedTree = OptimizedRandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None)._1.head + + val unprunedTree = OptimizedRandomForest.run(rdd, strategy, numTrees = 1, featureSubsetStrategy = "auto", + seed = 42, instr = None, prune = false)._1.head + + assert(prunedTree.numNodes === 3) + assert(unprunedTree.numNodes === 5) + } +} + +private object OptimizedRandomForestSuite { + def mapToVec(map: Map[Int, Double]): Vector = { + val size = (map.keys.toSeq :+ 0).max + 1 + val (indices, values) = map.toSeq.sortBy(_._1).unzip + Vectors.sparse(size, indices.toArray, values.toArray) + } +} diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedTreeTests.scala b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedTreeTests.scala new file mode 100755 index 0000000..1b30d95 --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/tree/impl/OptimizedTreeTests.scala @@ -0,0 +1,319 @@ +/* + * Modified work Copyright (C) 2019 Cisco Systems + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} +import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, OptimizedDecisionTreeClassificationModel} +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, OptimizedDecisionTreeRegressionModel} +import org.apache.spark.ml.tree._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Impurity} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.{SparkContext, SparkFunSuite} + +import scala.collection.JavaConverters._ + +private[ml] object OptimizedTreeTests extends SparkFunSuite { + + /** + * Convert the given data to a DataFrame, and set the features and label metadata. + * @param data Dataset. Categorical features and labels must already have 0-based indices. + * This must be non-empty. + * @param categoricalFeatures Map: categorical feature index to number of distinct values + * @param numClasses Number of classes label can take. If 0, mark as continuous. + * @return DataFrame with metadata + */ + def setMetadata( + data: RDD[LabeledPoint], + categoricalFeatures: Map[Int, Int], + numClasses: Int): DataFrame = { + val spark = SparkSession.builder() + .sparkContext(data.sparkContext) + .getOrCreate() + import spark.implicits._ + + val df = data.toDF() + val numFeatures = data.first().features.size + val featuresAttributes = Range(0, numFeatures).map { feature => + if (categoricalFeatures.contains(feature)) { + NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature)) + } else { + NumericAttribute.defaultAttr.withIndex(feature) + } + }.toArray + val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata() + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName("label") + } else { + NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + df.select(df("features").as("features", featuresMetadata), + df("label").as("label", labelMetadata)) + } + + /** + * Java-friendly version of `setMetadata()` + */ + def setMetadata( + data: JavaRDD[LabeledPoint], + categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer], + numClasses: Int): DataFrame = { + setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numClasses) + } + + /** + * Set label metadata (particularly the number of classes) on a DataFrame. + * @param data Dataset. Categorical features and labels must already have 0-based indices. + * This must be non-empty. + * @param numClasses Number of classes label can take. If 0, mark as continuous. + * @param labelColName Name of the label column on which to set the metadata. + * @param featuresColName Name of the features column + * @return DataFrame with metadata + */ + def setMetadata( + data: DataFrame, + numClasses: Int, + labelColName: String, + featuresColName: String): DataFrame = { + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName(labelColName) + } else { + NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + data.select(data(featuresColName), data(labelColName).as(labelColName, labelMetadata)) + } + + /** Returns a DecisionTreeMetadata instance with hard-coded values for use in tests */ + def getMetadata( + numExamples: Int, + numFeatures: Int, + numClasses: Int, + featureArity: Map[Int, Int], + impurity: Impurity = Entropy, + unorderedFeatures: Option[Set[Int]] = None): DecisionTreeMetadata = { + // By default, assume all categorical features within tests + // have small enough arity to be treated as unordered + val unordered = unorderedFeatures.getOrElse(featureArity.keys.toSet) + + // Set numBins appropriately for categorical features + val maxBins = 4 + val numBins: Array[Int] = 0.until(numFeatures).toArray.map { featureIndex => + if (featureArity.contains(featureIndex) && featureArity(featureIndex) > 0) { + featureArity(featureIndex) + } else { + maxBins + } + } + + new DecisionTreeMetadata(numFeatures = numFeatures, numExamples = numExamples, + numClasses = numClasses, maxBins = maxBins, minInfoGain = 0.0, featureArity = featureArity, + unorderedFeatures = unordered, numBins = numBins, impurity = impurity, + quantileStrategy = null, maxDepth = 5, minInstancesPerNode = 1, numTrees = 1, + numFeaturesPerNode = 2) + } + + /** + * Check if the two trees are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + * If the trees are not equal, this prints the two trees and throws an exception. + */ + def checkEqual(a: DecisionTreeModel, b: OptimizedDecisionTreeModel): Unit = { + try { + checkEqual(a.rootNode, b.rootNode) + } catch { + case ex: Exception => + throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + "TREE A:\n" + a.toDebugString + "\n" + + "TREE B:\n" + b.toDebugString + "\n", ex) + } + } + + def checkEqual(a: OptimizedDecisionTreeModel, b: OptimizedDecisionTreeModel): Unit = { + try { + checkEqual(a.rootNode, b.rootNode) + } catch { + case ex: Exception => + throw new AssertionError("checkEqual failed since the two trees were not identical.\n" + + "TREE A:\n" + a.toDebugString + "\n" + + "TREE B:\n" + b.toDebugString + "\n", ex) + } + } + + /** + * Return true iff the two nodes and their descendants are exactly the same. + * Note: I hesitate to override Node.equals since it could cause problems if users + * make mistakes such as creating loops of Nodes. + */ + private def checkEqual(a: Node, b: OptimizedNode): Unit = { + assert(a.prediction === b.prediction) + assert(a.impurity === b.impurity) +// assert(a.impurityStats.stats === b.impurityStats.stats) + (a, b) match { + case (aye: InternalNode, bee: OptimizedInternalNode) => + assert(aye.split === bee.split) + checkEqual(aye.leftChild, bee.leftChild) + checkEqual(aye.rightChild, bee.rightChild) + case (aye: LeafNode, bee: OptimizedLeafNode) => // do nothing + case _ => + println(a.getClass.getCanonicalName, b.getClass.getCanonicalName) + throw new AssertionError("Found mismatched nodes") + } + } + + /** + * Check if the two models are exactly the same. + * If the models are not equal, this throws an exception. + */ + private def checkEqual(a: OptimizedNode, b: OptimizedNode): Unit = { + assert(a.prediction === b.prediction) + assert(a.impurity === b.impurity) + // assert(a.impurityStats.stats === b.impurityStats.stats) + (a, b) match { + case (aye: OptimizedInternalNode, bee: OptimizedInternalNode) => + assert(aye.split === bee.split) + checkEqual(aye.leftChild, bee.leftChild) + checkEqual(aye.rightChild, bee.rightChild) + case (aye: OptimizedLeafNode, bee: OptimizedLeafNode) => // do nothing + case _ => + println(a.getClass.getCanonicalName, b.getClass.getCanonicalName) + throw new AssertionError("Found mismatched nodes") + } + } + + def checkEqualOldClassification(a: TreeEnsembleModel[DecisionTreeClassificationModel], b: OptimizedTreeEnsembleModel[OptimizedDecisionTreeClassificationModel]): Unit = { + try { + a.trees.zip(b.trees).foreach { case (treeA, treeB) => + OptimizedTreeTests.checkEqual(treeA, treeB) + } + assert(a.treeWeights === b.treeWeights) + } catch { + case ex: Exception => throw new AssertionError( + "checkEqual failed since the two tree ensembles were not identical") + } + } + + def checkEqualOldRegression(a: TreeEnsembleModel[DecisionTreeRegressionModel], b: OptimizedTreeEnsembleModel[OptimizedDecisionTreeRegressionModel]): Unit = { + try { + a.trees.zip(b.trees).foreach { case (treeA, treeB) => + OptimizedTreeTests.checkEqual(treeA, treeB) + } + assert(a.treeWeights === b.treeWeights) + } catch { + case ex: Exception => throw new AssertionError( + "checkEqual failed since the two tree ensembles were not identical") + } + } + + + def checkEqual[M <: OptimizedDecisionTreeModel](a: OptimizedTreeEnsembleModel[M], b: OptimizedTreeEnsembleModel[M]): Unit = { + try { + a.trees.zip(b.trees).foreach { case (treeA, treeB) => + OptimizedTreeTests.checkEqual(treeA, treeB) + } + assert(a.treeWeights === b.treeWeights) + } catch { + case ex: Exception => throw new AssertionError( + "checkEqual failed since the two tree ensembles were not identical") + } + } + + + /** + * Create some toy data for testing feature importances. + */ + def featureImportanceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 1)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 1, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)), + new LabeledPoint(0, Vectors.dense(1, 0, 0, 0, 0)), + new LabeledPoint(1, Vectors.dense(1, 1, 0, 0, 0)) + )) + + /** + * Create some toy data for testing correctness of variance. + */ + def varianceData(sc: SparkContext): RDD[LabeledPoint] = sc.parallelize(Seq( + new LabeledPoint(1.0, Vectors.dense(Array(0.0))), + new LabeledPoint(2.0, Vectors.dense(Array(1.0))), + new LabeledPoint(3.0, Vectors.dense(Array(2.0))), + new LabeledPoint(10.0, Vectors.dense(Array(3.0))), + new LabeledPoint(12.0, Vectors.dense(Array(4.0))), + new LabeledPoint(14.0, Vectors.dense(Array(5.0))) + )) + + /** + * Create toy data that can be used for testing deep tree training; the generated data requires + * [[depth]] splits to split fully. Thus a tree fit on the generated data should have a depth of + * [[depth]] (unless splitting halts early due to other constraints e.g. max depth or min + * info gain). + */ + def deepTreeData(sc: SparkContext, depth: Int): RDD[LabeledPoint] = { + // Create a dataset with [[depth]] binary features; a training point has a label of 1 + // iff all features have a value of 1. + sc.parallelize(Range(0, depth + 1).map { idx => + val features = Array.fill[Double](depth)(1) + if (idx == depth) { + LabeledPoint(1.0, Vectors.dense(features)) + } else { + features(idx) = 0.0 + LabeledPoint(0.0, Vectors.dense(features)) + } + }) + } + + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + * + * This set of Params is for all Decision Tree-based models. + */ + val allParamSettings: Map[String, Any] = Map( + "checkpointInterval" -> 7, + "seed" -> 543L, + "maxDepth" -> 2, + "maxBins" -> 20, + "minInstancesPerNode" -> 2, + "minInfoGain" -> 1e-14, + "maxMemoryInMB" -> 257, + "cacheNodeIds" -> true + ) + + /** Data for tree read/write tests which produces a non-trivial tree. */ + def getTreeReadWriteData(sc: SparkContext): RDD[LabeledPoint] = { + val arr = Array( + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), + LabeledPoint(0.0, Vectors.dense(0.0, 2.0)), + LabeledPoint(0.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 1.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 0.0)), + LabeledPoint(1.0, Vectors.dense(1.0, 2.0))) + sc.parallelize(arr) + } +} diff --git a/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala b/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala new file mode 100755 index 0000000..f8d1ef9 --- /dev/null +++ b/src/test/scala/org/apache/spark/ml/tree/impl/TreeSplitUtilsSuite.scala @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.tree.{CategoricalSplit, ContinuousSplit, Split} +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.tree.impurity.{Entropy, Impurity} +import org.apache.spark.mllib.tree.model.ImpurityStats +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** Suite exercising helper methods for making split decisions during decision tree training. */ +class TreeSplitUtilsSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + /** + * Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated + * with the data from the specified training points. + */ + private def getAggregator( + metadata: DecisionTreeMetadata, + col: FeatureColumn, + from: Int, + to: Int, + labels: Array[Double], + featureSplits: Array[Split]): DTStatsAggregator = { + + val statsAggregator = new DTStatsAggregator(metadata, featureSubset = None) + val instanceWeights = Array.fill[Double](col.values.length)(1.0) + val indices = col.values.indices.toArray + AggUpdateUtils.updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels) + (new LocalDecisionTree).updateAggregator(statsAggregator, col, indices, instanceWeights, labels, + from, to, col.featureIndex, featureSplits) + statsAggregator + } + + /** Check that left/right impurities match what we'd expect for a split. */ + private def validateImpurityStats( + impurity: Impurity, + labels: Array[Double], + stats: ImpurityStats, + expectedLeftStats: Array[Double], + expectedRightStats: Array[Double]): Unit = { + // Verify that impurity stats were computed correctly for split + val numClasses = (labels.max + 1).toInt + val fullImpurityStatsArray + = Array.tabulate[Double](numClasses)((label: Int) => labels.count(_ == label).toDouble) + val fullImpurity = Entropy.calculate(fullImpurityStatsArray, labels.length) + assert(stats.impurityCalculator.stats === fullImpurityStatsArray) + assert(stats.impurity === fullImpurity) + assert(stats.leftImpurityCalculator.stats === expectedLeftStats) + assert(stats.rightImpurityCalculator.stats === expectedRightStats) + assert(stats.valid) + } + + /* * * * * * * * * * * Choosing Splits * * * * * * * * * * */ + + test("chooseSplit: choose correct type of split (continuous split)") { + // Construct (binned) continuous data + val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0) + val col = FeatureColumn(featureIndex = 0, values = Array(8, 1, 1, 2, 3, 5, 6)) + // Get an array of continuous splits corresponding to values in our binned data + val splits = LocalTreeTests.getContinuousSplits(1.to(8).toArray, featureIndex = 0) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = OptimizedTreeTests.getMetadata(numExamples = 7, + numFeatures = 1, numClasses = 2, Map.empty) + val statsAggregator = getAggregator(metadata, col, from = 1, to = 4, labels, splits) + // Choose split, check that it's a valid ContinuousSplit + val (split1, stats1) = SplitUtils.chooseSplit(statsAggregator, col.featureIndex, + col.featureIndex, splits) + assert(stats1.valid && split1.isInstanceOf[ContinuousSplit]) + } + + test("chooseOrderedCategoricalSplit: basic case") { + // Helper method for testing ordered categorical split + def testHelper( + values: Array[Int], + labels: Array[Double], + expectedLeftCategories: Array[Double], + expectedLeftStats: Array[Double], + expectedRightStats: Array[Double]): Unit = { + val featureIndex = 0 + // Construct FeatureVector to store categorical data + val featureArity = values.max + 1 + val arityMap = Map[Int, Int](featureIndex -> featureArity) + val col = FeatureColumn(featureIndex = 0, values = values) + // Construct DTStatsAggregator, compute sufficient stats + val metadata = OptimizedTreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, arityMap, unorderedFeatures = Some(Set.empty)) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, + labels, featureSplits = Array.empty) + // Choose split + val (split, stats) = + SplitUtils.chooseOrderedCategoricalSplit(statsAggregator, col.featureIndex, + col.featureIndex) + // Verify that split has the expected left-side/right-side categories + val expectedRightCategories = Range(0, featureArity) + .filter(c => !expectedLeftCategories.contains(c)).map(_.toDouble).toArray + split match { + case s: CategoricalSplit => + assert(s.featureIndex === featureIndex) + assert(s.leftCategories === expectedLeftCategories) + assert(s.rightCategories === expectedRightCategories) + case _ => + throw new AssertionError( + s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}") + } + validateImpurityStats(Entropy, labels, stats, expectedLeftStats, expectedRightStats) + } + + val values = Array(0, 0, 1, 2, 2, 2, 2) + val labels1 = Array(0, 0, 1, 1, 1, 1, 1).map(_.toDouble) + testHelper(values, labels1, Array(0.0), Array(2.0, 0.0), Array(0.0, 5.0)) + + val labels2 = Array(0, 0, 0, 1, 1, 1, 1).map(_.toDouble) + testHelper(values, labels2, Array(0.0, 1.0), Array(3.0, 0.0), Array(0.0, 4.0)) + } + + test("chooseContinuousSplit: basic case") { + // Construct data for continuous feature + val featureIndex = 0 + val thresholds = Array(0, 1, 2, 3) + val values = thresholds.indices.toArray + val labels = Array(0.0, 0.0, 1.0, 1.0) + val col = FeatureColumn(featureIndex = featureIndex, values = values) + + // Construct DTStatsAggregator, compute sufficient stats + val splits = LocalTreeTests.getContinuousSplits(thresholds, featureIndex) + val metadata = OptimizedTreeTests.getMetadata(numExamples = values.length, numFeatures = 1, + numClasses = 2, Map.empty) + val statsAggregator = getAggregator(metadata, col, from = 0, to = values.length, labels, splits) + + // Choose split, verify that it has expected threshold + val (split, stats) = SplitUtils.chooseContinuousSplit(statsAggregator, featureIndex, + featureIndex, splits) + split match { + case s: ContinuousSplit => + assert(s.featureIndex === featureIndex) + assert(s.threshold === 1) + case _ => + throw new AssertionError( + s"Expected ContinuousSplit but got ${split.getClass.getSimpleName}") + } + // Verify impurity stats of split + validateImpurityStats(Entropy, labels, stats, expectedLeftStats = Array(2.0, 0.0), + expectedRightStats = Array(0.0, 2.0)) + } +}