diff --git a/docs/Project.toml b/docs/Project.toml index b3a6a54..0e3c304 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,6 +5,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" HierarchicalGaussianFiltering = "63d42c3e-681c-42be-892f-a47f35336a79" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" diff --git a/src/update_hgf/node_updates/continuous_state_node.jl b/src/update_hgf/node_updates/continuous_state_node.jl index 5876fa3..2af667b 100644 --- a/src/update_hgf/node_updates/continuous_state_node.jl +++ b/src/update_hgf/node_updates/continuous_state_node.jl @@ -45,8 +45,10 @@ function calculate_prediction_mean(node::ContinuousStateNode, stepsize::Real) #Transform the parent's value drift_increment = transform_parent_value( parent.states.posterior_mean, - coupling_transform, + coupling_transform; derivation_level = 0, + child = node, + parent = parent ) #Add the drift increment @@ -225,17 +227,21 @@ function calculate_posterior_precision_increment( #Calculate the increment child.states.prediction_precision * ( coupling_strength^2 * transform_parent_value( - coupling_transform, node.states.posterior_mean, + coupling_transform; derivation_level = 1, + parent = node, + child = child, ) - coupling_strength * - node.states.value_prediction_error * transform_parent_value( - coupling_transform, node.states.posterior_mean, + coupling_transform; derivation_level = 2, - ) + parent = node, + child = child, + ) * + child.states.value_prediction_error ) end @@ -396,9 +402,11 @@ function calculate_posterior_mean_increment( ( child.parameters.coupling_strengths[node.name] * transform_parent_value( - child.parameters.coupling_transforms[node.name], node.states.posterior_mean, + child.parameters.coupling_transforms[node.name]; derivation_level = 1, + parent = node, + child = child, ) * child.states.prediction_precision ) / node.states.posterior_precision @@ -416,9 +424,11 @@ function calculate_posterior_mean_increment( ( child.parameters.coupling_strengths[node.name] * transform_parent_value( - child.parameters.coupling_transforms[node.name], node.states.posterior_mean, + child.parameters.coupling_transforms[node.name]; derivation_level = 1, + parent = node, + child = child, ) * child.states.prediction_precision ) / node.states.prediction_precision diff --git a/src/update_hgf/nonlinear_transforms.jl b/src/update_hgf/nonlinear_transforms.jl index 2217e17..6edc188 100644 --- a/src/update_hgf/nonlinear_transforms.jl +++ b/src/update_hgf/nonlinear_transforms.jl @@ -1,14 +1,16 @@ #Transformation (of varioius derivations) for a linear transformation function transform_parent_value( - transform_type::LinearTransform, - parent_value::Real; + parent_value::Real, + transform_type::LinearTransform; derivation_level::Integer, + parent::AbstractNode, + child::AbstractNode, ) if derivation_level == 0 return parent_value - elseif derivation_level == 0 + elseif derivation_level == 1 return 1 - elseif derivation_level == 0 + elseif derivation_level == 2 return 0 else @error "derivation level is misspecified" @@ -17,26 +19,28 @@ end #Transformation (of varioius derivations) for a nonlinear transformation function transform_parent_value( - transform_type::NonlinearTransform, parent_value::Real, + transform_type::NonlinearTransform; derivation_level::Integer, + parent::AbstractNode, + child::AbstractNode, ) #Get the transformation function that fits the derivation level if derivation_level == 0 - transform_function = node.parameters.coupling_transforms[parent.name].base_function + transform_function = child.parameters.coupling_transforms[parent.name].base_function elseif derivation_level == 1 transform_function = - node.parameters.coupling_transforms[parent.name].first_derivation + child.parameters.coupling_transforms[parent.name].first_derivation elseif derivation_level == 2 transform_function = - node.parameters.coupling_transforms[parent.name].second_derivation + child.parameters.coupling_transforms[parent.name].second_derivation else @error "derivation level is misspecified" end #Get the transformation parameters - transform_parameters = node.parameters.coupling_transforms[parent.name].parameters + transform_parameters = child.parameters.coupling_transforms[parent.name].parameters #Transform the value return transform_function(parent_value, transform_parameters) diff --git a/test/testsuite/test_custom_structures.jl b/test/testsuite/test_custom_structures.jl new file mode 100644 index 0000000..4f9ed85 --- /dev/null +++ b/test/testsuite/test_custom_structures.jl @@ -0,0 +1,39 @@ +using Test +using HierarchicalGaussianFiltering + + +@testset "test custom structures" begin + + @testset "Many continuous nodes" begin + nodes = [ + ContinuousInput(name = "u"), + ContinuousInput(name = "u2"), + + ContinuousState(name = "x1"), + ContinuousState(name = "x2"), + ContinuousState(name = "x3"), + ContinuousState(name = "x4"), + ContinuousState(name = "x5"), + + ] + + edges = Dict( + ("u", "x1") => ObservationCoupling(), + ("u2", "x2") => ObservationCoupling(), + + ("x1", "x2") => DriftCoupling(), + ("x1", "x3") => VolatilityCoupling(), + ("u2", "x3") => NoiseCoupling(), + ("x2", "x4") => VolatilityCoupling(), + ("x3", "x5") => DriftCoupling(), + ) + + hgf = init_hgf( + nodes = nodes, + edges = edges, + verbose = false, + ) + + update_hgf!(hgf, [1,1]) + end +end \ No newline at end of file diff --git a/test/testsuite/test_nonlinear_transforms.jl b/test/testsuite/test_nonlinear_transforms.jl new file mode 100644 index 0000000..5ac204d --- /dev/null +++ b/test/testsuite/test_nonlinear_transforms.jl @@ -0,0 +1,40 @@ +using Test +using HierarchicalGaussianFiltering + +@testset "Testing nonlinear transforms" begin + + @testset "Sinoid transform" begin + nodes = [ + ContinuousInput(name = "u"), + ContinuousState(name = "x1"), + ContinuousState(name = "x2"), + ] + + base = function(x, parameters::Dict) + sin(x) + end + first_derivative = function(x, parameters::Dict) + cos(x) + end + second_derivative = function(x, parameters::Dict) + -sin(x) + end + transform_parameters = Dict() + + edges = Dict( + ("u", "x1") => ObservationCoupling(), + ("x1", "x2") => DriftCoupling(2,NonlinearTransform(base, first_derivative, second_derivative, transform_parameters)), + ) + + hgf = init_hgf( + nodes = nodes, + edges = edges, + verbose = false, + ) + + update_hgf!(hgf, 1) + end +end + + +