Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fixes #149

Merged
merged 3 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ authors = [ "Peter Thestrup Waade [email protected]",
"Anna Hedvig Møller [email protected]",
"Jacopo Comoglio [email protected]",
"Christoph Mathys [email protected]"]
version = "0.5.1"
version = "0.5.2"

[deps]
ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
Expand All @@ -15,4 +15,4 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
ActionModels = "0.5"
Distributions = "0.25"
RecipesBase = "1"
julia = "1.9"
julia = "1.10"
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -8,6 +9,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
HierarchicalGaussianFiltering = "63d42c3e-681c-42be-892f-a47f35336a79"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModels_variations/utils/get_history.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Gets the history of a state from a specific node in an HGF. A vector of states c

Gets the history of all states for a specific node in an HGF. If only a node object is passed, it will return the history of all states in that node. If only an HGF object is passed, it will return the history of all states in all nodes in the HGF.
"""
function ActionModels.get_history() end
# function ActionModels.get_history() end

### For getting histories of specific states ###
function ActionModels.get_history(hgf::HGF, target_state::Tuple{String,String})
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModels_variations/utils/get_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Gets a single parameter value from a specific node in an HGF. A vector of parame

Gets all parameter values for a specific node in an HGF. If only a node object is passed, returns all parameters in that node. If only an HGF object is passed, returns all parameters of all nodes in the HGF.
"""
function ActionModels.get_parameters() end
# function ActionModels.get_parameters() end

### For getting a specific parameter from a specific node ###
#For parameters other than coupling strengths
Expand Down
6 changes: 4 additions & 2 deletions src/ActionModels_variations/utils/get_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

Gets all parameter values for a specific node in an HGF. If only a node object is passed, returns all states in that node. If only an HGF object is passed, returns all states of all nodes in the HGF.
"""
function ActionModels.get_states() end
# function ActionModels.get_states() end


### For getting a specific state from a specific node ###
Expand Down Expand Up @@ -36,7 +36,9 @@
#If the state does not exist in the node
if !(Symbol(state_name) in fieldnames(typeof(node.states)))
#throw an error
throw(ArgumentError("The node $node_name does not have the state $state_name"))
throw(

Check warning on line 39 in src/ActionModels_variations/utils/get_states.jl

View check run for this annotation

Codecov / codecov/patch

src/ActionModels_variations/utils/get_states.jl#L39

Added line #L39 was not covered by tests
ArgumentError("The node $(node.node_name) does not have the state $state_name"),
)
end

#If a prediction state has been specified
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModels_variations/utils/give_inputs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ give_inputs!(hgf::HGF, inputs)

Give inputs to an agent. Input can be a single value, a vector of values, or an array of values.
"""
function ActionModels.give_inputs!() end
# function ActionModels.give_inputs!() end


### Giving a single input ###
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModels_variations/utils/set_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Setting a single parameter value for an HGF.

Set mutliple parameters values for an HGF. Takes a dictionary of parameter names and values.
"""
function ActionModels.set_parameters!() end
# function ActionModels.set_parameters!() end

### For setting a single parameter ###

Expand Down
2 changes: 1 addition & 1 deletion src/HierarchicalGaussianFiltering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module HierarchicalGaussianFiltering
using ActionModels, Distributions, RecipesBase

#Export functions
export init_node, init_hgf, premade_hgf, check_hgf, check_node, update_hgf!
export init_node, init_hgf, premade_hgf, check_hgf, update_hgf!
export get_prediction, get_surprise
export premade_agent,
init_agent, plot_predictive_simulation, plot_trajectory, plot_trajectory!
Expand Down
2 changes: 1 addition & 1 deletion src/create_hgf/hgf_structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ end
"""
"""
Base.@kwdef mutable struct OrderedNodes
all_nodes::Vector{AbstractNode} = []
all_nodes::Vector{AbstractNode} = AbstractNode[]
input_nodes::Vector{AbstractInputNode} = []
all_state_nodes::Vector{AbstractStateNode} = []
early_update_state_nodes::Vector{AbstractStateNode} = []
Expand Down
4 changes: 2 additions & 2 deletions src/update_hgf/node_updates/continuous_state_node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function calculate_prediction_mean(node::ContinuousStateNode, stepsize::Real)
coupling_transform;
derivation_level = 0,
child = node,
parent = parent
parent = parent,
)

#Add the drift increment
Expand Down Expand Up @@ -240,7 +240,7 @@ function calculate_posterior_precision_increment(
derivation_level = 2,
parent = node,
child = child,
) *
) *
child.states.value_prediction_error
)

Expand Down
6 changes: 3 additions & 3 deletions src/update_hgf/nonlinear_transforms.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#Transformation (of varioius derivations) for a linear transformation
function transform_parent_value(
parent_value::Real,
parent_value::Real,
transform_type::LinearTransform;
derivation_level::Integer,
parent::AbstractNode,
Expand Down Expand Up @@ -31,10 +31,10 @@ function transform_parent_value(
transform_function = child.parameters.coupling_transforms[parent.name].base_function
elseif derivation_level == 1
transform_function =
child.parameters.coupling_transforms[parent.name].first_derivation
child.parameters.coupling_transforms[parent.name].first_derivation
elseif derivation_level == 2
transform_function =
child.parameters.coupling_transforms[parent.name].second_derivation
child.parameters.coupling_transforms[parent.name].second_derivation
else
@error "derivation level is misspecified"
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Expand Down
3 changes: 3 additions & 0 deletions test/testsuite/Aqua.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
using HierarchicalGaussianFiltering
using Aqua
Aqua.test_all(HierarchicalGaussianFiltering, ambiguities = false)
55 changes: 24 additions & 31 deletions test/testsuite/test_custom_structures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,28 @@ using HierarchicalGaussianFiltering
@testset "test custom structures" begin

@testset "Many continuous nodes" begin
nodes = [
ContinuousInput(name = "u"),
ContinuousInput(name = "u2"),

ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
ContinuousState(name = "x3"),
ContinuousState(name = "x4"),
ContinuousState(name = "x5"),

]

edges = Dict(
("u", "x1") => ObservationCoupling(),
("u2", "x2") => ObservationCoupling(),

("x1", "x2") => DriftCoupling(),
("x1", "x3") => VolatilityCoupling(),
("u2", "x3") => NoiseCoupling(),
("x2", "x4") => VolatilityCoupling(),
("x3", "x5") => DriftCoupling(),
)

hgf = init_hgf(
nodes = nodes,
edges = edges,
verbose = false,
)

update_hgf!(hgf, [1,1])
nodes = [
ContinuousInput(name = "u"),
ContinuousInput(name = "u2"),
ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
ContinuousState(name = "x3"),
ContinuousState(name = "x4"),
ContinuousState(name = "x5"),
]

edges = Dict(
("u", "x1") => ObservationCoupling(),
("u2", "x2") => ObservationCoupling(),
("x1", "x2") => DriftCoupling(),
("x1", "x3") => VolatilityCoupling(),
("u2", "x3") => NoiseCoupling(),
("x2", "x4") => VolatilityCoupling(),
("x3", "x5") => DriftCoupling(),
)

hgf = init_hgf(nodes = nodes, edges = edges, verbose = false)

update_hgf!(hgf, [1, 1])
end
end
end
27 changes: 14 additions & 13 deletions test/testsuite/test_nonlinear_transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,40 @@ using Test
using HierarchicalGaussianFiltering

@testset "Testing nonlinear transforms" begin

@testset "Sinoid transform" begin
nodes = [
ContinuousInput(name = "u"),
ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
]

base = function(x, parameters::Dict)
base = function (x, parameters::Dict)
sin(x)
end
first_derivative = function(x, parameters::Dict)
first_derivative = function (x, parameters::Dict)
cos(x)
end
second_derivative = function(x, parameters::Dict)
second_derivative = function (x, parameters::Dict)
-sin(x)
end
transform_parameters = Dict()

edges = Dict(
("u", "x1") => ObservationCoupling(),
("x1", "x2") => DriftCoupling(2,NonlinearTransform(base, first_derivative, second_derivative, transform_parameters)),
("x1", "x2") => DriftCoupling(
2,
NonlinearTransform(
base,
first_derivative,
second_derivative,
transform_parameters,
),
),
)

hgf = init_hgf(
nodes = nodes,
edges = edges,
verbose = false,
)
hgf = init_hgf(nodes = nodes, edges = edges, verbose = false)

update_hgf!(hgf, 1)
end
end