Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
package org.trustedanalytics.sparktk.models

import java.io.{ FileOutputStream, File }
import java.net.URI
import java.net.{ URL, URI }
import java.nio.DoubleBuffer
import java.nio.file.{ Files, Path }
import org.apache.commons.lang.StringUtils
import org.apache.hadoop.conf.Configuration
import org.apache.commons.io.{ IOUtils, FileUtils }
import org.apache.spark.SparkContext
import org.trustedanalytics.model.archive.format.ModelArchiveFormat
import org.trustedanalytics.sparktk.saveload.SaveLoad

Expand Down Expand Up @@ -91,7 +92,8 @@ object ScoringModelUtils {
* @param sourcePath Path to source location. Defaults to use the path to the currently running jar.
* @return full path to the location of the MAR file for Scoring Engine
*/
def saveToMar(marSavePath: String,
def saveToMar(sc: SparkContext,
marSavePath: String,
modelClass: String,
modelSrcDir: java.nio.file.Path,
modelReader: String = classOf[SparkTkModelAdapter].getName,
Expand All @@ -114,24 +116,81 @@ object ScoringModelUtils {
val x = new TkSearchPath(absolutePath.substring(0, absolutePath.lastIndexOf("/")))
var jarFileList = x.jarsInSearchPath.values.toList

if (marSavePath.startsWith("hdfs")) {
val protocol = getProtocol(marSavePath)

if ("file".equalsIgnoreCase(protocol)) {
print("Local")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove or enhance these prints (or log)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the debug statements

jarFileList = jarFileList ::: List(new File(modelSrcDir.toString))
}
else {
print("not local")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

val modelFile = Files.createTempDirectory("localModel")
val localModelPath = new org.apache.hadoop.fs.Path(modelFile.toString)
val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(modelFile.toString), new Configuration())
hdfsFileSystem.copyToLocalFile(new org.apache.hadoop.fs.Path(modelSrcDir.toString), localModelPath)
jarFileList = jarFileList ::: List(new File(localModelPath.toString))
}
else {
jarFileList = jarFileList ::: List(new File(modelSrcDir.toString))
}
ModelArchiveFormat.write(jarFileList, modelReader, modelClass, zipOutStream)

}
SaveLoad.saveMar(marSavePath, zipFile)
SaveLoad.saveMar(sc, marSavePath, zipFile)
}
finally {
FileUtils.deleteQuietly(zipFile)
IOUtils.closeQuietly(zipOutStream)
}
}

/**
* Returns the protocol for a given URI or filename.
*
* @param source Determine the protocol for this URI or filename.
*
* @return The protocol for the given source.
*/
def getProtocol(source: String): String = {
require(source != null, "marfile source must not be null")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check for empty as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


var protocol: String = null
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make protocol a val and assign to the try statement directly

try {
val uri = new URI(source)

if (uri.isAbsolute) {
protocol = uri.getScheme
}
else {
val url = new URL(source)
protocol = url.getProtocol
}

}
catch {
case ex: Exception =>
if (source.startsWith("//")) {
throw new IllegalArgumentException("Relative context: " + source)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Relative context" - not sure how that message helps?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed it to: "Does not support Relative context: " + source

}
else {
val file = new File(source)
protocol = getProtocol(file)
}
}
protocol
}

/**
* Returns the protocol for a given file.
*
* @param file Determine the protocol for this file.
*
* @return The protocol for the given file.
*/
private def getProtocol(file: File): String = {
var result: String = null
try {
result = file.toURI.toURL.getProtocol
}
catch {
case ex: Exception => result = "unknown"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of this catch and return "unknown" --this mean you've go to have logic somewhere else to make sense of it. It's stronger to let the exception run free or at least have this method return a Try object.

}
result
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ case class LogisticRegressionModel private[logistic_regression] (observationColu
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[LogisticRegressionModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[LogisticRegressionModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ case class NaiveBayesModel private[naive_bayes] (sparkModel: SparkNaiveBayesMode
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[NaiveBayesModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[NaiveBayesModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ case class RandomForestClassifierModel private[random_forest_classifier] (sparkM
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString, overwrite = true)
ScoringModelUtils.saveToMar(marSavePath, classOf[RandomForestClassifierModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[RandomForestClassifierModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ case class SvmModel private[svm] (sparkModel: SparkSvmModel,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[SvmModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[SvmModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ case class GaussianMixtureModel private[gmm] (observationColumns: Seq[String],
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[GaussianMixtureModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[GaussianMixtureModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ case class KMeansModel private[kmeans] (columns: Seq[String],
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[KMeansModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[KMeansModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ case class LdaModel private[lda] (documentColumnName: String,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[LdaModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[LdaModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ case class PcaModel private[pca] (columns: Seq[String],
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[PcaModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[PcaModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ case class CollaborativeFilteringModel(sourceColumnName: String,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[CollaborativeFilteringModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[CollaborativeFilteringModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ case class LinearRegressionModel(observationColumns: Seq[String],
// The spark linear regression model save will fail, if we don't specify the "overwrite", since the temp
// directory has already been created.
save(sc, tmpDir.toString, overwrite = true)
ScoringModelUtils.saveToMar(marSavePath, classOf[LinearRegressionModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[LinearRegressionModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ case class RandomForestRegressorModel private[random_forest_regressor] (sparkMod
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString, overwrite = true)
ScoringModelUtils.saveToMar(marSavePath, classOf[RandomForestRegressorModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[RandomForestRegressorModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ case class CoxProportionalHazardsModel private[cox_ph] (sparkModel: CoxPhModel,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString, overwrite = true)
ScoringModelUtils.saveToMar(marSavePath, classOf[CoxProportionalHazardsModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[CoxProportionalHazardsModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ case class ArimaModel private[arima] (ts: DenseVector,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[ArimaModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[ArimaModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ case class ArimaxModel private[arimax] (timeseriesColumn: String,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[ArimaxModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[ArimaxModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ case class ArxModel private[arx] (timeseriesColumn: String,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[ArxModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[ArxModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ case class MaxModel private[max] (timeseriesColumn: String,
try {
tmpDir = Files.createTempDirectory("sparktk-scoring-model")
save(sc, tmpDir.toString)
ScoringModelUtils.saveToMar(marSavePath, classOf[MaxModel].getName, tmpDir)
ScoringModelUtils.saveToMar(sc, marSavePath, classOf[MaxModel].getName, tmpDir)
}
finally {
sys.addShutdownHook(FileUtils.deleteQuietly(tmpDir.toFile)) // Delete temporary directory on exit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.trustedanalytics.sparktk.saveload

import java.io.File
import java.nio.file.Files
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.Path
import java.net.URI
Expand All @@ -28,6 +29,7 @@ import org.json4s.jackson.Serialization
import org.json4s.{ NoTypeHints, Extraction, DefaultFormats }
import org.json4s.jackson.JsonMethods._
import org.json4s.JsonDSL._
import org.trustedanalytics.sparktk.models.ScoringModelUtils

/**
* Simple save/load library which uses json4s to read/write text files, including info for format validation
Expand Down Expand Up @@ -56,21 +58,24 @@ object SaveLoad {
* @param zipFile the MAR file to be stored
* @return full path to the location of the MAR file
*/
def saveMar(storagePath: String, zipFile: File): String = {
if (storagePath.startsWith("hdfs")) {
def saveMar(sc: SparkContext, storagePath: String, zipFile: File): String = {

val protocol = ScoringModelUtils.getProtocol(storagePath)

if ("file".equalsIgnoreCase(protocol)) {
print("Local")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

improve print message.

Does this really mean local? What about s3? --does it fall under local or hdsf? Could provide more help in making the right set of options, besides local and hdfs?

Also, this reveals to me that the protocol/path work you've added above is not unique to ScoringModelUtils, but more generally to sparktk saveload. That logic should move to this more generic location. What do you think?

val file = new File(storagePath)
FileUtils.copyFile(zipFile, file)
file.getCanonicalPath
}
else {
val hdfsPath = new Path(storagePath)
val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(storagePath), new Configuration())
val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(storagePath), sc.hadoopConfiguration)
val localPath = new Path(zipFile.getAbsolutePath)
hdfsFileSystem.copyFromLocalFile(false, true, localPath, hdfsPath)
hdfsFileSystem.setPermission(hdfsPath, new FsPermission(FsAction.ALL, FsAction.ALL, FsAction.NONE))
storagePath
}
else {
val file = new File(storagePath)
FileUtils.copyFile(zipFile, file)
file.getCanonicalPath
}

}

/**
Expand Down