diff --git a/.github/workflows/CI_full.yml b/.github/workflows/CI_full.yml index 6f6d6c8..7942932 100644 --- a/.github/workflows/CI_full.yml +++ b/.github/workflows/CI_full.yml @@ -37,7 +37,7 @@ jobs: with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 diff --git a/.github/workflows/CI_small.yml b/.github/workflows/CI_small.yml index 41a5a20..b8a0c04 100644 --- a/.github/workflows/CI_small.yml +++ b/.github/workflows/CI_small.yml @@ -33,7 +33,7 @@ jobs: with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - uses: julia-actions/julia-processcoverage@v1 diff --git a/Project.toml b/Project.toml index 1672728..fe65b3b 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,7 @@ authors = [ "Peter Thestrup Waade ptw@cas.au.dk", "Anna Hedvig Møller hedvig.2808@gmail.com", "Jacopo Comoglio jacopo.comoglio@gmail.com", "Christoph Mathys chmathys@cas.au.dk"] -version = "0.5.1" +version = "0.5.2" [deps] ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" @@ -15,4 +15,4 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" ActionModels = "0.5" Distributions = "0.25" RecipesBase = "1" -julia = "1.9" +julia = "1.10" diff --git a/README.md b/README.md index 255aefc..0acb0d5 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [![Build Status](https://github.com/ilabcode/HierarchicalGaussianFiltering.jl/actions/workflows/CI_full.yml/badge.svg?branch=main)](https://github.com/ilabcode/HierarchicalGaussianFiltering.jl/actions/workflows/CI_full.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/ilabcode/HierarchicalGaussianFiltering.jl/branch/main/graph/badge.svg?token=NVFiiPydFA)](https://codecov.io/gh/ilabcode/HierarchicalGaussianFiltering.jl) [![License: GNU](https://img.shields.io/badge/License-GNU-yellow)]() +[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) # Welcome to The Hierarchical Gaussian Filtering Package! diff --git a/docs/Project.toml b/docs/Project.toml index b3a6a54..2809280 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,12 +1,15 @@ [deps] ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" HierarchicalGaussianFiltering = "63d42c3e-681c-42be-892f-a47f35336a79" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" diff --git a/src/ActionModels_variations/utils/get_history.jl b/src/ActionModels_variations/utils/get_history.jl index 75bf0ec..8b8ce6b 100644 --- a/src/ActionModels_variations/utils/get_history.jl +++ b/src/ActionModels_variations/utils/get_history.jl @@ -7,7 +7,7 @@ Gets the history of a state from a specific node in an HGF. A vector of states c Gets the history of all states for a specific node in an HGF. If only a node object is passed, it will return the history of all states in that node. If only an HGF object is passed, it will return the history of all states in all nodes in the HGF. """ -function ActionModels.get_history() end +# function ActionModels.get_history() end ### For getting histories of specific states ### function ActionModels.get_history(hgf::HGF, target_state::Tuple{String,String}) diff --git a/src/ActionModels_variations/utils/get_parameters.jl b/src/ActionModels_variations/utils/get_parameters.jl index c1a890c..d5ed385 100644 --- a/src/ActionModels_variations/utils/get_parameters.jl +++ b/src/ActionModels_variations/utils/get_parameters.jl @@ -7,7 +7,7 @@ Gets a single parameter value from a specific node in an HGF. A vector of parame Gets all parameter values for a specific node in an HGF. If only a node object is passed, returns all parameters in that node. If only an HGF object is passed, returns all parameters of all nodes in the HGF. """ -function ActionModels.get_parameters() end +# function ActionModels.get_parameters() end ### For getting a specific parameter from a specific node ### #For parameters other than coupling strengths diff --git a/src/ActionModels_variations/utils/get_states.jl b/src/ActionModels_variations/utils/get_states.jl index 02db849..104a667 100644 --- a/src/ActionModels_variations/utils/get_states.jl +++ b/src/ActionModels_variations/utils/get_states.jl @@ -7,7 +7,7 @@ Gets a single state value from a specific node in an HGF. A vector of states can Gets all parameter values for a specific node in an HGF. If only a node object is passed, returns all states in that node. If only an HGF object is passed, returns all states of all nodes in the HGF. """ -function ActionModels.get_states() end +# function ActionModels.get_states() end ### For getting a specific state from a specific node ### @@ -36,7 +36,9 @@ function ActionModels.get_states(node::AbstractNode, state_name::String) #If the state does not exist in the node if !(Symbol(state_name) in fieldnames(typeof(node.states))) #throw an error - throw(ArgumentError("The node $node_name does not have the state $state_name")) + throw( + ArgumentError("The node $(node.node_name) does not have the state $state_name"), + ) end #If a prediction state has been specified diff --git a/src/ActionModels_variations/utils/give_inputs.jl b/src/ActionModels_variations/utils/give_inputs.jl index 4720e54..064e464 100644 --- a/src/ActionModels_variations/utils/give_inputs.jl +++ b/src/ActionModels_variations/utils/give_inputs.jl @@ -3,7 +3,7 @@ give_inputs!(hgf::HGF, inputs) Give inputs to an agent. Input can be a single value, a vector of values, or an array of values. """ -function ActionModels.give_inputs!() end +# function ActionModels.give_inputs!() end ### Giving a single input ### diff --git a/src/ActionModels_variations/utils/set_parameters.jl b/src/ActionModels_variations/utils/set_parameters.jl index e83f2c6..baa6402 100644 --- a/src/ActionModels_variations/utils/set_parameters.jl +++ b/src/ActionModels_variations/utils/set_parameters.jl @@ -7,7 +7,7 @@ Setting a single parameter value for an HGF. Set mutliple parameters values for an HGF. Takes a dictionary of parameter names and values. """ -function ActionModels.set_parameters!() end +# function ActionModels.set_parameters!() end ### For setting a single parameter ### diff --git a/src/HierarchicalGaussianFiltering.jl b/src/HierarchicalGaussianFiltering.jl index fce8b48..770475c 100644 --- a/src/HierarchicalGaussianFiltering.jl +++ b/src/HierarchicalGaussianFiltering.jl @@ -4,7 +4,7 @@ module HierarchicalGaussianFiltering using ActionModels, Distributions, RecipesBase #Export functions -export init_node, init_hgf, premade_hgf, check_hgf, check_node, update_hgf! +export init_node, init_hgf, premade_hgf, check_hgf, update_hgf! export get_prediction, get_surprise export premade_agent, init_agent, plot_predictive_simulation, plot_trajectory, plot_trajectory! diff --git a/src/create_hgf/hgf_structs.jl b/src/create_hgf/hgf_structs.jl index a60f6ee..c6ff7b3 100644 --- a/src/create_hgf/hgf_structs.jl +++ b/src/create_hgf/hgf_structs.jl @@ -79,7 +79,7 @@ end """ """ Base.@kwdef mutable struct OrderedNodes - all_nodes::Vector{AbstractNode} = [] + all_nodes::Vector{AbstractNode} = AbstractNode[] input_nodes::Vector{AbstractInputNode} = [] all_state_nodes::Vector{AbstractStateNode} = [] early_update_state_nodes::Vector{AbstractStateNode} = [] diff --git a/src/update_hgf/node_updates/continuous_state_node.jl b/src/update_hgf/node_updates/continuous_state_node.jl index 5876fa3..ab32105 100644 --- a/src/update_hgf/node_updates/continuous_state_node.jl +++ b/src/update_hgf/node_updates/continuous_state_node.jl @@ -45,8 +45,10 @@ function calculate_prediction_mean(node::ContinuousStateNode, stepsize::Real) #Transform the parent's value drift_increment = transform_parent_value( parent.states.posterior_mean, - coupling_transform, + coupling_transform; derivation_level = 0, + child = node, + parent = parent, ) #Add the drift increment @@ -225,17 +227,21 @@ function calculate_posterior_precision_increment( #Calculate the increment child.states.prediction_precision * ( coupling_strength^2 * transform_parent_value( - coupling_transform, node.states.posterior_mean, + coupling_transform; derivation_level = 1, + parent = node, + child = child, ) - coupling_strength * - node.states.value_prediction_error * transform_parent_value( - coupling_transform, node.states.posterior_mean, + coupling_transform; derivation_level = 2, - ) + parent = node, + child = child, + ) * + child.states.value_prediction_error ) end @@ -396,9 +402,11 @@ function calculate_posterior_mean_increment( ( child.parameters.coupling_strengths[node.name] * transform_parent_value( - child.parameters.coupling_transforms[node.name], node.states.posterior_mean, + child.parameters.coupling_transforms[node.name]; derivation_level = 1, + parent = node, + child = child, ) * child.states.prediction_precision ) / node.states.posterior_precision @@ -416,9 +424,11 @@ function calculate_posterior_mean_increment( ( child.parameters.coupling_strengths[node.name] * transform_parent_value( - child.parameters.coupling_transforms[node.name], node.states.posterior_mean, + child.parameters.coupling_transforms[node.name]; derivation_level = 1, + parent = node, + child = child, ) * child.states.prediction_precision ) / node.states.prediction_precision diff --git a/src/update_hgf/nonlinear_transforms.jl b/src/update_hgf/nonlinear_transforms.jl index 2217e17..8cc577f 100644 --- a/src/update_hgf/nonlinear_transforms.jl +++ b/src/update_hgf/nonlinear_transforms.jl @@ -1,14 +1,16 @@ #Transformation (of varioius derivations) for a linear transformation function transform_parent_value( - transform_type::LinearTransform, - parent_value::Real; + parent_value::Real, + transform_type::LinearTransform; derivation_level::Integer, + parent::AbstractNode, + child::AbstractNode, ) if derivation_level == 0 return parent_value - elseif derivation_level == 0 + elseif derivation_level == 1 return 1 - elseif derivation_level == 0 + elseif derivation_level == 2 return 0 else @error "derivation level is misspecified" @@ -17,26 +19,28 @@ end #Transformation (of varioius derivations) for a nonlinear transformation function transform_parent_value( - transform_type::NonlinearTransform, parent_value::Real, + transform_type::NonlinearTransform; derivation_level::Integer, + parent::AbstractNode, + child::AbstractNode, ) #Get the transformation function that fits the derivation level if derivation_level == 0 - transform_function = node.parameters.coupling_transforms[parent.name].base_function + transform_function = child.parameters.coupling_transforms[parent.name].base_function elseif derivation_level == 1 transform_function = - node.parameters.coupling_transforms[parent.name].first_derivation + child.parameters.coupling_transforms[parent.name].first_derivation elseif derivation_level == 2 transform_function = - node.parameters.coupling_transforms[parent.name].second_derivation + child.parameters.coupling_transforms[parent.name].second_derivation else @error "derivation level is misspecified" end #Get the transformation parameters - transform_parameters = node.parameters.coupling_transforms[parent.name].parameters + transform_parameters = child.parameters.coupling_transforms[parent.name].parameters #Transform the value return transform_function(parent_value, transform_parameters) diff --git a/test/Project.toml b/test/Project.toml index 38e7d75..60bf59e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" diff --git a/test/testsuite/Aqua.jl b/test/testsuite/Aqua.jl new file mode 100644 index 0000000..ed45564 --- /dev/null +++ b/test/testsuite/Aqua.jl @@ -0,0 +1,3 @@ +using HierarchicalGaussianFiltering +using Aqua +Aqua.test_all(HierarchicalGaussianFiltering, ambiguities = false) diff --git a/test/testsuite/test_custom_structures.jl b/test/testsuite/test_custom_structures.jl new file mode 100644 index 0000000..6d336ef --- /dev/null +++ b/test/testsuite/test_custom_structures.jl @@ -0,0 +1,32 @@ +using Test +using HierarchicalGaussianFiltering + + +@testset "test custom structures" begin + + @testset "Many continuous nodes" begin + nodes = [ + ContinuousInput(name = "u"), + ContinuousInput(name = "u2"), + ContinuousState(name = "x1"), + ContinuousState(name = "x2"), + ContinuousState(name = "x3"), + ContinuousState(name = "x4"), + ContinuousState(name = "x5"), + ] + + edges = Dict( + ("u", "x1") => ObservationCoupling(), + ("u2", "x2") => ObservationCoupling(), + ("x1", "x2") => DriftCoupling(), + ("x1", "x3") => VolatilityCoupling(), + ("u2", "x3") => NoiseCoupling(), + ("x2", "x4") => VolatilityCoupling(), + ("x3", "x5") => DriftCoupling(), + ) + + hgf = init_hgf(nodes = nodes, edges = edges, verbose = false) + + update_hgf!(hgf, [1, 1]) + end +end diff --git a/test/testsuite/test_nonlinear_transforms.jl b/test/testsuite/test_nonlinear_transforms.jl new file mode 100644 index 0000000..fbab0ad --- /dev/null +++ b/test/testsuite/test_nonlinear_transforms.jl @@ -0,0 +1,41 @@ +using Test +using HierarchicalGaussianFiltering + +@testset "Testing nonlinear transforms" begin + + @testset "Sinoid transform" begin + nodes = [ + ContinuousInput(name = "u"), + ContinuousState(name = "x1"), + ContinuousState(name = "x2"), + ] + + base = function (x, parameters::Dict) + sin(x) + end + first_derivative = function (x, parameters::Dict) + cos(x) + end + second_derivative = function (x, parameters::Dict) + -sin(x) + end + transform_parameters = Dict() + + edges = Dict( + ("u", "x1") => ObservationCoupling(), + ("x1", "x2") => DriftCoupling( + 2, + NonlinearTransform( + base, + first_derivative, + second_derivative, + transform_parameters, + ), + ), + ) + + hgf = init_hgf(nodes = nodes, edges = edges, verbose = false) + + update_hgf!(hgf, 1) + end +end