Skip to content

Commit f349da5

Browse files
committed
add pytorch dcgan
1 parent 9e1feff commit f349da5

File tree

13 files changed

+184
-33
lines changed

13 files changed

+184
-33
lines changed

Diff for: build.gradle

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// Top-level build file where you can add configuration options common to all sub-projects/modules.
22

33
plugins {
4-
id 'com.android.application' version '8.8.0' apply false
5-
id 'com.android.library' version '8.8.0' apply false
4+
id 'com.android.application' version '8.8.1' apply false
5+
id 'com.android.library' version '8.8.1' apply false
66
id 'org.jetbrains.kotlin.android' version '2.0.20' apply false
77
id 'com.google.devtools.ksp' version '2.0.10-1.0.24' apply false
88
id 'org.jetbrains.kotlin.plugin.compose' version '2.0.20' apply false

Diff for: buildSrc/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ dependencies {
99
// implementation localGroovy()
1010
implementation 'org.jetbrains.kotlin:kotlin-stdlib:2.1.0'
1111
// 添加了这个,就可以看 Android Gradle 插件的源码了
12-
implementation 'com.android.tools.build:gradle-api:8.8.0'
12+
implementation 'com.android.tools.build:gradle-api:8.8.1'
1313
// implementation "org.jetbrains.kotlin:kotlin-script-runtime:1.3.40"
1414
implementation 'com.google.code.gson:gson:2.11.0'
1515
implementation 'com.android.tools:common:31.7.3'

Diff for: imitate/src/main/java/com/engineer/imitate/ui/fragments/EntranceFragment.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import androidx.recyclerview.widget.RecyclerView
2020
import com.alibaba.android.arouter.facade.annotation.Route
2121
import com.andrefrsousa.superbottomsheet.SuperBottomSheetFragment
2222
import com.bumptech.glide.Glide
23-
import com.engineer.ai.DigitalClassificationActivity
23+
import com.engineer.ai.AIHomeActivity
2424
import com.engineer.imitate.R
2525
import com.engineer.imitate.databinding.FragmentEntranceBinding
2626
import com.engineer.imitate.ui.activity.CLActivity
@@ -89,7 +89,7 @@ class EntranceFragment : Fragment() {
8989
}
9090

9191
viewBinding.ai.setOnClickListener {
92-
startActivity(Intent(context, DigitalClassificationActivity::class.java))
92+
startActivity(Intent(context, AIHomeActivity::class.java))
9393
}
9494

9595
viewBinding.scanWifi.setOnClickListener {

Diff for: subs/ai/build.gradle

+5-2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ dependencies {
5959
}
6060
implementation 'com.google.android.gms:play-services-tasks:18.2.0'
6161

62-
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
63-
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
62+
63+
implementation 'org.pytorch:pytorch_android_lite:1.12.2'
64+
implementation 'org.pytorch:pytorch_android_torchvision_lite:1.12.2'
65+
// implementation 'org.pytorch:pytorch_android:1.10.0'
66+
// implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
6467
}

Diff for: subs/ai/src/main/AndroidManifest.xml

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
<manifest xmlns:android="http://schemas.android.com/apk/res/android">
33

44
<application>
5+
<activity android:name=".AIHomeActivity" android:exported="true" />
56
<activity
67
android:name=".DigitalClassificationActivity"
78
android:exported="true" />

Diff for: subs/ai/src/main/assets/dcgan.pt

-16.2 KB
Binary file not shown.
+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package com.engineer.ai
2+
3+
import android.content.Intent
4+
import android.os.Bundle
5+
import android.widget.Button
6+
import android.widget.LinearLayout
7+
import androidx.appcompat.app.AppCompatActivity
8+
9+
10+
class AIHomeActivity : AppCompatActivity() {
11+
12+
private val pages = arrayOf(GanActivity::class.java,
13+
DigitalClassificationActivity::class.java)
14+
15+
override fun onCreate(savedInstanceState: Bundle?) {
16+
super.onCreate(savedInstanceState)
17+
val linerLayout = LinearLayout(this)
18+
linerLayout.orientation = LinearLayout.VERTICAL
19+
20+
for (page in pages) {
21+
val button = Button(this)
22+
button.text = page.simpleName
23+
button.setOnClickListener {
24+
startActivity(Intent(this, page))
25+
}
26+
linerLayout.addView(button)
27+
}
28+
setContentView(linerLayout)
29+
}
30+
31+
32+
}

Diff for: subs/ai/src/main/java/com/engineer/ai/DigitalClassificationActivity.kt

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import android.view.MotionEvent
99
import android.widget.Button
1010
import android.widget.TextView
1111
import com.divyanshu.draw.widget.DrawView
12+
import com.engineer.ai.util.DigitClassifier
1213
import org.tensorflow.lite.TensorFlowLite
1314

1415
class DigitalClassificationActivity : AppCompatActivity() {

Diff for: subs/ai/src/main/java/com/engineer/ai/GanActivity.kt

+24-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ package com.engineer.ai
33
import android.os.Bundle
44
import androidx.appcompat.app.AppCompatActivity
55
import com.engineer.ai.databinding.ActivityGanBinding
6+
import com.engineer.ai.util.AndroidAssetsFileUtil
7+
import com.engineer.ai.util.Utils
68
import org.pytorch.IValue
9+
import org.pytorch.LiteModuleLoader
710
import org.pytorch.Module
811
import org.pytorch.Tensor
912
import java.util.Random
@@ -28,21 +31,35 @@ class GanActivity : AppCompatActivity() {
2831
val zDim = intArrayOf(1, 100)
2932
val outDims = intArrayOf(64, 64, 3)
3033

31-
val z = FloatArray(zDim[0] * outDims[1])
32-
val random = Random(System.currentTimeMillis())
33-
z[0] = random.nextGaussian().toFloat()
34+
val z = FloatArray(zDim[0] * zDim[1])
35+
36+
val rand = Random()
37+
// 生成高斯随机数
38+
for (c in 0 until zDim[0] * zDim[1]) {
39+
z[c] = rand.nextGaussian().toFloat()
40+
}
3441
val shape = longArrayOf(1, 100)
3542
val tensor = Tensor.fromBlob(z, shape)
3643

3744
val resultT = module.forward(IValue.from(tensor)).toTensor()
3845
val resultArray = resultT.dataAsFloatArray
39-
40-
// val img = floatArrayOf(outDims[0],outDims[1],outDims[2])
41-
46+
val resultImg = Array(outDims[0]) { Array(outDims[1]) { FloatArray(outDims[2]) { 0.0f } } }
47+
var index = 0
48+
// 根据输出的一维数组,解析生成的卡通图像
49+
for (j in 0 until outDims[2]) {
50+
for (k in 0 until outDims[0]) {
51+
for (m in 0 until outDims[1]) {
52+
resultImg[k][m][j] = resultArray[index] * 127.5f + 127.5f
53+
index++
54+
}
55+
}
56+
}
57+
val bitmap = Utils.getBitmap(resultImg, outDims)
58+
viewBinding.ganResult.setImageBitmap(bitmap)
4259
}
4360

4461
private fun initModel() {
45-
module = Module.load(AndroidAssetsFileUtil.assetFilePath(this, modelName))
62+
module = LiteModuleLoader.load(AndroidAssetsFileUtil.assetFilePath(this, modelName))
4663
}
4764

4865

Diff for: subs/ai/src/main/java/com/engineer/ai/Mina.kt

-16
This file was deleted.

Diff for: subs/ai/src/main/java/com/engineer/ai/AndroidAssetsFileUtil.kt renamed to subs/ai/src/main/java/com/engineer/ai/util/AndroidAssetsFileUtil.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
package com.engineer.ai
1+
package com.engineer.ai.util
22

33
import android.content.Context
44
import java.io.File
55
import java.io.FileOutputStream
6-
import java.io.IOException
76

87
object AndroidAssetsFileUtil {
98

Diff for: subs/ai/src/main/java/com/engineer/ai/DigitClassifier.kt renamed to subs/ai/src/main/java/com/engineer/ai/util/DigitClassifier.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ See the License for the specific language governing permissions and
1111
limitations under the License.
1212
==============================================================================*/
1313

14-
package com.engineer.ai
14+
package com.engineer.ai.util
1515

1616
import android.content.Context
1717
import android.content.res.AssetManager
+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package com.engineer.ai.util;
2+
3+
import android.content.Context;
4+
import android.database.Cursor;
5+
import android.graphics.Bitmap;
6+
import android.graphics.BitmapFactory;
7+
import android.graphics.Matrix;
8+
import android.net.Uri;
9+
import android.provider.MediaStore;
10+
11+
public class Utils {
12+
13+
/**
14+
* @param context 上下文
15+
* @param uri 资源标识
16+
* @return 路径值
17+
*/
18+
public static String getPathFromUri(Context context, Uri uri) {
19+
20+
String result;
21+
Cursor cursor = context.getContentResolver().query(uri, null, null, null, null);
22+
if (cursor == null) {
23+
result = uri.getPath();
24+
} else {
25+
cursor.moveToFirst();
26+
int idx = cursor.getColumnIndex(MediaStore.Images.ImageColumns.DATA);
27+
result = cursor.getString(idx);
28+
cursor.close();
29+
}
30+
return result;
31+
}
32+
33+
/**
34+
* @param image_array 输入的四维数组
35+
* @param dim_info 未读信息
36+
* @return
37+
*/
38+
39+
public static Bitmap getBitmap(float[][][] image_array, int[] dim_info) {
40+
int count = 0;
41+
int[] color_info = new int[dim_info[0] * dim_info[1]];
42+
// 遍历图像,获取颜色信息
43+
for (int i = 0; i < dim_info[0]; i++) {
44+
for (int j = 0; j < dim_info[1]; j++) {
45+
float[] arr = image_array[i][j];
46+
int alpha = 255;
47+
int red = (int) arr[0];
48+
int green = (int) arr[1];
49+
int blue = (int) arr[2];
50+
int tempARGB = (alpha << 24) | (red << 16) | (green << 8) | blue;
51+
color_info[count++] = tempARGB;
52+
}
53+
}
54+
// 创建bitmap对象
55+
return Bitmap.createBitmap(color_info, dim_info[0], dim_info[1], Bitmap.Config.ARGB_8888);
56+
}
57+
58+
/**
59+
* @param filePath 文件路径
60+
* @return Bitmap对象
61+
*/
62+
public static Bitmap getScaleBitmapByPath(String filePath) {
63+
64+
BitmapFactory.Options options = new BitmapFactory.Options();
65+
options.inJustDecodeBounds = true;
66+
BitmapFactory.decodeFile(filePath, options);
67+
int width = options.outWidth;
68+
int height = options.outHeight;
69+
70+
int maxSize = 500;
71+
options.inSampleSize = 1;
72+
while (true) {
73+
if (width / options.inSampleSize < maxSize || height / options.inSampleSize < maxSize) {
74+
break;
75+
}
76+
options.inSampleSize *= 2;
77+
}
78+
options.inJustDecodeBounds = false;
79+
80+
// 返回解码后的图片
81+
return BitmapFactory.decodeFile(filePath, options);
82+
}
83+
84+
85+
/**
86+
* @param origin 原始Bitmap
87+
* @param newWidth 缩放后的宽度
88+
* @param newHeight 缩放后的高度
89+
* @return 缩放后的Bitmap
90+
*/
91+
public static Bitmap getScaleBitmapByBitmap(Bitmap origin, int newWidth, int newHeight) {
92+
// 如果输入的Bitmap为空,则直接返回
93+
if (origin == null) {
94+
return null;
95+
}
96+
// 原始Bitmap的长宽
97+
int height = origin.getHeight();
98+
int width = origin.getWidth();
99+
100+
// 计算缩放后的图像比例
101+
float scaleWidthRatio= ((float) newWidth) / width;
102+
float scaleHeightRatio = ((float) newHeight) / height;
103+
104+
Matrix matrix = new Matrix();
105+
matrix.postScale(scaleWidthRatio, scaleHeightRatio);
106+
107+
// 创建新的Bitmap
108+
Bitmap scaledBitmap = Bitmap.createBitmap(origin, 0, 0, width, height, matrix, false);
109+
if (!origin.isRecycled()) {
110+
origin.recycle();
111+
}
112+
return scaledBitmap;
113+
}
114+
}

0 commit comments

Comments
 (0)