Skip to content

Commit

Permalink
first JET edits
Browse files Browse the repository at this point in the history
  • Loading branch information
PTWaade committed May 4, 2024
1 parent 169018a commit 9956a27
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 53 deletions.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
HierarchicalGaussianFiltering = "63d42c3e-681c-42be-892f-a47f35336a79"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand Down
4 changes: 3 additions & 1 deletion src/ActionModels_variations/utils/get_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ function ActionModels.get_states(node::AbstractNode, state_name::String)
#If the state does not exist in the node
if !(Symbol(state_name) in fieldnames(typeof(node.states)))
#throw an error
throw(ArgumentError("The node $node_name does not have the state $state_name"))
throw(
ArgumentError("The node $(node.node_name) does not have the state $state_name"),
)
end

#If a prediction state has been specified
Expand Down
2 changes: 1 addition & 1 deletion src/create_hgf/hgf_structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ end
"""
"""
Base.@kwdef mutable struct OrderedNodes
all_nodes::Vector{AbstractNode} = []
all_nodes::Vector{AbstractNode} = AbstractNode[]
input_nodes::Vector{AbstractInputNode} = []
all_state_nodes::Vector{AbstractStateNode} = []
early_update_state_nodes::Vector{AbstractStateNode} = []
Expand Down
4 changes: 2 additions & 2 deletions src/update_hgf/node_updates/continuous_state_node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ function calculate_prediction_mean(node::ContinuousStateNode, stepsize::Real)
coupling_transform;
derivation_level = 0,
child = node,
parent = parent
parent = parent,
)

#Add the drift increment
Expand Down Expand Up @@ -240,7 +240,7 @@ function calculate_posterior_precision_increment(
derivation_level = 2,
parent = node,
child = child,
) *
) *
child.states.value_prediction_error
)

Expand Down
6 changes: 3 additions & 3 deletions src/update_hgf/nonlinear_transforms.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#Transformation (of varioius derivations) for a linear transformation
function transform_parent_value(
parent_value::Real,
parent_value::Real,
transform_type::LinearTransform;
derivation_level::Integer,
parent::AbstractNode,
Expand Down Expand Up @@ -31,10 +31,10 @@ function transform_parent_value(
transform_function = child.parameters.coupling_transforms[parent.name].base_function
elseif derivation_level == 1
transform_function =
child.parameters.coupling_transforms[parent.name].first_derivation
child.parameters.coupling_transforms[parent.name].first_derivation
elseif derivation_level == 2
transform_function =
child.parameters.coupling_transforms[parent.name].second_derivation
child.parameters.coupling_transforms[parent.name].second_derivation
else
@error "derivation level is misspecified"
end
Expand Down
3 changes: 1 addition & 2 deletions test/testsuite/Aqua.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using HierarchicalGaussianFiltering
using Aqua
Aqua.test_all(HierarchicalGaussianFiltering,
ambiguities = false)
Aqua.test_all(HierarchicalGaussianFiltering, ambiguities = false)
55 changes: 24 additions & 31 deletions test/testsuite/test_custom_structures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,28 @@ using HierarchicalGaussianFiltering
@testset "test custom structures" begin

@testset "Many continuous nodes" begin
nodes = [
ContinuousInput(name = "u"),
ContinuousInput(name = "u2"),

ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
ContinuousState(name = "x3"),
ContinuousState(name = "x4"),
ContinuousState(name = "x5"),

]

edges = Dict(
("u", "x1") => ObservationCoupling(),
("u2", "x2") => ObservationCoupling(),

("x1", "x2") => DriftCoupling(),
("x1", "x3") => VolatilityCoupling(),
("u2", "x3") => NoiseCoupling(),
("x2", "x4") => VolatilityCoupling(),
("x3", "x5") => DriftCoupling(),
)

hgf = init_hgf(
nodes = nodes,
edges = edges,
verbose = false,
)

update_hgf!(hgf, [1,1])
nodes = [
ContinuousInput(name = "u"),
ContinuousInput(name = "u2"),
ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
ContinuousState(name = "x3"),
ContinuousState(name = "x4"),
ContinuousState(name = "x5"),
]

edges = Dict(
("u", "x1") => ObservationCoupling(),
("u2", "x2") => ObservationCoupling(),
("x1", "x2") => DriftCoupling(),
("x1", "x3") => VolatilityCoupling(),
("u2", "x3") => NoiseCoupling(),
("x2", "x4") => VolatilityCoupling(),
("x3", "x5") => DriftCoupling(),
)

hgf = init_hgf(nodes = nodes, edges = edges, verbose = false)

update_hgf!(hgf, [1, 1])
end
end
end
27 changes: 14 additions & 13 deletions test/testsuite/test_nonlinear_transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,40 @@ using Test
using HierarchicalGaussianFiltering

@testset "Testing nonlinear transforms" begin

@testset "Sinoid transform" begin
nodes = [
ContinuousInput(name = "u"),
ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
]

base = function(x, parameters::Dict)
base = function (x, parameters::Dict)
sin(x)
end
first_derivative = function(x, parameters::Dict)
first_derivative = function (x, parameters::Dict)
cos(x)
end
second_derivative = function(x, parameters::Dict)
second_derivative = function (x, parameters::Dict)
-sin(x)
end
transform_parameters = Dict()

edges = Dict(
("u", "x1") => ObservationCoupling(),
("x1", "x2") => DriftCoupling(2,NonlinearTransform(base, first_derivative, second_derivative, transform_parameters)),
("x1", "x2") => DriftCoupling(
2,
NonlinearTransform(
base,
first_derivative,
second_derivative,
transform_parameters,
),
),
)

hgf = init_hgf(
nodes = nodes,
edges = edges,
verbose = false,
)
hgf = init_hgf(nodes = nodes, edges = edges, verbose = false)

update_hgf!(hgf, 1)
end
end



0 comments on commit 9956a27

Please sign in to comment.