Skip to content

Commit

Permalink
Merge pull request #149 from ilabcode/small_fixes
Browse files Browse the repository at this point in the history
Small fixes
  • Loading branch information
PTWaade authored May 9, 2024
2 parents 1129702 + 8354515 commit 9fc5ced
Show file tree
Hide file tree
Showing 15 changed files with 61 additions and 59 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ authors = [ "Peter Thestrup Waade [email protected]",
"Anna Hedvig Møller [email protected]",
"Jacopo Comoglio [email protected]",
"Christoph Mathys [email protected]"]
version = "0.5.1"
version = "0.5.2"

[deps]
ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
Expand All @@ -15,4 +15,4 @@ RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
ActionModels = "0.5"
Distributions = "0.25"
RecipesBase = "1"
julia = "1.9"
julia = "1.10"
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Expand All @@ -8,6 +9,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
HierarchicalGaussianFiltering = "63d42c3e-681c-42be-892f-a47f35336a79"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModels_variations/utils/get_history.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Gets the history of a state from a specific node in an HGF. A vector of states c
Gets the history of all states for a specific node in an HGF. If only a node object is passed, it will return the history of all states in that node. If only an HGF object is passed, it will return the history of all states in all nodes in the HGF.
"""
function ActionModels.get_history() end
# function ActionModels.get_history() end

### For getting histories of specific states ###
function ActionModels.get_history(hgf::HGF, target_state::Tuple{String,String})
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModels_variations/utils/get_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Gets a single parameter value from a specific node in an HGF. A vector of parame
Gets all parameter values for a specific node in an HGF. If only a node object is passed, returns all parameters in that node. If only an HGF object is passed, returns all parameters of all nodes in the HGF.
"""
function ActionModels.get_parameters() end
# function ActionModels.get_parameters() end

### For getting a specific parameter from a specific node ###
#For parameters other than coupling strengths
Expand Down
6 changes: 4 additions & 2 deletions src/ActionModels_variations/utils/get_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Gets a single state value from a specific node in an HGF. A vector of states can
Gets all parameter values for a specific node in an HGF. If only a node object is passed, returns all states in that node. If only an HGF object is passed, returns all states of all nodes in the HGF.
"""
function ActionModels.get_states() end
# function ActionModels.get_states() end


### For getting a specific state from a specific node ###
Expand Down Expand Up @@ -36,7 +36,9 @@ function ActionModels.get_states(node::AbstractNode, state_name::String)
#If the state does not exist in the node
if !(Symbol(state_name) in fieldnames(typeof(node.states)))
#throw an error
throw(ArgumentError("The node $node_name does not have the state $state_name"))
throw(

Check warning on line 39 in src/ActionModels_variations/utils/get_states.jl

View check run for this annotation

Codecov / codecov/patch

src/ActionModels_variations/utils/get_states.jl#L39

Added line #L39 was not covered by tests
ArgumentError("The node $(node.node_name) does not have the state $state_name"),
)
end

#If a prediction state has been specified
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModels_variations/utils/give_inputs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ give_inputs!(hgf::HGF, inputs)
Give inputs to an agent. Input can be a single value, a vector of values, or an array of values.
"""
function ActionModels.give_inputs!() end
# function ActionModels.give_inputs!() end


### Giving a single input ###
Expand Down
2 changes: 1 addition & 1 deletion src/ActionModels_variations/utils/set_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Setting a single parameter value for an HGF.
Set mutliple parameters values for an HGF. Takes a dictionary of parameter names and values.
"""
function ActionModels.set_parameters!() end
# function ActionModels.set_parameters!() end

### For setting a single parameter ###

Expand Down
2 changes: 1 addition & 1 deletion src/HierarchicalGaussianFiltering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module HierarchicalGaussianFiltering
using ActionModels, Distributions, RecipesBase

#Export functions
export init_node, init_hgf, premade_hgf, check_hgf, check_node, update_hgf!
export init_node, init_hgf, premade_hgf, check_hgf, update_hgf!
export get_prediction, get_surprise
export premade_agent,
init_agent, plot_predictive_simulation, plot_trajectory, plot_trajectory!
Expand Down
2 changes: 1 addition & 1 deletion src/create_hgf/hgf_structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ end
"""
"""
Base.@kwdef mutable struct OrderedNodes
all_nodes::Vector{AbstractNode} = []
all_nodes::Vector{AbstractNode} = AbstractNode[]
input_nodes::Vector{AbstractInputNode} = []
all_state_nodes::Vector{AbstractStateNode} = []
early_update_state_nodes::Vector{AbstractStateNode} = []
Expand Down
4 changes: 2 additions & 2 deletions src/update_hgf/node_updates/continuous_state_node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function calculate_prediction_mean(node::ContinuousStateNode, stepsize::Real)
coupling_transform;
derivation_level = 0,
child = node,
parent = parent
parent = parent,
)

#Add the drift increment
Expand Down Expand Up @@ -240,7 +240,7 @@ function calculate_posterior_precision_increment(
derivation_level = 2,
parent = node,
child = child,
) *
) *
child.states.value_prediction_error
)

Expand Down
6 changes: 3 additions & 3 deletions src/update_hgf/nonlinear_transforms.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#Transformation (of varioius derivations) for a linear transformation
function transform_parent_value(
parent_value::Real,
parent_value::Real,
transform_type::LinearTransform;
derivation_level::Integer,
parent::AbstractNode,
Expand Down Expand Up @@ -31,10 +31,10 @@ function transform_parent_value(
transform_function = child.parameters.coupling_transforms[parent.name].base_function
elseif derivation_level == 1
transform_function =
child.parameters.coupling_transforms[parent.name].first_derivation
child.parameters.coupling_transforms[parent.name].first_derivation
elseif derivation_level == 2
transform_function =
child.parameters.coupling_transforms[parent.name].second_derivation
child.parameters.coupling_transforms[parent.name].second_derivation
else
@error "derivation level is misspecified"
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GR = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Expand Down
3 changes: 3 additions & 0 deletions test/testsuite/Aqua.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
using HierarchicalGaussianFiltering
using Aqua
Aqua.test_all(HierarchicalGaussianFiltering, ambiguities = false)
55 changes: 24 additions & 31 deletions test/testsuite/test_custom_structures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,28 @@ using HierarchicalGaussianFiltering
@testset "test custom structures" begin

@testset "Many continuous nodes" begin
nodes = [
ContinuousInput(name = "u"),
ContinuousInput(name = "u2"),

ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
ContinuousState(name = "x3"),
ContinuousState(name = "x4"),
ContinuousState(name = "x5"),

]

edges = Dict(
("u", "x1") => ObservationCoupling(),
("u2", "x2") => ObservationCoupling(),

("x1", "x2") => DriftCoupling(),
("x1", "x3") => VolatilityCoupling(),
("u2", "x3") => NoiseCoupling(),
("x2", "x4") => VolatilityCoupling(),
("x3", "x5") => DriftCoupling(),
)

hgf = init_hgf(
nodes = nodes,
edges = edges,
verbose = false,
)

update_hgf!(hgf, [1,1])
nodes = [
ContinuousInput(name = "u"),
ContinuousInput(name = "u2"),
ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
ContinuousState(name = "x3"),
ContinuousState(name = "x4"),
ContinuousState(name = "x5"),
]

edges = Dict(
("u", "x1") => ObservationCoupling(),
("u2", "x2") => ObservationCoupling(),
("x1", "x2") => DriftCoupling(),
("x1", "x3") => VolatilityCoupling(),
("u2", "x3") => NoiseCoupling(),
("x2", "x4") => VolatilityCoupling(),
("x3", "x5") => DriftCoupling(),
)

hgf = init_hgf(nodes = nodes, edges = edges, verbose = false)

update_hgf!(hgf, [1, 1])
end
end
end
27 changes: 14 additions & 13 deletions test/testsuite/test_nonlinear_transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,40 @@ using Test
using HierarchicalGaussianFiltering

@testset "Testing nonlinear transforms" begin

@testset "Sinoid transform" begin
nodes = [
ContinuousInput(name = "u"),
ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
]

base = function(x, parameters::Dict)
base = function (x, parameters::Dict)
sin(x)
end
first_derivative = function(x, parameters::Dict)
first_derivative = function (x, parameters::Dict)
cos(x)
end
second_derivative = function(x, parameters::Dict)
second_derivative = function (x, parameters::Dict)
-sin(x)
end
transform_parameters = Dict()

edges = Dict(
("u", "x1") => ObservationCoupling(),
("x1", "x2") => DriftCoupling(2,NonlinearTransform(base, first_derivative, second_derivative, transform_parameters)),
("x1", "x2") => DriftCoupling(
2,
NonlinearTransform(
base,
first_derivative,
second_derivative,
transform_parameters,
),
),
)

hgf = init_hgf(
nodes = nodes,
edges = edges,
verbose = false,
)
hgf = init_hgf(nodes = nodes, edges = edges, verbose = false)

update_hgf!(hgf, 1)
end
end



0 comments on commit 9fc5ced

Please sign in to comment.