From 98c2132375263dd91ca41e91bb96acfb82461860 Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Thu, 21 Nov 2024 13:04:25 +0300 Subject: [PATCH 1/8] make gnn_models.py and models_utils.py better --- src/models_builder/gnn_models.py | 378 ++++++++++++++++++++++------- src/models_builder/models_utils.py | 24 +- 2 files changed, 312 insertions(+), 90 deletions(-) diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index c906a32..846b418 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -2,10 +2,14 @@ import json import random from math import ceil +from pathlib import Path from types import FunctionType +from typing import Callable, List, Union, Any, Type, Protocol + import numpy as np import torch import sklearn.metrics +from flask_socketio import SocketIO from torch.nn.utils import clip_grad_norm from torch import tensor import torch.nn.functional as F @@ -19,6 +23,7 @@ hash_data_sha256, \ TECHNICAL_PARAMETER_KEY, IMPORT_INFO_KEY, OPTIMIZERS_PARAMETERS_PATH, FUNCTIONS_PARAMETERS_PATH from aux.declaration import Declare +from base.datasets_processing import DatasetManager from explainers.explainer import ProgressBar from explainers.ProtGNN.MCTS import mcts_args from attacks.evasion_attacks import EvasionAttacker @@ -44,7 +49,10 @@ class Metric: } @staticmethod - def add_custom(name, compute_function): + def add_custom( + name: str, + compute_function: Callable + ) -> None: """ Register a custom metric. Example for accuracy: @@ -60,7 +68,12 @@ def add_custom(name, compute_function): raise NameError(f"Metric '{name}' already registered, use another name") Metric.available_metrics[name] = compute_function - def __init__(self, name, mask, **kwargs): + def __init__( + self, + name: str, + mask: Union[str, List[bool]], + **kwargs + ): """ :param name: name to refer to this metric :param mask: 'train', 'val', 'test', or a bool valued list @@ -70,16 +83,22 @@ def __init__(self, name, mask, **kwargs): self.mask = mask self.kwargs = kwargs - def compute(self, y_true, y_pred): + def compute( + self, + y_true, + y_pred + ): if self.name in Metric.available_metrics: if y_true.device != "cpu": y_true = y_true.cpu() return Metric.available_metrics[self.name](y_true, y_pred, **self.kwargs) - raise NotImplementedError() @staticmethod - def create_mask_by_target_list(y_true, target_list=None): + def create_mask_by_target_list( + y_true, + target_list: List = None + ) -> torch.Tensor: if target_list is None: mask = [True] * len(y_true) else: @@ -88,7 +107,6 @@ def create_mask_by_target_list(y_true, target_list=None): if 0 <= i < len(mask): mask[i] = True return tensor(mask) - # return mask class GNNModelManager: @@ -96,9 +114,11 @@ class GNNModelManager: training, evaluation, save and load principle """ - def __init__(self, - manager_config=None, - modification: ModelModificationConfig = None): + def __init__( + self, + manager_config: ModelManagerConfig = None, + modification: ModelModificationConfig = None + ): """ :param manager_config: socket to use for sending data to frontend :param modification: socket to use for sending data to frontend @@ -118,11 +138,7 @@ def __init__(self, # else: # raise Exception() - # if modification is None: - # modification = ModelModificationConfig() - if modification is None: - # raise RuntimeError("model manager config must be specified") modification = ConfigPattern( _config_class="ModelModificationConfig", _config_kwargs={}, @@ -141,10 +157,6 @@ def __init__(self, # QUE Kirill do we need to store it? maybe pass when need to self.dataset_path = None - - # FIXME Kirill, remove self.gen_dataset - self.gen_dataset = None - self.mi_defender = None self.mi_defense_name = None self.mi_defense_config = None @@ -184,28 +196,48 @@ def __init__(self, self.set_evasion_attacker() self.set_evasion_defender() - def train_model(self, **kwargs): + def train_model( + self, + **kwargs + ): pass - def train_1_step(self, gen_dataset): + def train_1_step( + self, + gen_dataset: DatasetManager + ): """ Perform 1 step of model training. """ # raise NotImplementedError() pass - def train_complete(self, gen_dataset, steps=None, **kwargs): + def train_complete( + self, + gen_dataset: DatasetManager, + steps: int = None, + **kwargs + ) -> None: """ """ # raise NotImplementedError() pass - def train_on_batch(self, batch, **kwargs): + def train_on_batch( + self, + batch, + **kwargs + ): pass - def evaluate_model(self, **kwargs): + def evaluate_model( + self, + **kwargs + ): pass - def get_name(self): + def get_name( + self + ) -> str: manager_name = self.manager_config.to_saveable_dict() # FIXME Kirill, make ModelManagerConfig and remove manager_name[CONFIG_CLASS_NAME] manager_name[CONFIG_CLASS_NAME] = self.__class__.__name__ @@ -215,13 +247,20 @@ def get_name(self): json_str = json.dumps(manager_name, indent=2) return json_str - def load_model(self, path=None, **kwargs): + def load_model( + self, + path: Union[str, Path] = None, + **kwargs + ) -> Type: """ Load model from torch save format """ raise NotImplementedError() - def save_model(self, path=None): + def save_model( + self, + path: Union[str, Path] = None + ) -> None: """ Save the model in torch format @@ -230,11 +269,17 @@ def save_model(self, path=None): """ raise NotImplementedError() - def model_path_info(self): + def model_path_info( + self + ) -> Union[str, Path]: path, _ = Declare.models_path(self) return path - def load_model_executor(self, path=None, **kwargs): + def load_model_executor( + self, + path: Union[str, Path, None] = None, + **kwargs + ) -> Union[str, Path]: """ Load executor. Generates the download model path if no other path is specified. @@ -267,7 +312,9 @@ class variables self.gnn.eval() return model_dir_path - def get_hash(self): + def get_hash( + self + ) -> str: """ calculates the hash on behalf of the model manager required for storage. The sha256 algorithm is used. @@ -277,15 +324,18 @@ def get_hash(self): gnn_MM_name_hash = hash_data_sha256(json_object.encode('utf-8')) return gnn_MM_name_hash - def save_model_executor(self, path=None, files_paths=None): + def save_model_executor( + self, + path: Union[str, Path, None] = None, + files_paths: List[Union[str, Path]] = None + ) -> Path: """ Save executor, generates paths and prepares all information about the model and its parameters for saving - :param gnn_architecture_path: path to save the architecture of the model, - by default it forms the path itself. :param path: path to save the model. By default, the path is compiled based on the global class variables + :param files_paths: """ if path is None: dir_path, files_paths = Declare.models_path(self) @@ -319,7 +369,11 @@ def save_model_executor(self, path=None, files_paths=None): f.write(self.mi_attack_config.json_for_config()) return path.parent - def set_poison_attacker(self, poison_attack_config=None, poison_attack_name: str = None): + def set_poison_attacker( + self, + poison_attack_config: PoisonAttackConfig = None, + poison_attack_name: str = None + ) -> None: if poison_attack_config is None: poison_attack_config = ConfigPattern( _class_name=poison_attack_name or "EmptyPoisonAttacker", @@ -357,7 +411,11 @@ def set_poison_attacker(self, poison_attack_config=None, poison_attack_name: str ) self.poison_attack_flag = True - def set_evasion_attacker(self, evasion_attack_config=None, evasion_attack_name: str = None): + def set_evasion_attacker( + self, + evasion_attack_config: EvasionAttackConfig = None, + evasion_attack_name: str = None + ) -> None: if evasion_attack_config is None: evasion_attack_config = ConfigPattern( _class_name=evasion_attack_name or "EmptyEvasionAttacker", @@ -393,7 +451,11 @@ def set_evasion_attacker(self, evasion_attack_config=None, evasion_attack_name: ) self.evasion_attack_flag = True - def set_mi_attacker(self, mi_attack_config=None, mi_attack_name: str = None): + def set_mi_attacker( + self, + mi_attack_config: MIAttackConfig = None, + mi_attack_name: str = None + ) -> None: if mi_attack_config is None: mi_attack_config = ConfigPattern( _class_name=mi_attack_name or "EmptyMIAttacker", @@ -429,7 +491,11 @@ def set_mi_attacker(self, mi_attack_config=None, mi_attack_name: str = None): ) self.mi_attack_flag = True - def set_poison_defender(self, poison_defense_config=None, poison_defense_name: str = None): + def set_poison_defender( + self, + poison_defense_config: PoisonDefenseConfig = None, + poison_defense_name: str = None + ) -> None: if poison_defense_config is None: poison_defense_config = ConfigPattern( _class_name=poison_defense_name or "EmptyPoisonDefender", @@ -465,7 +531,11 @@ def set_poison_defender(self, poison_defense_config=None, poison_defense_name: s ) self.poison_defense_flag = True - def set_evasion_defender(self, evasion_defense_config=None, evasion_defense_name: str = None): + def set_evasion_defender( + self, + evasion_defense_config: EvasionDefenseConfig = None, + evasion_defense_name: str = None + ) -> None: if evasion_defense_config is None: evasion_defense_config = ConfigPattern( _class_name=evasion_defense_name or "EmptyEvasionDefender", @@ -501,7 +571,11 @@ def set_evasion_defender(self, evasion_defense_config=None, evasion_defense_name ) self.evasion_defense_flag = True - def set_mi_defender(self, mi_defense_config=None, mi_defense_name: str = None): + def set_mi_defender( + self, + mi_defense_config: MIDefenseConfig = None, + mi_defense_name: str = None + ) -> None: """ """ @@ -541,15 +615,21 @@ def set_mi_defender(self, mi_defense_config=None, mi_defense_name: str = None): self.mi_defense_flag = True @staticmethod - def available_attacker(): + def available_attacker( + ): pass @staticmethod - def available_defender(): + def available_defender( + ): pass @staticmethod - def from_model_path(model_path, dataset_path, **kwargs): + def from_model_path( + model_path: dict, + dataset_path: Union[str, Path], + **kwargs + ) -> [Type, Path]: """ Use information about model and model manager take gnn model, create gnn model manager object and load weights to the save model @@ -610,7 +690,9 @@ def from_model_path(model_path, dataset_path, **kwargs): return gnn_model_manager_obj, model_dir_path - def get_full_info(self): + def get_full_info( + self + ) -> dict: """ Get available info about model for frontend """ @@ -623,7 +705,9 @@ def get_full_info(self): result["epochs"] = f"Epochs={self.epochs}" return result - def get_model_data(self): + def get_model_data( + self + ) -> dict: """ :return: dict with the available functions of the model manager by the 'functions' key. """ @@ -638,7 +722,9 @@ def get_own_functions(cls): return model_data @staticmethod - def take_gnn_obj(gnn_file): + def take_gnn_obj( + gnn_file: Union[str, Path] + ) -> Type: with open(gnn_file) as f: params = json.load(f) class_name = params.pop(CONFIG_CLASS_NAME) @@ -667,22 +753,34 @@ def take_gnn_obj(gnn_file): obj_name) return gnn - def before_epoch(self, gen_dataset): + def before_epoch( + self, + gen_dataset: DatasetManager + ): """ This hook is called before training the next training epoch """ pass - def after_epoch(self, gen_dataset): + def after_epoch( + self, + gen_dataset: DatasetManager + ): """ This hook is called after training the next training epoch """ pass - def before_batch(self, batch): + def before_batch( + self, + batch + ): """ This hook is called before training the next training batch """ pass - def after_batch(self, batch): + def after_batch( + self, + batch + ): """ This hook is called after training the next training batch """ pass @@ -725,8 +823,8 @@ class FrameworkGNNModelManager(GNNModelManager): to prevent leakage of the response during training. """ - def __init__(self, gnn=None, - dataset_path=None, + def __init__(self, gnn: Type = None, + dataset_path: Union[str, Path] = None, **kwargs ): """ @@ -770,7 +868,9 @@ def __init__(self, gnn=None, if self.gnn is not None: self.init() - def init(self): + def init( + self + ) -> None: """ Initialize optimizer and loss function. """ @@ -785,7 +885,14 @@ def init(self): if "loss_function" in getattr(self.manager_config, CONFIG_OBJ): self.loss_function = getattr(self.manager_config, CONFIG_OBJ).loss_function.create_obj() - def train_complete(self, gen_dataset, steps=None, pbar=None, metrics=None, **kwargs): + def train_complete( + self, + gen_dataset: DatasetManager, + steps: int = None, + pbar: Protocol = None, + metrics: Union[List[Metric], Metric] = None, + **kwargs + ) -> None: for _ in range(steps): self.before_epoch(gen_dataset) print("epoch", self.modification.epochs) @@ -800,10 +907,19 @@ def train_complete(self, gen_dataset, steps=None, pbar=None, metrics=None, **kwa if early_stopping_flag: break - def early_stopping(self, train_loss, gen_dataset, metrics, steps): + def early_stopping( + self, + train_loss, + gen_dataset: DatasetManager, + metrics: Union[List[Metric], Metric], + steps: int + ) -> bool: return False - def train_1_step(self, gen_dataset): + def train_1_step( + self, + gen_dataset: DatasetManager + ) -> List[Union[float, int]]: task_type = gen_dataset.domain() if task_type == "single-graph": # FIXME Kirill, add data_x_copy mask @@ -828,9 +944,13 @@ def train_1_step(self, gen_dataset): print("loss %.8f" % loss) self.modification.epochs += 1 self.gnn.eval() - return loss.cpu().detach().numpy().tolist() + return loss.detach().numpy().tolist() - def train_on_batch_full(self, batch, task_type=None): + def train_on_batch_full( + self, + batch, + task_type: str = None + ) -> torch.Tensor: if self.mi_defender: self.mi_defender.pre_batch() if self.evasion_defender: @@ -848,12 +968,19 @@ def train_on_batch_full(self, batch, task_type=None): loss = self.optimizer_step(loss=loss) return loss - def optimizer_step(self, loss): + def optimizer_step( + self, + loss: torch.Tensor + ) -> torch.Tensor: loss.backward() self.optimizer.step() return loss - def train_on_batch(self, batch, task_type=None): + def train_on_batch( + self, + batch, + task_type: str = None + ) -> torch.Tensor: loss = None if hasattr(batch, "edge_weight"): weight = batch.edge_weight @@ -893,11 +1020,18 @@ def train_on_batch(self, batch, task_type=None): raise ValueError("Unsupported task type") return loss - def get_name(self, **kwargs): + def get_name( + self, + **kwargs + ) -> str: json_str = super().get_name() return json_str - def load_model(self, path=None, **kwargs): + def load_model( + self, + path: Union[str, Path, None] = None, + **kwargs + ) -> Type: """ Load model from torch save format @@ -914,7 +1048,10 @@ class variables self.init() return self.gnn - def save_model(self, path=None): + def save_model( + self, + path: Union[str, Path] = None + ) -> None: """ Save the model in torch format @@ -923,7 +1060,12 @@ def save_model(self, path=None): """ torch.save(self.gnn.state_dict(), path) - def report_results(self, train_loss, gen_dataset, metrics): + def report_results( + self, + train_loss, + gen_dataset: DatasetManager, + metrics: List[Metric] + ) -> None: metrics_values = self.evaluate_model(gen_dataset=gen_dataset, metrics=metrics) self.compute_stats_data(gen_dataset, predictions=True, logits=True) self.send_epoch_results( @@ -932,8 +1074,15 @@ def report_results(self, train_loss, gen_dataset, metrics): for k, v in self.stats_data.items()}, weights={"weights": self.gnn.get_weights()}, loss=train_loss) - def train_model(self, gen_dataset, save_model_flag=True, mode=None, steps=None, metrics=None, - socket=None): + def train_model( + self, + gen_dataset: DatasetManager, + save_model_flag: bool = True, + mode: Union[str, None] = None, + steps=None, + metrics: List[Metric] = None, + socket: SocketIO = None + ) -> None: """ Convenient train method. @@ -988,7 +1137,12 @@ def train_model(self, gen_dataset, save_model_flag=True, mode=None, steps=None, finally: self.socket = None - def run_model(self, gen_dataset, mask='test', out='answers'): + def run_model( + self, + gen_dataset: DatasetManager, + mask: Union[str, List[bool], torch.Tensor] = 'test', + out: str = 'answers' + ) -> torch.Tensor: """ Run the model on a part of dataset specified with a mask. @@ -1019,7 +1173,7 @@ def run_model(self, gen_dataset, mask='test', out='answers'): dataset = gen_dataset.dataset part_loader = DataLoader( dataset.index_select(mask), batch_size=self.batch, shuffle=False) - full_out = torch.Tensor() + full_out = torch.empty(0) # y_true = torch.Tensor() if hasattr(self, 'optimizer'): self.optimizer.zero_grad() @@ -1071,7 +1225,10 @@ def run_model(self, gen_dataset, mask='test', out='answers'): return full_out - def evaluate_model(self, gen_dataset, metrics): + def evaluate_model( + self, + gen_dataset: DatasetManager, + ) -> dict: """ Compute metrics for a model result on a part of dataset specified by the metric mask. @@ -1111,7 +1268,12 @@ def evaluate_model(self, gen_dataset, metrics): self.mi_attacker.attack() return metrics_values - def compute_stats_data(self, gen_dataset, predictions=False, logits=False): + def compute_stats_data( + self, + gen_dataset: DatasetManager, + predictions: bool = False, + logits: bool = False + ): """ :param gen_dataset: wrapper over the dataset, stores the dataset and all meta-information about the dataset @@ -1132,7 +1294,14 @@ def compute_stats_data(self, gen_dataset, predictions=False, logits=False): logits = self.run_model(gen_dataset, mask='all', out='logits') self.stats_data["embeddings"] = logits.detach().cpu().tolist() - def send_data(self, block, msg, tag='model', obligate=True, socket=None): + def send_data( + self, + block, + msg, + tag='model', + obligate=True, + socket=None + ): """ Send data to the frontend. @@ -1151,8 +1320,15 @@ def send_data(self, block, msg, tag='model', obligate=True, socket=None): socket.send(block=block, msg=msg, tag=tag, obligate=obligate) return True - def send_epoch_results(self, metrics_values=None, stats_data=None, weights=None, loss=None, obligate=False, - socket=None): + def send_epoch_results( + self, + metrics_values=None, + stats_data=None, + weights=None, + loss=None, + obligate=False, + socket=None + ): """ Send updates to the frontend after a training epoch: epoch, metrics, logits, loss. @@ -1174,7 +1350,10 @@ def send_epoch_results(self, metrics_values=None, stats_data=None, weights=None, if stats_data: self.send_data("mt", stats_data, tag='model_stats', obligate=obligate, socket=socket) - def load_train_test_split(self, gen_dataset): + def load_train_test_split( + self, + gen_dataset: DatasetManager + ) -> DatasetManager: path = self.model_path_info() path = path / 'train_test_split' gen_dataset.train_mask, gen_dataset.val_mask, gen_dataset.test_mask, _ = torch.load(path)[:] @@ -1182,10 +1361,6 @@ def load_train_test_split(self, gen_dataset): class ProtGNNModelManager(FrameworkGNNModelManager): - # additional_config = ModelManagerConfig( - # loss_function={CONFIG_CLASS_NAME: "CrossEntropyLoss"}, - # mask_features=[], - # ) additional_config = ConfigPattern( _config_class="ModelManagerConfig", _config_kwargs={ @@ -1208,10 +1383,17 @@ class ProtGNNModelManager(FrameworkGNNModelManager): } ) - def __init__(self, gnn=None, dataset_path=None, **kwargs): + def __init__( + self, + gnn: Type = None, + dataset_path: Union[str, Path] = None, + **kwargs + ): super().__init__(gnn=gnn, dataset_path=dataset_path, **kwargs) # Get prot layer and its params + self.is_best = None + self.cur_acc = None self.prot_layer = getattr(self.gnn, self.gnn.prot_layer_name) _config_obj = getattr(self.manager_config, CONFIG_OBJ) self.clst = _config_obj.clst @@ -1231,7 +1413,10 @@ def __init__(self, gnn=None, dataset_path=None, **kwargs): self.gnn.best_prots = self.prot_layer.prototype_graphs self.best_acc = 0.0 - def save_model(self, path=None): + def save_model( + self, + path: Union[str, Path, None] = None + ) -> None: """ Save the model in torch format @@ -1242,7 +1427,11 @@ def save_model(self, path=None): "best_prots": self.gnn.best_prots, }, path) - def load_model(self, path=None, **kwargs): + def load_model( + self, + path: Union[str, Path, None] = None, + **kwargs + ) -> Type: """ Load model from torch save format @@ -1259,7 +1448,11 @@ class variables self.init() return self.gnn - def train_on_batch(self, batch, task_type=None): + def train_on_batch( + self, + batch, + task_type: str = None + ) -> torch.Tensor: if task_type == "single-graph": self.optimizer.zero_grad() logits = self.gnn(batch.x, batch.edge_index) @@ -1322,13 +1515,19 @@ def train_on_batch(self, batch, task_type=None): raise ValueError("Unsupported task type") return loss - def optimizer_step(self, loss): + def optimizer_step( + self, + loss: torch.Tensor + ) -> torch.Tensor: loss.backward() torch.nn.utils.clip_grad_value_(self.gnn.parameters(), clip_value=2.0) self.optimizer.step() return loss - def before_epoch(self, gen_dataset): + def before_epoch( + self, + gen_dataset: DatasetManager + ): cur_step = self.modification.epochs train_ind = [n for n, x in enumerate(gen_dataset.train_mask) if x] # Prototype projection @@ -1351,9 +1550,12 @@ def after_epoch(self, gen_dataset): # check if best model metrics_values = self.evaluate_model( - gen_dataset, metrics=[Metric("Accuracy", mask='val'), - Metric("Precision", mask='val'), - Metric("Recall", mask='val')]) + gen_dataset, metrics=[ + Metric("Accuracy", mask='val'), + Metric("Precision", mask='val'), + Metric("Recall", mask='val') + ] + ) self.cur_acc = metrics_values['val']["Accuracy"] self.is_best = (self.cur_acc - self.best_acc >= 0.01) @@ -1362,7 +1564,13 @@ def after_epoch(self, gen_dataset): self.early_stop_count = 0 self.gnn.best_prots = self.prot_layer.prototype_graphs - def early_stopping(self, train_loss, gen_dataset, metrics, steps): + def early_stopping( + self, + train_loss, + gen_dataset: DatasetManager, + metrics: Union[List[Metric], Metric], + steps: int + ) -> bool: step = self.modification.epochs if self.is_best: self.early_stop_count = 0 diff --git a/src/models_builder/models_utils.py b/src/models_builder/models_utils.py index d060e45..c1b8a6b 100644 --- a/src/models_builder/models_utils.py +++ b/src/models_builder/models_utils.py @@ -1,8 +1,13 @@ +from typing import Any + import torch from torch_geometric.nn import MessagePassing -def apply_message_gradient_capture(layer, name): +def apply_message_gradient_capture( + layer: Any, + name: str +) -> None: """ # Example how get Tensors # for name, layer in self.gnn.named_children(): @@ -12,23 +17,32 @@ def apply_message_gradient_capture(layer, name): original_message = layer.message layer.message_gradients = {} - def capture_message_gradients(x_j, *args, **kwargs): + def capture_message_gradients( + x_j: torch.Tensor, + *args, + **kwargs + ): x_j = x_j.requires_grad_() if not layer.training: return original_message(x_j=x_j, *args, **kwargs) - def save_message_grad(grad): + def save_message_grad( + grad: torch.Tensor + ) -> None: layer.message_gradients[name] = grad.detach() x_j.register_hook(save_message_grad) return original_message(x_j=x_j, *args, **kwargs) layer.message = capture_message_gradients - def get_message_gradients(): + def get_message_gradients( + ) -> dict: return layer.message_gradients layer.get_message_gradients = get_message_gradients -def apply_decorator_to_graph_layers(model): +def apply_decorator_to_graph_layers( + model: Any +) -> None: # TODO Kirill add more options """ Example how use this def From 7d06840eb8f0757600de0222e4b90a3155475a1b Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Thu, 21 Nov 2024 15:00:56 +0300 Subject: [PATCH 2/8] make better configs.py and gnn_constructor.py --- src/aux/configs.py | 321 ++++++++++++++++++-------- src/models_builder/gnn_constructor.py | 142 +++++++++--- 2 files changed, 330 insertions(+), 133 deletions(-) diff --git a/src/aux/configs.py b/src/aux/configs.py index 911e568..303affa 100644 --- a/src/aux/configs.py +++ b/src/aux/configs.py @@ -4,6 +4,8 @@ from json import JSONEncoder import copy import inspect +from pathlib import Path +from typing import Union, Any, Type, Tuple, List from aux.utils import setting_class_default_parameters, EXPLAINERS_INIT_PARAMETERS_PATH, \ EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, EXPLAINERS_GLOBAL_RUN_PARAMETERS_PATH, \ @@ -21,7 +23,10 @@ # Patch of json.dumps() - classes which implement to_json() can be jsonified -def _default(self, obj): +def _default( + self, + obj: Any +): return getattr(obj.__class__, "to_json", _default.default)(obj) @@ -49,7 +54,10 @@ class GeneralConfig: "_config_kwargs"} _CONFIG_KEYS = "_config_keys" - def __init__(self, **kwargs): + def __init__( + self, + **kwargs + ): self._config_keys = set() for key, value in kwargs.items(): @@ -63,7 +71,11 @@ def __init__(self, **kwargs): self._config_keys.add(key) setattr(self, key, value) - def __setattr__(self, key, value): + def __setattr__( + self, + key: str, + value: Any + ) -> None: frame = inspect.currentframe() try: locals_info = frame.f_back.f_locals @@ -93,16 +105,24 @@ def __setattr__(self, key, value): finally: del frame - def json_for_config(self): + def json_for_config( + self + ) -> str: config_kwargs = self.to_saveable_dict().copy() config_kwargs = dict(sorted(config_kwargs.items())) json_object = json.dumps(config_kwargs, indent=2) return json_object - def hash_for_config(self): + def hash_for_config( + self + ) -> str: return hash_data_sha256(self.json_for_config().encode('utf-8')) - def to_saveable_dict(self, compact=False, **kwargs): + def to_saveable_dict( + self, + compact: bool = False, + **kwargs + ) -> dict: def sorted_dict(d): res = {} for key in sorted(d): @@ -137,7 +157,9 @@ def sorted_dict(d): dct[key] = str(value) return dct - def to_dict(self): + def to_dict( + self + ) -> dict: """ Represent config as a dictionary, as well as all included configs. Dict is a copy of all values. """ @@ -152,7 +174,10 @@ def to_dict(self): return res @staticmethod - def set_defaults_config_pattern_info(key, value): + def set_defaults_config_pattern_info( + key: str, + value: dict + ) -> dict: if "_config_kwargs" not in value: raise Exception("_config_kwargs can't set automatically") if key in _key_path: @@ -172,14 +197,23 @@ def set_defaults_config_pattern_info(key, value): # f"{key} is not currently supported") return value - def to_json(self): + def to_json( + self + ): """ Special method which allows to use json.dumps() on Config object """ return self.to_dict() -class ConfigPattern(GeneralConfig): - def __init__(self, _config_class: str, _config_kwargs, - _class_name: str = None, _import_path: str = None, _class_import_info=None): +class ConfigPattern( + GeneralConfig +): + def __init__( + self, + _config_class: str, + _config_kwargs: Union[dict, None], + _class_name: Union[str, None] = None, + _import_path: Union[str, Path, None] = None, + _class_import_info: str = None): if _import_path is not None: _import_path = str(_import_path) super().__init__(_class_name=_class_name, _import_path=_import_path, @@ -196,7 +230,10 @@ def __init__(self, _config_class: str, _config_kwargs, self._config_kwargs, save_kwargs = self._set_defaults() setattr(self, CONFIG_OBJ, self.make_config_by_pattern(save_kwargs)) - def __getattribute__(self, item): + def __getattribute__( + self, + item: Union[str, Type] + ) -> Any: if item == "__dict__" or item == "__class__": return object.__getattribute__(self, item) @@ -212,14 +249,20 @@ def __getattribute__(self, item): attr = getattr(self, CONFIG_OBJ).__getattribute__(item) return attr - def __setattr__(self, key, value): + def __setattr__( + self, + key: str, + value: Any + ) -> None: if (hasattr(self, CONFIG_OBJ) and hasattr(getattr(self, CONFIG_OBJ), self._CONFIG_KEYS) and key in getattr(getattr(self, CONFIG_OBJ), self._CONFIG_KEYS)): getattr(self, CONFIG_OBJ).__setattr__(key, value) else: super().__setattr__(key, value) - def _set_defaults(self): + def _set_defaults( + self + ) -> [dict, dict]: default_parameters_file_path = self._import_path kwargs = self._config_kwargs @@ -232,7 +275,10 @@ def _set_defaults(self): ) return init_kwargs, save_kwargs - def make_config_by_pattern(self, save_kwargs): + def make_config_by_pattern( + self, + save_kwargs: dict + ) -> Type: config_class = import_by_name(self._config_class, ["aux.configs"]) config_obj = config_class(save_kwargs=save_kwargs, **self._config_kwargs) return config_obj @@ -253,7 +299,10 @@ def create_obj(self, **kwargs): print(e) return obj - def merge(self, config): + def merge( + self, + config: Type + ): self_config_obj = getattr(self, CONFIG_OBJ) if config.__class__.__name__ == "ConfigPattern": config_obj = getattr(config, CONFIG_OBJ) @@ -262,7 +311,12 @@ def merge(self, config): setattr(self, CONFIG_OBJ, self_config_obj.merge(config)) return self - def to_saveable_dict(self, compact=False, need_full=True, **kwargs): + def to_saveable_dict( + self, + compact: bool = False, + need_full: bool = True, + **kwargs + ) -> dict: """ Create dict which values are strings without spaces and are guaranteed to be sorted by key including inner dicts and configs. @@ -284,7 +338,9 @@ def to_saveable_dict(self, compact=False, need_full=True, **kwargs): return dct -class Config(GeneralConfig): +class Config( + GeneralConfig +): """ Contains a set of named parameters. Immutable - values can be set in constructor only. Parameters can be dicts or Configs themselves. @@ -296,31 +352,50 @@ class Config(GeneralConfig): # _mutable = False # _CONFIG_KEYS = "_config_keys" - def __init__(self, save_kwargs=None, **kwargs): + def __init__( + self, + save_kwargs: Union[dict, None] = None, + **kwargs + ): self.__dict__[CONFIG_SAVE_KWARGS_KEY] = save_kwargs super().__init__(**kwargs) - def __str__(self): + def __str__( + self + ) -> str: return str(dict(filter(lambda x: x[0] in self._config_keys, self.__dict__.items()))) - def __iter__(self): + def __iter__( + self + ): for key, value in self.__dict__.items(): if key in self._config_keys: yield key, value - def __getitem__(self, item): + def __getitem__( + self, + item: str + ) -> Any: if item in self._config_keys: return self.__dict__[item] - def __contains__(self, item): + def __contains__( + self, + item: str + ) -> bool: return item in self._config_keys - def __eq__(self, other): + def __eq__( + self, + other: Type + ) -> bool: if type(other) != type(self): return False return all(getattr(self, a) == getattr(other, a) for a in self._config_keys) - def copy(self): + def copy( + self + ) -> object: res = type(self)() # res.__dict__ = self.__dict__.copy() for k, v in self.__dict__.items(): @@ -329,7 +404,10 @@ def copy(self): res.__dict__[k] = copy.copy(v) return res - def merge(self, config): + def merge( + self, + config: Union[dict, object] + ): """ Create a new config with params obtained by updating self params with a given ones. Given config is a dict or a Config. """ @@ -344,7 +422,11 @@ def merge(self, config): kwargs.update(config.copy()) return type(self)(**kwargs) - def to_saveable_dict(self, compact=False, **kwargs): + def to_saveable_dict( + self, + compact: bool = False, + **kwargs + ) -> dict: """ Create dict which values are strings without spaces and are guaranteed to be sorted by key including inner dicts and configs. @@ -361,23 +443,34 @@ def to_saveable_dict(self, compact=False, **kwargs): return dct -class DatasetConfig(Config): +class DatasetConfig( + Config +): """ Contains a set of distinguishing characteristics to identify the dataset or family of datasets. Determines the path to the file with raw data in the inner storage. """ - def __init__(self, domain: str = None, group: str = None, graph: str = None): + def __init__( + self, + domain: str = None, + group: str = None, + graph: str = None + ): """ """ super().__init__(domain=domain, group=group, graph=graph) - def full_name(self): + def full_name( + self + ) -> tuple: """ Return all fields as a tuple. """ return tuple([self.domain, self.group, self.graph]) @staticmethod - def from_full_name(full_name: tuple): + def from_full_name( + full_name: tuple + ) -> object: """ Build DatasetConfig from a name tuple. """ res = DatasetConfig( domain=full_name[0], group=full_name[1], graph=full_name[2]) @@ -390,17 +483,21 @@ class DatasetVarConfig(Config): Specifies the path to the file with tensors in the inner storage. """ - def __init__(self, - features: dict = None, - labeling: str = None, - dataset_ver_ind: int = None): + def __init__( + self, + features: dict = None, + labeling: str = None, + dataset_ver_ind: int = None + ): """ """ super().__init__( features=features, labeling=labeling, dataset_ver_ind=dataset_ver_ind) -class ModelStructureConfig(Config): +class ModelStructureConfig( + Config +): """ Contains a full description of a model structure. Represents a list of layers. @@ -578,73 +675,84 @@ class ModelStructureConfig(Config): >>> } """ - def __init__(self, layers=None): + def __init__( + self, + layers=None + ): """ """ super().__init__(layers=layers) - def __str__(self): + def __str__( + self + ) -> str: return json.dumps(self, indent=2) - def __iter__(self): + def __iter__( + self + ) -> None: for layer in self.layers: yield layer - def __getitem__(self, item): + def __getitem__( + self, + item: int + ) -> Any: assert isinstance(item, int) return self.layers[item] - def __len__(self): + def __len__( + self + ) -> int: return len(self.layers) -class ModelConfig(Config): +class ModelConfig( + Config +): """ Config for GNN model. Can contain structure (for framework models) and/or additional parameters (for custom models). """ - def __init__(self, - structure: [dict, ModelStructureConfig] = None, - **kwargs): + def __init__( + self, + structure: Union[dict, ModelStructureConfig] = None, + **kwargs + ): if structure is not None and not isinstance(structure, Config): assert isinstance(structure, dict) structure = ModelStructureConfig(**structure) super().__init__(structure=structure, **kwargs) -class ModelManagerConfig(Config): +class ModelManagerConfig( + Config +): """ Full description of model manager parameters. """ - # key_path = { - # "optimizer": OPTIMIZERS_PARAMETERS_PATH, - # "loss_function": FUNCTIONS_PARAMETERS_PATH, - # } - - def __init__(self, **kwargs): - """ """ - # FIXME misha how to find all such params? - # if CONFIG_CLASS_NAME in kwargs: - - # for key, value in kwargs.items(): - # if key in self.key_path and not isinstance(value, Config): - # if 'CONFIG_PARAMS_PATH_KEY' not in value: - # value[CONFIG_PARAMS_PATH_KEY] = self.key_path[key] - # kwargs[key] = Config(**value) - + def __init__( + self, + **kwargs + ): super().__init__(**kwargs) -class ModelModificationConfig(Config): +class ModelModificationConfig( + Config +): """ Variability of a model given its structure and manager. Represents model attack type and the instance version. """ _mutable = True - def __init__(self, - model_ver_ind: [int, None] = None, - epochs=None, **kwargs): + def __init__( + self, + model_ver_ind: [int, None] = None, + epochs=None, + **kwargs + ): """ :param model_ver_ind: model index when saving. If None, then takes the nearest unoccupied index starting from 0 in ascending increments of 1 @@ -653,18 +761,26 @@ def __init__(self, epochs=epochs, **kwargs) self.__dict__[DATA_CHANGE_FLAG] = False - def __setattr__(self, key, value): + def __setattr__( + self, + key: str, + value: Any + ) -> None: # Any change of ModelModificationConfig should change flag self.__dict__[DATA_CHANGE_FLAG] = True super().__setattr__(key, value) - def data_change_flag(self): + def data_change_flag( + self + ) -> bool: loc = self.__dict__[DATA_CHANGE_FLAG] self.__dict__[DATA_CHANGE_FLAG] = False return loc -class EvasionAttackConfig(Config): +class EvasionAttackConfig( + Config +): _mutable = True def __init__( @@ -676,7 +792,9 @@ def __init__( ) -class EvasionDefenseConfig(Config): +class EvasionDefenseConfig( + Config +): _mutable = True def __init__( @@ -688,7 +806,9 @@ def __init__( ) -class PoisonAttackConfig(Config): +class PoisonAttackConfig( + Config +): _mutable = True def __init__( @@ -700,7 +820,9 @@ def __init__( ) -class PoisonDefenseConfig(Config): +class PoisonDefenseConfig( + Config +): _mutable = True def __init__( @@ -712,7 +834,9 @@ def __init__( ) -class MIAttackConfig(Config): +class MIAttackConfig( + Config +): _mutable = True def __init__( @@ -724,7 +848,9 @@ def __init__( ) -class MIDefenseConfig(Config): +class MIDefenseConfig( + Config +): _mutable = True def __init__( @@ -736,7 +862,9 @@ def __init__( ) -class ExplainerInitConfig(Config): +class ExplainerInitConfig( + Config +): """ """ @@ -750,45 +878,38 @@ def __init__(self, ) -class ExplainerRunConfig(Config): +class ExplainerRunConfig( + Config +): """ """ - def __init__(self, - # mode: str, - # class_name: str = None, - **kwargs): - # assert mode in ["local", "global"] - # path = { - # "local": EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, - # "global": EXPLAINERS_GLOBAL_RUN_PARAMETERS_PATH, - # }[mode] - # if kwargs is None: - # kwargs = {} - # if CONFIG_CLASS_NAME not in kwargs.keys(): - # if class_name is None: - # raise Exception("ExplainerRunConfig need class_name") - # kwargs[CONFIG_CLASS_NAME] = class_name - # kwargs[CONFIG_PARAMS_PATH_KEY] = path + def __init__( + self, + **kwargs + ): super().__init__( **kwargs - # kwargs=Config(**kwargs), - # mode=mode, ) -class ExplainerModificationConfig(Config): +class ExplainerModificationConfig( + Config +): """ """ # _mutable = True - def __init__(self, - explainer_ver_ind: [int, None] = None, - **kwargs): + def __init__( + self, + explainer_ver_ind: [int, None] = None, + **kwargs + ): super().__init__( explainer_ver_ind=explainer_ver_ind, - **kwargs) + **kwargs + ) if __name__ == '__main__': diff --git a/src/models_builder/gnn_constructor.py b/src/models_builder/gnn_constructor.py index 9c82197..e09e8a5 100644 --- a/src/models_builder/gnn_constructor.py +++ b/src/models_builder/gnn_constructor.py @@ -2,16 +2,17 @@ import json import warnings from collections import OrderedDict -from typing import Dict, Callable +from typing import Dict, Callable, Union, Iterator, Type import torch from torch.nn.parameter import UninitializedParameter from torch.utils import hooks +from ..parameter import Parameter from torch.utils.hooks import RemovableHandle from torch_geometric.nn import MessagePassing from aux.utils import import_by_name, CUSTOM_LAYERS_INFO_PATH, MODULES_PARAMETERS_PATH, hash_data_sha256, \ TECHNICAL_PARAMETER_KEY, IMPORT_INFO_KEY -from aux.configs import ModelConfig, CONFIG_CLASS_NAME +from aux.configs import ModelConfig, CONFIG_CLASS_NAME, ModelStructureConfig class GNNConstructor: @@ -33,34 +34,46 @@ def __init__(self, self.obj_name = None self.model_config = model_config - def forward(self): + def forward( + self + ): raise NotImplementedError("forward can't be called, because it is not implemented") - def get_all_layer_embeddings(self): + def get_all_layer_embeddings( + self + ): """ :return: vectors representing the input data from the outputs of each layer of the neural network """ raise NotImplementedError("get_all_layer_embeddings can't be called, because it is not implemented") - def get_architecture(self): + def get_architecture( + self + ): """ :return: the architecture of the model, for display on the front """ raise NotImplementedError("get_architecture can't be called, because it is not implemented") - def get_num_hops(self): + def get_num_hops( + self + ): """ :return: the number of graph convolution layers. Required for some model interpretation algorithms to work """ raise NotImplementedError("get_num_hops can't be called, because it is not implemented") - def reset_parameters(self): + def reset_parameters( + self + ): """ resets all model parameters. Required for the reset button on the front to work. """ raise NotImplementedError("reset_parameters can't be called, because it is not implemented") - def get_predictions(self): + def get_predictions( + self + ): """ :return: a vector of estimates for the distribution of input data by class. Required for some interpretation algorithms to work. @@ -68,13 +81,17 @@ def get_predictions(self): """ raise NotImplementedError("get_predictions can't be called, because it is not implemented") - def get_parameters(self): + def get_parameters( + self + ): """ :return: matrix with model parameters """ raise NotImplementedError("get_predictions can't be called, because it is not implemented") - def get_answer(self): + def get_answer( + self + ): """ :return: an answer to which class the input belongs to. Required for some interpretation methods to work. @@ -82,7 +99,11 @@ def get_answer(self): """ raise NotImplementedError("get_answer can't be called, because it is not implemented") - def get_name(self, obj_name_flag=False, **kwargs): + def get_name( + self, + obj_name_flag: bool = False, + **kwargs + ): gnn_name = self.model_config.to_saveable_dict().copy() gnn_name[CONFIG_CLASS_NAME] = self.__class__.__name__ if obj_name_flag: @@ -93,7 +114,9 @@ def get_name(self, obj_name_flag=False, **kwargs): json_str = json.dumps(gnn_name, indent=2) return json_str - def suitable_model_managers(self): + def suitable_model_managers( + self + ): """ :return: a set of names of suitable model manager classes. Model manager classes must be inherited from the GNNModelManager class @@ -102,13 +125,17 @@ def suitable_model_managers(self): # === Permanent methods (not to be overwritten) - def get_hash(self): + def get_hash( + self + ) -> str: gnn_name = self.get_name() json_object = json.dumps(gnn_name) gnn_name_hash = hash_data_sha256(json_object.encode('utf-8')) return gnn_name_hash - def get_full_info(self): + def get_full_info( + self + ) -> dict: """ Get available info about model for frontend """ # FIXMe architecture and weights can be not accessible @@ -129,16 +156,23 @@ def get_full_info(self): return result -class GNNConstructorTorch(GNNConstructor, torch.nn.Module): +class GNNConstructorTorch( + GNNConstructor, + torch.nn.Module +): """ Base class for writing models using the torch library. Inherited from GNNConstructor and torch.nn.Module classes. """ - def __init__(self): + def __init__( + self + ): super().__init__() torch.nn.Module.__init__(self) - def flow(self): + def flow( + self + ): """ Flow direction of message passing, usually 'source_to_target' """ for module in self.modules(): @@ -146,7 +180,9 @@ def flow(self): return module.flow return 'source_to_target' - def get_neurons(self): + def get_neurons( + self + ): """ Return number of neurons of each convolution layer as list: [n_1, n_2, ..., n_k] """ neurons = [] @@ -165,7 +201,9 @@ def get_neurons(self): neurons.append(n_neurons) return neurons - def get_weights(self): + def get_weights( + self + ): """ Get model weights calling torch.nn.Module.state_dict() to draw them on the frontend. """ @@ -190,13 +228,18 @@ def get_weights(self): return model_data -class FrameworkGNNConstructor(GNNConstructorTorch): +class FrameworkGNNConstructor( + GNNConstructorTorch +): """ A class that uses metaprogramming to form a wide variety of models using the 'structure' variable in json format. Inherited from the GNNConstructorTorch class. """ - def __init__(self, model_config: ModelConfig = None, ): + def __init__( + self, + model_config: ModelConfig = None, + ): """ :param model_config: description of the gnn structure """ @@ -309,7 +352,10 @@ def __init__(self, model_config: ModelConfig = None, ): self.model_manager_restrictions = set() self._check_model_structure(self.structure) - def _check_model_structure(self, structure): + def _check_model_structure( + self, + structure: Union[dict, ModelStructureConfig], + ): with open(CUSTOM_LAYERS_INFO_PATH) as f: information_check_correctness_models = json.load(f) allowable_transitions = set(information_check_correctness_models["allowable_transitions"]) @@ -390,20 +436,30 @@ def _check_model_structure(self, structure): model_strong_restrictions.update(layer_strong_restrictions) self.model_manager_restrictions = model_manager_restrictions - def get_all_layer_embeddings(self, *args, **kwargs): + def get_all_layer_embeddings( + self, + *args, + **kwargs + ) -> dict: self._save_emb_flag = True layer_emb_dict = self(*args, **kwargs) self._save_emb_flag = False return layer_emb_dict - def register_my_forward_hook(self, hook: Callable[..., None]) -> RemovableHandle: + def register_my_forward_hook( + self, hook: Callable[..., None] + ) -> RemovableHandle: r"""Registers a forward hook on the module. """ handle = hooks.RemovableHandle(self._forward_hooks) self._my_forward_hooks[handle.id] = hook return handle - def forward(self, *args, **kwargs): + def forward( + self, + *args, + **kwargs + ) -> torch.Tensor: layer_ind = -1 tensor_storage = {} dim_cat = 0 @@ -483,15 +539,21 @@ def forward(self, *args, **kwargs): return layer_emb_dict return x - def reset_parameters(self): + def reset_parameters( + self + ) -> None: for elem in list(self.__dict__['_modules'].items()): if {'reset_parameters'}.issubset(dir(getattr(self, elem[0]))): getattr(self, elem[0]).reset_parameters() - def get_architecture(self): + def get_architecture( + self + ) -> Union[dict, ModelStructureConfig]: return self.structure - def get_num_hops(self): + def get_num_hops( + self + ) -> int: if self.num_hops is None: num_hops = 0 for module in self.modules(): @@ -505,21 +567,35 @@ def get_num_hops(self): else: return self.num_hops - def get_predictions(self, *args, **kwargs): + def get_predictions( + self, + *args, + **kwargs + ) -> torch.Tensor: return self(*args, **kwargs).softmax(dim=-1) # return self.forward(*args, **kwargs) - def get_parameters(self): + def get_parameters( + self + ) -> Iterator[Parameter]: return self.parameters() - def get_answer(self, *args, **kwargs): + def get_answer( + self, + *args, + **kwargs + ) -> torch.Tensor: return self.get_predictions(*args, **kwargs).argmax(dim=1) - def suitable_model_managers(self): + def suitable_model_managers( + self + ) -> set: return self.model_manager_restrictions @staticmethod - def arguments_read(*args, **kwargs): + def arguments_read( + *args, **kwargs + ) -> [torch.Tensor, torch.Tensor, Type, torch.Tensor]: """ The method is launched when the forward is executed extracts from the variable data or kwargs the data necessary to pass the forward: x, edge_index, batch From 87854b303ccebb506accfe3f2d0bae73216224b7 Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Thu, 21 Nov 2024 15:38:09 +0300 Subject: [PATCH 3/8] make better files in aux, models_zoo.py + some small fix --- src/aux/custom_decorators.py | 36 +++++++++--- src/aux/data_info.py | 85 ++++++++++++++++++++------- src/aux/declaration.py | 48 +++++++++++---- src/aux/prefix_storage.py | 69 +++++++++++++++++----- src/aux/utils.py | 25 ++++++-- src/models_builder/gnn_constructor.py | 3 +- src/models_builder/gnn_models.py | 1 + src/models_builder/models_zoo.py | 6 +- 8 files changed, 210 insertions(+), 63 deletions(-) diff --git a/src/aux/custom_decorators.py b/src/aux/custom_decorators.py index f6b9cb2..7f290bd 100644 --- a/src/aux/custom_decorators.py +++ b/src/aux/custom_decorators.py @@ -2,19 +2,28 @@ from functools import wraps import logging import functools +from typing import Callable logging.basicConfig(level=logging.INFO) -def retry(max_tries=3, delay_seconds=1): +def retry( + max_tries: int = 3, + delay_seconds: int = 1 +): """ Allows you to re-execute the program after the Nth amount of time :param max_tries: number of restart attempts :param delay_seconds: time interval between attempts """ - def decorator_retry(func): + def decorator_retry( + func: Callable + ) -> Callable: @wraps(func) - def wrapper_retry(*args, **kwargs): + def wrapper_retry( + *args, + **kwargs + ): tries = 0 while tries < max_tries: try: @@ -30,13 +39,17 @@ def wrapper_retry(*args, **kwargs): return decorator_retry -def memoize(func): +def memoize( + func: Callable +) -> Callable: """ Caching function """ cache = {} - def wrapper(*args): + def wrapper( + *args + ): if args in cache: return cache[args] else: @@ -47,11 +60,16 @@ def wrapper(*args): return wrapper -def timing_decorator(func): +def timing_decorator( + func: Callable +) -> Callable: """ Timing functions """ - def wrapper(*args, **kwargs): + def wrapper( + *args, + **kwargs + ): start_time = time.time() result = func(*args, **kwargs) end_time = time.time() @@ -61,7 +79,9 @@ def wrapper(*args, **kwargs): return wrapper -def log_execution(func): +def log_execution( + func: Callable +) -> Callable: """ Function call logging """ diff --git a/src/aux/data_info.py b/src/aux/data_info.py index 390823a..ae82230 100644 --- a/src/aux/data_info.py +++ b/src/aux/data_info.py @@ -2,6 +2,8 @@ import importlib.util import json import logging +from typing import List, Tuple, Union + # from pydantic.utils import deep_update # from pydantic.v1.utils import deep_update @@ -24,7 +26,8 @@ class DataInfo: """ @staticmethod - def refresh_all_data_info(): + def refresh_all_data_info( + ) -> None: """ Calling all files to update with information about saved objects """ @@ -35,7 +38,8 @@ def refresh_all_data_info(): DataInfo.refresh_explanations_dir_structure() @staticmethod - def refresh_data_dir_structure(): + def refresh_data_dir_structure( + ) -> None: """ Calling a file update with information about saved raw datasets """ @@ -51,7 +55,8 @@ def refresh_data_dir_structure(): prev_path = path @staticmethod - def refresh_models_dir_structure(): + def refresh_models_dir_structure( + ) -> None: """ Calling a file update with information about saved models """ @@ -62,7 +67,8 @@ def refresh_models_dir_structure(): f.write(str(Path(*path)) + '\n') @staticmethod - def refresh_explanations_dir_structure(): + def refresh_explanations_dir_structure( + ) -> None: """ Calling a file update with information about saved explanations """ @@ -73,7 +79,8 @@ def refresh_explanations_dir_structure(): f.write(str(Path(*path)) + '\n') @staticmethod - def refresh_data_var_dir_structure(): + def refresh_data_var_dir_structure( + ) -> None: """ Calling a file update with information about saved prepared datasets """ @@ -84,7 +91,9 @@ def refresh_data_var_dir_structure(): f.write(str(Path(*path)) + '\n') @staticmethod - def take_keys_etc_by_prefix(prefix): + def take_keys_etc_by_prefix( + prefix: Tuple + ) -> [List, List, dict, int]: """ :param prefix: what data and in what order were used to form the path when saving the object @@ -114,7 +123,11 @@ def take_keys_etc_by_prefix(prefix): return keys_list, full_keys_list, dir_structure, empty_dir_shift @staticmethod - def values_list_by_path_and_keys(path, full_keys_list, dir_structure): + def values_list_by_path_and_keys( + path: Union[str, Path], + full_keys_list: List, + dir_structure: dict + ) -> List: """ :param path: path of the saved object @@ -133,7 +146,10 @@ def values_list_by_path_and_keys(path, full_keys_list, dir_structure): return parts_val @staticmethod - def values_list_and_technical_files_by_path_and_prefix(path, prefix): + def values_list_and_technical_files_by_path_and_prefix( + path: Union[str, Path], + prefix: Tuple + ) -> [List, dict]: """ :param path: path of the saved object @@ -169,12 +185,16 @@ def values_list_and_technical_files_by_path_and_prefix(path, prefix): else: file_name = file_info_dict["file_name"] file_name += file_info_dict["format"] - description_info.update({key: {parts_val[-1]: os.path.join(os.path.join(*path[:parts_parse]), file_name)}}) + description_info.update( + {key: {parts_val[-1]: os.path.join(os.path.join(*path[:parts_parse]), file_name)}}) parts_parse += 1 return parts_val, description_info @staticmethod - def fill_prefix_storage(prefix, file_with_paths): + def fill_prefix_storage( + prefix: Tuple, + file_with_paths: Union[str, Path] + ) -> [PrefixStorage, dict]: """ Fill prefix storage by file with paths @@ -183,7 +203,7 @@ def fill_prefix_storage(prefix, file_with_paths): :param file_with_paths: file with paths of saved objects :return: fill prefix storage and dict with description_info about objects use hash """ - keys_list, full_keys_list, dir_structure, empty_dir_shift =\ + keys_list, full_keys_list, dir_structure, empty_dir_shift = \ DataInfo.take_keys_etc_by_prefix(prefix=prefix) ps = PrefixStorage(keys_list) with open(file_with_paths, 'r', encoding='utf-8') as f: @@ -202,7 +222,10 @@ def fill_prefix_storage(prefix, file_with_paths): return ps, description_info @staticmethod - def deep_update(d, u): + def deep_update( + d: dict, + u: dict + ) -> dict: for k, v in u.items(): if isinstance(v, collections.abc.Mapping): d[k] = DataInfo.deep_update(d.get(k, {}), v) @@ -211,14 +234,19 @@ def deep_update(d, u): return d @staticmethod - def description_info_with_paths_to_description_info_with_files_values(description_info, root_path): + def description_info_with_paths_to_description_info_with_files_values( + description_info: dict, + root_path: Union[str, Path] + ) -> dict: for description_info_key, description_info_val in description_info.items(): for obj_name, obj_file_path in description_info_val.items(): with open(os.path.join(root_path, obj_file_path)) as f: description_info[description_info_key][obj_name] = f.read() return description_info + @staticmethod - def explainers_parse(): + def explainers_parse( + ) -> [PrefixStorage, dict]: """ Parses the path to explainers from a technical file with the paths of all saved explainers. """ @@ -232,7 +260,8 @@ def explainers_parse(): return ps, description_info @staticmethod - def models_parse(): + def models_parse( + ) -> [PrefixStorage, dict]: """ Parses the path to models from a technical file with the paths of all saved models. """ @@ -246,7 +275,8 @@ def models_parse(): return ps, description_info @staticmethod - def data_parse(): + def data_parse( + ) -> [PrefixStorage, dict]: """ Parses the path to raw datasets from a technical file with the paths of all saved raw datasets. """ @@ -257,7 +287,8 @@ def data_parse(): return ps @staticmethod - def data_var_parse(): + def data_var_parse( + ) -> [PrefixStorage, dict]: """ Parses the path to prepared datasets from a technical file with the paths of all saved prepared datasets. """ @@ -268,7 +299,9 @@ def data_var_parse(): return ps @staticmethod - def clean_prepared_data(dry_run=False): + def clean_prepared_data( + dry_run: bool = False + ) -> None: """ Remove all prepared data for all datasets. """ @@ -279,7 +312,9 @@ def clean_prepared_data(dry_run=False): shutil.rmtree(path) @staticmethod - def all_obj_ver_by_obj_path(obj_dir_path): + def all_obj_ver_by_obj_path( + obj_dir_path: Union[str, Path] + ) -> set: """ :param obj_dir_path: path to the saved object @@ -294,7 +329,9 @@ def all_obj_ver_by_obj_path(obj_dir_path): return set(vers_ind) @staticmethod - def del_all_empty_folders(dir_path): + def del_all_empty_folders( + dir_path: Union[str, Path] + ) -> None: """ Deletes all empty folders and files with meta information in the selected directory @@ -316,7 +353,8 @@ def del_all_empty_folders(dir_path): class UserCodeInfo: @staticmethod - def user_models_list_ref(): + def user_models_list_ref( + ) -> dict: """ :return: dict with information about user models objects in directory /user_model_list Contains information about objects class name, objects names and import paths @@ -371,7 +409,10 @@ def models_init(): return user_models_obj_dict_info @staticmethod - def take_user_model_obj(user_file_path, obj_name: str): + def take_user_model_obj( + user_file_path: Union[str, Path], + obj_name: str + ) -> object: """ :param user_file_path: path to the user file with user model :param obj_name: user object name diff --git a/src/aux/declaration.py b/src/aux/declaration.py index 7988911..ee3eefc 100644 --- a/src/aux/declaration.py +++ b/src/aux/declaration.py @@ -1,5 +1,7 @@ import json +from typing import Union, Type +from aux.configs import DatasetConfig, DatasetVarConfig from aux.utils import MODELS_DIR, GRAPHS_DIR, EXPLANATIONS_DIR, hash_data_sha256, \ SAVE_DIR_STRUCTURE_PATH import os @@ -12,7 +14,11 @@ class Declare: """ @staticmethod - def obj_info_to_path(what_save=None, previous_path=None, obj_info=None): + def obj_info_to_path( + what_save: str = None, + previous_path: Union[str, Path] = None, + obj_info: Union[None, list, tuple, dict] = None + ) -> [Path, list]: """ :param what_save: the path for which object is being built. Now support: data_root, data_prepared, models, explanations @@ -87,7 +93,9 @@ def obj_info_to_path(what_save=None, previous_path=None, obj_info=None): return path, files_paths @staticmethod - def dataset_root_dir(dataset_config): + def dataset_root_dir( + dataset_config: DatasetConfig + ) -> [Path, list]: """ :param dataset_config: DatasetConfig :return: forms the path to the data folder and adds to it the path to a specific dataset @@ -99,7 +107,10 @@ def dataset_root_dir(dataset_config): return path, files_paths @staticmethod - def dataset_prepared_dir(dataset_config, dataset_var_config): + def dataset_prepared_dir( + dataset_config: DatasetConfig, + dataset_var_config: DatasetVarConfig + ) -> [Path, list]: """ :param dataset_config: DatasetConfig :param dataset_var_config: DatasetVarConfig @@ -128,7 +139,9 @@ def dataset_prepared_dir(dataset_config, dataset_var_config): return path, files_paths @staticmethod - def models_path(class_obj): + def models_path( + class_obj: Type + ) -> [Path, list]: """ :param class_obj: class base on GNNModelManager :return: The path where the model will be saved @@ -189,8 +202,8 @@ def declare_model_by_config( mi_attack_hash: str, evasion_attack_hash: str, poison_attack_hash: str, - epochs=None, - ): + epochs: Union[int, str] = None, + ) -> [Path, list]: """ Formation of the way to save the path of the model in the root of the project according to its hyperparameters and features @@ -199,6 +212,12 @@ def declare_model_by_config( :param model_ver_ind: index of explain version :param gnn_name: gnn hash :param epochs: number of epochs during which the model was trained + :param mi_defense_hash: + :param evasion_defense_hash: + :param poison_defense_hash: + :param mi_attack_hash: + :param evasion_attack_hash: + :param poison_attack_hash: :return: the path where the model is saved use information from ModelConfig """ if not isinstance(model_ver_ind, int) or model_ver_ind < 0: @@ -223,16 +242,19 @@ def declare_model_by_config( return path, files_paths @staticmethod - def explanation_file_path(models_path: str, explainer_name: str, - explainer_ver_ind: int = None, - explainer_run_kwargs=None, explainer_init_kwargs=None): + def explanation_file_path( + models_path: str, + explainer_name: str, + explainer_ver_ind: int = None, + explainer_run_kwargs: dict = None, + explainer_init_kwargs: dict = None + ) -> [Path, list]: """ :param explainer_init_kwargs: dict with kwargs for explainer class :param explainer_run_kwargs:dict with kwargs for run explanation :param models_path: model path :param explainer_name: explainer name. Example: Zorro :param explainer_ver_ind: index of explain version - :param explainer_attack_type: type of attack on explainer. Now support: original :return: path for explanations result file and list with technical files """ explainer_init_kwargs = explainer_init_kwargs.copy() @@ -279,7 +301,10 @@ def explanation_file_path(models_path: str, explainer_name: str, return path, files_paths @staticmethod - def explainer_kwargs_path_full(model_path, explainer_path): + def explainer_kwargs_path_full( + model_path: Union[str, Path], + explainer_path: Union[str, Path] + ) -> list: """ :param model_path: model path :param explainer_path: explanation path @@ -287,6 +312,7 @@ def explainer_kwargs_path_full(model_path, explainer_path): """ path = Path(str(model_path).replace(str(MODELS_DIR), str(EXPLANATIONS_DIR))) what_save = "explanations" + # BUG Misha, check is correct next line, because in def obj_info_to_path can't be Path or str obj_info = explainer_path _, files_paths = Declare.obj_info_to_path(what_save=what_save, previous_path=path, diff --git a/src/aux/prefix_storage.py b/src/aux/prefix_storage.py index 376a780..8b3c844 100644 --- a/src/aux/prefix_storage.py +++ b/src/aux/prefix_storage.py @@ -1,6 +1,7 @@ import json from json.encoder import JSONEncoder from pathlib import Path +from typing import Union class PrefixStorage: @@ -10,26 +11,38 @@ class PrefixStorage: * adding, removing, filtering, iterating elements; * gathering contents from file structure. """ - def __init__(self, keys: (tuple, list)): + def __init__( + self, + keys: Union[tuple, list] + ): assert isinstance(keys, (tuple, list)) assert len(keys) >= 1 self._keys = keys self.content = {} if len(keys) > 1 else set() @property - def depth(self): + def depth( + self + ) -> int: return len(self._keys) @property - def keys(self): + def keys( + self + ) -> tuple: return tuple(self._keys) - def size(self): + def size( + self + ) -> int: def count(obj): return sum(count(_) for _ in obj.values()) if isinstance(obj, dict) else len(obj) return count(self.content) - def add(self, values: (dict, tuple, list)): + def add( + self, + values: Union[dict, tuple, list] + ) -> None: """ Add one list of values. """ @@ -56,7 +69,11 @@ def add(obj, depth): else: raise TypeError("dict, tuple, or list were expected") - def merge(self, ps, ignore_conflicts=False): + def merge( + self, + ps, + ignore_conflicts: bool = False + ) -> None: """ Extend this with another PrefixStorage with same keys. if ignore_conflicts=True, do not raise Exception when values sets intersect. @@ -80,7 +97,10 @@ def merge(content1, content2): merge(self.content, ps.content) - def remove(self, values: (dict, tuple, list)): + def remove( + self, + values: Union[dict, tuple, list] + ) -> None: """ Remove one tuple of values if it is present. """ @@ -97,7 +117,10 @@ def rm(obj, depth): rm(self.content, 0) - def filter(self, key_values: dict): + def filter( + self, + key_values: dict + ): """ Find all items satisfying specified key values. Returns a new PrefixStorage. """ @@ -132,7 +155,10 @@ def filter(obj, depth): ps.content = filter(self.content, 0) return ps - def check(self, values: (dict, tuple, list)): + def check( + self, + values: Union[dict, tuple, list] + ) -> bool: """ Check if a tuple of values is present. """ @@ -151,7 +177,9 @@ def check(self, values: (dict, tuple, list)): else: return False - def __iter__(self): + def __iter__( + self + ): def enum(obj, elems): if isinstance(obj, (set, list)): for e in obj: @@ -165,7 +193,9 @@ def enum(obj, elems): yield _ @staticmethod - def from_json(string): + def from_json( + string: str + ): """ Construct PrefixStorage object from a json string. """ @@ -174,7 +204,10 @@ def from_json(string): ps.content = data["content"] return ps - def to_json(self, **dump_args): + def to_json( + self, + **dump_args + ) -> str: """ Return json string. """ class Encoder(JSONEncoder): def default(self, obj): @@ -183,7 +216,11 @@ def default(self, obj): return json.JSONEncoder.default(self, obj) return json.dumps({"keys": self.keys, "content": self.content}, cls=Encoder, **dump_args) - def fill_from_folder(self, path: Path, file_pattern=r".*"): + def fill_from_folder( + self, + path: Path, + file_pattern: str = r".*" + ) -> None: """ Recursively walk over the given folder and repeat its structure. The content will be replaced. @@ -208,7 +245,11 @@ def walk(p, elems): self.add(e) print(f"Added {self.size()} items of {len(res)} files found.") - def remap(self, mapping, only_values=False): + def remap( + self, + mapping, + only_values: bool = False + ): """ Change keys order and combination. """ diff --git a/src/aux/utils.py b/src/aux/utils.py index 2acd503..017951b 100644 --- a/src/aux/utils.py +++ b/src/aux/utils.py @@ -3,6 +3,8 @@ import warnings from pathlib import Path from pydoc import locate +from typing import Union, Type, Any + import numpy as np root_dir = Path(__file__).parent.parent.parent.resolve() # directory of source root @@ -40,11 +42,16 @@ TECHNICAL_PARAMETER_KEY = "_technical_parameter" -def hash_data_sha256(data): +def hash_data_sha256( + data +) -> str: return hashlib.sha256(data).hexdigest() -def import_by_name(name: str, packs: list = None): +def import_by_name( + name: str, + packs: list = None +) -> None: """ Import name from packages, return class :param name: class name, full or relative @@ -63,7 +70,9 @@ def import_by_name(name: str, packs: list = None): raise ImportError(f"Unknown {packs} model '{name}', couldn't import.") -def model_managers_info_by_names_list(model_managers_names: set): +def model_managers_info_by_names_list( + model_managers_names: set +) -> dict: """ :param model_managers_names: set with model managers class names (user and framework) :return: dict with info about model managers @@ -86,7 +95,11 @@ def model_managers_info_by_names_list(model_managers_names: set): return model_managers_info -def setting_class_default_parameters(class_name: str, class_kwargs: dict, default_parameters_file_path): +def setting_class_default_parameters( + class_name: str, + class_kwargs: dict, + default_parameters_file_path: Union[str, Path] +) -> [dict, dict]: """ :param class_name: class name, should be same in default_parameters_file :param class_kwargs: dict with parameters, which needs to be supplemented with default parameters @@ -143,6 +156,8 @@ def setting_class_default_parameters(class_name: str, class_kwargs: dict, defaul return class_kwargs_for_save, class_kwargs_for_init -def all_subclasses(cls): +def all_subclasses( + cls: Type[Any] +) -> set: return set(cls.__subclasses__()).union( [s for c in cls.__subclasses__() for s in all_subclasses(c)]) diff --git a/src/models_builder/gnn_constructor.py b/src/models_builder/gnn_constructor.py index e09e8a5..3df22e8 100644 --- a/src/models_builder/gnn_constructor.py +++ b/src/models_builder/gnn_constructor.py @@ -6,7 +6,6 @@ import torch from torch.nn.parameter import UninitializedParameter from torch.utils import hooks -from ..parameter import Parameter from torch.utils.hooks import RemovableHandle from torch_geometric.nn import MessagePassing @@ -577,7 +576,7 @@ def get_predictions( def get_parameters( self - ) -> Iterator[Parameter]: + ) -> Iterator: return self.parameters() def get_answer( diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index 846b418..bd547b6 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -1228,6 +1228,7 @@ def run_model( def evaluate_model( self, gen_dataset: DatasetManager, + metrics: Union[List[Metric], Metric] ) -> dict: """ Compute metrics for a model result on a part of dataset specified by the metric mask. diff --git a/src/models_builder/models_zoo.py b/src/models_builder/models_zoo.py index 72003bd..a8a9b18 100644 --- a/src/models_builder/models_zoo.py +++ b/src/models_builder/models_zoo.py @@ -1,8 +1,12 @@ +from base.datasets_processing import DatasetManager from models_builder.gnn_constructor import FrameworkGNNConstructor from aux.configs import ModelConfig, ModelStructureConfig -def model_configs_zoo(dataset, model_name): +def model_configs_zoo( + dataset: DatasetManager, + model_name: str +): gat_gin_lin = FrameworkGNNConstructor( model_config=ModelConfig( structure=ModelStructureConfig( From b326cb0c72f69ef48ced97b0a48163ea55bbdf44 Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Thu, 21 Nov 2024 15:48:42 +0300 Subject: [PATCH 4/8] make better files in defense --- src/defense/defense_base.py | 18 ++++++-- src/defense/evasion_defense.py | 84 +++++++++++++++++++++++++--------- src/defense/mi_defense.py | 29 +++++++++--- src/defense/poison_defense.py | 42 +++++++++++++---- 4 files changed, 134 insertions(+), 39 deletions(-) diff --git a/src/defense/defense_base.py b/src/defense/defense_base.py index 1951bf6..d46921a 100644 --- a/src/defense/defense_base.py +++ b/src/defense/defense_base.py @@ -1,14 +1,26 @@ +from typing import Type + +from base.datasets_processing import DatasetManager + + class Defender: name = "Defender" - def __init__(self): + def __init__( + self + ): pass - def defense_diff(self): + def defense_diff( + self + ): pass @staticmethod - def check_availability(gen_dataset, model_manager): + def check_availability( + gen_dataset: DatasetManager, + model_manager: Type + ): return False diff --git a/src/defense/evasion_defense.py b/src/defense/evasion_defense.py index 8a85f34..18d2cca 100644 --- a/src/defense/evasion_defense.py +++ b/src/defense/evasion_defense.py @@ -1,8 +1,10 @@ +from typing import Type + import torch from defense.defense_base import Defender from src.aux.utils import import_by_name -from src.aux.configs import ModelModificationConfig, ConfigPattern +from src.aux.configs import ModelModificationConfig, ConfigPattern, EvasionAttackConfig from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \ EVASION_DEFENSE_PARAMETERS_PATH from attacks.evasion_attacks import FGSMAttacker @@ -12,24 +14,43 @@ import copy -class EvasionDefender(Defender): - def __init__(self, **kwargs): +class EvasionDefender( + Defender +): + def __init__( + self, + **kwargs + ): super().__init__() - def pre_batch(self, **kwargs): + def pre_batch( + self, + **kwargs + ): pass - def post_batch(self, **kwargs): + def post_batch( + self, + **kwargs + ): pass -class EmptyEvasionDefender(EvasionDefender): +class EmptyEvasionDefender( + EvasionDefender +): name = "EmptyEvasionDefender" - def pre_batch(self, **kwargs): + def pre_batch( + self, + **kwargs + ): pass - def post_batch(self, **kwargs): + def post_batch( + self, + **kwargs + ): pass @@ -52,27 +73,39 @@ def post_batch(self, model_manager, batch, loss, **kwargs): # TODO Kirill, add code in pre_batch -class QuantizationDefender(EvasionDefender): +class QuantizationDefender( + EvasionDefender +): name = "QuantizationDefender" - def __init__(self, qbit=8): + def __init__( + self, + qbit: int = 8 + ): super().__init__() self.regularization_strength = qbit - def pre_batch(self, **kwargs): + def pre_batch( + self, + **kwargs + ): + # TODO Kirill pass -class DataWrap: - def __init__(self, batch) -> None: - self.data = batch - self.dataset = self - - -class AdvTraining(EvasionDefender): +class AdvTraining( + EvasionDefender +): + # TODO Kirill, rewrite name = "AdvTraining" - def __init__(self, attack_name=None, attack_config=None, attack_type=None, device='cpu'): + def __init__( + self, + attack_name: str = None, + attack_config: EvasionAttackConfig = None, + attack_type: str = None, + device: str = 'cpu' + ): super().__init__() assert device is not None, "Please specify 'device'!" if not attack_config: @@ -118,7 +151,11 @@ def __init__(self, attack_name=None, attack_config=None, attack_type=None, devic else: raise KeyError(f"There is no {self.attack_config._class_name} class") - def pre_batch(self, model_manager, batch): + def pre_batch( + self, + model_manager: Type, + batch + ): super().pre_batch(model_manager=model_manager, batch=batch) self.perturbed_gen_dataset = data.Data() self.perturbed_gen_dataset.data = copy.deepcopy(batch) @@ -129,7 +166,12 @@ def pre_batch(self, model_manager, batch): gen_dataset=self.perturbed_gen_dataset, mask_tensor=self.perturbed_gen_dataset.data.train_mask) - def post_batch(self, model_manager, batch, loss) -> dict: + def post_batch( + self, + model_manager: Type, + batch, + loss: torch.Tensor + ) -> dict: super().post_batch(model_manager=model_manager, batch=batch, loss=loss) # Output on perturbed data outputs = model_manager.gnn(self.perturbed_gen_dataset.data.x, self.perturbed_gen_dataset.data.edge_index) diff --git a/src/defense/mi_defense.py b/src/defense/mi_defense.py index 5563cde..75ce9c0 100644 --- a/src/defense/mi_defense.py +++ b/src/defense/mi_defense.py @@ -1,22 +1,39 @@ from defense.defense_base import Defender -class MIDefender(Defender): - def __init__(self, **kwargs): +class MIDefender( + Defender +): + def __init__( + self, + **kwargs + ): super().__init__() - def pre_batch(self, **kwargs): + def pre_batch( + self, + **kwargs + ): pass - def post_batch(self, **kwargs): + def post_batch( + self, + **kwargs + ): pass class EmptyMIDefender(MIDefender): name = "EmptyMIDefender" - def pre_batch(self, **kwargs): + def pre_batch( + self, + **kwargs + ): pass - def post_batch(self, **kwargs): + def post_batch( + self, + **kwargs + ): pass diff --git a/src/defense/poison_defense.py b/src/defense/poison_defense.py index d09df04..0e3020a 100644 --- a/src/defense/poison_defense.py +++ b/src/defense/poison_defense.py @@ -1,26 +1,43 @@ import numpy as np +from base.datasets_processing import DatasetManager from defense.defense_base import Defender -class PoisonDefender(Defender): - def __init__(self, **kwargs): +class PoisonDefender( + Defender +): + def __init__( + self, + **kwargs + ): super().__init__() - def defense(self, **kwargs): + def defense( + self, + **kwargs + ): pass -class BadRandomPoisonDefender(PoisonDefender): +class BadRandomPoisonDefender( + PoisonDefender +): name = "BadRandomPoisonDefender" - def __init__(self, n_edges_percent=0.1): + def __init__( + self, + n_edges_percent: float = 0.1 + ): self.defense_diff = None super().__init__() self.n_edges_percent = n_edges_percent - def defense(self, gen_dataset): + def defense( + self, + gen_dataset: DatasetManager + ) -> DatasetManager: edge_index = gen_dataset.data.edge_index random_indices = np.random.choice( edge_index.shape[1], @@ -35,12 +52,19 @@ def defense(self, gen_dataset): self.defense_diff = edge_index_diff return gen_dataset - def defense_diff(self): + def defense_diff( + self + ): return self.defense_diff -class EmptyPoisonDefender(PoisonDefender): +class EmptyPoisonDefender( + PoisonDefender +): name = "EmptyPoisonDefender" - def defense(self, gen_dataset): + def defense( + self, + gen_dataset:DatasetManager + ) -> DatasetManager: return gen_dataset From 0b827951af7aa94991813b3d52f654b9de50391f Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Thu, 21 Nov 2024 16:01:00 +0300 Subject: [PATCH 5/8] make better files in attacks --- src/attacks/attack_base.py | 23 ++++- src/attacks/evasion_attacks.py | 150 ++++++++++++++++++++++++--------- src/attacks/mi_attacks.py | 18 +++- src/attacks/poison_attacks.py | 56 ++++++------ 4 files changed, 174 insertions(+), 73 deletions(-) diff --git a/src/attacks/attack_base.py b/src/attacks/attack_base.py index 16754aa..fad6055 100644 --- a/src/attacks/attack_base.py +++ b/src/attacks/attack_base.py @@ -1,17 +1,32 @@ +from typing import Type + +from base.datasets_processing import DatasetManager + + class Attacker: name = "Attacker" - def __init__(self): + def __init__( + self + ): pass - def attack(self, **kwargs): + def attack( + self, + **kwargs + ): pass - def attack_diff(self): + def attack_diff( + self + ): pass @staticmethod - def check_availability(gen_dataset, model_manager): + def check_availability( + gen_dataset: DatasetManager, + model_manager: Type + ): return False diff --git a/src/attacks/evasion_attacks.py b/src/attacks/evasion_attacks.py index 47a65a5..a3ad225 100644 --- a/src/attacks/evasion_attacks.py +++ b/src/attacks/evasion_attacks.py @@ -1,8 +1,11 @@ +from typing import Type, Union + import torch import torch.nn.functional as F import numpy as np from attacks.attack_base import Attacker +from base.datasets_processing import DatasetManager # Nettack imports from src.attacks.nettack.nettack import Nettack @@ -11,31 +14,50 @@ # PGD imports from attacks.evasion_attacks_collection.pgd.utils import Projection, RandomSampling import torch.nn.functional as F -from torch_geometric.utils import to_dense_adj, dense_to_sparse, k_hop_subgraph +from torch_geometric.utils import k_hop_subgraph from tqdm import tqdm -from torch_geometric.nn import SGConv -class EvasionAttacker(Attacker): - def __init__(self, **kwargs): +class EvasionAttacker( + Attacker +): + def __init__( + self, + **kwargs + ): super().__init__() -class EmptyEvasionAttacker(EvasionAttacker): +class EmptyEvasionAttacker( + EvasionAttacker +): name = "EmptyEvasionAttacker" - def attack(self, **kwargs): + def attack( + self, + **kwargs + ): pass -class FGSMAttacker(EvasionAttacker): +class FGSMAttacker( + EvasionAttacker +): name = "FGSM" - def __init__(self, epsilon=0.1): + def __init__( + self, + epsilon: float = 0.1 + ): super().__init__() self.epsilon = epsilon - def attack(self, model_manager, gen_dataset, mask_tensor): + def attack( + self, + model_manager: Type, + gen_dataset: DatasetManager, + mask_tensor: torch.Tensor + ): gen_dataset.data.x.requires_grad = True output = model_manager.gnn(gen_dataset.data.x, gen_dataset.data.edge_index, gen_dataset.data.batch) loss = model_manager.loss_function(output[mask_tensor], @@ -49,16 +71,20 @@ def attack(self, model_manager, gen_dataset, mask_tensor): return gen_dataset -class PGDAttacker(EvasionAttacker): +class PGDAttacker( + EvasionAttacker +): name = "PGD" - def __init__(self, - is_feature_attack=False, - element_idx=0, - epsilon=0.5, - learning_rate=0.001, - num_iterations=100, - num_rand_trials=100): + def __init__( + self, + is_feature_attack: bool = False, + element_idx: int = 0, + epsilon: float = 0.5, + learning_rate: float = 0.001, + num_iterations: int = 100, + num_rand_trials: int = 100 + ): super().__init__() self.attack_diff = None @@ -69,13 +95,22 @@ def __init__(self, self.num_iterations = num_iterations self.num_rand_trials = num_rand_trials - def attack(self, model_manager, gen_dataset, mask_tensor): + def attack( + self, + model_manager: Type, + gen_dataset: DatasetManager, + mask_tensor: torch.Tensor + ) -> None: if gen_dataset.is_multi(): self._attack_on_graph(model_manager, gen_dataset) else: self._attack_on_node(model_manager, gen_dataset) - def _attack_on_node(self, model_manager, gen_dataset): + def _attack_on_node( + self, + model_manager: Type, + gen_dataset: DatasetManager + ) -> None: node_idx = self.element_idx edge_index = gen_dataset.data.edge_index @@ -118,7 +153,11 @@ def _attack_on_node(self, model_manager, gen_dataset): else: # structure attack pass - def _attack_on_graph(self, model_manager, gen_dataset): + def _attack_on_graph( + self, + model_manager: Type, + gen_dataset: DatasetManager + ): graph_idx = self.element_idx edge_index = gen_dataset.dataset[graph_idx].edge_index @@ -149,21 +188,26 @@ def _attack_on_graph(self, model_manager, gen_dataset): else: # structure attack pass - def attack_diff(self): + def attack_diff( + self + ): return self.attack_diff -class NettackEvasionAttacker(EvasionAttacker): +class NettackEvasionAttacker( + EvasionAttacker +): name = "NettackEvasionAttacker" - def __init__(self, - node_idx=0, - n_perturbations=None, - perturb_features=True, - perturb_structure=True, - direct=True, - n_influencers=0 - ): + def __init__( + self, + node_idx: int = 0, + n_perturbations: Union[int, None] = None, + perturb_features: bool = True, + perturb_structure: bool = True, + direct: bool = True, + n_influencers: int = 0 + ): super().__init__() self.attack_diff = None @@ -174,7 +218,12 @@ def __init__(self, self.direct = direct self.n_influencers = n_influencers - def attack(self, model_manager, gen_dataset, mask_tensor): + def attack( + self, + model_manager: Type, + gen_dataset: DatasetManager, + mask_tensor: torch.Tensor + ) -> DatasetManager: # Prepare data = gen_dataset.data _A_obs, _X_obs, _z_obs = data_to_csr_matrix(data) @@ -222,11 +271,17 @@ def attack(self, model_manager, gen_dataset, mask_tensor): return gen_dataset - def attack_diff(self): + def attack_diff( + self + ): return self.attack_diff @staticmethod - def _evasion(gen_dataset, feature_perturbations, structure_perturbations): + def _evasion( + gen_dataset: DatasetManager, + feature_perturbations, + structure_perturbations + ): cleaned_feat_pert = list(filter(None, feature_perturbations)) if cleaned_feat_pert: # list is not empty x = gen_dataset.data.x.clone() @@ -243,17 +298,27 @@ def _evasion(gen_dataset, feature_perturbations, structure_perturbations): # add edges for edge in cleaned_struct_pert: edge_index = torch.cat((edge_index, - torch.tensor((edge[0], edge[1]), dtype=torch.int32).to(torch.int64).unsqueeze(1)), dim=1) + torch.tensor((edge[0], edge[1]), dtype=torch.int32).to(torch.int64).unsqueeze( + 1)), dim=1) edge_index = torch.cat((edge_index, - torch.tensor((edge[1], edge[0]), dtype=torch.int32).to(torch.int64).unsqueeze(1)), dim=1) + torch.tensor((edge[1], edge[0]), dtype=torch.int32).to(torch.int64).unsqueeze( + 1)), dim=1) gen_dataset.data.edge_index = edge_index -class NettackGroupEvasionAttacker(EvasionAttacker): + +class NettackGroupEvasionAttacker( + EvasionAttacker +): name = "NettackGroupEvasionAttacker" - def __init__(self,node_idxs, **kwargs): + + def __init__( + self, + node_idxs: list, + **kwargs + ): super().__init__() - self.node_idxs = node_idxs # kwargs.get("node_idxs") + self.node_idxs = node_idxs # kwargs.get("node_idxs") assert isinstance(self.node_idxs, list) self.n_perturbations = kwargs.get("n_perturbations") self.perturb_features = kwargs.get("perturb_features") @@ -262,8 +327,13 @@ def __init__(self,node_idxs, **kwargs): self.n_influencers = kwargs.get("n_influencers") self.attacker = NettackEvasionAttacker(0, **kwargs) - def attack(self, model_manager, gen_dataset, mask_tensor): + def attack( + self, + model_manager: Type, + gen_dataset: DatasetManager, + mask_tensor: torch.Tensor + ) -> DatasetManager: for node_idx in self.node_idxs: self.attacker.node_idx = node_idx gen_dataset = self.attacker.attack(model_manager, gen_dataset, mask_tensor) - return gen_dataset \ No newline at end of file + return gen_dataset diff --git a/src/attacks/mi_attacks.py b/src/attacks/mi_attacks.py index 7e1d080..27ab795 100644 --- a/src/attacks/mi_attacks.py +++ b/src/attacks/mi_attacks.py @@ -1,13 +1,23 @@ from attacks.attack_base import Attacker -class MIAttacker(Attacker): - def __init__(self, **kwargs): +class MIAttacker( + Attacker +): + def __init__( + self, + **kwargs + ): super().__init__() -class EmptyMIAttacker(MIAttacker): +class EmptyMIAttacker( + MIAttacker +): name = "EmptyMIAttacker" - def attack(self, **kwargs): + def attack( + self, + **kwargs + ): pass diff --git a/src/attacks/poison_attacks.py b/src/attacks/poison_attacks.py index e02028d..d70418b 100644 --- a/src/attacks/poison_attacks.py +++ b/src/attacks/poison_attacks.py @@ -4,30 +4,51 @@ from attacks.attack_base import Attacker from pathlib import Path +from base.datasets_processing import DatasetManager + POISON_ATTACKS_DIR = Path(__file__).parent.resolve() / 'poison_attacks_collection' -class PoisonAttacker(Attacker): - def __init__(self, **kwargs): + +class PoisonAttacker( + Attacker +): + def __init__( + self, + **kwargs + ): super().__init__() -class EmptyPoisonAttacker(PoisonAttacker): +class EmptyPoisonAttacker( + PoisonAttacker +): name = "EmptyPoisonAttacker" - def attack(self, **kwargs): + def attack( + self, + **kwargs + ): pass -class RandomPoisonAttack(PoisonAttacker): +class RandomPoisonAttack( + PoisonAttacker +): name = "RandomPoisonAttack" - def __init__(self, n_edges_percent=0.1): + def __init__( + self, + n_edges_percent: float = 0.1 + ): self.attack_diff = None super().__init__() self.n_edges_percent = n_edges_percent - def attack(self, gen_dataset): + def attack( + self, + gen_dataset: DatasetManager + ) -> DatasetManager: edge_index = gen_dataset.data.edge_index random_indices = np.random.choice( edge_index.shape[1], @@ -42,22 +63,7 @@ def attack(self, gen_dataset): self.attack_diff = edge_index_diff return gen_dataset - def attack_diff(self): + def attack_diff( + self + ): return self.attack_diff - -class EmptyPoisonAttacker(PoisonAttacker): - name = "EmptyPoisonAttacker" - - def attack(self, **kwargs): - pass - -# for attack_name in POISON_ATTACKS_DIR.rglob("*_attack.py"): -# try: -# importlib.import_module(str(attack_name)) -# except ImportError: -# print(f"Couldn't import Attack: {attack_name}") - -# import attacks.poison_attacks_collection.metattack.meta_gradient_attack - -# # TODO this is not best practice to import this thing here this way -# from attacks.poison_attacks_collection.metattack.meta_gradient_attack import BaseMeta From ba86e9d82715d1ce7db0e2ca6a715797c6b8c777 Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Thu, 21 Nov 2024 16:17:45 +0300 Subject: [PATCH 6/8] change DatasetManager to GeneralDataset in def call type --- src/attacks/attack_base.py | 4 +-- src/attacks/evasion_attacks.py | 20 ++++++------ src/attacks/poison_attacks.py | 6 ++-- src/defense/defense_base.py | 4 +-- src/defense/poison_defense.py | 10 +++--- src/explainers/explainer.py | 53 +++++++++++++++++++++++++------- src/models_builder/gnn_models.py | 34 ++++++++++---------- src/models_builder/models_zoo.py | 4 +-- 8 files changed, 83 insertions(+), 52 deletions(-) diff --git a/src/attacks/attack_base.py b/src/attacks/attack_base.py index fad6055..80f2b86 100644 --- a/src/attacks/attack_base.py +++ b/src/attacks/attack_base.py @@ -1,6 +1,6 @@ from typing import Type -from base.datasets_processing import DatasetManager +from base.datasets_processing import GeneralDataset class Attacker: @@ -24,7 +24,7 @@ def attack_diff( @staticmethod def check_availability( - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, model_manager: Type ): return False diff --git a/src/attacks/evasion_attacks.py b/src/attacks/evasion_attacks.py index a3ad225..9dff2a6 100644 --- a/src/attacks/evasion_attacks.py +++ b/src/attacks/evasion_attacks.py @@ -5,7 +5,7 @@ import numpy as np from attacks.attack_base import Attacker -from base.datasets_processing import DatasetManager +from base.datasets_processing import GeneralDataset # Nettack imports from src.attacks.nettack.nettack import Nettack @@ -55,7 +55,7 @@ def __init__( def attack( self, model_manager: Type, - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, mask_tensor: torch.Tensor ): gen_dataset.data.x.requires_grad = True @@ -98,7 +98,7 @@ def __init__( def attack( self, model_manager: Type, - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, mask_tensor: torch.Tensor ) -> None: if gen_dataset.is_multi(): @@ -109,7 +109,7 @@ def attack( def _attack_on_node( self, model_manager: Type, - gen_dataset: DatasetManager + gen_dataset: GeneralDataset ) -> None: node_idx = self.element_idx @@ -156,7 +156,7 @@ def _attack_on_node( def _attack_on_graph( self, model_manager: Type, - gen_dataset: DatasetManager + gen_dataset: GeneralDataset ): graph_idx = self.element_idx @@ -221,9 +221,9 @@ def __init__( def attack( self, model_manager: Type, - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, mask_tensor: torch.Tensor - ) -> DatasetManager: + ) -> GeneralDataset: # Prepare data = gen_dataset.data _A_obs, _X_obs, _z_obs = data_to_csr_matrix(data) @@ -278,7 +278,7 @@ def attack_diff( @staticmethod def _evasion( - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, feature_perturbations, structure_perturbations ): @@ -330,9 +330,9 @@ def __init__( def attack( self, model_manager: Type, - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, mask_tensor: torch.Tensor - ) -> DatasetManager: + ) -> GeneralDataset: for node_idx in self.node_idxs: self.attacker.node_idx = node_idx gen_dataset = self.attacker.attack(model_manager, gen_dataset, mask_tensor) diff --git a/src/attacks/poison_attacks.py b/src/attacks/poison_attacks.py index d70418b..dcfffee 100644 --- a/src/attacks/poison_attacks.py +++ b/src/attacks/poison_attacks.py @@ -4,7 +4,7 @@ from attacks.attack_base import Attacker from pathlib import Path -from base.datasets_processing import DatasetManager +from base.datasets_processing import GeneralDataset POISON_ATTACKS_DIR = Path(__file__).parent.resolve() / 'poison_attacks_collection' @@ -47,8 +47,8 @@ def __init__( def attack( self, - gen_dataset: DatasetManager - ) -> DatasetManager: + gen_dataset: GeneralDataset + ) -> GeneralDataset: edge_index = gen_dataset.data.edge_index random_indices = np.random.choice( edge_index.shape[1], diff --git a/src/defense/defense_base.py b/src/defense/defense_base.py index d46921a..3997e5e 100644 --- a/src/defense/defense_base.py +++ b/src/defense/defense_base.py @@ -1,6 +1,6 @@ from typing import Type -from base.datasets_processing import DatasetManager +from base.datasets_processing import GeneralDataset class Defender: @@ -18,7 +18,7 @@ def defense_diff( @staticmethod def check_availability( - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, model_manager: Type ): return False diff --git a/src/defense/poison_defense.py b/src/defense/poison_defense.py index 0e3020a..fdb9dc2 100644 --- a/src/defense/poison_defense.py +++ b/src/defense/poison_defense.py @@ -1,6 +1,6 @@ import numpy as np -from base.datasets_processing import DatasetManager +from base.datasets_processing import GeneralDataset from defense.defense_base import Defender @@ -36,8 +36,8 @@ def __init__( def defense( self, - gen_dataset: DatasetManager - ) -> DatasetManager: + gen_dataset: GeneralDataset + ) -> GeneralDataset: edge_index = gen_dataset.data.edge_index random_indices = np.random.choice( edge_index.shape[1], @@ -65,6 +65,6 @@ class EmptyPoisonDefender( def defense( self, - gen_dataset:DatasetManager - ) -> DatasetManager: + gen_dataset: GeneralDataset + ) -> GeneralDataset: return gen_dataset diff --git a/src/explainers/explainer.py b/src/explainers/explainer.py index bb80e3f..f7fce49 100644 --- a/src/explainers/explainer.py +++ b/src/explainers/explainer.py @@ -1,18 +1,30 @@ from time import sleep from abc import ABC, abstractmethod +from typing import Union, Callable, Any, Type + +from flask_socketio import SocketIO from tqdm import tqdm -from base.datasets_processing import GeneralDataset +from base.datasets_processing import GeneralDataset, DatasetManager class ProgressBar(tqdm): - def __init__(self, socket, dst, *args, **kwargs): + def __init__( + self, + socket: SocketIO, + dst, + *args, + **kwargs + ): super(ProgressBar, self).__init__(*args, **kwargs) self.dst = dst self.socket = socket self._kwargs = {} - def _report(self, obligate=False): + def _report( + self, + obligate: bool = False + ) -> None: if self.socket is not None: msg = {} msg.update(self._kwargs) @@ -23,20 +35,32 @@ def _report(self, obligate=False): }}) self.socket.send(block=self.dst, msg=msg, tag=self.dst + '_progress', obligate=obligate) - def reset(self, total=None, **kwargs): + def reset( + self, + total: Union[float, None] = None, + **kwargs + ): res = super().reset(total=total) self._kwargs = kwargs self._report(obligate=True) return res - def update(self, n=1): + def update( + self, + n: int = 1 + ): res = super().update(n=n) self._report(obligate=True) return res -def finalize_decorator(func): - def wrapper(*args, **kwargs): +def finalize_decorator( + func: Callable +) -> Callable: + def wrapper( + *args, + **kwargs + ) -> Any: # Before call self: Explainer = args[0] self._run_mode = args[1] @@ -50,18 +74,25 @@ def wrapper(*args, **kwargs): return wrapper -class Explainer(ABC): +class Explainer( + ABC +): """ Superclass for supported explainers. """ name = 'Explainer' @staticmethod - def check_availability(gen_dataset, model_manager): + def check_availability( + gen_dataset: DatasetManager, + model_manager: Type + ) -> bool: """ Availability check for the given dataset and model manager. """ return False - def __init__(self, gen_dataset: GeneralDataset, model): + def __init__( + self, + gen_dataset: GeneralDataset, model): """ :param gen_dataset: dataset :param model: GNN model @@ -170,4 +201,4 @@ def _finalize(self): self.explanation = Explanation(type='string', local=False, data=data) # Remove unpickable attributes - self.pbar = None \ No newline at end of file + self.pbar = None diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index bd547b6..b95a25a 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -23,7 +23,7 @@ hash_data_sha256, \ TECHNICAL_PARAMETER_KEY, IMPORT_INFO_KEY, OPTIMIZERS_PARAMETERS_PATH, FUNCTIONS_PARAMETERS_PATH from aux.declaration import Declare -from base.datasets_processing import DatasetManager +from base.datasets_processing import GeneralDataset from explainers.explainer import ProgressBar from explainers.ProtGNN.MCTS import mcts_args from attacks.evasion_attacks import EvasionAttacker @@ -204,7 +204,7 @@ def train_model( def train_1_step( self, - gen_dataset: DatasetManager + gen_dataset: GeneralDataset ): """ Perform 1 step of model training. """ @@ -213,7 +213,7 @@ def train_1_step( def train_complete( self, - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, steps: int = None, **kwargs ) -> None: @@ -755,7 +755,7 @@ def take_gnn_obj( def before_epoch( self, - gen_dataset: DatasetManager + gen_dataset: GeneralDataset ): """ This hook is called before training the next training epoch """ @@ -763,7 +763,7 @@ def before_epoch( def after_epoch( self, - gen_dataset: DatasetManager + gen_dataset: GeneralDataset ): """ This hook is called after training the next training epoch """ @@ -887,7 +887,7 @@ def init( def train_complete( self, - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, steps: int = None, pbar: Protocol = None, metrics: Union[List[Metric], Metric] = None, @@ -910,7 +910,7 @@ def train_complete( def early_stopping( self, train_loss, - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, metrics: Union[List[Metric], Metric], steps: int ) -> bool: @@ -918,7 +918,7 @@ def early_stopping( def train_1_step( self, - gen_dataset: DatasetManager + gen_dataset: GeneralDataset ) -> List[Union[float, int]]: task_type = gen_dataset.domain() if task_type == "single-graph": @@ -1063,7 +1063,7 @@ def save_model( def report_results( self, train_loss, - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, metrics: List[Metric] ) -> None: metrics_values = self.evaluate_model(gen_dataset=gen_dataset, metrics=metrics) @@ -1076,7 +1076,7 @@ def report_results( def train_model( self, - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, save_model_flag: bool = True, mode: Union[str, None] = None, steps=None, @@ -1139,7 +1139,7 @@ def train_model( def run_model( self, - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, mask: Union[str, List[bool], torch.Tensor] = 'test', out: str = 'answers' ) -> torch.Tensor: @@ -1227,7 +1227,7 @@ def run_model( def evaluate_model( self, - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, metrics: Union[List[Metric], Metric] ) -> dict: """ @@ -1271,7 +1271,7 @@ def evaluate_model( def compute_stats_data( self, - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, predictions: bool = False, logits: bool = False ): @@ -1353,8 +1353,8 @@ def send_epoch_results( def load_train_test_split( self, - gen_dataset: DatasetManager - ) -> DatasetManager: + gen_dataset: GeneralDataset + ) -> GeneralDataset: path = self.model_path_info() path = path / 'train_test_split' gen_dataset.train_mask, gen_dataset.val_mask, gen_dataset.test_mask, _ = torch.load(path)[:] @@ -1527,7 +1527,7 @@ def optimizer_step( def before_epoch( self, - gen_dataset: DatasetManager + gen_dataset: GeneralDataset ): cur_step = self.modification.epochs train_ind = [n for n, x in enumerate(gen_dataset.train_mask) if x] @@ -1568,7 +1568,7 @@ def after_epoch(self, gen_dataset): def early_stopping( self, train_loss, - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, metrics: Union[List[Metric], Metric], steps: int ) -> bool: diff --git a/src/models_builder/models_zoo.py b/src/models_builder/models_zoo.py index a8a9b18..6b353ce 100644 --- a/src/models_builder/models_zoo.py +++ b/src/models_builder/models_zoo.py @@ -1,10 +1,10 @@ -from base.datasets_processing import DatasetManager +from base.datasets_processing import GeneralDataset from models_builder.gnn_constructor import FrameworkGNNConstructor from aux.configs import ModelConfig, ModelStructureConfig def model_configs_zoo( - dataset: DatasetManager, + dataset: GeneralDataset, model_name: str ): gat_gin_lin = FrameworkGNNConstructor( From 45220f9870e93e49e83f48aad27459fa6cbb83fa Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Thu, 21 Nov 2024 16:45:30 +0300 Subject: [PATCH 7/8] make better files in explainers --- src/explainers/explainer.py | 55 +++++++++++++++----- src/explainers/explainer_metrics.py | 78 ++++++++++++++++++++++------ src/explainers/explainers_manager.py | 49 +++++++++++++---- src/explainers/explanation.py | 54 +++++++++++++++---- src/models_builder/gnn_models.py | 2 +- 5 files changed, 188 insertions(+), 50 deletions(-) diff --git a/src/explainers/explainer.py b/src/explainers/explainer.py index f7fce49..d940fb4 100644 --- a/src/explainers/explainer.py +++ b/src/explainers/explainer.py @@ -1,3 +1,4 @@ +from pathlib import Path from time import sleep from abc import ABC, abstractmethod from typing import Union, Callable, Any, Type @@ -5,7 +6,7 @@ from flask_socketio import SocketIO from tqdm import tqdm -from base.datasets_processing import GeneralDataset, DatasetManager +from base.datasets_processing import GeneralDataset class ProgressBar(tqdm): @@ -84,7 +85,7 @@ class Explainer( @staticmethod def check_availability( - gen_dataset: DatasetManager, + gen_dataset: GeneralDataset, model_manager: Type ) -> bool: """ Availability check for the given dataset and model manager. """ @@ -92,11 +93,13 @@ def check_availability( def __init__( self, - gen_dataset: GeneralDataset, model): + gen_dataset: GeneralDataset, + model: Type, + **kwargs + ): """ :param gen_dataset: dataset :param model: GNN model - :param kwargs: init args """ self.gen_dataset = gen_dataset self.model = model @@ -110,7 +113,12 @@ def __init__( @finalize_decorator @abstractmethod - def run(self, mode, kwargs, finalize=True): + def run( + self, + mode: str, + kwargs: dict, + finalize: bool = True + ): """ Run explanation on a given element (node or graph). finalize_decorator handles finalize() call when run() is finished. @@ -123,7 +131,9 @@ def run(self, mode, kwargs, finalize=True): pass @abstractmethod - def _finalize(self): + def _finalize( + self + ): """ Convert current explanation into inner framework json-able format. @@ -131,7 +141,10 @@ def _finalize(self): """ pass - def save(self, path): + def save( + self, + path: Union[str, Path] + ) -> None: """ Dump explanation in json format at a given path. @@ -141,24 +154,40 @@ def save(self, path): self.explanation.save(path) -class DummyExplainer(Explainer): +class DummyExplainer( + Explainer +): """ Dummy explainer for debugging """ name = '_Dummy' @staticmethod - def check_availability(gen_dataset, model_manager): + def check_availability( + gen_dataset: GeneralDataset, + model_manager: Type + ) -> bool: """ Fits for all """ return True - def __init__(self, gen_dataset, model, init_arg=None, **kwargs): + def __init__( + self, + gen_dataset: GeneralDataset, + model: Type, + init_arg=None, + **kwargs + ): Explainer.__init__(self, gen_dataset, model) self.init_arg = init_arg self._local_explanation = None self._global_explanation = None @finalize_decorator - def run(self, mode, kwargs, finalize=True): + def run( + self, + mode: str, + kwargs: dict, + finalize: bool = True + ) -> None: self.pbar.reset(total=10, mode=mode) if mode == "local": assert self._global_explanation is not None @@ -187,7 +216,9 @@ def run(self, mode, kwargs, finalize=True): # Remove unpickable attributes self.pbar = None - def _finalize(self): + def _finalize( + self + ) -> None: mode = self._run_mode if mode == "local": assert self._global_explanation is not None diff --git a/src/explainers/explainer_metrics.py b/src/explainers/explainer_metrics.py index 32e0eb5..40d127b 100644 --- a/src/explainers/explainer_metrics.py +++ b/src/explainers/explainer_metrics.py @@ -1,10 +1,18 @@ +from typing import Type + import numpy as np import torch from torch_geometric.utils import subgraph class NodesExplainerMetric: - def __init__(self, model, graph, explainer, kwargs_dict): + def __init__( + self, + model: Type, + graph, + explainer, + kwargs_dict: dict + ): self.model = model self.explainer = explainer self.graph = graph @@ -21,7 +29,10 @@ def __init__(self, model, graph, explainer, kwargs_dict): self.dictionary = { } - def evaluate(self, target_nodes_indices): + def evaluate( + self, + target_nodes_indices: list + ) -> dict: num_targets = len(target_nodes_indices) sparsity = 0 stability = 0 @@ -46,7 +57,10 @@ def evaluate(self, target_nodes_indices): self.dictionary["fidelity"] = fidelity return self.dictionary - def calculate_fidelity(self, target_nodes_indices): + def calculate_fidelity( + self, + target_nodes_indices: list + ) -> float: original_answer = self.model.get_answer(self.x, self.edge_index) same_answers_count = 0 for node_ind in target_nodes_indices: @@ -62,7 +76,10 @@ def calculate_fidelity(self, target_nodes_indices): fidelity = same_answers_count / len(target_nodes_indices) return fidelity - def calculate_sparsity(self, node_ind): + def calculate_sparsity( + self, + node_ind: int + ) -> float: explanation = self.get_explanation(node_ind) sparsity = 1 - (len(explanation["data"]["nodes"]) + len(explanation["data"]["edges"])) / ( len(self.x) + len(self.edge_index)) @@ -70,11 +87,11 @@ def calculate_sparsity(self, node_ind): def calculate_stability( self, - node_ind, - graph_perturbations_nums=10, - feature_change_percent=0.05, - node_removal_percent=0.05 - ): + node_ind: int, + graph_perturbations_nums: int = 10, + feature_change_percent: float = 0.05, + node_removal_percent: float = 0.05 + ) -> float: base_explanation = self.get_explanation(node_ind) stability = 0 for _ in range(graph_perturbations_nums): @@ -90,7 +107,11 @@ def calculate_stability( stability = stability / graph_perturbations_nums return stability - def calculate_consistency(self, node_ind, num_explanation_runs=10): + def calculate_consistency( + self, + node_ind: int, + num_explanation_runs: int = 10 + ) -> float: explanation = self.get_explanation(node_ind) consistency = 0 for _ in range(num_explanation_runs): @@ -103,13 +124,22 @@ def calculate_consistency(self, node_ind, num_explanation_runs=10): consistency = consistency / num_explanation_runs return consistency - def calculate_explanation(self, x, edge_index, node_idx, **kwargs): + def calculate_explanation( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + node_idx: int, + **kwargs + ): print(f"Processing explanation calculation for node id {node_idx}.") self.explainer.evaluate_tensor_graph(x, edge_index, node_idx, **kwargs) print(f"Explanation calculation for node id {node_idx} completed.") return self.explainer.explanation.dictionary - def get_explanation(self, node_ind): + def get_explanation( + self, + node_ind: int + ): if node_ind in self.nodes_explanations: node_explanation = self.nodes_explanations[node_ind] else: @@ -118,7 +148,9 @@ def get_explanation(self, node_ind): return node_explanation @staticmethod - def parse_explanation(explanation): + def parse_explanation( + explanation: dict + ) -> [dict, dict]: important_nodes = { int(node): float(weight) for node, weight in explanation["data"]["nodes"].items() } @@ -129,7 +161,12 @@ def parse_explanation(explanation): return important_nodes, important_edges @staticmethod - def filter_graph_by_explanation(x, edge_index, explanation, target_node): + def filter_graph_by_explanation( + x: torch.Tensor, + edge_index: torch.Tensor, + explanation: dict, + target_node: int + ) -> [torch.Tensor, torch.Tensor, int]: important_nodes, important_edges = NodesExplainerMetric.parse_explanation(explanation) all_important_nodes = set(important_nodes.keys()) all_important_nodes.add(target_node) @@ -147,7 +184,10 @@ def filter_graph_by_explanation(x, edge_index, explanation, target_node): return new_x, new_edge_index, new_target_node @staticmethod - def calculate_explanation_vectors(base_explanation, perturbed_explanation): + def calculate_explanation_vectors( + base_explanation, + perturbed_explanation + ): base_important_nodes, base_important_edges = NodesExplainerMetric.parse_explanation( base_explanation ) @@ -171,7 +211,13 @@ def calculate_explanation_vectors(base_explanation, perturbed_explanation): return base_explanation_vector, perturbed_explanation_vector @staticmethod - def perturb_graph(x, edge_index, node_ind, feature_change_percent, node_removal_percent): + def perturb_graph( + x: torch.Tensor, + edge_index: torch.Tensor, + node_ind: int, + feature_change_percent: float, + node_removal_percent: float + ) -> [torch.Tensor, torch.Tensor]: new_x = x.clone() num_nodes = x.shape[0] num_features = x.shape[1] diff --git a/src/explainers/explainers_manager.py b/src/explainers/explainers_manager.py index cf93aae..8909d8e 100644 --- a/src/explainers/explainers_manager.py +++ b/src/explainers/explainers_manager.py @@ -1,8 +1,11 @@ import json +from socket import SocketIO +from typing import Union, Type -from aux.configs import ExplainerInitConfig, ExplainerModificationConfig, CONFIG_OBJ, ConfigPattern +from aux.configs import ExplainerInitConfig, ExplainerModificationConfig, CONFIG_OBJ, ConfigPattern, ExplainerRunConfig from aux.declaration import Declare from aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH +from base.datasets_processing import GeneralDataset from explainers.explainer import Explainer, ProgressBar from explainers.explainer_metrics import NodesExplainerMetric @@ -36,12 +39,14 @@ class FrameworkExplainersManager: def __init__( self, - dataset, gnn_manager, - init_config=None, + dataset: GeneralDataset, + gnn_manager: Type, + init_config: Union[ConfigPattern, ExplainerInitConfig] = None, explainer_name: str = None, - modification_config: ExplainerModificationConfig = None, + modification_config: Union[ConfigPattern, ExplainerModificationConfig] = None, device: str = None ): + self.files_paths = None if device is None: device = "cpu" self.device = device @@ -107,20 +112,27 @@ def __init__( self.gen_dataset, model=self.gnn, device=self.device, # device=device("cpu"), - **init_kwargs) + **init_kwargs + ) self.explanation = None self.explanation_data = None self.running = False - def save_explanation(self, run_config): + def save_explanation( + self, + run_config: Union[ConfigPattern, ExplainerRunConfig] + ) -> None: """ Save explanation to file. """ self.explanation_result_path(run_config) self.explainer.save(self.explainer_result_file_path) print("Saved explanation") - def load_explanation(self, run_config): + def load_explanation( + self, + run_config: Union[ConfigPattern, ExplainerRunConfig] + ) -> dict: if self.modification_config.explainer_ver_ind is None: raise RuntimeError("explainer_ver_ind should not be None") self.explanation_result_path(run_config) @@ -133,7 +145,10 @@ def load_explanation(self, run_config): f"{self.explainer_result_file_path}") return explanation - def explanation_result_path(self, run_config): + def explanation_result_path( + self, + run_config: Union[ConfigPattern, ExplainerRunConfig] + ) -> None: # TODO pass configs self.explainer_result_file_path, self.files_paths = Declare.explanation_file_path( models_path=self.gnn_model_path, @@ -143,7 +158,11 @@ def explanation_result_path(self, run_config): explainer_run_kwargs=run_config.to_saveable_dict(), ) - def conduct_experiment(self, run_config, socket=None): + def conduct_experiment( + self, + run_config: Union[ConfigPattern, ExplainerRunConfig], + socket: SocketIO = None + ) -> dict: """ Runs the full cycle of the interpretation experiment """ @@ -176,7 +195,12 @@ def conduct_experiment(self, run_config, socket=None): return result - def evaluate_metrics(self, target_nodes_indices, run_config=None, socket=None): + def evaluate_metrics( + self, + target_nodes_indices: list, + run_config: Union[ConfigPattern, ExplainerRunConfig, None] = None, + socket: SocketIO = None + ) -> dict: """ Evaluates explanation metrics between given node indices """ @@ -221,7 +245,10 @@ def evaluate_metrics(self, target_nodes_indices, run_config=None, socket=None): return result @staticmethod - def available_explainers(gen_dataset, model_manager): + def available_explainers( + gen_dataset: GeneralDataset, + model_manager: Type + ) -> list: """ Get a list of explainers applicable for current model and dataset. """ return [ diff --git a/src/explainers/explanation.py b/src/explainers/explanation.py index adf21e6..d6eda27 100644 --- a/src/explainers/explanation.py +++ b/src/explainers/explanation.py @@ -1,11 +1,19 @@ import json +from pathlib import Path +from typing import Union, Any class Explanation: """ General class to represent GNN explanation. """ - def __init__(self, local, type, data=None, meta=None): + def __init__( + self, + local: bool, + type: str, + data: str = None, + meta=None + ): """ :param local: True if local, False if global :param type: "subgraph", "prototype", etc @@ -20,21 +28,32 @@ def __init__(self, local, type, data=None, meta=None): if meta is not None: self.dictionary['info']['meta'] = meta - def save(self, path): + def save( + self, + path: Union[str, Path] + ) -> None: with open(path, 'w', encoding='utf-8') as f: json.dump(self.dictionary, f, ensure_ascii=False, indent=4) -class AttributionExplanation(Explanation): +class AttributionExplanation( + Explanation +): """ Attribution explanation as important subgraph. Importance scores (binary or continual) can be assigned to nodes, edges, and features. """ - def __init__(self, local=True, directed=False, nodes="binary", edges=False, features=False): + def __init__( + self, + local: bool = True, + directed: bool = False, + nodes: str = "binary", + edges: bool = False, + features: bool = False + ): """ :param local: True if local, False if global - :param type: "subgraph", "prototype", etc :param nodes: "binary", "continuous", or None/False :param edges: "binary", "continuous", or None/False :param features: "binary", "continuous", or None/False @@ -44,18 +63,33 @@ def __init__(self, local=True, directed=False, nodes="binary", edges=False, feat super(AttributionExplanation, self).__init__(local=local, type="subgraph", meta=meta) self.dictionary['info']['directed'] = directed - def add_edges(self, edge_data): + def add_edges( + self, + edge_data: dict + ) -> None: self.dictionary['data']['edges'] = edge_data - def add_features(self, feature_data): + def add_features( + self, + feature_data: dict + ) -> None: self.dictionary['data']['features'] = feature_data - def add_nodes(self, node_data): + def add_nodes( + self, + node_data: dict + ) -> None: self.dictionary['data']['nodes'] = node_data -class ConceptExplanationGlobal(Explanation): - def __init__(self, raw_neurons, n_neurons): +class ConceptExplanationGlobal( + Explanation +): + def __init__( + self, + raw_neurons: list, + n_neurons: Any + ): Explanation.__init__(self, False, 'string') self.dictionary['data']['neurons'] = {} for n in range(n_neurons): diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index b95a25a..f1d020d 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -944,7 +944,7 @@ def train_1_step( print("loss %.8f" % loss) self.modification.epochs += 1 self.gnn.eval() - return loss.detach().numpy().tolist() + return loss.cpu().detach().numpy().tolist() def train_on_batch_full( self, From c73763a9e939bff6caf043f57723fd2e06c94db4 Mon Sep 17 00:00:00 2001 From: "lukyanov_kirya@bk.ru" Date: Thu, 21 Nov 2024 17:14:43 +0300 Subject: [PATCH 8/8] make better files in base --- src/base/custom_datasets.py | 75 ++++++++--- src/base/datasets_processing.py | 220 ++++++++++++++++++++++++-------- src/base/ptg_datasets.py | 67 +++++++--- src/base/vk_datasets.py | 51 ++++++-- 4 files changed, 319 insertions(+), 94 deletions(-) diff --git a/src/base/custom_datasets.py b/src/base/custom_datasets.py index 912835c..2be1bf5 100644 --- a/src/base/custom_datasets.py +++ b/src/base/custom_datasets.py @@ -1,20 +1,27 @@ import json import os from pathlib import Path +from typing import Union + import numpy as np import torch from torch_geometric.data import Data, InMemoryDataset from aux.declaration import Declare from base.datasets_processing import GeneralDataset, DatasetInfo -from aux.configs import DatasetConfig, DatasetVarConfig +from aux.configs import DatasetConfig, DatasetVarConfig, ConfigPattern from base.ptg_datasets import LocalDataset -class CustomDataset(GeneralDataset): +class CustomDataset( + GeneralDataset +): """ User-defined dataset in 'ij' format. """ - def __init__(self, dataset_config: DatasetConfig): + def __init__( + self, + dataset_config: Union[ConfigPattern, DatasetConfig] + ): """ Args: dataset_config: DatasetConfig dict from frontend @@ -27,31 +34,44 @@ def __init__(self, dataset_config: DatasetConfig): self.edge_index = None @property - def node_attributes_dir(self): + def node_attributes_dir( + self + ): """ Path to dir with node attributes. """ return self.root_dir / 'raw' / (self.name + '.node_attributes') @property - def edge_attributes_dir(self): + def edge_attributes_dir( + self + ): """ Path to dir with edge attributes. """ return self.root_dir / 'raw' / (self.name + '.edge_attributes') @property - def labels_dir(self): + def labels_dir( + self + ): """ Path to dir with labels. """ return self.root_dir / 'raw' / (self.name + '.labels') @property - def edges_path(self): + def edges_path( + self + ): """ Path to file with edge list. """ return self.root_dir / 'raw' / (self.name + '.ij') @property - def edge_index_path(self): + def edge_index_path( + self + ): """ Path to dir with labels. """ return self.root_dir / 'raw' / (self.name + '.edge_index') - def build(self, dataset_var_config: DatasetVarConfig): + def build( + self, + dataset_var_config: Union[ConfigPattern, DatasetVarConfig] + ) -> None: """ Build ptg dataset based on dataset_var_config and create DatasetVarData. """ if dataset_var_config == self.dataset_var_config: @@ -62,7 +82,10 @@ def build(self, dataset_var_config: DatasetVarConfig): self.dataset_var_config = dataset_var_config self.dataset = LocalDataset(self.results_dir, process_func=self._create_ptg) - def _compute_stat(self, stat): + def _compute_stat( + self, + stat: str + ) -> dict: """ Compute some additional stats """ if stat == "attr_corr": @@ -123,7 +146,9 @@ def _compute_stat(self, stat): else: return super()._compute_stat(stat) - def _compute_dataset_data(self): + def _compute_dataset_data( + self + ) -> None: """ Get DatasetData for debug graph Structure according to https://docs.google.com/spreadsheets/d/1fNI3sneeGoOFyIZP_spEjjD-7JX2jNl_P8CQrA4HZiI/edit#gid=1096434224 """ @@ -272,7 +297,9 @@ def _compute_dataset_data(self): # if self.info.name == "": # self.dataset_data['info']['name'] = '/'.join(self.dataset_config.full_name()) - def _create_ptg(self): + def _create_ptg( + self + ) -> None: """ Create PTG Dataset and save tensors """ if self.edge_index is None: @@ -295,7 +322,10 @@ def _create_ptg(self): self.results_dir.mkdir(exist_ok=True, parents=True) torch.save(InMemoryDataset.collate(data_list), self.results_dir / 'data.pt') - def _iter_nodes(self, graph: int = None): + def _iter_nodes( + self, + graph: int = None + ) -> None: """ Iterate over nodes according to mapping. Yields pairs of (node_index, original_id) """ # offset = sum(self.info.nodes[:graph]) if self.is_multi() else 0 @@ -308,7 +338,10 @@ def _iter_nodes(self, graph: int = None): for n in range(self.info.nodes[graph or 0]): yield offset+n, str(n) - def _labeling_tensor(self, g_ix=None) -> list: + def _labeling_tensor( + self, + g_ix=None + ) -> list: """ Returns list of labels (not tensors) """ y = [] # Read labels @@ -330,21 +363,29 @@ def _labeling_tensor(self, g_ix=None) -> list: return y - def _feature_tensor(self, g_ix=None) -> list: + def _feature_tensor( + self, + g_ix=None + ) -> list: """ Returns list of features (not tensors) for graph g_ix. """ features = self.dataset_var_config.features # dict about attributes construction nodes_onehot = "str_g" in features and features["str_g"] == "one_hot" # Read attributes - def one_hot(x, values): + def one_hot( + x: int, + values: list + ) -> list: res = [0] * len(values) for ix, v in enumerate(values): if x == v: res[ix] = 1 return res - def as_is(x): + def as_is( + x + ) -> list: return x if isinstance(x, list) else [x] # TODO other encoding types from Kirill diff --git a/src/base/datasets_processing.py b/src/base/datasets_processing.py index d696d04..1fce165 100644 --- a/src/base/datasets_processing.py +++ b/src/base/datasets_processing.py @@ -1,11 +1,14 @@ import json import os from pathlib import Path +from typing import Union, Type + import torch +import torch_geometric from torch import default_generator, randperm from torch_geometric.data import Dataset, InMemoryDataset -from aux.configs import DatasetConfig, DatasetVarConfig +from aux.configs import DatasetConfig, DatasetVarConfig, ConfigPattern from aux.custom_decorators import timing_decorator from aux.declaration import Declare from aux.utils import TORCH_GEOM_GRAPHS_PATH @@ -17,7 +20,9 @@ class DatasetInfo: Some fields are obligate, others are not. """ - def __init__(self): + def __init__( + self + ): self.name: str = "" self.count: int = None self.directed: bool = None @@ -35,7 +40,9 @@ def __init__(self): self.edge_info: dict = {} self.graph_info: dict = {} - def check_validity(self): + def check_validity( + self + ) -> None: """ Check existing fields have allowed values. """ assert self.count > 0 assert len(self.node_attributes) > 0 @@ -56,7 +63,9 @@ def check_validity(self): assert isinstance(k, str) assert isinstance(v, int) and v > 1 - def check_consistency(self): + def check_consistency( + self + ) -> None: """ Check existing fields are consistent. """ assert self.count == len(self.nodes) assert len(self.node_attributes["names"]) == len(self.node_attributes["types"]) == len( @@ -64,13 +73,18 @@ def check_consistency(self): assert len(self.edge_attributes["names"]) == len(self.edge_attributes["types"]) == len( self.edge_attributes["values"]) - def check_sufficiency(self): + def check_sufficiency( + self + ) -> None: """ Check all obligates fields are defined. """ for attr in self.__dict__.keys(): if attr is None: raise ValueError(f"Attribute '{attr}' of metainfo should be defined.") - def check_consistency_with_dataset(self, dataset: Dataset): + def check_consistency_with_dataset( + self, + dataset: Dataset + ) -> None: """ Check if metainfo fields are consistent with dataset. """ assert self.count == len(dataset) from base.ptg_datasets import is_graph_directed @@ -80,17 +94,24 @@ def check_consistency_with_dataset(self, dataset: Dataset): assert self.node_attributes["types"][0] == "other" # TODO check features values range - def check(self): + def check( + self + ) -> None: """ Check metainfo is sufficient, consistent, and valid. """ self.check_sufficiency() self.check_consistency() self.check_validity() - def to_dict(self): + def to_dict( + self + ) -> dict: """ Return info as a dictionary. """ return dict(self.__dict__) - def save(self, path: Path): + def save( + self, + path: Union[str, Path] + ) -> None: """ Save into file non-null info. """ not_nones = {k: v for k, v in self.__dict__.items() if v is not None} path.parent.mkdir(exist_ok=True, parents=True) @@ -98,7 +119,9 @@ def save(self, path: Path): json.dump(not_nones, f, indent=1) @staticmethod - def induce(dataset: Dataset): + def induce( + dataset: Dataset + ): """ Induce metainfo from a given PTG dataset. """ res = DatasetInfo() res.count = len(dataset) @@ -116,7 +139,9 @@ def induce(dataset: Dataset): return res @staticmethod - def read(path: Path): + def read( + path: Union[str, Path] + ): """ Read info from a file. """ with path.open('r') as f: a_dict = json.load(f) @@ -128,7 +153,9 @@ def read(path: Path): return res @staticmethod - def node_attributes_to_node_attr_slices(node_attributes): + def node_attributes_to_node_attr_slices( + node_attributes: dict + ) -> dict: node_attr_slices = {} start_attr_index = 0 for i in range(len(node_attributes['names'])): @@ -145,7 +172,12 @@ def node_attributes_to_node_attr_slices(node_attributes): class VisiblePart: - def __init__(self, gen_dataset, center: [int, list] = None, depth: [int] = None): + def __init__( + self, + gen_dataset, + center: [int, list] = None, + depth: [int] = None + ): """ Compute a part of dataset specified by a center node/graph and a depth :param gen_dataset: @@ -226,10 +258,14 @@ def __init__(self, gen_dataset, center: [int, list] = None, depth: [int] = None) self.nodes = gen_dataset.info.nodes[0] self._ixes = list(range(self.nodes)) - def ixes(self): + def ixes( + self + ) -> list: return self._ixes - def as_dict(self): + def as_dict( + self + ) -> dict: res = {} if self.nodes: res['nodes'] = self.nodes @@ -239,7 +275,10 @@ def as_dict(self): res['graphs'] = self.graphs return res - def filter(self, array): + def filter( + self, + array + ) -> dict: """ Suppose ixes = [2,4]: [a, b, c, d] -> {2: b, 4: d} """ return {ix: array[ix] for ix in self._ixes} @@ -249,7 +288,10 @@ class GeneralDataset: """ Generalisation of PTG and user-defined datasets: custom, VK, etc. """ - def __init__(self, dataset_config: DatasetConfig): + def __init__( + self, + dataset_config: Union[DatasetConfig, ConfigPattern] + ): """ Args: dataset_config: DatasetConfig dict from frontend @@ -281,51 +323,71 @@ def __init__(self, dataset_config: DatasetConfig): self._labels = None @property - def root_dir(self): + def root_dir( + self + ): """ Dataset root directory with folders 'raw' and 'prepared'. """ # FIXME Misha, dataset_prepared_dir return path and files_paths not only path return Declare.dataset_root_dir(self.dataset_config)[0] @property - def results_dir(self): + def results_dir( + self + ): """ Path to 'prepared/../' folder where tensor data is stored. """ # FIXME Misha, dataset_prepared_dir return path and files_paths not only path return Path(Declare.dataset_prepared_dir(self.dataset_config, self.dataset_var_config)[0]) @property - def raw_dir(self): + def raw_dir( + self + ): """ Path to 'raw/' folder where raw data is stored. """ return self.root_dir / 'raw' @property - def api_path(self): + def api_path( + self + ): """ Path to '.api' file. Could be not present. """ return self.root_dir / '.api' @property - def info_path(self): + def info_path( + self + ): """ Path to '.info' file. """ return self.root_dir / 'raw' / '.info' @property - def stats_dir(self): + def stats_dir( + self + ): """ Path to '.stats' directory. """ return self.root_dir / '.stats' @property - def data(self): + def data( + self + ): return self.dataset._data @property - def num_classes(self): + def num_classes( + self + ): return self.dataset.num_classes @property - def num_node_features(self): + def num_node_features( + self + ): return self.dataset.num_node_features @property - def labels(self): + def labels( + self + ): if self._labels is None: # NOTE: this is a copy from torch_geometric.data.dataset v=2.3.1 from torch_geometric.data.dataset import _get_flattened_data_list @@ -333,22 +395,34 @@ def labels(self): self._labels = torch.cat([data.y for data in data_list if 'y' in data], dim=0) return self._labels - def __len__(self): + def __len__( + self + ) -> int: return self.info.count - def domain(self): + def domain( + self + ) -> str: return self.dataset_config.domain - def is_multi(self): + def is_multi( + self + ) -> bool: """ Return whether this dataset is multiple-graphs or single-graph. """ return self.info.count > 1 - def build(self, dataset_var_config: DatasetVarConfig): + def build( + self, + dataset_var_config: Union[ConfigPattern, DatasetVarConfig] + ): """ Create node feature tensors from attributes based on dataset_var_config. """ raise NotImplementedError() - def get_dataset_data(self, part=None): + def get_dataset_data( + self, + part: Union[dict, None] = None + ) -> dict: """ Get DatasetData for specified graphs or nodes """ edges_list = [] @@ -379,7 +453,9 @@ def get_dataset_data(self, part=None): return res - def _compute_dataset_data(self): + def _compute_dataset_data( + self + ) -> None: num = len(self.dataset) data_list = [self.dataset.get(ix) for ix in range(num)] is_directed = self.info.directed @@ -427,13 +503,19 @@ def _compute_dataset_data(self): # if self.info.name == "": # self.dataset_data['info']['name'] = '/'.join(self.dataset_config.full_name()) - def set_visible_part(self, part: dict): + def set_visible_part( + self, + part: dict + ) -> None: if self.dataset_data is None: self._compute_dataset_data() self.visible_part = VisiblePart(self, **part) - def get_dataset_var_data(self, part=None): + def get_dataset_var_data( + self, + part: Union[dict, None] = None + ) -> dict: """ Get DatasetVarData for specified graphs or nodes """ if self.dataset_var_data is None: @@ -455,7 +537,9 @@ def get_dataset_var_data(self, part=None): return dataset_var_data - def _compute_dataset_var_data(self): + def _compute_dataset_var_data( + self + ) -> None: """ Prepare dataset_var_data for frontend on demand. """ # FIXME version fail in torch-geom 2.3.1 @@ -485,7 +569,10 @@ def _compute_dataset_var_data(self): "labels": labels if self.is_multi() else labels[0], } - def get_stat(self, stat): + def get_stat( + self, + stat + ): """ Get statistics. """ if stat in self.stats: @@ -511,7 +598,10 @@ def get_stat(self, stat): json.dump(value, f, ensure_ascii=False) return value - def _compute_stat(self, stat): + def _compute_stat( + self, + stat + ): """ Compute statistics. """ if self.is_multi(): # try: @@ -608,7 +698,9 @@ def _compute_stat(self, stat): value = str(e) return value - def is_one_hot_able(self): + def is_one_hot_able( + self + ) -> bool: """ Return whether features are 1-hot encodings. If yes, nodes can be colored. """ assert self.dataset_var_config @@ -633,12 +725,16 @@ def is_one_hot_able(self): elif features['attr'][attr] == 'other': # Check honestly each feature vector feats = self.dataset_var_data['features'] - res = all(all(all(x == 1 or x == 0 for x in f) for f in feat) for feat in feats) and\ + res = all(all(all(x == 1 or x == 0 for x in f) for f in feat) for feat in feats) and \ all(all(sum(f) == 1 for f in feat) for feat in feats) return res - def train_test_split(self, percent_train_class: float = 0.8, percent_test_class: float = 0.2): + def train_test_split( + self, + percent_train_class: float = 0.8, + percent_test_class: float = 0.2 + ) -> None: """ Compute train-validation-test split of graphs/nodes. """ self.percent_train_class = percent_train_class self.percent_test_class = percent_test_class @@ -680,7 +776,10 @@ def train_test_split(self, percent_train_class: float = 0.8, percent_test_class: self.dataset.data.test_mask = test_mask self.dataset.data.val_mask = val_mask - def save_train_test_mask(self, path): + def save_train_test_mask( + self, + path: Union[str, Path] + ) -> None: """ Save current train/test mask to a given path (together with the model). """ if path is not None: path /= 'train_test_split' @@ -700,7 +799,9 @@ class DatasetManager: @staticmethod def register_torch_geometric_local( - dataset: InMemoryDataset, name: str = None) -> GeneralDataset: + dataset: InMemoryDataset, + name: str = None + ) -> GeneralDataset: """ Save a given PTG dataset locally. Dataset is then always available for use by its config. @@ -717,8 +818,10 @@ def register_torch_geometric_local( # QUE Misha, Kirill - can we use get_by_config always instead of it? @staticmethod @timing_decorator - def get_by_config(dataset_config: DatasetConfig, - dataset_var_config: DatasetVarConfig = None) -> GeneralDataset: + def get_by_config( + dataset_config: DatasetConfig, + dataset_var_config: DatasetVarConfig = None + ) -> GeneralDataset: """ Get GeneralDataset by dataset config. Used from the frontend. """ dataset_group = dataset_config.group @@ -743,7 +846,10 @@ def get_by_config(dataset_config: DatasetConfig, @staticmethod @timing_decorator - def get_by_full_name(full_name=None, **kwargs): + def get_by_full_name( + full_name=None, + **kwargs + ) -> [GeneralDataset, torch_geometric.data.Data, Path]: """ Get PTG dataset by its full name tuple. Starts the creation of an object from raw data or takes already saved datasets in prepared @@ -770,8 +876,10 @@ def get_by_full_name(full_name=None, **kwargs): @staticmethod def register_torch_geometric_api( - dataset: Dataset, name: str = None, - obj_name: str = 'DATASET_TO_EXPORT') -> GeneralDataset: + dataset: Dataset, + name: str = None, + obj_name: str = 'DATASET_TO_EXPORT' + ) -> GeneralDataset: """ Register a user defined code implementing a PTG dataset. This function should be called at each framework run to make the dataset available for use. @@ -797,7 +905,9 @@ def register_torch_geometric_api( return gen_dataset @staticmethod - def register_custom_ij(path: Path) -> GeneralDataset: + def register_custom_ij( + path: Path + ) -> GeneralDataset: """ :return: GeneralDataset """ @@ -805,8 +915,12 @@ def register_custom_ij(path: Path) -> GeneralDataset: @staticmethod def _register_torch_geometric( - dataset: Dataset, name=None, group=None, - exists_ok=False, copy_data=False) -> GeneralDataset: + dataset: Dataset, + name: Union[str, None] = None, + group: str = None, + exists_ok: bool = False, + copy_data: bool = False + ) -> GeneralDataset: """ Create GeneralDataset from an externally specified torch geometric dataset. @@ -874,7 +988,9 @@ def _register_torch_geometric( return gen_dataset -def is_in_torch_geometric_datasets(full_name=None): +def is_in_torch_geometric_datasets( + full_name: tuple = None +) -> bool: from aux.prefix_storage import PrefixStorage with open(TORCH_GEOM_GRAPHS_PATH, 'r') as f: return PrefixStorage.from_json(f.read()).check(full_name) diff --git a/src/base/ptg_datasets.py b/src/base/ptg_datasets.py index 5cd555f..f42f058 100644 --- a/src/base/ptg_datasets.py +++ b/src/base/ptg_datasets.py @@ -3,16 +3,20 @@ import os import shutil from pathlib import Path +from typing import Union, Callable + import torch from torch_geometric.data import InMemoryDataset, Data, Dataset from torch_geometric.data.data import BaseData from aux.utils import import_by_name, root_dir, root_dir_len from base.datasets_processing import GeneralDataset, is_in_torch_geometric_datasets, DatasetInfo -from aux.configs import DatasetConfig, DatasetVarConfig +from aux.configs import DatasetConfig, DatasetVarConfig, ConfigPattern -class PTGDataset(GeneralDataset): +class PTGDataset( + GeneralDataset +): """ Contains a PTG dataset. """ attr_name = 'unknown' @@ -22,7 +26,11 @@ class PTGDataset(GeneralDataset): dataset_ver_ind=0 ) - def __init__(self, dataset_config: DatasetConfig, **kwargs): + def __init__( + self, + dataset_config: Union[ConfigPattern, DatasetConfig], + **kwargs + ): """ :param dataset_config: dataset config dictionary :param kwargs: additional args to init torch dataset class @@ -127,14 +135,20 @@ def __init__(self, dataset_config: DatasetConfig, **kwargs): # raise FileNotFoundError( # f"No data found for dataset '{self.dataset_config.full_name()}'") - def move_processed(self, processed: (str, Path)): + def move_processed( + self, + processed: Union[str, Path] + ) -> None: if not self.results_dir.exists(): self.results_dir.mkdir(parents=True) os.rename(processed, self.results_dir) else: shutil.rmtree(processed) - def move_raw(self, raw: (str, Path)): + def move_raw( + self, + raw: Union[str, Path] + ) -> None: if Path(raw) == self.raw_dir: return if not self.raw_dir.exists(): @@ -143,25 +157,39 @@ def move_raw(self, raw: (str, Path)): else: raise RuntimeError(f"raw_dir '{self.raw_dir}' already exists") - def _compute_dataset_data(self, center=None, depth=None): + def _compute_dataset_data( + self, + center=None, + depth: Union[int, None] = None + ) -> None: # assert len(name_type) == 1 # FIXME dataset_data = super()._compute_dataset_data() # FIXME add features return dataset_data - def build(self, dataset_var_config: dict=None): + def build( + self, + dataset_var_config: dict = None + ) -> None: """ PTG dataset is already built """ # Use cached ptg dataset. Only default dataset_var_config is allowed. assert self.dataset_var_config == dataset_var_config -class LocalDataset(InMemoryDataset): +class LocalDataset( + InMemoryDataset +): """ Locally saved PTG Dataset. """ - def __init__(self, results_dir, process_func=None, **kwargs): + def __init__( + self, + results_dir: Union[str, Path], + process_func: Union[Callable, None] = None, + **kwargs + ): """ :param results_dir: @@ -172,7 +200,7 @@ def __init__(self, results_dir, process_func=None, **kwargs): if process_func: self.process = process_func # Init and process if needed - super().__init__(None, **kwargs) + super().__init__(None, **kwargs) # Load self.data, *rest_data = torch.load(self.processed_paths[0]) @@ -180,22 +208,31 @@ def __init__(self, results_dir, process_func=None, **kwargs): try: self.slices = rest_data[0] # TODO can use rest_data[1] ? - except IndexError: pass + except IndexError: + pass @property - def processed_file_names(self): + def processed_file_names( + self + ): return 'data.pt' - def process(self): + def process( + self + ): raise RuntimeError("Dataset is supposed to be processed and saved earlier.") # torch.save(self.collate(self.data_list), self.processed_paths[0]) @property - def processed_dir(self) -> str: + def processed_dir( + self + ) -> str: return self.results_dir -def is_graph_directed(data: (Data, BaseData)) -> bool: +def is_graph_directed( + data: Union[Data, BaseData] +) -> bool: """ Detect whether graph is directed or not (for each edge i->j, exists j->i). """ # Note: this does not work correctly. E.g. for TUDataset/MUTAG it incorrectly says directed. diff --git a/src/base/vk_datasets.py b/src/base/vk_datasets.py index 4d48349..936ae11 100644 --- a/src/base/vk_datasets.py +++ b/src/base/vk_datasets.py @@ -6,6 +6,7 @@ from numbers import Number from operator import itemgetter from pathlib import Path +from typing import Union import numpy as np @@ -23,7 +24,8 @@ class AttrInfo: _attribute_vals_cache = {} # (full_name, attribute) -> attribute_vals @staticmethod - def vk_attr(): + def vk_attr( + ): vk_dict = { ('age',): list(range(0, len(AGE_GROUPS) + 1)), ('sex',): [1, 2], @@ -39,7 +41,10 @@ def vk_attr(): return vk_dict @staticmethod - def attribute_vals(full_name, attribute: [str, tuple, list]) -> list: + def attribute_vals( + full_name: tuple, + attribute: [str, tuple, list] + ) -> list: """ Get a set of possible attribute values or None for textual or continuous attributes. """ # Convert to tuple @@ -90,10 +95,14 @@ def attribute_vals(full_name, attribute: [str, tuple, list]) -> list: return res @staticmethod - def one_hot(full_name, attribute: [str, tuple, list], value, add_none=False): + def one_hot( + full_name: tuple, + attribute: [str, tuple, list], + value: int, + add_none: bool = False + ) -> Union[np.ndarray, list]: """ 1-hot encoding feature. If no such value, return all zeros or with 1 it in last element. :param full_name: - :param graph: MyGraph :param attribute: attribute name, e.g. 'sex', ('personal', 'smoking'). :param value: value of this attribute. If a list, a multiple-hot vector will be constructed. :param add_none: if True, last element of returned vector encodes undefined or @@ -142,11 +151,21 @@ def one_hot(full_name, attribute: [str, tuple, list], value, add_none=False): return res # all zeros here -class VKDataset(CustomDataset): +class ConfigPatter: + pass + + +class VKDataset( + CustomDataset +): """ Custom dataset of VK samples with specific attributes processing and features creation. """ - def __init__(self, dataset_config: DatasetConfig, add_none=False): + def __init__( + self, + dataset_config: Union[ConfigPatter, DatasetConfig], + add_none: bool = False + ): """ Args: dataset_config: DatasetConfig dict from frontend @@ -156,7 +175,9 @@ def __init__(self, dataset_config: DatasetConfig, add_none=False): super().__init__(dataset_config) self.add_none = add_none - def _compute_dataset_data(self): + def _compute_dataset_data( + self + ): """ Get DatasetData for VK graph """ super()._compute_dataset_data() @@ -174,7 +195,10 @@ def _compute_dataset_data(self): # labelings[filename] = max([-1 if x is None else x for x in d.values()]) + 1 # self.dataset_data["info"]["labelings"] = labelings - def _feature_tensor(self, g_ix=None) -> list: + def _feature_tensor( + self, + g_ix=None + ) -> list: # FIXME Misha self.node_map[graph] ... x = [[] for _ in range(len(self.node_map))] features = self.dataset_var_config.features @@ -204,7 +228,10 @@ def _feature_tensor(self, g_ix=None) -> list: return x @staticmethod - def bdate_to_age(attr_dir_path: str, node_map: list): + def bdate_to_age( + attr_dir_path: str, + node_map: list + ) -> None: with open(attr_dir_path / Path('bdate'), 'r') as f: age_dict = json.load(f) node_age = {} @@ -223,7 +250,11 @@ def bdate_to_age(attr_dir_path: str, node_map: list): json.dump(node_age, f1) -def make_vk_labeling(attr_path: str, labeling_path: str, attr_val: int = 1): +def make_vk_labeling( + attr_path: str, + labeling_path: str, + attr_val: int = 1 +) -> None: """ Creates a markup file where the attribute's target value is set to 1 and the rest to 0 Args: