Skip to content

Commit

Permalink
Merge branch 'dev' into documentation_polish
Browse files Browse the repository at this point in the history
  • Loading branch information
PTWaade authored Nov 2, 2023
2 parents 7828d7d + 2be18d5 commit 124c88f
Show file tree
Hide file tree
Showing 38 changed files with 48,325 additions and 293 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
arch:
- x64
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CI_small.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
arch:
- x64
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
name: Documentation
runs-on: macOS-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
with:
version: '1'
Expand Down
11 changes: 7 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
name = "HierarchicalGaussianFiltering"
uuid = "63d42c3e-681c-42be-892f-a47f35336a79"
authors = ["Peter Thestrup Waade [email protected]", "Anna Hedvig Møller [email protected]", "Jacopo Comoglio [email protected]", "Christoph Mathys [email protected]\n and contributors"]
version = "0.3.1"
authors = [ "Peter Thestrup Waade [email protected]",
"Anna Hedvig Møller [email protected]",
"Jacopo Comoglio [email protected]",
"Christoph Mathys [email protected]"]
version = "0.3.3"

[deps]
ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"

[compat]
ActionModels = "0.3"
ActionModels = "0.4"
Distributions = "0.25"
RecipesBase = "1"
julia = "1.8"
julia = "1.9"
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ plot_trajectory!(agent, ("x1", "prediction"))

````@example index
using Distributions
prior = Dict(("x2", "evolution_rate") => Normal(1, 0.5))
prior = Dict(("x2", "volatility") => Normal(1, 0.5))
model = fit_model(agent, prior, inputs, actions, n_iterations = 20)
````
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
HierarchicalGaussianFiltering = "63d42c3e-681c-42be-892f-a47f35336a79"
Expand All @@ -10,3 +11,4 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
2 changes: 1 addition & 1 deletion docs/src/Julia_src_files/building_an_HGF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ state_nodes = [
Dict(
"name" => "continuous_state_node",
"type" => "continuous",
"evolution_rate" => -2,
"volatility" => -2,
"initial_mean" => 0,
"initial_precision" => 1,
),
Expand Down
8 changes: 4 additions & 4 deletions docs/src/Julia_src_files/fitting_hgf_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ using HierarchicalGaussianFiltering
hgf_parameters = Dict(
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("x2", "evolution_rate") => -2.5,
("x2", "volatility") => -2.5,
("x2", "initial_mean") => 0,
("x2", "initial_precision") => 1,
("x3", "evolution_rate") => -6.0,
("x3", "volatility") => -6.0,
("x3", "initial_mean") => 1,
("x3", "initial_precision") => 1,
("x1", "x2", "value_coupling") => 1.0,
Expand Down Expand Up @@ -95,12 +95,12 @@ fixed_parameters = Dict(
("x3", "initial_precision") => 1,
("x1", "x2", "value_coupling") => 1.0,
("x2", "x3", "volatility_coupling") => 1.0,
("x3", "evolution_rate") => -6.0,
("x3", "volatility") => -6.0,
);

# As you can read from the fixed parameters, the evolution rate of x2 is not configured. We set the prior for the x2 evolution rate:
using Distributions
param_priors = Dict(("x2", "evolution_rate") => Normal(-3.0, 0.5));
param_priors = Dict(("x2", "volatility") => Normal(-3.0, 0.5));

# We can fit the evolution rate by inputting the variables:

Expand Down
2 changes: 1 addition & 1 deletion docs/src/Julia_src_files/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ plot_trajectory!(agent, ("x1", "prediction"))
# ### Fitting parameters

using Distributions
prior = Dict(("x2", "evolution_rate") => Normal(1, 0.5))
prior = Dict(("x2", "volatility") => Normal(1, 0.5))

model = fit_model(agent, prior, inputs, actions, n_iterations = 20)

Expand Down
10 changes: 5 additions & 5 deletions docs/src/Julia_src_files/premade_HGF.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ plot_trajectory(agent_binary_3_level, ("x3", "posterior"))

# ## Categorical 3-level state transition HGF

# The categorical 3-level HGF model learns state transition probabilities between a set of n categorical startes.
# The categorical 3-level HGF model learns state transition probabilities between a set of categorical states.

# - input node: categorical
# - input node: categorical input nodes
# - state nodes:
# - 1st level: n categorical state nodes (value coupling to input node)
# - 2nd level: n binary state nodes pr. n categorical state nodes (value coupling from each categorical state node to n binary state nodes)
# - 3rd level: continous (volatility coupling to all nodes in 2nd level (n x n nodes))
# - 1st level: categorical state nodes (value coupling to input node)
# - 2nd level: binary state nodes for each categorical state node (value coupling from each categorical state node to binary state nodes)
# - 3rd level: continous (volatility coupling to all nodes in 2nd level)
2 changes: 1 addition & 1 deletion docs/src/Julia_src_files/premade_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ set_parameters!(agent, ("x3", "initial_precision"), 0.4)
# Set multiple parameter values
set_parameters!(
agent,
Dict(("x3", "initial_precision") => 1, ("x3", "evolution_rate") => 0),
Dict(("x3", "initial_precision") => 1, ("x3", "volatility") => 0),
)


Expand Down
49 changes: 49 additions & 0 deletions docs/src/Julia_src_files/the_HGF_nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,55 @@
# ## Building principles

# The following rules apply for connecting nodes, when customizing your own HGF structure:
# ### Parameters

# - no parameters in the categorical state node

# ### The states of Categorical input nodes and parameters

# - input value

# ### Parameters

# - no parameters in the categorical state node

# ## Continuous Nodes

# ### The states of Continuous state nodes and parameters

# #### States

# - posterior mean
# - posterior precision
# - value prediction error
# - volatility prediction error
# - prediction mean
# - prediciton volatility
# - prediction precision
# - auxiliary prediction precision

# ### Parameters

# - evolution rate (default is 0)
# - value coupling
# - volatility coupling
# - initial mean (default is 0)
# - initital precision (default is 0)


# ### The states of Continuous input nodes and parameters

# - input value
# - value prediction error
# - volatility prediction error
# - prediction volatility
# - prediction precision

# ### Parameters

# - input noise (default is 0)
# - value coupling
# - volatility coupling

# ### Binary state node rules:

Expand Down
8 changes: 4 additions & 4 deletions docs/src/Julia_src_files/utility_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ get_parameters(agent)
# ERROR WITH THIS get_parameters(agent, ("x2", "x3", "volatility_coupling"))

# getting multiple parameters specify them in a vector
get_parameters(agent, [("x3", "evolution_rate"), ("x3", "initial_precision")])
get_parameters(agent, [("x3", "volatility"), ("x3", "initial_precision")])


# ### Getting States
Expand All @@ -46,7 +46,7 @@ get_states(agent)
get_states(agent, ("x2", "posterior_precision"))

#getting multiple states
get_states(agent, [("x2", "posterior_precision"), ("x2", "auxiliary_prediction_precision")])
get_states(agent, [("x2", "posterior_precision"), ("x2", "volatility_weighted_prediction_precision")])


# ### Setting Parameters
Expand All @@ -61,10 +61,10 @@ agent_parameter = Dict("sigmoid_action_precision" => 3)
hgf_parameters = Dict(
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("x2", "evolution_rate") => -2.5,
("x2", "volatility") => -2.5,
("x2", "initial_mean") => 0,
("x2", "initial_precision") => 1,
("x3", "evolution_rate") => -6.0,
("x3", "volatility") => -6.0,
("x3", "initial_mean") => 1,
("x3", "initial_precision") => 1,
("x1", "x2", "value_coupling") => 1.0,
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ plot_trajectory!(agent, ("x1", "prediction"))

````@example index
using Distributions
prior = Dict(("x2", "evolution_rate") => Normal(1, 0.5))
prior = Dict(("x2", "volatility") => Normal(1, 0.5))
model = fit_model(agent, prior, inputs, actions, n_iterations = 20)
````
Expand Down
72 changes: 72 additions & 0 deletions docs/src/tutorials/classic_JGET.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using ActionModels, HierarchicalGaussianFiltering
using CSV, DataFrames
using Plots, StatsPlots
using Distributions
path = "docs/src/tutorials/data"

#Load data
data = CSV.read("$path/classic_cannonball_data.csv", DataFrame)

#Create HGF
hgf = premade_hgf("JGET", verbose = false)
#Create agent
agent = premade_agent("hgf_gaussian_action", hgf)
#Set parameters
parameters = Dict(
"gaussian_action_precision" => 1,
("x1", "volatility") => -8,
("x2", "volatility") => -5,
("x3", "volatility") => -5,
("x4", "volatility") => -5,
("x1", "x2", "volatility_coupling") => 1,
("x3", "x4", "volatility_coupling") => 1,
)
set_parameters!(agent, parameters)

inputs = data[(data.ID.==20).&(data.session.==1), :].outcome
#Simulate updates and actions
actions = give_inputs!(agent, inputs);
#Plot belief trajectories
plot_trajectory(agent, "u")
plot_trajectory!(agent, "x1")
plot_trajectory(agent, "x2")
plot_trajectory(agent, "x3")
plot_trajectory(agent, "x4")

priors = Dict(
"gaussian_action_precision" => LogNormal(-1, 0.1),
("x1", "volatility") => Normal(-8, 1),
)

data_subset = data[(data.ID.∈[[20, 21]]).&(data.session.∈[[1, 2]]), :]

using Distributed
addprocs(6, exeflags = "--project")
@everywhere @eval using HierarchicalGaussianFiltering

results = fit_model(
agent,
priors,
data_subset,
independent_group_cols = [:ID, :session],
input_cols = [:outcome],
action_cols = [:response],
n_cores = 6,
)

fitted_model = results[(20, 1)]

plot_parameter_distribution(fitted_model, priors)




posterior = get_posteriors(fitted_model)

set_parameters!(agent, posterior)

reset!(agent)

give_inputs!(agent, inputs)

get_history(agent, ("x1", "value_prediction_error"))
8 changes: 4 additions & 4 deletions docs/src/tutorials/classic_binary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ inputs = CSV.read(data_path * "classic_binary_inputs.csv", DataFrame)[!, 1];
hgf_parameters = Dict(
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("x2", "evolution_rate") => -2.5,
("x2", "volatility") => -2.5,
("x2", "initial_mean") => 0,
("x2", "initial_precision") => 1,
("x3", "evolution_rate") => -6.0,
("x3", "volatility") => -6.0,
("x3", "initial_mean") => 1,
("x3", "initial_precision") => 1,
("x1", "x2", "value_coupling") => 1.0,
Expand Down Expand Up @@ -64,11 +64,11 @@ fixed_parameters = Dict(
("x3", "initial_precision") => 1,
("x1", "x2", "value_coupling") => 1.0,
("x2", "x3", "volatility_coupling") => 1.0,
("x3", "evolution_rate") => -6.0,
("x3", "volatility") => -6.0,
);

# Set priors for parameter recovery
param_priors = Dict(("x2", "evolution_rate") => Normal(-3.0, 0.5));
param_priors = Dict(("x2", "volatility") => Normal(-3.0, 0.5));
#-
# Prior predictive plot
plot_predictive_simulation(
Expand Down
26 changes: 9 additions & 17 deletions docs/src/tutorials/classic_usdchf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ end
hgf = premade_hgf("continuous_2level", verbose = false);
agent = premade_agent("hgf_gaussian_action", hgf, verbose = false);

# Set parameters for parameter recovyer
# Set parameters for parameter recover
parameters = Dict(
("u", "x1", "value_coupling") => 1.0,
("x1", "x2", "volatility_coupling") => 1.0,
("u", "evolution_rate") => -log(1e4),
("x1", "evolution_rate") => -13,
("x2", "evolution_rate") => -2,
("u", "input_noise") => -log(1e4),
("x1", "volatility") => -13,
("x2", "volatility") => -2,
("x1", "initial_mean") => 1.04,
("x1", "initial_precision") => 1 / (0.0001),
("x2", "initial_mean") => 1.0,
Expand Down Expand Up @@ -91,9 +91,9 @@ fixed_parameters = Dict(
);

param_priors = Dict(
("u", "evolution_rate") => Normal(-10, 2),
("x1", "evolution_rate") => Normal(-10, 4),
("x2", "evolution_rate") => Normal(-4, 4),
("u", "input_noise") => Normal(-6, 1),
("x1", "volatility") => Normal(-4, 1),
("x2", "volatility") => Normal(-4, 1),
);
#-
# Prior predictive simulation plot
Expand All @@ -102,7 +102,7 @@ plot_predictive_simulation(
agent,
inputs,
("x1", "posterior_mean");
n_simulations = 3,
n_simulations = 100,
)
#-
# Do parameter recovery
Expand All @@ -112,7 +112,7 @@ fitted_model = fit_model(
inputs,
actions,
fixed_parameters = fixed_parameters,
verbose = true,
verbose = false,
n_iterations = 10,
)
#-
Expand All @@ -122,11 +122,3 @@ plot(fitted_model)
# Plot prior posterior distributions
plot_parameter_distribution(fitted_model, param_priors)
#-
# Posterior predictive plot
plot_predictive_simulation(
fitted_model,
agent,
inputs,
("x1", "posterior_mean");
n_simulations = 3,
)
Loading

0 comments on commit 124c88f

Please sign in to comment.