diff --git a/ok-marketplace-ml/src/main/kotlin/Inferrer.kt b/ok-marketplace-ml/src/main/kotlin/Inferrer.kt index 0589789..b4ca38e 100644 --- a/ok-marketplace-ml/src/main/kotlin/Inferrer.kt +++ b/ok-marketplace-ml/src/main/kotlin/Inferrer.kt @@ -7,18 +7,32 @@ import ai.onnxruntime.OrtSession.SessionOptions import java.io.IOException import java.nio.file.Paths +/** + * Основной класс, выполняющий анализ текста с использованием модели машинного обучения + */ class Inferrer( modelPath: String = "model.onnx", private val tokenizerJson: String = "tokenizer.json", ) { + /** + * Токенайзер преобразует текст в набор токенов + */ private val tokenizer: HuggingFaceTokenizer by lazy { runCatching { HuggingFaceTokenizer.newInstance(Paths.get(tokenizerJson)) } .onFailure { e -> e.printStackTrace() } .getOrThrow() } + + /** + * Onnx-runtime environment - среда исполнения модели + */ private val env: OrtEnvironment by lazy { OrtEnvironment.getEnvironment() ?: throw Exception("Failed to get ORT environment") } + + /** + * Onnx-runtime session - сессия среды исполнения модели + */ private val session: OrtSession by lazy { val s = env.createSession(modelPath, SessionOptions()) ?: throw Exception("Failed to get session") println( @@ -33,7 +47,8 @@ class Inferrer( } /* - seperates tokens into arrays according to class ids + Расширение для разбора результатов инференса + separates tokens into arrays according to class ids below is the relation from class id to the label "id2label": { @@ -59,8 +74,12 @@ class Inferrer( private fun findMaxIndex(arr: FloatArray): Int = arr.indices.maxBy { arr[it] } + /** + * Инференс - главный метод вычисления результатов машинного анализа + */ fun infer(inputText: String) = try { + // Выполняем предварительное кодирования текста в массивы val encoding = try { tokenizer.encode(inputText) } catch (ioException: IOException) { @@ -68,7 +87,8 @@ class Inferrer( throw ioException } - val tokens = encoding.tokens ?: throw Exception("No tokens detected") // tokenize the text + val tokens = encoding.tokens ?: throw Exception("No tokens detected") // извлечение токенов + // Формируем входные данные для модели val modelInputs = mapOf( "input_ids" to OnnxTensor.createTensor( env, @@ -80,13 +100,19 @@ class Inferrer( ), ) + // Объект для хранения результатов инференса val inferringResult = InferringResult() - @Suppress("UNCHECKED_CAST") + + // Выполняем инференс session.run(modelInputs) + // извлекаем результат инференса и преобразуем в нужный формат ?.firstOrNull() ?.value ?.value - ?.let { it as? Array> } + ?.let { + @Suppress("UNCHECKED_CAST") + it as? Array> + } ?.firstOrNull() ?.forEachIndexed { i, logits0i -> try { @@ -96,15 +122,17 @@ class Inferrer( } } ?: throw Exception("Empty result") + + // выводим результат инференса inferringResult.displayResult(tokens) } catch (e: OrtException) { e.printStackTrace() } + /** + * Вывод результатов в консоль + */ private fun InferringResult.displayResult(tokens: Array) { - /* - * Showing the results - * */ val tokensSpecialChar = tokens[1][0].toString() // word seperators in tokens println("All persons in the text: ${persons.cleanResult(tokensSpecialChar)}") println("All Organizations in the text: ${organizations.cleanResult(tokensSpecialChar)}") @@ -112,6 +140,9 @@ class Inferrer( println("All Miscellanous entities in the text: ${misc.cleanResult(tokensSpecialChar)}") } + /** + * Вспомогательная функция для вывода результатов инференса в консоль + */ private fun String.cleanResult(tokensSpecialChar: String) = split(tokensSpecialChar.toRegex()) .dropLastWhile { it.isEmpty() } .filter { it.isNotBlank() } diff --git a/ok-marketplace-ml/src/main/kotlin/InferringResult.kt b/ok-marketplace-ml/src/main/kotlin/InferringResult.kt index ce4cc64..da8fbca 100644 --- a/ok-marketplace-ml/src/main/kotlin/InferringResult.kt +++ b/ok-marketplace-ml/src/main/kotlin/InferringResult.kt @@ -1,6 +1,21 @@ +/** + * Модель для представления результатов инференса + */ data class InferringResult( + /** + * Персоны в тексте + */ var persons: String = "", + /** + * Локации в тексте + */ var locations: String = "", + /** + * Организации в тексте + */ var organizations: String = "", + /** + * Остальные значимые элементы в тексте + */ var misc: String = "", )