Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
HDembinski committed Aug 1, 2024
1 parent b0ed4ad commit 6c41d5c
Showing 1 changed file with 44 additions and 9 deletions.
53 changes: 44 additions & 9 deletions tests/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,27 +1066,62 @@ def model(x, a, b):


def test_LeastSquares_2D():
x = np.array([1.0, 2.0, 3.0])
y = np.array([4.0, 5.0, 6.0])
z = 1.5 * x + 0.2 * y
ze = 1.5

def model(xy, a, b):
x, y = xy
return a * x + b * y

c = LeastSquares((x, y), z, ze, model, grad=numerical_model_gradient(model))
x = np.array([1.0, 2.0, 3.0])
y = np.array([4.0, 5.0, 6.0])
f = model((x, y), 1.5, 0.2)
fe = 1.5

c = LeastSquares((x, y), f, fe, model, grad=numerical_model_gradient(model))
assert c.ndata == 3

ref = numerical_cost_gradient(c)
assert_allclose(c.grad(1, 2), ref(1, 2))

assert_equal(c.x, (x, y))
assert_equal(c.y, f)
assert_equal(c.yerror, fe)
assert_allclose(c(1.5, 0.2), 0.0)
assert_allclose(c(2.5, 0.2), np.sum(((f - 2.5 * x - 0.2 * y) / fe) ** 2))
assert_allclose(c(1.5, 1.2), np.sum(((f - 1.5 * x - 1.2 * y) / fe) ** 2))

c.y = 2 * f
assert_equal(c.y, 2 * f)
c.x = (y, x)
assert_equal(c.x, (y, x))


def test_LeastSquares_3D():
def model(xyz, a, b):
x, y, z = xyz
return a * x + b * y + a * b * z

x = np.array([1.0, 2.0, 3.0, 4.0])
y = np.array([4.0, 5.0, 6.0, 7.0])
z = np.array([7.0, 8.0, 9.0, 10.0])

f = model((x, y, z), 1.5, 0.2)
fe = 1.5

c = LeastSquares((x, y, z), f, fe, model, grad=numerical_model_gradient(model))
assert c.ndata == 4

ref = numerical_cost_gradient(c)
assert_allclose(c.grad(1, 2), ref(1, 2))

assert_equal(c.x, (x, y))
assert_equal(c.y, z)
assert_equal(c.yerror, ze)
assert_equal(c.yerror, fe)
assert_allclose(c(1.5, 0.2), 0.0)
assert_allclose(c(2.5, 0.2), np.sum(((z - 2.5 * x - 0.2 * y) / ze) ** 2))
assert_allclose(c(1.5, 1.2), np.sum(((z - 1.5 * x - 1.2 * y) / ze) ** 2))
assert_allclose(
c(2.5, 0.2), np.sum(((f - 2.5 * x - 0.2 * y - 2.5 * 0.2 * z) / fe) ** 2)
)
assert_allclose(
c(1.5, 1.2), np.sum(((f - 1.5 * x - 1.2 * y - 1.5 * 1.2 * z) / fe) ** 2)
)

c.y = 2 * z
assert_equal(c.y, 2 * z)
Expand Down

0 comments on commit 6c41d5c

Please sign in to comment.