Skip to content

Commit

Permalink
made compatible with ActionModels 0.6.1
Browse files Browse the repository at this point in the history
  • Loading branch information
PTWaade committed Sep 13, 2024
1 parent 1bd04e6 commit 1734804
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 90 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ plot_trajectory!(agent, ("x", "prediction"))
using Distributions
prior = Dict(("xprob", "volatility") => Normal(1, 0.5))
model = fit_model(agent, prior, inputs, actions, n_iterations = 20)
#Create model
model = create_model(agent, prior, inputs, actions;)
#Fit single chain with 10 iterations
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)
````
![Image1](docs/src/images/readme/fit_model.png)
### Plot chains
Expand All @@ -106,7 +111,7 @@ plot(model)
### Plot prior angainst posterior

````@example index
plot_parameter_distribution(model, prior)
# plot_parameter_distribution(model, prior)
````
![Image1](docs/src/images/readme/prior_posterior.png)
### Get posterior
Expand Down
8 changes: 6 additions & 2 deletions docs/julia_files/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ plot_trajectory!(agent, ("xbin", "prediction"))
using Distributions
prior = Dict(("xprob", "volatility") => Normal(1, 0.5))

model = fit_model(agent, prior, inputs, actions, n_iterations = 20)
#Create model
model = create_model(agent, prior, inputs, actions)

#Fit single chain with 10 iterations
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)

#-

Expand All @@ -65,7 +69,7 @@ plot(model)
#-

# ### Plot prior angainst posterior
plot_parameter_distribution(model, prior)
# plot_parameter_distribution(model, prior)
#-
# ### Get posterior
get_posteriors(model)
30 changes: 13 additions & 17 deletions docs/julia_files/tutorials/classic_binary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,27 +81,23 @@ plot_predictive_simulation(
actions = CSV.read(data_path * "classic_binary_actions.csv", DataFrame)[!, 1];
#-
# Fit the actions
fitted_model = fit_model(
agent,
param_priors,
inputs,
actions,
fixed_parameters = fixed_parameters,
verbose = true,
n_iterations = 10,
)
#Create model
model = create_model(agent, param_priors, inputs, actions)

#Fit single chain with 10 iterations
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)
#-
#Plot the chains
plot(fitted_model)
#-
# Plot the posterior
plot_parameter_distribution(fitted_model, param_priors)
# plot_parameter_distribution(fitted_model, param_priors)
#-
# Posterior predictive plot
plot_predictive_simulation(
fitted_model,
agent,
inputs,
("xbin", "prediction_mean"),
n_simulations = 3,
)
# plot_predictive_simulation(
# fitted_model,
# agent,
# inputs,
# ("xbin", "prediction_mean"),
# n_simulations = 3,
# )
15 changes: 5 additions & 10 deletions docs/julia_files/tutorials/classic_usdchf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,14 @@ plot_predictive_simulation(
)
#-
# Do parameter recovery
fitted_model = fit_model(
agent,
param_priors,
inputs,
actions,
fixed_parameters = fixed_parameters,
verbose = false,
n_iterations = 10,
)
model = create_model(agent, param_priors, inputs, actions)

#Fit single chain with 10 iterations
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)
#-
# Plot the chains
plot(fitted_model)
#-
# Plot prior posterior distributions
plot_parameter_distribution(fitted_model, param_priors)
# plot_parameter_distribution(fitted_model, param_priors)
#-
32 changes: 14 additions & 18 deletions docs/julia_files/user_guide/fitting_hgf_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,19 @@ param_priors = Dict(("xprob", "volatility") => Normal(-3.0, 0.5));

# We can fit the evolution rate by inputting the variables:

# Fit the actions
fitted_model = fit_model(
agent,
param_priors,
inputs,
actions,
fixed_parameters = fixed_parameters,
verbose = true,
n_iterations = 10,
)
# Create model
model = create_model(agent, param_priors, inputs, actions)

#Fit single chain with 10 iterations
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)

set_parameters!(agent, hgf_parameters)

# ## Plotting Functions
plot(fitted_model)

# Plot the posterior
plot_parameter_distribution(fitted_model, param_priors)
# plot_parameter_distribution(fitted_model, param_priors)


# # Predictive Simulations with plot\_predictive\_distributions()
Expand Down Expand Up @@ -151,13 +147,13 @@ fitted_model =
set_parameters!(agent, hgf_parameters)

# We can place our turing chain as a our posterior in the function, and get our posterior predictive simulation plot:
plot_predictive_simulation(
fitted_model,
agent,
inputs,
("xbin", "prediction_mean"),
n_simulations = 100,
)
# plot_predictive_simulation(
# fitted_model,
# agent,
# inputs,
# ("xbin", "prediction_mean"),
# n_simulations = 100,
# )

# We can get the posterior
get_posteriors(fitted_model)
Expand Down
68 changes: 30 additions & 38 deletions test/testsuite/test_fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,26 @@ using Turing
("x", "drift") => Normal(0, 1),
)

#Create model
model = create_model(test_agent, test_param_priors, test_input, test_responses;)

#Fit single chain with defaults
fitted_model = fit_model(
test_agent,
test_param_priors,
test_input,
test_responses;
fixed_parameters = test_fixed_parameters,
verbose = false,
n_iterations = 10,
)
@test fitted_model isa Turing.Chains
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)

@test fitted_model isa ActionModels.FitModelResults

#Plot the parameter distribution
plot_parameter_distribution(fitted_model, test_param_priors)
# plot_parameter_distribution(fitted_model, test_param_priors)

# Posterior predictive plot
plot_predictive_simulation(
fitted_model,
test_agent,
test_input,
("x", "posterior_mean");
verbose = false,
n_simulations = 3,
)
# plot_predictive_simulation(
# fitted_model,
# test_agent,
# test_input,
# ("x", "posterior_mean");
# verbose = false,
# n_simulations = 3,
# )
end


Expand Down Expand Up @@ -95,29 +91,25 @@ using Turing
("xprob", "volatility") => Normal(-7, 5),
)

#Create model
model = create_model(test_agent, test_param_priors, test_input, test_responses;)

#Fit single chain with defaults
fitted_model = fit_model(
test_agent,
test_param_priors,
test_input,
test_responses;
fixed_parameters = test_fixed_parameters,
verbose = false,
n_iterations = 10,
)
@test fitted_model isa Turing.Chains
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)

@test fitted_model isa ActionModels.FitModelResults

#Plot the parameter distribution
plot_parameter_distribution(fitted_model, test_param_priors)
# plot_parameter_distribution(fitted_model, test_param_priors)

# Posterior predictive plot
plot_predictive_simulation(
fitted_model,
test_agent,
test_input,
("xbin", "posterior_mean"),
verbose = false,
n_simulations = 3,
)
# plot_predictive_simulation(
# fitted_model,
# test_agent,
# test_input,
# ("xbin", "posterior_mean"),
# verbose = false,
# n_simulations = 3,
# )
end
end
6 changes: 3 additions & 3 deletions test/testsuite/test_premade_agent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using Test
actions = give_inputs!(test_agent, [0.01, 0.02, 0.03])

#Check that actions are floats
@test actions isa Vector{Any}
@test actions isa Vector

#Check that get_surprise works
@test get_surprise(test_agent.substruct) isa Real
Expand All @@ -36,7 +36,7 @@ using Test
actions = give_inputs!(test_agent, [1, 0, 1])

#Check that actions are floats
@test actions isa Vector{Any}
@test actions isa Vector

#Check that get_surprise works
@test get_surprise(test_agent.substruct) isa Real
Expand All @@ -56,7 +56,7 @@ using Test
actions = give_inputs!(test_agent, [1, 0, 1])

#Check that actions are floats
@test actions isa Vector{Any}
@test actions isa Vector

#Check that get_surprise works
@test get_surprise(test_agent.substruct) isa Real
Expand Down

0 comments on commit 1734804

Please sign in to comment.