Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ApproximateGradientSumFunction class and example SGFunction for Stochastic Gradient Descent #1550

Merged
merged 183 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 160 commits
Commits
Show all changes
183 commits
Select commit Hold shift + click to select a range
8389932
First attempt at sampling class
MargaretDuff Aug 2, 2023
7331c73
Changed how probabilities and samplers interact in SPDHG
MargaretDuff Aug 2, 2023
343509c
Initial playing
MargaretDuff Aug 7, 2023
40ac9c0
Ready to start some basic testing
MargaretDuff Aug 8, 2023
34fb1d5
Started to debug
MargaretDuff Aug 8, 2023
6ef169a
Testind SGD
MargaretDuff Aug 8, 2023
2bce666
Update sampling.py
MargaretDuff Aug 8, 2023
ea759c5
Changed to factory method style and added in permuatations
MargaretDuff Aug 9, 2023
d1909a3
Debugging and fixing random generator in show epochs
MargaretDuff Aug 9, 2023
98b0694
Testing SPDHG
MargaretDuff Aug 9, 2023
05b67cb
Changed the show epochs
MargaretDuff Aug 10, 2023
001350b
Meeting with Vaggelis, Jakob, Gemma and Edo
MargaretDuff Aug 11, 2023
890dec0
Set up for installation
MargaretDuff Aug 14, 2023
25806fc
Added staggered and custom order and started with writing documentation
MargaretDuff Aug 14, 2023
75abbfe
Work on documentation
MargaretDuff Aug 14, 2023
ebdf329
Commenting and examples in the class
MargaretDuff Aug 15, 2023
ba35fb8
Debugging sampler
MargaretDuff Aug 15, 2023
ff5cdf1
sorted build and imports
MargaretDuff Aug 15, 2023
f62f064
Changes to todo
MargaretDuff Aug 16, 2023
beac6fa
Changes after dev meeting
MargaretDuff Aug 17, 2023
1202e53
Checking probabilities in init
MargaretDuff Aug 18, 2023
079935b
initial testing
MargaretDuff Aug 23, 2023
43e3dc4
Sped up PDHG and SPDHG testing
MargaretDuff Aug 24, 2023
004ab2f
Removed timing statements
MargaretDuff Aug 24, 2023
7b857e0
Got rid of epochs - still need to fix the shuffle
MargaretDuff Sep 13, 2023
1f7d546
Fixed random without replacement shuffle=False
MargaretDuff Sep 14, 2023
6993a95
Changes after meeting 12-09-2023. Remove epochs in sampler and deprec…
MargaretDuff Sep 14, 2023
bafc748
Sampler unit tests added
MargaretDuff Sep 19, 2023
d62aa2b
Some checks for setting step sizes
MargaretDuff Sep 19, 2023
c81b71c
Started looking at unit tests and debugging SPDHG setters and init
MargaretDuff Sep 21, 2023
b28f2f1
Notes after discussions with gemma
MargaretDuff Sep 22, 2023
4a87f48
Changes after discussion with gemma
MargaretDuff Sep 25, 2023
b35222f
Updated tests
MargaretDuff Sep 25, 2023
6e552af
Just a commenting change
MargaretDuff Sep 25, 2023
4ae9b3c
Tiny changes
MargaretDuff Sep 28, 2023
69c1e1a
Merge branch 'master' of github.com:TomographicImaging/CIL into SGD
MargaretDuff Sep 28, 2023
6575af6
Initial changes and tests- currently failing tests
MargaretDuff Sep 28, 2023
6b463bc
Sorted tests and checks on the set_norms function
MargaretDuff Oct 2, 2023
215bfa6
Changed a comment
MargaretDuff Oct 2, 2023
3898a03
Removed reference to dask
MargaretDuff Oct 4, 2023
b946d79
Bug fixes
MargaretDuff Oct 5, 2023
96e4730
Changes based on Gemma's review
MargaretDuff Oct 5, 2023
3c36f3f
Small changes
MargaretDuff Oct 9, 2023
1ca3a2b
Comments from Edo fixed
MargaretDuff Oct 9, 2023
4b541e7
Merge branch 'master' into blockoperator-norms
MargaretDuff Oct 9, 2023
9a04de4
Added stuff to gitignore
MargaretDuff Oct 9, 2023
5a302c8
Fixed tests
MargaretDuff Oct 9, 2023
0bffa24
Added a note to the documentation about which sampler to use
MargaretDuff Oct 11, 2023
8416837
Option for list or blockfunction
MargaretDuff Oct 11, 2023
37565fc
Fixed the bugs of the previous commit
MargaretDuff Oct 11, 2023
18647af
Merge branch 'master' of github.com:MargaretDuff/CIL-margaret into st…
MargaretDuff Oct 11, 2023
222c377
Moved the sampler to the algorithms folder
MargaretDuff Oct 12, 2023
1d70eb3
Updated tests
MargaretDuff Oct 12, 2023
5c9fa3a
Sampler inheritance
MargaretDuff Oct 12, 2023
48d355b
Notes from meeting
MargaretDuff Oct 12, 2023
8e84276
Moved sampler to a new folder algorithms.utilities- think there is st…
MargaretDuff Oct 12, 2023
a9cb92e
Some notes from the stochastic meeting
MargaretDuff Oct 12, 2023
c552257
changed cmake file for new folder
MargaretDuff Oct 12, 2023
c6e1458
Some changes from Edo
MargaretDuff Oct 16, 2023
2b35fad
Maths documentation
MargaretDuff Oct 16, 2023
43e6fee
Some more Edo comments on sampler
MargaretDuff Oct 16, 2023
f77b553
Tried to sort the tests
MargaretDuff Oct 17, 2023
cf1b7f1
Vaggelis comment on checks
MargaretDuff Oct 17, 2023
c2c4df9
Change to jinja version in doc_environment.yml
MargaretDuff Oct 17, 2023
544a215
Merge branch 'TomographicImaging:master' into blockoperator-norms
MargaretDuff Oct 17, 2023
d11296f
Revert changes to docs_environment.yml
lauramurgatroyd Oct 18, 2023
32e057b
Docstring change
MargaretDuff Oct 18, 2023
4e0ca6a
Docstring change
MargaretDuff Oct 18, 2023
87f1a00
Revert naming of docs environment file
lauramurgatroyd Oct 18, 2023
2ff165a
Updated changelog
MargaretDuff Oct 18, 2023
81fc7e2
Updated changelog
MargaretDuff Oct 18, 2023
8f100e0
Updated changelog
MargaretDuff Oct 18, 2023
920cf90
Started adding new unit tests
MargaretDuff Oct 20, 2023
3a02a47
More work on tests
MargaretDuff Oct 20, 2023
10748ef
SG tests
MargaretDuff Oct 20, 2023
381342c
Changes to docstring
MargaretDuff Oct 25, 2023
c67818b
Changes to tests
MargaretDuff Oct 25, 2023
6b5ff83
SGD tests including SumFunction
MargaretDuff Oct 25, 2023
c7e5b9f
Merge branch 'SGD' of github.com:MargaretDuff/CIL-margaret into SGD
MargaretDuff Oct 26, 2023
876d4c9
Added size to the BlockOperator
MargaretDuff Oct 26, 2023
5ae4aaf
Merged the blockoperator-norms branch
MargaretDuff Oct 30, 2023
b983e2f
Removed precalculated_norms and pull the prob_weights from the sampler
MargaretDuff Oct 31, 2023
71cbdf9
Changes to setting tau and new unit test
MargaretDuff Oct 31, 2023
8f24634
Just some comments
MargaretDuff Oct 31, 2023
f0f4de3
Changes after discussion with Edo and Gemma
MargaretDuff Nov 2, 2023
100a42d
Merge branch 'blockoperator-norms' of github.com:MargaretDuff/CIL-mar…
MargaretDuff Nov 2, 2023
26584c9
Documentation changes
MargaretDuff Nov 2, 2023
ba8226b
Merge branch 'blockoperator-norms' of github.com:MargaretDuff/CIL-mar…
MargaretDuff Nov 2, 2023
d182423
Changes to SPDHG with block_norms
MargaretDuff Nov 3, 2023
ad86a58
Started setting up factory methods
MargaretDuff Nov 6, 2023
40ba3f4
Added function sampler
MargaretDuff Nov 6, 2023
835ce83
Abstract base class
MargaretDuff Nov 6, 2023
3760458
prob_weights to sampler
MargaretDuff Nov 7, 2023
878675d
TODO:s
MargaretDuff Nov 7, 2023
5ce0a09
Changes after stochastic meeting
MargaretDuff Nov 8, 2023
2d99762
Updates to sampler
MargaretDuff Nov 8, 2023
7154834
Updates to SPDHG after stochastic meeting
MargaretDuff Nov 8, 2023
520b9fa
Updated unit tests
MargaretDuff Nov 8, 2023
13c27e3
Merge branch 'master' into SGD
MargaretDuff Nov 8, 2023
11a4624
Merge branch 'master' into stochastic_sampling
MargaretDuff Nov 8, 2023
4e7f2b6
Merge error fixed
MargaretDuff Nov 8, 2023
d861a13
SPDHG documentation changes
MargaretDuff Nov 15, 2023
c0f0703
Merge branch 'stochastic_sampling' of github.com:MargaretDuff/CIL-mar…
MargaretDuff Nov 22, 2023
fce2a8e
Merged sampler into SGD
MargaretDuff Nov 22, 2023
ea4f114
Merge branch 'master' into SPDHG_unit_tests
MargaretDuff Nov 22, 2023
f95560f
Merged in SPDHG speed up
MargaretDuff Nov 22, 2023
0af2e61
Changes from meeting with Edo and Gemma
MargaretDuff Nov 22, 2023
8e14034
Remove changes to BlockOperator.py
MargaretDuff Nov 22, 2023
5c34e69
sigma and tau properties
MargaretDuff Nov 22, 2023
d1fffdf
Another attempt at speeding up unit tests
MargaretDuff Nov 23, 2023
b3dc8a1
Added random seeds to tests
MargaretDuff Nov 23, 2023
edbaa9f
Started on Gemma's suggestions
MargaretDuff Nov 24, 2023
dc1b67a
Some more of Gemma's changes
MargaretDuff Nov 27, 2023
3b41fc4
Last of Gemma's changes
MargaretDuff Nov 27, 2023
7e5759b
Merge branch 'master' into stochastic_sampling
MargaretDuff Nov 27, 2023
bab0b98
Edo's comments
MargaretDuff Nov 28, 2023
41ff3b5
New __str__ functions in sampler
MargaretDuff Nov 30, 2023
aaa7200
Documentation changes
MargaretDuff Nov 30, 2023
b9bb04d
Documentation changes x2
MargaretDuff Nov 30, 2023
ef25425
Moved custom order to an example of a function
MargaretDuff Dec 5, 2023
0948e39
Back to num_indices and more explanation for custom function examples
MargaretDuff Dec 5, 2023
5804f7d
Updates from chat with Gemma
MargaretDuff Dec 7, 2023
fca94f4
Updates from chat with Gemma
MargaretDuff Dec 7, 2023
7d4ffe6
Pulled prime factorisation code out of the Herman Meyer function
MargaretDuff Dec 8, 2023
2dba9d7
created herman_meyer sampling as a fucntion of iteration number
gfardell Dec 8, 2023
c576a51
Merge pull request #1 from gfardell/stochastic_sampling_hm
MargaretDuff Dec 11, 2023
4c36fdf
Merge branch 'master' into stochastic_sampling
MargaretDuff Dec 11, 2023
f5c2d96
Update Wrappers/Python/cil/optimisation/algorithms/SPDHG.py
MargaretDuff Dec 11, 2023
188000f
Changes from Edo review
MargaretDuff Dec 11, 2023
0155e3d
Merge branch 'stochastic_sampling' of github.com:MargaretDuff/CIL-mar…
MargaretDuff Dec 11, 2023
86c1e3e
Removed from_order to replace with functions
MargaretDuff Dec 11, 2023
47542a5
fix failing tests
MargaretDuff Dec 12, 2023
ddbdbb3
Test fix...again
MargaretDuff Dec 12, 2023
060e915
Merged conflict
MargaretDuff Dec 12, 2023
8e7a6ac
Merge branch 'master' into stochastic_sampling
MargaretDuff Dec 12, 2023
c8b9cc4
Merge branch 'stochastic_sampling' of github.com:MargaretDuff/CIL-mar…
MargaretDuff Dec 12, 2023
5f03675
Fixed tests
MargaretDuff Dec 12, 2023
44173f4
Merge branch 'master' into SGD
MargaretDuff Dec 20, 2023
d72b7c1
Merge branch 'master' into SGD
MargaretDuff Jan 25, 2024
3531b96
Tidy up PR
MargaretDuff Jan 25, 2024
d383733
Tidy up PR
MargaretDuff Jan 25, 2024
470ed86
Updated doc strings and requirements for sampler class - need to do d…
MargaretDuff Jan 25, 2024
54cf27c
optimisation.rst updated to add in the new documentation
MargaretDuff Jan 26, 2024
4c4a26c
Changes after discussion with Edo and Kris
MargaretDuff Feb 12, 2024
c064388
Fixed merge error
MargaretDuff Feb 12, 2024
4b97d9b
Fixed merge error
MargaretDuff Feb 13, 2024
f41ae7b
Added skip astra
MargaretDuff Feb 13, 2024
be75374
New data_passes function and getter
MargaretDuff Feb 13, 2024
778c7c1
New data_passes function and getter
MargaretDuff Feb 13, 2024
655df78
Merge branch 'master' into SGD
MargaretDuff Feb 13, 2024
9990536
Rate to step_size
MargaretDuff Feb 13, 2024
57b71e1
Use backtracking in unit tests
MargaretDuff Feb 13, 2024
eac5397
Getter for num_functions
MargaretDuff Feb 13, 2024
9befcf8
Discussion with Zeljko, Edo and Vaggelis
MargaretDuff Feb 14, 2024
e1019a3
Documentation on data passes
MargaretDuff Feb 14, 2024
c971dd6
Comments from discussion with Edo
MargaretDuff Feb 27, 2024
66c4853
Merge branch 'master' into SGD
MargaretDuff Feb 27, 2024
addfe47
Changes after Vaggelis review
MargaretDuff Mar 12, 2024
649e186
Merge branch 'master' into SGD
MargaretDuff Mar 12, 2024
a3bd92a
Merge branch 'master' into SGD
MargaretDuff Mar 13, 2024
3a74f75
Tweak to unit tests after discussion with Edo
MargaretDuff Mar 14, 2024
2395b03
Some of Jakob's comments
MargaretDuff Mar 20, 2024
a7b5016
Updated documentation from Vaggelis and Jakob comments
MargaretDuff Mar 21, 2024
c716d31
Merge branch 'master' into SGD
MargaretDuff Mar 21, 2024
b9d8ab4
Try to fix rst file example
MargaretDuff Mar 21, 2024
08e91a9
Merge branch 'SGD' of github.com:MargaretDuff/CIL-margaret into SGD
MargaretDuff Mar 21, 2024
9aebc50
Try to fix rst file example
MargaretDuff Mar 21, 2024
e034e03
Try to fix rst file bullet points
MargaretDuff Mar 21, 2024
cf0e60f
Try to fix rst file example
MargaretDuff Mar 21, 2024
fe21db0
Try to fix SGD docs
MargaretDuff Mar 21, 2024
53dd837
Updated example after Vaggelis comments
MargaretDuff Mar 21, 2024
843413f
Discussions with Edo and Gemma
MargaretDuff Mar 21, 2024
f3e416a
Documentation for the multiplication factor
MargaretDuff Mar 22, 2024
c22868e
Documentation for the multiplication factor
MargaretDuff Mar 22, 2024
6294ccb
Merge branch 'master' into SGD
MargaretDuff Mar 25, 2024
561f3e7
Improved documentation after discussion with Edo
MargaretDuff Mar 25, 2024
eab5ce4
Neaten documentation
MargaretDuff Mar 25, 2024
ddb159d
Some of Edo's comments
MargaretDuff Mar 26, 2024
13ded02
Edo's comments
MargaretDuff Mar 26, 2024
e976cf7
Try to fix formating in documentation
MargaretDuff Mar 26, 2024
049041f
Merge branch 'master' into SGD
MargaretDuff Mar 26, 2024
bc6d8b5
Update change log
MargaretDuff Mar 26, 2024
e039e99
Merge branch 'SGD' of github.com:MargaretDuff/CIL-margaret into SGD
MargaretDuff Mar 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# Copyright 2024 United Kingdom Research and Innovation
# Copyright 2024 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Authors:
# CIL Developers, listed at: https://github.com/TomographicImaging/CIL/blob/master/NOTICE.txt


from cil.optimisation.functions import SumFunction
from cil.optimisation.utilities import Sampler
import numbers
from abc import ABC, abstractmethod
import numpy as np


class ApproximateGradientSumFunction(SumFunction, ABC):
r"""ApproximateGradientSumFunction represents the following sum

.. math:: \sum_{i=0}^{n-1} f_{i} = (f_{0} + f_{2} + ... + f_{n-1})

where there are :math:`n` functions. The gradient method from a CIL function is overwritten and calls an approximate gradient method.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

It is an abstract base class and any child classes must implement an `approximate_gradient` function.

Parameters:
-----------
functions : `list` of functions
A list of functions: :code:`[f_{0}, f_{2}, ..., f_{n-1}]`. Each function is assumed to be smooth function with an implemented :func:`~Function.gradient` method. Each function must have the same domain. The number of functions must be strictly greater than 1.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
sampler: An instance of a CIL Sampler class ( :meth:`~optimisation.utilities.sampler`) or another class which has a `next` function implemented to output integers in {0,...,n-1}.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
This sampler is called each time gradient is called and sets the internal `function_num` passed to the `approximate_gradient` function. The `num_indices` must match the number of functions provided. Default is `Sampler.random_with_replacement(len(functions))`.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved


Note
-----
We provide two ways of keeping track the amount of data you have seen:
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
- `data_passes_indices` a list of lists the length of which should be the number of iterations currently run. Each entry corresponds to the indices of the function numbers seen in that iteration.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
- `data_passes` is a list of floats the length of which should be the number of iterations currently run. Each entry corresponds to the proportion of data seen up to this iteration. Warning: if your functions do not contain an equal `amount` of data, for example your data was not partitioned into equal batches, then you must first use the `set_data_partition_weights" function for this to be accurate.



Note
----
The :meth:`~ApproximateGradientSumFunction.gradient` returns the approximate gradient depending on an index provided by the :code:`sampler` method.

Example
-------
Consider the objective is to minimise:

.. math:: \sum_{i=0}^{n-1} f_{i}(x) = \sum_{i=0}^{n-1}\|A_{i} x - b_{i}\|^{2}

>>> list_of_functions = [LeastSquares(Ai, b=bi)] for Ai,bi in zip(A_subsets, b_subsets))
>>> f = ApproximateGradientSumFunction(list_of_functions)

>>> list_of_functions = [LeastSquares(Ai, b=bi)] for Ai,bi in zip(A_subsets, b_subsets))
>>> sampler = Sampler.random_shuffle(len(list_of_functions))
>>> f = ApproximateGradientSumFunction(list_of_functions, sampler=sampler)

MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

"""

def __init__(self, functions, sampler=None):

if sampler is None:
sampler = Sampler.random_with_replacement(len(functions))

if not isinstance(functions, list):
raise TypeError("Input to functions should be a list of functions")
if not hasattr(sampler, "next"):
raise ValueError('The provided sampler must have a `next` method')

self.sampler = sampler

self._partition_weights = [1 / len(functions)] * len(functions)

self._data_passes_indices = []

super(ApproximateGradientSumFunction, self).__init__(*functions)

def __call__(self, x):
r"""Returns the value of the sum of functions at :math:`x`.

.. math:: (f_{0} + f_{1} + ... + f_{n-1})(x) = f_{0}(x) + f_{1}(x) + ... + f_{n-1}(x)

Parameters
----------
x : DataContainer

--------
float
the value of the SumFunction at x


"""
return super(ApproximateGradientSumFunction, self).__call__(x)

def full_gradient(self, x, out=None):
r"""Returns the value of the full gradient of the sum of functions at :math:`x`.

.. math:: \nabla_x(f_{0} + f_{1} + ... + f_{n-1})(x) = \nabla_xf_{0}(x) + \nabla_xf_{1}(x) + ... + \nabla_xf_{n-1}(x)

Parameters
----------
x : DataContainer
out: return DataContainer, if `None` a new DataContainer is returned, default `None`.

Returns
--------
DataContainer
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
the value of the gradient of the sum function at x or nothing if `out`
"""

return super(ApproximateGradientSumFunction, self).gradient(x, out=out)

@abstractmethod
def approximate_gradient(self, x, function_num, out=None):
""" Computes the approximate gradient for each selected function at :code:`x` given a `function_number` in {0,...,len(functions)-1}.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
x : DataContainer
out: return DataContainer, if `None` a new DataContainer is returned, default `None`.
function_num: `int`
Between 0 and the number of functions in the list
Returns
--------
DataContainer
the value of the approximate gradient of the sum function at :code:`x` given a `function_number` in {0,...,len(functions)-1} or nothing if `out`
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
"""
pass

def gradient(self, x, out=None):
""" Selects a random function using the `sampler` and then calls the approximate gradient at :code:`x`

Parameters
----------
x : DataContainer
out: return DataContainer, if `None` a new DataContainer is returned, default `None`.

Returns
--------
DataContainer
the value of the approximate gradient of the sum function at :code:`x` or nothing if `out`
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
"""

self.function_num = self.sampler.next()

if self.function_num > self.num_functions:
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
raise IndexError(
'The sampler has outputted an index larger than the number of functions to sample from. Please ensure your sampler samples from {0,1,...,len(functions)-1} only.')

if isinstance(self.function_num, numbers.Number):
return self.approximate_gradient(x, self.function_num, out=out)
raise ValueError("Batch gradient is not yet implemented")
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

def _update_data_passes_indices(self, indices):
""" Internal function that updates the list of lists containing the function indices seen at each iteration.

Parameters
----------
indices: list
List of indices seen in a given iteration

"""
self._data_passes_indices.append(indices)

def set_data_partition_weights(self, weights):
""" Setter for the partition weights used to calculate the data passes

Parameters
----------
weights: list of positive floats that sum to one.
The proportion of the data held in each function. Equivalent to the proportions that you partitioned your data into.

"""
if len(weights) != len(self.functions):
raise ValueError(
'The provided weights must be a list the same length as the number of functions')

if abs(sum(weights) - 1) > 1e-6:
raise ValueError('The provided weights must sum to one')

if any(np.array(weights) < 0):
raise ValueError(
'The provided weights must be greater than or equal to zero')

self._partition_weights = weights

@property
def data_passes_indices(self):
return self._data_passes_indices

@property
def data_passes(self):
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
data_passes = []
for el in self._data_passes_indices:
try:
data_passes.append(data_passes[-1])
except IndexError:
data_passes.append(0)
for i in el:
data_passes[-1] += self._partition_weights[i]
return data_passes
3 changes: 3 additions & 0 deletions Wrappers/Python/cil/optimisation/functions/Function.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,9 @@ def __add__(self, other):
else:
return super(SumFunction, self).__add__(other)

@property
def num_functions(self):
return len(self.functions)

class ScaledFunction(Function):

Expand Down
72 changes: 72 additions & 0 deletions Wrappers/Python/cil/optimisation/functions/SGFunction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# This work is part of the Core Imaging Library (CIL) developed by CCPi
# (Collaborative Computational Project in Tomographic Imaging), with
# substantial contributions by UKRI-STFC and University of Manchester.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
from .ApproximateGradientSumFunction import ApproximateGradientSumFunction
from .Function import SumFunction

class SGFunction(ApproximateGradientSumFunction):

"""
Stochastic gradient function, a child class of `ApproximateGradientSumFunction`, which defines from a list of functions, :math:`{f_0,...,f_{n-1}}` a `SumFunction`, :math:`f_0+...+f_{n-1}` where each time the `gradient` is called, the `sampler` provides an index, :math:`i \in {0,...,n-1}`
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
and the gradient function returns the approximate gradient :math:`n\nabla_xf_i(x)`. This can be used with the `cil.optimisation.algorithms` algorithm GD to give a stochastic gradient descent algorithm.

Parameters:
-----------
functions : `list` of functions
A list of functions: :code:`[f_{0}, f_{1}, ..., f_{n-1}]`. Each function is assumed to be smooth function with an implemented :func:`~Function.gradient` method. Each function must have the same domain. The number of functions must be strictly greater than 1.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
sampler: An instance of a CIL Sampler class ( :meth:`~optimisation.utilities.sampler`) or another class which has a `next` function implemented to output integers in {0,...,n-1}.
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
This sampler is called each time gradient is called and sets the internal `function_num` passed to the `approximate_gradient` function. The `num_indices` must match the number of functions provided. Default is `Sampler.random_with_replacement(len(functions))`.
"""

def __init__(self, functions, sampler=None):
super(SGFunction, self).__init__(functions, sampler)


def approximate_gradient(self, x, function_num, out=None):

r""" Returns the gradient of the function at index `function_num` at :code:`x`.

Parameters
----------
x : DataContainer
out: return DataContainer, if `None` a new DataContainer is returned, default `None`.
function_num: `int`
Between 0 and the number of functions in the list
Returns
--------
DataContainer
the value of the approximate gradient of the sum function at :code:`x` given a `function_number` in {0,...,len(functions)-1} or nothing if `out`
"""

self._update_data_passes_indices([function_num])
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved

# compute gradient of randomly selected(function_num) function
MargaretDuff marked this conversation as resolved.
Show resolved Hide resolved
if out is None:
out = self.functions[function_num].gradient(x)
else:
self.functions[function_num].gradient(x, out = out)

# scale wrt number of functions
out*=self.num_functions

return out






3 changes: 3 additions & 0 deletions Wrappers/Python/cil/optimisation/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@
from .KullbackLeibler import KullbackLeibler
from .Rosenbrock import Rosenbrock
from .TotalVariation import TotalVariation
from .ApproximateGradientSumFunction import ApproximateGradientSumFunction
from .SGFunction import SGFunction

Loading
Loading