Skip to content

Commit

Permalink
added test for nonlinear transforms; fixed minor bug
Browse files Browse the repository at this point in the history
  • Loading branch information
PTWaade committed Jul 23, 2024
1 parent d931cd1 commit ec2a4f3
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
186 changes: 186 additions & 0 deletions docs/julia_files/tutorials/sinoid_transform.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
### A Pluto.jl notebook ###
# v0.19.45

using Markdown
using InteractiveUtils

# ╔═╡ e47a29a4-38b0-4f19-972a-7c1c3e83c2f7
using Pkg;
Pkg.activate("../../"); #Activate the docs environment

# ╔═╡ 5ae02fee-48f1-11ef-209c-539e778f577d
using HierarchicalGaussianFiltering, Distributions, StatsPlots

# ╔═╡ ff1abe62-b477-4c34-927d-4cc758981d60
nodes = [
ContinuousInput(name = "u"),
ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
]

# ╔═╡ b53c65b6-24fb-4df0-85e8-c94e973f92d1
sine_transform = NonlinearTransform(
function (x, parameters::Dict)
sin(x)
end, #base function
function (x, parameters::Dict)
cos(x)
end, #first derivative
function (x, parameters::Dict)
-sin(x)
end, #second derivative
Dict(), #no parameters
)

# ╔═╡ c1e30a10-e9f9-4348-92b5-85fdc0be00b0
begin
edges_linear =
Dict(("u", "x1") => ObservationCoupling(), ("x1", "x2") => DriftCoupling())

edges_nonlinear = Dict(
("u", "x1") => ObservationCoupling(),
("x1", "x2") => DriftCoupling(1, sine_transform),
)

end

# ╔═╡ ec6ceba5-d202-4345-bce8-fc0dada2a019
begin
hgf_linear = init_hgf(nodes = nodes, edges = edges_linear, verbose = false)
hgf_nonlinear = init_hgf(nodes = nodes, edges = edges_nonlinear, verbose = false)
end

# ╔═╡ d2e95dab-7312-4835-b780-91571ba83239
parameters = Dict(
("x1", "autoconnection_strength") => 0,
("x1", "volatility") => -4,
("x2", "volatility") => -4,
("u", "input_noise") => log(0.25),
)

# ╔═╡ e319c480-cb6d-41fd-a4ec-892a12b95a06
begin
set_parameters!(hgf_linear, parameters)

set_parameters!(hgf_nonlinear, parameters)
end

# ╔═╡ cb170251-4dce-4e75-b457-40f2fb959a5b
begin

sample_rate = 20

inputs = sin.(collect(0:1/sample_rate:35))

inputs = rand(Normal(0, 0.25), length(inputs)) + inputs
plot(inputs)
end

# ╔═╡ 847373ac-647d-43a3-93a1-92af9772dce0
begin
reset!(hgf_linear)
give_inputs!(hgf_linear, inputs)
reset!(hgf_nonlinear)
give_inputs!(hgf_nonlinear, inputs)
end

# ╔═╡ c7f92940-18c8-4212-8808-8b29d9b870a1
plot_settings = (; label = "", title = "")

# ╔═╡ 5835d516-f822-4128-900b-16d09b550cb7

plot(
plot_trajectory(hgf_linear, "u"; plot_settings..., title = "inputs"),
plot_trajectory(
hgf_linear,
("u", "prediction");
plot_settings...,
title = "u prediction",
),
plot_trajectory(hgf_linear, ("x1"); plot_settings..., title = "x1 posterior"),
plot_trajectory(hgf_linear, ("x2"); plot_settings..., title = "x2 posterior"),
)

# ╔═╡ f1ad8c91-7642-4478-8091-bbcecc9f2bcf

plot(
plot_trajectory(hgf_nonlinear, "u"; plot_settings..., title = "inputs"),
plot_trajectory(
hgf_nonlinear,
("u", "prediction");
plot_settings...,
title = "u prediction",
),
plot_trajectory(hgf_nonlinear, ("x1"); plot_settings..., title = "x1 posterior"),
plot_trajectory(hgf_nonlinear, ("x2"); plot_settings..., title = "x2 posterior"),
)

# ╔═╡ 50330bc3-e74d-4cfc-908a-a56acbcdd2bc
begin
μ₁_linear = get_history(hgf_linear, ("x1", "posterior_mean"))

#Band mu1 between -1 and 1 (the noise makes it occasionally jump over)
μ₁_linear = min.(μ₁_linear, 1.0)
μ₁_linear = max.(μ₁_linear, -1.0)

μ₂_linear = get_history(hgf_linear, ("x2", "posterior_mean"))



μ₁_nonlinear = get_history(hgf_nonlinear, ("x1", "posterior_mean"))

#Band mu1 between -1 and 1 (the noise makes it occasionally jump over)
μ₁_nonlinear = min.(μ₁_nonlinear, 1.0)
μ₁_nonlinear = max.(μ₁_nonlinear, -1.0)

μ₂_nonlinear = get_history(hgf_nonlinear, ("x2", "posterior_mean"))
end

# ╔═╡ fa51e10d-3ab3-487c-90bc-9362956f563a
#Plot sine transformed x2 against x1 - should be equal
plot(
plot(sin.(μ₂_linear) - μ₁_linear, title = "linear"),
plot(sin.(μ₂_nonlinear) - μ₁_nonlinear, title = "nonlinear"),
)

# ╔═╡ c52a0167-e780-47fc-901c-44cf84b973d4
#Plot asine transformed x1 against x2 - should be equal
plot(
plot(asin.(μ₁_linear) - μ₂_linear, title = "linear"),
plot(asin.(μ₁_nonlinear) - μ₂_nonlinear, title = "nonlinear"),
)
#The linear fares worse

# ╔═╡ af36ec40-ca4f-4aff-ae35-439adc5cae89
begin
linear_plt = plot(asin.(μ₁_linear), label = "μ₁ asin")
plot!(μ₂_linear, label = "μ₂", title = "linear")

nonlinear_plt = plot(asin.(μ₁_nonlinear), label = "μ₁ asin")
plot!(μ₂_nonlinear, label = "μ₂", title = "nonlinear")

plot(linear_plt, nonlinear_plt)
end

# ╔═╡ 0a6546fe-9f5e-4d2a-8c02-59d96332b0a9
###NEXT STEPS: PREDICT FURTHER IN THE FUTURE

# ╔═╡ Cell order:
# ╠═e47a29a4-38b0-4f19-972a-7c1c3e83c2f7
# ╠═5ae02fee-48f1-11ef-209c-539e778f577d
# ╠═ff1abe62-b477-4c34-927d-4cc758981d60
# ╠═b53c65b6-24fb-4df0-85e8-c94e973f92d1
# ╠═c1e30a10-e9f9-4348-92b5-85fdc0be00b0
# ╠═ec6ceba5-d202-4345-bce8-fc0dada2a019
# ╠═d2e95dab-7312-4835-b780-91571ba83239
# ╠═e319c480-cb6d-41fd-a4ec-892a12b95a06
# ╠═cb170251-4dce-4e75-b457-40f2fb959a5b
# ╠═847373ac-647d-43a3-93a1-92af9772dce0
# ╠═c7f92940-18c8-4212-8808-8b29d9b870a1
# ╠═5835d516-f822-4128-900b-16d09b550cb7
# ╠═f1ad8c91-7642-4478-8091-bbcecc9f2bcf
# ╠═50330bc3-e74d-4cfc-908a-a56acbcdd2bc
# ╠═fa51e10d-3ab3-487c-90bc-9362956f563a
# ╠═c52a0167-e780-47fc-901c-44cf84b973d4
# ╠═af36ec40-ca4f-4aff-ae35-439adc5cae89
# ╠═0a6546fe-9f5e-4d2a-8c02-59d96332b0a9
4 changes: 3 additions & 1 deletion src/create_hgf/hgf_structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ struct EnhancedUpdate <: HGFUpdateType end
#Types for specifying nonlinear transformations
abstract type CouplingTransform end

Base.@kwdef mutable struct LinearTransform <: CouplingTransform end
Base.@kwdef mutable struct LinearTransform <: CouplingTransform
parameters::Dict = Dict()
end

Base.@kwdef mutable struct NonlinearTransform <: CouplingTransform
base_function::Function
Expand Down
13 changes: 12 additions & 1 deletion test/testsuite/test_nonlinear_transforms.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test
using HierarchicalGaussianFiltering
using Distributions

@testset "Testing nonlinear transforms" begin

Expand Down Expand Up @@ -36,6 +37,16 @@ using HierarchicalGaussianFiltering

hgf = init_hgf(nodes = nodes, edges = edges, verbose = false)

update_hgf!(hgf, 1)
set_parameters!(
hgf,
Dict(("u", "input_noise") => 4("x1", "autoconnection_strength") => 0),
)

inputs = sin.(collect(0:1/20:1000/20))

#Add gaussian noise
inputs = rand(Normal(0, 0.5), length(inputs)) + inputs

give_inputs!(hgf, inputs)
end
end

0 comments on commit ec2a4f3

Please sign in to comment.