Skip to content

Commit

Permalink
Pass unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
EricDinging committed Dec 17, 2023
1 parent c3a940a commit 292f306
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
7 changes: 6 additions & 1 deletion fedscale/cloud/internal/model_adapter_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ class ModelAdapterBase(abc.ABC):
"""
Represents an adapter that operates on a framework-specific model.
"""

@abc.abstractmethod
def set_weights(self, weights: np.ndarray):
def set_weights(
self, weights: np.ndarray, is_aggregator=True, client_training_results=None
):
"""
Set the model's weights to the numpy weights array.
:param weights: numpy weights array
:param is_aggregator: boolean indicating whether the caller is the aggregator
:param client_training_results: list of gradients from every clients, for q-fedavg
"""
pass

Expand Down
19 changes: 17 additions & 2 deletions fedscale/cloud/internal/tensorflow_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,28 @@ class TensorflowModelAdapter(ModelAdapterBase):
def __init__(self, model: tf.keras.Model):
self.model = model

def set_weights(self, weights: List[np.ndarray]):
def set_weights(
self,
weights: List[np.ndarray],
is_aggregator=True,
client_training_results=None,
):
"""
Set the model's weights to the numpy weights array.
:param weights: numpy weights array
:param is_aggregator: boolean indicating whether the caller is the aggregator
:param client_training_results: list of gradients from every clients, for q-fedavg
"""
for i, layer in enumerate(self.model.layers):
if layer.trainable:
layer.set_weights(weights[i])

def get_weights(self) -> List[np.ndarray]:
return [np.asarray(layer.get_weights()) for layer in self.model.layers if layer.trainable]
return [
np.asarray(layer.get_weights())
for layer in self.model.layers
if layer.trainable
]

def get_model(self):
return self.model
14 changes: 10 additions & 4 deletions fedscale/tests/cloud/aggregation/test_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, model_wrapper):
self.model_in_update = 1
self.tasks_round = 3
self.model_wrapper = model_wrapper
self.client_training_results = None


def multiply_weights(weights, factor):
Expand All @@ -23,8 +24,9 @@ def multiply_weights(weights, factor):
class TestAggregator:
def test_update_weight_aggregation_for_keras_model(self):
x = tf.keras.Input(shape=(2,))
y = tf.keras.layers.Dense(2, activation='softmax')(
tf.keras.layers.Dense(4, activation='softmax')(x))
y = tf.keras.layers.Dense(2, activation="softmax")(
tf.keras.layers.Dense(4, activation="softmax")(x)
)
model = tf.keras.Model(x, y)
model_adapter = TensorflowModelAdapter(model)
aggregator = MockAggregator(model_adapter)
Expand All @@ -34,7 +36,9 @@ def test_update_weight_aggregation_for_keras_model(self):
aggregator.update_weight_aggregation(multiply_weights(weights, 2))
aggregator.model_in_update += 1
aggregator.update_weight_aggregation(multiply_weights(weights, 5))
np.array_equal(aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3))
np.array_equal(
aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3)
)

def test_update_weight_aggregation_for_torch_model(self):
model = torch.nn.Linear(3, 2)
Expand All @@ -46,4 +50,6 @@ def test_update_weight_aggregation_for_torch_model(self):
aggregator.update_weight_aggregation(multiply_weights(weights, 2))
aggregator.model_in_update += 1
aggregator.update_weight_aggregation(multiply_weights(weights, 5))
np.array_equal(aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3))
np.array_equal(
aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3)
)

0 comments on commit 292f306

Please sign in to comment.