Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

M8l5 ml #34

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.onnx_data filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
26 changes: 26 additions & 0 deletions docs/ml-models-list.md

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ mkpl-cor = { module = "ru.otus.otuskotlin.marketplace.libs:ok-marketplace-lib-co
mkpl-state-common = { module = "ru.otus.otuskotlin.marketplace.state:ok-marketplace-states-common", version.ref = "mkpl" }
mkpl-state-biz = { module = "ru.otus.otuskotlin.marketplace.state:ok-marketplace-states-biz", version.ref = "mkpl" }

# Machine Learning
ml-tokenizer = "ai.djl.huggingface:tokenizers:0.25.0"
ml-onnx-runtime = "com.microsoft.onnxruntime:onnxruntime:1.16.3"

[bundles]
kotest = ["kotest-junit5", "kotest-core", "kotest-datatest", "kotest-property"]
Expand Down
8 changes: 8 additions & 0 deletions ok-marketplace-ml/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Kotlin ONNX ML Sample

Демонстрация использования [ONNX](https://onnxruntime.ai/docs/get-started/with-java.html) в Kotlin на примере NLP
модели [Roberta NER model](https://huggingface.co/xlm-roberta-large-finetuned-conll03-english). Необходимо скачать файлы модели (`model.onnx`, `model.onnx_data`, `tokenizer.json`) в
папку [onnx-model](onnx-model).


[Ноутбук с python-моделью](./Ml_demo1.ipynb)
26 changes: 26 additions & 0 deletions ok-marketplace-ml/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
plugins {
id("build-jvm")
}

group = "ru.otus.otuskotlin.marketplace.ml"
version = "0.0.1"

dependencies {
implementation(libs.ml.onnx.runtime)
implementation(libs.ml.tokenizer)
implementation(libs.logback)

testImplementation(kotlin("test-junit5"))
}

tasks {
test {
useJUnitPlatform()
}
}

allprojects {
repositories {
mavenCentral()
}
}
3 changes: 3 additions & 0 deletions ok-marketplace-ml/onnx-model/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*
!/.gitattributes
!/.gitignore
29 changes: 29 additions & 0 deletions ok-marketplace-ml/settings.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
rootProject.name = "ok-marketplace-ml"

dependencyResolutionManagement {
versionCatalogs {
create("libs") {
from(files("../gradle/libs.versions.toml"))
}
}
}

pluginManagement {
includeBuild("../build-plugin")
plugins {
id("build-jvm") apply false
id("build-kmp") apply false
}
repositories {
mavenCentral()
gradlePluginPortal()
}
}

plugins {
id("org.gradle.toolchains.foojay-resolver-convention") version "0.5.0"
}

// Включает вот такую конструкцию
//implementation(projects.m2l5Gradle.sub1.ssub1)
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS")
150 changes: 150 additions & 0 deletions ok-marketplace-ml/src/main/kotlin/Inferrer.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer
import ai.onnxruntime.OnnxTensor
import ai.onnxruntime.OrtEnvironment
import ai.onnxruntime.OrtException
import ai.onnxruntime.OrtSession
import ai.onnxruntime.OrtSession.SessionOptions
import java.io.IOException
import java.nio.file.Paths

/**
* Основной класс, выполняющий анализ текста с использованием модели машинного обучения
*/
class Inferrer(
modelPath: String = "model.onnx",
private val tokenizerJson: String = "tokenizer.json",
) {
/**
* Токенайзер преобразует текст в набор токенов
*/
private val tokenizer: HuggingFaceTokenizer by lazy {
runCatching { HuggingFaceTokenizer.newInstance(Paths.get(tokenizerJson)) }
.onFailure { e -> e.printStackTrace() }
.getOrThrow()
}

/**
* Onnx-runtime environment - среда исполнения модели
*/
private val env: OrtEnvironment by lazy {
OrtEnvironment.getEnvironment() ?: throw Exception("Failed to get ORT environment")
}

/**
* Onnx-runtime session - сессия среды исполнения модели
*/
private val session: OrtSession by lazy {
val s = env.createSession(modelPath, SessionOptions()) ?: throw Exception("Failed to get session")
println(
"""
Model Input Names: ${s.inputNames.joinToString()}
Model Input info: ${s.inputInfo.entries.joinToString { "${it.key}=${it.value}" }}
Model Output Names: ${s.outputNames.joinToString()}
Model Output info: ${s.outputInfo.entries.joinToString { "${it.key}=${it.value}" }}
""".trimIndent()
)
s
}

/*
Расширение для разбора результатов инференса
separates tokens into arrays according to class ids

below is the relation from class id to the label
"id2label": {
"0": "B-LOC",
"1": "B-MISC",
"2": "B-ORG",
"3": "I-LOC",
"4": "I-MISC",
"5": "I-ORG",
"6": "I-PER",
"7": "O"
* */
private fun InferringResult.post(
clazz: Int,
token: String,
) = when (clazz) {
6 -> persons += token
2, 5 -> organizations += token
3, 0 -> locations += token
1, 4 -> misc += token
else -> Unit
}

private fun findMaxIndex(arr: FloatArray): Int = arr.indices.maxBy { arr[it] }

/**
* Инференс - главный метод вычисления результатов машинного анализа
*/
fun infer(inputText: String) = try {

// Выполняем предварительное кодирования текста в массивы
val encoding = try {
tokenizer.encode(inputText)
} catch (ioException: IOException) {
ioException.printStackTrace()
throw ioException
}

val tokens = encoding.tokens ?: throw Exception("No tokens detected") // извлечение токенов
// Формируем входные данные для модели
val modelInputs = mapOf(
"input_ids" to OnnxTensor.createTensor(
env,
arrayOf(encoding.ids ?: throw Exception("Empty ids"))
),
"attention_mask" to OnnxTensor.createTensor(
env,
arrayOf(encoding.attentionMask ?: throw Exception("Empty attention mask"))
),
)

// Объект для хранения результатов инференса
val inferringResult = InferringResult()

// Выполняем инференс
session.run(modelInputs)
// извлекаем результат инференса и преобразуем в нужный формат
?.firstOrNull()
?.value
?.value
?.let {
@Suppress("UNCHECKED_CAST")
it as? Array<Array<FloatArray>>
}
?.firstOrNull()
?.forEachIndexed { i, logits0i ->
try {
inferringResult.post(findMaxIndex(logits0i), tokens[i])
} catch (exception: Exception) {
exception.printStackTrace()
}
}
?: throw Exception("Empty result")

// выводим результат инференса
inferringResult.displayResult(tokens)
} catch (e: OrtException) {
e.printStackTrace()
}

/**
* Вывод результатов в консоль
*/
private fun InferringResult.displayResult(tokens: Array<String>) {
val tokensSpecialChar = tokens[1][0].toString() // word seperators in tokens
println("All persons in the text: ${persons.cleanResult(tokensSpecialChar)}")
println("All Organizations in the text: ${organizations.cleanResult(tokensSpecialChar)}")
println("All Locations in the text: ${locations.cleanResult(tokensSpecialChar)}")
println("All Miscellanous entities in the text: ${misc.cleanResult(tokensSpecialChar)}")
}

/**
* Вспомогательная функция для вывода результатов инференса в консоль
*/
private fun String.cleanResult(tokensSpecialChar: String) = split(tokensSpecialChar.toRegex())
.dropLastWhile { it.isEmpty() }
.filter { it.isNotBlank() }
.joinToString()
}
21 changes: 21 additions & 0 deletions ok-marketplace-ml/src/main/kotlin/InferringResult.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/**
* Модель для представления результатов инференса
*/
data class InferringResult(
/**
* Персоны в тексте
*/
var persons: String = "",
/**
* Локации в тексте
*/
var locations: String = "",
/**
* Организации в тексте
*/
var organizations: String = "",
/**
* Остальные значимые элементы в тексте
*/
var misc: String = "",
)
35 changes: 35 additions & 0 deletions ok-marketplace-ml/src/test/kotlin/OnnxInferTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import kotlin.test.Test

class OnnxInferTest {
@Test
fun onnxinferTest() {
val inferrer = Inferrer(
modelPath = "onnx-model/model.onnx",
tokenizerJson = "onnx-model/tokenizer.json",
)
inputTexts.forEach {
println("========================================")
println("TEXT: $it")
inferrer.infer(it)
}
}

companion object {
val inputTexts = listOf(
"Ahwar wants to work at Google in london. EU rejected German call to boycott British lamb.",
"""
KotlinDL is a high-level Deep Learning API written in Kotlin and inspired by Keras. Under the hood, it uses TensorFlow Java API and ONNX Runtime API for Java. KotlinDL offers simple APIs for training deep learning models from scratch, importing existing Keras and ONNX models for inference, and leveraging transfer learning for tailoring existing pre-trained models to your tasks.
This project aims to make Deep Learning easier for JVM and Android developers and simplify deploying deep learning models in production environments.
Here's an example of what a classic convolutional neural network LeNet would look like in KotlinDL:

""".trimIndent(),
"""
«Я́ндекс» — российская транснациональная компания в отрасли информационных технологий, чьё головное юридическое лицо зарегистрировано в Нидерландах, владеющая одноимённой системой поиска в интернете, интернет-порталом и веб-службами в нескольких странах. Наиболее заметное положение занимает на рынках России, Белоруссии и Казахстана[5].

Поисковая система Yandex.ru была официально анонсирована 23 сентября 1997 года и первое время развивалась в рамках компании CompTek International. Как отдельная компания «Яндекс» образовалась в 2000 году.

В мае 2011 года «Яндекс» провёл первичное размещение акций, заработав на этом больше, чем какая-либо из интернет-компаний со времён IPO-поисковика Google в 2004 году[6][7].
""".trimIndent(),
)
}
}
1 change: 1 addition & 0 deletions settings.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ includeBuild("ok-marketplace-states")
includeBuild("ok-marketplace-libs")

includeBuild("ok-marketplace-tests")
includeBuild("ok-marketplace-ml")
includeBuild("pgkn")
Loading