Skip to content

Commit

Permalink
Merge pull request #162 from ilabcode/dev
Browse files Browse the repository at this point in the history
new version
  • Loading branch information
PTWaade authored Sep 13, 2024
2 parents 2f99df2 + 1734804 commit 9db9b21
Show file tree
Hide file tree
Showing 13 changed files with 104 additions and 139 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ authors = [ "Peter Thestrup Waade [email protected]",
"Anna Hedvig Møller [email protected]",
"Jacopo Comoglio [email protected]",
"Christoph Mathys [email protected]"]
version = "0.5.4"
version = "0.5.5"


[deps]
Expand All @@ -13,7 +13,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"

[compat]
ActionModels = "0.5"
ActionModels = "0.6"
Distributions = "0.25"
RecipesBase = "1"
julia = "1.10"
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ plot_trajectory!(agent, ("x", "prediction"))
using Distributions
prior = Dict(("xprob", "volatility") => Normal(1, 0.5))
model = fit_model(agent, prior, inputs, actions, n_iterations = 20)
#Create model
model = create_model(agent, prior, inputs, actions;)
#Fit single chain with 10 iterations
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)
````
![Image1](docs/src/images/readme/fit_model.png)
### Plot chains
Expand All @@ -106,7 +111,7 @@ plot(model)
### Plot prior angainst posterior

````@example index
plot_parameter_distribution(model, prior)
# plot_parameter_distribution(model, prior)
````
![Image1](docs/src/images/readme/prior_posterior.png)
### Get posterior
Expand Down
8 changes: 6 additions & 2 deletions docs/julia_files/index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ plot_trajectory!(agent, ("xbin", "prediction"))
using Distributions
prior = Dict(("xprob", "volatility") => Normal(1, 0.5))

model = fit_model(agent, prior, inputs, actions, n_iterations = 20)
#Create model
model = create_model(agent, prior, inputs, actions)

#Fit single chain with 10 iterations
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)

#-

Expand All @@ -65,7 +69,7 @@ plot(model)
#-

# ### Plot prior angainst posterior
plot_parameter_distribution(model, prior)
# plot_parameter_distribution(model, prior)
#-
# ### Get posterior
get_posteriors(model)
30 changes: 13 additions & 17 deletions docs/julia_files/tutorials/classic_binary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,27 +81,23 @@ plot_predictive_simulation(
actions = CSV.read(data_path * "classic_binary_actions.csv", DataFrame)[!, 1];
#-
# Fit the actions
fitted_model = fit_model(
agent,
param_priors,
inputs,
actions,
fixed_parameters = fixed_parameters,
verbose = true,
n_iterations = 10,
)
#Create model
model = create_model(agent, param_priors, inputs, actions)

#Fit single chain with 10 iterations
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)
#-
#Plot the chains
plot(fitted_model)
#-
# Plot the posterior
plot_parameter_distribution(fitted_model, param_priors)
# plot_parameter_distribution(fitted_model, param_priors)
#-
# Posterior predictive plot
plot_predictive_simulation(
fitted_model,
agent,
inputs,
("xbin", "prediction_mean"),
n_simulations = 3,
)
# plot_predictive_simulation(
# fitted_model,
# agent,
# inputs,
# ("xbin", "prediction_mean"),
# n_simulations = 3,
# )
15 changes: 5 additions & 10 deletions docs/julia_files/tutorials/classic_usdchf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,14 @@ plot_predictive_simulation(
)
#-
# Do parameter recovery
fitted_model = fit_model(
agent,
param_priors,
inputs,
actions,
fixed_parameters = fixed_parameters,
verbose = false,
n_iterations = 10,
)
model = create_model(agent, param_priors, inputs, actions)

#Fit single chain with 10 iterations
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)
#-
# Plot the chains
plot(fitted_model)
#-
# Plot prior posterior distributions
plot_parameter_distribution(fitted_model, param_priors)
# plot_parameter_distribution(fitted_model, param_priors)
#-
32 changes: 14 additions & 18 deletions docs/julia_files/user_guide/fitting_hgf_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,19 @@ param_priors = Dict(("xprob", "volatility") => Normal(-3.0, 0.5));

# We can fit the evolution rate by inputting the variables:

# Fit the actions
fitted_model = fit_model(
agent,
param_priors,
inputs,
actions,
fixed_parameters = fixed_parameters,
verbose = true,
n_iterations = 10,
)
# Create model
model = create_model(agent, param_priors, inputs, actions)

#Fit single chain with 10 iterations
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)

set_parameters!(agent, hgf_parameters)

# ## Plotting Functions
plot(fitted_model)

# Plot the posterior
plot_parameter_distribution(fitted_model, param_priors)
# plot_parameter_distribution(fitted_model, param_priors)


# # Predictive Simulations with plot\_predictive\_distributions()
Expand Down Expand Up @@ -151,13 +147,13 @@ fitted_model =
set_parameters!(agent, hgf_parameters)

# We can place our turing chain as a our posterior in the function, and get our posterior predictive simulation plot:
plot_predictive_simulation(
fitted_model,
agent,
inputs,
("xbin", "prediction_mean"),
n_simulations = 100,
)
# plot_predictive_simulation(
# fitted_model,
# agent,
# inputs,
# ("xbin", "prediction_mean"),
# n_simulations = 100,
# )

# We can get the posterior
get_posteriors(fitted_model)
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Documenter
using Literate


hgf_path = dirname(dirname(pathof(HierarchicalGaussianFiltering)))
hgf_path = dirname(pathof(HierarchicalGaussianFiltering))

juliafiles_path = hgf_path * "/docs/julia_files"
user_guides_path = juliafiles_path * "/user_guide"
Expand Down
File renamed without changes
File renamed without changes
30 changes: 9 additions & 21 deletions src/ActionModels_variations/utils/get_history.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,32 +85,20 @@ end

function ActionModels.get_history(node::AbstractNode)

#Initialize dictionary
state_histories = Dict()

#Go through all states in the node's history
for state_key in fieldnames(typeof(node.history))
#Get the node's name and history
node_name = node.name
node_history = node.history

#And add their histories to the output
state_histories[String(state_key)] = getproperty(node.history, state_key)
#Return a dictionary of the node's history
Dict((node_name,string(key))=>getfield(node_history, key) for key fieldnames(typeof(node_history)))

end

return state_histories
end

function ActionModels.get_history(hgf::HGF)

#Initialize dict for state histories
state_histories = Dict()

#For each node
for node in hgf.ordered_nodes.all_nodes
#Get out the histories of the node
node_histories = get_history(node)
#And merge them with the dict
merge(state_histories, node_histories)
end
#Get the histories of all nodes
merge(
[get_history(node) for node in hgf.ordered_nodes.all_nodes]...
)

return state_histories
end
39 changes: 14 additions & 25 deletions src/ActionModels_variations/utils/get_states.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,38 +100,27 @@ function ActionModels.get_states(hgf::HGF, node_name::String)
throw(ArgumentError("The node $node_name does not exist"))
end

#Initialize dict
states = Dict()

#Get out the node
node = hgf.all_nodes[node_name]
#Get the states of the node
node = get_states(hgf.all_nodes[node_name])
end

#For each state in the node
for state_key in fieldnames(typeof(node.states))
function ActionModels.get_states(node::AbstractNode)

#Add it to the dictionary
states[(node_name, String(state_key))] = get_states(node, String(state_key))
#Get the node's name and states
node_name = node.name
node_states = node.states

end
#Return a dictionary of the node's states
Dict((node_name,string(key))=>getfield(node_states, key) for key fieldnames(typeof(node_states)))

#Get its states
return states
end


### For getting all states of an HGF ###
function ActionModels.get_states(hgf::HGF)

#Initialize dict for state states
states = Dict()

#For each node
for node_name in keys(hgf.all_nodes)
#Get out the states of the node
node_states = get_states(hgf, node_name)
#And merge them with the dict
states = merge(states, node_states)
end

return states
#Get the states of all nodes
merge(
[get_states(node) for node in hgf.ordered_nodes.all_nodes]...
)

end
68 changes: 30 additions & 38 deletions test/testsuite/test_fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,26 @@ using Turing
("x", "drift") => Normal(0, 1),
)

#Create model
model = create_model(test_agent, test_param_priors, test_input, test_responses;)

#Fit single chain with defaults
fitted_model = fit_model(
test_agent,
test_param_priors,
test_input,
test_responses;
fixed_parameters = test_fixed_parameters,
verbose = false,
n_iterations = 10,
)
@test fitted_model isa Turing.Chains
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)

@test fitted_model isa ActionModels.FitModelResults

#Plot the parameter distribution
plot_parameter_distribution(fitted_model, test_param_priors)
# plot_parameter_distribution(fitted_model, test_param_priors)

# Posterior predictive plot
plot_predictive_simulation(
fitted_model,
test_agent,
test_input,
("x", "posterior_mean");
verbose = false,
n_simulations = 3,
)
# plot_predictive_simulation(
# fitted_model,
# test_agent,
# test_input,
# ("x", "posterior_mean");
# verbose = false,
# n_simulations = 3,
# )
end


Expand Down Expand Up @@ -95,29 +91,25 @@ using Turing
("xprob", "volatility") => Normal(-7, 5),
)

#Create model
model = create_model(test_agent, test_param_priors, test_input, test_responses;)

#Fit single chain with defaults
fitted_model = fit_model(
test_agent,
test_param_priors,
test_input,
test_responses;
fixed_parameters = test_fixed_parameters,
verbose = false,
n_iterations = 10,
)
@test fitted_model isa Turing.Chains
fitted_model = fit_model(model; n_iterations = 10, n_chains = 1)

@test fitted_model isa ActionModels.FitModelResults

#Plot the parameter distribution
plot_parameter_distribution(fitted_model, test_param_priors)
# plot_parameter_distribution(fitted_model, test_param_priors)

# Posterior predictive plot
plot_predictive_simulation(
fitted_model,
test_agent,
test_input,
("xbin", "posterior_mean"),
verbose = false,
n_simulations = 3,
)
# plot_predictive_simulation(
# fitted_model,
# test_agent,
# test_input,
# ("xbin", "posterior_mean"),
# verbose = false,
# n_simulations = 3,
# )
end
end
Loading

0 comments on commit 9db9b21

Please sign in to comment.