From 3313ff9772345a70314a58481ddab15f97dc7809 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Fri, 10 Nov 2023 10:50:17 +0100 Subject: [PATCH 01/16] fixed deprecated function --- test/test_fit_model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_fit_model.jl b/test/test_fit_model.jl index 8b4d164..3d8ee28 100644 --- a/test/test_fit_model.jl +++ b/test/test_fit_model.jl @@ -106,7 +106,7 @@ using Turing ) test_param_priors = Dict( - "softmax_action_precision" => Truncated(Normal(100, 20), 0, Inf), + "softmax_action_precision" => truncated(Normal(100, 20), 0, Inf), ("x2", "volatility") => Normal(-7, 5), ) From 6a0a057fd62ad70753fc8ab349e8b9977d78804a Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Fri, 10 Nov 2023 11:54:40 +0100 Subject: [PATCH 02/16] renamed premade node names --- README.md | 12 +- .../src/Julia_src_files/fitting_hgf_models.jl | 44 +- docs/src/Julia_src_files/index.jl | 12 +- docs/src/Julia_src_files/premade_HGF.jl | 14 +- docs/src/Julia_src_files/premade_models.jl | 20 +- docs/src/Julia_src_files/utility_functions.jl | 46 +- docs/src/index.md | 12 +- docs/src/tutorials/classic_JGET.jl | 24 +- docs/src/tutorials/classic_binary.jl | 42 +- docs/src/tutorials/classic_usdchf.jl | 38 +- src/create_hgf/init_hgf.jl | 22 +- src/premade_models/premade_agents.jl | 22 +- src/premade_models/premade_hgfs.jl | 720 +++++++++--------- src/utils/get_prediction.jl | 6 +- .../canonical_continuous2level_states.csv | 2 +- test/test_canonical.jl | 76 +- test/test_fit_model.jl | 40 +- 17 files changed, 576 insertions(+), 576 deletions(-) diff --git a/README.md b/README.md index 5e548de..95ccc8d 100644 --- a/README.md +++ b/README.md @@ -55,10 +55,10 @@ get_parameters(agent) ![Image1](docs/src/images/readme/get_parameters.png) -Set a new parameter for initial precision of x2 and define some inputs +Set a new parameter for initial precision of xprob and define some inputs ````@example index -set_parameters!(agent, ("x2", "initial_precision"), 0.9) +set_parameters!(agent, ("xprob", "initial_precision"), 0.9) inputs = [1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0]; nothing #hide ```` @@ -75,23 +75,23 @@ actions = give_inputs!(agent, inputs) using StatsPlots using Plots plot_trajectory(agent, ("u", "input_value")) -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("x", "prediction")) ```` ![Image1](docs/src/images/readme/plot_trajectory.png) -Plot state trajectory of input value, action and prediction of x1 +Plot state trajectory of input value, action and prediction of x ````@example index plot_trajectory(agent, ("u", "input_value")) plot_trajectory!(agent, "action") -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("x", "prediction")) ```` ![Image1](docs/src/images/readme/plot_trajectory_2.png) ### Fitting parameters ````@example index using Distributions -prior = Dict(("x2", "volatility") => Normal(1, 0.5)) +prior = Dict(("xprob", "volatility") => Normal(1, 0.5)) model = fit_model(agent, prior, inputs, actions, n_iterations = 20) ```` diff --git a/docs/src/Julia_src_files/fitting_hgf_models.jl b/docs/src/Julia_src_files/fitting_hgf_models.jl index 3970a92..935b197 100644 --- a/docs/src/Julia_src_files/fitting_hgf_models.jl +++ b/docs/src/Julia_src_files/fitting_hgf_models.jl @@ -47,14 +47,14 @@ using HierarchicalGaussianFiltering hgf_parameters = Dict( ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "volatility") => -2.5, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - ("x3", "volatility") => -6.0, - ("x3", "initial_mean") => 1, - ("x3", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, + ("xprob", "volatility") => -2.5, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "volatility") => -6.0, + ("xvol", "initial_mean") => 1, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "value_coupling") => 1.0, + ("xprob", "xvol", "volatility_coupling") => 1.0, ) hgf = premade_hgf("binary_3level", hgf_parameters, verbose = false) @@ -76,7 +76,7 @@ actions = give_inputs!(agent, inputs) using StatsPlots using Plots plot_trajectory(agent, ("u", "input_value")) -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("xbin", "prediction")) @@ -84,23 +84,23 @@ plot_trajectory!(agent, ("x1", "prediction")) # We define a set of fixed parameters to use in this fitting process: -# Set fixed parameters. We choose to fit the evolution rate of the x2 node. +# Set fixed parameters. We choose to fit the evolution rate of the xprob node. fixed_parameters = Dict( "sigmoid_action_precision" => 5, ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - ("x3", "initial_mean") => 1, - ("x3", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, - ("x3", "volatility") => -6.0, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "initial_mean") => 1, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "value_coupling") => 1.0, + ("xprob", "xvol", "volatility_coupling") => 1.0, + ("xvol", "volatility") => -6.0, ); -# As you can read from the fixed parameters, the evolution rate of x2 is not configured. We set the prior for the x2 evolution rate: +# As you can read from the fixed parameters, the evolution rate of xprob is not configured. We set the prior for the xprob evolution rate: using Distributions -param_priors = Dict(("x2", "volatility") => Normal(-3.0, 0.5)); +param_priors = Dict(("xprob", "volatility") => Normal(-3.0, 0.5)); # We can fit the evolution rate by inputting the variables: @@ -132,7 +132,7 @@ plot_parameter_distribution(fitted_model, param_priors) # We will provide a code example of prior and posterior predictive simulation. We can fit a different parameter, and start with a prior predictive check. # Set prior we wish to simulate over -param_priors = Dict(("x3", "initial_precision") => Normal(1.0, 0.5)); +param_priors = Dict(("xvol", "initial_precision") => Normal(1.0, 0.5)); # When we look at our predictive simulation plot we should aim to see actions in the plausible space they could be in. # Prior predictive plot @@ -140,7 +140,7 @@ plot_predictive_simulation( param_priors, agent, inputs, - ("x1", "prediction_mean"), + ("xbin", "prediction_mean"), n_simulations = 100, ) @@ -157,7 +157,7 @@ plot_predictive_simulation( fitted_model, agent, inputs, - ("x1", "prediction_mean"), + ("xbin", "prediction_mean"), n_simulations = 100, ) diff --git a/docs/src/Julia_src_files/index.jl b/docs/src/Julia_src_files/index.jl index f079a20..bef750d 100644 --- a/docs/src/Julia_src_files/index.jl +++ b/docs/src/Julia_src_files/index.jl @@ -31,8 +31,8 @@ get_states(agent) #- get_parameters(agent) -# Set a new parameter for initial precision of x2 and define some inputs -set_parameters!(agent, ("x2", "initial_precision"), 0.9) +# Set a new parameter for initial precision of xprob and define some inputs +set_parameters!(agent, ("xprob", "initial_precision"), 0.9) inputs = [1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0]; # ### Give inputs to the agent @@ -42,18 +42,18 @@ actions = give_inputs!(agent, inputs) using StatsPlots using Plots plot_trajectory(agent, ("u", "input_value")) -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("xbin", "prediction")) -# Plot state trajectory of input value, action and prediction of x1 +# Plot state trajectory of input value, action and prediction of xbin plot_trajectory(agent, ("u", "input_value")) plot_trajectory!(agent, "action") -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("xbin", "prediction")) # ### Fitting parameters using Distributions -prior = Dict(("x2", "volatility") => Normal(1, 0.5)) +prior = Dict(("xprob", "volatility") => Normal(1, 0.5)) model = fit_model(agent, prior, inputs, actions, n_iterations = 20) diff --git a/docs/src/Julia_src_files/premade_HGF.jl b/docs/src/Julia_src_files/premade_HGF.jl index 8a660c0..f243c12 100644 --- a/docs/src/Julia_src_files/premade_HGF.jl +++ b/docs/src/Julia_src_files/premade_HGF.jl @@ -61,7 +61,7 @@ agent_continuous_2_level = give_inputs!(agent_continuous_2_level, inputs_continuous); plot_trajectory( agent_continuous_2_level, - "x2", + "xvol", color = "blue", size = (1300, 500), xlims = (0, 615), @@ -88,7 +88,7 @@ agent_JGET = premade_agent("hgf_gaussian_action", JGET, verbose = false); give_inputs!(agent_JGET, inputs_continuous); plot_trajectory( agent_JGET, - "x2", + "xvol", color = "blue", size = (1300, 500), xlims = (0, 615), @@ -115,12 +115,12 @@ agent_binary_2_level = # Evolve agent plot trajetories give_inputs!(agent_binary_2_level, inputs_binary); plot_trajectory(agent_binary_2_level, ("u", "input_value")) -plot_trajectory!(agent_binary_2_level, ("x1", "prediction")) +plot_trajectory!(agent_binary_2_level, ("xbin", "prediction")) #- -plot_trajectory(agent_binary_2_level, ("x2", "posterior")) +plot_trajectory(agent_binary_2_level, ("xprob", "posterior")) # ## Binary 3-level HGF @@ -139,13 +139,13 @@ agent_binary_3_level = # Evolve agent plot trajetories give_inputs!(agent_binary_3_level, inputs_binary); plot_trajectory(agent_binary_3_level, ("u", "input_value")) -plot_trajectory!(agent_binary_3_level, ("x1", "prediction")) +plot_trajectory!(agent_binary_3_level, ("xbin", "prediction")) #- -plot_trajectory(agent_binary_3_level, ("x2", "posterior")) +plot_trajectory(agent_binary_3_level, ("xprob", "posterior")) #- -plot_trajectory(agent_binary_3_level, ("x3", "posterior")) +plot_trajectory(agent_binary_3_level, ("xvol", "posterior")) diff --git a/docs/src/Julia_src_files/premade_models.jl b/docs/src/Julia_src_files/premade_models.jl index 54f4222..e75622b 100644 --- a/docs/src/Julia_src_files/premade_models.jl +++ b/docs/src/Julia_src_files/premade_models.jl @@ -15,7 +15,7 @@ # This premade agent model can be found as "hgf_gaussian_action" in the package. The Action distribution is a gaussian distribution with mean of the target state from the chosen HGF, and the standard deviation consisting of the action precision parameter inversed. #md # - Default hgf: contionus_2level -# - Default Target state: (x1, posterior mean) +# - Default Target state: (x, posterior mean) # - Default Parameters: gaussian action precision = 1 # ## HGF Binary Softmax agent @@ -23,7 +23,7 @@ # The action distribution is a Bernoulli distribution, and the parameter is action probability. Action probability is calculated using a softmax on the action precision parameter and the target value from the HGF. #md # - Default hgf: binary_3level -# - Default target state; (x1, prediction mean) +# - Default target state; (xbin, prediction mean) # - Default parameters: softmax action precision = 1 # ## HGF unit square sigmoid agent @@ -31,7 +31,7 @@ # The action distribution is Bernoulli distribution with the parameter beinga a softmax of the target value and action precision. # - Default hgf: binary_3level -# - Default target state; (x1, prediction mean) +# - Default target state; (xbin, prediction mean) # - Default parameters: softmax action precision = 1 @@ -40,7 +40,7 @@ # The action distribution is a categorical distribution. The action model takes the target node from the HGF, and takes out the prediction state. This state is a vector of values for each category. The vector is the only thing used in the categorical distribution # - Default hgf: categorical_3level -# - Default target state: Target categorical node x1 +# - Default target state: Target categorical node xcat # - Default parameters: none # ## Using premade agents @@ -61,21 +61,21 @@ agent = premade_agent("hgf_binary_softmax_action") get_parameters(agent) # Get specific parameter in agent: -get_parameters(agent, ("x3", "initial_precision")) +get_parameters(agent, ("xvol", "initial_precision")) # Get all states in an agent: get_states(agent) # Get specific state in an agent: -get_states(agent, ("x1", "posterior_precision")) +get_states(agent, ("xbin", "posterior_precision")) # Set a parameter value -set_parameters!(agent, ("x3", "initial_precision"), 0.4) +set_parameters!(agent, ("xvol", "initial_precision"), 0.4) # Set multiple parameter values set_parameters!( agent, - Dict(("x3", "initial_precision") => 1, ("x3", "volatility") => 0), + Dict(("xvol", "initial_precision") => 1, ("xvol", "volatility") => 0), ) @@ -88,7 +88,7 @@ input = [1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0] actions = give_inputs!(agent, input) # Get the history of a single state in the agent -get_history(agent, ("x1", "prediction_mean")) +get_history(agent, ("xbin", "prediction_mean")) # We can plot the input and prediciton means with plot trajectory. Notice, when using plot_trajectory!() you can layer plots. @@ -98,7 +98,7 @@ using Plots plot_trajectory(agent, ("u", "input_value")) # Let's add prediction mean on top of the plot -plot_trajectory!(agent, ("x1", "prediction_mean")) +plot_trajectory!(agent, ("xbin", "prediction_mean")) # ### Overview of functions diff --git a/docs/src/Julia_src_files/utility_functions.jl b/docs/src/Julia_src_files/utility_functions.jl index bc5324f..3cb6a2d 100644 --- a/docs/src/Julia_src_files/utility_functions.jl +++ b/docs/src/Julia_src_files/utility_functions.jl @@ -31,10 +31,10 @@ agent = premade_agent("hgf_binary_softmax_action") get_parameters(agent) # getting couplings -# ERROR WITH THIS get_parameters(agent, ("x2", "x3", "volatility_coupling")) +get_parameters(agent, ("xprob", "xvol", "volatility_coupling")) # getting multiple parameters specify them in a vector -get_parameters(agent, [("x3", "volatility"), ("x3", "initial_precision")]) +get_parameters(agent, [("xvol", "volatility"), ("xvol", "initial_precision")]) # ### Getting States @@ -43,10 +43,10 @@ get_parameters(agent, [("x3", "volatility"), ("x3", "initial_precision")]) get_states(agent) #getting a single state -get_states(agent, ("x2", "posterior_precision")) +get_states(agent, ("xprob", "posterior_precision")) #getting multiple states -get_states(agent, [("x2", "posterior_precision"), ("x2", "volatility_weighted_prediction_precision")]) +get_states(agent, [("xprob", "posterior_precision"), ("xprob", "volatility_weighted_prediction_precision")]) # ### Setting Parameters @@ -61,14 +61,14 @@ agent_parameter = Dict("sigmoid_action_precision" => 3) hgf_parameters = Dict( ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "volatility") => -2.5, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - ("x3", "volatility") => -6.0, - ("x3", "initial_mean") => 1, - ("x3", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, + ("xprob", "volatility") => -2.5, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "volatility") => -6.0, + ("xvol", "initial_mean") => 1, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "value_coupling") => 1.0, + ("xprob", "xvol", "volatility_coupling") => 1.0, ) hgf = premade_hgf("binary_3level", hgf_parameters) @@ -79,13 +79,13 @@ agent = premade_agent("hgf_unit_square_sigmoid_action", hgf, agent_parameter) # Changing a single parameter -set_parameters!(agent, ("x3", "initial_precision"), 4) +set_parameters!(agent, ("xvol", "initial_precision"), 4) # Changing multiple parameters set_parameters!( agent, - Dict(("x3", "initial_precision") => 5, ("x1", "x2", "value_coupling") => 2.0), + Dict(("xvol", "initial_precision") => 5, ("xbin", "xprob", "value_coupling") => 2.0), ) # ###Giving Inputs @@ -143,12 +143,12 @@ get_history(agent) #- # getting history of single state -get_history(agent, ("x3", "posterior_precision")) +get_history(agent, ("xvol", "posterior_precision")) #- # getting history of multiple states: -get_history(agent, [("x1", "prediction_mean"), ("x3", "posterior_precision")]) +get_history(agent, [("xbin", "prediction_mean"), ("xvol", "posterior_precision")]) # ### Plotting State Trajectories @@ -158,29 +158,29 @@ using Plots plot_trajectory(agent, ("u", "input_value")) #Adding state trajectory on top -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("xbin", "prediction")) # Plotting more individual states: -## Plot posterior of x2 -plot_trajectory(agent, ("x2", "posterior")) +## Plot posterior of xprob +plot_trajectory(agent, ("xprob", "posterior")) #- -## Plot posterior of x3 -plot_trajectory(agent, ("x3", "posterior")) +## Plot posterior of xvol +plot_trajectory(agent, ("xvol", "posterior")) # ### Getting Predictions -# You can specify an HGF or an agent in the funciton. The default node to extract is the node "x1" which is the first level node in every premade HGF structure. +# You can specify an HGF or an agent in the funciton. # get prediction of the last state get_prediction(agent) #specify another node to get predictions from: -get_prediction(agent, "x2") +get_prediction(agent, "xprob") # ### Getting Purprise diff --git a/docs/src/index.md b/docs/src/index.md index c7e3555..7cb3a58 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -49,10 +49,10 @@ get_states(agent) get_parameters(agent) ```` -Set a new parameter for initial precision of x2 and define some inputs +Set a new parameter for initial precision of xprob and define some inputs ````@example index -set_parameters!(agent, ("x2", "initial_precision"), 0.9) +set_parameters!(agent, ("xprob", "initial_precision"), 0.9) inputs = [1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0]; nothing #hide ```` @@ -69,22 +69,22 @@ actions = give_inputs!(agent, inputs) using StatsPlots using Plots plot_trajectory(agent, ("u", "input_value")) -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("x", "prediction")) ```` -Plot state trajectory of input value, action and prediction of x1 +Plot state trajectory of input value, action and prediction of x ````@example index plot_trajectory(agent, ("u", "input_value")) plot_trajectory!(agent, "action") -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("x", "prediction")) ```` ### Fitting parameters ````@example index using Distributions -prior = Dict(("x2", "volatility") => Normal(1, 0.5)) +prior = Dict(("xprob", "volatility") => Normal(1, 0.5)) model = fit_model(agent, prior, inputs, actions, n_iterations = 20) ```` diff --git a/docs/src/tutorials/classic_JGET.jl b/docs/src/tutorials/classic_JGET.jl index 0aae4de..b5f5cec 100644 --- a/docs/src/tutorials/classic_JGET.jl +++ b/docs/src/tutorials/classic_JGET.jl @@ -14,12 +14,12 @@ agent = premade_agent("hgf_gaussian_action", hgf) #Set parameters parameters = Dict( "gaussian_action_precision" => 1, - ("x1", "volatility") => -8, - ("x2", "volatility") => -5, - ("x3", "volatility") => -5, - ("x4", "volatility") => -5, - ("x1", "x2", "volatility_coupling") => 1, - ("x3", "x4", "volatility_coupling") => 1, + ("x", "volatility") => -8, + ("xvol", "volatility") => -5, + ("xnoise", "volatility") => -5, + ("xnoise_vol", "volatility") => -5, + ("x", "xvol", "volatility_coupling") => 1, + ("xnoise", "xnoise_vol", "volatility_coupling") => 1, ) set_parameters!(agent, parameters) @@ -28,14 +28,14 @@ inputs = data[(data.ID.==20).&(data.session.==1), :].outcome actions = give_inputs!(agent, inputs); #Plot belief trajectories plot_trajectory(agent, "u") -plot_trajectory!(agent, "x1") -plot_trajectory(agent, "x2") -plot_trajectory(agent, "x3") -plot_trajectory(agent, "x4") +plot_trajectory!(agent, "x") +plot_trajectory(agent, "xvol") +plot_trajectory(agent, "xnoise") +plot_trajectory(agent, "xnoise_vol") priors = Dict( "gaussian_action_precision" => LogNormal(-1, 0.1), - ("x1", "volatility") => Normal(-8, 1), + ("x", "volatility") => Normal(-8, 1), ) data_subset = data[(data.ID.∈[[20, 21]]).&(data.session.∈[[1, 2]]), :] @@ -69,4 +69,4 @@ reset!(agent) give_inputs!(agent, inputs) -get_history(agent, ("x1", "value_prediction_error")) +get_history(agent, ("x", "value_prediction_error")) diff --git a/docs/src/tutorials/classic_binary.jl b/docs/src/tutorials/classic_binary.jl index d29f30d..efd5d9c 100644 --- a/docs/src/tutorials/classic_binary.jl +++ b/docs/src/tutorials/classic_binary.jl @@ -23,14 +23,14 @@ inputs = CSV.read(data_path * "classic_binary_inputs.csv", DataFrame)[!, 1]; hgf_parameters = Dict( ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "volatility") => -2.5, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - ("x3", "volatility") => -6.0, - ("x3", "initial_mean") => 1, - ("x3", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, + ("xprob", "volatility") => -2.5, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "volatility") => -6.0, + ("xvol", "initial_mean") => 1, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "value_coupling") => 1.0, + ("xprob", "xvol", "volatility_coupling") => 1.0, ); hgf = premade_hgf("binary_3level", hgf_parameters, verbose = false); @@ -45,37 +45,37 @@ actions = give_inputs!(agent, inputs); # Plot the trajectory of the agent plot_trajectory(agent, ("u", "input_value")) -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("xbin", "prediction")) # - -plot_trajectory(agent, ("x2", "posterior")) -plot_trajectory(agent, ("x3", "posterior")) +plot_trajectory(agent, ("xprob", "posterior")) +plot_trajectory(agent, ("xvol", "posterior")) # Set fixed parameters fixed_parameters = Dict( "sigmoid_action_precision" => 5, ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - ("x3", "initial_mean") => 1, - ("x3", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, - ("x3", "volatility") => -6.0, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "initial_mean") => 1, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "value_coupling") => 1.0, + ("xprob", "xvol", "volatility_coupling") => 1.0, + ("xvol", "volatility") => -6.0, ); # Set priors for parameter recovery -param_priors = Dict(("x2", "volatility") => Normal(-3.0, 0.5)); +param_priors = Dict(("xprob", "volatility") => Normal(-3.0, 0.5)); #- # Prior predictive plot plot_predictive_simulation( param_priors, agent, inputs, - ("x1", "prediction_mean"), + ("xbin", "prediction_mean"), n_simulations = 100, ) #- @@ -104,6 +104,6 @@ plot_predictive_simulation( fitted_model, agent, inputs, - ("x1", "prediction_mean"), + ("xbin", "prediction_mean"), n_simulations = 3, ) diff --git a/docs/src/tutorials/classic_usdchf.jl b/docs/src/tutorials/classic_usdchf.jl index bb305a6..e9a9141 100644 --- a/docs/src/tutorials/classic_usdchf.jl +++ b/docs/src/tutorials/classic_usdchf.jl @@ -28,15 +28,15 @@ agent = premade_agent("hgf_gaussian_action", hgf, verbose = false); # Set parameters for parameter recover parameters = Dict( - ("u", "x1", "value_coupling") => 1.0, - ("x1", "x2", "volatility_coupling") => 1.0, + ("u", "x", "value_coupling") => 1.0, + ("x", "xvol", "volatility_coupling") => 1.0, ("u", "input_noise") => -log(1e4), - ("x1", "volatility") => -13, - ("x2", "volatility") => -2, - ("x1", "initial_mean") => 1.04, - ("x1", "initial_precision") => 1 / (0.0001), - ("x2", "initial_mean") => 1.0, - ("x2", "initial_precision") => 1 / 0.1, + ("x", "volatility") => -13, + ("xvol", "volatility") => -2, + ("x", "initial_mean") => 1.04, + ("x", "initial_precision") => 1 / (0.0001), + ("xvol", "initial_mean") => 1.0, + ("xvol", "initial_precision") => 1 / 0.1, "gaussian_action_precision" => 100, ); @@ -59,7 +59,7 @@ plot_trajectory( xlabel = "Trading days since 1 January 2010", ) #- -plot_trajectory!(agent, ("x1", "posterior"), color = "red") +plot_trajectory!(agent, ("x", "posterior"), color = "red") plot_trajectory!( agent, "action", @@ -71,7 +71,7 @@ plot_trajectory!( #- plot_trajectory( agent, - "x2", + "xvol", color = "blue", size = (1300, 500), xlims = (0, 615), @@ -81,19 +81,19 @@ plot_trajectory( #- # Set priors for fitting fixed_parameters = Dict( - ("u", "x1", "value_coupling") => 1.0, - ("x1", "x2", "volatility_coupling") => 1.0, - ("x1", "initial_mean") => 0, - ("x1", "initial_precision") => 2000, - ("x2", "initial_mean") => 1.0, - ("x2", "initial_precision") => 600.0, + ("u", "x", "value_coupling") => 1.0, + ("x", "xvol", "volatility_coupling") => 1.0, + ("x", "initial_mean") => 0, + ("x", "initial_precision") => 2000, + ("xvol", "initial_mean") => 1.0, + ("xvol", "initial_precision") => 600.0, "gaussian_action_precision" => 100, ); param_priors = Dict( ("u", "input_noise") => Normal(-6, 1), - ("x1", "volatility") => Normal(-4, 1), - ("x2", "volatility") => Normal(-4, 1), + ("x", "volatility") => Normal(-4, 1), + ("xvol", "volatility") => Normal(-4, 1), ); #- # Prior predictive simulation plot @@ -101,7 +101,7 @@ plot_predictive_simulation( param_priors, agent, inputs, - ("x1", "posterior_mean"); + ("x", "posterior_mean"); n_simulations = 100, ) #- diff --git a/src/create_hgf/init_hgf.jl b/src/create_hgf/init_hgf.jl index b348647..b9bdc80 100644 --- a/src/create_hgf/init_hgf.jl +++ b/src/create_hgf/init_hgf.jl @@ -34,14 +34,14 @@ input_nodes = Dict( #List of state nodes state_nodes = [ Dict( - "name" => "x1", + "name" => "x", "type" => "continuous", "volatility" => -2, "initial_mean" => 0, "initial_precision" => 1, ), Dict( - "name" => "x2", + "name" => "xvol", "type" => "continuous", "volatility" => -2, "initial_mean" => 0, @@ -53,11 +53,11 @@ state_nodes = [ edges = [ Dict( "child" => "u", - "value_parents" => ("x1", 1), + "value_parents" => ("x", 1), ), Dict( - "child" => "x1", - "volatility_parents" => ("x2", 1), + "child" => "x", + "volatility_parents" => ("xvol", 1), ), ] @@ -86,8 +86,8 @@ input_nodes = [ ] state_nodes = [ - "x1", - "x2", + "x", + "xvol", "x3", "x4", ] @@ -95,19 +95,19 @@ state_nodes = [ edges = [ Dict( "child" => "u1", - "value_parents" => ["x1", "x2"], + "value_parents" => ["x", "xvol"], "volatility_parents" => "x3" ), Dict( "child" => "u2", - "value_parents" => ["x1"], + "value_parents" => ["x"], ), Dict( - "child" => "x1", + "child" => "x", "volatility_parents" => "x4", ), Dict( - "child" => "x2", + "child" => "xvol", "volatility_parents" => "x4", ), ] diff --git a/src/premade_models/premade_agents.jl b/src/premade_models/premade_agents.jl index b128a12..ffb2d8d 100644 --- a/src/premade_models/premade_agents.jl +++ b/src/premade_models/premade_agents.jl @@ -69,7 +69,7 @@ function premade_hgf_multiple_actions(config::Dict) if "gaussian_target_state" in keys(config) settings["gaussian_target_state"] = config["gaussian_target_state"] else - default_target_state = ("x1", "posterior_mean") + default_target_state = ("x", "posterior_mean") settings["gaussian_target_state"] = default_target_state @warn "setting gaussian_target_state was not set by the user. Using the default: $default_target_state" end @@ -90,7 +90,7 @@ function premade_hgf_multiple_actions(config::Dict) if "softmax_target_state" in keys(config) settings["softmax_target_state"] = config["softmax_target_state"] else - default_target_state = ("x1", "prediction_mean") + default_target_state = ("xbin", "prediction_mean") settings["softmax_target_state"] = default_target_state @warn "setting softmax_target_state was not set by the user. Using the default: $default_target_state" end @@ -111,7 +111,7 @@ function premade_hgf_multiple_actions(config::Dict) if "sigmoid_target_state" in keys(config) settings["sigmoid_target_state"] = config["sigmoid_target_state"] else - default_target_state = ("x1", "prediction_mean") + default_target_state = ("xbin", "prediction_mean") settings["sigmoid_target_state"] = default_target_state @warn "setting sigmoid_target_state was not set by the user. Using the default: $default_target_state" end @@ -144,7 +144,7 @@ Create an agent suitable for the HGF Gaussian action model. # Config defaults: - "HGF": "continuous_2level" - "gaussian_action_precision": 1 - - "target_state": ("x1", "posterior_mean") + - "target_state": ("x", "posterior_mean") """ function premade_hgf_gaussian(config::Dict) @@ -153,7 +153,7 @@ function premade_hgf_gaussian(config::Dict) #Default parameters and settings defaults = Dict( "gaussian_action_precision" => 1, - "target_state" => ("x1", "posterior_mean"), + "target_state" => ("x", "posterior_mean"), "HGF" => "continuous_2level", ) @@ -205,7 +205,7 @@ Create an agent suitable for the HGF binary softmax model. # Config defaults: - "HGF": "binary_3level" - "softmax_action_precision": 1 - - "target_state": ("x1", "prediction_mean") + - "target_state": ("xbin", "prediction_mean") """ function premade_hgf_binary_softmax(config::Dict) @@ -214,7 +214,7 @@ function premade_hgf_binary_softmax(config::Dict) #Default parameters and settings defaults = Dict( "softmax_action_precision" => 1, - "target_state" => ("x1", "prediction_mean"), + "target_state" => ("xbin", "prediction_mean"), "HGF" => "binary_3level", ) @@ -266,7 +266,7 @@ Create an agent suitable for the HGF unit square sigmoid model. # Config defaults: - "HGF": "binary_3level" - "sigmoid_action_precision": 1 - - "target_state": ("x1", "prediction_mean") + - "target_state": ("xbin", "prediction_mean") """ function premade_hgf_unit_square_sigmoid(config::Dict) @@ -275,7 +275,7 @@ function premade_hgf_unit_square_sigmoid(config::Dict) #Default parameters and settings defaults = Dict( "sigmoid_action_precision" => 1, - "target_state" => ("x1", "prediction_mean"), + "target_state" => ("xbin", "prediction_mean"), "HGF" => "binary_3level", ) @@ -326,14 +326,14 @@ Create an agent suitable for the HGF predict category model. # Config defaults: - "HGF": "categorical_3level" - - "target_categorical_node": "x1" + - "target_categorical_node": "xcat" """ function premade_hgf_predict_category(config::Dict) ## Combine defaults and user settings #Default parameters and settings - defaults = Dict("target_categorical_node" => "x1", "HGF" => "categorical_3level") + defaults = Dict("target_categorical_node" => "xcat", "HGF" => "categorical_3level") #If there is no HGF in the user-set parameters if !("HGF" in keys(config)) diff --git a/src/premade_models/premade_hgfs.jl b/src/premade_models/premade_hgfs.jl index fc1372f..6c1f8fd 100644 --- a/src/premade_models/premade_hgfs.jl +++ b/src/premade_models/premade_hgfs.jl @@ -2,18 +2,18 @@ premade_continuous_2level(config::Dict; verbose::Bool = true) The standard 2 level continuous HGF, which filters a continuous input. -It has a continous input node u, with a single value parent x1, which in turn has a single volatility parent x2. +It has a continous input node u, with a single value parent x, which in turn has a single volatility parent xvol. # Config defaults: - ("u", "input_noise"): -2 - - ("x1", "volatility"): -2 - - ("x2", "volatility"): -2 - - ("u", "x1", "value_coupling"): 1 - - ("x1", "x2", "volatility_coupling"): 1 - - ("x1", "initial_mean"): 0 - - ("x1", "initial_precision"): 1 - - ("x2", "initial_mean"): 0 - - ("x2", "initial_precision"): 1 + - ("x", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("u", "x", "value_coupling"): 1 + - ("x", "xvol", "volatility_coupling"): 1 + - ("x", "initial_mean"): 0 + - ("x", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 """ function premade_continuous_2level(config::Dict; verbose::Bool = true) @@ -21,22 +21,22 @@ function premade_continuous_2level(config::Dict; verbose::Bool = true) spec_defaults = Dict( ("u", "input_noise") => -2, - ("x1", "volatility") => -2, - ("x1", "drift") => 0, - ("x1", "autoregression_target") => 0, - ("x1", "autoregression_strength") => 0, - ("x1", "initial_mean") => 0, - ("x1", "initial_precision") => 1, - - ("x2", "volatility") => -2, - ("x2", "drift") => 0, - ("x2", "autoregression_target") => 0, - ("x2", "autoregression_strength") => 0, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - - ("u", "x1", "value_coupling") => 1, - ("x1", "x2", "volatility_coupling") => 1, + ("x", "volatility") => -2, + ("x", "drift") => 0, + ("x", "autoregression_target") => 0, + ("x", "autoregression_strength") => 0, + ("x", "initial_mean") => 0, + ("x", "initial_precision") => 1, + + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoregression_target") => 0, + ("xvol", "autoregression_strength") => 0, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, + + ("u", "x", "value_coupling") => 1, + ("x", "xvol", "volatility_coupling") => 1, "update_type" => EnhancedUpdate(), ) @@ -60,24 +60,24 @@ function premade_continuous_2level(config::Dict; verbose::Bool = true) #List of state nodes to create state_nodes = [ Dict( - "name" => "x1", + "name" => "x", "type" => "continuous", - "volatility" => config[("x1", "volatility")], - "drift" => config[("x1", "drift")], - "autoregression_target" => config[("x1", "autoregression_target")], - "autoregression_strength" => config[("x1", "autoregression_strength")], - "initial_mean" => config[("x1", "initial_mean")], - "initial_precision" => config[("x1", "initial_precision")], + "volatility" => config[("x", "volatility")], + "drift" => config[("x", "drift")], + "autoregression_target" => config[("x", "autoregression_target")], + "autoregression_strength" => config[("x", "autoregression_strength")], + "initial_mean" => config[("x", "initial_mean")], + "initial_precision" => config[("x", "initial_precision")], ), Dict( - "name" => "x2", + "name" => "xvol", "type" => "continuous", - "volatility" => config[("x2", "volatility")], - "drift" => config[("x2", "drift")], - "autoregression_target" => config[("x2", "autoregression_target")], - "autoregression_strength" => config[("x2", "autoregression_strength")], - "initial_mean" => config[("x2", "initial_mean")], - "initial_precision" => config[("x2", "initial_precision")], + "volatility" => config[("xvol", "volatility")], + "drift" => config[("xvol", "drift")], + "autoregression_target" => config[("xvol", "autoregression_target")], + "autoregression_strength" => config[("xvol", "autoregression_strength")], + "initial_mean" => config[("xvol", "initial_mean")], + "initial_precision" => config[("xvol", "initial_precision")], ), ] @@ -85,11 +85,11 @@ function premade_continuous_2level(config::Dict; verbose::Bool = true) edges = [ Dict( "child" => "u", - "value_parents" => ("x1", config[("u", "x1", "value_coupling")]), + "value_parents" => ("x", config[("u", "x", "value_coupling")]), ), Dict( - "child" => "x1", - "volatility_parents" => ("x2", config[("x1", "x2", "volatility_coupling")]), + "child" => "x", + "volatility_parents" => ("xvol", config[("x", "xvol", "volatility_coupling")]), ), ] @@ -107,26 +107,26 @@ end """ premade_JGET(config::Dict; verbose::Bool = true) -The HGF used in the JGET model. It has a single continuous input node u, with a value parent x1, and a volatility parent x3. x1 has volatility parent x2, and x3 has a volatility parent x4. +The HGF used in the JGET model. It has a single continuous input node u, with a value parent x, and a volatility parent xnoise. x has volatility parent xvol, and xnoise has a volatility parent xnoise_vol. # Config defaults: - ("u", "input_noise"): -2 - - ("x1", "volatility"): -2 - - ("x2", "volatility"): -2 - - ("x3", "volatility"): -2 - - ("x4", "volatility"): -2 - - ("u", "x1", "value_coupling"): 1 - - ("u", "x3", "value_coupling"): 1 - - ("x1", "x2", "volatility_coupling"): 1 - - ("x3", "x4", "volatility_coupling"): 1 - - ("x1", "initial_mean"): 0 - - ("x1", "initial_precision"): 1 - - ("x2", "initial_mean"): 0 - - ("x2", "initial_precision"): 1 - - ("x3", "initial_mean"): 0 - - ("x3", "initial_precision"): 1 - - ("x4", "initial_mean"): 0 - - ("x4", "initial_precision"): 1 + - ("x", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("xnoise", "volatility"): -2 + - ("xnoise_vol", "volatility"): -2 + - ("u", "x", "value_coupling"): 1 + - ("u", "xnoise", "value_coupling"): 1 + - ("x", "xvol", "volatility_coupling"): 1 + - ("xnoise", "xnoise_vol", "volatility_coupling"): 1 + - ("x", "initial_mean"): 0 + - ("x", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 + - ("xnoise", "initial_mean"): 0 + - ("xnoise", "initial_precision"): 1 + - ("xnoise_vol", "initial_mean"): 0 + - ("xnoise_vol", "initial_precision"): 1 """ function premade_JGET(config::Dict; verbose::Bool = true) @@ -134,38 +134,38 @@ function premade_JGET(config::Dict; verbose::Bool = true) spec_defaults = Dict( ("u", "input_noise") => -2, - ("x1", "volatility") => -2, - ("x1", "drift") => 0, - ("x1", "autoregression_target") => 0, - ("x1", "autoregression_strength") => 0, - ("x1", "initial_mean") => 0, - ("x1", "initial_precision") => 1, - - ("x2", "volatility") => -2, - ("x2", "drift") => 0, - ("x2", "autoregression_target") => 0, - ("x2", "autoregression_strength") => 0, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - - ("x3", "volatility") => -2, - ("x3", "drift") => 0, - ("x3", "autoregression_target") => 0, - ("x3", "autoregression_strength") => 0, - ("x3", "initial_mean") => 0, - ("x3", "initial_precision") => 1, - - ("x4", "volatility") => -2, - ("x4", "drift") => 0, - ("x4", "autoregression_target") => 0, - ("x4", "autoregression_strength") => 0, - ("x4", "initial_mean") => 0, - ("x4", "initial_precision") => 1, - - ("u", "x1", "value_coupling") => 1, - ("u", "x3", "volatility_coupling") => 1, - ("x1", "x2", "volatility_coupling") => 1, - ("x3", "x4", "volatility_coupling") => 1, + ("x", "volatility") => -2, + ("x", "drift") => 0, + ("x", "autoregression_target") => 0, + ("x", "autoregression_strength") => 0, + ("x", "initial_mean") => 0, + ("x", "initial_precision") => 1, + + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoregression_target") => 0, + ("xvol", "autoregression_strength") => 0, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, + + ("xnoise", "volatility") => -2, + ("xnoise", "drift") => 0, + ("xnoise", "autoregression_target") => 0, + ("xnoise", "autoregression_strength") => 0, + ("xnoise", "initial_mean") => 0, + ("xnoise", "initial_precision") => 1, + + ("xnoise_vol", "volatility") => -2, + ("xnoise_vol", "drift") => 0, + ("xnoise_vol", "autoregression_target") => 0, + ("xnoise_vol", "autoregression_strength") => 0, + ("xnoise_vol", "initial_mean") => 0, + ("xnoise_vol", "initial_precision") => 1, + + ("u", "x", "value_coupling") => 1, + ("u", "xnoise", "volatility_coupling") => 1, + ("x", "xvol", "volatility_coupling") => 1, + ("xnoise", "xnoise_vol", "volatility_coupling") => 1, "update_type" => EnhancedUpdate(), ) @@ -189,44 +189,44 @@ function premade_JGET(config::Dict; verbose::Bool = true) #List of state nodes to create state_nodes = [ Dict( - "name" => "x1", + "name" => "x", "type" => "continuous", - "volatility" => config[("x1", "volatility")], - "drift" => config[("x1", "drift")], - "autoregression_target" => config[("x1", "autoregression_target")], - "autoregression_strength" => config[("x1", "autoregression_strength")], - "initial_mean" => config[("x1", "initial_mean")], - "initial_precision" => config[("x1", "initial_precision")], + "volatility" => config[("x", "volatility")], + "drift" => config[("x", "drift")], + "autoregression_target" => config[("x", "autoregression_target")], + "autoregression_strength" => config[("x", "autoregression_strength")], + "initial_mean" => config[("x", "initial_mean")], + "initial_precision" => config[("x", "initial_precision")], ), Dict( - "name" => "x2", + "name" => "xvol", "type" => "continuous", - "volatility" => config[("x2", "volatility")], - "drift" => config[("x2", "drift")], - "autoregression_target" => config[("x2", "autoregression_target")], - "autoregression_strength" => config[("x2", "autoregression_strength")], - "initial_mean" => config[("x2", "initial_mean")], - "initial_precision" => config[("x2", "initial_precision")], + "volatility" => config[("xvol", "volatility")], + "drift" => config[("xvol", "drift")], + "autoregression_target" => config[("xvol", "autoregression_target")], + "autoregression_strength" => config[("xvol", "autoregression_strength")], + "initial_mean" => config[("xvol", "initial_mean")], + "initial_precision" => config[("xvol", "initial_precision")], ), Dict( - "name" => "x3", + "name" => "xnoise", "type" => "continuous", - "volatility" => config[("x3", "volatility")], - "drift" => config[("x3", "drift")], - "autoregression_target" => config[("x3", "autoregression_target")], - "autoregression_strength" => config[("x3", "autoregression_strength")], - "initial_mean" => config[("x3", "initial_precision")], - "initial_precision" => config[("x3", "initial_precision")], + "volatility" => config[("xnoise", "volatility")], + "drift" => config[("xnoise", "drift")], + "autoregression_target" => config[("xnoise", "autoregression_target")], + "autoregression_strength" => config[("xnoise", "autoregression_strength")], + "initial_mean" => config[("xnoise", "initial_precision")], + "initial_precision" => config[("xnoise", "initial_precision")], ), Dict( - "name" => "x4", + "name" => "xnoise_vol", "type" => "continuous", - "volatility" => config[("x4", "volatility")], - "drift" => config[("x4", "drift")], - "autoregression_target" => config[("x4", "autoregression_target")], - "autoregression_strength" => config[("x4", "autoregression_strength")], - "initial_mean" => config[("x4", "initial_mean")], - "initial_precision" => config[("x4", "initial_precision")], + "volatility" => config[("xnoise_vol", "volatility")], + "drift" => config[("xnoise_vol", "drift")], + "autoregression_target" => config[("xnoise_vol", "autoregression_target")], + "autoregression_strength" => config[("xnoise_vol", "autoregression_strength")], + "initial_mean" => config[("xnoise_vol", "initial_mean")], + "initial_precision" => config[("xnoise_vol", "initial_precision")], ), ] @@ -234,16 +234,16 @@ function premade_JGET(config::Dict; verbose::Bool = true) edges = [ Dict( "child" => "u", - "value_parents" => ("x1", config[("u", "x1", "value_coupling")]), - "volatility_parents" => ("x3", config[("u", "x3", "volatility_coupling")]), + "value_parents" => ("x", config[("u", "x", "value_coupling")]), + "volatility_parents" => ("xnoise", config[("u", "xnoise", "volatility_coupling")]), ), Dict( - "child" => "x1", - "volatility_parents" => ("x2", config[("x1", "x2", "volatility_coupling")]), + "child" => "x", + "volatility_parents" => ("xvol", config[("x", "xvol", "volatility_coupling")]), ), Dict( - "child" => "x3", - "volatility_parents" => ("x4", config[("x3", "x4", "volatility_coupling")]), + "child" => "xnoise", + "volatility_parents" => ("xnoise_vol", config[("xnoise", "xnoise_vol", "volatility_coupling")]), ), ] @@ -262,15 +262,15 @@ end premade_binary_2level(config::Dict; verbose::Bool = true) The standard binary 2 level HGF model, which takes a binary input, and learns the probability of either outcome. -It has one binary input node u, with a binary value parent x1, which in turn has a continuous value parent x2. +It has one binary input node u, with a binary value parent xbin, which in turn has a continuous value parent xprob. # Config defaults: - ("u", "category_means"): [0, 1] - ("u", "input_precision"): Inf - - ("x2", "volatility"): -2 - - ("x1", "x2", "value_coupling"): 1 - - ("x2", "initial_mean"): 0 - - ("x2", "initial_precision"): 1 + - ("xprob", "volatility"): -2 + - ("xbin", "xprob", "value_coupling"): 1 + - ("xprob", "initial_mean"): 0 + - ("xprob", "initial_precision"): 1 """ function premade_binary_2level(config::Dict; verbose::Bool = true) @@ -279,14 +279,14 @@ function premade_binary_2level(config::Dict; verbose::Bool = true) ("u", "category_means") => [0, 1], ("u", "input_precision") => Inf, - ("x2", "volatility") => -2, - ("x2", "drift") => 0, - ("x2", "autoregression_target") => 0, - ("x2", "autoregression_strength") => 0, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, + ("xprob", "volatility") => -2, + ("xprob", "drift") => 0, + ("xprob", "autoregression_target") => 0, + ("xprob", "autoregression_strength") => 0, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1, + ("xbin", "xprob", "value_coupling") => 1, "update_type" => EnhancedUpdate(), ) @@ -310,25 +310,25 @@ function premade_binary_2level(config::Dict; verbose::Bool = true) #List of state nodes to create state_nodes = [ - Dict("name" => "x1", "type" => "binary"), + Dict("name" => "xbin", "type" => "binary"), Dict( - "name" => "x2", + "name" => "xprob", "type" => "continuous", - "volatility" => config[("x2", "volatility")], - "drift" => config[("x2", "drift")], - "autoregression_target" => config[("x2", "autoregression_target")], - "autoregression_strength" => config[("x2", "autoregression_strength")], - "initial_mean" => config[("x2", "initial_mean")], - "initial_precision" => config[("x2", "initial_precision")], + "volatility" => config[("xprob", "volatility")], + "drift" => config[("xprob", "drift")], + "autoregression_target" => config[("xprob", "autoregression_target")], + "autoregression_strength" => config[("xprob", "autoregression_strength")], + "initial_mean" => config[("xprob", "initial_mean")], + "initial_precision" => config[("xprob", "initial_precision")], ), ] #List of child-parent relations edges = [ - Dict("child" => "u", "value_parents" => "x1"), + Dict("child" => "u", "value_parents" => "xbin"), Dict( - "child" => "x1", - "value_parents" => ("x2", config[("x1", "x2", "value_coupling")]), + "child" => "xbin", + "value_parents" => ("xprob", config[("xbin", "xprob", "value_coupling")]), ), ] @@ -347,26 +347,26 @@ end premade_binary_3level(config::Dict; verbose::Bool = true) The standard binary 3 level HGF model, which takes a binary input, and learns the probability of either outcome. -It has one binary input node u, with a binary value parent x1, which in turn has a continuous value parent x2. This then has a continunous volatility parent x3. +It has one binary input node u, with a binary value parent xbin, which in turn has a continuous value parent xprob. This then has a continunous volatility parent xvol. This HGF has five shared parameters: -"x2_volatility" -"x2_initial_precisions" -"x2_initial_means" -"value_couplings_x1_x2" -"volatility_couplings_x2_x3" +"xprob_volatility" +"xprob_initial_precisions" +"xprob_initial_means" +"value_couplings_xbin_xprob" +"volatility_couplings_xprob_xvol" # Config defaults: - ("u", "category_means"): [0, 1] - ("u", "input_precision"): Inf - - ("x2", "volatility"): -2 - - ("x3", "volatility"): -2 - - ("x1", "x2", "value_coupling"): 1 - - ("x2", "x3", "volatility_coupling"): 1 - - ("x2", "initial_mean"): 0 - - ("x2", "initial_precision"): 1 - - ("x3", "initial_mean"): 0 - - ("x3", "initial_precision"): 1 + - ("xprob", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("xbin", "xprob", "value_coupling"): 1 + - ("xprob", "xvol", "volatility_coupling"): 1 + - ("xprob", "initial_mean"): 0 + - ("xprob", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 """ function premade_binary_3level(config::Dict; verbose::Bool = true) @@ -375,22 +375,22 @@ function premade_binary_3level(config::Dict; verbose::Bool = true) ("u", "category_means") => [0, 1], ("u", "input_precision") => Inf, - ("x2", "volatility") => -2, - ("x2", "drift") => 0, - ("x2", "autoregression_target") => 0, - ("x2", "autoregression_strength") => 0, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, + ("xprob", "volatility") => -2, + ("xprob", "drift") => 0, + ("xprob", "autoregression_target") => 0, + ("xprob", "autoregression_strength") => 0, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, - ("x3", "volatility") => -2, - ("x3", "drift") => 0, - ("x3", "autoregression_target") => 0, - ("x3", "autoregression_strength") => 0, - ("x3", "initial_mean") => 0, - ("x3", "initial_precision") => 1, + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoregression_target") => 0, + ("xvol", "autoregression_strength") => 0, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1, - ("x2", "x3", "volatility_coupling") => 1, + ("xbin", "xprob", "value_coupling") => 1, + ("xprob", "xvol", "volatility_coupling") => 1, "update_type" => EnhancedUpdate(), ) @@ -414,39 +414,39 @@ function premade_binary_3level(config::Dict; verbose::Bool = true) #List of state nodes to create state_nodes = [ - Dict("name" => "x1", "type" => "binary"), + Dict("name" => "xbin", "type" => "binary"), Dict( - "name" => "x2", + "name" => "xprob", "type" => "continuous", - "volatility" => config[("x2", "volatility")], - "drift" => config[("x2", "drift")], - "autoregression_target" => config[("x2", "autoregression_target")], - "autoregression_strength" => config[("x2", "autoregression_strength")], - "initial_mean" => config[("x2", "initial_mean")], - "initial_precision" => config[("x2", "initial_precision")], + "volatility" => config[("xprob", "volatility")], + "drift" => config[("xprob", "drift")], + "autoregression_target" => config[("xprob", "autoregression_target")], + "autoregression_strength" => config[("xprob", "autoregression_strength")], + "initial_mean" => config[("xprob", "initial_mean")], + "initial_precision" => config[("xprob", "initial_precision")], ), Dict( - "name" => "x3", + "name" => "xvol", "type" => "continuous", - "volatility" => config[("x3", "volatility")], - "drift" => config[("x3", "drift")], - "autoregression_target" => config[("x3", "autoregression_target")], - "autoregression_strength" => config[("x3", "autoregression_strength")], - "initial_mean" => config[("x3", "initial_mean")], - "initial_precision" => config[("x3", "initial_precision")], + "volatility" => config[("xvol", "volatility")], + "drift" => config[("xvol", "drift")], + "autoregression_target" => config[("xvol", "autoregression_target")], + "autoregression_strength" => config[("xvol", "autoregression_strength")], + "initial_mean" => config[("xvol", "initial_mean")], + "initial_precision" => config[("xvol", "initial_precision")], ), ] #List of child-parent relations edges = [ - Dict("child" => "u", "value_parents" => "x1"), + Dict("child" => "u", "value_parents" => "xbin"), Dict( - "child" => "x1", - "value_parents" => ("x2", config[("x1", "x2", "value_coupling")]), + "child" => "xbin", + "value_parents" => ("xprob", config[("xbin", "xprob", "value_coupling")]), ), Dict( - "child" => "x2", - "volatility_parents" => ("x3", config[("x2", "x3", "volatility_coupling")]), + "child" => "xprob", + "volatility_parents" => ("xvol", config[("xprob", "xvol", "volatility_coupling")]), ), ] @@ -464,21 +464,21 @@ end premade_categorical_3level(config::Dict; verbose::Bool = true) The categorical 3 level HGF model, which takes an input from one of n categories and learns the probability of a category appearing. -It has one categorical input node u, with a categorical value parent x1. -The categorical node has a binary value parent x1_n for each category n, each of which has a continuous value parent x2_n. -Finally, all of these continuous nodes share a continuous volatility parent x3. -Setting parameter values for x1 and x2 sets that parameter value for each of the x1_n and x2_n nodes. +It has one categorical input node u, with a categorical value parent xcat. +The categorical node has a binary value parent xbin_n for each category n, each of which has a continuous value parent xprob_n. +Finally, all of these continuous nodes share a continuous volatility parent xvol. +Setting parameter values for xbin and xprob sets that parameter value for each of the xbin_n and xprob_n nodes. # Config defaults: - "n_categories": 4 - - ("x2", "volatility"): -2 - - ("x3", "volatility"): -2 - - ("x1", "x2", "value_coupling"): 1 - - ("x2", "x3", "volatility_coupling"): 1 - - ("x2", "initial_mean"): 0 - - ("x2", "initial_precision"): 1 - - ("x3", "initial_mean"): 0 - - ("x3", "initial_precision"): 1 + - ("xprob", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("xbin", "xprob", "value_coupling"): 1 + - ("xprob", "xvol", "volatility_coupling"): 1 + - ("xprob", "initial_mean"): 0 + - ("xprob", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 """ function premade_categorical_3level(config::Dict; verbose::Bool = true) @@ -486,22 +486,22 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) defaults = Dict( "n_categories" => 4, - ("x2", "volatility") => -2, - ("x2", "drift") => 0, - ("x2", "autoregression_target") => 0, - ("x2", "autoregression_strength") => 0, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, + ("xprob", "volatility") => -2, + ("xprob", "drift") => 0, + ("xprob", "autoregression_target") => 0, + ("xprob", "autoregression_strength") => 0, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, - ("x3", "volatility") => -2, - ("x3", "drift") => 0, - ("x3", "autoregression_target") => 0, - ("x3", "autoregression_strength") => 0, - ("x3", "initial_mean") => 0, - ("x3", "initial_precision") => 1, + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoregression_target") => 0, + ("xvol", "autoregression_strength") => 0, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1, - ("x2", "x3", "volatility_coupling") => 1, + ("xbin", "xprob", "value_coupling") => 1, + ("xprob", "xvol", "volatility_coupling") => 1, "update_type" => EnhancedUpdate(), ) @@ -522,26 +522,26 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) binary_continuous_parent_names = Vector{String}() #Empty lists for derived parameters - derived_parameters_x2_initial_precision = [] - derived_parameters_x2_initial_mean = [] - derived_parameters_x2_volatility = [] - derived_parameters_x2_drift = [] - derived_parameters_x2_autoregression_target = [] - derived_parameters_x2_autoregression_strength = [] - derived_parameters_x2_x3_volatility_coupling = [] - derived_parameters_value_coupling_x1_x2 = [] + derived_parameters_xprob_initial_precision = [] + derived_parameters_xprob_initial_mean = [] + derived_parameters_xprob_volatility = [] + derived_parameters_xprob_drift = [] + derived_parameters_xprob_autoregression_target = [] + derived_parameters_xprob_autoregression_strength = [] + derived_parameters_xprob_xvol_volatility_coupling = [] + derived_parameters_value_coupling_xbin_xprob = [] #Populate the category node vectors with node names for category_number = 1:config["n_categories"] - push!(category_binary_parent_names, "x1_" * string(category_number)) - push!(binary_continuous_parent_names, "x2_" * string(category_number)) + push!(category_binary_parent_names, "xbin_" * string(category_number)) + push!(binary_continuous_parent_names, "xprob_" * string(category_number)) end ##List of input nodes input_nodes = Dict("name" => "u", "type" => "categorical") ##List of state nodes - state_nodes = [Dict{String,Any}("name" => "x1", "type" => "categorical")] + state_nodes = [Dict{String,Any}("name" => "xcat", "type" => "categorical")] #Add category node binary parents for node_name in category_binary_parent_names @@ -555,43 +555,43 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) Dict( "name" => node_name, "type" => "continuous", - "initial_mean" => config[("x2", "initial_mean")], - "initial_precision" => config[("x2", "initial_precision")], - "volatility" => config[("x2", "volatility")], - "drift" => config[("x2", "drift")], - "autoregression_target" => config[("x2", "autoregression_target")], - "autoregression_strength" => config[("x2", "autoregression_strength")], + "initial_mean" => config[("xprob", "initial_mean")], + "initial_precision" => config[("xprob", "initial_precision")], + "volatility" => config[("xprob", "volatility")], + "drift" => config[("xprob", "drift")], + "autoregression_target" => config[("xprob", "autoregression_target")], + "autoregression_strength" => config[("xprob", "autoregression_strength")], ), ) #Add the derived parameter name to derived parameters vector - push!(derived_parameters_x2_initial_precision, (node_name, "initial_precision")) - push!(derived_parameters_x2_initial_mean, (node_name, "initial_mean")) - push!(derived_parameters_x2_volatility, (node_name, "volatility")) - push!(derived_parameters_x2_drift, (node_name, "drift")) - push!(derived_parameters_x2_autoregression_strength, (node_name, "autoregression_strength")) - push!(derived_parameters_x2_autoregression_target, (node_name, "autoregression_target")) + push!(derived_parameters_xprob_initial_precision, (node_name, "initial_precision")) + push!(derived_parameters_xprob_initial_mean, (node_name, "initial_mean")) + push!(derived_parameters_xprob_volatility, (node_name, "volatility")) + push!(derived_parameters_xprob_drift, (node_name, "drift")) + push!(derived_parameters_xprob_autoregression_strength, (node_name, "autoregression_strength")) + push!(derived_parameters_xprob_autoregression_target, (node_name, "autoregression_target")) end #Add volatility parent push!( state_nodes, Dict( - "name" => "x3", + "name" => "xvol", "type" => "continuous", - "volatility" => config[("x3", "volatility")], - "drift" => config[("x3", "drift")], - "autoregression_target" => config[("x3", "autoregression_target")], - "autoregression_strength" => config[("x3", "autoregression_strength")], - "initial_mean" => config[("x3", "initial_mean")], - "initial_precision" => config[("x3", "initial_precision")], + "volatility" => config[("xvol", "volatility")], + "drift" => config[("xvol", "drift")], + "autoregression_target" => config[("xvol", "autoregression_target")], + "autoregression_strength" => config[("xvol", "autoregression_strength")], + "initial_mean" => config[("xvol", "initial_mean")], + "initial_precision" => config[("xvol", "initial_precision")], ), ) ##List of child-parent relations edges = [ - Dict("child" => "u", "value_parents" => "x1"), - Dict("child" => "x1", "value_parents" => category_binary_parent_names), + Dict("child" => "u", "value_parents" => "xcat"), + Dict("child" => "xcat", "value_parents" => category_binary_parent_names), ] #Add relations between binary nodes and their parents @@ -601,12 +601,12 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) edges, Dict( "child" => child_name, - "value_parents" => (parent_name, config[("x1", "x2", "value_coupling")]), + "value_parents" => (parent_name, config[("xbin", "xprob", "value_coupling")]), ), ) #Add the derived parameter name to derived parameters vector push!( - derived_parameters_value_coupling_x1_x2, + derived_parameters_value_coupling_xbin_xprob, (child_name, parent_name, "value_coupling"), ) end @@ -617,43 +617,43 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) edges, Dict( "child" => child_name, - "volatility_parents" => ("x3", config[("x2", "x3", "volatility_coupling")]), + "volatility_parents" => ("xvol", config[("xprob", "xvol", "volatility_coupling")]), ), ) #Add the derived parameter name to derived parameters vector push!( - derived_parameters_x2_x3_volatility_coupling, - (child_name, "x3", "volatility_coupling"), + derived_parameters_xprob_xvol_volatility_coupling, + (child_name, "xvol", "volatility_coupling"), ) end #Create dictionary with shared parameter information shared_parameters = Dict() - shared_parameters["x2_volatility"] = - (config[("x2", "volatility")], derived_parameters_x2_volatility) + shared_parameters["xprob_volatility"] = + (config[("xprob", "volatility")], derived_parameters_xprob_volatility) - shared_parameters["x2_initial_precisions"] = - (config[("x2", "initial_precision")], derived_parameters_x2_initial_precision) + shared_parameters["xprob_initial_precisions"] = + (config[("xprob", "initial_precision")], derived_parameters_xprob_initial_precision) - shared_parameters["x2_initial_means"] = - (config[("x2", "initial_mean")], derived_parameters_x2_initial_mean) + shared_parameters["xprob_initial_means"] = + (config[("xprob", "initial_mean")], derived_parameters_xprob_initial_mean) - shared_parameters["x2_drifts"] = - (config[("x2", "drift")], derived_parameters_x2_drift) + shared_parameters["xprob_drifts"] = + (config[("xprob", "drift")], derived_parameters_xprob_drift) - shared_parameters["x2_autoregression_strengths"] = - (config[("x2", "autoregression_strength")], derived_parameters_x2_autoregression_strength) + shared_parameters["xprob_autoregression_strengths"] = + (config[("xprob", "autoregression_strength")], derived_parameters_xprob_autoregression_strength) - shared_parameters["x2_autoregression_targets"] = - (config[("x2", "autoregression_target")], derived_parameters_x2_autoregression_target) + shared_parameters["xprob_autoregression_targets"] = + (config[("xprob", "autoregression_target")], derived_parameters_xprob_autoregression_target) - shared_parameters["value_couplings_x1_x2"] = - (config[("x1", "x2", "value_coupling")], derived_parameters_value_coupling_x1_x2) + shared_parameters["value_couplings_xbin_xprob"] = + (config[("xbin", "xprob", "value_coupling")], derived_parameters_value_coupling_xbin_xprob) - shared_parameters["volatility_couplings_x2_x3"] = ( - config[("x2", "x3", "volatility_coupling")], - derived_parameters_x2_x3_volatility_coupling, + shared_parameters["volatility_couplings_xprob_xvol"] = ( + config[("xprob", "xvol", "volatility_coupling")], + derived_parameters_xprob_xvol_volatility_coupling, ) #Initialize the HGF @@ -671,29 +671,29 @@ end premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true) The categorical state transition 3 level HGF model, learns state transition probabilities between a set of n categorical states. -It has one categorical input node u, with a categorical value parent x1_n for each of the n categories, representing which category was transitioned from. -Each categorical node then has a binary parent x1_n_m, representing the category m which the transition was towards. -Each binary node x1_n_m has a continuous parent x2_n_m. -Finally, all of these continuous nodes share a continuous volatility parent x3. -Setting parameter values for x1 and x2 sets that parameter value for each of the x1_n_m and x2_n_m nodes. +It has one categorical input node u, with a categorical value parent xcat_n for each of the n categories, representing which category was transitioned from. +Each categorical node then has a binary parent xbin_n_m, representing the category m which the transition was towards. +Each binary node xbin_n_m has a continuous parent xprob_n_m. +Finally, all of these continuous nodes share a continuous volatility parent xvol. +Setting parameter values for xbin and xprob sets that parameter value for each of the xbin_n_m and xprob_n_m nodes. This HGF has five shared parameters: -"x2_volatility" -"x2_initial_precisions" -"x2_initial_means" -"value_couplings_x1_x2" -"volatility_couplings_x2_x3" +"xprob_volatility" +"xprob_initial_precisions" +"xprob_initial_means" +"value_couplings_xbin_xprob" +"volatility_couplings_xprob_xvol" # Config defaults: - "n_categories": 4 - - ("x2", "volatility"): -2 - - ("x3", "volatility"): -2 - - ("x1", "x2", "volatility_coupling"): 1 - - ("x2", "x3", "volatility_coupling"): 1 - - ("x2", "initial_mean"): 0 - - ("x2", "initial_precision"): 1 - - ("x3", "initial_mean"): 0 - - ("x3", "initial_precision"): 1 + - ("xprob", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("xbin", "xprob", "volatility_coupling"): 1 + - ("xprob", "xvol", "volatility_coupling"): 1 + - ("xprob", "initial_mean"): 0 + - ("xprob", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 """ function premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true) @@ -701,22 +701,22 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo defaults = Dict( "n_categories" => 4, - ("x2", "volatility") => -2, - ("x2", "drift") => 0, - ("x2", "autoregression_target") => 0, - ("x2", "autoregression_strength") => 0, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, + ("xprob", "volatility") => -2, + ("xprob", "drift") => 0, + ("xprob", "autoregression_target") => 0, + ("xprob", "autoregression_strength") => 0, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, - ("x3", "volatility") => -2, - ("x3", "drift") => 0, - ("x3", "autoregression_target") => 0, - ("x3", "autoregression_strength") => 0, - ("x3", "initial_mean") => 0, - ("x3", "initial_precision") => 1, + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoregression_target") => 0, + ("xvol", "autoregression_strength") => 0, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1, - ("x2", "x3", "volatility_coupling") => 1, + ("xbin", "xprob", "value_coupling") => 1, + ("xprob", "xvol", "volatility_coupling") => 1, "update_type" => EnhancedUpdate(), ) @@ -738,31 +738,31 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo binary_node_continuous_parent_names = Vector{String}() #Empty lists for derived parameters - derived_parameters_x2_initial_precision = [] - derived_parameters_x2_initial_mean = [] - derived_parameters_x2_volatility = [] - derived_parameters_x2_drift = [] - derived_parameters_x2_autoregression_target = [] - derived_parameters_x2_autoregression_strength = [] - derived_parameters_value_coupling_x1_x2 = [] - derived_parameters_x2_x3_volatility_coupling = [] + derived_parameters_xprob_initial_precision = [] + derived_parameters_xprob_initial_mean = [] + derived_parameters_xprob_volatility = [] + derived_parameters_xprob_drift = [] + derived_parameters_xprob_autoregression_target = [] + derived_parameters_xprob_autoregression_strength = [] + derived_parameters_value_coupling_xbin_xprob = [] + derived_parameters_xprob_xvol_volatility_coupling = [] #Go through each category that the transition may have been from for category_from = 1:config["n_categories"] #One input node and its state node parent for each push!(categorical_input_node_names, "u" * string(category_from)) - push!(categorical_state_node_names, "x1_" * string(category_from)) + push!(categorical_state_node_names, "xcat_" * string(category_from)) #Go through each category that the transition may have been to for category_to = 1:config["n_categories"] #Each categorical state node has a binary parent for each push!( categorical_node_binary_parent_names, - "x1_" * string(category_from) * "_" * string(category_to), + "xbin_" * string(category_from) * "_" * string(category_to), ) #And each binary parent has a continuous parent of its own push!( binary_node_continuous_parent_names, - "x2_" * string(category_from) * "_" * string(category_to), + "xprob_" * string(category_from) * "_" * string(category_to), ) end end @@ -801,21 +801,21 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo Dict( "name" => node_name, "type" => "continuous", - "initial_mean" => config[("x2", "initial_mean")], - "initial_precision" => config[("x2", "initial_precision")], - "volatility" => config[("x2", "volatility")], - "drift" => config[("x2", "drift")], - "autoregression_target" => config[("x2", "autoregression_target")], - "autoregression_strength" => config[("x2", "autoregression_strength")], + "initial_mean" => config[("xprob", "initial_mean")], + "initial_precision" => config[("xprob", "initial_precision")], + "volatility" => config[("xprob", "volatility")], + "drift" => config[("xprob", "drift")], + "autoregression_target" => config[("xprob", "autoregression_target")], + "autoregression_strength" => config[("xprob", "autoregression_strength")], ), ) #Add the derived parameter name to derived parameters vector - push!(derived_parameters_x2_initial_precision, (node_name, "initial_precision")) - push!(derived_parameters_x2_initial_mean, (node_name, "initial_mean")) - push!(derived_parameters_x2_volatility, (node_name, "volatility")) - push!(derived_parameters_x2_drift, (node_name, "drift")) - push!(derived_parameters_x2_autoregression_strength, (node_name, "autoregression_strength")) - push!(derived_parameters_x2_autoregression_target, (node_name, "autoregression_target")) + push!(derived_parameters_xprob_initial_precision, (node_name, "initial_precision")) + push!(derived_parameters_xprob_initial_mean, (node_name, "initial_mean")) + push!(derived_parameters_xprob_volatility, (node_name, "volatility")) + push!(derived_parameters_xprob_drift, (node_name, "drift")) + push!(derived_parameters_xprob_autoregression_strength, (node_name, "autoregression_strength")) + push!(derived_parameters_xprob_autoregression_target, (node_name, "autoregression_target")) end @@ -823,14 +823,14 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo push!( state_nodes, Dict( - "name" => "x3", + "name" => "xvol", "type" => "continuous", - "volatility" => config[("x3", "volatility")], - "drift" => config[("x3", "drift")], - "autoregression_target" => config[("x3", "autoregression_target")], - "autoregression_strength" => config[("x3", "autoregression_strength")], - "initial_mean" => config[("x3", "initial_mean")], - "initial_precision" => config[("x3", "initial_precision")], + "volatility" => config[("xvol", "volatility")], + "drift" => config[("xvol", "drift")], + "autoregression_target" => config[("xvol", "autoregression_target")], + "autoregression_strength" => config[("xvol", "autoregression_strength")], + "initial_mean" => config[("xvol", "initial_mean")], + "initial_precision" => config[("xvol", "initial_precision")], ), ) @@ -875,12 +875,12 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo edges, Dict( "child" => child_name, - "value_parents" => (parent_name, config[("x1", "x2", "value_coupling")]), + "value_parents" => (parent_name, config[("xbin", "xprob", "value_coupling")]), ), ) #Add the derived parameter name to derived parameters vector push!( - derived_parameters_value_coupling_x1_x2, + derived_parameters_value_coupling_xbin_xprob, (child_name, parent_name, "value_coupling"), ) end @@ -892,13 +892,13 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo edges, Dict( "child" => child_name, - "volatility_parents" => ("x3", config[("x2", "x3", "volatility_coupling")]), + "volatility_parents" => ("xvol", config[("xprob", "xvol", "volatility_coupling")]), ), ) #Add the derived parameter name to derived parameters vector push!( - derived_parameters_x2_x3_volatility_coupling, - (child_name, "x3", "volatility_coupling"), + derived_parameters_xprob_xvol_volatility_coupling, + (child_name, "xvol", "volatility_coupling"), ) end @@ -907,30 +907,30 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo shared_parameters = Dict() - shared_parameters["x2_volatility"] = - (config[("x2", "volatility")], derived_parameters_x2_volatility) + shared_parameters["xprob_volatility"] = + (config[("xprob", "volatility")], derived_parameters_xprob_volatility) - shared_parameters["x2_initial_precisions"] = - (config[("x2", "initial_precision")], derived_parameters_x2_initial_precision) + shared_parameters["xprob_initial_precisions"] = + (config[("xprob", "initial_precision")], derived_parameters_xprob_initial_precision) - shared_parameters["x2_initial_means"] = - (config[("x2", "initial_mean")], derived_parameters_x2_initial_mean) + shared_parameters["xprob_initial_means"] = + (config[("xprob", "initial_mean")], derived_parameters_xprob_initial_mean) - shared_parameters["x2_drifts"] = - (config[("x2", "drift")], derived_parameters_x2_drift) + shared_parameters["xprob_drifts"] = + (config[("xprob", "drift")], derived_parameters_xprob_drift) - shared_parameters["x2_autoregression_strengths"] = - (config[("x2", "autoregression_strength")], derived_parameters_x2_autoregression_strength) + shared_parameters["xprob_autoregression_strengths"] = + (config[("xprob", "autoregression_strength")], derived_parameters_xprob_autoregression_strength) - shared_parameters["x2_autoregression_targets"] = - (config[("x2", "autoregression_target")], derived_parameters_x2_autoregression_target) + shared_parameters["xprob_autoregression_targets"] = + (config[("xprob", "autoregression_target")], derived_parameters_xprob_autoregression_target) - shared_parameters["value_couplings_x1_x2"] = - (config[("x1", "x2", "value_coupling")], derived_parameters_value_coupling_x1_x2) + shared_parameters["value_couplings_xbin_xprob"] = + (config[("xbin", "xprob", "value_coupling")], derived_parameters_value_coupling_xbin_xprob) - shared_parameters["volatility_couplings_x2_x3"] = ( - config[("x2", "x3", "volatility_coupling")], - derived_parameters_x2_x3_volatility_coupling, + shared_parameters["volatility_couplings_xprob_xvol"] = ( + config[("xprob", "xvol", "volatility_coupling")], + derived_parameters_xprob_xvol_volatility_coupling, ) #Initialize the HGF diff --git a/src/utils/get_prediction.jl b/src/utils/get_prediction.jl index bfced30..7bdd30a 100644 --- a/src/utils/get_prediction.jl +++ b/src/utils/get_prediction.jl @@ -6,15 +6,15 @@ A single node can also be passed. """ function get_prediction end -function get_prediction(agent::Agent, node_name::String = "x1") +function get_prediction(agent::Agent, node_name::String) - #Get prediction form the HGF + #Get prediction from the HGF prediction = get_prediction(agent.substruct, node_name) return prediction end -function get_prediction(hgf::HGF, node_name::String = "x1") +function get_prediction(hgf::HGF, node_name::String) #Get the prediction of the given node return get_prediction(hgf.all_nodes[node_name]) end diff --git a/test/data/canonical_continuous2level_states.csv b/test/data/canonical_continuous2level_states.csv index 0952fc5..aeaab6f 100644 --- a/test/data/canonical_continuous2level_states.csv +++ b/test/data/canonical_continuous2level_states.csv @@ -1,4 +1,4 @@ -,x1_mean,x1_precision,x2_mean,x2_precision +,mu1,pi1,mu2,pi2 0,1.04,10000.0,1.0,10.0 1,1.0377859183728282,19421.14485405236,0.9968176766176965,4.262927405210326 2,1.0356343652765874,27356.60292609981,0.9912464787404184,2.7209034885694074 diff --git a/test/test_canonical.jl b/test/test_canonical.jl index 7a8d5e4..c8e73cc 100644 --- a/test/test_canonical.jl +++ b/test/test_canonical.jl @@ -36,15 +36,15 @@ using Plots ### Set up HGF ### #set parameters and starting states parameters = Dict( - ("u", "x1", "value_coupling") => 1.0, - ("x1", "x2", "volatility_coupling") => 1.0, + ("u", "x", "value_coupling") => 1.0, + ("x", "xvol", "volatility_coupling") => 1.0, ("u", "input_noise") => log(1e-4), - ("x1", "volatility") => -13, - ("x2", "volatility") => -2, - ("x1", "initial_mean") => 1.04, - ("x1", "initial_precision") => 1e4, - ("x2", "initial_mean") => 1.0, - ("x2", "initial_precision") => 10, + ("x", "volatility") => -13, + ("xvol", "volatility") => -2, + ("x", "initial_mean") => 1.04, + ("x", "initial_precision") => 1e4, + ("xvol", "initial_mean") => 1.0, + ("xvol", "initial_precision") => 10, "update_type" => ClassicUpdate(), ) @@ -56,26 +56,26 @@ using Plots #Construct result output dataframe result_outputs = DataFrame( - x1_mean = test_hgf.state_nodes["x1"].history.posterior_mean, - x1_precision = test_hgf.state_nodes["x1"].history.posterior_precision, - x2_mean = test_hgf.state_nodes["x2"].history.posterior_mean, - x2_precision = test_hgf.state_nodes["x2"].history.posterior_precision, + x_mean = test_hgf.state_nodes["x"].history.posterior_mean, + x_precision = test_hgf.state_nodes["x"].history.posterior_precision, + xvol_mean = test_hgf.state_nodes["xvol"].history.posterior_mean, + xvol_precision = test_hgf.state_nodes["xvol"].history.posterior_precision, ) #Test if the values are approximately the same @testset "compare output trajectories" begin for i = 1:nrow(result_outputs) - @test result_outputs.x1_mean[i] ≈ target_outputs.x1_mean[i] - @test result_outputs.x1_precision[i] ≈ target_outputs.x1_precision[i] - @test result_outputs.x2_mean[i] ≈ target_outputs.x2_mean[i] - @test result_outputs.x2_precision[i] ≈ target_outputs.x2_precision[i] + @test result_outputs.x_mean[i] ≈ target_outputs.mu1[i] + @test result_outputs.x_precision[i] ≈ target_outputs.pi1[i] + @test result_outputs.xvol_mean[i] ≈ target_outputs.mu2[i] + @test result_outputs.xvol_precision[i] ≈ target_outputs.pi2[i] end end @testset "Trajectory plots" begin #Make trajectory plots plot_trajectory(test_hgf, "u") - plot_trajectory!(test_hgf, ("x1", "posterior")) + plot_trajectory!(test_hgf, ("x", "posterior")) end end @@ -95,14 +95,14 @@ using Plots test_parameters = Dict( ("u", "category_means") => [0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "volatility") => -2.5, - ("x3", "volatility") => -6.0, - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, - ("x2", "initial_mean") => 0.0, - ("x2", "initial_precision") => 1.0, - ("x3", "initial_mean") => 1.0, - ("x3", "initial_precision") => 1.0, + ("xprob", "volatility") => -2.5, + ("xvol", "volatility") => -6.0, + ("xbin", "xprob", "value_coupling") => 1.0, + ("xprob", "xvol", "volatility_coupling") => 1.0, + ("xprob", "initial_mean") => 0.0, + ("xprob", "initial_precision") => 1.0, + ("xvol", "initial_mean") => 1.0, + ("xvol", "initial_precision") => 1.0, "update_type" => ClassicUpdate(), ) @@ -114,12 +114,12 @@ using Plots #Construct result output dataframe result_outputs = DataFrame( - x1_mean = test_hgf.state_nodes["x1"].history.posterior_mean, - x1_precision = test_hgf.state_nodes["x1"].history.posterior_precision, - x2_mean = test_hgf.state_nodes["x2"].history.posterior_mean, - x2_precision = test_hgf.state_nodes["x2"].history.posterior_precision, - x3_mean = test_hgf.state_nodes["x3"].history.posterior_mean, - x3_precision = test_hgf.state_nodes["x3"].history.posterior_precision, + xbin_mean = test_hgf.state_nodes["xbin"].history.posterior_mean, + xbin_precision = test_hgf.state_nodes["xbin"].history.posterior_precision, + xprob_mean = test_hgf.state_nodes["xprob"].history.posterior_mean, + xprob_precision = test_hgf.state_nodes["xprob"].history.posterior_precision, + xvol_mean = test_hgf.state_nodes["xvol"].history.posterior_mean, + xvol_precision = test_hgf.state_nodes["xvol"].history.posterior_precision, ) #Remove the first row @@ -128,25 +128,25 @@ using Plots #Test if the values are approximately the same @testset "compare output trajectories" begin for i = 1:nrow(canonical_trajectory) - @test result_outputs.x1_mean[i] ≈ canonical_trajectory.mu1[i] - @test result_outputs.x1_precision[i] ≈ 1 / canonical_trajectory.sa1[i] + @test result_outputs.xbin_mean[i] ≈ canonical_trajectory.mu1[i] + @test result_outputs.xbin_precision[i] ≈ 1 / canonical_trajectory.sa1[i] @test isapprox( - result_outputs.x2_mean[i], + result_outputs.xprob_mean[i], canonical_trajectory.mu2[i], rtol = 0.1, ) @test isapprox( - result_outputs.x2_precision[i], + result_outputs.xprob_precision[i], 1 / canonical_trajectory.sa2[i], rtol = 0.1, ) @test isapprox( - result_outputs.x3_mean[i], + result_outputs.xvol_mean[i], canonical_trajectory.mu3[i], rtol = 0.1, ) @test isapprox( - result_outputs.x3_precision[i], + result_outputs.xvol_precision[i], 1 / canonical_trajectory.sa3[i], rtol = 0.1, ) @@ -156,7 +156,7 @@ using Plots @testset "Trajectory plots" begin #Make trajectory plots plot_trajectory(test_hgf, "u") - plot_trajectory!(test_hgf, ("x1", "prediction")) + plot_trajectory!(test_hgf, ("xbin", "prediction")) end end end diff --git a/test/test_fit_model.jl b/test/test_fit_model.jl index 3d8ee28..a46c6ad 100644 --- a/test/test_fit_model.jl +++ b/test/test_fit_model.jl @@ -22,21 +22,21 @@ using Turing # Set fixed parsmeters and priors for fitting test_fixed_parameters = Dict( - ("x1", "initial_mean") => 100, - ("x2", "initial_mean") => 1.0, - ("x2", "initial_precision") => 600, - ("u", "x1", "value_coupling") => 1.0, - ("x1", "x2", "volatility_coupling") => 1.0, + ("x", "initial_mean") => 100, + ("xvol", "initial_mean") => 1.0, + ("xvol", "initial_precision") => 600, + ("u", "x", "value_coupling") => 1.0, + ("x", "xvol", "volatility_coupling") => 1.0, "gaussian_action_precision" => 100, - ("x2", "volatility") => -4, + ("xvol", "volatility") => -4, ("u", "input_noise") => 4, - ("x2", "drift") => 1, + ("xvol", "drift") => 1, ) test_param_priors = Dict( - ("x1", "volatility") => Normal(log(100.0), 4), - ("x1", "initial_mean") => Normal(1, sqrt(100.0)), - ("x1", "drift") => Normal(0, 1), + ("x", "volatility") => Normal(log(100.0), 4), + ("x", "initial_mean") => Normal(1, sqrt(100.0)), + ("x", "drift") => Normal(0, 1), ) #Fit single chain with defaults @@ -73,7 +73,7 @@ using Turing fitted_model, test_agent, test_input, - ("x1", "posterior_mean"); + ("x", "posterior_mean"); verbose = false, n_simulations = 3, ) @@ -96,18 +96,18 @@ using Turing test_fixed_parameters = Dict( ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "initial_mean") => 3.0, - ("x2", "initial_precision") => exp(2.306), - ("x3", "initial_mean") => 3.2189, - ("x3", "initial_precision") => exp(-1.0986), - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, - ("x3", "volatility") => -3, + ("xprob", "initial_mean") => 3.0, + ("xprob", "initial_precision") => exp(2.306), + ("xvol", "initial_mean") => 3.2189, + ("xvol", "initial_precision") => exp(-1.0986), + ("xbin", "xprob", "value_coupling") => 1.0, + ("xprob", "xvol", "volatility_coupling") => 1.0, + ("xvol", "volatility") => -3, ) test_param_priors = Dict( "softmax_action_precision" => truncated(Normal(100, 20), 0, Inf), - ("x2", "volatility") => Normal(-7, 5), + ("xprob", "volatility") => Normal(-7, 5), ) #Fit single chain with defaults @@ -144,7 +144,7 @@ using Turing fitted_model, test_agent, test_input, - ("x1", "posterior_mean"), + ("xbin", "posterior_mean"), verbose = false, n_simulations = 3, ) From 00b8a57400eb968e40930d6f3c563b3fca7cc900 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Tue, 28 Nov 2023 14:55:25 +0100 Subject: [PATCH 03/16] Many changes --- .gitignore | 3 +- Project.toml | 2 +- docs/Project.toml | 1 + docs/src/Julia_src_files/building_an_HGF.jl | 10 +- .../src/Julia_src_files/fitting_hgf_models.jl | 14 +- docs/src/Julia_src_files/utility_functions.jl | 19 +- docs/src/tutorials/classic_JGET.jl | 63 +- docs/src/tutorials/classic_binary.jl | 8 +- docs/src/tutorials/classic_usdchf.jl | 6 +- .../utils/get_parameters.jl | 29 +- .../utils/get_states.jl | 1 - src/ActionModels_variations/utils/reset.jl | 140 +-- .../utils/set_parameters.jl | 23 +- src/HierarchicalGaussianFiltering.jl | 29 +- src/create_hgf/check_hgf.jl | 170 +--- src/{structs.jl => create_hgf/hgf_structs.jl} | 368 ++++--- src/create_hgf/init_hgf.jl | 295 +++--- src/premade_models/premade_hgfs.jl | 945 ------------------ .../premade_hgfs/premade_JGET.jl | 138 +++ .../premade_hgfs/premade_binary_2level.jl | 73 ++ .../premade_hgfs/premade_binary_3level.jl | 109 ++ .../premade_categorical_3level.jl | 207 ++++ .../premade_categorical_transitions_3level.jl | 272 +++++ .../premade_hgfs/premade_continuous_2level.jl | 93 ++ .../node_updates/binary_input_node.jl | 68 ++ .../node_updates/binary_state_node.jl | 236 +++++ .../node_updates/categorical_input_node.jl | 43 + .../node_updates/categorical_state_node.jl | 167 ++++ .../node_updates/continuous_input_node.jl | 178 ++++ .../node_updates/continuous_state_node.jl | 614 ++++++++++++ src/update_hgf/update_equations.jl | 813 --------------- src/update_hgf/update_hgf.jl | 24 +- src/update_hgf/update_node.jl | 308 ------ src/utils/get_prediction.jl | 18 +- src/utils/get_surprise.jl | 6 +- test/Project.toml | 1 + test/quicktests.jl | 4 +- test/runtests.jl | 77 +- .../data/canonical_binary3level.csv | 0 .../canonical_continuous2level_inputs.dat | 0 .../canonical_continuous2level_states.csv | 0 test/{ => testsuite}/test_canonical.jl | 11 +- test/{ => testsuite}/test_fit_model.jl | 7 +- test/{ => testsuite}/test_initialization.jl | 30 +- test/{ => testsuite}/test_premade_agent.jl | 0 test/{ => testsuite}/test_premade_hgf.jl | 0 .../{ => testsuite}/test_shared_parameters.jl | 6 +- 47 files changed, 2820 insertions(+), 2809 deletions(-) rename src/{structs.jl => create_hgf/hgf_structs.jl} (53%) delete mode 100644 src/premade_models/premade_hgfs.jl create mode 100644 src/premade_models/premade_hgfs/premade_JGET.jl create mode 100644 src/premade_models/premade_hgfs/premade_binary_2level.jl create mode 100644 src/premade_models/premade_hgfs/premade_binary_3level.jl create mode 100644 src/premade_models/premade_hgfs/premade_categorical_3level.jl create mode 100644 src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl create mode 100644 src/premade_models/premade_hgfs/premade_continuous_2level.jl create mode 100644 src/update_hgf/node_updates/binary_input_node.jl create mode 100644 src/update_hgf/node_updates/binary_state_node.jl create mode 100644 src/update_hgf/node_updates/categorical_input_node.jl create mode 100644 src/update_hgf/node_updates/categorical_state_node.jl create mode 100644 src/update_hgf/node_updates/continuous_input_node.jl create mode 100644 src/update_hgf/node_updates/continuous_state_node.jl delete mode 100644 src/update_hgf/update_equations.jl delete mode 100644 src/update_hgf/update_node.jl rename test/{ => testsuite}/data/canonical_binary3level.csv (100%) rename test/{ => testsuite}/data/canonical_continuous2level_inputs.dat (100%) rename test/{ => testsuite}/data/canonical_continuous2level_states.csv (100%) rename test/{ => testsuite}/test_canonical.jl (95%) rename test/{ => testsuite}/test_fit_model.jl (95%) rename test/{ => testsuite}/test_initialization.jl (74%) rename test/{ => testsuite}/test_premade_agent.jl (100%) rename test/{ => testsuite}/test_premade_hgf.jl (100%) rename test/{ => testsuite}/test_shared_parameters.jl (92%) diff --git a/.gitignore b/.gitignore index b02f4d6..f0af5a4 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ settings.json Manifest.toml docs/Manifest.toml test/Manifest.toml -/docs/src/generated_markdowns/*.md \ No newline at end of file +/docs/src/generated_markdowns/*.md +.vscode \ No newline at end of file diff --git a/Project.toml b/Project.toml index be97bc3..e8b3640 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,7 @@ authors = [ "Peter Thestrup Waade ptw@cas.au.dk", "Anna Hedvig Møller hedvig.2808@gmail.com", "Jacopo Comoglio jacopo.comoglio@gmail.com", "Christoph Mathys chmathys@cas.au.dk"] -version = "0.3.3" +version = "0.5.0" [deps] ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" diff --git a/docs/Project.toml b/docs/Project.toml index 3bfd917..dca5a47 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,6 +2,7 @@ ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Debugger = "31a5f54b-26ea-5ae9-a837-f05ce5417438" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/docs/src/Julia_src_files/building_an_HGF.jl b/docs/src/Julia_src_files/building_an_HGF.jl index 7f1b0d7..7d1a2a5 100644 --- a/docs/src/Julia_src_files/building_an_HGF.jl +++ b/docs/src/Julia_src_files/building_an_HGF.jl @@ -49,12 +49,10 @@ state_nodes = [ # At the buttom of our hierarchy we have the binary input node. The Input node has binary state node as parent. -edges = [ - Dict("child" => "Input_node", "value_parents" => "binary_state_node"), - - ## The next relation is from the point of view of the binary state node. We specify out continous state node as parent with the value coupling as 1. - Dict("child" => "binary_state_node", "value_parents" => ("continuous_state_node", 1)), -]; +edges = Dict( + ("Input_node", "binary_state_node") => ObservationCoupling(), + ("binary_state_node", "continuous_state_node") => ProbabilityCoupling(1), +); # We are ready to initialize our HGF now. diff --git a/docs/src/Julia_src_files/fitting_hgf_models.jl b/docs/src/Julia_src_files/fitting_hgf_models.jl index 935b197..8e4b0b9 100644 --- a/docs/src/Julia_src_files/fitting_hgf_models.jl +++ b/docs/src/Julia_src_files/fitting_hgf_models.jl @@ -53,8 +53,8 @@ hgf_parameters = Dict( ("xvol", "volatility") => -6.0, ("xvol", "initial_mean") => 1, ("xvol", "initial_precision") => 1, - ("xbin", "xprob", "value_coupling") => 1.0, - ("xprob", "xvol", "volatility_coupling") => 1.0, + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, ) hgf = premade_hgf("binary_3level", hgf_parameters, verbose = false) @@ -93,8 +93,8 @@ fixed_parameters = Dict( ("xprob", "initial_precision") => 1, ("xvol", "initial_mean") => 1, ("xvol", "initial_precision") => 1, - ("xbin", "xprob", "value_coupling") => 1.0, - ("xprob", "xvol", "volatility_coupling") => 1.0, + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, ("xvol", "volatility") => -6.0, ); @@ -114,9 +114,9 @@ fitted_model = fit_model( verbose = true, n_iterations = 10, ) +set_parameters!(agent, hgf_parameters) # ## Plotting Functions - plot(fitted_model) # Plot the posterior @@ -149,10 +149,9 @@ plot_predictive_simulation( # Fit the actions where we use the default parameter values from the HGF. fitted_model = fit_model(agent, param_priors, inputs, actions, verbose = true, n_iterations = 10) - +set_parameters!(agent, hgf_parameters) # We can place our turing chain as a our posterior in the function, and get our posterior predictive simulation plot: - plot_predictive_simulation( fitted_model, agent, @@ -162,7 +161,6 @@ plot_predictive_simulation( ) # We can get the posterior - get_posteriors(fitted_model) # plot the chains diff --git a/docs/src/Julia_src_files/utility_functions.jl b/docs/src/Julia_src_files/utility_functions.jl index 3cb6a2d..438b2d0 100644 --- a/docs/src/Julia_src_files/utility_functions.jl +++ b/docs/src/Julia_src_files/utility_functions.jl @@ -31,7 +31,7 @@ agent = premade_agent("hgf_binary_softmax_action") get_parameters(agent) # getting couplings -get_parameters(agent, ("xprob", "xvol", "volatility_coupling")) +get_parameters(agent, ("xprob", "xvol", "coupling_strength")) # getting multiple parameters specify them in a vector get_parameters(agent, [("xvol", "volatility"), ("xvol", "initial_precision")]) @@ -46,7 +46,13 @@ get_states(agent) get_states(agent, ("xprob", "posterior_precision")) #getting multiple states -get_states(agent, [("xprob", "posterior_precision"), ("xprob", "volatility_weighted_prediction_precision")]) +get_states( + agent, + [ + ("xprob", "posterior_precision"), + ("xprob", "volatility_weighted_prediction_precision"), + ], +) # ### Setting Parameters @@ -67,8 +73,8 @@ hgf_parameters = Dict( ("xvol", "volatility") => -6.0, ("xvol", "initial_mean") => 1, ("xvol", "initial_precision") => 1, - ("xbin", "xprob", "value_coupling") => 1.0, - ("xprob", "xvol", "volatility_coupling") => 1.0, + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, ) hgf = premade_hgf("binary_3level", hgf_parameters) @@ -85,7 +91,7 @@ set_parameters!(agent, ("xvol", "initial_precision"), 4) set_parameters!( agent, - Dict(("xvol", "initial_precision") => 5, ("xbin", "xprob", "value_coupling") => 2.0), + Dict(("xvol", "initial_precision") => 5, ("xbin", "xprob", "coupling_strength") => 2.0), ) # ###Giving Inputs @@ -176,9 +182,6 @@ plot_trajectory(agent, ("xvol", "posterior")) # You can specify an HGF or an agent in the funciton. -# get prediction of the last state -get_prediction(agent) - #specify another node to get predictions from: get_prediction(agent, "xprob") diff --git a/docs/src/tutorials/classic_JGET.jl b/docs/src/tutorials/classic_JGET.jl index b5f5cec..700313d 100644 --- a/docs/src/tutorials/classic_JGET.jl +++ b/docs/src/tutorials/classic_JGET.jl @@ -2,10 +2,16 @@ using ActionModels, HierarchicalGaussianFiltering using CSV, DataFrames using Plots, StatsPlots using Distributions -path = "docs/src/tutorials/data" + + +# Get the path for the HGF superfolder +hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) +# Add the path to the data files +data_path = hgf_path * "/docs/src/tutorials/data/" #Load data -data = CSV.read("$path/classic_cannonball_data.csv", DataFrame) +data = CSV.read(data_path * "classic_cannonball_data.csv", DataFrame) +inputs = data[(data.ID.==20).&(data.session.==1), :].outcome #Create HGF hgf = premade_hgf("JGET", verbose = false) @@ -14,16 +20,19 @@ agent = premade_agent("hgf_gaussian_action", hgf) #Set parameters parameters = Dict( "gaussian_action_precision" => 1, + ("u", "input_noise") => 0, + ("x", "initial_mean") => first(inputs) + 2, + ("x", "initial_precision") => 0.001, ("x", "volatility") => -8, - ("xvol", "volatility") => -5, - ("xnoise", "volatility") => -5, - ("xnoise_vol", "volatility") => -5, - ("x", "xvol", "volatility_coupling") => 1, - ("xnoise", "xnoise_vol", "volatility_coupling") => 1, + ("xvol", "volatility") => -8, + ("xnoise", "volatility") => -7, + ("xnoise_vol", "volatility") => -2, + ("x", "xvol", "coupling_strength") => 1, + ("xnoise", "xnoise_vol", "coupling_strength") => 1, ) set_parameters!(agent, parameters) +reset!(agent) -inputs = data[(data.ID.==20).&(data.session.==1), :].outcome #Simulate updates and actions actions = give_inputs!(agent, inputs); #Plot belief trajectories @@ -32,41 +41,3 @@ plot_trajectory!(agent, "x") plot_trajectory(agent, "xvol") plot_trajectory(agent, "xnoise") plot_trajectory(agent, "xnoise_vol") - -priors = Dict( - "gaussian_action_precision" => LogNormal(-1, 0.1), - ("x", "volatility") => Normal(-8, 1), -) - -data_subset = data[(data.ID.∈[[20, 21]]).&(data.session.∈[[1, 2]]), :] - -using Distributed -addprocs(6, exeflags = "--project") -@everywhere @eval using HierarchicalGaussianFiltering - -results = fit_model( - agent, - priors, - data_subset, - independent_group_cols = [:ID, :session], - input_cols = [:outcome], - action_cols = [:response], - n_cores = 6, -) - -fitted_model = results[(20, 1)] - -plot_parameter_distribution(fitted_model, priors) - - - - -posterior = get_posteriors(fitted_model) - -set_parameters!(agent, posterior) - -reset!(agent) - -give_inputs!(agent, inputs) - -get_history(agent, ("x", "value_prediction_error")) diff --git a/docs/src/tutorials/classic_binary.jl b/docs/src/tutorials/classic_binary.jl index efd5d9c..80e39a2 100644 --- a/docs/src/tutorials/classic_binary.jl +++ b/docs/src/tutorials/classic_binary.jl @@ -29,8 +29,8 @@ hgf_parameters = Dict( ("xvol", "volatility") => -6.0, ("xvol", "initial_mean") => 1, ("xvol", "initial_precision") => 1, - ("xbin", "xprob", "value_coupling") => 1.0, - ("xprob", "xvol", "volatility_coupling") => 1.0, + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, ); hgf = premade_hgf("binary_3level", hgf_parameters, verbose = false); @@ -62,8 +62,8 @@ fixed_parameters = Dict( ("xprob", "initial_precision") => 1, ("xvol", "initial_mean") => 1, ("xvol", "initial_precision") => 1, - ("xbin", "xprob", "value_coupling") => 1.0, - ("xprob", "xvol", "volatility_coupling") => 1.0, + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, ("xvol", "volatility") => -6.0, ); diff --git a/docs/src/tutorials/classic_usdchf.jl b/docs/src/tutorials/classic_usdchf.jl index e9a9141..5d38109 100644 --- a/docs/src/tutorials/classic_usdchf.jl +++ b/docs/src/tutorials/classic_usdchf.jl @@ -28,8 +28,7 @@ agent = premade_agent("hgf_gaussian_action", hgf, verbose = false); # Set parameters for parameter recover parameters = Dict( - ("u", "x", "value_coupling") => 1.0, - ("x", "xvol", "volatility_coupling") => 1.0, + ("x", "xvol", "coupling_strength") => 1.0, ("u", "input_noise") => -log(1e4), ("x", "volatility") => -13, ("xvol", "volatility") => -2, @@ -81,8 +80,7 @@ plot_trajectory( #- # Set priors for fitting fixed_parameters = Dict( - ("u", "x", "value_coupling") => 1.0, - ("x", "xvol", "volatility_coupling") => 1.0, + ("x", "xvol", "coupling_strength") => 1.0, ("x", "initial_mean") => 0, ("x", "initial_precision") => 2000, ("xvol", "initial_mean") => 1.0, diff --git a/src/ActionModels_variations/utils/get_parameters.jl b/src/ActionModels_variations/utils/get_parameters.jl index 6c6008a..a467715 100644 --- a/src/ActionModels_variations/utils/get_parameters.jl +++ b/src/ActionModels_variations/utils/get_parameters.jl @@ -49,6 +49,14 @@ function ActionModels.get_parameters(hgf::HGF, target_param::Tuple{String,String #Unpack node name, parent name and param name (node_name, parent_name, param_name) = target_param + #If the specified parameter is not a coupling strength + if !(param_name == "coupling_strength") + throw( + ArgumentError( + "the parameter $target_param is specified as three strings, but is not a coupling strength", + ), + ) + end #If the node does not exist if !(node_name in keys(hgf.all_nodes)) @@ -59,26 +67,15 @@ function ActionModels.get_parameters(hgf::HGF, target_param::Tuple{String,String #Get out the node node = hgf.all_nodes[node_name] - - #If the parameter does not exist in the node - if !(Symbol(param_name) in fieldnames(typeof(node.parameters))) - #Throw an error - throw( - ArgumentError( - "The node $node_name does not have the parameter $param_name in its parameters", - ), - ) - end - #Get out the dictionary of coupling strengths - coupling_strengths = getproperty(node.parameters, Symbol(param_name)) + coupling_strengths = getproperty(node.parameters, :coupling_strengths) #If the specified parent is not in the dictionary if !(parent_name in keys(coupling_strengths)) #Throw an error throw( ArgumentError( - "The node $node_name does not have a $param_name to a parent called $parent_name", + "The node $node_name does not have a coupling strength parameter to a parent called $parent_name", ), ) end @@ -193,16 +190,16 @@ function ActionModels.get_parameters(node::AbstractNode) for param_key in fieldnames(typeof(node.parameters)) #If the parameter is a coupling strength - if param_key in (:value_coupling, :volatility_coupling) + if param_key == :coupling_strengths #Get out the dict with coupling strengths - coupling_strengths = getproperty(node.parameters, param_key) + coupling_strengths = node.parameters.coupling_strengths #Go through each parent for parent_name in keys(coupling_strengths) #Add the coupling strength to the ouput dict - parameters[(node.name, parent_name, string(param_key))] = + parameters[(node.name, parent_name, "coupling_strength")] = coupling_strengths[parent_name] end diff --git a/src/ActionModels_variations/utils/get_states.jl b/src/ActionModels_variations/utils/get_states.jl index e37da53..ef302a9 100644 --- a/src/ActionModels_variations/utils/get_states.jl +++ b/src/ActionModels_variations/utils/get_states.jl @@ -43,7 +43,6 @@ function ActionModels.get_states(node::AbstractNode, state_name::String) if state_name in [ "prediction", "prediction_mean", - "predicted_volatility", "prediction_precision", "volatility_weighted_prediction_precision", ] diff --git a/src/ActionModels_variations/utils/reset.jl b/src/ActionModels_variations/utils/reset.jl index 88c826c..a877670 100644 --- a/src/ActionModels_variations/utils/reset.jl +++ b/src/ActionModels_variations/utils/reset.jl @@ -8,61 +8,8 @@ function ActionModels.reset!(hgf::HGF) #Go through each node for node in hgf.ordered_nodes.all_nodes - #For categorical state nodes - if node isa CategoricalStateNode - #Set states to vectors of missing - node.states.posterior .= missing - node.states.value_prediction_error .= missing - #Empty prediction state - empty!(node.states.prediction) - - #For binary input nodes - elseif node isa BinaryInputNode - #Set states to missing - node.states.value_prediction_error .= missing - node.states.input_value = missing - - #For continuous state nodes - elseif node isa ContinuousStateNode - #Set posterior to initial belief - node.states.posterior_mean = node.parameters.initial_mean - node.states.posterior_precision = node.parameters.initial_precision - #For other states - for state_name in [ - :value_prediction_error, - :volatility_prediction_error, - :prediction_mean, - :predicted_volatility, - :prediction_precision, - :volatility_weighted_prediction_precision, - ] - #Set the state to missing - setfield!(node.states, state_name, missing) - end - - #For continuous input nodes - elseif node isa ContinuousInputNode - - #For all states except auxiliary prediction precision - for state_name in [ - :input_value, - :value_prediction_error, - :volatility_prediction_error, - :predicted_volatility, - :prediction_precision, - ] - #Set the state to missing - setfield!(node.states, state_name, missing) - end - - #For other nodes - else - #For each state - for state_name in fieldnames(typeof(node.states)) - #Set the state to missing - setfield!(node.states, state_name, missing) - end - end + #Reset its state + reset_state!(node) #For each state in the history for state_name in fieldnames(typeof(node.history)) @@ -70,19 +17,76 @@ function ActionModels.reset!(hgf::HGF) #Empty the history empty!(getfield(node.history, state_name)) - #For states other than prediction states - if !( - state_name in [ - :prediction, - :prediction_mean, - :predicted_volatility, - :prediction_precision, - :volatility_weighted_prediction_precision, - ] - ) - #Add the new current state as the first state in the history - push!(getfield(node.history, state_name), getfield(node.states, state_name)) - end + #Add the new current state as the first state in the history + push!(getfield(node.history, state_name), getfield(node.states, state_name)) end end end + + +function reset_state!(node::ContinuousStateNode) + + node.states.posterior_mean = node.parameters.initial_mean + node.states.posterior_precision = node.parameters.initial_precision + + node.states.value_prediction_error = missing + node.states.precision_prediction_error = missing + + node.states.prediction_mean = missing + node.states.predicted_volatility = missing + node.states.prediction_precision = missing + node.states.volatility_weighted_prediction_precision = missing + + return nothing +end + +function reset_state!(node::ContinuousInputNode) + + node.states.input_value = missing + + node.states.value_prediction_error = missing + node.states.precision_prediction_error = missing + + node.states.prediction_mean = missing + node.states.prediction_precision = missing + + return nothing +end + +function reset_state!(node::BinaryStateNode) + + node.states.posterior_mean = missing + node.states.posterior_precision = missing + + node.states.value_prediction_error = missing + + node.states.prediction_mean = missing + node.states.prediction_precision = missing + + return nothing +end + +function reset_state!(node::BinaryInputNode) + + node.states.input_value = missing + node.states.value_prediction_error .= missing + + return nothing +end + +function reset_state!(node::CategoricalStateNode) + + node.states.posterior .= missing + node.states.value_prediction_error .= missing + node.states.prediction .= missing + node.states.parent_predictions .= missing + + return nothing +end + +function reset_state!(node::CategoricalInputNode) + + node.states.input_value = missing + + return nothing +end diff --git a/src/ActionModels_variations/utils/set_parameters.jl b/src/ActionModels_variations/utils/set_parameters.jl index 78b0a95..b57c021 100644 --- a/src/ActionModels_variations/utils/set_parameters.jl +++ b/src/ActionModels_variations/utils/set_parameters.jl @@ -61,6 +61,14 @@ function ActionModels.set_parameters!( #Unpack node name, parent name and parameter name (node_name, parent_name, param_name) = target_param + #If the specified parameter is not a coupling strength + if !(param_name == "coupling_strength") + throw( + ArgumentError( + "the parameter $target_param is specified as three strings, but is not a coupling strength", + ), + ) + end #If the node does not exist if !(node_name in keys(hgf.all_nodes)) @@ -71,26 +79,15 @@ function ActionModels.set_parameters!( #Get the child node node = hgf.all_nodes[node_name] - - #If the param does not exist in the node - if !(Symbol(param_name) in fieldnames(typeof(node.parameters))) - #Throw an error - throw( - ArgumentError( - "The node $node_name does not have the parameter $param_name in its parameters", - ), - ) - end - #Get coupling_strengths - coupling_strengths = getfield(node.parameters, Symbol(param_name)) + coupling_strengths = node.parameters.coupling_strengths #If the specified parent is not in the dictionary if !(parent_name in keys(coupling_strengths)) #Throw an error throw( ArgumentError( - "The node $node_name does not have a $param_name to a parent called $parent_name", + "The node $node_name does not have a coupling strength parameter to a parent called $parent_name", ), ) end diff --git a/src/HierarchicalGaussianFiltering.jl b/src/HierarchicalGaussianFiltering.jl index 21c6467..f43a36a 100644 --- a/src/HierarchicalGaussianFiltering.jl +++ b/src/HierarchicalGaussianFiltering.jl @@ -14,6 +14,12 @@ export premade_agent, plot_trajectory! export get_history, get_parameters, get_states, set_parameters!, reset!, give_inputs! export EnhancedUpdate, ClassicUpdate +export DriftCoupling, + ObservationCoupling, + CategoryCoupling, + ProbabilityCoupling, + VolatilityCoupling, + NoiseCoupling #Add premade agents to shared dict at initialization function __init__() @@ -26,7 +32,7 @@ function __init__() end #Types for HGFs -include("structs.jl") +include("create_hgf/hgf_structs.jl") #Overloading ActionModels functions include("ActionModels_variations/core/create_premade_agent.jl") @@ -39,6 +45,15 @@ include("ActionModels_variations/utils/give_inputs.jl") include("ActionModels_variations/utils/reset.jl") include("ActionModels_variations/utils/set_parameters.jl") +#Functions for updating the HGF +include("update_hgf/update_hgf.jl") +include("update_hgf/node_updates/continuous_input_node.jl") +include("update_hgf/node_updates/continuous_state_node.jl") +include("update_hgf/node_updates/binary_input_node.jl") +include("update_hgf/node_updates/binary_state_node.jl") +include("update_hgf/node_updates/categorical_input_node.jl") +include("update_hgf/node_updates/categorical_state_node.jl") + #Functions for creating HGFs include("create_hgf/check_hgf.jl") include("create_hgf/init_hgf.jl") @@ -46,15 +61,15 @@ include("create_hgf/create_premade_hgf.jl") #Plotting functions -#Functions for updating HGFs based on inputs -include("update_hgf/update_equations.jl") -include("update_hgf/update_hgf.jl") -include("update_hgf/update_node.jl") - #Functions for premade agents include("premade_models/premade_action_models.jl") include("premade_models/premade_agents.jl") -include("premade_models/premade_hgfs.jl") +include("premade_models/premade_hgfs/premade_binary_2level.jl") +include("premade_models/premade_hgfs/premade_binary_3level.jl") +include("premade_models/premade_hgfs/premade_categorical_3level.jl") +include("premade_models/premade_hgfs/premade_categorical_transitions_3level.jl") +include("premade_models/premade_hgfs/premade_continuous_2level.jl") +include("premade_models/premade_hgfs/premade_JGET.jl") #Utility functions for HGFs include("utils/get_prediction.jl") diff --git a/src/create_hgf/check_hgf.jl b/src/create_hgf/check_hgf.jl index 4469dda..960d229 100644 --- a/src/create_hgf/check_hgf.jl +++ b/src/create_hgf/check_hgf.jl @@ -5,7 +5,7 @@ Check whether an HGF has specified correctly. A single node can also be passed. """ function check_hgf(hgf::HGF) - ## Check for duplicate names + ## Check for duplicate names ## #Get node names node_names = getfield.(hgf.ordered_nodes.all_nodes, :name) #If there are any duplicate names @@ -18,9 +18,10 @@ function check_hgf(hgf::HGF) ) end + #If there are shared parameters if length(hgf.shared_parameters) > 0 - ## Check for the same derived parameter in multiple shared parameters + ## Check for the same derived parameter in multiple shared parameters ## #Get all derived parameters derived_parameters = [ @@ -29,7 +30,7 @@ function check_hgf(hgf::HGF) parameter_key in keys(hgf.shared_parameters) ] for parameter in list_of_derived_parameters ] - #check for duplicate names + #Check for duplicate names if length(derived_parameters) > length(unique(derived_parameters)) #Throw an error throw( @@ -39,7 +40,7 @@ function check_hgf(hgf::HGF) ) end - ## Check if the shared parameter is part of own derived parameters + ## Check if the shared parameter is part of own derived parameters ## #Go through each specified shared parameter for (shared_parameter_key, dict_value) in hgf.shared_parameters #check if the name of the shared parameter is part of its own derived parameters @@ -54,7 +55,7 @@ function check_hgf(hgf::HGF) end - ## Check each node + ### Check each node ### for node in hgf.ordered_nodes.all_nodes check_hgf(node) end @@ -66,52 +67,17 @@ function check_hgf(node::ContinuousStateNode) #Extract node name for error messages node_name = node.name - #Disallow binary input node value children - if any(isa.(node.value_children, BinaryInputNode)) + #If there are observation children, disallow noise and volatility children + if length(node.edges.observation_children) > 0 && + (length(node.edges.volatility_children) > 0 || length(node.edges.noise_children) > 0) throw( ArgumentError( - "The continuous state node $node_name has a value child which is a binary input node. This is not supported.", + "The state node $node_name has observation children. It is not supported for it to also have volatility or noise children, because it disrupts the update order.", ), ) end - #Disallow binary input node volatility children - if any(isa.(node.volatility_children, BinaryInputNode)) - throw( - ArgumentError( - "The continuous state node $node_name has a volatility child which is a binary input node. This is not supported.", - ), - ) - end - - #Disallow having volatility children if a value child is a continuous inputnode - if any(isa.(node.value_children, ContinuousInputNode)) - if length(node.volatility_children) > 0 - throw( - ArgumentError( - "The state node $node_name has a continuous input node as a value child. It also has volatility children, which disrupts the update order. This is not supported.", - ), - ) - end - end - - #Disallow having the same parent as value parent and volatility parent - if any(node.value_parents .∈ Ref(node.volatility_parents)) - throw( - ArgumentError( - "The state node $node_name has the same parent as value parent and volatility parent. This is not supported.", - ), - ) - end - - #Disallow having the same child as value child and volatility child - if any(node.value_children .∈ Ref(node.volatility_children)) - throw( - ArgumentError( - "The state node $node_name has the same child as value child and volatility child. This is not supported.", - ), - ) - end + #Disallow having the same node as multiple types of connections return nothing end @@ -121,38 +87,20 @@ function check_hgf(node::BinaryStateNode) #Extract node name for error messages node_name = node.name - #Require exactly one value parent - if length(node.value_parents) != 1 + #Require exactly one probability parent + if length(node.edges.probability_parents) != 1 throw( ArgumentError( - "The binary state node $node_name does not have exactly one value parent. This is not supported.", + "The binary state node $node_name does not have exactly one probability parent. This is not supported.", ), ) end - #Require exactly one value child - if length(node.value_children) != 1 + #Require exactly one observation child or category child + if length(node.edges.observation_children) + length(node.edges.category_children) != 1 throw( ArgumentError( - "The binary state node $node_name does not have exactly one value child. This is not supported.", - ), - ) - end - - # #Allow only binary input node and categorical state node children - #if any(.!(typeof.(node.value_children) in [BinaryInputNode, CategoricalStateNode])) - # throw( - # ArgumentError( - # "The binary state node $node_name has a child which is neither a binary input node nor a categorical state node. This is not supported.", - # ), - # ) - #end - - #Allow only continuous state node node parents - if any(.!isa.(node.value_parents, ContinuousStateNode)) - throw( - ArgumentError( - "The binary state node $node_name has a parent which is not a continuous state node. This is not supported.", + "The binary state node $node_name does not have exactly one observation child. This is not supported.", ), ) end @@ -166,28 +114,10 @@ function check_hgf(node::CategoricalStateNode) node_name = node.name #Require exactly one value child - if length(node.value_children) != 1 - throw( - ArgumentError( - "The categorical state node $node_name does not have exactly one value child. This is not supported.", - ), - ) - end - - #Allow only categorical input node children - if any(.!isa.(node.value_children, CategoricalInputNode)) + if length(node.edges.observation_children) != 1 throw( ArgumentError( - "The categorical state node $node_name has a child which is not a categorical input node. This is not supported.", - ), - ) - end - - #Allow only continuous state node node parents - if any(.!isa.(node.value_parents, BinaryStateNode)) - throw( - ArgumentError( - "The categorical state node $node_name has a parent which is not a binary state node. This is not supported.", + "The categorical state node $node_name does not have exactly one observation child. This is not supported.", ), ) end @@ -201,41 +131,11 @@ function check_hgf(node::ContinuousInputNode) #Extract node name for error messages node_name = node.name - #Allow only continuous state node node parents - if any(.!isa.(node.value_parents, ContinuousStateNode)) - throw( - ArgumentError( - "The continuous input node $node_name has a parent which is not a continuous state node. This is not supported.", - ), - ) - end - - #Require at least one value parent - if length(node.value_parents) == 0 - throw( - ArgumentError( - "The input node $node_name does not have any value parents. This is not supported.", - ), - ) - end - - #Disallow multiple value parents if there are volatility parents - if length(node.volatility_parents) > 0 - if length(node.value_parents) > 1 - throw( - ArgumentError( - "The input node $node_name has multiple value parents and at least one volatility parent. This is not supported.", - ), - ) - end - end - - - #Disallow having the same parent as value parent and volatility parent - if any(node.value_parents .∈ Ref(node.volatility_parents)) + #Disallow multiple observation parents if there are noise parents + if length(node.edges.noise_parents) > 0 && length(node.edges.observation_parents) > 1 throw( ArgumentError( - "The state node $node_name has the same parent as value parent and volatility parent. This is not supported.", + "The input node $node_name has multiple value parents and at least one volatility parent. This is not supported.", ), ) end @@ -249,19 +149,10 @@ function check_hgf(node::BinaryInputNode) node_name = node.name #Require exactly one value parent - if length(node.value_parents) != 1 + if length(node.edges.observation_parents) != 1 throw( ArgumentError( - "The binary input node $node_name does not have exactly one value parent. This is not supported.", - ), - ) - end - - #Allow only binary state nodes as parents - if any(.!isa.(node.value_parents, BinaryStateNode)) - throw( - ArgumentError( - "The binary input node $node_name has a parent which is not a binary state node. This is not supported.", + "The binary input node $node_name does not have exactly one observation parent. This is not supported.", ), ) end @@ -275,19 +166,10 @@ function check_hgf(node::CategoricalInputNode) node_name = node.name #Require exactly one value parent - if length(node.value_parents) != 1 - throw( - ArgumentError( - "The categorical input node $node_name does not have exactly one value parent. This is not supported.", - ), - ) - end - - #Allow only categorical state nodes as parents - if any(.!isa.(node.value_parents, CategoricalStateNode)) + if length(node.edges.observation_parents) != 1 throw( ArgumentError( - "The categorical input node $node_name has a parent which is not a categorical state node. This is not supported.", + "The categorical input node $node_name does not have exactly one observation parent. This is not supported.", ), ) end diff --git a/src/structs.jl b/src/create_hgf/hgf_structs.jl similarity index 53% rename from src/structs.jl rename to src/create_hgf/hgf_structs.jl index 1e7b776..59fa088 100644 --- a/src/structs.jl +++ b/src/create_hgf/hgf_structs.jl @@ -1,6 +1,6 @@ -################################ -######## Abstract Types ######## -################################ +##################################### +######## Abstract node types ######## +##################################### #Top-level node type abstract type AbstractNode end @@ -9,6 +9,18 @@ abstract type AbstractNode end abstract type AbstractStateNode <: AbstractNode end abstract type AbstractInputNode <: AbstractNode end +#Variable type subtypes +abstract type AbstractContinuousStateNode <: AbstractStateNode end +abstract type AbstractContinuousInputNode <: AbstractInputNode end +abstract type AbstractBinaryStateNode <: AbstractStateNode end +abstract type AbstractBinaryInputNode <: AbstractInputNode end +abstract type AbstractCategoricalStateNode <: AbstractStateNode end +abstract type AbstractCategoricalInputNode <: AbstractInputNode end + +################################## +######## HGF update types ######## +################################## + #Supertype for HGF update types abstract type HGFUpdateType end @@ -16,20 +28,87 @@ abstract type HGFUpdateType end struct ClassicUpdate <: HGFUpdateType end struct EnhancedUpdate <: HGFUpdateType end +################################ +######## Coupling types ######## +################################ + +#Supertypes for coupling types +abstract type CouplingType end +abstract type ValueCoupling <: CouplingType end +abstract type PrecisionCoupling <: CouplingType end + +#Concrete value coupling types +Base.@kwdef mutable struct DriftCoupling <: ValueCoupling + strength::Union{Nothing,Real} = nothing +end +Base.@kwdef mutable struct ProbabilityCoupling <: ValueCoupling + strength::Union{Nothing,Real} = nothing +end +Base.@kwdef mutable struct CategoryCoupling <: ValueCoupling end +Base.@kwdef mutable struct ObservationCoupling <: ValueCoupling end + +#Concrete precision coupling types +Base.@kwdef mutable struct VolatilityCoupling <: PrecisionCoupling + strength::Union{Nothing,Real} = nothing +end +Base.@kwdef mutable struct NoiseCoupling <: PrecisionCoupling + strength::Union{Nothing,Real} = nothing +end + +############################ +######## HGF Struct ######## +############################ +""" +""" +Base.@kwdef mutable struct OrderedNodes + all_nodes::Vector{AbstractNode} = [] + input_nodes::Vector{AbstractInputNode} = [] + all_state_nodes::Vector{AbstractStateNode} = [] + early_update_state_nodes::Vector{AbstractStateNode} = [] + late_update_state_nodes::Vector{AbstractStateNode} = [] +end + +""" +""" +Base.@kwdef mutable struct HGF + all_nodes::Dict{String,AbstractNode} + input_nodes::Dict{String,AbstractInputNode} + state_nodes::Dict{String,AbstractStateNode} + ordered_nodes::OrderedNodes = OrderedNodes() + shared_parameters::Dict = Dict() +end + + + ####################################### ######## Continuous State Node ######## ####################################### +Base.@kwdef mutable struct ContinuousStateNodeEdges + #Possible parent types + drift_parents::Vector{<:AbstractContinuousStateNode} = Vector{ContinuousStateNode}() + volatility_parents::Vector{<:AbstractContinuousStateNode} = + Vector{ContinuousStateNode}() + + #Possible children types + drift_children::Vector{<:AbstractContinuousStateNode} = Vector{ContinuousStateNode}() + volatility_children::Vector{<:AbstractContinuousStateNode} = + Vector{ContinuousStateNode}() + probability_children::Vector{<:AbstractBinaryStateNode} = Vector{BinaryStateNode}() + observation_children::Vector{<:AbstractContinuousInputNode} = + Vector{ContinuousInputNode}() + noise_children::Vector{<:AbstractContinuousInputNode} = Vector{ContinuousInputNode}() +end + """ -Configuration of continuous state nodes' parameters +Configuration of continuous state nodes' parameters """ Base.@kwdef mutable struct ContinuousStateNodeParameters volatility::Real = 0 drift::Real = 0 autoregression_target::Real = 0 autoregression_strength::Real = 0 - value_coupling::Dict{String,Real} = Dict{String,Real}() - volatility_coupling::Dict{String,Real} = Dict{String,Real}() + coupling_strengths::Dict{String,Real} = Dict{String,Real}() initial_mean::Real = 0 initial_precision::Real = 0 end @@ -41,7 +120,7 @@ Base.@kwdef mutable struct ContinuousStateNodeState posterior_mean::Union{Real} = 0 posterior_precision::Union{Real} = 1 value_prediction_error::Union{Real,Missing} = missing - volatility_prediction_error::Union{Real,Missing} = missing + precision_prediction_error::Union{Real,Missing} = missing prediction_mean::Union{Real,Missing} = missing predicted_volatility::Union{Real,Missing} = missing prediction_precision::Union{Real,Missing} = missing @@ -54,22 +133,19 @@ Configuration of continuous state node history Base.@kwdef mutable struct ContinuousStateNodeHistory posterior_mean::Vector{Real} = [] posterior_precision::Vector{Real} = [] - value_prediction_error::Vector{Union{Real,Missing}} = [missing] - volatility_prediction_error::Vector{Union{Real,Missing}} = [missing] - prediction_mean::Vector{Real} = [] - predicted_volatility::Vector{Real} = [] - prediction_precision::Vector{Real} = [] - volatility_weighted_prediction_precision::Vector{Real} = [] + value_prediction_error::Vector{Union{Real,Missing}} = [] + precision_prediction_error::Vector{Union{Real,Missing}} = [] + prediction_mean::Vector{Union{Real,Missing}} = [] + predicted_volatility::Vector{Union{Real,Missing}} = [] + prediction_precision::Vector{Union{Real,Missing}} = [] + volatility_weighted_prediction_precision::Vector{Union{Real,Missing}} = [] end """ """ -Base.@kwdef mutable struct ContinuousStateNode <: AbstractStateNode +Base.@kwdef mutable struct ContinuousStateNode <: AbstractContinuousStateNode name::String - value_parents::Vector{AbstractStateNode} = [] - volatility_parents::Vector{AbstractStateNode} = [] - value_children::Vector{AbstractNode} = [] - volatility_children::Vector{AbstractNode} = [] + edges::ContinuousStateNodeEdges = ContinuousStateNodeEdges() parameters::ContinuousStateNodeParameters = ContinuousStateNodeParameters() states::ContinuousStateNodeState = ContinuousStateNodeState() history::ContinuousStateNodeHistory = ContinuousStateNodeHistory() @@ -77,148 +153,121 @@ Base.@kwdef mutable struct ContinuousStateNode <: AbstractStateNode end -################################### -######## Binary State Node ######## -################################### +####################################### +######## Continuous Input Node ######## +####################################### +Base.@kwdef mutable struct ContinuousInputNodeEdges + #Possible parents + observation_parents::Vector{<:AbstractContinuousStateNode} = + Vector{ContinuousStateNode}() + noise_parents::Vector{<:AbstractContinuousStateNode} = Vector{ContinuousStateNode}() +end """ - Configure parameters of binary state node +Configuration of continuous input node parameters """ -Base.@kwdef mutable struct BinaryStateNodeParameters - value_coupling::Dict{String,Real} = Dict{String,Real}() +Base.@kwdef mutable struct ContinuousInputNodeParameters + input_noise::Real = 0 + coupling_strengths::Dict{String,Real} = Dict{String,Real}() end """ -Overview of the states of the binary state node +Configuration of continuous input node states """ -Base.@kwdef mutable struct BinaryStateNodeState - posterior_mean::Union{Real,Missing} = missing - posterior_precision::Union{Real,Missing} = missing +Base.@kwdef mutable struct ContinuousInputNodeState + input_value::Union{Real,Missing} = missing value_prediction_error::Union{Real,Missing} = missing + precision_prediction_error::Union{Real,Missing} = missing prediction_mean::Union{Real,Missing} = missing prediction_precision::Union{Real,Missing} = missing end """ -Overview of the history of the binary state node +Configuration of continuous input node history """ -Base.@kwdef mutable struct BinaryStateNodeHistory - posterior_mean::Vector{Union{Real,Missing}} = [] - posterior_precision::Vector{Union{Real,Missing}} = [] - value_prediction_error::Vector{Union{Real,Missing}} = [missing] - prediction_mean::Vector{Real} = [] - prediction_precision::Vector{Real} = [] +Base.@kwdef mutable struct ContinuousInputNodeHistory + input_value::Vector{Union{Real,Missing}} = [] + value_prediction_error::Vector{Union{Real,Missing}} = [] + precision_prediction_error::Vector{Union{Real,Missing}} = [] + prediction_mean::Vector{Union{Real,Missing}} = [] + prediction_precision::Vector{Union{Real,Missing}} = [] end """ -Overview of edge posibilities """ -Base.@kwdef mutable struct BinaryStateNode <: AbstractStateNode +Base.@kwdef mutable struct ContinuousInputNode <: AbstractContinuousInputNode name::String - value_parents::Vector{AbstractStateNode} = [] - volatility_parents::Vector{Nothing} = [] - value_children::Vector{AbstractNode} = [] - volatility_children::Vector{Nothing} = [] - parameters::BinaryStateNodeParameters = BinaryStateNodeParameters() - states::BinaryStateNodeState = BinaryStateNodeState() - history::BinaryStateNodeHistory = BinaryStateNodeHistory() - update_type::HGFUpdateType = ClassicUpdate() -end - - -######################################## -######## Categorical State Node ######## -######################################## -Base.@kwdef mutable struct CategoricalStateNodeParameters end - -""" -Configuration of states in categorical state node -""" -Base.@kwdef mutable struct CategoricalStateNodeState - posterior::Vector{Union{Real,Missing}} = [] - value_prediction_error::Vector{Union{Real,Missing}} = [] - prediction::Vector{Real} = [] - parent_predictions::Vector{Real} = [] + edges::ContinuousInputNodeEdges = ContinuousInputNodeEdges() + parameters::ContinuousInputNodeParameters = ContinuousInputNodeParameters() + states::ContinuousInputNodeState = ContinuousInputNodeState() + history::ContinuousInputNodeHistory = ContinuousInputNodeHistory() end -""" -Configuration of history in categorical state node -""" -Base.@kwdef mutable struct CategoricalStateNodeHistory - posterior::Vector{Vector{Union{Real,Missing}}} = [] - value_prediction_error::Vector{Vector{Union{Real,Missing}}} = [] - prediction::Vector{Vector{Real}} = [] - parent_predictions::Vector{Vector{Real}} = [] -end -""" -Configuration of edges in categorical state node -""" -Base.@kwdef mutable struct CategoricalStateNode <: AbstractStateNode - name::String - value_parents::Vector{AbstractStateNode} = [] - volatility_parents::Vector{Nothing} = [] - value_children::Vector{AbstractNode} = [] - volatility_children::Vector{Nothing} = [] - category_parent_order::Vector{String} = [] - parameters::CategoricalStateNodeParameters = CategoricalStateNodeParameters() - states::CategoricalStateNodeState = CategoricalStateNodeState() - history::CategoricalStateNodeHistory = CategoricalStateNodeHistory() - update_type::HGFUpdateType = ClassicUpdate() +################################### +######## Binary State Node ######## +################################### +Base.@kwdef mutable struct BinaryStateNodeEdges + #Possible parent types + probability_parents::Vector{<:AbstractContinuousStateNode} = + Vector{ContinuousStateNode}() + + #Possible children types + category_children::Vector{<:AbstractCategoricalStateNode} = + Vector{CategoricalStateNode}() + observation_children::Vector{<:AbstractBinaryInputNode} = Vector{BinaryInputNode}() end - -####################################### -######## Continuous Input Node ######## -####################################### """ -Configuration of continuous input node parameters + Configure parameters of binary state node """ -Base.@kwdef mutable struct ContinuousInputNodeParameters - input_noise::Real = 0 - value_coupling::Dict{String,Real} = Dict{String,Real}() - volatility_coupling::Dict{String,Real} = Dict{String,Real}() +Base.@kwdef mutable struct BinaryStateNodeParameters + coupling_strengths::Dict{String,Real} = Dict{String,Real}() end """ -Configuration of continuous input node states +Overview of the states of the binary state node """ -Base.@kwdef mutable struct ContinuousInputNodeState - input_value::Union{Real,Missing} = missing +Base.@kwdef mutable struct BinaryStateNodeState + posterior_mean::Union{Real,Missing} = missing + posterior_precision::Union{Real,Missing} = missing value_prediction_error::Union{Real,Missing} = missing - volatility_prediction_error::Union{Real,Missing} = missing - predicted_volatility::Union{Real,Missing} = missing + prediction_mean::Union{Real,Missing} = missing prediction_precision::Union{Real,Missing} = missing - volatility_weighted_prediction_precision::Union{Real} = 1 end """ -Configuration of continuous input node history +Overview of the history of the binary state node """ -Base.@kwdef mutable struct ContinuousInputNodeHistory - input_value::Vector{Union{Real,Missing}} = [missing] - value_prediction_error::Vector{Union{Real,Missing}} = [missing] - volatility_prediction_error::Vector{Union{Real,Missing}} = [missing] - predicted_volatility::Vector{Real} = [] - prediction_precision::Vector{Real} = [] +Base.@kwdef mutable struct BinaryStateNodeHistory + posterior_mean::Vector{Union{Real,Missing}} = [] + posterior_precision::Vector{Union{Real,Missing}} = [] + value_prediction_error::Vector{Union{Real,Missing}} = [] + prediction_mean::Vector{Union{Real,Missing}} = [] + prediction_precision::Vector{Union{Real,Missing}} = [] end """ +Overview of edge posibilities """ -Base.@kwdef mutable struct ContinuousInputNode <: AbstractInputNode +Base.@kwdef mutable struct BinaryStateNode <: AbstractBinaryStateNode name::String - value_parents::Vector{AbstractStateNode} = [] - volatility_parents::Vector{AbstractStateNode} = [] - parameters::ContinuousInputNodeParameters = ContinuousInputNodeParameters() - states::ContinuousInputNodeState = ContinuousInputNodeState() - history::ContinuousInputNodeHistory = ContinuousInputNodeHistory() + edges::BinaryStateNodeEdges = BinaryStateNodeEdges() + parameters::BinaryStateNodeParameters = BinaryStateNodeParameters() + states::BinaryStateNodeState = BinaryStateNodeState() + history::BinaryStateNodeHistory = BinaryStateNodeHistory() + update_type::HGFUpdateType = ClassicUpdate() end + ################################### ######## Binary Input Node ######## ################################### +Base.@kwdef mutable struct BinaryInputNodeEdges + observation_parents::Vector{<:AbstractBinaryStateNode} = Vector{BinaryStateNode}() +end """ Configuration of parameters in binary input node. Default category mean set to [0,1] @@ -226,6 +275,7 @@ Configuration of parameters in binary input node. Default category mean set to [ Base.@kwdef mutable struct BinaryInputNodeParameters category_means::Vector{Union{Real}} = [0, 1] input_precision::Real = Inf + coupling_strengths::Dict{String,Real} = Dict{String,Real}() end """ @@ -246,10 +296,9 @@ end """ """ -Base.@kwdef mutable struct BinaryInputNode <: AbstractInputNode +Base.@kwdef mutable struct BinaryInputNode <: AbstractBinaryInputNode name::String - value_parents::Vector{AbstractStateNode} = [] - volatility_parents::Vector{Nothing} = [] + edges::BinaryInputNodeEdges = BinaryInputNodeEdges() parameters::BinaryInputNodeParameters = BinaryInputNodeParameters() states::BinaryInputNodeState = BinaryInputNodeState() history::BinaryInputNodeHistory = BinaryInputNodeHistory() @@ -257,12 +306,70 @@ end +######################################## +######## Categorical State Node ######## +######################################## +Base.@kwdef mutable struct CategoricalStateNodeEdges + #Possible parents + category_parents::Vector{<:AbstractBinaryStateNode} = Vector{BinaryStateNode}() + #The order of the category parents + category_parent_order::Vector{String} = [] + + #Possible children + observation_children::Vector{<:AbstractCategoricalInputNode} = + Vector{CategoricalInputNode}() +end + +Base.@kwdef mutable struct CategoricalStateNodeParameters + coupling_strengths::Dict{String,Real} = Dict{String,Real}() +end + +""" +Configuration of states in categorical state node +""" +Base.@kwdef mutable struct CategoricalStateNodeState + posterior::Vector{Union{Real,Missing}} = [] + value_prediction_error::Vector{Union{Real,Missing}} = [] + prediction::Vector{Union{Real,Missing}} = [] + parent_predictions::Vector{Union{Real,Missing}} = [] +end + +""" +Configuration of history in categorical state node +""" +Base.@kwdef mutable struct CategoricalStateNodeHistory + posterior::Vector{Vector{Union{Real,Missing}}} = [] + value_prediction_error::Vector{Vector{Union{Real,Missing}}} = [] + prediction::Vector{Vector{Union{Real,Missing}}} = [] + parent_predictions::Vector{Vector{Union{Real,Missing}}} = [] +end + +""" +Configuration of edges in categorical state node +""" +Base.@kwdef mutable struct CategoricalStateNode <: AbstractCategoricalStateNode + name::String + edges::CategoricalStateNodeEdges = CategoricalStateNodeEdges() + parameters::CategoricalStateNodeParameters = CategoricalStateNodeParameters() + states::CategoricalStateNodeState = CategoricalStateNodeState() + history::CategoricalStateNodeHistory = CategoricalStateNodeHistory() + update_type::HGFUpdateType = ClassicUpdate() +end + + + ######################################## ######## Categorical Input Node ######## ######################################## +Base.@kwdef mutable struct CategoricalInputNodeEdges + observation_parents::Vector{<:AbstractCategoricalStateNode} = + Vector{CategoricalStateNode}() +end -Base.@kwdef mutable struct CategoricalInputNodeParameters end +Base.@kwdef mutable struct CategoricalInputNodeParameters + coupling_strengths::Dict{String,Real} = Dict{String,Real}() +end """ Configuration of states of categorical input node @@ -280,35 +387,10 @@ end """ """ -Base.@kwdef mutable struct CategoricalInputNode <: AbstractInputNode +Base.@kwdef mutable struct CategoricalInputNode <: AbstractCategoricalInputNode name::String - value_parents::Vector{AbstractStateNode} = [] - volatility_parents::Vector{Nothing} = [] + edges::CategoricalInputNodeEdges = CategoricalInputNodeEdges() parameters::CategoricalInputNodeParameters = CategoricalInputNodeParameters() states::CategoricalInputNodeState = CategoricalInputNodeState() history::CategoricalInputNodeHistory = CategoricalInputNodeHistory() end - - -############################ -######## HGF Struct ######## -############################ -""" -""" -Base.@kwdef mutable struct OrderedNodes - all_nodes::Vector{AbstractNode} = [] - input_nodes::Vector{AbstractInputNode} = [] - all_state_nodes::Vector{AbstractStateNode} = [] - early_update_state_nodes::Vector{AbstractStateNode} = [] - late_update_state_nodes::Vector{AbstractStateNode} = [] -end - -""" -""" -Base.@kwdef mutable struct HGF - all_nodes::Dict{String,AbstractNode} - input_nodes::Dict{String,AbstractInputNode} - state_nodes::Dict{String,AbstractStateNode} - ordered_nodes::OrderedNodes = OrderedNodes() - shared_parameters::Dict = Dict() -end diff --git a/src/create_hgf/init_hgf.jl b/src/create_hgf/init_hgf.jl index b9bdc80..f05c2f7 100644 --- a/src/create_hgf/init_hgf.jl +++ b/src/create_hgf/init_hgf.jl @@ -24,6 +24,15 @@ Edge information includes 'child', as well as 'value_parents' and/or 'volatility ```julia ##Create a simple 2level continuous HGF## +#Set defaults for nodes +node_defaults = Dict( + "volatility" => -2, + "input_noise" => -2, + "initial_mean" => 0, + "initial_precision" => 1, + "coupling_strength" => 1, +) + #List of input nodes input_nodes = Dict( "name" => "u", @@ -50,68 +59,11 @@ state_nodes = [ ] #List of child-parent relations -edges = [ - Dict( - "child" => "u", - "value_parents" => ("x", 1), - ), - Dict( - "child" => "x", - "volatility_parents" => ("xvol", 1), - ), -] - -#Initialize the HGF -hgf = init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, -) - -##Create a more complicated HGF without specifying information for each node## - -#Set defaults for all nodes -node_defaults = Dict( - "volatility" => -2, - "input_noise" => -2, - "initial_mean" => 0, - "initial_precision" => 1, - "value_coupling" => 1, - "volatility_coupling" => 1, +edges = Dict( + ("u", "x") -> ObservationCoupling() + ("x", "xvol") -> VolatilityCoupling() ) -input_nodes = [ - "u1", - "u2", -] - -state_nodes = [ - "x", - "xvol", - "x3", - "x4", -] - -edges = [ - Dict( - "child" => "u1", - "value_parents" => ["x", "xvol"], - "volatility_parents" => "x3" - ), - Dict( - "child" => "u2", - "value_parents" => ["x"], - ), - Dict( - "child" => "x", - "volatility_parents" => "x4", - ), - Dict( - "child" => "xvol", - "volatility_parents" => "x4", - ), -] - hgf = init_hgf( input_nodes = input_nodes, state_nodes = state_nodes, @@ -123,7 +75,7 @@ hgf = init_hgf( function init_hgf(; input_nodes::Union{String,Dict,Vector}, state_nodes::Union{String,Dict,Vector}, - edges::Union{Vector{<:Dict},Dict}, + edges::Dict{Tuple{String,String},<:CouplingType}, shared_parameters::Dict = Dict(), node_defaults::Dict = Dict(), update_type::HGFUpdateType = EnhancedUpdate(), @@ -139,11 +91,10 @@ function init_hgf(; "autoregression_strength" => 0, "initial_mean" => 0, "initial_precision" => 1, - "value_coupling" => 1, - "volatility_coupling" => 1, + "coupling_strength" => 1, "category_means" => [0, 1], "input_precision" => Inf, - "input_noise" => -2 + "input_noise" => -2, ) #If verbose @@ -220,101 +171,18 @@ function init_hgf(; ### Set up edges ### + #For each specified edge + for (node_names, coupling_type) in edges - #If user has only specified a single edge and not in a vector - if edges isa Dict - #Put it in a vector - edges = [edges] - end - - #For each child - for edge in edges - - #Find corresponding child node - child_node = all_nodes_dict[edge["child"]] - - #Add empty vectors for when the user has not specified any - edge = merge(Dict("value_parents" => [], "volatility_parents" => []), edge) - - #If there are any value parents - if length(edge["value_parents"]) > 0 - #Get out value parents - value_parents = edge["value_parents"] + #Extract the child and parent names + child_name, parent_name = node_names - #If the value parents were not specified as a vector - if .!isa(value_parents, Vector) - #Make it into one - value_parents = [value_parents] - end - - #For each value parent - for parent_info in value_parents - - #If only the node's name was specified as a string - if parent_info isa String - #Make it a tuple, and give it the default coupling strength - parent_info = (parent_info, node_defaults["value_coupling"]) - end - - #Find the corresponding parent - parent_node = all_nodes_dict[parent_info[1]] - - #Add the parent to the child node - push!(child_node.value_parents, parent_node) - - #Add the child node to the parent node - push!(parent_node.value_children, child_node) - - #Except for binary input nodes and categorical nodes - if !( - typeof(child_node) in - [BinaryInputNode, CategoricalInputNode, CategoricalStateNode] - ) - #Add coupling strength to child node - child_node.parameters.value_coupling[parent_node.name] = parent_info[2] - end - end - end - - #If there are any volatility parents - if length(edge["volatility_parents"]) > 0 - #Get out volatility parents - volatility_parents = edge["volatility_parents"] - - #If the volatility parents were not specified as a vector - if .!isa(volatility_parents, Vector) - #Make it into one - volatility_parents = [volatility_parents] - end + #Find corresponding child node and parent node + child_node = all_nodes_dict[child_name] + parent_node = all_nodes_dict[parent_name] - #For each volatility parent - for parent_info in volatility_parents - - #If only the node's name was specified as a string - if parent_info isa String - #Make it a tuple, and give it the default coupling strength - parent_info = (parent_info, node_defaults["volatility_coupling"]) - end - - #Find the corresponding parent - parent_node = all_nodes_dict[parent_info[1]] - - #Add the parent to the child node - push!(child_node.volatility_parents, parent_node) - - #Add the child node to the parent node - push!(parent_node.volatility_children, child_node) - - #Add coupling strength to child node - child_node.parameters.volatility_coupling[parent_node.name] = parent_info[2] - - #If the enhanced HGF update is used - if update_type isa EnhancedUpdate && parent_node isa ContinuousStateNode - #Set the node to use the enhanced HGF update - parent_node.update_type = update_type - end - end - end + #Create the edge + init_edge!(child_node, parent_node, coupling_type, update_type, node_defaults) end ## Determine Update order ## @@ -323,7 +191,7 @@ function init_hgf(; #If verbose if verbose - #Warn that automaitc update order is used + #Warn that automatic update order is used @warn "No update order specified. Using the order in which nodes were inputted" end @@ -377,7 +245,7 @@ function init_hgf(; push!(ordered_nodes.all_state_nodes, node) #If any of the nodes' value children are continuous input nodes - if any(isa.(node.value_children, ContinuousInputNode)) + if any(isa.(node.edges.observation_children, ContinuousInputNode)) #Add it to the early update list push!(ordered_nodes.early_update_state_nodes, node) else @@ -422,44 +290,42 @@ function init_hgf(; ### Check that the HGF has been specified properly ### check_hgf(hgf) - ### Initialize node history ### + ### Initialize states and history ### #For each state node for node in hgf.ordered_nodes.all_state_nodes - - #For categorical state nodes + #If it is a categorical state node if node isa CategoricalStateNode - #Make vector of order of category parents - for parent in node.value_parents - push!(node.category_parent_order, parent.name) + #Make vector with ordered category parents + for parent in node.edges.category_parents + push!(node.edges.category_parent_order, parent.name) end - #Set posterior to vector of zeros equal to the number of categories + #Set posterior to vector of missing with length equal to the number of categories node.states.posterior = - Vector{Union{Real,Missing}}(missing, length(node.value_parents)) - push!(node.history.posterior, node.states.posterior) + Vector{Union{Real,Missing}}(missing, length(node.edges.category_parents)) - #Set posterior to vector of missing equal to the number of categories + #Set posterior to vector of missing with length equal to the number of categories node.states.value_prediction_error = - Vector{Union{Real,Missing}}(missing, length(node.value_parents)) - push!(node.history.value_prediction_error, node.states.value_prediction_error) + Vector{Union{Real,Missing}}(missing, length(node.edges.category_parents)) - #Set parent predictions form last timestep to be agnostic - node.states.parent_predictions = - repeat([1 / length(node.value_parents)], length(node.value_parents)) - - #Set predictions form last timestep to be agnostic - node.states.prediction = - repeat([1 / length(node.value_parents)], length(node.value_parents)) + #Set parent predictions from last timestep to be agnostic + node.states.parent_predictions = repeat( + [1 / length(node.edges.category_parents)], + length(node.edges.category_parents), + ) - #For other nodes - else - #Save posterior to node history - push!(node.history.posterior_mean, node.states.posterior_mean) - push!(node.history.posterior_precision, node.states.posterior_precision) + #Set predictions from last timestep to be agnostic + node.states.prediction = repeat( + [1 / length(node.edges.category_parents)], + length(node.edges.category_parents), + ) end end + #Reset the hgf, initializing states and history + reset!(hgf) + return hgf end @@ -565,3 +431,68 @@ function init_node(input_or_state_node, node_defaults, node_info) return node end + +### Function for initializing and edge ### +function init_edge!( + child_node::AbstractNode, + parent_node::AbstractStateNode, + coupling_type::CouplingType, + update_type::HGFUpdateType, + node_defaults::Dict, +) + + #Get correct field for storing parents + if coupling_type isa DriftCoupling + parents_field = :drift_parents + children_field = :drift_children + + elseif coupling_type isa ObservationCoupling + parents_field = :observation_parents + children_field = :observation_children + + elseif coupling_type isa CategoryCoupling + parents_field = :category_parents + children_field = :category_children + + elseif coupling_type isa ProbabilityCoupling + parents_field = :probability_parents + children_field = :probability_children + + elseif coupling_type isa VolatilityCoupling + parents_field = :volatility_parents + children_field = :volatility_children + + elseif coupling_type isa NoiseCoupling + parents_field = :noise_parents + children_field = :noise_children + end + + #Add the parent to the child node + push!(getfield(child_node.edges, parents_field), parent_node) + + #Add the child node to the parent node + push!(getfield(parent_node.edges, children_field), child_node) + + #If the coupling type has a coupling strength + if hasproperty(coupling_type, :strength) + #If the user has not specified a coupling strength + if isnothing(coupling_type.strength) + #Use the defaults coupling strength + coupling_strength = node_defaults["coupling_strength"] + + #Otherwise + else + #Use the specified coupling strength + coupling_strength = coupling_type.strength + end + + #And set it as a parameter for the child + child_node.parameters.coupling_strengths[parent_node.name] = coupling_strength + end + + #If the enhanced HGF update is used, and if it is a precision coupling (volatility or noise) + if update_type isa EnhancedUpdate && coupling_type isa PrecisionCoupling + #Set the node to use the enhanced HGF update + parent_node.update_type = update_type + end +end diff --git a/src/premade_models/premade_hgfs.jl b/src/premade_models/premade_hgfs.jl deleted file mode 100644 index 6c1f8fd..0000000 --- a/src/premade_models/premade_hgfs.jl +++ /dev/null @@ -1,945 +0,0 @@ -""" - premade_continuous_2level(config::Dict; verbose::Bool = true) - -The standard 2 level continuous HGF, which filters a continuous input. -It has a continous input node u, with a single value parent x, which in turn has a single volatility parent xvol. - -# Config defaults: - - ("u", "input_noise"): -2 - - ("x", "volatility"): -2 - - ("xvol", "volatility"): -2 - - ("u", "x", "value_coupling"): 1 - - ("x", "xvol", "volatility_coupling"): 1 - - ("x", "initial_mean"): 0 - - ("x", "initial_precision"): 1 - - ("xvol", "initial_mean"): 0 - - ("xvol", "initial_precision"): 1 -""" -function premade_continuous_2level(config::Dict; verbose::Bool = true) - - #Defaults - spec_defaults = Dict( - ("u", "input_noise") => -2, - - ("x", "volatility") => -2, - ("x", "drift") => 0, - ("x", "autoregression_target") => 0, - ("x", "autoregression_strength") => 0, - ("x", "initial_mean") => 0, - ("x", "initial_precision") => 1, - - ("xvol", "volatility") => -2, - ("xvol", "drift") => 0, - ("xvol", "autoregression_target") => 0, - ("xvol", "autoregression_strength") => 0, - ("xvol", "initial_mean") => 0, - ("xvol", "initial_precision") => 1, - - ("u", "x", "value_coupling") => 1, - ("x", "xvol", "volatility_coupling") => 1, - - "update_type" => EnhancedUpdate(), - ) - - #Warn the user about used defaults and misspecified keys - if verbose - warn_premade_defaults(spec_defaults, config) - end - - #Merge to overwrite defaults - config = merge(spec_defaults, config) - - - #List of input nodes to create - input_nodes = Dict( - "name" => "u", - "type" => "continuous", - "input_noise" => config[("u", "input_noise")], - ) - - #List of state nodes to create - state_nodes = [ - Dict( - "name" => "x", - "type" => "continuous", - "volatility" => config[("x", "volatility")], - "drift" => config[("x", "drift")], - "autoregression_target" => config[("x", "autoregression_target")], - "autoregression_strength" => config[("x", "autoregression_strength")], - "initial_mean" => config[("x", "initial_mean")], - "initial_precision" => config[("x", "initial_precision")], - ), - Dict( - "name" => "xvol", - "type" => "continuous", - "volatility" => config[("xvol", "volatility")], - "drift" => config[("xvol", "drift")], - "autoregression_target" => config[("xvol", "autoregression_target")], - "autoregression_strength" => config[("xvol", "autoregression_strength")], - "initial_mean" => config[("xvol", "initial_mean")], - "initial_precision" => config[("xvol", "initial_precision")], - ), - ] - - #List of child-parent relations - edges = [ - Dict( - "child" => "u", - "value_parents" => ("x", config[("u", "x", "value_coupling")]), - ), - Dict( - "child" => "x", - "volatility_parents" => ("xvol", config[("x", "xvol", "volatility_coupling")]), - ), - ] - - #Initialize the HGF - init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - verbose = false, - update_type = config["update_type"], - ) -end - - -""" -premade_JGET(config::Dict; verbose::Bool = true) - -The HGF used in the JGET model. It has a single continuous input node u, with a value parent x, and a volatility parent xnoise. x has volatility parent xvol, and xnoise has a volatility parent xnoise_vol. - -# Config defaults: - - ("u", "input_noise"): -2 - - ("x", "volatility"): -2 - - ("xvol", "volatility"): -2 - - ("xnoise", "volatility"): -2 - - ("xnoise_vol", "volatility"): -2 - - ("u", "x", "value_coupling"): 1 - - ("u", "xnoise", "value_coupling"): 1 - - ("x", "xvol", "volatility_coupling"): 1 - - ("xnoise", "xnoise_vol", "volatility_coupling"): 1 - - ("x", "initial_mean"): 0 - - ("x", "initial_precision"): 1 - - ("xvol", "initial_mean"): 0 - - ("xvol", "initial_precision"): 1 - - ("xnoise", "initial_mean"): 0 - - ("xnoise", "initial_precision"): 1 - - ("xnoise_vol", "initial_mean"): 0 - - ("xnoise_vol", "initial_precision"): 1 -""" -function premade_JGET(config::Dict; verbose::Bool = true) - - #Defaults - spec_defaults = Dict( - ("u", "input_noise") => -2, - - ("x", "volatility") => -2, - ("x", "drift") => 0, - ("x", "autoregression_target") => 0, - ("x", "autoregression_strength") => 0, - ("x", "initial_mean") => 0, - ("x", "initial_precision") => 1, - - ("xvol", "volatility") => -2, - ("xvol", "drift") => 0, - ("xvol", "autoregression_target") => 0, - ("xvol", "autoregression_strength") => 0, - ("xvol", "initial_mean") => 0, - ("xvol", "initial_precision") => 1, - - ("xnoise", "volatility") => -2, - ("xnoise", "drift") => 0, - ("xnoise", "autoregression_target") => 0, - ("xnoise", "autoregression_strength") => 0, - ("xnoise", "initial_mean") => 0, - ("xnoise", "initial_precision") => 1, - - ("xnoise_vol", "volatility") => -2, - ("xnoise_vol", "drift") => 0, - ("xnoise_vol", "autoregression_target") => 0, - ("xnoise_vol", "autoregression_strength") => 0, - ("xnoise_vol", "initial_mean") => 0, - ("xnoise_vol", "initial_precision") => 1, - - ("u", "x", "value_coupling") => 1, - ("u", "xnoise", "volatility_coupling") => 1, - ("x", "xvol", "volatility_coupling") => 1, - ("xnoise", "xnoise_vol", "volatility_coupling") => 1, - - "update_type" => EnhancedUpdate(), - ) - - #Warn the user about used defaults and misspecified keys - if verbose - warn_premade_defaults(spec_defaults, config) - end - - #Merge to overwrite defaults - config = merge(spec_defaults, config) - - - #List of input nodes to create - input_nodes = Dict( - "name" => "u", - "type" => "continuous", - "input_noise" => config[("u", "input_noise")], - ) - - #List of state nodes to create - state_nodes = [ - Dict( - "name" => "x", - "type" => "continuous", - "volatility" => config[("x", "volatility")], - "drift" => config[("x", "drift")], - "autoregression_target" => config[("x", "autoregression_target")], - "autoregression_strength" => config[("x", "autoregression_strength")], - "initial_mean" => config[("x", "initial_mean")], - "initial_precision" => config[("x", "initial_precision")], - ), - Dict( - "name" => "xvol", - "type" => "continuous", - "volatility" => config[("xvol", "volatility")], - "drift" => config[("xvol", "drift")], - "autoregression_target" => config[("xvol", "autoregression_target")], - "autoregression_strength" => config[("xvol", "autoregression_strength")], - "initial_mean" => config[("xvol", "initial_mean")], - "initial_precision" => config[("xvol", "initial_precision")], - ), - Dict( - "name" => "xnoise", - "type" => "continuous", - "volatility" => config[("xnoise", "volatility")], - "drift" => config[("xnoise", "drift")], - "autoregression_target" => config[("xnoise", "autoregression_target")], - "autoregression_strength" => config[("xnoise", "autoregression_strength")], - "initial_mean" => config[("xnoise", "initial_precision")], - "initial_precision" => config[("xnoise", "initial_precision")], - ), - Dict( - "name" => "xnoise_vol", - "type" => "continuous", - "volatility" => config[("xnoise_vol", "volatility")], - "drift" => config[("xnoise_vol", "drift")], - "autoregression_target" => config[("xnoise_vol", "autoregression_target")], - "autoregression_strength" => config[("xnoise_vol", "autoregression_strength")], - "initial_mean" => config[("xnoise_vol", "initial_mean")], - "initial_precision" => config[("xnoise_vol", "initial_precision")], - ), - ] - - #List of child-parent relations - edges = [ - Dict( - "child" => "u", - "value_parents" => ("x", config[("u", "x", "value_coupling")]), - "volatility_parents" => ("xnoise", config[("u", "xnoise", "volatility_coupling")]), - ), - Dict( - "child" => "x", - "volatility_parents" => ("xvol", config[("x", "xvol", "volatility_coupling")]), - ), - Dict( - "child" => "xnoise", - "volatility_parents" => ("xnoise_vol", config[("xnoise", "xnoise_vol", "volatility_coupling")]), - ), - ] - - #Initialize the HGF - init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - verbose = false, - update_type = config["update_type"], - ) -end - - -""" - premade_binary_2level(config::Dict; verbose::Bool = true) - -The standard binary 2 level HGF model, which takes a binary input, and learns the probability of either outcome. -It has one binary input node u, with a binary value parent xbin, which in turn has a continuous value parent xprob. - -# Config defaults: - - ("u", "category_means"): [0, 1] - - ("u", "input_precision"): Inf - - ("xprob", "volatility"): -2 - - ("xbin", "xprob", "value_coupling"): 1 - - ("xprob", "initial_mean"): 0 - - ("xprob", "initial_precision"): 1 -""" -function premade_binary_2level(config::Dict; verbose::Bool = true) - - #Defaults - spec_defaults = Dict( - ("u", "category_means") => [0, 1], - ("u", "input_precision") => Inf, - - ("xprob", "volatility") => -2, - ("xprob", "drift") => 0, - ("xprob", "autoregression_target") => 0, - ("xprob", "autoregression_strength") => 0, - ("xprob", "initial_mean") => 0, - ("xprob", "initial_precision") => 1, - - ("xbin", "xprob", "value_coupling") => 1, - - "update_type" => EnhancedUpdate(), - ) - - #Warn the user about used defaults and misspecified keys - if verbose - warn_premade_defaults(spec_defaults, config) - end - - #Merge to overwrite defaults - config = merge(spec_defaults, config) - - - #List of input nodes to create - input_nodes = Dict( - "name" => "u", - "type" => "binary", - "category_means" => config[("u", "category_means")], - "input_precision" => config[("u", "input_precision")], - ) - - #List of state nodes to create - state_nodes = [ - Dict("name" => "xbin", "type" => "binary"), - Dict( - "name" => "xprob", - "type" => "continuous", - "volatility" => config[("xprob", "volatility")], - "drift" => config[("xprob", "drift")], - "autoregression_target" => config[("xprob", "autoregression_target")], - "autoregression_strength" => config[("xprob", "autoregression_strength")], - "initial_mean" => config[("xprob", "initial_mean")], - "initial_precision" => config[("xprob", "initial_precision")], - ), - ] - - #List of child-parent relations - edges = [ - Dict("child" => "u", "value_parents" => "xbin"), - Dict( - "child" => "xbin", - "value_parents" => ("xprob", config[("xbin", "xprob", "value_coupling")]), - ), - ] - - #Initialize the HGF - init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - verbose = false, - update_type = config["update_type"], - ) -end - - -""" - premade_binary_3level(config::Dict; verbose::Bool = true) - -The standard binary 3 level HGF model, which takes a binary input, and learns the probability of either outcome. -It has one binary input node u, with a binary value parent xbin, which in turn has a continuous value parent xprob. This then has a continunous volatility parent xvol. - -This HGF has five shared parameters: -"xprob_volatility" -"xprob_initial_precisions" -"xprob_initial_means" -"value_couplings_xbin_xprob" -"volatility_couplings_xprob_xvol" - -# Config defaults: - - ("u", "category_means"): [0, 1] - - ("u", "input_precision"): Inf - - ("xprob", "volatility"): -2 - - ("xvol", "volatility"): -2 - - ("xbin", "xprob", "value_coupling"): 1 - - ("xprob", "xvol", "volatility_coupling"): 1 - - ("xprob", "initial_mean"): 0 - - ("xprob", "initial_precision"): 1 - - ("xvol", "initial_mean"): 0 - - ("xvol", "initial_precision"): 1 -""" -function premade_binary_3level(config::Dict; verbose::Bool = true) - - #Defaults - defaults = Dict( - ("u", "category_means") => [0, 1], - ("u", "input_precision") => Inf, - - ("xprob", "volatility") => -2, - ("xprob", "drift") => 0, - ("xprob", "autoregression_target") => 0, - ("xprob", "autoregression_strength") => 0, - ("xprob", "initial_mean") => 0, - ("xprob", "initial_precision") => 1, - - ("xvol", "volatility") => -2, - ("xvol", "drift") => 0, - ("xvol", "autoregression_target") => 0, - ("xvol", "autoregression_strength") => 0, - ("xvol", "initial_mean") => 0, - ("xvol", "initial_precision") => 1, - - ("xbin", "xprob", "value_coupling") => 1, - ("xprob", "xvol", "volatility_coupling") => 1, - - "update_type" => EnhancedUpdate(), - ) - - #Warn the user about used defaults and misspecified keys - if verbose - warn_premade_defaults(defaults, config) - end - - #Merge to overwrite defaults - config = merge(defaults, config) - - - #List of input nodes to create - input_nodes = Dict( - "name" => "u", - "type" => "binary", - "category_means" => config[("u", "category_means")], - "input_precision" => config[("u", "input_precision")], - ) - - #List of state nodes to create - state_nodes = [ - Dict("name" => "xbin", "type" => "binary"), - Dict( - "name" => "xprob", - "type" => "continuous", - "volatility" => config[("xprob", "volatility")], - "drift" => config[("xprob", "drift")], - "autoregression_target" => config[("xprob", "autoregression_target")], - "autoregression_strength" => config[("xprob", "autoregression_strength")], - "initial_mean" => config[("xprob", "initial_mean")], - "initial_precision" => config[("xprob", "initial_precision")], - ), - Dict( - "name" => "xvol", - "type" => "continuous", - "volatility" => config[("xvol", "volatility")], - "drift" => config[("xvol", "drift")], - "autoregression_target" => config[("xvol", "autoregression_target")], - "autoregression_strength" => config[("xvol", "autoregression_strength")], - "initial_mean" => config[("xvol", "initial_mean")], - "initial_precision" => config[("xvol", "initial_precision")], - ), - ] - - #List of child-parent relations - edges = [ - Dict("child" => "u", "value_parents" => "xbin"), - Dict( - "child" => "xbin", - "value_parents" => ("xprob", config[("xbin", "xprob", "value_coupling")]), - ), - Dict( - "child" => "xprob", - "volatility_parents" => ("xvol", config[("xprob", "xvol", "volatility_coupling")]), - ), - ] - - #Initialize the HGF - init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - verbose = false, - update_type = config["update_type"], - ) -end - -""" - premade_categorical_3level(config::Dict; verbose::Bool = true) - -The categorical 3 level HGF model, which takes an input from one of n categories and learns the probability of a category appearing. -It has one categorical input node u, with a categorical value parent xcat. -The categorical node has a binary value parent xbin_n for each category n, each of which has a continuous value parent xprob_n. -Finally, all of these continuous nodes share a continuous volatility parent xvol. -Setting parameter values for xbin and xprob sets that parameter value for each of the xbin_n and xprob_n nodes. - -# Config defaults: - - "n_categories": 4 - - ("xprob", "volatility"): -2 - - ("xvol", "volatility"): -2 - - ("xbin", "xprob", "value_coupling"): 1 - - ("xprob", "xvol", "volatility_coupling"): 1 - - ("xprob", "initial_mean"): 0 - - ("xprob", "initial_precision"): 1 - - ("xvol", "initial_mean"): 0 - - ("xvol", "initial_precision"): 1 -""" -function premade_categorical_3level(config::Dict; verbose::Bool = true) - - #Defaults - defaults = Dict( - "n_categories" => 4, - - ("xprob", "volatility") => -2, - ("xprob", "drift") => 0, - ("xprob", "autoregression_target") => 0, - ("xprob", "autoregression_strength") => 0, - ("xprob", "initial_mean") => 0, - ("xprob", "initial_precision") => 1, - - ("xvol", "volatility") => -2, - ("xvol", "drift") => 0, - ("xvol", "autoregression_target") => 0, - ("xvol", "autoregression_strength") => 0, - ("xvol", "initial_mean") => 0, - ("xvol", "initial_precision") => 1, - - ("xbin", "xprob", "value_coupling") => 1, - ("xprob", "xvol", "volatility_coupling") => 1, - - "update_type" => EnhancedUpdate(), - ) - - #Warn the user about used defaults and misspecified keys - if verbose - warn_premade_defaults(defaults, config) - end - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ##Prep category node parent names - #Vector for category node binary parent names - category_binary_parent_names = Vector{String}() - #Vector for binary node continuous parent names - binary_continuous_parent_names = Vector{String}() - - #Empty lists for derived parameters - derived_parameters_xprob_initial_precision = [] - derived_parameters_xprob_initial_mean = [] - derived_parameters_xprob_volatility = [] - derived_parameters_xprob_drift = [] - derived_parameters_xprob_autoregression_target = [] - derived_parameters_xprob_autoregression_strength = [] - derived_parameters_xprob_xvol_volatility_coupling = [] - derived_parameters_value_coupling_xbin_xprob = [] - - #Populate the category node vectors with node names - for category_number = 1:config["n_categories"] - push!(category_binary_parent_names, "xbin_" * string(category_number)) - push!(binary_continuous_parent_names, "xprob_" * string(category_number)) - end - - ##List of input nodes - input_nodes = Dict("name" => "u", "type" => "categorical") - - ##List of state nodes - state_nodes = [Dict{String,Any}("name" => "xcat", "type" => "categorical")] - - #Add category node binary parents - for node_name in category_binary_parent_names - push!(state_nodes, Dict("name" => node_name, "type" => "binary")) - end - - #Add binary node continuous parents - for node_name in binary_continuous_parent_names - push!( - state_nodes, - Dict( - "name" => node_name, - "type" => "continuous", - "initial_mean" => config[("xprob", "initial_mean")], - "initial_precision" => config[("xprob", "initial_precision")], - "volatility" => config[("xprob", "volatility")], - "drift" => config[("xprob", "drift")], - "autoregression_target" => config[("xprob", "autoregression_target")], - "autoregression_strength" => config[("xprob", "autoregression_strength")], - ), - ) - #Add the derived parameter name to derived parameters vector - push!(derived_parameters_xprob_initial_precision, (node_name, "initial_precision")) - push!(derived_parameters_xprob_initial_mean, (node_name, "initial_mean")) - push!(derived_parameters_xprob_volatility, (node_name, "volatility")) - push!(derived_parameters_xprob_drift, (node_name, "drift")) - push!(derived_parameters_xprob_autoregression_strength, (node_name, "autoregression_strength")) - push!(derived_parameters_xprob_autoregression_target, (node_name, "autoregression_target")) - end - - #Add volatility parent - push!( - state_nodes, - Dict( - "name" => "xvol", - "type" => "continuous", - "volatility" => config[("xvol", "volatility")], - "drift" => config[("xvol", "drift")], - "autoregression_target" => config[("xvol", "autoregression_target")], - "autoregression_strength" => config[("xvol", "autoregression_strength")], - "initial_mean" => config[("xvol", "initial_mean")], - "initial_precision" => config[("xvol", "initial_precision")], - ), - ) - - - ##List of child-parent relations - edges = [ - Dict("child" => "u", "value_parents" => "xcat"), - Dict("child" => "xcat", "value_parents" => category_binary_parent_names), - ] - - #Add relations between binary nodes and their parents - for (child_name, parent_name) in - zip(category_binary_parent_names, binary_continuous_parent_names) - push!( - edges, - Dict( - "child" => child_name, - "value_parents" => (parent_name, config[("xbin", "xprob", "value_coupling")]), - ), - ) - #Add the derived parameter name to derived parameters vector - push!( - derived_parameters_value_coupling_xbin_xprob, - (child_name, parent_name, "value_coupling"), - ) - end - - #Add relations between binary node parents and the volatility parent - for child_name in binary_continuous_parent_names - push!( - edges, - Dict( - "child" => child_name, - "volatility_parents" => ("xvol", config[("xprob", "xvol", "volatility_coupling")]), - ), - ) - #Add the derived parameter name to derived parameters vector - push!( - derived_parameters_xprob_xvol_volatility_coupling, - (child_name, "xvol", "volatility_coupling"), - ) - end - - #Create dictionary with shared parameter information - shared_parameters = Dict() - - shared_parameters["xprob_volatility"] = - (config[("xprob", "volatility")], derived_parameters_xprob_volatility) - - shared_parameters["xprob_initial_precisions"] = - (config[("xprob", "initial_precision")], derived_parameters_xprob_initial_precision) - - shared_parameters["xprob_initial_means"] = - (config[("xprob", "initial_mean")], derived_parameters_xprob_initial_mean) - - shared_parameters["xprob_drifts"] = - (config[("xprob", "drift")], derived_parameters_xprob_drift) - - shared_parameters["xprob_autoregression_strengths"] = - (config[("xprob", "autoregression_strength")], derived_parameters_xprob_autoregression_strength) - - shared_parameters["xprob_autoregression_targets"] = - (config[("xprob", "autoregression_target")], derived_parameters_xprob_autoregression_target) - - shared_parameters["value_couplings_xbin_xprob"] = - (config[("xbin", "xprob", "value_coupling")], derived_parameters_value_coupling_xbin_xprob) - - shared_parameters["volatility_couplings_xprob_xvol"] = ( - config[("xprob", "xvol", "volatility_coupling")], - derived_parameters_xprob_xvol_volatility_coupling, - ) - - #Initialize the HGF - init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - shared_parameters = shared_parameters, - verbose = false, - update_type = config["update_type"], - ) -end - -""" - premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true) - -The categorical state transition 3 level HGF model, learns state transition probabilities between a set of n categorical states. -It has one categorical input node u, with a categorical value parent xcat_n for each of the n categories, representing which category was transitioned from. -Each categorical node then has a binary parent xbin_n_m, representing the category m which the transition was towards. -Each binary node xbin_n_m has a continuous parent xprob_n_m. -Finally, all of these continuous nodes share a continuous volatility parent xvol. -Setting parameter values for xbin and xprob sets that parameter value for each of the xbin_n_m and xprob_n_m nodes. - -This HGF has five shared parameters: -"xprob_volatility" -"xprob_initial_precisions" -"xprob_initial_means" -"value_couplings_xbin_xprob" -"volatility_couplings_xprob_xvol" - -# Config defaults: - - "n_categories": 4 - - ("xprob", "volatility"): -2 - - ("xvol", "volatility"): -2 - - ("xbin", "xprob", "volatility_coupling"): 1 - - ("xprob", "xvol", "volatility_coupling"): 1 - - ("xprob", "initial_mean"): 0 - - ("xprob", "initial_precision"): 1 - - ("xvol", "initial_mean"): 0 - - ("xvol", "initial_precision"): 1 -""" -function premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true) - - #Defaults - defaults = Dict( - "n_categories" => 4, - - ("xprob", "volatility") => -2, - ("xprob", "drift") => 0, - ("xprob", "autoregression_target") => 0, - ("xprob", "autoregression_strength") => 0, - ("xprob", "initial_mean") => 0, - ("xprob", "initial_precision") => 1, - - ("xvol", "volatility") => -2, - ("xvol", "drift") => 0, - ("xvol", "autoregression_target") => 0, - ("xvol", "autoregression_strength") => 0, - ("xvol", "initial_mean") => 0, - ("xvol", "initial_precision") => 1, - - ("xbin", "xprob", "value_coupling") => 1, - ("xprob", "xvol", "volatility_coupling") => 1, - - "update_type" => EnhancedUpdate(), - ) - - #Warn the user about used defaults and misspecified keys - if verbose - warn_premade_defaults(defaults, config) - end - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ##Prepare node names - #Empty lists for node names - categorical_input_node_names = Vector{String}() - categorical_state_node_names = Vector{String}() - categorical_node_binary_parent_names = Vector{String}() - binary_node_continuous_parent_names = Vector{String}() - - #Empty lists for derived parameters - derived_parameters_xprob_initial_precision = [] - derived_parameters_xprob_initial_mean = [] - derived_parameters_xprob_volatility = [] - derived_parameters_xprob_drift = [] - derived_parameters_xprob_autoregression_target = [] - derived_parameters_xprob_autoregression_strength = [] - derived_parameters_value_coupling_xbin_xprob = [] - derived_parameters_xprob_xvol_volatility_coupling = [] - - #Go through each category that the transition may have been from - for category_from = 1:config["n_categories"] - #One input node and its state node parent for each - push!(categorical_input_node_names, "u" * string(category_from)) - push!(categorical_state_node_names, "xcat_" * string(category_from)) - #Go through each category that the transition may have been to - for category_to = 1:config["n_categories"] - #Each categorical state node has a binary parent for each - push!( - categorical_node_binary_parent_names, - "xbin_" * string(category_from) * "_" * string(category_to), - ) - #And each binary parent has a continuous parent of its own - push!( - binary_node_continuous_parent_names, - "xprob_" * string(category_from) * "_" * string(category_to), - ) - end - end - - ##Create input nodes - #Initialize list - input_nodes = Vector{Dict}() - - #For each categorical input node - for node_name in categorical_input_node_names - #Add it to the list - push!(input_nodes, Dict("name" => node_name, "type" => "categorical")) - end - - ##Create state nodes - #Initialize list - state_nodes = Vector{Dict}() - - #For each cateogrical state node - for node_name in categorical_state_node_names - #Add it to the list - push!(state_nodes, Dict("name" => node_name, "type" => "categorical")) - end - - #For each categorical node binary parent - for node_name in categorical_node_binary_parent_names - #Add it to the list - push!(state_nodes, Dict("name" => node_name, "type" => "binary")) - end - - #For each binary node continuous parent - for node_name in binary_node_continuous_parent_names - #Add it to the list, with parameter settings from the config - push!( - state_nodes, - Dict( - "name" => node_name, - "type" => "continuous", - "initial_mean" => config[("xprob", "initial_mean")], - "initial_precision" => config[("xprob", "initial_precision")], - "volatility" => config[("xprob", "volatility")], - "drift" => config[("xprob", "drift")], - "autoregression_target" => config[("xprob", "autoregression_target")], - "autoregression_strength" => config[("xprob", "autoregression_strength")], - ), - ) - #Add the derived parameter name to derived parameters vector - push!(derived_parameters_xprob_initial_precision, (node_name, "initial_precision")) - push!(derived_parameters_xprob_initial_mean, (node_name, "initial_mean")) - push!(derived_parameters_xprob_volatility, (node_name, "volatility")) - push!(derived_parameters_xprob_drift, (node_name, "drift")) - push!(derived_parameters_xprob_autoregression_strength, (node_name, "autoregression_strength")) - push!(derived_parameters_xprob_autoregression_target, (node_name, "autoregression_target")) - end - - - #Add the shared volatility parent of the continuous nodes - push!( - state_nodes, - Dict( - "name" => "xvol", - "type" => "continuous", - "volatility" => config[("xvol", "volatility")], - "drift" => config[("xvol", "drift")], - "autoregression_target" => config[("xvol", "autoregression_target")], - "autoregression_strength" => config[("xvol", "autoregression_strength")], - "initial_mean" => config[("xvol", "initial_mean")], - "initial_precision" => config[("xvol", "initial_precision")], - ), - ) - - ##Create child-parent relations - #Initialize list - edges = Vector{Dict}() - - #For each categorical input node and its corresponding state node parent - for (child_name, parent_name) in - zip(categorical_input_node_names, categorical_state_node_names) - #Add their relation to the list - push!(edges, Dict("child" => child_name, "value_parents" => parent_name)) - end - - #For each categorical state node - for child_node_name in categorical_state_node_names - #Get the category it represents transitions from - (child_supername, child_category_from) = split(child_node_name, "_") - - #For each potential parent node - for parent_node_name in categorical_node_binary_parent_names - #Get the category it represents transitions from - (parent_supername, parent_category_from, parent_category_to) = - split(parent_node_name, "_") - - #If these match - if parent_category_from == child_category_from - #Add the parent as parent of the child - push!( - edges, - Dict("child" => child_node_name, "value_parents" => parent_node_name), - ) - end - end - end - - #For each binary parent of categorical nodes and their corresponding continuous parents - for (child_name, parent_name) in - zip(categorical_node_binary_parent_names, binary_node_continuous_parent_names) - #Add their relations to the list, with the same value coupling - push!( - edges, - Dict( - "child" => child_name, - "value_parents" => (parent_name, config[("xbin", "xprob", "value_coupling")]), - ), - ) - #Add the derived parameter name to derived parameters vector - push!( - derived_parameters_value_coupling_xbin_xprob, - (child_name, parent_name, "value_coupling"), - ) - end - - - #Add the shared continuous node volatility parent to the continuous nodes - for child_name in binary_node_continuous_parent_names - push!( - edges, - Dict( - "child" => child_name, - "volatility_parents" => ("xvol", config[("xprob", "xvol", "volatility_coupling")]), - ), - ) - #Add the derived parameter name to derived parameters vector - push!( - derived_parameters_xprob_xvol_volatility_coupling, - (child_name, "xvol", "volatility_coupling"), - ) - - end - - #Create dictionary with shared parameter information - - shared_parameters = Dict() - - shared_parameters["xprob_volatility"] = - (config[("xprob", "volatility")], derived_parameters_xprob_volatility) - - shared_parameters["xprob_initial_precisions"] = - (config[("xprob", "initial_precision")], derived_parameters_xprob_initial_precision) - - shared_parameters["xprob_initial_means"] = - (config[("xprob", "initial_mean")], derived_parameters_xprob_initial_mean) - - shared_parameters["xprob_drifts"] = - (config[("xprob", "drift")], derived_parameters_xprob_drift) - - shared_parameters["xprob_autoregression_strengths"] = - (config[("xprob", "autoregression_strength")], derived_parameters_xprob_autoregression_strength) - - shared_parameters["xprob_autoregression_targets"] = - (config[("xprob", "autoregression_target")], derived_parameters_xprob_autoregression_target) - - shared_parameters["value_couplings_xbin_xprob"] = - (config[("xbin", "xprob", "value_coupling")], derived_parameters_value_coupling_xbin_xprob) - - shared_parameters["volatility_couplings_xprob_xvol"] = ( - config[("xprob", "xvol", "volatility_coupling")], - derived_parameters_xprob_xvol_volatility_coupling, - ) - - #Initialize the HGF - init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - shared_parameters = shared_parameters, - verbose = false, - update_type = config["update_type"], - ) -end diff --git a/src/premade_models/premade_hgfs/premade_JGET.jl b/src/premade_models/premade_hgfs/premade_JGET.jl new file mode 100644 index 0000000..1685f49 --- /dev/null +++ b/src/premade_models/premade_hgfs/premade_JGET.jl @@ -0,0 +1,138 @@ +""" +premade_JGET(config::Dict; verbose::Bool = true) + +The HGF used in the JGET model. It has a single continuous input node u, with a value parent x, and a volatility parent xnoise. x has volatility parent xvol, and xnoise has a volatility parent xnoise_vol. + +# Config defaults: + - ("u", "input_noise"): -2 + - ("x", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("xnoise", "volatility"): -2 + - ("xnoise_vol", "volatility"): -2 + - ("u", "x", "coupling_strength"): 1 + - ("u", "xnoise", "coupling_strength"): 1 + - ("x", "xvol", "coupling_strength"): 1 + - ("xnoise", "xnoise_vol", "coupling_strength"): 1 + - ("x", "initial_mean"): 0 + - ("x", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 + - ("xnoise", "initial_mean"): 0 + - ("xnoise", "initial_precision"): 1 + - ("xnoise_vol", "initial_mean"): 0 + - ("xnoise_vol", "initial_precision"): 1 +""" +function premade_JGET(config::Dict; verbose::Bool = true) + + #Defaults + spec_defaults = Dict( + ("u", "input_noise") => -2, + ("x", "volatility") => -2, + ("x", "drift") => 0, + ("x", "autoregression_target") => 0, + ("x", "autoregression_strength") => 0, + ("x", "initial_mean") => 0, + ("x", "initial_precision") => 1, + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoregression_target") => 0, + ("xvol", "autoregression_strength") => 0, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, + ("xnoise", "volatility") => -2, + ("xnoise", "drift") => 0, + ("xnoise", "autoregression_target") => 0, + ("xnoise", "autoregression_strength") => 0, + ("xnoise", "initial_mean") => 0, + ("xnoise", "initial_precision") => 1, + ("xnoise_vol", "volatility") => -2, + ("xnoise_vol", "drift") => 0, + ("xnoise_vol", "autoregression_target") => 0, + ("xnoise_vol", "autoregression_strength") => 0, + ("xnoise_vol", "initial_mean") => 0, + ("xnoise_vol", "initial_precision") => 1, + ("u", "xnoise", "coupling_strength") => 1, + ("x", "xvol", "coupling_strength") => 1, + ("xnoise", "xnoise_vol", "coupling_strength") => 1, + "update_type" => EnhancedUpdate(), + ) + + #Warn the user about used defaults and misspecified keys + if verbose + warn_premade_defaults(spec_defaults, config) + end + + #Merge to overwrite defaults + config = merge(spec_defaults, config) + + + #List of input nodes to create + input_nodes = Dict( + "name" => "u", + "type" => "continuous", + "input_noise" => config[("u", "input_noise")], + ) + + #List of state nodes to create + state_nodes = [ + Dict( + "name" => "x", + "type" => "continuous", + "volatility" => config[("x", "volatility")], + "drift" => config[("x", "drift")], + "autoregression_target" => config[("x", "autoregression_target")], + "autoregression_strength" => config[("x", "autoregression_strength")], + "initial_mean" => config[("x", "initial_mean")], + "initial_precision" => config[("x", "initial_precision")], + ), + Dict( + "name" => "xvol", + "type" => "continuous", + "volatility" => config[("xvol", "volatility")], + "drift" => config[("xvol", "drift")], + "autoregression_target" => config[("xvol", "autoregression_target")], + "autoregression_strength" => config[("xvol", "autoregression_strength")], + "initial_mean" => config[("xvol", "initial_mean")], + "initial_precision" => config[("xvol", "initial_precision")], + ), + Dict( + "name" => "xnoise", + "type" => "continuous", + "volatility" => config[("xnoise", "volatility")], + "drift" => config[("xnoise", "drift")], + "autoregression_target" => config[("xnoise", "autoregression_target")], + "autoregression_strength" => config[("xnoise", "autoregression_strength")], + "initial_mean" => config[("xnoise", "initial_precision")], + "initial_precision" => config[("xnoise", "initial_precision")], + ), + Dict( + "name" => "xnoise_vol", + "type" => "continuous", + "volatility" => config[("xnoise_vol", "volatility")], + "drift" => config[("xnoise_vol", "drift")], + "autoregression_target" => config[("xnoise_vol", "autoregression_target")], + "autoregression_strength" => + config[("xnoise_vol", "autoregression_strength")], + "initial_mean" => config[("xnoise_vol", "initial_mean")], + "initial_precision" => config[("xnoise_vol", "initial_precision")], + ), + ] + + #List of child-parent relations + edges = Dict( + ("u", "x") => ObservationCoupling(), + ("u", "xnoise") => NoiseCoupling(config[("u", "xnoise", "coupling_strength")]), + ("x", "xvol") => VolatilityCoupling(config[("x", "xvol", "coupling_strength")]), + ("xnoise", "xnoise_vol") => + VolatilityCoupling(config[("xnoise", "xnoise_vol", "coupling_strength")]), + ) + + #Initialize the HGF + init_hgf( + input_nodes = input_nodes, + state_nodes = state_nodes, + edges = edges, + verbose = false, + update_type = config["update_type"], + ) +end diff --git a/src/premade_models/premade_hgfs/premade_binary_2level.jl b/src/premade_models/premade_hgfs/premade_binary_2level.jl new file mode 100644 index 0000000..1ac8f37 --- /dev/null +++ b/src/premade_models/premade_hgfs/premade_binary_2level.jl @@ -0,0 +1,73 @@ + +""" + premade_binary_2level(config::Dict; verbose::Bool = true) + +The standard binary 2 level HGF model, which takes a binary input, and learns the probability of either outcome. +It has one binary input node u, with a binary value parent xbin, which in turn has a continuous value parent xprob. + +# Config defaults: +""" +function premade_binary_2level(config::Dict; verbose::Bool = true) + + #Defaults + spec_defaults = Dict( + ("u", "category_means") => [0, 1], + ("u", "input_precision") => Inf, + ("xprob", "volatility") => -2, + ("xprob", "drift") => 0, + ("xprob", "autoregression_target") => 0, + ("xprob", "autoregression_strength") => 0, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xbin", "xprob", "coupling_strength") => 1, + "update_type" => EnhancedUpdate(), + ) + + #Warn the user about used defaults and misspecified keys + if verbose + warn_premade_defaults(spec_defaults, config) + end + + #Merge to overwrite defaults + config = merge(spec_defaults, config) + + + #List of input nodes to create + input_nodes = Dict( + "name" => "u", + "type" => "binary", + "category_means" => config[("u", "category_means")], + "input_precision" => config[("u", "input_precision")], + ) + + #List of state nodes to create + state_nodes = [ + Dict("name" => "xbin", "type" => "binary"), + Dict( + "name" => "xprob", + "type" => "continuous", + "volatility" => config[("xprob", "volatility")], + "drift" => config[("xprob", "drift")], + "autoregression_target" => config[("xprob", "autoregression_target")], + "autoregression_strength" => config[("xprob", "autoregression_strength")], + "initial_mean" => config[("xprob", "initial_mean")], + "initial_precision" => config[("xprob", "initial_precision")], + ), + ] + + #List of edges + edges = Dict( + ("u", "xbin") => ObservationCoupling(), + ("xbin", "xprob") => + ProbabilityCoupling(config[("xbin", "xprob", "coupling_strength")]), + ) + + #Initialize the HGF + init_hgf( + input_nodes = input_nodes, + state_nodes = state_nodes, + edges = edges, + verbose = false, + update_type = config["update_type"], + ) +end diff --git a/src/premade_models/premade_hgfs/premade_binary_3level.jl b/src/premade_models/premade_hgfs/premade_binary_3level.jl new file mode 100644 index 0000000..90cb496 --- /dev/null +++ b/src/premade_models/premade_hgfs/premade_binary_3level.jl @@ -0,0 +1,109 @@ + +""" + premade_binary_3level(config::Dict; verbose::Bool = true) + +The standard binary 3 level HGF model, which takes a binary input, and learns the probability of either outcome. +It has one binary input node u, with a binary value parent xbin, which in turn has a continuous value parent xprob. This then has a continunous volatility parent xvol. + +This HGF has five shared parameters: +"xprob_volatility" +"xprob_initial_precisions" +"xprob_initial_means" +"coupling_strengths_xbin_xprob" +"coupling_strengths_xprob_xvol" + +# Config defaults: + - ("u", "category_means"): [0, 1] + - ("u", "input_precision"): Inf + - ("xprob", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("xbin", "xprob", "coupling_strength"): 1 + - ("xprob", "xvol", "coupling_strength"): 1 + - ("xprob", "initial_mean"): 0 + - ("xprob", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 +""" +function premade_binary_3level(config::Dict; verbose::Bool = true) + + #Defaults + defaults = Dict( + ("u", "category_means") => [0, 1], + ("u", "input_precision") => Inf, + ("xprob", "volatility") => -2, + ("xprob", "drift") => 0, + ("xprob", "autoregression_target") => 0, + ("xprob", "autoregression_strength") => 0, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoregression_target") => 0, + ("xvol", "autoregression_strength") => 0, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "coupling_strength") => 1, + ("xprob", "xvol", "coupling_strength") => 1, + "update_type" => EnhancedUpdate(), + ) + + #Warn the user about used defaults and misspecified keys + if verbose + warn_premade_defaults(defaults, config) + end + + #Merge to overwrite defaults + config = merge(defaults, config) + + + #List of input nodes to create + input_nodes = Dict( + "name" => "u", + "type" => "binary", + "category_means" => config[("u", "category_means")], + "input_precision" => config[("u", "input_precision")], + ) + + #List of state nodes to create + state_nodes = [ + Dict("name" => "xbin", "type" => "binary"), + Dict( + "name" => "xprob", + "type" => "continuous", + "volatility" => config[("xprob", "volatility")], + "drift" => config[("xprob", "drift")], + "autoregression_target" => config[("xprob", "autoregression_target")], + "autoregression_strength" => config[("xprob", "autoregression_strength")], + "initial_mean" => config[("xprob", "initial_mean")], + "initial_precision" => config[("xprob", "initial_precision")], + ), + Dict( + "name" => "xvol", + "type" => "continuous", + "volatility" => config[("xvol", "volatility")], + "drift" => config[("xvol", "drift")], + "autoregression_target" => config[("xvol", "autoregression_target")], + "autoregression_strength" => config[("xvol", "autoregression_strength")], + "initial_mean" => config[("xvol", "initial_mean")], + "initial_precision" => config[("xvol", "initial_precision")], + ), + ] + + #List of child-parent relations + edges = Dict( + ("u", "xbin") => ObservationCoupling(), + ("xbin", "xprob") => + ProbabilityCoupling(config[("xbin", "xprob", "coupling_strength")]), + ("xprob", "xvol") => + VolatilityCoupling(config[("xprob", "xvol", "coupling_strength")]), + ) + + #Initialize the HGF + init_hgf( + input_nodes = input_nodes, + state_nodes = state_nodes, + edges = edges, + verbose = false, + update_type = config["update_type"], + ) +end diff --git a/src/premade_models/premade_hgfs/premade_categorical_3level.jl b/src/premade_models/premade_hgfs/premade_categorical_3level.jl new file mode 100644 index 0000000..6f58919 --- /dev/null +++ b/src/premade_models/premade_hgfs/premade_categorical_3level.jl @@ -0,0 +1,207 @@ + + +""" + premade_categorical_3level(config::Dict; verbose::Bool = true) + +The categorical 3 level HGF model, which takes an input from one of n categories and learns the probability of a category appearing. +It has one categorical input node u, with a categorical value parent xcat. +The categorical node has a binary value parent xbin_n for each category n, each of which has a continuous value parent xprob_n. +Finally, all of these continuous nodes share a continuous volatility parent xvol. +Setting parameter values for xbin and xprob sets that parameter value for each of the xbin_n and xprob_n nodes. + +# Config defaults: + - "n_categories": 4 + - ("xprob", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("xbin", "xprob", "coupling_strength"): 1 + - ("xprob", "xvol", "coupling_strength"): 1 + - ("xprob", "initial_mean"): 0 + - ("xprob", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 +""" +function premade_categorical_3level(config::Dict; verbose::Bool = true) + + #Defaults + defaults = Dict( + "n_categories" => 4, + ("xprob", "volatility") => -2, + ("xprob", "drift") => 0, + ("xprob", "autoregression_target") => 0, + ("xprob", "autoregression_strength") => 0, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoregression_target") => 0, + ("xvol", "autoregression_strength") => 0, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "coupling_strength") => 1, + ("xprob", "xvol", "coupling_strength") => 1, + "update_type" => EnhancedUpdate(), + ) + + #Warn the user about used defaults and misspecified keys + if verbose + warn_premade_defaults(defaults, config) + end + + #Merge to overwrite defaults + config = merge(defaults, config) + + + ##Prep category node parent names + #Vector for category node binary parent names + category_parent_names = Vector{String}() + #Vector for binary node continuous parent names + probability_parent_names = Vector{String}() + + #Empty lists for derived parameters + derived_parameters_xprob_initial_precision = [] + derived_parameters_xprob_initial_mean = [] + derived_parameters_xprob_volatility = [] + derived_parameters_xprob_drift = [] + derived_parameters_xprob_autoregression_target = [] + derived_parameters_xprob_autoregression_strength = [] + derived_parameters_xbin_xprob_coupling_strength = [] + derived_parameters_xprob_xvol_coupling_strength = [] + + #Populate the category node vectors with node names + for category_number = 1:config["n_categories"] + push!(category_parent_names, "xbin_" * string(category_number)) + push!(probability_parent_names, "xprob_" * string(category_number)) + end + + ##List of input nodes + input_nodes = Dict("name" => "u", "type" => "categorical") + + ##List of state nodes + state_nodes = [Dict{String,Any}("name" => "xcat", "type" => "categorical")] + + #Add category node binary parents + for node_name in category_parent_names + push!(state_nodes, Dict("name" => node_name, "type" => "binary")) + end + + #Add binary node continuous parents + for node_name in probability_parent_names + push!( + state_nodes, + Dict( + "name" => node_name, + "type" => "continuous", + "initial_mean" => config[("xprob", "initial_mean")], + "initial_precision" => config[("xprob", "initial_precision")], + "volatility" => config[("xprob", "volatility")], + "drift" => config[("xprob", "drift")], + "autoregression_target" => config[("xprob", "autoregression_target")], + "autoregression_strength" => config[("xprob", "autoregression_strength")], + ), + ) + #Add the derived parameter name to derived parameters vector + push!(derived_parameters_xprob_initial_precision, (node_name, "initial_precision")) + push!(derived_parameters_xprob_initial_mean, (node_name, "initial_mean")) + push!(derived_parameters_xprob_volatility, (node_name, "volatility")) + push!(derived_parameters_xprob_drift, (node_name, "drift")) + push!( + derived_parameters_xprob_autoregression_strength, + (node_name, "autoregression_strength"), + ) + push!( + derived_parameters_xprob_autoregression_target, + (node_name, "autoregression_target"), + ) + end + + #Add volatility parent + push!( + state_nodes, + Dict( + "name" => "xvol", + "type" => "continuous", + "volatility" => config[("xvol", "volatility")], + "drift" => config[("xvol", "drift")], + "autoregression_target" => config[("xvol", "autoregression_target")], + "autoregression_strength" => config[("xvol", "autoregression_strength")], + "initial_mean" => config[("xvol", "initial_mean")], + "initial_precision" => config[("xvol", "initial_precision")], + ), + ) + + + ##List of child-parent relations + #Set the input node coupling + edges = Dict{Tuple{String,String},CouplingType}(("u", "xcat") => ObservationCoupling()) + + #For each set of categroy parents and their probability parents + for (category_parent_name, probability_parent_name) in + zip(category_parent_names, probability_parent_names) + + #Connect the binary category parents to the categorical state node + edges[("xcat", category_parent_name)] = CategoryCoupling() + + #Connect each category parent to its probability parent + edges[(category_parent_name, probability_parent_name)] = + ProbabilityCoupling(config[("xbin", "xprob", "coupling_strength")]) + + #Connect the probability parents to the shared volatility parent + edges[(probability_parent_name, "xvol")] = + VolatilityCoupling(config[("xprob", "xvol", "coupling_strength")]) + + #Add the coupling strengths to the lists of derived parameters + push!( + derived_parameters_xbin_xprob_coupling_strength, + (category_parent_name, probability_parent_name, "coupling_strength"), + ) + push!( + derived_parameters_xprob_xvol_coupling_strength, + (probability_parent_name, "xvol", "coupling_strength"), + ) + end + + #Create dictionary with shared parameter information + shared_parameters = Dict() + + shared_parameters["xprob_volatility"] = + (config[("xprob", "volatility")], derived_parameters_xprob_volatility) + + shared_parameters["xprob_initial_precision"] = + (config[("xprob", "initial_precision")], derived_parameters_xprob_initial_precision) + + shared_parameters["xprob_initial_mean"] = + (config[("xprob", "initial_mean")], derived_parameters_xprob_initial_mean) + + shared_parameters["xprob_drift"] = + (config[("xprob", "drift")], derived_parameters_xprob_drift) + + shared_parameters["xprob_autoregression_strength"] = ( + config[("xprob", "autoregression_strength")], + derived_parameters_xprob_autoregression_strength, + ) + + shared_parameters["xprob_autoregression_target"] = ( + config[("xprob", "autoregression_target")], + derived_parameters_xprob_autoregression_target, + ) + + shared_parameters["xbin_xprob_coupling_strength"] = ( + config[("xbin", "xprob", "coupling_strength")], + derived_parameters_xbin_xprob_coupling_strength, + ) + + shared_parameters["xprob_xvol_coupling_strength"] = ( + config[("xprob", "xvol", "coupling_strength")], + derived_parameters_xprob_xvol_coupling_strength, + ) + + #Initialize the HGF + init_hgf( + input_nodes = input_nodes, + state_nodes = state_nodes, + edges = edges, + shared_parameters = shared_parameters, + verbose = false, + update_type = config["update_type"], + ) +end diff --git a/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl new file mode 100644 index 0000000..b4c5788 --- /dev/null +++ b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl @@ -0,0 +1,272 @@ + +""" + premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true) + +The categorical state transition 3 level HGF model, learns state transition probabilities between a set of n categorical states. +It has one categorical input node u, with a categorical value parent xcat_n for each of the n categories, representing which category was transitioned from. +Each categorical node then has a binary parent xbin_n_m, representing the category m which the transition was towards. +Each binary node xbin_n_m has a continuous parent xprob_n_m. +Finally, all of these continuous nodes share a continuous volatility parent xvol. +Setting parameter values for xbin and xprob sets that parameter value for each of the xbin_n_m and xprob_n_m nodes. + +This HGF has five shared parameters: +"xprob_volatility" +"xprob_initial_precisions" +"xprob_initial_means" +"coupling_strengths_xbin_xprob" +"coupling_strengths_xprob_xvol" + +# Config defaults: + - "n_categories": 4 + - ("xprob", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("xbin", "xprob", "coupling_strength"): 1 + - ("xprob", "xvol", "coupling_strength"): 1 + - ("xprob", "initial_mean"): 0 + - ("xprob", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 +""" +function premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true) + + #Defaults + defaults = Dict( + "n_categories" => 4, + ("xprob", "volatility") => -2, + ("xprob", "drift") => 0, + ("xprob", "autoregression_target") => 0, + ("xprob", "autoregression_strength") => 0, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoregression_target") => 0, + ("xvol", "autoregression_strength") => 0, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "coupling_strength") => 1, + ("xprob", "xvol", "coupling_strength") => 1, + "update_type" => EnhancedUpdate(), + ) + + #Warn the user about used defaults and misspecified keys + if verbose + warn_premade_defaults(defaults, config) + end + + #Merge to overwrite defaults + config = merge(defaults, config) + + + ##Prepare node names + #Empty lists for node names + input_node_names = Vector{String}() + observation_parent_names = Vector{String}() + category_parent_names = Vector{String}() + probability_parent_names = Vector{String}() + + #Empty lists for derived parameters + derived_parameters_xprob_initial_precision = [] + derived_parameters_xprob_initial_mean = [] + derived_parameters_xprob_volatility = [] + derived_parameters_xprob_drift = [] + derived_parameters_xprob_autoregression_target = [] + derived_parameters_xprob_autoregression_strength = [] + derived_parameters_xbin_xprob_coupling_strength = [] + derived_parameters_xprob_xvol_coupling_strength = [] + + #Go through each category that the transition may have been from + for category_from = 1:config["n_categories"] + #One input node and its state node parent for each + push!(input_node_names, "u" * string(category_from)) + push!(observation_parent_names, "xcat_" * string(category_from)) + #Go through each category that the transition may have been to + for category_to = 1:config["n_categories"] + #Each categorical state node has a binary parent for each + push!( + category_parent_names, + "xbin_" * string(category_from) * "_" * string(category_to), + ) + #And each binary parent has a continuous parent of its own + push!( + probability_parent_names, + "xprob_" * string(category_from) * "_" * string(category_to), + ) + end + end + + ##Create input nodes + #Initialize list + input_nodes = Vector{Dict}() + + #For each categorical input node + for node_name in input_node_names + #Add it to the list + push!(input_nodes, Dict("name" => node_name, "type" => "categorical")) + end + + ##Create state nodes + #Initialize list + state_nodes = Vector{Dict}() + + #For each categorical state node + for node_name in observation_parent_names + #Add it to the list + push!(state_nodes, Dict("name" => node_name, "type" => "categorical")) + end + + #For each categorical node binary parent + for node_name in category_parent_names + #Add it to the list + push!(state_nodes, Dict("name" => node_name, "type" => "binary")) + end + + #For each binary node continuous parent + for node_name in probability_parent_names + #Add it to the list, with parameter settings from the config + push!( + state_nodes, + Dict( + "name" => node_name, + "type" => "continuous", + "initial_mean" => config[("xprob", "initial_mean")], + "initial_precision" => config[("xprob", "initial_precision")], + "volatility" => config[("xprob", "volatility")], + "drift" => config[("xprob", "drift")], + "autoregression_target" => config[("xprob", "autoregression_target")], + "autoregression_strength" => config[("xprob", "autoregression_strength")], + ), + ) + #Add the derived parameter name to derived parameters vector + push!(derived_parameters_xprob_initial_precision, (node_name, "initial_precision")) + push!(derived_parameters_xprob_initial_mean, (node_name, "initial_mean")) + push!(derived_parameters_xprob_volatility, (node_name, "volatility")) + push!(derived_parameters_xprob_drift, (node_name, "drift")) + push!( + derived_parameters_xprob_autoregression_strength, + (node_name, "autoregression_strength"), + ) + push!( + derived_parameters_xprob_autoregression_target, + (node_name, "autoregression_target"), + ) + end + + + #Add the shared volatility parent of the continuous nodes + push!( + state_nodes, + Dict( + "name" => "xvol", + "type" => "continuous", + "volatility" => config[("xvol", "volatility")], + "drift" => config[("xvol", "drift")], + "autoregression_target" => config[("xvol", "autoregression_target")], + "autoregression_strength" => config[("xvol", "autoregression_strength")], + "initial_mean" => config[("xvol", "initial_mean")], + "initial_precision" => config[("xvol", "initial_precision")], + ), + ) + + ##Create edges + #Initialize list + edges = Dict{Tuple{String,String},CouplingType}() + + #For each categorical input node and its corresponding state node parent + for (input_node_name, observation_parent_name) in + zip(input_node_names, observation_parent_names) + + #Add their connection + edges[(input_node_name, observation_parent_name)] = ObservationCoupling() + end + + #For each categorical state node + for observation_parent_name in observation_parent_names + #Get the category it represents transitions from + (observation_parent_prefix, child_category_from) = + split(observation_parent_name, "_") + + #For each potential parent node + for category_parent_name in category_parent_names + #Get the category it represents transitions to and from + (category_parent_prefix, parent_category_from, parent_category_to) = + split(category_parent_name, "_") + + #If these match + if parent_category_from == child_category_from + #Add their connection + edges[(observation_parent_name, category_parent_name)] = CategoryCoupling() + end + end + end + + #For each set of category parent and probability parent + for (category_parent_name, probability_parent_name) in + zip(category_parent_names, probability_parent_names) + + #Connect them + edges[(category_parent_name, probability_parent_name)] = + ProbabilityCoupling(config[("xbin", "xprob", "coupling_strength")]) + + #Connect the probability parents to the shared volatility parent + edges[(probability_parent_name, "xvol")] = + VolatilityCoupling(config[("xprob", "xvol", "coupling_strength")]) + + + #Add the parameters as derived parameters for shared parameters + push!( + derived_parameters_xbin_xprob_coupling_strength, + (category_parent_name, probability_parent_name, "coupling_strength"), + ) + push!( + derived_parameters_xprob_xvol_coupling_strength, + (probability_parent_name, "xvol", "coupling_strength"), + ) + end + + #Create dictionary with shared parameter information + + shared_parameters = Dict() + + shared_parameters["xprob_volatility"] = + (config[("xprob", "volatility")], derived_parameters_xprob_volatility) + + shared_parameters["xprob_initial_precision"] = + (config[("xprob", "initial_precision")], derived_parameters_xprob_initial_precision) + + shared_parameters["xprob_initial_mean"] = + (config[("xprob", "initial_mean")], derived_parameters_xprob_initial_mean) + + shared_parameters["xprob_drift"] = + (config[("xprob", "drift")], derived_parameters_xprob_drift) + + shared_parameters["xprob_autoregression_strength"] = ( + config[("xprob", "autoregression_strength")], + derived_parameters_xprob_autoregression_strength, + ) + + shared_parameters["xprob_autoregression_target"] = ( + config[("xprob", "autoregression_target")], + derived_parameters_xprob_autoregression_target, + ) + + shared_parameters["xbin_xprob_coupling_strength"] = ( + config[("xbin", "xprob", "coupling_strength")], + derived_parameters_xbin_xprob_coupling_strength, + ) + + shared_parameters["xprob_xvol_coupling_strength"] = ( + config[("xprob", "xvol", "coupling_strength")], + derived_parameters_xprob_xvol_coupling_strength, + ) + + #Initialize the HGF + init_hgf( + input_nodes = input_nodes, + state_nodes = state_nodes, + edges = edges, + shared_parameters = shared_parameters, + verbose = false, + update_type = config["update_type"], + ) +end diff --git a/src/premade_models/premade_hgfs/premade_continuous_2level.jl b/src/premade_models/premade_hgfs/premade_continuous_2level.jl new file mode 100644 index 0000000..389c93c --- /dev/null +++ b/src/premade_models/premade_hgfs/premade_continuous_2level.jl @@ -0,0 +1,93 @@ +""" + premade_continuous_2level(config::Dict; verbose::Bool = true) + +The standard 2 level continuous HGF, which filters a continuous input. +It has a continous input node u, with a single value parent x, which in turn has a single volatility parent xvol. + +# Config defaults: + - ("u", "input_noise"): -2 + - ("x", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("u", "x", "coupling_strength"): 1 + - ("x", "xvol", "coupling_strength"): 1 + - ("x", "initial_mean"): 0 + - ("x", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 +""" +function premade_continuous_2level(config::Dict; verbose::Bool = true) + + #Defaults + spec_defaults = Dict( + ("u", "input_noise") => -2, + ("x", "volatility") => -2, + ("x", "drift") => 0, + ("x", "autoregression_target") => 0, + ("x", "autoregression_strength") => 0, + ("x", "initial_mean") => 0, + ("x", "initial_precision") => 1, + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoregression_target") => 0, + ("xvol", "autoregression_strength") => 0, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, + ("x", "xvol", "coupling_strength") => 1, + "update_type" => EnhancedUpdate(), + ) + + #Warn the user about used defaults and misspecified keys + if verbose + warn_premade_defaults(spec_defaults, config) + end + + #Merge to overwrite defaults + config = merge(spec_defaults, config) + + + #List of input nodes to create + input_nodes = Dict( + "name" => "u", + "type" => "continuous", + "input_noise" => config[("u", "input_noise")], + ) + + #List of state nodes to create + state_nodes = [ + Dict( + "name" => "x", + "type" => "continuous", + "volatility" => config[("x", "volatility")], + "drift" => config[("x", "drift")], + "autoregression_target" => config[("x", "autoregression_target")], + "autoregression_strength" => config[("x", "autoregression_strength")], + "initial_mean" => config[("x", "initial_mean")], + "initial_precision" => config[("x", "initial_precision")], + ), + Dict( + "name" => "xvol", + "type" => "continuous", + "volatility" => config[("xvol", "volatility")], + "drift" => config[("xvol", "drift")], + "autoregression_target" => config[("xvol", "autoregression_target")], + "autoregression_strength" => config[("xvol", "autoregression_strength")], + "initial_mean" => config[("xvol", "initial_mean")], + "initial_precision" => config[("xvol", "initial_precision")], + ), + ] + + #List of child-parent relations + edges = Dict( + ("u", "x") => ObservationCoupling(), + ("x", "xvol") => VolatilityCoupling(config[("x", "xvol", "coupling_strength")]), + ) + + #Initialize the HGF + init_hgf( + input_nodes = input_nodes, + state_nodes = state_nodes, + edges = edges, + verbose = false, + update_type = config["update_type"], + ) +end diff --git a/src/update_hgf/node_updates/binary_input_node.jl b/src/update_hgf/node_updates/binary_input_node.jl new file mode 100644 index 0000000..2c8ea7c --- /dev/null +++ b/src/update_hgf/node_updates/binary_input_node.jl @@ -0,0 +1,68 @@ +################################### +######## Update prediction ######## +################################### + +##### Superfunction ##### +""" + update_node_prediction!(node::BinaryInputNode) + +There is no prediction update for binary input nodes, as the prediction precision is constant. +""" +function update_node_prediction!(node::BinaryInputNode) + return nothing +end + + +############################################### +######## Update value prediction error ######## +############################################### + +##### Superfunction ##### +""" + update_node_value_prediction_error!(node::BinaryInputNode) + +Update the value prediction error of a single binary input node. +""" +function update_node_value_prediction_error!(node::BinaryInputNode) + + #Calculate value prediction error + node.states.value_prediction_error = calculate_value_prediction_error(node) + push!(node.history.value_prediction_error, node.states.value_prediction_error) + + return nothing +end + +@doc raw""" + calculate_value_prediction_error(node::BinaryInputNode) + +Calculates the prediciton error of a binary input node with finite precision. + +Uses the equation +`` \delta_n= u - \sum_{j=1}^{j\;value\;parents} \hat{\mu}_{j} `` +""" +function calculate_value_prediction_error(node::BinaryInputNode) + + #For missing input + if ismissing(node.states.input_value) + #Set the prediction error to missing + value_prediction_error = [missing, missing] + else + #Substract to find the difference to each of the Gaussian means + value_prediction_error = node.parameters.category_means .- node.states.input_value + end +end + + +################################################### +######## Update precision prediction error ######## +################################################### + +##### Superfunction ##### +""" + update_node_precision_prediction_error!(node::BinaryInputNode) + +There is no volatility prediction error update for binary input nodes. +""" +function update_node_precision_prediction_error!(node::BinaryInputNode) + return nothing +end diff --git a/src/update_hgf/node_updates/binary_state_node.jl b/src/update_hgf/node_updates/binary_state_node.jl new file mode 100644 index 0000000..95fea07 --- /dev/null +++ b/src/update_hgf/node_updates/binary_state_node.jl @@ -0,0 +1,236 @@ +################################### +######## Update prediction ######## +################################### + +##### Superfunction ##### +""" + update_node_prediction!(node::BinaryStateNode) + +Update the prediction of a single binary state node. +""" +function update_node_prediction!(node::BinaryStateNode) + + #Update prediction mean + node.states.prediction_mean = calculate_prediction_mean(node) + push!(node.history.prediction_mean, node.states.prediction_mean) + + #Update prediction precision + node.states.prediction_precision = calculate_prediction_precision(node) + push!(node.history.prediction_precision, node.states.prediction_precision) + + return nothing +end + +##### Mean update ##### +@doc raw""" + calculate_prediction_mean(node::BinaryStateNode) + +Calculates a binary state node's prediction mean. + +Uses the equation +`` \hat{\mu}_n= \big(1+e^{\sum_{j=1}^{j\;value \; parents} \hat{\mu}_{j}}\big)^{-1} `` +""" +function calculate_prediction_mean(node::BinaryStateNode) + probability_parents = node.edges.probability_parents + + prediction_mean = 0 + + for parent in probability_parents + prediction_mean += + parent.states.prediction_mean * node.parameters.coupling_strengths[parent.name] + end + + prediction_mean = 1 / (1 + exp(-prediction_mean)) + + return prediction_mean +end + +##### Precision update ##### +@doc raw""" + calculate_prediction_precision(node::BinaryStateNode) + +Calculates a binary state node's prediction precision. + +Uses the equation +`` \hat{\pi}_n = \frac{1}{\hat{\mu}_n \cdot (1-\hat{\mu}_n)} `` +""" +function calculate_prediction_precision(node::BinaryStateNode) + 1 / (node.states.prediction_mean * (1 - node.states.prediction_mean)) +end + +################################## +######## Update posterior ######## +################################## + +##### Superfunction ##### +""" + update_node_posterior!(node::AbstractStateNode; update_type::HGFUpdateType) + +Update the posterior of a single continuous state node. This is the classic HGF update. +""" +function update_node_posterior!(node::BinaryStateNode, update_type::HGFUpdateType) + #Update posterior precision + node.states.posterior_precision = calculate_posterior_precision(node) + push!(node.history.posterior_precision, node.states.posterior_precision) + + #Update posterior mean + node.states.posterior_mean = calculate_posterior_mean(node, update_type) + push!(node.history.posterior_mean, node.states.posterior_mean) + + return nothing +end + +##### Precision update ##### +@doc raw""" + calculate_posterior_precision(node::BinaryStateNode) + +Calculates a binary node's posterior precision. + +Uses the equations + +`` \pi_n = inf `` +if the precision is infinite + + `` \pi_n = \frac{1}{\hat{\mu}_n \cdot (1-\hat{\mu}_n)} `` + if the precision is other than infinite +""" +function calculate_posterior_precision(node::BinaryStateNode) + + ## If the child is an observation child ## + if length(node.edges.observation_children) > 0 + + #Extract the observation child + child = node.edges.observation_children[1] + + #Simple update with inifinte precision + if child.parameters.input_precision == Inf + posterior_precision = Inf + #Update with finite precision + else + posterior_precision = + 1 / (node.states.posterior_mean * (1 - node.states.posterior_mean)) + end + + ## If the child is a category child ## + elseif length(node.edges.category_children) > 0 + + posterior_precision = Inf + + else + @error "the binary state node $(node.name) has neither category nor observation children" + end + + return posterior_precision +end + +##### Mean update ##### +@doc raw""" + calculate_posterior_mean(node::BinaryStateNode) + +Calculates a node's posterior mean. + +Uses the equation +`` \mu = \frac{e^{-0.5 \cdot \pi_n \cdot \eta_1^2}}{\hat{\mu}_n \cdot e^{-0.5 \cdot \pi_n \cdot \eta_1^2} \; + 1-\hat{\mu}_n \cdot e^{-0.5 \cdot \pi_n \cdot \eta_2^2}} `` +""" +function calculate_posterior_mean(node::BinaryStateNode, update_type::HGFUpdateType) + + ## If the child is an observation child ## + if length(node.edges.observation_children) > 0 + + #Extract the child + child = node.edges.observation_children[1] + + #For missing inputs + if ismissing(child.states.input_value) + #Set the posterior to missing + posterior_mean = missing + else + #Update with infinte input precision + if child.parameters.input_precision == Inf + posterior_mean = child.states.input_value + + #Update with finite input precision + else + posterior_mean = + node.states.prediction_mean * exp( + -0.5 * + node.states.prediction_precision * + child.parameters.category_means[1]^2, + ) / ( + node.states.prediction_mean * exp( + -0.5 * + node.states.prediction_precision * + child.parameters.category_means[1]^2, + ) + + (1 - node.states.prediction_mean) * exp( + -0.5 * + node.states.prediction_precision * + child.parameters.category_means[2]^2, + ) + ) + end + end + + ## If the child is a category child ## + elseif length(node.edges.category_children) > 0 + + #Extract the child + child = node.edges.category_children[1] + + #Find the nodes' own category number + category_number = findfirst(child.edges.category_parent_order .== node.name) + + #Find the corresponding value in the child + posterior_mean = child.states.posterior[category_number] + + else + @error "the binary state node $(node.name) has neither category nor observation children" + end + + return posterior_mean +end + + +############################################### +######## Update value prediction error ######## +############################################### + +##### Superfunction ##### +""" + update_node_value_prediction_error!(node::AbstractStateNode) + +Update the value prediction error of a single state node. +""" +function update_node_value_prediction_error!(node::BinaryStateNode) + #Update value prediction error + node.states.value_prediction_error = calculate_value_prediction_error(node) + push!(node.history.value_prediction_error, node.states.value_prediction_error) + + return nothing +end + +@doc raw""" + calculate_value_prediction_error(node::AbstractNode) + +Calculate's a state node's value prediction error. + +Uses the equation +`` \delta_n = \mu_n - \hat{\mu}_n `` +""" +function calculate_value_prediction_error(node::BinaryStateNode) + node.states.posterior_mean - node.states.prediction_mean +end + +################################################### +######## Update precision prediction error ######## +################################################### + +##### Superfunction ##### +""" + update_node_precision_prediction_error!(node::BinaryStateNode) + +There is no volatility prediction error update for binary state nodes. +""" +function update_node_precision_prediction_error!(node::BinaryStateNode) + return nothing +end diff --git a/src/update_hgf/node_updates/categorical_input_node.jl b/src/update_hgf/node_updates/categorical_input_node.jl new file mode 100644 index 0000000..80fa786 --- /dev/null +++ b/src/update_hgf/node_updates/categorical_input_node.jl @@ -0,0 +1,43 @@ +################################### +######## Update prediction ######## +################################### + +##### Superfunction ##### +""" + update_node_prediction!(node::CategoricalInputNode) + +There is no prediction update for categorical input nodes, as the prediction precision is constant. +""" +function update_node_prediction!(node::CategoricalInputNode) + return nothing +end + + +############################################### +######## Update value prediction error ######## +############################################### + +##### Superfunction ##### +""" + update_node_value_prediction_error!(node::CategoricalInputNode) + + There is no value prediction error update for categorical input nodes. +""" +function update_node_value_prediction_error!(node::CategoricalInputNode) + return nothing +end + + +################################################### +######## Update precision prediction error ######## +################################################### + +##### Superfunction ##### +""" + update_node_precision_prediction_error!(node::CategoricalInputNode) + +There is no volatility prediction error update for categorical input nodes. +""" +function update_node_precision_prediction_error!(node::CategoricalInputNode) + return nothing +end diff --git a/src/update_hgf/node_updates/categorical_state_node.jl b/src/update_hgf/node_updates/categorical_state_node.jl new file mode 100644 index 0000000..44b6540 --- /dev/null +++ b/src/update_hgf/node_updates/categorical_state_node.jl @@ -0,0 +1,167 @@ +################################### +######## Update prediction ######## +################################### + +##### Superfunction ##### +""" + update_node_prediction!(node::CategoricalStateNode) + +Update the prediction of a single categorical state node. +""" +function update_node_prediction!(node::CategoricalStateNode) + + #Update prediction mean + node.states.prediction, node.states.parent_predictions = calculate_prediction(node) + push!(node.history.prediction, node.states.prediction) + push!(node.history.parent_predictions, node.states.parent_predictions) + return nothing +end + + +@doc raw""" + function calculate_prediction(node::CategoricalStateNode) + +Calculate the prediction for a categorical state node. + +Uses the equation +`` \vec{\hat{\mu}_n}= \frac{\hat{\mu}_j}{\sum_{j=1}^{j\;binary \;parents} \hat{\mu}_j} `` +""" +function calculate_prediction(node::CategoricalStateNode) + + #Get parent posteriors + parent_posteriors = + map(x -> x.states.posterior_mean, collect(values(node.edges.category_parents))) + + #Get current parent predictions + parent_predictions = + map(x -> x.states.prediction_mean, collect(values(node.edges.category_parents))) + + #Get previous parent predictions + previous_parent_predictions = node.states.parent_predictions + + #If there was an observation + if any(!ismissing, node.states.posterior) + + #Calculate implied learning rate + implied_learning_rate = + ( + (parent_posteriors .- previous_parent_predictions) ./ + (parent_predictions .- previous_parent_predictions) + ) .- 1 + + # calculate the prediction mean + prediction = + ((implied_learning_rate .* parent_predictions) .+ 1) ./ + sum(implied_learning_rate .* parent_predictions .+ 1) + + #If there was no observation + else + #Extract prediction from last timestep + prediction = node.states.prediction + end + + return prediction, parent_predictions + +end + + +################################## +######## Update posterior ######## +################################## + +##### Superfunction ##### +""" + update_node_posterior!(node::CategoricalStateNode) + +Update the posterior of a single categorical state node. +""" +function update_node_posterior!(node::CategoricalStateNode, update_type::ClassicUpdate) + + #Update posterior mean + node.states.posterior = calculate_posterior(node) + push!(node.history.posterior, node.states.posterior) + + return nothing +end + + +@doc raw""" + calculate_posterior(node::CategoricalStateNode) + +Calculate the posterior for a categorical state node. + +One hot encoding +`` \vec{u} = [0, 0, \dots ,1, \dots,0] `` +""" +function calculate_posterior(node::CategoricalStateNode) + + #Extract the child + child = node.edges.observation_children[1] + + #Initialize posterior as previous posterior + posterior = node.states.posterior + + #For missing inputs + if ismissing(child.states.input_value) + #Set the posterior to be all missing + posterior .= missing + + else + #Set all values to 0 + posterior .= zero(Real) + + #Set the posterior for the observed category to 1 + posterior[child.states.input_value] = 1 + end + return posterior +end + + +############################################### +######## Update value prediction error ######## +############################################### + +##### Superfunction ##### +""" + update_node_value_prediction_error!(node::AbstractStateNode) + +Update the value prediction error of a single state node. +""" +function update_node_value_prediction_error!(node::CategoricalStateNode) + #Update value prediction error + node.states.value_prediction_error = calculate_value_prediction_error(node) + push!(node.history.value_prediction_error, node.states.value_prediction_error) + + return nothing +end + +@doc raw""" + calculate_value_prediction_error(node::CategoricalStateNode) + +Calculate the value prediction error for a categorical state node. + +Uses the equation +`` \delta_n= u - \sum_{j=1}^{j\;value\;parents} \hat{\mu}_{j} `` +""" +function calculate_value_prediction_error(node::CategoricalStateNode) + + #Get the prediction error for each category + value_prediction_error = node.states.posterior - node.states.prediction + + return value_prediction_error +end + + +################################################### +######## Update precision prediction error ######## +################################################### + +##### Superfunction ##### +""" + update_node_precision_prediction_error!(node::CategoricalStateNode) + +There is no volatility prediction error update for categorical state nodes. +""" +function update_node_precision_prediction_error!(node::CategoricalStateNode) + return nothing +end diff --git a/src/update_hgf/node_updates/continuous_input_node.jl b/src/update_hgf/node_updates/continuous_input_node.jl new file mode 100644 index 0000000..8c4d261 --- /dev/null +++ b/src/update_hgf/node_updates/continuous_input_node.jl @@ -0,0 +1,178 @@ +################################### +######## Update prediction ######## +################################### + +##### Superfunction ##### +""" + update_node_prediction!(node::AbstractInputNode) + +Update the posterior of a single input node. +""" +function update_node_prediction!(node::ContinuousInputNode) + + #Update node prediction mean + node.states.prediction_mean = calculate_prediction_mean(node) + push!(node.history.prediction_mean, node.states.prediction_mean) + + #Update prediction precision + node.states.prediction_precision = calculate_prediction_precision(node) + push!(node.history.prediction_precision, node.states.prediction_precision) + + return nothing +end + + +##### Mean update ##### +function calculate_prediction_mean(node::ContinuousInputNode) + #Extract parents + observation_parents = node.edges.observation_parents + + #Initialize prediction at 0 + prediction_mean = 0 + + #Sum the predictions of the parents + for parent in observation_parents + prediction_mean += parent.states.prediction_mean + end + + return prediction_mean +end + + +##### Precision update ##### +@doc raw""" + calculate_prediction_precision(node::AbstractInputNode) + +Calculates an input node's prediction precision. + +Uses the equation +`` \hat{\pi}_n = \frac{1}{\nu}_n `` +""" +function calculate_prediction_precision(node::ContinuousInputNode) + + #Extract noise parents + noise_parents = node.edges.noise_parents + + #Initialize noise from input noise parameter + predicted_noise = node.parameters.input_noise + + #Go through each noise parent + for parent in noise_parents + #Add its mean to the predicted noise + predicted_noise += + parent.states.posterior_mean * node.parameters.coupling_strengths[parent.name] + end + + #The prediction precision is the inverse of the predicted noise + prediction_precision = 1 / exp(predicted_noise) + + return prediction_precision +end + + + +############################################### +######## Update value prediction error ######## +############################################### + +##### Superfunction ##### +""" + update_node_value_prediction_error!(node::AbstractInputNode) + +Update the value prediction error of a single input node. +""" +function update_node_value_prediction_error!(node::ContinuousInputNode) + + #Calculate value prediction error + node.states.value_prediction_error = calculate_value_prediction_error(node) + push!(node.history.value_prediction_error, node.states.value_prediction_error) + + return nothing +end + + + +@doc raw""" + calculate_value_prediction_error(node::ContinuousInputNode) + +Calculate's an input node's value prediction error. + +Uses the equation +``\delta_n= u - \sum_{j=1}^{j\;value\;parents} \hat{\mu}_{j} `` +""" +function calculate_value_prediction_error(node::ContinuousInputNode) + #For missing input + if ismissing(node.states.input_value) + #Set the prediction error to missing + value_prediction_error = missing + else + #Get value prediction error between the prediction and the input value + value_prediction_error = node.states.input_value - node.states.prediction_mean + end + + return value_prediction_error +end + + +################################################### +######## Update precision prediction error ######## +################################################### + +##### Superfunction ##### +""" + update_node_precision_prediction_error!(node::AbstractInputNode) + +Update the value prediction error of a single input node. +""" +function update_node_precision_prediction_error!(node::ContinuousInputNode) + + #Calculate volatility prediction error, only if there are volatility parents + if length(node.edges.noise_parents) > 0 + node.states.precision_prediction_error = calculate_precision_prediction_error(node) + push!( + node.history.precision_prediction_error, + node.states.precision_prediction_error, + ) + end + + return nothing +end + +@doc raw""" + calculate_precision_prediction_error(node::ContinuousInputNode) + +Calculates an input node's volatility prediction error. + +Uses the equation +`` \mu'_j=\sum_{j=1}^{j\;value\;parents} \mu_{j} `` +`` \pi'_j=\frac{{\sum_{j=1}^{j\;value\;parents} \pi_{j}}}{j} `` +`` \Delta_n=\frac{\hat{\pi}_n}{\pi'_j} + \hat{\mu}_i\cdot (u -\mu'_j^2 )-1 `` +""" +function calculate_precision_prediction_error(node::ContinuousInputNode) + + #For missing input + if ismissing(node.states.input_value) + #Set the prediction error to missing + precision_prediction_error = missing + else + #Extract parents + observation_parents = node.edges.observation_parents + + #Average the posterior precision of the observation parents + parents_average_posterior_precision = 0 + + for parent in observation_parents + parents_average_posterior_precision += parent.states.posterior_precision + end + + parents_average_posterior_precision = + parents_average_posterior_precision / length(observation_parents) + + #Get the noise prediction error using the average parent parents_posterior_precision + precision_prediction_error = + node.states.prediction_precision / parents_average_posterior_precision + + node.states.prediction_precision * node.states.value_prediction_error^2 - 1 + end + + return precision_prediction_error +end diff --git a/src/update_hgf/node_updates/continuous_state_node.jl b/src/update_hgf/node_updates/continuous_state_node.jl new file mode 100644 index 0000000..9574ead --- /dev/null +++ b/src/update_hgf/node_updates/continuous_state_node.jl @@ -0,0 +1,614 @@ +################################### +######## Update prediction ######## +################################### + +##### Superfunction ##### +""" + update_node_prediction!(node::ContinuousStateNode) + +Update the prediction of a single state node. +""" +function update_node_prediction!(node::ContinuousStateNode) + + #Update prediction mean + node.states.prediction_mean = calculate_prediction_mean(node) + push!(node.history.prediction_mean, node.states.prediction_mean) + + #Update prediction volatility + node.states.predicted_volatility = calculate_predicted_volatility(node) + push!(node.history.predicted_volatility, node.states.predicted_volatility) + + #Update prediction precision + node.states.prediction_precision = calculate_prediction_precision(node) + push!(node.history.prediction_precision, node.states.prediction_precision) + + #Get auxiliary prediction precision, only if there are volatility children and/or volatility parents + if length(node.edges.volatility_parents) > 0 || + length(node.edges.volatility_children) > 0 || + length(node.edges.noise_children) > 0 + node.states.volatility_weighted_prediction_precision = + calculate_volatility_weighted_prediction_precision(node) + push!( + node.history.volatility_weighted_prediction_precision, + node.states.volatility_weighted_prediction_precision, + ) + end + + return nothing +end + + +##### Mean update ##### +@doc raw""" + calculate_prediction_mean(node::AbstractNode) + +Calculates a node's prediction mean. + +Uses the equation +`` \hat{\mu}_i=\mu_i+\sum_{j=1}^{j\;value\;parents} \mu_{j} \cdot \alpha_{i,j} `` +""" +function calculate_prediction_mean(node::ContinuousStateNode) + #Get out value parents + drift_parents = node.edges.drift_parents + + #Initialize the total drift as the baseline drift plus the autoregression drift + predicted_drift = + node.parameters.drift + + node.parameters.autoregression_strength * + (node.parameters.autoregression_target - node.states.posterior_mean) + + #Add contributions from value parents + for parent in drift_parents + predicted_drift += + parent.states.posterior_mean * node.parameters.coupling_strengths[parent.name] + end + + #Add the drift to the posterior to get the prediction mean + prediction_mean = node.states.posterior_mean + 1 * predicted_drift + + return prediction_mean +end + +##### Predicted volatility update ##### +@doc raw""" + calculate_predicted_volatility(node::AbstractNode) + +Calculates a node's prediction volatility. + +Uses the equation +`` \nu_i =exp( \omega_i + \sum_{j=1}^{j\;volatility\;parents} \mu_{j} \cdot \kappa_{i,j}} `` +""" +function calculate_predicted_volatility(node::ContinuousStateNode) + volatility_parents = node.edges.volatility_parents + + predicted_volatility = node.parameters.volatility + + for parent in volatility_parents + predicted_volatility += + parent.states.posterior_mean * node.parameters.coupling_strengths[parent.name] + end + + return exp(predicted_volatility) +end + +##### Precision update ##### +@doc raw""" + calculate_prediction_precision(node::AbstractNode) + +Calculates a node's prediction precision. + +Uses the equation +`` \hat{\pi}_i^ = \frac{1}{\frac{1}{\pi_i}+\nu_i^} `` +""" +function calculate_prediction_precision(node::ContinuousStateNode) + prediction_precision = + 1 / (1 / node.states.posterior_precision + node.states.predicted_volatility) + + #If the posterior precision is negative + if prediction_precision < 0 + #Throw an error + throw( + #Of the custom type where samples are rejected + RejectParameters( + "With these parameters and inputs, the prediction precision of node $(node.name) becomes negative after $(length(node.history.prediction_precision)) inputs", + ), + ) + end + + return prediction_precision +end + +@doc raw""" + calculate_volatility_weighted_prediction_precision(node::AbstractNode) + +Calculates a node's auxiliary prediction precision. + +Uses the equation +`` \gamma_i = \nu_i \cdot \hat{\pi}_i `` +""" +function calculate_volatility_weighted_prediction_precision(node::ContinuousStateNode) + node.states.predicted_volatility * node.states.prediction_precision +end + +################################## +######## Update posterior ######## +################################## + +##### Superfunction ##### +""" + update_node_posterior!(node::AbstractStateNode; update_type::HGFUpdateType) + +Update the posterior of a single continuous state node. This is the classic HGF update. +""" +function update_node_posterior!(node::ContinuousStateNode, update_type::ClassicUpdate) + #Update posterior precision + node.states.posterior_precision = calculate_posterior_precision(node) + push!(node.history.posterior_precision, node.states.posterior_precision) + + #Update posterior mean + node.states.posterior_mean = calculate_posterior_mean(node, update_type) + push!(node.history.posterior_mean, node.states.posterior_mean) + + return nothing +end + +""" + update_node_posterior!(node::AbstractStateNode) + +Update the posterior of a single continuous state node. This is the enahnced HGF update. +""" +function update_node_posterior!(node::ContinuousStateNode, update_type::EnhancedUpdate) + #Update posterior mean + node.states.posterior_mean = calculate_posterior_mean(node, update_type) + push!(node.history.posterior_mean, node.states.posterior_mean) + + #Update posterior precision + node.states.posterior_precision = calculate_posterior_precision(node) + push!(node.history.posterior_precision, node.states.posterior_precision) + + return nothing +end + +##### Precision update ##### +@doc raw""" + calculate_posterior_precision(node::AbstractNode) + +Calculates a node's posterior precision. + +Uses the equation +`` \pi_i^{'} = \hat{\pi}_i +\underbrace{\sum_{j=1}^{j\;children} \alpha_{j,i} \cdot \hat{\pi}_{j}} _\text{sum \;of \;VAPE \; continuous \; value \;chidren} `` +""" +function calculate_posterior_precision(node::ContinuousStateNode) + + #Initialize as the node's own prediction + posterior_precision = node.states.prediction_precision + + #Add update terms from drift children + for child in node.edges.drift_children + posterior_precision += + calculate_posterior_precision_increment(node, child, DriftCoupling()) + end + + #Add update terms from observation children + for child in node.edges.observation_children + posterior_precision += + calculate_posterior_precision_increment(node, child, ObservationCoupling()) + end + + #Add update terms from probability children + for child in node.edges.probability_children + posterior_precision += + calculate_posterior_precision_increment(node, child, ProbabilityCoupling()) + end + + #Add update terms from volatility children + for child in node.edges.volatility_children + posterior_precision += + calculate_posterior_precision_increment(node, child, VolatilityCoupling()) + end + + #Add update terms from noise children + for child in node.edges.noise_children + posterior_precision += + calculate_posterior_precision_increment(node, child, NoiseCoupling()) + end + + + #If the posterior precision is negative + if posterior_precision < 0 + #Throw an error + throw( + #Of the custom type where samples are rejected + RejectParameters( + "With these parameters and inputs, the posterior precision of node $(node.name) becomes negative after $(length(node.history.posterior_precision)) inputs", + ), + ) + end + + return posterior_precision +end + +## Drift coupling ## +function calculate_posterior_precision_increment( + node::ContinuousStateNode, + child::ContinuousStateNode, + coupling_type::DriftCoupling, +) + child.parameters.coupling_strengths[node.name] * child.states.prediction_precision + +end + +## Observation coupling ## +function calculate_posterior_precision_increment( + node::ContinuousStateNode, + child::ContinuousInputNode, + coupling_type::ObservationCoupling, +) + + #If missing input + if ismissing(child.states.input_value) + #No increment + return 0 + else + return child.states.prediction_precision + end +end + +## Probability coupling ## +function calculate_posterior_precision_increment( + node::ContinuousStateNode, + child::BinaryStateNode, + coupling_type::ProbabilityCoupling, +) + + #If there is a missing posterior (due to a missing input) + if ismissing(child.states.posterior_mean) + #No update + return 0 + else + return child.parameters.coupling_strengths[node.name]^2 / + child.states.prediction_precision + end +end + +@doc raw""" + calculate_posterior_precision_vope(node::AbstractNode, child::AbstractNode) + +Calculates the posterior precision update term for a single continuous volatility child to a state node. + +Uses the equation +`` `` +""" +function calculate_posterior_precision_increment( + node::ContinuousStateNode, + child::ContinuousStateNode, + coupling_type::VolatilityCoupling, +) + + 1 / 2 * + ( + child.parameters.coupling_strengths[node.name] * + child.states.volatility_weighted_prediction_precision + )^2 + + child.states.precision_prediction_error * + ( + child.parameters.coupling_strengths[node.name] * + child.states.volatility_weighted_prediction_precision + )^2 - + 1 / 2 * + child.parameters.coupling_strengths[node.name]^2 * + child.states.volatility_weighted_prediction_precision * + child.states.precision_prediction_error +end + +function calculate_posterior_precision_increment( + node::ContinuousStateNode, + child::ContinuousInputNode, + coupling_type::NoiseCoupling, +) + + #If the input node child had a missing input + if ismissing(child.states.input_value) + #No increment + return 0 + else + update_term = + 1 / 2 * (child.parameters.coupling_strengths[node.name])^2 + + child.states.precision_prediction_error * + (child.parameters.coupling_strengths[node.name])^2 - + 1 / 2 * + child.parameters.coupling_strengths[node.name]^2 * + child.states.precision_prediction_error + + return update_term + end +end + +##### Mean update ##### +@doc raw""" + calculate_posterior_mean(node::AbstractNode) + +Calculates a node's posterior mean. + +Uses the equation +`` `` +""" +function calculate_posterior_mean(node::ContinuousStateNode, update_type::HGFUpdateType) + + #Initialize as the prediction + posterior_mean = node.states.prediction_mean + + #Add update terms from drift children + for child in node.edges.drift_children + posterior_mean += + calculate_posterior_mean_increment(node, child, DriftCoupling(), update_type) + end + + #Add update terms from observation children + for child in node.edges.observation_children + posterior_mean += calculate_posterior_mean_increment( + node, + child, + ObservationCoupling(), + update_type, + ) + end + + #Add update terms from probability children + for child in node.edges.probability_children + posterior_mean += calculate_posterior_mean_increment( + node, + child, + ProbabilityCoupling(), + update_type, + ) + end + + #Add update terms from volatility children + for child in node.edges.volatility_children + posterior_mean += calculate_posterior_mean_increment( + node, + child, + VolatilityCoupling(), + update_type, + ) + end + + #Add update terms from noise children + for child in node.edges.noise_children + posterior_mean += + calculate_posterior_mean_increment(node, child, NoiseCoupling(), update_type) + end + + return posterior_mean +end + +## Classic drift coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousStateNode, + coupling_type::DriftCoupling, + update_type::ClassicUpdate, +) + (child.parameters.coupling_strengths[node.name] * child.states.prediction_precision) / + node.states.posterior_precision * child.states.value_prediction_error +end + +## Enhanced drift coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousStateNode, + coupling_type::DriftCoupling, + update_type::EnhancedUpdate, +) + (child.parameters.coupling_strengths[node.name] * child.states.prediction_precision) / + node.states.prediction_precision * child.states.value_prediction_error + +end + +## Classic observation coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousInputNode, + coupling_type::ObservationCoupling, + update_type::ClassicUpdate, +) + #For input node children with missing input + if ismissing(child.states.input_value) + #No update + return 0 + else + update_term = + child.states.prediction_precision / node.states.posterior_precision * + child.states.value_prediction_error + + return update_term + end +end + +## Enhanced observation coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousInputNode, + coupling_type::ObservationCoupling, + update_type::EnhancedUpdate, +) + #For input node children with missing input + if ismissing(child.states.input_value) + #No update + return 0 + else + update_term = + child.states.prediction_precision / node.states.prediction_precision * + child.states.value_prediction_error + + return update_term + end +end + +## Classic probability coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::BinaryStateNode, + coupling_type::ProbabilityCoupling, + update_type::ClassicUpdate, +) + #If the posterior is missing (due to missing inputs) + if ismissing(child.states.posterior_mean) + #No update + return 0 + else + return child.parameters.coupling_strengths[node.name] / + node.states.posterior_precision * child.states.value_prediction_error + end +end + +## Enhanced Probability coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::BinaryStateNode, + coupling_type::ProbabilityCoupling, + update_type::EnhancedUpdate, +) + #If the posterior is missing (due to missing inputs) + if ismissing(child.states.posterior_mean) + #No update + return 0 + else + return child.parameters.coupling_strengths[node.name] / + node.states.prediction_precision * child.states.value_prediction_error + end +end + +## Classic Volatility coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousStateNode, + coupling_type::VolatilityCoupling, + update_type::ClassicUpdate, +) + 1 / 2 * ( + child.parameters.coupling_strengths[node.name] * + child.states.volatility_weighted_prediction_precision + ) / node.states.posterior_precision * child.states.precision_prediction_error +end + +## Enhanced Volatility coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousStateNode, + coupling_type::VolatilityCoupling, + update_type::EnhancedUpdate, +) + 1 / 2 * ( + child.parameters.coupling_strengths[node.name] * + child.states.volatility_weighted_prediction_precision + ) / node.states.prediction_precision * child.states.precision_prediction_error +end + +## Classic Noise coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousInputNode, + coupling_type::NoiseCoupling, + update_type::ClassicUpdate, +) + #For input node children with missing input + if ismissing(child.states.input_value) + #No update + return 0 + else + update_term = + 1 / 2 * (child.parameters.coupling_strengths[node.name]) / + node.states.posterior_precision * child.states.precision_prediction_error + + return update_term + end +end + +## Enhanced Noise coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousInputNode, + coupling_type::NoiseCoupling, + update_type::EnhancedUpdate, +) + #For input node children with missing input + if ismissing(child.states.input_value) + #No update + return 0 + else + update_term = + 1 / 2 * (child.parameters.coupling_strengths[node.name]) / + node.states.prediction_precision * child.states.precision_prediction_error + + return update_term + end +end + +############################################### +######## Update value prediction error ######## +############################################### + +##### Superfunction ##### +""" + update_node_value_prediction_error!(node::AbstractStateNode) + +Update the value prediction error of a single state node. +""" +function update_node_value_prediction_error!(node::ContinuousStateNode) + #Update value prediction error + node.states.value_prediction_error = calculate_value_prediction_error(node) + push!(node.history.value_prediction_error, node.states.value_prediction_error) + + return nothing +end + +@doc raw""" + calculate_value_prediction_error(node::AbstractNode) + +Calculate's a state node's value prediction error. + +Uses the equation +`` \delta_n = \mu_n - \hat{\mu}_n `` +""" +function calculate_value_prediction_error(node::ContinuousStateNode) + node.states.posterior_mean - node.states.prediction_mean +end + +################################################### +######## Update precision prediction error ######## +################################################### + +##### Superfunction ##### +""" + update_node_precision_prediction_error!(node::AbstractStateNode) + +Update the volatility prediction error of a single state node. +""" +function update_node_precision_prediction_error!(node::ContinuousStateNode) + + #Update volatility prediction error, only if there are volatility parents + if length(node.edges.volatility_parents) > 0 + node.states.precision_prediction_error = calculate_precision_prediction_error(node) + push!( + node.history.precision_prediction_error, + node.states.precision_prediction_error, + ) + end + + return nothing +end + +@doc raw""" + calculate_precision_prediction_error(node::AbstractNode) + +Calculates a state node's volatility prediction error. + +Uses the equation +`` \Delta_n = \frac{\hat{\pi}_n}{\pi_n} + \hat{\pi}_n \cdot \delta_n^2-1 `` +""" +function calculate_precision_prediction_error(node::ContinuousStateNode) + node.states.prediction_precision / node.states.posterior_precision + + node.states.prediction_precision * node.states.value_prediction_error^2 - 1 +end diff --git a/src/update_hgf/update_equations.jl b/src/update_hgf/update_equations.jl deleted file mode 100644 index d541877..0000000 --- a/src/update_hgf/update_equations.jl +++ /dev/null @@ -1,813 +0,0 @@ -####################################### -######## Continuous State Node ######## -####################################### - -######## Prediction update ######## - -### Mean update ### -@doc raw""" - calculate_prediction_mean(node::AbstractNode) - -Calculates a node's prediction mean. - -Uses the equation -`` \hat{\mu}_i=\mu_i+\sum_{j=1}^{j\;value\;parents} \mu_{j} \cdot \alpha_{i,j} `` -""" -function calculate_prediction_mean(node::AbstractNode) - #Get out value parents - value_parents = node.value_parents - - #Initialize the total drift as the basline drift plus the autoregression drift - predicted_drift = - node.parameters.drift + - node.parameters.autoregression_strength * - (node.parameters.autoregression_target - node.states.posterior_mean) - - #Add contributions from value parents - for parent in value_parents - predicted_drift += - parent.states.posterior_mean * node.parameters.value_coupling[parent.name] - end - - #Add the drift to the posterior to get the prediction mean - prediction_mean = node.states.posterior_mean + 1 * predicted_drift - - return prediction_mean -end - -### Volatility update ### -@doc raw""" - calculate_predicted_volatility(node::AbstractNode) - -Calculates a node's prediction volatility. - -Uses the equation -`` \nu_i =exp( \omega_i + \sum_{j=1}^{j\;volatility\;parents} \mu_{j} \cdot \kappa_{i,j}} `` -""" -function calculate_predicted_volatility(node::AbstractNode) - volatility_parents = node.volatility_parents - - predicted_volatility = node.parameters.volatility - - for parent in volatility_parents - predicted_volatility += - parent.states.posterior_mean * node.parameters.volatility_coupling[parent.name] - end - - return exp(predicted_volatility) -end - -### Precision update ### -@doc raw""" - calculate_prediction_precision(node::AbstractNode) - -Calculates a node's prediction precision. - -Uses the equation -`` \hat{\pi}_i^ = \frac{1}{\frac{1}{\pi_i}+\nu_i^} `` -""" -function calculate_prediction_precision(node::AbstractNode) - prediction_precision = - 1 / (1 / node.states.posterior_precision + node.states.predicted_volatility) - - #If the posterior precision is negative - if prediction_precision < 0 - #Throw an error - throw( - #Of the custom type where samples are rejected - RejectParameters( - "With these parameters and inputs, the prediction precision of node $(node.name) becomes negative after $(length(node.history.prediction_precision)) inputs", - ), - ) - end - - return prediction_precision -end - -@doc raw""" - calculate_volatility_weighted_prediction_precision(node::AbstractNode) - -Calculates a node's auxiliary prediction precision. - -Uses the equation -`` \gamma_i = \nu_i \cdot \hat{\pi}_i `` -""" -function calculate_volatility_weighted_prediction_precision(node::AbstractNode) - node.states.predicted_volatility * node.states.prediction_precision -end - -######## Posterior update functions ######## - -### Precision update ### -@doc raw""" - calculate_posterior_precision(node::AbstractNode) - -Calculates a node's posterior precision. - -Uses the equation -`` \pi_i^{'} = \hat{\pi}_i +\underbrace{\sum_{j=1}^{j\;children} \alpha_{j,i} \cdot \hat{\pi}_{j}} _\text{sum \;of \;VAPE \; continuous \; value \;chidren} `` -""" -function calculate_posterior_precision(node::AbstractNode) - value_children = node.value_children - volatility_children = node.volatility_children - - #Initialize as the node's own prediction - posterior_precision = node.states.prediction_precision - - #Add update terms from value children - for child in value_children - posterior_precision += calculate_posterior_precision_vape(node, child) - end - - #Add update terms from volatility children - for child in volatility_children - posterior_precision += calculate_posterior_precision_vope(node, child) - end - - #If the posterior precision is negative - if posterior_precision < 0 - #Throw an error - throw( - #Of the custom type where samples are rejected - RejectParameters( - "With these parameters and inputs, the posterior precision of node $(node.name) becomes negative after $(length(node.history.posterior_precision)) inputs", - ), - ) - end - - return posterior_precision -end - -@doc raw""" - calculate_posterior_precision_vape(node::AbstractNode, child::AbstractNode) - -Calculates the posterior precision update term for a single continuous value child to a state node. - -Uses the equation -`` `` -""" -function calculate_posterior_precision_vape(node::AbstractNode, child::AbstractNode) - - #For input node children with missing input - if child isa AbstractInputNode && ismissing(child.states.input_value) - #No update - return 0 - else - return child.parameters.value_coupling[node.name] * - child.states.prediction_precision - end -end - -@doc raw""" - calculate_posterior_precision_vape(node::AbstractNode, child::BinaryStateNode) - -Calculates the posterior precision update term for a single binary value child to a state node. - -Uses the equation -`` `` -""" -function calculate_posterior_precision_vape(node::AbstractNode, child::BinaryStateNode) - - #For missing inputs - if ismissing(child.states.posterior_mean) - #No update - return 0 - else - return child.parameters.value_coupling[node.name]^2 / - child.states.prediction_precision - end -end - -@doc raw""" - calculate_posterior_precision_vope(node::AbstractNode, child::AbstractNode) - -Calculates the posterior precision update term for a single continuous volatility child to a state node. - -Uses the equation -`` `` -""" -function calculate_posterior_precision_vope(node::AbstractNode, child::AbstractNode) - - #For input node children with missing input - if child isa AbstractInputNode && ismissing(child.states.input_value) - #No update - return 0 - else - update_term = - 1 / 2 * - ( - child.parameters.volatility_coupling[node.name] * - child.states.volatility_weighted_prediction_precision - )^2 + - child.states.volatility_prediction_error * - ( - child.parameters.volatility_coupling[node.name] * - child.states.volatility_weighted_prediction_precision - )^2 - - 1 / 2 * - child.parameters.volatility_coupling[node.name]^2 * - child.states.volatility_weighted_prediction_precision * - child.states.volatility_prediction_error - - return update_term - end -end - -### Mean update ### -@doc raw""" - calculate_posterior_mean(node::AbstractNode) - -Calculates a node's posterior mean. - -Uses the equation -`` `` -""" -function calculate_posterior_mean(node::AbstractNode, update_type::HGFUpdateType) - value_children = node.value_children - volatility_children = node.volatility_children - - #Initialize as the prediction - posterior_mean = node.states.prediction_mean - - #Add update terms from value children - for child in value_children - posterior_mean += - calculate_posterior_mean_value_child_increment(node, child, update_type) - end - - #Add update terms from volatility children - for child in volatility_children - posterior_mean += - calculate_posterior_mean_volatility_child_increment(node, child, update_type) - end - - return posterior_mean -end - -@doc raw""" - calculate_posterior_mean_value_child_increment(node::AbstractNode, child::AbstractNode) - -Calculates the posterior mean update term for a single continuous value child to a state node. -This is the classic HGF update. - -Uses the equation -`` `` -""" -function calculate_posterior_mean_value_child_increment( - node::AbstractNode, - child::AbstractNode, - update_type::HGFUpdateType, -) - #For input node children with missing input - if child isa AbstractInputNode && ismissing(child.states.input_value) - #No update - return 0 - else - update_term = - ( - child.parameters.value_coupling[node.name] * - child.states.prediction_precision - ) / node.states.posterior_precision * child.states.value_prediction_error - - return update_term - end -end - -@doc raw""" - calculate_posterior_mean_value_child_increment(node::AbstractNode, child::AbstractNode) - -Calculates the posterior mean update term for a single continuous value child to a state node. -This is the enhanced HGF update. - -Uses the equation -`` `` -""" -function calculate_posterior_mean_value_child_increment( - node::AbstractNode, - child::AbstractNode, - update_type::EnhancedUpdate, -) - #For input node children with missing input - if child isa AbstractInputNode && ismissing(child.states.input_value) - #No update - return 0 - else - update_term = - ( - child.parameters.value_coupling[node.name] * - child.states.prediction_precision - ) / node.states.prediction_precision * child.states.value_prediction_error - - return update_term - end -end - -@doc raw""" - calculate_posterior_mean_value_child_increment(node::AbstractNode, child::BinaryStateNode) - -Calculates the posterior mean update term for a single binary value child to a state node. -This is the classic HGF update. - -Uses the equation -`` `` -""" -function calculate_posterior_mean_value_child_increment( - node::AbstractNode, - child::BinaryStateNode, - update_type::HGFUpdateType, -) - #For missing inputs - if ismissing(child.states.posterior_mean) - #No update - return 0 - else - return child.parameters.value_coupling[node.name] / - (node.states.posterior_precision) * child.states.value_prediction_error - end -end - -@doc raw""" - calculate_posterior_mean_value_child_increment(node::AbstractNode, child::BinaryStateNode) - -Calculates the posterior mean update term for a single binary value child to a state node. -This is the enhanced HGF update. - -Uses the equation -`` `` -""" -function calculate_posterior_mean_value_child_increment( - node::AbstractNode, - child::BinaryStateNode, - update_type::EnhancedUpdate, -) - #For missing inputs - if ismissing(child.states.posterior_mean) - #No update - return 0 - else - return child.parameters.value_coupling[node.name] / - (node.states.prediction_precision) * child.states.value_prediction_error - end -end - -@doc raw""" - calculate_posterior_mean_volatility_child_increment(node::AbstractNode, child::AbstractNode) - -Calculates the posterior mean update term for a single continuos volatility child to a state node. - -Uses the equation -`` `` -""" -function calculate_posterior_mean_volatility_child_increment( - node::AbstractNode, - child::AbstractNode, - update_type::HGFUpdateType, -) - #For input node children with missing input - if child isa AbstractInputNode && ismissing(child.states.input_value) - #No update - return 0 - else - update_term = - 1 / 2 * ( - child.parameters.volatility_coupling[node.name] * - child.states.volatility_weighted_prediction_precision - ) / node.states.posterior_precision * child.states.volatility_prediction_error - - return update_term - end -end - -@doc raw""" - calculate_posterior_mean_volatility_child_increment(node::AbstractNode, child::AbstractNode) - -Calculates the posterior mean update term for a single continuos volatility child to a state node. - -Uses the equation -`` `` -""" -function calculate_posterior_mean_volatility_child_increment( - node::AbstractNode, - child::AbstractNode, - update_type::EnhancedUpdate, -) - #For input node children with missing input - if child isa AbstractInputNode && ismissing(child.states.input_value) - #No update - return 0 - else - update_term = - 1 / 2 * ( - child.parameters.volatility_coupling[node.name] * - child.states.volatility_weighted_prediction_precision - ) / node.states.prediction_precision * child.states.volatility_prediction_error - - return update_term - end -end - -######## Prediction error update functions ######## -@doc raw""" - calculate_value_prediction_error(node::AbstractNode) - -Calculate's a state node's value prediction error. - -Uses the equation -`` \delta_n = \mu_n - \hat{\mu}_n `` -""" -function calculate_value_prediction_error(node::AbstractNode) - node.states.posterior_mean - node.states.prediction_mean -end - -@doc raw""" - calculate_volatility_prediction_error(node::AbstractNode) - -Calculates a state node's volatility prediction error. - -Uses the equation -`` \Delta_n = \frac{\hat{\pi}_n}{\pi_n} + \hat{\pi}_n \cdot \delta_n^2-1 `` -""" -function calculate_volatility_prediction_error(node::AbstractNode) - node.states.prediction_precision / node.states.posterior_precision + - node.states.prediction_precision * node.states.value_prediction_error^2 - 1 -end - - -############################################## -######## Binary State Node Variations ######## -############################################## - -######## Prediction update ######## - -### Mean update ### -@doc raw""" - calculate_prediction_mean(node::BinaryStateNode) - -Calculates a binary state node's prediction mean. - -Uses the equation -`` \hat{\mu}_n= \big(1+e^{\sum_{j=1}^{j\;value \; parents} \hat{\mu}_{j}}\big)^{-1} `` -""" -function calculate_prediction_mean(node::BinaryStateNode) - value_parents = node.value_parents - - prediction_mean = 0 - - for parent in value_parents - prediction_mean += - parent.states.prediction_mean * node.parameters.value_coupling[parent.name] - end - - prediction_mean = 1 / (1 + exp(-prediction_mean)) - - return prediction_mean -end - -### Precision update ### -@doc raw""" - calculate_prediction_precision(node::BinaryStateNode) - -Calculates a binary state node's prediction precision. - -Uses the equation -`` \hat{\pi}_n = \frac{1}{\hat{\mu}_n \cdot (1-\hat{\mu}_n)} `` -""" -function calculate_prediction_precision(node::BinaryStateNode) - 1 / (node.states.prediction_mean * (1 - node.states.prediction_mean)) -end - -######## Posterior update functions ######## - -### Precision update ### -@doc raw""" - calculate_posterior_precision(node::BinaryStateNode) - -Calculates a binary node's posterior precision. - -Uses the equations - -`` \pi_n = inf `` -if the precision is infinite - - `` \pi_n = \frac{1}{\hat{\mu}_n \cdot (1-\hat{\mu}_n)} `` - if the precision is other than infinite -""" -function calculate_posterior_precision(node::BinaryStateNode) - #Extract the child - child = node.value_children[1] - - #Simple update with inifinte precision - if child isa CategoricalStateNode || child.parameters.input_precision == Inf - posterior_precision = Inf - #Update with finite precision - else - posterior_precision = - 1 / (node.states.posterior_mean * (1 - node.states.posterior_mean)) - end - - return posterior_precision -end - -### Mean update ### -@doc raw""" - calculate_posterior_mean(node::BinaryStateNode) - -Calculates a node's posterior mean. - -Uses the equation -`` \mu = \frac{e^{-0.5 \cdot \pi_n \cdot \eta_1^2}}{\hat{\mu}_n \cdot e^{-0.5 \cdot \pi_n \cdot \eta_1^2} \; + 1-\hat{\mu}_n \cdot e^{-0.5 \cdot \pi_n \cdot \eta_2^2}} `` -""" -function calculate_posterior_mean(node::BinaryStateNode, update_type::HGFUpdateType) - #Extract the child - child = node.value_children[1] - - #Update with categorical state node child - if child isa CategoricalStateNode - - #Find the nodes' own category number - category_number = findfirst(child.category_parent_order .== node.name) - - #Find the corresponding value in the child - posterior_mean = child.states.posterior[category_number] - - #For binary input node children - elseif child isa BinaryInputNode - - #For missing inputs - if ismissing(child.states.input_value) - #Set the posterior to missing - posterior_mean = missing - else - #Update with infinte input precision - if child.parameters.input_precision == Inf - posterior_mean = child.states.input_value - - #Update with finite input precision - else - posterior_mean = - node.states.prediction_mean * exp( - -0.5 * - node.states.prediction_precision * - child.parameters.category_means[1]^2, - ) / ( - node.states.prediction_mean * exp( - -0.5 * - node.states.prediction_precision * - child.parameters.category_means[1]^2, - ) + - (1 - node.states.prediction_mean) * exp( - -0.5 * - node.states.prediction_precision * - child.parameters.category_means[2]^2, - ) - ) - end - end - end - return posterior_mean -end - - -################################################### -######## Categorical State Node Variations ######## -################################################### -@doc raw""" - calculate_posterior(node::CategoricalStateNode) - -Calculate the posterior for a categorical state node. - -One hot encoding -`` \vec{u} = [0, 0, \dots ,1, \dots,0] `` -""" -function calculate_posterior(node::CategoricalStateNode) - - #Get child - child = node.value_children[1] - - #Initialize posterior as previous posterior - posterior = node.states.posterior - - #For missing inputs - if ismissing(child.states.input_value) - #Set the posterior to be all missing - posterior .= missing - - else - #Set all values to 0 - posterior .= zero(Real) - - #Set the posterior for the observed category to 1 - posterior[child.states.input_value] = 1 - end - return posterior -end - -@doc raw""" - function calculate_prediction(node::CategoricalStateNode) - -Calculate the prediction for a categorical state node. - -Uses the equation -`` \vec{\hat{\mu}_n}= \frac{\hat{\mu}_j}{\sum_{j=1}^{j\;binary \;parents} \hat{\mu}_j} `` -""" -function calculate_prediction(node::CategoricalStateNode) - - #Get parent posteriors - parent_posteriors = - map(x -> x.states.posterior_mean, collect(values(node.value_parents))) - - #Get current parent predictions - parent_predictions = - map(x -> x.states.prediction_mean, collect(values(node.value_parents))) - - #Get previous parent predictions - previous_parent_predictions = node.states.parent_predictions - - #If there was an observation - if any(!ismissing, node.states.posterior) - - #Calculate implied learning rate - implied_learning_rate = - ( - (parent_posteriors .- previous_parent_predictions) ./ - (parent_predictions .- previous_parent_predictions) - ) .- 1 - - # calculate the prediction mean - prediction = - ((implied_learning_rate .* parent_predictions) .+ 1) ./ - sum(implied_learning_rate .* parent_predictions .+ 1) - - #If there was no observation - else - #Extract prediction from last timestep - prediction = node.states.prediction - end - - return prediction, parent_predictions - -end - -@doc raw""" - calculate_value_prediction_error(node::CategoricalStateNode) - -Calculate the value prediction error for a categorical state node. - -Uses the equation -`` \delta_n= u - \sum_{j=1}^{j\;value\;parents} \hat{\mu}_{j} `` -""" -function calculate_value_prediction_error(node::CategoricalStateNode) - - #Get the prediction error for each category - value_prediction_error = node.states.posterior - node.states.prediction - - return value_prediction_error -end - - -################################################### -######## Conntinuous Input Node Variations ######## -################################################### - -@doc raw""" - calculate_predicted_volatility(node::AbstractInputNode) - -Calculates an input node's prediction volatility. - -Uses the equation -`` \nu_i =exp( \omega_i + \sum_{j=1}^{j\;volatility\;parents} \mu_{j} \cdot \kappa_{i,j}} `` -""" -function calculate_predicted_volatility(node::AbstractInputNode) - volatility_parents = node.volatility_parents - - predicted_volatility = node.parameters.input_noise - - for parent in volatility_parents - predicted_volatility += - parent.states.posterior_mean * node.parameters.volatility_coupling[parent.name] - end - - return exp(predicted_volatility) -end - -@doc raw""" - calculate_prediction_precision(node::AbstractInputNode) - -Calculates an input node's prediction precision. - -Uses the equation -`` \hat{\pi}_n = \frac{1}{\nu}_n `` -""" -function calculate_prediction_precision(node::AbstractInputNode) - - #Doesn't use own posterior precision - 1 / node.states.predicted_volatility -end - -""" - calculate_volatility_weighted_prediction_precision(node::AbstractInputNode) - -An input node's auxiliary prediction precision is always 1. -""" -function calculate_volatility_weighted_prediction_precision(node::AbstractInputNode) - 1 -end - -@doc raw""" - calculate_value_prediction_error(node::ContinuousInputNode) - -Calculate's an input node's value prediction error. - -Uses the equation -``\delta_n= u - \sum_{j=1}^{j\;value\;parents} \hat{\mu}_{j} `` -""" -function calculate_value_prediction_error(node::ContinuousInputNode) - #For missing input - if ismissing(node.states.input_value) - #Set the prediction error to missing - value_prediction_error = missing - else - #Extract parents - value_parents = node.value_parents - - #Sum the prediction_means of the parents - parents_prediction_mean = 0 - for parent in value_parents - parents_prediction_mean += parent.states.prediction_mean - end - - #Get VOPE using parents_prediction_mean instead of own - value_prediction_error = node.states.input_value - parents_prediction_mean - end - return value_prediction_error -end - -@doc raw""" - calculate_volatility_prediction_error(node::ContinuousInputNode) - -Calculates an input node's volatility prediction error. - -Uses the equation -`` \mu'_j=\sum_{j=1}^{j\;value\;parents} \mu_{j} `` -`` \pi'_j=\frac{{\sum_{j=1}^{j\;value\;parents} \pi_{j}}}{j} `` -`` \Delta_n=\frac{\hat{\pi}_n}{\pi'_j} + \hat{\mu}_i\cdot (u -\mu'_j^2 )-1 `` -""" -function calculate_volatility_prediction_error(node::ContinuousInputNode) - - #For missing input - if ismissing(node.states.input_value) - #Set the prediction error to missing - volatility_prediction_error = missing - else - #Extract parents - value_parents = node.value_parents - - #Sum the posterior mean and average the posterior precision of the value parents - parents_posterior_mean = 0 - parents_posterior_precision = 0 - - for parent in value_parents - parents_posterior_mean += parent.states.posterior_mean - parents_posterior_precision += parent.states.posterior_precision - end - - parents_posterior_precision = parents_posterior_precision / length(value_parents) - - #Get the VOPE using parents_posterior_precision and parents_posterior_mean - volatility_prediction_error = - node.states.prediction_precision / parents_posterior_precision + - node.states.prediction_precision * - (node.states.input_value - parents_posterior_mean)^2 - 1 - end - - return volatility_prediction_error -end - - -############################################## -######## Binary Input Node Variations ######## -############################################## -@doc raw""" - calculate_value_prediction_error(node::BinaryInputNode) - -Calculates the prediciton error of a binary input node with finite precision. - -Uses the equation -`` \delta_n= u - \sum_{j=1}^{j\;value\;parents} \hat{\mu}_{j} `` -""" -function calculate_value_prediction_error(node::BinaryInputNode) - - #For missing input - if ismissing(node.states.input_value) - #Set the prediction error to missing - value_prediction_error = [missing, missing] - else - #Substract to find the difference to each of the Gaussian means - value_prediction_error = node.parameters.category_means .- node.states.input_value - end -end - - -################################################### -######## Categorical Input Node Variations ######## -################################################### diff --git a/src/update_hgf/update_hgf.jl b/src/update_hgf/update_hgf.jl index e644d8c..0af64b8 100644 --- a/src/update_hgf/update_hgf.jl +++ b/src/update_hgf/update_hgf.jl @@ -50,15 +50,15 @@ function update_hgf!( update_node_posterior!(node, node.update_type) #And its value prediction error update_node_value_prediction_error!(node) - #And its volatility prediction error - update_node_volatility_prediction_error!(node) + #And its precision prediction error + update_node_precision_prediction_error!(node) end - ## Update input node volatility prediction errors + ## Update input node precision prediction errors #For each input node, in the specified update order for node in hgf.ordered_nodes.input_nodes #Update its value prediction error - update_node_volatility_prediction_error!(node) + update_node_precision_prediction_error!(node) end ## Update remaining state nodes @@ -69,7 +69,7 @@ function update_hgf!( #And its value prediction error update_node_value_prediction_error!(node) #And its volatility prediction error - update_node_volatility_prediction_error!(node) + update_node_precision_prediction_error!(node) end return nothing @@ -117,3 +117,17 @@ function enter_node_inputs!(hgf::HGF, inputs::Dict{String,<:Union{Real,Missing}} return nothing end + + +""" + update_node_input!(node::AbstractInputNode, input::Union{Real,Missing}) + +Update the prediction of a single input node. +""" +function update_node_input!(node::AbstractInputNode, input::Union{Real,Missing}) + #Receive input + node.states.input_value = input + push!(node.history.input_value, node.states.input_value) + + return nothing +end diff --git a/src/update_hgf/update_node.jl b/src/update_hgf/update_node.jl deleted file mode 100644 index b8e88ae..0000000 --- a/src/update_hgf/update_node.jl +++ /dev/null @@ -1,308 +0,0 @@ -####################################### -######## Continuous State Node ######## -####################################### -""" - update_node_prediction!(node::AbstractStateNode) - -Update the prediction of a single state node. -""" -function update_node_prediction!(node::AbstractStateNode) - - #Update prediction mean - node.states.prediction_mean = calculate_prediction_mean(node) - push!(node.history.prediction_mean, node.states.prediction_mean) - - #Update prediction volatility - node.states.predicted_volatility = calculate_predicted_volatility(node) - push!(node.history.predicted_volatility, node.states.predicted_volatility) - - #Update prediction precision - node.states.prediction_precision = calculate_prediction_precision(node) - push!(node.history.prediction_precision, node.states.prediction_precision) - - #Get auxiliary prediction precision, only if there are volatility children and/or volatility parents - if length(node.volatility_parents) > 0 || length(node.volatility_children) > 0 - node.states.volatility_weighted_prediction_precision = - calculate_volatility_weighted_prediction_precision(node) - push!( - node.history.volatility_weighted_prediction_precision, - node.states.volatility_weighted_prediction_precision, - ) - end - - return nothing -end - -""" - update_node_posterior!(node::AbstractStateNode; update_type::HGFUpdateType) - -Update the posterior of a single continuous state node. This is the classic HGF update. -""" -function update_node_posterior!(node::AbstractStateNode, update_type::ClassicUpdate) - #Update posterior precision - node.states.posterior_precision = calculate_posterior_precision(node) - push!(node.history.posterior_precision, node.states.posterior_precision) - - #Update posterior mean - node.states.posterior_mean = calculate_posterior_mean(node, update_type) - push!(node.history.posterior_mean, node.states.posterior_mean) - - return nothing -end - -""" - update_node_posterior!(node::AbstractStateNode) - -Update the posterior of a single continuous state node. This is the enahnced HGF update. -""" -function update_node_posterior!(node::AbstractStateNode, update_type::EnhancedUpdate) - #Update posterior mean - node.states.posterior_mean = calculate_posterior_mean(node, update_type) - push!(node.history.posterior_mean, node.states.posterior_mean) - - #Update posterior precision - node.states.posterior_precision = calculate_posterior_precision(node) - push!(node.history.posterior_precision, node.states.posterior_precision) - - return nothing -end - -""" - update_node_value_prediction_error!(node::AbstractStateNode) - -Update the value prediction error of a single state node. -""" -function update_node_value_prediction_error!(node::AbstractStateNode) - #Update value prediction error - node.states.value_prediction_error = calculate_value_prediction_error(node) - push!(node.history.value_prediction_error, node.states.value_prediction_error) - - return nothing -end - -""" - update_node_volatility_prediction_error!(node::AbstractStateNode) - -Update the volatility prediction error of a single state node. -""" -function update_node_volatility_prediction_error!(node::AbstractStateNode) - - #Update volatility prediction error, only if there are volatility parents - if length(node.volatility_parents) > 0 - node.states.volatility_prediction_error = - calculate_volatility_prediction_error(node) - push!( - node.history.volatility_prediction_error, - node.states.volatility_prediction_error, - ) - end - - return nothing -end - - -############################################## -######## Binary State Node Variations ######## -############################################## -""" - update_node_prediction!(node::BinaryStateNode) - -Update the prediction of a single binary state node. -""" -function update_node_prediction!(node::BinaryStateNode) - - #Update prediction mean - node.states.prediction_mean = calculate_prediction_mean(node) - push!(node.history.prediction_mean, node.states.prediction_mean) - - #Update prediction precision - node.states.prediction_precision = calculate_prediction_precision(node) - push!(node.history.prediction_precision, node.states.prediction_precision) - - return nothing -end - -""" - update_node_volatility_prediction_error!(node::BinaryStateNode) - -There is no volatility prediction error update for binary state nodes. -""" -function update_node_volatility_prediction_error!(node::BinaryStateNode) - return nothing -end - - -################################################### -######## Categorical State Node Variations ######## -################################################### -""" - update_node_prediction!(node::CategoricalStateNode) - -Update the prediction of a single categorical state node. -""" -function update_node_prediction!(node::CategoricalStateNode) - - #Update prediction mean - node.states.prediction, node.states.parent_predictions = calculate_prediction(node) - push!(node.history.prediction, node.states.prediction) - push!(node.history.parent_predictions, node.states.parent_predictions) - return nothing -end - -""" - update_node_posterior!(node::CategoricalStateNode) - -Update the posterior of a single categorical state node. -""" -function update_node_posterior!(node::CategoricalStateNode, update_type::ClassicUpdate) - - #Update posterior mean - node.states.posterior = calculate_posterior(node) - push!(node.history.posterior, node.states.posterior) - - return nothing -end - -""" - update_node_volatility_prediction_error!(node::CategoricalStateNode) - -There is no volatility prediction error update for categorical state nodes. -""" -function update_node_volatility_prediction_error!(node::CategoricalStateNode) - return nothing -end - - -################################################### -######## Conntinuous Input Node Variations ######## -################################################### -""" - update_node_input!(node::AbstractInputNode, input::Union{Real,Missing}) - -Update the prediction of a single input node. -""" -function update_node_input!(node::AbstractInputNode, input::Union{Real,Missing}) - #Receive input - node.states.input_value = input - push!(node.history.input_value, node.states.input_value) - - return nothing -end - -""" - update_node_prediction!(node::AbstractInputNode) - -Update the posterior of a single input node. -""" -function update_node_prediction!(node::AbstractInputNode) - #Update prediction volatility - node.states.predicted_volatility = calculate_predicted_volatility(node) - push!(node.history.predicted_volatility, node.states.predicted_volatility) - - #Update prediction precision - node.states.prediction_precision = calculate_prediction_precision(node) - push!(node.history.prediction_precision, node.states.prediction_precision) - - return nothing -end - -""" - update_node_value_prediction_error!(node::AbstractInputNode) - -Update the value prediction error of a single input node. -""" -function update_node_value_prediction_error!(node::AbstractInputNode) - - #Calculate value prediction error - node.states.value_prediction_error = calculate_value_prediction_error(node) - push!(node.history.value_prediction_error, node.states.value_prediction_error) - - return nothing -end - -""" - update_node_volatility_prediction_error!(node::AbstractInputNode) - -Update the value prediction error of a single input node. -""" -function update_node_volatility_prediction_error!(node::AbstractInputNode) - - #Calculate volatility prediction error, only if there are volatility parents - if length(node.volatility_parents) > 0 - node.states.volatility_prediction_error = - calculate_volatility_prediction_error(node) - push!( - node.history.volatility_prediction_error, - node.states.volatility_prediction_error, - ) - end - - return nothing -end - - -############################################## -######## Binary Input Node Variations ######## -############################################## -""" - update_node_prediction!(node::BinaryInputNode) - -There is no prediction update for binary input nodes, as the prediction precision is constant. -""" -function update_node_prediction!(node::BinaryInputNode) - return nothing -end - -""" - update_node_value_prediction_error!(node::BinaryInputNode) - -Update the value prediction error of a single binary input node. -""" -function update_node_value_prediction_error!(node::BinaryInputNode) - - #Calculate value prediction error - node.states.value_prediction_error = calculate_value_prediction_error(node) - push!(node.history.value_prediction_error, node.states.value_prediction_error) - - return nothing -end - -""" - update_node_volatility_prediction_error!(node::BinaryInputNode) - -There is no volatility prediction error update for binary input nodes. -""" -function update_node_volatility_prediction_error!(node::BinaryInputNode) - return nothing -end - - -################################################### -######## Categorical Input Node Variations ######## -################################################### -""" - update_node_prediction!(node::CategoricalInputNode) - -There is no prediction update for categorical input nodes, as the prediction precision is constant. -""" -function update_node_prediction!(node::CategoricalInputNode) - return nothing -end - -""" - update_node_value_prediction_error!(node::CategoricalInputNode) - - There is no value prediction error update for categorical input nodes. -""" -function update_node_value_prediction_error!(node::CategoricalInputNode) - return nothing -end - -""" - update_node_volatility_prediction_error!(node::CategoricalInputNode) - -There is no volatility prediction error update for categorical input nodes. -""" -function update_node_volatility_prediction_error!(node::CategoricalInputNode) - return nothing -end diff --git a/src/utils/get_prediction.jl b/src/utils/get_prediction.jl index 7bdd30a..9151d40 100644 --- a/src/utils/get_prediction.jl +++ b/src/utils/get_prediction.jl @@ -20,7 +20,7 @@ function get_prediction(hgf::HGF, node_name::String) end ### Single node functions ### -function get_prediction(node::AbstractNode) +function get_prediction(node::ContinuousStateNode) #Save old states old_states = (; @@ -54,7 +54,8 @@ function get_prediction(node::AbstractNode) node.states.prediction_mean = old_states.prediction_mean node.states.predicted_volatility = old_states.predicted_volatility node.states.prediction_precision = old_states.prediction_precision - node.states.volatility_weighted_prediction_precision = old_states.volatility_weighted_prediction_precision + node.states.volatility_weighted_prediction_precision = + old_states.volatility_weighted_prediction_precision return new_states end @@ -104,29 +105,26 @@ function get_prediction(node::CategoricalStateNode) end -function get_prediction(node::AbstractInputNode) +function get_prediction(node::ContinuousInputNode) #Save old states old_states = (; - predicted_volatility = node.states.predicted_volatility, + prediction_mean = node.states.prediction_mean, prediction_precision = node.states.prediction_precision, ) - #Update prediction volatility - node.states.predicted_volatility = calculate_predicted_volatility(node) - #Update prediction precision + node.states.prediction_mean = calculate_prediction_mean(node) node.states.prediction_precision = calculate_prediction_precision(node) #Save new states new_states = (; - predicted_volatility = node.states.predicted_volatility, + prediction_mean = node.states.prediction_mean, prediction_precision = node.states.prediction_precision, - volatility_weighted_prediction_precision = 1.0, ) #Change states back to the old states - node.states.predicted_volatility = old_states.predicted_volatility + node.states.prediction_mean = old_states.prediction_mean node.states.prediction_precision = old_states.prediction_precision return new_states diff --git a/src/utils/get_surprise.jl b/src/utils/get_surprise.jl index 9de92f9..acdcf25 100644 --- a/src/utils/get_surprise.jl +++ b/src/utils/get_surprise.jl @@ -63,7 +63,7 @@ function get_surprise(node::ContinuousInputNode) #Sum the predictions of the vaue parents parents_prediction_mean = 0 - for parent in node.value_parents + for parent in node.edges.observation_parents parents_prediction_mean += parent.states.prediction_mean end @@ -87,7 +87,7 @@ function get_surprise(node::BinaryInputNode) #Sum the predictions of the vaue parents parents_prediction_mean = 0 - for parent in node.value_parents + for parent in node.edges.observation_parents parents_prediction_mean += parent.states.prediction_mean end @@ -138,7 +138,7 @@ Calculate the surprise of a categorical input node on seeing the last input. function get_surprise(node::CategoricalInputNode) #Get value parent - parent = node.value_parents[1] + parent = node.edges.observation_parents[1] #Get surprise surprise = sum(-log.(exp.(log.(parent.states.prediction) .* parent.states.posterior))) diff --git a/test/Project.toml b/test/Project.toml index 3b91bab..2c924d0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" diff --git a/test/quicktests.jl b/test/quicktests.jl index 7afa9e3..73f1da7 100644 --- a/test/quicktests.jl +++ b/test/quicktests.jl @@ -1,7 +1,7 @@ using HierarchicalGaussianFiltering -hgf = premade_hgf("continuous_2level", verbose = true) +hgf = premade_hgf("continuous_2level", verbose = false) -update_hgf!(hgf, [0.01, 0.02, 0.06]) +give_inputs!(hgf, [0.01, 0.02, 0.06]) get_states(hgf) diff --git a/test/runtests.jl b/test/runtests.jl index 524fd5d..5cbf2f5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,58 +1,57 @@ -using ActionModels using HierarchicalGaussianFiltering using Test +using Glob -@testset "Unit tests" begin +#Get the root path of the package +hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) - # Test the quick tests that are used as pre-commit tests - include("quicktests.jl") +@testset "All tests" begin - # Test that the HGF gives canonical outputs - include("test_canonical.jl") + @testset "Unit tests" begin - # Test initialization - include("test_initialization.jl") + #Get the path to the testing folder + test_path = hgf_path * "/test/" - # Test premade HGF models - include("test_premade_hgf.jl") + @testset "quick tests" begin + # Test the quick tests that are used as pre-commit tests + include(test_path * "quicktests.jl") + end - # Test shared parameters - include("test_shared_parameters.jl") + # List the julia filenames in the testsuite + filenames = glob("*.jl", test_path * "testsuite") - # Test premade action models - include("test_premade_agent.jl") - - #Run fitting tests - include("test_fit_model.jl") + # For each file + for filename in filenames + #Run it + include(filename) + end + end - # Test update_hgf - # Test node_update - # Test action models - # Test update equations + @testset "Documentation tests" begin -end + #Set up path for the documentation folder + documentation_path = hgf_path * "/docs/src/" + @testset "Sourcefiles" begin -@testset "Documentation" begin + # List the julia filenames in the documentation source files folder + filenames = glob("*.jl", documentation_path * "/Julia_src_files") - #Set up path for the documentation folder - hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) - documentation_path = hgf_path * "/docs/src/" + for filename in filenames + include(filename) + end + end - @testset "tutorials" begin - - #Get path for the tutorials subfolder - tutorials_path = documentation_path * "tutorials/" - - #Classic tutorials - include(tutorials_path * "classic_binary.jl") - include(tutorials_path * "classic_usdchf.jl") - end + @testset "Tutorials" begin - @testset "sourcefiles" begin - - #Get path for the tutorials subfolder - sourcefiles_path = documentation_path * "Julia_src_files/" + # List the julia filenames in the tutorials folder + filenames = glob("*.jl", documentation_path * "/tutorials") + # For each file + for filename in filenames + #Run it + include(filename) + end + end end end diff --git a/test/data/canonical_binary3level.csv b/test/testsuite/data/canonical_binary3level.csv similarity index 100% rename from test/data/canonical_binary3level.csv rename to test/testsuite/data/canonical_binary3level.csv diff --git a/test/data/canonical_continuous2level_inputs.dat b/test/testsuite/data/canonical_continuous2level_inputs.dat similarity index 100% rename from test/data/canonical_continuous2level_inputs.dat rename to test/testsuite/data/canonical_continuous2level_inputs.dat diff --git a/test/data/canonical_continuous2level_states.csv b/test/testsuite/data/canonical_continuous2level_states.csv similarity index 100% rename from test/data/canonical_continuous2level_states.csv rename to test/testsuite/data/canonical_continuous2level_states.csv diff --git a/test/test_canonical.jl b/test/testsuite/test_canonical.jl similarity index 95% rename from test/test_canonical.jl rename to test/testsuite/test_canonical.jl index c8e73cc..e97ac49 100644 --- a/test/test_canonical.jl +++ b/test/testsuite/test_canonical.jl @@ -10,7 +10,7 @@ using Plots #Get the path for the HGF superfolder hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) #Add the path to the data files - data_path = hgf_path * "/test/data/" + data_path = hgf_path * "/test/testsuite/data/" @testset "Canonical Continuous 2level" begin @@ -36,15 +36,14 @@ using Plots ### Set up HGF ### #set parameters and starting states parameters = Dict( - ("u", "x", "value_coupling") => 1.0, - ("x", "xvol", "volatility_coupling") => 1.0, ("u", "input_noise") => log(1e-4), ("x", "volatility") => -13, - ("xvol", "volatility") => -2, ("x", "initial_mean") => 1.04, ("x", "initial_precision") => 1e4, + ("xvol", "volatility") => -2, ("xvol", "initial_mean") => 1.0, ("xvol", "initial_precision") => 10, + ("x", "xvol", "coupling_strength") => 1.0, "update_type" => ClassicUpdate(), ) @@ -97,8 +96,8 @@ using Plots ("u", "input_precision") => Inf, ("xprob", "volatility") => -2.5, ("xvol", "volatility") => -6.0, - ("xbin", "xprob", "value_coupling") => 1.0, - ("xprob", "xvol", "volatility_coupling") => 1.0, + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, ("xprob", "initial_mean") => 0.0, ("xprob", "initial_precision") => 1.0, ("xvol", "initial_mean") => 1.0, diff --git a/test/test_fit_model.jl b/test/testsuite/test_fit_model.jl similarity index 95% rename from test/test_fit_model.jl rename to test/testsuite/test_fit_model.jl index a46c6ad..bb65660 100644 --- a/test/test_fit_model.jl +++ b/test/testsuite/test_fit_model.jl @@ -25,8 +25,7 @@ using Turing ("x", "initial_mean") => 100, ("xvol", "initial_mean") => 1.0, ("xvol", "initial_precision") => 600, - ("u", "x", "value_coupling") => 1.0, - ("x", "xvol", "volatility_coupling") => 1.0, + ("x", "xvol", "coupling_strength") => 1.0, "gaussian_action_precision" => 100, ("xvol", "volatility") => -4, ("u", "input_noise") => 4, @@ -100,8 +99,8 @@ using Turing ("xprob", "initial_precision") => exp(2.306), ("xvol", "initial_mean") => 3.2189, ("xvol", "initial_precision") => exp(-1.0986), - ("xbin", "xprob", "value_coupling") => 1.0, - ("xprob", "xvol", "volatility_coupling") => 1.0, + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, ("xvol", "volatility") => -3, ) diff --git a/test/test_initialization.jl b/test/testsuite/test_initialization.jl similarity index 74% rename from test/test_initialization.jl rename to test/testsuite/test_initialization.jl index bd21fc8..044b0a6 100644 --- a/test/test_initialization.jl +++ b/test/testsuite/test_initialization.jl @@ -10,7 +10,7 @@ using Test "input_precision" => Inf, "initial_mean" => 1, "initial_precision" => 2, - "value_coupling" => 1, + "coupling_strength" => 1, "drift" => 2, ) @@ -28,20 +28,19 @@ using Test "volatility" => 2, "initial_mean" => 4, "initial_precision" => 3, - "drift" => 5 + "drift" => 5, ), ] #List of child-parent relations - edges = [ - Dict("child" => "u1", "value_parents" => "x1"), - Dict("child" => "u2", "value_parents" => "x2", "volatility_parents" => "x3"), - Dict( - "child" => "x1", - "value_parents" => ("x3", 2), - "volatility_parents" => [("x4", 2), "x5"], - ), - ] + edges = Dict( + ("u1", "x1") => ObservationCoupling(), + ("u2", "x2") => ObservationCoupling(), + ("u2", "x3") => NoiseCoupling(), + ("x1", "x3") => DriftCoupling(2), + ("x1", "x4") => VolatilityCoupling(2), + ("x1", "x5") => VolatilityCoupling(), + ) #Initialize an HGF test_hgf = init_hgf( @@ -65,12 +64,9 @@ using Test @test test_hgf.state_nodes["x1"].parameters.drift == 2 @test test_hgf.state_nodes["x5"].parameters.drift == 5 - @test test_hgf.input_nodes["u1"].parameters.value_coupling["x1"] == 1 - @test test_hgf.input_nodes["u2"].parameters.value_coupling["x2"] == 1 - @test test_hgf.input_nodes["u2"].parameters.volatility_coupling["x3"] == 1 - @test test_hgf.state_nodes["x1"].parameters.value_coupling["x3"] == 2 - @test test_hgf.state_nodes["x1"].parameters.volatility_coupling["x4"] == 2 - @test test_hgf.state_nodes["x1"].parameters.volatility_coupling["x5"] == 1 + @test test_hgf.state_nodes["x1"].parameters.coupling_strengths["x3"] == 2 + @test test_hgf.state_nodes["x1"].parameters.coupling_strengths["x4"] == 2 + @test test_hgf.state_nodes["x1"].parameters.coupling_strengths["x5"] == 1 @test test_hgf.state_nodes["x1"].states.posterior_mean == 1 @test test_hgf.state_nodes["x1"].states.posterior_precision == 2 diff --git a/test/test_premade_agent.jl b/test/testsuite/test_premade_agent.jl similarity index 100% rename from test/test_premade_agent.jl rename to test/testsuite/test_premade_agent.jl diff --git a/test/test_premade_hgf.jl b/test/testsuite/test_premade_hgf.jl similarity index 100% rename from test/test_premade_hgf.jl rename to test/testsuite/test_premade_hgf.jl diff --git a/test/test_shared_parameters.jl b/test/testsuite/test_shared_parameters.jl similarity index 92% rename from test/test_shared_parameters.jl rename to test/testsuite/test_shared_parameters.jl index 9231e01..1551d05 100644 --- a/test/test_shared_parameters.jl +++ b/test/testsuite/test_shared_parameters.jl @@ -25,11 +25,7 @@ state_nodes = [ ] #List of child-parent relations -edges = [ - Dict("child" => "u", "value_parents" => ("x1", 1)), - Dict("child" => "x1", "volatility_parents" => ("x2", 1)), -] - +edges = Dict(("u", "x1") => ObservationCoupling(), ("x1", "x2") => VolatilityCoupling(1)) # one shared parameter shared_parameters_1 = From 8f5a2b9f6b699f167456e7dd587f42a9097d8ef2 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Tue, 28 Nov 2023 18:02:39 +0100 Subject: [PATCH 04/16] changed node input syntax and update type --- docs/src/Julia_src_files/building_an_HGF.jl | 36 +-- src/ActionModels_variations/utils/reset.jl | 1 - src/HierarchicalGaussianFiltering.jl | 3 + src/create_hgf/hgf_structs.jl | 55 +++- src/create_hgf/init_hgf.jl | 277 +++++------------- .../premade_agents/premade_gaussian_action.jl | 1 + .../premade_multiple_actions.jl | 1 + .../premade_predict_category.jl | 1 + .../premade_agents/premade_sigmoid_action.jl | 1 + .../premade_agents/premade_softmax_action.jl | 1 + .../premade_hgfs/premade_JGET.jl | 87 +++--- .../premade_hgfs/premade_binary_2level.jl | 38 +-- .../premade_hgfs/premade_binary_3level.jl | 55 ++-- .../premade_categorical_3level.jl | 55 ++-- .../premade_categorical_transitions_3level.jl | 58 ++-- .../premade_hgfs/premade_continuous_2level.jl | 52 ++-- .../node_updates/binary_input_node.jl | 26 -- test/testsuite/test_initialization.jl | 48 ++- test/testsuite/test_shared_parameters.jl | 38 +-- 19 files changed, 313 insertions(+), 521 deletions(-) create mode 100644 src/premade_models/premade_agents/premade_gaussian_action.jl create mode 100644 src/premade_models/premade_agents/premade_multiple_actions.jl create mode 100644 src/premade_models/premade_agents/premade_predict_category.jl create mode 100644 src/premade_models/premade_agents/premade_sigmoid_action.jl create mode 100644 src/premade_models/premade_agents/premade_softmax_action.jl diff --git a/docs/src/Julia_src_files/building_an_HGF.jl b/docs/src/Julia_src_files/building_an_HGF.jl index 7d1a2a5..d1b35be 100644 --- a/docs/src/Julia_src_files/building_an_HGF.jl +++ b/docs/src/Julia_src_files/building_an_HGF.jl @@ -17,12 +17,16 @@ # We can recall from the HGF nodes, that a binary input node's parameters are category means and input precision. We will set category means to [0,1] and the input precision to Inf. -input_nodes = Dict( - "name" => "Input_node", - "type" => "binary", - "category_means" => [0, 1], - "input_precision" => Inf, -); +nodes = [ + BinaryInput("Input_node"), + BinaryState("binary_state_node"), + ContinuousState( + name = "continuous_state_node", + volatility = -2, + initial_mean = 0, + initial_precision = 1, + ), +] # ## Defining State Nodes @@ -30,19 +34,6 @@ input_nodes = Dict( # The continuous state node have evolution rate, initial mean and initial precision parameters which we specify as well. -state_nodes = [ - ## Configuring the first binary state node - Dict("name" => "binary_state_node", "type" => "binary"), - ## Configuring the continuous state node - Dict( - "name" => "continuous_state_node", - "type" => "continuous", - "volatility" => -2, - "initial_mean" => 0, - "initial_precision" => 1, - ), -]; - # ## Defining Edges # When defining the edges we start by sepcifying which node the perspective is from. So, when we specify the edges we start by specifying what the child in the relation is. @@ -59,12 +50,7 @@ edges = Dict( using HierarchicalGaussianFiltering using ActionModels -Binary_2_level_hgf = init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - verbose = false, -); +Binary_2_level_hgf = init_hgf(nodes = nodes, edges = edges, verbose = false); # We can access the states in our HGF: get_states(Binary_2_level_hgf) #- diff --git a/src/ActionModels_variations/utils/reset.jl b/src/ActionModels_variations/utils/reset.jl index a877670..65ca1c3 100644 --- a/src/ActionModels_variations/utils/reset.jl +++ b/src/ActionModels_variations/utils/reset.jl @@ -69,7 +69,6 @@ end function reset_state!(node::BinaryInputNode) node.states.input_value = missing - node.states.value_prediction_error .= missing return nothing end diff --git a/src/HierarchicalGaussianFiltering.jl b/src/HierarchicalGaussianFiltering.jl index f43a36a..d5969c4 100644 --- a/src/HierarchicalGaussianFiltering.jl +++ b/src/HierarchicalGaussianFiltering.jl @@ -14,6 +14,9 @@ export premade_agent, plot_trajectory! export get_history, get_parameters, get_states, set_parameters!, reset!, give_inputs! export EnhancedUpdate, ClassicUpdate +export NodeDefaults +export ContinuousState, + ContinuousInput, BinaryState, BinaryInput, CategoricalState, CategoricalInput export DriftCoupling, ObservationCoupling, CategoryCoupling, diff --git a/src/create_hgf/hgf_structs.jl b/src/create_hgf/hgf_structs.jl index 59fa088..915cbed 100644 --- a/src/create_hgf/hgf_structs.jl +++ b/src/create_hgf/hgf_structs.jl @@ -17,6 +17,11 @@ abstract type AbstractBinaryInputNode <: AbstractInputNode end abstract type AbstractCategoricalStateNode <: AbstractStateNode end abstract type AbstractCategoricalInputNode <: AbstractInputNode end +#Abstract type for node information +abstract type AbstractNodeInfo end +abstract type AbstractInputNodeInfo <: AbstractNodeInfo end +abstract type AbstractStateNodeInfo <: AbstractNodeInfo end + ################################## ######## HGF update types ######## ################################## @@ -78,6 +83,51 @@ Base.@kwdef mutable struct HGF shared_parameters::Dict = Dict() end +################################## +######## HGF Info Structs ######## +################################## +Base.@kwdef struct NodeDefaults + input_noise::Real = -2 + volatility::Real = -2 + drift::Real = 0 + autoregression_target::Real = 0 + autoregression_strength::Real = 0 + initial_mean::Real = 0 + initial_precision::Real = 1 + coupling_strength::Real = 1 + update_type::HGFUpdateType = EnhancedUpdate() +end + +Base.@kwdef mutable struct ContinuousState <: AbstractStateNodeInfo + name::String + volatility::Union{Real,Nothing} = nothing + drift::Union{Real,Nothing} = nothing + autoregression_target::Union{Real,Nothing} = nothing + autoregression_strength::Union{Real,Nothing} = nothing + initial_mean::Union{Real,Nothing} = nothing + initial_precision::Union{Real,Nothing} = nothing +end + +Base.@kwdef mutable struct ContinuousInput <: AbstractInputNodeInfo + name::String + input_noise::Union{Real,Nothing} = nothing +end + +Base.@kwdef mutable struct BinaryState <: AbstractStateNodeInfo + name::String +end + +Base.@kwdef mutable struct BinaryInput <: AbstractInputNodeInfo + name::String +end + +Base.@kwdef mutable struct CategoricalState <: AbstractStateNodeInfo + name::String +end + +Base.@kwdef mutable struct CategoricalInput <: AbstractInputNodeInfo + name::String +end @@ -203,7 +253,6 @@ Base.@kwdef mutable struct ContinuousInputNode <: AbstractContinuousInputNode history::ContinuousInputNodeHistory = ContinuousInputNodeHistory() end - ################################### ######## Binary State Node ######## ################################### @@ -261,7 +310,6 @@ end - ################################### ######## Binary Input Node ######## ################################### @@ -283,7 +331,6 @@ Configuration of states of binary input node """ Base.@kwdef mutable struct BinaryInputNodeState input_value::Union{Real,Missing} = missing - value_prediction_error::Vector{Union{Real,Missing}} = [missing, missing] end """ @@ -291,7 +338,6 @@ Configuration of history of binary input node """ Base.@kwdef mutable struct BinaryInputNodeHistory input_value::Vector{Union{Real,Missing}} = [missing] - value_prediction_error::Vector{Vector{Union{Real,Missing}}} = [[missing, missing]] end """ @@ -358,7 +404,6 @@ end - ######################################## ######## Categorical Input Node ######## ######################################## diff --git a/src/create_hgf/init_hgf.jl b/src/create_hgf/init_hgf.jl index f05c2f7..71cf55e 100644 --- a/src/create_hgf/init_hgf.jl +++ b/src/create_hgf/init_hgf.jl @@ -73,103 +73,55 @@ hgf = init_hgf( ``` """ function init_hgf(; - input_nodes::Union{String,Dict,Vector}, - state_nodes::Union{String,Dict,Vector}, + nodes::Vector{<:AbstractNodeInfo}, edges::Dict{Tuple{String,String},<:CouplingType}, + node_defaults::NodeDefaults = NodeDefaults(), shared_parameters::Dict = Dict(), - node_defaults::Dict = Dict(), - update_type::HGFUpdateType = EnhancedUpdate(), update_order::Union{Nothing,Vector{String}} = nothing, verbose::Bool = true, ) - ### Defaults ### - preset_node_defaults = Dict( - "type" => "continuous", - "volatility" => -2, - "drift" => 0, - "autoregression_target" => 0, - "autoregression_strength" => 0, - "initial_mean" => 0, - "initial_precision" => 1, - "coupling_strength" => 1, - "category_means" => [0, 1], - "input_precision" => Inf, - "input_noise" => -2, - ) - - #If verbose - if verbose - #If some node defaults have been specified - if length(node_defaults) > 0 - #Warn the user of unspecified defaults and errors - warn_premade_defaults( - preset_node_defaults, - node_defaults, - "in the node defaults,", - ) - end - end - - #Use presets wherever node defaults were not given - node_defaults = merge(preset_node_defaults, node_defaults) - ### Initialize nodes ### #Initialize empty dictionaries for storing nodes all_nodes_dict = Dict{String,AbstractNode}() input_nodes_dict = Dict{String,AbstractInputNode}() state_nodes_dict = Dict{String,AbstractStateNode}() - - ## Input nodes ## - - #If user has only specified a single node and not in a vector - if input_nodes isa Dict - #Put it in a vector - input_nodes = [input_nodes] - end + input_nodes_inputted_order = Vector{String}() + state_nodes_inputted_order = Vector{String}() #For each specified input node - for node_info in input_nodes - - #If only the node's name was specified as a string - if node_info isa String - #Make it into a dictionary - node_info = Dict("name" => node_info) + for node_info in nodes + #For each field in the node info + for fieldname in fieldnames(typeof(node_info)) + #If it hasn't been specified by the user + if isnothing(getfield(node_info, fieldname)) + #Set the node_defaults' value instead + setfield!(node_info, fieldname, getfield(node_defaults, fieldname)) + end end #Create the node - node = init_node("input_node", node_defaults, node_info) - - #Add it to the dictionary - all_nodes_dict[node.name] = node - input_nodes_dict[node.name] = node - end - - ## State nodes ## - #If user has only specified a single node and not in a vector - if state_nodes isa Dict - #Put it in a vector - state_nodes = [state_nodes] - end + node = init_node(node_info) - #For each specified state node - for node_info in state_nodes + #Add it to the large dictionary + all_nodes_dict[node_info.name] = node - #If only the node's name was specified as a string - if node_info isa String - #Make it into a named tuple - node_info = Dict("name" => node_info) + #If it is an input node + if node isa AbstractInputNode + #Add it to the input node dict + input_nodes_dict[node_info.name] = node + #Store its name in the inputted order + push!(input_nodes_inputted_order, node_info.name) + + #If it is a state node + elseif node isa AbstractStateNode + #Add it to the state node dict + state_nodes_dict[node_info.name] = node + #Store its name in the inputted order + push!(state_nodes_inputted_order, node_info.name) end - - #Create the node - node = init_node("state_node", node_defaults, node_info) - - #Add it to the dictionary - all_nodes_dict[node.name] = node - state_nodes_dict[node.name] = node end - ### Set up edges ### #For each specified edge for (node_names, coupling_type) in edges @@ -182,7 +134,7 @@ function init_hgf(; parent_node = all_nodes_dict[parent_name] #Create the edge - init_edge!(child_node, parent_node, coupling_type, update_type, node_defaults) + init_edge!(child_node, parent_node, coupling_type, node_defaults) end ## Determine Update order ## @@ -195,34 +147,8 @@ function init_hgf(; @warn "No update order specified. Using the order in which nodes were inputted" end - #Initialize empty vector for storing the update order - update_order = [] - - #For each input node, in the order inputted - for node_info in input_nodes - - #If only the node's name was specified as a string - if node_info isa String - #Make it into a named tuple - node_info = Dict("name" => node_info) - end - - #Add the node to the vector - push!(update_order, all_nodes_dict[node_info["name"]]) - end - - #For each state node, in the order inputted - for node_info in state_nodes - - #If only the node's name was specified as a string - if node_info isa String - #Make it into a named tuple - node_info = Dict("name" => node_info) - end - - #Add the node to the vector - push!(update_order, all_nodes_dict[node_info["name"]]) - end + #Use the order that the nodes were specified in + update_order = append!(input_nodes_inputted_order, state_nodes_inputted_order) end ## Order nodes ## @@ -230,7 +156,10 @@ function init_hgf(; ordered_nodes = OrderedNodes() #For each node, in the specified update order - for node in update_order + for node_name in update_order + + #Extract node + node = all_nodes_dict[node_name] #Have a field for all nodes push!(ordered_nodes.all_nodes, node) @@ -331,114 +260,51 @@ end -""" - init_node(input_or_state_node, node_defaults, node_info) -Function for creating a node, given specifications -""" -function init_node(input_or_state_node, node_defaults, node_info) - - #Get parameters and starting state. Specific node settings supercede node defaults, which again supercede the function's defaults. - parameters = merge(node_defaults, node_info) - - #For an input node - if input_or_state_node == "input_node" - #If it is continuous - if parameters["type"] == "continuous" - #Initialize it - node = ContinuousInputNode( - name = parameters["name"], - parameters = ContinuousInputNodeParameters( - input_noise = parameters["input_noise"], - ), - states = ContinuousInputNodeState(), - ) - #If it is binary - elseif parameters["type"] == "binary" - #Initialize it - node = BinaryInputNode( - name = parameters["name"], - parameters = BinaryInputNodeParameters( - category_means = parameters["category_means"], - input_precision = parameters["input_precision"], - ), - states = BinaryInputNodeState(), - ) - #If it is categorical - elseif parameters["type"] == "categorical" - #Initialize it - node = CategoricalInputNode( - name = parameters["name"], - parameters = CategoricalInputNodeParameters(), - states = CategoricalInputNodeState(), - ) - else - #The node has been misspecified. Throw an error - throw( - ArgumentError("the type of node $parameters['name'] has been misspecified"), - ) - end - - #For a state node - elseif input_or_state_node == "state_node" - #If it is continuous - if parameters["type"] == "continuous" - #Initialize it - node = ContinuousStateNode( - name = parameters["name"], - #Set parameters - parameters = ContinuousStateNodeParameters( - volatility = parameters["volatility"], - drift = parameters["drift"], - initial_mean = parameters["initial_mean"], - initial_precision = parameters["initial_precision"], - autoregression_target = parameters["autoregression_target"], - autoregression_strength = parameters["autoregression_strength"], - ), - #Set states - states = ContinuousStateNodeState( - posterior_mean = parameters["initial_mean"], - posterior_precision = parameters["initial_precision"], - ), - ) +function init_node(node_info::ContinuousState) + ContinuousStateNode( + name = node_info.name, + parameters = ContinuousStateNodeParameters( + volatility = node_info.volatility, + drift = node_info.drift, + initial_mean = node_info.initial_mean, + initial_precision = node_info.initial_precision, + autoregression_target = node_info.autoregression_target, + autoregression_strength = node_info.autoregression_strength, + ), + ) +end - #If it is binary - elseif parameters["type"] == "binary" - #Initialize it - node = BinaryStateNode( - name = parameters["name"], - parameters = BinaryStateNodeParameters(), - states = BinaryStateNodeState(), - ) +function init_node(node_info::ContinuousInput) + ContinuousInputNode( + name = node_info.name, + parameters = ContinuousInputNodeParameters(input_noise = node_info.input_noise), + ) +end - #If it categorical - elseif parameters["type"] == "categorical" +function init_node(node_info::BinaryState) + BinaryStateNode(name = node_info.name) +end - #Initialize it - node = CategoricalStateNode( - name = parameters["name"], - parameters = CategoricalStateNodeParameters(), - states = CategoricalStateNodeState(), - ) +function init_node(node_info::BinaryInput) + BinaryInputNode(name = node_info.name) +end - else - #The node has been misspecified. Throw an error - throw( - ArgumentError("the type of node $parameters['name'] has been misspecified"), - ) - end - end +function init_node(node_info::CategoricalState) + CategoricalStateNode(name = node_info.name) +end - return node +function init_node(node_info::CategoricalInput) + CategoricalInputNode(name = node_info.name) end + ### Function for initializing and edge ### function init_edge!( child_node::AbstractNode, parent_node::AbstractStateNode, coupling_type::CouplingType, - update_type::HGFUpdateType, - node_defaults::Dict, + node_defaults::NodeDefaults, ) #Get correct field for storing parents @@ -478,7 +344,7 @@ function init_edge!( #If the user has not specified a coupling strength if isnothing(coupling_type.strength) #Use the defaults coupling strength - coupling_strength = node_defaults["coupling_strength"] + coupling_strength = node_defaults.coupling_strength #Otherwise else @@ -490,9 +356,10 @@ function init_edge!( child_node.parameters.coupling_strengths[parent_node.name] = coupling_strength end - #If the enhanced HGF update is used, and if it is a precision coupling (volatility or noise) - if update_type isa EnhancedUpdate && coupling_type isa PrecisionCoupling + + #If the enhanced HGF update is the defaults, and if it is a precision coupling (volatility or noise) + if node_defaults.update_type isa EnhancedUpdate && coupling_type isa PrecisionCoupling #Set the node to use the enhanced HGF update - parent_node.update_type = update_type + parent_node.update_type = node_defaults.update_type end end diff --git a/src/premade_models/premade_agents/premade_gaussian_action.jl b/src/premade_models/premade_agents/premade_gaussian_action.jl new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/premade_models/premade_agents/premade_gaussian_action.jl @@ -0,0 +1 @@ + diff --git a/src/premade_models/premade_agents/premade_multiple_actions.jl b/src/premade_models/premade_agents/premade_multiple_actions.jl new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/premade_models/premade_agents/premade_multiple_actions.jl @@ -0,0 +1 @@ + diff --git a/src/premade_models/premade_agents/premade_predict_category.jl b/src/premade_models/premade_agents/premade_predict_category.jl new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/premade_models/premade_agents/premade_predict_category.jl @@ -0,0 +1 @@ + diff --git a/src/premade_models/premade_agents/premade_sigmoid_action.jl b/src/premade_models/premade_agents/premade_sigmoid_action.jl new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/premade_models/premade_agents/premade_sigmoid_action.jl @@ -0,0 +1 @@ + diff --git a/src/premade_models/premade_agents/premade_softmax_action.jl b/src/premade_models/premade_agents/premade_softmax_action.jl new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/premade_models/premade_agents/premade_softmax_action.jl @@ -0,0 +1 @@ + diff --git a/src/premade_models/premade_hgfs/premade_JGET.jl b/src/premade_models/premade_hgfs/premade_JGET.jl index 1685f49..ec5c74f 100644 --- a/src/premade_models/premade_hgfs/premade_JGET.jl +++ b/src/premade_models/premade_hgfs/premade_JGET.jl @@ -65,56 +65,44 @@ function premade_JGET(config::Dict; verbose::Bool = true) #Merge to overwrite defaults config = merge(spec_defaults, config) - - #List of input nodes to create - input_nodes = Dict( - "name" => "u", - "type" => "continuous", - "input_noise" => config[("u", "input_noise")], - ) - - #List of state nodes to create - state_nodes = [ - Dict( - "name" => "x", - "type" => "continuous", - "volatility" => config[("x", "volatility")], - "drift" => config[("x", "drift")], - "autoregression_target" => config[("x", "autoregression_target")], - "autoregression_strength" => config[("x", "autoregression_strength")], - "initial_mean" => config[("x", "initial_mean")], - "initial_precision" => config[("x", "initial_precision")], + #List of nodes + nodes = [ + ContinuousInput(name = "u", input_noise = config[("u", "input_noise")]), + ContinuousState( + name = "x", + volatility = config[("x", "volatility")], + drift = config[("x", "drift")], + autoregression_target = config[("x", "autoregression_target")], + autoregression_strength = config[("x", "autoregression_strength")], + initial_mean = config[("x", "initial_mean")], + initial_precision = config[("x", "initial_precision")], ), - Dict( - "name" => "xvol", - "type" => "continuous", - "volatility" => config[("xvol", "volatility")], - "drift" => config[("xvol", "drift")], - "autoregression_target" => config[("xvol", "autoregression_target")], - "autoregression_strength" => config[("xvol", "autoregression_strength")], - "initial_mean" => config[("xvol", "initial_mean")], - "initial_precision" => config[("xvol", "initial_precision")], + ContinuousState( + name = "xvol", + volatility = config[("xvol", "volatility")], + drift = config[("xvol", "drift")], + autoregression_target = config[("xvol", "autoregression_target")], + autoregression_strength = config[("xvol", "autoregression_strength")], + initial_mean = config[("xvol", "initial_mean")], + initial_precision = config[("xvol", "initial_precision")], ), - Dict( - "name" => "xnoise", - "type" => "continuous", - "volatility" => config[("xnoise", "volatility")], - "drift" => config[("xnoise", "drift")], - "autoregression_target" => config[("xnoise", "autoregression_target")], - "autoregression_strength" => config[("xnoise", "autoregression_strength")], - "initial_mean" => config[("xnoise", "initial_precision")], - "initial_precision" => config[("xnoise", "initial_precision")], + ContinuousState( + name = "xnoise", + volatility = config[("xnoise", "volatility")], + drift = config[("xnoise", "drift")], + autoregression_target = config[("xnoise", "autoregression_target")], + autoregression_strength = config[("xnoise", "autoregression_strength")], + initial_mean = config[("xnoise", "initial_mean")], + initial_precision = config[("xnoise", "initial_precision")], ), - Dict( - "name" => "xnoise_vol", - "type" => "continuous", - "volatility" => config[("xnoise_vol", "volatility")], - "drift" => config[("xnoise_vol", "drift")], - "autoregression_target" => config[("xnoise_vol", "autoregression_target")], - "autoregression_strength" => - config[("xnoise_vol", "autoregression_strength")], - "initial_mean" => config[("xnoise_vol", "initial_mean")], - "initial_precision" => config[("xnoise_vol", "initial_precision")], + ContinuousState( + name = "xnoise_vol", + volatility = config[("xnoise_vol", "volatility")], + drift = config[("xnoise_vol", "drift")], + autoregression_target = config[("xnoise_vol", "autoregression_target")], + autoregression_strength = config[("xnoise_vol", "autoregression_strength")], + initial_mean = config[("xnoise_vol", "initial_mean")], + initial_precision = config[("xnoise_vol", "initial_precision")], ), ] @@ -129,10 +117,9 @@ function premade_JGET(config::Dict; verbose::Bool = true) #Initialize the HGF init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, + nodes = nodes, edges = edges, verbose = false, - update_type = config["update_type"], + node_defaults = NodeDefaults(update_type = config["update_type"]), ) end diff --git a/src/premade_models/premade_hgfs/premade_binary_2level.jl b/src/premade_models/premade_hgfs/premade_binary_2level.jl index 1ac8f37..779d2dc 100644 --- a/src/premade_models/premade_hgfs/premade_binary_2level.jl +++ b/src/premade_models/premade_hgfs/premade_binary_2level.jl @@ -31,27 +31,18 @@ function premade_binary_2level(config::Dict; verbose::Bool = true) #Merge to overwrite defaults config = merge(spec_defaults, config) - - #List of input nodes to create - input_nodes = Dict( - "name" => "u", - "type" => "binary", - "category_means" => config[("u", "category_means")], - "input_precision" => config[("u", "input_precision")], - ) - - #List of state nodes to create - state_nodes = [ - Dict("name" => "xbin", "type" => "binary"), - Dict( - "name" => "xprob", - "type" => "continuous", - "volatility" => config[("xprob", "volatility")], - "drift" => config[("xprob", "drift")], - "autoregression_target" => config[("xprob", "autoregression_target")], - "autoregression_strength" => config[("xprob", "autoregression_strength")], - "initial_mean" => config[("xprob", "initial_mean")], - "initial_precision" => config[("xprob", "initial_precision")], + #List of nodes + nodes = [ + BinaryInput("u"), + BinaryState("xbin"), + ContinuousState( + name = "xprob", + volatility = config[("xprob", "volatility")], + drift = config[("xprob", "drift")], + autoregression_target = config[("xprob", "autoregression_target")], + autoregression_strength = config[("xprob", "autoregression_strength")], + initial_mean = config[("xprob", "initial_mean")], + initial_precision = config[("xprob", "initial_precision")], ), ] @@ -64,10 +55,9 @@ function premade_binary_2level(config::Dict; verbose::Bool = true) #Initialize the HGF init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, + nodes = nodes, edges = edges, verbose = false, - update_type = config["update_type"], + node_defaults = NodeDefaults(update_type = config["update_type"]), ) end diff --git a/src/premade_models/premade_hgfs/premade_binary_3level.jl b/src/premade_models/premade_hgfs/premade_binary_3level.jl index 90cb496..24cb05c 100644 --- a/src/premade_models/premade_hgfs/premade_binary_3level.jl +++ b/src/premade_models/premade_hgfs/premade_binary_3level.jl @@ -55,37 +55,27 @@ function premade_binary_3level(config::Dict; verbose::Bool = true) #Merge to overwrite defaults config = merge(defaults, config) - - #List of input nodes to create - input_nodes = Dict( - "name" => "u", - "type" => "binary", - "category_means" => config[("u", "category_means")], - "input_precision" => config[("u", "input_precision")], - ) - - #List of state nodes to create - state_nodes = [ - Dict("name" => "xbin", "type" => "binary"), - Dict( - "name" => "xprob", - "type" => "continuous", - "volatility" => config[("xprob", "volatility")], - "drift" => config[("xprob", "drift")], - "autoregression_target" => config[("xprob", "autoregression_target")], - "autoregression_strength" => config[("xprob", "autoregression_strength")], - "initial_mean" => config[("xprob", "initial_mean")], - "initial_precision" => config[("xprob", "initial_precision")], + #List of nodes + nodes = [ + BinaryInput("u"), + BinaryState("xbin"), + ContinuousState( + name = "xprob", + volatility = config[("xprob", "volatility")], + drift = config[("xprob", "drift")], + autoregression_target = config[("xprob", "autoregression_target")], + autoregression_strength = config[("xprob", "autoregression_strength")], + initial_mean = config[("xprob", "initial_mean")], + initial_precision = config[("xprob", "initial_precision")], ), - Dict( - "name" => "xvol", - "type" => "continuous", - "volatility" => config[("xvol", "volatility")], - "drift" => config[("xvol", "drift")], - "autoregression_target" => config[("xvol", "autoregression_target")], - "autoregression_strength" => config[("xvol", "autoregression_strength")], - "initial_mean" => config[("xvol", "initial_mean")], - "initial_precision" => config[("xvol", "initial_precision")], + ContinuousState( + name = "xvol", + volatility = config[("xvol", "volatility")], + drift = config[("xvol", "drift")], + autoregression_target = config[("xvol", "autoregression_target")], + autoregression_strength = config[("xvol", "autoregression_strength")], + initial_mean = config[("xvol", "initial_mean")], + initial_precision = config[("xvol", "initial_precision")], ), ] @@ -100,10 +90,9 @@ function premade_binary_3level(config::Dict; verbose::Bool = true) #Initialize the HGF init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, + nodes = nodes, edges = edges, verbose = false, - update_type = config["update_type"], + node_defaults = NodeDefaults(update_type = config["update_type"]), ) end diff --git a/src/premade_models/premade_hgfs/premade_categorical_3level.jl b/src/premade_models/premade_hgfs/premade_categorical_3level.jl index 6f58919..54d693c 100644 --- a/src/premade_models/premade_hgfs/premade_categorical_3level.jl +++ b/src/premade_models/premade_hgfs/premade_categorical_3level.jl @@ -73,30 +73,26 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) push!(probability_parent_names, "xprob_" * string(category_number)) end - ##List of input nodes - input_nodes = Dict("name" => "u", "type" => "categorical") - - ##List of state nodes - state_nodes = [Dict{String,Any}("name" => "xcat", "type" => "categorical")] + #Initialize list of nodes + nodes = [CategoricalInput("u"), CategoricalState("xcat")] #Add category node binary parents for node_name in category_parent_names - push!(state_nodes, Dict("name" => node_name, "type" => "binary")) + push!(nodes, BinaryState(node_name)) end #Add binary node continuous parents for node_name in probability_parent_names push!( - state_nodes, - Dict( - "name" => node_name, - "type" => "continuous", - "initial_mean" => config[("xprob", "initial_mean")], - "initial_precision" => config[("xprob", "initial_precision")], - "volatility" => config[("xprob", "volatility")], - "drift" => config[("xprob", "drift")], - "autoregression_target" => config[("xprob", "autoregression_target")], - "autoregression_strength" => config[("xprob", "autoregression_strength")], + nodes, + ContinuousState( + name = node_name, + volatility = config[("xprob", "volatility")], + drift = config[("xprob", "drift")], + autoregression_target = config[("xprob", "autoregression_target")], + autoregression_strength = config[("xprob", "autoregression_strength")], + initial_mean = config[("xprob", "initial_mean")], + initial_precision = config[("xprob", "initial_precision")], ), ) #Add the derived parameter name to derived parameters vector @@ -116,21 +112,19 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) #Add volatility parent push!( - state_nodes, - Dict( - "name" => "xvol", - "type" => "continuous", - "volatility" => config[("xvol", "volatility")], - "drift" => config[("xvol", "drift")], - "autoregression_target" => config[("xvol", "autoregression_target")], - "autoregression_strength" => config[("xvol", "autoregression_strength")], - "initial_mean" => config[("xvol", "initial_mean")], - "initial_precision" => config[("xvol", "initial_precision")], + nodes, + ContinuousState( + name = "xvol", + volatility = config[("xvol", "volatility")], + drift = config[("xvol", "drift")], + autoregression_target = config[("xvol", "autoregression_target")], + autoregression_strength = config[("xvol", "autoregression_strength")], + initial_mean = config[("xvol", "initial_mean")], + initial_precision = config[("xvol", "initial_precision")], ), ) - - ##List of child-parent relations + ##List of edges #Set the input node coupling edges = Dict{Tuple{String,String},CouplingType}(("u", "xcat") => ObservationCoupling()) @@ -197,11 +191,10 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) #Initialize the HGF init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, + nodes = nodes, edges = edges, shared_parameters = shared_parameters, verbose = false, - update_type = config["update_type"], + node_defaults = NodeDefaults(update_type = config["update_type"]), ) end diff --git a/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl index b4c5788..0e74804 100644 --- a/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl +++ b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl @@ -95,46 +95,40 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo end end - ##Create input nodes - #Initialize list - input_nodes = Vector{Dict}() + ##List of nodes + nodes = Vector{AbstractNodeInfo}() #For each categorical input node for node_name in input_node_names #Add it to the list - push!(input_nodes, Dict("name" => node_name, "type" => "categorical")) + push!(nodes, CategoricalInput(node_name)) end - ##Create state nodes - #Initialize list - state_nodes = Vector{Dict}() - #For each categorical state node for node_name in observation_parent_names #Add it to the list - push!(state_nodes, Dict("name" => node_name, "type" => "categorical")) + push!(nodes, CategoricalState(node_name)) end #For each categorical node binary parent for node_name in category_parent_names #Add it to the list - push!(state_nodes, Dict("name" => node_name, "type" => "binary")) + push!(nodes, BinaryState(node_name)) end #For each binary node continuous parent for node_name in probability_parent_names #Add it to the list, with parameter settings from the config push!( - state_nodes, - Dict( - "name" => node_name, - "type" => "continuous", - "initial_mean" => config[("xprob", "initial_mean")], - "initial_precision" => config[("xprob", "initial_precision")], - "volatility" => config[("xprob", "volatility")], - "drift" => config[("xprob", "drift")], - "autoregression_target" => config[("xprob", "autoregression_target")], - "autoregression_strength" => config[("xprob", "autoregression_strength")], + nodes, + ContinuousState( + name = node_name, + volatility = config[("xprob", "volatility")], + drift = config[("xprob", "drift")], + autoregression_target = config[("xprob", "autoregression_target")], + autoregression_strength = config[("xprob", "autoregression_strength")], + initial_mean = config[("xprob", "initial_mean")], + initial_precision = config[("xprob", "initial_precision")], ), ) #Add the derived parameter name to derived parameters vector @@ -155,16 +149,15 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo #Add the shared volatility parent of the continuous nodes push!( - state_nodes, - Dict( - "name" => "xvol", - "type" => "continuous", - "volatility" => config[("xvol", "volatility")], - "drift" => config[("xvol", "drift")], - "autoregression_target" => config[("xvol", "autoregression_target")], - "autoregression_strength" => config[("xvol", "autoregression_strength")], - "initial_mean" => config[("xvol", "initial_mean")], - "initial_precision" => config[("xvol", "initial_precision")], + nodes, + ContinuousState( + name = "xvol", + volatility = config[("xvol", "volatility")], + drift = config[("xvol", "drift")], + autoregression_target = config[("xvol", "autoregression_target")], + autoregression_strength = config[("xvol", "autoregression_strength")], + initial_mean = config[("xvol", "initial_mean")], + initial_precision = config[("xvol", "initial_precision")], ), ) @@ -262,11 +255,10 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo #Initialize the HGF init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, + nodes = nodes, edges = edges, shared_parameters = shared_parameters, verbose = false, - update_type = config["update_type"], + node_defaults = NodeDefaults(update_type = config["update_type"]), ) end diff --git a/src/premade_models/premade_hgfs/premade_continuous_2level.jl b/src/premade_models/premade_hgfs/premade_continuous_2level.jl index 389c93c..7eb03c3 100644 --- a/src/premade_models/premade_hgfs/premade_continuous_2level.jl +++ b/src/premade_models/premade_hgfs/premade_continuous_2level.jl @@ -44,35 +44,26 @@ function premade_continuous_2level(config::Dict; verbose::Bool = true) #Merge to overwrite defaults config = merge(spec_defaults, config) - - #List of input nodes to create - input_nodes = Dict( - "name" => "u", - "type" => "continuous", - "input_noise" => config[("u", "input_noise")], - ) - - #List of state nodes to create - state_nodes = [ - Dict( - "name" => "x", - "type" => "continuous", - "volatility" => config[("x", "volatility")], - "drift" => config[("x", "drift")], - "autoregression_target" => config[("x", "autoregression_target")], - "autoregression_strength" => config[("x", "autoregression_strength")], - "initial_mean" => config[("x", "initial_mean")], - "initial_precision" => config[("x", "initial_precision")], + #List of nodes + nodes = [ + ContinuousInput(name = "u", input_noise = config[("u", "input_noise")]), + ContinuousState( + name = "x", + volatility = config[("x", "volatility")], + drift = config[("x", "drift")], + autoregression_target = config[("x", "autoregression_target")], + autoregression_strength = config[("x", "autoregression_strength")], + initial_mean = config[("x", "initial_mean")], + initial_precision = config[("x", "initial_precision")], ), - Dict( - "name" => "xvol", - "type" => "continuous", - "volatility" => config[("xvol", "volatility")], - "drift" => config[("xvol", "drift")], - "autoregression_target" => config[("xvol", "autoregression_target")], - "autoregression_strength" => config[("xvol", "autoregression_strength")], - "initial_mean" => config[("xvol", "initial_mean")], - "initial_precision" => config[("xvol", "initial_precision")], + ContinuousState( + name = "xvol", + volatility = config[("xvol", "volatility")], + drift = config[("xvol", "drift")], + autoregression_target = config[("xvol", "autoregression_target")], + autoregression_strength = config[("xvol", "autoregression_strength")], + initial_mean = config[("xvol", "initial_mean")], + initial_precision = config[("xvol", "initial_precision")], ), ] @@ -84,10 +75,9 @@ function premade_continuous_2level(config::Dict; verbose::Bool = true) #Initialize the HGF init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, + nodes = nodes, edges = edges, verbose = false, - update_type = config["update_type"], + node_defaults = NodeDefaults(update_type = config["update_type"]), ) end diff --git a/src/update_hgf/node_updates/binary_input_node.jl b/src/update_hgf/node_updates/binary_input_node.jl index 2c8ea7c..3e05c18 100644 --- a/src/update_hgf/node_updates/binary_input_node.jl +++ b/src/update_hgf/node_updates/binary_input_node.jl @@ -24,35 +24,9 @@ end Update the value prediction error of a single binary input node. """ function update_node_value_prediction_error!(node::BinaryInputNode) - - #Calculate value prediction error - node.states.value_prediction_error = calculate_value_prediction_error(node) - push!(node.history.value_prediction_error, node.states.value_prediction_error) - return nothing end -@doc raw""" - calculate_value_prediction_error(node::BinaryInputNode) - -Calculates the prediciton error of a binary input node with finite precision. - -Uses the equation -`` \delta_n= u - \sum_{j=1}^{j\;value\;parents} \hat{\mu}_{j} `` -""" -function calculate_value_prediction_error(node::BinaryInputNode) - - #For missing input - if ismissing(node.states.input_value) - #Set the prediction error to missing - value_prediction_error = [missing, missing] - else - #Substract to find the difference to each of the Gaussian means - value_prediction_error = node.parameters.category_means .- node.states.input_value - end -end - - ################################################### ######## Update precision prediction error ######## ################################################### diff --git a/test/testsuite/test_initialization.jl b/test/testsuite/test_initialization.jl index 044b0a6..706e058 100644 --- a/test/testsuite/test_initialization.jl +++ b/test/testsuite/test_initialization.jl @@ -3,32 +3,29 @@ using Test @testset "Initialization" begin #Parameter values to be used for all nodes unless other values are given - node_defaults = Dict( - "volatility" => 3, - "input_noise" => -2, - "category_means" => [0, 1], - "input_precision" => Inf, - "initial_mean" => 1, - "initial_precision" => 2, - "coupling_strength" => 1, - "drift" => 2, + node_defaults = NodeDefaults( + volatility = 3, + input_noise = -2, + initial_mean = 1, + initial_precision = 2, + coupling_strength = 1, + drift = 2, ) - #List of input nodes to create - input_nodes = [Dict("name" => "u1", "input_noise" => 2), "u2"] - - #List of state nodes to create - state_nodes = [ - "x1", - "x2", - "x3", - Dict("name" => "x4", "volatility" => 2), - Dict( - "name" => "x5", - "volatility" => 2, - "initial_mean" => 4, - "initial_precision" => 3, - "drift" => 5, + #List of nodes + nodes = [ + ContinuousInput(name = "u1", input_noise = 2), + ContinuousInput(name = "u2"), + ContinuousState(name = "x1"), + ContinuousState(name = "x2"), + ContinuousState(name = "x3"), + ContinuousState(name = "x4", volatility = 2), + ContinuousState( + name = "x5", + volatility = 2, + initial_mean = 4, + initial_precision = 3, + drift = 5, ), ] @@ -44,8 +41,7 @@ using Test #Initialize an HGF test_hgf = init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, + nodes = nodes, edges = edges, node_defaults = node_defaults, verbose = false, diff --git a/test/testsuite/test_shared_parameters.jl b/test/testsuite/test_shared_parameters.jl index 1551d05..19ac42f 100644 --- a/test/testsuite/test_shared_parameters.jl +++ b/test/testsuite/test_shared_parameters.jl @@ -3,25 +3,11 @@ using Test # Test of custom HGF with shared parameters -#List of input nodes to create -input_nodes = Dict("name" => "u", "type" => "continuous", "input_noise" => 2) - -#List of state nodes to create -state_nodes = [ - Dict( - "name" => "x1", - "type" => "continuous", - "volatility" => 2, - "initial_mean" => 1, - "initial_precision" => 1, - ), - Dict( - "name" => "x2", - "type" => "continuous", - "volatility" => 2, - "initial_mean" => 1, - "initial_precision" => 1, - ), +#List of nodes +nodes = [ + ContinuousInput(name = "u", input_noise = 2), + ContinuousState(name = "x1", volatility = 2, initial_mean = 1, initial_precision = 1), + ContinuousState(name = "x2", volatility = 2, initial_mean = 1, initial_precision = 1), ] #List of child-parent relations @@ -32,12 +18,7 @@ shared_parameters_1 = Dict("volatilitys" => (9, [("x1", "volatility"), ("x2", "volatility")])) #Initialize the HGF -hgf_1 = init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - shared_parameters = shared_parameters_1, -) +hgf_1 = init_hgf(nodes = nodes, edges = edges, shared_parameters = shared_parameters_1) #get shared parameter get_parameters(hgf_1) @@ -54,12 +35,7 @@ shared_parameters_2 = Dict( #Initialize the HGF -hgf_2 = init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - shared_parameters = shared_parameters_2, -) +hgf_2 = init_hgf(nodes = nodes, edges = edges, shared_parameters = shared_parameters_2) #get all parameters get_parameters(hgf_2) From 4533a0cf6f0067cfe6e992be891c0ba1c0297c82 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Wed, 13 Dec 2023 00:01:51 +0100 Subject: [PATCH 05/16] fixed premade_agents --- docs/src/Julia_src_files/building_an_HGF.jl | 11 +- .../src/Julia_src_files/fitting_hgf_models.jl | 4 +- docs/src/Julia_src_files/utility_functions.jl | 2 +- docs/src/tutorials/classic_JGET.jl | 2 +- docs/src/tutorials/classic_binary.jl | 4 +- docs/src/tutorials/classic_usdchf.jl | 4 +- src/HierarchicalGaussianFiltering.jl | 13 +- src/premade_models/premade_action_models.jl | 257 ------------ src/premade_models/premade_agents.jl | 376 ------------------ .../premade_agents/premade_gaussian_action.jl | 99 +++++ .../premade_multiple_actions.jl | 1 - .../premade_predict_category.jl | 106 +++++ .../premade_agents/premade_sigmoid_action.jl | 106 +++++ .../premade_agents/premade_softmax_action.jl | 105 +++++ test/runtests.jl | 10 +- test/testsuite/test_fit_model.jl | 4 +- 16 files changed, 442 insertions(+), 662 deletions(-) delete mode 100644 src/premade_models/premade_action_models.jl delete mode 100644 src/premade_models/premade_agents.jl delete mode 100644 src/premade_models/premade_agents/premade_multiple_actions.jl diff --git a/docs/src/Julia_src_files/building_an_HGF.jl b/docs/src/Julia_src_files/building_an_HGF.jl index d1b35be..8cd6451 100644 --- a/docs/src/Julia_src_files/building_an_HGF.jl +++ b/docs/src/Julia_src_files/building_an_HGF.jl @@ -80,7 +80,7 @@ function binary_softmax_action(agent, input) target_state = agent.settings["target_state"] ##Take out the parameter from our agent - action_precision = agent.parameters["softmax_action_precision"] + action_noise = agent.parameters["action_noise"] ##Get the specified state out of the hgf target_value = get_states(hgf, target_state) @@ -88,7 +88,7 @@ function binary_softmax_action(agent, input) ##--------------- Update step starts ----------------- ##Use sotmax to get the action probability - action_probability = 1 / (1 + exp(-action_precision * target_value)) + action_probability = 1 / (1 + exp(action_noise * target_value)) ##---------------- Update step end ------------------ ##If the action probability is not between 0 and 1 @@ -121,7 +121,7 @@ action_model = binary_softmax_action; # The parameter of the agent is just softmax action precision. We set this value to 1 -parameters = Dict("softmax_action_precision" => 1); +parameters = Dict("action_noise" => 1); # The states of the agent are empty, but the states from the HGF will be accessible. @@ -129,10 +129,7 @@ states = Dict() # In the settings we specify what our target state is. We want it to be the prediction mean of our binary state node. -settings = Dict( - "hgf_actions" => "softmax_action", - "target_state" => ("binary_state_node", "prediction_mean"), -); +settings = Dict("target_state" => ("binary_state_node", "prediction_mean")); ## Let's initialize our agent agent = init_agent( diff --git a/docs/src/Julia_src_files/fitting_hgf_models.jl b/docs/src/Julia_src_files/fitting_hgf_models.jl index 8e4b0b9..9f89554 100644 --- a/docs/src/Julia_src_files/fitting_hgf_models.jl +++ b/docs/src/Julia_src_files/fitting_hgf_models.jl @@ -60,7 +60,7 @@ hgf_parameters = Dict( hgf = premade_hgf("binary_3level", hgf_parameters, verbose = false) # Create an agent -agent_parameters = Dict("sigmoid_action_precision" => 5); +agent_parameters = Dict("action_noise" => 0.2); agent = premade_agent("hgf_unit_square_sigmoid_action", hgf, agent_parameters, verbose = false); @@ -86,7 +86,7 @@ plot_trajectory!(agent, ("xbin", "prediction")) # Set fixed parameters. We choose to fit the evolution rate of the xprob node. fixed_parameters = Dict( - "sigmoid_action_precision" => 5, + "action_noise" => 0.2, ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, ("xprob", "initial_mean") => 0, diff --git a/docs/src/Julia_src_files/utility_functions.jl b/docs/src/Julia_src_files/utility_functions.jl index 438b2d0..434f650 100644 --- a/docs/src/Julia_src_files/utility_functions.jl +++ b/docs/src/Julia_src_files/utility_functions.jl @@ -60,7 +60,7 @@ get_states( # you can set parameters before you initialize your agent, you can set them after and change them when you wish to. # Let's try an initialize a new agent with parameters. We start by choosing the premade unit square sigmoid action agent whose parameter is sigmoid action precision. -agent_parameter = Dict("sigmoid_action_precision" => 3) +agent_parameter = Dict("action_noise" => 0.3) #We also specify our HGF and custom parameter settings: diff --git a/docs/src/tutorials/classic_JGET.jl b/docs/src/tutorials/classic_JGET.jl index 700313d..1d1fb27 100644 --- a/docs/src/tutorials/classic_JGET.jl +++ b/docs/src/tutorials/classic_JGET.jl @@ -19,7 +19,7 @@ hgf = premade_hgf("JGET", verbose = false) agent = premade_agent("hgf_gaussian_action", hgf) #Set parameters parameters = Dict( - "gaussian_action_precision" => 1, + "action_noise" => 1, ("u", "input_noise") => 0, ("x", "initial_mean") => first(inputs) + 2, ("x", "initial_precision") => 0.001, diff --git a/docs/src/tutorials/classic_binary.jl b/docs/src/tutorials/classic_binary.jl index 80e39a2..d4c7f65 100644 --- a/docs/src/tutorials/classic_binary.jl +++ b/docs/src/tutorials/classic_binary.jl @@ -36,7 +36,7 @@ hgf_parameters = Dict( hgf = premade_hgf("binary_3level", hgf_parameters, verbose = false); # Create an agent -agent_parameters = Dict("sigmoid_action_precision" => 5); +agent_parameters = Dict("action_noise" => 0.2); agent = premade_agent("hgf_unit_square_sigmoid_action", hgf, agent_parameters, verbose = false); @@ -55,7 +55,7 @@ plot_trajectory(agent, ("xvol", "posterior")) # Set fixed parameters fixed_parameters = Dict( - "sigmoid_action_precision" => 5, + "action_noise" => 0.2, ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, ("xprob", "initial_mean") => 0, diff --git a/docs/src/tutorials/classic_usdchf.jl b/docs/src/tutorials/classic_usdchf.jl index 5d38109..25157c2 100644 --- a/docs/src/tutorials/classic_usdchf.jl +++ b/docs/src/tutorials/classic_usdchf.jl @@ -36,7 +36,7 @@ parameters = Dict( ("x", "initial_precision") => 1 / (0.0001), ("xvol", "initial_mean") => 1.0, ("xvol", "initial_precision") => 1 / 0.1, - "gaussian_action_precision" => 100, + "action_noise" => 0.01, ); set_parameters!(agent, parameters) @@ -85,7 +85,7 @@ fixed_parameters = Dict( ("x", "initial_precision") => 2000, ("xvol", "initial_mean") => 1.0, ("xvol", "initial_precision") => 600.0, - "gaussian_action_precision" => 100, + "action_noise" => 0.01, ); param_priors = Dict( diff --git a/src/HierarchicalGaussianFiltering.jl b/src/HierarchicalGaussianFiltering.jl index d5969c4..cce12b3 100644 --- a/src/HierarchicalGaussianFiltering.jl +++ b/src/HierarchicalGaussianFiltering.jl @@ -7,11 +7,7 @@ using ActionModels, Distributions, RecipesBase export init_node, init_hgf, premade_hgf, check_hgf, check_node, update_hgf! export get_prediction, get_surprise, hgf_multiple_actions export premade_agent, - init_agent, - multiple_actions, - plot_predictive_simulation, - plot_trajectory, - plot_trajectory! + init_agent, plot_predictive_simulation, plot_trajectory, plot_trajectory! export get_history, get_parameters, get_states, set_parameters!, reset!, give_inputs! export EnhancedUpdate, ClassicUpdate export NodeDefaults @@ -65,8 +61,11 @@ include("create_hgf/create_premade_hgf.jl") #Plotting functions #Functions for premade agents -include("premade_models/premade_action_models.jl") -include("premade_models/premade_agents.jl") +include("premade_models/premade_agents/premade_gaussian_action.jl") +include("premade_models/premade_agents/premade_predict_category.jl") +include("premade_models/premade_agents/premade_sigmoid_action.jl") +include("premade_models/premade_agents/premade_softmax_action.jl") + include("premade_models/premade_hgfs/premade_binary_2level.jl") include("premade_models/premade_hgfs/premade_binary_3level.jl") include("premade_models/premade_hgfs/premade_categorical_3level.jl") diff --git a/src/premade_models/premade_action_models.jl b/src/premade_models/premade_action_models.jl deleted file mode 100644 index eb5c330..0000000 --- a/src/premade_models/premade_action_models.jl +++ /dev/null @@ -1,257 +0,0 @@ -###### Multiple Actions ###### -""" - update_hgf_multiple_actions(agent::Agent, input) - -Action model that first updates the HGF, and then runs multiple action models. -""" -function update_hgf_multiple_actions(agent::Agent, input) - - #Update the hgf - hgf = agent.substruct - update_hgf!(hgf, input) - - #Extract vector of action models - action_models = agent.settings["hgf_actions"] - - #Initialize vector for action distributions - action_distributions = [] - - #Do each action model separately - for action_model in action_models - #And append them to the vector of action distributions - push!(action_distributions, action_model(agent, input)) - end - - return action_distributions -end - - -###### Gaussian Action ###### -""" - update_hgf_gaussian_action(agent::Agent, input) - -Action model that first updates the HGF, and then reports a given HGF state with Gaussian noise. - -In addition to the HGF substruct, the following must be present in the agent: -Parameters: "gaussian_action_precision" -Settings: "target_state" -""" -function update_hgf_gaussian_action(agent::Agent, input) - - #Update the HGF - update_hgf!(agent.substruct, input) - - #Run the action model - action_distribution = hgf_gaussian_action(agent, input) - - return action_distribution -end - -""" - hgf_gaussian_action(agent::Agent, input) - -Action model which reports a given HGF state with Gaussian noise. - -In addition to the HGF substruct, the following must be present in the agent: -Parameters: "gaussian_action_precision" -Settings: "target_state" -""" -function hgf_gaussian_action(agent::Agent, input) - - #Get out hgf, settings and parameters - hgf = agent.substruct - target_state = agent.settings["target_state"] - action_precision = agent.parameters["gaussian_action_precision"] - - #Get the specified state - action_mean = get_states(hgf, target_state) - - #If the gaussian mean becomes a NaN - if isnan(action_mean) - #Throw an error that will reject samples when fitted - throw( - RejectParameters( - "With these parameters and inputs, the mean of the gaussian action became $action_mean, which is invalid. Try other parameter settings", - ), - ) - end - - #Create normal distribution with mean of the target value and a standard deviation from parameters - distribution = Distributions.Normal(action_mean, 1 / action_precision) - - #Return the action distribution - return distribution -end - - -###### Softmax Action ###### -""" - update_hgf_softmax_action(agent::Agent, input) - -Action model that first updates the HGF, and then passes a state from the HGF through a softmax to give a binary action. - -In addition to the HGF substruct, the following must be present in the agent: -Parameters: "softmax_action_precision" -Settings: "target_state" -""" -function update_hgf_binary_softmax_action(agent::Agent, input) - - #Update the HGF - update_hgf!(agent.substruct, input) - - #Run the action model - action_distribution = hgf_binary_softmax_action(agent, input) - - return action_distribution -end - -""" - hgf_binary_softmax_action(agent, input) - -Action model which gives a binary action. The action probability is the softmax of a specified state of a node. - -In addition to the HGF substruct, the following must be present in the agent: -Parameters: "softmax_action_precision" -Settings: "target_state" -""" -function hgf_binary_softmax_action(agent::Agent, input) - - #Get out HGF, settings and parameters - hgf = agent.substruct - target_state = agent.settings["target_state"] - action_precision = agent.parameters["softmax_action_precision"] - - #Get the specified state - target_value = get_states(hgf, target_state) - - #Use sotmax to get the action probability - action_probability = 1 / (1 + exp(-action_precision * target_value)) - - #If the action probability is not between 0 and 1 - if !(0 <= action_probability <= 1) - #Throw an error that will reject samples when fitted - throw( - RejectParameters( - "With these parameters and inputs, the action probability became $action_probability, which should be between 0 and 1. Try other parameter settings", - ), - ) - end - - #Create Bernoulli normal distribution with mean of the target value and a standard deviation from parameters - distribution = Distributions.Bernoulli(action_probability) - - #Return the action distribution - return distribution -end - -###### Unit Square Sigmoid Action ###### -""" - update_hgf_unit_square_sigmoid_action(agent::Agent, input) - -Action model that first updates the HGF, and then passes a state from the HGF through a unit square sigmoid transform to give a binary action. - -In addition to the HGF substruct, the following must be present in the agent: -Parameters: "sigmoid_action_precision" -Settings: "target_state" -""" -function update_hgf_unit_square_sigmoid_action(agent::Agent, input) - - #Update the HGF - update_hgf!(agent.substruct, input) - - #Run the action model - action_distribution = hgf_unit_square_sigmoid_action(agent, input) - - return action_distribution -end - -""" - unit_square_sigmoid_action(agent, input) - -Action model which gives a binary action. The action probability is the unit square sigmoid of a specified state of a node. - -In addition to the HGF substruct, the following must be present in the agent: -Parameters: "sigmoid_action_precision" -Settings: "target_state" -""" -function hgf_unit_square_sigmoid_action(agent::Agent, input) - - #Get out settings and parameters - target_state = agent.settings["target_state"] - action_precision = agent.parameters["sigmoid_action_precision"] - - #Get out the HGF - hgf = agent.substruct - - #Get the specified state - target_value = get_states(hgf, target_state) - - #Use softmax to get the action probability - action_probability = - target_value^action_precision / - (target_value^action_precision + (1 - target_value)^action_precision) - - #If the action probability is not between 0 and 1 - if !(0 <= action_probability <= 1) - #Throw an error that will reject samples when fitted - throw( - RejectParameters( - "With these parameters and inputs, the action probability became $action_probability, which should be between 0 and 1. Try other parameter settings", - ), - ) - end - - #Create Bernoulli normal distribution with mean of the target value and a standard deviation from parameters - distribution = Distributions.Bernoulli(action_probability) - - #Return the action distribution - return distribution -end - - - -###### Categorical Prediction Action ###### -""" - update_hgf_predict_category_action(agent::Agent, input) - -Action model that first updates the HGF, and then returns a categorical prediction of the input. The HGF used must be a categorical HGF. - -In addition to the HGF substruct, the following must be present in the agent: -Settings: "target_categorical_node" -""" -function update_hgf_predict_category_action(agent::Agent, input) - - #Update the HGF - update_hgf!(agent.substruct, input) - - #Run the action model - action_distribution = hgf_predict_category_action(agent, input) - - return action_distribution -end - -""" - hgf_predict_category_action(agent::Agent, input) - -Action model which gives a categorical prediction of the input, based on an HGF. The HGF used must be a categorical HGF. - -In addition to the HGF substruct, the following must be present in the agent: -Settings: "target_categorical_node" -""" -function hgf_predict_category_action(agent::Agent, input) - - #Get out settings and parameters - target_node = agent.settings["target_categorical_node"] - - #Get out the HGF - hgf = agent.substruct - - #Get the specified state - predicted_category_probabilities = get_states(hgf, (target_node, "prediction")) - - #Create Bernoulli normal distribution with mean of the target value and a standard deviation from parameters - distribution = Distributions.Categorical(predicted_category_probabilities) - - #Return the action distribution - return distribution -end diff --git a/src/premade_models/premade_agents.jl b/src/premade_models/premade_agents.jl deleted file mode 100644 index ffb2d8d..0000000 --- a/src/premade_models/premade_agents.jl +++ /dev/null @@ -1,376 +0,0 @@ -""" - premade_hgf_multiple_actions(config::Dict) - -Create an agent suitable for multiple aciton models that can depend on an HGF substruct. The used action models are specified as a vector. - -# Config defaults: - - "HGF": "continuous_2level" - - "hgf_actions": ["gaussian_action", "softmax_action", "unit_square_sigmoid_action"] -""" -function premade_hgf_multiple_actions(config::Dict) - - ## Combine defaults and user settings - - #Default parameters and settings - defaults = Dict( - "HGF" => "continuous_2level", - "hgf_actions" => - ["gaussian_action", "softmax_action", "unit_square_sigmoid_action"], - ) - - #If there is no HGF in the user-set parameters - if !("HGF" in keys(config)) - HGF_name = defaults["HGF"] - #Make a default HGF - config["HGF"] = premade_hgf(HGF_name) - #And warn them - @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" - end - - #Warn the user about used defaults and misspecified keys - warn_premade_defaults(defaults, config) - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ## Create agent - #Set the action model - action_model = update_hgf_multiple_actions - - #Set the HGF - hgf = config["HGF"] - - #Set parameters - parameters = Dict() - #Set states - states = Dict() - #Set settings - settings = Dict("hgf_actions" => config["hgf_actions"]) - - - ## Add parameters for each action type - for action_string in config["hgf_actions"] - - #Parameters for the gaussian action - if action_string == "gaussian_action" - - #Action precision parameter - if "gaussian_action_precision" in keys(config) - parameters["gaussian_action_precision"] = - config["gaussian_action_precision"] - else - default_action_precision = 1 - parameters["gaussian_action_precision"] = default_action_precision - @warn "parameter gaussian_action_precision was not set by the user. Using the default: $default_action_precision" - end - - #Target state setting - if "gaussian_target_state" in keys(config) - settings["gaussian_target_state"] = config["gaussian_target_state"] - else - default_target_state = ("x", "posterior_mean") - settings["gaussian_target_state"] = default_target_state - @warn "setting gaussian_target_state was not set by the user. Using the default: $default_target_state" - end - - #Parameters for the softmax action - elseif action_string == "softmax_action" - - #Action precision parameter - if "softmax_action_precision" in keys(config) - parameters["softmax_action_precision"] = config["softmax_action_precision"] - else - default_action_precision = 1 - parameters["softmax_action_precision"] = default_action_precision - @warn "parameter softmax_action_precision was not set by the user. Using the default: $default_action_precision" - end - - #Target state setting - if "softmax_target_state" in keys(config) - settings["softmax_target_state"] = config["softmax_target_state"] - else - default_target_state = ("xbin", "prediction_mean") - settings["softmax_target_state"] = default_target_state - @warn "setting softmax_target_state was not set by the user. Using the default: $default_target_state" - end - - #Parameters for the unit square sigmoid action - elseif action_string == "unit_square_sigmoid_action" - - #Action precision parameter - if "sigmoid_action_precision" in keys(config) - parameters["sigmoid_action_precision"] = config["sigmoid_action_precision"] - else - default_action_precision = 1 - parameters["sigmoid_action_precision"] = default_action_precision - @warn "parameter sigmoid_action_precision was not set by the user. Using the default: $default_action_precision" - end - - #Target state setting - if "sigmoid_target_state" in keys(config) - settings["sigmoid_target_state"] = config["sigmoid_target_state"] - else - default_target_state = ("xbin", "prediction_mean") - settings["sigmoid_target_state"] = default_target_state - @warn "setting sigmoid_target_state was not set by the user. Using the default: $default_target_state" - end - - else - throw( - ArgumentError( - "$action_string is not a valid action type. Valid action types are: gaussian_action, softmax_action, unit_square_sigmoid_action", - ), - ) - end - end - - #Create the agent - return init_agent( - action_model; - substruct = hgf, - parameters = parameters, - states = states, - settings = settings, - ) -end - - -""" - premade_hgf_gaussian(config::Dict) - -Create an agent suitable for the HGF Gaussian action model. - -# Config defaults: - - "HGF": "continuous_2level" - - "gaussian_action_precision": 1 - - "target_state": ("x", "posterior_mean") -""" -function premade_hgf_gaussian(config::Dict) - - ## Combine defaults and user settings - - #Default parameters and settings - defaults = Dict( - "gaussian_action_precision" => 1, - "target_state" => ("x", "posterior_mean"), - "HGF" => "continuous_2level", - ) - - #If there is no HGF in the user-set parameters - if !("HGF" in keys(config)) - HGF_name = defaults["HGF"] - #Make a default HGF - config["HGF"] = premade_hgf(HGF_name) - #And warn them - @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" - end - - #Warn the user about used defaults and misspecified keys - warn_premade_defaults(defaults, config) - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ## Create agent - #Set the action model - action_model = update_hgf_gaussian_action - - #Set the HGF - hgf = config["HGF"] - - #Set parameters - parameters = Dict("gaussian_action_precision" => config["gaussian_action_precision"]) - #Set states - states = Dict() - #Set settings - settings = Dict("target_state" => config["target_state"]) - - #Create the agent - return init_agent( - action_model; - substruct = hgf, - parameters = parameters, - states = states, - settings = settings, - ) -end - -""" - premade_hgf_binary_softmax(config::Dict) - -Create an agent suitable for the HGF binary softmax model. - -# Config defaults: - - "HGF": "binary_3level" - - "softmax_action_precision": 1 - - "target_state": ("xbin", "prediction_mean") -""" -function premade_hgf_binary_softmax(config::Dict) - - ## Combine defaults and user settings - - #Default parameters and settings - defaults = Dict( - "softmax_action_precision" => 1, - "target_state" => ("xbin", "prediction_mean"), - "HGF" => "binary_3level", - ) - - #If there is no HGF in the user-set parameters - if !("HGF" in keys(config)) - HGF_name = defaults["HGF"] - #Make a default HGF - config["HGF"] = premade_hgf(HGF_name) - #And warn them - @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" - end - - #Warn the user about used defaults and misspecified keys - warn_premade_defaults(defaults, config) - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ## Create agent - #Set the action model - action_model = update_hgf_binary_softmax_action - - #Set the HGF - hgf = config["HGF"] - - #Set parameters - parameters = Dict("softmax_action_precision" => config["softmax_action_precision"]) - #Set states - states = Dict() - #Set settings - settings = Dict("target_state" => config["target_state"]) - - #Create the agent - return init_agent( - action_model, - substruct = hgf, - parameters = parameters, - states = states, - settings = settings, - ) -end - -""" - premade_hgf_binary_softmax(config::Dict) - -Create an agent suitable for the HGF unit square sigmoid model. - -# Config defaults: - - "HGF": "binary_3level" - - "sigmoid_action_precision": 1 - - "target_state": ("xbin", "prediction_mean") -""" -function premade_hgf_unit_square_sigmoid(config::Dict) - - ## Combine defaults and user settings - - #Default parameters and settings - defaults = Dict( - "sigmoid_action_precision" => 1, - "target_state" => ("xbin", "prediction_mean"), - "HGF" => "binary_3level", - ) - - #If there is no HGF in the user-set parameters - if !("HGF" in keys(config)) - HGF_name = defaults["HGF"] - #Make a default HGF - config["HGF"] = premade_hgf(HGF_name) - #And warn them - @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" - end - - #Warn the user about used defaults and misspecified keys - warn_premade_defaults(defaults, config) - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ## Create agent - #Set the action model - action_model = update_hgf_unit_square_sigmoid_action - - #Set the HGF - hgf = config["HGF"] - - #Set parameters - parameters = Dict("sigmoid_action_precision" => config["sigmoid_action_precision"]) - #Set states - states = Dict() - #Set settings - settings = Dict("target_state" => config["target_state"]) - - #Create the agent - return init_agent( - action_model, - substruct = hgf, - parameters = parameters, - states = states, - settings = settings, - ) -end - -""" - premade_hgf_predict_category(config::Dict) - -Create an agent suitable for the HGF predict category model. - -# Config defaults: - - "HGF": "categorical_3level" - - "target_categorical_node": "xcat" -""" -function premade_hgf_predict_category(config::Dict) - - ## Combine defaults and user settings - - #Default parameters and settings - defaults = Dict("target_categorical_node" => "xcat", "HGF" => "categorical_3level") - - #If there is no HGF in the user-set parameters - if !("HGF" in keys(config)) - HGF_name = defaults["HGF"] - #Make a default HGF - config["HGF"] = premade_hgf(HGF_name) - #And warn them - @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" - end - - #Warn the user about used defaults and misspecified keys - warn_premade_defaults(defaults, config) - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ## Create agent - #Set the action model - action_model = update_hgf_predict_category_action - - #Set the HGF - hgf = config["HGF"] - - #Set parameters - parameters = Dict() - #Set states - states = Dict() - #Set settings - settings = Dict("target_categorical_node" => config["target_categorical_node"]) - - #Create the agent - return init_agent( - action_model, - substruct = hgf, - parameters = parameters, - states = states, - settings = settings, - ) -end diff --git a/src/premade_models/premade_agents/premade_gaussian_action.jl b/src/premade_models/premade_agents/premade_gaussian_action.jl index 8b13789..e5ff210 100644 --- a/src/premade_models/premade_agents/premade_gaussian_action.jl +++ b/src/premade_models/premade_agents/premade_gaussian_action.jl @@ -1 +1,100 @@ +""" + hgf_gaussian_action(agent::Agent, input) +Action model which reports a given HGF state with Gaussian noise. + +In addition to the HGF substruct, the following must be present in the agent: +Parameters: "gaussian_action_precision" +Settings: "target_state" +""" +function hgf_gaussian_action(agent::Agent, input) + + #Extract HGF, settings and parameters + hgf = agent.substruct + target_state = agent.settings["target_state"] + action_noise = agent.parameters["action_noise"] + + #Update the HGF + update_hgf!(agent.substruct, input) + + #Extract specified belief state + action_mean = get_states(hgf, target_state) + + #If the gaussian mean becomes a NaN + if isnan(action_mean) + #Throw an error that will reject samples when fitted + throw( + RejectParameters( + "With these parameters and inputs, the mean of the gaussian action became $action_mean, which is invalid. Try other parameter settings", + ), + ) + end + + #Create normal distribution with mean of the target value and a standard deviation from parameters + distribution = Distributions.Normal(action_mean, action_noise) + + #Return the action distribution + return distribution +end + + +""" + premade_hgf_gaussian(config::Dict) + +Create an agent suitable for the HGF Gaussian action model. + +# Config defaults: + - "HGF": "continuous_2level" + - "gaussian_action_precision": 1 + - "target_state": ("x", "posterior_mean") +""" +function premade_hgf_gaussian(config::Dict) + + ## Combine defaults and user settings + + #Default parameters and settings + defaults = Dict( + "action_noise" => 1, + "target_state" => ("x", "posterior_mean"), + "HGF" => "continuous_2level", + ) + + #If there is no HGF in the user-set parameters + if !("HGF" in keys(config)) + HGF_name = defaults["HGF"] + #Make a default HGF + config["HGF"] = premade_hgf(HGF_name) + #And warn them + @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" + end + + #Warn the user about used defaults and misspecified keys + warn_premade_defaults(defaults, config) + + #Merge to overwrite defaults + config = merge(defaults, config) + + + ## Create agent + #Set the action model + action_model = hgf_gaussian_action + + #Set the HGF + hgf = config["HGF"] + + #Set parameters + parameters = Dict("action_noise" => config["action_noise"]) + #Set states + states = Dict() + #Set settings + settings = Dict("target_state" => config["target_state"]) + + #Create the agent + return init_agent( + action_model; + substruct = hgf, + parameters = parameters, + states = states, + settings = settings, + ) +end diff --git a/src/premade_models/premade_agents/premade_multiple_actions.jl b/src/premade_models/premade_agents/premade_multiple_actions.jl deleted file mode 100644 index 8b13789..0000000 --- a/src/premade_models/premade_agents/premade_multiple_actions.jl +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/premade_models/premade_agents/premade_predict_category.jl b/src/premade_models/premade_agents/premade_predict_category.jl index 8b13789..b98dde7 100644 --- a/src/premade_models/premade_agents/premade_predict_category.jl +++ b/src/premade_models/premade_agents/premade_predict_category.jl @@ -1 +1,107 @@ + +###### Categorical Prediction Action ###### +""" + update_hgf_predict_category_action(agent::Agent, input) + +Action model that first updates the HGF, and then returns a categorical prediction of the input. The HGF used must be a categorical HGF. + +In addition to the HGF substruct, the following must be present in the agent: +Settings: "target_categorical_node" +""" +function update_hgf_predict_category_action(agent::Agent, input) + + #Update the HGF + update_hgf!(agent.substruct, input) + + #Run the action model + action_distribution = hgf_predict_category_action(agent, input) + + return action_distribution +end + +""" + hgf_predict_category_action(agent::Agent, input) + +Action model which gives a categorical prediction of the input, based on an HGF. The HGF used must be a categorical HGF. + +In addition to the HGF substruct, the following must be present in the agent: +Settings: "target_categorical_node" +""" +function hgf_predict_category_action(agent::Agent, input) + + #Get out settings and parameters + target_node = agent.settings["target_categorical_node"] + + #Get out the HGF + hgf = agent.substruct + + #Get the specified state + predicted_category_probabilities = get_states(hgf, (target_node, "prediction")) + + #Create Bernoulli normal distribution with mean of the target value and a standard deviation from parameters + distribution = Distributions.Categorical(predicted_category_probabilities) + + #Return the action distribution + return distribution +end + + + + + +""" + premade_hgf_predict_category(config::Dict) + +Create an agent suitable for the HGF predict category model. + +# Config defaults: + - "HGF": "categorical_3level" + - "target_categorical_node": "xcat" +""" +function premade_hgf_predict_category(config::Dict) + + ## Combine defaults and user settings + + #Default parameters and settings + defaults = Dict("target_categorical_node" => "xcat", "HGF" => "categorical_3level") + + #If there is no HGF in the user-set parameters + if !("HGF" in keys(config)) + HGF_name = defaults["HGF"] + #Make a default HGF + config["HGF"] = premade_hgf(HGF_name) + #And warn them + @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" + end + + #Warn the user about used defaults and misspecified keys + warn_premade_defaults(defaults, config) + + #Merge to overwrite defaults + config = merge(defaults, config) + + + ## Create agent + #Set the action model + action_model = update_hgf_predict_category_action + + #Set the HGF + hgf = config["HGF"] + + #Set parameters + parameters = Dict() + #Set states + states = Dict() + #Set settings + settings = Dict("target_categorical_node" => config["target_categorical_node"]) + + #Create the agent + return init_agent( + action_model, + substruct = hgf, + parameters = parameters, + states = states, + settings = settings, + ) +end diff --git a/src/premade_models/premade_agents/premade_sigmoid_action.jl b/src/premade_models/premade_agents/premade_sigmoid_action.jl index 8b13789..e702dc9 100644 --- a/src/premade_models/premade_agents/premade_sigmoid_action.jl +++ b/src/premade_models/premade_agents/premade_sigmoid_action.jl @@ -1 +1,107 @@ +""" + unit_square_sigmoid_action(agent, input) +Action model which gives a binary action. The action probability is the unit square sigmoid of a specified state of a node. + +In addition to the HGF substruct, the following must be present in the agent: +Parameters: "action_precision" +Settings: "target_state" +""" +function hgf_unit_square_sigmoid_action(agent::Agent, input) + + #Extract HGF, settings and parameters + hgf = agent.substruct + target_state = agent.settings["target_state"] + action_noise = agent.parameters["action_noise"] + + #Update the HGF + update_hgf!(hgf, input) + + #Get the specified state + target_value = get_states(hgf, target_state) + + #Use softmax to get the action probability + action_precision = 1 / action_noise + + action_probability = + target_value^action_precision / + (target_value^action_precision + (1 - target_value)^action_precision) + + #If the action probability is not between 0 and 1 + if !(0 <= action_probability <= 1) + #Throw an error that will reject samples when fitted + throw( + RejectParameters( + "With these parameters and inputs, the action probability became $action_probability, which should be between 0 and 1. Try other parameter settings", + ), + ) + end + + #Create Bernoulli normal distribution with mean of the target value and a standard deviation from parameters + distribution = Distributions.Bernoulli(action_probability) + + #Return the action distribution + return distribution +end + + +""" + premade_hgf_binary_softmax(config::Dict) + +Create an agent suitable for the HGF unit square sigmoid model. + +# Config defaults: + - "HGF": "binary_3level" + - "sigmoid_action_precision": 1 + - "target_state": ("xbin", "prediction_mean") +""" +function premade_hgf_unit_square_sigmoid(config::Dict) + + ## Combine defaults and user settings + + #Default parameters and settings + defaults = Dict( + "action_noise" => 1, + "target_state" => ("xbin", "prediction_mean"), + "HGF" => "binary_3level", + ) + + #If there is no HGF in the user-set parameters + if !("HGF" in keys(config)) + HGF_name = defaults["HGF"] + #Make a default HGF + config["HGF"] = premade_hgf(HGF_name) + #And warn them + @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" + end + + #Warn the user about used defaults and misspecified keys + warn_premade_defaults(defaults, config) + + #Merge to overwrite defaults + config = merge(defaults, config) + + + ## Create agent + #Set the action model + action_model = hgf_unit_square_sigmoid_action + + #Set the HGF + hgf = config["HGF"] + + #Set parameters + parameters = Dict("action_noise" => config["action_noise"]) + #Set states + states = Dict() + #Set settings + settings = Dict("target_state" => config["target_state"]) + + #Create the agent + return init_agent( + action_model, + substruct = hgf, + parameters = parameters, + states = states, + settings = settings, + ) +end diff --git a/src/premade_models/premade_agents/premade_softmax_action.jl b/src/premade_models/premade_agents/premade_softmax_action.jl index 8b13789..e72c336 100644 --- a/src/premade_models/premade_agents/premade_softmax_action.jl +++ b/src/premade_models/premade_agents/premade_softmax_action.jl @@ -1 +1,106 @@ +""" + hgf_binary_softmax_action(agent, input) +Action model which gives a binary action. The action probability is the softmax of a specified state of a node. + +In addition to the HGF substruct, the following must be present in the agent: +Parameters: "softmax_action_precision" +Settings: "target_state" +""" +function hgf_binary_softmax_action(agent::Agent, input) + + #Get out HGF, settings and parameters + hgf = agent.substruct + target_state = agent.settings["target_state"] + action_noise = agent.parameters["action_noise"] + + #Update the HGF + update_hgf!(hgf, input) + + #Get the specified state + target_value = get_states(hgf, target_state) + + #Use sotmax to get the action probability + action_probability = 1 / (1 + exp(action_noise * target_value)) + + #If the action probability is not between 0 and 1 + if !(0 <= action_probability <= 1) + #Throw an error that will reject samples when fitted + throw( + RejectParameters( + "With these parameters and inputs, the action probability became $action_probability, which should be between 0 and 1. Try other parameter settings", + ), + ) + end + + #Create Bernoulli normal distribution with mean of the target value and a standard deviation from parameters + distribution = Distributions.Bernoulli(action_probability) + + #Return the action distribution + return distribution +end + + + + + +""" + premade_hgf_binary_softmax(config::Dict) + +Create an agent suitable for the HGF binary softmax model. + +# Config defaults: + - "HGF": "binary_3level" + - "softmax_action_precision": 1 + - "target_state": ("xbin", "prediction_mean") +""" +function premade_hgf_binary_softmax(config::Dict) + + ## Combine defaults and user settings + + #Default parameters and settings + defaults = Dict( + "action_noise" => 1, + "target_state" => ("xbin", "prediction_mean"), + "HGF" => "binary_3level", + ) + + #If there is no HGF in the user-set parameters + if !("HGF" in keys(config)) + HGF_name = defaults["HGF"] + #Make a default HGF + config["HGF"] = premade_hgf(HGF_name) + #And warn them + @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" + end + + #Warn the user about used defaults and misspecified keys + warn_premade_defaults(defaults, config) + + #Merge to overwrite defaults + config = merge(defaults, config) + + + ## Create agent + #Set the action model + action_model = hgf_binary_softmax_action + + #Set the HGF + hgf = config["HGF"] + + #Set parameters + parameters = Dict("action_noise" => config["action_noise"]) + #Set states + states = Dict() + #Set settings + settings = Dict("target_state" => config["target_state"]) + + #Create the agent + return init_agent( + action_model, + substruct = hgf, + parameters = parameters, + states = states, + settings = settings, + ) +end diff --git a/test/runtests.jl b/test/runtests.jl index 5cbf2f5..f072d30 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,7 +38,9 @@ hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) filenames = glob("*.jl", documentation_path * "/Julia_src_files") for filename in filenames - include(filename) + @testset "$filename" begin + include(filename) + end end end @@ -47,10 +49,10 @@ hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) # List the julia filenames in the tutorials folder filenames = glob("*.jl", documentation_path * "/tutorials") - # For each file for filename in filenames - #Run it - include(filename) + @testset "$filename" begin + include(filename) + end end end end diff --git a/test/testsuite/test_fit_model.jl b/test/testsuite/test_fit_model.jl index bb65660..082acbd 100644 --- a/test/testsuite/test_fit_model.jl +++ b/test/testsuite/test_fit_model.jl @@ -26,7 +26,7 @@ using Turing ("xvol", "initial_mean") => 1.0, ("xvol", "initial_precision") => 600, ("x", "xvol", "coupling_strength") => 1.0, - "gaussian_action_precision" => 100, + "action_noise" => 0.01, ("xvol", "volatility") => -4, ("u", "input_noise") => 4, ("xvol", "drift") => 1, @@ -105,7 +105,7 @@ using Turing ) test_param_priors = Dict( - "softmax_action_precision" => truncated(Normal(100, 20), 0, Inf), + "action_noise" => truncated(Normal(0.01, 20), 0, Inf), ("xprob", "volatility") => Normal(-7, 5), ) From 2a8bc4db60125bc9fe4b061507643ad2c34c9036 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Fri, 22 Dec 2023 10:54:56 +0100 Subject: [PATCH 06/16] moved docs files --- .gitignore | 11 +- docs/Project.toml | 3 +- .../Julia_src_files => julia_files}/index.jl | 0 .../tutorials/classic_JGET.jl | 0 .../tutorials/classic_binary.jl | 3 + .../tutorials/classic_usdchf.jl | 0 .../tutorials/data/classic_binary_actions.csv | 0 .../tutorials/data/classic_binary_inputs.csv | 0 .../data/classic_cannonball_data.csv | 0 .../tutorials/data/classic_usdchf_inputs.dat | 0 .../user_guide}/all_functions.jl | 0 .../user_guide}/building_an_HGF.jl | 0 .../user_guide}/fitting_hgf_models.jl | 0 .../user_guide}/premade_HGF.jl | 4 +- .../user_guide}/premade_models.jl | 0 .../user_guide}/the_HGF_nodes.jl | 0 .../user_guide}/updating_the_HGF.jl | 0 .../user_guide}/utility_functions.jl | 0 docs/make.jl | 215 ++++++++++++++---- docs/src/index.md | 113 --------- 20 files changed, 186 insertions(+), 163 deletions(-) rename docs/{src/Julia_src_files => julia_files}/index.jl (100%) rename docs/{src => julia_files}/tutorials/classic_JGET.jl (100%) rename docs/{src => julia_files}/tutorials/classic_binary.jl (98%) rename docs/{src => julia_files}/tutorials/classic_usdchf.jl (100%) rename docs/{src => julia_files}/tutorials/data/classic_binary_actions.csv (100%) rename docs/{src => julia_files}/tutorials/data/classic_binary_inputs.csv (100%) rename docs/{src => julia_files}/tutorials/data/classic_cannonball_data.csv (100%) rename docs/{src => julia_files}/tutorials/data/classic_usdchf_inputs.dat (100%) rename docs/{src/Julia_src_files => julia_files/user_guide}/all_functions.jl (100%) rename docs/{src/Julia_src_files => julia_files/user_guide}/building_an_HGF.jl (100%) rename docs/{src/Julia_src_files => julia_files/user_guide}/fitting_hgf_models.jl (100%) rename docs/{src/Julia_src_files => julia_files/user_guide}/premade_HGF.jl (98%) rename docs/{src/Julia_src_files => julia_files/user_guide}/premade_models.jl (100%) rename docs/{src/Julia_src_files => julia_files/user_guide}/the_HGF_nodes.jl (100%) rename docs/{src/Julia_src_files => julia_files/user_guide}/updating_the_HGF.jl (100%) rename docs/{src/Julia_src_files => julia_files/user_guide}/utility_functions.jl (100%) delete mode 100644 docs/src/index.md diff --git a/.gitignore b/.gitignore index f0af5a4..5c16191 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,17 @@ *.jl.*.cov *.jl.cov *.jl.mem -/docs/build/ .DS_Store +.vscode + testing_script*.jl settings.json + Manifest.toml docs/Manifest.toml test/Manifest.toml -/docs/src/generated_markdowns/*.md -.vscode \ No newline at end of file + +/docs/src/generated +/docs/src/index.md +/docs/build/ +/build \ No newline at end of file diff --git a/docs/Project.toml b/docs/Project.toml index dca5a47..b3a6a54 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -2,14 +2,13 @@ ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -Debugger = "31a5f54b-26ea-5ae9-a837-f05ce5417438" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" HierarchicalGaussianFiltering = "63d42c3e-681c-42be-892f-a47f35336a79" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/docs/src/Julia_src_files/index.jl b/docs/julia_files/index.jl similarity index 100% rename from docs/src/Julia_src_files/index.jl rename to docs/julia_files/index.jl diff --git a/docs/src/tutorials/classic_JGET.jl b/docs/julia_files/tutorials/classic_JGET.jl similarity index 100% rename from docs/src/tutorials/classic_JGET.jl rename to docs/julia_files/tutorials/classic_JGET.jl diff --git a/docs/src/tutorials/classic_binary.jl b/docs/julia_files/tutorials/classic_binary.jl similarity index 98% rename from docs/src/tutorials/classic_binary.jl rename to docs/julia_files/tutorials/classic_binary.jl index d4c7f65..719d77f 100644 --- a/docs/src/tutorials/classic_binary.jl +++ b/docs/julia_files/tutorials/classic_binary.jl @@ -11,6 +11,9 @@ using Plots using StatsPlots using Distributions +CSV.read(pwd(), DataFrame) +print(pwd()) + # Get the path for the HGF superfolder hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) # Add the path to the data files diff --git a/docs/src/tutorials/classic_usdchf.jl b/docs/julia_files/tutorials/classic_usdchf.jl similarity index 100% rename from docs/src/tutorials/classic_usdchf.jl rename to docs/julia_files/tutorials/classic_usdchf.jl diff --git a/docs/src/tutorials/data/classic_binary_actions.csv b/docs/julia_files/tutorials/data/classic_binary_actions.csv similarity index 100% rename from docs/src/tutorials/data/classic_binary_actions.csv rename to docs/julia_files/tutorials/data/classic_binary_actions.csv diff --git a/docs/src/tutorials/data/classic_binary_inputs.csv b/docs/julia_files/tutorials/data/classic_binary_inputs.csv similarity index 100% rename from docs/src/tutorials/data/classic_binary_inputs.csv rename to docs/julia_files/tutorials/data/classic_binary_inputs.csv diff --git a/docs/src/tutorials/data/classic_cannonball_data.csv b/docs/julia_files/tutorials/data/classic_cannonball_data.csv similarity index 100% rename from docs/src/tutorials/data/classic_cannonball_data.csv rename to docs/julia_files/tutorials/data/classic_cannonball_data.csv diff --git a/docs/src/tutorials/data/classic_usdchf_inputs.dat b/docs/julia_files/tutorials/data/classic_usdchf_inputs.dat similarity index 100% rename from docs/src/tutorials/data/classic_usdchf_inputs.dat rename to docs/julia_files/tutorials/data/classic_usdchf_inputs.dat diff --git a/docs/src/Julia_src_files/all_functions.jl b/docs/julia_files/user_guide/all_functions.jl similarity index 100% rename from docs/src/Julia_src_files/all_functions.jl rename to docs/julia_files/user_guide/all_functions.jl diff --git a/docs/src/Julia_src_files/building_an_HGF.jl b/docs/julia_files/user_guide/building_an_HGF.jl similarity index 100% rename from docs/src/Julia_src_files/building_an_HGF.jl rename to docs/julia_files/user_guide/building_an_HGF.jl diff --git a/docs/src/Julia_src_files/fitting_hgf_models.jl b/docs/julia_files/user_guide/fitting_hgf_models.jl similarity index 100% rename from docs/src/Julia_src_files/fitting_hgf_models.jl rename to docs/julia_files/user_guide/fitting_hgf_models.jl diff --git a/docs/src/Julia_src_files/premade_HGF.jl b/docs/julia_files/user_guide/premade_HGF.jl similarity index 98% rename from docs/src/Julia_src_files/premade_HGF.jl rename to docs/julia_files/user_guide/premade_HGF.jl index f243c12..e4ea569 100644 --- a/docs/src/Julia_src_files/premade_HGF.jl +++ b/docs/julia_files/user_guide/premade_HGF.jl @@ -19,7 +19,9 @@ using DataFrames #hide using Plots #hide using StatsPlots #hide -hgf_path_continuous = dirname(dirname(pathof(HierarchicalGaussianFiltering))); #hide +#CSV.read(pwd(), DataFrame) + +hgf_path_continuous = dirname(pathof(HierarchicalGaussianFiltering)); #hide hgf_path_continuous = hgf_path_continuous * "/docs/src/tutorials/data/"; #hide inputs_continuous = Float64[]; #hide diff --git a/docs/src/Julia_src_files/premade_models.jl b/docs/julia_files/user_guide/premade_models.jl similarity index 100% rename from docs/src/Julia_src_files/premade_models.jl rename to docs/julia_files/user_guide/premade_models.jl diff --git a/docs/src/Julia_src_files/the_HGF_nodes.jl b/docs/julia_files/user_guide/the_HGF_nodes.jl similarity index 100% rename from docs/src/Julia_src_files/the_HGF_nodes.jl rename to docs/julia_files/user_guide/the_HGF_nodes.jl diff --git a/docs/src/Julia_src_files/updating_the_HGF.jl b/docs/julia_files/user_guide/updating_the_HGF.jl similarity index 100% rename from docs/src/Julia_src_files/updating_the_HGF.jl rename to docs/julia_files/user_guide/updating_the_HGF.jl diff --git a/docs/src/Julia_src_files/utility_functions.jl b/docs/julia_files/user_guide/utility_functions.jl similarity index 100% rename from docs/src/Julia_src_files/utility_functions.jl rename to docs/julia_files/user_guide/utility_functions.jl diff --git a/docs/make.jl b/docs/make.jl index f6bd8a8..3a0907b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,43 +2,60 @@ using HierarchicalGaussianFiltering using Documenter using Literate + +hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) + +juliafiles_path = hgf_path * "/docs/julia_files" +user_guides_path = juliafiles_path * "/user_guide" +tutorials_path = juliafiles_path * "/tutorials" + +markdown_src_path = hgf_path * "/docs/src" +theory_path = markdown_src_path * "/theory" +generated_user_guide_path = markdown_src_path * "/generated/user_guide" +generated_tutorials_path = markdown_src_path * "/generated/tutorials" + + #Remove old tutorial markdown files -for filename in readdir("docs/src/generated_markdowns") - rm("docs/src/generated_markdowns/" * filename) +for filename in readdir(generated_user_guide_path) + if endswith(filename, ".md") + rm(generated_user_guide_path * "/" * filename) + end end -rm("docs/src/index.md") -#Generate new markdown files from the documentation source files -for filename in readdir("docs/src/Julia_src_files") - if endswith(filename, ".jl") +for filename in readdir(generated_tutorials_path) + if endswith(filename, ".md") + rm(generated_tutorials_path * "/" * filename) + end +end +rm(markdown_src_path * "/" * "index.md") - #Place the index file in another folder than the rest of the documentation - if startswith(filename, "index") - Literate.markdown( - "docs/src/Julia_src_files/" * filename, - "docs/src", - documenter = true, - ) - else - Literate.markdown( - "docs/src/Julia_src_files/" * filename, - "docs/src/generated_markdowns", - documenter = true, - ) - end +#Generate index markdown file +Literate.markdown(juliafiles_path * "/" * "index.jl", markdown_src_path, documenter = true) + +#Generate markdown files for user guide +for filename in readdir(user_guides_path) + if endswith(filename, ".jl") + Literate.markdown( + user_guides_path * "/" * filename, + generated_user_guide_path, + documenter = true, + ) end end -#Generate new tutorial markdown files from the tutorials -for filename in readdir("docs/src/tutorials") +#Generate markdown files for tutorials +for filename in readdir(tutorials_path) if endswith(filename, ".jl") Literate.markdown( - "docs/src/tutorials/" * filename, - "docs/src/generated_markdowns", + tutorials_path * "/" * filename, + generated_tutorials_path, documenter = true, ) end end + + + #Set documenter metadata DocMeta.setdocmeta!( HierarchicalGaussianFiltering, @@ -61,30 +78,140 @@ makedocs(; ), pages = [ "Introduction to Hierarchical Gaussian Filtering" => "./index.md", - "Theory" => [ - "./theory/genmodel.md", - "./theory/node.md", - "./theory/vape.md", - "./theory/vope.md", - ], - "Using the package" => [ - "The HGF Nodes" => "./generated_markdowns/the_HGF_nodes.md", - "Building an HGF" => "./generated_markdowns/building_an_HGF.md", - "Updating the HGF" => "./generated_markdowns/updating_the_HGF.md", - "List Of Premade Agent Models" => "./generated_markdowns/premade_models.md", - "List Of Premade HGF's" => "./generated_markdowns/premade_HGF.md", - "Fitting an HGF-agent model to data" => "./generated_markdowns/fitting_hgf_models.md", - "Utility Functions" => "./generated_markdowns/utility_functions.md", - ], - "Tutorials" => [ - "classic binary" => "./generated_markdowns/classic_binary.md", - "classic continouous" => "./generated_markdowns/classic_usdchf.md", - ], - "All Functions" => "./generated_markdowns/all_functions.md", + # "Theory" => [ + # "./theory" * "/genmodel.md", + # "./theory" * "/node.md", + # "./theory" * "/vape.md", + # "./theory" * "/vope.md", + # ], + # "Using the package" => [ + # "The HGF Nodes" => "./generated/user_guide" * "/the_HGF_nodes.md", + # "Building an HGF" => "./generated/user_guide" * "/building_an_HGF.md", + # "Updating the HGF" => "./generated/user_guide" * "/updating_the_HGF.md", + # "List Of Premade Agent Models" => "./generated/user_guide" * "/premade_models.md", + # "List Of Premade HGF's" => "./generated/user_guide" * "/premade_HGF.md", + # "Fitting an HGF-agent model to data" => "./generated/user_guide" * "/fitting_hgf_models.md", + # "Utility Functions" => "./generated/user_guide" * "/utility_functions.md", + # ], + # "Tutorials" => [ + # "classic binary" => "./generated/tutorials" * "/classic_binary.md", + # "classic continouous" => "./generated/tutorials" * "/classic_usdchf.md", + # "classic JGET" => "./generated/tutorials" * "/classic_JGET.md", + # ], + # "All Functions" => "./generated/user_guide" * "/all_functions.md", ], ) + deploydocs(; repo = "github.com/ilabcode/HierarchicalGaussianFiltering.jl", devbranch = "main", push_preview = false, ) + + + + + + + + + + + + + + + + + + +# using HierarchicalGaussianFiltering +# using Documenter +# using Literate + +# #Remove old tutorial markdown files +# for filename in readdir("docs/src/generated_markdowns") +# rm("docs/src/generated_markdowns/" * filename) +# end +# rm("docs/src/index.md") +# #Generate new markdown files from the documentation source files +# for filename in readdir("docs/src/user_guide") +# if endswith(filename, ".jl") + +# #Place the index file in another folder than the rest of the documentation +# if startswith(filename, "index") +# Literate.markdown( +# "docs/src/user_guide/" * filename, +# "docs/src", +# documenter = true, +# ) +# else +# Literate.markdown( +# "docs/src/Julia_src_files/" * filename, +# "docs/src/generated_markdowns", +# documenter = true, +# ) +# end +# end +# end + +# #Generate new tutorial markdown files from the tutorials +# for filename in readdir("docs/src/tutorials") +# if endswith(filename, ".jl") +# Literate.markdown( +# "docs/src/tutorials/" * filename, +# "docs/src/generated_markdowns", +# documenter = true, +# ) +# end +# end + +# #Set documenter metadata +# DocMeta.setdocmeta!( +# HierarchicalGaussianFiltering, +# :DocTestSetup, +# :(using HierarchicalGaussianFiltering); +# recursive = true, +# ) + +# #Create documentation +# makedocs(; +# modules = [HierarchicalGaussianFiltering], +# authors = "Peter Thestrup Waade ptw@cas.au.dk, Jacopo Comoglio jacopo.comoglio@gmail.com, Christoph Mathys chmathys@cas.au.dk +# and contributors", +# repo = "https://github.com/ilabcode/HierarchicalGaussianFiltering.jl/blob/{commit}{path}#{line}", +# sitename = "HierarchicalGaussianFiltering.jl", +# format = Documenter.HTML(; +# prettyurls = get(ENV, "CI", "false") == "true", +# canonical = "https://ilabcode.github.io/HierarchicalGaussianFiltering.jl", +# assets = String[], +# ), +# pages = [ +# "Introduction to Hierarchical Gaussian Filtering" => "./index.md", +# "Theory" => [ +# "./theory/genmodel.md", +# "./theory/node.md", +# "./theory/vape.md", +# "./theory/vope.md", +# ], +# "Using the package" => [ +# "The HGF Nodes" => "./generated_markdowns/the_HGF_nodes.md", +# "Building an HGF" => "./generated_markdowns/building_an_HGF.md", +# "Updating the HGF" => "./generated_markdowns/updating_the_HGF.md", +# "List Of Premade Agent Models" => "./generated_markdowns/premade_models.md", +# "List Of Premade HGF's" => "./generated_markdowns/premade_HGF.md", +# "Fitting an HGF-agent model to data" => "./generated_markdowns/fitting_hgf_models.md", +# "Utility Functions" => "./generated_markdowns/utility_functions.md", +# ], +# "Tutorials" => [ +# "classic binary" => "./generated_markdowns/classic_binary.md", +# "classic continouous" => "./generated_markdowns/classic_usdchf.md", +# ], +# "All Functions" => "./generated_markdowns/all_functions.md", +# ], +# ) +# deploydocs(; +# repo = "github.com/ilabcode/HierarchicalGaussianFiltering.jl", +# devbranch = "main", +# push_preview = false, +# ) diff --git a/docs/src/index.md b/docs/src/index.md deleted file mode 100644 index 7cb3a58..0000000 --- a/docs/src/index.md +++ /dev/null @@ -1,113 +0,0 @@ -```@meta -EditURL = "/docs/src/Julia_src_files/index.jl" -``` - -# Welcome to The Hierarchical Gaussian Filtering Package! - -Hierarchical Gaussian Filtering (HGF) is a novel and adaptive package for doing cognitive and behavioral modelling. With the HGF you can fit time series data fit participant-level individual parameters, measure group differences based on model-specific parameters or use the model for any time series with underlying change in uncertainty. - -The HGF consists of a network of probabilistic nodes hierarchically structured. The hierarchy is determined by the coupling between nodes. A node (child node) in the network can inheret either its value or volatility sufficient statistics from a node higher in the hierarchy (a parent node). - -The presentation of a new observation at the lower level of the hierarchy (i.e. the input node) trigger a recursuve update of the nodes belief throught the bottom-up propagation of precision-weigthed prediction error. - -The HGF will be explained in more detail in the theory section of the documentation - -It is also recommended to check out the ActionModels.jl pacakge for stronger intuition behind the use of agents and action models. - -## Getting started - -The last official release can be downloaded from Julia with "] add HierarchicalGaussianFiltering" - -We provide a script for getting started with commonly used functions and use cases - -Load packages - -````@example index -using HierarchicalGaussianFiltering -using ActionModels -```` - -### Get premade agent - -````@example index -premade_agent("help") -```` - -### Create agent - -````@example index -agent = premade_agent("hgf_binary_softmax_action") -```` - -### Get states and parameters - -````@example index -get_states(agent) -```` - -````@example index -get_parameters(agent) -```` - -Set a new parameter for initial precision of xprob and define some inputs - -````@example index -set_parameters!(agent, ("xprob", "initial_precision"), 0.9) -inputs = [1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0]; -nothing #hide -```` - -### Give inputs to the agent - -````@example index -actions = give_inputs!(agent, inputs) -```` - -### Plot state trajectories of input and prediction - -````@example index -using StatsPlots -using Plots -plot_trajectory(agent, ("u", "input_value")) -plot_trajectory!(agent, ("x", "prediction")) -```` - -Plot state trajectory of input value, action and prediction of x - -````@example index -plot_trajectory(agent, ("u", "input_value")) -plot_trajectory!(agent, "action") -plot_trajectory!(agent, ("x", "prediction")) -```` - -### Fitting parameters - -````@example index -using Distributions -prior = Dict(("xprob", "volatility") => Normal(1, 0.5)) - -model = fit_model(agent, prior, inputs, actions, n_iterations = 20) -```` - -### Plot chains - -````@example index -plot(model) -```` - -### Plot prior angainst posterior - -````@example index -plot_parameter_distribution(model, prior) -```` - -### Get posterior - -````@example index -get_posteriors(model) -```` - ---- - -*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* - From ab6f9be94f3d0203078bd33756cf6b094305912a Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Fri, 22 Mar 2024 13:22:36 +0100 Subject: [PATCH 07/16] minor --- docs/julia_files/tutorials/classic_binary.jl | 2 +- docs/julia_files/tutorials/classic_usdchf.jl | 6 +- docs/make.jl | 108 ------------------ .../node_updates/categorical_state_node.jl | 2 +- .../node_updates/continuous_state_node.jl | 32 +++--- test/Project.toml | 3 + 6 files changed, 25 insertions(+), 128 deletions(-) diff --git a/docs/julia_files/tutorials/classic_binary.jl b/docs/julia_files/tutorials/classic_binary.jl index 719d77f..b53c618 100644 --- a/docs/julia_files/tutorials/classic_binary.jl +++ b/docs/julia_files/tutorials/classic_binary.jl @@ -17,7 +17,7 @@ print(pwd()) # Get the path for the HGF superfolder hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) # Add the path to the data files -data_path = hgf_path * "/docs/src/tutorials/data/" +data_path = hgf_path * "/docs/julia_files/tutorials/data/" # Load the data inputs = CSV.read(data_path * "classic_binary_inputs.csv", DataFrame)[!, 1]; diff --git a/docs/julia_files/tutorials/classic_usdchf.jl b/docs/julia_files/tutorials/classic_usdchf.jl index 25157c2..d783fc1 100644 --- a/docs/julia_files/tutorials/classic_usdchf.jl +++ b/docs/julia_files/tutorials/classic_usdchf.jl @@ -12,7 +12,7 @@ using Distributions # Get the path for the HGF superfolder hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) # Add the path to the data files -data_path = hgf_path * "/docs/src/tutorials/data/" +data_path = hgf_path * "/docs/julia_files/tutorials/data/" # Load the data inputs = Float64[] @@ -85,13 +85,13 @@ fixed_parameters = Dict( ("x", "initial_precision") => 2000, ("xvol", "initial_mean") => 1.0, ("xvol", "initial_precision") => 600.0, - "action_noise" => 0.01, ); param_priors = Dict( ("u", "input_noise") => Normal(-6, 1), ("x", "volatility") => Normal(-4, 1), ("xvol", "volatility") => Normal(-4, 1), + "action_noise" => LogNormal(log(0.01), 1), ); #- # Prior predictive simulation plot @@ -111,7 +111,7 @@ fitted_model = fit_model( actions, fixed_parameters = fixed_parameters, verbose = false, - n_iterations = 10, + n_iterations = 4000, ) #- # Plot the chains diff --git a/docs/make.jl b/docs/make.jl index 3a0907b..a05db21 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -107,111 +107,3 @@ deploydocs(; devbranch = "main", push_preview = false, ) - - - - - - - - - - - - - - - - - - -# using HierarchicalGaussianFiltering -# using Documenter -# using Literate - -# #Remove old tutorial markdown files -# for filename in readdir("docs/src/generated_markdowns") -# rm("docs/src/generated_markdowns/" * filename) -# end -# rm("docs/src/index.md") -# #Generate new markdown files from the documentation source files -# for filename in readdir("docs/src/user_guide") -# if endswith(filename, ".jl") - -# #Place the index file in another folder than the rest of the documentation -# if startswith(filename, "index") -# Literate.markdown( -# "docs/src/user_guide/" * filename, -# "docs/src", -# documenter = true, -# ) -# else -# Literate.markdown( -# "docs/src/Julia_src_files/" * filename, -# "docs/src/generated_markdowns", -# documenter = true, -# ) -# end -# end -# end - -# #Generate new tutorial markdown files from the tutorials -# for filename in readdir("docs/src/tutorials") -# if endswith(filename, ".jl") -# Literate.markdown( -# "docs/src/tutorials/" * filename, -# "docs/src/generated_markdowns", -# documenter = true, -# ) -# end -# end - -# #Set documenter metadata -# DocMeta.setdocmeta!( -# HierarchicalGaussianFiltering, -# :DocTestSetup, -# :(using HierarchicalGaussianFiltering); -# recursive = true, -# ) - -# #Create documentation -# makedocs(; -# modules = [HierarchicalGaussianFiltering], -# authors = "Peter Thestrup Waade ptw@cas.au.dk, Jacopo Comoglio jacopo.comoglio@gmail.com, Christoph Mathys chmathys@cas.au.dk -# and contributors", -# repo = "https://github.com/ilabcode/HierarchicalGaussianFiltering.jl/blob/{commit}{path}#{line}", -# sitename = "HierarchicalGaussianFiltering.jl", -# format = Documenter.HTML(; -# prettyurls = get(ENV, "CI", "false") == "true", -# canonical = "https://ilabcode.github.io/HierarchicalGaussianFiltering.jl", -# assets = String[], -# ), -# pages = [ -# "Introduction to Hierarchical Gaussian Filtering" => "./index.md", -# "Theory" => [ -# "./theory/genmodel.md", -# "./theory/node.md", -# "./theory/vape.md", -# "./theory/vope.md", -# ], -# "Using the package" => [ -# "The HGF Nodes" => "./generated_markdowns/the_HGF_nodes.md", -# "Building an HGF" => "./generated_markdowns/building_an_HGF.md", -# "Updating the HGF" => "./generated_markdowns/updating_the_HGF.md", -# "List Of Premade Agent Models" => "./generated_markdowns/premade_models.md", -# "List Of Premade HGF's" => "./generated_markdowns/premade_HGF.md", -# "Fitting an HGF-agent model to data" => "./generated_markdowns/fitting_hgf_models.md", -# "Utility Functions" => "./generated_markdowns/utility_functions.md", -# ], -# "Tutorials" => [ -# "classic binary" => "./generated_markdowns/classic_binary.md", -# "classic continouous" => "./generated_markdowns/classic_usdchf.md", -# ], -# "All Functions" => "./generated_markdowns/all_functions.md", -# ], -# ) -# deploydocs(; -# repo = "github.com/ilabcode/HierarchicalGaussianFiltering.jl", -# devbranch = "main", -# push_preview = false, -# ) diff --git a/src/update_hgf/node_updates/categorical_state_node.jl b/src/update_hgf/node_updates/categorical_state_node.jl index 44b6540..458daba 100644 --- a/src/update_hgf/node_updates/categorical_state_node.jl +++ b/src/update_hgf/node_updates/categorical_state_node.jl @@ -49,7 +49,7 @@ function calculate_prediction(node::CategoricalStateNode) (parent_predictions .- previous_parent_predictions) ) .- 1 - # calculate the prediction mean + #Calculate the prediction mean prediction = ((implied_learning_rate .* parent_predictions) .+ 1) ./ sum(implied_learning_rate .* parent_predictions .+ 1) diff --git a/src/update_hgf/node_updates/continuous_state_node.jl b/src/update_hgf/node_updates/continuous_state_node.jl index 9574ead..58fe6cf 100644 --- a/src/update_hgf/node_updates/continuous_state_node.jl +++ b/src/update_hgf/node_updates/continuous_state_node.jl @@ -234,7 +234,7 @@ function calculate_posterior_precision_increment( child::ContinuousStateNode, coupling_type::DriftCoupling, ) - child.parameters.coupling_strengths[node.name] * child.states.prediction_precision + child.parameters.coupling_strengths[node.name]^2 * child.states.prediction_precision end @@ -390,8 +390,12 @@ function calculate_posterior_mean_increment( coupling_type::DriftCoupling, update_type::ClassicUpdate, ) - (child.parameters.coupling_strengths[node.name] * child.states.prediction_precision) / - node.states.posterior_precision * child.states.value_prediction_error + ( + ( + child.parameters.coupling_strengths[node.name] * + child.states.prediction_precision + ) / node.states.posterior_precision + ) * child.states.value_prediction_error end ## Enhanced drift coupling ## @@ -401,8 +405,12 @@ function calculate_posterior_mean_increment( coupling_type::DriftCoupling, update_type::EnhancedUpdate, ) - (child.parameters.coupling_strengths[node.name] * child.states.prediction_precision) / - node.states.prediction_precision * child.states.value_prediction_error + ( + ( + child.parameters.coupling_strengths[node.name] * + child.states.prediction_precision + ) / node.states.prediction_precision + ) * child.states.value_prediction_error end @@ -418,11 +426,8 @@ function calculate_posterior_mean_increment( #No update return 0 else - update_term = - child.states.prediction_precision / node.states.posterior_precision * - child.states.value_prediction_error - - return update_term + return (child.states.prediction_precision / node.states.posterior_precision) * + child.states.value_prediction_error end end @@ -438,11 +443,8 @@ function calculate_posterior_mean_increment( #No update return 0 else - update_term = - child.states.prediction_precision / node.states.prediction_precision * - child.states.value_prediction_error - - return update_term + return (child.states.prediction_precision / node.states.prediction_precision) * + child.states.value_prediction_error end end diff --git a/test/Project.toml b/test/Project.toml index 2c924d0..38e7d75 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,9 +1,12 @@ [deps] ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" From 194005e247f711ea463a44cca4e52eae1da95786 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Sun, 7 Apr 2024 23:43:45 +0200 Subject: [PATCH 08/16] added bias and autoconnectionstrength --- src/create_hgf/hgf_structs.jl | 12 +++---- src/create_hgf/init_hgf.jl | 3 +- .../premade_hgfs/premade_JGET.jl | 31 ++++++++--------- .../premade_hgfs/premade_binary_2level.jl | 6 ++-- .../premade_hgfs/premade_binary_3level.jl | 12 +++---- .../premade_categorical_3level.jl | 34 ++++++------------- .../premade_categorical_transitions_3level.jl | 34 ++++++------------- .../premade_hgfs/premade_continuous_2level.jl | 19 ++++++----- .../node_updates/continuous_input_node.jl | 2 +- .../node_updates/continuous_state_node.jl | 11 +++--- 10 files changed, 63 insertions(+), 101 deletions(-) diff --git a/src/create_hgf/hgf_structs.jl b/src/create_hgf/hgf_structs.jl index 915cbed..6b79cd2 100644 --- a/src/create_hgf/hgf_structs.jl +++ b/src/create_hgf/hgf_structs.jl @@ -88,10 +88,10 @@ end ################################## Base.@kwdef struct NodeDefaults input_noise::Real = -2 + bias::Real = 0 volatility::Real = -2 drift::Real = 0 - autoregression_target::Real = 0 - autoregression_strength::Real = 0 + autoconnection_strength::Real = 1 initial_mean::Real = 0 initial_precision::Real = 1 coupling_strength::Real = 1 @@ -102,8 +102,7 @@ Base.@kwdef mutable struct ContinuousState <: AbstractStateNodeInfo name::String volatility::Union{Real,Nothing} = nothing drift::Union{Real,Nothing} = nothing - autoregression_target::Union{Real,Nothing} = nothing - autoregression_strength::Union{Real,Nothing} = nothing + autoconnection_strength::Union{Real,Nothing} = nothing initial_mean::Union{Real,Nothing} = nothing initial_precision::Union{Real,Nothing} = nothing end @@ -111,6 +110,7 @@ end Base.@kwdef mutable struct ContinuousInput <: AbstractInputNodeInfo name::String input_noise::Union{Real,Nothing} = nothing + bias::Union{Real,Nothing} = nothing end Base.@kwdef mutable struct BinaryState <: AbstractStateNodeInfo @@ -156,8 +156,7 @@ Configuration of continuous state nodes' parameters Base.@kwdef mutable struct ContinuousStateNodeParameters volatility::Real = 0 drift::Real = 0 - autoregression_target::Real = 0 - autoregression_strength::Real = 0 + autoconnection_strength::Real = 1 coupling_strengths::Dict{String,Real} = Dict{String,Real}() initial_mean::Real = 0 initial_precision::Real = 0 @@ -218,6 +217,7 @@ Configuration of continuous input node parameters """ Base.@kwdef mutable struct ContinuousInputNodeParameters input_noise::Real = 0 + bias::Real = 0 coupling_strengths::Dict{String,Real} = Dict{String,Real}() end diff --git a/src/create_hgf/init_hgf.jl b/src/create_hgf/init_hgf.jl index 71cf55e..93ccaca 100644 --- a/src/create_hgf/init_hgf.jl +++ b/src/create_hgf/init_hgf.jl @@ -269,8 +269,7 @@ function init_node(node_info::ContinuousState) drift = node_info.drift, initial_mean = node_info.initial_mean, initial_precision = node_info.initial_precision, - autoregression_target = node_info.autoregression_target, - autoregression_strength = node_info.autoregression_strength, + autoconnection_strength = node_info.autoconnection_strength, ), ) end diff --git a/src/premade_models/premade_hgfs/premade_JGET.jl b/src/premade_models/premade_hgfs/premade_JGET.jl index ec5c74f..557e091 100644 --- a/src/premade_models/premade_hgfs/premade_JGET.jl +++ b/src/premade_models/premade_hgfs/premade_JGET.jl @@ -27,28 +27,25 @@ function premade_JGET(config::Dict; verbose::Bool = true) #Defaults spec_defaults = Dict( ("u", "input_noise") => -2, + ("u", "bias") => 0, ("x", "volatility") => -2, ("x", "drift") => 0, - ("x", "autoregression_target") => 0, - ("x", "autoregression_strength") => 0, + ("x", "autoconnection_strength") => 1, ("x", "initial_mean") => 0, ("x", "initial_precision") => 1, ("xvol", "volatility") => -2, ("xvol", "drift") => 0, - ("xvol", "autoregression_target") => 0, - ("xvol", "autoregression_strength") => 0, + ("xvol", "autoconnection_strength") => 1, ("xvol", "initial_mean") => 0, ("xvol", "initial_precision") => 1, ("xnoise", "volatility") => -2, ("xnoise", "drift") => 0, - ("xnoise", "autoregression_target") => 0, - ("xnoise", "autoregression_strength") => 0, + ("xnoise", "autoconnection_strength") => 1, ("xnoise", "initial_mean") => 0, ("xnoise", "initial_precision") => 1, ("xnoise_vol", "volatility") => -2, ("xnoise_vol", "drift") => 0, - ("xnoise_vol", "autoregression_target") => 0, - ("xnoise_vol", "autoregression_strength") => 0, + ("xnoise_vol", "autoconnection_strength") => 1, ("xnoise_vol", "initial_mean") => 0, ("xnoise_vol", "initial_precision") => 1, ("u", "xnoise", "coupling_strength") => 1, @@ -67,13 +64,16 @@ function premade_JGET(config::Dict; verbose::Bool = true) #List of nodes nodes = [ - ContinuousInput(name = "u", input_noise = config[("u", "input_noise")]), + ContinuousInput( + name = "u", + input_noise = config[("u", "input_noise")], + bias = config[("u", "bias")], + ), ContinuousState( name = "x", volatility = config[("x", "volatility")], drift = config[("x", "drift")], - autoregression_target = config[("x", "autoregression_target")], - autoregression_strength = config[("x", "autoregression_strength")], + autoconnection_strength = config[("x", "autoconnection_strength")], initial_mean = config[("x", "initial_mean")], initial_precision = config[("x", "initial_precision")], ), @@ -81,8 +81,7 @@ function premade_JGET(config::Dict; verbose::Bool = true) name = "xvol", volatility = config[("xvol", "volatility")], drift = config[("xvol", "drift")], - autoregression_target = config[("xvol", "autoregression_target")], - autoregression_strength = config[("xvol", "autoregression_strength")], + autoconnection_strength = config[("xvol", "autoconnection_strength")], initial_mean = config[("xvol", "initial_mean")], initial_precision = config[("xvol", "initial_precision")], ), @@ -90,8 +89,7 @@ function premade_JGET(config::Dict; verbose::Bool = true) name = "xnoise", volatility = config[("xnoise", "volatility")], drift = config[("xnoise", "drift")], - autoregression_target = config[("xnoise", "autoregression_target")], - autoregression_strength = config[("xnoise", "autoregression_strength")], + autoconnection_strength = config[("xnoise", "autoconnection_strength")], initial_mean = config[("xnoise", "initial_mean")], initial_precision = config[("xnoise", "initial_precision")], ), @@ -99,8 +97,7 @@ function premade_JGET(config::Dict; verbose::Bool = true) name = "xnoise_vol", volatility = config[("xnoise_vol", "volatility")], drift = config[("xnoise_vol", "drift")], - autoregression_target = config[("xnoise_vol", "autoregression_target")], - autoregression_strength = config[("xnoise_vol", "autoregression_strength")], + autoconnection_strength = config[("xnoise_vol", "autoconnection_strength")], initial_mean = config[("xnoise_vol", "initial_mean")], initial_precision = config[("xnoise_vol", "initial_precision")], ), diff --git a/src/premade_models/premade_hgfs/premade_binary_2level.jl b/src/premade_models/premade_hgfs/premade_binary_2level.jl index 779d2dc..0413edd 100644 --- a/src/premade_models/premade_hgfs/premade_binary_2level.jl +++ b/src/premade_models/premade_hgfs/premade_binary_2level.jl @@ -15,8 +15,7 @@ function premade_binary_2level(config::Dict; verbose::Bool = true) ("u", "input_precision") => Inf, ("xprob", "volatility") => -2, ("xprob", "drift") => 0, - ("xprob", "autoregression_target") => 0, - ("xprob", "autoregression_strength") => 0, + ("xprob", "autoconnection_strength") => 1, ("xprob", "initial_mean") => 0, ("xprob", "initial_precision") => 1, ("xbin", "xprob", "coupling_strength") => 1, @@ -39,8 +38,7 @@ function premade_binary_2level(config::Dict; verbose::Bool = true) name = "xprob", volatility = config[("xprob", "volatility")], drift = config[("xprob", "drift")], - autoregression_target = config[("xprob", "autoregression_target")], - autoregression_strength = config[("xprob", "autoregression_strength")], + autoconnection_strength = config[("xprob", "autoconnection_strength")], initial_mean = config[("xprob", "initial_mean")], initial_precision = config[("xprob", "initial_precision")], ), diff --git a/src/premade_models/premade_hgfs/premade_binary_3level.jl b/src/premade_models/premade_hgfs/premade_binary_3level.jl index 24cb05c..1ee9c0b 100644 --- a/src/premade_models/premade_hgfs/premade_binary_3level.jl +++ b/src/premade_models/premade_hgfs/premade_binary_3level.jl @@ -32,14 +32,12 @@ function premade_binary_3level(config::Dict; verbose::Bool = true) ("u", "input_precision") => Inf, ("xprob", "volatility") => -2, ("xprob", "drift") => 0, - ("xprob", "autoregression_target") => 0, - ("xprob", "autoregression_strength") => 0, + ("xprob", "autoconnection_strength") => 1, ("xprob", "initial_mean") => 0, ("xprob", "initial_precision") => 1, ("xvol", "volatility") => -2, ("xvol", "drift") => 0, - ("xvol", "autoregression_target") => 0, - ("xvol", "autoregression_strength") => 0, + ("xvol", "autoconnection_strength") => 1, ("xvol", "initial_mean") => 0, ("xvol", "initial_precision") => 1, ("xbin", "xprob", "coupling_strength") => 1, @@ -63,8 +61,7 @@ function premade_binary_3level(config::Dict; verbose::Bool = true) name = "xprob", volatility = config[("xprob", "volatility")], drift = config[("xprob", "drift")], - autoregression_target = config[("xprob", "autoregression_target")], - autoregression_strength = config[("xprob", "autoregression_strength")], + autoconnection_strength = config[("xprob", "autoconnection_strength")], initial_mean = config[("xprob", "initial_mean")], initial_precision = config[("xprob", "initial_precision")], ), @@ -72,8 +69,7 @@ function premade_binary_3level(config::Dict; verbose::Bool = true) name = "xvol", volatility = config[("xvol", "volatility")], drift = config[("xvol", "drift")], - autoregression_target = config[("xvol", "autoregression_target")], - autoregression_strength = config[("xvol", "autoregression_strength")], + autoconnection_strength = config[("xvol", "autoconnection_strength")], initial_mean = config[("xvol", "initial_mean")], initial_precision = config[("xvol", "initial_precision")], ), diff --git a/src/premade_models/premade_hgfs/premade_categorical_3level.jl b/src/premade_models/premade_hgfs/premade_categorical_3level.jl index 54d693c..50e278d 100644 --- a/src/premade_models/premade_hgfs/premade_categorical_3level.jl +++ b/src/premade_models/premade_hgfs/premade_categorical_3level.jl @@ -27,14 +27,12 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) "n_categories" => 4, ("xprob", "volatility") => -2, ("xprob", "drift") => 0, - ("xprob", "autoregression_target") => 0, - ("xprob", "autoregression_strength") => 0, + ("xprob", "autoconnection_strength") => 1, ("xprob", "initial_mean") => 0, ("xprob", "initial_precision") => 1, ("xvol", "volatility") => -2, ("xvol", "drift") => 0, - ("xvol", "autoregression_target") => 0, - ("xvol", "autoregression_strength") => 0, + ("xvol", "autoconnection_strength") => 1, ("xvol", "initial_mean") => 0, ("xvol", "initial_precision") => 1, ("xbin", "xprob", "coupling_strength") => 1, @@ -62,8 +60,7 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) derived_parameters_xprob_initial_mean = [] derived_parameters_xprob_volatility = [] derived_parameters_xprob_drift = [] - derived_parameters_xprob_autoregression_target = [] - derived_parameters_xprob_autoregression_strength = [] + derived_parameters_xprob_autoconnection_strength = [] derived_parameters_xbin_xprob_coupling_strength = [] derived_parameters_xprob_xvol_coupling_strength = [] @@ -89,8 +86,7 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) name = node_name, volatility = config[("xprob", "volatility")], drift = config[("xprob", "drift")], - autoregression_target = config[("xprob", "autoregression_target")], - autoregression_strength = config[("xprob", "autoregression_strength")], + autoconnection_strength = config[("xprob", "autoconnection_strength")], initial_mean = config[("xprob", "initial_mean")], initial_precision = config[("xprob", "initial_precision")], ), @@ -101,12 +97,8 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) push!(derived_parameters_xprob_volatility, (node_name, "volatility")) push!(derived_parameters_xprob_drift, (node_name, "drift")) push!( - derived_parameters_xprob_autoregression_strength, - (node_name, "autoregression_strength"), - ) - push!( - derived_parameters_xprob_autoregression_target, - (node_name, "autoregression_target"), + derived_parameters_xprob_autoconnection_strength, + (node_name, "autoconnection_strength"), ) end @@ -117,8 +109,7 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) name = "xvol", volatility = config[("xvol", "volatility")], drift = config[("xvol", "drift")], - autoregression_target = config[("xvol", "autoregression_target")], - autoregression_strength = config[("xvol", "autoregression_strength")], + autoconnection_strength = config[("xvol", "autoconnection_strength")], initial_mean = config[("xvol", "initial_mean")], initial_precision = config[("xvol", "initial_precision")], ), @@ -169,14 +160,9 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) shared_parameters["xprob_drift"] = (config[("xprob", "drift")], derived_parameters_xprob_drift) - shared_parameters["xprob_autoregression_strength"] = ( - config[("xprob", "autoregression_strength")], - derived_parameters_xprob_autoregression_strength, - ) - - shared_parameters["xprob_autoregression_target"] = ( - config[("xprob", "autoregression_target")], - derived_parameters_xprob_autoregression_target, + shared_parameters["autoconnection_strength"] = ( + config[("xprob", "autoconnection_strength")], + derived_parameters_xprob_autoconnection_strength, ) shared_parameters["xbin_xprob_coupling_strength"] = ( diff --git a/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl index 0e74804..a96a185 100644 --- a/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl +++ b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl @@ -34,14 +34,12 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo "n_categories" => 4, ("xprob", "volatility") => -2, ("xprob", "drift") => 0, - ("xprob", "autoregression_target") => 0, - ("xprob", "autoregression_strength") => 0, + ("xprob", "autoconnection_strength") => 1, ("xprob", "initial_mean") => 0, ("xprob", "initial_precision") => 1, ("xvol", "volatility") => -2, ("xvol", "drift") => 0, - ("xvol", "autoregression_target") => 0, - ("xvol", "autoregression_strength") => 0, + ("xvol", "autoconnection_strength") => 1, ("xvol", "initial_mean") => 0, ("xvol", "initial_precision") => 1, ("xbin", "xprob", "coupling_strength") => 1, @@ -70,8 +68,7 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo derived_parameters_xprob_initial_mean = [] derived_parameters_xprob_volatility = [] derived_parameters_xprob_drift = [] - derived_parameters_xprob_autoregression_target = [] - derived_parameters_xprob_autoregression_strength = [] + derived_parameters_xprob_autoconnection_strength = [] derived_parameters_xbin_xprob_coupling_strength = [] derived_parameters_xprob_xvol_coupling_strength = [] @@ -125,8 +122,7 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo name = node_name, volatility = config[("xprob", "volatility")], drift = config[("xprob", "drift")], - autoregression_target = config[("xprob", "autoregression_target")], - autoregression_strength = config[("xprob", "autoregression_strength")], + autoconnection_strength = config[("xprob", "autoconnection_strength")], initial_mean = config[("xprob", "initial_mean")], initial_precision = config[("xprob", "initial_precision")], ), @@ -137,12 +133,8 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo push!(derived_parameters_xprob_volatility, (node_name, "volatility")) push!(derived_parameters_xprob_drift, (node_name, "drift")) push!( - derived_parameters_xprob_autoregression_strength, - (node_name, "autoregression_strength"), - ) - push!( - derived_parameters_xprob_autoregression_target, - (node_name, "autoregression_target"), + derived_parameters_xprob_autoconnection_strength, + (node_name, "autoconnection_strength"), ) end @@ -154,8 +146,7 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo name = "xvol", volatility = config[("xvol", "volatility")], drift = config[("xvol", "drift")], - autoregression_target = config[("xvol", "autoregression_target")], - autoregression_strength = config[("xvol", "autoregression_strength")], + autoconnection_strength = config[("xvol", "autoconnection_strength")], initial_mean = config[("xvol", "initial_mean")], initial_precision = config[("xvol", "initial_precision")], ), @@ -233,14 +224,9 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo shared_parameters["xprob_drift"] = (config[("xprob", "drift")], derived_parameters_xprob_drift) - shared_parameters["xprob_autoregression_strength"] = ( - config[("xprob", "autoregression_strength")], - derived_parameters_xprob_autoregression_strength, - ) - - shared_parameters["xprob_autoregression_target"] = ( - config[("xprob", "autoregression_target")], - derived_parameters_xprob_autoregression_target, + shared_parameters["autoconnection_strength"] = ( + config[("xprob", "autoconnection_strength")], + derived_parameters_xprob_autoconnection_strength, ) shared_parameters["xbin_xprob_coupling_strength"] = ( diff --git a/src/premade_models/premade_hgfs/premade_continuous_2level.jl b/src/premade_models/premade_hgfs/premade_continuous_2level.jl index 7eb03c3..4421bac 100644 --- a/src/premade_models/premade_hgfs/premade_continuous_2level.jl +++ b/src/premade_models/premade_hgfs/premade_continuous_2level.jl @@ -20,16 +20,15 @@ function premade_continuous_2level(config::Dict; verbose::Bool = true) #Defaults spec_defaults = Dict( ("u", "input_noise") => -2, + ("u", "bias") => 0, ("x", "volatility") => -2, ("x", "drift") => 0, - ("x", "autoregression_target") => 0, - ("x", "autoregression_strength") => 0, + ("x", "autoconnection_strength") => 1, ("x", "initial_mean") => 0, ("x", "initial_precision") => 1, ("xvol", "volatility") => -2, ("xvol", "drift") => 0, - ("xvol", "autoregression_target") => 0, - ("xvol", "autoregression_strength") => 0, + ("xvol", "autoconnection_strength") => 1, ("xvol", "initial_mean") => 0, ("xvol", "initial_precision") => 1, ("x", "xvol", "coupling_strength") => 1, @@ -46,13 +45,16 @@ function premade_continuous_2level(config::Dict; verbose::Bool = true) #List of nodes nodes = [ - ContinuousInput(name = "u", input_noise = config[("u", "input_noise")]), + ContinuousInput( + name = "u", + input_noise = config[("u", "input_noise")], + bias = config[("u", "bias")], + ), ContinuousState( name = "x", volatility = config[("x", "volatility")], drift = config[("x", "drift")], - autoregression_target = config[("x", "autoregression_target")], - autoregression_strength = config[("x", "autoregression_strength")], + autoconnection_strength = config[("x", "autoconnection_strength")], initial_mean = config[("x", "initial_mean")], initial_precision = config[("x", "initial_precision")], ), @@ -60,8 +62,7 @@ function premade_continuous_2level(config::Dict; verbose::Bool = true) name = "xvol", volatility = config[("xvol", "volatility")], drift = config[("xvol", "drift")], - autoregression_target = config[("xvol", "autoregression_target")], - autoregression_strength = config[("xvol", "autoregression_strength")], + autoconnection_strength = config[("xvol", "autoconnection_strength")], initial_mean = config[("xvol", "initial_mean")], initial_precision = config[("xvol", "initial_precision")], ), diff --git a/src/update_hgf/node_updates/continuous_input_node.jl b/src/update_hgf/node_updates/continuous_input_node.jl index 8c4d261..a2bec4d 100644 --- a/src/update_hgf/node_updates/continuous_input_node.jl +++ b/src/update_hgf/node_updates/continuous_input_node.jl @@ -28,7 +28,7 @@ function calculate_prediction_mean(node::ContinuousInputNode) observation_parents = node.edges.observation_parents #Initialize prediction at 0 - prediction_mean = 0 + prediction_mean = node.parameters.bias #Sum the predictions of the parents for parent in observation_parents diff --git a/src/update_hgf/node_updates/continuous_state_node.jl b/src/update_hgf/node_updates/continuous_state_node.jl index 58fe6cf..e260d71 100644 --- a/src/update_hgf/node_updates/continuous_state_node.jl +++ b/src/update_hgf/node_updates/continuous_state_node.jl @@ -51,11 +51,8 @@ function calculate_prediction_mean(node::ContinuousStateNode) #Get out value parents drift_parents = node.edges.drift_parents - #Initialize the total drift as the baseline drift plus the autoregression drift - predicted_drift = - node.parameters.drift + - node.parameters.autoregression_strength * - (node.parameters.autoregression_target - node.states.posterior_mean) + #Initialize the total drift as the baseline drift + predicted_drift = node.parameters.drift #Add contributions from value parents for parent in drift_parents @@ -64,7 +61,9 @@ function calculate_prediction_mean(node::ContinuousStateNode) end #Add the drift to the posterior to get the prediction mean - prediction_mean = node.states.posterior_mean + 1 * predicted_drift + prediction_mean = + node.parameters.autoconnection_strength * node.states.posterior_mean + + 1 * predicted_drift return prediction_mean end From f21a9a7588bc7225b8da26c9c492e4e43ab18d38 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Sun, 7 Apr 2024 23:50:19 +0200 Subject: [PATCH 09/16] minor --- src/update_hgf/update_hgf.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/update_hgf/update_hgf.jl b/src/update_hgf/update_hgf.jl index 0af64b8..5c0cf16 100644 --- a/src/update_hgf/update_hgf.jl +++ b/src/update_hgf/update_hgf.jl @@ -83,7 +83,7 @@ Set input values in input nodes. Can either take a single value, a vector of val function enter_node_inputs!(hgf::HGF, input::Union{Real,Missing}) #Update the input node by passing the specified input to it - update_node_input!(hgf.ordered_nodes.input_nodes[1], input) + update_node_input!(first(hgf.ordered_nodes.input_nodes), input) return nothing end @@ -93,7 +93,7 @@ function enter_node_inputs!(hgf::HGF, inputs::Vector{<:Union{Real,Missing}}) #If the vector of inputs only contain a single input if length(inputs) == 1 #Just input that into the first input node - update_node_input!(hgf.ordered_nodes.input_nodes[1], inputs[1]) + enter_node_inputs!(hgf, first(inputs)) else From 1316109704bc9b36cd97ccbdccedd94b2ca5f4fb Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Fri, 12 Apr 2024 14:49:33 +0200 Subject: [PATCH 10/16] added timesteps and change to effective prediction precision --- docs/julia_files/tutorials/classic_binary.jl | 4 - docs/julia_files/tutorials/classic_usdchf.jl | 2 +- .../user_guide/utility_functions.jl | 5 +- .../core/plot_trajectory.jl | 7 +- .../utils/get_states.jl | 2 +- .../utils/give_inputs.jl | 49 ++++++++-- src/ActionModels_variations/utils/reset.jl | 6 +- src/create_hgf/hgf_structs.jl | 7 +- src/create_hgf/init_hgf.jl | 1 + .../node_updates/binary_input_node.jl | 2 +- .../node_updates/binary_state_node.jl | 2 +- .../node_updates/categorical_input_node.jl | 2 +- .../node_updates/categorical_state_node.jl | 2 +- .../node_updates/continuous_input_node.jl | 2 +- .../node_updates/continuous_state_node.jl | 90 +++++++------------ src/update_hgf/update_hgf.jl | 10 ++- src/utils/get_prediction.jl | 41 ++++----- 17 files changed, 118 insertions(+), 116 deletions(-) diff --git a/docs/julia_files/tutorials/classic_binary.jl b/docs/julia_files/tutorials/classic_binary.jl index b53c618..dddd856 100644 --- a/docs/julia_files/tutorials/classic_binary.jl +++ b/docs/julia_files/tutorials/classic_binary.jl @@ -11,9 +11,6 @@ using Plots using StatsPlots using Distributions -CSV.read(pwd(), DataFrame) -print(pwd()) - # Get the path for the HGF superfolder hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) # Add the path to the data files @@ -50,7 +47,6 @@ actions = give_inputs!(agent, inputs); plot_trajectory(agent, ("u", "input_value")) plot_trajectory!(agent, ("xbin", "prediction")) - # - plot_trajectory(agent, ("xprob", "posterior")) diff --git a/docs/julia_files/tutorials/classic_usdchf.jl b/docs/julia_files/tutorials/classic_usdchf.jl index d783fc1..f75ce59 100644 --- a/docs/julia_files/tutorials/classic_usdchf.jl +++ b/docs/julia_files/tutorials/classic_usdchf.jl @@ -111,7 +111,7 @@ fitted_model = fit_model( actions, fixed_parameters = fixed_parameters, verbose = false, - n_iterations = 4000, + n_iterations = 10, ) #- # Plot the chains diff --git a/docs/julia_files/user_guide/utility_functions.jl b/docs/julia_files/user_guide/utility_functions.jl index 434f650..3e2bd21 100644 --- a/docs/julia_files/user_guide/utility_functions.jl +++ b/docs/julia_files/user_guide/utility_functions.jl @@ -48,10 +48,7 @@ get_states(agent, ("xprob", "posterior_precision")) #getting multiple states get_states( agent, - [ - ("xprob", "posterior_precision"), - ("xprob", "volatility_weighted_prediction_precision"), - ], + [("xprob", "posterior_precision"), ("xprob", "effective_prediction_precision")], ) diff --git a/src/ActionModels_variations/core/plot_trajectory.jl b/src/ActionModels_variations/core/plot_trajectory.jl index 110f07c..d616da5 100644 --- a/src/ActionModels_variations/core/plot_trajectory.jl +++ b/src/ActionModels_variations/core/plot_trajectory.jl @@ -160,6 +160,9 @@ end #Get the node node = hgf.all_nodes[node_name] + #Get the timesteps + timesteps = hgf.timesteps + #If the entire distribution is to be plotted if state_name in ["posterior", "prediction"] && !(node isa CategoricalStateNode) @@ -187,7 +190,7 @@ end end #Plot the history of means - history_mean + (timesteps, history_mean) end #If single state is specified @@ -227,7 +230,7 @@ end title --> "State trajectory" #Plot the history - state_history + (timesteps, state_history) end end end diff --git a/src/ActionModels_variations/utils/get_states.jl b/src/ActionModels_variations/utils/get_states.jl index ef302a9..02db849 100644 --- a/src/ActionModels_variations/utils/get_states.jl +++ b/src/ActionModels_variations/utils/get_states.jl @@ -44,7 +44,7 @@ function ActionModels.get_states(node::AbstractNode, state_name::String) "prediction", "prediction_mean", "prediction_precision", - "volatility_weighted_prediction_precision", + "effective_prediction_precision", ] #Get the new prediction prediction = get_prediction(node) diff --git a/src/ActionModels_variations/utils/give_inputs.jl b/src/ActionModels_variations/utils/give_inputs.jl index de013d6..a34e05a 100644 --- a/src/ActionModels_variations/utils/give_inputs.jl +++ b/src/ActionModels_variations/utils/give_inputs.jl @@ -5,26 +5,49 @@ Give inputs to an agent. Input can be a single value, a vector of values, or an """ function ActionModels.give_inputs!() end -function ActionModels.give_inputs!(hgf::HGF, inputs::Real) + +### Giving a single input ### +function ActionModels.give_inputs!(hgf::HGF, inputs::Real; stepsizes::Real = 1) #Input the value to the hgf - update_hgf!(hgf, inputs) + update_hgf!(hgf, inputs, stepsize = stepsizes) return nothing end -function ActionModels.give_inputs!(hgf::HGF, inputs::Vector) +### Giving a vector of inputs ### +function ActionModels.give_inputs!( + hgf::HGF, + inputs::Vector; + stepsizes::Union{Real,Vector} = 1, +) + + #Create vector of stepsizes + if stepsizes isa Real + stepsizes = fill(stepsizes, length(inputs)) + end + + #Check that inputs and stepsizes are the same length + if length(inputs) != length(stepsizes) + @error "The number of inputs and stepsizes must be the same." + end #Each entry in the vector is an input - for input in inputs + for (input, stepsize) in zip(inputs, stepsizes) #Input it to the hgf - update_hgf!(hgf, input) + update_hgf!(hgf, input; stepsize = stepsize) end return nothing end -function ActionModels.give_inputs!(hgf::HGF, inputs::Array) + +### Giving a matrix of inputs (multiple per timestep) ### +function ActionModels.give_inputs!( + hgf::HGF, + inputs::Array; + stepsizes::Union{Real,Vector} = 1, +) #If number of column in input is diffferent from amount of input nodes if size(inputs, 2) != length(hgf.input_nodes) @@ -36,10 +59,20 @@ function ActionModels.give_inputs!(hgf::HGF, inputs::Array) ) end + #Create vector of stepsizes + if stepsizes isa Real + stepsizes = fill(stepsizes, length(inputs)) + end + + #Check that inputs and stepsizes are the same length + if size(inputs, 1) != length(stepsizes) + @error "The number of inputs and stepsizes must be the same." + end + #Take each row in the array - for input in eachrow(inputs) + for (input, stepsize) in zip(eachrow(inputs), stepsizes) #Input it to the hgf - update_hgf!(hgf, Vector(input)) + update_hgf!(hgf, Vector(input), stepsize = stepsize) end return nothing diff --git a/src/ActionModels_variations/utils/reset.jl b/src/ActionModels_variations/utils/reset.jl index 65ca1c3..a614628 100644 --- a/src/ActionModels_variations/utils/reset.jl +++ b/src/ActionModels_variations/utils/reset.jl @@ -5,6 +5,9 @@ Reset an HGF to its initial state. """ function ActionModels.reset!(hgf::HGF) + #Reset the timesteps for the HGF + hgf.timesteps = [0] + #Go through each node for node in hgf.ordered_nodes.all_nodes @@ -33,9 +36,8 @@ function reset_state!(node::ContinuousStateNode) node.states.precision_prediction_error = missing node.states.prediction_mean = missing - node.states.predicted_volatility = missing node.states.prediction_precision = missing - node.states.volatility_weighted_prediction_precision = missing + node.states.effective_prediction_precision = missing return nothing end diff --git a/src/create_hgf/hgf_structs.jl b/src/create_hgf/hgf_structs.jl index 6b79cd2..13f75ba 100644 --- a/src/create_hgf/hgf_structs.jl +++ b/src/create_hgf/hgf_structs.jl @@ -81,6 +81,7 @@ Base.@kwdef mutable struct HGF state_nodes::Dict{String,AbstractStateNode} ordered_nodes::OrderedNodes = OrderedNodes() shared_parameters::Dict = Dict() + timesteps::Vector{Real} = [0] end ################################## @@ -171,9 +172,8 @@ Base.@kwdef mutable struct ContinuousStateNodeState value_prediction_error::Union{Real,Missing} = missing precision_prediction_error::Union{Real,Missing} = missing prediction_mean::Union{Real,Missing} = missing - predicted_volatility::Union{Real,Missing} = missing prediction_precision::Union{Real,Missing} = missing - volatility_weighted_prediction_precision::Union{Real,Missing} = missing + effective_prediction_precision::Union{Real,Missing} = missing end """ @@ -185,9 +185,8 @@ Base.@kwdef mutable struct ContinuousStateNodeHistory value_prediction_error::Vector{Union{Real,Missing}} = [] precision_prediction_error::Vector{Union{Real,Missing}} = [] prediction_mean::Vector{Union{Real,Missing}} = [] - predicted_volatility::Vector{Union{Real,Missing}} = [] prediction_precision::Vector{Union{Real,Missing}} = [] - volatility_weighted_prediction_precision::Vector{Union{Real,Missing}} = [] + effective_prediction_precision::Vector{Union{Real,Missing}} = [] end """ diff --git a/src/create_hgf/init_hgf.jl b/src/create_hgf/init_hgf.jl index 93ccaca..01ee21b 100644 --- a/src/create_hgf/init_hgf.jl +++ b/src/create_hgf/init_hgf.jl @@ -214,6 +214,7 @@ function init_hgf(; state_nodes_dict, ordered_nodes, shared_parameters_dict, + [0], ) ### Check that the HGF has been specified properly ### diff --git a/src/update_hgf/node_updates/binary_input_node.jl b/src/update_hgf/node_updates/binary_input_node.jl index 3e05c18..a6b77d1 100644 --- a/src/update_hgf/node_updates/binary_input_node.jl +++ b/src/update_hgf/node_updates/binary_input_node.jl @@ -8,7 +8,7 @@ There is no prediction update for binary input nodes, as the prediction precision is constant. """ -function update_node_prediction!(node::BinaryInputNode) +function update_node_prediction!(node::BinaryInputNode, stepsize::Real) return nothing end diff --git a/src/update_hgf/node_updates/binary_state_node.jl b/src/update_hgf/node_updates/binary_state_node.jl index 95fea07..400caf6 100644 --- a/src/update_hgf/node_updates/binary_state_node.jl +++ b/src/update_hgf/node_updates/binary_state_node.jl @@ -8,7 +8,7 @@ Update the prediction of a single binary state node. """ -function update_node_prediction!(node::BinaryStateNode) +function update_node_prediction!(node::BinaryStateNode, stepsize::Real) #Update prediction mean node.states.prediction_mean = calculate_prediction_mean(node) diff --git a/src/update_hgf/node_updates/categorical_input_node.jl b/src/update_hgf/node_updates/categorical_input_node.jl index 80fa786..1c2e5f6 100644 --- a/src/update_hgf/node_updates/categorical_input_node.jl +++ b/src/update_hgf/node_updates/categorical_input_node.jl @@ -8,7 +8,7 @@ There is no prediction update for categorical input nodes, as the prediction precision is constant. """ -function update_node_prediction!(node::CategoricalInputNode) +function update_node_prediction!(node::CategoricalInputNode, stepsize::Real) return nothing end diff --git a/src/update_hgf/node_updates/categorical_state_node.jl b/src/update_hgf/node_updates/categorical_state_node.jl index 458daba..4c751fc 100644 --- a/src/update_hgf/node_updates/categorical_state_node.jl +++ b/src/update_hgf/node_updates/categorical_state_node.jl @@ -8,7 +8,7 @@ Update the prediction of a single categorical state node. """ -function update_node_prediction!(node::CategoricalStateNode) +function update_node_prediction!(node::CategoricalStateNode, stepsize::Real) #Update prediction mean node.states.prediction, node.states.parent_predictions = calculate_prediction(node) diff --git a/src/update_hgf/node_updates/continuous_input_node.jl b/src/update_hgf/node_updates/continuous_input_node.jl index a2bec4d..510e1f1 100644 --- a/src/update_hgf/node_updates/continuous_input_node.jl +++ b/src/update_hgf/node_updates/continuous_input_node.jl @@ -8,7 +8,7 @@ Update the posterior of a single input node. """ -function update_node_prediction!(node::ContinuousInputNode) +function update_node_prediction!(node::ContinuousInputNode, stepsize::Real) #Update node prediction mean node.states.prediction_mean = calculate_prediction_mean(node) diff --git a/src/update_hgf/node_updates/continuous_state_node.jl b/src/update_hgf/node_updates/continuous_state_node.jl index e260d71..20b2f73 100644 --- a/src/update_hgf/node_updates/continuous_state_node.jl +++ b/src/update_hgf/node_updates/continuous_state_node.jl @@ -8,36 +8,24 @@ Update the prediction of a single state node. """ -function update_node_prediction!(node::ContinuousStateNode) +function update_node_prediction!(node::ContinuousStateNode, stepsize::Real) #Update prediction mean - node.states.prediction_mean = calculate_prediction_mean(node) + node.states.prediction_mean = calculate_prediction_mean(node, stepsize) push!(node.history.prediction_mean, node.states.prediction_mean) - #Update prediction volatility - node.states.predicted_volatility = calculate_predicted_volatility(node) - push!(node.history.predicted_volatility, node.states.predicted_volatility) - #Update prediction precision - node.states.prediction_precision = calculate_prediction_precision(node) + node.states.prediction_precision, node.states.effective_prediction_precision = + calculate_prediction_precision(node, stepsize) push!(node.history.prediction_precision, node.states.prediction_precision) - - #Get auxiliary prediction precision, only if there are volatility children and/or volatility parents - if length(node.edges.volatility_parents) > 0 || - length(node.edges.volatility_children) > 0 || - length(node.edges.noise_children) > 0 - node.states.volatility_weighted_prediction_precision = - calculate_volatility_weighted_prediction_precision(node) - push!( - node.history.volatility_weighted_prediction_precision, - node.states.volatility_weighted_prediction_precision, - ) - end + push!( + node.history.effective_prediction_precision, + node.states.effective_prediction_precision, + ) return nothing end - ##### Mean update ##### @doc raw""" calculate_prediction_mean(node::AbstractNode) @@ -47,7 +35,7 @@ Calculates a node's prediction mean. Uses the equation `` \hat{\mu}_i=\mu_i+\sum_{j=1}^{j\;value\;parents} \mu_{j} \cdot \alpha_{i,j} `` """ -function calculate_prediction_mean(node::ContinuousStateNode) +function calculate_prediction_mean(node::ContinuousStateNode, stepsize::Real) #Get out value parents drift_parents = node.edges.drift_parents @@ -60,48 +48,47 @@ function calculate_prediction_mean(node::ContinuousStateNode) parent.states.posterior_mean * node.parameters.coupling_strengths[parent.name] end + #Multiply with stepsize + predicted_drift = stepsize * predicted_drift + #Add the drift to the posterior to get the prediction mean prediction_mean = node.parameters.autoconnection_strength * node.states.posterior_mean + - 1 * predicted_drift + predicted_drift return prediction_mean end -##### Predicted volatility update ##### +##### Precision update ##### @doc raw""" - calculate_predicted_volatility(node::AbstractNode) + calculate_prediction_precision(node::AbstractNode) -Calculates a node's prediction volatility. +Calculates a node's prediction precision. Uses the equation -`` \nu_i =exp( \omega_i + \sum_{j=1}^{j\;volatility\;parents} \mu_{j} \cdot \kappa_{i,j}} `` +`` \hat{\pi}_i^ = `` """ -function calculate_predicted_volatility(node::ContinuousStateNode) +function calculate_prediction_precision(node::ContinuousStateNode, stepsize::Real) + #Extract volatility parents volatility_parents = node.edges.volatility_parents + #Initialize the predicted volatility as the baseline volatility predicted_volatility = node.parameters.volatility + #Add contributions from volatility parents for parent in volatility_parents predicted_volatility += parent.states.posterior_mean * node.parameters.coupling_strengths[parent.name] end - return exp(predicted_volatility) -end - -##### Precision update ##### -@doc raw""" - calculate_prediction_precision(node::AbstractNode) + #Exponentiate and multiply with stepsize + predicted_volatility = stepsize * exp(predicted_volatility) -Calculates a node's prediction precision. + #Calculate prediction precision + prediction_precision = 1 / (1 / node.states.posterior_precision + predicted_volatility) -Uses the equation -`` \hat{\pi}_i^ = \frac{1}{\frac{1}{\pi_i}+\nu_i^} `` -""" -function calculate_prediction_precision(node::ContinuousStateNode) - prediction_precision = - 1 / (1 / node.states.posterior_precision + node.states.predicted_volatility) + #Calculate the volatility-weighted effective precision + effective_prediction_precision = predicted_volatility * prediction_precision #If the posterior precision is negative if prediction_precision < 0 @@ -114,20 +101,9 @@ function calculate_prediction_precision(node::ContinuousStateNode) ) end - return prediction_precision + return prediction_precision, effective_prediction_precision end -@doc raw""" - calculate_volatility_weighted_prediction_precision(node::AbstractNode) - -Calculates a node's auxiliary prediction precision. - -Uses the equation -`` \gamma_i = \nu_i \cdot \hat{\pi}_i `` -""" -function calculate_volatility_weighted_prediction_precision(node::ContinuousStateNode) - node.states.predicted_volatility * node.states.prediction_precision -end ################################## ######## Update posterior ######## @@ -287,16 +263,16 @@ function calculate_posterior_precision_increment( 1 / 2 * ( child.parameters.coupling_strengths[node.name] * - child.states.volatility_weighted_prediction_precision + child.states.effective_prediction_precision )^2 + child.states.precision_prediction_error * ( child.parameters.coupling_strengths[node.name] * - child.states.volatility_weighted_prediction_precision + child.states.effective_prediction_precision )^2 - 1 / 2 * child.parameters.coupling_strengths[node.name]^2 * - child.states.volatility_weighted_prediction_precision * + child.states.effective_prediction_precision * child.states.precision_prediction_error end @@ -490,7 +466,7 @@ function calculate_posterior_mean_increment( ) 1 / 2 * ( child.parameters.coupling_strengths[node.name] * - child.states.volatility_weighted_prediction_precision + child.states.effective_prediction_precision ) / node.states.posterior_precision * child.states.precision_prediction_error end @@ -503,7 +479,7 @@ function calculate_posterior_mean_increment( ) 1 / 2 * ( child.parameters.coupling_strengths[node.name] * - child.states.volatility_weighted_prediction_precision + child.states.effective_prediction_precision ) / node.states.prediction_precision * child.states.precision_prediction_error end diff --git a/src/update_hgf/update_hgf.jl b/src/update_hgf/update_hgf.jl index 5c0cf16..4cba09b 100644 --- a/src/update_hgf/update_hgf.jl +++ b/src/update_hgf/update_hgf.jl @@ -18,19 +18,23 @@ function update_hgf!( Missing, Vector{<:Union{Real,Missing}}, Dict{String,<:Union{Real,Missing}}, - }, + }; + stepsize::Real = 1, ) ## Update node predictions from last timestep + #Update the timepoint + push!(hgf.timesteps, hgf.timesteps[end] + stepsize) + #For each node (in the opposite update order) for node in reverse(hgf.ordered_nodes.all_state_nodes) #Update its prediction from last trial - update_node_prediction!(node) + update_node_prediction!(node, stepsize) end #For each input node, in the specified update order for node in reverse(hgf.ordered_nodes.input_nodes) #Update its prediction from last trial - update_node_prediction!(node) + update_node_prediction!(node, stepsize) end ## Supply inputs to input nodes diff --git a/src/utils/get_prediction.jl b/src/utils/get_prediction.jl index 9151d40..ecdc59a 100644 --- a/src/utils/get_prediction.jl +++ b/src/utils/get_prediction.jl @@ -6,61 +6,52 @@ A single node can also be passed. """ function get_prediction end -function get_prediction(agent::Agent, node_name::String) +function get_prediction(agent::Agent, node_name::String, stepsize::Real = 1) #Get prediction from the HGF - prediction = get_prediction(agent.substruct, node_name) + prediction = get_prediction(agent.substruct, node_name, stepsize) return prediction end -function get_prediction(hgf::HGF, node_name::String) +function get_prediction(hgf::HGF, node_name::String, stepsize::Real = 1) #Get the prediction of the given node - return get_prediction(hgf.all_nodes[node_name]) + return get_prediction(hgf.all_nodes[node_name], stepsize) end ### Single node functions ### -function get_prediction(node::ContinuousStateNode) +function get_prediction(node::ContinuousStateNode, stepsize::Real = 1) #Save old states old_states = (; prediction_mean = node.states.prediction_mean, - predicted_volatility = node.states.predicted_volatility, prediction_precision = node.states.prediction_precision, - volatility_weighted_prediction_precision = node.states.volatility_weighted_prediction_precision, + effective_prediction_precision = node.states.effective_prediction_precision, ) #Update prediction mean - node.states.prediction_mean = calculate_prediction_mean(node) - - #Update prediction volatility - node.states.predicted_volatility = calculate_predicted_volatility(node) + node.states.prediction_mean = calculate_prediction_mean(node, stepsize) #Update prediction precision - node.states.prediction_precision = calculate_prediction_precision(node) - - node.states.volatility_weighted_prediction_precision = - calculate_volatility_weighted_prediction_precision(node) + node.states.prediction_precision, node.states.effective_prediction_precision = + calculate_prediction_precision(node, stepsize) #Save new states new_states = (; prediction_mean = node.states.prediction_mean, - predicted_volatility = node.states.predicted_volatility, prediction_precision = node.states.prediction_precision, - volatility_weighted_prediction_precision = node.states.volatility_weighted_prediction_precision, + effective_prediction_precision = node.states.effective_prediction_precision, ) #Change states back to the old states node.states.prediction_mean = old_states.prediction_mean - node.states.predicted_volatility = old_states.predicted_volatility node.states.prediction_precision = old_states.prediction_precision - node.states.volatility_weighted_prediction_precision = - old_states.volatility_weighted_prediction_precision + node.states.effective_prediction_precision = old_states.effective_prediction_precision return new_states end -function get_prediction(node::BinaryStateNode) +function get_prediction(node::BinaryStateNode, stepsize::Real = 1) #Save old states old_states = (; @@ -87,7 +78,7 @@ function get_prediction(node::BinaryStateNode) return new_states end -function get_prediction(node::CategoricalStateNode) +function get_prediction(node::CategoricalStateNode, stepsize::Real = 1) #Save old states old_states = (; prediction = node.states.prediction) @@ -105,7 +96,7 @@ function get_prediction(node::CategoricalStateNode) end -function get_prediction(node::ContinuousInputNode) +function get_prediction(node::ContinuousInputNode, stepsize::Real = 1) #Save old states old_states = (; @@ -130,7 +121,7 @@ function get_prediction(node::ContinuousInputNode) return new_states end -function get_prediction(node::BinaryInputNode) +function get_prediction(node::BinaryInputNode, stepsize::Real = 1) #Binary input nodes have no prediction states new_states = (;) @@ -138,7 +129,7 @@ function get_prediction(node::BinaryInputNode) return new_states end -function get_prediction(node::CategoricalInputNode) +function get_prediction(node::CategoricalInputNode, stepsize::Real = 1) #Binary input nodes have no prediction states new_states = (;) From c90877ea3dd9e2d19bd4cf3585bce95224fcae72 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Fri, 12 Apr 2024 14:53:19 +0200 Subject: [PATCH 11/16] minor update to test --- test/testsuite/test_fit_model.jl | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/test/testsuite/test_fit_model.jl b/test/testsuite/test_fit_model.jl index 082acbd..4d6698c 100644 --- a/test/testsuite/test_fit_model.jl +++ b/test/testsuite/test_fit_model.jl @@ -50,20 +50,6 @@ using Turing ) @test fitted_model isa Turing.Chains - #Fit with multiple chains and HMC - fitted_model = fit_model( - test_agent, - test_param_priors, - test_input, - test_responses; - fixed_parameters = test_fixed_parameters, - sampler = HMC(0.01, 5), - n_chains = 4, - verbose = false, - n_iterations = 10, - ) - @test fitted_model isa Turing.Chains - #Plot the parameter distribution plot_parameter_distribution(fitted_model, test_param_priors) @@ -121,20 +107,6 @@ using Turing ) @test fitted_model isa Turing.Chains - #Fit with multiple chains and HMC - fitted_model = fit_model( - test_agent, - test_param_priors, - test_input, - test_responses; - fixed_parameters = test_fixed_parameters, - sampler = HMC(0.01, 5), - n_chains = 4, - verbose = false, - n_iterations = 10, - ) - @test fitted_model isa Turing.Chains - #Plot the parameter distribution plot_parameter_distribution(fitted_model, test_param_priors) From 27d5659817c3e4f51a9ab7a9391bf1a27ed228e9 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Fri, 12 Apr 2024 15:04:47 +0200 Subject: [PATCH 12/16] minor bug --- src/ActionModels_variations/utils/give_inputs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ActionModels_variations/utils/give_inputs.jl b/src/ActionModels_variations/utils/give_inputs.jl index a34e05a..4720e54 100644 --- a/src/ActionModels_variations/utils/give_inputs.jl +++ b/src/ActionModels_variations/utils/give_inputs.jl @@ -61,7 +61,7 @@ function ActionModels.give_inputs!( #Create vector of stepsizes if stepsizes isa Real - stepsizes = fill(stepsizes, length(inputs)) + stepsizes = fill(stepsizes, size(inputs, 1)) end #Check that inputs and stepsizes are the same length From ef1a88f2b5ddedb65e80cf7f11a29b582b0f67fa Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Fri, 12 Apr 2024 17:21:44 +0200 Subject: [PATCH 13/16] minor name change to premade hgfs + added save_history --- README.md | 2 +- docs/julia_files/index.jl | 2 +- docs/julia_files/tutorials/classic_JGET.jl | 2 +- docs/julia_files/tutorials/classic_binary.jl | 3 +- docs/julia_files/tutorials/classic_usdchf.jl | 2 +- .../julia_files/user_guide/building_an_HGF.jl | 4 +-- .../user_guide/fitting_hgf_models.jl | 3 +- docs/julia_files/user_guide/premade_HGF.jl | 8 ++--- docs/julia_files/user_guide/premade_models.jl | 4 +-- .../user_guide/utility_functions.jl | 4 +-- src/HierarchicalGaussianFiltering.jl | 18 +++++----- src/create_hgf/hgf_structs.jl | 1 + src/create_hgf/init_hgf.jl | 2 ++ ...gaussian_action.jl => premade_gaussian.jl} | 6 ++-- .../premade_predict_category.jl | 12 +++---- ...e_sigmoid_action.jl => premade_sigmoid.jl} | 0 ...e_softmax_action.jl => premade_softmax.jl} | 0 .../premade_hgfs/premade_JGET.jl | 2 ++ .../premade_hgfs/premade_binary_2level.jl | 2 ++ .../premade_hgfs/premade_binary_3level.jl | 2 ++ .../premade_categorical_3level.jl | 2 ++ .../premade_categorical_transitions_3level.jl | 2 ++ .../premade_hgfs/premade_continuous_2level.jl | 2 ++ .../node_updates/binary_state_node.jl | 5 --- .../node_updates/categorical_state_node.jl | 5 +-- .../node_updates/continuous_input_node.jl | 17 ++++------ .../node_updates/continuous_state_node.jl | 27 +++++---------- src/update_hgf/update_hgf.jl | 34 +++++++++++++------ test/testsuite/test_fit_model.jl | 4 +-- test/testsuite/test_premade_agent.jl | 12 +++---- 30 files changed, 97 insertions(+), 92 deletions(-) rename src/premade_models/premade_agents/{premade_gaussian_action.jl => premade_gaussian.jl} (95%) rename src/premade_models/premade_agents/{premade_sigmoid_action.jl => premade_sigmoid.jl} (100%) rename src/premade_models/premade_agents/{premade_softmax_action.jl => premade_softmax.jl} (100%) diff --git a/README.md b/README.md index 95ccc8d..255aefc 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ premade_agent("help") ### Create agent ````@example index -agent = premade_agent("hgf_binary_softmax_action") +agent = premade_agent("hgf_binary_softmax") ```` ### Get states and parameters diff --git a/docs/julia_files/index.jl b/docs/julia_files/index.jl index bef750d..6d5a5d8 100644 --- a/docs/julia_files/index.jl +++ b/docs/julia_files/index.jl @@ -24,7 +24,7 @@ using ActionModels premade_agent("help") # ### Create agent -agent = premade_agent("hgf_binary_softmax_action") +agent = premade_agent("hgf_binary_softmax") # ### Get states and parameters get_states(agent) diff --git a/docs/julia_files/tutorials/classic_JGET.jl b/docs/julia_files/tutorials/classic_JGET.jl index 1d1fb27..40e6303 100644 --- a/docs/julia_files/tutorials/classic_JGET.jl +++ b/docs/julia_files/tutorials/classic_JGET.jl @@ -16,7 +16,7 @@ inputs = data[(data.ID.==20).&(data.session.==1), :].outcome #Create HGF hgf = premade_hgf("JGET", verbose = false) #Create agent -agent = premade_agent("hgf_gaussian_action", hgf) +agent = premade_agent("hgf_gaussian", hgf) #Set parameters parameters = Dict( "action_noise" => 1, diff --git a/docs/julia_files/tutorials/classic_binary.jl b/docs/julia_files/tutorials/classic_binary.jl index dddd856..8be1e9d 100644 --- a/docs/julia_files/tutorials/classic_binary.jl +++ b/docs/julia_files/tutorials/classic_binary.jl @@ -37,8 +37,7 @@ hgf = premade_hgf("binary_3level", hgf_parameters, verbose = false); # Create an agent agent_parameters = Dict("action_noise" => 0.2); -agent = - premade_agent("hgf_unit_square_sigmoid_action", hgf, agent_parameters, verbose = false); +agent = premade_agent("hgf_unit_square_sigmoid", hgf, agent_parameters, verbose = false); # Evolve agent and save actions actions = give_inputs!(agent, inputs); diff --git a/docs/julia_files/tutorials/classic_usdchf.jl b/docs/julia_files/tutorials/classic_usdchf.jl index f75ce59..676fea2 100644 --- a/docs/julia_files/tutorials/classic_usdchf.jl +++ b/docs/julia_files/tutorials/classic_usdchf.jl @@ -24,7 +24,7 @@ end #Create HGF hgf = premade_hgf("continuous_2level", verbose = false); -agent = premade_agent("hgf_gaussian_action", hgf, verbose = false); +agent = premade_agent("hgf_gaussian", hgf, verbose = false); # Set parameters for parameter recover parameters = Dict( diff --git a/docs/julia_files/user_guide/building_an_HGF.jl b/docs/julia_files/user_guide/building_an_HGF.jl index 8cd6451..0329070 100644 --- a/docs/julia_files/user_guide/building_an_HGF.jl +++ b/docs/julia_files/user_guide/building_an_HGF.jl @@ -68,7 +68,7 @@ get_parameters(Binary_2_level_hgf) # We initialize the action model and create it. In a softmax action model we need a parameter from the agent called softmax action precision which is used in the update step of the action model. using Distributions -function binary_softmax_action(agent, input) +function binary_softmax(agent, input) ##------- Staty by getting all information --------- @@ -117,7 +117,7 @@ end # Let's define our action model -action_model = binary_softmax_action; +action_model = binary_softmax; # The parameter of the agent is just softmax action precision. We set this value to 1 diff --git a/docs/julia_files/user_guide/fitting_hgf_models.jl b/docs/julia_files/user_guide/fitting_hgf_models.jl index 9f89554..d72a56f 100644 --- a/docs/julia_files/user_guide/fitting_hgf_models.jl +++ b/docs/julia_files/user_guide/fitting_hgf_models.jl @@ -61,8 +61,7 @@ hgf = premade_hgf("binary_3level", hgf_parameters, verbose = false) # Create an agent agent_parameters = Dict("action_noise" => 0.2); -agent = - premade_agent("hgf_unit_square_sigmoid_action", hgf, agent_parameters, verbose = false); +agent = premade_agent("hgf_unit_square_sigmoid", hgf, agent_parameters, verbose = false); # Define a set of inputs inputs = diff --git a/docs/julia_files/user_guide/premade_HGF.jl b/docs/julia_files/user_guide/premade_HGF.jl index e4ea569..5e29ea8 100644 --- a/docs/julia_files/user_guide/premade_HGF.jl +++ b/docs/julia_files/user_guide/premade_HGF.jl @@ -57,7 +57,7 @@ inputs_binary = CSV.read(hgf_path_binary * "classic_binary_inputs.csv", DataFram #Create HGF and Agent continuous_2_level = premade_hgf("continuous_2level"); agent_continuous_2_level = - premade_agent("hgf_gaussian_action", continuous_2_level, verbose = false); + premade_agent("hgf_gaussian", continuous_2_level, verbose = false); # Evolve agent plot trajetories give_inputs!(agent_continuous_2_level, inputs_continuous); @@ -84,7 +84,7 @@ plot_trajectory( #Create HGF and Agent JGET = premade_hgf("JGET"); -agent_JGET = premade_agent("hgf_gaussian_action", JGET, verbose = false); +agent_JGET = premade_agent("hgf_gaussian", JGET, verbose = false); # Evolve agent plot trajetories give_inputs!(agent_JGET, inputs_continuous); @@ -112,7 +112,7 @@ hgf_binary_2_level = premade_hgf("binary_2level", verbose = false); # Create an agent agent_binary_2_level = - premade_agent("hgf_unit_square_sigmoid_action", hgf_binary_2_level, verbose = false); + premade_agent("hgf_unit_square_sigmoid", hgf_binary_2_level, verbose = false); # Evolve agent plot trajetories give_inputs!(agent_binary_2_level, inputs_binary); @@ -136,7 +136,7 @@ hgf_binary_3_level = premade_hgf("binary_3level", verbose = false); # Create an agent agent_binary_3_level = - premade_agent("hgf_unit_square_sigmoid_action", hgf_binary_3_level, verbose = false); + premade_agent("hgf_unit_square_sigmoid", hgf_binary_3_level, verbose = false); # Evolve agent plot trajetories give_inputs!(agent_binary_3_level, inputs_binary); diff --git a/docs/julia_files/user_guide/premade_models.jl b/docs/julia_files/user_guide/premade_models.jl index e75622b..ec6351c 100644 --- a/docs/julia_files/user_guide/premade_models.jl +++ b/docs/julia_files/user_guide/premade_models.jl @@ -12,7 +12,7 @@ # ## HGF with Gaussian Action Noise agent -# This premade agent model can be found as "hgf_gaussian_action" in the package. The Action distribution is a gaussian distribution with mean of the target state from the chosen HGF, and the standard deviation consisting of the action precision parameter inversed. #md +# This premade agent model can be found as "hgf_gaussian" in the package. The Action distribution is a gaussian distribution with mean of the target state from the chosen HGF, and the standard deviation consisting of the action precision parameter inversed. #md # - Default hgf: contionus_2level # - Default Target state: (x, posterior mean) @@ -53,7 +53,7 @@ using HierarchicalGaussianFiltering premade_agent("help") # Define an agent with default parameter values and default HGF -agent = premade_agent("hgf_binary_softmax_action") +agent = premade_agent("hgf_binary_softmax") # ## Utility functions for accessing parameters and states diff --git a/docs/julia_files/user_guide/utility_functions.jl b/docs/julia_files/user_guide/utility_functions.jl index 3e2bd21..cdd278c 100644 --- a/docs/julia_files/user_guide/utility_functions.jl +++ b/docs/julia_files/user_guide/utility_functions.jl @@ -20,7 +20,7 @@ using HierarchicalGaussianFiltering premade_agent("help") # set agent -agent = premade_agent("hgf_binary_softmax_action") +agent = premade_agent("hgf_binary_softmax") # ### Getting Parameters @@ -77,7 +77,7 @@ hgf_parameters = Dict( hgf = premade_hgf("binary_3level", hgf_parameters) # Define our agent with the HGF and agent parameter settings -agent = premade_agent("hgf_unit_square_sigmoid_action", hgf, agent_parameter) +agent = premade_agent("hgf_unit_square_sigmoid", hgf, agent_parameter) # Changing a single parameter diff --git a/src/HierarchicalGaussianFiltering.jl b/src/HierarchicalGaussianFiltering.jl index cce12b3..fc73469 100644 --- a/src/HierarchicalGaussianFiltering.jl +++ b/src/HierarchicalGaussianFiltering.jl @@ -5,7 +5,7 @@ using ActionModels, Distributions, RecipesBase #Export functions export init_node, init_hgf, premade_hgf, check_hgf, check_node, update_hgf! -export get_prediction, get_surprise, hgf_multiple_actions +export get_prediction, get_surprise export premade_agent, init_agent, plot_predictive_simulation, plot_trajectory, plot_trajectory! export get_history, get_parameters, get_states, set_parameters!, reset!, give_inputs! @@ -22,12 +22,10 @@ export DriftCoupling, #Add premade agents to shared dict at initialization function __init__() - ActionModels.premade_agents["hgf_gaussian_action"] = premade_hgf_gaussian - ActionModels.premade_agents["hgf_binary_softmax_action"] = premade_hgf_binary_softmax - ActionModels.premade_agents["hgf_unit_square_sigmoid_action"] = - premade_hgf_unit_square_sigmoid - ActionModels.premade_agents["hgf_predict_category_action"] = - premade_hgf_predict_category + ActionModels.premade_agents["hgf_gaussian"] = premade_hgf_gaussian + ActionModels.premade_agents["hgf_binary_softmax"] = premade_hgf_binary_softmax + ActionModels.premade_agents["hgf_unit_square_sigmoid"] = premade_hgf_unit_square_sigmoid + ActionModels.premade_agents["hgf_predict_category"] = premade_hgf_predict_category end #Types for HGFs @@ -61,10 +59,10 @@ include("create_hgf/create_premade_hgf.jl") #Plotting functions #Functions for premade agents -include("premade_models/premade_agents/premade_gaussian_action.jl") +include("premade_models/premade_agents/premade_gaussian.jl") include("premade_models/premade_agents/premade_predict_category.jl") -include("premade_models/premade_agents/premade_sigmoid_action.jl") -include("premade_models/premade_agents/premade_softmax_action.jl") +include("premade_models/premade_agents/premade_sigmoid.jl") +include("premade_models/premade_agents/premade_softmax.jl") include("premade_models/premade_hgfs/premade_binary_2level.jl") include("premade_models/premade_hgfs/premade_binary_3level.jl") diff --git a/src/create_hgf/hgf_structs.jl b/src/create_hgf/hgf_structs.jl index 13f75ba..da2feda 100644 --- a/src/create_hgf/hgf_structs.jl +++ b/src/create_hgf/hgf_structs.jl @@ -81,6 +81,7 @@ Base.@kwdef mutable struct HGF state_nodes::Dict{String,AbstractStateNode} ordered_nodes::OrderedNodes = OrderedNodes() shared_parameters::Dict = Dict() + save_history::Bool = true timesteps::Vector{Real} = [0] end diff --git a/src/create_hgf/init_hgf.jl b/src/create_hgf/init_hgf.jl index 01ee21b..b18c58d 100644 --- a/src/create_hgf/init_hgf.jl +++ b/src/create_hgf/init_hgf.jl @@ -79,6 +79,7 @@ function init_hgf(; shared_parameters::Dict = Dict(), update_order::Union{Nothing,Vector{String}} = nothing, verbose::Bool = true, + save_history::Bool = true, ) ### Initialize nodes ### @@ -214,6 +215,7 @@ function init_hgf(; state_nodes_dict, ordered_nodes, shared_parameters_dict, + save_history, [0], ) diff --git a/src/premade_models/premade_agents/premade_gaussian_action.jl b/src/premade_models/premade_agents/premade_gaussian.jl similarity index 95% rename from src/premade_models/premade_agents/premade_gaussian_action.jl rename to src/premade_models/premade_agents/premade_gaussian.jl index e5ff210..d672221 100644 --- a/src/premade_models/premade_agents/premade_gaussian_action.jl +++ b/src/premade_models/premade_agents/premade_gaussian.jl @@ -1,5 +1,5 @@ """ - hgf_gaussian_action(agent::Agent, input) + hgf_gaussian(agent::Agent, input) Action model which reports a given HGF state with Gaussian noise. @@ -7,7 +7,7 @@ In addition to the HGF substruct, the following must be present in the agent: Parameters: "gaussian_action_precision" Settings: "target_state" """ -function hgf_gaussian_action(agent::Agent, input) +function hgf_gaussian(agent::Agent, input) #Extract HGF, settings and parameters hgf = agent.substruct @@ -77,7 +77,7 @@ function premade_hgf_gaussian(config::Dict) ## Create agent #Set the action model - action_model = hgf_gaussian_action + action_model = hgf_gaussian #Set the HGF hgf = config["HGF"] diff --git a/src/premade_models/premade_agents/premade_predict_category.jl b/src/premade_models/premade_agents/premade_predict_category.jl index b98dde7..e6c2574 100644 --- a/src/premade_models/premade_agents/premade_predict_category.jl +++ b/src/premade_models/premade_agents/premade_predict_category.jl @@ -2,33 +2,33 @@ ###### Categorical Prediction Action ###### """ - update_hgf_predict_category_action(agent::Agent, input) + update_hgf_predict_category(agent::Agent, input) Action model that first updates the HGF, and then returns a categorical prediction of the input. The HGF used must be a categorical HGF. In addition to the HGF substruct, the following must be present in the agent: Settings: "target_categorical_node" """ -function update_hgf_predict_category_action(agent::Agent, input) +function update_hgf_predict_category(agent::Agent, input) #Update the HGF update_hgf!(agent.substruct, input) #Run the action model - action_distribution = hgf_predict_category_action(agent, input) + action_distribution = hgf_predict_category(agent, input) return action_distribution end """ - hgf_predict_category_action(agent::Agent, input) + hgf_predict_category(agent::Agent, input) Action model which gives a categorical prediction of the input, based on an HGF. The HGF used must be a categorical HGF. In addition to the HGF substruct, the following must be present in the agent: Settings: "target_categorical_node" """ -function hgf_predict_category_action(agent::Agent, input) +function hgf_predict_category(agent::Agent, input) #Get out settings and parameters target_node = agent.settings["target_categorical_node"] @@ -84,7 +84,7 @@ function premade_hgf_predict_category(config::Dict) ## Create agent #Set the action model - action_model = update_hgf_predict_category_action + action_model = update_hgf_predict_category #Set the HGF hgf = config["HGF"] diff --git a/src/premade_models/premade_agents/premade_sigmoid_action.jl b/src/premade_models/premade_agents/premade_sigmoid.jl similarity index 100% rename from src/premade_models/premade_agents/premade_sigmoid_action.jl rename to src/premade_models/premade_agents/premade_sigmoid.jl diff --git a/src/premade_models/premade_agents/premade_softmax_action.jl b/src/premade_models/premade_agents/premade_softmax.jl similarity index 100% rename from src/premade_models/premade_agents/premade_softmax_action.jl rename to src/premade_models/premade_agents/premade_softmax.jl diff --git a/src/premade_models/premade_hgfs/premade_JGET.jl b/src/premade_models/premade_hgfs/premade_JGET.jl index 557e091..beb170b 100644 --- a/src/premade_models/premade_hgfs/premade_JGET.jl +++ b/src/premade_models/premade_hgfs/premade_JGET.jl @@ -52,6 +52,7 @@ function premade_JGET(config::Dict; verbose::Bool = true) ("x", "xvol", "coupling_strength") => 1, ("xnoise", "xnoise_vol", "coupling_strength") => 1, "update_type" => EnhancedUpdate(), + "save_history" => true, ) #Warn the user about used defaults and misspecified keys @@ -118,5 +119,6 @@ function premade_JGET(config::Dict; verbose::Bool = true) edges = edges, verbose = false, node_defaults = NodeDefaults(update_type = config["update_type"]), + save_history = config["save_history"], ) end diff --git a/src/premade_models/premade_hgfs/premade_binary_2level.jl b/src/premade_models/premade_hgfs/premade_binary_2level.jl index 0413edd..49750d9 100644 --- a/src/premade_models/premade_hgfs/premade_binary_2level.jl +++ b/src/premade_models/premade_hgfs/premade_binary_2level.jl @@ -20,6 +20,7 @@ function premade_binary_2level(config::Dict; verbose::Bool = true) ("xprob", "initial_precision") => 1, ("xbin", "xprob", "coupling_strength") => 1, "update_type" => EnhancedUpdate(), + "save_history" => true, ) #Warn the user about used defaults and misspecified keys @@ -57,5 +58,6 @@ function premade_binary_2level(config::Dict; verbose::Bool = true) edges = edges, verbose = false, node_defaults = NodeDefaults(update_type = config["update_type"]), + save_history = config["save_history"], ) end diff --git a/src/premade_models/premade_hgfs/premade_binary_3level.jl b/src/premade_models/premade_hgfs/premade_binary_3level.jl index 1ee9c0b..c04f733 100644 --- a/src/premade_models/premade_hgfs/premade_binary_3level.jl +++ b/src/premade_models/premade_hgfs/premade_binary_3level.jl @@ -43,6 +43,7 @@ function premade_binary_3level(config::Dict; verbose::Bool = true) ("xbin", "xprob", "coupling_strength") => 1, ("xprob", "xvol", "coupling_strength") => 1, "update_type" => EnhancedUpdate(), + "save_history" => true, ) #Warn the user about used defaults and misspecified keys @@ -90,5 +91,6 @@ function premade_binary_3level(config::Dict; verbose::Bool = true) edges = edges, verbose = false, node_defaults = NodeDefaults(update_type = config["update_type"]), + save_history = config["save_history"], ) end diff --git a/src/premade_models/premade_hgfs/premade_categorical_3level.jl b/src/premade_models/premade_hgfs/premade_categorical_3level.jl index 50e278d..d658a3b 100644 --- a/src/premade_models/premade_hgfs/premade_categorical_3level.jl +++ b/src/premade_models/premade_hgfs/premade_categorical_3level.jl @@ -38,6 +38,7 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) ("xbin", "xprob", "coupling_strength") => 1, ("xprob", "xvol", "coupling_strength") => 1, "update_type" => EnhancedUpdate(), + "save_history" => true, ) #Warn the user about used defaults and misspecified keys @@ -182,5 +183,6 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) shared_parameters = shared_parameters, verbose = false, node_defaults = NodeDefaults(update_type = config["update_type"]), + save_history = config["save_history"], ) end diff --git a/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl index a96a185..e66d9f3 100644 --- a/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl +++ b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl @@ -45,6 +45,7 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo ("xbin", "xprob", "coupling_strength") => 1, ("xprob", "xvol", "coupling_strength") => 1, "update_type" => EnhancedUpdate(), + "save_history" => true, ) #Warn the user about used defaults and misspecified keys @@ -246,5 +247,6 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo shared_parameters = shared_parameters, verbose = false, node_defaults = NodeDefaults(update_type = config["update_type"]), + save_history = config["save_history"], ) end diff --git a/src/premade_models/premade_hgfs/premade_continuous_2level.jl b/src/premade_models/premade_hgfs/premade_continuous_2level.jl index 4421bac..a4c01f4 100644 --- a/src/premade_models/premade_hgfs/premade_continuous_2level.jl +++ b/src/premade_models/premade_hgfs/premade_continuous_2level.jl @@ -33,6 +33,7 @@ function premade_continuous_2level(config::Dict; verbose::Bool = true) ("xvol", "initial_precision") => 1, ("x", "xvol", "coupling_strength") => 1, "update_type" => EnhancedUpdate(), + "save_history" => true, ) #Warn the user about used defaults and misspecified keys @@ -80,5 +81,6 @@ function premade_continuous_2level(config::Dict; verbose::Bool = true) edges = edges, verbose = false, node_defaults = NodeDefaults(update_type = config["update_type"]), + save_history = config["save_history"], ) end diff --git a/src/update_hgf/node_updates/binary_state_node.jl b/src/update_hgf/node_updates/binary_state_node.jl index 400caf6..4e80214 100644 --- a/src/update_hgf/node_updates/binary_state_node.jl +++ b/src/update_hgf/node_updates/binary_state_node.jl @@ -12,11 +12,9 @@ function update_node_prediction!(node::BinaryStateNode, stepsize::Real) #Update prediction mean node.states.prediction_mean = calculate_prediction_mean(node) - push!(node.history.prediction_mean, node.states.prediction_mean) #Update prediction precision node.states.prediction_precision = calculate_prediction_precision(node) - push!(node.history.prediction_precision, node.states.prediction_precision) return nothing end @@ -71,11 +69,9 @@ Update the posterior of a single continuous state node. This is the classic HGF function update_node_posterior!(node::BinaryStateNode, update_type::HGFUpdateType) #Update posterior precision node.states.posterior_precision = calculate_posterior_precision(node) - push!(node.history.posterior_precision, node.states.posterior_precision) #Update posterior mean node.states.posterior_mean = calculate_posterior_mean(node, update_type) - push!(node.history.posterior_mean, node.states.posterior_mean) return nothing end @@ -204,7 +200,6 @@ Update the value prediction error of a single state node. function update_node_value_prediction_error!(node::BinaryStateNode) #Update value prediction error node.states.value_prediction_error = calculate_value_prediction_error(node) - push!(node.history.value_prediction_error, node.states.value_prediction_error) return nothing end diff --git a/src/update_hgf/node_updates/categorical_state_node.jl b/src/update_hgf/node_updates/categorical_state_node.jl index 4c751fc..f2bf1c0 100644 --- a/src/update_hgf/node_updates/categorical_state_node.jl +++ b/src/update_hgf/node_updates/categorical_state_node.jl @@ -12,8 +12,7 @@ function update_node_prediction!(node::CategoricalStateNode, stepsize::Real) #Update prediction mean node.states.prediction, node.states.parent_predictions = calculate_prediction(node) - push!(node.history.prediction, node.states.prediction) - push!(node.history.parent_predictions, node.states.parent_predictions) + return nothing end @@ -79,7 +78,6 @@ function update_node_posterior!(node::CategoricalStateNode, update_type::Classic #Update posterior mean node.states.posterior = calculate_posterior(node) - push!(node.history.posterior, node.states.posterior) return nothing end @@ -130,7 +128,6 @@ Update the value prediction error of a single state node. function update_node_value_prediction_error!(node::CategoricalStateNode) #Update value prediction error node.states.value_prediction_error = calculate_value_prediction_error(node) - push!(node.history.value_prediction_error, node.states.value_prediction_error) return nothing end diff --git a/src/update_hgf/node_updates/continuous_input_node.jl b/src/update_hgf/node_updates/continuous_input_node.jl index 510e1f1..5c6bddc 100644 --- a/src/update_hgf/node_updates/continuous_input_node.jl +++ b/src/update_hgf/node_updates/continuous_input_node.jl @@ -12,11 +12,9 @@ function update_node_prediction!(node::ContinuousInputNode, stepsize::Real) #Update node prediction mean node.states.prediction_mean = calculate_prediction_mean(node) - push!(node.history.prediction_mean, node.states.prediction_mean) #Update prediction precision node.states.prediction_precision = calculate_prediction_precision(node) - push!(node.history.prediction_precision, node.states.prediction_precision) return nothing end @@ -85,7 +83,6 @@ function update_node_value_prediction_error!(node::ContinuousInputNode) #Calculate value prediction error node.states.value_prediction_error = calculate_value_prediction_error(node) - push!(node.history.value_prediction_error, node.states.value_prediction_error) return nothing end @@ -127,13 +124,7 @@ Update the value prediction error of a single input node. function update_node_precision_prediction_error!(node::ContinuousInputNode) #Calculate volatility prediction error, only if there are volatility parents - if length(node.edges.noise_parents) > 0 - node.states.precision_prediction_error = calculate_precision_prediction_error(node) - push!( - node.history.precision_prediction_error, - node.states.precision_prediction_error, - ) - end + node.states.precision_prediction_error = calculate_precision_prediction_error(node) return nothing end @@ -150,6 +141,12 @@ Uses the equation """ function calculate_precision_prediction_error(node::ContinuousInputNode) + #If there are no noise parents + if length(node.edges.noise_parents) == 0 + #Skip + return missing + end + #For missing input if ismissing(node.states.input_value) #Set the prediction error to missing diff --git a/src/update_hgf/node_updates/continuous_state_node.jl b/src/update_hgf/node_updates/continuous_state_node.jl index 20b2f73..148709e 100644 --- a/src/update_hgf/node_updates/continuous_state_node.jl +++ b/src/update_hgf/node_updates/continuous_state_node.jl @@ -12,16 +12,10 @@ function update_node_prediction!(node::ContinuousStateNode, stepsize::Real) #Update prediction mean node.states.prediction_mean = calculate_prediction_mean(node, stepsize) - push!(node.history.prediction_mean, node.states.prediction_mean) #Update prediction precision node.states.prediction_precision, node.states.effective_prediction_precision = calculate_prediction_precision(node, stepsize) - push!(node.history.prediction_precision, node.states.prediction_precision) - push!( - node.history.effective_prediction_precision, - node.states.effective_prediction_precision, - ) return nothing end @@ -118,11 +112,9 @@ Update the posterior of a single continuous state node. This is the classic HGF function update_node_posterior!(node::ContinuousStateNode, update_type::ClassicUpdate) #Update posterior precision node.states.posterior_precision = calculate_posterior_precision(node) - push!(node.history.posterior_precision, node.states.posterior_precision) #Update posterior mean node.states.posterior_mean = calculate_posterior_mean(node, update_type) - push!(node.history.posterior_mean, node.states.posterior_mean) return nothing end @@ -135,11 +127,9 @@ Update the posterior of a single continuous state node. This is the enahnced HGF function update_node_posterior!(node::ContinuousStateNode, update_type::EnhancedUpdate) #Update posterior mean node.states.posterior_mean = calculate_posterior_mean(node, update_type) - push!(node.history.posterior_mean, node.states.posterior_mean) #Update posterior precision node.states.posterior_precision = calculate_posterior_precision(node) - push!(node.history.posterior_precision, node.states.posterior_precision) return nothing end @@ -536,7 +526,6 @@ Update the value prediction error of a single state node. function update_node_value_prediction_error!(node::ContinuousStateNode) #Update value prediction error node.states.value_prediction_error = calculate_value_prediction_error(node) - push!(node.history.value_prediction_error, node.states.value_prediction_error) return nothing end @@ -566,13 +555,7 @@ Update the volatility prediction error of a single state node. function update_node_precision_prediction_error!(node::ContinuousStateNode) #Update volatility prediction error, only if there are volatility parents - if length(node.edges.volatility_parents) > 0 - node.states.precision_prediction_error = calculate_precision_prediction_error(node) - push!( - node.history.precision_prediction_error, - node.states.precision_prediction_error, - ) - end + node.states.precision_prediction_error = calculate_precision_prediction_error(node) return nothing end @@ -586,6 +569,14 @@ Uses the equation `` \Delta_n = \frac{\hat{\pi}_n}{\pi_n} + \hat{\pi}_n \cdot \delta_n^2-1 `` """ function calculate_precision_prediction_error(node::ContinuousStateNode) + + #If there are no volatility parents + if length(node.edges.volatility_parents) == 0 + #Skip + return missing + end + node.states.prediction_precision / node.states.posterior_precision + node.states.prediction_precision * node.states.value_prediction_error^2 - 1 + end diff --git a/src/update_hgf/update_hgf.jl b/src/update_hgf/update_hgf.jl index 4cba09b..05737f7 100644 --- a/src/update_hgf/update_hgf.jl +++ b/src/update_hgf/update_hgf.jl @@ -21,10 +21,7 @@ function update_hgf!( }; stepsize::Real = 1, ) - ## Update node predictions from last timestep - #Update the timepoint - push!(hgf.timesteps, hgf.timesteps[end] + stepsize) - + ### Update node predictions from last timestep ### #For each node (in the opposite update order) for node in reverse(hgf.ordered_nodes.all_state_nodes) #Update its prediction from last trial @@ -37,17 +34,17 @@ function update_hgf!( update_node_prediction!(node, stepsize) end - ## Supply inputs to input nodes + ### Supply inputs to input nodes ### enter_node_inputs!(hgf, inputs) - ## Update input node value prediction errors + ### Update input node value prediction errors ### #For each input node, in the specified update order for node in hgf.ordered_nodes.input_nodes #Update its value prediction error update_node_value_prediction_error!(node) end - ## Update input node value parent posteriors + ### Update input node value parent posteriors ### #For each node that is a value parent of an input node for node in hgf.ordered_nodes.early_update_state_nodes #Update its posterior @@ -58,14 +55,14 @@ function update_hgf!( update_node_precision_prediction_error!(node) end - ## Update input node precision prediction errors + ### Update input node precision prediction errors ### #For each input node, in the specified update order for node in hgf.ordered_nodes.input_nodes #Update its value prediction error update_node_precision_prediction_error!(node) end - ## Update remaining state nodes + ### Update remaining state nodes ### #For each state node, in the specified update order for node in hgf.ordered_nodes.late_update_state_nodes #Update its posterior @@ -76,6 +73,24 @@ function update_hgf!( update_node_precision_prediction_error!(node) end + ### Save the history for each node ### + #If save history is enabled + if hgf.save_history + + #Update the timepoint + push!(hgf.timesteps, hgf.timesteps[end] + stepsize) + + #Go through each node + for node in hgf.ordered_nodes.all_nodes + + #Go through each state + for state_name in fieldnames(typeof(node.states)) + #Add that state to the history + push!(getfield(node.history, state_name), getfield(node.states, state_name)) + end + end + end + return nothing end @@ -131,7 +146,6 @@ Update the prediction of a single input node. function update_node_input!(node::AbstractInputNode, input::Union{Real,Missing}) #Receive input node.states.input_value = input - push!(node.history.input_value, node.states.input_value) return nothing end diff --git a/test/testsuite/test_fit_model.jl b/test/testsuite/test_fit_model.jl index 4d6698c..2825740 100644 --- a/test/testsuite/test_fit_model.jl +++ b/test/testsuite/test_fit_model.jl @@ -18,7 +18,7 @@ using Turing test_hgf = premade_hgf("continuous_2level", verbose = false) #Create agent - test_agent = premade_agent("hgf_gaussian_action", test_hgf, verbose = false) + test_agent = premade_agent("hgf_gaussian", test_hgf, verbose = false) # Set fixed parsmeters and priors for fitting test_fixed_parameters = Dict( @@ -75,7 +75,7 @@ using Turing test_hgf = premade_hgf("binary_3level", verbose = false) #Create agent - test_agent = premade_agent("hgf_binary_softmax_action", test_hgf, verbose = false) + test_agent = premade_agent("hgf_binary_softmax", test_hgf, verbose = false) #Set fixed parameters and priors test_fixed_parameters = Dict( diff --git a/test/testsuite/test_premade_agent.jl b/test/testsuite/test_premade_agent.jl index 7811fd1..1f6961e 100644 --- a/test/testsuite/test_premade_agent.jl +++ b/test/testsuite/test_premade_agent.jl @@ -4,11 +4,11 @@ using Test @testset "Premade Action Models" begin - @testset "hgf_gaussian_action" begin + @testset "hgf_gaussian" begin #Create an HGF agent with the gaussian response test_agent = premade_agent( - "hgf_gaussian_action", + "hgf_gaussian", premade_hgf("continuous_2level", verbose = false), verbose = false, ) @@ -23,11 +23,11 @@ using Test @test get_surprise(test_agent.substruct) isa Real end - @testset "hgf_binary_softmax_action" begin + @testset "hgf_binary_softmax" begin #Create HGF agent with binary softmax action test_agent = premade_agent( - "hgf_binary_softmax_action", + "hgf_binary_softmax", premade_hgf("binary_3level", verbose = false), verbose = false, ) @@ -43,11 +43,11 @@ using Test end - @testset "hgf_unit_square_sigmoid_action" begin + @testset "hgf_unit_square_sigmoid" begin #Create HGF agent with binary softmax action test_agent = premade_agent( - "hgf_unit_square_sigmoid_action", + "hgf_unit_square_sigmoid", premade_hgf("binary_3level", verbose = false), verbose = false, ) From 2386c6817fe5f8a196eddcf1603a27b24034f504 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Mon, 15 Apr 2024 09:39:16 +0200 Subject: [PATCH 14/16] changed groupd parameter syntax --- Project.toml | 2 +- .../utils/get_parameters.jl | 22 ++-- .../utils/set_parameters.jl | 20 ++-- src/HierarchicalGaussianFiltering.jl | 5 +- src/create_hgf/check_hgf.jl | 30 +++--- src/create_hgf/hgf_structs.jl | 2 +- src/create_hgf/init_hgf.jl | 26 ++--- .../premade_categorical_3level.jl | 101 ++++++++++-------- .../premade_categorical_transitions_3level.jl | 101 ++++++++++-------- test/testsuite/test_grouped_parameters.jl | 69 ++++++++++++ test/testsuite/test_shared_parameters.jl | 50 --------- 11 files changed, 229 insertions(+), 199 deletions(-) create mode 100644 test/testsuite/test_grouped_parameters.jl delete mode 100644 test/testsuite/test_shared_parameters.jl diff --git a/Project.toml b/Project.toml index e8b3640..cf96c33 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" [compat] -ActionModels = "0.4" +ActionModels = "0.5" Distributions = "0.25" RecipesBase = "1" julia = "1.9" diff --git a/src/ActionModels_variations/utils/get_parameters.jl b/src/ActionModels_variations/utils/get_parameters.jl index a467715..27077cb 100644 --- a/src/ActionModels_variations/utils/get_parameters.jl +++ b/src/ActionModels_variations/utils/get_parameters.jl @@ -92,10 +92,10 @@ end function ActionModels.get_parameters(hgf::HGF, target_parameter::String) - #If the target parameter is a shared parameter - if target_parameter in keys(hgf.shared_parameters) - #Acess the parameter value in shared_parameters - return hgf.shared_parameters[target_parameter].value + #If the target parameter is a parameter group + if target_parameter in keys(hgf.parameter_groups) + #Access the parameter value in parameter_groups + return hgf.parameter_groups[target_parameter].value #If the target parameter is a node elseif target_parameter in keys(hgf.all_nodes) #Take out the node @@ -106,7 +106,7 @@ function ActionModels.get_parameters(hgf::HGF, target_parameter::String) #If the target parameter is neither a node nor in the shared parameters throw an error throw( ArgumentError( - "The node or parameter $target_parameter does not exist in the HGF or in shared parameters", + "The node or parameter $target_parameter does not exist in the HGF's nodes or parameter groups", ), ) end @@ -167,13 +167,13 @@ function ActionModels.get_parameters(hgf::HGF) end #If there are shared parameters - if length(hgf.shared_parameters) > 0 + if length(hgf.parameter_groups) > 0 #Go through each shared parameter - for (shared_parameter_key, shared_parameter_value) in hgf.shared_parameters - #Remove derived parameters from the list - filter!(x -> x[1] ∉ shared_parameter_value.derived_parameters, parameters) - #Set the shared parameter value - parameters[shared_parameter_key] = shared_parameter_value.value + for (parameter_group, grouped_parameters) in hgf.parameter_groups + #Remove grouped parameters from the list + filter!(x -> x[1] ∉ grouped_parameters.grouped_parameters, parameters) + #Add the parameter group parameter instead + parameters[parameter_group] = grouped_parameters.value end end diff --git a/src/ActionModels_variations/utils/set_parameters.jl b/src/ActionModels_variations/utils/set_parameters.jl index b57c021..7758f03 100644 --- a/src/ActionModels_variations/utils/set_parameters.jl +++ b/src/ActionModels_variations/utils/set_parameters.jl @@ -100,27 +100,27 @@ end ### For setting a single parameter ### function ActionModels.set_parameters!(hgf::HGF, target_param::String, param_value::Any) #If the target parameter is not in the shared parameters - if !(target_param in keys(hgf.shared_parameters)) + if !(target_param in keys(hgf.parameter_groups)) throw( ArgumentError( - "the parameter $target_param is passed to the HGF but is not in the HGF's shared parameters. Check that it is specified correctly", + "the parameter $target_param is a string, but is not in the HGF's grouped parameters. Check that it is specified correctly", ), ) end #Get out the shared parameter struct - shared_parameter = hgf.shared_parameters[target_param] + parameter_group = hgf.parameter_groups[target_param] - #Set the value in the shared parameter - setfield!(shared_parameter, :value, param_value) + #Set the value in the parameter group + setfield!(parameter_group, :value, param_value) - #Get out the derived parameters - derived_parameters = shared_parameter.derived_parameters + #Get out the grouped parameters + grouped_parameters = parameter_group.grouped_parameters - #For each derived parameter - for derived_parameter_key in derived_parameters + #For each grouped parameter + for grouped_parameter_key in grouped_parameters #Set the parameter - set_parameters!(hgf, derived_parameter_key, param_value) + set_parameters!(hgf, grouped_parameter_key, param_value) end end diff --git a/src/HierarchicalGaussianFiltering.jl b/src/HierarchicalGaussianFiltering.jl index fc73469..ad9229a 100644 --- a/src/HierarchicalGaussianFiltering.jl +++ b/src/HierarchicalGaussianFiltering.jl @@ -8,7 +8,9 @@ export init_node, init_hgf, premade_hgf, check_hgf, check_node, update_hgf! export get_prediction, get_surprise export premade_agent, init_agent, plot_predictive_simulation, plot_trajectory, plot_trajectory! -export get_history, get_parameters, get_states, set_parameters!, reset!, give_inputs! +export get_history, + get_parameters, get_states, set_parameters!, reset!, give_inputs!, set_save_history! +export ParameterGroup export EnhancedUpdate, ClassicUpdate export NodeDefaults export ContinuousState, @@ -41,6 +43,7 @@ include("ActionModels_variations/utils/get_states.jl") include("ActionModels_variations/utils/give_inputs.jl") include("ActionModels_variations/utils/reset.jl") include("ActionModels_variations/utils/set_parameters.jl") +include("ActionModels_variations/utils/set_save_history.jl") #Functions for updating the HGF include("update_hgf/update_hgf.jl") diff --git a/src/create_hgf/check_hgf.jl b/src/create_hgf/check_hgf.jl index 960d229..edc4767 100644 --- a/src/create_hgf/check_hgf.jl +++ b/src/create_hgf/check_hgf.jl @@ -19,35 +19,35 @@ function check_hgf(hgf::HGF) end #If there are shared parameters - if length(hgf.shared_parameters) > 0 + if length(hgf.parameter_groups) > 0 - ## Check for the same derived parameter in multiple shared parameters ## + ## Check for the same grouped parameter in multiple shared parameters ## - #Get all derived parameters - derived_parameters = [ - parameter for list_of_derived_parameters in [ - hgf.shared_parameters[parameter_key].derived_parameters for - parameter_key in keys(hgf.shared_parameters) - ] for parameter in list_of_derived_parameters + #Get all grouped parameters + grouped_parameters = [ + parameter for list_of_grouped_parameters in [ + hgf.parameter_groups[parameter_key].grouped_parameters for + parameter_key in keys(hgf.parameter_groups) + ] for parameter in list_of_grouped_parameters ] #Check for duplicate names - if length(derived_parameters) > length(unique(derived_parameters)) + if length(grouped_parameters) > length(unique(grouped_parameters)) #Throw an error throw( ArgumentError( - "At least one parameter is set by multiple shared parameters. This is not supported.", + "At least one parameter is set by multiple parameter groups. This is not supported.", ), ) end - ## Check if the shared parameter is part of own derived parameters ## + ## Check if the shared parameter is part of own grouped parameters ## #Go through each specified shared parameter - for (shared_parameter_key, dict_value) in hgf.shared_parameters - #check if the name of the shared parameter is part of its own derived parameters - if shared_parameter_key in dict_value.derived_parameters + for (parameter_group_key, grouped_parameters) in hgf.parameter_groups + #check if the name of the shared parameter is part of its own grouped parameters + if parameter_group_key in grouped_parameters.grouped_parameters throw( ArgumentError( - "The shared parameter is part of the list of derived parameters", + "The parameter group name $parameter_group_key is part of the list of parameters in the group", ), ) end diff --git a/src/create_hgf/hgf_structs.jl b/src/create_hgf/hgf_structs.jl index da2feda..0e4c288 100644 --- a/src/create_hgf/hgf_structs.jl +++ b/src/create_hgf/hgf_structs.jl @@ -80,7 +80,7 @@ Base.@kwdef mutable struct HGF input_nodes::Dict{String,AbstractInputNode} state_nodes::Dict{String,AbstractStateNode} ordered_nodes::OrderedNodes = OrderedNodes() - shared_parameters::Dict = Dict() + parameter_groups::Dict = Dict() save_history::Bool = true timesteps::Vector{Real} = [0] end diff --git a/src/create_hgf/init_hgf.jl b/src/create_hgf/init_hgf.jl index b18c58d..f45d1c5 100644 --- a/src/create_hgf/init_hgf.jl +++ b/src/create_hgf/init_hgf.jl @@ -76,7 +76,7 @@ function init_hgf(; nodes::Vector{<:AbstractNodeInfo}, edges::Dict{Tuple{String,String},<:CouplingType}, node_defaults::NodeDefaults = NodeDefaults(), - shared_parameters::Dict = Dict(), + parameter_groups::Vector{ParameterGroup} = Vector{ParameterGroup}(), update_order::Union{Nothing,Vector{String}} = nothing, verbose::Bool = true, save_history::Bool = true, @@ -186,25 +186,15 @@ function init_hgf(; end #initializing shared parameters - shared_parameters_dict = Dict() + parameter_groups_dict = Dict() #Go through each specified shared parameter - for (shared_parameter_key, dict_value) in shared_parameters - #Unpack the shared parameter value and the derived parameters - (shared_parameter_value, derived_parameters) = dict_value - #check if the name of the shared parameter is part of its own derived parameters - if shared_parameter_key in derived_parameters - throw( - ArgumentError( - "The shared parameter is part of the list of derived parameters", - ), - ) - end + for parameter_group in parameter_groups - #Add as a SharedParameter to the shared parameter dictionary - shared_parameters_dict[shared_parameter_key] = SharedParameter( - value = shared_parameter_value, - derived_parameters = derived_parameters, + #Add as a GroupedParameters to the shared parameter dictionary + parameter_groups_dict[parameter_group.name] = ActionModels.GroupedParameters( + value = parameter_group.value, + grouped_parameters = parameter_group.parameters, ) end @@ -214,7 +204,7 @@ function init_hgf(; input_nodes_dict, state_nodes_dict, ordered_nodes, - shared_parameters_dict, + parameter_groups_dict, save_history, [0], ) diff --git a/src/premade_models/premade_hgfs/premade_categorical_3level.jl b/src/premade_models/premade_hgfs/premade_categorical_3level.jl index d658a3b..d040a3c 100644 --- a/src/premade_models/premade_hgfs/premade_categorical_3level.jl +++ b/src/premade_models/premade_hgfs/premade_categorical_3level.jl @@ -56,14 +56,14 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) #Vector for binary node continuous parent names probability_parent_names = Vector{String}() - #Empty lists for derived parameters - derived_parameters_xprob_initial_precision = [] - derived_parameters_xprob_initial_mean = [] - derived_parameters_xprob_volatility = [] - derived_parameters_xprob_drift = [] - derived_parameters_xprob_autoconnection_strength = [] - derived_parameters_xbin_xprob_coupling_strength = [] - derived_parameters_xprob_xvol_coupling_strength = [] + #Empty lists for grouped parameters + grouped_parameters_xprob_initial_precision = [] + grouped_parameters_xprob_initial_mean = [] + grouped_parameters_xprob_volatility = [] + grouped_parameters_xprob_drift = [] + grouped_parameters_xprob_autoconnection_strength = [] + grouped_parameters_xbin_xprob_coupling_strength = [] + grouped_parameters_xprob_xvol_coupling_strength = [] #Populate the category node vectors with node names for category_number = 1:config["n_categories"] @@ -92,13 +92,13 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) initial_precision = config[("xprob", "initial_precision")], ), ) - #Add the derived parameter name to derived parameters vector - push!(derived_parameters_xprob_initial_precision, (node_name, "initial_precision")) - push!(derived_parameters_xprob_initial_mean, (node_name, "initial_mean")) - push!(derived_parameters_xprob_volatility, (node_name, "volatility")) - push!(derived_parameters_xprob_drift, (node_name, "drift")) + #Add the grouped parameter name to grouped parameters vector + push!(grouped_parameters_xprob_initial_precision, (node_name, "initial_precision")) + push!(grouped_parameters_xprob_initial_mean, (node_name, "initial_mean")) + push!(grouped_parameters_xprob_volatility, (node_name, "volatility")) + push!(grouped_parameters_xprob_drift, (node_name, "drift")) push!( - derived_parameters_xprob_autoconnection_strength, + grouped_parameters_xprob_autoconnection_strength, (node_name, "autoconnection_strength"), ) end @@ -135,52 +135,61 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true) edges[(probability_parent_name, "xvol")] = VolatilityCoupling(config[("xprob", "xvol", "coupling_strength")]) - #Add the coupling strengths to the lists of derived parameters + #Add the coupling strengths to the lists of grouped parameters push!( - derived_parameters_xbin_xprob_coupling_strength, + grouped_parameters_xbin_xprob_coupling_strength, (category_parent_name, probability_parent_name, "coupling_strength"), ) push!( - derived_parameters_xprob_xvol_coupling_strength, + grouped_parameters_xprob_xvol_coupling_strength, (probability_parent_name, "xvol", "coupling_strength"), ) end #Create dictionary with shared parameter information - shared_parameters = Dict() - - shared_parameters["xprob_volatility"] = - (config[("xprob", "volatility")], derived_parameters_xprob_volatility) - - shared_parameters["xprob_initial_precision"] = - (config[("xprob", "initial_precision")], derived_parameters_xprob_initial_precision) - - shared_parameters["xprob_initial_mean"] = - (config[("xprob", "initial_mean")], derived_parameters_xprob_initial_mean) - - shared_parameters["xprob_drift"] = - (config[("xprob", "drift")], derived_parameters_xprob_drift) - - shared_parameters["autoconnection_strength"] = ( - config[("xprob", "autoconnection_strength")], - derived_parameters_xprob_autoconnection_strength, - ) - - shared_parameters["xbin_xprob_coupling_strength"] = ( - config[("xbin", "xprob", "coupling_strength")], - derived_parameters_xbin_xprob_coupling_strength, - ) - - shared_parameters["xprob_xvol_coupling_strength"] = ( - config[("xprob", "xvol", "coupling_strength")], - derived_parameters_xprob_xvol_coupling_strength, - ) + parameter_groups = [ + ParameterGroup( + "xprob_volatility", + grouped_parameters_xprob_volatility, + config[("xprob", "volatility")], + ), + ParameterGroup( + "xprob_initial_precision", + grouped_parameters_xprob_initial_precision, + config[("xprob", "initial_precision")], + ), + ParameterGroup( + "xprob_initial_mean", + grouped_parameters_xprob_initial_mean, + config[("xprob", "initial_mean")], + ), + ParameterGroup( + "xprob_drift", + grouped_parameters_xprob_drift, + config[("xprob", "drift")], + ), + ParameterGroup( + "xprob_autoconnection_strength", + grouped_parameters_xprob_autoconnection_strength, + config[("xprob", "autoconnection_strength")], + ), + ParameterGroup( + "xbin_xprob_coupling_strength", + grouped_parameters_xbin_xprob_coupling_strength, + config[("xbin", "xprob", "coupling_strength")], + ), + ParameterGroup( + "xprob_xvol_coupling_strength", + grouped_parameters_xprob_xvol_coupling_strength, + config[("xprob", "xvol", "coupling_strength")], + ), + ] #Initialize the HGF init_hgf( nodes = nodes, edges = edges, - shared_parameters = shared_parameters, + parameter_groups = parameter_groups, verbose = false, node_defaults = NodeDefaults(update_type = config["update_type"]), save_history = config["save_history"], diff --git a/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl index e66d9f3..8a77871 100644 --- a/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl +++ b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl @@ -64,14 +64,14 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo category_parent_names = Vector{String}() probability_parent_names = Vector{String}() - #Empty lists for derived parameters - derived_parameters_xprob_initial_precision = [] - derived_parameters_xprob_initial_mean = [] - derived_parameters_xprob_volatility = [] - derived_parameters_xprob_drift = [] - derived_parameters_xprob_autoconnection_strength = [] - derived_parameters_xbin_xprob_coupling_strength = [] - derived_parameters_xprob_xvol_coupling_strength = [] + #Empty lists for grouped parameters + grouped_parameters_xprob_initial_precision = [] + grouped_parameters_xprob_initial_mean = [] + grouped_parameters_xprob_volatility = [] + grouped_parameters_xprob_drift = [] + grouped_parameters_xprob_autoconnection_strength = [] + grouped_parameters_xbin_xprob_coupling_strength = [] + grouped_parameters_xprob_xvol_coupling_strength = [] #Go through each category that the transition may have been from for category_from = 1:config["n_categories"] @@ -128,13 +128,13 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo initial_precision = config[("xprob", "initial_precision")], ), ) - #Add the derived parameter name to derived parameters vector - push!(derived_parameters_xprob_initial_precision, (node_name, "initial_precision")) - push!(derived_parameters_xprob_initial_mean, (node_name, "initial_mean")) - push!(derived_parameters_xprob_volatility, (node_name, "volatility")) - push!(derived_parameters_xprob_drift, (node_name, "drift")) + #Add the grouped parameter name to grouped parameters vector + push!(grouped_parameters_xprob_initial_precision, (node_name, "initial_precision")) + push!(grouped_parameters_xprob_initial_mean, (node_name, "initial_mean")) + push!(grouped_parameters_xprob_volatility, (node_name, "volatility")) + push!(grouped_parameters_xprob_drift, (node_name, "drift")) push!( - derived_parameters_xprob_autoconnection_strength, + grouped_parameters_xprob_autoconnection_strength, (node_name, "autoconnection_strength"), ) end @@ -198,53 +198,62 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo VolatilityCoupling(config[("xprob", "xvol", "coupling_strength")]) - #Add the parameters as derived parameters for shared parameters + #Add the parameters as grouped parameters for shared parameters push!( - derived_parameters_xbin_xprob_coupling_strength, + grouped_parameters_xbin_xprob_coupling_strength, (category_parent_name, probability_parent_name, "coupling_strength"), ) push!( - derived_parameters_xprob_xvol_coupling_strength, + grouped_parameters_xprob_xvol_coupling_strength, (probability_parent_name, "xvol", "coupling_strength"), ) end #Create dictionary with shared parameter information - shared_parameters = Dict() - - shared_parameters["xprob_volatility"] = - (config[("xprob", "volatility")], derived_parameters_xprob_volatility) - - shared_parameters["xprob_initial_precision"] = - (config[("xprob", "initial_precision")], derived_parameters_xprob_initial_precision) - - shared_parameters["xprob_initial_mean"] = - (config[("xprob", "initial_mean")], derived_parameters_xprob_initial_mean) - - shared_parameters["xprob_drift"] = - (config[("xprob", "drift")], derived_parameters_xprob_drift) - - shared_parameters["autoconnection_strength"] = ( - config[("xprob", "autoconnection_strength")], - derived_parameters_xprob_autoconnection_strength, - ) - - shared_parameters["xbin_xprob_coupling_strength"] = ( - config[("xbin", "xprob", "coupling_strength")], - derived_parameters_xbin_xprob_coupling_strength, - ) - - shared_parameters["xprob_xvol_coupling_strength"] = ( - config[("xprob", "xvol", "coupling_strength")], - derived_parameters_xprob_xvol_coupling_strength, - ) + parameter_groups = [ + ParameterGroup( + "xprob_volatility", + grouped_parameters_xprob_volatility, + config[("xprob", "volatility")], + ), + ParameterGroup( + "xprob_initial_precision", + grouped_parameters_xprob_initial_precision, + config[("xprob", "initial_precision")], + ), + ParameterGroup( + "xprob_initial_mean", + grouped_parameters_xprob_initial_mean, + config[("xprob", "initial_mean")], + ), + ParameterGroup( + "xprob_drift", + grouped_parameters_xprob_drift, + config[("xprob", "drift")], + ), + ParameterGroup( + "xprob_autoconnection_strength", + grouped_parameters_xprob_autoconnection_strength, + config[("xprob", "autoconnection_strength")], + ), + ParameterGroup( + "xbin_xprob_coupling_strength", + grouped_parameters_xbin_xprob_coupling_strength, + config[("xbin", "xprob", "coupling_strength")], + ), + ParameterGroup( + "xprob_xvol_coupling_strength", + grouped_parameters_xprob_xvol_coupling_strength, + config[("xprob", "xvol", "coupling_strength")], + ), + ] #Initialize the HGF init_hgf( nodes = nodes, edges = edges, - shared_parameters = shared_parameters, + parameter_groups = parameter_groups, verbose = false, node_defaults = NodeDefaults(update_type = config["update_type"]), save_history = config["save_history"], diff --git a/test/testsuite/test_grouped_parameters.jl b/test/testsuite/test_grouped_parameters.jl new file mode 100644 index 0000000..02aeca2 --- /dev/null +++ b/test/testsuite/test_grouped_parameters.jl @@ -0,0 +1,69 @@ +using HierarchicalGaussianFiltering +using Test + +@testset "Grouped parameters" begin + + # Test of custom HGF with shared parameters + + #List of nodes + nodes = [ + ContinuousInput(name = "u", input_noise = 2), + ContinuousState( + name = "x1", + volatility = 2, + initial_mean = 1, + initial_precision = 1, + ), + ContinuousState( + name = "x2", + volatility = 2, + initial_mean = 1, + initial_precision = 1, + ), + ] + + #List of child-parent relations + edges = + Dict(("u", "x1") => ObservationCoupling(), ("x1", "x2") => VolatilityCoupling(1)) + + # one shared parameter + parameter_groups_1 = + [ParameterGroup("volatilities", [("x1", "volatility"), ("x2", "volatility")], 9)] + + #Initialize the HGF + hgf_1 = init_hgf(nodes = nodes, edges = edges, parameter_groups = parameter_groups_1) + + #get shared parameter + get_parameters(hgf_1) + + @test get_parameters(hgf_1, "volatilities") == 9 + + #set shared parameter + set_parameters!(hgf_1, "volatilities", 2) + + parameter_groups_2 = [ + ParameterGroup( + "initial_means", + [("x1", "initial_mean"), ("x2", "initial_mean")], + 9, + ), + ParameterGroup("volatilities", [("x1", "volatility"), ("x2", "volatility")], 9), + ] + + #Initialize the HGF + hgf_2 = init_hgf(nodes = nodes, edges = edges, parameter_groups = parameter_groups_2) + + #get all parameters + get_parameters(hgf_2) + + #get shared parameter + @test get_parameters(hgf_2, "volatilities") == 9 + + #set shared parameter + set_parameters!(hgf_2, Dict("volatilities" => -2, "initial_means" => 1)) + + @test get_parameters(hgf_2, "volatilities") == -2 + @test get_parameters(hgf_2, "initial_means") == 1 + + +end diff --git a/test/testsuite/test_shared_parameters.jl b/test/testsuite/test_shared_parameters.jl deleted file mode 100644 index 19ac42f..0000000 --- a/test/testsuite/test_shared_parameters.jl +++ /dev/null @@ -1,50 +0,0 @@ -using HierarchicalGaussianFiltering -using Test - -# Test of custom HGF with shared parameters - -#List of nodes -nodes = [ - ContinuousInput(name = "u", input_noise = 2), - ContinuousState(name = "x1", volatility = 2, initial_mean = 1, initial_precision = 1), - ContinuousState(name = "x2", volatility = 2, initial_mean = 1, initial_precision = 1), -] - -#List of child-parent relations -edges = Dict(("u", "x1") => ObservationCoupling(), ("x1", "x2") => VolatilityCoupling(1)) - -# one shared parameter -shared_parameters_1 = - Dict("volatilitys" => (9, [("x1", "volatility"), ("x2", "volatility")])) - -#Initialize the HGF -hgf_1 = init_hgf(nodes = nodes, edges = edges, shared_parameters = shared_parameters_1) - -#get shared parameter -get_parameters(hgf_1) - -@test get_parameters(hgf_1, "volatilitys") == 9 - -#set shared parameter -set_parameters!(hgf_1, "volatilitys", 2) - -shared_parameters_2 = Dict( - "initial_means" => (9, [("x1", "initial_mean"), ("x2", "initial_mean")]), - "volatilitys" => (9, [("x1", "volatility"), ("x2", "volatility")]), -) - - -#Initialize the HGF -hgf_2 = init_hgf(nodes = nodes, edges = edges, shared_parameters = shared_parameters_2) - -#get all parameters -get_parameters(hgf_2) - -#get shared parameter -@test get_parameters(hgf_2, "volatilitys") == 9 - -#set shared parameter -set_parameters!(hgf_2, Dict("volatilitys" => -2, "initial_means" => 1)) - -@test get_parameters(hgf_2, "volatilitys") == -2 -@test get_parameters(hgf_2, "initial_means") == 1 From 20792df531235f643c6355faa055e1b5fe44721e Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Mon, 15 Apr 2024 09:41:23 +0200 Subject: [PATCH 15/16] set save history function --- src/ActionModels_variations/utils/set_save_history.jl | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 src/ActionModels_variations/utils/set_save_history.jl diff --git a/src/ActionModels_variations/utils/set_save_history.jl b/src/ActionModels_variations/utils/set_save_history.jl new file mode 100644 index 0000000..9862b31 --- /dev/null +++ b/src/ActionModels_variations/utils/set_save_history.jl @@ -0,0 +1,4 @@ +#To set save history +function ActionModels.set_save_history!(hgf::HGF, save_history::Bool) + hgf.save_history = save_history +end From 0eea8bc7bf730d1802e9ba9f9a8ed920c1ee3be5 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Mon, 15 Apr 2024 17:30:08 +0200 Subject: [PATCH 16/16] Added nonlinear transformations --- .../utils/get_parameters.jl | 93 ++++++++++----- .../utils/set_parameters.jl | 71 ++++++++---- src/HierarchicalGaussianFiltering.jl | 8 +- src/create_hgf/hgf_structs.jl | 19 ++- src/create_hgf/init_hgf.jl | 105 ----------------- src/create_hgf/init_node_edge.jl | 108 ++++++++++++++++++ .../node_updates/continuous_input_node.jl | 2 +- .../node_updates/continuous_state_node.jl | 67 +++++++++-- src/update_hgf/nonlinear_transforms.jl | 43 +++++++ test/testsuite/test_initialization.jl | 4 +- 10 files changed, 346 insertions(+), 174 deletions(-) create mode 100644 src/create_hgf/init_node_edge.jl create mode 100644 src/update_hgf/nonlinear_transforms.jl diff --git a/src/ActionModels_variations/utils/get_parameters.jl b/src/ActionModels_variations/utils/get_parameters.jl index 27077cb..c1a890c 100644 --- a/src/ActionModels_variations/utils/get_parameters.jl +++ b/src/ActionModels_variations/utils/get_parameters.jl @@ -43,21 +43,12 @@ function ActionModels.get_parameters(hgf::HGF, target_param::Tuple{String,String return param end -##For coupling strengths +##For coupling strengths and coupling transforms function ActionModels.get_parameters(hgf::HGF, target_param::Tuple{String,String,String}) #Unpack node name, parent name and param name (node_name, parent_name, param_name) = target_param - #If the specified parameter is not a coupling strength - if !(param_name == "coupling_strength") - throw( - ArgumentError( - "the parameter $target_param is specified as three strings, but is not a coupling strength", - ), - ) - end - #If the node does not exist if !(node_name in keys(hgf.all_nodes)) #Throw an error @@ -67,29 +58,59 @@ function ActionModels.get_parameters(hgf::HGF, target_param::Tuple{String,String #Get out the node node = hgf.all_nodes[node_name] - #Get out the dictionary of coupling strengths - coupling_strengths = getproperty(node.parameters, :coupling_strengths) + #If the parameter is a coupling strength + if param_name == "coupling_strength" - #If the specified parent is not in the dictionary - if !(parent_name in keys(coupling_strengths)) - #Throw an error - throw( - ArgumentError( - "The node $node_name does not have a coupling strength parameter to a parent called $parent_name", - ), - ) - end + #Get out the dictionary of coupling strengths + coupling_strengths = getproperty(node.parameters, :coupling_strengths) + + #If the specified parent is not in the dictionary + if !(parent_name in keys(coupling_strengths)) + #Throw an error + throw( + ArgumentError( + "The node $node_name does not have a coupling strength parameter to a parent called $parent_name", + ), + ) + end + + #Get the coupling strength for that given parent + param = coupling_strengths[parent_name] - #Get the coupling strength for that given parent - param = coupling_strengths[parent_name] + else + + #Get out the coupling transforms + coupling_transforms = getproperty(node.parameters, :coupling_transforms) + + #If the specified parent is not in the dictionary + if !(parent_name in keys(coupling_transforms)) + #Throw an error + throw( + ArgumentError( + "The node $node_name does not have a coupling transformation to a parent called $parent_name", + ), + ) + end + + #If the specified parameter does not exist for the transform + if !(param_name in keys(coupling_transforms.parameters)) + throw( + ArgumentError( + "There is no parameter called $param_name for the transformation function between $node_name and its parent $parent_name", + ), + ) + end + + #Extract the parameter + param = coupling_transforms.parameters[param_name] + end return param end -### For getting all parameters of a specific node ### - +### For getting a single-string parameter (a parameter group), or all parameters of a node ### function ActionModels.get_parameters(hgf::HGF, target_parameter::String) #If the target parameter is a parameter group @@ -196,13 +217,31 @@ function ActionModels.get_parameters(node::AbstractNode) coupling_strengths = node.parameters.coupling_strengths #Go through each parent - for parent_name in keys(coupling_strengths) + for (parent_name, coupling_strength) in coupling_strengths #Add the coupling strength to the ouput dict parameters[(node.name, parent_name, "coupling_strength")] = - coupling_strengths[parent_name] + coupling_strength end + + #If the parameter is a coupling transform + elseif param_key == :coupling_transforms + + #Go through each parent and corresponding transform + for (parent_name, coupling_transform) in node.parameters.coupling_transforms + + #Go through each parameter for the transform + for (coupling_parameter, parameter_value) in coupling_transform.parameters + + #Add the coupling strength to the ouput dict + parameters[(node.name, parent_name, coupling_parameter)] = + parameter_value + + end + end + + #For other nodes else #And add their values to the dictionary parameters[(node.name, String(param_key))] = diff --git a/src/ActionModels_variations/utils/set_parameters.jl b/src/ActionModels_variations/utils/set_parameters.jl index 7758f03..e83f2c6 100644 --- a/src/ActionModels_variations/utils/set_parameters.jl +++ b/src/ActionModels_variations/utils/set_parameters.jl @@ -11,7 +11,7 @@ function ActionModels.set_parameters!() end ### For setting a single parameter ### -##For parameters other than coupling strengths +##For parameters other than coupling strengths and transforms function ActionModels.set_parameters!( hgf::HGF, target_param::Tuple{String,String}, @@ -61,15 +61,6 @@ function ActionModels.set_parameters!( #Unpack node name, parent name and parameter name (node_name, parent_name, param_name) = target_param - #If the specified parameter is not a coupling strength - if !(param_name == "coupling_strength") - throw( - ArgumentError( - "the parameter $target_param is specified as three strings, but is not a coupling strength", - ), - ) - end - #If the node does not exist if !(node_name in keys(hgf.all_nodes)) #Throw an error @@ -79,26 +70,56 @@ function ActionModels.set_parameters!( #Get the child node node = hgf.all_nodes[node_name] - #Get coupling_strengths - coupling_strengths = node.parameters.coupling_strengths + #If it is a coupling strength + if param_name == "coupling_strength" + + #Get coupling_strengths + coupling_strengths = node.parameters.coupling_strengths + + #If the specified parent is not in the dictionary + if !(parent_name in keys(coupling_strengths)) + #Throw an error + throw( + ArgumentError( + "The node $node_name does not have a coupling strength parameter to a parent called $parent_name", + ), + ) + end + + #Set the coupling strength to the specified parent to the specified value + coupling_strengths[parent_name] = param_value + + else + + #Get out the coupling transforms + coupling_transforms = getproperty(node.parameters, :coupling_transforms) + + #If the specified parent is not in the dictionary + if !(parent_name in keys(coupling_transforms)) + #Throw an error + throw( + ArgumentError( + "The node $node_name does not have a coupling transformation to a parent called $parent_name", + ), + ) + end + + #If the specified parameter does not exist for the transform + if !(param_name in keys(coupling_transforms.parameters)) + throw( + ArgumentError( + "There is no parameter called $param_name for the transformation function between $node_name and its parent $parent_name", + ), + ) + end - #If the specified parent is not in the dictionary - if !(parent_name in keys(coupling_strengths)) - #Throw an error - throw( - ArgumentError( - "The node $node_name does not have a coupling strength parameter to a parent called $parent_name", - ), - ) + #Set the parameter + coupling_transforms.parameters[param_name] = param_value end - - #Set the coupling strength to the specified parent to the specified value - coupling_strengths[parent_name] = param_value - end ### For setting a single parameter ### -function ActionModels.set_parameters!(hgf::HGF, target_param::String, param_value::Any) +function ActionModels.set_parameters!(hgf::HGF, target_param::String, param_value::Real) #If the target parameter is not in the shared parameters if !(target_param in keys(hgf.parameter_groups)) throw( diff --git a/src/HierarchicalGaussianFiltering.jl b/src/HierarchicalGaussianFiltering.jl index ad9229a..fce8b48 100644 --- a/src/HierarchicalGaussianFiltering.jl +++ b/src/HierarchicalGaussianFiltering.jl @@ -20,7 +20,9 @@ export DriftCoupling, CategoryCoupling, ProbabilityCoupling, VolatilityCoupling, - NoiseCoupling + NoiseCoupling, + LinearTransform, + NonlinearTransform #Add premade agents to shared dict at initialization function __init__() @@ -47,6 +49,7 @@ include("ActionModels_variations/utils/set_save_history.jl") #Functions for updating the HGF include("update_hgf/update_hgf.jl") +include("update_hgf/nonlinear_transforms.jl") include("update_hgf/node_updates/continuous_input_node.jl") include("update_hgf/node_updates/continuous_state_node.jl") include("update_hgf/node_updates/binary_input_node.jl") @@ -57,10 +60,9 @@ include("update_hgf/node_updates/categorical_state_node.jl") #Functions for creating HGFs include("create_hgf/check_hgf.jl") include("create_hgf/init_hgf.jl") +include("create_hgf/init_node_edge.jl") include("create_hgf/create_premade_hgf.jl") -#Plotting functions - #Functions for premade agents include("premade_models/premade_agents/premade_gaussian.jl") include("premade_models/premade_agents/premade_predict_category.jl") diff --git a/src/create_hgf/hgf_structs.jl b/src/create_hgf/hgf_structs.jl index 0e4c288..a60f6ee 100644 --- a/src/create_hgf/hgf_structs.jl +++ b/src/create_hgf/hgf_structs.jl @@ -37,6 +37,18 @@ struct EnhancedUpdate <: HGFUpdateType end ######## Coupling types ######## ################################ +#Types for specifying nonlinear transformations +abstract type CouplingTransform end + +Base.@kwdef mutable struct LinearTransform <: CouplingTransform end + +Base.@kwdef mutable struct NonlinearTransform <: CouplingTransform + base_function::Function + first_derivation::Function + second_derivation::Function + parameters::Dict = Dict() +end + #Supertypes for coupling types abstract type CouplingType end abstract type ValueCoupling <: CouplingType end @@ -45,6 +57,7 @@ abstract type PrecisionCoupling <: CouplingType end #Concrete value coupling types Base.@kwdef mutable struct DriftCoupling <: ValueCoupling strength::Union{Nothing,Real} = nothing + transform::CouplingTransform = LinearTransform() end Base.@kwdef mutable struct ProbabilityCoupling <: ValueCoupling strength::Union{Nothing,Real} = nothing @@ -159,9 +172,10 @@ Base.@kwdef mutable struct ContinuousStateNodeParameters volatility::Real = 0 drift::Real = 0 autoconnection_strength::Real = 1 - coupling_strengths::Dict{String,Real} = Dict{String,Real}() initial_mean::Real = 0 initial_precision::Real = 0 + coupling_strengths::Dict{String,Real} = Dict{String,Real}() + coupling_transforms::Dict{String,CouplingTransform} = Dict{String,Real}() end """ @@ -210,6 +224,8 @@ Base.@kwdef mutable struct ContinuousInputNodeEdges observation_parents::Vector{<:AbstractContinuousStateNode} = Vector{ContinuousStateNode}() noise_parents::Vector{<:AbstractContinuousStateNode} = Vector{ContinuousStateNode}() + + end """ @@ -219,6 +235,7 @@ Base.@kwdef mutable struct ContinuousInputNodeParameters input_noise::Real = 0 bias::Real = 0 coupling_strengths::Dict{String,Real} = Dict{String,Real}() + coupling_transforms::Dict{String,CouplingTransform} = Dict{String,Real}() end """ diff --git a/src/create_hgf/init_hgf.jl b/src/create_hgf/init_hgf.jl index f45d1c5..ca308fb 100644 --- a/src/create_hgf/init_hgf.jl +++ b/src/create_hgf/init_hgf.jl @@ -250,108 +250,3 @@ function init_hgf(; return hgf end - - - - -function init_node(node_info::ContinuousState) - ContinuousStateNode( - name = node_info.name, - parameters = ContinuousStateNodeParameters( - volatility = node_info.volatility, - drift = node_info.drift, - initial_mean = node_info.initial_mean, - initial_precision = node_info.initial_precision, - autoconnection_strength = node_info.autoconnection_strength, - ), - ) -end - -function init_node(node_info::ContinuousInput) - ContinuousInputNode( - name = node_info.name, - parameters = ContinuousInputNodeParameters(input_noise = node_info.input_noise), - ) -end - -function init_node(node_info::BinaryState) - BinaryStateNode(name = node_info.name) -end - -function init_node(node_info::BinaryInput) - BinaryInputNode(name = node_info.name) -end - -function init_node(node_info::CategoricalState) - CategoricalStateNode(name = node_info.name) -end - -function init_node(node_info::CategoricalInput) - CategoricalInputNode(name = node_info.name) -end - - -### Function for initializing and edge ### -function init_edge!( - child_node::AbstractNode, - parent_node::AbstractStateNode, - coupling_type::CouplingType, - node_defaults::NodeDefaults, -) - - #Get correct field for storing parents - if coupling_type isa DriftCoupling - parents_field = :drift_parents - children_field = :drift_children - - elseif coupling_type isa ObservationCoupling - parents_field = :observation_parents - children_field = :observation_children - - elseif coupling_type isa CategoryCoupling - parents_field = :category_parents - children_field = :category_children - - elseif coupling_type isa ProbabilityCoupling - parents_field = :probability_parents - children_field = :probability_children - - elseif coupling_type isa VolatilityCoupling - parents_field = :volatility_parents - children_field = :volatility_children - - elseif coupling_type isa NoiseCoupling - parents_field = :noise_parents - children_field = :noise_children - end - - #Add the parent to the child node - push!(getfield(child_node.edges, parents_field), parent_node) - - #Add the child node to the parent node - push!(getfield(parent_node.edges, children_field), child_node) - - #If the coupling type has a coupling strength - if hasproperty(coupling_type, :strength) - #If the user has not specified a coupling strength - if isnothing(coupling_type.strength) - #Use the defaults coupling strength - coupling_strength = node_defaults.coupling_strength - - #Otherwise - else - #Use the specified coupling strength - coupling_strength = coupling_type.strength - end - - #And set it as a parameter for the child - child_node.parameters.coupling_strengths[parent_node.name] = coupling_strength - end - - - #If the enhanced HGF update is the defaults, and if it is a precision coupling (volatility or noise) - if node_defaults.update_type isa EnhancedUpdate && coupling_type isa PrecisionCoupling - #Set the node to use the enhanced HGF update - parent_node.update_type = node_defaults.update_type - end -end diff --git a/src/create_hgf/init_node_edge.jl b/src/create_hgf/init_node_edge.jl new file mode 100644 index 0000000..29c5944 --- /dev/null +++ b/src/create_hgf/init_node_edge.jl @@ -0,0 +1,108 @@ +### Function for initializing a node ### +function init_node(node_info::ContinuousState) + ContinuousStateNode( + name = node_info.name, + parameters = ContinuousStateNodeParameters( + volatility = node_info.volatility, + drift = node_info.drift, + initial_mean = node_info.initial_mean, + initial_precision = node_info.initial_precision, + autoconnection_strength = node_info.autoconnection_strength, + ), + ) +end + +function init_node(node_info::ContinuousInput) + ContinuousInputNode( + name = node_info.name, + parameters = ContinuousInputNodeParameters(input_noise = node_info.input_noise), + ) +end + +function init_node(node_info::BinaryState) + BinaryStateNode(name = node_info.name) +end + +function init_node(node_info::BinaryInput) + BinaryInputNode(name = node_info.name) +end + +function init_node(node_info::CategoricalState) + CategoricalStateNode(name = node_info.name) +end + +function init_node(node_info::CategoricalInput) + CategoricalInputNode(name = node_info.name) +end + + +### Function for initializing an edge ### +function init_edge!( + child_node::AbstractNode, + parent_node::AbstractStateNode, + coupling_type::CouplingType, + node_defaults::NodeDefaults, +) + + #Get correct field for storing parents + if coupling_type isa DriftCoupling + parents_field = :drift_parents + children_field = :drift_children + + elseif coupling_type isa ObservationCoupling + parents_field = :observation_parents + children_field = :observation_children + + elseif coupling_type isa CategoryCoupling + parents_field = :category_parents + children_field = :category_children + + elseif coupling_type isa ProbabilityCoupling + parents_field = :probability_parents + children_field = :probability_children + + elseif coupling_type isa VolatilityCoupling + parents_field = :volatility_parents + children_field = :volatility_children + + elseif coupling_type isa NoiseCoupling + parents_field = :noise_parents + children_field = :noise_children + end + + #Add the parent to the child node + push!(getfield(child_node.edges, parents_field), parent_node) + + #Add the child node to the parent node + push!(getfield(parent_node.edges, children_field), child_node) + + #If the coupling type has a coupling strength + if hasproperty(coupling_type, :strength) + #If the user has not specified a coupling strength + if isnothing(coupling_type.strength) + #Use the defaults coupling strength + coupling_strength = node_defaults.coupling_strength + + #Otherwise + else + #Use the specified coupling strength + coupling_strength = coupling_type.strength + end + + #And set it as a parameter for the child + child_node.parameters.coupling_strengths[parent_node.name] = coupling_strength + end + + #If the coupling can have a transformation + if hasproperty(coupling_type, :transform) + #Save that transformation in the node + child_node.parameters.coupling_transforms[parent_node.name] = + coupling_type.transform + end + + #If the enhanced HGF update is the defaults, and if it is a precision coupling (volatility or noise) + if node_defaults.update_type isa EnhancedUpdate && coupling_type isa PrecisionCoupling + #Set the node to use the enhanced HGF update + parent_node.update_type = node_defaults.update_type + end +end diff --git a/src/update_hgf/node_updates/continuous_input_node.jl b/src/update_hgf/node_updates/continuous_input_node.jl index 5c6bddc..b418591 100644 --- a/src/update_hgf/node_updates/continuous_input_node.jl +++ b/src/update_hgf/node_updates/continuous_input_node.jl @@ -25,7 +25,7 @@ function calculate_prediction_mean(node::ContinuousInputNode) #Extract parents observation_parents = node.edges.observation_parents - #Initialize prediction at 0 + #Initialize prediction at the bias prediction_mean = node.parameters.bias #Sum the predictions of the parents diff --git a/src/update_hgf/node_updates/continuous_state_node.jl b/src/update_hgf/node_updates/continuous_state_node.jl index 148709e..5876fa3 100644 --- a/src/update_hgf/node_updates/continuous_state_node.jl +++ b/src/update_hgf/node_updates/continuous_state_node.jl @@ -30,16 +30,27 @@ Uses the equation `` \hat{\mu}_i=\mu_i+\sum_{j=1}^{j\;value\;parents} \mu_{j} \cdot \alpha_{i,j} `` """ function calculate_prediction_mean(node::ContinuousStateNode, stepsize::Real) - #Get out value parents + #Get out drift parents drift_parents = node.edges.drift_parents #Initialize the total drift as the baseline drift predicted_drift = node.parameters.drift - #Add contributions from value parents + #For each drift parent for parent in drift_parents - predicted_drift += - parent.states.posterior_mean * node.parameters.coupling_strengths[parent.name] + + #Get out the coupling transform + coupling_transform = node.parameters.coupling_transforms[parent.name] + + #Transform the parent's value + drift_increment = transform_parent_value( + parent.states.posterior_mean, + coupling_transform, + derivation_level = 0, + ) + + #Add the drift increment + predicted_drift += drift_increment * node.parameters.coupling_strengths[parent.name] end #Multiply with stepsize @@ -111,7 +122,7 @@ Update the posterior of a single continuous state node. This is the classic HGF """ function update_node_posterior!(node::ContinuousStateNode, update_type::ClassicUpdate) #Update posterior precision - node.states.posterior_precision = calculate_posterior_precision(node) + node.states.posterior_precision = calculate_posterior_precision(node, update_type) #Update posterior mean node.states.posterior_mean = calculate_posterior_mean(node, update_type) @@ -129,7 +140,7 @@ function update_node_posterior!(node::ContinuousStateNode, update_type::Enhanced node.states.posterior_mean = calculate_posterior_mean(node, update_type) #Update posterior precision - node.states.posterior_precision = calculate_posterior_precision(node) + node.states.posterior_precision = calculate_posterior_precision(node, update_type) return nothing end @@ -143,15 +154,22 @@ Calculates a node's posterior precision. Uses the equation `` \pi_i^{'} = \hat{\pi}_i +\underbrace{\sum_{j=1}^{j\;children} \alpha_{j,i} \cdot \hat{\pi}_{j}} _\text{sum \;of \;VAPE \; continuous \; value \;chidren} `` """ -function calculate_posterior_precision(node::ContinuousStateNode) +function calculate_posterior_precision( + node::ContinuousStateNode, + update_type::HGFUpdateType, +) #Initialize as the node's own prediction posterior_precision = node.states.prediction_precision #Add update terms from drift children for child in node.edges.drift_children - posterior_precision += - calculate_posterior_precision_increment(node, child, DriftCoupling()) + posterior_precision += calculate_posterior_precision_increment( + node, + child, + DriftCoupling(), + update_type, + ) end #Add update terms from observation children @@ -198,8 +216,27 @@ function calculate_posterior_precision_increment( node::ContinuousStateNode, child::ContinuousStateNode, coupling_type::DriftCoupling, + update_type::HGFUpdateType, ) - child.parameters.coupling_strengths[node.name]^2 * child.states.prediction_precision + #Get out the coupling strength and coupling stransform + coupling_strength = child.parameters.coupling_strengths[node.name] + coupling_transform = child.parameters.coupling_transforms[node.name] + + #Calculate the increment + child.states.prediction_precision * ( + coupling_strength^2 * transform_parent_value( + coupling_transform, + node.states.posterior_mean, + derivation_level = 1, + ) - + coupling_strength * + node.states.value_prediction_error * + transform_parent_value( + coupling_transform, + node.states.posterior_mean, + derivation_level = 2, + ) + ) end @@ -358,6 +395,11 @@ function calculate_posterior_mean_increment( ( ( child.parameters.coupling_strengths[node.name] * + transform_parent_value( + child.parameters.coupling_transforms[node.name], + node.states.posterior_mean, + derivation_level = 1, + ) * child.states.prediction_precision ) / node.states.posterior_precision ) * child.states.value_prediction_error @@ -373,6 +415,11 @@ function calculate_posterior_mean_increment( ( ( child.parameters.coupling_strengths[node.name] * + transform_parent_value( + child.parameters.coupling_transforms[node.name], + node.states.posterior_mean, + derivation_level = 1, + ) * child.states.prediction_precision ) / node.states.prediction_precision ) * child.states.value_prediction_error diff --git a/src/update_hgf/nonlinear_transforms.jl b/src/update_hgf/nonlinear_transforms.jl new file mode 100644 index 0000000..2217e17 --- /dev/null +++ b/src/update_hgf/nonlinear_transforms.jl @@ -0,0 +1,43 @@ +#Transformation (of varioius derivations) for a linear transformation +function transform_parent_value( + transform_type::LinearTransform, + parent_value::Real; + derivation_level::Integer, +) + if derivation_level == 0 + return parent_value + elseif derivation_level == 0 + return 1 + elseif derivation_level == 0 + return 0 + else + @error "derivation level is misspecified" + end +end + +#Transformation (of varioius derivations) for a nonlinear transformation +function transform_parent_value( + transform_type::NonlinearTransform, + parent_value::Real, + derivation_level::Integer, +) + + #Get the transformation function that fits the derivation level + if derivation_level == 0 + transform_function = node.parameters.coupling_transforms[parent.name].base_function + elseif derivation_level == 1 + transform_function = + node.parameters.coupling_transforms[parent.name].first_derivation + elseif derivation_level == 2 + transform_function = + node.parameters.coupling_transforms[parent.name].second_derivation + else + @error "derivation level is misspecified" + end + + #Get the transformation parameters + transform_parameters = node.parameters.coupling_transforms[parent.name].parameters + + #Transform the value + return transform_function(parent_value, transform_parameters) +end diff --git a/test/testsuite/test_initialization.jl b/test/testsuite/test_initialization.jl index 706e058..56556da 100644 --- a/test/testsuite/test_initialization.jl +++ b/test/testsuite/test_initialization.jl @@ -34,8 +34,8 @@ using Test ("u1", "x1") => ObservationCoupling(), ("u2", "x2") => ObservationCoupling(), ("u2", "x3") => NoiseCoupling(), - ("x1", "x3") => DriftCoupling(2), - ("x1", "x4") => VolatilityCoupling(2), + ("x1", "x3") => DriftCoupling(strength = 2), + ("x1", "x4") => VolatilityCoupling(strength = 2), ("x1", "x5") => VolatilityCoupling(), )