Skip to content

Commit

Permalink
Merge pull request #32 from sb-ai-lab/bugfix/fix_cli
Browse files Browse the repository at this point in the history
Fix cli seed initialization
  • Loading branch information
mralexdmitriy authored Apr 23, 2024
2 parents eecd13e + ba92095 commit 1b7f622
Show file tree
Hide file tree
Showing 12 changed files with 51 additions and 14 deletions.
3 changes: 3 additions & 0 deletions examples/configs/logreg-sbol-smm-vm-yc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ member:
logging_level: 'info'
heartbeat_interval: 2.
sent_task_timout: 3600
member_model_params: {
output_dim: 19,
}

grpc_arbiter:
use_arbiter: False
Expand Down
4 changes: 2 additions & 2 deletions stalactite/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,11 @@ def preprocess(self):
self._search_and_fill_preprocessor_params()
if len(self.preprocessors_params) != 0:
train_split_data, preprocessors_dict = self._preprocess_split(
self.dataset[self.member_id][train_split_key], self.preprocessors_params # #todo: refactor
self.dataset[self.member_id][train_split_key], self.preprocessors_params
)
test_split_data, _ = self._preprocess_split(
self.dataset[self.member_id][test_split_key],
self.preprocessors_params, # todo: refactor
self.preprocessors_params,
preprocessors_dict=preprocessors_dict,
)

Expand Down
23 changes: 18 additions & 5 deletions stalactite/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from stalactite.ml.arbitered.base import PartyArbiter
from stalactite.ml import (
HonestPartyMasterLinReg,
HonestPartyMasterLinRegConsequently,
HonestPartyMemberLogReg,
HonestPartyMemberResNet,
HonestPartyMemberEfficientNet,
Expand Down Expand Up @@ -32,6 +33,8 @@ def get_party_master(config_path: str, is_infer: bool = False) -> PartyMaster:
if config.grpc_arbiter.use_arbiter:
master_processor, processors = load_processors_arbitered(config)
master_processor = master_processor if config.data.dataset.lower() == "sbol_smm" else processors[0]
if config.data.dataset_size == -1:
config.data.dataset_size = len(master_processor.dataset[config.data.train_split][config.data.uids_key])
master_class = ArbiteredPartyMasterLogReg
if config.grpc_arbiter.security_protocol_params is not None:
if config.grpc_arbiter.security_protocol_params.he_type == 'paillier':
Expand Down Expand Up @@ -59,10 +62,13 @@ def get_party_master(config_path: str, is_infer: bool = False) -> PartyMaster:
do_train=not is_infer,
do_save_model=config.vfl_model.do_save_model,
model_path=config.vfl_model.vfl_model_path,
seed=config.common.seed
)

else:
master_processor, processors = load_processors_honest(config)
if config.data.dataset_size == -1:
config.data.dataset_size = len(master_processor.dataset[config.data.train_split][config.data.uids_key])
if 'logreg' in config.vfl_model.vfl_model_name:
master_class = HonestPartyMasterLogReg
elif "resnet" in config.vfl_model.vfl_model_name:
Expand All @@ -72,7 +78,10 @@ def get_party_master(config_path: str, is_infer: bool = False) -> PartyMaster:
elif "mlp" in config.vfl_model.vfl_model_name:
master_class = HonestPartyMasterMLPSplitNN
else:
master_class = HonestPartyMasterLinReg
if config.vfl_model.is_consequently:
master_class = HonestPartyMasterLinRegConsequently
else:
master_class = HonestPartyMasterLinReg
return master_class(
uid="master",
epochs=config.vfl_model.epochs,
Expand All @@ -90,7 +99,8 @@ def get_party_master(config_path: str, is_infer: bool = False) -> PartyMaster:
do_train=not is_infer,
model_name=config.vfl_model.vfl_model_name if
config.vfl_model.vfl_model_name in ["resnet", "mlp", "efficientnet"] else None,
model_params=config.master.master_model_params
model_params=config.master.master_model_params,
seed=config.common.seed
)


Expand Down Expand Up @@ -124,7 +134,8 @@ def get_party_member(config_path: str, member_rank: int, is_infer: bool = False)
do_train=not is_infer,
do_save_model=config.vfl_model.do_save_model,
model_path=config.vfl_model.vfl_model_path,
use_inner_join=False
use_inner_join=False,
seed=config.common.seed
)

else:
Expand All @@ -140,6 +151,7 @@ def get_party_member(config_path: str, member_rank: int, is_infer: bool = False)
else:
member_class = HonestPartyMemberLinReg

member_ids = [f"member-{member_rank}" for member_rank in range(config.common.world_size)]
return member_class(
uid=f"member-{member_rank}",
member_record_uids=processors[member_rank].dataset[config.data.train_split][config.data.uids_key],
Expand All @@ -152,13 +164,14 @@ def get_party_member(config_path: str, member_rank: int, is_infer: bool = False)
report_train_metrics_iteration=config.common.report_train_metrics_iteration,
report_test_metrics_iteration=config.common.report_test_metrics_iteration,
is_consequently=config.vfl_model.is_consequently,
members=None,
members=member_ids if config.vfl_model.is_consequently else None,
do_predict=is_infer,
do_train=not is_infer,
do_save_model=config.vfl_model.do_save_model,
model_path=config.vfl_model.vfl_model_path,
model_params=config.member.member_model_params,
use_inner_join=True if member_rank == 0 else False
use_inner_join=True if member_rank == 0 else False,
seed=config.common.seed
)


Expand Down
2 changes: 2 additions & 0 deletions stalactite/ml/arbitered/logistic_regression/party_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
do_save_model: bool = False,
processor=None,
run_mlflow: bool = False,
seed: int = None

) -> None:
""" Initialize ArbiteredPartyMasterLinReg.
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
self.do_predict = do_predict
self.do_save_model = do_save_model
self.model_path = model_path
self.seed = seed

self.uid2tensor_idx = None
self.uid2tensor_idx_test = None
Expand Down
2 changes: 2 additions & 0 deletions stalactite/ml/arbitered/logistic_regression/party_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
do_save_model: bool = False,
processor=None,
use_inner_join: bool = True,
seed: int = None
) -> None:
self.id = uid
self.epochs = epochs
Expand All @@ -52,6 +53,7 @@ def __init__(
self.do_save_model = do_save_model
self.model_path = model_path
self.use_inner_join = use_inner_join
self.seed = seed

def predict_partial(self, uids: RecordsBatch) -> DataTensor:
logger.info(f'{self.id} makes partial predictions')
Expand Down
3 changes: 2 additions & 1 deletion stalactite/ml/honest/logistic_regression/party_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from stalactite.ml.honest.linear_regression.party_member import HonestPartyMemberLinReg
from stalactite.models import LogisticRegressionBatch

from stalactite.utils import init_linear_np

class HonestPartyMemberLogReg(HonestPartyMemberLinReg):
def initialize_model_from_params(self, **model_params) -> Any:
Expand All @@ -19,6 +19,7 @@ def initialize_model(self, do_load_model: bool = False) -> None:
input_dim=self._dataset[self._data_params.train_split][self._data_params.features_key].shape[1],
**self._model_params
)
init_linear_np(self._model.linear, seed=self.seed)

def initialize_optimizer(self) -> None:
self._optimizer = SGD([
Expand Down
2 changes: 1 addition & 1 deletion stalactite/ml/honest/split_learning/resnet/party_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class HonestPartyMasterResNetSplitNN(HonestPartyMasterSplitNN):

def initialize_model(self, do_load_model: bool = False) -> None:
""" Initialize the model based on the specified model name. """
self._model = ResNetTop(**self._model_params)
self._model = ResNetTop(**self._model_params, seed=self.seed)
self._criterion = torch.nn.BCEWithLogitsLoss(
pos_weight=self.class_weights) if self.binary else torch.nn.CrossEntropyLoss(weight=self.class_weights)

Expand Down
2 changes: 1 addition & 1 deletion stalactite/ml/honest/split_learning/resnet/party_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def initialize_model(self, do_load_model: bool = False) -> None:
if do_load_model:
self._model = self.load_model()
else:
self._model = ResNetBottom(input_dim=input_dim, **self._model_params)
self._model = ResNetBottom(input_dim=input_dim, **self._model_params, seed=self.seed)

def initialize_optimizer(self) -> None:
self._optimizer = SGD([
Expand Down
6 changes: 5 additions & 1 deletion stalactite/models/split_learning/resnet_bottom.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn as nn

from stalactite.models.resnet import ResNetBlock
from stalactite.utils import init_linear_np

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -41,6 +42,7 @@ def __init__(
use_noise: bool = False,
device: torch.device = torch.device("cpu"),
init_weights: float = None,
seed: int = None,
**kwargs,
):
super(ResNetBottom, self).__init__()
Expand All @@ -52,7 +54,7 @@ def __init__(
assert (
len(drop_rate) == len(hid_factor) and len(drop_rate[0]) == 2
), "Wrong number hidden_sizes/drop_rates. Must be equal."

self.seed = seed
num_features = input_dim if num_init_features is None else num_init_features
self.dense0 = nn.Linear(input_dim, num_features) if num_init_features is not None else nn.Identity()
self.features1 = nn.Sequential(OrderedDict([]))
Expand All @@ -75,6 +77,8 @@ def __init__(
if isinstance(m, nn.Linear):
if init_weights:
nn.init.constant_(m.weight, init_weights)
else:
init_linear_np(m, seed=self.seed)
nn.init.zeros_(m.bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
7 changes: 6 additions & 1 deletion stalactite/models/split_learning/resnet_top.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from torch import nn, Tensor

from stalactite.utils import init_linear_np

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -33,12 +35,13 @@ def __init__(
num_init_features: Optional[int] = None,
use_bn: bool = True,
init_weights: float = None,
seed: int = None,
**kwargs,
):
super(ResNetTop, self).__init__()

num_features = input_dim if num_init_features is None else num_init_features

self.seed = seed
self.features = nn.Sequential(OrderedDict([]))
if use_bn:
self.features.add_module("norm", nn.BatchNorm1d(num_features))
Expand All @@ -50,6 +53,8 @@ def __init__(
if isinstance(m, nn.Linear):
if init_weights:
nn.init.constant_(m.weight, init_weights)
else:
init_linear_np(m, seed=self.seed)
nn.init.zeros_(m.bias)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
8 changes: 6 additions & 2 deletions stalactite/run_grpc_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from stalactite.configs import VFLConfig
from stalactite.ml.arbitered.base import Role
from stalactite.data_utils import get_party_master, get_party_arbiter, get_party_member
from stalactite.utils import seed_all

import logging

Expand All @@ -36,7 +37,7 @@
def main(config_path, infer, role):
config = VFLConfig.load_and_validate(config_path)
global_logging(role=role, config=config)

seed_all(config.common.seed)
arbiter_grpc_host = None
if config.grpc_arbiter.use_arbiter:
arbiter_grpc_host = os.environ.get("GRPC_ARBITER_HOST", config.grpc_arbiter.external_host)
Expand All @@ -61,9 +62,12 @@ def main(config_path, infer, role):
comm.run()

elif role == Role.master:
master = get_party_master(config_path, is_infer=infer)
if config.data.dataset_size == -1:
config.data.dataset_size = len(master.target_uids)
with reporting(config):
comm = GRpcMasterPartyCommunicator(
participant=get_party_master(config_path, is_infer=infer),
participant=master,
world_size=config.common.world_size,
port=config.grpc_server.port,
server_thread_pool_size=config.grpc_server.server_threadpool_max_workers,
Expand Down
3 changes: 3 additions & 0 deletions stalactite/utils_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,9 @@ def local_arbiter_main():
comm.run()
logger.info("Finishing thread %s" % threading.current_thread().name)

if config.data.dataset_size == -1:
config.data.dataset_size = len(master.target_uids)

with reporting(config):
run_local_agents(
master=master,
Expand Down

0 comments on commit 1b7f622

Please sign in to comment.