Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

libtorch C++ jni #2

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
446 changes: 11 additions & 435 deletions README.md

Large diffs are not rendered by default.

82 changes: 27 additions & 55 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,55 +1,27 @@
resolvers +=
"Sonatype OSS Snapshots" at "https://oss.sonatype.org/content/repositories/snapshots"

import Dependencies._

lazy val root = (project in file(".")).settings(
inThisBuild(
List(
organization := "be.botkop",
scalaVersion := "2.12.5",
version := "0.1.2-SNAPSHOT"
)),
name := "scorch",
libraryDependencies += numsca,
libraryDependencies += scalaTest % Test
)

crossScalaVersions := Seq("2.11.12", "2.12.4")

publishTo := {
val nexus = "https://oss.sonatype.org/"
if (isSnapshot.value)
Some("snapshots" at nexus + "content/repositories/snapshots")
else
Some("releases" at nexus + "service/local/staging/deploy/maven2")
}

pomIncludeRepository := { _ =>
false
}

licenses := Seq(
"BSD-style" -> url("http://www.opensource.org/licenses/bsd-license.php"))

homepage := Some(url("https://github.com/botkop"))

scmInfo := Some(
ScmInfo(
url("https://github.com/botkop/scorch"),
"scm:[email protected]:botkop/scorch.git"
)
)

developers := List(
Developer(
id = "botkop",
name = "Koen Dejonghe",
email = "[email protected]",
url = url("https://github.com/botkop")
)
)

publishMavenStyle := true
publishArtifact in Test := false
// skip in publish := true
import sbt._
import sbt.Keys._


version := "1.0"

scalaVersion := "2.12.7"


// https://mvnrepository.com/artifact/org.bytedeco/javacpp
libraryDependencies += "org.bytedeco" % "javacpp" % "1.4.3"
libraryDependencies += "org.scala-lang" % "scala-reflect" % "2.12.7"

enablePlugins(JniGeneratorPlugin, JniBuildPlugin)
JniBuildPlugin.autoImport.torchLibPath in jniBuild := "/home/nazar/libtorch"
//sourceDirectory in nativeCompile := sourceDirectory.value / "native"
//target in nativeCompile :=target.value / "native" / nativePlatform.value


libraryDependencies += "com.typesafe.scala-logging" %% "scala-logging" % "3.7.2"
libraryDependencies += "ch.qos.logback" % "logback-classic" % "1.2.3"

lazy val scalaTest = "org.scalatest" %% "scalatest" % "3.0.3"

libraryDependencies += scalaTest % Test


6 changes: 0 additions & 6 deletions project/Dependencies.scala

This file was deleted.

48 changes: 48 additions & 0 deletions project/JniBuildPlugin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

import sbt._
import sbt.Keys._

import sys.process._


object JniBuildPlugin extends AutoPlugin {

override val trigger: PluginTrigger = noTrigger

override val requires: Plugins = plugins.JvmPlugin

object autoImport extends JniGeneratorKeys {
lazy val jniBuild = taskKey[Unit]("Builds so lib")
}

import autoImport._

override lazy val projectSettings: Seq[Setting[_]] =Seq(

targetGeneratorDir in jniBuild := sourceDirectory.value / "native" ,

targetLibName in jniBuild := "java_torch_lib",

jniBuild := {
val directory = (targetGeneratorDir in jniBuild).value
val cmake_prefix = (torchLibPath in jniBuild).value
val log = streams.value.log

log.info("Build to " + directory.getAbsolutePath)
val command = s"cmake -H$directory -B$directory -DCMAKE_PREFIX_PATH=$cmake_prefix"
log.info(command)
val exitCode = Process(command) ! log
if (exitCode != 0) sys.error(s"An error occurred while running cmake. Exit code: $exitCode.")
val command1 = s"make -C$directory"
log.info(command1)
val exitCode1 = Process(command1) ! log
if (exitCode1 != 0) sys.error(s"An error occurred while running make. Exit code: $exitCode1.")
},

jniBuild := jniBuild.dependsOn(jniGen).value,
compile := (compile in Compile).dependsOn(jniBuild).value,

)


}
125 changes: 125 additions & 0 deletions project/JniGeneratorPlugin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@

import java.io.{File, FileInputStream}

import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Opcodes}

import scala.collection.JavaConverters._
import scala.collection.mutable
import sbt._
import sbt.Keys._

import sys.process._

trait JniGeneratorKeys {

lazy val torchLibPath = settingKey[String]("Path to C++ torch library.")

lazy val targetGeneratorDir = settingKey[File]("target directory to store generated cpp files.")

lazy val targetLibName = settingKey[String]("target cpp file name.")

lazy val builderClass = settingKey[String]("class name that generates cpp file.")

lazy val jniGen = taskKey[Unit]("Generates cpp files")

lazy val javahClasses: TaskKey[Set[String]] = taskKey[Set[String]](
"Finds the fully qualified names of classes containing native declarations.")

}


object JniGeneratorPlugin extends AutoPlugin {

override val trigger: PluginTrigger = noTrigger

override val requires: Plugins = plugins.JvmPlugin

object autoImport extends JniGeneratorKeys

import autoImport._

override lazy val projectSettings: Seq[Setting[_]] =Seq(
javahClasses in jniGen := {
import xsbti.compile._
val compiled: CompileAnalysis = (compile in Compile).value
val classFiles: Set[File] = compiled.readStamps.getAllProductStamps.asScala.keySet.toSet
val nativeClasses = classFiles flatMap { file => findNativeClasses(file) }
nativeClasses
},

targetGeneratorDir in jniGen := sourceDirectory.value / "native" ,

targetLibName in jniGen := "java_torch_lib",

builderClass in jniGen := "generate.Builder",

jniGen := {
val directory = (targetGeneratorDir in jniGen).value
val builder = (builderClass in jniGen).value
val libName = (targetLibName in jniGen).value
// The full classpath cannot be used here since it also generates resources. In a project combining JniJavah and
// JniPackage, we would have a chicken-and-egg problem.
val classPath: String = ((dependencyClasspath in Compile).value.map(_.data) ++ {
Seq((classDirectory in Compile).value)
}).mkString(sys.props("path.separator"))
val classes = (javahClasses in jniGen).value
val log = streams.value.log

if (classes.nonEmpty) {
log.info("Sources will be generated to " + directory.getAbsolutePath)
log.info("Generating header for " + classes.mkString(" "))
val command = s"java -classpath $classPath $builder -d $directory -o $libName ${classes.mkString(" ")}" // " torch_scala.NativeLibraryConfig" }"
log.info(command)
val exitCode = Process(command) ! log
if (exitCode != 0) sys.error(s"An error occurred while running javah. Exit code: $exitCode.")
}
}

)

private class NativeFinder extends ClassVisitor(Opcodes.ASM5) {
private var fullyQualifiedName: String = ""

/** Classes found to contain at least one @native definition. */
private val _nativeClasses = mutable.HashSet.empty[String]

def nativeClasses: Set[String] = _nativeClasses.toSet

override def visit(
version: Int, access: Int, name: String, signature: String, superName: String,
interfaces: Array[String]): Unit = {
fullyQualifiedName = name.replaceAll("/", ".")
}

override def visitMethod(
access: Int, name: String, desc: String, signature: String, exceptions: Array[String]): MethodVisitor = {
val isNative = (access & Opcodes.ACC_NATIVE) != 0
if (isNative)
_nativeClasses += fullyQualifiedName
// Return null, meaning that we do not want to visit the method further.
null
}
}

/** Finds classes containing native implementations (i.e., `@native` definitions).
*
* @param javaFile Java file from which classes are being read.
* @return Set containing all the fully qualified names of classes that contain at least one member annotated with
* the `@native` annotation.
*/
def findNativeClasses(javaFile: File): Set[String] = {
var inputStream: FileInputStream = null
try {
inputStream = new FileInputStream(javaFile)
val reader = new ClassReader(inputStream)
val finder = new NativeFinder
reader.accept(finder, 0)
finder.nativeClasses
} finally {
if (inputStream != null)
inputStream.close()
}
}


}
1 change: 0 additions & 1 deletion project/build.properties

This file was deleted.

25 changes: 24 additions & 1 deletion project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1 +1,24 @@
addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "2.0")
/* Copyright 2017-18, Emmanouil Antonios Platanios. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*/

import sbt.Defaults.sbtPluginExtra

logLevel := Level.Warn

libraryDependencies ++= Seq(
"ch.qos.logback" % "logback-classic" % "1.2.3",
"org.ow2.asm" % "asm" % "6.2.1")


Loading