From 9956a27a054bbd3d125badafd2e8af4236e2e93f Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Sat, 4 May 2024 10:21:00 +0200 Subject: [PATCH] first JET edits --- docs/Project.toml | 1 + .../utils/get_states.jl | 4 +- src/create_hgf/hgf_structs.jl | 2 +- .../node_updates/continuous_state_node.jl | 4 +- src/update_hgf/nonlinear_transforms.jl | 6 +- test/testsuite/Aqua.jl | 3 +- test/testsuite/test_custom_structures.jl | 55 ++++++++----------- test/testsuite/test_nonlinear_transforms.jl | 27 ++++----- 8 files changed, 49 insertions(+), 53 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index b8c4ddf..2809280 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -9,6 +9,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" HierarchicalGaussianFiltering = "63d42c3e-681c-42be-892f-a47f35336a79" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" diff --git a/src/ActionModels_variations/utils/get_states.jl b/src/ActionModels_variations/utils/get_states.jl index f81f417..104a667 100644 --- a/src/ActionModels_variations/utils/get_states.jl +++ b/src/ActionModels_variations/utils/get_states.jl @@ -36,7 +36,9 @@ function ActionModels.get_states(node::AbstractNode, state_name::String) #If the state does not exist in the node if !(Symbol(state_name) in fieldnames(typeof(node.states))) #throw an error - throw(ArgumentError("The node $node_name does not have the state $state_name")) + throw( + ArgumentError("The node $(node.node_name) does not have the state $state_name"), + ) end #If a prediction state has been specified diff --git a/src/create_hgf/hgf_structs.jl b/src/create_hgf/hgf_structs.jl index a60f6ee..c6ff7b3 100644 --- a/src/create_hgf/hgf_structs.jl +++ b/src/create_hgf/hgf_structs.jl @@ -79,7 +79,7 @@ end """ """ Base.@kwdef mutable struct OrderedNodes - all_nodes::Vector{AbstractNode} = [] + all_nodes::Vector{AbstractNode} = AbstractNode[] input_nodes::Vector{AbstractInputNode} = [] all_state_nodes::Vector{AbstractStateNode} = [] early_update_state_nodes::Vector{AbstractStateNode} = [] diff --git a/src/update_hgf/node_updates/continuous_state_node.jl b/src/update_hgf/node_updates/continuous_state_node.jl index 2af667b..ab32105 100644 --- a/src/update_hgf/node_updates/continuous_state_node.jl +++ b/src/update_hgf/node_updates/continuous_state_node.jl @@ -48,7 +48,7 @@ function calculate_prediction_mean(node::ContinuousStateNode, stepsize::Real) coupling_transform; derivation_level = 0, child = node, - parent = parent + parent = parent, ) #Add the drift increment @@ -240,7 +240,7 @@ function calculate_posterior_precision_increment( derivation_level = 2, parent = node, child = child, - ) * + ) * child.states.value_prediction_error ) diff --git a/src/update_hgf/nonlinear_transforms.jl b/src/update_hgf/nonlinear_transforms.jl index 6edc188..8cc577f 100644 --- a/src/update_hgf/nonlinear_transforms.jl +++ b/src/update_hgf/nonlinear_transforms.jl @@ -1,6 +1,6 @@ #Transformation (of varioius derivations) for a linear transformation function transform_parent_value( - parent_value::Real, + parent_value::Real, transform_type::LinearTransform; derivation_level::Integer, parent::AbstractNode, @@ -31,10 +31,10 @@ function transform_parent_value( transform_function = child.parameters.coupling_transforms[parent.name].base_function elseif derivation_level == 1 transform_function = - child.parameters.coupling_transforms[parent.name].first_derivation + child.parameters.coupling_transforms[parent.name].first_derivation elseif derivation_level == 2 transform_function = - child.parameters.coupling_transforms[parent.name].second_derivation + child.parameters.coupling_transforms[parent.name].second_derivation else @error "derivation level is misspecified" end diff --git a/test/testsuite/Aqua.jl b/test/testsuite/Aqua.jl index 25e76e1..ed45564 100644 --- a/test/testsuite/Aqua.jl +++ b/test/testsuite/Aqua.jl @@ -1,4 +1,3 @@ using HierarchicalGaussianFiltering using Aqua -Aqua.test_all(HierarchicalGaussianFiltering, -ambiguities = false) +Aqua.test_all(HierarchicalGaussianFiltering, ambiguities = false) diff --git a/test/testsuite/test_custom_structures.jl b/test/testsuite/test_custom_structures.jl index 4f9ed85..6d336ef 100644 --- a/test/testsuite/test_custom_structures.jl +++ b/test/testsuite/test_custom_structures.jl @@ -5,35 +5,28 @@ using HierarchicalGaussianFiltering @testset "test custom structures" begin @testset "Many continuous nodes" begin - nodes = [ - ContinuousInput(name = "u"), - ContinuousInput(name = "u2"), - - ContinuousState(name = "x1"), - ContinuousState(name = "x2"), - ContinuousState(name = "x3"), - ContinuousState(name = "x4"), - ContinuousState(name = "x5"), - - ] - - edges = Dict( - ("u", "x1") => ObservationCoupling(), - ("u2", "x2") => ObservationCoupling(), - - ("x1", "x2") => DriftCoupling(), - ("x1", "x3") => VolatilityCoupling(), - ("u2", "x3") => NoiseCoupling(), - ("x2", "x4") => VolatilityCoupling(), - ("x3", "x5") => DriftCoupling(), - ) - - hgf = init_hgf( - nodes = nodes, - edges = edges, - verbose = false, - ) - - update_hgf!(hgf, [1,1]) + nodes = [ + ContinuousInput(name = "u"), + ContinuousInput(name = "u2"), + ContinuousState(name = "x1"), + ContinuousState(name = "x2"), + ContinuousState(name = "x3"), + ContinuousState(name = "x4"), + ContinuousState(name = "x5"), + ] + + edges = Dict( + ("u", "x1") => ObservationCoupling(), + ("u2", "x2") => ObservationCoupling(), + ("x1", "x2") => DriftCoupling(), + ("x1", "x3") => VolatilityCoupling(), + ("u2", "x3") => NoiseCoupling(), + ("x2", "x4") => VolatilityCoupling(), + ("x3", "x5") => DriftCoupling(), + ) + + hgf = init_hgf(nodes = nodes, edges = edges, verbose = false) + + update_hgf!(hgf, [1, 1]) end -end \ No newline at end of file +end diff --git a/test/testsuite/test_nonlinear_transforms.jl b/test/testsuite/test_nonlinear_transforms.jl index 5ac204d..fbab0ad 100644 --- a/test/testsuite/test_nonlinear_transforms.jl +++ b/test/testsuite/test_nonlinear_transforms.jl @@ -2,7 +2,7 @@ using Test using HierarchicalGaussianFiltering @testset "Testing nonlinear transforms" begin - + @testset "Sinoid transform" begin nodes = [ ContinuousInput(name = "u"), @@ -10,31 +10,32 @@ using HierarchicalGaussianFiltering ContinuousState(name = "x2"), ] - base = function(x, parameters::Dict) + base = function (x, parameters::Dict) sin(x) end - first_derivative = function(x, parameters::Dict) + first_derivative = function (x, parameters::Dict) cos(x) end - second_derivative = function(x, parameters::Dict) + second_derivative = function (x, parameters::Dict) -sin(x) end transform_parameters = Dict() edges = Dict( ("u", "x1") => ObservationCoupling(), - ("x1", "x2") => DriftCoupling(2,NonlinearTransform(base, first_derivative, second_derivative, transform_parameters)), + ("x1", "x2") => DriftCoupling( + 2, + NonlinearTransform( + base, + first_derivative, + second_derivative, + transform_parameters, + ), + ), ) - hgf = init_hgf( - nodes = nodes, - edges = edges, - verbose = false, - ) + hgf = init_hgf(nodes = nodes, edges = edges, verbose = false) update_hgf!(hgf, 1) end end - - -