Skip to content

Commit

Permalink
edited errors for transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
PTWaade committed May 1, 2024
1 parent d7f2ef5 commit 3a15a32
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
HierarchicalGaussianFiltering = "63d42c3e-681c-42be-892f-a47f35336a79"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Expand Down
24 changes: 17 additions & 7 deletions src/update_hgf/node_updates/continuous_state_node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ function calculate_prediction_mean(node::ContinuousStateNode, stepsize::Real)
#Transform the parent's value
drift_increment = transform_parent_value(
parent.states.posterior_mean,
coupling_transform,
coupling_transform;
derivation_level = 0,
child = node,
parent = parent
)

#Add the drift increment
Expand Down Expand Up @@ -225,17 +227,21 @@ function calculate_posterior_precision_increment(
#Calculate the increment
child.states.prediction_precision * (
coupling_strength^2 * transform_parent_value(
coupling_transform,
node.states.posterior_mean,
coupling_transform;
derivation_level = 1,
parent = node,
child = child,
) -
coupling_strength *
node.states.value_prediction_error *
transform_parent_value(
coupling_transform,
node.states.posterior_mean,
coupling_transform;
derivation_level = 2,
)
parent = node,
child = child,
) *
child.states.value_prediction_error
)

end
Expand Down Expand Up @@ -396,9 +402,11 @@ function calculate_posterior_mean_increment(
(
child.parameters.coupling_strengths[node.name] *
transform_parent_value(
child.parameters.coupling_transforms[node.name],
node.states.posterior_mean,
child.parameters.coupling_transforms[node.name];
derivation_level = 1,
parent = node,
child = child,
) *
child.states.prediction_precision
) / node.states.posterior_precision
Expand All @@ -416,9 +424,11 @@ function calculate_posterior_mean_increment(
(
child.parameters.coupling_strengths[node.name] *
transform_parent_value(
child.parameters.coupling_transforms[node.name],
node.states.posterior_mean,
child.parameters.coupling_transforms[node.name];
derivation_level = 1,
parent = node,
child = child,
) *
child.states.prediction_precision
) / node.states.prediction_precision
Expand Down
22 changes: 13 additions & 9 deletions src/update_hgf/nonlinear_transforms.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
#Transformation (of varioius derivations) for a linear transformation
function transform_parent_value(
transform_type::LinearTransform,
parent_value::Real;
parent_value::Real,
transform_type::LinearTransform;
derivation_level::Integer,
parent::AbstractNode,
child::AbstractNode,
)
if derivation_level == 0
return parent_value
elseif derivation_level == 0
elseif derivation_level == 1
return 1
elseif derivation_level == 0
elseif derivation_level == 2
return 0
else
@error "derivation level is misspecified"
Expand All @@ -17,26 +19,28 @@ end

#Transformation (of varioius derivations) for a nonlinear transformation
function transform_parent_value(
transform_type::NonlinearTransform,
parent_value::Real,
transform_type::NonlinearTransform;
derivation_level::Integer,
parent::AbstractNode,
child::AbstractNode,
)

#Get the transformation function that fits the derivation level
if derivation_level == 0
transform_function = node.parameters.coupling_transforms[parent.name].base_function
transform_function = child.parameters.coupling_transforms[parent.name].base_function
elseif derivation_level == 1
transform_function =
node.parameters.coupling_transforms[parent.name].first_derivation
child.parameters.coupling_transforms[parent.name].first_derivation
elseif derivation_level == 2
transform_function =
node.parameters.coupling_transforms[parent.name].second_derivation
child.parameters.coupling_transforms[parent.name].second_derivation
else
@error "derivation level is misspecified"
end

#Get the transformation parameters
transform_parameters = node.parameters.coupling_transforms[parent.name].parameters
transform_parameters = child.parameters.coupling_transforms[parent.name].parameters

#Transform the value
return transform_function(parent_value, transform_parameters)
Expand Down
39 changes: 39 additions & 0 deletions test/testsuite/test_custom_structures.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using Test
using HierarchicalGaussianFiltering


@testset "test custom structures" begin

@testset "Many continuous nodes" begin
nodes = [
ContinuousInput(name = "u"),
ContinuousInput(name = "u2"),

ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
ContinuousState(name = "x3"),
ContinuousState(name = "x4"),
ContinuousState(name = "x5"),

]

edges = Dict(
("u", "x1") => ObservationCoupling(),
("u2", "x2") => ObservationCoupling(),

("x1", "x2") => DriftCoupling(),
("x1", "x3") => VolatilityCoupling(),
("u2", "x3") => NoiseCoupling(),
("x2", "x4") => VolatilityCoupling(),
("x3", "x5") => DriftCoupling(),
)

hgf = init_hgf(
nodes = nodes,
edges = edges,
verbose = false,
)

update_hgf!(hgf, [1,1])
end
end
40 changes: 40 additions & 0 deletions test/testsuite/test_nonlinear_transforms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using Test
using HierarchicalGaussianFiltering

@testset "Testing nonlinear transforms" begin

@testset "Sinoid transform" begin
nodes = [
ContinuousInput(name = "u"),
ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
]

base = function(x, parameters::Dict)
sin(x)
end
first_derivative = function(x, parameters::Dict)
cos(x)
end
second_derivative = function(x, parameters::Dict)
-sin(x)
end
transform_parameters = Dict()

edges = Dict(
("u", "x1") => ObservationCoupling(),
("x1", "x2") => DriftCoupling(2,NonlinearTransform(base, first_derivative, second_derivative, transform_parameters)),
)

hgf = init_hgf(
nodes = nodes,
edges = edges,
verbose = false,
)

update_hgf!(hgf, 1)
end
end



0 comments on commit 3a15a32

Please sign in to comment.