Skip to content

Commit

Permalink
[Edge] Merge pull request #198 from continue-revolution/master
Browse files Browse the repository at this point in the history
Android TFLite Support
  • Loading branch information
fanlai0990 committed Feb 10, 2023
2 parents 62af2cd + 31e3249 commit 2014eef
Show file tree
Hide file tree
Showing 46 changed files with 2,359 additions and 18 deletions.
31 changes: 27 additions & 4 deletions fedscale/cloud/aggregation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

This document contains explanation and instruction of aggregation for mobiles.

An example android aggregator accompanied by an [sample android app](https://github.com/SymbioticLab/FedScale/fedscale/cloud/execution/android). The android app has [MNN](https://github.com/alibaba/MNN) backend support.
An example android aggregator accompanied by
- [MNN](https://github.com/SymbioticLab/FedScale/fedscale/edge/mnn/). The android app has [MNN](https://github.com/alibaba/MNN) backend support for training and testing.
- [TFLite](https://github.com/SymbioticLab/FedScale/fedscale/edge/tflite/). The android app has [TFLite](https://www.tensorflow.org/lite) backend support for training and testing.

`fedscale/cloud/aggregation/android_aggregator.py` contains an inherited version of aggregator. While keeping all functionalities of the original [aggregator](https://github.com/SymbioticLab/FedScale/blob/master/fedscale/cloud/aggregation/aggregator.py), it adds support to do bijective conversion between PyTorch model and MNN model. It uses JSON to communicate with android client.
## MNN

`fedscale/cloud/aggregation/aggregator_mnn.py` contains an inherited version of aggregator. While keeping all functionalities of the original [aggregator](https://github.com/SymbioticLab/FedScale/blob/master/fedscale/cloud/aggregation/aggregator.py), it adds support to do bijective conversion between PyTorch model and MNN model. It uses JSON to communicate with android client.

**Note**:
MNN does not support direct conversion from MNN to PyTorch model, so we did a manual conversion from MNN to JSON, then from JSON to PyTorch model. We currently only support Convolution (including Linear) and BatchNorm conversion. We welcome contribution to support more conversion for operators with trainable parameters.
Expand All @@ -22,9 +26,28 @@ cd FedScale
source install.sh
pip install -e .
cd fedscale/cloud/aggregation
python3 android_aggregator.py --experiment_mode=mobile --num_participants=1 --model=linear
python3 aggregator_mnn.py --experiment_mode=mobile --num_participants=1 --model=linear
```
and configure your android app according to the [tutorial](https://github.com/SymbioticLab/FedScale/fedscale/edge/mnn/README.md).

## TFLite

`fedscale/cloud/aggregation/aggregator_tflite.py` contains an inherited version of aggregator. While keeping all functionalities of the original [aggregator](https://github.com/SymbioticLab/FedScale/blob/master/fedscale/cloud/aggregation/aggregator.py), it adds support to do bijective conversion between tensorflow model and TFLite model.

`fedscale/utils/models/tflite_model_provider.py` contains a simple linear model with Flatten->Dense->Dense, used for simple test of our sample android app. Please feel free to contribute to it and add more models.

`fedscale/cloud/internal/tflite_model_adapter.py` defer from `fedscale/cloud/internal/tensorflow_model_adapter.py` in the way that TFLite adapter skip layers without weights, such as Flatten.

In order to run this aggregator with default setting in order to test sample app, please run
```
git clone https://github.com/SymbioticLab/FedScale.git
cd FedScale
source install.sh
pip install -e .
cd fedscale/cloud/aggregation
python3 aggregator_tflite.py --experiment_mode=mobile --num_participants=1 --engine=tensorflow
```
and configure your android app according to the [tutorial](https://github.com/SymbioticLab/FedScale/fedscale/cloud/execution/android/README.md).
and configure your android app according to the [tutorial](https://github.com/SymbioticLab/FedScale/fedscale/edge/tflite/README.md).

---
If you need any other help, feel free to contact FedScale team or the developer [website](https://continue-revolution.github.io) [email](mailto:[email protected]) of this android aggregator.
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from fedscale.cloud.aggregation.aggregator import Aggregator
from fedscale.utils.models.simple.linear_model import LinearModel
from fedscale.utils.models.mnn_convert import *
from fedscale.cloud.internal.torch_model_adapter import TorchModelAdapter


class Android_Aggregator(Aggregator):
class MNNAggregator(Aggregator):
"""This aggregator collects training/testing feedbacks from Android MNN APPs.
Args:
Expand All @@ -27,24 +28,24 @@ def init_model(self):
NOTE: MNN does not support dropout.
"""
if self.args.model == 'linear':
self.model = LinearModel()
self.model_weights = self.model.state_dict()
self.model_wrapper = TorchModelAdapter(LinearModel())
self.model_weights = self.model_wrapper.get_weights()
else:
super().init_model()
self.mnn_json = torch_to_mnn(self.model, self.input_shape, True)
self.keymap_mnn2torch = init_keymap(self.model_weights, self.mnn_json)
self.mnn_json = torch_to_mnn(self.model_wrapper.get_model(), self.input_shape, True)
self.keymap_mnn2torch = init_keymap(self.model_wrapper.get_model().state_dict(), self.mnn_json)

def round_weight_handler(self, last_model):
def update_weight_aggregation(self, update_weights):
"""
Update model when the round completes.
Then convert new model to mnn json.
Args:
last_model (list): A list of global model weight in last round.
"""
super().round_weight_handler(last_model)
if self.round > 1:
self.mnn_json = torch_to_mnn(self.model, self.input_shape)
super().update_weight_aggregation(update_weights)
if self.model_in_update == self.tasks_round:
self.mnn_json = torch_to_mnn(self.model_wrapper.get_model(), self.input_shape)

def deserialize_response(self, responses):
"""
Expand Down Expand Up @@ -75,12 +76,12 @@ def serialize_response(self, responses):
Returns:
bytes: The serialized response object to server.
"""
if responses == self.model:
if type(responses) is list and all([np.array_equal(a, b) for a, b in zip(responses, self.model_wrapper.get_weights())]):
responses = self.mnn_json
data = json.dumps(responses)
return data.encode('utf-8')


if __name__ == "__main__":
aggregator = Android_Aggregator(parser.args)
aggregator = MNNAggregator(parser.args)
aggregator.run()
85 changes: 85 additions & 0 deletions fedscale/cloud/aggregation/aggregator_tflite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np

import fedscale.cloud.config_parser as parser
from fedscale.cloud.aggregation.aggregator import Aggregator
from fedscale.cloud.internal.tflite_model_adapter import TFLiteModelAdapter
from fedscale.utils.models.tflite_model_provider import *


class TFLiteAggregator(Aggregator):
"""This aggregator collects training/testing feedbacks from Android TFLite APPs.
Args:
args (dictionary): Variable arguments for FedScale runtime config.
Defaults to the setup in arg_parser.py.
"""

def __init__(self, args):
super().__init__(args)
self.tflite_model = None

def init_model(self):
"""
Load the model architecture and convert to TFLite.
"""
self.model_wrapper = TFLiteModelAdapter(
build_simple_linear(self.args))
self.tflite_model = convert_and_save(
TFLiteModel(self.model_wrapper.get_model()))
self.model_weights = self.model_wrapper.get_weights()

def update_weight_aggregation(self, update_weights):
"""
Update model when the round completes.
Then convert new model to TFLite.
Args:
update_weights (list): A list of global model weight in last round.
"""
super().update_weight_aggregation(update_weights)
if self.model_in_update == self.tasks_round:
self.tflite_model = convert_and_save(
TFLiteModel(self.model_wrapper.get_model()))

def deserialize_response(self, responses):
"""
Deserialize the response from executor.
If the response contains mnn json model, convert to pytorch state_dict.
Args:
responses (byte stream): Serialized response from executor.
Returns:
string, bool, or bytes: The deserialized response object from executor.
"""
data = super().deserialize_response(responses)
if "update_weight" in data:
path = f'cache/{data["client_id"]}.ckpt'
with open(path, 'wb') as model_file:
model_file.write(data["update_weight"])
restored_tensors = [
np.asarray(tf.raw_ops.Restore(file_pattern=path, tensor_name=var.name,
dt=var.dtype, name='restore')) for var in self.model_wrapper.get_model().weights if var.trainable]
os.remove(path)
data["update_weight"] = restored_tensors
return data

def serialize_response(self, responses):
"""
Serialize the response to send to server upon assigned job completion.
If the responses is the pytorch model, change it to mnn_json.
Args:
responses (ServerResponse): Serialized response from server.
Returns:
bytes: The serialized response object to server.
"""
if type(responses) is list:
responses = self.tflite_model
return super().serialize_response(responses)


if __name__ == "__main__":
aggregator = TFLiteAggregator(parser.args)
aggregator.run()
21 changes: 21 additions & 0 deletions fedscale/cloud/internal/tflite_model_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import List

import numpy as np
import tensorflow as tf

from fedscale.cloud.internal.model_adapter_base import ModelAdapterBase


class TFLiteModelAdapter(ModelAdapterBase):
def __init__(self, model: tf.keras.Model):
self.model = model

def set_weights(self, weights: List[np.ndarray]):
for var, weight in zip(self.model.weights, weights):
var.assign(weight)

def get_weights(self) -> List[np.ndarray]:
return [np.asarray(var.read_value()) for var in self.model.weights if var.trainable]

def get_model(self):
return self.model
4 changes: 2 additions & 2 deletions fedscale/edge/mnn/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Android Sample App
# Android MNN Sample App

This directory contains minimum files modified from [MNN Android Demo](https://github.com/alibaba/MNN/tree/master/project/android/demo). The training and testing will be conducted by MNN C++ backend, while the task execution and communication with server will be managed by Java. The sample has been tested upon image classification with a simple linear model and a small subset of [ImageNet-MINI](https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000). This documentation contains a step-by-step tutorial on how to download, build and config this app on your own device, and modify this app for your own implementation and deployment.

Expand All @@ -25,7 +25,7 @@ This directory contains minimum files modified from [MNN Android Demo](https://g
1. ssh to your own server and run
```
cd fedscale/cloud/aggregation/android
python3 android_aggregator.py --experiment_mode=mobile --num_participants=1 --model=linear
python3 aggregator_mnn.py --experiment_mode=mobile --num_participants=1 --model=linear
```
2. Change aggregator IP address inside `assets/conf.json` and click `Run` inside Android Studio.

Expand Down
2 changes: 1 addition & 1 deletion fedscale/edge/mnn/jni/mnntrainnative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ Java_com_fedscale_android_mnn_MNNNative_nativeTrain(
StringBuffer s;
Writer<StringBuffer> writer(s);
writer.StartObject();
writer.Key("clientId");
writer.Key("client_id");
writer.String(clientId.c_str());
writer.Key("moving_loss");
writer.Double(epochTrainLoss);
Expand Down
8 changes: 8 additions & 0 deletions fedscale/edge/tflite/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
assets/*train*
assets/TrainSet
assets/*test*
assets/TestSet
assets/*tflite*
.idea
.gradle
app/build
38 changes: 38 additions & 0 deletions fedscale/edge/tflite/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Android TFLite Sample App

This directory contains minimum files modified from [MNN Android Demo](https://github.com/alibaba/MNN/tree/master/project/android/demo) and [TFLite Android Demo](https://www.tensorflow.org/lite/examples/on_device_training/overview). The training and testing will be conducted by TFLite backend, while the task execution and communication with server will be managed by Java. The sample has been tested upon image classification with a simple linear model and a small subset of [ImageNet-MINI](https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000). This documentation contains a step-by-step tutorial on how to download, build and config this app on your own device, and modify this app for your own implementation and deployment.

## Download and build sample android app

1. Download and unzip [sample dataset (TrainTest.zip)](https://drive.google.com/file/d/1nfi3SVzjaE0LPxwj_5DNdqi6rK7BU8kb/view?usp=sharing) to `assets/` directory. Remove `TrainTest.zip` after unzip to save space on your mobile device. After unzip, you should see 3 files and 2 directories under `assets/`:
1. `TrainSet`: Training set directory, contains 316 images.
2. `TestSet`: Testing set directory, contains 34 images.
3. `conf.json`: Configuration file for mobile app.
4. `train_labels.txt`: Training label file with format `<filename> <label>`, where `<filename>` is the path after `TrainSet/`.
5. `test_labels.txt`: Testing label file with the same format as `train_labels.txt`.
2. Install [Android Studio](https://developer.android.com/studio) and open project `fedscale/edge/tflite`. Download necessary SDKs, NDKs and CMake when prompted. My version:
- SDK: API 32
- Android Gradle Plugin Version: 3.5.3
- Gradle Version: 5.4.1
- Source Compatibility: Java 8
- Target Compatibility: Java 8
3. Make project. Android Studio will compile and build the app for you.

## Test this app with default setting

1. ssh to your own server and run
```
cd fedscale/cloud/aggregation/android
python3 aggregator_tflite.py --experiment_mode=mobile --num_participants=1 --engine=tensorflow
```
2. Change aggregator IP address inside `assets/conf.json` and click `Run` inside Android Studio.

## Customize your own app

1. If you want to use your own dataset, please put your data under `assets/TrainSet` and `assets/TestSet`, make sure that your label has the same format as my label file.
1. If you want to change the file/dir name under `assets`, please make sure to change the corresponding config in `assets` attribute inside `assets/conf.json`.
2. If you want to use your own model for **image classification**, please either change `channel`, `width` and `height` inside `assets/conf.json` to your own input and change `num_classes` to your own classes, or override these attributes when sending `CLIENT_TRAIN` request.
3. If you want to use your own model for tasks other than image classification, you may need to write your own TFLite trainer and tester. Please refer to [TFLite](https://www.tensorflow.org/lite/api_docs) for further development guide. You may also need to change `channel`, `width` and `height` inside `assets/conf.json` to your own input and change or remove `num_classes`.

----
If you need any other help, feel free to contact FedScale team or the developer [website](https://continue-revolution.github.io) [email](mailto:[email protected]) of this app.
108 changes: 108 additions & 0 deletions fedscale/edge/tflite/app/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
buildscript {
repositories {
maven {
url "https://plugins.gradle.org/m2/"
}
}
dependencies {
classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.10"
}
}

apply plugin: 'com.android.application'
apply plugin: 'com.google.protobuf'

android {

aaptOptions {
noCompress "tflite"
}

compileSdkVersion 32
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}

defaultConfig {
applicationId "com.fedscale.android.executor"
minSdkVersion 23
targetSdkVersion 32
multiDexEnabled true
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}

buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
sourceSets {
main {
assets {
srcDirs = ["../assets"]
}
}
}
// dataBinding {
// enabled = true
// }
packagingOptions {
exclude 'META-INF/com.android.tools/proguard/coroutines.pro'
}
}

dependencies {

implementation fileTree(dir: 'libs', include: ['*.jar'])

// App compat and UI things
implementation 'androidx.appcompat:appcompat:1.5.0'
implementation 'com.google.android.material:material:1.6.1'
implementation 'androidx.constraintlayout:constraintlayout:2.1.4'

// You need to build grpc-java to obtain these libraries below.
implementation 'javax.annotation:javax.annotation-api:1.3.2'
implementation 'io.grpc:grpc-okhttp:1.49.0' // CURRENT_GRPC_VERSION
implementation 'io.grpc:grpc-protobuf-lite:1.49.0' // CURRENT_GRPC_VERSION
implementation 'io.grpc:grpc-stub:1.49.0' // CURRENT_GRPC_VERSION
implementation 'net.razorvine:pickle:1.3'

//WindowManager
implementation 'androidx.window:window:1.1.0-alpha03'

// Unit testing
testImplementation 'junit:junit:4.13.2'

// Instrumented testing
androidTestImplementation 'androidx.test.ext:junit:1.1.3'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'

// Tensorflow lite dependencies
implementation 'org.tensorflow:tensorflow-lite:2.9.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.9.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.4.2'
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:2.9.0'
}

protobuf {
protoc { artifact = 'com.google.protobuf:protoc:3.21.1' }
plugins {
grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.49.0' // CURRENT_GRPC_VERSION
}
}
generateProtoTasks {
all().each { task ->
task.builtins {
java { option 'lite' }
}
task.plugins {
grpc { // Options added to --grpc_out
option 'lite' }
}
}
}
}
Loading

0 comments on commit 2014eef

Please sign in to comment.