From 9bea9a962a6888e2a3f14fc0e3cde32536862988 Mon Sep 17 00:00:00 2001 From: zhengbomo Date: Sat, 4 Dec 2021 16:20:03 +0800 Subject: [PATCH 1/3] migrated to the Android embedding v2 --- android/build.gradle | 4 +- .../java/sq/flutter/tflite/TflitePlugin.java | 62 +++++++++++++++---- 2 files changed, 51 insertions(+), 15 deletions(-) diff --git a/android/build.gradle b/android/build.gradle index 8002459..f2b8004 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -33,7 +33,7 @@ android { } dependencies { - compile 'org.tensorflow:tensorflow-lite:+' - compile 'org.tensorflow:tensorflow-lite-gpu:+' + implementation 'org.tensorflow:tensorflow-lite:+' + implementation 'org.tensorflow:tensorflow-lite-gpu:+' } } diff --git a/android/src/main/java/sq/flutter/tflite/TflitePlugin.java b/android/src/main/java/sq/flutter/tflite/TflitePlugin.java index 2579d35..29aa626 100644 --- a/android/src/main/java/sq/flutter/tflite/TflitePlugin.java +++ b/android/src/main/java/sq/flutter/tflite/TflitePlugin.java @@ -1,5 +1,6 @@ package sq.flutter.tflite; +import android.app.Activity; import android.content.Context; import android.content.res.AssetFileDescriptor; import android.content.res.AssetManager; @@ -19,11 +20,20 @@ import android.renderscript.Type; import android.util.Log; +import androidx.annotation.NonNull; + +import io.flutter.FlutterInjector; +import io.flutter.embedding.android.FlutterActivity; +import io.flutter.embedding.engine.loader.FlutterLoader; +import io.flutter.embedding.engine.plugins.activity.ActivityAware; +import io.flutter.embedding.engine.plugins.FlutterPlugin; +import io.flutter.embedding.engine.plugins.activity.ActivityPluginBinding; +import io.flutter.plugin.common.BinaryMessenger; import io.flutter.plugin.common.MethodCall; import io.flutter.plugin.common.MethodChannel; import io.flutter.plugin.common.MethodChannel.MethodCallHandler; import io.flutter.plugin.common.MethodChannel.Result; -import io.flutter.plugin.common.PluginRegistry.Registrar; +import io.flutter.plugin.common.PluginRegistry; import org.tensorflow.lite.DataType; import org.tensorflow.lite.Interpreter; @@ -52,8 +62,8 @@ import java.util.Vector; -public class TflitePlugin implements MethodCallHandler { - private final Registrar mRegistrar; +public class TflitePlugin implements FlutterPlugin, MethodCallHandler, ActivityAware { + private Activity activity; private Interpreter tfLite; private boolean tfLiteBusy = false; private int inputSize = 0; @@ -82,17 +92,41 @@ public class TflitePlugin implements MethodCallHandler { List parentToChildEdges = new ArrayList<>(); List childToParentEdges = new ArrayList<>(); - public static void registerWith(Registrar registrar) { - final MethodChannel channel = new MethodChannel(registrar.messenger(), "tflite"); - channel.setMethodCallHandler(new TflitePlugin(registrar)); + private MethodChannel channel; + + @Override + public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { + channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), "ddd"); + channel.setMethodCallHandler(this); + } + + @Override + public void onAttachedToActivity(@NonNull ActivityPluginBinding binding) { + activity = binding.getActivity(); + } + + @Override + public void onDetachedFromActivity() { + activity = null; + } + + @Override + public void onDetachedFromActivityForConfigChanges() { + this.onDetachedFromActivity(); } - private TflitePlugin(Registrar registrar) { - this.mRegistrar = registrar; + @Override + public void onReattachedToActivityForConfigChanges(@NonNull ActivityPluginBinding binding) { + this.onAttachedToActivity(binding); + } + + @Override + public void onDetachedFromEngine(@NonNull FlutterPluginBinding binding) { + channel.setMethodCallHandler(null); } @Override - public void onMethodCall(MethodCall call, Result result) { + public void onMethodCall(@NonNull MethodCall call, @NonNull Result result) { if (call.method.equals("loadModel")) { try { String res = loadModel((HashMap) call.arguments); @@ -205,8 +239,9 @@ private String loadModel(HashMap args) throws IOException { String key = null; AssetManager assetManager = null; if (isAsset) { - assetManager = mRegistrar.context().getAssets(); - key = mRegistrar.lookupKeyForAsset(model); + assetManager = activity.getApplicationContext().getAssets(); + FlutterLoader loader = FlutterInjector.instance().flutterLoader(); + key = loader.getLookupKeyForAsset(model); AssetFileDescriptor fileDescriptor = assetManager.openFd(key); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); @@ -238,7 +273,8 @@ private String loadModel(HashMap args) throws IOException { if (labels.length() > 0) { if (isAsset) { - key = mRegistrar.lookupKeyForAsset(labels); + FlutterLoader loader = FlutterInjector.instance().flutterLoader(); + key = loader.getLookupKeyForAsset(labels); loadLabels(assetManager, key); } else { loadLabels(null, labels); @@ -411,7 +447,7 @@ ByteBuffer feedInputTensorFrame(List bytesList, int imageHeight, int ima Bitmap bitmapRaw = Bitmap.createBitmap(imageWidth, imageHeight, Bitmap.Config.ARGB_8888); Allocation bmData = renderScriptNV21ToRGBA888( - mRegistrar.context(), + activity.getApplicationContext(), imageWidth, imageHeight, data); From 1d6e4369bd29d6d52deb01dc31e5a2a53ca30256 Mon Sep 17 00:00:00 2001 From: zhengbomo Date: Sat, 4 Dec 2021 16:50:34 +0800 Subject: [PATCH 2/3] update channel name --- android/src/main/java/sq/flutter/tflite/TflitePlugin.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/android/src/main/java/sq/flutter/tflite/TflitePlugin.java b/android/src/main/java/sq/flutter/tflite/TflitePlugin.java index 29aa626..54582c1 100644 --- a/android/src/main/java/sq/flutter/tflite/TflitePlugin.java +++ b/android/src/main/java/sq/flutter/tflite/TflitePlugin.java @@ -96,7 +96,7 @@ public class TflitePlugin implements FlutterPlugin, MethodCallHandler, ActivityA @Override public void onAttachedToEngine(@NonNull FlutterPluginBinding flutterPluginBinding) { - channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), "ddd"); + channel = new MethodChannel(flutterPluginBinding.getBinaryMessenger(), "tflite"); channel.setMethodCallHandler(this); } From 08c3df135fd5dbb6de21482d6247ec9ce7f81ee4 Mon Sep 17 00:00:00 2001 From: bomo Date: Thu, 16 Dec 2021 21:59:54 +0800 Subject: [PATCH 3/3] rename global variable name --- ios/Classes/TflitePlugin.mm | 104 ++++++++++++++++++------------------ 1 file changed, 52 insertions(+), 52 deletions(-) diff --git a/ios/Classes/TflitePlugin.mm b/ios/Classes/TflitePlugin.mm index b1871b3..049891a 100644 --- a/ios/Classes/TflitePlugin.mm +++ b/ios/Classes/TflitePlugin.mm @@ -1116,13 +1116,13 @@ void runSegmentationOnFrame(NSDictionary* args, FlutterResult result) { } -NSArray* part_names = @[ +NSArray* _tflite_part_names = @[ @"nose", @"leftEye", @"rightEye", @"leftEar", @"rightEar", @"leftShoulder", @"rightShoulder", @"leftElbow", @"rightElbow", @"leftWrist", @"rightWrist", @"leftHip", @"rightHip", @"leftKnee", @"rightKnee", @"leftAnkle", @"rightAnkle" ]; -NSArray* pose_chain = @[ +NSArray* _tflite_pose_chain = @[ @[@"nose", @"leftEye"], @[@"leftEye", @"leftEar"], @[@"nose", @"rightEye"], @[@"rightEye", @"rightEar"], @[@"nose", @"leftShoulder"], @[@"leftShoulder", @"leftElbow"], @[@"leftElbow", @"leftWrist"], @@ -1133,23 +1133,23 @@ void runSegmentationOnFrame(NSDictionary* args, FlutterResult result) { @[@"rightKnee", @"rightAnkle"] ]; -NSMutableDictionary* parts_ids = [NSMutableDictionary dictionary]; -NSMutableArray* parent_to_child_edges = [NSMutableArray array]; -NSMutableArray* child_to_parent_edges = [NSMutableArray array]; -int local_maximum_radius = 1; -int output_stride = 16; -int height; -int width; -int num_keypoints; +NSMutableDictionary* _tflite_parts_ids = [NSMutableDictionary dictionary]; +NSMutableArray* _tflite_parent_to_child_edges = [NSMutableArray array]; +NSMutableArray* _tflite_child_to_parent_edges = [NSMutableArray array]; +int _tflite_local_maximum_radius = 1; +int _tflite_output_stride = 16; +int _tflite_height; +int _tflite_width; +int _tflite_num_keypoints; void initPoseNet() { - if ([parts_ids count] == 0) { - for (int i = 0; i < [part_names count]; ++i) - [parts_ids setValue:[NSNumber numberWithInt:i] forKey:part_names[i]]; + if ([_tflite_parts_ids count] == 0) { + for (int i = 0; i < [_tflite_part_names count]; ++i) + [_tflite_parts_ids setValue:[NSNumber numberWithInt:i] forKey:_tflite_part_names[i]]; - for (int i = 0; i < [pose_chain count]; ++i) { - [parent_to_child_edges addObject:parts_ids[pose_chain[i][1]]]; - [child_to_parent_edges addObject:parts_ids[pose_chain[i][0]]]; + for (int i = 0; i < [_tflite_pose_chain count]; ++i) { + [_tflite_parent_to_child_edges addObject:_tflite_parts_ids[_tflite_pose_chain[i][1]]]; + [_tflite_child_to_parent_edges addObject:_tflite_parts_ids[_tflite_pose_chain[i][0]]]; } } } @@ -1163,12 +1163,12 @@ bool scoreIsMaximumInLocalWindow(int keypoint_id, bool local_maxium = true; int y_start = MAX(heatmap_y - local_maximum_radius, 0); - int y_end = MIN(heatmap_y + local_maximum_radius + 1, height); + int y_end = MIN(heatmap_y + local_maximum_radius + 1, _tflite_height); for (int y_current = y_start; y_current < y_end; ++y_current) { int x_start = MAX(heatmap_x - local_maximum_radius, 0); - int x_end = MIN(heatmap_x + local_maximum_radius + 1, width); + int x_end = MIN(heatmap_x + local_maximum_radius + 1, _tflite_width); for (int x_current = x_start; x_current < x_end; ++x_current) { - if (sigmoid(scores[(y_current * width + x_current) * num_keypoints + keypoint_id]) > score) { + if (sigmoid(scores[(y_current * _tflite_width + x_current) * _tflite_num_keypoints + keypoint_id]) > score) { local_maxium = false; break; } @@ -1188,11 +1188,11 @@ PriorityQueue buildPartWithScoreQueue(float* scores, float threshold, int local_maximum_radius) { PriorityQueue pq; - for (int heatmap_y = 0; heatmap_y < height; ++heatmap_y) { - for (int heatmap_x = 0; heatmap_x < width; ++heatmap_x) { - for (int keypoint_id = 0; keypoint_id < num_keypoints; ++keypoint_id) { - float score = sigmoid(scores[(heatmap_y * width + heatmap_x) * - num_keypoints + keypoint_id]); + for (int heatmap_y = 0; heatmap_y < _tflite_height; ++heatmap_y) { + for (int heatmap_x = 0; heatmap_x < _tflite_width; ++heatmap_x) { + for (int keypoint_id = 0; keypoint_id < _tflite_num_keypoints; ++keypoint_id) { + float score = sigmoid(scores[(heatmap_y * _tflite_width + heatmap_x) * + _tflite_num_keypoints + keypoint_id]); if (score < threshold) continue; if (scoreIsMaximumInLocalWindow(keypoint_id, score, heatmap_y, heatmap_x, @@ -1217,11 +1217,11 @@ void getImageCoords(float* res, int heatmap_x = [keypoint[@"x"] intValue]; int keypoint_id = [keypoint[@"partId"] intValue]; - int offset = (heatmap_y * width + heatmap_x) * num_keypoints * 2 + keypoint_id; + int offset = (heatmap_y * _tflite_width + heatmap_x) * _tflite_num_keypoints * 2 + keypoint_id; float offset_y = offsets[offset]; - float offset_x = offsets[offset + num_keypoints]; - res[0] = heatmap_y * output_stride + offset_y; - res[1] = heatmap_x * output_stride + offset_x; + float offset_x = offsets[offset + _tflite_num_keypoints]; + res[0] = heatmap_y * _tflite_output_stride + offset_y; + res[1] = heatmap_x * _tflite_output_stride + offset_x; } @@ -1244,19 +1244,19 @@ bool withinNmsRadiusOfCorrespondingPoint(NSMutableArray* poses, } void getStridedIndexNearPoint(int* res, float _y, float _x) { - int y_ = round(_y / output_stride); - int x_ = round(_x / output_stride); - int y = y_ < 0 ? 0 : y_ > height - 1 ? height - 1 : y_; - int x = x_ < 0 ? 0 : x_ > width - 1 ? width - 1 : x_; + int y_ = round(_y / _tflite_output_stride); + int x_ = round(_x / _tflite_output_stride); + int y = y_ < 0 ? 0 : y_ > _tflite_height - 1 ? _tflite_height - 1 : y_; + int x = x_ < 0 ? 0 : x_ > _tflite_width - 1 ? _tflite_width - 1 : x_; res[0] = y; res[1] = x; } void getDisplacement(float* res, int edgeId, int* keypoint, float* displacements) { - int num_edges = (int)[parent_to_child_edges count]; + int num_edges = (int)[_tflite_parent_to_child_edges count]; int y = keypoint[0]; int x = keypoint[1]; - int offset = (y * width + x) * num_edges * 2 + edgeId; + int offset = (y * _tflite_width + x) * num_edges * 2 + edgeId; res[0] = displacements[offset]; res[1] = displacements[offset + num_edges]; } @@ -1265,7 +1265,7 @@ float getInstanceScore(NSMutableDictionary* keypoints) { float scores = 0; for (NSMutableDictionary* keypoint in keypoints.allValues) scores += [keypoint[@"score"] floatValue]; - return scores / num_keypoints; + return scores / _tflite_num_keypoints; } NSMutableDictionary* traverseToTargetKeypoint(int edge_id, @@ -1298,25 +1298,25 @@ float getInstanceScore(NSMutableDictionary* keypoints) { int target_keypoint_y = target_keypoint_indices[0]; int target_keypoint_x = target_keypoint_indices[1]; - int offset = (target_keypoint_y * width + target_keypoint_x) * num_keypoints * 2 + target_keypoint_id; + int offset = (target_keypoint_y * _tflite_width + target_keypoint_x) * _tflite_num_keypoints * 2 + target_keypoint_id; float offset_y = offsets[offset]; - float offset_x = offsets[offset + num_keypoints]; + float offset_x = offsets[offset + _tflite_num_keypoints]; - target_keypoint[0] = target_keypoint_y * output_stride + offset_y; - target_keypoint[1] = target_keypoint_x * output_stride + offset_x; + target_keypoint[0] = target_keypoint_y * _tflite_output_stride + offset_y; + target_keypoint[1] = target_keypoint_x * _tflite_output_stride + offset_x; } int target_keypoint_indices[2]; getStridedIndexNearPoint(target_keypoint_indices, target_keypoint[0], target_keypoint[1]); - float score = sigmoid(scores[(target_keypoint_indices[0] * width + - target_keypoint_indices[1]) * num_keypoints + target_keypoint_id]); + float score = sigmoid(scores[(target_keypoint_indices[0] * _tflite_width + + target_keypoint_indices[1]) * _tflite_num_keypoints + target_keypoint_id]); NSMutableDictionary* keypoint = [NSMutableDictionary dictionary]; [keypoint setValue:[NSNumber numberWithFloat:score] forKey:@"score"]; [keypoint setValue:[NSNumber numberWithFloat:target_keypoint[0] / input_size] forKey:@"y"]; [keypoint setValue:[NSNumber numberWithFloat:target_keypoint[1] / input_size] forKey:@"x"]; - [keypoint setValue:part_names[target_keypoint_id] forKey:@"part"]; + [keypoint setValue:_tflite_part_names[target_keypoint_id] forKey:@"part"]; return keypoint; } @@ -1330,9 +1330,9 @@ float getInstanceScore(NSMutableDictionary* keypoints) { assert(interpreter->outputs().size() == 4); TfLiteTensor* scores_tensor = interpreter->tensor(interpreter->outputs()[0]); #endif - height = scores_tensor->dims->data[1]; - width = scores_tensor->dims->data[2]; - num_keypoints = scores_tensor->dims->data[3]; + _tflite_height = scores_tensor->dims->data[1]; + _tflite_width = scores_tensor->dims->data[2]; + _tflite_num_keypoints = scores_tensor->dims->data[3]; #ifdef TFLITE2 float* scores = TfLiteInterpreterGetOutputTensor(interpreter, 0)->data.f; @@ -1345,9 +1345,9 @@ float getInstanceScore(NSMutableDictionary* keypoints) { float* displacements_fwd = interpreter->typed_output_tensor(2); float* displacements_bwd = interpreter->typed_output_tensor(3); #endif - PriorityQueue pq = buildPartWithScoreQueue(scores, threshold, local_maximum_radius); + PriorityQueue pq = buildPartWithScoreQueue(scores, threshold, _tflite_local_maximum_radius); - int num_edges = (int)[parent_to_child_edges count]; + int num_edges = (int)[_tflite_parent_to_child_edges count]; int sqared_nms_radius = nms_radius * nms_radius; NSMutableArray* results = [NSMutableArray array]; @@ -1367,14 +1367,14 @@ float getInstanceScore(NSMutableDictionary* keypoints) { [keypoint setValue:[NSNumber numberWithFloat:[root[@"score"] floatValue]] forKey:@"score"]; [keypoint setValue:[NSNumber numberWithFloat:root_point[0] / input_size] forKey:@"y"]; [keypoint setValue:[NSNumber numberWithFloat:root_point[1] / input_size] forKey:@"x"]; - [keypoint setValue:part_names[[root[@"partId"] intValue]] forKey:@"part"]; + [keypoint setValue:_tflite_part_names[[root[@"partId"] intValue]] forKey:@"part"]; NSMutableDictionary* keypoints = [NSMutableDictionary dictionary]; [keypoints setObject:keypoint forKey:root[@"partId"]]; for (int edge = num_edges - 1; edge >= 0; --edge) { - int source_keypoint_id = [parent_to_child_edges[edge] intValue]; - int target_keypoint_id = [child_to_parent_edges[edge] intValue]; + int source_keypoint_id = [_tflite_parent_to_child_edges[edge] intValue]; + int target_keypoint_id = [_tflite_child_to_parent_edges[edge] intValue]; if (keypoints[[NSNumber numberWithInt:source_keypoint_id]] && !(keypoints[[NSNumber numberWithInt:target_keypoint_id]])) { keypoint = traverseToTargetKeypoint(edge, keypoints[[NSNumber numberWithInt:source_keypoint_id]], @@ -1384,8 +1384,8 @@ float getInstanceScore(NSMutableDictionary* keypoints) { } for (int edge = 0; edge < num_edges; ++edge) { - int source_keypoint_id = [child_to_parent_edges[edge] intValue]; - int target_keypoint_id = [parent_to_child_edges[edge] intValue]; + int source_keypoint_id = [_tflite_child_to_parent_edges[edge] intValue]; + int target_keypoint_id = [_tflite_parent_to_child_edges[edge] intValue]; if (keypoints[[NSNumber numberWithInt:source_keypoint_id]] && !(keypoints[[NSNumber numberWithInt:target_keypoint_id]])) { keypoint = traverseToTargetKeypoint(edge, keypoints[[NSNumber numberWithInt:source_keypoint_id]],