Skip to content

Commit

Permalink
Merge pull request #246 from EricDinging/fix-test
Browse files Browse the repository at this point in the history
Fix q-fedavg model loading error | Fix test cases
  • Loading branch information
fanlai0990 committed Dec 18, 2023
2 parents 0f90918 + 292f306 commit 7ec441c
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 13 deletions.
12 changes: 6 additions & 6 deletions fedscale/cloud/aggregation/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,9 @@ def update_round_gradient(
update_weights = result["update_weight"]
if type(update_weights) is dict:
update_weights = [x for x in update_weights.values()]

weights = [
torch.from_numpy(np.asarray(x, dtype=np.float32)).to(
device=self.device
)
for x in update_weights
torch.tensor(x).to(device=self.device) for x in update_weights
]
grads = [
(u - v) * 1.0 / learning_rate for u, v in zip(last_model, weights)
Expand All @@ -100,8 +98,10 @@ def update_round_gradient(
) + (1.0 / learning_rate) * np.float_power(loss + 1e-10, qfedq)

# update global model
for idx, param in enumerate(target_model.parameters()):
param.data = last_model[idx] - Deltas[idx] / (hs + 1e-10)
new_state_dict = {
name: last_model[idx] - Deltas[idx] / (hs + 1e-10) for idx, name in enumerate(target_model.state_dict().keys())
}
target_model.load_state_dict(new_state_dict)

else:
# The default optimizer, FedAvg, has been applied in aggregator.py on the fly
Expand Down
7 changes: 6 additions & 1 deletion fedscale/cloud/internal/model_adapter_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ class ModelAdapterBase(abc.ABC):
"""
Represents an adapter that operates on a framework-specific model.
"""

@abc.abstractmethod
def set_weights(self, weights: np.ndarray):
def set_weights(
self, weights: np.ndarray, is_aggregator=True, client_training_results=None
):
"""
Set the model's weights to the numpy weights array.
:param weights: numpy weights array
:param is_aggregator: boolean indicating whether the caller is the aggregator
:param client_training_results: list of gradients from every clients, for q-fedavg
"""
pass

Expand Down
19 changes: 17 additions & 2 deletions fedscale/cloud/internal/tensorflow_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,28 @@ class TensorflowModelAdapter(ModelAdapterBase):
def __init__(self, model: tf.keras.Model):
self.model = model

def set_weights(self, weights: List[np.ndarray]):
def set_weights(
self,
weights: List[np.ndarray],
is_aggregator=True,
client_training_results=None,
):
"""
Set the model's weights to the numpy weights array.
:param weights: numpy weights array
:param is_aggregator: boolean indicating whether the caller is the aggregator
:param client_training_results: list of gradients from every clients, for q-fedavg
"""
for i, layer in enumerate(self.model.layers):
if layer.trainable:
layer.set_weights(weights[i])

def get_weights(self) -> List[np.ndarray]:
return [np.asarray(layer.get_weights()) for layer in self.model.layers if layer.trainable]
return [
np.asarray(layer.get_weights())
for layer in self.model.layers
if layer.trainable
]

def get_model(self):
return self.model
14 changes: 10 additions & 4 deletions fedscale/tests/cloud/aggregation/test_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, model_wrapper):
self.model_in_update = 1
self.tasks_round = 3
self.model_wrapper = model_wrapper
self.client_training_results = None


def multiply_weights(weights, factor):
Expand All @@ -23,8 +24,9 @@ def multiply_weights(weights, factor):
class TestAggregator:
def test_update_weight_aggregation_for_keras_model(self):
x = tf.keras.Input(shape=(2,))
y = tf.keras.layers.Dense(2, activation='softmax')(
tf.keras.layers.Dense(4, activation='softmax')(x))
y = tf.keras.layers.Dense(2, activation="softmax")(
tf.keras.layers.Dense(4, activation="softmax")(x)
)
model = tf.keras.Model(x, y)
model_adapter = TensorflowModelAdapter(model)
aggregator = MockAggregator(model_adapter)
Expand All @@ -34,7 +36,9 @@ def test_update_weight_aggregation_for_keras_model(self):
aggregator.update_weight_aggregation(multiply_weights(weights, 2))
aggregator.model_in_update += 1
aggregator.update_weight_aggregation(multiply_weights(weights, 5))
np.array_equal(aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3))
np.array_equal(
aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3)
)

def test_update_weight_aggregation_for_torch_model(self):
model = torch.nn.Linear(3, 2)
Expand All @@ -46,4 +50,6 @@ def test_update_weight_aggregation_for_torch_model(self):
aggregator.update_weight_aggregation(multiply_weights(weights, 2))
aggregator.model_in_update += 1
aggregator.update_weight_aggregation(multiply_weights(weights, 5))
np.array_equal(aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3))
np.array_equal(
aggregator.model_wrapper.get_weights(), multiply_weights(weights, 3)
)

0 comments on commit 7ec441c

Please sign in to comment.