diff --git a/.gitignore b/.gitignore index b02f4d6..5c16191 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,17 @@ *.jl.*.cov *.jl.cov *.jl.mem -/docs/build/ .DS_Store +.vscode + testing_script*.jl settings.json + Manifest.toml docs/Manifest.toml test/Manifest.toml -/docs/src/generated_markdowns/*.md \ No newline at end of file + +/docs/src/generated +/docs/src/index.md +/docs/build/ +/build \ No newline at end of file diff --git a/Project.toml b/Project.toml index be97bc3..cf96c33 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,7 @@ authors = [ "Peter Thestrup Waade ptw@cas.au.dk", "Anna Hedvig Møller hedvig.2808@gmail.com", "Jacopo Comoglio jacopo.comoglio@gmail.com", "Christoph Mathys chmathys@cas.au.dk"] -version = "0.3.3" +version = "0.5.0" [deps] ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" @@ -12,7 +12,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" [compat] -ActionModels = "0.4" +ActionModels = "0.5" Distributions = "0.25" RecipesBase = "1" julia = "1.9" diff --git a/README.md b/README.md index 5e548de..255aefc 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,7 @@ premade_agent("help") ### Create agent ````@example index -agent = premade_agent("hgf_binary_softmax_action") +agent = premade_agent("hgf_binary_softmax") ```` ### Get states and parameters @@ -55,10 +55,10 @@ get_parameters(agent) ![Image1](docs/src/images/readme/get_parameters.png) -Set a new parameter for initial precision of x2 and define some inputs +Set a new parameter for initial precision of xprob and define some inputs ````@example index -set_parameters!(agent, ("x2", "initial_precision"), 0.9) +set_parameters!(agent, ("xprob", "initial_precision"), 0.9) inputs = [1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0]; nothing #hide ```` @@ -75,23 +75,23 @@ actions = give_inputs!(agent, inputs) using StatsPlots using Plots plot_trajectory(agent, ("u", "input_value")) -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("x", "prediction")) ```` ![Image1](docs/src/images/readme/plot_trajectory.png) -Plot state trajectory of input value, action and prediction of x1 +Plot state trajectory of input value, action and prediction of x ````@example index plot_trajectory(agent, ("u", "input_value")) plot_trajectory!(agent, "action") -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("x", "prediction")) ```` ![Image1](docs/src/images/readme/plot_trajectory_2.png) ### Fitting parameters ````@example index using Distributions -prior = Dict(("x2", "volatility") => Normal(1, 0.5)) +prior = Dict(("xprob", "volatility") => Normal(1, 0.5)) model = fit_model(agent, prior, inputs, actions, n_iterations = 20) ```` diff --git a/docs/Project.toml b/docs/Project.toml index 3bfd917..b3a6a54 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,10 +5,10 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" HierarchicalGaussianFiltering = "63d42c3e-681c-42be-892f-a47f35336a79" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/docs/src/Julia_src_files/index.jl b/docs/julia_files/index.jl similarity index 87% rename from docs/src/Julia_src_files/index.jl rename to docs/julia_files/index.jl index f079a20..6d5a5d8 100644 --- a/docs/src/Julia_src_files/index.jl +++ b/docs/julia_files/index.jl @@ -24,15 +24,15 @@ using ActionModels premade_agent("help") # ### Create agent -agent = premade_agent("hgf_binary_softmax_action") +agent = premade_agent("hgf_binary_softmax") # ### Get states and parameters get_states(agent) #- get_parameters(agent) -# Set a new parameter for initial precision of x2 and define some inputs -set_parameters!(agent, ("x2", "initial_precision"), 0.9) +# Set a new parameter for initial precision of xprob and define some inputs +set_parameters!(agent, ("xprob", "initial_precision"), 0.9) inputs = [1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0]; # ### Give inputs to the agent @@ -42,18 +42,18 @@ actions = give_inputs!(agent, inputs) using StatsPlots using Plots plot_trajectory(agent, ("u", "input_value")) -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("xbin", "prediction")) -# Plot state trajectory of input value, action and prediction of x1 +# Plot state trajectory of input value, action and prediction of xbin plot_trajectory(agent, ("u", "input_value")) plot_trajectory!(agent, "action") -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("xbin", "prediction")) # ### Fitting parameters using Distributions -prior = Dict(("x2", "volatility") => Normal(1, 0.5)) +prior = Dict(("xprob", "volatility") => Normal(1, 0.5)) model = fit_model(agent, prior, inputs, actions, n_iterations = 20) diff --git a/docs/julia_files/tutorials/classic_JGET.jl b/docs/julia_files/tutorials/classic_JGET.jl new file mode 100644 index 0000000..40e6303 --- /dev/null +++ b/docs/julia_files/tutorials/classic_JGET.jl @@ -0,0 +1,43 @@ +using ActionModels, HierarchicalGaussianFiltering +using CSV, DataFrames +using Plots, StatsPlots +using Distributions + + +# Get the path for the HGF superfolder +hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) +# Add the path to the data files +data_path = hgf_path * "/docs/src/tutorials/data/" + +#Load data +data = CSV.read(data_path * "classic_cannonball_data.csv", DataFrame) +inputs = data[(data.ID.==20).&(data.session.==1), :].outcome + +#Create HGF +hgf = premade_hgf("JGET", verbose = false) +#Create agent +agent = premade_agent("hgf_gaussian", hgf) +#Set parameters +parameters = Dict( + "action_noise" => 1, + ("u", "input_noise") => 0, + ("x", "initial_mean") => first(inputs) + 2, + ("x", "initial_precision") => 0.001, + ("x", "volatility") => -8, + ("xvol", "volatility") => -8, + ("xnoise", "volatility") => -7, + ("xnoise_vol", "volatility") => -2, + ("x", "xvol", "coupling_strength") => 1, + ("xnoise", "xnoise_vol", "coupling_strength") => 1, +) +set_parameters!(agent, parameters) +reset!(agent) + +#Simulate updates and actions +actions = give_inputs!(agent, inputs); +#Plot belief trajectories +plot_trajectory(agent, "u") +plot_trajectory!(agent, "x") +plot_trajectory(agent, "xvol") +plot_trajectory(agent, "xnoise") +plot_trajectory(agent, "xnoise_vol") diff --git a/docs/src/tutorials/classic_binary.jl b/docs/julia_files/tutorials/classic_binary.jl similarity index 60% rename from docs/src/tutorials/classic_binary.jl rename to docs/julia_files/tutorials/classic_binary.jl index d29f30d..8be1e9d 100644 --- a/docs/src/tutorials/classic_binary.jl +++ b/docs/julia_files/tutorials/classic_binary.jl @@ -14,7 +14,7 @@ using Distributions # Get the path for the HGF superfolder hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) # Add the path to the data files -data_path = hgf_path * "/docs/src/tutorials/data/" +data_path = hgf_path * "/docs/julia_files/tutorials/data/" # Load the data inputs = CSV.read(data_path * "classic_binary_inputs.csv", DataFrame)[!, 1]; @@ -23,59 +23,57 @@ inputs = CSV.read(data_path * "classic_binary_inputs.csv", DataFrame)[!, 1]; hgf_parameters = Dict( ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "volatility") => -2.5, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - ("x3", "volatility") => -6.0, - ("x3", "initial_mean") => 1, - ("x3", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, + ("xprob", "volatility") => -2.5, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "volatility") => -6.0, + ("xvol", "initial_mean") => 1, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, ); hgf = premade_hgf("binary_3level", hgf_parameters, verbose = false); # Create an agent -agent_parameters = Dict("sigmoid_action_precision" => 5); -agent = - premade_agent("hgf_unit_square_sigmoid_action", hgf, agent_parameters, verbose = false); +agent_parameters = Dict("action_noise" => 0.2); +agent = premade_agent("hgf_unit_square_sigmoid", hgf, agent_parameters, verbose = false); # Evolve agent and save actions actions = give_inputs!(agent, inputs); # Plot the trajectory of the agent plot_trajectory(agent, ("u", "input_value")) -plot_trajectory!(agent, ("x1", "prediction")) - +plot_trajectory!(agent, ("xbin", "prediction")) # - -plot_trajectory(agent, ("x2", "posterior")) -plot_trajectory(agent, ("x3", "posterior")) +plot_trajectory(agent, ("xprob", "posterior")) +plot_trajectory(agent, ("xvol", "posterior")) # Set fixed parameters fixed_parameters = Dict( - "sigmoid_action_precision" => 5, + "action_noise" => 0.2, ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - ("x3", "initial_mean") => 1, - ("x3", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, - ("x3", "volatility") => -6.0, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "initial_mean") => 1, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, + ("xvol", "volatility") => -6.0, ); # Set priors for parameter recovery -param_priors = Dict(("x2", "volatility") => Normal(-3.0, 0.5)); +param_priors = Dict(("xprob", "volatility") => Normal(-3.0, 0.5)); #- # Prior predictive plot plot_predictive_simulation( param_priors, agent, inputs, - ("x1", "prediction_mean"), + ("xbin", "prediction_mean"), n_simulations = 100, ) #- @@ -104,6 +102,6 @@ plot_predictive_simulation( fitted_model, agent, inputs, - ("x1", "prediction_mean"), + ("xbin", "prediction_mean"), n_simulations = 3, ) diff --git a/docs/src/tutorials/classic_usdchf.jl b/docs/julia_files/tutorials/classic_usdchf.jl similarity index 68% rename from docs/src/tutorials/classic_usdchf.jl rename to docs/julia_files/tutorials/classic_usdchf.jl index bb305a6..676fea2 100644 --- a/docs/src/tutorials/classic_usdchf.jl +++ b/docs/julia_files/tutorials/classic_usdchf.jl @@ -12,7 +12,7 @@ using Distributions # Get the path for the HGF superfolder hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) # Add the path to the data files -data_path = hgf_path * "/docs/src/tutorials/data/" +data_path = hgf_path * "/docs/julia_files/tutorials/data/" # Load the data inputs = Float64[] @@ -24,20 +24,19 @@ end #Create HGF hgf = premade_hgf("continuous_2level", verbose = false); -agent = premade_agent("hgf_gaussian_action", hgf, verbose = false); +agent = premade_agent("hgf_gaussian", hgf, verbose = false); # Set parameters for parameter recover parameters = Dict( - ("u", "x1", "value_coupling") => 1.0, - ("x1", "x2", "volatility_coupling") => 1.0, + ("x", "xvol", "coupling_strength") => 1.0, ("u", "input_noise") => -log(1e4), - ("x1", "volatility") => -13, - ("x2", "volatility") => -2, - ("x1", "initial_mean") => 1.04, - ("x1", "initial_precision") => 1 / (0.0001), - ("x2", "initial_mean") => 1.0, - ("x2", "initial_precision") => 1 / 0.1, - "gaussian_action_precision" => 100, + ("x", "volatility") => -13, + ("xvol", "volatility") => -2, + ("x", "initial_mean") => 1.04, + ("x", "initial_precision") => 1 / (0.0001), + ("xvol", "initial_mean") => 1.0, + ("xvol", "initial_precision") => 1 / 0.1, + "action_noise" => 0.01, ); set_parameters!(agent, parameters) @@ -59,7 +58,7 @@ plot_trajectory( xlabel = "Trading days since 1 January 2010", ) #- -plot_trajectory!(agent, ("x1", "posterior"), color = "red") +plot_trajectory!(agent, ("x", "posterior"), color = "red") plot_trajectory!( agent, "action", @@ -71,7 +70,7 @@ plot_trajectory!( #- plot_trajectory( agent, - "x2", + "xvol", color = "blue", size = (1300, 500), xlims = (0, 615), @@ -81,19 +80,18 @@ plot_trajectory( #- # Set priors for fitting fixed_parameters = Dict( - ("u", "x1", "value_coupling") => 1.0, - ("x1", "x2", "volatility_coupling") => 1.0, - ("x1", "initial_mean") => 0, - ("x1", "initial_precision") => 2000, - ("x2", "initial_mean") => 1.0, - ("x2", "initial_precision") => 600.0, - "gaussian_action_precision" => 100, + ("x", "xvol", "coupling_strength") => 1.0, + ("x", "initial_mean") => 0, + ("x", "initial_precision") => 2000, + ("xvol", "initial_mean") => 1.0, + ("xvol", "initial_precision") => 600.0, ); param_priors = Dict( ("u", "input_noise") => Normal(-6, 1), - ("x1", "volatility") => Normal(-4, 1), - ("x2", "volatility") => Normal(-4, 1), + ("x", "volatility") => Normal(-4, 1), + ("xvol", "volatility") => Normal(-4, 1), + "action_noise" => LogNormal(log(0.01), 1), ); #- # Prior predictive simulation plot @@ -101,7 +99,7 @@ plot_predictive_simulation( param_priors, agent, inputs, - ("x1", "posterior_mean"); + ("x", "posterior_mean"); n_simulations = 100, ) #- diff --git a/docs/src/tutorials/data/classic_binary_actions.csv b/docs/julia_files/tutorials/data/classic_binary_actions.csv similarity index 100% rename from docs/src/tutorials/data/classic_binary_actions.csv rename to docs/julia_files/tutorials/data/classic_binary_actions.csv diff --git a/docs/src/tutorials/data/classic_binary_inputs.csv b/docs/julia_files/tutorials/data/classic_binary_inputs.csv similarity index 100% rename from docs/src/tutorials/data/classic_binary_inputs.csv rename to docs/julia_files/tutorials/data/classic_binary_inputs.csv diff --git a/docs/src/tutorials/data/classic_cannonball_data.csv b/docs/julia_files/tutorials/data/classic_cannonball_data.csv similarity index 100% rename from docs/src/tutorials/data/classic_cannonball_data.csv rename to docs/julia_files/tutorials/data/classic_cannonball_data.csv diff --git a/docs/src/tutorials/data/classic_usdchf_inputs.dat b/docs/julia_files/tutorials/data/classic_usdchf_inputs.dat similarity index 100% rename from docs/src/tutorials/data/classic_usdchf_inputs.dat rename to docs/julia_files/tutorials/data/classic_usdchf_inputs.dat diff --git a/docs/src/Julia_src_files/all_functions.jl b/docs/julia_files/user_guide/all_functions.jl similarity index 100% rename from docs/src/Julia_src_files/all_functions.jl rename to docs/julia_files/user_guide/all_functions.jl diff --git a/docs/src/Julia_src_files/building_an_HGF.jl b/docs/julia_files/user_guide/building_an_HGF.jl similarity index 78% rename from docs/src/Julia_src_files/building_an_HGF.jl rename to docs/julia_files/user_guide/building_an_HGF.jl index 7f1b0d7..0329070 100644 --- a/docs/src/Julia_src_files/building_an_HGF.jl +++ b/docs/julia_files/user_guide/building_an_HGF.jl @@ -17,12 +17,16 @@ # We can recall from the HGF nodes, that a binary input node's parameters are category means and input precision. We will set category means to [0,1] and the input precision to Inf. -input_nodes = Dict( - "name" => "Input_node", - "type" => "binary", - "category_means" => [0, 1], - "input_precision" => Inf, -); +nodes = [ + BinaryInput("Input_node"), + BinaryState("binary_state_node"), + ContinuousState( + name = "continuous_state_node", + volatility = -2, + initial_mean = 0, + initial_precision = 1, + ), +] # ## Defining State Nodes @@ -30,43 +34,23 @@ input_nodes = Dict( # The continuous state node have evolution rate, initial mean and initial precision parameters which we specify as well. -state_nodes = [ - ## Configuring the first binary state node - Dict("name" => "binary_state_node", "type" => "binary"), - ## Configuring the continuous state node - Dict( - "name" => "continuous_state_node", - "type" => "continuous", - "volatility" => -2, - "initial_mean" => 0, - "initial_precision" => 1, - ), -]; - # ## Defining Edges # When defining the edges we start by sepcifying which node the perspective is from. So, when we specify the edges we start by specifying what the child in the relation is. # At the buttom of our hierarchy we have the binary input node. The Input node has binary state node as parent. -edges = [ - Dict("child" => "Input_node", "value_parents" => "binary_state_node"), - - ## The next relation is from the point of view of the binary state node. We specify out continous state node as parent with the value coupling as 1. - Dict("child" => "binary_state_node", "value_parents" => ("continuous_state_node", 1)), -]; +edges = Dict( + ("Input_node", "binary_state_node") => ObservationCoupling(), + ("binary_state_node", "continuous_state_node") => ProbabilityCoupling(1), +); # We are ready to initialize our HGF now. using HierarchicalGaussianFiltering using ActionModels -Binary_2_level_hgf = init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - verbose = false, -); +Binary_2_level_hgf = init_hgf(nodes = nodes, edges = edges, verbose = false); # We can access the states in our HGF: get_states(Binary_2_level_hgf) #- @@ -84,7 +68,7 @@ get_parameters(Binary_2_level_hgf) # We initialize the action model and create it. In a softmax action model we need a parameter from the agent called softmax action precision which is used in the update step of the action model. using Distributions -function binary_softmax_action(agent, input) +function binary_softmax(agent, input) ##------- Staty by getting all information --------- @@ -96,7 +80,7 @@ function binary_softmax_action(agent, input) target_state = agent.settings["target_state"] ##Take out the parameter from our agent - action_precision = agent.parameters["softmax_action_precision"] + action_noise = agent.parameters["action_noise"] ##Get the specified state out of the hgf target_value = get_states(hgf, target_state) @@ -104,7 +88,7 @@ function binary_softmax_action(agent, input) ##--------------- Update step starts ----------------- ##Use sotmax to get the action probability - action_probability = 1 / (1 + exp(-action_precision * target_value)) + action_probability = 1 / (1 + exp(action_noise * target_value)) ##---------------- Update step end ------------------ ##If the action probability is not between 0 and 1 @@ -133,11 +117,11 @@ end # Let's define our action model -action_model = binary_softmax_action; +action_model = binary_softmax; # The parameter of the agent is just softmax action precision. We set this value to 1 -parameters = Dict("softmax_action_precision" => 1); +parameters = Dict("action_noise" => 1); # The states of the agent are empty, but the states from the HGF will be accessible. @@ -145,10 +129,7 @@ states = Dict() # In the settings we specify what our target state is. We want it to be the prediction mean of our binary state node. -settings = Dict( - "hgf_actions" => "softmax_action", - "target_state" => ("binary_state_node", "prediction_mean"), -); +settings = Dict("target_state" => ("binary_state_node", "prediction_mean")); ## Let's initialize our agent agent = init_agent( diff --git a/docs/src/Julia_src_files/fitting_hgf_models.jl b/docs/julia_files/user_guide/fitting_hgf_models.jl similarity index 80% rename from docs/src/Julia_src_files/fitting_hgf_models.jl rename to docs/julia_files/user_guide/fitting_hgf_models.jl index 3970a92..d72a56f 100644 --- a/docs/src/Julia_src_files/fitting_hgf_models.jl +++ b/docs/julia_files/user_guide/fitting_hgf_models.jl @@ -47,22 +47,21 @@ using HierarchicalGaussianFiltering hgf_parameters = Dict( ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "volatility") => -2.5, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - ("x3", "volatility") => -6.0, - ("x3", "initial_mean") => 1, - ("x3", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, + ("xprob", "volatility") => -2.5, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "volatility") => -6.0, + ("xvol", "initial_mean") => 1, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, ) hgf = premade_hgf("binary_3level", hgf_parameters, verbose = false) # Create an agent -agent_parameters = Dict("sigmoid_action_precision" => 5); -agent = - premade_agent("hgf_unit_square_sigmoid_action", hgf, agent_parameters, verbose = false); +agent_parameters = Dict("action_noise" => 0.2); +agent = premade_agent("hgf_unit_square_sigmoid", hgf, agent_parameters, verbose = false); # Define a set of inputs inputs = @@ -76,7 +75,7 @@ actions = give_inputs!(agent, inputs) using StatsPlots using Plots plot_trajectory(agent, ("u", "input_value")) -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("xbin", "prediction")) @@ -84,23 +83,23 @@ plot_trajectory!(agent, ("x1", "prediction")) # We define a set of fixed parameters to use in this fitting process: -# Set fixed parameters. We choose to fit the evolution rate of the x2 node. +# Set fixed parameters. We choose to fit the evolution rate of the xprob node. fixed_parameters = Dict( - "sigmoid_action_precision" => 5, + "action_noise" => 0.2, ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - ("x3", "initial_mean") => 1, - ("x3", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, - ("x3", "volatility") => -6.0, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "initial_mean") => 1, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, + ("xvol", "volatility") => -6.0, ); -# As you can read from the fixed parameters, the evolution rate of x2 is not configured. We set the prior for the x2 evolution rate: +# As you can read from the fixed parameters, the evolution rate of xprob is not configured. We set the prior for the xprob evolution rate: using Distributions -param_priors = Dict(("x2", "volatility") => Normal(-3.0, 0.5)); +param_priors = Dict(("xprob", "volatility") => Normal(-3.0, 0.5)); # We can fit the evolution rate by inputting the variables: @@ -114,9 +113,9 @@ fitted_model = fit_model( verbose = true, n_iterations = 10, ) +set_parameters!(agent, hgf_parameters) # ## Plotting Functions - plot(fitted_model) # Plot the posterior @@ -132,7 +131,7 @@ plot_parameter_distribution(fitted_model, param_priors) # We will provide a code example of prior and posterior predictive simulation. We can fit a different parameter, and start with a prior predictive check. # Set prior we wish to simulate over -param_priors = Dict(("x3", "initial_precision") => Normal(1.0, 0.5)); +param_priors = Dict(("xvol", "initial_precision") => Normal(1.0, 0.5)); # When we look at our predictive simulation plot we should aim to see actions in the plausible space they could be in. # Prior predictive plot @@ -140,7 +139,7 @@ plot_predictive_simulation( param_priors, agent, inputs, - ("x1", "prediction_mean"), + ("xbin", "prediction_mean"), n_simulations = 100, ) @@ -149,20 +148,18 @@ plot_predictive_simulation( # Fit the actions where we use the default parameter values from the HGF. fitted_model = fit_model(agent, param_priors, inputs, actions, verbose = true, n_iterations = 10) - +set_parameters!(agent, hgf_parameters) # We can place our turing chain as a our posterior in the function, and get our posterior predictive simulation plot: - plot_predictive_simulation( fitted_model, agent, inputs, - ("x1", "prediction_mean"), + ("xbin", "prediction_mean"), n_simulations = 100, ) # We can get the posterior - get_posteriors(fitted_model) # plot the chains diff --git a/docs/src/Julia_src_files/premade_HGF.jl b/docs/julia_files/user_guide/premade_HGF.jl similarity index 86% rename from docs/src/Julia_src_files/premade_HGF.jl rename to docs/julia_files/user_guide/premade_HGF.jl index 8a660c0..5e29ea8 100644 --- a/docs/src/Julia_src_files/premade_HGF.jl +++ b/docs/julia_files/user_guide/premade_HGF.jl @@ -19,7 +19,9 @@ using DataFrames #hide using Plots #hide using StatsPlots #hide -hgf_path_continuous = dirname(dirname(pathof(HierarchicalGaussianFiltering))); #hide +#CSV.read(pwd(), DataFrame) + +hgf_path_continuous = dirname(pathof(HierarchicalGaussianFiltering)); #hide hgf_path_continuous = hgf_path_continuous * "/docs/src/tutorials/data/"; #hide inputs_continuous = Float64[]; #hide @@ -55,13 +57,13 @@ inputs_binary = CSV.read(hgf_path_binary * "classic_binary_inputs.csv", DataFram #Create HGF and Agent continuous_2_level = premade_hgf("continuous_2level"); agent_continuous_2_level = - premade_agent("hgf_gaussian_action", continuous_2_level, verbose = false); + premade_agent("hgf_gaussian", continuous_2_level, verbose = false); # Evolve agent plot trajetories give_inputs!(agent_continuous_2_level, inputs_continuous); plot_trajectory( agent_continuous_2_level, - "x2", + "xvol", color = "blue", size = (1300, 500), xlims = (0, 615), @@ -82,13 +84,13 @@ plot_trajectory( #Create HGF and Agent JGET = premade_hgf("JGET"); -agent_JGET = premade_agent("hgf_gaussian_action", JGET, verbose = false); +agent_JGET = premade_agent("hgf_gaussian", JGET, verbose = false); # Evolve agent plot trajetories give_inputs!(agent_JGET, inputs_continuous); plot_trajectory( agent_JGET, - "x2", + "xvol", color = "blue", size = (1300, 500), xlims = (0, 615), @@ -110,17 +112,17 @@ hgf_binary_2_level = premade_hgf("binary_2level", verbose = false); # Create an agent agent_binary_2_level = - premade_agent("hgf_unit_square_sigmoid_action", hgf_binary_2_level, verbose = false); + premade_agent("hgf_unit_square_sigmoid", hgf_binary_2_level, verbose = false); # Evolve agent plot trajetories give_inputs!(agent_binary_2_level, inputs_binary); plot_trajectory(agent_binary_2_level, ("u", "input_value")) -plot_trajectory!(agent_binary_2_level, ("x1", "prediction")) +plot_trajectory!(agent_binary_2_level, ("xbin", "prediction")) #- -plot_trajectory(agent_binary_2_level, ("x2", "posterior")) +plot_trajectory(agent_binary_2_level, ("xprob", "posterior")) # ## Binary 3-level HGF @@ -134,18 +136,18 @@ hgf_binary_3_level = premade_hgf("binary_3level", verbose = false); # Create an agent agent_binary_3_level = - premade_agent("hgf_unit_square_sigmoid_action", hgf_binary_3_level, verbose = false); + premade_agent("hgf_unit_square_sigmoid", hgf_binary_3_level, verbose = false); # Evolve agent plot trajetories give_inputs!(agent_binary_3_level, inputs_binary); plot_trajectory(agent_binary_3_level, ("u", "input_value")) -plot_trajectory!(agent_binary_3_level, ("x1", "prediction")) +plot_trajectory!(agent_binary_3_level, ("xbin", "prediction")) #- -plot_trajectory(agent_binary_3_level, ("x2", "posterior")) +plot_trajectory(agent_binary_3_level, ("xprob", "posterior")) #- -plot_trajectory(agent_binary_3_level, ("x3", "posterior")) +plot_trajectory(agent_binary_3_level, ("xvol", "posterior")) diff --git a/docs/src/Julia_src_files/premade_models.jl b/docs/julia_files/user_guide/premade_models.jl similarity index 77% rename from docs/src/Julia_src_files/premade_models.jl rename to docs/julia_files/user_guide/premade_models.jl index 54f4222..ec6351c 100644 --- a/docs/src/Julia_src_files/premade_models.jl +++ b/docs/julia_files/user_guide/premade_models.jl @@ -12,10 +12,10 @@ # ## HGF with Gaussian Action Noise agent -# This premade agent model can be found as "hgf_gaussian_action" in the package. The Action distribution is a gaussian distribution with mean of the target state from the chosen HGF, and the standard deviation consisting of the action precision parameter inversed. #md +# This premade agent model can be found as "hgf_gaussian" in the package. The Action distribution is a gaussian distribution with mean of the target state from the chosen HGF, and the standard deviation consisting of the action precision parameter inversed. #md # - Default hgf: contionus_2level -# - Default Target state: (x1, posterior mean) +# - Default Target state: (x, posterior mean) # - Default Parameters: gaussian action precision = 1 # ## HGF Binary Softmax agent @@ -23,7 +23,7 @@ # The action distribution is a Bernoulli distribution, and the parameter is action probability. Action probability is calculated using a softmax on the action precision parameter and the target value from the HGF. #md # - Default hgf: binary_3level -# - Default target state; (x1, prediction mean) +# - Default target state; (xbin, prediction mean) # - Default parameters: softmax action precision = 1 # ## HGF unit square sigmoid agent @@ -31,7 +31,7 @@ # The action distribution is Bernoulli distribution with the parameter beinga a softmax of the target value and action precision. # - Default hgf: binary_3level -# - Default target state; (x1, prediction mean) +# - Default target state; (xbin, prediction mean) # - Default parameters: softmax action precision = 1 @@ -40,7 +40,7 @@ # The action distribution is a categorical distribution. The action model takes the target node from the HGF, and takes out the prediction state. This state is a vector of values for each category. The vector is the only thing used in the categorical distribution # - Default hgf: categorical_3level -# - Default target state: Target categorical node x1 +# - Default target state: Target categorical node xcat # - Default parameters: none # ## Using premade agents @@ -53,7 +53,7 @@ using HierarchicalGaussianFiltering premade_agent("help") # Define an agent with default parameter values and default HGF -agent = premade_agent("hgf_binary_softmax_action") +agent = premade_agent("hgf_binary_softmax") # ## Utility functions for accessing parameters and states @@ -61,21 +61,21 @@ agent = premade_agent("hgf_binary_softmax_action") get_parameters(agent) # Get specific parameter in agent: -get_parameters(agent, ("x3", "initial_precision")) +get_parameters(agent, ("xvol", "initial_precision")) # Get all states in an agent: get_states(agent) # Get specific state in an agent: -get_states(agent, ("x1", "posterior_precision")) +get_states(agent, ("xbin", "posterior_precision")) # Set a parameter value -set_parameters!(agent, ("x3", "initial_precision"), 0.4) +set_parameters!(agent, ("xvol", "initial_precision"), 0.4) # Set multiple parameter values set_parameters!( agent, - Dict(("x3", "initial_precision") => 1, ("x3", "volatility") => 0), + Dict(("xvol", "initial_precision") => 1, ("xvol", "volatility") => 0), ) @@ -88,7 +88,7 @@ input = [1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0] actions = give_inputs!(agent, input) # Get the history of a single state in the agent -get_history(agent, ("x1", "prediction_mean")) +get_history(agent, ("xbin", "prediction_mean")) # We can plot the input and prediciton means with plot trajectory. Notice, when using plot_trajectory!() you can layer plots. @@ -98,7 +98,7 @@ using Plots plot_trajectory(agent, ("u", "input_value")) # Let's add prediction mean on top of the plot -plot_trajectory!(agent, ("x1", "prediction_mean")) +plot_trajectory!(agent, ("xbin", "prediction_mean")) # ### Overview of functions diff --git a/docs/src/Julia_src_files/the_HGF_nodes.jl b/docs/julia_files/user_guide/the_HGF_nodes.jl similarity index 100% rename from docs/src/Julia_src_files/the_HGF_nodes.jl rename to docs/julia_files/user_guide/the_HGF_nodes.jl diff --git a/docs/src/Julia_src_files/updating_the_HGF.jl b/docs/julia_files/user_guide/updating_the_HGF.jl similarity index 100% rename from docs/src/Julia_src_files/updating_the_HGF.jl rename to docs/julia_files/user_guide/updating_the_HGF.jl diff --git a/docs/src/Julia_src_files/utility_functions.jl b/docs/julia_files/user_guide/utility_functions.jl similarity index 69% rename from docs/src/Julia_src_files/utility_functions.jl rename to docs/julia_files/user_guide/utility_functions.jl index bc5324f..cdd278c 100644 --- a/docs/src/Julia_src_files/utility_functions.jl +++ b/docs/julia_files/user_guide/utility_functions.jl @@ -20,7 +20,7 @@ using HierarchicalGaussianFiltering premade_agent("help") # set agent -agent = premade_agent("hgf_binary_softmax_action") +agent = premade_agent("hgf_binary_softmax") # ### Getting Parameters @@ -31,10 +31,10 @@ agent = premade_agent("hgf_binary_softmax_action") get_parameters(agent) # getting couplings -# ERROR WITH THIS get_parameters(agent, ("x2", "x3", "volatility_coupling")) +get_parameters(agent, ("xprob", "xvol", "coupling_strength")) # getting multiple parameters specify them in a vector -get_parameters(agent, [("x3", "volatility"), ("x3", "initial_precision")]) +get_parameters(agent, [("xvol", "volatility"), ("xvol", "initial_precision")]) # ### Getting States @@ -43,10 +43,13 @@ get_parameters(agent, [("x3", "volatility"), ("x3", "initial_precision")]) get_states(agent) #getting a single state -get_states(agent, ("x2", "posterior_precision")) +get_states(agent, ("xprob", "posterior_precision")) #getting multiple states -get_states(agent, [("x2", "posterior_precision"), ("x2", "volatility_weighted_prediction_precision")]) +get_states( + agent, + [("xprob", "posterior_precision"), ("xprob", "effective_prediction_precision")], +) # ### Setting Parameters @@ -54,38 +57,38 @@ get_states(agent, [("x2", "posterior_precision"), ("x2", "volatility_weighted_pr # you can set parameters before you initialize your agent, you can set them after and change them when you wish to. # Let's try an initialize a new agent with parameters. We start by choosing the premade unit square sigmoid action agent whose parameter is sigmoid action precision. -agent_parameter = Dict("sigmoid_action_precision" => 3) +agent_parameter = Dict("action_noise" => 0.3) #We also specify our HGF and custom parameter settings: hgf_parameters = Dict( ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "volatility") => -2.5, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - ("x3", "volatility") => -6.0, - ("x3", "initial_mean") => 1, - ("x3", "initial_precision") => 1, - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, + ("xprob", "volatility") => -2.5, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "volatility") => -6.0, + ("xvol", "initial_mean") => 1, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, ) hgf = premade_hgf("binary_3level", hgf_parameters) # Define our agent with the HGF and agent parameter settings -agent = premade_agent("hgf_unit_square_sigmoid_action", hgf, agent_parameter) +agent = premade_agent("hgf_unit_square_sigmoid", hgf, agent_parameter) # Changing a single parameter -set_parameters!(agent, ("x3", "initial_precision"), 4) +set_parameters!(agent, ("xvol", "initial_precision"), 4) # Changing multiple parameters set_parameters!( agent, - Dict(("x3", "initial_precision") => 5, ("x1", "x2", "value_coupling") => 2.0), + Dict(("xvol", "initial_precision") => 5, ("xbin", "xprob", "coupling_strength") => 2.0), ) # ###Giving Inputs @@ -143,12 +146,12 @@ get_history(agent) #- # getting history of single state -get_history(agent, ("x3", "posterior_precision")) +get_history(agent, ("xvol", "posterior_precision")) #- # getting history of multiple states: -get_history(agent, [("x1", "prediction_mean"), ("x3", "posterior_precision")]) +get_history(agent, [("xbin", "prediction_mean"), ("xvol", "posterior_precision")]) # ### Plotting State Trajectories @@ -158,29 +161,26 @@ using Plots plot_trajectory(agent, ("u", "input_value")) #Adding state trajectory on top -plot_trajectory!(agent, ("x1", "prediction")) +plot_trajectory!(agent, ("xbin", "prediction")) # Plotting more individual states: -## Plot posterior of x2 -plot_trajectory(agent, ("x2", "posterior")) +## Plot posterior of xprob +plot_trajectory(agent, ("xprob", "posterior")) #- -## Plot posterior of x3 -plot_trajectory(agent, ("x3", "posterior")) +## Plot posterior of xvol +plot_trajectory(agent, ("xvol", "posterior")) # ### Getting Predictions -# You can specify an HGF or an agent in the funciton. The default node to extract is the node "x1" which is the first level node in every premade HGF structure. - -# get prediction of the last state -get_prediction(agent) +# You can specify an HGF or an agent in the funciton. #specify another node to get predictions from: -get_prediction(agent, "x2") +get_prediction(agent, "xprob") # ### Getting Purprise diff --git a/docs/make.jl b/docs/make.jl index f6bd8a8..a05db21 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,43 +2,60 @@ using HierarchicalGaussianFiltering using Documenter using Literate + +hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) + +juliafiles_path = hgf_path * "/docs/julia_files" +user_guides_path = juliafiles_path * "/user_guide" +tutorials_path = juliafiles_path * "/tutorials" + +markdown_src_path = hgf_path * "/docs/src" +theory_path = markdown_src_path * "/theory" +generated_user_guide_path = markdown_src_path * "/generated/user_guide" +generated_tutorials_path = markdown_src_path * "/generated/tutorials" + + #Remove old tutorial markdown files -for filename in readdir("docs/src/generated_markdowns") - rm("docs/src/generated_markdowns/" * filename) +for filename in readdir(generated_user_guide_path) + if endswith(filename, ".md") + rm(generated_user_guide_path * "/" * filename) + end end -rm("docs/src/index.md") -#Generate new markdown files from the documentation source files -for filename in readdir("docs/src/Julia_src_files") - if endswith(filename, ".jl") +for filename in readdir(generated_tutorials_path) + if endswith(filename, ".md") + rm(generated_tutorials_path * "/" * filename) + end +end +rm(markdown_src_path * "/" * "index.md") - #Place the index file in another folder than the rest of the documentation - if startswith(filename, "index") - Literate.markdown( - "docs/src/Julia_src_files/" * filename, - "docs/src", - documenter = true, - ) - else - Literate.markdown( - "docs/src/Julia_src_files/" * filename, - "docs/src/generated_markdowns", - documenter = true, - ) - end +#Generate index markdown file +Literate.markdown(juliafiles_path * "/" * "index.jl", markdown_src_path, documenter = true) + +#Generate markdown files for user guide +for filename in readdir(user_guides_path) + if endswith(filename, ".jl") + Literate.markdown( + user_guides_path * "/" * filename, + generated_user_guide_path, + documenter = true, + ) end end -#Generate new tutorial markdown files from the tutorials -for filename in readdir("docs/src/tutorials") +#Generate markdown files for tutorials +for filename in readdir(tutorials_path) if endswith(filename, ".jl") Literate.markdown( - "docs/src/tutorials/" * filename, - "docs/src/generated_markdowns", + tutorials_path * "/" * filename, + generated_tutorials_path, documenter = true, ) end end + + + #Set documenter metadata DocMeta.setdocmeta!( HierarchicalGaussianFiltering, @@ -61,28 +78,30 @@ makedocs(; ), pages = [ "Introduction to Hierarchical Gaussian Filtering" => "./index.md", - "Theory" => [ - "./theory/genmodel.md", - "./theory/node.md", - "./theory/vape.md", - "./theory/vope.md", - ], - "Using the package" => [ - "The HGF Nodes" => "./generated_markdowns/the_HGF_nodes.md", - "Building an HGF" => "./generated_markdowns/building_an_HGF.md", - "Updating the HGF" => "./generated_markdowns/updating_the_HGF.md", - "List Of Premade Agent Models" => "./generated_markdowns/premade_models.md", - "List Of Premade HGF's" => "./generated_markdowns/premade_HGF.md", - "Fitting an HGF-agent model to data" => "./generated_markdowns/fitting_hgf_models.md", - "Utility Functions" => "./generated_markdowns/utility_functions.md", - ], - "Tutorials" => [ - "classic binary" => "./generated_markdowns/classic_binary.md", - "classic continouous" => "./generated_markdowns/classic_usdchf.md", - ], - "All Functions" => "./generated_markdowns/all_functions.md", + # "Theory" => [ + # "./theory" * "/genmodel.md", + # "./theory" * "/node.md", + # "./theory" * "/vape.md", + # "./theory" * "/vope.md", + # ], + # "Using the package" => [ + # "The HGF Nodes" => "./generated/user_guide" * "/the_HGF_nodes.md", + # "Building an HGF" => "./generated/user_guide" * "/building_an_HGF.md", + # "Updating the HGF" => "./generated/user_guide" * "/updating_the_HGF.md", + # "List Of Premade Agent Models" => "./generated/user_guide" * "/premade_models.md", + # "List Of Premade HGF's" => "./generated/user_guide" * "/premade_HGF.md", + # "Fitting an HGF-agent model to data" => "./generated/user_guide" * "/fitting_hgf_models.md", + # "Utility Functions" => "./generated/user_guide" * "/utility_functions.md", + # ], + # "Tutorials" => [ + # "classic binary" => "./generated/tutorials" * "/classic_binary.md", + # "classic continouous" => "./generated/tutorials" * "/classic_usdchf.md", + # "classic JGET" => "./generated/tutorials" * "/classic_JGET.md", + # ], + # "All Functions" => "./generated/user_guide" * "/all_functions.md", ], ) + deploydocs(; repo = "github.com/ilabcode/HierarchicalGaussianFiltering.jl", devbranch = "main", diff --git a/docs/src/index.md b/docs/src/index.md deleted file mode 100644 index c7e3555..0000000 --- a/docs/src/index.md +++ /dev/null @@ -1,113 +0,0 @@ -```@meta -EditURL = "/docs/src/Julia_src_files/index.jl" -``` - -# Welcome to The Hierarchical Gaussian Filtering Package! - -Hierarchical Gaussian Filtering (HGF) is a novel and adaptive package for doing cognitive and behavioral modelling. With the HGF you can fit time series data fit participant-level individual parameters, measure group differences based on model-specific parameters or use the model for any time series with underlying change in uncertainty. - -The HGF consists of a network of probabilistic nodes hierarchically structured. The hierarchy is determined by the coupling between nodes. A node (child node) in the network can inheret either its value or volatility sufficient statistics from a node higher in the hierarchy (a parent node). - -The presentation of a new observation at the lower level of the hierarchy (i.e. the input node) trigger a recursuve update of the nodes belief throught the bottom-up propagation of precision-weigthed prediction error. - -The HGF will be explained in more detail in the theory section of the documentation - -It is also recommended to check out the ActionModels.jl pacakge for stronger intuition behind the use of agents and action models. - -## Getting started - -The last official release can be downloaded from Julia with "] add HierarchicalGaussianFiltering" - -We provide a script for getting started with commonly used functions and use cases - -Load packages - -````@example index -using HierarchicalGaussianFiltering -using ActionModels -```` - -### Get premade agent - -````@example index -premade_agent("help") -```` - -### Create agent - -````@example index -agent = premade_agent("hgf_binary_softmax_action") -```` - -### Get states and parameters - -````@example index -get_states(agent) -```` - -````@example index -get_parameters(agent) -```` - -Set a new parameter for initial precision of x2 and define some inputs - -````@example index -set_parameters!(agent, ("x2", "initial_precision"), 0.9) -inputs = [1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0]; -nothing #hide -```` - -### Give inputs to the agent - -````@example index -actions = give_inputs!(agent, inputs) -```` - -### Plot state trajectories of input and prediction - -````@example index -using StatsPlots -using Plots -plot_trajectory(agent, ("u", "input_value")) -plot_trajectory!(agent, ("x1", "prediction")) -```` - -Plot state trajectory of input value, action and prediction of x1 - -````@example index -plot_trajectory(agent, ("u", "input_value")) -plot_trajectory!(agent, "action") -plot_trajectory!(agent, ("x1", "prediction")) -```` - -### Fitting parameters - -````@example index -using Distributions -prior = Dict(("x2", "volatility") => Normal(1, 0.5)) - -model = fit_model(agent, prior, inputs, actions, n_iterations = 20) -```` - -### Plot chains - -````@example index -plot(model) -```` - -### Plot prior angainst posterior - -````@example index -plot_parameter_distribution(model, prior) -```` - -### Get posterior - -````@example index -get_posteriors(model) -```` - ---- - -*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).* - diff --git a/docs/src/tutorials/classic_JGET.jl b/docs/src/tutorials/classic_JGET.jl deleted file mode 100644 index 0aae4de..0000000 --- a/docs/src/tutorials/classic_JGET.jl +++ /dev/null @@ -1,72 +0,0 @@ -using ActionModels, HierarchicalGaussianFiltering -using CSV, DataFrames -using Plots, StatsPlots -using Distributions -path = "docs/src/tutorials/data" - -#Load data -data = CSV.read("$path/classic_cannonball_data.csv", DataFrame) - -#Create HGF -hgf = premade_hgf("JGET", verbose = false) -#Create agent -agent = premade_agent("hgf_gaussian_action", hgf) -#Set parameters -parameters = Dict( - "gaussian_action_precision" => 1, - ("x1", "volatility") => -8, - ("x2", "volatility") => -5, - ("x3", "volatility") => -5, - ("x4", "volatility") => -5, - ("x1", "x2", "volatility_coupling") => 1, - ("x3", "x4", "volatility_coupling") => 1, -) -set_parameters!(agent, parameters) - -inputs = data[(data.ID.==20).&(data.session.==1), :].outcome -#Simulate updates and actions -actions = give_inputs!(agent, inputs); -#Plot belief trajectories -plot_trajectory(agent, "u") -plot_trajectory!(agent, "x1") -plot_trajectory(agent, "x2") -plot_trajectory(agent, "x3") -plot_trajectory(agent, "x4") - -priors = Dict( - "gaussian_action_precision" => LogNormal(-1, 0.1), - ("x1", "volatility") => Normal(-8, 1), -) - -data_subset = data[(data.ID.∈[[20, 21]]).&(data.session.∈[[1, 2]]), :] - -using Distributed -addprocs(6, exeflags = "--project") -@everywhere @eval using HierarchicalGaussianFiltering - -results = fit_model( - agent, - priors, - data_subset, - independent_group_cols = [:ID, :session], - input_cols = [:outcome], - action_cols = [:response], - n_cores = 6, -) - -fitted_model = results[(20, 1)] - -plot_parameter_distribution(fitted_model, priors) - - - - -posterior = get_posteriors(fitted_model) - -set_parameters!(agent, posterior) - -reset!(agent) - -give_inputs!(agent, inputs) - -get_history(agent, ("x1", "value_prediction_error")) diff --git a/src/ActionModels_variations/core/plot_trajectory.jl b/src/ActionModels_variations/core/plot_trajectory.jl index 110f07c..d616da5 100644 --- a/src/ActionModels_variations/core/plot_trajectory.jl +++ b/src/ActionModels_variations/core/plot_trajectory.jl @@ -160,6 +160,9 @@ end #Get the node node = hgf.all_nodes[node_name] + #Get the timesteps + timesteps = hgf.timesteps + #If the entire distribution is to be plotted if state_name in ["posterior", "prediction"] && !(node isa CategoricalStateNode) @@ -187,7 +190,7 @@ end end #Plot the history of means - history_mean + (timesteps, history_mean) end #If single state is specified @@ -227,7 +230,7 @@ end title --> "State trajectory" #Plot the history - state_history + (timesteps, state_history) end end end diff --git a/src/ActionModels_variations/utils/get_parameters.jl b/src/ActionModels_variations/utils/get_parameters.jl index 6c6008a..c1a890c 100644 --- a/src/ActionModels_variations/utils/get_parameters.jl +++ b/src/ActionModels_variations/utils/get_parameters.jl @@ -43,13 +43,12 @@ function ActionModels.get_parameters(hgf::HGF, target_param::Tuple{String,String return param end -##For coupling strengths +##For coupling strengths and coupling transforms function ActionModels.get_parameters(hgf::HGF, target_param::Tuple{String,String,String}) #Unpack node name, parent name and param name (node_name, parent_name, param_name) = target_param - #If the node does not exist if !(node_name in keys(hgf.all_nodes)) #Throw an error @@ -59,46 +58,65 @@ function ActionModels.get_parameters(hgf::HGF, target_param::Tuple{String,String #Get out the node node = hgf.all_nodes[node_name] + #If the parameter is a coupling strength + if param_name == "coupling_strength" - #If the parameter does not exist in the node - if !(Symbol(param_name) in fieldnames(typeof(node.parameters))) - #Throw an error - throw( - ArgumentError( - "The node $node_name does not have the parameter $param_name in its parameters", - ), - ) - end + #Get out the dictionary of coupling strengths + coupling_strengths = getproperty(node.parameters, :coupling_strengths) - #Get out the dictionary of coupling strengths - coupling_strengths = getproperty(node.parameters, Symbol(param_name)) + #If the specified parent is not in the dictionary + if !(parent_name in keys(coupling_strengths)) + #Throw an error + throw( + ArgumentError( + "The node $node_name does not have a coupling strength parameter to a parent called $parent_name", + ), + ) + end - #If the specified parent is not in the dictionary - if !(parent_name in keys(coupling_strengths)) - #Throw an error - throw( - ArgumentError( - "The node $node_name does not have a $param_name to a parent called $parent_name", - ), - ) - end + #Get the coupling strength for that given parent + param = coupling_strengths[parent_name] - #Get the coupling strength for that given parent - param = coupling_strengths[parent_name] + else + + #Get out the coupling transforms + coupling_transforms = getproperty(node.parameters, :coupling_transforms) + + #If the specified parent is not in the dictionary + if !(parent_name in keys(coupling_transforms)) + #Throw an error + throw( + ArgumentError( + "The node $node_name does not have a coupling transformation to a parent called $parent_name", + ), + ) + end + + #If the specified parameter does not exist for the transform + if !(param_name in keys(coupling_transforms.parameters)) + throw( + ArgumentError( + "There is no parameter called $param_name for the transformation function between $node_name and its parent $parent_name", + ), + ) + end + + #Extract the parameter + param = coupling_transforms.parameters[param_name] + end return param end -### For getting all parameters of a specific node ### - +### For getting a single-string parameter (a parameter group), or all parameters of a node ### function ActionModels.get_parameters(hgf::HGF, target_parameter::String) - #If the target parameter is a shared parameter - if target_parameter in keys(hgf.shared_parameters) - #Acess the parameter value in shared_parameters - return hgf.shared_parameters[target_parameter].value + #If the target parameter is a parameter group + if target_parameter in keys(hgf.parameter_groups) + #Access the parameter value in parameter_groups + return hgf.parameter_groups[target_parameter].value #If the target parameter is a node elseif target_parameter in keys(hgf.all_nodes) #Take out the node @@ -109,7 +127,7 @@ function ActionModels.get_parameters(hgf::HGF, target_parameter::String) #If the target parameter is neither a node nor in the shared parameters throw an error throw( ArgumentError( - "The node or parameter $target_parameter does not exist in the HGF or in shared parameters", + "The node or parameter $target_parameter does not exist in the HGF's nodes or parameter groups", ), ) end @@ -170,13 +188,13 @@ function ActionModels.get_parameters(hgf::HGF) end #If there are shared parameters - if length(hgf.shared_parameters) > 0 + if length(hgf.parameter_groups) > 0 #Go through each shared parameter - for (shared_parameter_key, shared_parameter_value) in hgf.shared_parameters - #Remove derived parameters from the list - filter!(x -> x[1] ∉ shared_parameter_value.derived_parameters, parameters) - #Set the shared parameter value - parameters[shared_parameter_key] = shared_parameter_value.value + for (parameter_group, grouped_parameters) in hgf.parameter_groups + #Remove grouped parameters from the list + filter!(x -> x[1] ∉ grouped_parameters.grouped_parameters, parameters) + #Add the parameter group parameter instead + parameters[parameter_group] = grouped_parameters.value end end @@ -193,19 +211,37 @@ function ActionModels.get_parameters(node::AbstractNode) for param_key in fieldnames(typeof(node.parameters)) #If the parameter is a coupling strength - if param_key in (:value_coupling, :volatility_coupling) + if param_key == :coupling_strengths #Get out the dict with coupling strengths - coupling_strengths = getproperty(node.parameters, param_key) + coupling_strengths = node.parameters.coupling_strengths #Go through each parent - for parent_name in keys(coupling_strengths) + for (parent_name, coupling_strength) in coupling_strengths #Add the coupling strength to the ouput dict - parameters[(node.name, parent_name, string(param_key))] = - coupling_strengths[parent_name] + parameters[(node.name, parent_name, "coupling_strength")] = + coupling_strength end + + #If the parameter is a coupling transform + elseif param_key == :coupling_transforms + + #Go through each parent and corresponding transform + for (parent_name, coupling_transform) in node.parameters.coupling_transforms + + #Go through each parameter for the transform + for (coupling_parameter, parameter_value) in coupling_transform.parameters + + #Add the coupling strength to the ouput dict + parameters[(node.name, parent_name, coupling_parameter)] = + parameter_value + + end + end + + #For other nodes else #And add their values to the dictionary parameters[(node.name, String(param_key))] = diff --git a/src/ActionModels_variations/utils/get_states.jl b/src/ActionModels_variations/utils/get_states.jl index e37da53..02db849 100644 --- a/src/ActionModels_variations/utils/get_states.jl +++ b/src/ActionModels_variations/utils/get_states.jl @@ -43,9 +43,8 @@ function ActionModels.get_states(node::AbstractNode, state_name::String) if state_name in [ "prediction", "prediction_mean", - "predicted_volatility", "prediction_precision", - "volatility_weighted_prediction_precision", + "effective_prediction_precision", ] #Get the new prediction prediction = get_prediction(node) diff --git a/src/ActionModels_variations/utils/give_inputs.jl b/src/ActionModels_variations/utils/give_inputs.jl index de013d6..4720e54 100644 --- a/src/ActionModels_variations/utils/give_inputs.jl +++ b/src/ActionModels_variations/utils/give_inputs.jl @@ -5,26 +5,49 @@ Give inputs to an agent. Input can be a single value, a vector of values, or an """ function ActionModels.give_inputs!() end -function ActionModels.give_inputs!(hgf::HGF, inputs::Real) + +### Giving a single input ### +function ActionModels.give_inputs!(hgf::HGF, inputs::Real; stepsizes::Real = 1) #Input the value to the hgf - update_hgf!(hgf, inputs) + update_hgf!(hgf, inputs, stepsize = stepsizes) return nothing end -function ActionModels.give_inputs!(hgf::HGF, inputs::Vector) +### Giving a vector of inputs ### +function ActionModels.give_inputs!( + hgf::HGF, + inputs::Vector; + stepsizes::Union{Real,Vector} = 1, +) + + #Create vector of stepsizes + if stepsizes isa Real + stepsizes = fill(stepsizes, length(inputs)) + end + + #Check that inputs and stepsizes are the same length + if length(inputs) != length(stepsizes) + @error "The number of inputs and stepsizes must be the same." + end #Each entry in the vector is an input - for input in inputs + for (input, stepsize) in zip(inputs, stepsizes) #Input it to the hgf - update_hgf!(hgf, input) + update_hgf!(hgf, input; stepsize = stepsize) end return nothing end -function ActionModels.give_inputs!(hgf::HGF, inputs::Array) + +### Giving a matrix of inputs (multiple per timestep) ### +function ActionModels.give_inputs!( + hgf::HGF, + inputs::Array; + stepsizes::Union{Real,Vector} = 1, +) #If number of column in input is diffferent from amount of input nodes if size(inputs, 2) != length(hgf.input_nodes) @@ -36,10 +59,20 @@ function ActionModels.give_inputs!(hgf::HGF, inputs::Array) ) end + #Create vector of stepsizes + if stepsizes isa Real + stepsizes = fill(stepsizes, size(inputs, 1)) + end + + #Check that inputs and stepsizes are the same length + if size(inputs, 1) != length(stepsizes) + @error "The number of inputs and stepsizes must be the same." + end + #Take each row in the array - for input in eachrow(inputs) + for (input, stepsize) in zip(eachrow(inputs), stepsizes) #Input it to the hgf - update_hgf!(hgf, Vector(input)) + update_hgf!(hgf, Vector(input), stepsize = stepsize) end return nothing diff --git a/src/ActionModels_variations/utils/reset.jl b/src/ActionModels_variations/utils/reset.jl index 88c826c..a614628 100644 --- a/src/ActionModels_variations/utils/reset.jl +++ b/src/ActionModels_variations/utils/reset.jl @@ -5,64 +5,14 @@ Reset an HGF to its initial state. """ function ActionModels.reset!(hgf::HGF) + #Reset the timesteps for the HGF + hgf.timesteps = [0] + #Go through each node for node in hgf.ordered_nodes.all_nodes - #For categorical state nodes - if node isa CategoricalStateNode - #Set states to vectors of missing - node.states.posterior .= missing - node.states.value_prediction_error .= missing - #Empty prediction state - empty!(node.states.prediction) - - #For binary input nodes - elseif node isa BinaryInputNode - #Set states to missing - node.states.value_prediction_error .= missing - node.states.input_value = missing - - #For continuous state nodes - elseif node isa ContinuousStateNode - #Set posterior to initial belief - node.states.posterior_mean = node.parameters.initial_mean - node.states.posterior_precision = node.parameters.initial_precision - #For other states - for state_name in [ - :value_prediction_error, - :volatility_prediction_error, - :prediction_mean, - :predicted_volatility, - :prediction_precision, - :volatility_weighted_prediction_precision, - ] - #Set the state to missing - setfield!(node.states, state_name, missing) - end - - #For continuous input nodes - elseif node isa ContinuousInputNode - - #For all states except auxiliary prediction precision - for state_name in [ - :input_value, - :value_prediction_error, - :volatility_prediction_error, - :predicted_volatility, - :prediction_precision, - ] - #Set the state to missing - setfield!(node.states, state_name, missing) - end - - #For other nodes - else - #For each state - for state_name in fieldnames(typeof(node.states)) - #Set the state to missing - setfield!(node.states, state_name, missing) - end - end + #Reset its state + reset_state!(node) #For each state in the history for state_name in fieldnames(typeof(node.history)) @@ -70,19 +20,74 @@ function ActionModels.reset!(hgf::HGF) #Empty the history empty!(getfield(node.history, state_name)) - #For states other than prediction states - if !( - state_name in [ - :prediction, - :prediction_mean, - :predicted_volatility, - :prediction_precision, - :volatility_weighted_prediction_precision, - ] - ) - #Add the new current state as the first state in the history - push!(getfield(node.history, state_name), getfield(node.states, state_name)) - end + #Add the new current state as the first state in the history + push!(getfield(node.history, state_name), getfield(node.states, state_name)) end end end + + +function reset_state!(node::ContinuousStateNode) + + node.states.posterior_mean = node.parameters.initial_mean + node.states.posterior_precision = node.parameters.initial_precision + + node.states.value_prediction_error = missing + node.states.precision_prediction_error = missing + + node.states.prediction_mean = missing + node.states.prediction_precision = missing + node.states.effective_prediction_precision = missing + + return nothing +end + +function reset_state!(node::ContinuousInputNode) + + node.states.input_value = missing + + node.states.value_prediction_error = missing + node.states.precision_prediction_error = missing + + node.states.prediction_mean = missing + node.states.prediction_precision = missing + + return nothing +end + +function reset_state!(node::BinaryStateNode) + + node.states.posterior_mean = missing + node.states.posterior_precision = missing + + node.states.value_prediction_error = missing + + node.states.prediction_mean = missing + node.states.prediction_precision = missing + + return nothing +end + +function reset_state!(node::BinaryInputNode) + + node.states.input_value = missing + + return nothing +end + +function reset_state!(node::CategoricalStateNode) + + node.states.posterior .= missing + node.states.value_prediction_error .= missing + node.states.prediction .= missing + node.states.parent_predictions .= missing + + return nothing +end + +function reset_state!(node::CategoricalInputNode) + + node.states.input_value = missing + + return nothing +end diff --git a/src/ActionModels_variations/utils/set_parameters.jl b/src/ActionModels_variations/utils/set_parameters.jl index 78b0a95..e83f2c6 100644 --- a/src/ActionModels_variations/utils/set_parameters.jl +++ b/src/ActionModels_variations/utils/set_parameters.jl @@ -11,7 +11,7 @@ function ActionModels.set_parameters!() end ### For setting a single parameter ### -##For parameters other than coupling strengths +##For parameters other than coupling strengths and transforms function ActionModels.set_parameters!( hgf::HGF, target_param::Tuple{String,String}, @@ -61,7 +61,6 @@ function ActionModels.set_parameters!( #Unpack node name, parent name and parameter name (node_name, parent_name, param_name) = target_param - #If the node does not exist if !(node_name in keys(hgf.all_nodes)) #Throw an error @@ -71,59 +70,78 @@ function ActionModels.set_parameters!( #Get the child node node = hgf.all_nodes[node_name] + #If it is a coupling strength + if param_name == "coupling_strength" + + #Get coupling_strengths + coupling_strengths = node.parameters.coupling_strengths + + #If the specified parent is not in the dictionary + if !(parent_name in keys(coupling_strengths)) + #Throw an error + throw( + ArgumentError( + "The node $node_name does not have a coupling strength parameter to a parent called $parent_name", + ), + ) + end + + #Set the coupling strength to the specified parent to the specified value + coupling_strengths[parent_name] = param_value + + else + + #Get out the coupling transforms + coupling_transforms = getproperty(node.parameters, :coupling_transforms) + + #If the specified parent is not in the dictionary + if !(parent_name in keys(coupling_transforms)) + #Throw an error + throw( + ArgumentError( + "The node $node_name does not have a coupling transformation to a parent called $parent_name", + ), + ) + end + + #If the specified parameter does not exist for the transform + if !(param_name in keys(coupling_transforms.parameters)) + throw( + ArgumentError( + "There is no parameter called $param_name for the transformation function between $node_name and its parent $parent_name", + ), + ) + end - #If the param does not exist in the node - if !(Symbol(param_name) in fieldnames(typeof(node.parameters))) - #Throw an error - throw( - ArgumentError( - "The node $node_name does not have the parameter $param_name in its parameters", - ), - ) - end - - #Get coupling_strengths - coupling_strengths = getfield(node.parameters, Symbol(param_name)) - - #If the specified parent is not in the dictionary - if !(parent_name in keys(coupling_strengths)) - #Throw an error - throw( - ArgumentError( - "The node $node_name does not have a $param_name to a parent called $parent_name", - ), - ) + #Set the parameter + coupling_transforms.parameters[param_name] = param_value end - - #Set the coupling strength to the specified parent to the specified value - coupling_strengths[parent_name] = param_value - end ### For setting a single parameter ### -function ActionModels.set_parameters!(hgf::HGF, target_param::String, param_value::Any) +function ActionModels.set_parameters!(hgf::HGF, target_param::String, param_value::Real) #If the target parameter is not in the shared parameters - if !(target_param in keys(hgf.shared_parameters)) + if !(target_param in keys(hgf.parameter_groups)) throw( ArgumentError( - "the parameter $target_param is passed to the HGF but is not in the HGF's shared parameters. Check that it is specified correctly", + "the parameter $target_param is a string, but is not in the HGF's grouped parameters. Check that it is specified correctly", ), ) end #Get out the shared parameter struct - shared_parameter = hgf.shared_parameters[target_param] + parameter_group = hgf.parameter_groups[target_param] - #Set the value in the shared parameter - setfield!(shared_parameter, :value, param_value) + #Set the value in the parameter group + setfield!(parameter_group, :value, param_value) - #Get out the derived parameters - derived_parameters = shared_parameter.derived_parameters + #Get out the grouped parameters + grouped_parameters = parameter_group.grouped_parameters - #For each derived parameter - for derived_parameter_key in derived_parameters + #For each grouped parameter + for grouped_parameter_key in grouped_parameters #Set the parameter - set_parameters!(hgf, derived_parameter_key, param_value) + set_parameters!(hgf, grouped_parameter_key, param_value) end end diff --git a/src/ActionModels_variations/utils/set_save_history.jl b/src/ActionModels_variations/utils/set_save_history.jl new file mode 100644 index 0000000..9862b31 --- /dev/null +++ b/src/ActionModels_variations/utils/set_save_history.jl @@ -0,0 +1,4 @@ +#To set save history +function ActionModels.set_save_history!(hgf::HGF, save_history::Bool) + hgf.save_history = save_history +end diff --git a/src/HierarchicalGaussianFiltering.jl b/src/HierarchicalGaussianFiltering.jl index 21c6467..fce8b48 100644 --- a/src/HierarchicalGaussianFiltering.jl +++ b/src/HierarchicalGaussianFiltering.jl @@ -5,28 +5,35 @@ using ActionModels, Distributions, RecipesBase #Export functions export init_node, init_hgf, premade_hgf, check_hgf, check_node, update_hgf! -export get_prediction, get_surprise, hgf_multiple_actions +export get_prediction, get_surprise export premade_agent, - init_agent, - multiple_actions, - plot_predictive_simulation, - plot_trajectory, - plot_trajectory! -export get_history, get_parameters, get_states, set_parameters!, reset!, give_inputs! + init_agent, plot_predictive_simulation, plot_trajectory, plot_trajectory! +export get_history, + get_parameters, get_states, set_parameters!, reset!, give_inputs!, set_save_history! +export ParameterGroup export EnhancedUpdate, ClassicUpdate +export NodeDefaults +export ContinuousState, + ContinuousInput, BinaryState, BinaryInput, CategoricalState, CategoricalInput +export DriftCoupling, + ObservationCoupling, + CategoryCoupling, + ProbabilityCoupling, + VolatilityCoupling, + NoiseCoupling, + LinearTransform, + NonlinearTransform #Add premade agents to shared dict at initialization function __init__() - ActionModels.premade_agents["hgf_gaussian_action"] = premade_hgf_gaussian - ActionModels.premade_agents["hgf_binary_softmax_action"] = premade_hgf_binary_softmax - ActionModels.premade_agents["hgf_unit_square_sigmoid_action"] = - premade_hgf_unit_square_sigmoid - ActionModels.premade_agents["hgf_predict_category_action"] = - premade_hgf_predict_category + ActionModels.premade_agents["hgf_gaussian"] = premade_hgf_gaussian + ActionModels.premade_agents["hgf_binary_softmax"] = premade_hgf_binary_softmax + ActionModels.premade_agents["hgf_unit_square_sigmoid"] = premade_hgf_unit_square_sigmoid + ActionModels.premade_agents["hgf_predict_category"] = premade_hgf_predict_category end #Types for HGFs -include("structs.jl") +include("create_hgf/hgf_structs.jl") #Overloading ActionModels functions include("ActionModels_variations/core/create_premade_agent.jl") @@ -38,23 +45,36 @@ include("ActionModels_variations/utils/get_states.jl") include("ActionModels_variations/utils/give_inputs.jl") include("ActionModels_variations/utils/reset.jl") include("ActionModels_variations/utils/set_parameters.jl") +include("ActionModels_variations/utils/set_save_history.jl") + +#Functions for updating the HGF +include("update_hgf/update_hgf.jl") +include("update_hgf/nonlinear_transforms.jl") +include("update_hgf/node_updates/continuous_input_node.jl") +include("update_hgf/node_updates/continuous_state_node.jl") +include("update_hgf/node_updates/binary_input_node.jl") +include("update_hgf/node_updates/binary_state_node.jl") +include("update_hgf/node_updates/categorical_input_node.jl") +include("update_hgf/node_updates/categorical_state_node.jl") #Functions for creating HGFs include("create_hgf/check_hgf.jl") include("create_hgf/init_hgf.jl") +include("create_hgf/init_node_edge.jl") include("create_hgf/create_premade_hgf.jl") -#Plotting functions - -#Functions for updating HGFs based on inputs -include("update_hgf/update_equations.jl") -include("update_hgf/update_hgf.jl") -include("update_hgf/update_node.jl") - #Functions for premade agents -include("premade_models/premade_action_models.jl") -include("premade_models/premade_agents.jl") -include("premade_models/premade_hgfs.jl") +include("premade_models/premade_agents/premade_gaussian.jl") +include("premade_models/premade_agents/premade_predict_category.jl") +include("premade_models/premade_agents/premade_sigmoid.jl") +include("premade_models/premade_agents/premade_softmax.jl") + +include("premade_models/premade_hgfs/premade_binary_2level.jl") +include("premade_models/premade_hgfs/premade_binary_3level.jl") +include("premade_models/premade_hgfs/premade_categorical_3level.jl") +include("premade_models/premade_hgfs/premade_categorical_transitions_3level.jl") +include("premade_models/premade_hgfs/premade_continuous_2level.jl") +include("premade_models/premade_hgfs/premade_JGET.jl") #Utility functions for HGFs include("utils/get_prediction.jl") diff --git a/src/create_hgf/check_hgf.jl b/src/create_hgf/check_hgf.jl index 4469dda..edc4767 100644 --- a/src/create_hgf/check_hgf.jl +++ b/src/create_hgf/check_hgf.jl @@ -5,7 +5,7 @@ Check whether an HGF has specified correctly. A single node can also be passed. """ function check_hgf(hgf::HGF) - ## Check for duplicate names + ## Check for duplicate names ## #Get node names node_names = getfield.(hgf.ordered_nodes.all_nodes, :name) #If there are any duplicate names @@ -18,35 +18,36 @@ function check_hgf(hgf::HGF) ) end - if length(hgf.shared_parameters) > 0 + #If there are shared parameters + if length(hgf.parameter_groups) > 0 - ## Check for the same derived parameter in multiple shared parameters + ## Check for the same grouped parameter in multiple shared parameters ## - #Get all derived parameters - derived_parameters = [ - parameter for list_of_derived_parameters in [ - hgf.shared_parameters[parameter_key].derived_parameters for - parameter_key in keys(hgf.shared_parameters) - ] for parameter in list_of_derived_parameters + #Get all grouped parameters + grouped_parameters = [ + parameter for list_of_grouped_parameters in [ + hgf.parameter_groups[parameter_key].grouped_parameters for + parameter_key in keys(hgf.parameter_groups) + ] for parameter in list_of_grouped_parameters ] - #check for duplicate names - if length(derived_parameters) > length(unique(derived_parameters)) + #Check for duplicate names + if length(grouped_parameters) > length(unique(grouped_parameters)) #Throw an error throw( ArgumentError( - "At least one parameter is set by multiple shared parameters. This is not supported.", + "At least one parameter is set by multiple parameter groups. This is not supported.", ), ) end - ## Check if the shared parameter is part of own derived parameters + ## Check if the shared parameter is part of own grouped parameters ## #Go through each specified shared parameter - for (shared_parameter_key, dict_value) in hgf.shared_parameters - #check if the name of the shared parameter is part of its own derived parameters - if shared_parameter_key in dict_value.derived_parameters + for (parameter_group_key, grouped_parameters) in hgf.parameter_groups + #check if the name of the shared parameter is part of its own grouped parameters + if parameter_group_key in grouped_parameters.grouped_parameters throw( ArgumentError( - "The shared parameter is part of the list of derived parameters", + "The parameter group name $parameter_group_key is part of the list of parameters in the group", ), ) end @@ -54,7 +55,7 @@ function check_hgf(hgf::HGF) end - ## Check each node + ### Check each node ### for node in hgf.ordered_nodes.all_nodes check_hgf(node) end @@ -66,52 +67,17 @@ function check_hgf(node::ContinuousStateNode) #Extract node name for error messages node_name = node.name - #Disallow binary input node value children - if any(isa.(node.value_children, BinaryInputNode)) + #If there are observation children, disallow noise and volatility children + if length(node.edges.observation_children) > 0 && + (length(node.edges.volatility_children) > 0 || length(node.edges.noise_children) > 0) throw( ArgumentError( - "The continuous state node $node_name has a value child which is a binary input node. This is not supported.", + "The state node $node_name has observation children. It is not supported for it to also have volatility or noise children, because it disrupts the update order.", ), ) end - #Disallow binary input node volatility children - if any(isa.(node.volatility_children, BinaryInputNode)) - throw( - ArgumentError( - "The continuous state node $node_name has a volatility child which is a binary input node. This is not supported.", - ), - ) - end - - #Disallow having volatility children if a value child is a continuous inputnode - if any(isa.(node.value_children, ContinuousInputNode)) - if length(node.volatility_children) > 0 - throw( - ArgumentError( - "The state node $node_name has a continuous input node as a value child. It also has volatility children, which disrupts the update order. This is not supported.", - ), - ) - end - end - - #Disallow having the same parent as value parent and volatility parent - if any(node.value_parents .∈ Ref(node.volatility_parents)) - throw( - ArgumentError( - "The state node $node_name has the same parent as value parent and volatility parent. This is not supported.", - ), - ) - end - - #Disallow having the same child as value child and volatility child - if any(node.value_children .∈ Ref(node.volatility_children)) - throw( - ArgumentError( - "The state node $node_name has the same child as value child and volatility child. This is not supported.", - ), - ) - end + #Disallow having the same node as multiple types of connections return nothing end @@ -121,38 +87,20 @@ function check_hgf(node::BinaryStateNode) #Extract node name for error messages node_name = node.name - #Require exactly one value parent - if length(node.value_parents) != 1 + #Require exactly one probability parent + if length(node.edges.probability_parents) != 1 throw( ArgumentError( - "The binary state node $node_name does not have exactly one value parent. This is not supported.", + "The binary state node $node_name does not have exactly one probability parent. This is not supported.", ), ) end - #Require exactly one value child - if length(node.value_children) != 1 + #Require exactly one observation child or category child + if length(node.edges.observation_children) + length(node.edges.category_children) != 1 throw( ArgumentError( - "The binary state node $node_name does not have exactly one value child. This is not supported.", - ), - ) - end - - # #Allow only binary input node and categorical state node children - #if any(.!(typeof.(node.value_children) in [BinaryInputNode, CategoricalStateNode])) - # throw( - # ArgumentError( - # "The binary state node $node_name has a child which is neither a binary input node nor a categorical state node. This is not supported.", - # ), - # ) - #end - - #Allow only continuous state node node parents - if any(.!isa.(node.value_parents, ContinuousStateNode)) - throw( - ArgumentError( - "The binary state node $node_name has a parent which is not a continuous state node. This is not supported.", + "The binary state node $node_name does not have exactly one observation child. This is not supported.", ), ) end @@ -166,28 +114,10 @@ function check_hgf(node::CategoricalStateNode) node_name = node.name #Require exactly one value child - if length(node.value_children) != 1 - throw( - ArgumentError( - "The categorical state node $node_name does not have exactly one value child. This is not supported.", - ), - ) - end - - #Allow only categorical input node children - if any(.!isa.(node.value_children, CategoricalInputNode)) + if length(node.edges.observation_children) != 1 throw( ArgumentError( - "The categorical state node $node_name has a child which is not a categorical input node. This is not supported.", - ), - ) - end - - #Allow only continuous state node node parents - if any(.!isa.(node.value_parents, BinaryStateNode)) - throw( - ArgumentError( - "The categorical state node $node_name has a parent which is not a binary state node. This is not supported.", + "The categorical state node $node_name does not have exactly one observation child. This is not supported.", ), ) end @@ -201,41 +131,11 @@ function check_hgf(node::ContinuousInputNode) #Extract node name for error messages node_name = node.name - #Allow only continuous state node node parents - if any(.!isa.(node.value_parents, ContinuousStateNode)) - throw( - ArgumentError( - "The continuous input node $node_name has a parent which is not a continuous state node. This is not supported.", - ), - ) - end - - #Require at least one value parent - if length(node.value_parents) == 0 - throw( - ArgumentError( - "The input node $node_name does not have any value parents. This is not supported.", - ), - ) - end - - #Disallow multiple value parents if there are volatility parents - if length(node.volatility_parents) > 0 - if length(node.value_parents) > 1 - throw( - ArgumentError( - "The input node $node_name has multiple value parents and at least one volatility parent. This is not supported.", - ), - ) - end - end - - - #Disallow having the same parent as value parent and volatility parent - if any(node.value_parents .∈ Ref(node.volatility_parents)) + #Disallow multiple observation parents if there are noise parents + if length(node.edges.noise_parents) > 0 && length(node.edges.observation_parents) > 1 throw( ArgumentError( - "The state node $node_name has the same parent as value parent and volatility parent. This is not supported.", + "The input node $node_name has multiple value parents and at least one volatility parent. This is not supported.", ), ) end @@ -249,19 +149,10 @@ function check_hgf(node::BinaryInputNode) node_name = node.name #Require exactly one value parent - if length(node.value_parents) != 1 + if length(node.edges.observation_parents) != 1 throw( ArgumentError( - "The binary input node $node_name does not have exactly one value parent. This is not supported.", - ), - ) - end - - #Allow only binary state nodes as parents - if any(.!isa.(node.value_parents, BinaryStateNode)) - throw( - ArgumentError( - "The binary input node $node_name has a parent which is not a binary state node. This is not supported.", + "The binary input node $node_name does not have exactly one observation parent. This is not supported.", ), ) end @@ -275,19 +166,10 @@ function check_hgf(node::CategoricalInputNode) node_name = node.name #Require exactly one value parent - if length(node.value_parents) != 1 - throw( - ArgumentError( - "The categorical input node $node_name does not have exactly one value parent. This is not supported.", - ), - ) - end - - #Allow only categorical state nodes as parents - if any(.!isa.(node.value_parents, CategoricalStateNode)) + if length(node.edges.observation_parents) != 1 throw( ArgumentError( - "The categorical input node $node_name has a parent which is not a categorical state node. This is not supported.", + "The categorical input node $node_name does not have exactly one observation parent. This is not supported.", ), ) end diff --git a/src/create_hgf/hgf_structs.jl b/src/create_hgf/hgf_structs.jl new file mode 100644 index 0000000..a60f6ee --- /dev/null +++ b/src/create_hgf/hgf_structs.jl @@ -0,0 +1,458 @@ +##################################### +######## Abstract node types ######## +##################################### + +#Top-level node type +abstract type AbstractNode end + +#Input and state node subtypes +abstract type AbstractStateNode <: AbstractNode end +abstract type AbstractInputNode <: AbstractNode end + +#Variable type subtypes +abstract type AbstractContinuousStateNode <: AbstractStateNode end +abstract type AbstractContinuousInputNode <: AbstractInputNode end +abstract type AbstractBinaryStateNode <: AbstractStateNode end +abstract type AbstractBinaryInputNode <: AbstractInputNode end +abstract type AbstractCategoricalStateNode <: AbstractStateNode end +abstract type AbstractCategoricalInputNode <: AbstractInputNode end + +#Abstract type for node information +abstract type AbstractNodeInfo end +abstract type AbstractInputNodeInfo <: AbstractNodeInfo end +abstract type AbstractStateNodeInfo <: AbstractNodeInfo end + +################################## +######## HGF update types ######## +################################## + +#Supertype for HGF update types +abstract type HGFUpdateType end + +#Classic and enhance dupdate types +struct ClassicUpdate <: HGFUpdateType end +struct EnhancedUpdate <: HGFUpdateType end + +################################ +######## Coupling types ######## +################################ + +#Types for specifying nonlinear transformations +abstract type CouplingTransform end + +Base.@kwdef mutable struct LinearTransform <: CouplingTransform end + +Base.@kwdef mutable struct NonlinearTransform <: CouplingTransform + base_function::Function + first_derivation::Function + second_derivation::Function + parameters::Dict = Dict() +end + +#Supertypes for coupling types +abstract type CouplingType end +abstract type ValueCoupling <: CouplingType end +abstract type PrecisionCoupling <: CouplingType end + +#Concrete value coupling types +Base.@kwdef mutable struct DriftCoupling <: ValueCoupling + strength::Union{Nothing,Real} = nothing + transform::CouplingTransform = LinearTransform() +end +Base.@kwdef mutable struct ProbabilityCoupling <: ValueCoupling + strength::Union{Nothing,Real} = nothing +end +Base.@kwdef mutable struct CategoryCoupling <: ValueCoupling end +Base.@kwdef mutable struct ObservationCoupling <: ValueCoupling end + +#Concrete precision coupling types +Base.@kwdef mutable struct VolatilityCoupling <: PrecisionCoupling + strength::Union{Nothing,Real} = nothing +end +Base.@kwdef mutable struct NoiseCoupling <: PrecisionCoupling + strength::Union{Nothing,Real} = nothing +end + +############################ +######## HGF Struct ######## +############################ +""" +""" +Base.@kwdef mutable struct OrderedNodes + all_nodes::Vector{AbstractNode} = [] + input_nodes::Vector{AbstractInputNode} = [] + all_state_nodes::Vector{AbstractStateNode} = [] + early_update_state_nodes::Vector{AbstractStateNode} = [] + late_update_state_nodes::Vector{AbstractStateNode} = [] +end + +""" +""" +Base.@kwdef mutable struct HGF + all_nodes::Dict{String,AbstractNode} + input_nodes::Dict{String,AbstractInputNode} + state_nodes::Dict{String,AbstractStateNode} + ordered_nodes::OrderedNodes = OrderedNodes() + parameter_groups::Dict = Dict() + save_history::Bool = true + timesteps::Vector{Real} = [0] +end + +################################## +######## HGF Info Structs ######## +################################## +Base.@kwdef struct NodeDefaults + input_noise::Real = -2 + bias::Real = 0 + volatility::Real = -2 + drift::Real = 0 + autoconnection_strength::Real = 1 + initial_mean::Real = 0 + initial_precision::Real = 1 + coupling_strength::Real = 1 + update_type::HGFUpdateType = EnhancedUpdate() +end + +Base.@kwdef mutable struct ContinuousState <: AbstractStateNodeInfo + name::String + volatility::Union{Real,Nothing} = nothing + drift::Union{Real,Nothing} = nothing + autoconnection_strength::Union{Real,Nothing} = nothing + initial_mean::Union{Real,Nothing} = nothing + initial_precision::Union{Real,Nothing} = nothing +end + +Base.@kwdef mutable struct ContinuousInput <: AbstractInputNodeInfo + name::String + input_noise::Union{Real,Nothing} = nothing + bias::Union{Real,Nothing} = nothing +end + +Base.@kwdef mutable struct BinaryState <: AbstractStateNodeInfo + name::String +end + +Base.@kwdef mutable struct BinaryInput <: AbstractInputNodeInfo + name::String +end + +Base.@kwdef mutable struct CategoricalState <: AbstractStateNodeInfo + name::String +end + +Base.@kwdef mutable struct CategoricalInput <: AbstractInputNodeInfo + name::String +end + + + +####################################### +######## Continuous State Node ######## +####################################### +Base.@kwdef mutable struct ContinuousStateNodeEdges + #Possible parent types + drift_parents::Vector{<:AbstractContinuousStateNode} = Vector{ContinuousStateNode}() + volatility_parents::Vector{<:AbstractContinuousStateNode} = + Vector{ContinuousStateNode}() + + #Possible children types + drift_children::Vector{<:AbstractContinuousStateNode} = Vector{ContinuousStateNode}() + volatility_children::Vector{<:AbstractContinuousStateNode} = + Vector{ContinuousStateNode}() + probability_children::Vector{<:AbstractBinaryStateNode} = Vector{BinaryStateNode}() + observation_children::Vector{<:AbstractContinuousInputNode} = + Vector{ContinuousInputNode}() + noise_children::Vector{<:AbstractContinuousInputNode} = Vector{ContinuousInputNode}() +end + +""" +Configuration of continuous state nodes' parameters +""" +Base.@kwdef mutable struct ContinuousStateNodeParameters + volatility::Real = 0 + drift::Real = 0 + autoconnection_strength::Real = 1 + initial_mean::Real = 0 + initial_precision::Real = 0 + coupling_strengths::Dict{String,Real} = Dict{String,Real}() + coupling_transforms::Dict{String,CouplingTransform} = Dict{String,Real}() +end + +""" +Configurations of the continuous state node states +""" +Base.@kwdef mutable struct ContinuousStateNodeState + posterior_mean::Union{Real} = 0 + posterior_precision::Union{Real} = 1 + value_prediction_error::Union{Real,Missing} = missing + precision_prediction_error::Union{Real,Missing} = missing + prediction_mean::Union{Real,Missing} = missing + prediction_precision::Union{Real,Missing} = missing + effective_prediction_precision::Union{Real,Missing} = missing +end + +""" +Configuration of continuous state node history +""" +Base.@kwdef mutable struct ContinuousStateNodeHistory + posterior_mean::Vector{Real} = [] + posterior_precision::Vector{Real} = [] + value_prediction_error::Vector{Union{Real,Missing}} = [] + precision_prediction_error::Vector{Union{Real,Missing}} = [] + prediction_mean::Vector{Union{Real,Missing}} = [] + prediction_precision::Vector{Union{Real,Missing}} = [] + effective_prediction_precision::Vector{Union{Real,Missing}} = [] +end + +""" +""" +Base.@kwdef mutable struct ContinuousStateNode <: AbstractContinuousStateNode + name::String + edges::ContinuousStateNodeEdges = ContinuousStateNodeEdges() + parameters::ContinuousStateNodeParameters = ContinuousStateNodeParameters() + states::ContinuousStateNodeState = ContinuousStateNodeState() + history::ContinuousStateNodeHistory = ContinuousStateNodeHistory() + update_type::HGFUpdateType = ClassicUpdate() +end + + +####################################### +######## Continuous Input Node ######## +####################################### +Base.@kwdef mutable struct ContinuousInputNodeEdges + #Possible parents + observation_parents::Vector{<:AbstractContinuousStateNode} = + Vector{ContinuousStateNode}() + noise_parents::Vector{<:AbstractContinuousStateNode} = Vector{ContinuousStateNode}() + + +end + +""" +Configuration of continuous input node parameters +""" +Base.@kwdef mutable struct ContinuousInputNodeParameters + input_noise::Real = 0 + bias::Real = 0 + coupling_strengths::Dict{String,Real} = Dict{String,Real}() + coupling_transforms::Dict{String,CouplingTransform} = Dict{String,Real}() +end + +""" +Configuration of continuous input node states +""" +Base.@kwdef mutable struct ContinuousInputNodeState + input_value::Union{Real,Missing} = missing + value_prediction_error::Union{Real,Missing} = missing + precision_prediction_error::Union{Real,Missing} = missing + prediction_mean::Union{Real,Missing} = missing + prediction_precision::Union{Real,Missing} = missing +end + +""" +Configuration of continuous input node history +""" +Base.@kwdef mutable struct ContinuousInputNodeHistory + input_value::Vector{Union{Real,Missing}} = [] + value_prediction_error::Vector{Union{Real,Missing}} = [] + precision_prediction_error::Vector{Union{Real,Missing}} = [] + prediction_mean::Vector{Union{Real,Missing}} = [] + prediction_precision::Vector{Union{Real,Missing}} = [] +end + +""" +""" +Base.@kwdef mutable struct ContinuousInputNode <: AbstractContinuousInputNode + name::String + edges::ContinuousInputNodeEdges = ContinuousInputNodeEdges() + parameters::ContinuousInputNodeParameters = ContinuousInputNodeParameters() + states::ContinuousInputNodeState = ContinuousInputNodeState() + history::ContinuousInputNodeHistory = ContinuousInputNodeHistory() +end + +################################### +######## Binary State Node ######## +################################### +Base.@kwdef mutable struct BinaryStateNodeEdges + #Possible parent types + probability_parents::Vector{<:AbstractContinuousStateNode} = + Vector{ContinuousStateNode}() + + #Possible children types + category_children::Vector{<:AbstractCategoricalStateNode} = + Vector{CategoricalStateNode}() + observation_children::Vector{<:AbstractBinaryInputNode} = Vector{BinaryInputNode}() +end + +""" + Configure parameters of binary state node +""" +Base.@kwdef mutable struct BinaryStateNodeParameters + coupling_strengths::Dict{String,Real} = Dict{String,Real}() +end + +""" +Overview of the states of the binary state node +""" +Base.@kwdef mutable struct BinaryStateNodeState + posterior_mean::Union{Real,Missing} = missing + posterior_precision::Union{Real,Missing} = missing + value_prediction_error::Union{Real,Missing} = missing + prediction_mean::Union{Real,Missing} = missing + prediction_precision::Union{Real,Missing} = missing +end + +""" +Overview of the history of the binary state node +""" +Base.@kwdef mutable struct BinaryStateNodeHistory + posterior_mean::Vector{Union{Real,Missing}} = [] + posterior_precision::Vector{Union{Real,Missing}} = [] + value_prediction_error::Vector{Union{Real,Missing}} = [] + prediction_mean::Vector{Union{Real,Missing}} = [] + prediction_precision::Vector{Union{Real,Missing}} = [] +end + +""" +Overview of edge posibilities +""" +Base.@kwdef mutable struct BinaryStateNode <: AbstractBinaryStateNode + name::String + edges::BinaryStateNodeEdges = BinaryStateNodeEdges() + parameters::BinaryStateNodeParameters = BinaryStateNodeParameters() + states::BinaryStateNodeState = BinaryStateNodeState() + history::BinaryStateNodeHistory = BinaryStateNodeHistory() + update_type::HGFUpdateType = ClassicUpdate() +end + + + +################################### +######## Binary Input Node ######## +################################### +Base.@kwdef mutable struct BinaryInputNodeEdges + observation_parents::Vector{<:AbstractBinaryStateNode} = Vector{BinaryStateNode}() +end + +""" +Configuration of parameters in binary input node. Default category mean set to [0,1] +""" +Base.@kwdef mutable struct BinaryInputNodeParameters + category_means::Vector{Union{Real}} = [0, 1] + input_precision::Real = Inf + coupling_strengths::Dict{String,Real} = Dict{String,Real}() +end + +""" +Configuration of states of binary input node +""" +Base.@kwdef mutable struct BinaryInputNodeState + input_value::Union{Real,Missing} = missing +end + +""" +Configuration of history of binary input node +""" +Base.@kwdef mutable struct BinaryInputNodeHistory + input_value::Vector{Union{Real,Missing}} = [missing] +end + +""" +""" +Base.@kwdef mutable struct BinaryInputNode <: AbstractBinaryInputNode + name::String + edges::BinaryInputNodeEdges = BinaryInputNodeEdges() + parameters::BinaryInputNodeParameters = BinaryInputNodeParameters() + states::BinaryInputNodeState = BinaryInputNodeState() + history::BinaryInputNodeHistory = BinaryInputNodeHistory() +end + + + +######################################## +######## Categorical State Node ######## +######################################## +Base.@kwdef mutable struct CategoricalStateNodeEdges + #Possible parents + category_parents::Vector{<:AbstractBinaryStateNode} = Vector{BinaryStateNode}() + #The order of the category parents + category_parent_order::Vector{String} = [] + + #Possible children + observation_children::Vector{<:AbstractCategoricalInputNode} = + Vector{CategoricalInputNode}() +end + +Base.@kwdef mutable struct CategoricalStateNodeParameters + coupling_strengths::Dict{String,Real} = Dict{String,Real}() +end + +""" +Configuration of states in categorical state node +""" +Base.@kwdef mutable struct CategoricalStateNodeState + posterior::Vector{Union{Real,Missing}} = [] + value_prediction_error::Vector{Union{Real,Missing}} = [] + prediction::Vector{Union{Real,Missing}} = [] + parent_predictions::Vector{Union{Real,Missing}} = [] +end + +""" +Configuration of history in categorical state node +""" +Base.@kwdef mutable struct CategoricalStateNodeHistory + posterior::Vector{Vector{Union{Real,Missing}}} = [] + value_prediction_error::Vector{Vector{Union{Real,Missing}}} = [] + prediction::Vector{Vector{Union{Real,Missing}}} = [] + parent_predictions::Vector{Vector{Union{Real,Missing}}} = [] +end + +""" +Configuration of edges in categorical state node +""" +Base.@kwdef mutable struct CategoricalStateNode <: AbstractCategoricalStateNode + name::String + edges::CategoricalStateNodeEdges = CategoricalStateNodeEdges() + parameters::CategoricalStateNodeParameters = CategoricalStateNodeParameters() + states::CategoricalStateNodeState = CategoricalStateNodeState() + history::CategoricalStateNodeHistory = CategoricalStateNodeHistory() + update_type::HGFUpdateType = ClassicUpdate() +end + + + +######################################## +######## Categorical Input Node ######## +######################################## +Base.@kwdef mutable struct CategoricalInputNodeEdges + observation_parents::Vector{<:AbstractCategoricalStateNode} = + Vector{CategoricalStateNode}() +end + +Base.@kwdef mutable struct CategoricalInputNodeParameters + coupling_strengths::Dict{String,Real} = Dict{String,Real}() +end + +""" +Configuration of states of categorical input node +""" +Base.@kwdef mutable struct CategoricalInputNodeState + input_value::Union{Real,Missing} = missing +end + +""" +History of categorical input node +""" +Base.@kwdef mutable struct CategoricalInputNodeHistory + input_value::Vector{Union{Real,Missing}} = [missing] +end + +""" +""" +Base.@kwdef mutable struct CategoricalInputNode <: AbstractCategoricalInputNode + name::String + edges::CategoricalInputNodeEdges = CategoricalInputNodeEdges() + parameters::CategoricalInputNodeParameters = CategoricalInputNodeParameters() + states::CategoricalInputNodeState = CategoricalInputNodeState() + history::CategoricalInputNodeHistory = CategoricalInputNodeHistory() +end diff --git a/src/create_hgf/init_hgf.jl b/src/create_hgf/init_hgf.jl index b348647..ca308fb 100644 --- a/src/create_hgf/init_hgf.jl +++ b/src/create_hgf/init_hgf.jl @@ -24,6 +24,15 @@ Edge information includes 'child', as well as 'value_parents' and/or 'volatility ```julia ##Create a simple 2level continuous HGF## +#Set defaults for nodes +node_defaults = Dict( + "volatility" => -2, + "input_noise" => -2, + "initial_mean" => 0, + "initial_precision" => 1, + "coupling_strength" => 1, +) + #List of input nodes input_nodes = Dict( "name" => "u", @@ -34,14 +43,14 @@ input_nodes = Dict( #List of state nodes state_nodes = [ Dict( - "name" => "x1", + "name" => "x", "type" => "continuous", "volatility" => -2, "initial_mean" => 0, "initial_precision" => 1, ), Dict( - "name" => "x2", + "name" => "xvol", "type" => "continuous", "volatility" => -2, "initial_mean" => 0, @@ -50,68 +59,11 @@ state_nodes = [ ] #List of child-parent relations -edges = [ - Dict( - "child" => "u", - "value_parents" => ("x1", 1), - ), - Dict( - "child" => "x1", - "volatility_parents" => ("x2", 1), - ), -] - -#Initialize the HGF -hgf = init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, -) - -##Create a more complicated HGF without specifying information for each node## - -#Set defaults for all nodes -node_defaults = Dict( - "volatility" => -2, - "input_noise" => -2, - "initial_mean" => 0, - "initial_precision" => 1, - "value_coupling" => 1, - "volatility_coupling" => 1, +edges = Dict( + ("u", "x") -> ObservationCoupling() + ("x", "xvol") -> VolatilityCoupling() ) -input_nodes = [ - "u1", - "u2", -] - -state_nodes = [ - "x1", - "x2", - "x3", - "x4", -] - -edges = [ - Dict( - "child" => "u1", - "value_parents" => ["x1", "x2"], - "volatility_parents" => "x3" - ), - Dict( - "child" => "u2", - "value_parents" => ["x1"], - ), - Dict( - "child" => "x1", - "volatility_parents" => "x4", - ), - Dict( - "child" => "x2", - "volatility_parents" => "x4", - ), -] - hgf = init_hgf( input_nodes = input_nodes, state_nodes = state_nodes, @@ -121,200 +73,69 @@ hgf = init_hgf( ``` """ function init_hgf(; - input_nodes::Union{String,Dict,Vector}, - state_nodes::Union{String,Dict,Vector}, - edges::Union{Vector{<:Dict},Dict}, - shared_parameters::Dict = Dict(), - node_defaults::Dict = Dict(), - update_type::HGFUpdateType = EnhancedUpdate(), + nodes::Vector{<:AbstractNodeInfo}, + edges::Dict{Tuple{String,String},<:CouplingType}, + node_defaults::NodeDefaults = NodeDefaults(), + parameter_groups::Vector{ParameterGroup} = Vector{ParameterGroup}(), update_order::Union{Nothing,Vector{String}} = nothing, verbose::Bool = true, + save_history::Bool = true, ) - ### Defaults ### - preset_node_defaults = Dict( - "type" => "continuous", - "volatility" => -2, - "drift" => 0, - "autoregression_target" => 0, - "autoregression_strength" => 0, - "initial_mean" => 0, - "initial_precision" => 1, - "value_coupling" => 1, - "volatility_coupling" => 1, - "category_means" => [0, 1], - "input_precision" => Inf, - "input_noise" => -2 - ) - - #If verbose - if verbose - #If some node defaults have been specified - if length(node_defaults) > 0 - #Warn the user of unspecified defaults and errors - warn_premade_defaults( - preset_node_defaults, - node_defaults, - "in the node defaults,", - ) - end - end - - #Use presets wherever node defaults were not given - node_defaults = merge(preset_node_defaults, node_defaults) - ### Initialize nodes ### #Initialize empty dictionaries for storing nodes all_nodes_dict = Dict{String,AbstractNode}() input_nodes_dict = Dict{String,AbstractInputNode}() state_nodes_dict = Dict{String,AbstractStateNode}() - - ## Input nodes ## - - #If user has only specified a single node and not in a vector - if input_nodes isa Dict - #Put it in a vector - input_nodes = [input_nodes] - end + input_nodes_inputted_order = Vector{String}() + state_nodes_inputted_order = Vector{String}() #For each specified input node - for node_info in input_nodes - - #If only the node's name was specified as a string - if node_info isa String - #Make it into a dictionary - node_info = Dict("name" => node_info) + for node_info in nodes + #For each field in the node info + for fieldname in fieldnames(typeof(node_info)) + #If it hasn't been specified by the user + if isnothing(getfield(node_info, fieldname)) + #Set the node_defaults' value instead + setfield!(node_info, fieldname, getfield(node_defaults, fieldname)) + end end #Create the node - node = init_node("input_node", node_defaults, node_info) + node = init_node(node_info) - #Add it to the dictionary - all_nodes_dict[node.name] = node - input_nodes_dict[node.name] = node - end - - ## State nodes ## - #If user has only specified a single node and not in a vector - if state_nodes isa Dict - #Put it in a vector - state_nodes = [state_nodes] - end - - #For each specified state node - for node_info in state_nodes + #Add it to the large dictionary + all_nodes_dict[node_info.name] = node - #If only the node's name was specified as a string - if node_info isa String - #Make it into a named tuple - node_info = Dict("name" => node_info) + #If it is an input node + if node isa AbstractInputNode + #Add it to the input node dict + input_nodes_dict[node_info.name] = node + #Store its name in the inputted order + push!(input_nodes_inputted_order, node_info.name) + + #If it is a state node + elseif node isa AbstractStateNode + #Add it to the state node dict + state_nodes_dict[node_info.name] = node + #Store its name in the inputted order + push!(state_nodes_inputted_order, node_info.name) end - - #Create the node - node = init_node("state_node", node_defaults, node_info) - - #Add it to the dictionary - all_nodes_dict[node.name] = node - state_nodes_dict[node.name] = node end - ### Set up edges ### + #For each specified edge + for (node_names, coupling_type) in edges - #If user has only specified a single edge and not in a vector - if edges isa Dict - #Put it in a vector - edges = [edges] - end - - #For each child - for edge in edges - - #Find corresponding child node - child_node = all_nodes_dict[edge["child"]] - - #Add empty vectors for when the user has not specified any - edge = merge(Dict("value_parents" => [], "volatility_parents" => []), edge) - - #If there are any value parents - if length(edge["value_parents"]) > 0 - #Get out value parents - value_parents = edge["value_parents"] - - #If the value parents were not specified as a vector - if .!isa(value_parents, Vector) - #Make it into one - value_parents = [value_parents] - end - - #For each value parent - for parent_info in value_parents - - #If only the node's name was specified as a string - if parent_info isa String - #Make it a tuple, and give it the default coupling strength - parent_info = (parent_info, node_defaults["value_coupling"]) - end - - #Find the corresponding parent - parent_node = all_nodes_dict[parent_info[1]] - - #Add the parent to the child node - push!(child_node.value_parents, parent_node) - - #Add the child node to the parent node - push!(parent_node.value_children, child_node) - - #Except for binary input nodes and categorical nodes - if !( - typeof(child_node) in - [BinaryInputNode, CategoricalInputNode, CategoricalStateNode] - ) - #Add coupling strength to child node - child_node.parameters.value_coupling[parent_node.name] = parent_info[2] - end - end - end - - #If there are any volatility parents - if length(edge["volatility_parents"]) > 0 - #Get out volatility parents - volatility_parents = edge["volatility_parents"] - - #If the volatility parents were not specified as a vector - if .!isa(volatility_parents, Vector) - #Make it into one - volatility_parents = [volatility_parents] - end - - #For each volatility parent - for parent_info in volatility_parents - - #If only the node's name was specified as a string - if parent_info isa String - #Make it a tuple, and give it the default coupling strength - parent_info = (parent_info, node_defaults["volatility_coupling"]) - end - - #Find the corresponding parent - parent_node = all_nodes_dict[parent_info[1]] - - #Add the parent to the child node - push!(child_node.volatility_parents, parent_node) - - #Add the child node to the parent node - push!(parent_node.volatility_children, child_node) + #Extract the child and parent names + child_name, parent_name = node_names - #Add coupling strength to child node - child_node.parameters.volatility_coupling[parent_node.name] = parent_info[2] + #Find corresponding child node and parent node + child_node = all_nodes_dict[child_name] + parent_node = all_nodes_dict[parent_name] - #If the enhanced HGF update is used - if update_type isa EnhancedUpdate && parent_node isa ContinuousStateNode - #Set the node to use the enhanced HGF update - parent_node.update_type = update_type - end - end - end + #Create the edge + init_edge!(child_node, parent_node, coupling_type, node_defaults) end ## Determine Update order ## @@ -323,38 +144,12 @@ function init_hgf(; #If verbose if verbose - #Warn that automaitc update order is used + #Warn that automatic update order is used @warn "No update order specified. Using the order in which nodes were inputted" end - #Initialize empty vector for storing the update order - update_order = [] - - #For each input node, in the order inputted - for node_info in input_nodes - - #If only the node's name was specified as a string - if node_info isa String - #Make it into a named tuple - node_info = Dict("name" => node_info) - end - - #Add the node to the vector - push!(update_order, all_nodes_dict[node_info["name"]]) - end - - #For each state node, in the order inputted - for node_info in state_nodes - - #If only the node's name was specified as a string - if node_info isa String - #Make it into a named tuple - node_info = Dict("name" => node_info) - end - - #Add the node to the vector - push!(update_order, all_nodes_dict[node_info["name"]]) - end + #Use the order that the nodes were specified in + update_order = append!(input_nodes_inputted_order, state_nodes_inputted_order) end ## Order nodes ## @@ -362,7 +157,10 @@ function init_hgf(; ordered_nodes = OrderedNodes() #For each node, in the specified update order - for node in update_order + for node_name in update_order + + #Extract node + node = all_nodes_dict[node_name] #Have a field for all nodes push!(ordered_nodes.all_nodes, node) @@ -377,7 +175,7 @@ function init_hgf(; push!(ordered_nodes.all_state_nodes, node) #If any of the nodes' value children are continuous input nodes - if any(isa.(node.value_children, ContinuousInputNode)) + if any(isa.(node.edges.observation_children, ContinuousInputNode)) #Add it to the early update list push!(ordered_nodes.early_update_state_nodes, node) else @@ -388,25 +186,15 @@ function init_hgf(; end #initializing shared parameters - shared_parameters_dict = Dict() + parameter_groups_dict = Dict() #Go through each specified shared parameter - for (shared_parameter_key, dict_value) in shared_parameters - #Unpack the shared parameter value and the derived parameters - (shared_parameter_value, derived_parameters) = dict_value - #check if the name of the shared parameter is part of its own derived parameters - if shared_parameter_key in derived_parameters - throw( - ArgumentError( - "The shared parameter is part of the list of derived parameters", - ), - ) - end + for parameter_group in parameter_groups - #Add as a SharedParameter to the shared parameter dictionary - shared_parameters_dict[shared_parameter_key] = SharedParameter( - value = shared_parameter_value, - derived_parameters = derived_parameters, + #Add as a GroupedParameters to the shared parameter dictionary + parameter_groups_dict[parameter_group.name] = ActionModels.GroupedParameters( + value = parameter_group.value, + grouped_parameters = parameter_group.parameters, ) end @@ -416,152 +204,49 @@ function init_hgf(; input_nodes_dict, state_nodes_dict, ordered_nodes, - shared_parameters_dict, + parameter_groups_dict, + save_history, + [0], ) ### Check that the HGF has been specified properly ### check_hgf(hgf) - ### Initialize node history ### + ### Initialize states and history ### #For each state node for node in hgf.ordered_nodes.all_state_nodes - - #For categorical state nodes + #If it is a categorical state node if node isa CategoricalStateNode - #Make vector of order of category parents - for parent in node.value_parents - push!(node.category_parent_order, parent.name) + #Make vector with ordered category parents + for parent in node.edges.category_parents + push!(node.edges.category_parent_order, parent.name) end - #Set posterior to vector of zeros equal to the number of categories + #Set posterior to vector of missing with length equal to the number of categories node.states.posterior = - Vector{Union{Real,Missing}}(missing, length(node.value_parents)) - push!(node.history.posterior, node.states.posterior) + Vector{Union{Real,Missing}}(missing, length(node.edges.category_parents)) - #Set posterior to vector of missing equal to the number of categories + #Set posterior to vector of missing with length equal to the number of categories node.states.value_prediction_error = - Vector{Union{Real,Missing}}(missing, length(node.value_parents)) - push!(node.history.value_prediction_error, node.states.value_prediction_error) - - #Set parent predictions form last timestep to be agnostic - node.states.parent_predictions = - repeat([1 / length(node.value_parents)], length(node.value_parents)) - - #Set predictions form last timestep to be agnostic - node.states.prediction = - repeat([1 / length(node.value_parents)], length(node.value_parents)) - - #For other nodes - else - #Save posterior to node history - push!(node.history.posterior_mean, node.states.posterior_mean) - push!(node.history.posterior_precision, node.states.posterior_precision) - end - end - - return hgf -end - + Vector{Union{Real,Missing}}(missing, length(node.edges.category_parents)) - -""" - init_node(input_or_state_node, node_defaults, node_info) - -Function for creating a node, given specifications -""" -function init_node(input_or_state_node, node_defaults, node_info) - - #Get parameters and starting state. Specific node settings supercede node defaults, which again supercede the function's defaults. - parameters = merge(node_defaults, node_info) - - #For an input node - if input_or_state_node == "input_node" - #If it is continuous - if parameters["type"] == "continuous" - #Initialize it - node = ContinuousInputNode( - name = parameters["name"], - parameters = ContinuousInputNodeParameters( - input_noise = parameters["input_noise"], - ), - states = ContinuousInputNodeState(), - ) - #If it is binary - elseif parameters["type"] == "binary" - #Initialize it - node = BinaryInputNode( - name = parameters["name"], - parameters = BinaryInputNodeParameters( - category_means = parameters["category_means"], - input_precision = parameters["input_precision"], - ), - states = BinaryInputNodeState(), - ) - #If it is categorical - elseif parameters["type"] == "categorical" - #Initialize it - node = CategoricalInputNode( - name = parameters["name"], - parameters = CategoricalInputNodeParameters(), - states = CategoricalInputNodeState(), - ) - else - #The node has been misspecified. Throw an error - throw( - ArgumentError("the type of node $parameters['name'] has been misspecified"), - ) - end - - #For a state node - elseif input_or_state_node == "state_node" - #If it is continuous - if parameters["type"] == "continuous" - #Initialize it - node = ContinuousStateNode( - name = parameters["name"], - #Set parameters - parameters = ContinuousStateNodeParameters( - volatility = parameters["volatility"], - drift = parameters["drift"], - initial_mean = parameters["initial_mean"], - initial_precision = parameters["initial_precision"], - autoregression_target = parameters["autoregression_target"], - autoregression_strength = parameters["autoregression_strength"], - ), - #Set states - states = ContinuousStateNodeState( - posterior_mean = parameters["initial_mean"], - posterior_precision = parameters["initial_precision"], - ), + #Set parent predictions from last timestep to be agnostic + node.states.parent_predictions = repeat( + [1 / length(node.edges.category_parents)], + length(node.edges.category_parents), ) - #If it is binary - elseif parameters["type"] == "binary" - #Initialize it - node = BinaryStateNode( - name = parameters["name"], - parameters = BinaryStateNodeParameters(), - states = BinaryStateNodeState(), - ) - - #If it categorical - elseif parameters["type"] == "categorical" - - #Initialize it - node = CategoricalStateNode( - name = parameters["name"], - parameters = CategoricalStateNodeParameters(), - states = CategoricalStateNodeState(), - ) - - else - #The node has been misspecified. Throw an error - throw( - ArgumentError("the type of node $parameters['name'] has been misspecified"), + #Set predictions from last timestep to be agnostic + node.states.prediction = repeat( + [1 / length(node.edges.category_parents)], + length(node.edges.category_parents), ) end end - return node + #Reset the hgf, initializing states and history + reset!(hgf) + + return hgf end diff --git a/src/create_hgf/init_node_edge.jl b/src/create_hgf/init_node_edge.jl new file mode 100644 index 0000000..29c5944 --- /dev/null +++ b/src/create_hgf/init_node_edge.jl @@ -0,0 +1,108 @@ +### Function for initializing a node ### +function init_node(node_info::ContinuousState) + ContinuousStateNode( + name = node_info.name, + parameters = ContinuousStateNodeParameters( + volatility = node_info.volatility, + drift = node_info.drift, + initial_mean = node_info.initial_mean, + initial_precision = node_info.initial_precision, + autoconnection_strength = node_info.autoconnection_strength, + ), + ) +end + +function init_node(node_info::ContinuousInput) + ContinuousInputNode( + name = node_info.name, + parameters = ContinuousInputNodeParameters(input_noise = node_info.input_noise), + ) +end + +function init_node(node_info::BinaryState) + BinaryStateNode(name = node_info.name) +end + +function init_node(node_info::BinaryInput) + BinaryInputNode(name = node_info.name) +end + +function init_node(node_info::CategoricalState) + CategoricalStateNode(name = node_info.name) +end + +function init_node(node_info::CategoricalInput) + CategoricalInputNode(name = node_info.name) +end + + +### Function for initializing an edge ### +function init_edge!( + child_node::AbstractNode, + parent_node::AbstractStateNode, + coupling_type::CouplingType, + node_defaults::NodeDefaults, +) + + #Get correct field for storing parents + if coupling_type isa DriftCoupling + parents_field = :drift_parents + children_field = :drift_children + + elseif coupling_type isa ObservationCoupling + parents_field = :observation_parents + children_field = :observation_children + + elseif coupling_type isa CategoryCoupling + parents_field = :category_parents + children_field = :category_children + + elseif coupling_type isa ProbabilityCoupling + parents_field = :probability_parents + children_field = :probability_children + + elseif coupling_type isa VolatilityCoupling + parents_field = :volatility_parents + children_field = :volatility_children + + elseif coupling_type isa NoiseCoupling + parents_field = :noise_parents + children_field = :noise_children + end + + #Add the parent to the child node + push!(getfield(child_node.edges, parents_field), parent_node) + + #Add the child node to the parent node + push!(getfield(parent_node.edges, children_field), child_node) + + #If the coupling type has a coupling strength + if hasproperty(coupling_type, :strength) + #If the user has not specified a coupling strength + if isnothing(coupling_type.strength) + #Use the defaults coupling strength + coupling_strength = node_defaults.coupling_strength + + #Otherwise + else + #Use the specified coupling strength + coupling_strength = coupling_type.strength + end + + #And set it as a parameter for the child + child_node.parameters.coupling_strengths[parent_node.name] = coupling_strength + end + + #If the coupling can have a transformation + if hasproperty(coupling_type, :transform) + #Save that transformation in the node + child_node.parameters.coupling_transforms[parent_node.name] = + coupling_type.transform + end + + #If the enhanced HGF update is the defaults, and if it is a precision coupling (volatility or noise) + if node_defaults.update_type isa EnhancedUpdate && coupling_type isa PrecisionCoupling + #Set the node to use the enhanced HGF update + parent_node.update_type = node_defaults.update_type + end +end diff --git a/src/premade_models/premade_action_models.jl b/src/premade_models/premade_action_models.jl deleted file mode 100644 index eb5c330..0000000 --- a/src/premade_models/premade_action_models.jl +++ /dev/null @@ -1,257 +0,0 @@ -###### Multiple Actions ###### -""" - update_hgf_multiple_actions(agent::Agent, input) - -Action model that first updates the HGF, and then runs multiple action models. -""" -function update_hgf_multiple_actions(agent::Agent, input) - - #Update the hgf - hgf = agent.substruct - update_hgf!(hgf, input) - - #Extract vector of action models - action_models = agent.settings["hgf_actions"] - - #Initialize vector for action distributions - action_distributions = [] - - #Do each action model separately - for action_model in action_models - #And append them to the vector of action distributions - push!(action_distributions, action_model(agent, input)) - end - - return action_distributions -end - - -###### Gaussian Action ###### -""" - update_hgf_gaussian_action(agent::Agent, input) - -Action model that first updates the HGF, and then reports a given HGF state with Gaussian noise. - -In addition to the HGF substruct, the following must be present in the agent: -Parameters: "gaussian_action_precision" -Settings: "target_state" -""" -function update_hgf_gaussian_action(agent::Agent, input) - - #Update the HGF - update_hgf!(agent.substruct, input) - - #Run the action model - action_distribution = hgf_gaussian_action(agent, input) - - return action_distribution -end - -""" - hgf_gaussian_action(agent::Agent, input) - -Action model which reports a given HGF state with Gaussian noise. - -In addition to the HGF substruct, the following must be present in the agent: -Parameters: "gaussian_action_precision" -Settings: "target_state" -""" -function hgf_gaussian_action(agent::Agent, input) - - #Get out hgf, settings and parameters - hgf = agent.substruct - target_state = agent.settings["target_state"] - action_precision = agent.parameters["gaussian_action_precision"] - - #Get the specified state - action_mean = get_states(hgf, target_state) - - #If the gaussian mean becomes a NaN - if isnan(action_mean) - #Throw an error that will reject samples when fitted - throw( - RejectParameters( - "With these parameters and inputs, the mean of the gaussian action became $action_mean, which is invalid. Try other parameter settings", - ), - ) - end - - #Create normal distribution with mean of the target value and a standard deviation from parameters - distribution = Distributions.Normal(action_mean, 1 / action_precision) - - #Return the action distribution - return distribution -end - - -###### Softmax Action ###### -""" - update_hgf_softmax_action(agent::Agent, input) - -Action model that first updates the HGF, and then passes a state from the HGF through a softmax to give a binary action. - -In addition to the HGF substruct, the following must be present in the agent: -Parameters: "softmax_action_precision" -Settings: "target_state" -""" -function update_hgf_binary_softmax_action(agent::Agent, input) - - #Update the HGF - update_hgf!(agent.substruct, input) - - #Run the action model - action_distribution = hgf_binary_softmax_action(agent, input) - - return action_distribution -end - -""" - hgf_binary_softmax_action(agent, input) - -Action model which gives a binary action. The action probability is the softmax of a specified state of a node. - -In addition to the HGF substruct, the following must be present in the agent: -Parameters: "softmax_action_precision" -Settings: "target_state" -""" -function hgf_binary_softmax_action(agent::Agent, input) - - #Get out HGF, settings and parameters - hgf = agent.substruct - target_state = agent.settings["target_state"] - action_precision = agent.parameters["softmax_action_precision"] - - #Get the specified state - target_value = get_states(hgf, target_state) - - #Use sotmax to get the action probability - action_probability = 1 / (1 + exp(-action_precision * target_value)) - - #If the action probability is not between 0 and 1 - if !(0 <= action_probability <= 1) - #Throw an error that will reject samples when fitted - throw( - RejectParameters( - "With these parameters and inputs, the action probability became $action_probability, which should be between 0 and 1. Try other parameter settings", - ), - ) - end - - #Create Bernoulli normal distribution with mean of the target value and a standard deviation from parameters - distribution = Distributions.Bernoulli(action_probability) - - #Return the action distribution - return distribution -end - -###### Unit Square Sigmoid Action ###### -""" - update_hgf_unit_square_sigmoid_action(agent::Agent, input) - -Action model that first updates the HGF, and then passes a state from the HGF through a unit square sigmoid transform to give a binary action. - -In addition to the HGF substruct, the following must be present in the agent: -Parameters: "sigmoid_action_precision" -Settings: "target_state" -""" -function update_hgf_unit_square_sigmoid_action(agent::Agent, input) - - #Update the HGF - update_hgf!(agent.substruct, input) - - #Run the action model - action_distribution = hgf_unit_square_sigmoid_action(agent, input) - - return action_distribution -end - -""" - unit_square_sigmoid_action(agent, input) - -Action model which gives a binary action. The action probability is the unit square sigmoid of a specified state of a node. - -In addition to the HGF substruct, the following must be present in the agent: -Parameters: "sigmoid_action_precision" -Settings: "target_state" -""" -function hgf_unit_square_sigmoid_action(agent::Agent, input) - - #Get out settings and parameters - target_state = agent.settings["target_state"] - action_precision = agent.parameters["sigmoid_action_precision"] - - #Get out the HGF - hgf = agent.substruct - - #Get the specified state - target_value = get_states(hgf, target_state) - - #Use softmax to get the action probability - action_probability = - target_value^action_precision / - (target_value^action_precision + (1 - target_value)^action_precision) - - #If the action probability is not between 0 and 1 - if !(0 <= action_probability <= 1) - #Throw an error that will reject samples when fitted - throw( - RejectParameters( - "With these parameters and inputs, the action probability became $action_probability, which should be between 0 and 1. Try other parameter settings", - ), - ) - end - - #Create Bernoulli normal distribution with mean of the target value and a standard deviation from parameters - distribution = Distributions.Bernoulli(action_probability) - - #Return the action distribution - return distribution -end - - - -###### Categorical Prediction Action ###### -""" - update_hgf_predict_category_action(agent::Agent, input) - -Action model that first updates the HGF, and then returns a categorical prediction of the input. The HGF used must be a categorical HGF. - -In addition to the HGF substruct, the following must be present in the agent: -Settings: "target_categorical_node" -""" -function update_hgf_predict_category_action(agent::Agent, input) - - #Update the HGF - update_hgf!(agent.substruct, input) - - #Run the action model - action_distribution = hgf_predict_category_action(agent, input) - - return action_distribution -end - -""" - hgf_predict_category_action(agent::Agent, input) - -Action model which gives a categorical prediction of the input, based on an HGF. The HGF used must be a categorical HGF. - -In addition to the HGF substruct, the following must be present in the agent: -Settings: "target_categorical_node" -""" -function hgf_predict_category_action(agent::Agent, input) - - #Get out settings and parameters - target_node = agent.settings["target_categorical_node"] - - #Get out the HGF - hgf = agent.substruct - - #Get the specified state - predicted_category_probabilities = get_states(hgf, (target_node, "prediction")) - - #Create Bernoulli normal distribution with mean of the target value and a standard deviation from parameters - distribution = Distributions.Categorical(predicted_category_probabilities) - - #Return the action distribution - return distribution -end diff --git a/src/premade_models/premade_agents.jl b/src/premade_models/premade_agents.jl deleted file mode 100644 index b128a12..0000000 --- a/src/premade_models/premade_agents.jl +++ /dev/null @@ -1,376 +0,0 @@ -""" - premade_hgf_multiple_actions(config::Dict) - -Create an agent suitable for multiple aciton models that can depend on an HGF substruct. The used action models are specified as a vector. - -# Config defaults: - - "HGF": "continuous_2level" - - "hgf_actions": ["gaussian_action", "softmax_action", "unit_square_sigmoid_action"] -""" -function premade_hgf_multiple_actions(config::Dict) - - ## Combine defaults and user settings - - #Default parameters and settings - defaults = Dict( - "HGF" => "continuous_2level", - "hgf_actions" => - ["gaussian_action", "softmax_action", "unit_square_sigmoid_action"], - ) - - #If there is no HGF in the user-set parameters - if !("HGF" in keys(config)) - HGF_name = defaults["HGF"] - #Make a default HGF - config["HGF"] = premade_hgf(HGF_name) - #And warn them - @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" - end - - #Warn the user about used defaults and misspecified keys - warn_premade_defaults(defaults, config) - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ## Create agent - #Set the action model - action_model = update_hgf_multiple_actions - - #Set the HGF - hgf = config["HGF"] - - #Set parameters - parameters = Dict() - #Set states - states = Dict() - #Set settings - settings = Dict("hgf_actions" => config["hgf_actions"]) - - - ## Add parameters for each action type - for action_string in config["hgf_actions"] - - #Parameters for the gaussian action - if action_string == "gaussian_action" - - #Action precision parameter - if "gaussian_action_precision" in keys(config) - parameters["gaussian_action_precision"] = - config["gaussian_action_precision"] - else - default_action_precision = 1 - parameters["gaussian_action_precision"] = default_action_precision - @warn "parameter gaussian_action_precision was not set by the user. Using the default: $default_action_precision" - end - - #Target state setting - if "gaussian_target_state" in keys(config) - settings["gaussian_target_state"] = config["gaussian_target_state"] - else - default_target_state = ("x1", "posterior_mean") - settings["gaussian_target_state"] = default_target_state - @warn "setting gaussian_target_state was not set by the user. Using the default: $default_target_state" - end - - #Parameters for the softmax action - elseif action_string == "softmax_action" - - #Action precision parameter - if "softmax_action_precision" in keys(config) - parameters["softmax_action_precision"] = config["softmax_action_precision"] - else - default_action_precision = 1 - parameters["softmax_action_precision"] = default_action_precision - @warn "parameter softmax_action_precision was not set by the user. Using the default: $default_action_precision" - end - - #Target state setting - if "softmax_target_state" in keys(config) - settings["softmax_target_state"] = config["softmax_target_state"] - else - default_target_state = ("x1", "prediction_mean") - settings["softmax_target_state"] = default_target_state - @warn "setting softmax_target_state was not set by the user. Using the default: $default_target_state" - end - - #Parameters for the unit square sigmoid action - elseif action_string == "unit_square_sigmoid_action" - - #Action precision parameter - if "sigmoid_action_precision" in keys(config) - parameters["sigmoid_action_precision"] = config["sigmoid_action_precision"] - else - default_action_precision = 1 - parameters["sigmoid_action_precision"] = default_action_precision - @warn "parameter sigmoid_action_precision was not set by the user. Using the default: $default_action_precision" - end - - #Target state setting - if "sigmoid_target_state" in keys(config) - settings["sigmoid_target_state"] = config["sigmoid_target_state"] - else - default_target_state = ("x1", "prediction_mean") - settings["sigmoid_target_state"] = default_target_state - @warn "setting sigmoid_target_state was not set by the user. Using the default: $default_target_state" - end - - else - throw( - ArgumentError( - "$action_string is not a valid action type. Valid action types are: gaussian_action, softmax_action, unit_square_sigmoid_action", - ), - ) - end - end - - #Create the agent - return init_agent( - action_model; - substruct = hgf, - parameters = parameters, - states = states, - settings = settings, - ) -end - - -""" - premade_hgf_gaussian(config::Dict) - -Create an agent suitable for the HGF Gaussian action model. - -# Config defaults: - - "HGF": "continuous_2level" - - "gaussian_action_precision": 1 - - "target_state": ("x1", "posterior_mean") -""" -function premade_hgf_gaussian(config::Dict) - - ## Combine defaults and user settings - - #Default parameters and settings - defaults = Dict( - "gaussian_action_precision" => 1, - "target_state" => ("x1", "posterior_mean"), - "HGF" => "continuous_2level", - ) - - #If there is no HGF in the user-set parameters - if !("HGF" in keys(config)) - HGF_name = defaults["HGF"] - #Make a default HGF - config["HGF"] = premade_hgf(HGF_name) - #And warn them - @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" - end - - #Warn the user about used defaults and misspecified keys - warn_premade_defaults(defaults, config) - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ## Create agent - #Set the action model - action_model = update_hgf_gaussian_action - - #Set the HGF - hgf = config["HGF"] - - #Set parameters - parameters = Dict("gaussian_action_precision" => config["gaussian_action_precision"]) - #Set states - states = Dict() - #Set settings - settings = Dict("target_state" => config["target_state"]) - - #Create the agent - return init_agent( - action_model; - substruct = hgf, - parameters = parameters, - states = states, - settings = settings, - ) -end - -""" - premade_hgf_binary_softmax(config::Dict) - -Create an agent suitable for the HGF binary softmax model. - -# Config defaults: - - "HGF": "binary_3level" - - "softmax_action_precision": 1 - - "target_state": ("x1", "prediction_mean") -""" -function premade_hgf_binary_softmax(config::Dict) - - ## Combine defaults and user settings - - #Default parameters and settings - defaults = Dict( - "softmax_action_precision" => 1, - "target_state" => ("x1", "prediction_mean"), - "HGF" => "binary_3level", - ) - - #If there is no HGF in the user-set parameters - if !("HGF" in keys(config)) - HGF_name = defaults["HGF"] - #Make a default HGF - config["HGF"] = premade_hgf(HGF_name) - #And warn them - @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" - end - - #Warn the user about used defaults and misspecified keys - warn_premade_defaults(defaults, config) - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ## Create agent - #Set the action model - action_model = update_hgf_binary_softmax_action - - #Set the HGF - hgf = config["HGF"] - - #Set parameters - parameters = Dict("softmax_action_precision" => config["softmax_action_precision"]) - #Set states - states = Dict() - #Set settings - settings = Dict("target_state" => config["target_state"]) - - #Create the agent - return init_agent( - action_model, - substruct = hgf, - parameters = parameters, - states = states, - settings = settings, - ) -end - -""" - premade_hgf_binary_softmax(config::Dict) - -Create an agent suitable for the HGF unit square sigmoid model. - -# Config defaults: - - "HGF": "binary_3level" - - "sigmoid_action_precision": 1 - - "target_state": ("x1", "prediction_mean") -""" -function premade_hgf_unit_square_sigmoid(config::Dict) - - ## Combine defaults and user settings - - #Default parameters and settings - defaults = Dict( - "sigmoid_action_precision" => 1, - "target_state" => ("x1", "prediction_mean"), - "HGF" => "binary_3level", - ) - - #If there is no HGF in the user-set parameters - if !("HGF" in keys(config)) - HGF_name = defaults["HGF"] - #Make a default HGF - config["HGF"] = premade_hgf(HGF_name) - #And warn them - @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" - end - - #Warn the user about used defaults and misspecified keys - warn_premade_defaults(defaults, config) - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ## Create agent - #Set the action model - action_model = update_hgf_unit_square_sigmoid_action - - #Set the HGF - hgf = config["HGF"] - - #Set parameters - parameters = Dict("sigmoid_action_precision" => config["sigmoid_action_precision"]) - #Set states - states = Dict() - #Set settings - settings = Dict("target_state" => config["target_state"]) - - #Create the agent - return init_agent( - action_model, - substruct = hgf, - parameters = parameters, - states = states, - settings = settings, - ) -end - -""" - premade_hgf_predict_category(config::Dict) - -Create an agent suitable for the HGF predict category model. - -# Config defaults: - - "HGF": "categorical_3level" - - "target_categorical_node": "x1" -""" -function premade_hgf_predict_category(config::Dict) - - ## Combine defaults and user settings - - #Default parameters and settings - defaults = Dict("target_categorical_node" => "x1", "HGF" => "categorical_3level") - - #If there is no HGF in the user-set parameters - if !("HGF" in keys(config)) - HGF_name = defaults["HGF"] - #Make a default HGF - config["HGF"] = premade_hgf(HGF_name) - #And warn them - @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" - end - - #Warn the user about used defaults and misspecified keys - warn_premade_defaults(defaults, config) - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ## Create agent - #Set the action model - action_model = update_hgf_predict_category_action - - #Set the HGF - hgf = config["HGF"] - - #Set parameters - parameters = Dict() - #Set states - states = Dict() - #Set settings - settings = Dict("target_categorical_node" => config["target_categorical_node"]) - - #Create the agent - return init_agent( - action_model, - substruct = hgf, - parameters = parameters, - states = states, - settings = settings, - ) -end diff --git a/src/premade_models/premade_agents/premade_gaussian.jl b/src/premade_models/premade_agents/premade_gaussian.jl new file mode 100644 index 0000000..d672221 --- /dev/null +++ b/src/premade_models/premade_agents/premade_gaussian.jl @@ -0,0 +1,100 @@ +""" + hgf_gaussian(agent::Agent, input) + +Action model which reports a given HGF state with Gaussian noise. + +In addition to the HGF substruct, the following must be present in the agent: +Parameters: "gaussian_action_precision" +Settings: "target_state" +""" +function hgf_gaussian(agent::Agent, input) + + #Extract HGF, settings and parameters + hgf = agent.substruct + target_state = agent.settings["target_state"] + action_noise = agent.parameters["action_noise"] + + #Update the HGF + update_hgf!(agent.substruct, input) + + #Extract specified belief state + action_mean = get_states(hgf, target_state) + + #If the gaussian mean becomes a NaN + if isnan(action_mean) + #Throw an error that will reject samples when fitted + throw( + RejectParameters( + "With these parameters and inputs, the mean of the gaussian action became $action_mean, which is invalid. Try other parameter settings", + ), + ) + end + + #Create normal distribution with mean of the target value and a standard deviation from parameters + distribution = Distributions.Normal(action_mean, action_noise) + + #Return the action distribution + return distribution +end + + +""" + premade_hgf_gaussian(config::Dict) + +Create an agent suitable for the HGF Gaussian action model. + +# Config defaults: + - "HGF": "continuous_2level" + - "gaussian_action_precision": 1 + - "target_state": ("x", "posterior_mean") +""" +function premade_hgf_gaussian(config::Dict) + + ## Combine defaults and user settings + + #Default parameters and settings + defaults = Dict( + "action_noise" => 1, + "target_state" => ("x", "posterior_mean"), + "HGF" => "continuous_2level", + ) + + #If there is no HGF in the user-set parameters + if !("HGF" in keys(config)) + HGF_name = defaults["HGF"] + #Make a default HGF + config["HGF"] = premade_hgf(HGF_name) + #And warn them + @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" + end + + #Warn the user about used defaults and misspecified keys + warn_premade_defaults(defaults, config) + + #Merge to overwrite defaults + config = merge(defaults, config) + + + ## Create agent + #Set the action model + action_model = hgf_gaussian + + #Set the HGF + hgf = config["HGF"] + + #Set parameters + parameters = Dict("action_noise" => config["action_noise"]) + #Set states + states = Dict() + #Set settings + settings = Dict("target_state" => config["target_state"]) + + #Create the agent + return init_agent( + action_model; + substruct = hgf, + parameters = parameters, + states = states, + settings = settings, + ) +end diff --git a/src/premade_models/premade_agents/premade_predict_category.jl b/src/premade_models/premade_agents/premade_predict_category.jl new file mode 100644 index 0000000..e6c2574 --- /dev/null +++ b/src/premade_models/premade_agents/premade_predict_category.jl @@ -0,0 +1,107 @@ + + +###### Categorical Prediction Action ###### +""" + update_hgf_predict_category(agent::Agent, input) + +Action model that first updates the HGF, and then returns a categorical prediction of the input. The HGF used must be a categorical HGF. + +In addition to the HGF substruct, the following must be present in the agent: +Settings: "target_categorical_node" +""" +function update_hgf_predict_category(agent::Agent, input) + + #Update the HGF + update_hgf!(agent.substruct, input) + + #Run the action model + action_distribution = hgf_predict_category(agent, input) + + return action_distribution +end + +""" + hgf_predict_category(agent::Agent, input) + +Action model which gives a categorical prediction of the input, based on an HGF. The HGF used must be a categorical HGF. + +In addition to the HGF substruct, the following must be present in the agent: +Settings: "target_categorical_node" +""" +function hgf_predict_category(agent::Agent, input) + + #Get out settings and parameters + target_node = agent.settings["target_categorical_node"] + + #Get out the HGF + hgf = agent.substruct + + #Get the specified state + predicted_category_probabilities = get_states(hgf, (target_node, "prediction")) + + #Create Bernoulli normal distribution with mean of the target value and a standard deviation from parameters + distribution = Distributions.Categorical(predicted_category_probabilities) + + #Return the action distribution + return distribution +end + + + + + +""" + premade_hgf_predict_category(config::Dict) + +Create an agent suitable for the HGF predict category model. + +# Config defaults: + - "HGF": "categorical_3level" + - "target_categorical_node": "xcat" +""" +function premade_hgf_predict_category(config::Dict) + + ## Combine defaults and user settings + + #Default parameters and settings + defaults = Dict("target_categorical_node" => "xcat", "HGF" => "categorical_3level") + + #If there is no HGF in the user-set parameters + if !("HGF" in keys(config)) + HGF_name = defaults["HGF"] + #Make a default HGF + config["HGF"] = premade_hgf(HGF_name) + #And warn them + @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" + end + + #Warn the user about used defaults and misspecified keys + warn_premade_defaults(defaults, config) + + #Merge to overwrite defaults + config = merge(defaults, config) + + + ## Create agent + #Set the action model + action_model = update_hgf_predict_category + + #Set the HGF + hgf = config["HGF"] + + #Set parameters + parameters = Dict() + #Set states + states = Dict() + #Set settings + settings = Dict("target_categorical_node" => config["target_categorical_node"]) + + #Create the agent + return init_agent( + action_model, + substruct = hgf, + parameters = parameters, + states = states, + settings = settings, + ) +end diff --git a/src/premade_models/premade_agents/premade_sigmoid.jl b/src/premade_models/premade_agents/premade_sigmoid.jl new file mode 100644 index 0000000..e702dc9 --- /dev/null +++ b/src/premade_models/premade_agents/premade_sigmoid.jl @@ -0,0 +1,107 @@ +""" + unit_square_sigmoid_action(agent, input) + +Action model which gives a binary action. The action probability is the unit square sigmoid of a specified state of a node. + +In addition to the HGF substruct, the following must be present in the agent: +Parameters: "action_precision" +Settings: "target_state" +""" +function hgf_unit_square_sigmoid_action(agent::Agent, input) + + #Extract HGF, settings and parameters + hgf = agent.substruct + target_state = agent.settings["target_state"] + action_noise = agent.parameters["action_noise"] + + #Update the HGF + update_hgf!(hgf, input) + + #Get the specified state + target_value = get_states(hgf, target_state) + + #Use softmax to get the action probability + action_precision = 1 / action_noise + + action_probability = + target_value^action_precision / + (target_value^action_precision + (1 - target_value)^action_precision) + + #If the action probability is not between 0 and 1 + if !(0 <= action_probability <= 1) + #Throw an error that will reject samples when fitted + throw( + RejectParameters( + "With these parameters and inputs, the action probability became $action_probability, which should be between 0 and 1. Try other parameter settings", + ), + ) + end + + #Create Bernoulli normal distribution with mean of the target value and a standard deviation from parameters + distribution = Distributions.Bernoulli(action_probability) + + #Return the action distribution + return distribution +end + + +""" + premade_hgf_binary_softmax(config::Dict) + +Create an agent suitable for the HGF unit square sigmoid model. + +# Config defaults: + - "HGF": "binary_3level" + - "sigmoid_action_precision": 1 + - "target_state": ("xbin", "prediction_mean") +""" +function premade_hgf_unit_square_sigmoid(config::Dict) + + ## Combine defaults and user settings + + #Default parameters and settings + defaults = Dict( + "action_noise" => 1, + "target_state" => ("xbin", "prediction_mean"), + "HGF" => "binary_3level", + ) + + #If there is no HGF in the user-set parameters + if !("HGF" in keys(config)) + HGF_name = defaults["HGF"] + #Make a default HGF + config["HGF"] = premade_hgf(HGF_name) + #And warn them + @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" + end + + #Warn the user about used defaults and misspecified keys + warn_premade_defaults(defaults, config) + + #Merge to overwrite defaults + config = merge(defaults, config) + + + ## Create agent + #Set the action model + action_model = hgf_unit_square_sigmoid_action + + #Set the HGF + hgf = config["HGF"] + + #Set parameters + parameters = Dict("action_noise" => config["action_noise"]) + #Set states + states = Dict() + #Set settings + settings = Dict("target_state" => config["target_state"]) + + #Create the agent + return init_agent( + action_model, + substruct = hgf, + parameters = parameters, + states = states, + settings = settings, + ) +end diff --git a/src/premade_models/premade_agents/premade_softmax.jl b/src/premade_models/premade_agents/premade_softmax.jl new file mode 100644 index 0000000..e72c336 --- /dev/null +++ b/src/premade_models/premade_agents/premade_softmax.jl @@ -0,0 +1,106 @@ +""" + hgf_binary_softmax_action(agent, input) + +Action model which gives a binary action. The action probability is the softmax of a specified state of a node. + +In addition to the HGF substruct, the following must be present in the agent: +Parameters: "softmax_action_precision" +Settings: "target_state" +""" +function hgf_binary_softmax_action(agent::Agent, input) + + #Get out HGF, settings and parameters + hgf = agent.substruct + target_state = agent.settings["target_state"] + action_noise = agent.parameters["action_noise"] + + #Update the HGF + update_hgf!(hgf, input) + + #Get the specified state + target_value = get_states(hgf, target_state) + + #Use sotmax to get the action probability + action_probability = 1 / (1 + exp(action_noise * target_value)) + + #If the action probability is not between 0 and 1 + if !(0 <= action_probability <= 1) + #Throw an error that will reject samples when fitted + throw( + RejectParameters( + "With these parameters and inputs, the action probability became $action_probability, which should be between 0 and 1. Try other parameter settings", + ), + ) + end + + #Create Bernoulli normal distribution with mean of the target value and a standard deviation from parameters + distribution = Distributions.Bernoulli(action_probability) + + #Return the action distribution + return distribution +end + + + + + +""" + premade_hgf_binary_softmax(config::Dict) + +Create an agent suitable for the HGF binary softmax model. + +# Config defaults: + - "HGF": "binary_3level" + - "softmax_action_precision": 1 + - "target_state": ("xbin", "prediction_mean") +""" +function premade_hgf_binary_softmax(config::Dict) + + ## Combine defaults and user settings + + #Default parameters and settings + defaults = Dict( + "action_noise" => 1, + "target_state" => ("xbin", "prediction_mean"), + "HGF" => "binary_3level", + ) + + #If there is no HGF in the user-set parameters + if !("HGF" in keys(config)) + HGF_name = defaults["HGF"] + #Make a default HGF + config["HGF"] = premade_hgf(HGF_name) + #And warn them + @warn "an HGF was not set by the user. Using the default: a $HGF_name HGF with default settings" + end + + #Warn the user about used defaults and misspecified keys + warn_premade_defaults(defaults, config) + + #Merge to overwrite defaults + config = merge(defaults, config) + + + ## Create agent + #Set the action model + action_model = hgf_binary_softmax_action + + #Set the HGF + hgf = config["HGF"] + + #Set parameters + parameters = Dict("action_noise" => config["action_noise"]) + #Set states + states = Dict() + #Set settings + settings = Dict("target_state" => config["target_state"]) + + #Create the agent + return init_agent( + action_model, + substruct = hgf, + parameters = parameters, + states = states, + settings = settings, + ) +end diff --git a/src/premade_models/premade_hgfs.jl b/src/premade_models/premade_hgfs.jl deleted file mode 100644 index fc1372f..0000000 --- a/src/premade_models/premade_hgfs.jl +++ /dev/null @@ -1,945 +0,0 @@ -""" - premade_continuous_2level(config::Dict; verbose::Bool = true) - -The standard 2 level continuous HGF, which filters a continuous input. -It has a continous input node u, with a single value parent x1, which in turn has a single volatility parent x2. - -# Config defaults: - - ("u", "input_noise"): -2 - - ("x1", "volatility"): -2 - - ("x2", "volatility"): -2 - - ("u", "x1", "value_coupling"): 1 - - ("x1", "x2", "volatility_coupling"): 1 - - ("x1", "initial_mean"): 0 - - ("x1", "initial_precision"): 1 - - ("x2", "initial_mean"): 0 - - ("x2", "initial_precision"): 1 -""" -function premade_continuous_2level(config::Dict; verbose::Bool = true) - - #Defaults - spec_defaults = Dict( - ("u", "input_noise") => -2, - - ("x1", "volatility") => -2, - ("x1", "drift") => 0, - ("x1", "autoregression_target") => 0, - ("x1", "autoregression_strength") => 0, - ("x1", "initial_mean") => 0, - ("x1", "initial_precision") => 1, - - ("x2", "volatility") => -2, - ("x2", "drift") => 0, - ("x2", "autoregression_target") => 0, - ("x2", "autoregression_strength") => 0, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - - ("u", "x1", "value_coupling") => 1, - ("x1", "x2", "volatility_coupling") => 1, - - "update_type" => EnhancedUpdate(), - ) - - #Warn the user about used defaults and misspecified keys - if verbose - warn_premade_defaults(spec_defaults, config) - end - - #Merge to overwrite defaults - config = merge(spec_defaults, config) - - - #List of input nodes to create - input_nodes = Dict( - "name" => "u", - "type" => "continuous", - "input_noise" => config[("u", "input_noise")], - ) - - #List of state nodes to create - state_nodes = [ - Dict( - "name" => "x1", - "type" => "continuous", - "volatility" => config[("x1", "volatility")], - "drift" => config[("x1", "drift")], - "autoregression_target" => config[("x1", "autoregression_target")], - "autoregression_strength" => config[("x1", "autoregression_strength")], - "initial_mean" => config[("x1", "initial_mean")], - "initial_precision" => config[("x1", "initial_precision")], - ), - Dict( - "name" => "x2", - "type" => "continuous", - "volatility" => config[("x2", "volatility")], - "drift" => config[("x2", "drift")], - "autoregression_target" => config[("x2", "autoregression_target")], - "autoregression_strength" => config[("x2", "autoregression_strength")], - "initial_mean" => config[("x2", "initial_mean")], - "initial_precision" => config[("x2", "initial_precision")], - ), - ] - - #List of child-parent relations - edges = [ - Dict( - "child" => "u", - "value_parents" => ("x1", config[("u", "x1", "value_coupling")]), - ), - Dict( - "child" => "x1", - "volatility_parents" => ("x2", config[("x1", "x2", "volatility_coupling")]), - ), - ] - - #Initialize the HGF - init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - verbose = false, - update_type = config["update_type"], - ) -end - - -""" -premade_JGET(config::Dict; verbose::Bool = true) - -The HGF used in the JGET model. It has a single continuous input node u, with a value parent x1, and a volatility parent x3. x1 has volatility parent x2, and x3 has a volatility parent x4. - -# Config defaults: - - ("u", "input_noise"): -2 - - ("x1", "volatility"): -2 - - ("x2", "volatility"): -2 - - ("x3", "volatility"): -2 - - ("x4", "volatility"): -2 - - ("u", "x1", "value_coupling"): 1 - - ("u", "x3", "value_coupling"): 1 - - ("x1", "x2", "volatility_coupling"): 1 - - ("x3", "x4", "volatility_coupling"): 1 - - ("x1", "initial_mean"): 0 - - ("x1", "initial_precision"): 1 - - ("x2", "initial_mean"): 0 - - ("x2", "initial_precision"): 1 - - ("x3", "initial_mean"): 0 - - ("x3", "initial_precision"): 1 - - ("x4", "initial_mean"): 0 - - ("x4", "initial_precision"): 1 -""" -function premade_JGET(config::Dict; verbose::Bool = true) - - #Defaults - spec_defaults = Dict( - ("u", "input_noise") => -2, - - ("x1", "volatility") => -2, - ("x1", "drift") => 0, - ("x1", "autoregression_target") => 0, - ("x1", "autoregression_strength") => 0, - ("x1", "initial_mean") => 0, - ("x1", "initial_precision") => 1, - - ("x2", "volatility") => -2, - ("x2", "drift") => 0, - ("x2", "autoregression_target") => 0, - ("x2", "autoregression_strength") => 0, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - - ("x3", "volatility") => -2, - ("x3", "drift") => 0, - ("x3", "autoregression_target") => 0, - ("x3", "autoregression_strength") => 0, - ("x3", "initial_mean") => 0, - ("x3", "initial_precision") => 1, - - ("x4", "volatility") => -2, - ("x4", "drift") => 0, - ("x4", "autoregression_target") => 0, - ("x4", "autoregression_strength") => 0, - ("x4", "initial_mean") => 0, - ("x4", "initial_precision") => 1, - - ("u", "x1", "value_coupling") => 1, - ("u", "x3", "volatility_coupling") => 1, - ("x1", "x2", "volatility_coupling") => 1, - ("x3", "x4", "volatility_coupling") => 1, - - "update_type" => EnhancedUpdate(), - ) - - #Warn the user about used defaults and misspecified keys - if verbose - warn_premade_defaults(spec_defaults, config) - end - - #Merge to overwrite defaults - config = merge(spec_defaults, config) - - - #List of input nodes to create - input_nodes = Dict( - "name" => "u", - "type" => "continuous", - "input_noise" => config[("u", "input_noise")], - ) - - #List of state nodes to create - state_nodes = [ - Dict( - "name" => "x1", - "type" => "continuous", - "volatility" => config[("x1", "volatility")], - "drift" => config[("x1", "drift")], - "autoregression_target" => config[("x1", "autoregression_target")], - "autoregression_strength" => config[("x1", "autoregression_strength")], - "initial_mean" => config[("x1", "initial_mean")], - "initial_precision" => config[("x1", "initial_precision")], - ), - Dict( - "name" => "x2", - "type" => "continuous", - "volatility" => config[("x2", "volatility")], - "drift" => config[("x2", "drift")], - "autoregression_target" => config[("x2", "autoregression_target")], - "autoregression_strength" => config[("x2", "autoregression_strength")], - "initial_mean" => config[("x2", "initial_mean")], - "initial_precision" => config[("x2", "initial_precision")], - ), - Dict( - "name" => "x3", - "type" => "continuous", - "volatility" => config[("x3", "volatility")], - "drift" => config[("x3", "drift")], - "autoregression_target" => config[("x3", "autoregression_target")], - "autoregression_strength" => config[("x3", "autoregression_strength")], - "initial_mean" => config[("x3", "initial_precision")], - "initial_precision" => config[("x3", "initial_precision")], - ), - Dict( - "name" => "x4", - "type" => "continuous", - "volatility" => config[("x4", "volatility")], - "drift" => config[("x4", "drift")], - "autoregression_target" => config[("x4", "autoregression_target")], - "autoregression_strength" => config[("x4", "autoregression_strength")], - "initial_mean" => config[("x4", "initial_mean")], - "initial_precision" => config[("x4", "initial_precision")], - ), - ] - - #List of child-parent relations - edges = [ - Dict( - "child" => "u", - "value_parents" => ("x1", config[("u", "x1", "value_coupling")]), - "volatility_parents" => ("x3", config[("u", "x3", "volatility_coupling")]), - ), - Dict( - "child" => "x1", - "volatility_parents" => ("x2", config[("x1", "x2", "volatility_coupling")]), - ), - Dict( - "child" => "x3", - "volatility_parents" => ("x4", config[("x3", "x4", "volatility_coupling")]), - ), - ] - - #Initialize the HGF - init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - verbose = false, - update_type = config["update_type"], - ) -end - - -""" - premade_binary_2level(config::Dict; verbose::Bool = true) - -The standard binary 2 level HGF model, which takes a binary input, and learns the probability of either outcome. -It has one binary input node u, with a binary value parent x1, which in turn has a continuous value parent x2. - -# Config defaults: - - ("u", "category_means"): [0, 1] - - ("u", "input_precision"): Inf - - ("x2", "volatility"): -2 - - ("x1", "x2", "value_coupling"): 1 - - ("x2", "initial_mean"): 0 - - ("x2", "initial_precision"): 1 -""" -function premade_binary_2level(config::Dict; verbose::Bool = true) - - #Defaults - spec_defaults = Dict( - ("u", "category_means") => [0, 1], - ("u", "input_precision") => Inf, - - ("x2", "volatility") => -2, - ("x2", "drift") => 0, - ("x2", "autoregression_target") => 0, - ("x2", "autoregression_strength") => 0, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - - ("x1", "x2", "value_coupling") => 1, - - "update_type" => EnhancedUpdate(), - ) - - #Warn the user about used defaults and misspecified keys - if verbose - warn_premade_defaults(spec_defaults, config) - end - - #Merge to overwrite defaults - config = merge(spec_defaults, config) - - - #List of input nodes to create - input_nodes = Dict( - "name" => "u", - "type" => "binary", - "category_means" => config[("u", "category_means")], - "input_precision" => config[("u", "input_precision")], - ) - - #List of state nodes to create - state_nodes = [ - Dict("name" => "x1", "type" => "binary"), - Dict( - "name" => "x2", - "type" => "continuous", - "volatility" => config[("x2", "volatility")], - "drift" => config[("x2", "drift")], - "autoregression_target" => config[("x2", "autoregression_target")], - "autoregression_strength" => config[("x2", "autoregression_strength")], - "initial_mean" => config[("x2", "initial_mean")], - "initial_precision" => config[("x2", "initial_precision")], - ), - ] - - #List of child-parent relations - edges = [ - Dict("child" => "u", "value_parents" => "x1"), - Dict( - "child" => "x1", - "value_parents" => ("x2", config[("x1", "x2", "value_coupling")]), - ), - ] - - #Initialize the HGF - init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - verbose = false, - update_type = config["update_type"], - ) -end - - -""" - premade_binary_3level(config::Dict; verbose::Bool = true) - -The standard binary 3 level HGF model, which takes a binary input, and learns the probability of either outcome. -It has one binary input node u, with a binary value parent x1, which in turn has a continuous value parent x2. This then has a continunous volatility parent x3. - -This HGF has five shared parameters: -"x2_volatility" -"x2_initial_precisions" -"x2_initial_means" -"value_couplings_x1_x2" -"volatility_couplings_x2_x3" - -# Config defaults: - - ("u", "category_means"): [0, 1] - - ("u", "input_precision"): Inf - - ("x2", "volatility"): -2 - - ("x3", "volatility"): -2 - - ("x1", "x2", "value_coupling"): 1 - - ("x2", "x3", "volatility_coupling"): 1 - - ("x2", "initial_mean"): 0 - - ("x2", "initial_precision"): 1 - - ("x3", "initial_mean"): 0 - - ("x3", "initial_precision"): 1 -""" -function premade_binary_3level(config::Dict; verbose::Bool = true) - - #Defaults - defaults = Dict( - ("u", "category_means") => [0, 1], - ("u", "input_precision") => Inf, - - ("x2", "volatility") => -2, - ("x2", "drift") => 0, - ("x2", "autoregression_target") => 0, - ("x2", "autoregression_strength") => 0, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - - ("x3", "volatility") => -2, - ("x3", "drift") => 0, - ("x3", "autoregression_target") => 0, - ("x3", "autoregression_strength") => 0, - ("x3", "initial_mean") => 0, - ("x3", "initial_precision") => 1, - - ("x1", "x2", "value_coupling") => 1, - ("x2", "x3", "volatility_coupling") => 1, - - "update_type" => EnhancedUpdate(), - ) - - #Warn the user about used defaults and misspecified keys - if verbose - warn_premade_defaults(defaults, config) - end - - #Merge to overwrite defaults - config = merge(defaults, config) - - - #List of input nodes to create - input_nodes = Dict( - "name" => "u", - "type" => "binary", - "category_means" => config[("u", "category_means")], - "input_precision" => config[("u", "input_precision")], - ) - - #List of state nodes to create - state_nodes = [ - Dict("name" => "x1", "type" => "binary"), - Dict( - "name" => "x2", - "type" => "continuous", - "volatility" => config[("x2", "volatility")], - "drift" => config[("x2", "drift")], - "autoregression_target" => config[("x2", "autoregression_target")], - "autoregression_strength" => config[("x2", "autoregression_strength")], - "initial_mean" => config[("x2", "initial_mean")], - "initial_precision" => config[("x2", "initial_precision")], - ), - Dict( - "name" => "x3", - "type" => "continuous", - "volatility" => config[("x3", "volatility")], - "drift" => config[("x3", "drift")], - "autoregression_target" => config[("x3", "autoregression_target")], - "autoregression_strength" => config[("x3", "autoregression_strength")], - "initial_mean" => config[("x3", "initial_mean")], - "initial_precision" => config[("x3", "initial_precision")], - ), - ] - - #List of child-parent relations - edges = [ - Dict("child" => "u", "value_parents" => "x1"), - Dict( - "child" => "x1", - "value_parents" => ("x2", config[("x1", "x2", "value_coupling")]), - ), - Dict( - "child" => "x2", - "volatility_parents" => ("x3", config[("x2", "x3", "volatility_coupling")]), - ), - ] - - #Initialize the HGF - init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - verbose = false, - update_type = config["update_type"], - ) -end - -""" - premade_categorical_3level(config::Dict; verbose::Bool = true) - -The categorical 3 level HGF model, which takes an input from one of n categories and learns the probability of a category appearing. -It has one categorical input node u, with a categorical value parent x1. -The categorical node has a binary value parent x1_n for each category n, each of which has a continuous value parent x2_n. -Finally, all of these continuous nodes share a continuous volatility parent x3. -Setting parameter values for x1 and x2 sets that parameter value for each of the x1_n and x2_n nodes. - -# Config defaults: - - "n_categories": 4 - - ("x2", "volatility"): -2 - - ("x3", "volatility"): -2 - - ("x1", "x2", "value_coupling"): 1 - - ("x2", "x3", "volatility_coupling"): 1 - - ("x2", "initial_mean"): 0 - - ("x2", "initial_precision"): 1 - - ("x3", "initial_mean"): 0 - - ("x3", "initial_precision"): 1 -""" -function premade_categorical_3level(config::Dict; verbose::Bool = true) - - #Defaults - defaults = Dict( - "n_categories" => 4, - - ("x2", "volatility") => -2, - ("x2", "drift") => 0, - ("x2", "autoregression_target") => 0, - ("x2", "autoregression_strength") => 0, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - - ("x3", "volatility") => -2, - ("x3", "drift") => 0, - ("x3", "autoregression_target") => 0, - ("x3", "autoregression_strength") => 0, - ("x3", "initial_mean") => 0, - ("x3", "initial_precision") => 1, - - ("x1", "x2", "value_coupling") => 1, - ("x2", "x3", "volatility_coupling") => 1, - - "update_type" => EnhancedUpdate(), - ) - - #Warn the user about used defaults and misspecified keys - if verbose - warn_premade_defaults(defaults, config) - end - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ##Prep category node parent names - #Vector for category node binary parent names - category_binary_parent_names = Vector{String}() - #Vector for binary node continuous parent names - binary_continuous_parent_names = Vector{String}() - - #Empty lists for derived parameters - derived_parameters_x2_initial_precision = [] - derived_parameters_x2_initial_mean = [] - derived_parameters_x2_volatility = [] - derived_parameters_x2_drift = [] - derived_parameters_x2_autoregression_target = [] - derived_parameters_x2_autoregression_strength = [] - derived_parameters_x2_x3_volatility_coupling = [] - derived_parameters_value_coupling_x1_x2 = [] - - #Populate the category node vectors with node names - for category_number = 1:config["n_categories"] - push!(category_binary_parent_names, "x1_" * string(category_number)) - push!(binary_continuous_parent_names, "x2_" * string(category_number)) - end - - ##List of input nodes - input_nodes = Dict("name" => "u", "type" => "categorical") - - ##List of state nodes - state_nodes = [Dict{String,Any}("name" => "x1", "type" => "categorical")] - - #Add category node binary parents - for node_name in category_binary_parent_names - push!(state_nodes, Dict("name" => node_name, "type" => "binary")) - end - - #Add binary node continuous parents - for node_name in binary_continuous_parent_names - push!( - state_nodes, - Dict( - "name" => node_name, - "type" => "continuous", - "initial_mean" => config[("x2", "initial_mean")], - "initial_precision" => config[("x2", "initial_precision")], - "volatility" => config[("x2", "volatility")], - "drift" => config[("x2", "drift")], - "autoregression_target" => config[("x2", "autoregression_target")], - "autoregression_strength" => config[("x2", "autoregression_strength")], - ), - ) - #Add the derived parameter name to derived parameters vector - push!(derived_parameters_x2_initial_precision, (node_name, "initial_precision")) - push!(derived_parameters_x2_initial_mean, (node_name, "initial_mean")) - push!(derived_parameters_x2_volatility, (node_name, "volatility")) - push!(derived_parameters_x2_drift, (node_name, "drift")) - push!(derived_parameters_x2_autoregression_strength, (node_name, "autoregression_strength")) - push!(derived_parameters_x2_autoregression_target, (node_name, "autoregression_target")) - end - - #Add volatility parent - push!( - state_nodes, - Dict( - "name" => "x3", - "type" => "continuous", - "volatility" => config[("x3", "volatility")], - "drift" => config[("x3", "drift")], - "autoregression_target" => config[("x3", "autoregression_target")], - "autoregression_strength" => config[("x3", "autoregression_strength")], - "initial_mean" => config[("x3", "initial_mean")], - "initial_precision" => config[("x3", "initial_precision")], - ), - ) - - - ##List of child-parent relations - edges = [ - Dict("child" => "u", "value_parents" => "x1"), - Dict("child" => "x1", "value_parents" => category_binary_parent_names), - ] - - #Add relations between binary nodes and their parents - for (child_name, parent_name) in - zip(category_binary_parent_names, binary_continuous_parent_names) - push!( - edges, - Dict( - "child" => child_name, - "value_parents" => (parent_name, config[("x1", "x2", "value_coupling")]), - ), - ) - #Add the derived parameter name to derived parameters vector - push!( - derived_parameters_value_coupling_x1_x2, - (child_name, parent_name, "value_coupling"), - ) - end - - #Add relations between binary node parents and the volatility parent - for child_name in binary_continuous_parent_names - push!( - edges, - Dict( - "child" => child_name, - "volatility_parents" => ("x3", config[("x2", "x3", "volatility_coupling")]), - ), - ) - #Add the derived parameter name to derived parameters vector - push!( - derived_parameters_x2_x3_volatility_coupling, - (child_name, "x3", "volatility_coupling"), - ) - end - - #Create dictionary with shared parameter information - shared_parameters = Dict() - - shared_parameters["x2_volatility"] = - (config[("x2", "volatility")], derived_parameters_x2_volatility) - - shared_parameters["x2_initial_precisions"] = - (config[("x2", "initial_precision")], derived_parameters_x2_initial_precision) - - shared_parameters["x2_initial_means"] = - (config[("x2", "initial_mean")], derived_parameters_x2_initial_mean) - - shared_parameters["x2_drifts"] = - (config[("x2", "drift")], derived_parameters_x2_drift) - - shared_parameters["x2_autoregression_strengths"] = - (config[("x2", "autoregression_strength")], derived_parameters_x2_autoregression_strength) - - shared_parameters["x2_autoregression_targets"] = - (config[("x2", "autoregression_target")], derived_parameters_x2_autoregression_target) - - shared_parameters["value_couplings_x1_x2"] = - (config[("x1", "x2", "value_coupling")], derived_parameters_value_coupling_x1_x2) - - shared_parameters["volatility_couplings_x2_x3"] = ( - config[("x2", "x3", "volatility_coupling")], - derived_parameters_x2_x3_volatility_coupling, - ) - - #Initialize the HGF - init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - shared_parameters = shared_parameters, - verbose = false, - update_type = config["update_type"], - ) -end - -""" - premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true) - -The categorical state transition 3 level HGF model, learns state transition probabilities between a set of n categorical states. -It has one categorical input node u, with a categorical value parent x1_n for each of the n categories, representing which category was transitioned from. -Each categorical node then has a binary parent x1_n_m, representing the category m which the transition was towards. -Each binary node x1_n_m has a continuous parent x2_n_m. -Finally, all of these continuous nodes share a continuous volatility parent x3. -Setting parameter values for x1 and x2 sets that parameter value for each of the x1_n_m and x2_n_m nodes. - -This HGF has five shared parameters: -"x2_volatility" -"x2_initial_precisions" -"x2_initial_means" -"value_couplings_x1_x2" -"volatility_couplings_x2_x3" - -# Config defaults: - - "n_categories": 4 - - ("x2", "volatility"): -2 - - ("x3", "volatility"): -2 - - ("x1", "x2", "volatility_coupling"): 1 - - ("x2", "x3", "volatility_coupling"): 1 - - ("x2", "initial_mean"): 0 - - ("x2", "initial_precision"): 1 - - ("x3", "initial_mean"): 0 - - ("x3", "initial_precision"): 1 -""" -function premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true) - - #Defaults - defaults = Dict( - "n_categories" => 4, - - ("x2", "volatility") => -2, - ("x2", "drift") => 0, - ("x2", "autoregression_target") => 0, - ("x2", "autoregression_strength") => 0, - ("x2", "initial_mean") => 0, - ("x2", "initial_precision") => 1, - - ("x3", "volatility") => -2, - ("x3", "drift") => 0, - ("x3", "autoregression_target") => 0, - ("x3", "autoregression_strength") => 0, - ("x3", "initial_mean") => 0, - ("x3", "initial_precision") => 1, - - ("x1", "x2", "value_coupling") => 1, - ("x2", "x3", "volatility_coupling") => 1, - - "update_type" => EnhancedUpdate(), - ) - - #Warn the user about used defaults and misspecified keys - if verbose - warn_premade_defaults(defaults, config) - end - - #Merge to overwrite defaults - config = merge(defaults, config) - - - ##Prepare node names - #Empty lists for node names - categorical_input_node_names = Vector{String}() - categorical_state_node_names = Vector{String}() - categorical_node_binary_parent_names = Vector{String}() - binary_node_continuous_parent_names = Vector{String}() - - #Empty lists for derived parameters - derived_parameters_x2_initial_precision = [] - derived_parameters_x2_initial_mean = [] - derived_parameters_x2_volatility = [] - derived_parameters_x2_drift = [] - derived_parameters_x2_autoregression_target = [] - derived_parameters_x2_autoregression_strength = [] - derived_parameters_value_coupling_x1_x2 = [] - derived_parameters_x2_x3_volatility_coupling = [] - - #Go through each category that the transition may have been from - for category_from = 1:config["n_categories"] - #One input node and its state node parent for each - push!(categorical_input_node_names, "u" * string(category_from)) - push!(categorical_state_node_names, "x1_" * string(category_from)) - #Go through each category that the transition may have been to - for category_to = 1:config["n_categories"] - #Each categorical state node has a binary parent for each - push!( - categorical_node_binary_parent_names, - "x1_" * string(category_from) * "_" * string(category_to), - ) - #And each binary parent has a continuous parent of its own - push!( - binary_node_continuous_parent_names, - "x2_" * string(category_from) * "_" * string(category_to), - ) - end - end - - ##Create input nodes - #Initialize list - input_nodes = Vector{Dict}() - - #For each categorical input node - for node_name in categorical_input_node_names - #Add it to the list - push!(input_nodes, Dict("name" => node_name, "type" => "categorical")) - end - - ##Create state nodes - #Initialize list - state_nodes = Vector{Dict}() - - #For each cateogrical state node - for node_name in categorical_state_node_names - #Add it to the list - push!(state_nodes, Dict("name" => node_name, "type" => "categorical")) - end - - #For each categorical node binary parent - for node_name in categorical_node_binary_parent_names - #Add it to the list - push!(state_nodes, Dict("name" => node_name, "type" => "binary")) - end - - #For each binary node continuous parent - for node_name in binary_node_continuous_parent_names - #Add it to the list, with parameter settings from the config - push!( - state_nodes, - Dict( - "name" => node_name, - "type" => "continuous", - "initial_mean" => config[("x2", "initial_mean")], - "initial_precision" => config[("x2", "initial_precision")], - "volatility" => config[("x2", "volatility")], - "drift" => config[("x2", "drift")], - "autoregression_target" => config[("x2", "autoregression_target")], - "autoregression_strength" => config[("x2", "autoregression_strength")], - ), - ) - #Add the derived parameter name to derived parameters vector - push!(derived_parameters_x2_initial_precision, (node_name, "initial_precision")) - push!(derived_parameters_x2_initial_mean, (node_name, "initial_mean")) - push!(derived_parameters_x2_volatility, (node_name, "volatility")) - push!(derived_parameters_x2_drift, (node_name, "drift")) - push!(derived_parameters_x2_autoregression_strength, (node_name, "autoregression_strength")) - push!(derived_parameters_x2_autoregression_target, (node_name, "autoregression_target")) - end - - - #Add the shared volatility parent of the continuous nodes - push!( - state_nodes, - Dict( - "name" => "x3", - "type" => "continuous", - "volatility" => config[("x3", "volatility")], - "drift" => config[("x3", "drift")], - "autoregression_target" => config[("x3", "autoregression_target")], - "autoregression_strength" => config[("x3", "autoregression_strength")], - "initial_mean" => config[("x3", "initial_mean")], - "initial_precision" => config[("x3", "initial_precision")], - ), - ) - - ##Create child-parent relations - #Initialize list - edges = Vector{Dict}() - - #For each categorical input node and its corresponding state node parent - for (child_name, parent_name) in - zip(categorical_input_node_names, categorical_state_node_names) - #Add their relation to the list - push!(edges, Dict("child" => child_name, "value_parents" => parent_name)) - end - - #For each categorical state node - for child_node_name in categorical_state_node_names - #Get the category it represents transitions from - (child_supername, child_category_from) = split(child_node_name, "_") - - #For each potential parent node - for parent_node_name in categorical_node_binary_parent_names - #Get the category it represents transitions from - (parent_supername, parent_category_from, parent_category_to) = - split(parent_node_name, "_") - - #If these match - if parent_category_from == child_category_from - #Add the parent as parent of the child - push!( - edges, - Dict("child" => child_node_name, "value_parents" => parent_node_name), - ) - end - end - end - - #For each binary parent of categorical nodes and their corresponding continuous parents - for (child_name, parent_name) in - zip(categorical_node_binary_parent_names, binary_node_continuous_parent_names) - #Add their relations to the list, with the same value coupling - push!( - edges, - Dict( - "child" => child_name, - "value_parents" => (parent_name, config[("x1", "x2", "value_coupling")]), - ), - ) - #Add the derived parameter name to derived parameters vector - push!( - derived_parameters_value_coupling_x1_x2, - (child_name, parent_name, "value_coupling"), - ) - end - - - #Add the shared continuous node volatility parent to the continuous nodes - for child_name in binary_node_continuous_parent_names - push!( - edges, - Dict( - "child" => child_name, - "volatility_parents" => ("x3", config[("x2", "x3", "volatility_coupling")]), - ), - ) - #Add the derived parameter name to derived parameters vector - push!( - derived_parameters_x2_x3_volatility_coupling, - (child_name, "x3", "volatility_coupling"), - ) - - end - - #Create dictionary with shared parameter information - - shared_parameters = Dict() - - shared_parameters["x2_volatility"] = - (config[("x2", "volatility")], derived_parameters_x2_volatility) - - shared_parameters["x2_initial_precisions"] = - (config[("x2", "initial_precision")], derived_parameters_x2_initial_precision) - - shared_parameters["x2_initial_means"] = - (config[("x2", "initial_mean")], derived_parameters_x2_initial_mean) - - shared_parameters["x2_drifts"] = - (config[("x2", "drift")], derived_parameters_x2_drift) - - shared_parameters["x2_autoregression_strengths"] = - (config[("x2", "autoregression_strength")], derived_parameters_x2_autoregression_strength) - - shared_parameters["x2_autoregression_targets"] = - (config[("x2", "autoregression_target")], derived_parameters_x2_autoregression_target) - - shared_parameters["value_couplings_x1_x2"] = - (config[("x1", "x2", "value_coupling")], derived_parameters_value_coupling_x1_x2) - - shared_parameters["volatility_couplings_x2_x3"] = ( - config[("x2", "x3", "volatility_coupling")], - derived_parameters_x2_x3_volatility_coupling, - ) - - #Initialize the HGF - init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - shared_parameters = shared_parameters, - verbose = false, - update_type = config["update_type"], - ) -end diff --git a/src/premade_models/premade_hgfs/premade_JGET.jl b/src/premade_models/premade_hgfs/premade_JGET.jl new file mode 100644 index 0000000..beb170b --- /dev/null +++ b/src/premade_models/premade_hgfs/premade_JGET.jl @@ -0,0 +1,124 @@ +""" +premade_JGET(config::Dict; verbose::Bool = true) + +The HGF used in the JGET model. It has a single continuous input node u, with a value parent x, and a volatility parent xnoise. x has volatility parent xvol, and xnoise has a volatility parent xnoise_vol. + +# Config defaults: + - ("u", "input_noise"): -2 + - ("x", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("xnoise", "volatility"): -2 + - ("xnoise_vol", "volatility"): -2 + - ("u", "x", "coupling_strength"): 1 + - ("u", "xnoise", "coupling_strength"): 1 + - ("x", "xvol", "coupling_strength"): 1 + - ("xnoise", "xnoise_vol", "coupling_strength"): 1 + - ("x", "initial_mean"): 0 + - ("x", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 + - ("xnoise", "initial_mean"): 0 + - ("xnoise", "initial_precision"): 1 + - ("xnoise_vol", "initial_mean"): 0 + - ("xnoise_vol", "initial_precision"): 1 +""" +function premade_JGET(config::Dict; verbose::Bool = true) + + #Defaults + spec_defaults = Dict( + ("u", "input_noise") => -2, + ("u", "bias") => 0, + ("x", "volatility") => -2, + ("x", "drift") => 0, + ("x", "autoconnection_strength") => 1, + ("x", "initial_mean") => 0, + ("x", "initial_precision") => 1, + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoconnection_strength") => 1, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, + ("xnoise", "volatility") => -2, + ("xnoise", "drift") => 0, + ("xnoise", "autoconnection_strength") => 1, + ("xnoise", "initial_mean") => 0, + ("xnoise", "initial_precision") => 1, + ("xnoise_vol", "volatility") => -2, + ("xnoise_vol", "drift") => 0, + ("xnoise_vol", "autoconnection_strength") => 1, + ("xnoise_vol", "initial_mean") => 0, + ("xnoise_vol", "initial_precision") => 1, + ("u", "xnoise", "coupling_strength") => 1, + ("x", "xvol", "coupling_strength") => 1, + ("xnoise", "xnoise_vol", "coupling_strength") => 1, + "update_type" => EnhancedUpdate(), + "save_history" => true, + ) + + #Warn the user about used defaults and misspecified keys + if verbose + warn_premade_defaults(spec_defaults, config) + end + + #Merge to overwrite defaults + config = merge(spec_defaults, config) + + #List of nodes + nodes = [ + ContinuousInput( + name = "u", + input_noise = config[("u", "input_noise")], + bias = config[("u", "bias")], + ), + ContinuousState( + name = "x", + volatility = config[("x", "volatility")], + drift = config[("x", "drift")], + autoconnection_strength = config[("x", "autoconnection_strength")], + initial_mean = config[("x", "initial_mean")], + initial_precision = config[("x", "initial_precision")], + ), + ContinuousState( + name = "xvol", + volatility = config[("xvol", "volatility")], + drift = config[("xvol", "drift")], + autoconnection_strength = config[("xvol", "autoconnection_strength")], + initial_mean = config[("xvol", "initial_mean")], + initial_precision = config[("xvol", "initial_precision")], + ), + ContinuousState( + name = "xnoise", + volatility = config[("xnoise", "volatility")], + drift = config[("xnoise", "drift")], + autoconnection_strength = config[("xnoise", "autoconnection_strength")], + initial_mean = config[("xnoise", "initial_mean")], + initial_precision = config[("xnoise", "initial_precision")], + ), + ContinuousState( + name = "xnoise_vol", + volatility = config[("xnoise_vol", "volatility")], + drift = config[("xnoise_vol", "drift")], + autoconnection_strength = config[("xnoise_vol", "autoconnection_strength")], + initial_mean = config[("xnoise_vol", "initial_mean")], + initial_precision = config[("xnoise_vol", "initial_precision")], + ), + ] + + #List of child-parent relations + edges = Dict( + ("u", "x") => ObservationCoupling(), + ("u", "xnoise") => NoiseCoupling(config[("u", "xnoise", "coupling_strength")]), + ("x", "xvol") => VolatilityCoupling(config[("x", "xvol", "coupling_strength")]), + ("xnoise", "xnoise_vol") => + VolatilityCoupling(config[("xnoise", "xnoise_vol", "coupling_strength")]), + ) + + #Initialize the HGF + init_hgf( + nodes = nodes, + edges = edges, + verbose = false, + node_defaults = NodeDefaults(update_type = config["update_type"]), + save_history = config["save_history"], + ) +end diff --git a/src/premade_models/premade_hgfs/premade_binary_2level.jl b/src/premade_models/premade_hgfs/premade_binary_2level.jl new file mode 100644 index 0000000..49750d9 --- /dev/null +++ b/src/premade_models/premade_hgfs/premade_binary_2level.jl @@ -0,0 +1,63 @@ + +""" + premade_binary_2level(config::Dict; verbose::Bool = true) + +The standard binary 2 level HGF model, which takes a binary input, and learns the probability of either outcome. +It has one binary input node u, with a binary value parent xbin, which in turn has a continuous value parent xprob. + +# Config defaults: +""" +function premade_binary_2level(config::Dict; verbose::Bool = true) + + #Defaults + spec_defaults = Dict( + ("u", "category_means") => [0, 1], + ("u", "input_precision") => Inf, + ("xprob", "volatility") => -2, + ("xprob", "drift") => 0, + ("xprob", "autoconnection_strength") => 1, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xbin", "xprob", "coupling_strength") => 1, + "update_type" => EnhancedUpdate(), + "save_history" => true, + ) + + #Warn the user about used defaults and misspecified keys + if verbose + warn_premade_defaults(spec_defaults, config) + end + + #Merge to overwrite defaults + config = merge(spec_defaults, config) + + #List of nodes + nodes = [ + BinaryInput("u"), + BinaryState("xbin"), + ContinuousState( + name = "xprob", + volatility = config[("xprob", "volatility")], + drift = config[("xprob", "drift")], + autoconnection_strength = config[("xprob", "autoconnection_strength")], + initial_mean = config[("xprob", "initial_mean")], + initial_precision = config[("xprob", "initial_precision")], + ), + ] + + #List of edges + edges = Dict( + ("u", "xbin") => ObservationCoupling(), + ("xbin", "xprob") => + ProbabilityCoupling(config[("xbin", "xprob", "coupling_strength")]), + ) + + #Initialize the HGF + init_hgf( + nodes = nodes, + edges = edges, + verbose = false, + node_defaults = NodeDefaults(update_type = config["update_type"]), + save_history = config["save_history"], + ) +end diff --git a/src/premade_models/premade_hgfs/premade_binary_3level.jl b/src/premade_models/premade_hgfs/premade_binary_3level.jl new file mode 100644 index 0000000..c04f733 --- /dev/null +++ b/src/premade_models/premade_hgfs/premade_binary_3level.jl @@ -0,0 +1,96 @@ + +""" + premade_binary_3level(config::Dict; verbose::Bool = true) + +The standard binary 3 level HGF model, which takes a binary input, and learns the probability of either outcome. +It has one binary input node u, with a binary value parent xbin, which in turn has a continuous value parent xprob. This then has a continunous volatility parent xvol. + +This HGF has five shared parameters: +"xprob_volatility" +"xprob_initial_precisions" +"xprob_initial_means" +"coupling_strengths_xbin_xprob" +"coupling_strengths_xprob_xvol" + +# Config defaults: + - ("u", "category_means"): [0, 1] + - ("u", "input_precision"): Inf + - ("xprob", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("xbin", "xprob", "coupling_strength"): 1 + - ("xprob", "xvol", "coupling_strength"): 1 + - ("xprob", "initial_mean"): 0 + - ("xprob", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 +""" +function premade_binary_3level(config::Dict; verbose::Bool = true) + + #Defaults + defaults = Dict( + ("u", "category_means") => [0, 1], + ("u", "input_precision") => Inf, + ("xprob", "volatility") => -2, + ("xprob", "drift") => 0, + ("xprob", "autoconnection_strength") => 1, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoconnection_strength") => 1, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "coupling_strength") => 1, + ("xprob", "xvol", "coupling_strength") => 1, + "update_type" => EnhancedUpdate(), + "save_history" => true, + ) + + #Warn the user about used defaults and misspecified keys + if verbose + warn_premade_defaults(defaults, config) + end + + #Merge to overwrite defaults + config = merge(defaults, config) + + #List of nodes + nodes = [ + BinaryInput("u"), + BinaryState("xbin"), + ContinuousState( + name = "xprob", + volatility = config[("xprob", "volatility")], + drift = config[("xprob", "drift")], + autoconnection_strength = config[("xprob", "autoconnection_strength")], + initial_mean = config[("xprob", "initial_mean")], + initial_precision = config[("xprob", "initial_precision")], + ), + ContinuousState( + name = "xvol", + volatility = config[("xvol", "volatility")], + drift = config[("xvol", "drift")], + autoconnection_strength = config[("xvol", "autoconnection_strength")], + initial_mean = config[("xvol", "initial_mean")], + initial_precision = config[("xvol", "initial_precision")], + ), + ] + + #List of child-parent relations + edges = Dict( + ("u", "xbin") => ObservationCoupling(), + ("xbin", "xprob") => + ProbabilityCoupling(config[("xbin", "xprob", "coupling_strength")]), + ("xprob", "xvol") => + VolatilityCoupling(config[("xprob", "xvol", "coupling_strength")]), + ) + + #Initialize the HGF + init_hgf( + nodes = nodes, + edges = edges, + verbose = false, + node_defaults = NodeDefaults(update_type = config["update_type"]), + save_history = config["save_history"], + ) +end diff --git a/src/premade_models/premade_hgfs/premade_categorical_3level.jl b/src/premade_models/premade_hgfs/premade_categorical_3level.jl new file mode 100644 index 0000000..d040a3c --- /dev/null +++ b/src/premade_models/premade_hgfs/premade_categorical_3level.jl @@ -0,0 +1,197 @@ + + +""" + premade_categorical_3level(config::Dict; verbose::Bool = true) + +The categorical 3 level HGF model, which takes an input from one of n categories and learns the probability of a category appearing. +It has one categorical input node u, with a categorical value parent xcat. +The categorical node has a binary value parent xbin_n for each category n, each of which has a continuous value parent xprob_n. +Finally, all of these continuous nodes share a continuous volatility parent xvol. +Setting parameter values for xbin and xprob sets that parameter value for each of the xbin_n and xprob_n nodes. + +# Config defaults: + - "n_categories": 4 + - ("xprob", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("xbin", "xprob", "coupling_strength"): 1 + - ("xprob", "xvol", "coupling_strength"): 1 + - ("xprob", "initial_mean"): 0 + - ("xprob", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 +""" +function premade_categorical_3level(config::Dict; verbose::Bool = true) + + #Defaults + defaults = Dict( + "n_categories" => 4, + ("xprob", "volatility") => -2, + ("xprob", "drift") => 0, + ("xprob", "autoconnection_strength") => 1, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoconnection_strength") => 1, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "coupling_strength") => 1, + ("xprob", "xvol", "coupling_strength") => 1, + "update_type" => EnhancedUpdate(), + "save_history" => true, + ) + + #Warn the user about used defaults and misspecified keys + if verbose + warn_premade_defaults(defaults, config) + end + + #Merge to overwrite defaults + config = merge(defaults, config) + + + ##Prep category node parent names + #Vector for category node binary parent names + category_parent_names = Vector{String}() + #Vector for binary node continuous parent names + probability_parent_names = Vector{String}() + + #Empty lists for grouped parameters + grouped_parameters_xprob_initial_precision = [] + grouped_parameters_xprob_initial_mean = [] + grouped_parameters_xprob_volatility = [] + grouped_parameters_xprob_drift = [] + grouped_parameters_xprob_autoconnection_strength = [] + grouped_parameters_xbin_xprob_coupling_strength = [] + grouped_parameters_xprob_xvol_coupling_strength = [] + + #Populate the category node vectors with node names + for category_number = 1:config["n_categories"] + push!(category_parent_names, "xbin_" * string(category_number)) + push!(probability_parent_names, "xprob_" * string(category_number)) + end + + #Initialize list of nodes + nodes = [CategoricalInput("u"), CategoricalState("xcat")] + + #Add category node binary parents + for node_name in category_parent_names + push!(nodes, BinaryState(node_name)) + end + + #Add binary node continuous parents + for node_name in probability_parent_names + push!( + nodes, + ContinuousState( + name = node_name, + volatility = config[("xprob", "volatility")], + drift = config[("xprob", "drift")], + autoconnection_strength = config[("xprob", "autoconnection_strength")], + initial_mean = config[("xprob", "initial_mean")], + initial_precision = config[("xprob", "initial_precision")], + ), + ) + #Add the grouped parameter name to grouped parameters vector + push!(grouped_parameters_xprob_initial_precision, (node_name, "initial_precision")) + push!(grouped_parameters_xprob_initial_mean, (node_name, "initial_mean")) + push!(grouped_parameters_xprob_volatility, (node_name, "volatility")) + push!(grouped_parameters_xprob_drift, (node_name, "drift")) + push!( + grouped_parameters_xprob_autoconnection_strength, + (node_name, "autoconnection_strength"), + ) + end + + #Add volatility parent + push!( + nodes, + ContinuousState( + name = "xvol", + volatility = config[("xvol", "volatility")], + drift = config[("xvol", "drift")], + autoconnection_strength = config[("xvol", "autoconnection_strength")], + initial_mean = config[("xvol", "initial_mean")], + initial_precision = config[("xvol", "initial_precision")], + ), + ) + + ##List of edges + #Set the input node coupling + edges = Dict{Tuple{String,String},CouplingType}(("u", "xcat") => ObservationCoupling()) + + #For each set of categroy parents and their probability parents + for (category_parent_name, probability_parent_name) in + zip(category_parent_names, probability_parent_names) + + #Connect the binary category parents to the categorical state node + edges[("xcat", category_parent_name)] = CategoryCoupling() + + #Connect each category parent to its probability parent + edges[(category_parent_name, probability_parent_name)] = + ProbabilityCoupling(config[("xbin", "xprob", "coupling_strength")]) + + #Connect the probability parents to the shared volatility parent + edges[(probability_parent_name, "xvol")] = + VolatilityCoupling(config[("xprob", "xvol", "coupling_strength")]) + + #Add the coupling strengths to the lists of grouped parameters + push!( + grouped_parameters_xbin_xprob_coupling_strength, + (category_parent_name, probability_parent_name, "coupling_strength"), + ) + push!( + grouped_parameters_xprob_xvol_coupling_strength, + (probability_parent_name, "xvol", "coupling_strength"), + ) + end + + #Create dictionary with shared parameter information + parameter_groups = [ + ParameterGroup( + "xprob_volatility", + grouped_parameters_xprob_volatility, + config[("xprob", "volatility")], + ), + ParameterGroup( + "xprob_initial_precision", + grouped_parameters_xprob_initial_precision, + config[("xprob", "initial_precision")], + ), + ParameterGroup( + "xprob_initial_mean", + grouped_parameters_xprob_initial_mean, + config[("xprob", "initial_mean")], + ), + ParameterGroup( + "xprob_drift", + grouped_parameters_xprob_drift, + config[("xprob", "drift")], + ), + ParameterGroup( + "xprob_autoconnection_strength", + grouped_parameters_xprob_autoconnection_strength, + config[("xprob", "autoconnection_strength")], + ), + ParameterGroup( + "xbin_xprob_coupling_strength", + grouped_parameters_xbin_xprob_coupling_strength, + config[("xbin", "xprob", "coupling_strength")], + ), + ParameterGroup( + "xprob_xvol_coupling_strength", + grouped_parameters_xprob_xvol_coupling_strength, + config[("xprob", "xvol", "coupling_strength")], + ), + ] + + #Initialize the HGF + init_hgf( + nodes = nodes, + edges = edges, + parameter_groups = parameter_groups, + verbose = false, + node_defaults = NodeDefaults(update_type = config["update_type"]), + save_history = config["save_history"], + ) +end diff --git a/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl new file mode 100644 index 0000000..8a77871 --- /dev/null +++ b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl @@ -0,0 +1,261 @@ + +""" + premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true) + +The categorical state transition 3 level HGF model, learns state transition probabilities between a set of n categorical states. +It has one categorical input node u, with a categorical value parent xcat_n for each of the n categories, representing which category was transitioned from. +Each categorical node then has a binary parent xbin_n_m, representing the category m which the transition was towards. +Each binary node xbin_n_m has a continuous parent xprob_n_m. +Finally, all of these continuous nodes share a continuous volatility parent xvol. +Setting parameter values for xbin and xprob sets that parameter value for each of the xbin_n_m and xprob_n_m nodes. + +This HGF has five shared parameters: +"xprob_volatility" +"xprob_initial_precisions" +"xprob_initial_means" +"coupling_strengths_xbin_xprob" +"coupling_strengths_xprob_xvol" + +# Config defaults: + - "n_categories": 4 + - ("xprob", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("xbin", "xprob", "coupling_strength"): 1 + - ("xprob", "xvol", "coupling_strength"): 1 + - ("xprob", "initial_mean"): 0 + - ("xprob", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 +""" +function premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true) + + #Defaults + defaults = Dict( + "n_categories" => 4, + ("xprob", "volatility") => -2, + ("xprob", "drift") => 0, + ("xprob", "autoconnection_strength") => 1, + ("xprob", "initial_mean") => 0, + ("xprob", "initial_precision") => 1, + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoconnection_strength") => 1, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, + ("xbin", "xprob", "coupling_strength") => 1, + ("xprob", "xvol", "coupling_strength") => 1, + "update_type" => EnhancedUpdate(), + "save_history" => true, + ) + + #Warn the user about used defaults and misspecified keys + if verbose + warn_premade_defaults(defaults, config) + end + + #Merge to overwrite defaults + config = merge(defaults, config) + + + ##Prepare node names + #Empty lists for node names + input_node_names = Vector{String}() + observation_parent_names = Vector{String}() + category_parent_names = Vector{String}() + probability_parent_names = Vector{String}() + + #Empty lists for grouped parameters + grouped_parameters_xprob_initial_precision = [] + grouped_parameters_xprob_initial_mean = [] + grouped_parameters_xprob_volatility = [] + grouped_parameters_xprob_drift = [] + grouped_parameters_xprob_autoconnection_strength = [] + grouped_parameters_xbin_xprob_coupling_strength = [] + grouped_parameters_xprob_xvol_coupling_strength = [] + + #Go through each category that the transition may have been from + for category_from = 1:config["n_categories"] + #One input node and its state node parent for each + push!(input_node_names, "u" * string(category_from)) + push!(observation_parent_names, "xcat_" * string(category_from)) + #Go through each category that the transition may have been to + for category_to = 1:config["n_categories"] + #Each categorical state node has a binary parent for each + push!( + category_parent_names, + "xbin_" * string(category_from) * "_" * string(category_to), + ) + #And each binary parent has a continuous parent of its own + push!( + probability_parent_names, + "xprob_" * string(category_from) * "_" * string(category_to), + ) + end + end + + ##List of nodes + nodes = Vector{AbstractNodeInfo}() + + #For each categorical input node + for node_name in input_node_names + #Add it to the list + push!(nodes, CategoricalInput(node_name)) + end + + #For each categorical state node + for node_name in observation_parent_names + #Add it to the list + push!(nodes, CategoricalState(node_name)) + end + + #For each categorical node binary parent + for node_name in category_parent_names + #Add it to the list + push!(nodes, BinaryState(node_name)) + end + + #For each binary node continuous parent + for node_name in probability_parent_names + #Add it to the list, with parameter settings from the config + push!( + nodes, + ContinuousState( + name = node_name, + volatility = config[("xprob", "volatility")], + drift = config[("xprob", "drift")], + autoconnection_strength = config[("xprob", "autoconnection_strength")], + initial_mean = config[("xprob", "initial_mean")], + initial_precision = config[("xprob", "initial_precision")], + ), + ) + #Add the grouped parameter name to grouped parameters vector + push!(grouped_parameters_xprob_initial_precision, (node_name, "initial_precision")) + push!(grouped_parameters_xprob_initial_mean, (node_name, "initial_mean")) + push!(grouped_parameters_xprob_volatility, (node_name, "volatility")) + push!(grouped_parameters_xprob_drift, (node_name, "drift")) + push!( + grouped_parameters_xprob_autoconnection_strength, + (node_name, "autoconnection_strength"), + ) + end + + + #Add the shared volatility parent of the continuous nodes + push!( + nodes, + ContinuousState( + name = "xvol", + volatility = config[("xvol", "volatility")], + drift = config[("xvol", "drift")], + autoconnection_strength = config[("xvol", "autoconnection_strength")], + initial_mean = config[("xvol", "initial_mean")], + initial_precision = config[("xvol", "initial_precision")], + ), + ) + + ##Create edges + #Initialize list + edges = Dict{Tuple{String,String},CouplingType}() + + #For each categorical input node and its corresponding state node parent + for (input_node_name, observation_parent_name) in + zip(input_node_names, observation_parent_names) + + #Add their connection + edges[(input_node_name, observation_parent_name)] = ObservationCoupling() + end + + #For each categorical state node + for observation_parent_name in observation_parent_names + #Get the category it represents transitions from + (observation_parent_prefix, child_category_from) = + split(observation_parent_name, "_") + + #For each potential parent node + for category_parent_name in category_parent_names + #Get the category it represents transitions to and from + (category_parent_prefix, parent_category_from, parent_category_to) = + split(category_parent_name, "_") + + #If these match + if parent_category_from == child_category_from + #Add their connection + edges[(observation_parent_name, category_parent_name)] = CategoryCoupling() + end + end + end + + #For each set of category parent and probability parent + for (category_parent_name, probability_parent_name) in + zip(category_parent_names, probability_parent_names) + + #Connect them + edges[(category_parent_name, probability_parent_name)] = + ProbabilityCoupling(config[("xbin", "xprob", "coupling_strength")]) + + #Connect the probability parents to the shared volatility parent + edges[(probability_parent_name, "xvol")] = + VolatilityCoupling(config[("xprob", "xvol", "coupling_strength")]) + + + #Add the parameters as grouped parameters for shared parameters + push!( + grouped_parameters_xbin_xprob_coupling_strength, + (category_parent_name, probability_parent_name, "coupling_strength"), + ) + push!( + grouped_parameters_xprob_xvol_coupling_strength, + (probability_parent_name, "xvol", "coupling_strength"), + ) + end + + #Create dictionary with shared parameter information + + parameter_groups = [ + ParameterGroup( + "xprob_volatility", + grouped_parameters_xprob_volatility, + config[("xprob", "volatility")], + ), + ParameterGroup( + "xprob_initial_precision", + grouped_parameters_xprob_initial_precision, + config[("xprob", "initial_precision")], + ), + ParameterGroup( + "xprob_initial_mean", + grouped_parameters_xprob_initial_mean, + config[("xprob", "initial_mean")], + ), + ParameterGroup( + "xprob_drift", + grouped_parameters_xprob_drift, + config[("xprob", "drift")], + ), + ParameterGroup( + "xprob_autoconnection_strength", + grouped_parameters_xprob_autoconnection_strength, + config[("xprob", "autoconnection_strength")], + ), + ParameterGroup( + "xbin_xprob_coupling_strength", + grouped_parameters_xbin_xprob_coupling_strength, + config[("xbin", "xprob", "coupling_strength")], + ), + ParameterGroup( + "xprob_xvol_coupling_strength", + grouped_parameters_xprob_xvol_coupling_strength, + config[("xprob", "xvol", "coupling_strength")], + ), + ] + + #Initialize the HGF + init_hgf( + nodes = nodes, + edges = edges, + parameter_groups = parameter_groups, + verbose = false, + node_defaults = NodeDefaults(update_type = config["update_type"]), + save_history = config["save_history"], + ) +end diff --git a/src/premade_models/premade_hgfs/premade_continuous_2level.jl b/src/premade_models/premade_hgfs/premade_continuous_2level.jl new file mode 100644 index 0000000..a4c01f4 --- /dev/null +++ b/src/premade_models/premade_hgfs/premade_continuous_2level.jl @@ -0,0 +1,86 @@ +""" + premade_continuous_2level(config::Dict; verbose::Bool = true) + +The standard 2 level continuous HGF, which filters a continuous input. +It has a continous input node u, with a single value parent x, which in turn has a single volatility parent xvol. + +# Config defaults: + - ("u", "input_noise"): -2 + - ("x", "volatility"): -2 + - ("xvol", "volatility"): -2 + - ("u", "x", "coupling_strength"): 1 + - ("x", "xvol", "coupling_strength"): 1 + - ("x", "initial_mean"): 0 + - ("x", "initial_precision"): 1 + - ("xvol", "initial_mean"): 0 + - ("xvol", "initial_precision"): 1 +""" +function premade_continuous_2level(config::Dict; verbose::Bool = true) + + #Defaults + spec_defaults = Dict( + ("u", "input_noise") => -2, + ("u", "bias") => 0, + ("x", "volatility") => -2, + ("x", "drift") => 0, + ("x", "autoconnection_strength") => 1, + ("x", "initial_mean") => 0, + ("x", "initial_precision") => 1, + ("xvol", "volatility") => -2, + ("xvol", "drift") => 0, + ("xvol", "autoconnection_strength") => 1, + ("xvol", "initial_mean") => 0, + ("xvol", "initial_precision") => 1, + ("x", "xvol", "coupling_strength") => 1, + "update_type" => EnhancedUpdate(), + "save_history" => true, + ) + + #Warn the user about used defaults and misspecified keys + if verbose + warn_premade_defaults(spec_defaults, config) + end + + #Merge to overwrite defaults + config = merge(spec_defaults, config) + + #List of nodes + nodes = [ + ContinuousInput( + name = "u", + input_noise = config[("u", "input_noise")], + bias = config[("u", "bias")], + ), + ContinuousState( + name = "x", + volatility = config[("x", "volatility")], + drift = config[("x", "drift")], + autoconnection_strength = config[("x", "autoconnection_strength")], + initial_mean = config[("x", "initial_mean")], + initial_precision = config[("x", "initial_precision")], + ), + ContinuousState( + name = "xvol", + volatility = config[("xvol", "volatility")], + drift = config[("xvol", "drift")], + autoconnection_strength = config[("xvol", "autoconnection_strength")], + initial_mean = config[("xvol", "initial_mean")], + initial_precision = config[("xvol", "initial_precision")], + ), + ] + + #List of child-parent relations + edges = Dict( + ("u", "x") => ObservationCoupling(), + ("x", "xvol") => VolatilityCoupling(config[("x", "xvol", "coupling_strength")]), + ) + + #Initialize the HGF + init_hgf( + nodes = nodes, + edges = edges, + verbose = false, + node_defaults = NodeDefaults(update_type = config["update_type"]), + save_history = config["save_history"], + ) +end diff --git a/src/structs.jl b/src/structs.jl deleted file mode 100644 index 1e7b776..0000000 --- a/src/structs.jl +++ /dev/null @@ -1,314 +0,0 @@ -################################ -######## Abstract Types ######## -################################ - -#Top-level node type -abstract type AbstractNode end - -#Input and state node subtypes -abstract type AbstractStateNode <: AbstractNode end -abstract type AbstractInputNode <: AbstractNode end - -#Supertype for HGF update types -abstract type HGFUpdateType end - -#Classic and enhance dupdate types -struct ClassicUpdate <: HGFUpdateType end -struct EnhancedUpdate <: HGFUpdateType end - - -####################################### -######## Continuous State Node ######## -####################################### -""" -Configuration of continuous state nodes' parameters -""" -Base.@kwdef mutable struct ContinuousStateNodeParameters - volatility::Real = 0 - drift::Real = 0 - autoregression_target::Real = 0 - autoregression_strength::Real = 0 - value_coupling::Dict{String,Real} = Dict{String,Real}() - volatility_coupling::Dict{String,Real} = Dict{String,Real}() - initial_mean::Real = 0 - initial_precision::Real = 0 -end - -""" -Configurations of the continuous state node states -""" -Base.@kwdef mutable struct ContinuousStateNodeState - posterior_mean::Union{Real} = 0 - posterior_precision::Union{Real} = 1 - value_prediction_error::Union{Real,Missing} = missing - volatility_prediction_error::Union{Real,Missing} = missing - prediction_mean::Union{Real,Missing} = missing - predicted_volatility::Union{Real,Missing} = missing - prediction_precision::Union{Real,Missing} = missing - volatility_weighted_prediction_precision::Union{Real,Missing} = missing -end - -""" -Configuration of continuous state node history -""" -Base.@kwdef mutable struct ContinuousStateNodeHistory - posterior_mean::Vector{Real} = [] - posterior_precision::Vector{Real} = [] - value_prediction_error::Vector{Union{Real,Missing}} = [missing] - volatility_prediction_error::Vector{Union{Real,Missing}} = [missing] - prediction_mean::Vector{Real} = [] - predicted_volatility::Vector{Real} = [] - prediction_precision::Vector{Real} = [] - volatility_weighted_prediction_precision::Vector{Real} = [] -end - -""" -""" -Base.@kwdef mutable struct ContinuousStateNode <: AbstractStateNode - name::String - value_parents::Vector{AbstractStateNode} = [] - volatility_parents::Vector{AbstractStateNode} = [] - value_children::Vector{AbstractNode} = [] - volatility_children::Vector{AbstractNode} = [] - parameters::ContinuousStateNodeParameters = ContinuousStateNodeParameters() - states::ContinuousStateNodeState = ContinuousStateNodeState() - history::ContinuousStateNodeHistory = ContinuousStateNodeHistory() - update_type::HGFUpdateType = ClassicUpdate() -end - - -################################### -######## Binary State Node ######## -################################### - -""" - Configure parameters of binary state node -""" -Base.@kwdef mutable struct BinaryStateNodeParameters - value_coupling::Dict{String,Real} = Dict{String,Real}() -end - -""" -Overview of the states of the binary state node -""" -Base.@kwdef mutable struct BinaryStateNodeState - posterior_mean::Union{Real,Missing} = missing - posterior_precision::Union{Real,Missing} = missing - value_prediction_error::Union{Real,Missing} = missing - prediction_mean::Union{Real,Missing} = missing - prediction_precision::Union{Real,Missing} = missing -end - -""" -Overview of the history of the binary state node -""" -Base.@kwdef mutable struct BinaryStateNodeHistory - posterior_mean::Vector{Union{Real,Missing}} = [] - posterior_precision::Vector{Union{Real,Missing}} = [] - value_prediction_error::Vector{Union{Real,Missing}} = [missing] - prediction_mean::Vector{Real} = [] - prediction_precision::Vector{Real} = [] -end - -""" -Overview of edge posibilities -""" -Base.@kwdef mutable struct BinaryStateNode <: AbstractStateNode - name::String - value_parents::Vector{AbstractStateNode} = [] - volatility_parents::Vector{Nothing} = [] - value_children::Vector{AbstractNode} = [] - volatility_children::Vector{Nothing} = [] - parameters::BinaryStateNodeParameters = BinaryStateNodeParameters() - states::BinaryStateNodeState = BinaryStateNodeState() - history::BinaryStateNodeHistory = BinaryStateNodeHistory() - update_type::HGFUpdateType = ClassicUpdate() -end - - -######################################## -######## Categorical State Node ######## -######################################## -Base.@kwdef mutable struct CategoricalStateNodeParameters end - -""" -Configuration of states in categorical state node -""" -Base.@kwdef mutable struct CategoricalStateNodeState - posterior::Vector{Union{Real,Missing}} = [] - value_prediction_error::Vector{Union{Real,Missing}} = [] - prediction::Vector{Real} = [] - parent_predictions::Vector{Real} = [] -end - -""" -Configuration of history in categorical state node -""" -Base.@kwdef mutable struct CategoricalStateNodeHistory - posterior::Vector{Vector{Union{Real,Missing}}} = [] - value_prediction_error::Vector{Vector{Union{Real,Missing}}} = [] - prediction::Vector{Vector{Real}} = [] - parent_predictions::Vector{Vector{Real}} = [] -end - -""" -Configuration of edges in categorical state node -""" -Base.@kwdef mutable struct CategoricalStateNode <: AbstractStateNode - name::String - value_parents::Vector{AbstractStateNode} = [] - volatility_parents::Vector{Nothing} = [] - value_children::Vector{AbstractNode} = [] - volatility_children::Vector{Nothing} = [] - category_parent_order::Vector{String} = [] - parameters::CategoricalStateNodeParameters = CategoricalStateNodeParameters() - states::CategoricalStateNodeState = CategoricalStateNodeState() - history::CategoricalStateNodeHistory = CategoricalStateNodeHistory() - update_type::HGFUpdateType = ClassicUpdate() -end - - -####################################### -######## Continuous Input Node ######## -####################################### -""" -Configuration of continuous input node parameters -""" -Base.@kwdef mutable struct ContinuousInputNodeParameters - input_noise::Real = 0 - value_coupling::Dict{String,Real} = Dict{String,Real}() - volatility_coupling::Dict{String,Real} = Dict{String,Real}() -end - -""" -Configuration of continuous input node states -""" -Base.@kwdef mutable struct ContinuousInputNodeState - input_value::Union{Real,Missing} = missing - value_prediction_error::Union{Real,Missing} = missing - volatility_prediction_error::Union{Real,Missing} = missing - predicted_volatility::Union{Real,Missing} = missing - prediction_precision::Union{Real,Missing} = missing - volatility_weighted_prediction_precision::Union{Real} = 1 -end - -""" -Configuration of continuous input node history -""" -Base.@kwdef mutable struct ContinuousInputNodeHistory - input_value::Vector{Union{Real,Missing}} = [missing] - value_prediction_error::Vector{Union{Real,Missing}} = [missing] - volatility_prediction_error::Vector{Union{Real,Missing}} = [missing] - predicted_volatility::Vector{Real} = [] - prediction_precision::Vector{Real} = [] -end - -""" -""" -Base.@kwdef mutable struct ContinuousInputNode <: AbstractInputNode - name::String - value_parents::Vector{AbstractStateNode} = [] - volatility_parents::Vector{AbstractStateNode} = [] - parameters::ContinuousInputNodeParameters = ContinuousInputNodeParameters() - states::ContinuousInputNodeState = ContinuousInputNodeState() - history::ContinuousInputNodeHistory = ContinuousInputNodeHistory() -end - - - -################################### -######## Binary Input Node ######## -################################### - -""" -Configuration of parameters in binary input node. Default category mean set to [0,1] -""" -Base.@kwdef mutable struct BinaryInputNodeParameters - category_means::Vector{Union{Real}} = [0, 1] - input_precision::Real = Inf -end - -""" -Configuration of states of binary input node -""" -Base.@kwdef mutable struct BinaryInputNodeState - input_value::Union{Real,Missing} = missing - value_prediction_error::Vector{Union{Real,Missing}} = [missing, missing] -end - -""" -Configuration of history of binary input node -""" -Base.@kwdef mutable struct BinaryInputNodeHistory - input_value::Vector{Union{Real,Missing}} = [missing] - value_prediction_error::Vector{Vector{Union{Real,Missing}}} = [[missing, missing]] -end - -""" -""" -Base.@kwdef mutable struct BinaryInputNode <: AbstractInputNode - name::String - value_parents::Vector{AbstractStateNode} = [] - volatility_parents::Vector{Nothing} = [] - parameters::BinaryInputNodeParameters = BinaryInputNodeParameters() - states::BinaryInputNodeState = BinaryInputNodeState() - history::BinaryInputNodeHistory = BinaryInputNodeHistory() -end - - - - -######################################## -######## Categorical Input Node ######## -######################################## - -Base.@kwdef mutable struct CategoricalInputNodeParameters end - -""" -Configuration of states of categorical input node -""" -Base.@kwdef mutable struct CategoricalInputNodeState - input_value::Union{Real,Missing} = missing -end - -""" -History of categorical input node -""" -Base.@kwdef mutable struct CategoricalInputNodeHistory - input_value::Vector{Union{Real,Missing}} = [missing] -end - -""" -""" -Base.@kwdef mutable struct CategoricalInputNode <: AbstractInputNode - name::String - value_parents::Vector{AbstractStateNode} = [] - volatility_parents::Vector{Nothing} = [] - parameters::CategoricalInputNodeParameters = CategoricalInputNodeParameters() - states::CategoricalInputNodeState = CategoricalInputNodeState() - history::CategoricalInputNodeHistory = CategoricalInputNodeHistory() -end - - -############################ -######## HGF Struct ######## -############################ -""" -""" -Base.@kwdef mutable struct OrderedNodes - all_nodes::Vector{AbstractNode} = [] - input_nodes::Vector{AbstractInputNode} = [] - all_state_nodes::Vector{AbstractStateNode} = [] - early_update_state_nodes::Vector{AbstractStateNode} = [] - late_update_state_nodes::Vector{AbstractStateNode} = [] -end - -""" -""" -Base.@kwdef mutable struct HGF - all_nodes::Dict{String,AbstractNode} - input_nodes::Dict{String,AbstractInputNode} - state_nodes::Dict{String,AbstractStateNode} - ordered_nodes::OrderedNodes = OrderedNodes() - shared_parameters::Dict = Dict() -end diff --git a/src/update_hgf/node_updates/binary_input_node.jl b/src/update_hgf/node_updates/binary_input_node.jl new file mode 100644 index 0000000..a6b77d1 --- /dev/null +++ b/src/update_hgf/node_updates/binary_input_node.jl @@ -0,0 +1,42 @@ +################################### +######## Update prediction ######## +################################### + +##### Superfunction ##### +""" + update_node_prediction!(node::BinaryInputNode) + +There is no prediction update for binary input nodes, as the prediction precision is constant. +""" +function update_node_prediction!(node::BinaryInputNode, stepsize::Real) + return nothing +end + + +############################################### +######## Update value prediction error ######## +############################################### + +##### Superfunction ##### +""" + update_node_value_prediction_error!(node::BinaryInputNode) + +Update the value prediction error of a single binary input node. +""" +function update_node_value_prediction_error!(node::BinaryInputNode) + return nothing +end + +################################################### +######## Update precision prediction error ######## +################################################### + +##### Superfunction ##### +""" + update_node_precision_prediction_error!(node::BinaryInputNode) + +There is no volatility prediction error update for binary input nodes. +""" +function update_node_precision_prediction_error!(node::BinaryInputNode) + return nothing +end diff --git a/src/update_hgf/node_updates/binary_state_node.jl b/src/update_hgf/node_updates/binary_state_node.jl new file mode 100644 index 0000000..4e80214 --- /dev/null +++ b/src/update_hgf/node_updates/binary_state_node.jl @@ -0,0 +1,231 @@ +################################### +######## Update prediction ######## +################################### + +##### Superfunction ##### +""" + update_node_prediction!(node::BinaryStateNode) + +Update the prediction of a single binary state node. +""" +function update_node_prediction!(node::BinaryStateNode, stepsize::Real) + + #Update prediction mean + node.states.prediction_mean = calculate_prediction_mean(node) + + #Update prediction precision + node.states.prediction_precision = calculate_prediction_precision(node) + + return nothing +end + +##### Mean update ##### +@doc raw""" + calculate_prediction_mean(node::BinaryStateNode) + +Calculates a binary state node's prediction mean. + +Uses the equation +`` \hat{\mu}_n= \big(1+e^{\sum_{j=1}^{j\;value \; parents} \hat{\mu}_{j}}\big)^{-1} `` +""" +function calculate_prediction_mean(node::BinaryStateNode) + probability_parents = node.edges.probability_parents + + prediction_mean = 0 + + for parent in probability_parents + prediction_mean += + parent.states.prediction_mean * node.parameters.coupling_strengths[parent.name] + end + + prediction_mean = 1 / (1 + exp(-prediction_mean)) + + return prediction_mean +end + +##### Precision update ##### +@doc raw""" + calculate_prediction_precision(node::BinaryStateNode) + +Calculates a binary state node's prediction precision. + +Uses the equation +`` \hat{\pi}_n = \frac{1}{\hat{\mu}_n \cdot (1-\hat{\mu}_n)} `` +""" +function calculate_prediction_precision(node::BinaryStateNode) + 1 / (node.states.prediction_mean * (1 - node.states.prediction_mean)) +end + +################################## +######## Update posterior ######## +################################## + +##### Superfunction ##### +""" + update_node_posterior!(node::AbstractStateNode; update_type::HGFUpdateType) + +Update the posterior of a single continuous state node. This is the classic HGF update. +""" +function update_node_posterior!(node::BinaryStateNode, update_type::HGFUpdateType) + #Update posterior precision + node.states.posterior_precision = calculate_posterior_precision(node) + + #Update posterior mean + node.states.posterior_mean = calculate_posterior_mean(node, update_type) + + return nothing +end + +##### Precision update ##### +@doc raw""" + calculate_posterior_precision(node::BinaryStateNode) + +Calculates a binary node's posterior precision. + +Uses the equations + +`` \pi_n = inf `` +if the precision is infinite + + `` \pi_n = \frac{1}{\hat{\mu}_n \cdot (1-\hat{\mu}_n)} `` + if the precision is other than infinite +""" +function calculate_posterior_precision(node::BinaryStateNode) + + ## If the child is an observation child ## + if length(node.edges.observation_children) > 0 + + #Extract the observation child + child = node.edges.observation_children[1] + + #Simple update with inifinte precision + if child.parameters.input_precision == Inf + posterior_precision = Inf + #Update with finite precision + else + posterior_precision = + 1 / (node.states.posterior_mean * (1 - node.states.posterior_mean)) + end + + ## If the child is a category child ## + elseif length(node.edges.category_children) > 0 + + posterior_precision = Inf + + else + @error "the binary state node $(node.name) has neither category nor observation children" + end + + return posterior_precision +end + +##### Mean update ##### +@doc raw""" + calculate_posterior_mean(node::BinaryStateNode) + +Calculates a node's posterior mean. + +Uses the equation +`` \mu = \frac{e^{-0.5 \cdot \pi_n \cdot \eta_1^2}}{\hat{\mu}_n \cdot e^{-0.5 \cdot \pi_n \cdot \eta_1^2} \; + 1-\hat{\mu}_n \cdot e^{-0.5 \cdot \pi_n \cdot \eta_2^2}} `` +""" +function calculate_posterior_mean(node::BinaryStateNode, update_type::HGFUpdateType) + + ## If the child is an observation child ## + if length(node.edges.observation_children) > 0 + + #Extract the child + child = node.edges.observation_children[1] + + #For missing inputs + if ismissing(child.states.input_value) + #Set the posterior to missing + posterior_mean = missing + else + #Update with infinte input precision + if child.parameters.input_precision == Inf + posterior_mean = child.states.input_value + + #Update with finite input precision + else + posterior_mean = + node.states.prediction_mean * exp( + -0.5 * + node.states.prediction_precision * + child.parameters.category_means[1]^2, + ) / ( + node.states.prediction_mean * exp( + -0.5 * + node.states.prediction_precision * + child.parameters.category_means[1]^2, + ) + + (1 - node.states.prediction_mean) * exp( + -0.5 * + node.states.prediction_precision * + child.parameters.category_means[2]^2, + ) + ) + end + end + + ## If the child is a category child ## + elseif length(node.edges.category_children) > 0 + + #Extract the child + child = node.edges.category_children[1] + + #Find the nodes' own category number + category_number = findfirst(child.edges.category_parent_order .== node.name) + + #Find the corresponding value in the child + posterior_mean = child.states.posterior[category_number] + + else + @error "the binary state node $(node.name) has neither category nor observation children" + end + + return posterior_mean +end + + +############################################### +######## Update value prediction error ######## +############################################### + +##### Superfunction ##### +""" + update_node_value_prediction_error!(node::AbstractStateNode) + +Update the value prediction error of a single state node. +""" +function update_node_value_prediction_error!(node::BinaryStateNode) + #Update value prediction error + node.states.value_prediction_error = calculate_value_prediction_error(node) + + return nothing +end + +@doc raw""" + calculate_value_prediction_error(node::AbstractNode) + +Calculate's a state node's value prediction error. + +Uses the equation +`` \delta_n = \mu_n - \hat{\mu}_n `` +""" +function calculate_value_prediction_error(node::BinaryStateNode) + node.states.posterior_mean - node.states.prediction_mean +end + +################################################### +######## Update precision prediction error ######## +################################################### + +##### Superfunction ##### +""" + update_node_precision_prediction_error!(node::BinaryStateNode) + +There is no volatility prediction error update for binary state nodes. +""" +function update_node_precision_prediction_error!(node::BinaryStateNode) + return nothing +end diff --git a/src/update_hgf/node_updates/categorical_input_node.jl b/src/update_hgf/node_updates/categorical_input_node.jl new file mode 100644 index 0000000..1c2e5f6 --- /dev/null +++ b/src/update_hgf/node_updates/categorical_input_node.jl @@ -0,0 +1,43 @@ +################################### +######## Update prediction ######## +################################### + +##### Superfunction ##### +""" + update_node_prediction!(node::CategoricalInputNode) + +There is no prediction update for categorical input nodes, as the prediction precision is constant. +""" +function update_node_prediction!(node::CategoricalInputNode, stepsize::Real) + return nothing +end + + +############################################### +######## Update value prediction error ######## +############################################### + +##### Superfunction ##### +""" + update_node_value_prediction_error!(node::CategoricalInputNode) + + There is no value prediction error update for categorical input nodes. +""" +function update_node_value_prediction_error!(node::CategoricalInputNode) + return nothing +end + + +################################################### +######## Update precision prediction error ######## +################################################### + +##### Superfunction ##### +""" + update_node_precision_prediction_error!(node::CategoricalInputNode) + +There is no volatility prediction error update for categorical input nodes. +""" +function update_node_precision_prediction_error!(node::CategoricalInputNode) + return nothing +end diff --git a/src/update_hgf/node_updates/categorical_state_node.jl b/src/update_hgf/node_updates/categorical_state_node.jl new file mode 100644 index 0000000..f2bf1c0 --- /dev/null +++ b/src/update_hgf/node_updates/categorical_state_node.jl @@ -0,0 +1,164 @@ +################################### +######## Update prediction ######## +################################### + +##### Superfunction ##### +""" + update_node_prediction!(node::CategoricalStateNode) + +Update the prediction of a single categorical state node. +""" +function update_node_prediction!(node::CategoricalStateNode, stepsize::Real) + + #Update prediction mean + node.states.prediction, node.states.parent_predictions = calculate_prediction(node) + + return nothing +end + + +@doc raw""" + function calculate_prediction(node::CategoricalStateNode) + +Calculate the prediction for a categorical state node. + +Uses the equation +`` \vec{\hat{\mu}_n}= \frac{\hat{\mu}_j}{\sum_{j=1}^{j\;binary \;parents} \hat{\mu}_j} `` +""" +function calculate_prediction(node::CategoricalStateNode) + + #Get parent posteriors + parent_posteriors = + map(x -> x.states.posterior_mean, collect(values(node.edges.category_parents))) + + #Get current parent predictions + parent_predictions = + map(x -> x.states.prediction_mean, collect(values(node.edges.category_parents))) + + #Get previous parent predictions + previous_parent_predictions = node.states.parent_predictions + + #If there was an observation + if any(!ismissing, node.states.posterior) + + #Calculate implied learning rate + implied_learning_rate = + ( + (parent_posteriors .- previous_parent_predictions) ./ + (parent_predictions .- previous_parent_predictions) + ) .- 1 + + #Calculate the prediction mean + prediction = + ((implied_learning_rate .* parent_predictions) .+ 1) ./ + sum(implied_learning_rate .* parent_predictions .+ 1) + + #If there was no observation + else + #Extract prediction from last timestep + prediction = node.states.prediction + end + + return prediction, parent_predictions + +end + + +################################## +######## Update posterior ######## +################################## + +##### Superfunction ##### +""" + update_node_posterior!(node::CategoricalStateNode) + +Update the posterior of a single categorical state node. +""" +function update_node_posterior!(node::CategoricalStateNode, update_type::ClassicUpdate) + + #Update posterior mean + node.states.posterior = calculate_posterior(node) + + return nothing +end + + +@doc raw""" + calculate_posterior(node::CategoricalStateNode) + +Calculate the posterior for a categorical state node. + +One hot encoding +`` \vec{u} = [0, 0, \dots ,1, \dots,0] `` +""" +function calculate_posterior(node::CategoricalStateNode) + + #Extract the child + child = node.edges.observation_children[1] + + #Initialize posterior as previous posterior + posterior = node.states.posterior + + #For missing inputs + if ismissing(child.states.input_value) + #Set the posterior to be all missing + posterior .= missing + + else + #Set all values to 0 + posterior .= zero(Real) + + #Set the posterior for the observed category to 1 + posterior[child.states.input_value] = 1 + end + return posterior +end + + +############################################### +######## Update value prediction error ######## +############################################### + +##### Superfunction ##### +""" + update_node_value_prediction_error!(node::AbstractStateNode) + +Update the value prediction error of a single state node. +""" +function update_node_value_prediction_error!(node::CategoricalStateNode) + #Update value prediction error + node.states.value_prediction_error = calculate_value_prediction_error(node) + + return nothing +end + +@doc raw""" + calculate_value_prediction_error(node::CategoricalStateNode) + +Calculate the value prediction error for a categorical state node. + +Uses the equation +`` \delta_n= u - \sum_{j=1}^{j\;value\;parents} \hat{\mu}_{j} `` +""" +function calculate_value_prediction_error(node::CategoricalStateNode) + + #Get the prediction error for each category + value_prediction_error = node.states.posterior - node.states.prediction + + return value_prediction_error +end + + +################################################### +######## Update precision prediction error ######## +################################################### + +##### Superfunction ##### +""" + update_node_precision_prediction_error!(node::CategoricalStateNode) + +There is no volatility prediction error update for categorical state nodes. +""" +function update_node_precision_prediction_error!(node::CategoricalStateNode) + return nothing +end diff --git a/src/update_hgf/node_updates/continuous_input_node.jl b/src/update_hgf/node_updates/continuous_input_node.jl new file mode 100644 index 0000000..b418591 --- /dev/null +++ b/src/update_hgf/node_updates/continuous_input_node.jl @@ -0,0 +1,175 @@ +################################### +######## Update prediction ######## +################################### + +##### Superfunction ##### +""" + update_node_prediction!(node::AbstractInputNode) + +Update the posterior of a single input node. +""" +function update_node_prediction!(node::ContinuousInputNode, stepsize::Real) + + #Update node prediction mean + node.states.prediction_mean = calculate_prediction_mean(node) + + #Update prediction precision + node.states.prediction_precision = calculate_prediction_precision(node) + + return nothing +end + + +##### Mean update ##### +function calculate_prediction_mean(node::ContinuousInputNode) + #Extract parents + observation_parents = node.edges.observation_parents + + #Initialize prediction at the bias + prediction_mean = node.parameters.bias + + #Sum the predictions of the parents + for parent in observation_parents + prediction_mean += parent.states.prediction_mean + end + + return prediction_mean +end + + +##### Precision update ##### +@doc raw""" + calculate_prediction_precision(node::AbstractInputNode) + +Calculates an input node's prediction precision. + +Uses the equation +`` \hat{\pi}_n = \frac{1}{\nu}_n `` +""" +function calculate_prediction_precision(node::ContinuousInputNode) + + #Extract noise parents + noise_parents = node.edges.noise_parents + + #Initialize noise from input noise parameter + predicted_noise = node.parameters.input_noise + + #Go through each noise parent + for parent in noise_parents + #Add its mean to the predicted noise + predicted_noise += + parent.states.posterior_mean * node.parameters.coupling_strengths[parent.name] + end + + #The prediction precision is the inverse of the predicted noise + prediction_precision = 1 / exp(predicted_noise) + + return prediction_precision +end + + + +############################################### +######## Update value prediction error ######## +############################################### + +##### Superfunction ##### +""" + update_node_value_prediction_error!(node::AbstractInputNode) + +Update the value prediction error of a single input node. +""" +function update_node_value_prediction_error!(node::ContinuousInputNode) + + #Calculate value prediction error + node.states.value_prediction_error = calculate_value_prediction_error(node) + + return nothing +end + + + +@doc raw""" + calculate_value_prediction_error(node::ContinuousInputNode) + +Calculate's an input node's value prediction error. + +Uses the equation +``\delta_n= u - \sum_{j=1}^{j\;value\;parents} \hat{\mu}_{j} `` +""" +function calculate_value_prediction_error(node::ContinuousInputNode) + #For missing input + if ismissing(node.states.input_value) + #Set the prediction error to missing + value_prediction_error = missing + else + #Get value prediction error between the prediction and the input value + value_prediction_error = node.states.input_value - node.states.prediction_mean + end + + return value_prediction_error +end + + +################################################### +######## Update precision prediction error ######## +################################################### + +##### Superfunction ##### +""" + update_node_precision_prediction_error!(node::AbstractInputNode) + +Update the value prediction error of a single input node. +""" +function update_node_precision_prediction_error!(node::ContinuousInputNode) + + #Calculate volatility prediction error, only if there are volatility parents + node.states.precision_prediction_error = calculate_precision_prediction_error(node) + + return nothing +end + +@doc raw""" + calculate_precision_prediction_error(node::ContinuousInputNode) + +Calculates an input node's volatility prediction error. + +Uses the equation +`` \mu'_j=\sum_{j=1}^{j\;value\;parents} \mu_{j} `` +`` \pi'_j=\frac{{\sum_{j=1}^{j\;value\;parents} \pi_{j}}}{j} `` +`` \Delta_n=\frac{\hat{\pi}_n}{\pi'_j} + \hat{\mu}_i\cdot (u -\mu'_j^2 )-1 `` +""" +function calculate_precision_prediction_error(node::ContinuousInputNode) + + #If there are no noise parents + if length(node.edges.noise_parents) == 0 + #Skip + return missing + end + + #For missing input + if ismissing(node.states.input_value) + #Set the prediction error to missing + precision_prediction_error = missing + else + #Extract parents + observation_parents = node.edges.observation_parents + + #Average the posterior precision of the observation parents + parents_average_posterior_precision = 0 + + for parent in observation_parents + parents_average_posterior_precision += parent.states.posterior_precision + end + + parents_average_posterior_precision = + parents_average_posterior_precision / length(observation_parents) + + #Get the noise prediction error using the average parent parents_posterior_precision + precision_prediction_error = + node.states.prediction_precision / parents_average_posterior_precision + + node.states.prediction_precision * node.states.value_prediction_error^2 - 1 + end + + return precision_prediction_error +end diff --git a/src/update_hgf/node_updates/continuous_state_node.jl b/src/update_hgf/node_updates/continuous_state_node.jl new file mode 100644 index 0000000..5876fa3 --- /dev/null +++ b/src/update_hgf/node_updates/continuous_state_node.jl @@ -0,0 +1,629 @@ +################################### +######## Update prediction ######## +################################### + +##### Superfunction ##### +""" + update_node_prediction!(node::ContinuousStateNode) + +Update the prediction of a single state node. +""" +function update_node_prediction!(node::ContinuousStateNode, stepsize::Real) + + #Update prediction mean + node.states.prediction_mean = calculate_prediction_mean(node, stepsize) + + #Update prediction precision + node.states.prediction_precision, node.states.effective_prediction_precision = + calculate_prediction_precision(node, stepsize) + + return nothing +end + +##### Mean update ##### +@doc raw""" + calculate_prediction_mean(node::AbstractNode) + +Calculates a node's prediction mean. + +Uses the equation +`` \hat{\mu}_i=\mu_i+\sum_{j=1}^{j\;value\;parents} \mu_{j} \cdot \alpha_{i,j} `` +""" +function calculate_prediction_mean(node::ContinuousStateNode, stepsize::Real) + #Get out drift parents + drift_parents = node.edges.drift_parents + + #Initialize the total drift as the baseline drift + predicted_drift = node.parameters.drift + + #For each drift parent + for parent in drift_parents + + #Get out the coupling transform + coupling_transform = node.parameters.coupling_transforms[parent.name] + + #Transform the parent's value + drift_increment = transform_parent_value( + parent.states.posterior_mean, + coupling_transform, + derivation_level = 0, + ) + + #Add the drift increment + predicted_drift += drift_increment * node.parameters.coupling_strengths[parent.name] + end + + #Multiply with stepsize + predicted_drift = stepsize * predicted_drift + + #Add the drift to the posterior to get the prediction mean + prediction_mean = + node.parameters.autoconnection_strength * node.states.posterior_mean + + predicted_drift + + return prediction_mean +end + +##### Precision update ##### +@doc raw""" + calculate_prediction_precision(node::AbstractNode) + +Calculates a node's prediction precision. + +Uses the equation +`` \hat{\pi}_i^ = `` +""" +function calculate_prediction_precision(node::ContinuousStateNode, stepsize::Real) + #Extract volatility parents + volatility_parents = node.edges.volatility_parents + + #Initialize the predicted volatility as the baseline volatility + predicted_volatility = node.parameters.volatility + + #Add contributions from volatility parents + for parent in volatility_parents + predicted_volatility += + parent.states.posterior_mean * node.parameters.coupling_strengths[parent.name] + end + + #Exponentiate and multiply with stepsize + predicted_volatility = stepsize * exp(predicted_volatility) + + #Calculate prediction precision + prediction_precision = 1 / (1 / node.states.posterior_precision + predicted_volatility) + + #Calculate the volatility-weighted effective precision + effective_prediction_precision = predicted_volatility * prediction_precision + + #If the posterior precision is negative + if prediction_precision < 0 + #Throw an error + throw( + #Of the custom type where samples are rejected + RejectParameters( + "With these parameters and inputs, the prediction precision of node $(node.name) becomes negative after $(length(node.history.prediction_precision)) inputs", + ), + ) + end + + return prediction_precision, effective_prediction_precision +end + + +################################## +######## Update posterior ######## +################################## + +##### Superfunction ##### +""" + update_node_posterior!(node::AbstractStateNode; update_type::HGFUpdateType) + +Update the posterior of a single continuous state node. This is the classic HGF update. +""" +function update_node_posterior!(node::ContinuousStateNode, update_type::ClassicUpdate) + #Update posterior precision + node.states.posterior_precision = calculate_posterior_precision(node, update_type) + + #Update posterior mean + node.states.posterior_mean = calculate_posterior_mean(node, update_type) + + return nothing +end + +""" + update_node_posterior!(node::AbstractStateNode) + +Update the posterior of a single continuous state node. This is the enahnced HGF update. +""" +function update_node_posterior!(node::ContinuousStateNode, update_type::EnhancedUpdate) + #Update posterior mean + node.states.posterior_mean = calculate_posterior_mean(node, update_type) + + #Update posterior precision + node.states.posterior_precision = calculate_posterior_precision(node, update_type) + + return nothing +end + +##### Precision update ##### +@doc raw""" + calculate_posterior_precision(node::AbstractNode) + +Calculates a node's posterior precision. + +Uses the equation +`` \pi_i^{'} = \hat{\pi}_i +\underbrace{\sum_{j=1}^{j\;children} \alpha_{j,i} \cdot \hat{\pi}_{j}} _\text{sum \;of \;VAPE \; continuous \; value \;chidren} `` +""" +function calculate_posterior_precision( + node::ContinuousStateNode, + update_type::HGFUpdateType, +) + + #Initialize as the node's own prediction + posterior_precision = node.states.prediction_precision + + #Add update terms from drift children + for child in node.edges.drift_children + posterior_precision += calculate_posterior_precision_increment( + node, + child, + DriftCoupling(), + update_type, + ) + end + + #Add update terms from observation children + for child in node.edges.observation_children + posterior_precision += + calculate_posterior_precision_increment(node, child, ObservationCoupling()) + end + + #Add update terms from probability children + for child in node.edges.probability_children + posterior_precision += + calculate_posterior_precision_increment(node, child, ProbabilityCoupling()) + end + + #Add update terms from volatility children + for child in node.edges.volatility_children + posterior_precision += + calculate_posterior_precision_increment(node, child, VolatilityCoupling()) + end + + #Add update terms from noise children + for child in node.edges.noise_children + posterior_precision += + calculate_posterior_precision_increment(node, child, NoiseCoupling()) + end + + + #If the posterior precision is negative + if posterior_precision < 0 + #Throw an error + throw( + #Of the custom type where samples are rejected + RejectParameters( + "With these parameters and inputs, the posterior precision of node $(node.name) becomes negative after $(length(node.history.posterior_precision)) inputs", + ), + ) + end + + return posterior_precision +end + +## Drift coupling ## +function calculate_posterior_precision_increment( + node::ContinuousStateNode, + child::ContinuousStateNode, + coupling_type::DriftCoupling, + update_type::HGFUpdateType, +) + #Get out the coupling strength and coupling stransform + coupling_strength = child.parameters.coupling_strengths[node.name] + coupling_transform = child.parameters.coupling_transforms[node.name] + + #Calculate the increment + child.states.prediction_precision * ( + coupling_strength^2 * transform_parent_value( + coupling_transform, + node.states.posterior_mean, + derivation_level = 1, + ) - + coupling_strength * + node.states.value_prediction_error * + transform_parent_value( + coupling_transform, + node.states.posterior_mean, + derivation_level = 2, + ) + ) + +end + +## Observation coupling ## +function calculate_posterior_precision_increment( + node::ContinuousStateNode, + child::ContinuousInputNode, + coupling_type::ObservationCoupling, +) + + #If missing input + if ismissing(child.states.input_value) + #No increment + return 0 + else + return child.states.prediction_precision + end +end + +## Probability coupling ## +function calculate_posterior_precision_increment( + node::ContinuousStateNode, + child::BinaryStateNode, + coupling_type::ProbabilityCoupling, +) + + #If there is a missing posterior (due to a missing input) + if ismissing(child.states.posterior_mean) + #No update + return 0 + else + return child.parameters.coupling_strengths[node.name]^2 / + child.states.prediction_precision + end +end + +@doc raw""" + calculate_posterior_precision_vope(node::AbstractNode, child::AbstractNode) + +Calculates the posterior precision update term for a single continuous volatility child to a state node. + +Uses the equation +`` `` +""" +function calculate_posterior_precision_increment( + node::ContinuousStateNode, + child::ContinuousStateNode, + coupling_type::VolatilityCoupling, +) + + 1 / 2 * + ( + child.parameters.coupling_strengths[node.name] * + child.states.effective_prediction_precision + )^2 + + child.states.precision_prediction_error * + ( + child.parameters.coupling_strengths[node.name] * + child.states.effective_prediction_precision + )^2 - + 1 / 2 * + child.parameters.coupling_strengths[node.name]^2 * + child.states.effective_prediction_precision * + child.states.precision_prediction_error +end + +function calculate_posterior_precision_increment( + node::ContinuousStateNode, + child::ContinuousInputNode, + coupling_type::NoiseCoupling, +) + + #If the input node child had a missing input + if ismissing(child.states.input_value) + #No increment + return 0 + else + update_term = + 1 / 2 * (child.parameters.coupling_strengths[node.name])^2 + + child.states.precision_prediction_error * + (child.parameters.coupling_strengths[node.name])^2 - + 1 / 2 * + child.parameters.coupling_strengths[node.name]^2 * + child.states.precision_prediction_error + + return update_term + end +end + +##### Mean update ##### +@doc raw""" + calculate_posterior_mean(node::AbstractNode) + +Calculates a node's posterior mean. + +Uses the equation +`` `` +""" +function calculate_posterior_mean(node::ContinuousStateNode, update_type::HGFUpdateType) + + #Initialize as the prediction + posterior_mean = node.states.prediction_mean + + #Add update terms from drift children + for child in node.edges.drift_children + posterior_mean += + calculate_posterior_mean_increment(node, child, DriftCoupling(), update_type) + end + + #Add update terms from observation children + for child in node.edges.observation_children + posterior_mean += calculate_posterior_mean_increment( + node, + child, + ObservationCoupling(), + update_type, + ) + end + + #Add update terms from probability children + for child in node.edges.probability_children + posterior_mean += calculate_posterior_mean_increment( + node, + child, + ProbabilityCoupling(), + update_type, + ) + end + + #Add update terms from volatility children + for child in node.edges.volatility_children + posterior_mean += calculate_posterior_mean_increment( + node, + child, + VolatilityCoupling(), + update_type, + ) + end + + #Add update terms from noise children + for child in node.edges.noise_children + posterior_mean += + calculate_posterior_mean_increment(node, child, NoiseCoupling(), update_type) + end + + return posterior_mean +end + +## Classic drift coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousStateNode, + coupling_type::DriftCoupling, + update_type::ClassicUpdate, +) + ( + ( + child.parameters.coupling_strengths[node.name] * + transform_parent_value( + child.parameters.coupling_transforms[node.name], + node.states.posterior_mean, + derivation_level = 1, + ) * + child.states.prediction_precision + ) / node.states.posterior_precision + ) * child.states.value_prediction_error +end + +## Enhanced drift coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousStateNode, + coupling_type::DriftCoupling, + update_type::EnhancedUpdate, +) + ( + ( + child.parameters.coupling_strengths[node.name] * + transform_parent_value( + child.parameters.coupling_transforms[node.name], + node.states.posterior_mean, + derivation_level = 1, + ) * + child.states.prediction_precision + ) / node.states.prediction_precision + ) * child.states.value_prediction_error + +end + +## Classic observation coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousInputNode, + coupling_type::ObservationCoupling, + update_type::ClassicUpdate, +) + #For input node children with missing input + if ismissing(child.states.input_value) + #No update + return 0 + else + return (child.states.prediction_precision / node.states.posterior_precision) * + child.states.value_prediction_error + end +end + +## Enhanced observation coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousInputNode, + coupling_type::ObservationCoupling, + update_type::EnhancedUpdate, +) + #For input node children with missing input + if ismissing(child.states.input_value) + #No update + return 0 + else + return (child.states.prediction_precision / node.states.prediction_precision) * + child.states.value_prediction_error + end +end + +## Classic probability coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::BinaryStateNode, + coupling_type::ProbabilityCoupling, + update_type::ClassicUpdate, +) + #If the posterior is missing (due to missing inputs) + if ismissing(child.states.posterior_mean) + #No update + return 0 + else + return child.parameters.coupling_strengths[node.name] / + node.states.posterior_precision * child.states.value_prediction_error + end +end + +## Enhanced Probability coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::BinaryStateNode, + coupling_type::ProbabilityCoupling, + update_type::EnhancedUpdate, +) + #If the posterior is missing (due to missing inputs) + if ismissing(child.states.posterior_mean) + #No update + return 0 + else + return child.parameters.coupling_strengths[node.name] / + node.states.prediction_precision * child.states.value_prediction_error + end +end + +## Classic Volatility coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousStateNode, + coupling_type::VolatilityCoupling, + update_type::ClassicUpdate, +) + 1 / 2 * ( + child.parameters.coupling_strengths[node.name] * + child.states.effective_prediction_precision + ) / node.states.posterior_precision * child.states.precision_prediction_error +end + +## Enhanced Volatility coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousStateNode, + coupling_type::VolatilityCoupling, + update_type::EnhancedUpdate, +) + 1 / 2 * ( + child.parameters.coupling_strengths[node.name] * + child.states.effective_prediction_precision + ) / node.states.prediction_precision * child.states.precision_prediction_error +end + +## Classic Noise coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousInputNode, + coupling_type::NoiseCoupling, + update_type::ClassicUpdate, +) + #For input node children with missing input + if ismissing(child.states.input_value) + #No update + return 0 + else + update_term = + 1 / 2 * (child.parameters.coupling_strengths[node.name]) / + node.states.posterior_precision * child.states.precision_prediction_error + + return update_term + end +end + +## Enhanced Noise coupling ## +function calculate_posterior_mean_increment( + node::ContinuousStateNode, + child::ContinuousInputNode, + coupling_type::NoiseCoupling, + update_type::EnhancedUpdate, +) + #For input node children with missing input + if ismissing(child.states.input_value) + #No update + return 0 + else + update_term = + 1 / 2 * (child.parameters.coupling_strengths[node.name]) / + node.states.prediction_precision * child.states.precision_prediction_error + + return update_term + end +end + +############################################### +######## Update value prediction error ######## +############################################### + +##### Superfunction ##### +""" + update_node_value_prediction_error!(node::AbstractStateNode) + +Update the value prediction error of a single state node. +""" +function update_node_value_prediction_error!(node::ContinuousStateNode) + #Update value prediction error + node.states.value_prediction_error = calculate_value_prediction_error(node) + + return nothing +end + +@doc raw""" + calculate_value_prediction_error(node::AbstractNode) + +Calculate's a state node's value prediction error. + +Uses the equation +`` \delta_n = \mu_n - \hat{\mu}_n `` +""" +function calculate_value_prediction_error(node::ContinuousStateNode) + node.states.posterior_mean - node.states.prediction_mean +end + +################################################### +######## Update precision prediction error ######## +################################################### + +##### Superfunction ##### +""" + update_node_precision_prediction_error!(node::AbstractStateNode) + +Update the volatility prediction error of a single state node. +""" +function update_node_precision_prediction_error!(node::ContinuousStateNode) + + #Update volatility prediction error, only if there are volatility parents + node.states.precision_prediction_error = calculate_precision_prediction_error(node) + + return nothing +end + +@doc raw""" + calculate_precision_prediction_error(node::AbstractNode) + +Calculates a state node's volatility prediction error. + +Uses the equation +`` \Delta_n = \frac{\hat{\pi}_n}{\pi_n} + \hat{\pi}_n \cdot \delta_n^2-1 `` +""" +function calculate_precision_prediction_error(node::ContinuousStateNode) + + #If there are no volatility parents + if length(node.edges.volatility_parents) == 0 + #Skip + return missing + end + + node.states.prediction_precision / node.states.posterior_precision + + node.states.prediction_precision * node.states.value_prediction_error^2 - 1 + +end diff --git a/src/update_hgf/nonlinear_transforms.jl b/src/update_hgf/nonlinear_transforms.jl new file mode 100644 index 0000000..2217e17 --- /dev/null +++ b/src/update_hgf/nonlinear_transforms.jl @@ -0,0 +1,43 @@ +#Transformation (of varioius derivations) for a linear transformation +function transform_parent_value( + transform_type::LinearTransform, + parent_value::Real; + derivation_level::Integer, +) + if derivation_level == 0 + return parent_value + elseif derivation_level == 0 + return 1 + elseif derivation_level == 0 + return 0 + else + @error "derivation level is misspecified" + end +end + +#Transformation (of varioius derivations) for a nonlinear transformation +function transform_parent_value( + transform_type::NonlinearTransform, + parent_value::Real, + derivation_level::Integer, +) + + #Get the transformation function that fits the derivation level + if derivation_level == 0 + transform_function = node.parameters.coupling_transforms[parent.name].base_function + elseif derivation_level == 1 + transform_function = + node.parameters.coupling_transforms[parent.name].first_derivation + elseif derivation_level == 2 + transform_function = + node.parameters.coupling_transforms[parent.name].second_derivation + else + @error "derivation level is misspecified" + end + + #Get the transformation parameters + transform_parameters = node.parameters.coupling_transforms[parent.name].parameters + + #Transform the value + return transform_function(parent_value, transform_parameters) +end diff --git a/src/update_hgf/update_equations.jl b/src/update_hgf/update_equations.jl deleted file mode 100644 index d541877..0000000 --- a/src/update_hgf/update_equations.jl +++ /dev/null @@ -1,813 +0,0 @@ -####################################### -######## Continuous State Node ######## -####################################### - -######## Prediction update ######## - -### Mean update ### -@doc raw""" - calculate_prediction_mean(node::AbstractNode) - -Calculates a node's prediction mean. - -Uses the equation -`` \hat{\mu}_i=\mu_i+\sum_{j=1}^{j\;value\;parents} \mu_{j} \cdot \alpha_{i,j} `` -""" -function calculate_prediction_mean(node::AbstractNode) - #Get out value parents - value_parents = node.value_parents - - #Initialize the total drift as the basline drift plus the autoregression drift - predicted_drift = - node.parameters.drift + - node.parameters.autoregression_strength * - (node.parameters.autoregression_target - node.states.posterior_mean) - - #Add contributions from value parents - for parent in value_parents - predicted_drift += - parent.states.posterior_mean * node.parameters.value_coupling[parent.name] - end - - #Add the drift to the posterior to get the prediction mean - prediction_mean = node.states.posterior_mean + 1 * predicted_drift - - return prediction_mean -end - -### Volatility update ### -@doc raw""" - calculate_predicted_volatility(node::AbstractNode) - -Calculates a node's prediction volatility. - -Uses the equation -`` \nu_i =exp( \omega_i + \sum_{j=1}^{j\;volatility\;parents} \mu_{j} \cdot \kappa_{i,j}} `` -""" -function calculate_predicted_volatility(node::AbstractNode) - volatility_parents = node.volatility_parents - - predicted_volatility = node.parameters.volatility - - for parent in volatility_parents - predicted_volatility += - parent.states.posterior_mean * node.parameters.volatility_coupling[parent.name] - end - - return exp(predicted_volatility) -end - -### Precision update ### -@doc raw""" - calculate_prediction_precision(node::AbstractNode) - -Calculates a node's prediction precision. - -Uses the equation -`` \hat{\pi}_i^ = \frac{1}{\frac{1}{\pi_i}+\nu_i^} `` -""" -function calculate_prediction_precision(node::AbstractNode) - prediction_precision = - 1 / (1 / node.states.posterior_precision + node.states.predicted_volatility) - - #If the posterior precision is negative - if prediction_precision < 0 - #Throw an error - throw( - #Of the custom type where samples are rejected - RejectParameters( - "With these parameters and inputs, the prediction precision of node $(node.name) becomes negative after $(length(node.history.prediction_precision)) inputs", - ), - ) - end - - return prediction_precision -end - -@doc raw""" - calculate_volatility_weighted_prediction_precision(node::AbstractNode) - -Calculates a node's auxiliary prediction precision. - -Uses the equation -`` \gamma_i = \nu_i \cdot \hat{\pi}_i `` -""" -function calculate_volatility_weighted_prediction_precision(node::AbstractNode) - node.states.predicted_volatility * node.states.prediction_precision -end - -######## Posterior update functions ######## - -### Precision update ### -@doc raw""" - calculate_posterior_precision(node::AbstractNode) - -Calculates a node's posterior precision. - -Uses the equation -`` \pi_i^{'} = \hat{\pi}_i +\underbrace{\sum_{j=1}^{j\;children} \alpha_{j,i} \cdot \hat{\pi}_{j}} _\text{sum \;of \;VAPE \; continuous \; value \;chidren} `` -""" -function calculate_posterior_precision(node::AbstractNode) - value_children = node.value_children - volatility_children = node.volatility_children - - #Initialize as the node's own prediction - posterior_precision = node.states.prediction_precision - - #Add update terms from value children - for child in value_children - posterior_precision += calculate_posterior_precision_vape(node, child) - end - - #Add update terms from volatility children - for child in volatility_children - posterior_precision += calculate_posterior_precision_vope(node, child) - end - - #If the posterior precision is negative - if posterior_precision < 0 - #Throw an error - throw( - #Of the custom type where samples are rejected - RejectParameters( - "With these parameters and inputs, the posterior precision of node $(node.name) becomes negative after $(length(node.history.posterior_precision)) inputs", - ), - ) - end - - return posterior_precision -end - -@doc raw""" - calculate_posterior_precision_vape(node::AbstractNode, child::AbstractNode) - -Calculates the posterior precision update term for a single continuous value child to a state node. - -Uses the equation -`` `` -""" -function calculate_posterior_precision_vape(node::AbstractNode, child::AbstractNode) - - #For input node children with missing input - if child isa AbstractInputNode && ismissing(child.states.input_value) - #No update - return 0 - else - return child.parameters.value_coupling[node.name] * - child.states.prediction_precision - end -end - -@doc raw""" - calculate_posterior_precision_vape(node::AbstractNode, child::BinaryStateNode) - -Calculates the posterior precision update term for a single binary value child to a state node. - -Uses the equation -`` `` -""" -function calculate_posterior_precision_vape(node::AbstractNode, child::BinaryStateNode) - - #For missing inputs - if ismissing(child.states.posterior_mean) - #No update - return 0 - else - return child.parameters.value_coupling[node.name]^2 / - child.states.prediction_precision - end -end - -@doc raw""" - calculate_posterior_precision_vope(node::AbstractNode, child::AbstractNode) - -Calculates the posterior precision update term for a single continuous volatility child to a state node. - -Uses the equation -`` `` -""" -function calculate_posterior_precision_vope(node::AbstractNode, child::AbstractNode) - - #For input node children with missing input - if child isa AbstractInputNode && ismissing(child.states.input_value) - #No update - return 0 - else - update_term = - 1 / 2 * - ( - child.parameters.volatility_coupling[node.name] * - child.states.volatility_weighted_prediction_precision - )^2 + - child.states.volatility_prediction_error * - ( - child.parameters.volatility_coupling[node.name] * - child.states.volatility_weighted_prediction_precision - )^2 - - 1 / 2 * - child.parameters.volatility_coupling[node.name]^2 * - child.states.volatility_weighted_prediction_precision * - child.states.volatility_prediction_error - - return update_term - end -end - -### Mean update ### -@doc raw""" - calculate_posterior_mean(node::AbstractNode) - -Calculates a node's posterior mean. - -Uses the equation -`` `` -""" -function calculate_posterior_mean(node::AbstractNode, update_type::HGFUpdateType) - value_children = node.value_children - volatility_children = node.volatility_children - - #Initialize as the prediction - posterior_mean = node.states.prediction_mean - - #Add update terms from value children - for child in value_children - posterior_mean += - calculate_posterior_mean_value_child_increment(node, child, update_type) - end - - #Add update terms from volatility children - for child in volatility_children - posterior_mean += - calculate_posterior_mean_volatility_child_increment(node, child, update_type) - end - - return posterior_mean -end - -@doc raw""" - calculate_posterior_mean_value_child_increment(node::AbstractNode, child::AbstractNode) - -Calculates the posterior mean update term for a single continuous value child to a state node. -This is the classic HGF update. - -Uses the equation -`` `` -""" -function calculate_posterior_mean_value_child_increment( - node::AbstractNode, - child::AbstractNode, - update_type::HGFUpdateType, -) - #For input node children with missing input - if child isa AbstractInputNode && ismissing(child.states.input_value) - #No update - return 0 - else - update_term = - ( - child.parameters.value_coupling[node.name] * - child.states.prediction_precision - ) / node.states.posterior_precision * child.states.value_prediction_error - - return update_term - end -end - -@doc raw""" - calculate_posterior_mean_value_child_increment(node::AbstractNode, child::AbstractNode) - -Calculates the posterior mean update term for a single continuous value child to a state node. -This is the enhanced HGF update. - -Uses the equation -`` `` -""" -function calculate_posterior_mean_value_child_increment( - node::AbstractNode, - child::AbstractNode, - update_type::EnhancedUpdate, -) - #For input node children with missing input - if child isa AbstractInputNode && ismissing(child.states.input_value) - #No update - return 0 - else - update_term = - ( - child.parameters.value_coupling[node.name] * - child.states.prediction_precision - ) / node.states.prediction_precision * child.states.value_prediction_error - - return update_term - end -end - -@doc raw""" - calculate_posterior_mean_value_child_increment(node::AbstractNode, child::BinaryStateNode) - -Calculates the posterior mean update term for a single binary value child to a state node. -This is the classic HGF update. - -Uses the equation -`` `` -""" -function calculate_posterior_mean_value_child_increment( - node::AbstractNode, - child::BinaryStateNode, - update_type::HGFUpdateType, -) - #For missing inputs - if ismissing(child.states.posterior_mean) - #No update - return 0 - else - return child.parameters.value_coupling[node.name] / - (node.states.posterior_precision) * child.states.value_prediction_error - end -end - -@doc raw""" - calculate_posterior_mean_value_child_increment(node::AbstractNode, child::BinaryStateNode) - -Calculates the posterior mean update term for a single binary value child to a state node. -This is the enhanced HGF update. - -Uses the equation -`` `` -""" -function calculate_posterior_mean_value_child_increment( - node::AbstractNode, - child::BinaryStateNode, - update_type::EnhancedUpdate, -) - #For missing inputs - if ismissing(child.states.posterior_mean) - #No update - return 0 - else - return child.parameters.value_coupling[node.name] / - (node.states.prediction_precision) * child.states.value_prediction_error - end -end - -@doc raw""" - calculate_posterior_mean_volatility_child_increment(node::AbstractNode, child::AbstractNode) - -Calculates the posterior mean update term for a single continuos volatility child to a state node. - -Uses the equation -`` `` -""" -function calculate_posterior_mean_volatility_child_increment( - node::AbstractNode, - child::AbstractNode, - update_type::HGFUpdateType, -) - #For input node children with missing input - if child isa AbstractInputNode && ismissing(child.states.input_value) - #No update - return 0 - else - update_term = - 1 / 2 * ( - child.parameters.volatility_coupling[node.name] * - child.states.volatility_weighted_prediction_precision - ) / node.states.posterior_precision * child.states.volatility_prediction_error - - return update_term - end -end - -@doc raw""" - calculate_posterior_mean_volatility_child_increment(node::AbstractNode, child::AbstractNode) - -Calculates the posterior mean update term for a single continuos volatility child to a state node. - -Uses the equation -`` `` -""" -function calculate_posterior_mean_volatility_child_increment( - node::AbstractNode, - child::AbstractNode, - update_type::EnhancedUpdate, -) - #For input node children with missing input - if child isa AbstractInputNode && ismissing(child.states.input_value) - #No update - return 0 - else - update_term = - 1 / 2 * ( - child.parameters.volatility_coupling[node.name] * - child.states.volatility_weighted_prediction_precision - ) / node.states.prediction_precision * child.states.volatility_prediction_error - - return update_term - end -end - -######## Prediction error update functions ######## -@doc raw""" - calculate_value_prediction_error(node::AbstractNode) - -Calculate's a state node's value prediction error. - -Uses the equation -`` \delta_n = \mu_n - \hat{\mu}_n `` -""" -function calculate_value_prediction_error(node::AbstractNode) - node.states.posterior_mean - node.states.prediction_mean -end - -@doc raw""" - calculate_volatility_prediction_error(node::AbstractNode) - -Calculates a state node's volatility prediction error. - -Uses the equation -`` \Delta_n = \frac{\hat{\pi}_n}{\pi_n} + \hat{\pi}_n \cdot \delta_n^2-1 `` -""" -function calculate_volatility_prediction_error(node::AbstractNode) - node.states.prediction_precision / node.states.posterior_precision + - node.states.prediction_precision * node.states.value_prediction_error^2 - 1 -end - - -############################################## -######## Binary State Node Variations ######## -############################################## - -######## Prediction update ######## - -### Mean update ### -@doc raw""" - calculate_prediction_mean(node::BinaryStateNode) - -Calculates a binary state node's prediction mean. - -Uses the equation -`` \hat{\mu}_n= \big(1+e^{\sum_{j=1}^{j\;value \; parents} \hat{\mu}_{j}}\big)^{-1} `` -""" -function calculate_prediction_mean(node::BinaryStateNode) - value_parents = node.value_parents - - prediction_mean = 0 - - for parent in value_parents - prediction_mean += - parent.states.prediction_mean * node.parameters.value_coupling[parent.name] - end - - prediction_mean = 1 / (1 + exp(-prediction_mean)) - - return prediction_mean -end - -### Precision update ### -@doc raw""" - calculate_prediction_precision(node::BinaryStateNode) - -Calculates a binary state node's prediction precision. - -Uses the equation -`` \hat{\pi}_n = \frac{1}{\hat{\mu}_n \cdot (1-\hat{\mu}_n)} `` -""" -function calculate_prediction_precision(node::BinaryStateNode) - 1 / (node.states.prediction_mean * (1 - node.states.prediction_mean)) -end - -######## Posterior update functions ######## - -### Precision update ### -@doc raw""" - calculate_posterior_precision(node::BinaryStateNode) - -Calculates a binary node's posterior precision. - -Uses the equations - -`` \pi_n = inf `` -if the precision is infinite - - `` \pi_n = \frac{1}{\hat{\mu}_n \cdot (1-\hat{\mu}_n)} `` - if the precision is other than infinite -""" -function calculate_posterior_precision(node::BinaryStateNode) - #Extract the child - child = node.value_children[1] - - #Simple update with inifinte precision - if child isa CategoricalStateNode || child.parameters.input_precision == Inf - posterior_precision = Inf - #Update with finite precision - else - posterior_precision = - 1 / (node.states.posterior_mean * (1 - node.states.posterior_mean)) - end - - return posterior_precision -end - -### Mean update ### -@doc raw""" - calculate_posterior_mean(node::BinaryStateNode) - -Calculates a node's posterior mean. - -Uses the equation -`` \mu = \frac{e^{-0.5 \cdot \pi_n \cdot \eta_1^2}}{\hat{\mu}_n \cdot e^{-0.5 \cdot \pi_n \cdot \eta_1^2} \; + 1-\hat{\mu}_n \cdot e^{-0.5 \cdot \pi_n \cdot \eta_2^2}} `` -""" -function calculate_posterior_mean(node::BinaryStateNode, update_type::HGFUpdateType) - #Extract the child - child = node.value_children[1] - - #Update with categorical state node child - if child isa CategoricalStateNode - - #Find the nodes' own category number - category_number = findfirst(child.category_parent_order .== node.name) - - #Find the corresponding value in the child - posterior_mean = child.states.posterior[category_number] - - #For binary input node children - elseif child isa BinaryInputNode - - #For missing inputs - if ismissing(child.states.input_value) - #Set the posterior to missing - posterior_mean = missing - else - #Update with infinte input precision - if child.parameters.input_precision == Inf - posterior_mean = child.states.input_value - - #Update with finite input precision - else - posterior_mean = - node.states.prediction_mean * exp( - -0.5 * - node.states.prediction_precision * - child.parameters.category_means[1]^2, - ) / ( - node.states.prediction_mean * exp( - -0.5 * - node.states.prediction_precision * - child.parameters.category_means[1]^2, - ) + - (1 - node.states.prediction_mean) * exp( - -0.5 * - node.states.prediction_precision * - child.parameters.category_means[2]^2, - ) - ) - end - end - end - return posterior_mean -end - - -################################################### -######## Categorical State Node Variations ######## -################################################### -@doc raw""" - calculate_posterior(node::CategoricalStateNode) - -Calculate the posterior for a categorical state node. - -One hot encoding -`` \vec{u} = [0, 0, \dots ,1, \dots,0] `` -""" -function calculate_posterior(node::CategoricalStateNode) - - #Get child - child = node.value_children[1] - - #Initialize posterior as previous posterior - posterior = node.states.posterior - - #For missing inputs - if ismissing(child.states.input_value) - #Set the posterior to be all missing - posterior .= missing - - else - #Set all values to 0 - posterior .= zero(Real) - - #Set the posterior for the observed category to 1 - posterior[child.states.input_value] = 1 - end - return posterior -end - -@doc raw""" - function calculate_prediction(node::CategoricalStateNode) - -Calculate the prediction for a categorical state node. - -Uses the equation -`` \vec{\hat{\mu}_n}= \frac{\hat{\mu}_j}{\sum_{j=1}^{j\;binary \;parents} \hat{\mu}_j} `` -""" -function calculate_prediction(node::CategoricalStateNode) - - #Get parent posteriors - parent_posteriors = - map(x -> x.states.posterior_mean, collect(values(node.value_parents))) - - #Get current parent predictions - parent_predictions = - map(x -> x.states.prediction_mean, collect(values(node.value_parents))) - - #Get previous parent predictions - previous_parent_predictions = node.states.parent_predictions - - #If there was an observation - if any(!ismissing, node.states.posterior) - - #Calculate implied learning rate - implied_learning_rate = - ( - (parent_posteriors .- previous_parent_predictions) ./ - (parent_predictions .- previous_parent_predictions) - ) .- 1 - - # calculate the prediction mean - prediction = - ((implied_learning_rate .* parent_predictions) .+ 1) ./ - sum(implied_learning_rate .* parent_predictions .+ 1) - - #If there was no observation - else - #Extract prediction from last timestep - prediction = node.states.prediction - end - - return prediction, parent_predictions - -end - -@doc raw""" - calculate_value_prediction_error(node::CategoricalStateNode) - -Calculate the value prediction error for a categorical state node. - -Uses the equation -`` \delta_n= u - \sum_{j=1}^{j\;value\;parents} \hat{\mu}_{j} `` -""" -function calculate_value_prediction_error(node::CategoricalStateNode) - - #Get the prediction error for each category - value_prediction_error = node.states.posterior - node.states.prediction - - return value_prediction_error -end - - -################################################### -######## Conntinuous Input Node Variations ######## -################################################### - -@doc raw""" - calculate_predicted_volatility(node::AbstractInputNode) - -Calculates an input node's prediction volatility. - -Uses the equation -`` \nu_i =exp( \omega_i + \sum_{j=1}^{j\;volatility\;parents} \mu_{j} \cdot \kappa_{i,j}} `` -""" -function calculate_predicted_volatility(node::AbstractInputNode) - volatility_parents = node.volatility_parents - - predicted_volatility = node.parameters.input_noise - - for parent in volatility_parents - predicted_volatility += - parent.states.posterior_mean * node.parameters.volatility_coupling[parent.name] - end - - return exp(predicted_volatility) -end - -@doc raw""" - calculate_prediction_precision(node::AbstractInputNode) - -Calculates an input node's prediction precision. - -Uses the equation -`` \hat{\pi}_n = \frac{1}{\nu}_n `` -""" -function calculate_prediction_precision(node::AbstractInputNode) - - #Doesn't use own posterior precision - 1 / node.states.predicted_volatility -end - -""" - calculate_volatility_weighted_prediction_precision(node::AbstractInputNode) - -An input node's auxiliary prediction precision is always 1. -""" -function calculate_volatility_weighted_prediction_precision(node::AbstractInputNode) - 1 -end - -@doc raw""" - calculate_value_prediction_error(node::ContinuousInputNode) - -Calculate's an input node's value prediction error. - -Uses the equation -``\delta_n= u - \sum_{j=1}^{j\;value\;parents} \hat{\mu}_{j} `` -""" -function calculate_value_prediction_error(node::ContinuousInputNode) - #For missing input - if ismissing(node.states.input_value) - #Set the prediction error to missing - value_prediction_error = missing - else - #Extract parents - value_parents = node.value_parents - - #Sum the prediction_means of the parents - parents_prediction_mean = 0 - for parent in value_parents - parents_prediction_mean += parent.states.prediction_mean - end - - #Get VOPE using parents_prediction_mean instead of own - value_prediction_error = node.states.input_value - parents_prediction_mean - end - return value_prediction_error -end - -@doc raw""" - calculate_volatility_prediction_error(node::ContinuousInputNode) - -Calculates an input node's volatility prediction error. - -Uses the equation -`` \mu'_j=\sum_{j=1}^{j\;value\;parents} \mu_{j} `` -`` \pi'_j=\frac{{\sum_{j=1}^{j\;value\;parents} \pi_{j}}}{j} `` -`` \Delta_n=\frac{\hat{\pi}_n}{\pi'_j} + \hat{\mu}_i\cdot (u -\mu'_j^2 )-1 `` -""" -function calculate_volatility_prediction_error(node::ContinuousInputNode) - - #For missing input - if ismissing(node.states.input_value) - #Set the prediction error to missing - volatility_prediction_error = missing - else - #Extract parents - value_parents = node.value_parents - - #Sum the posterior mean and average the posterior precision of the value parents - parents_posterior_mean = 0 - parents_posterior_precision = 0 - - for parent in value_parents - parents_posterior_mean += parent.states.posterior_mean - parents_posterior_precision += parent.states.posterior_precision - end - - parents_posterior_precision = parents_posterior_precision / length(value_parents) - - #Get the VOPE using parents_posterior_precision and parents_posterior_mean - volatility_prediction_error = - node.states.prediction_precision / parents_posterior_precision + - node.states.prediction_precision * - (node.states.input_value - parents_posterior_mean)^2 - 1 - end - - return volatility_prediction_error -end - - -############################################## -######## Binary Input Node Variations ######## -############################################## -@doc raw""" - calculate_value_prediction_error(node::BinaryInputNode) - -Calculates the prediciton error of a binary input node with finite precision. - -Uses the equation -`` \delta_n= u - \sum_{j=1}^{j\;value\;parents} \hat{\mu}_{j} `` -""" -function calculate_value_prediction_error(node::BinaryInputNode) - - #For missing input - if ismissing(node.states.input_value) - #Set the prediction error to missing - value_prediction_error = [missing, missing] - else - #Substract to find the difference to each of the Gaussian means - value_prediction_error = node.parameters.category_means .- node.states.input_value - end -end - - -################################################### -######## Categorical Input Node Variations ######## -################################################### diff --git a/src/update_hgf/update_hgf.jl b/src/update_hgf/update_hgf.jl index e644d8c..05737f7 100644 --- a/src/update_hgf/update_hgf.jl +++ b/src/update_hgf/update_hgf.jl @@ -18,50 +18,51 @@ function update_hgf!( Missing, Vector{<:Union{Real,Missing}}, Dict{String,<:Union{Real,Missing}}, - }, + }; + stepsize::Real = 1, ) - ## Update node predictions from last timestep + ### Update node predictions from last timestep ### #For each node (in the opposite update order) for node in reverse(hgf.ordered_nodes.all_state_nodes) #Update its prediction from last trial - update_node_prediction!(node) + update_node_prediction!(node, stepsize) end #For each input node, in the specified update order for node in reverse(hgf.ordered_nodes.input_nodes) #Update its prediction from last trial - update_node_prediction!(node) + update_node_prediction!(node, stepsize) end - ## Supply inputs to input nodes + ### Supply inputs to input nodes ### enter_node_inputs!(hgf, inputs) - ## Update input node value prediction errors + ### Update input node value prediction errors ### #For each input node, in the specified update order for node in hgf.ordered_nodes.input_nodes #Update its value prediction error update_node_value_prediction_error!(node) end - ## Update input node value parent posteriors + ### Update input node value parent posteriors ### #For each node that is a value parent of an input node for node in hgf.ordered_nodes.early_update_state_nodes #Update its posterior update_node_posterior!(node, node.update_type) #And its value prediction error update_node_value_prediction_error!(node) - #And its volatility prediction error - update_node_volatility_prediction_error!(node) + #And its precision prediction error + update_node_precision_prediction_error!(node) end - ## Update input node volatility prediction errors + ### Update input node precision prediction errors ### #For each input node, in the specified update order for node in hgf.ordered_nodes.input_nodes #Update its value prediction error - update_node_volatility_prediction_error!(node) + update_node_precision_prediction_error!(node) end - ## Update remaining state nodes + ### Update remaining state nodes ### #For each state node, in the specified update order for node in hgf.ordered_nodes.late_update_state_nodes #Update its posterior @@ -69,7 +70,25 @@ function update_hgf!( #And its value prediction error update_node_value_prediction_error!(node) #And its volatility prediction error - update_node_volatility_prediction_error!(node) + update_node_precision_prediction_error!(node) + end + + ### Save the history for each node ### + #If save history is enabled + if hgf.save_history + + #Update the timepoint + push!(hgf.timesteps, hgf.timesteps[end] + stepsize) + + #Go through each node + for node in hgf.ordered_nodes.all_nodes + + #Go through each state + for state_name in fieldnames(typeof(node.states)) + #Add that state to the history + push!(getfield(node.history, state_name), getfield(node.states, state_name)) + end + end end return nothing @@ -83,7 +102,7 @@ Set input values in input nodes. Can either take a single value, a vector of val function enter_node_inputs!(hgf::HGF, input::Union{Real,Missing}) #Update the input node by passing the specified input to it - update_node_input!(hgf.ordered_nodes.input_nodes[1], input) + update_node_input!(first(hgf.ordered_nodes.input_nodes), input) return nothing end @@ -93,7 +112,7 @@ function enter_node_inputs!(hgf::HGF, inputs::Vector{<:Union{Real,Missing}}) #If the vector of inputs only contain a single input if length(inputs) == 1 #Just input that into the first input node - update_node_input!(hgf.ordered_nodes.input_nodes[1], inputs[1]) + enter_node_inputs!(hgf, first(inputs)) else @@ -117,3 +136,16 @@ function enter_node_inputs!(hgf::HGF, inputs::Dict{String,<:Union{Real,Missing}} return nothing end + + +""" + update_node_input!(node::AbstractInputNode, input::Union{Real,Missing}) + +Update the prediction of a single input node. +""" +function update_node_input!(node::AbstractInputNode, input::Union{Real,Missing}) + #Receive input + node.states.input_value = input + + return nothing +end diff --git a/src/update_hgf/update_node.jl b/src/update_hgf/update_node.jl deleted file mode 100644 index b8e88ae..0000000 --- a/src/update_hgf/update_node.jl +++ /dev/null @@ -1,308 +0,0 @@ -####################################### -######## Continuous State Node ######## -####################################### -""" - update_node_prediction!(node::AbstractStateNode) - -Update the prediction of a single state node. -""" -function update_node_prediction!(node::AbstractStateNode) - - #Update prediction mean - node.states.prediction_mean = calculate_prediction_mean(node) - push!(node.history.prediction_mean, node.states.prediction_mean) - - #Update prediction volatility - node.states.predicted_volatility = calculate_predicted_volatility(node) - push!(node.history.predicted_volatility, node.states.predicted_volatility) - - #Update prediction precision - node.states.prediction_precision = calculate_prediction_precision(node) - push!(node.history.prediction_precision, node.states.prediction_precision) - - #Get auxiliary prediction precision, only if there are volatility children and/or volatility parents - if length(node.volatility_parents) > 0 || length(node.volatility_children) > 0 - node.states.volatility_weighted_prediction_precision = - calculate_volatility_weighted_prediction_precision(node) - push!( - node.history.volatility_weighted_prediction_precision, - node.states.volatility_weighted_prediction_precision, - ) - end - - return nothing -end - -""" - update_node_posterior!(node::AbstractStateNode; update_type::HGFUpdateType) - -Update the posterior of a single continuous state node. This is the classic HGF update. -""" -function update_node_posterior!(node::AbstractStateNode, update_type::ClassicUpdate) - #Update posterior precision - node.states.posterior_precision = calculate_posterior_precision(node) - push!(node.history.posterior_precision, node.states.posterior_precision) - - #Update posterior mean - node.states.posterior_mean = calculate_posterior_mean(node, update_type) - push!(node.history.posterior_mean, node.states.posterior_mean) - - return nothing -end - -""" - update_node_posterior!(node::AbstractStateNode) - -Update the posterior of a single continuous state node. This is the enahnced HGF update. -""" -function update_node_posterior!(node::AbstractStateNode, update_type::EnhancedUpdate) - #Update posterior mean - node.states.posterior_mean = calculate_posterior_mean(node, update_type) - push!(node.history.posterior_mean, node.states.posterior_mean) - - #Update posterior precision - node.states.posterior_precision = calculate_posterior_precision(node) - push!(node.history.posterior_precision, node.states.posterior_precision) - - return nothing -end - -""" - update_node_value_prediction_error!(node::AbstractStateNode) - -Update the value prediction error of a single state node. -""" -function update_node_value_prediction_error!(node::AbstractStateNode) - #Update value prediction error - node.states.value_prediction_error = calculate_value_prediction_error(node) - push!(node.history.value_prediction_error, node.states.value_prediction_error) - - return nothing -end - -""" - update_node_volatility_prediction_error!(node::AbstractStateNode) - -Update the volatility prediction error of a single state node. -""" -function update_node_volatility_prediction_error!(node::AbstractStateNode) - - #Update volatility prediction error, only if there are volatility parents - if length(node.volatility_parents) > 0 - node.states.volatility_prediction_error = - calculate_volatility_prediction_error(node) - push!( - node.history.volatility_prediction_error, - node.states.volatility_prediction_error, - ) - end - - return nothing -end - - -############################################## -######## Binary State Node Variations ######## -############################################## -""" - update_node_prediction!(node::BinaryStateNode) - -Update the prediction of a single binary state node. -""" -function update_node_prediction!(node::BinaryStateNode) - - #Update prediction mean - node.states.prediction_mean = calculate_prediction_mean(node) - push!(node.history.prediction_mean, node.states.prediction_mean) - - #Update prediction precision - node.states.prediction_precision = calculate_prediction_precision(node) - push!(node.history.prediction_precision, node.states.prediction_precision) - - return nothing -end - -""" - update_node_volatility_prediction_error!(node::BinaryStateNode) - -There is no volatility prediction error update for binary state nodes. -""" -function update_node_volatility_prediction_error!(node::BinaryStateNode) - return nothing -end - - -################################################### -######## Categorical State Node Variations ######## -################################################### -""" - update_node_prediction!(node::CategoricalStateNode) - -Update the prediction of a single categorical state node. -""" -function update_node_prediction!(node::CategoricalStateNode) - - #Update prediction mean - node.states.prediction, node.states.parent_predictions = calculate_prediction(node) - push!(node.history.prediction, node.states.prediction) - push!(node.history.parent_predictions, node.states.parent_predictions) - return nothing -end - -""" - update_node_posterior!(node::CategoricalStateNode) - -Update the posterior of a single categorical state node. -""" -function update_node_posterior!(node::CategoricalStateNode, update_type::ClassicUpdate) - - #Update posterior mean - node.states.posterior = calculate_posterior(node) - push!(node.history.posterior, node.states.posterior) - - return nothing -end - -""" - update_node_volatility_prediction_error!(node::CategoricalStateNode) - -There is no volatility prediction error update for categorical state nodes. -""" -function update_node_volatility_prediction_error!(node::CategoricalStateNode) - return nothing -end - - -################################################### -######## Conntinuous Input Node Variations ######## -################################################### -""" - update_node_input!(node::AbstractInputNode, input::Union{Real,Missing}) - -Update the prediction of a single input node. -""" -function update_node_input!(node::AbstractInputNode, input::Union{Real,Missing}) - #Receive input - node.states.input_value = input - push!(node.history.input_value, node.states.input_value) - - return nothing -end - -""" - update_node_prediction!(node::AbstractInputNode) - -Update the posterior of a single input node. -""" -function update_node_prediction!(node::AbstractInputNode) - #Update prediction volatility - node.states.predicted_volatility = calculate_predicted_volatility(node) - push!(node.history.predicted_volatility, node.states.predicted_volatility) - - #Update prediction precision - node.states.prediction_precision = calculate_prediction_precision(node) - push!(node.history.prediction_precision, node.states.prediction_precision) - - return nothing -end - -""" - update_node_value_prediction_error!(node::AbstractInputNode) - -Update the value prediction error of a single input node. -""" -function update_node_value_prediction_error!(node::AbstractInputNode) - - #Calculate value prediction error - node.states.value_prediction_error = calculate_value_prediction_error(node) - push!(node.history.value_prediction_error, node.states.value_prediction_error) - - return nothing -end - -""" - update_node_volatility_prediction_error!(node::AbstractInputNode) - -Update the value prediction error of a single input node. -""" -function update_node_volatility_prediction_error!(node::AbstractInputNode) - - #Calculate volatility prediction error, only if there are volatility parents - if length(node.volatility_parents) > 0 - node.states.volatility_prediction_error = - calculate_volatility_prediction_error(node) - push!( - node.history.volatility_prediction_error, - node.states.volatility_prediction_error, - ) - end - - return nothing -end - - -############################################## -######## Binary Input Node Variations ######## -############################################## -""" - update_node_prediction!(node::BinaryInputNode) - -There is no prediction update for binary input nodes, as the prediction precision is constant. -""" -function update_node_prediction!(node::BinaryInputNode) - return nothing -end - -""" - update_node_value_prediction_error!(node::BinaryInputNode) - -Update the value prediction error of a single binary input node. -""" -function update_node_value_prediction_error!(node::BinaryInputNode) - - #Calculate value prediction error - node.states.value_prediction_error = calculate_value_prediction_error(node) - push!(node.history.value_prediction_error, node.states.value_prediction_error) - - return nothing -end - -""" - update_node_volatility_prediction_error!(node::BinaryInputNode) - -There is no volatility prediction error update for binary input nodes. -""" -function update_node_volatility_prediction_error!(node::BinaryInputNode) - return nothing -end - - -################################################### -######## Categorical Input Node Variations ######## -################################################### -""" - update_node_prediction!(node::CategoricalInputNode) - -There is no prediction update for categorical input nodes, as the prediction precision is constant. -""" -function update_node_prediction!(node::CategoricalInputNode) - return nothing -end - -""" - update_node_value_prediction_error!(node::CategoricalInputNode) - - There is no value prediction error update for categorical input nodes. -""" -function update_node_value_prediction_error!(node::CategoricalInputNode) - return nothing -end - -""" - update_node_volatility_prediction_error!(node::CategoricalInputNode) - -There is no volatility prediction error update for categorical input nodes. -""" -function update_node_volatility_prediction_error!(node::CategoricalInputNode) - return nothing -end diff --git a/src/utils/get_prediction.jl b/src/utils/get_prediction.jl index bfced30..ecdc59a 100644 --- a/src/utils/get_prediction.jl +++ b/src/utils/get_prediction.jl @@ -6,60 +6,52 @@ A single node can also be passed. """ function get_prediction end -function get_prediction(agent::Agent, node_name::String = "x1") +function get_prediction(agent::Agent, node_name::String, stepsize::Real = 1) - #Get prediction form the HGF - prediction = get_prediction(agent.substruct, node_name) + #Get prediction from the HGF + prediction = get_prediction(agent.substruct, node_name, stepsize) return prediction end -function get_prediction(hgf::HGF, node_name::String = "x1") +function get_prediction(hgf::HGF, node_name::String, stepsize::Real = 1) #Get the prediction of the given node - return get_prediction(hgf.all_nodes[node_name]) + return get_prediction(hgf.all_nodes[node_name], stepsize) end ### Single node functions ### -function get_prediction(node::AbstractNode) +function get_prediction(node::ContinuousStateNode, stepsize::Real = 1) #Save old states old_states = (; prediction_mean = node.states.prediction_mean, - predicted_volatility = node.states.predicted_volatility, prediction_precision = node.states.prediction_precision, - volatility_weighted_prediction_precision = node.states.volatility_weighted_prediction_precision, + effective_prediction_precision = node.states.effective_prediction_precision, ) #Update prediction mean - node.states.prediction_mean = calculate_prediction_mean(node) - - #Update prediction volatility - node.states.predicted_volatility = calculate_predicted_volatility(node) + node.states.prediction_mean = calculate_prediction_mean(node, stepsize) #Update prediction precision - node.states.prediction_precision = calculate_prediction_precision(node) - - node.states.volatility_weighted_prediction_precision = - calculate_volatility_weighted_prediction_precision(node) + node.states.prediction_precision, node.states.effective_prediction_precision = + calculate_prediction_precision(node, stepsize) #Save new states new_states = (; prediction_mean = node.states.prediction_mean, - predicted_volatility = node.states.predicted_volatility, prediction_precision = node.states.prediction_precision, - volatility_weighted_prediction_precision = node.states.volatility_weighted_prediction_precision, + effective_prediction_precision = node.states.effective_prediction_precision, ) #Change states back to the old states node.states.prediction_mean = old_states.prediction_mean - node.states.predicted_volatility = old_states.predicted_volatility node.states.prediction_precision = old_states.prediction_precision - node.states.volatility_weighted_prediction_precision = old_states.volatility_weighted_prediction_precision + node.states.effective_prediction_precision = old_states.effective_prediction_precision return new_states end -function get_prediction(node::BinaryStateNode) +function get_prediction(node::BinaryStateNode, stepsize::Real = 1) #Save old states old_states = (; @@ -86,7 +78,7 @@ function get_prediction(node::BinaryStateNode) return new_states end -function get_prediction(node::CategoricalStateNode) +function get_prediction(node::CategoricalStateNode, stepsize::Real = 1) #Save old states old_states = (; prediction = node.states.prediction) @@ -104,35 +96,32 @@ function get_prediction(node::CategoricalStateNode) end -function get_prediction(node::AbstractInputNode) +function get_prediction(node::ContinuousInputNode, stepsize::Real = 1) #Save old states old_states = (; - predicted_volatility = node.states.predicted_volatility, + prediction_mean = node.states.prediction_mean, prediction_precision = node.states.prediction_precision, ) - #Update prediction volatility - node.states.predicted_volatility = calculate_predicted_volatility(node) - #Update prediction precision + node.states.prediction_mean = calculate_prediction_mean(node) node.states.prediction_precision = calculate_prediction_precision(node) #Save new states new_states = (; - predicted_volatility = node.states.predicted_volatility, + prediction_mean = node.states.prediction_mean, prediction_precision = node.states.prediction_precision, - volatility_weighted_prediction_precision = 1.0, ) #Change states back to the old states - node.states.predicted_volatility = old_states.predicted_volatility + node.states.prediction_mean = old_states.prediction_mean node.states.prediction_precision = old_states.prediction_precision return new_states end -function get_prediction(node::BinaryInputNode) +function get_prediction(node::BinaryInputNode, stepsize::Real = 1) #Binary input nodes have no prediction states new_states = (;) @@ -140,7 +129,7 @@ function get_prediction(node::BinaryInputNode) return new_states end -function get_prediction(node::CategoricalInputNode) +function get_prediction(node::CategoricalInputNode, stepsize::Real = 1) #Binary input nodes have no prediction states new_states = (;) diff --git a/src/utils/get_surprise.jl b/src/utils/get_surprise.jl index 9de92f9..acdcf25 100644 --- a/src/utils/get_surprise.jl +++ b/src/utils/get_surprise.jl @@ -63,7 +63,7 @@ function get_surprise(node::ContinuousInputNode) #Sum the predictions of the vaue parents parents_prediction_mean = 0 - for parent in node.value_parents + for parent in node.edges.observation_parents parents_prediction_mean += parent.states.prediction_mean end @@ -87,7 +87,7 @@ function get_surprise(node::BinaryInputNode) #Sum the predictions of the vaue parents parents_prediction_mean = 0 - for parent in node.value_parents + for parent in node.edges.observation_parents parents_prediction_mean += parent.states.prediction_mean end @@ -138,7 +138,7 @@ Calculate the surprise of a categorical input node on seeing the last input. function get_surprise(node::CategoricalInputNode) #Get value parent - parent = node.value_parents[1] + parent = node.edges.observation_parents[1] #Get surprise surprise = sum(-log.(exp.(log.(parent.states.prediction) .* parent.states.posterior))) diff --git a/test/Project.toml b/test/Project.toml index 3b91bab..38e7d75 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,12 @@ [deps] ActionModels = "320cf53b-cc3b-4b34-9a10-0ecb113566a3" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" diff --git a/test/quicktests.jl b/test/quicktests.jl index 7afa9e3..73f1da7 100644 --- a/test/quicktests.jl +++ b/test/quicktests.jl @@ -1,7 +1,7 @@ using HierarchicalGaussianFiltering -hgf = premade_hgf("continuous_2level", verbose = true) +hgf = premade_hgf("continuous_2level", verbose = false) -update_hgf!(hgf, [0.01, 0.02, 0.06]) +give_inputs!(hgf, [0.01, 0.02, 0.06]) get_states(hgf) diff --git a/test/runtests.jl b/test/runtests.jl index 524fd5d..f072d30 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,58 +1,59 @@ -using ActionModels using HierarchicalGaussianFiltering using Test +using Glob -@testset "Unit tests" begin +#Get the root path of the package +hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) - # Test the quick tests that are used as pre-commit tests - include("quicktests.jl") +@testset "All tests" begin - # Test that the HGF gives canonical outputs - include("test_canonical.jl") + @testset "Unit tests" begin - # Test initialization - include("test_initialization.jl") + #Get the path to the testing folder + test_path = hgf_path * "/test/" - # Test premade HGF models - include("test_premade_hgf.jl") + @testset "quick tests" begin + # Test the quick tests that are used as pre-commit tests + include(test_path * "quicktests.jl") + end - # Test shared parameters - include("test_shared_parameters.jl") + # List the julia filenames in the testsuite + filenames = glob("*.jl", test_path * "testsuite") - # Test premade action models - include("test_premade_agent.jl") - - #Run fitting tests - include("test_fit_model.jl") + # For each file + for filename in filenames + #Run it + include(filename) + end + end - # Test update_hgf - # Test node_update - # Test action models - # Test update equations + @testset "Documentation tests" begin -end + #Set up path for the documentation folder + documentation_path = hgf_path * "/docs/src/" + @testset "Sourcefiles" begin -@testset "Documentation" begin + # List the julia filenames in the documentation source files folder + filenames = glob("*.jl", documentation_path * "/Julia_src_files") - #Set up path for the documentation folder - hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) - documentation_path = hgf_path * "/docs/src/" + for filename in filenames + @testset "$filename" begin + include(filename) + end + end + end - @testset "tutorials" begin - - #Get path for the tutorials subfolder - tutorials_path = documentation_path * "tutorials/" - - #Classic tutorials - include(tutorials_path * "classic_binary.jl") - include(tutorials_path * "classic_usdchf.jl") - end + @testset "Tutorials" begin - @testset "sourcefiles" begin - - #Get path for the tutorials subfolder - sourcefiles_path = documentation_path * "Julia_src_files/" + # List the julia filenames in the tutorials folder + filenames = glob("*.jl", documentation_path * "/tutorials") + for filename in filenames + @testset "$filename" begin + include(filename) + end + end + end end end diff --git a/test/test_shared_parameters.jl b/test/test_shared_parameters.jl deleted file mode 100644 index 9231e01..0000000 --- a/test/test_shared_parameters.jl +++ /dev/null @@ -1,78 +0,0 @@ -using HierarchicalGaussianFiltering -using Test - -# Test of custom HGF with shared parameters - -#List of input nodes to create -input_nodes = Dict("name" => "u", "type" => "continuous", "input_noise" => 2) - -#List of state nodes to create -state_nodes = [ - Dict( - "name" => "x1", - "type" => "continuous", - "volatility" => 2, - "initial_mean" => 1, - "initial_precision" => 1, - ), - Dict( - "name" => "x2", - "type" => "continuous", - "volatility" => 2, - "initial_mean" => 1, - "initial_precision" => 1, - ), -] - -#List of child-parent relations -edges = [ - Dict("child" => "u", "value_parents" => ("x1", 1)), - Dict("child" => "x1", "volatility_parents" => ("x2", 1)), -] - - -# one shared parameter -shared_parameters_1 = - Dict("volatilitys" => (9, [("x1", "volatility"), ("x2", "volatility")])) - -#Initialize the HGF -hgf_1 = init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - shared_parameters = shared_parameters_1, -) - -#get shared parameter -get_parameters(hgf_1) - -@test get_parameters(hgf_1, "volatilitys") == 9 - -#set shared parameter -set_parameters!(hgf_1, "volatilitys", 2) - -shared_parameters_2 = Dict( - "initial_means" => (9, [("x1", "initial_mean"), ("x2", "initial_mean")]), - "volatilitys" => (9, [("x1", "volatility"), ("x2", "volatility")]), -) - - -#Initialize the HGF -hgf_2 = init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, - edges = edges, - shared_parameters = shared_parameters_2, -) - -#get all parameters -get_parameters(hgf_2) - -#get shared parameter -@test get_parameters(hgf_2, "volatilitys") == 9 - -#set shared parameter -set_parameters!(hgf_2, Dict("volatilitys" => -2, "initial_means" => 1)) - -@test get_parameters(hgf_2, "volatilitys") == -2 -@test get_parameters(hgf_2, "initial_means") == 1 diff --git a/test/data/canonical_binary3level.csv b/test/testsuite/data/canonical_binary3level.csv similarity index 100% rename from test/data/canonical_binary3level.csv rename to test/testsuite/data/canonical_binary3level.csv diff --git a/test/data/canonical_continuous2level_inputs.dat b/test/testsuite/data/canonical_continuous2level_inputs.dat similarity index 100% rename from test/data/canonical_continuous2level_inputs.dat rename to test/testsuite/data/canonical_continuous2level_inputs.dat diff --git a/test/data/canonical_continuous2level_states.csv b/test/testsuite/data/canonical_continuous2level_states.csv similarity index 99% rename from test/data/canonical_continuous2level_states.csv rename to test/testsuite/data/canonical_continuous2level_states.csv index 0952fc5..aeaab6f 100644 --- a/test/data/canonical_continuous2level_states.csv +++ b/test/testsuite/data/canonical_continuous2level_states.csv @@ -1,4 +1,4 @@ -,x1_mean,x1_precision,x2_mean,x2_precision +,mu1,pi1,mu2,pi2 0,1.04,10000.0,1.0,10.0 1,1.0377859183728282,19421.14485405236,0.9968176766176965,4.262927405210326 2,1.0356343652765874,27356.60292609981,0.9912464787404184,2.7209034885694074 diff --git a/test/test_canonical.jl b/test/testsuite/test_canonical.jl similarity index 60% rename from test/test_canonical.jl rename to test/testsuite/test_canonical.jl index 7a8d5e4..e97ac49 100644 --- a/test/test_canonical.jl +++ b/test/testsuite/test_canonical.jl @@ -10,7 +10,7 @@ using Plots #Get the path for the HGF superfolder hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering))) #Add the path to the data files - data_path = hgf_path * "/test/data/" + data_path = hgf_path * "/test/testsuite/data/" @testset "Canonical Continuous 2level" begin @@ -36,15 +36,14 @@ using Plots ### Set up HGF ### #set parameters and starting states parameters = Dict( - ("u", "x1", "value_coupling") => 1.0, - ("x1", "x2", "volatility_coupling") => 1.0, ("u", "input_noise") => log(1e-4), - ("x1", "volatility") => -13, - ("x2", "volatility") => -2, - ("x1", "initial_mean") => 1.04, - ("x1", "initial_precision") => 1e4, - ("x2", "initial_mean") => 1.0, - ("x2", "initial_precision") => 10, + ("x", "volatility") => -13, + ("x", "initial_mean") => 1.04, + ("x", "initial_precision") => 1e4, + ("xvol", "volatility") => -2, + ("xvol", "initial_mean") => 1.0, + ("xvol", "initial_precision") => 10, + ("x", "xvol", "coupling_strength") => 1.0, "update_type" => ClassicUpdate(), ) @@ -56,26 +55,26 @@ using Plots #Construct result output dataframe result_outputs = DataFrame( - x1_mean = test_hgf.state_nodes["x1"].history.posterior_mean, - x1_precision = test_hgf.state_nodes["x1"].history.posterior_precision, - x2_mean = test_hgf.state_nodes["x2"].history.posterior_mean, - x2_precision = test_hgf.state_nodes["x2"].history.posterior_precision, + x_mean = test_hgf.state_nodes["x"].history.posterior_mean, + x_precision = test_hgf.state_nodes["x"].history.posterior_precision, + xvol_mean = test_hgf.state_nodes["xvol"].history.posterior_mean, + xvol_precision = test_hgf.state_nodes["xvol"].history.posterior_precision, ) #Test if the values are approximately the same @testset "compare output trajectories" begin for i = 1:nrow(result_outputs) - @test result_outputs.x1_mean[i] ≈ target_outputs.x1_mean[i] - @test result_outputs.x1_precision[i] ≈ target_outputs.x1_precision[i] - @test result_outputs.x2_mean[i] ≈ target_outputs.x2_mean[i] - @test result_outputs.x2_precision[i] ≈ target_outputs.x2_precision[i] + @test result_outputs.x_mean[i] ≈ target_outputs.mu1[i] + @test result_outputs.x_precision[i] ≈ target_outputs.pi1[i] + @test result_outputs.xvol_mean[i] ≈ target_outputs.mu2[i] + @test result_outputs.xvol_precision[i] ≈ target_outputs.pi2[i] end end @testset "Trajectory plots" begin #Make trajectory plots plot_trajectory(test_hgf, "u") - plot_trajectory!(test_hgf, ("x1", "posterior")) + plot_trajectory!(test_hgf, ("x", "posterior")) end end @@ -95,14 +94,14 @@ using Plots test_parameters = Dict( ("u", "category_means") => [0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "volatility") => -2.5, - ("x3", "volatility") => -6.0, - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, - ("x2", "initial_mean") => 0.0, - ("x2", "initial_precision") => 1.0, - ("x3", "initial_mean") => 1.0, - ("x3", "initial_precision") => 1.0, + ("xprob", "volatility") => -2.5, + ("xvol", "volatility") => -6.0, + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, + ("xprob", "initial_mean") => 0.0, + ("xprob", "initial_precision") => 1.0, + ("xvol", "initial_mean") => 1.0, + ("xvol", "initial_precision") => 1.0, "update_type" => ClassicUpdate(), ) @@ -114,12 +113,12 @@ using Plots #Construct result output dataframe result_outputs = DataFrame( - x1_mean = test_hgf.state_nodes["x1"].history.posterior_mean, - x1_precision = test_hgf.state_nodes["x1"].history.posterior_precision, - x2_mean = test_hgf.state_nodes["x2"].history.posterior_mean, - x2_precision = test_hgf.state_nodes["x2"].history.posterior_precision, - x3_mean = test_hgf.state_nodes["x3"].history.posterior_mean, - x3_precision = test_hgf.state_nodes["x3"].history.posterior_precision, + xbin_mean = test_hgf.state_nodes["xbin"].history.posterior_mean, + xbin_precision = test_hgf.state_nodes["xbin"].history.posterior_precision, + xprob_mean = test_hgf.state_nodes["xprob"].history.posterior_mean, + xprob_precision = test_hgf.state_nodes["xprob"].history.posterior_precision, + xvol_mean = test_hgf.state_nodes["xvol"].history.posterior_mean, + xvol_precision = test_hgf.state_nodes["xvol"].history.posterior_precision, ) #Remove the first row @@ -128,25 +127,25 @@ using Plots #Test if the values are approximately the same @testset "compare output trajectories" begin for i = 1:nrow(canonical_trajectory) - @test result_outputs.x1_mean[i] ≈ canonical_trajectory.mu1[i] - @test result_outputs.x1_precision[i] ≈ 1 / canonical_trajectory.sa1[i] + @test result_outputs.xbin_mean[i] ≈ canonical_trajectory.mu1[i] + @test result_outputs.xbin_precision[i] ≈ 1 / canonical_trajectory.sa1[i] @test isapprox( - result_outputs.x2_mean[i], + result_outputs.xprob_mean[i], canonical_trajectory.mu2[i], rtol = 0.1, ) @test isapprox( - result_outputs.x2_precision[i], + result_outputs.xprob_precision[i], 1 / canonical_trajectory.sa2[i], rtol = 0.1, ) @test isapprox( - result_outputs.x3_mean[i], + result_outputs.xvol_mean[i], canonical_trajectory.mu3[i], rtol = 0.1, ) @test isapprox( - result_outputs.x3_precision[i], + result_outputs.xvol_precision[i], 1 / canonical_trajectory.sa3[i], rtol = 0.1, ) @@ -156,7 +155,7 @@ using Plots @testset "Trajectory plots" begin #Make trajectory plots plot_trajectory(test_hgf, "u") - plot_trajectory!(test_hgf, ("x1", "prediction")) + plot_trajectory!(test_hgf, ("xbin", "prediction")) end end end diff --git a/test/test_fit_model.jl b/test/testsuite/test_fit_model.jl similarity index 54% rename from test/test_fit_model.jl rename to test/testsuite/test_fit_model.jl index 8b4d164..2825740 100644 --- a/test/test_fit_model.jl +++ b/test/testsuite/test_fit_model.jl @@ -18,25 +18,24 @@ using Turing test_hgf = premade_hgf("continuous_2level", verbose = false) #Create agent - test_agent = premade_agent("hgf_gaussian_action", test_hgf, verbose = false) + test_agent = premade_agent("hgf_gaussian", test_hgf, verbose = false) # Set fixed parsmeters and priors for fitting test_fixed_parameters = Dict( - ("x1", "initial_mean") => 100, - ("x2", "initial_mean") => 1.0, - ("x2", "initial_precision") => 600, - ("u", "x1", "value_coupling") => 1.0, - ("x1", "x2", "volatility_coupling") => 1.0, - "gaussian_action_precision" => 100, - ("x2", "volatility") => -4, + ("x", "initial_mean") => 100, + ("xvol", "initial_mean") => 1.0, + ("xvol", "initial_precision") => 600, + ("x", "xvol", "coupling_strength") => 1.0, + "action_noise" => 0.01, + ("xvol", "volatility") => -4, ("u", "input_noise") => 4, - ("x2", "drift") => 1, + ("xvol", "drift") => 1, ) test_param_priors = Dict( - ("x1", "volatility") => Normal(log(100.0), 4), - ("x1", "initial_mean") => Normal(1, sqrt(100.0)), - ("x1", "drift") => Normal(0, 1), + ("x", "volatility") => Normal(log(100.0), 4), + ("x", "initial_mean") => Normal(1, sqrt(100.0)), + ("x", "drift") => Normal(0, 1), ) #Fit single chain with defaults @@ -51,20 +50,6 @@ using Turing ) @test fitted_model isa Turing.Chains - #Fit with multiple chains and HMC - fitted_model = fit_model( - test_agent, - test_param_priors, - test_input, - test_responses; - fixed_parameters = test_fixed_parameters, - sampler = HMC(0.01, 5), - n_chains = 4, - verbose = false, - n_iterations = 10, - ) - @test fitted_model isa Turing.Chains - #Plot the parameter distribution plot_parameter_distribution(fitted_model, test_param_priors) @@ -73,7 +58,7 @@ using Turing fitted_model, test_agent, test_input, - ("x1", "posterior_mean"); + ("x", "posterior_mean"); verbose = false, n_simulations = 3, ) @@ -90,24 +75,24 @@ using Turing test_hgf = premade_hgf("binary_3level", verbose = false) #Create agent - test_agent = premade_agent("hgf_binary_softmax_action", test_hgf, verbose = false) + test_agent = premade_agent("hgf_binary_softmax", test_hgf, verbose = false) #Set fixed parameters and priors test_fixed_parameters = Dict( ("u", "category_means") => Real[0.0, 1.0], ("u", "input_precision") => Inf, - ("x2", "initial_mean") => 3.0, - ("x2", "initial_precision") => exp(2.306), - ("x3", "initial_mean") => 3.2189, - ("x3", "initial_precision") => exp(-1.0986), - ("x1", "x2", "value_coupling") => 1.0, - ("x2", "x3", "volatility_coupling") => 1.0, - ("x3", "volatility") => -3, + ("xprob", "initial_mean") => 3.0, + ("xprob", "initial_precision") => exp(2.306), + ("xvol", "initial_mean") => 3.2189, + ("xvol", "initial_precision") => exp(-1.0986), + ("xbin", "xprob", "coupling_strength") => 1.0, + ("xprob", "xvol", "coupling_strength") => 1.0, + ("xvol", "volatility") => -3, ) test_param_priors = Dict( - "softmax_action_precision" => Truncated(Normal(100, 20), 0, Inf), - ("x2", "volatility") => Normal(-7, 5), + "action_noise" => truncated(Normal(0.01, 20), 0, Inf), + ("xprob", "volatility") => Normal(-7, 5), ) #Fit single chain with defaults @@ -122,20 +107,6 @@ using Turing ) @test fitted_model isa Turing.Chains - #Fit with multiple chains and HMC - fitted_model = fit_model( - test_agent, - test_param_priors, - test_input, - test_responses; - fixed_parameters = test_fixed_parameters, - sampler = HMC(0.01, 5), - n_chains = 4, - verbose = false, - n_iterations = 10, - ) - @test fitted_model isa Turing.Chains - #Plot the parameter distribution plot_parameter_distribution(fitted_model, test_param_priors) @@ -144,7 +115,7 @@ using Turing fitted_model, test_agent, test_input, - ("x1", "posterior_mean"), + ("xbin", "posterior_mean"), verbose = false, n_simulations = 3, ) diff --git a/test/testsuite/test_grouped_parameters.jl b/test/testsuite/test_grouped_parameters.jl new file mode 100644 index 0000000..02aeca2 --- /dev/null +++ b/test/testsuite/test_grouped_parameters.jl @@ -0,0 +1,69 @@ +using HierarchicalGaussianFiltering +using Test + +@testset "Grouped parameters" begin + + # Test of custom HGF with shared parameters + + #List of nodes + nodes = [ + ContinuousInput(name = "u", input_noise = 2), + ContinuousState( + name = "x1", + volatility = 2, + initial_mean = 1, + initial_precision = 1, + ), + ContinuousState( + name = "x2", + volatility = 2, + initial_mean = 1, + initial_precision = 1, + ), + ] + + #List of child-parent relations + edges = + Dict(("u", "x1") => ObservationCoupling(), ("x1", "x2") => VolatilityCoupling(1)) + + # one shared parameter + parameter_groups_1 = + [ParameterGroup("volatilities", [("x1", "volatility"), ("x2", "volatility")], 9)] + + #Initialize the HGF + hgf_1 = init_hgf(nodes = nodes, edges = edges, parameter_groups = parameter_groups_1) + + #get shared parameter + get_parameters(hgf_1) + + @test get_parameters(hgf_1, "volatilities") == 9 + + #set shared parameter + set_parameters!(hgf_1, "volatilities", 2) + + parameter_groups_2 = [ + ParameterGroup( + "initial_means", + [("x1", "initial_mean"), ("x2", "initial_mean")], + 9, + ), + ParameterGroup("volatilities", [("x1", "volatility"), ("x2", "volatility")], 9), + ] + + #Initialize the HGF + hgf_2 = init_hgf(nodes = nodes, edges = edges, parameter_groups = parameter_groups_2) + + #get all parameters + get_parameters(hgf_2) + + #get shared parameter + @test get_parameters(hgf_2, "volatilities") == 9 + + #set shared parameter + set_parameters!(hgf_2, Dict("volatilities" => -2, "initial_means" => 1)) + + @test get_parameters(hgf_2, "volatilities") == -2 + @test get_parameters(hgf_2, "initial_means") == 1 + + +end diff --git a/test/test_initialization.jl b/test/testsuite/test_initialization.jl similarity index 53% rename from test/test_initialization.jl rename to test/testsuite/test_initialization.jl index bd21fc8..56556da 100644 --- a/test/test_initialization.jl +++ b/test/testsuite/test_initialization.jl @@ -3,50 +3,45 @@ using Test @testset "Initialization" begin #Parameter values to be used for all nodes unless other values are given - node_defaults = Dict( - "volatility" => 3, - "input_noise" => -2, - "category_means" => [0, 1], - "input_precision" => Inf, - "initial_mean" => 1, - "initial_precision" => 2, - "value_coupling" => 1, - "drift" => 2, + node_defaults = NodeDefaults( + volatility = 3, + input_noise = -2, + initial_mean = 1, + initial_precision = 2, + coupling_strength = 1, + drift = 2, ) - #List of input nodes to create - input_nodes = [Dict("name" => "u1", "input_noise" => 2), "u2"] - - #List of state nodes to create - state_nodes = [ - "x1", - "x2", - "x3", - Dict("name" => "x4", "volatility" => 2), - Dict( - "name" => "x5", - "volatility" => 2, - "initial_mean" => 4, - "initial_precision" => 3, - "drift" => 5 + #List of nodes + nodes = [ + ContinuousInput(name = "u1", input_noise = 2), + ContinuousInput(name = "u2"), + ContinuousState(name = "x1"), + ContinuousState(name = "x2"), + ContinuousState(name = "x3"), + ContinuousState(name = "x4", volatility = 2), + ContinuousState( + name = "x5", + volatility = 2, + initial_mean = 4, + initial_precision = 3, + drift = 5, ), ] #List of child-parent relations - edges = [ - Dict("child" => "u1", "value_parents" => "x1"), - Dict("child" => "u2", "value_parents" => "x2", "volatility_parents" => "x3"), - Dict( - "child" => "x1", - "value_parents" => ("x3", 2), - "volatility_parents" => [("x4", 2), "x5"], - ), - ] + edges = Dict( + ("u1", "x1") => ObservationCoupling(), + ("u2", "x2") => ObservationCoupling(), + ("u2", "x3") => NoiseCoupling(), + ("x1", "x3") => DriftCoupling(strength = 2), + ("x1", "x4") => VolatilityCoupling(strength = 2), + ("x1", "x5") => VolatilityCoupling(), + ) #Initialize an HGF test_hgf = init_hgf( - input_nodes = input_nodes, - state_nodes = state_nodes, + nodes = nodes, edges = edges, node_defaults = node_defaults, verbose = false, @@ -65,12 +60,9 @@ using Test @test test_hgf.state_nodes["x1"].parameters.drift == 2 @test test_hgf.state_nodes["x5"].parameters.drift == 5 - @test test_hgf.input_nodes["u1"].parameters.value_coupling["x1"] == 1 - @test test_hgf.input_nodes["u2"].parameters.value_coupling["x2"] == 1 - @test test_hgf.input_nodes["u2"].parameters.volatility_coupling["x3"] == 1 - @test test_hgf.state_nodes["x1"].parameters.value_coupling["x3"] == 2 - @test test_hgf.state_nodes["x1"].parameters.volatility_coupling["x4"] == 2 - @test test_hgf.state_nodes["x1"].parameters.volatility_coupling["x5"] == 1 + @test test_hgf.state_nodes["x1"].parameters.coupling_strengths["x3"] == 2 + @test test_hgf.state_nodes["x1"].parameters.coupling_strengths["x4"] == 2 + @test test_hgf.state_nodes["x1"].parameters.coupling_strengths["x5"] == 1 @test test_hgf.state_nodes["x1"].states.posterior_mean == 1 @test test_hgf.state_nodes["x1"].states.posterior_precision == 2 diff --git a/test/test_premade_agent.jl b/test/testsuite/test_premade_agent.jl similarity index 85% rename from test/test_premade_agent.jl rename to test/testsuite/test_premade_agent.jl index 7811fd1..1f6961e 100644 --- a/test/test_premade_agent.jl +++ b/test/testsuite/test_premade_agent.jl @@ -4,11 +4,11 @@ using Test @testset "Premade Action Models" begin - @testset "hgf_gaussian_action" begin + @testset "hgf_gaussian" begin #Create an HGF agent with the gaussian response test_agent = premade_agent( - "hgf_gaussian_action", + "hgf_gaussian", premade_hgf("continuous_2level", verbose = false), verbose = false, ) @@ -23,11 +23,11 @@ using Test @test get_surprise(test_agent.substruct) isa Real end - @testset "hgf_binary_softmax_action" begin + @testset "hgf_binary_softmax" begin #Create HGF agent with binary softmax action test_agent = premade_agent( - "hgf_binary_softmax_action", + "hgf_binary_softmax", premade_hgf("binary_3level", verbose = false), verbose = false, ) @@ -43,11 +43,11 @@ using Test end - @testset "hgf_unit_square_sigmoid_action" begin + @testset "hgf_unit_square_sigmoid" begin #Create HGF agent with binary softmax action test_agent = premade_agent( - "hgf_unit_square_sigmoid_action", + "hgf_unit_square_sigmoid", premade_hgf("binary_3level", verbose = false), verbose = false, ) diff --git a/test/test_premade_hgf.jl b/test/testsuite/test_premade_hgf.jl similarity index 100% rename from test/test_premade_hgf.jl rename to test/testsuite/test_premade_hgf.jl