From 6d45b8a758890dc248da652fb638b957ac22b556 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Tue, 8 Oct 2024 17:39:26 +0200 Subject: [PATCH] regorganization of actionmodels dependencies --- .../core/plot_predictive_simulation.jl | 39 ------------------- .../{core => }/create_premade_agent.jl | 0 .../{utils => }/get_history.jl | 0 .../{utils => }/get_parameters.jl | 0 .../{utils => }/get_states.jl | 0 .../{utils => }/give_inputs.jl | 0 .../{core => }/plot_trajectory.jl | 0 .../{utils => }/reset.jl | 0 .../{utils => }/set_parameters.jl | 0 .../{utils => }/set_save_history.jl | 0 src/HierarchicalGaussianFiltering.jl | 21 +++++----- test/testsuite/test_fit_model.jl | 27 +++++++++---- 12 files changed, 29 insertions(+), 58 deletions(-) delete mode 100644 src/ActionModels_variations/core/plot_predictive_simulation.jl rename src/ActionModels_variations/{core => }/create_premade_agent.jl (100%) rename src/ActionModels_variations/{utils => }/get_history.jl (100%) rename src/ActionModels_variations/{utils => }/get_parameters.jl (100%) rename src/ActionModels_variations/{utils => }/get_states.jl (100%) rename src/ActionModels_variations/{utils => }/give_inputs.jl (100%) rename src/ActionModels_variations/{core => }/plot_trajectory.jl (100%) rename src/ActionModels_variations/{utils => }/reset.jl (100%) rename src/ActionModels_variations/{utils => }/set_parameters.jl (100%) rename src/ActionModels_variations/{utils => }/set_save_history.jl (100%) diff --git a/src/ActionModels_variations/core/plot_predictive_simulation.jl b/src/ActionModels_variations/core/plot_predictive_simulation.jl deleted file mode 100644 index bf707f4..0000000 --- a/src/ActionModels_variations/core/plot_predictive_simulation.jl +++ /dev/null @@ -1,39 +0,0 @@ -""" - plot_predictive_simulation(hgf::HGF, parameter_distributions, target_state, inputs; kwargs...) - -Runs and plots results from a predictive simulation using only an HGF, instead of an agent. See the ActionModels documentation for more information. -""" -function ActionModels.plot_predictive_simulation( - hgf::HGF, - parameter_distributions, - target_state::Union{String,Tuple}, - inputs::Vector; - n_simulations::Int = 1000, - median_color::Union{String,Symbol} = :red, - title::String = "", - alpha::Real = 0.1, - linewidth::Real = 2, - verbose::Bool = true, -) - #Set an empty action model - empty_action_model = function () - return nothing - end - - #Create an agent containing the HGF - agent = init_agent(empty_action_model, hgf) - - #Run the plotting function on the agent - plot_predictive_simulation( - agent, - parameter_distributions, - target_state, - inputs; - n_simulations = n_simulations, - median_color = median_color, - title = title, - alpha = alpha, - linewidth = linewidth, - verbose = verbose, - ) -end diff --git a/src/ActionModels_variations/core/create_premade_agent.jl b/src/ActionModels_variations/create_premade_agent.jl similarity index 100% rename from src/ActionModels_variations/core/create_premade_agent.jl rename to src/ActionModels_variations/create_premade_agent.jl diff --git a/src/ActionModels_variations/utils/get_history.jl b/src/ActionModels_variations/get_history.jl similarity index 100% rename from src/ActionModels_variations/utils/get_history.jl rename to src/ActionModels_variations/get_history.jl diff --git a/src/ActionModels_variations/utils/get_parameters.jl b/src/ActionModels_variations/get_parameters.jl similarity index 100% rename from src/ActionModels_variations/utils/get_parameters.jl rename to src/ActionModels_variations/get_parameters.jl diff --git a/src/ActionModels_variations/utils/get_states.jl b/src/ActionModels_variations/get_states.jl similarity index 100% rename from src/ActionModels_variations/utils/get_states.jl rename to src/ActionModels_variations/get_states.jl diff --git a/src/ActionModels_variations/utils/give_inputs.jl b/src/ActionModels_variations/give_inputs.jl similarity index 100% rename from src/ActionModels_variations/utils/give_inputs.jl rename to src/ActionModels_variations/give_inputs.jl diff --git a/src/ActionModels_variations/core/plot_trajectory.jl b/src/ActionModels_variations/plot_trajectory.jl similarity index 100% rename from src/ActionModels_variations/core/plot_trajectory.jl rename to src/ActionModels_variations/plot_trajectory.jl diff --git a/src/ActionModels_variations/utils/reset.jl b/src/ActionModels_variations/reset.jl similarity index 100% rename from src/ActionModels_variations/utils/reset.jl rename to src/ActionModels_variations/reset.jl diff --git a/src/ActionModels_variations/utils/set_parameters.jl b/src/ActionModels_variations/set_parameters.jl similarity index 100% rename from src/ActionModels_variations/utils/set_parameters.jl rename to src/ActionModels_variations/set_parameters.jl diff --git a/src/ActionModels_variations/utils/set_save_history.jl b/src/ActionModels_variations/set_save_history.jl similarity index 100% rename from src/ActionModels_variations/utils/set_save_history.jl rename to src/ActionModels_variations/set_save_history.jl diff --git a/src/HierarchicalGaussianFiltering.jl b/src/HierarchicalGaussianFiltering.jl index 770475c..a15a723 100644 --- a/src/HierarchicalGaussianFiltering.jl +++ b/src/HierarchicalGaussianFiltering.jl @@ -7,7 +7,7 @@ using ActionModels, Distributions, RecipesBase export init_node, init_hgf, premade_hgf, check_hgf, update_hgf! export get_prediction, get_surprise export premade_agent, - init_agent, plot_predictive_simulation, plot_trajectory, plot_trajectory! + init_agent, plot_trajectory, plot_trajectory! export get_history, get_parameters, get_states, set_parameters!, reset!, give_inputs!, set_save_history! export ParameterGroup @@ -36,16 +36,15 @@ end include("create_hgf/hgf_structs.jl") #Overloading ActionModels functions -include("ActionModels_variations/core/create_premade_agent.jl") -include("ActionModels_variations/core/plot_predictive_simulation.jl") -include("ActionModels_variations/core/plot_trajectory.jl") -include("ActionModels_variations/utils/get_history.jl") -include("ActionModels_variations/utils/get_parameters.jl") -include("ActionModels_variations/utils/get_states.jl") -include("ActionModels_variations/utils/give_inputs.jl") -include("ActionModels_variations/utils/reset.jl") -include("ActionModels_variations/utils/set_parameters.jl") -include("ActionModels_variations/utils/set_save_history.jl") +include("ActionModels_variations/create_premade_agent.jl") +include("ActionModels_variations/plot_trajectory.jl") +include("ActionModels_variations/get_history.jl") +include("ActionModels_variations/get_parameters.jl") +include("ActionModels_variations/get_states.jl") +include("ActionModels_variations/give_inputs.jl") +include("ActionModels_variations/reset.jl") +include("ActionModels_variations/set_parameters.jl") +include("ActionModels_variations/set_save_history.jl") #Functions for updating the HGF include("update_hgf/update_hgf.jl") diff --git a/test/testsuite/test_fit_model.jl b/test/testsuite/test_fit_model.jl index 259210c..a30621f 100644 --- a/test/testsuite/test_fit_model.jl +++ b/test/testsuite/test_fit_model.jl @@ -1,18 +1,17 @@ using ActionModels using HierarchicalGaussianFiltering using Test -using Plots using StatsPlots using Distributions -using Turing +using ActionModels: Turing @testset "Model fitting" begin @testset "Continuous 2level" begin #Set inputs and responses - test_input = [1.0, 2, 3, 4, 5] - test_responses = [1.1, 2.2, 3.3, 4.4, 5.5] + test_input = [1.0, 1.0, 1.0, 1.0, 1.0] + test_responses = [1.0, 1.0, 1.0, 1.0, 1.0] #Create HGF test_hgf = premade_hgf("continuous_2level", verbose = false) @@ -27,15 +26,15 @@ using Turing ("xvol", "initial_precision") => 600, ("x", "xvol", "coupling_strength") => 1.0, "action_noise" => 0.01, - ("xvol", "volatility") => -4, + ("xvol", "volatility") => -10, ("u", "input_noise") => 4, ("xvol", "drift") => 1, + ("x", "drift") => Normal(0, 1), + ("x", "initial_mean") => Normal(1, 0.1), ) test_param_priors = Dict( - ("x", "volatility") => Normal(log(100.0), 4), - ("x", "initial_mean") => Normal(1, sqrt(100.0)), - ("x", "drift") => Normal(0, 1), + ("x", "volatility") => Normal(-10, 0.1), ) #Create model @@ -44,6 +43,18 @@ using Turing #Fit single chain with defaults fitted_model = fit_model(model; n_iterations = 10, n_chains = 1) + chains = fitted_model.chains + renamed_model = rename_chains(chains, model) + #Extract agent parameters + agent_parameters = extract_quantities(model, chains) + + estimates_df = get_estimates(agent_parameters) + estimates_dict = get_estimates(agent_parameters, Dict) + + #Extract state trajectories + state_trajectories = get_trajectories(model, chains, [("x", "value_prediction_error"), "action"]) + trajectory_estimates_df = get_estimates(state_trajectories) + @test fitted_model isa ActionModels.FitModelResults #Plot the parameter distribution