Skip to content

Commit

Permalink
Feat fedavg tf1.15 (#952)
Browse files Browse the repository at this point in the history
  • Loading branch information
whisylan authored Feb 25, 2022
1 parent 423273b commit 882310f
Show file tree
Hide file tree
Showing 19 changed files with 1,152 additions and 2 deletions.
8 changes: 8 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ protobuf:
--grpc_python_out=. \
protocols/fedlearner/channel/*.proto

python -m grpc_tools.protoc -I. \
--python_out=. \
fedlearner/fedavg/cluster/cluster.proto
python -m grpc_tools.protoc -I. \
--python_out=. \
--grpc_python_out=. \
fedlearner/fedavg/training_service.proto

lint:
pylint --rcfile ci/pylintrc fedlearner example

Expand Down
17 changes: 17 additions & 0 deletions deploy/scripts/trainer/run_fedavg.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/bin/bash
set -ex

source /app/deploy/scripts/hdfs_common.sh
source /app/deploy/scripts/env_to_args.sh

if [[ -n "${CODE_KEY}" ]]; then
pull_code ${CODE_KEY} $PWD
fi

if [[ $ROLE == "leader" ]]; then
export FL_LEADER_ADDRESS="0.0.0.0:50051"
elif [[ -n $PEER_ADDR ]]; then
export FL_LEADER_ADDRESS=$PEER_ADDR
fi

python $ROLE.py
1 change: 1 addition & 0 deletions deploy/scripts/wait4pair_wrapper.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ while [[ true ]]; do
export PEER_ADDR=`cat ${pair}`
break
else
echo "still waiting for peer addr"
sleep 1
fi
done
Expand Down
29 changes: 29 additions & 0 deletions example/fedavg/mnist/follower.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
from fedlearner.fedavg import train_from_keras_model
from .model import create_model, x_train, y_train, x_test, y_test

fed_leader_address = os.getenv("FL_LEADER_ADDRESS", "0.0.0.0:6870")
fl_name = "follower"
fl_cluster = {
"leader": {
"name": "leader",
"address": fed_leader_address
},
"followers": [{
"name": "follower"
}]
}

model = create_model()
x = x_train[len(x_train) // 2:]
y = y_train[len(y_train) // 2:]
train_from_keras_model(model,
x,
y,
batch_size=30,
epochs=1,
fl_name=fl_name,
fl_cluster=fl_cluster,
steps_per_sync=10)

model.evaluate(x_test, y_test)
29 changes: 29 additions & 0 deletions example/fedavg/mnist/leader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
from fedlearner.fedavg import train_from_keras_model
from .model import create_model, x_train, y_train, x_test, y_test

fed_leader_address = os.getenv("FL_LEADER_ADDRESS", "0.0.0.0:6870")
fl_name = "leader"
fl_cluster = {
"leader": {
"name": "leader",
"address": fed_leader_address
},
"followers": [{
"name": "follower"
}]
}

model = create_model()
x = x_train[:len(x_train) // 2]
y = y_train[:len(y_train) // 2]
train_from_keras_model(model,
x,
y,
batch_size=30,
epochs=1,
fl_name=fl_name,
fl_cluster=fl_cluster,
steps_per_sync=10)

model.evaluate(x_test, y_test)
21 changes: 21 additions & 0 deletions example/fedavg/mnist/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import tensorflow as tf
import numpy as np

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], -1).astype(np.float32) / 255.0
y_train = y_train.astype(np.int32)

x_test = x_test.reshape(x_test.shape[0], -1).astype(np.float32) / 255.0
y_test = y_test.astype(np.int32)


def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Dense(200, activation='relu', input_shape=(784, )),
tf.keras.layers.Dense(200, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax'),
])
model.compile(optimizer=tf.keras.optimizers.SGD(0.01),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['acc'])
return model
1 change: 1 addition & 0 deletions fedlearner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from fedlearner import data_join
from fedlearner import proxy
from fedlearner import trainer
from fedlearner import fedavg
157 changes: 157 additions & 0 deletions fedlearner/common/grpc_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import time
import collections
import grpc
from . import fl_logging as logging


class LocalServicerContext(grpc.ServicerContext):

def invocation_metadata(self):
return ()

def peer(self):
return "local"

def peer_identities(self):
return None

def peer_identity_key(self):
return None

def auth_context(self):
return dict()

def set_compression(self, compression):
return grpc.Compression.NoCompression

def send_initial_metadata(self, initial_metadata):
pass

def set_trailing_metadata(self, trailing_metadata):
pass

def abort(self, code, details):
pass

def abort_with_status(self, status):
pass

def set_code(self, code):
pass

def set_details(self, details):
pass

def disable_next_message_compression(self):
pass

def is_active(self):
return True

def time_remaining(self):
return None

def cancel(self):
pass

def add_callback(self, callback):
pass


def call_with_retry(call, max_retry_times=None, retry_interval=1):
retry_times = 0
while True:
try:
retry_times += 1
return call()
except grpc.RpcError as e:
if max_retry_times is None or retry_times < max_retry_times:
logging.warning(
"grpc call error, status: %s"
", details: %s, wait %ds for retry", e.code(), e.details(),
retry_interval)
time.sleep(retry_interval)
else:
raise e


#def remote_insecure_channel(address, options=None, compression=None):
# EGRESS_URL = os.getenv('EGRESS_URL', None)
# EGRESS_HOST = os.environ.get('EGRESS_HOST', None)
# EGRESS_DOMAIN = os.environ.get('EGRESS_DOMAIN', None)
# if not EGRESS_URL:
# return grpc.insecure_channel(address, options, compression)
#
# options = list(options) if options else list()
# default_authority = EGRESS_HOST or address
# options.append(('grpc.default_authority', default_authority))
# channel = grpc.insecure_channel(EGRESS_URL, options, compression)
#
# if EGRESS_DOMAIN:
# address = address + '.' + EGRESS_DOMAIN
# channel = grpc.intercept_channel(
# channel, add_metadata_interceptor({'x-host': address}))
#
# return channel
#
#
#def add_metadata_interceptor(headers):
# if not isinstance(headers, dict):
# raise TypeError("headers must be a dict")
# headers = list(headers.items())
#
# def add_metadata_fn(client_call_details, request_iterator,
# request_streaming, response_streaming):
# metadata = list(client_call_details.metadata or [])
# metadata.extend(headers)
# client_call_details = _ClientCallDetails(
# client_call_details.method, client_call_details.timeout, metadata,
# client_call_details.credentials)
# return client_call_details, request_iterator, None
#
# return _GenericClientInterceptor(add_metadata_fn)


class _ClientCallDetails(
collections.namedtuple(
'_ClientCallDetails',
('method', 'timeout', 'metadata', 'credentials')),
grpc.ClientCallDetails):
pass


class _GenericClientInterceptor(grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor):

def __init__(self, interceptor_function):
self._fn = interceptor_function

def intercept_unary_unary(self, continuation, client_call_details,
request):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, iter((request, )), False, False)
response = continuation(new_details, next(new_request_iterator))
return postprocess(response) if postprocess else response

def intercept_unary_stream(self, continuation, client_call_details,
request):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, iter((request, )), False, True)
response_it = continuation(new_details, next(new_request_iterator))
return postprocess(response_it) if postprocess else response_it

def intercept_stream_unary(self, continuation, client_call_details,
request_iterator):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, request_iterator, True, False)
response = continuation(new_details, new_request_iterator)
return postprocess(response) if postprocess else response

def intercept_stream_stream(self, continuation, client_call_details,
request_iterator):
new_details, new_request_iterator, postprocess = self._fn(
client_call_details, request_iterator, True, True)
response_it = continuation(new_details, new_request_iterator)
return postprocess(response_it) if postprocess else response_it
2 changes: 2 additions & 0 deletions fedlearner/fedavg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

from .fedavg import train_from_keras_model
50 changes: 50 additions & 0 deletions fedlearner/fedavg/_global_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2020 The FedLearner Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import threading
from fedlearner.common import stats

class _GlobalContext:
def __init__(self):
self.job = os.getenv("FL_JOB") \
or os.getenv("APPLICATION_ID") \
or "unknow"
self.task = os.getenv("FL_TASK") \
or "unknow"
self.task_index = os.getenv("FL_TASK_INDEX") \
or os.getenv("INDEX") \
or "0"
self.task_index = int(self.task_index)

self._stats_client = None

self._lock = threading.Lock()

@property
def stats_client(self):
if self._stats_client:
return self._stats_client

with self._lock:
if not self._stats_client:
self._stats_client = stats.with_tags({
"job": self.job,
"task": self.task,
"task_index": self.task_index,
})

return self._stats_client

global_context = _GlobalContext()
3 changes: 3 additions & 0 deletions fedlearner/fedavg/cluster/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

from .cluster_pb2 import FLNodeDef, FLClusterDef
from .cluster_spec import FLClusterSpec
13 changes: 13 additions & 0 deletions fedlearner/fedavg/cluster/cluster.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
syntax = "proto3";

package fedlearner.cluster;

message FLNodeDef {
string name = 1;
string address = 2;
}

message FLClusterDef {
FLNodeDef leader = 1;
repeated FLNodeDef followers = 2;
}
Loading

0 comments on commit 882310f

Please sign in to comment.