Skip to content

Commit

Permalink
Merge pull request #164 from ilabcode/test
Browse files Browse the repository at this point in the history
removed category means and noisy binary
  • Loading branch information
PTWaade authored Sep 15, 2024
2 parents 9dcf733 + 3cc5b82 commit 93b5017
Show file tree
Hide file tree
Showing 12 changed files with 13 additions and 94 deletions.
4 changes: 0 additions & 4 deletions docs/julia_files/tutorials/classic_binary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ inputs = CSV.read(data_path * "classic_binary_inputs.csv", DataFrame)[!, 1];

# Create an HGF
hgf_parameters = Dict(
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "volatility") => -2.5,
("xprob", "initial_mean") => 0,
("xprob", "initial_precision") => 1,
Expand Down Expand Up @@ -54,8 +52,6 @@ plot_trajectory(agent, ("xvol", "posterior"))
# Set fixed parameters
fixed_parameters = Dict(
"action_noise" => 0.2,
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "initial_mean") => 0,
("xprob", "initial_precision") => 1,
("xvol", "initial_mean") => 1,
Expand Down
4 changes: 0 additions & 4 deletions docs/julia_files/user_guide/fitting_hgf_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ using HierarchicalGaussianFiltering
# We will define a binary 3-level HGF and its parameters

hgf_parameters = Dict(
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "volatility") => -2.5,
("xprob", "initial_mean") => 0,
("xprob", "initial_precision") => 1,
Expand Down Expand Up @@ -86,8 +84,6 @@ plot_trajectory!(agent, ("xbin", "prediction"))
# Set fixed parameters. We choose to fit the evolution rate of the xprob node.
fixed_parameters = Dict(
"action_noise" => 0.2,
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "initial_mean") => 0,
("xprob", "initial_precision") => 1,
("xvol", "initial_mean") => 1,
Expand Down
2 changes: 0 additions & 2 deletions docs/julia_files/user_guide/utility_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ agent_parameter = Dict("action_noise" => 0.3)
#We also specify our HGF and custom parameter settings:

hgf_parameters = Dict(
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "volatility") => -2.5,
("xprob", "initial_mean") => 0,
("xprob", "initial_precision") => 1,
Expand Down
6 changes: 0 additions & 6 deletions src/ActionModels_variations/utils/set_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,6 @@ function ActionModels.set_parameters!(
)
end

#If the param is a vector of category_means
if param_value isa Vector
#Convert it to a vector of reals
param_value = convert(Vector{Real}, param_value)
end

#Set the parameter value
setfield!(node.parameters, Symbol(param_name), param_value)

Expand Down
2 changes: 0 additions & 2 deletions src/create_hgf/hgf_structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,6 @@ end
Configuration of parameters in binary input node. Default category mean set to [0,1]
"""
Base.@kwdef mutable struct BinaryInputNodeParameters
category_means::Vector{Union{Real}} = [0, 1]
input_precision::Real = Inf
coupling_strengths::Dict{String,Real} = Dict{String,Real}()
end

Expand Down
2 changes: 0 additions & 2 deletions src/premade_models/premade_hgfs/premade_binary_2level.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ function premade_binary_2level(config::Dict; verbose::Bool = true)

#Defaults
spec_defaults = Dict(
("u", "category_means") => [0, 1],
("u", "input_precision") => Inf,
("xprob", "volatility") => -2,
("xprob", "drift") => 0,
("xprob", "autoconnection_strength") => 1,
Expand Down
4 changes: 0 additions & 4 deletions src/premade_models/premade_hgfs/premade_binary_3level.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ This HGF has five shared parameters:
"coupling_strengths_xprob_xvol"
# Config defaults:
- ("u", "category_means"): [0, 1]
- ("u", "input_precision"): Inf
- ("xprob", "volatility"): -2
- ("xvol", "volatility"): -2
- ("xbin", "xprob", "coupling_strength"): 1
Expand All @@ -28,8 +26,6 @@ function premade_binary_3level(config::Dict; verbose::Bool = true)

#Defaults
defaults = Dict(
("u", "category_means") => [0, 1],
("u", "input_precision") => Inf,
("xprob", "volatility") => -2,
("xprob", "drift") => 0,
("xprob", "autoconnection_strength") => 1,
Expand Down
34 changes: 3 additions & 31 deletions src/update_hgf/node_updates/binary_state_node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,8 @@ function calculate_posterior_precision(node::BinaryStateNode)
child = node.edges.observation_children[1]

#Simple update with inifinte precision
if child.parameters.input_precision == Inf
posterior_precision = Inf
#Update with finite precision
else
posterior_precision =
1 / (node.states.posterior_mean * (1 - node.states.posterior_mean))
end
posterior_precision = Inf


## If the child is a category child ##
elseif length(node.edges.category_children) > 0
Expand Down Expand Up @@ -141,30 +136,7 @@ function calculate_posterior_mean(node::BinaryStateNode, update_type::HGFUpdateT
#Set the posterior to missing
posterior_mean = missing
else
#Update with infinte input precision
if child.parameters.input_precision == Inf
posterior_mean = child.states.input_value

#Update with finite input precision
else
posterior_mean =
node.states.prediction_mean * exp(
-0.5 *
node.states.prediction_precision *
child.parameters.category_means[1]^2,
) / (
node.states.prediction_mean * exp(
-0.5 *
node.states.prediction_precision *
child.parameters.category_means[1]^2,
) +
(1 - node.states.prediction_mean) * exp(
-0.5 *
node.states.prediction_precision *
child.parameters.category_means[2]^2,
)
)
end
posterior_mean = child.states.input_value
end

## If the child is a category child ##
Expand Down
43 changes: 9 additions & 34 deletions src/utils/get_surprise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,40 +103,15 @@ function get_surprise(node::BinaryInputNode)
parents_prediction_mean += parent.states.prediction_mean
end

#If the input precision is infinite
if node.parameters.input_precision == Inf

#If a 1 was observed
if node.states.input_value == 0
#Get surprise
surprise = -log(1 - parents_prediction_mean)

#If a 0 was observed
elseif node.states.input_value == 1
#Get surprise
surprise = -log(parents_prediction_mean)
end

#If the input precision is finite
else
#Get the surprise
surprise =
-log(
parents_prediction_mean * pdf(
Normal(
node.parameters.category_means[1],
node.parameters.input_precision,
),
node.states.input_value,
) +
(1 - parents_prediction_mean) * pdf(
Normal(
node.parameters.category_means[2],
node.parameters.input_precision,
),
node.states.input_value,
),
)
#If a 1 was observed
if node.states.input_value == 0
#Get surprise
surprise = -log(1 - parents_prediction_mean)

#If a 0 was observed
elseif node.states.input_value == 1
#Get surprise
surprise = -log(parents_prediction_mean)
end

return surprise
Expand Down
2 changes: 1 addition & 1 deletion test/testsuite/Aqua.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
using HierarchicalGaussianFiltering
using Aqua
Aqua.test_all(HierarchicalGaussianFiltering, ambiguities = false)
Aqua.test_all(HierarchicalGaussianFiltering, ambiguities = false, persistent_tasks = false)
2 changes: 0 additions & 2 deletions test/testsuite/test_canonical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ using Plots

#Set parameters
test_parameters = Dict(
("u", "category_means") => [0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "volatility") => -2.5,
("xvol", "volatility") => -6.0,
("xbin", "xprob", "coupling_strength") => 1.0,
Expand Down
2 changes: 0 additions & 2 deletions test/testsuite/test_fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ using Turing

#Set fixed parameters and priors
test_fixed_parameters = Dict(
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "initial_mean") => 3.0,
("xprob", "initial_precision") => exp(2.306),
("xvol", "initial_mean") => 3.2189,
Expand Down

0 comments on commit 93b5017

Please sign in to comment.