Skip to content

Commit

Permalink
M8l4 ml (#37)
Browse files Browse the repository at this point in the history
* M8l4 mlkotlin (#35)

* m8l4 ML in Kotlin project

* m8l4 ML in Kotlin project1

* m8l4 ML in Kotlin project3

* m8l4 ML in Kotlin project3

* m8l4 ML in Kotlin project3

(cherry picked from commit a63f224e18f74c2780bff7a2ff6faba585242baf)

* M8l4 ML

(cherry picked from commit 5d42aa43f041bc360c978a8c0ababa9ce74f4911)
  • Loading branch information
svok committed Oct 23, 2024
1 parent ac81140 commit 5e22b48
Show file tree
Hide file tree
Showing 11 changed files with 304 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.onnx_data filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
26 changes: 26 additions & 0 deletions docs/ml-models-list.md

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ mkpl-cor = { module = "ru.otus.otuskotlin.marketplace.libs:ok-marketplace-lib-co
mkpl-state-common = { module = "ru.otus.otuskotlin.marketplace.state:ok-marketplace-states-common", version.ref = "mkpl" }
mkpl-state-biz = { module = "ru.otus.otuskotlin.marketplace.state:ok-marketplace-states-biz", version.ref = "mkpl" }

# Machine Learning
ml-tokenizer = "ai.djl.huggingface:tokenizers:0.25.0"
ml-onnx-runtime = "com.microsoft.onnxruntime:onnxruntime:1.16.3"

[bundles]
kotest = ["kotest-junit5", "kotest-core", "kotest-datatest", "kotest-property"]
Expand Down
8 changes: 8 additions & 0 deletions ok-marketplace-ml/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Kotlin ONNX ML Sample

Демонстрация использования [ONNX](https://onnxruntime.ai/docs/get-started/with-java.html) в Kotlin на примере NLP
модели [Roberta NER model](https://huggingface.co/xlm-roberta-large-finetuned-conll03-english). Необходимо скачать файлы модели (`model.onnx`, `model.onnx_data`, `tokenizer.json`) в
папку [onnx-model](onnx-model).


[Ноутбук с python-моделью](./Ml_demo1.ipynb)
26 changes: 26 additions & 0 deletions ok-marketplace-ml/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
plugins {
id("build-jvm")
}

group = "ru.otus.otuskotlin.marketplace.ml"
version = "0.0.1"

dependencies {
implementation(libs.ml.onnx.runtime)
implementation(libs.ml.tokenizer)
implementation(libs.logback)

testImplementation(kotlin("test-junit5"))
}

tasks {
test {
useJUnitPlatform()
}
}

allprojects {
repositories {
mavenCentral()
}
}
3 changes: 3 additions & 0 deletions ok-marketplace-ml/onnx-model/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*
!/.gitattributes
!/.gitignore
29 changes: 29 additions & 0 deletions ok-marketplace-ml/settings.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
rootProject.name = "ok-marketplace-ml"

dependencyResolutionManagement {
versionCatalogs {
create("libs") {
from(files("../gradle/libs.versions.toml"))
}
}
}

pluginManagement {
includeBuild("../build-plugin")
plugins {
id("build-jvm") apply false
id("build-kmp") apply false
}
repositories {
mavenCentral()
gradlePluginPortal()
}
}

plugins {
id("org.gradle.toolchains.foojay-resolver-convention") version "0.5.0"
}

// Включает вот такую конструкцию
//implementation(projects.m2l5Gradle.sub1.ssub1)
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
150 changes: 150 additions & 0 deletions ok-marketplace-ml/src/main/kotlin/Inferrer.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtException
import ai.onnxruntime.OrtSession
import ai.onnxruntime.OrtSession.SessionOptions
import java.io.IOException
import java.nio.file.Paths

/**
* Основной класс, выполняющий анализ текста с использованием модели машинного обучения
*/
class Inferrer(
modelPath: String = "model.onnx",
private val tokenizerJson: String = "tokenizer.json",
) {
/**
* Токенайзер преобразует текст в набор токенов
*/
private val tokenizer: HuggingFaceTokenizer by lazy {
runCatching { HuggingFaceTokenizer.newInstance(Paths.get(tokenizerJson)) }
.onFailure { e -> e.printStackTrace() }
.getOrThrow()
}

/**
* Onnx-runtime environment - среда исполнения модели
*/
private val env: OrtEnvironment by lazy {
OrtEnvironment.getEnvironment() ?: throw Exception("Failed to get ORT environment")
}

/**
* Onnx-runtime session - сессия среды исполнения модели
*/
private val session: OrtSession by lazy {
val s = env.createSession(modelPath, SessionOptions()) ?: throw Exception("Failed to get session")
println(
"""
Model Input Names: ${s.inputNames.joinToString()}
Model Input info: ${s.inputInfo.entries.joinToString { "${it.key}=${it.value}" }}
Model Output Names: ${s.outputNames.joinToString()}
Model Output info: ${s.outputInfo.entries.joinToString { "${it.key}=${it.value}" }}
""".trimIndent()
)
s
}

/*
Расширение для разбора результатов инференса
separates tokens into arrays according to class ids
below is the relation from class id to the label
"id2label": {
"0": "B-LOC",
"1": "B-MISC",
"2": "B-ORG",
"3": "I-LOC",
"4": "I-MISC",
"5": "I-ORG",
"6": "I-PER",
"7": "O"
* */
private fun InferringResult.post(
clazz: Int,
token: String,
) = when (clazz) {
6 -> persons += token
2, 5 -> organizations += token
3, 0 -> locations += token
1, 4 -> misc += token
else -> Unit
}

private fun findMaxIndex(arr: FloatArray): Int = arr.indices.maxBy { arr[it] }

/**
* Инференс - главный метод вычисления результатов машинного анализа
*/
fun infer(inputText: String) = try {

// Выполняем предварительное кодирования текста в массивы
val encoding = try {
tokenizer.encode(inputText)
} catch (ioException: IOException) {
ioException.printStackTrace()
throw ioException
}

val tokens = encoding.tokens ?: throw Exception("No tokens detected") // извлечение токенов
// Формируем входные данные для модели
val modelInputs = mapOf(
"input_ids" to OnnxTensor.createTensor(
env,
arrayOf(encoding.ids ?: throw Exception("Empty ids"))
),
"attention_mask" to OnnxTensor.createTensor(
env,
arrayOf(encoding.attentionMask ?: throw Exception("Empty attention mask"))
),
)

// Объект для хранения результатов инференса
val inferringResult = InferringResult()

// Выполняем инференс
session.run(modelInputs)
// извлекаем результат инференса и преобразуем в нужный формат
?.firstOrNull()
?.value
?.value
?.let {
@Suppress("UNCHECKED_CAST")
it as? Array<Array<FloatArray>>
}
?.firstOrNull()
?.forEachIndexed { i, logits0i ->
try {
inferringResult.post(findMaxIndex(logits0i), tokens[i])
} catch (exception: Exception) {
exception.printStackTrace()
}
}
?: throw Exception("Empty result")

// выводим результат инференса
inferringResult.displayResult(tokens)
} catch (e: OrtException) {
e.printStackTrace()
}

/**
* Вывод результатов в консоль
*/
private fun InferringResult.displayResult(tokens: Array<String>) {
val tokensSpecialChar = tokens[1][0].toString() // word seperators in tokens
println("All persons in the text: ${persons.cleanResult(tokensSpecialChar)}")
println("All Organizations in the text: ${organizations.cleanResult(tokensSpecialChar)}")
println("All Locations in the text: ${locations.cleanResult(tokensSpecialChar)}")
println("All Miscellanous entities in the text: ${misc.cleanResult(tokensSpecialChar)}")
}

/**
* Вспомогательная функция для вывода результатов инференса в консоль
*/
private fun String.cleanResult(tokensSpecialChar: String) = split(tokensSpecialChar.toRegex())
.dropLastWhile { it.isEmpty() }
.filter { it.isNotBlank() }
.joinToString()
}
21 changes: 21 additions & 0 deletions ok-marketplace-ml/src/main/kotlin/InferringResult.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/**
* Модель для представления результатов инференса
*/
data class InferringResult(
/**
* Персоны в тексте
*/
var persons: String = "",
/**
* Локации в тексте
*/
var locations: String = "",
/**
* Организации в тексте
*/
var organizations: String = "",
/**
* Остальные значимые элементы в тексте
*/
var misc: String = "",
)
35 changes: 35 additions & 0 deletions ok-marketplace-ml/src/test/kotlin/OnnxInferTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import kotlin.test.Test

class OnnxInferTest {
@Test
fun onnxinferTest() {
val inferrer = Inferrer(
modelPath = "onnx-model/model.onnx",
tokenizerJson = "onnx-model/tokenizer.json",
)
inputTexts.forEach {
println("========================================")
println("TEXT: $it")
inferrer.infer(it)
}
}

companion object {
val inputTexts = listOf(
"Ahwar wants to work at Google in london. EU rejected German call to boycott British lamb.",
"""
KotlinDL is a high-level Deep Learning API written in Kotlin and inspired by Keras. Under the hood, it uses TensorFlow Java API and ONNX Runtime API for Java. KotlinDL offers simple APIs for training deep learning models from scratch, importing existing Keras and ONNX models for inference, and leveraging transfer learning for tailoring existing pre-trained models to your tasks.
This project aims to make Deep Learning easier for JVM and Android developers and simplify deploying deep learning models in production environments.
Here's an example of what a classic convolutional neural network LeNet would look like in KotlinDL:
""".trimIndent(),
"""
«Я́ндекс» — российская транснациональная компания в отрасли информационных технологий, чьё головное юридическое лицо зарегистрировано в Нидерландах, владеющая одноимённой системой поиска в интернете, интернет-порталом и веб-службами в нескольких странах. Наиболее заметное положение занимает на рынках России, Белоруссии и Казахстана[5].
Поисковая система Yandex.ru была официально анонсирована 23 сентября 1997 года и первое время развивалась в рамках компании CompTek International. Как отдельная компания «Яндекс» образовалась в 2000 году.
В мае 2011 года «Яндекс» провёл первичное размещение акций, заработав на этом больше, чем какая-либо из интернет-компаний со времён IPO-поисковика Google в 2004 году[6][7].
""".trimIndent(),
)
}
}
1 change: 1 addition & 0 deletions settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ includeBuild("ok-marketplace-states")
includeBuild("ok-marketplace-libs")

includeBuild("ok-marketplace-tests")
includeBuild("ok-marketplace-ml")
includeBuild("pgkn")

0 comments on commit 5e22b48

Please sign in to comment.