-
Notifications
You must be signed in to change notification settings - Fork 121
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Edge] Merge pull request #198 from continue-revolution/master
Android TFLite Support
- Loading branch information
Showing
46 changed files
with
2,359 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,13 @@ | |
|
||
This document contains explanation and instruction of aggregation for mobiles. | ||
|
||
An example android aggregator accompanied by an [sample android app](https://github.com/SymbioticLab/FedScale/fedscale/cloud/execution/android). The android app has [MNN](https://github.com/alibaba/MNN) backend support. | ||
An example android aggregator accompanied by | ||
- [MNN](https://github.com/SymbioticLab/FedScale/fedscale/edge/mnn/). The android app has [MNN](https://github.com/alibaba/MNN) backend support for training and testing. | ||
- [TFLite](https://github.com/SymbioticLab/FedScale/fedscale/edge/tflite/). The android app has [TFLite](https://www.tensorflow.org/lite) backend support for training and testing. | ||
|
||
`fedscale/cloud/aggregation/android_aggregator.py` contains an inherited version of aggregator. While keeping all functionalities of the original [aggregator](https://github.com/SymbioticLab/FedScale/blob/master/fedscale/cloud/aggregation/aggregator.py), it adds support to do bijective conversion between PyTorch model and MNN model. It uses JSON to communicate with android client. | ||
## MNN | ||
|
||
`fedscale/cloud/aggregation/aggregator_mnn.py` contains an inherited version of aggregator. While keeping all functionalities of the original [aggregator](https://github.com/SymbioticLab/FedScale/blob/master/fedscale/cloud/aggregation/aggregator.py), it adds support to do bijective conversion between PyTorch model and MNN model. It uses JSON to communicate with android client. | ||
|
||
**Note**: | ||
MNN does not support direct conversion from MNN to PyTorch model, so we did a manual conversion from MNN to JSON, then from JSON to PyTorch model. We currently only support Convolution (including Linear) and BatchNorm conversion. We welcome contribution to support more conversion for operators with trainable parameters. | ||
|
@@ -22,9 +26,28 @@ cd FedScale | |
source install.sh | ||
pip install -e . | ||
cd fedscale/cloud/aggregation | ||
python3 android_aggregator.py --experiment_mode=mobile --num_participants=1 --model=linear | ||
python3 aggregator_mnn.py --experiment_mode=mobile --num_participants=1 --model=linear | ||
``` | ||
and configure your android app according to the [tutorial](https://github.com/SymbioticLab/FedScale/fedscale/edge/mnn/README.md). | ||
|
||
## TFLite | ||
|
||
`fedscale/cloud/aggregation/aggregator_tflite.py` contains an inherited version of aggregator. While keeping all functionalities of the original [aggregator](https://github.com/SymbioticLab/FedScale/blob/master/fedscale/cloud/aggregation/aggregator.py), it adds support to do bijective conversion between tensorflow model and TFLite model. | ||
|
||
`fedscale/utils/models/tflite_model_provider.py` contains a simple linear model with Flatten->Dense->Dense, used for simple test of our sample android app. Please feel free to contribute to it and add more models. | ||
|
||
`fedscale/cloud/internal/tflite_model_adapter.py` defer from `fedscale/cloud/internal/tensorflow_model_adapter.py` in the way that TFLite adapter skip layers without weights, such as Flatten. | ||
|
||
In order to run this aggregator with default setting in order to test sample app, please run | ||
``` | ||
git clone https://github.com/SymbioticLab/FedScale.git | ||
cd FedScale | ||
source install.sh | ||
pip install -e . | ||
cd fedscale/cloud/aggregation | ||
python3 aggregator_tflite.py --experiment_mode=mobile --num_participants=1 --engine=tensorflow | ||
``` | ||
and configure your android app according to the [tutorial](https://github.com/SymbioticLab/FedScale/fedscale/cloud/execution/android/README.md). | ||
and configure your android app according to the [tutorial](https://github.com/SymbioticLab/FedScale/fedscale/edge/tflite/README.md). | ||
|
||
--- | ||
If you need any other help, feel free to contact FedScale team or the developer [website](https://continue-revolution.github.io) [email](mailto:[email protected]) of this android aggregator. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import numpy as np | ||
|
||
import fedscale.cloud.config_parser as parser | ||
from fedscale.cloud.aggregation.aggregator import Aggregator | ||
from fedscale.cloud.internal.tflite_model_adapter import TFLiteModelAdapter | ||
from fedscale.utils.models.tflite_model_provider import * | ||
|
||
|
||
class TFLiteAggregator(Aggregator): | ||
"""This aggregator collects training/testing feedbacks from Android TFLite APPs. | ||
Args: | ||
args (dictionary): Variable arguments for FedScale runtime config. | ||
Defaults to the setup in arg_parser.py. | ||
""" | ||
|
||
def __init__(self, args): | ||
super().__init__(args) | ||
self.tflite_model = None | ||
|
||
def init_model(self): | ||
""" | ||
Load the model architecture and convert to TFLite. | ||
""" | ||
self.model_wrapper = TFLiteModelAdapter( | ||
build_simple_linear(self.args)) | ||
self.tflite_model = convert_and_save( | ||
TFLiteModel(self.model_wrapper.get_model())) | ||
self.model_weights = self.model_wrapper.get_weights() | ||
|
||
def update_weight_aggregation(self, update_weights): | ||
""" | ||
Update model when the round completes. | ||
Then convert new model to TFLite. | ||
Args: | ||
update_weights (list): A list of global model weight in last round. | ||
""" | ||
super().update_weight_aggregation(update_weights) | ||
if self.model_in_update == self.tasks_round: | ||
self.tflite_model = convert_and_save( | ||
TFLiteModel(self.model_wrapper.get_model())) | ||
|
||
def deserialize_response(self, responses): | ||
""" | ||
Deserialize the response from executor. | ||
If the response contains mnn json model, convert to pytorch state_dict. | ||
Args: | ||
responses (byte stream): Serialized response from executor. | ||
Returns: | ||
string, bool, or bytes: The deserialized response object from executor. | ||
""" | ||
data = super().deserialize_response(responses) | ||
if "update_weight" in data: | ||
path = f'cache/{data["client_id"]}.ckpt' | ||
with open(path, 'wb') as model_file: | ||
model_file.write(data["update_weight"]) | ||
restored_tensors = [ | ||
np.asarray(tf.raw_ops.Restore(file_pattern=path, tensor_name=var.name, | ||
dt=var.dtype, name='restore')) for var in self.model_wrapper.get_model().weights if var.trainable] | ||
os.remove(path) | ||
data["update_weight"] = restored_tensors | ||
return data | ||
|
||
def serialize_response(self, responses): | ||
""" | ||
Serialize the response to send to server upon assigned job completion. | ||
If the responses is the pytorch model, change it to mnn_json. | ||
Args: | ||
responses (ServerResponse): Serialized response from server. | ||
Returns: | ||
bytes: The serialized response object to server. | ||
""" | ||
if type(responses) is list: | ||
responses = self.tflite_model | ||
return super().serialize_response(responses) | ||
|
||
|
||
if __name__ == "__main__": | ||
aggregator = TFLiteAggregator(parser.args) | ||
aggregator.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from typing import List | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from fedscale.cloud.internal.model_adapter_base import ModelAdapterBase | ||
|
||
|
||
class TFLiteModelAdapter(ModelAdapterBase): | ||
def __init__(self, model: tf.keras.Model): | ||
self.model = model | ||
|
||
def set_weights(self, weights: List[np.ndarray]): | ||
for var, weight in zip(self.model.weights, weights): | ||
var.assign(weight) | ||
|
||
def get_weights(self) -> List[np.ndarray]: | ||
return [np.asarray(var.read_value()) for var in self.model.weights if var.trainable] | ||
|
||
def get_model(self): | ||
return self.model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
assets/*train* | ||
assets/TrainSet | ||
assets/*test* | ||
assets/TestSet | ||
assets/*tflite* | ||
.idea | ||
.gradle | ||
app/build |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Android TFLite Sample App | ||
|
||
This directory contains minimum files modified from [MNN Android Demo](https://github.com/alibaba/MNN/tree/master/project/android/demo) and [TFLite Android Demo](https://www.tensorflow.org/lite/examples/on_device_training/overview). The training and testing will be conducted by TFLite backend, while the task execution and communication with server will be managed by Java. The sample has been tested upon image classification with a simple linear model and a small subset of [ImageNet-MINI](https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000). This documentation contains a step-by-step tutorial on how to download, build and config this app on your own device, and modify this app for your own implementation and deployment. | ||
|
||
## Download and build sample android app | ||
|
||
1. Download and unzip [sample dataset (TrainTest.zip)](https://drive.google.com/file/d/1nfi3SVzjaE0LPxwj_5DNdqi6rK7BU8kb/view?usp=sharing) to `assets/` directory. Remove `TrainTest.zip` after unzip to save space on your mobile device. After unzip, you should see 3 files and 2 directories under `assets/`: | ||
1. `TrainSet`: Training set directory, contains 316 images. | ||
2. `TestSet`: Testing set directory, contains 34 images. | ||
3. `conf.json`: Configuration file for mobile app. | ||
4. `train_labels.txt`: Training label file with format `<filename> <label>`, where `<filename>` is the path after `TrainSet/`. | ||
5. `test_labels.txt`: Testing label file with the same format as `train_labels.txt`. | ||
2. Install [Android Studio](https://developer.android.com/studio) and open project `fedscale/edge/tflite`. Download necessary SDKs, NDKs and CMake when prompted. My version: | ||
- SDK: API 32 | ||
- Android Gradle Plugin Version: 3.5.3 | ||
- Gradle Version: 5.4.1 | ||
- Source Compatibility: Java 8 | ||
- Target Compatibility: Java 8 | ||
3. Make project. Android Studio will compile and build the app for you. | ||
|
||
## Test this app with default setting | ||
|
||
1. ssh to your own server and run | ||
``` | ||
cd fedscale/cloud/aggregation/android | ||
python3 aggregator_tflite.py --experiment_mode=mobile --num_participants=1 --engine=tensorflow | ||
``` | ||
2. Change aggregator IP address inside `assets/conf.json` and click `Run` inside Android Studio. | ||
|
||
## Customize your own app | ||
|
||
1. If you want to use your own dataset, please put your data under `assets/TrainSet` and `assets/TestSet`, make sure that your label has the same format as my label file. | ||
1. If you want to change the file/dir name under `assets`, please make sure to change the corresponding config in `assets` attribute inside `assets/conf.json`. | ||
2. If you want to use your own model for **image classification**, please either change `channel`, `width` and `height` inside `assets/conf.json` to your own input and change `num_classes` to your own classes, or override these attributes when sending `CLIENT_TRAIN` request. | ||
3. If you want to use your own model for tasks other than image classification, you may need to write your own TFLite trainer and tester. Please refer to [TFLite](https://www.tensorflow.org/lite/api_docs) for further development guide. You may also need to change `channel`, `width` and `height` inside `assets/conf.json` to your own input and change or remove `num_classes`. | ||
|
||
---- | ||
If you need any other help, feel free to contact FedScale team or the developer [website](https://continue-revolution.github.io) [email](mailto:[email protected]) of this app. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
buildscript { | ||
repositories { | ||
maven { | ||
url "https://plugins.gradle.org/m2/" | ||
} | ||
} | ||
dependencies { | ||
classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.10" | ||
} | ||
} | ||
|
||
apply plugin: 'com.android.application' | ||
apply plugin: 'com.google.protobuf' | ||
|
||
android { | ||
|
||
aaptOptions { | ||
noCompress "tflite" | ||
} | ||
|
||
compileSdkVersion 32 | ||
compileOptions { | ||
sourceCompatibility JavaVersion.VERSION_1_8 | ||
targetCompatibility JavaVersion.VERSION_1_8 | ||
} | ||
|
||
defaultConfig { | ||
applicationId "com.fedscale.android.executor" | ||
minSdkVersion 23 | ||
targetSdkVersion 32 | ||
multiDexEnabled true | ||
versionCode 1 | ||
versionName "1.0" | ||
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" | ||
} | ||
|
||
buildTypes { | ||
release { | ||
minifyEnabled false | ||
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro' | ||
} | ||
} | ||
sourceSets { | ||
main { | ||
assets { | ||
srcDirs = ["../assets"] | ||
} | ||
} | ||
} | ||
// dataBinding { | ||
// enabled = true | ||
// } | ||
packagingOptions { | ||
exclude 'META-INF/com.android.tools/proguard/coroutines.pro' | ||
} | ||
} | ||
|
||
dependencies { | ||
|
||
implementation fileTree(dir: 'libs', include: ['*.jar']) | ||
|
||
// App compat and UI things | ||
implementation 'androidx.appcompat:appcompat:1.5.0' | ||
implementation 'com.google.android.material:material:1.6.1' | ||
implementation 'androidx.constraintlayout:constraintlayout:2.1.4' | ||
|
||
// You need to build grpc-java to obtain these libraries below. | ||
implementation 'javax.annotation:javax.annotation-api:1.3.2' | ||
implementation 'io.grpc:grpc-okhttp:1.49.0' // CURRENT_GRPC_VERSION | ||
implementation 'io.grpc:grpc-protobuf-lite:1.49.0' // CURRENT_GRPC_VERSION | ||
implementation 'io.grpc:grpc-stub:1.49.0' // CURRENT_GRPC_VERSION | ||
implementation 'net.razorvine:pickle:1.3' | ||
|
||
//WindowManager | ||
implementation 'androidx.window:window:1.1.0-alpha03' | ||
|
||
// Unit testing | ||
testImplementation 'junit:junit:4.13.2' | ||
|
||
// Instrumented testing | ||
androidTestImplementation 'androidx.test.ext:junit:1.1.3' | ||
androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0' | ||
|
||
// Tensorflow lite dependencies | ||
implementation 'org.tensorflow:tensorflow-lite:2.9.0' | ||
implementation 'org.tensorflow:tensorflow-lite-gpu:2.9.0' | ||
implementation 'org.tensorflow:tensorflow-lite-support:0.4.2' | ||
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.9.0' | ||
} | ||
|
||
protobuf { | ||
protoc { artifact = 'com.google.protobuf:protoc:3.21.1' } | ||
plugins { | ||
grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.49.0' // CURRENT_GRPC_VERSION | ||
} | ||
} | ||
generateProtoTasks { | ||
all().each { task -> | ||
task.builtins { | ||
java { option 'lite' } | ||
} | ||
task.plugins { | ||
grpc { // Options added to --grpc_out | ||
option 'lite' } | ||
} | ||
} | ||
} | ||
} |
Oops, something went wrong.