From cf8929e406954f5ab14fc85219d698003f6a1f01 Mon Sep 17 00:00:00 2001 From: bnorthan Date: Tue, 14 Jan 2025 18:04:08 -0500 Subject: [PATCH] Add documentation --- .../frameworks/base_framework.py | 112 +++++++++++++++++- 1 file changed, 111 insertions(+), 1 deletion(-) diff --git a/src/napari_easy_augment_batch_dl/frameworks/base_framework.py b/src/napari_easy_augment_batch_dl/frameworks/base_framework.py index 5cd1e5b..68dbe72 100644 --- a/src/napari_easy_augment_batch_dl/frameworks/base_framework.py +++ b/src/napari_easy_augment_batch_dl/frameworks/base_framework.py @@ -11,6 +11,21 @@ class TrainMode: Pixels = 1 class BaseFramework: + """ + Base class for frameworks. + + Frameworks represents DL approaches like Stardist, Cellpose, SAM or Semantic Segmentation Unets. + + The frameworks are intended to work with a specific pre-existing file organization + + parent_path: path to the parent directory, the images are stored in parent_path/images + parent_path/models: path to the models directory, where the models are stored + parent_path/patches: path to the patches directory, where the patches are stored + """ + + # initiate registry for all the frameworks. Each framework should register itself when imported: + # For example a cellpose framework would have the following line in its code: + # BaseFramework.register_framework('CellPoseInstanceFramework', CellPoseInstanceFramework) registry = {} @classmethod @@ -18,38 +33,120 @@ def register_framework(cls, name, framework): cls.registry[name] = framework def __init__(self, parent_path, num_classes=1): + """ + Base class for all frameworks. + """ + self.parent_path = parent_path + # path to store models self.model_path = os.path.join(parent_path, 'models') self.num_classes = num_classes + # path to store patches self.patch_path = os.path.join(parent_path, 'patches') + + # current model name self.model_name = 'notset' self.load_mode = LoadMode.NotLoadable + + # boxes should be set true if the framework detects bounding boxes self.boxes = False + + # builtins are models that are included with the framework (for example cyto3 in cellpose) self.builtin_names = [] + + # model dictionary stores all models including builtins and custom models self.model_dictionary = {} self.train_mode = TrainMode.Patches def train(self, updater=None): + """ + Train the model. + + This method must be implemented by derived classes. It defines the + process for training the model, potentially using an optional updater + for progress reporting. + + Args: + updater (callable, optional): A callback function that can be + called during training to report progress. + """ pass def predict(self, image): + """ + Predict the output for a given image. + + This method must be implemented by derived classes. It defines the + process for making predictions on input data. + + Args: + image: The input image as a numpy array. + + Returns: + The prediction result. + """ pass def get_model_names(self): + """ + Get the names of all models that are available for this framework. + + Override this method in derived classes to return a list of model names. + + Often the list will include builtins and custom models. + """ return ['notset'] def get_optimizers(self): + """ + Get names of available optimizers. + + (optional) Override this method in derived classes to return a list of optimizer names. + """ return [] def create_callback(self, updater): + """ + Create a callback function for training. The updater function may be wrapped in a custom callback class + + For example in a stardist trainer a keras.callbacks.Callback is created that calls the updater function. + + In other frameworks the updater function may be used directly by the training function. + + Args: + updater (callable): A callback function with parameters (message, progress) + """ self.updater = updater def generate_model_name(self, base_name="model"): + """ + Generate a unique model name based on the current time. + + May be overridden in derived classes to provide a custom naming scheme. + + Args: + base_name (str): The base name for the model. + + Returns: + str: The generated model name. + """ current_time = datetime.now().strftime("%Y%m%d_%H%M%S") model_name = f"{base_name}_{current_time}" return model_name def set_pretrained_model(self, model_name): + """ + Sets a pretrained model. Built in models are models which are included with the framework and are created by the framework developers. + + If the model name is a built_in model, the model is created by calling the set_builtin_model method. + + For example in Cellpose a Cyto3 model is a built in model. And can be create by calling 'models.CellposeModel(gpu=True, model_type='cyto3')' + + For cellpose frame 'set_builtin_model' will be overridden with the above code. + + Args: + model_name (str): The name of the model to set + """ if model_name != 'notset': model = self.model_dictionary.get(model_name, None) @@ -58,15 +155,28 @@ def set_pretrained_model(self, model_name): # if a built in model set it if model_name in self.builtin_names: - #self.model = models.CellposeModel(gpu=True, model_type=model_name) self.set_builtin_model(model_name) self.model_dictionary[model_name] = self.model else: self.model = model def set_optimizer(self, optimizer): + """ + Set the optimizer for the model. + """ pass def set_builtin_model(self, model_name): + """ + Built in models are models which are included with the framework and are created by the framework developers. + + For example in Cellpose a Cyto3 model is a built in model. And can be create by calling 'models.CellposeModel(gpu=True, model_type='cyto3')' + + For cellpose frame 'set_builtin_model' will be overridden with the above code. + + Each framework can optionally override this method to set a built in model using the protocol of the framework. + Args: + model_name (str): The name of the model to set + """ pass