模型一切就緒後緊接著重頭戲就是將模型放到 Android 上讀取並進行預測。首先執行 Android Studio 並開啟一個新專案,其中我們要建立一個 Ktolin 基底的專案。
接著打開 build.gradel(app)
新增 tensorflow-lite
的套件。
implementation "org.tensorflow:tensorflow-lite:+"
另外為了避免簽署生成 apk 期間壓縮我們的模型,我們需要在該檔案內 android{ }
中加入以下描述:
aaptOptions {
noCompress "tflite"
noCompress "lite"
}
接著建立一個 assets
資料夾放入稍早所轉換好的 .tflite
模型,並將此資料夾放在專案資料夾中 app -> src -> main
的位置。
我們先處理 layout 畫面,首先建立一個 EditText
提供使用者輸入數值,並有一個按鈕送(Button)出並觸發模型預測。最後將預測結果顯示在畫面上。
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<EditText
android:id="@+id/numberField"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:ems="10"
android:inputType="number"
app:layout_constraintBottom_toTopOf="@+id/guideline"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent" />
<androidx.constraintlayout.widget.Guideline
android:id="@+id/guideline"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:orientation="horizontal"
app:layout_constraintGuide_percent="0.5" />
<Button
android:id="@+id/btnPredict"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="16dp"
android:text="Predict"
app:layout_constraintEnd_toEndOf="@+id/numberField"
app:layout_constraintStart_toStartOf="@+id/numberField"
app:layout_constraintTop_toBottomOf="@+id/numberField" />
<TextView
android:id="@+id/txtResult"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="32dp"
android:fontFamily="sans-serif-black"
android:textColor="@android:color/black"
android:textSize="18sp"
app:layout_constraintEnd_toEndOf="@+id/btnPredict"
app:layout_constraintStart_toStartOf="@+id/btnPredict"
app:layout_constraintTop_toBottomOf="@+id/btnPredict" />
</androidx.constraintlayout.widget.ConstraintLayout>
接著開啟 MainActivity
撰寫主程式,首先建立一個 initInterpreter()
函數載入模型並初始化。其中 options 是對模型的一些資源設定,例如我們設定使用 4 個執行緒。以及設定使用 setUseNNAPI
,Android Neural Networks API (NNAPI) 是一個 Android C API,專門為在邊緣設備上針對機器學習運行計算密集型運算而設計。因為我們在這次範例中使用很簡單的網路層架構,如果是影像辨識專案有使用到許多卷積層相關的 API 那麼該模型是可能無法進行 tflite-android 的 NNAPI 加速的。loadModelFile()
函式負責去讀取 regression.tflite
並提供模型初始化。第三個函數是 doInference()
負責接收使用者輸入的數值,並丟入模型預測。我們可以發現 TFLite 一樣是透過 interpreter
進行模型預測,我們需要事先將輸出的變數建立一個空陣列並且使用 FloatArray
。
class MainActivity : AppCompatActivity() {
private lateinit var interpreter: Interpreter
private val mModelPath = "regression.tflite"
private lateinit var resultText : TextView
private lateinit var editText : EditText
private lateinit var checkButton : Button
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
resultText = findViewById(R.id.txtResult)
editText = findViewById(R.id.numberField)
checkButton = findViewById(R.id.btnPredict)
checkButton.setOnClickListener {
var result = doInference(editText.text.toString())
runOnUiThread {
resultText.text = result.toString()
}
}
initInterpreter()
}
private fun initInterpreter(){
val options = Interpreter.Options()
options.setNumThreads(4)
options.setUseNNAPI(true)
interpreter = Interpreter(loadModelFile(assets, mModelPath), options)
}
private fun doInference(inputString: String): Float {
val inputVal = FloatArray(1)
inputVal[0] = inputString.toFloat()
val output = Array(1) { FloatArray(1) }
interpreter.run(inputVal, output)
return output[0][0]
}
private fun loadModelFile(assetManager: AssetManager, modelPath: String): MappedByteBuffer {
val fileDescriptor = assetManager.openFd(modelPath)
val inputStream = FileInputStream(fileDescriptor.fileDescriptor)
val fileChannel = inputStream.channel
val startOffset = fileDescriptor.startOffset
val declaredLength = fileDescriptor.declaredLength
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength)
}
}