Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved documentation layout for algorithm class #1809

Merged
merged 10 commits into from
Aug 23, 2024
112 changes: 72 additions & 40 deletions Wrappers/Python/cil/optimisation/algorithms/Algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,31 +26,22 @@


class Algorithm:
'''Base class for iterative algorithms
r"""Base class providing minimal infrastructure for iterative algorithms.

provides the minimal infrastructure.
An iterative algorithm is designed to solve an optimization problem by repeatedly refining a solution. In CIL, we use iterative algorithms to minimize an objective function, often referred to as a loss. The process begins with an initial guess, and with each iteration, the algorithm updates the current solution based on the results of previous iterations (previous iterates). Iterative algorithms typically continue until a stopping criterion is met, indicating that an optimal or sufficiently good solution has been found. In CIL, stopping criteria can be implemented using a callback function (`cil.optimisation.utilities.callbacks`).

The user is required to implement the :code:`set_up`, :code:`__init__`, :code:`update` and :code:`update_objective` methods.

Algorithms are iterables so can be easily run in a for loop. They will
stop as soon as the stop criterion is met.
The user is required to implement the :code:`set_up`, :code:`__init__`, :code:`update` and
and :code:`update_objective` methods

A courtesy method :code:`run` is available to run :code:`n` iterations. The method accepts
a :code:`callbacks` list of callables, each of which receive the current Algorithm object
(which in turn contains the iteration number and the actual objective value)
and can be used to trigger print to screens and other user interactions. The :code:`run`
method will stop when the stopping criterion is met or `StopIteration` is raised.
'''
The method :code:`run` is available to run :code:`n` iterations. The method accepts a :code:`callbacks` list of callables, each of which receive the current Algorithm object (which in turn contains the iteration number and the actual objective value) and can be used to trigger print to screens and other user interactions. The :code:`run` method will stop when the stopping criterion is met or `StopIteration` is raised.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
update_objective_interval: int, optional, default 1
the interval every which we would save the current objective. 1 means every iteration, 2 every 2 iteration and so forth. This is by default 1 and should be increased when evaluating the objective is computationally expensive.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, update_objective_interval=1, max_iteration=None, log_file=None):
'''Set the minimal number of parameters:

:param update_objective_interval: the interval every which we would save the current\
objective. 1 means every iteration, 2 every 2 iteration\
and so forth. This is by default 1 and should be increased\
when evaluating the objective is computationally expensive.
:type update_objective_interval: int, optional, default 1
'''
self.iteration = -1
self.__max_iteration = 1
if max_iteration is not None:
Expand Down Expand Up @@ -82,9 +73,11 @@ def should_stop(self):
return self.iteration > self.max_iteration

def __set_up_logger(self, *_, **__):
"""Do not use: this is being deprecated"""
warn("use `run(callbacks=[LogfileCallback(log_file)])` instead", DeprecationWarning, stacklevel=2)

def max_iteration_stop_criterion(self):
"""Do not use: this is being deprecated"""
warn("use `should_stop()` instead of `max_iteration_stop_criterion()`", DeprecationWarning, stacklevel=2)
return self.iteration > self.max_iteration

Expand Down Expand Up @@ -119,10 +112,9 @@ def __next__(self):
return self.iteration

def _update_previous_solution(self):
""" Update the previous solution with the current one
r""" Update the previous solution with the current one

The concrete algorithm calls update_previous_solution. Normally this would
entail the swapping of pointers:
The concrete algorithm calls update_previous_solution. Normally this would entail the swapping of pointers:
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

.. highlight:: python
.. code-block:: python
Expand All @@ -135,26 +127,54 @@ def _update_previous_solution(self):
pass

def get_output(self):
" Returns the current solution. "
r""" Returns the current solution.

Returns
-------
DataContainer
The current solution

"""
return self.x

def _provable_convergence_condition(self):
r""" Checks if the algorithm set-up (e.g. chosen step-sizes or other parameters) meets a mathematical convergence criterion.

Returns
-------
bool: Outcome of the convergence check
"""
raise NotImplementedError(" Convergence criterion is not implemented for this algorithm. ")

def is_provably_convergent(self):
""" Check if the algorithm is convergent based on the provable convergence criterion.
r""" Check if the algorithm is convergent based on the provable convergence criterion.

Returns
-------
Boolean
Outcome of the convergence check

"""
return self._provable_convergence_condition()

@property
def solution(self):
" Returns the current solution. "
return self.get_output()

def get_last_loss(self, return_all=False):
'''Returns the last stored value of the loss function

if update_objective_interval is 1 it is the value of the objective at the current
iteration. If update_objective_interval > 1 it is the last stored value.
r'''Returns the last stored value of the loss function. "Loss" is an alias for "objective value". If `update_objective_interval` is 1 it is the value of the objective at the current iteration. If update_objective_interval > 1 it is the last stored value.

Parameters
----------
return_all: Boolean, default is False
If True, returns all the stored loss functions

Returns
-------
Float
Last stored value of the loss function

'''
try:
objective = self.__loss[-1]
Expand All @@ -174,12 +194,12 @@ def update_objective(self):
def iterations(self):
'''returns the iterations at which the objective has been evaluated'''
return self._iteration

@property
def loss(self):
'''returns the list of the values of the objective during the iteration
'''returns a list of the values of the objective (alias of loss) during the iteration

The length of this list may be shorter than the number of iterations run when
the update_objective_interval > 1
The length of this list may be shorter than the number of iterations run when the `update_objective_interval` > 1
'''
return self.__loss

Expand All @@ -198,23 +218,31 @@ def max_iteration(self, value):

@property
def update_objective_interval(self):
'''gets the update_objective_interval'''
return self.__update_objective_interval

@update_objective_interval.setter
def update_objective_interval(self, value):
'''sets the update_objective_interval'''
if not isinstance(value, Integral) or value < 0:
raise ValueError('interval must be an integer >= 0')
self.__update_objective_interval = value

def run(self, iterations=None, callbacks: Optional[List[Callback]]=None, verbose=1, **kwargs):
'''run upto :code:`iterations` with callbacks/logging.

:param iterations: number of iterations to run. If not set the algorithm will
run until :code:`should_stop()` is reached
:param verbose: 0=quiet, 1=info, 2=debug
:param callbacks: list of callables which are passed the current Algorithm
object each iteration. Defaults to :code:`[ProgressCallback(verbose)]`.
'''
r"""run upto :code:`iterations` with callbacks/logging.

For a demonstration of callbacks see https://github.com/TomographicImaging/CIL-Demos/blob/main/misc/callback_demonstration.ipynb

Parameters
-----------
iterations: int, default is None
Number of iterations to run. If not set the algorithm will run until :code:`should_stop()` is reached
callbacks: list of callables, default is Defaults to :code:`[ProgressCallback(verbose)]`
List of callables which are passed the current Algorithm object each iteration. Defaults to :code:`[ProgressCallback(verbose)]`.
verbose: 0=quiet, 1=info, 2=debug
Passed to the default callback to determine the verbosity of the printed output.
"""

if 'print_interval' in kwargs:
warn("use `TextProgressCallback(miniters)` instead of `run(print_interval)`",
DeprecationWarning, stacklevel=2)
Expand Down Expand Up @@ -248,6 +276,7 @@ def run(self, iterations=None, callbacks: Optional[List[Callback]]=None, verbose
break

def objective_to_dict(self, verbose=False):
"""Internal function to save and print objective functions"""
obj = self.get_last_objective(return_all=verbose)
if isinstance(obj, list) and len(obj) == 3:
if not np.isnan(obj[1:]).all():
Expand All @@ -256,11 +285,14 @@ def objective_to_dict(self, verbose=False):
return {'objective': obj}

def objective_to_string(self, verbose=False):
"""Do not use: this is being deprecated"""
warn("consider using `run(callbacks=[LogfileCallback(log_file)])` instead", DeprecationWarning, stacklevel=2)
return str(self.objective_to_dict(verbose=verbose))

def verbose_output(self, *_, **__):
"""Do not use: this is being deprecated"""
warn("use `run(callbacks=[ProgressCallback()])` instead", DeprecationWarning, stacklevel=2)

def verbose_header(self, *_, **__):
"""Do not use: this is being deprecated"""
warn("consider using `run(callbacks=[LogfileCallback(log_file)])` instead", DeprecationWarning, stacklevel=2)
Loading