Skip to content

Commit

Permalink
Merge pull request #209 from continue-revolution/master
Browse files Browse the repository at this point in the history
Merge TFLite & MNN Clients
  • Loading branch information
fanlai0990 committed Mar 3, 2023
2 parents 2014eef + 04ec45a commit d8fd582
Show file tree
Hide file tree
Showing 159 changed files with 36,084 additions and 1,854 deletions.
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,10 @@ landmark
coqa
.idea
out/
build/
build/
.vscode
fedscale/cache
fedscale/tmp*
fedscale/*.tflite
fedscale/*.ckpt
fedscale/cloud/aggregation/cache
10 changes: 3 additions & 7 deletions fedscale/cloud/aggregation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@ An example android aggregator accompanied by
**Note**:
MNN does not support direct conversion from MNN to PyTorch model, so we did a manual conversion from MNN to JSON, then from JSON to PyTorch model. We currently only support Convolution (including Linear) and BatchNorm conversion. We welcome contribution to support more conversion for operators with trainable parameters.

`scripts/convert.sh` contains model conversion code. It will clone MNN and build converter. You do not need to manually run this script. This script is run internally inside android aggregator.

`fedscale/utils/models/simple/linear_model.py` contains a simple linear model with Flatten->Linear->Softmax, used for simple test of our sample android app.

`fedscale/utils/models/mnn_convert.py` contains all the code necessary for MNN<->PyTorch model conversion.
`fedscale/utils/models/mnn_model_provider.py` contains all the code necessary for MNN<->PyTorch model conversion and currently supported PyTorch models that can be converted to MNN without bugs.

In order to run this aggregator with default setting in order to test sample app, please run
```
Expand All @@ -26,7 +22,7 @@ cd FedScale
source install.sh
pip install -e .
cd fedscale/cloud/aggregation
python3 aggregator_mnn.py --experiment_mode=mobile --num_participants=1 --model=linear
python3 aggregator_mnn.py --experiment_mode mobile --num_participants 1 --num_classes 10 --input_shape 32 32 3 --model linear
```
and configure your android app according to the [tutorial](https://github.com/SymbioticLab/FedScale/fedscale/edge/mnn/README.md).

Expand All @@ -45,7 +41,7 @@ cd FedScale
source install.sh
pip install -e .
cd fedscale/cloud/aggregation
python3 aggregator_tflite.py --experiment_mode=mobile --num_participants=1 --engine=tensorflow
python3 aggregator_tflite.py --experiment_mode mobile --num_participants 1 --num_classes 10 --input_shape 32 32 3 --engine tensorflow --model [linear|mobilenetv3|resnet50|mobilenetv3_finetune|resnet50_finetune] --learning_rate 1e-2
```
and configure your android app according to the [tutorial](https://github.com/SymbioticLab/FedScale/fedscale/edge/tflite/README.md).

Expand Down
45 changes: 22 additions & 23 deletions fedscale/cloud/aggregation/aggregator_mnn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import json

import fedscale.cloud.config_parser as parser
from fedscale.cloud.aggregation.aggregator import Aggregator
from fedscale.utils.models.simple.linear_model import LinearModel
from fedscale.utils.models.mnn_convert import *
from fedscale.utils.models.mnn_model_provider import *
from fedscale.cloud.internal.torch_model_adapter import TorchModelAdapter


Expand All @@ -14,55 +11,58 @@ class MNNAggregator(Aggregator):
args (dictionary): Variable arguments for fedscale runtime config.
Defaults to the setup in arg_parser.py.
"""

def __init__(self, args):
super().__init__(args)

# == mnn model and keymap ==
self.mnn_json = None
self.mnn_model = None
self.keymap_mnn2torch = {}
self.input_shape = args.input_shape

def init_model(self):
"""
Load the model architecture and convert to mnn.
NOTE: MNN does not support dropout.
"""
if self.args.model == 'linear':
self.model_wrapper = TorchModelAdapter(LinearModel())
self.model_weights = self.model_wrapper.get_weights()
else:
super().init_model()
self.mnn_json = torch_to_mnn(self.model_wrapper.get_model(), self.input_shape, True)
self.keymap_mnn2torch = init_keymap(self.model_wrapper.get_model().state_dict(), self.mnn_json)
self.model_wrapper = TorchModelAdapter(
get_mnn_model(self.args.model, self.args))
self.model_weights = self.model_wrapper.get_weights()
self.mnn_model = torch_to_mnn(
self.model_wrapper.get_model(), self.input_shape)
self.keymap_mnn2torch = init_keymap(
self.model_wrapper.get_model().state_dict())

def update_weight_aggregation(self, update_weights):
"""
Update model when the round completes.
Then convert new model to mnn json.
Args:
last_model (list): A list of global model weight in last round.
"""
super().update_weight_aggregation(update_weights)
if self.model_in_update == self.tasks_round:
self.mnn_json = torch_to_mnn(self.model_wrapper.get_model(), self.input_shape)
self.mnn_model = torch_to_mnn(
self.model_wrapper.get_model(), self.input_shape)

def deserialize_response(self, responses):
"""
Deserialize the response from executor.
If the response contains mnn json model, convert to pytorch state_dict.
Args:
responses (byte stream): Serialized response from executor.
Returns:
string, bool, or bytes: The deserialized response object from executor.
"""
data = json.loads(responses.decode('utf-8'))
data = super().deserialize_response(responses)
if "update_weight" in data:
data["update_weight"] = mnn_to_torch(
self.keymap_mnn2torch,
json.loads(data["update_weight"]))
data["update_weight"],
data["client_id"])
return data

def serialize_response(self, responses):
Expand All @@ -76,10 +76,9 @@ def serialize_response(self, responses):
Returns:
bytes: The serialized response object to server.
"""
if type(responses) is list and all([np.array_equal(a, b) for a, b in zip(responses, self.model_wrapper.get_weights())]):
responses = self.mnn_json
data = json.dumps(responses)
return data.encode('utf-8')
if type(responses) is list:
responses = self.mnn_model
return super().serialize_response(responses)


if __name__ == "__main__":
Expand Down
132 changes: 120 additions & 12 deletions fedscale/cloud/aggregation/aggregator_tflite.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import numpy as np

import fedscale.cloud.config_parser as parser
from fedscale.cloud.aggregation.aggregator import Aggregator
from fedscale.cloud.internal.tflite_model_adapter import TFLiteModelAdapter
from fedscale.utils.models.tflite_model_provider import *
from fedscale.cloud.channels import job_api_pb2
from fedscale.cloud.fllibs import *


class TFLiteAggregator(Aggregator):
Expand All @@ -17,15 +20,15 @@ class TFLiteAggregator(Aggregator):
def __init__(self, args):
super().__init__(args)
self.tflite_model = None
self.base = None

def init_model(self):
"""
Load the model architecture and convert to TFLite.
"""
self.model_wrapper = TFLiteModelAdapter(
build_simple_linear(self.args))
self.tflite_model = convert_and_save(
TFLiteModel(self.model_wrapper.get_model()))
model, self.base = get_tflite_model(self.args.model, self.args)
self.model_wrapper = TFLiteModelAdapter(model)
self.tflite_model = convert_and_save(model, self.base, self.args)
self.model_weights = self.model_wrapper.get_weights()

def update_weight_aggregation(self, update_weights):
Expand All @@ -39,12 +42,12 @@ def update_weight_aggregation(self, update_weights):
super().update_weight_aggregation(update_weights)
if self.model_in_update == self.tasks_round:
self.tflite_model = convert_and_save(
TFLiteModel(self.model_wrapper.get_model()))
self.model_wrapper.get_model(), self.base, self.args)

def deserialize_response(self, responses):
"""
Deserialize the response from executor.
If the response contains mnn json model, convert to pytorch state_dict.
If the response contains mnn model, convert to pytorch state_dict.
Args:
responses (byte stream): Serialized response from executor.
Expand All @@ -58,27 +61,132 @@ def deserialize_response(self, responses):
with open(path, 'wb') as model_file:
model_file.write(data["update_weight"])
restored_tensors = [
np.asarray(tf.raw_ops.Restore(file_pattern=path, tensor_name=var.name,
dt=var.dtype, name='restore')) for var in self.model_wrapper.get_model().weights if var.trainable]
np.asarray(tf.raw_ops.Restore(
file_pattern=path, tensor_name=var.name,
dt=var.dtype, name='restore')
) for var in self.model_wrapper.get_model().weights]
os.remove(path)
data["update_weight"] = restored_tensors
return data

def serialize_response(self, responses):
"""
Serialize the response to send to server upon assigned job completion.
If the responses is the pytorch model, change it to mnn_json.
""" Serialize the response to send to server upon assigned job completion
Args:
responses (ServerResponse): Serialized response from server.
Returns:
bytes: The serialized response object to server.
"""
if type(responses) is list:
responses = self.tflite_model
return super().serialize_response(responses)

# def create_client_task(self, executor_id):
# """Issue a new client training task to specific executor

# Args:
# executorId (int): Executor Id.

# Returns:
# tuple: Training config for new task. (dictionary, PyTorch or TensorFlow module)

# """
# next_client_id = self.resource_manager.get_next_task(executor_id)
# train_config = None
# # NOTE: model = None then the executor will load the global model broadcasted in UPDATE_MODEL
# if next_client_id is not None:
# config = self.get_client_conf(next_client_id)
# train_config = {'client_id': next_client_id, 'task_config': config}
# return train_config, self.tflite_model_bytes

def CLIENT_PING(self, request, context):
"""Handle client ping requests
Args:
request (PingRequest): Ping request info from executor.
Returns:
ServerResponse: Server response to ping request
"""
# NOTE: client_id = executor_id in deployment,
# while multiple client_id may use the same executor_id (VMs) in simulations
executor_id, client_id = request.executor_id, request.client_id
response_data = response_msg = commons.DUMMY_RESPONSE

if len(self.individual_client_events[executor_id]) == 0:
# send dummy response
current_event = commons.DUMMY_EVENT
response_data = response_msg = commons.DUMMY_RESPONSE
else:
# NOTE: This is a temp solution to bypass the following errors:
# 1. problem: server->client update_model package dropped, server->client model_test in error
# solution: ignore update_model, send model in model_test package
# 2. problem: server->client client_train package dropped, server->client dummy_event forever
# solution: keep event inside queue until client confirm event completed
# pitfall: simulation executor L388 multi-thread may ping the same event more than once
# update_model no confirmation, no way to tell if update_model finished
current_event = self.individual_client_events[executor_id][0]
while current_event == commons.UPDATE_MODEL:
self.individual_client_events[executor_id].popleft()
current_event = self.individual_client_events[executor_id][0]
if current_event == commons.CLIENT_TRAIN:
response_msg, response_data = self.create_client_task(
executor_id)
if response_msg is None:
current_event = commons.DUMMY_EVENT
if self.experiment_mode != commons.SIMULATION_MODE:
self.individual_client_events[executor_id].append(
commons.CLIENT_TRAIN)
elif current_event == commons.MODEL_TEST:
response_msg = self.get_test_config(client_id)
response_data = self.tflite_model
elif current_event == commons.SHUT_DOWN:
response_msg = self.get_shutdown_config(executor_id)
self.individual_client_events[executor_id].popleft()

response_msg, response_data = self.serialize_response(
response_msg), self.serialize_response(response_data)
# NOTE: in simulation mode, response data is pickle for faster (de)serialization
response = job_api_pb2.ServerResponse(event=current_event,
meta=response_msg, data=response_data)
if current_event != commons.DUMMY_EVENT:
logging.info(
f"Issue EVENT ({current_event}) to EXECUTOR ({executor_id})")

return response

def CLIENT_EXECUTE_COMPLETION(self, request, context):
"""FL clients complete the execution task.
Args:
request (CompleteRequest): Complete request info from executor.
Returns:
ServerResponse: Server response to job completion request
"""

executor_id, client_id, event = request.executor_id, request.client_id, request.event
execution_status, execution_msg = request.status, request.msg
meta_result, data_result = request.meta_result, request.data_result

if event in (commons.MODEL_TEST, commons.UPLOAD_MODEL):
if execution_status is False:
logging.error(
f"Executor {executor_id} fails to run client {client_id}, due to {execution_msg}")
else:
self.add_event_handler(
executor_id, event, meta_result, data_result)
event_pop = self.individual_client_events[executor_id].popleft()
logging.info(f"Event {event_pop} popped from queue.")
else:
logging.error(
f"Received undefined event {event} from client {client_id}")
return self.CLIENT_PING(request, context)


if __name__ == "__main__":
aggregator = TFLiteAggregator(parser.args)
Expand Down
2 changes: 1 addition & 1 deletion fedscale/cloud/internal/tflite_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def set_weights(self, weights: List[np.ndarray]):
var.assign(weight)

def get_weights(self) -> List[np.ndarray]:
return [np.asarray(var.read_value()) for var in self.model.weights if var.trainable]
return [np.asarray(var.read_value()) for var in self.model.weights]

def get_model(self):
return self.model
5 changes: 5 additions & 0 deletions fedscale/edge/android/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
.idea
.gradle
app/build
*/.cxx
assets/dataset
File renamed without changes.
21 changes: 21 additions & 0 deletions fedscale/edge/android/app/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
cmake_minimum_required(VERSION 3.4.1)
set(CMAKE_CXX_STANDARD 11)
include_directories(src/main/jni/include)

add_library(MNN SHARED IMPORTED)
add_library(MNNTrain SHARED IMPORTED)
add_library(MNNConvertDeps SHARED IMPORTED)
add_library(MNNExpress SHARED IMPORTED)

set_target_properties(MNN PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/src/main/jniLibs/${ANDROID_ABI}/libMNN.so)
set_target_properties(MNNTrain PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/src/main/jniLibs/${ANDROID_ABI}/libMNNTrain.so)
set_target_properties(MNNConvertDeps PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/src/main/jniLibs/${ANDROID_ABI}/libMNNConvertDeps.so)
set_target_properties(MNNExpress PROPERTIES IMPORTED_LOCATION ${CMAKE_SOURCE_DIR}/src/main/jniLibs/${ANDROID_ABI}/libMNN_Express.so)

add_library(mnncore SHARED src/main/jni/mnntrainnative.cpp)

find_library(log-lib log)
find_library(jnigraphics-lib jnigraphics)

add_definitions(-DMNN_USE_LOGCAT)
target_link_libraries(mnncore MNN MNNTrain MNNConvertDeps MNNExpress ${log-lib} ${jnigraphics-lib})
Loading

0 comments on commit d8fd582

Please sign in to comment.