Skip to content

Commit

Permalink
regorganization of actionmodels dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
PTWaade committed Oct 8, 2024
1 parent 6fd9ae5 commit 6d45b8a
Show file tree
Hide file tree
Showing 12 changed files with 29 additions and 58 deletions.
39 changes: 0 additions & 39 deletions src/ActionModels_variations/core/plot_predictive_simulation.jl

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
21 changes: 10 additions & 11 deletions src/HierarchicalGaussianFiltering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using ActionModels, Distributions, RecipesBase
export init_node, init_hgf, premade_hgf, check_hgf, update_hgf!
export get_prediction, get_surprise
export premade_agent,
init_agent, plot_predictive_simulation, plot_trajectory, plot_trajectory!
init_agent, plot_trajectory, plot_trajectory!
export get_history,
get_parameters, get_states, set_parameters!, reset!, give_inputs!, set_save_history!
export ParameterGroup
Expand Down Expand Up @@ -36,16 +36,15 @@ end
include("create_hgf/hgf_structs.jl")

#Overloading ActionModels functions
include("ActionModels_variations/core/create_premade_agent.jl")
include("ActionModels_variations/core/plot_predictive_simulation.jl")
include("ActionModels_variations/core/plot_trajectory.jl")
include("ActionModels_variations/utils/get_history.jl")
include("ActionModels_variations/utils/get_parameters.jl")
include("ActionModels_variations/utils/get_states.jl")
include("ActionModels_variations/utils/give_inputs.jl")
include("ActionModels_variations/utils/reset.jl")
include("ActionModels_variations/utils/set_parameters.jl")
include("ActionModels_variations/utils/set_save_history.jl")
include("ActionModels_variations/create_premade_agent.jl")
include("ActionModels_variations/plot_trajectory.jl")
include("ActionModels_variations/get_history.jl")
include("ActionModels_variations/get_parameters.jl")
include("ActionModels_variations/get_states.jl")
include("ActionModels_variations/give_inputs.jl")
include("ActionModels_variations/reset.jl")
include("ActionModels_variations/set_parameters.jl")
include("ActionModels_variations/set_save_history.jl")

#Functions for updating the HGF
include("update_hgf/update_hgf.jl")
Expand Down
27 changes: 19 additions & 8 deletions test/testsuite/test_fit_model.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
using ActionModels
using HierarchicalGaussianFiltering
using Test
using Plots
using StatsPlots
using Distributions
using Turing
using ActionModels: Turing

@testset "Model fitting" begin

@testset "Continuous 2level" begin

#Set inputs and responses
test_input = [1.0, 2, 3, 4, 5]
test_responses = [1.1, 2.2, 3.3, 4.4, 5.5]
test_input = [1.0, 1.0, 1.0, 1.0, 1.0]
test_responses = [1.0, 1.0, 1.0, 1.0, 1.0]

#Create HGF
test_hgf = premade_hgf("continuous_2level", verbose = false)
Expand All @@ -27,15 +26,15 @@ using Turing
("xvol", "initial_precision") => 600,
("x", "xvol", "coupling_strength") => 1.0,
"action_noise" => 0.01,
("xvol", "volatility") => -4,
("xvol", "volatility") => -10,
("u", "input_noise") => 4,
("xvol", "drift") => 1,
("x", "drift") => Normal(0, 1),
("x", "initial_mean") => Normal(1, 0.1),
)

test_param_priors = Dict(
("x", "volatility") => Normal(log(100.0), 4),
("x", "initial_mean") => Normal(1, sqrt(100.0)),
("x", "drift") => Normal(0, 1),
("x", "volatility") => Normal(-10, 0.1),
)

#Create model
Expand All @@ -44,6 +43,18 @@ using Turing
#Fit single chain with defaults
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)

chains = fitted_model.chains
renamed_model = rename_chains(chains, model)
#Extract agent parameters
agent_parameters = extract_quantities(model, chains)

estimates_df = get_estimates(agent_parameters)
estimates_dict = get_estimates(agent_parameters, Dict)

#Extract state trajectories
state_trajectories = get_trajectories(model, chains, [("x", "value_prediction_error"), "action"])
trajectory_estimates_df = get_estimates(state_trajectories)

@test fitted_model isa ActionModels.FitModelResults

#Plot the parameter distribution
Expand Down

0 comments on commit 6d45b8a

Please sign in to comment.