Skip to content

Commit

Permalink
Merge pull request #150 from ilabcode/dev
Browse files Browse the repository at this point in the history
Version 0.5.2
  • Loading branch information
PTWaade authored May 9, 2024
2 parents e46e07b + 9fc5ced commit d931cd1
Show file tree
Hide file tree
Showing 18 changed files with 125 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CI_small.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
- uses: julia-actions/julia-processcoverage@v1
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ authors = [ "Peter Thestrup Waade [email protected]",
"Anna Hedvig Møller [email protected]",
"Jacopo Comoglio [email protected]",
"Christoph Mathys [email protected]"]
version = "0.5.1"
version = "0.5.2"

[deps]
ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
Expand All @@ -15,4 +15,4 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
ActionModels = "0.5"
Distributions = "0.25"
RecipesBase = "1"
julia = "1.9"
julia = "1.10"
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
[![Build Status](https://github.com/ilabcode/HierarchicalGaussianFiltering.jl/actions/workflows/CI_full.yml/badge.svg?branch=main)](https://github.com/ilabcode/HierarchicalGaussianFiltering.jl/actions/workflows/CI_full.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/ilabcode/HierarchicalGaussianFiltering.jl/branch/main/graph/badge.svg?token=NVFiiPydFA)](https://codecov.io/gh/ilabcode/HierarchicalGaussianFiltering.jl)
[![License: GNU](https://img.shields.io/badge/License-GNU-yellow)](<https://www.gnu.org/licenses/>)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)


# Welcome to The Hierarchical Gaussian Filtering Package!
Expand Down
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
[deps]
ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
HierarchicalGaussianFiltering = "63d42c3e-681c-42be-892f-a47f35336a79"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModels_variations/utils/get_history.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Gets the history of a state from a specific node in an HGF. A vector of states c
Gets the history of all states for a specific node in an HGF. If only a node object is passed, it will return the history of all states in that node. If only an HGF object is passed, it will return the history of all states in all nodes in the HGF.
"""
function ActionModels.get_history() end
# function ActionModels.get_history() end

### For getting histories of specific states ###
function ActionModels.get_history(hgf::HGF, target_state::Tuple{String,String})
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModels_variations/utils/get_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Gets a single parameter value from a specific node in an HGF. A vector of parame
Gets all parameter values for a specific node in an HGF. If only a node object is passed, returns all parameters in that node. If only an HGF object is passed, returns all parameters of all nodes in the HGF.
"""
function ActionModels.get_parameters() end
# function ActionModels.get_parameters() end

### For getting a specific parameter from a specific node ###
#For parameters other than coupling strengths
Expand Down
6 changes: 4 additions & 2 deletions src/ActionModels_variations/utils/get_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Gets a single state value from a specific node in an HGF. A vector of states can
Gets all parameter values for a specific node in an HGF. If only a node object is passed, returns all states in that node. If only an HGF object is passed, returns all states of all nodes in the HGF.
"""
function ActionModels.get_states() end
# function ActionModels.get_states() end


### For getting a specific state from a specific node ###
Expand Down Expand Up @@ -36,7 +36,9 @@ function ActionModels.get_states(node::AbstractNode, state_name::String)
#If the state does not exist in the node
if !(Symbol(state_name) in fieldnames(typeof(node.states)))
#throw an error
throw(ArgumentError("The node $node_name does not have the state $state_name"))
throw(
ArgumentError("The node $(node.node_name) does not have the state $state_name"),
)
end

#If a prediction state has been specified
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModels_variations/utils/give_inputs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ give_inputs!(hgf::HGF, inputs)
Give inputs to an agent. Input can be a single value, a vector of values, or an array of values.
"""
function ActionModels.give_inputs!() end
# function ActionModels.give_inputs!() end


### Giving a single input ###
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModels_variations/utils/set_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Setting a single parameter value for an HGF.
Set mutliple parameters values for an HGF. Takes a dictionary of parameter names and values.
"""
function ActionModels.set_parameters!() end
# function ActionModels.set_parameters!() end

### For setting a single parameter ###

Expand Down
2 changes: 1 addition & 1 deletion src/HierarchicalGaussianFiltering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module HierarchicalGaussianFiltering
using ActionModels, Distributions, RecipesBase

#Export functions
export init_node, init_hgf, premade_hgf, check_hgf, check_node, update_hgf!
export init_node, init_hgf, premade_hgf, check_hgf, update_hgf!
export get_prediction, get_surprise
export premade_agent,
init_agent, plot_predictive_simulation, plot_trajectory, plot_trajectory!
Expand Down
2 changes: 1 addition & 1 deletion src/create_hgf/hgf_structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ end
"""
"""
Base.@kwdef mutable struct OrderedNodes
all_nodes::Vector{AbstractNode} = []
all_nodes::Vector{AbstractNode} = AbstractNode[]
input_nodes::Vector{AbstractInputNode} = []
all_state_nodes::Vector{AbstractStateNode} = []
early_update_state_nodes::Vector{AbstractStateNode} = []
Expand Down
24 changes: 17 additions & 7 deletions src/update_hgf/node_updates/continuous_state_node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ function calculate_prediction_mean(node::ContinuousStateNode, stepsize::Real)
#Transform the parent's value
drift_increment = transform_parent_value(
parent.states.posterior_mean,
coupling_transform,
coupling_transform;
derivation_level = 0,
child = node,
parent = parent,
)

#Add the drift increment
Expand Down Expand Up @@ -225,17 +227,21 @@ function calculate_posterior_precision_increment(
#Calculate the increment
child.states.prediction_precision * (
coupling_strength^2 * transform_parent_value(
coupling_transform,
node.states.posterior_mean,
coupling_transform;
derivation_level = 1,
parent = node,
child = child,
) -
coupling_strength *
node.states.value_prediction_error *
transform_parent_value(
coupling_transform,
node.states.posterior_mean,
coupling_transform;
derivation_level = 2,
)
parent = node,
child = child,
) *
child.states.value_prediction_error
)

end
Expand Down Expand Up @@ -396,9 +402,11 @@ function calculate_posterior_mean_increment(
(
child.parameters.coupling_strengths[node.name] *
transform_parent_value(
child.parameters.coupling_transforms[node.name],
node.states.posterior_mean,
child.parameters.coupling_transforms[node.name];
derivation_level = 1,
parent = node,
child = child,
) *
child.states.prediction_precision
) / node.states.posterior_precision
Expand All @@ -416,9 +424,11 @@ function calculate_posterior_mean_increment(
(
child.parameters.coupling_strengths[node.name] *
transform_parent_value(
child.parameters.coupling_transforms[node.name],
node.states.posterior_mean,
child.parameters.coupling_transforms[node.name];
derivation_level = 1,
parent = node,
child = child,
) *
child.states.prediction_precision
) / node.states.prediction_precision
Expand Down
22 changes: 13 additions & 9 deletions src/update_hgf/nonlinear_transforms.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
#Transformation (of varioius derivations) for a linear transformation
function transform_parent_value(
transform_type::LinearTransform,
parent_value::Real;
parent_value::Real,
transform_type::LinearTransform;
derivation_level::Integer,
parent::AbstractNode,
child::AbstractNode,
)
if derivation_level == 0
return parent_value
elseif derivation_level == 0
elseif derivation_level == 1
return 1
elseif derivation_level == 0
elseif derivation_level == 2
return 0
else
@error "derivation level is misspecified"
Expand All @@ -17,26 +19,28 @@ end

#Transformation (of varioius derivations) for a nonlinear transformation
function transform_parent_value(
transform_type::NonlinearTransform,
parent_value::Real,
transform_type::NonlinearTransform;
derivation_level::Integer,
parent::AbstractNode,
child::AbstractNode,
)

#Get the transformation function that fits the derivation level
if derivation_level == 0
transform_function = node.parameters.coupling_transforms[parent.name].base_function
transform_function = child.parameters.coupling_transforms[parent.name].base_function
elseif derivation_level == 1
transform_function =
node.parameters.coupling_transforms[parent.name].first_derivation
child.parameters.coupling_transforms[parent.name].first_derivation
elseif derivation_level == 2
transform_function =
node.parameters.coupling_transforms[parent.name].second_derivation
child.parameters.coupling_transforms[parent.name].second_derivation
else
@error "derivation level is misspecified"
end

#Get the transformation parameters
transform_parameters = node.parameters.coupling_transforms[parent.name].parameters
transform_parameters = child.parameters.coupling_transforms[parent.name].parameters

#Transform the value
return transform_function(parent_value, transform_parameters)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Expand Down
3 changes: 3 additions & 0 deletions test/testsuite/Aqua.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
using HierarchicalGaussianFiltering
using Aqua
Aqua.test_all(HierarchicalGaussianFiltering, ambiguities = false)
32 changes: 32 additions & 0 deletions test/testsuite/test_custom_structures.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
using Test
using HierarchicalGaussianFiltering


@testset "test custom structures" begin

@testset "Many continuous nodes" begin
nodes = [
ContinuousInput(name = "u"),
ContinuousInput(name = "u2"),
ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
ContinuousState(name = "x3"),
ContinuousState(name = "x4"),
ContinuousState(name = "x5"),
]

edges = Dict(
("u", "x1") => ObservationCoupling(),
("u2", "x2") => ObservationCoupling(),
("x1", "x2") => DriftCoupling(),
("x1", "x3") => VolatilityCoupling(),
("u2", "x3") => NoiseCoupling(),
("x2", "x4") => VolatilityCoupling(),
("x3", "x5") => DriftCoupling(),
)

hgf = init_hgf(nodes = nodes, edges = edges, verbose = false)

update_hgf!(hgf, [1, 1])
end
end
41 changes: 41 additions & 0 deletions test/testsuite/test_nonlinear_transforms.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using Test
using HierarchicalGaussianFiltering

@testset "Testing nonlinear transforms" begin

@testset "Sinoid transform" begin
nodes = [
ContinuousInput(name = "u"),
ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
]

base = function (x, parameters::Dict)
sin(x)
end
first_derivative = function (x, parameters::Dict)
cos(x)
end
second_derivative = function (x, parameters::Dict)
-sin(x)
end
transform_parameters = Dict()

edges = Dict(
("u", "x1") => ObservationCoupling(),
("x1", "x2") => DriftCoupling(
2,
NonlinearTransform(
base,
first_derivative,
second_derivative,
transform_parameters,
),
),
)

hgf = init_hgf(nodes = nodes, edges = edges, verbose = false)

update_hgf!(hgf, 1)
end
end

4 comments on commit d931cd1

@PTWaade
Copy link
Collaborator Author

@PTWaade PTWaade commented on d931cd1 May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register
minor bugfixes

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/106483

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.2 -m "<description of version>" d931cd1438787d497b005f748065a2ee08316810
git push origin v0.5.2

@PTWaade
Copy link
Collaborator Author

@PTWaade PTWaade commented on d931cd1 May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

Minor bugfixes
Added Aqua tests

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/106483

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.2 -m "<description of version>" d931cd1438787d497b005f748065a2ee08316810
git push origin v0.5.2

Please sign in to comment.