Skip to content

Commit

Permalink
Add documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
bnorthan committed Jan 14, 2025
1 parent c01152a commit cf8929e
Showing 1 changed file with 111 additions and 1 deletion.
112 changes: 111 additions & 1 deletion src/napari_easy_augment_batch_dl/frameworks/base_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,45 +11,142 @@ class TrainMode:
Pixels = 1

class BaseFramework:
"""
Base class for frameworks.
Frameworks represents DL approaches like Stardist, Cellpose, SAM or Semantic Segmentation Unets.
The frameworks are intended to work with a specific pre-existing file organization
parent_path: path to the parent directory, the images are stored in parent_path/images
parent_path/models: path to the models directory, where the models are stored
parent_path/patches: path to the patches directory, where the patches are stored
"""

# initiate registry for all the frameworks. Each framework should register itself when imported:
# For example a cellpose framework would have the following line in its code:
# BaseFramework.register_framework('CellPoseInstanceFramework', CellPoseInstanceFramework)
registry = {}

@classmethod
def register_framework(cls, name, framework):
cls.registry[name] = framework

def __init__(self, parent_path, num_classes=1):
"""
Base class for all frameworks.
"""

self.parent_path = parent_path
# path to store models
self.model_path = os.path.join(parent_path, 'models')
self.num_classes = num_classes
# path to store patches
self.patch_path = os.path.join(parent_path, 'patches')

# current model name
self.model_name = 'notset'
self.load_mode = LoadMode.NotLoadable

# boxes should be set true if the framework detects bounding boxes
self.boxes = False

# builtins are models that are included with the framework (for example cyto3 in cellpose)
self.builtin_names = []

# model dictionary stores all models including builtins and custom models
self.model_dictionary = {}
self.train_mode = TrainMode.Patches

def train(self, updater=None):
"""
Train the model.
This method must be implemented by derived classes. It defines the
process for training the model, potentially using an optional updater
for progress reporting.
Args:
updater (callable, optional): A callback function that can be
called during training to report progress.
"""
pass

def predict(self, image):
"""
Predict the output for a given image.
This method must be implemented by derived classes. It defines the
process for making predictions on input data.
Args:
image: The input image as a numpy array.
Returns:
The prediction result.
"""
pass

def get_model_names(self):
"""
Get the names of all models that are available for this framework.
Override this method in derived classes to return a list of model names.
Often the list will include builtins and custom models.
"""
return ['notset']

def get_optimizers(self):
"""
Get names of available optimizers.
(optional) Override this method in derived classes to return a list of optimizer names.
"""
return []

def create_callback(self, updater):
"""
Create a callback function for training. The updater function may be wrapped in a custom callback class
For example in a stardist trainer a keras.callbacks.Callback is created that calls the updater function.
In other frameworks the updater function may be used directly by the training function.
Args:
updater (callable): A callback function with parameters (message, progress)
"""
self.updater = updater

def generate_model_name(self, base_name="model"):
"""
Generate a unique model name based on the current time.
May be overridden in derived classes to provide a custom naming scheme.
Args:
base_name (str): The base name for the model.
Returns:
str: The generated model name.
"""
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = f"{base_name}_{current_time}"
return model_name

def set_pretrained_model(self, model_name):
"""
Sets a pretrained model. Built in models are models which are included with the framework and are created by the framework developers.
If the model name is a built_in model, the model is created by calling the set_builtin_model method.
For example in Cellpose a Cyto3 model is a built in model. And can be create by calling 'models.CellposeModel(gpu=True, model_type='cyto3')'
For cellpose frame 'set_builtin_model' will be overridden with the above code.
Args:
model_name (str): The name of the model to set
"""
if model_name != 'notset':

model = self.model_dictionary.get(model_name, None)
Expand All @@ -58,15 +155,28 @@ def set_pretrained_model(self, model_name):

# if a built in model set it
if model_name in self.builtin_names:
#self.model = models.CellposeModel(gpu=True, model_type=model_name)
self.set_builtin_model(model_name)
self.model_dictionary[model_name] = self.model
else:
self.model = model

def set_optimizer(self, optimizer):
"""
Set the optimizer for the model.
"""
pass

def set_builtin_model(self, model_name):
"""
Built in models are models which are included with the framework and are created by the framework developers.
For example in Cellpose a Cyto3 model is a built in model. And can be create by calling 'models.CellposeModel(gpu=True, model_type='cyto3')'
For cellpose frame 'set_builtin_model' will be overridden with the above code.
Each framework can optionally override this method to set a built in model using the protocol of the framework.
Args:
model_name (str): The name of the model to set
"""
pass

0 comments on commit cf8929e

Please sign in to comment.