-
Notifications
You must be signed in to change notification settings - Fork 28
AWS S3 support #433
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
AWS S3 support #433
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,12 +16,13 @@ | |
| package org.trustedanalytics.sparktk.models | ||
|
|
||
| import java.io.{ FileOutputStream, File } | ||
| import java.net.URI | ||
| import java.net.{ URL, URI } | ||
| import java.nio.DoubleBuffer | ||
| import java.nio.file.{ Files, Path } | ||
| import org.apache.commons.lang.StringUtils | ||
| import org.apache.hadoop.conf.Configuration | ||
| import org.apache.commons.io.{ IOUtils, FileUtils } | ||
| import org.apache.spark.SparkContext | ||
| import org.trustedanalytics.model.archive.format.ModelArchiveFormat | ||
| import org.trustedanalytics.sparktk.saveload.SaveLoad | ||
|
|
||
|
|
@@ -91,7 +92,8 @@ object ScoringModelUtils { | |
| * @param sourcePath Path to source location. Defaults to use the path to the currently running jar. | ||
| * @return full path to the location of the MAR file for Scoring Engine | ||
| */ | ||
| def saveToMar(marSavePath: String, | ||
| def saveToMar(sc: SparkContext, | ||
| marSavePath: String, | ||
| modelClass: String, | ||
| modelSrcDir: java.nio.file.Path, | ||
| modelReader: String = classOf[SparkTkModelAdapter].getName, | ||
|
|
@@ -114,24 +116,81 @@ object ScoringModelUtils { | |
| val x = new TkSearchPath(absolutePath.substring(0, absolutePath.lastIndexOf("/"))) | ||
| var jarFileList = x.jarsInSearchPath.values.toList | ||
|
|
||
| if (marSavePath.startsWith("hdfs")) { | ||
| val protocol = getProtocol(marSavePath) | ||
|
|
||
| if ("file".equalsIgnoreCase(protocol)) { | ||
| print("Local") | ||
| jarFileList = jarFileList ::: List(new File(modelSrcDir.toString)) | ||
| } | ||
| else { | ||
| print("not local") | ||
|
||
| val modelFile = Files.createTempDirectory("localModel") | ||
| val localModelPath = new org.apache.hadoop.fs.Path(modelFile.toString) | ||
| val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(modelFile.toString), new Configuration()) | ||
| hdfsFileSystem.copyToLocalFile(new org.apache.hadoop.fs.Path(modelSrcDir.toString), localModelPath) | ||
| jarFileList = jarFileList ::: List(new File(localModelPath.toString)) | ||
| } | ||
| else { | ||
| jarFileList = jarFileList ::: List(new File(modelSrcDir.toString)) | ||
| } | ||
| ModelArchiveFormat.write(jarFileList, modelReader, modelClass, zipOutStream) | ||
|
|
||
| } | ||
| SaveLoad.saveMar(marSavePath, zipFile) | ||
| SaveLoad.saveMar(sc, marSavePath, zipFile) | ||
| } | ||
| finally { | ||
| FileUtils.deleteQuietly(zipFile) | ||
| IOUtils.closeQuietly(zipOutStream) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Returns the protocol for a given URI or filename. | ||
| * | ||
| * @param source Determine the protocol for this URI or filename. | ||
| * | ||
| * @return The protocol for the given source. | ||
| */ | ||
| def getProtocol(source: String): String = { | ||
| require(source != null, "marfile source must not be null") | ||
|
||
|
|
||
| var protocol: String = null | ||
|
||
| try { | ||
| val uri = new URI(source) | ||
|
|
||
| if (uri.isAbsolute) { | ||
| protocol = uri.getScheme | ||
| } | ||
| else { | ||
| val url = new URL(source) | ||
| protocol = url.getProtocol | ||
| } | ||
|
|
||
| } | ||
| catch { | ||
| case ex: Exception => | ||
| if (source.startsWith("//")) { | ||
| throw new IllegalArgumentException("Relative context: " + source) | ||
|
||
| } | ||
| else { | ||
| val file = new File(source) | ||
| protocol = getProtocol(file) | ||
| } | ||
| } | ||
| protocol | ||
| } | ||
|
|
||
| /** | ||
| * Returns the protocol for a given file. | ||
| * | ||
| * @param file Determine the protocol for this file. | ||
| * | ||
| * @return The protocol for the given file. | ||
| */ | ||
| private def getProtocol(file: File): String = { | ||
| var result: String = null | ||
| try { | ||
| result = file.toURI.toURL.getProtocol | ||
| } | ||
| catch { | ||
| case ex: Exception => result = "unknown" | ||
|
||
| } | ||
| result | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| package org.trustedanalytics.sparktk.saveload | ||
|
|
||
| import java.io.File | ||
| import java.nio.file.Files | ||
| import org.apache.commons.io.FileUtils | ||
| import org.apache.hadoop.fs.Path | ||
| import java.net.URI | ||
|
|
@@ -28,6 +29,7 @@ import org.json4s.jackson.Serialization | |
| import org.json4s.{ NoTypeHints, Extraction, DefaultFormats } | ||
| import org.json4s.jackson.JsonMethods._ | ||
| import org.json4s.JsonDSL._ | ||
| import org.trustedanalytics.sparktk.models.ScoringModelUtils | ||
|
|
||
| /** | ||
| * Simple save/load library which uses json4s to read/write text files, including info for format validation | ||
|
|
@@ -56,21 +58,24 @@ object SaveLoad { | |
| * @param zipFile the MAR file to be stored | ||
| * @return full path to the location of the MAR file | ||
| */ | ||
| def saveMar(storagePath: String, zipFile: File): String = { | ||
| if (storagePath.startsWith("hdfs")) { | ||
| def saveMar(sc: SparkContext, storagePath: String, zipFile: File): String = { | ||
|
|
||
| val protocol = ScoringModelUtils.getProtocol(storagePath) | ||
|
|
||
| if ("file".equalsIgnoreCase(protocol)) { | ||
| print("Local") | ||
|
||
| val file = new File(storagePath) | ||
| FileUtils.copyFile(zipFile, file) | ||
| file.getCanonicalPath | ||
| } | ||
| else { | ||
| val hdfsPath = new Path(storagePath) | ||
| val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(storagePath), new Configuration()) | ||
| val hdfsFileSystem: org.apache.hadoop.fs.FileSystem = org.apache.hadoop.fs.FileSystem.get(new URI(storagePath), sc.hadoopConfiguration) | ||
| val localPath = new Path(zipFile.getAbsolutePath) | ||
| hdfsFileSystem.copyFromLocalFile(false, true, localPath, hdfsPath) | ||
| hdfsFileSystem.setPermission(hdfsPath, new FsPermission(FsAction.ALL, FsAction.ALL, FsAction.NONE)) | ||
| storagePath | ||
| } | ||
| else { | ||
| val file = new File(storagePath) | ||
| FileUtils.copyFile(zipFile, file) | ||
| file.getCanonicalPath | ||
| } | ||
|
|
||
| } | ||
|
|
||
| /** | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove or enhance these prints (or log)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed the debug statements