Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Uhgf (step 1) #168

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions docs/julia_files/tutorials/classic_binary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@ using StatsPlots
using Distributions

# Get the path for the HGF superfolder
hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering)))
#hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering)))
# Add the path to the data files
data_path = hgf_path * "/docs/julia_files/tutorials/data/"
data_path = "docs/julia_files/tutorials/data/"

# Load the data
inputs = CSV.read(data_path * "classic_binary_inputs.csv", DataFrame)[!, 1];

# Create an HGF
hgf_parameters = Dict(
("xprob", "volatility") => -2.5,
("xprob", "volatility") => -3,
("xprob", "initial_mean") => 0,
("xprob", "initial_precision") => 1,
("xvol", "volatility") => -6.0,
("xvol", "volatility") => -3,
("xvol", "initial_mean") => 1,
("xvol", "initial_precision") => 1,
("xbin", "xprob", "coupling_strength") => 1.0,
Expand Down Expand Up @@ -58,11 +58,13 @@ fixed_parameters = Dict(
("xvol", "initial_precision") => 1,
("xbin", "xprob", "coupling_strength") => 1.0,
("xprob", "xvol", "coupling_strength") => 1.0,
("xvol", "volatility") => -6.0,
);

# Set priors for parameter recovery
param_priors = Dict(("xprob", "volatility") => Normal(-3.0, 0.5));
param_priors = Dict(
("xprob", "volatility") => Normal(-2.0, 1.0),
("xvol", "volatility") => Normal(-2.0, 1.0),
);
#-
# Prior predictive plot
plot_predictive_simulation(
Expand All @@ -74,14 +76,14 @@ plot_predictive_simulation(
)
#-
# Get the actions from the MATLAB tutorial
actions = CSV.read(data_path * "classic_binary_actions.csv", DataFrame)[!, 1];
# actions = CSV.read(data_path * "classic_binary_actions.csv", DataFrame)[!, 1];
#-
# Fit the actions
#Create model
model = create_model(agent, param_priors, inputs, actions)

#Fit single chain with 10 iterations
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)
fitted_model = fit_model(model; n_iterations = 1000)
#-
#Plot the chains
plot(fitted_model)
Expand Down
103 changes: 103 additions & 0 deletions docs/notebooks/uhgf.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
### A Pluto.jl notebook ###
# v0.19.42

using Markdown
using InteractiveUtils

# ╔═╡ 4aa60710-27bc-11ef-3c69-fb6843531225
begin
import Pkg
# activate the shared project environment
Pkg.activate(Base.current_project())
# instantiate, i.e. make sure that all packages are downloaded
Pkg.instantiate()
using ActionModels
using HierarchicalGaussianFiltering
using CSV
using DataFrames
using Plots
using StatsPlots
using Distributions
end


# ╔═╡ 17cd12ca-907e-4346-a01a-83d905a195a4
md"# Testing the uHGF"

# ╔═╡ cf614e5a-a713-4400-9144-60c7887b3a46
md"## Define a standard binary 3-level HGF"

# ╔═╡ aa28184d-9773-4fad-bc01-7962d22fac78
hgf_parameters = Dict(
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "volatility") => 2.0,
("xprob", "initial_mean") => 0,
("xprob", "initial_precision") => 1,
("xvol", "volatility") => -3.0,
("xvol", "initial_mean") => 1,
("xvol", "initial_precision") => 1,
("xbin", "xprob", "coupling_strength") => 1.0,
("xprob", "xvol", "coupling_strength") => 1.0,
);

# ╔═╡ f962d241-48f9-4110-83fb-45726579c709
hgf = premade_hgf("binary_3level", hgf_parameters, verbose = false)

# ╔═╡ dd1a3a31-29b0-4791-84bf-a9affa9df3f2
md"## Create an agent"

# ╔═╡ f1b4804c-c867-46ad-8cd1-e0cfba2ade45
agent_parameters = Dict("action_noise" => 0.2);

# ╔═╡ 25844c3b-c4ee-46a7-b6bb-e3ad876c785a
agent = premade_agent("hgf_unit_square_sigmoid", hgf, agent_parameters, verbose = false)

# ╔═╡ b00a41be-7050-4b4f-a188-57e840acb53a
md"## Get inputs and evolve the agent"

# ╔═╡ df44d01d-4297-43ac-9b0d-627a6a17de21
md"## Trajectories"

# ╔═╡ b641bdc2-30c4-4cbd-bc8d-369d8912e182
begin
plot_trajectory(agent, ("u", "input_value"))
plot_trajectory!(agent, ("xbin", "prediction"))
end

# ╔═╡ f2d9b80e-41cf-425d-a1ce-c3c3316ecd1d
plot_trajectory(agent, ("xprob", "posterior"))

# ╔═╡ 4c4956a7-d9b4-42ce-85fe-4f156d180224
plot_trajectory(agent, ("xvol", "posterior"))

# ╔═╡ acfc9fe9-e557-4882-8e9e-b567916661a7
md"## Configuration"

# ╔═╡ 6eacb1dc-074c-4eae-81d4-73e078f1dbb7
data_path = "../julia_files/tutorials/data/"

# ╔═╡ 292e9b1d-0281-4e9f-b0db-d220cd29b322
inputs = CSV.read(data_path * "classic_binary_inputs.csv", DataFrame)[!, 1];

# ╔═╡ ce870931-2a2b-4f5c-98ea-847fc652a69f
actions = give_inputs!(agent, inputs)

# ╔═╡ Cell order:
# ╟─17cd12ca-907e-4346-a01a-83d905a195a4
# ╟─cf614e5a-a713-4400-9144-60c7887b3a46
# ╠═aa28184d-9773-4fad-bc01-7962d22fac78
# ╠═f962d241-48f9-4110-83fb-45726579c709
# ╟─dd1a3a31-29b0-4791-84bf-a9affa9df3f2
# ╠═f1b4804c-c867-46ad-8cd1-e0cfba2ade45
# ╠═25844c3b-c4ee-46a7-b6bb-e3ad876c785a
# ╟─b00a41be-7050-4b4f-a188-57e840acb53a
# ╠═292e9b1d-0281-4e9f-b0db-d220cd29b322
# ╠═ce870931-2a2b-4f5c-98ea-847fc652a69f
# ╟─df44d01d-4297-43ac-9b0d-627a6a17de21
# ╠═b641bdc2-30c4-4cbd-bc8d-369d8912e182
# ╠═f2d9b80e-41cf-425d-a1ce-c3c3316ecd1d
# ╠═4c4956a7-d9b4-42ce-85fe-4f156d180224
# ╟─acfc9fe9-e557-4882-8e9e-b567916661a7
# ╠═6eacb1dc-074c-4eae-81d4-73e078f1dbb7
# ╠═4aa60710-27bc-11ef-3c69-fb6843531225
1 change: 1 addition & 0 deletions src/HierarchicalGaussianFiltering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,6 @@ include("premade_models/premade_hgfs/premade_JGET.jl")
include("utils/get_prediction.jl")
include("utils/get_surprise.jl")
include("utils/pretty_printing.jl")
include("utils/helper_functions.jl")

end
2 changes: 1 addition & 1 deletion src/premade_models/premade_agents/premade_softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function hgf_binary_softmax_action(agent::Agent, input)
target_value = get_states(hgf, target_state)

#Use sotmax to get the action probability
action_probability = 1 / (1 + exp(action_noise * target_value))
action_probability = 1 / (1 + capped_exp(action_noise * target_value))

#If the action probability is not between 0 and 1
if !(0 <= action_probability <= 1)
Expand Down
3 changes: 2 additions & 1 deletion src/update_hgf/node_updates/binary_state_node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ function calculate_prediction_mean(node::BinaryStateNode)
parent.states.prediction_mean * node.parameters.coupling_strengths[parent.name]
end

prediction_mean = 1 / (1 + exp(-prediction_mean))
# Logistic transform into probability
prediction_mean = capped_logistic(prediction_mean)

return prediction_mean
end
Expand Down
2 changes: 1 addition & 1 deletion src/update_hgf/node_updates/continuous_input_node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ function calculate_prediction_precision(node::ContinuousInputNode)
end

#The prediction precision is the inverse of the predicted noise
prediction_precision = 1 / exp(predicted_noise)
prediction_precision = 1 / capped_exp(predicted_noise)

return prediction_precision
end
Expand Down
14 changes: 2 additions & 12 deletions src/update_hgf/node_updates/continuous_state_node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ function calculate_prediction_precision(node::ContinuousStateNode, stepsize::Rea
end

#Exponentiate and multiply with stepsize
predicted_volatility = stepsize * exp(predicted_volatility)
predicted_volatility = stepsize * capped_exp(predicted_volatility)

#Calculate prediction precision
prediction_precision = 1 / (1 / node.states.posterior_precision + predicted_volatility)
Expand Down Expand Up @@ -293,20 +293,10 @@ function calculate_posterior_precision_increment(
coupling_type::VolatilityCoupling,
)

1 / 2 *
(
child.parameters.coupling_strengths[node.name] *
child.states.effective_prediction_precision
)^2 +
child.states.precision_prediction_error *
(
child.parameters.coupling_strengths[node.name] *
child.states.effective_prediction_precision
)^2 -
1 / 2 *
child.parameters.coupling_strengths[node.name]^2 *
child.states.effective_prediction_precision *
child.states.precision_prediction_error
(1 - child.states.effective_prediction_precision)
end

function calculate_posterior_precision_increment(
Expand Down
14 changes: 14 additions & 0 deletions src/utils/helper_functions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Function to calculate the exponential of a number, but capped to avoid numerical instability
function capped_exp(x::T) where {T<:Real}
return exp(max(min(x, 700), -700))
end

# Capped logistic transform, to avoid numerical instability
function capped_logistic(x::T) where {T<:Real}

#Do the logistic transform with a capped exp
out = 1 / (1 + capped_exp(-x))

#Ensure numerical stability by avoiding extremes
return min(max(out, 0 + eps(T)), 1 - eps(T))
end
Loading