Skip to content

Commit

Permalink
GD: fix rate => step_size
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Feb 13, 2024
1 parent a3e8b98 commit 194dcf6
Showing 1 changed file with 8 additions and 17 deletions.
25 changes: 8 additions & 17 deletions Wrappers/Python/test/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
from cil.plugins.astra import ProjectionOperator

class TestAlgorithms(CCPiTestClass):

def test_GD(self):
ig = ImageGeometry(12,13,14)
initial = ig.allocate()
Expand All @@ -77,19 +76,15 @@ def test_GD(self):
identity = IdentityOperator(ig)

norm2sq = LeastSquares(identity, b)
rate = norm2sq.L / 3.
step_size = norm2sq.L / 3.

alg = GD(initial=initial,
objective_function=norm2sq,
rate=rate, atol=1e-9, rtol=1e-6)
alg = GD(initial=initial, objective_function=norm2sq, step_size=step_size,
atol=1e-9, rtol=1e-6)
alg.max_iteration = 1000
alg.run(verbose=0)
self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array())
alg = GD(initial=initial,
objective_function=norm2sq,
rate=rate, max_iteration=20,
update_objective_interval=2,
atol=1e-9, rtol=1e-6)
alg = GD(initial=initial, objective_function=norm2sq, step_size=step_size,
atol=1e-9, rtol=1e-6, max_iteration=20, update_objective_interval=2)
alg.max_iteration = 20
self.assertTrue(alg.max_iteration == 20)
self.assertTrue(alg.update_objective_interval==2)
Expand Down Expand Up @@ -129,17 +124,13 @@ def test_GDArmijo(self):
identity = IdentityOperator(ig)

norm2sq = LeastSquares(identity, b)
rate = None

alg = GD(initial=initial,
objective_function=norm2sq, rate=rate)
alg = GD(initial=initial, objective_function=norm2sq)
alg.max_iteration = 100
alg.run(verbose=0)
self.assertNumpyArrayAlmostEqual(alg.x.as_array(), b.as_array())
alg = GD(initial=initial,
objective_function=norm2sq,
max_iteration=20,
update_objective_interval=2)
alg = GD(initial=initial, objective_function=norm2sq,
max_iteration=20, update_objective_interval=2)
#alg.max_iteration = 20
self.assertTrue(alg.max_iteration == 20)
self.assertTrue(alg.update_objective_interval==2)
Expand Down

0 comments on commit 194dcf6

Please sign in to comment.