From 292f3061a401e56f4c4a1b60e3ed4ed5ade23168 Mon Sep 17 00:00:00 2001 From: EricDinging Date: Sun, 17 Dec 2023 13:48:14 -0500 Subject: [PATCH] Pass unit test --- fedscale/cloud/internal/model_adapter_base.py | 7 ++++++- .../internal/tensorflow_model_adapter.py | 19 +++++++++++++++++-- .../cloud/aggregation/test_aggregator.py | 14 ++++++++++---- 3 files changed, 33 insertions(+), 7 deletions(-) diff --git a/fedscale/cloud/internal/model_adapter_base.py b/fedscale/cloud/internal/model_adapter_base.py index 067ef91c..18beab77 100644 --- a/fedscale/cloud/internal/model_adapter_base.py +++ b/fedscale/cloud/internal/model_adapter_base.py @@ -7,11 +7,16 @@ class ModelAdapterBase(abc.ABC): """ Represents an adapter that operates on a framework-specific model. """ + @abc.abstractmethod - def set_weights(self, weights: np.ndarray): + def set_weights( + self, weights: np.ndarray, is_aggregator=True, client_training_results=None + ): """ Set the model's weights to the numpy weights array. :param weights: numpy weights array + :param is_aggregator: boolean indicating whether the caller is the aggregator + :param client_training_results: list of gradients from every clients, for q-fedavg """ pass diff --git a/fedscale/cloud/internal/tensorflow_model_adapter.py b/fedscale/cloud/internal/tensorflow_model_adapter.py index 5c5785a5..360fc031 100644 --- a/fedscale/cloud/internal/tensorflow_model_adapter.py +++ b/fedscale/cloud/internal/tensorflow_model_adapter.py @@ -10,13 +10,28 @@ class TensorflowModelAdapter(ModelAdapterBase): def __init__(self, model: tf.keras.Model): self.model = model - def set_weights(self, weights: List[np.ndarray]): + def set_weights( + self, + weights: List[np.ndarray], + is_aggregator=True, + client_training_results=None, + ): + """ + Set the model's weights to the numpy weights array. + :param weights: numpy weights array + :param is_aggregator: boolean indicating whether the caller is the aggregator + :param client_training_results: list of gradients from every clients, for q-fedavg + """ for i, layer in enumerate(self.model.layers): if layer.trainable: layer.set_weights(weights[i]) def get_weights(self) -> List[np.ndarray]: - return [np.asarray(layer.get_weights()) for layer in self.model.layers if layer.trainable] + return [ + np.asarray(layer.get_weights()) + for layer in self.model.layers + if layer.trainable + ] def get_model(self): return self.model diff --git a/fedscale/tests/cloud/aggregation/test_aggregator.py b/fedscale/tests/cloud/aggregation/test_aggregator.py index 4bd81d55..58c739f1 100644 --- a/fedscale/tests/cloud/aggregation/test_aggregator.py +++ b/fedscale/tests/cloud/aggregation/test_aggregator.py @@ -14,6 +14,7 @@ def __init__(self, model_wrapper): self.model_in_update = 1 self.tasks_round = 3 self.model_wrapper = model_wrapper + self.client_training_results = None def multiply_weights(weights, factor): @@ -23,8 +24,9 @@ def multiply_weights(weights, factor): class TestAggregator: def test_update_weight_aggregation_for_keras_model(self): x = tf.keras.Input(shape=(2,)) - y = tf.keras.layers.Dense(2, activation='softmax')( - tf.keras.layers.Dense(4, activation='softmax')(x)) + y = tf.keras.layers.Dense(2, activation="softmax")( + tf.keras.layers.Dense(4, activation="softmax")(x) + ) model = tf.keras.Model(x, y) model_adapter = TensorflowModelAdapter(model) aggregator = MockAggregator(model_adapter) @@ -34,7 +36,9 @@ def test_update_weight_aggregation_for_keras_model(self): aggregator.update_weight_aggregation(multiply_weights(weights, 2)) aggregator.model_in_update += 1 aggregator.update_weight_aggregation(multiply_weights(weights, 5)) - np.array_equal(aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3)) + np.array_equal( + aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3) + ) def test_update_weight_aggregation_for_torch_model(self): model = torch.nn.Linear(3, 2) @@ -46,4 +50,6 @@ def test_update_weight_aggregation_for_torch_model(self): aggregator.update_weight_aggregation(multiply_weights(weights, 2)) aggregator.model_in_update += 1 aggregator.update_weight_aggregation(multiply_weights(weights, 5)) - np.array_equal(aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3)) + np.array_equal( + aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3) + )