Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update dev from main #156

Merged
merged 6 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ authors = [ "Peter Thestrup Waade [email protected]",
"Christoph Mathys [email protected]"]
version = "0.5.4"


[deps]
ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
186 changes: 186 additions & 0 deletions docs/julia_files/tutorials/sinoid_transform.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
### A Pluto.jl notebook ###
# v0.19.45

using Markdown
using InteractiveUtils

# ╔═╡ e47a29a4-38b0-4f19-972a-7c1c3e83c2f7
using Pkg;
Pkg.activate("../../"); #Activate the docs environment

# ╔═╡ 5ae02fee-48f1-11ef-209c-539e778f577d
using HierarchicalGaussianFiltering, Distributions, StatsPlots

# ╔═╡ ff1abe62-b477-4c34-927d-4cc758981d60
nodes = [
ContinuousInput(name = "u"),
ContinuousState(name = "x1"),
ContinuousState(name = "x2"),
]

# ╔═╡ b53c65b6-24fb-4df0-85e8-c94e973f92d1
sine_transform = NonlinearTransform(
function (x, parameters::Dict)
sin(x)
end, #base function
function (x, parameters::Dict)
cos(x)
end, #first derivative
function (x, parameters::Dict)
-sin(x)
end, #second derivative
Dict(), #no parameters
)

# ╔═╡ c1e30a10-e9f9-4348-92b5-85fdc0be00b0
begin
edges_linear =
Dict(("u", "x1") => ObservationCoupling(), ("x1", "x2") => DriftCoupling())

edges_nonlinear = Dict(
("u", "x1") => ObservationCoupling(),
("x1", "x2") => DriftCoupling(1, sine_transform),
)

end

# ╔═╡ ec6ceba5-d202-4345-bce8-fc0dada2a019
begin
hgf_linear = init_hgf(nodes = nodes, edges = edges_linear, verbose = false)
hgf_nonlinear = init_hgf(nodes = nodes, edges = edges_nonlinear, verbose = false)
end

# ╔═╡ d2e95dab-7312-4835-b780-91571ba83239
parameters = Dict(
("x1", "autoconnection_strength") => 0,
("x1", "volatility") => -4,
("x2", "volatility") => -4,
("u", "input_noise") => log(0.25),
)

# ╔═╡ e319c480-cb6d-41fd-a4ec-892a12b95a06
begin
set_parameters!(hgf_linear, parameters)

set_parameters!(hgf_nonlinear, parameters)
end

# ╔═╡ cb170251-4dce-4e75-b457-40f2fb959a5b
begin

sample_rate = 20

inputs = sin.(collect(0:1/sample_rate:35))

inputs = rand(Normal(0, 0.25), length(inputs)) + inputs
plot(inputs)
end

# ╔═╡ 847373ac-647d-43a3-93a1-92af9772dce0
begin
reset!(hgf_linear)
give_inputs!(hgf_linear, inputs)
reset!(hgf_nonlinear)
give_inputs!(hgf_nonlinear, inputs)
end

# ╔═╡ c7f92940-18c8-4212-8808-8b29d9b870a1
plot_settings = (; label = "", title = "")

# ╔═╡ 5835d516-f822-4128-900b-16d09b550cb7

plot(
plot_trajectory(hgf_linear, "u"; plot_settings..., title = "inputs"),
plot_trajectory(
hgf_linear,
("u", "prediction");
plot_settings...,
title = "u prediction",
),
plot_trajectory(hgf_linear, ("x1"); plot_settings..., title = "x1 posterior"),
plot_trajectory(hgf_linear, ("x2"); plot_settings..., title = "x2 posterior"),
)

# ╔═╡ f1ad8c91-7642-4478-8091-bbcecc9f2bcf

plot(
plot_trajectory(hgf_nonlinear, "u"; plot_settings..., title = "inputs"),
plot_trajectory(
hgf_nonlinear,
("u", "prediction");
plot_settings...,
title = "u prediction",
),
plot_trajectory(hgf_nonlinear, ("x1"); plot_settings..., title = "x1 posterior"),
plot_trajectory(hgf_nonlinear, ("x2"); plot_settings..., title = "x2 posterior"),
)

# ╔═╡ 50330bc3-e74d-4cfc-908a-a56acbcdd2bc
begin
μ₁_linear = get_history(hgf_linear, ("x1", "posterior_mean"))

#Band mu1 between -1 and 1 (the noise makes it occasionally jump over)
μ₁_linear = min.(μ₁_linear, 1.0)
μ₁_linear = max.(μ₁_linear, -1.0)

μ₂_linear = get_history(hgf_linear, ("x2", "posterior_mean"))



μ₁_nonlinear = get_history(hgf_nonlinear, ("x1", "posterior_mean"))

#Band mu1 between -1 and 1 (the noise makes it occasionally jump over)
μ₁_nonlinear = min.(μ₁_nonlinear, 1.0)
μ₁_nonlinear = max.(μ₁_nonlinear, -1.0)

μ₂_nonlinear = get_history(hgf_nonlinear, ("x2", "posterior_mean"))
end

# ╔═╡ fa51e10d-3ab3-487c-90bc-9362956f563a
#Plot sine transformed x2 against x1 - should be equal
plot(
plot(sin.(μ₂_linear) - μ₁_linear, title = "linear"),
plot(sin.(μ₂_nonlinear) - μ₁_nonlinear, title = "nonlinear"),
)

# ╔═╡ c52a0167-e780-47fc-901c-44cf84b973d4
#Plot asine transformed x1 against x2 - should be equal
plot(
plot(asin.(μ₁_linear) - μ₂_linear, title = "linear"),
plot(asin.(μ₁_nonlinear) - μ₂_nonlinear, title = "nonlinear"),
)
#The linear fares worse

# ╔═╡ af36ec40-ca4f-4aff-ae35-439adc5cae89
begin
linear_plt = plot(asin.(μ₁_linear), label = "μ₁ asin")
plot!(μ₂_linear, label = "μ₂", title = "linear")

nonlinear_plt = plot(asin.(μ₁_nonlinear), label = "μ₁ asin")
plot!(μ₂_nonlinear, label = "μ₂", title = "nonlinear")

plot(linear_plt, nonlinear_plt)
end

# ╔═╡ 0a6546fe-9f5e-4d2a-8c02-59d96332b0a9
###NEXT STEPS: PREDICT FURTHER IN THE FUTURE

# ╔═╡ Cell order:
# ╠═e47a29a4-38b0-4f19-972a-7c1c3e83c2f7
# ╠═5ae02fee-48f1-11ef-209c-539e778f577d
# ╠═ff1abe62-b477-4c34-927d-4cc758981d60
# ╠═b53c65b6-24fb-4df0-85e8-c94e973f92d1
# ╠═c1e30a10-e9f9-4348-92b5-85fdc0be00b0
# ╠═ec6ceba5-d202-4345-bce8-fc0dada2a019
# ╠═d2e95dab-7312-4835-b780-91571ba83239
# ╠═e319c480-cb6d-41fd-a4ec-892a12b95a06
# ╠═cb170251-4dce-4e75-b457-40f2fb959a5b
# ╠═847373ac-647d-43a3-93a1-92af9772dce0
# ╠═c7f92940-18c8-4212-8808-8b29d9b870a1
# ╠═5835d516-f822-4128-900b-16d09b550cb7
# ╠═f1ad8c91-7642-4478-8091-bbcecc9f2bcf
# ╠═50330bc3-e74d-4cfc-908a-a56acbcdd2bc
# ╠═fa51e10d-3ab3-487c-90bc-9362956f563a
# ╠═c52a0167-e780-47fc-901c-44cf84b973d4
# ╠═af36ec40-ca4f-4aff-ae35-439adc5cae89
# ╠═0a6546fe-9f5e-4d2a-8c02-59d96332b0a9
4 changes: 3 additions & 1 deletion src/create_hgf/hgf_structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ struct EnhancedUpdate <: HGFUpdateType end
#Types for specifying nonlinear transformations
abstract type CouplingTransform end

Base.@kwdef mutable struct LinearTransform <: CouplingTransform end
Base.@kwdef mutable struct LinearTransform <: CouplingTransform
parameters::Dict = Dict()
end

Base.@kwdef mutable struct NonlinearTransform <: CouplingTransform
base_function::Function
Expand Down
16 changes: 15 additions & 1 deletion test/testsuite/test_nonlinear_transforms.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test
using HierarchicalGaussianFiltering
using Distributions

@testset "Testing nonlinear transforms" begin

Expand Down Expand Up @@ -36,6 +37,19 @@ using HierarchicalGaussianFiltering

hgf = init_hgf(nodes = nodes, edges = edges, verbose = false)

update_hgf!(hgf, 1)
set_parameters!(
hgf,
Dict(
("u", "input_noise") => 4,
("x1", "autoconnection_strength") => 0
),
)

inputs = sin.(collect(0:1/20:1000/20))

#Add gaussian noise
inputs = rand(Normal(0, 0.5), length(inputs)) + inputs

give_inputs!(hgf, inputs)
end
end
Loading