-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* M8l4 mlkotlin (#35) * m8l4 ML in Kotlin project * m8l4 ML in Kotlin project1 * m8l4 ML in Kotlin project3 * m8l4 ML in Kotlin project3 * m8l4 ML in Kotlin project3 (cherry picked from commit a63f224e18f74c2780bff7a2ff6faba585242baf) * M8l4 ML (cherry picked from commit 5d42aa43f041bc360c978a8c0ababa9ce74f4911)
- Loading branch information
Showing
11 changed files
with
304 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
*.onnx_data filter=lfs diff=lfs merge=lfs -text | ||
*.onnx filter=lfs diff=lfs merge=lfs -text |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Kotlin ONNX ML Sample | ||
|
||
Демонстрация использования [ONNX](https://onnxruntime.ai/docs/get-started/with-java.html) в Kotlin на примере NLP | ||
модели [Roberta NER model](https://huggingface.co/xlm-roberta-large-finetuned-conll03-english). Необходимо скачать файлы модели (`model.onnx`, `model.onnx_data`, `tokenizer.json`) в | ||
папку [onnx-model](onnx-model). | ||
|
||
|
||
[Ноутбук с python-моделью](./Ml_demo1.ipynb) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
plugins { | ||
id("build-jvm") | ||
} | ||
|
||
group = "ru.otus.otuskotlin.marketplace.ml" | ||
version = "0.0.1" | ||
|
||
dependencies { | ||
implementation(libs.ml.onnx.runtime) | ||
implementation(libs.ml.tokenizer) | ||
implementation(libs.logback) | ||
|
||
testImplementation(kotlin("test-junit5")) | ||
} | ||
|
||
tasks { | ||
test { | ||
useJUnitPlatform() | ||
} | ||
} | ||
|
||
allprojects { | ||
repositories { | ||
mavenCentral() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
* | ||
!/.gitattributes | ||
!/.gitignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
rootProject.name = "ok-marketplace-ml" | ||
|
||
dependencyResolutionManagement { | ||
versionCatalogs { | ||
create("libs") { | ||
from(files("../gradle/libs.versions.toml")) | ||
} | ||
} | ||
} | ||
|
||
pluginManagement { | ||
includeBuild("../build-plugin") | ||
plugins { | ||
id("build-jvm") apply false | ||
id("build-kmp") apply false | ||
} | ||
repositories { | ||
mavenCentral() | ||
gradlePluginPortal() | ||
} | ||
} | ||
|
||
plugins { | ||
id("org.gradle.toolchains.foojay-resolver-convention") version "0.5.0" | ||
} | ||
|
||
// Включает вот такую конструкцию | ||
//implementation(projects.m2l5Gradle.sub1.ssub1) | ||
enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer | ||
import ai.onnxruntime.OnnxTensor | ||
import ai.onnxruntime.OrtEnvironment | ||
import ai.onnxruntime.OrtException | ||
import ai.onnxruntime.OrtSession | ||
import ai.onnxruntime.OrtSession.SessionOptions | ||
import java.io.IOException | ||
import java.nio.file.Paths | ||
|
||
/** | ||
* Основной класс, выполняющий анализ текста с использованием модели машинного обучения | ||
*/ | ||
class Inferrer( | ||
modelPath: String = "model.onnx", | ||
private val tokenizerJson: String = "tokenizer.json", | ||
) { | ||
/** | ||
* Токенайзер преобразует текст в набор токенов | ||
*/ | ||
private val tokenizer: HuggingFaceTokenizer by lazy { | ||
runCatching { HuggingFaceTokenizer.newInstance(Paths.get(tokenizerJson)) } | ||
.onFailure { e -> e.printStackTrace() } | ||
.getOrThrow() | ||
} | ||
|
||
/** | ||
* Onnx-runtime environment - среда исполнения модели | ||
*/ | ||
private val env: OrtEnvironment by lazy { | ||
OrtEnvironment.getEnvironment() ?: throw Exception("Failed to get ORT environment") | ||
} | ||
|
||
/** | ||
* Onnx-runtime session - сессия среды исполнения модели | ||
*/ | ||
private val session: OrtSession by lazy { | ||
val s = env.createSession(modelPath, SessionOptions()) ?: throw Exception("Failed to get session") | ||
println( | ||
""" | ||
Model Input Names: ${s.inputNames.joinToString()} | ||
Model Input info: ${s.inputInfo.entries.joinToString { "${it.key}=${it.value}" }} | ||
Model Output Names: ${s.outputNames.joinToString()} | ||
Model Output info: ${s.outputInfo.entries.joinToString { "${it.key}=${it.value}" }} | ||
""".trimIndent() | ||
) | ||
s | ||
} | ||
|
||
/* | ||
Расширение для разбора результатов инференса | ||
separates tokens into arrays according to class ids | ||
below is the relation from class id to the label | ||
"id2label": { | ||
"0": "B-LOC", | ||
"1": "B-MISC", | ||
"2": "B-ORG", | ||
"3": "I-LOC", | ||
"4": "I-MISC", | ||
"5": "I-ORG", | ||
"6": "I-PER", | ||
"7": "O" | ||
* */ | ||
private fun InferringResult.post( | ||
clazz: Int, | ||
token: String, | ||
) = when (clazz) { | ||
6 -> persons += token | ||
2, 5 -> organizations += token | ||
3, 0 -> locations += token | ||
1, 4 -> misc += token | ||
else -> Unit | ||
} | ||
|
||
private fun findMaxIndex(arr: FloatArray): Int = arr.indices.maxBy { arr[it] } | ||
|
||
/** | ||
* Инференс - главный метод вычисления результатов машинного анализа | ||
*/ | ||
fun infer(inputText: String) = try { | ||
|
||
// Выполняем предварительное кодирования текста в массивы | ||
val encoding = try { | ||
tokenizer.encode(inputText) | ||
} catch (ioException: IOException) { | ||
ioException.printStackTrace() | ||
throw ioException | ||
} | ||
|
||
val tokens = encoding.tokens ?: throw Exception("No tokens detected") // извлечение токенов | ||
// Формируем входные данные для модели | ||
val modelInputs = mapOf( | ||
"input_ids" to OnnxTensor.createTensor( | ||
env, | ||
arrayOf(encoding.ids ?: throw Exception("Empty ids")) | ||
), | ||
"attention_mask" to OnnxTensor.createTensor( | ||
env, | ||
arrayOf(encoding.attentionMask ?: throw Exception("Empty attention mask")) | ||
), | ||
) | ||
|
||
// Объект для хранения результатов инференса | ||
val inferringResult = InferringResult() | ||
|
||
// Выполняем инференс | ||
session.run(modelInputs) | ||
// извлекаем результат инференса и преобразуем в нужный формат | ||
?.firstOrNull() | ||
?.value | ||
?.value | ||
?.let { | ||
@Suppress("UNCHECKED_CAST") | ||
it as? Array<Array<FloatArray>> | ||
} | ||
?.firstOrNull() | ||
?.forEachIndexed { i, logits0i -> | ||
try { | ||
inferringResult.post(findMaxIndex(logits0i), tokens[i]) | ||
} catch (exception: Exception) { | ||
exception.printStackTrace() | ||
} | ||
} | ||
?: throw Exception("Empty result") | ||
|
||
// выводим результат инференса | ||
inferringResult.displayResult(tokens) | ||
} catch (e: OrtException) { | ||
e.printStackTrace() | ||
} | ||
|
||
/** | ||
* Вывод результатов в консоль | ||
*/ | ||
private fun InferringResult.displayResult(tokens: Array<String>) { | ||
val tokensSpecialChar = tokens[1][0].toString() // word seperators in tokens | ||
println("All persons in the text: ${persons.cleanResult(tokensSpecialChar)}") | ||
println("All Organizations in the text: ${organizations.cleanResult(tokensSpecialChar)}") | ||
println("All Locations in the text: ${locations.cleanResult(tokensSpecialChar)}") | ||
println("All Miscellanous entities in the text: ${misc.cleanResult(tokensSpecialChar)}") | ||
} | ||
|
||
/** | ||
* Вспомогательная функция для вывода результатов инференса в консоль | ||
*/ | ||
private fun String.cleanResult(tokensSpecialChar: String) = split(tokensSpecialChar.toRegex()) | ||
.dropLastWhile { it.isEmpty() } | ||
.filter { it.isNotBlank() } | ||
.joinToString() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
/** | ||
* Модель для представления результатов инференса | ||
*/ | ||
data class InferringResult( | ||
/** | ||
* Персоны в тексте | ||
*/ | ||
var persons: String = "", | ||
/** | ||
* Локации в тексте | ||
*/ | ||
var locations: String = "", | ||
/** | ||
* Организации в тексте | ||
*/ | ||
var organizations: String = "", | ||
/** | ||
* Остальные значимые элементы в тексте | ||
*/ | ||
var misc: String = "", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import kotlin.test.Test | ||
|
||
class OnnxInferTest { | ||
@Test | ||
fun onnxinferTest() { | ||
val inferrer = Inferrer( | ||
modelPath = "onnx-model/model.onnx", | ||
tokenizerJson = "onnx-model/tokenizer.json", | ||
) | ||
inputTexts.forEach { | ||
println("========================================") | ||
println("TEXT: $it") | ||
inferrer.infer(it) | ||
} | ||
} | ||
|
||
companion object { | ||
val inputTexts = listOf( | ||
"Ahwar wants to work at Google in london. EU rejected German call to boycott British lamb.", | ||
""" | ||
KotlinDL is a high-level Deep Learning API written in Kotlin and inspired by Keras. Under the hood, it uses TensorFlow Java API and ONNX Runtime API for Java. KotlinDL offers simple APIs for training deep learning models from scratch, importing existing Keras and ONNX models for inference, and leveraging transfer learning for tailoring existing pre-trained models to your tasks. | ||
This project aims to make Deep Learning easier for JVM and Android developers and simplify deploying deep learning models in production environments. | ||
Here's an example of what a classic convolutional neural network LeNet would look like in KotlinDL: | ||
""".trimIndent(), | ||
""" | ||
«Я́ндекс» — российская транснациональная компания в отрасли информационных технологий, чьё головное юридическое лицо зарегистрировано в Нидерландах, владеющая одноимённой системой поиска в интернете, интернет-порталом и веб-службами в нескольких странах. Наиболее заметное положение занимает на рынках России, Белоруссии и Казахстана[5]. | ||
Поисковая система Yandex.ru была официально анонсирована 23 сентября 1997 года и первое время развивалась в рамках компании CompTek International. Как отдельная компания «Яндекс» образовалась в 2000 году. | ||
В мае 2011 года «Яндекс» провёл первичное размещение акций, заработав на этом больше, чем какая-либо из интернет-компаний со времён IPO-поисковика Google в 2004 году[6][7]. | ||
""".trimIndent(), | ||
) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters