Skip to content

Commit

Permalink
Merge pull request #165 from ilabcode/dev
Browse files Browse the repository at this point in the history
minor update
  • Loading branch information
PTWaade authored Sep 15, 2024
2 parents 66b5bec + 6fd9ae5 commit d1b8b76
Show file tree
Hide file tree
Showing 13 changed files with 14 additions and 95 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ authors = [ "Peter Thestrup Waade [email protected]",
"Anna Hedvig Møller [email protected]",
"Jacopo Comoglio [email protected]",
"Christoph Mathys [email protected]"]
version = "0.6.0"
version = "0.6.1"


[deps]
Expand Down
4 changes: 0 additions & 4 deletions docs/julia_files/tutorials/classic_binary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ inputs = CSV.read(data_path * "classic_binary_inputs.csv", DataFrame)[!, 1];

# Create an HGF
hgf_parameters = Dict(
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "volatility") => -2.5,
("xprob", "initial_mean") => 0,
("xprob", "initial_precision") => 1,
Expand Down Expand Up @@ -54,8 +52,6 @@ plot_trajectory(agent, ("xvol", "posterior"))
# Set fixed parameters
fixed_parameters = Dict(
"action_noise" => 0.2,
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "initial_mean") => 0,
("xprob", "initial_precision") => 1,
("xvol", "initial_mean") => 1,
Expand Down
4 changes: 0 additions & 4 deletions docs/julia_files/user_guide/fitting_hgf_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ using HierarchicalGaussianFiltering
# We will define a binary 3-level HGF and its parameters

hgf_parameters = Dict(
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "volatility") => -2.5,
("xprob", "initial_mean") => 0,
("xprob", "initial_precision") => 1,
Expand Down Expand Up @@ -86,8 +84,6 @@ plot_trajectory!(agent, ("xbin", "prediction"))
# Set fixed parameters. We choose to fit the evolution rate of the xprob node.
fixed_parameters = Dict(
"action_noise" => 0.2,
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "initial_mean") => 0,
("xprob", "initial_precision") => 1,
("xvol", "initial_mean") => 1,
Expand Down
2 changes: 0 additions & 2 deletions docs/julia_files/user_guide/utility_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ agent_parameter = Dict("action_noise" => 0.3)
#We also specify our HGF and custom parameter settings:

hgf_parameters = Dict(
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "volatility") => -2.5,
("xprob", "initial_mean") => 0,
("xprob", "initial_precision") => 1,
Expand Down
6 changes: 0 additions & 6 deletions src/ActionModels_variations/utils/set_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,6 @@ function ActionModels.set_parameters!(
)
end

#If the param is a vector of category_means
if param_value isa Vector
#Convert it to a vector of reals
param_value = convert(Vector{Real}, param_value)
end

#Set the parameter value
setfield!(node.parameters, Symbol(param_name), param_value)

Expand Down
2 changes: 0 additions & 2 deletions src/create_hgf/hgf_structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,6 @@ end
Configuration of parameters in binary input node. Default category mean set to [0,1]
"""
Base.@kwdef mutable struct BinaryInputNodeParameters
category_means::Vector{Union{Real}} = [0, 1]
input_precision::Real = Inf
coupling_strengths::Dict{String,Real} = Dict{String,Real}()
end

Expand Down
2 changes: 0 additions & 2 deletions src/premade_models/premade_hgfs/premade_binary_2level.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ function premade_binary_2level(config::Dict; verbose::Bool = true)

#Defaults
spec_defaults = Dict(
("u", "category_means") => [0, 1],
("u", "input_precision") => Inf,
("xprob", "volatility") => -2,
("xprob", "drift") => 0,
("xprob", "autoconnection_strength") => 1,
Expand Down
4 changes: 0 additions & 4 deletions src/premade_models/premade_hgfs/premade_binary_3level.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ This HGF has five shared parameters:
"coupling_strengths_xprob_xvol"
# Config defaults:
- ("u", "category_means"): [0, 1]
- ("u", "input_precision"): Inf
- ("xprob", "volatility"): -2
- ("xvol", "volatility"): -2
- ("xbin", "xprob", "coupling_strength"): 1
Expand All @@ -28,8 +26,6 @@ function premade_binary_3level(config::Dict; verbose::Bool = true)

#Defaults
defaults = Dict(
("u", "category_means") => [0, 1],
("u", "input_precision") => Inf,
("xprob", "volatility") => -2,
("xprob", "drift") => 0,
("xprob", "autoconnection_strength") => 1,
Expand Down
34 changes: 3 additions & 31 deletions src/update_hgf/node_updates/binary_state_node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,8 @@ function calculate_posterior_precision(node::BinaryStateNode)
child = node.edges.observation_children[1]

#Simple update with inifinte precision
if child.parameters.input_precision == Inf
posterior_precision = Inf
#Update with finite precision
else
posterior_precision =
1 / (node.states.posterior_mean * (1 - node.states.posterior_mean))
end
posterior_precision = Inf


## If the child is a category child ##
elseif length(node.edges.category_children) > 0
Expand Down Expand Up @@ -141,30 +136,7 @@ function calculate_posterior_mean(node::BinaryStateNode, update_type::HGFUpdateT
#Set the posterior to missing
posterior_mean = missing
else
#Update with infinte input precision
if child.parameters.input_precision == Inf
posterior_mean = child.states.input_value

#Update with finite input precision
else
posterior_mean =
node.states.prediction_mean * exp(
-0.5 *
node.states.prediction_precision *
child.parameters.category_means[1]^2,
) / (
node.states.prediction_mean * exp(
-0.5 *
node.states.prediction_precision *
child.parameters.category_means[1]^2,
) +
(1 - node.states.prediction_mean) * exp(
-0.5 *
node.states.prediction_precision *
child.parameters.category_means[2]^2,
)
)
end
posterior_mean = child.states.input_value
end

## If the child is a category child ##
Expand Down
43 changes: 9 additions & 34 deletions src/utils/get_surprise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,40 +103,15 @@ function get_surprise(node::BinaryInputNode)
parents_prediction_mean += parent.states.prediction_mean
end

#If the input precision is infinite
if node.parameters.input_precision == Inf

#If a 1 was observed
if node.states.input_value == 0
#Get surprise
surprise = -log(1 - parents_prediction_mean)

#If a 0 was observed
elseif node.states.input_value == 1
#Get surprise
surprise = -log(parents_prediction_mean)
end

#If the input precision is finite
else
#Get the surprise
surprise =
-log(
parents_prediction_mean * pdf(
Normal(
node.parameters.category_means[1],
node.parameters.input_precision,
),
node.states.input_value,
) +
(1 - parents_prediction_mean) * pdf(
Normal(
node.parameters.category_means[2],
node.parameters.input_precision,
),
node.states.input_value,
),
)
#If a 1 was observed
if node.states.input_value == 0
#Get surprise
surprise = -log(1 - parents_prediction_mean)

#If a 0 was observed
elseif node.states.input_value == 1
#Get surprise
surprise = -log(parents_prediction_mean)
end

return surprise
Expand Down
2 changes: 1 addition & 1 deletion test/testsuite/Aqua.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
using HierarchicalGaussianFiltering
using Aqua
Aqua.test_all(HierarchicalGaussianFiltering, ambiguities = false)
Aqua.test_all(HierarchicalGaussianFiltering, ambiguities = false, persistent_tasks = false)
2 changes: 0 additions & 2 deletions test/testsuite/test_canonical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,6 @@ using Plots

#Set parameters
test_parameters = Dict(
("u", "category_means") => [0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "volatility") => -2.5,
("xvol", "volatility") => -6.0,
("xbin", "xprob", "coupling_strength") => 1.0,
Expand Down
2 changes: 0 additions & 2 deletions test/testsuite/test_fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ using Turing

#Set fixed parameters and priors
test_fixed_parameters = Dict(
("u", "category_means") => Real[0.0, 1.0],
("u", "input_precision") => Inf,
("xprob", "initial_mean") => 3.0,
("xprob", "initial_precision") => exp(2.306),
("xvol", "initial_mean") => 3.2189,
Expand Down

2 comments on commit d1b8b76

@PTWaade
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Removed the noisy binary HGF

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/115222

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.1 -m "<description of version>" d1b8b76935aece97e760c7db8ae51692439cb41e
git push origin v0.6.1

Please sign in to comment.