Skip to content

Commit

Permalink
M8l4 ML
Browse files Browse the repository at this point in the history
  • Loading branch information
svok committed Jul 26, 2024
1 parent b6a259c commit 772e542
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
45 changes: 38 additions & 7 deletions ok-marketplace-ml/src/main/kotlin/Inferrer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,32 @@ import ai.onnxruntime.OrtSession.SessionOptions
import java.io.IOException
import java.nio.file.Paths

/**
* Основной класс, выполняющий анализ текста с использованием модели машинного обучения
*/
class Inferrer(
modelPath: String = "model.onnx",
private val tokenizerJson: String = "tokenizer.json",
) {
/**
* Токенайзер преобразует текст в набор токенов
*/
private val tokenizer: HuggingFaceTokenizer by lazy {
runCatching { HuggingFaceTokenizer.newInstance(Paths.get(tokenizerJson)) }
.onFailure { e -> e.printStackTrace() }
.getOrThrow()
}

/**
* Onnx-runtime environment - среда исполнения модели
*/
private val env: OrtEnvironment by lazy {
OrtEnvironment.getEnvironment() ?: throw Exception("Failed to get ORT environment")
}

/**
* Onnx-runtime session - сессия среды исполнения модели
*/
private val session: OrtSession by lazy {
val s = env.createSession(modelPath, SessionOptions()) ?: throw Exception("Failed to get session")
println(
Expand All @@ -33,7 +47,8 @@ class Inferrer(
}

/*
seperates tokens into arrays according to class ids
Расширение для разбора результатов инференса
separates tokens into arrays according to class ids
below is the relation from class id to the label
"id2label": {
Expand All @@ -59,16 +74,21 @@ class Inferrer(

private fun findMaxIndex(arr: FloatArray): Int = arr.indices.maxBy { arr[it] }

/**
* Инференс - главный метод вычисления результатов машинного анализа
*/
fun infer(inputText: String) = try {

// Выполняем предварительное кодирования текста в массивы
val encoding = try {
tokenizer.encode(inputText)
} catch (ioException: IOException) {
ioException.printStackTrace()
throw ioException
}

val tokens = encoding.tokens ?: throw Exception("No tokens detected") // tokenize the text
val tokens = encoding.tokens ?: throw Exception("No tokens detected") // извлечение токенов
// Формируем входные данные для модели
val modelInputs = mapOf(
"input_ids" to OnnxTensor.createTensor(
env,
Expand All @@ -80,13 +100,19 @@ class Inferrer(
),
)

// Объект для хранения результатов инференса
val inferringResult = InferringResult()
@Suppress("UNCHECKED_CAST")

// Выполняем инференс
session.run(modelInputs)
// извлекаем результат инференса и преобразуем в нужный формат
?.firstOrNull()
?.value
?.value
?.let { it as? Array<Array<FloatArray>> }
?.let {
@Suppress("UNCHECKED_CAST")
it as? Array<Array<FloatArray>>
}
?.firstOrNull()
?.forEachIndexed { i, logits0i ->
try {
Expand All @@ -96,22 +122,27 @@ class Inferrer(
}
}
?: throw Exception("Empty result")

// выводим результат инференса
inferringResult.displayResult(tokens)
} catch (e: OrtException) {
e.printStackTrace()
}

/**
* Вывод результатов в консоль
*/
private fun InferringResult.displayResult(tokens: Array<String>) {
/*
* Showing the results
* */
val tokensSpecialChar = tokens[1][0].toString() // word seperators in tokens
println("All persons in the text: ${persons.cleanResult(tokensSpecialChar)}")
println("All Organizations in the text: ${organizations.cleanResult(tokensSpecialChar)}")
println("All Locations in the text: ${locations.cleanResult(tokensSpecialChar)}")
println("All Miscellanous entities in the text: ${misc.cleanResult(tokensSpecialChar)}")
}

/**
* Вспомогательная функция для вывода результатов инференса в консоль
*/
private fun String.cleanResult(tokensSpecialChar: String) = split(tokensSpecialChar.toRegex())
.dropLastWhile { it.isEmpty() }
.filter { it.isNotBlank() }
Expand Down
15 changes: 15 additions & 0 deletions ok-marketplace-ml/src/main/kotlin/InferringResult.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
/**
* Модель для представления результатов инференса
*/
data class InferringResult(
/**
* Персоны в тексте
*/
var persons: String = "",
/**
* Локации в тексте
*/
var locations: String = "",
/**
* Организации в тексте
*/
var organizations: String = "",
/**
* Остальные значимые элементы в тексте
*/
var misc: String = "",
)

0 comments on commit 772e542

Please sign in to comment.