From 3cc5b82002bd5f3cd4467a631bf36027faaa3842 Mon Sep 17 00:00:00 2001 From: Peter Thestrup Waade Date: Sun, 15 Sep 2024 11:47:51 +0200 Subject: [PATCH] removed category means and noisy binary --- docs/julia_files/tutorials/classic_binary.jl | 4 -- .../user_guide/fitting_hgf_models.jl | 4 -- .../user_guide/utility_functions.jl | 2 - .../utils/set_parameters.jl | 6 --- src/create_hgf/hgf_structs.jl | 2 - .../premade_hgfs/premade_binary_2level.jl | 2 - .../premade_hgfs/premade_binary_3level.jl | 4 -- .../node_updates/binary_state_node.jl | 34 ++------------- src/utils/get_surprise.jl | 43 ++++--------------- test/testsuite/Aqua.jl | 2 +- test/testsuite/test_canonical.jl | 2 - test/testsuite/test_fit_model.jl | 2 - 12 files changed, 13 insertions(+), 94 deletions(-) diff --git a/docs/julia_files/tutorials/classic_binary.jl b/docs/julia_files/tutorials/classic_binary.jl index f07d2c8..7ecbf51 100644 --- a/docs/julia_files/tutorials/classic_binary.jl +++ b/docs/julia_files/tutorials/classic_binary.jl @@ -21,8 +21,6 @@ inputs = CSV.read(data_path * "classic_binary_inputs.csv", DataFrame)[!, 1]; # Create an HGF hgf_parameters = Dict( - ("u", "category_means") => Real[0.0, 1.0], - ("u", "input_precision") => Inf, ("xprob", "volatility") => -2.5, ("xprob", "initial_mean") => 0, ("xprob", "initial_precision") => 1, @@ -54,8 +52,6 @@ plot_trajectory(agent, ("xvol", "posterior")) # Set fixed parameters fixed_parameters = Dict( "action_noise" => 0.2, - ("u", "category_means") => Real[0.0, 1.0], - ("u", "input_precision") => Inf, ("xprob", "initial_mean") => 0, ("xprob", "initial_precision") => 1, ("xvol", "initial_mean") => 1, diff --git a/docs/julia_files/user_guide/fitting_hgf_models.jl b/docs/julia_files/user_guide/fitting_hgf_models.jl index f742b82..2fa0382 100644 --- a/docs/julia_files/user_guide/fitting_hgf_models.jl +++ b/docs/julia_files/user_guide/fitting_hgf_models.jl @@ -45,8 +45,6 @@ using HierarchicalGaussianFiltering # We will define a binary 3-level HGF and its parameters hgf_parameters = Dict( - ("u", "category_means") => Real[0.0, 1.0], - ("u", "input_precision") => Inf, ("xprob", "volatility") => -2.5, ("xprob", "initial_mean") => 0, ("xprob", "initial_precision") => 1, @@ -86,8 +84,6 @@ plot_trajectory!(agent, ("xbin", "prediction")) # Set fixed parameters. We choose to fit the evolution rate of the xprob node. fixed_parameters = Dict( "action_noise" => 0.2, - ("u", "category_means") => Real[0.0, 1.0], - ("u", "input_precision") => Inf, ("xprob", "initial_mean") => 0, ("xprob", "initial_precision") => 1, ("xvol", "initial_mean") => 1, diff --git a/docs/julia_files/user_guide/utility_functions.jl b/docs/julia_files/user_guide/utility_functions.jl index cdd278c..941ee58 100644 --- a/docs/julia_files/user_guide/utility_functions.jl +++ b/docs/julia_files/user_guide/utility_functions.jl @@ -62,8 +62,6 @@ agent_parameter = Dict("action_noise" => 0.3) #We also specify our HGF and custom parameter settings: hgf_parameters = Dict( - ("u", "category_means") => Real[0.0, 1.0], - ("u", "input_precision") => Inf, ("xprob", "volatility") => -2.5, ("xprob", "initial_mean") => 0, ("xprob", "initial_precision") => 1, diff --git a/src/ActionModels_variations/utils/set_parameters.jl b/src/ActionModels_variations/utils/set_parameters.jl index baa6402..b59f760 100644 --- a/src/ActionModels_variations/utils/set_parameters.jl +++ b/src/ActionModels_variations/utils/set_parameters.jl @@ -40,12 +40,6 @@ function ActionModels.set_parameters!( ) end - #If the param is a vector of category_means - if param_value isa Vector - #Convert it to a vector of reals - param_value = convert(Vector{Real}, param_value) - end - #Set the parameter value setfield!(node.parameters, Symbol(param_name), param_value) diff --git a/src/create_hgf/hgf_structs.jl b/src/create_hgf/hgf_structs.jl index 5c8e37e..fc34d82 100644 --- a/src/create_hgf/hgf_structs.jl +++ b/src/create_hgf/hgf_structs.jl @@ -340,8 +340,6 @@ end Configuration of parameters in binary input node. Default category mean set to [0,1] """ Base.@kwdef mutable struct BinaryInputNodeParameters - category_means::Vector{Union{Real}} = [0, 1] - input_precision::Real = Inf coupling_strengths::Dict{String,Real} = Dict{String,Real}() end diff --git a/src/premade_models/premade_hgfs/premade_binary_2level.jl b/src/premade_models/premade_hgfs/premade_binary_2level.jl index 49750d9..eb59a08 100644 --- a/src/premade_models/premade_hgfs/premade_binary_2level.jl +++ b/src/premade_models/premade_hgfs/premade_binary_2level.jl @@ -11,8 +11,6 @@ function premade_binary_2level(config::Dict; verbose::Bool = true) #Defaults spec_defaults = Dict( - ("u", "category_means") => [0, 1], - ("u", "input_precision") => Inf, ("xprob", "volatility") => -2, ("xprob", "drift") => 0, ("xprob", "autoconnection_strength") => 1, diff --git a/src/premade_models/premade_hgfs/premade_binary_3level.jl b/src/premade_models/premade_hgfs/premade_binary_3level.jl index c04f733..43226f2 100644 --- a/src/premade_models/premade_hgfs/premade_binary_3level.jl +++ b/src/premade_models/premade_hgfs/premade_binary_3level.jl @@ -13,8 +13,6 @@ This HGF has five shared parameters: "coupling_strengths_xprob_xvol" # Config defaults: - - ("u", "category_means"): [0, 1] - - ("u", "input_precision"): Inf - ("xprob", "volatility"): -2 - ("xvol", "volatility"): -2 - ("xbin", "xprob", "coupling_strength"): 1 @@ -28,8 +26,6 @@ function premade_binary_3level(config::Dict; verbose::Bool = true) #Defaults defaults = Dict( - ("u", "category_means") => [0, 1], - ("u", "input_precision") => Inf, ("xprob", "volatility") => -2, ("xprob", "drift") => 0, ("xprob", "autoconnection_strength") => 1, diff --git a/src/update_hgf/node_updates/binary_state_node.jl b/src/update_hgf/node_updates/binary_state_node.jl index 4e80214..9f7acca 100644 --- a/src/update_hgf/node_updates/binary_state_node.jl +++ b/src/update_hgf/node_updates/binary_state_node.jl @@ -99,13 +99,8 @@ function calculate_posterior_precision(node::BinaryStateNode) child = node.edges.observation_children[1] #Simple update with inifinte precision - if child.parameters.input_precision == Inf - posterior_precision = Inf - #Update with finite precision - else - posterior_precision = - 1 / (node.states.posterior_mean * (1 - node.states.posterior_mean)) - end + posterior_precision = Inf + ## If the child is a category child ## elseif length(node.edges.category_children) > 0 @@ -141,30 +136,7 @@ function calculate_posterior_mean(node::BinaryStateNode, update_type::HGFUpdateT #Set the posterior to missing posterior_mean = missing else - #Update with infinte input precision - if child.parameters.input_precision == Inf - posterior_mean = child.states.input_value - - #Update with finite input precision - else - posterior_mean = - node.states.prediction_mean * exp( - -0.5 * - node.states.prediction_precision * - child.parameters.category_means[1]^2, - ) / ( - node.states.prediction_mean * exp( - -0.5 * - node.states.prediction_precision * - child.parameters.category_means[1]^2, - ) + - (1 - node.states.prediction_mean) * exp( - -0.5 * - node.states.prediction_precision * - child.parameters.category_means[2]^2, - ) - ) - end + posterior_mean = child.states.input_value end ## If the child is a category child ## diff --git a/src/utils/get_surprise.jl b/src/utils/get_surprise.jl index 24b1667..1518b9f 100644 --- a/src/utils/get_surprise.jl +++ b/src/utils/get_surprise.jl @@ -103,40 +103,15 @@ function get_surprise(node::BinaryInputNode) parents_prediction_mean += parent.states.prediction_mean end - #If the input precision is infinite - if node.parameters.input_precision == Inf - - #If a 1 was observed - if node.states.input_value == 0 - #Get surprise - surprise = -log(1 - parents_prediction_mean) - - #If a 0 was observed - elseif node.states.input_value == 1 - #Get surprise - surprise = -log(parents_prediction_mean) - end - - #If the input precision is finite - else - #Get the surprise - surprise = - -log( - parents_prediction_mean * pdf( - Normal( - node.parameters.category_means[1], - node.parameters.input_precision, - ), - node.states.input_value, - ) + - (1 - parents_prediction_mean) * pdf( - Normal( - node.parameters.category_means[2], - node.parameters.input_precision, - ), - node.states.input_value, - ), - ) + #If a 1 was observed + if node.states.input_value == 0 + #Get surprise + surprise = -log(1 - parents_prediction_mean) + + #If a 0 was observed + elseif node.states.input_value == 1 + #Get surprise + surprise = -log(parents_prediction_mean) end return surprise diff --git a/test/testsuite/Aqua.jl b/test/testsuite/Aqua.jl index ed45564..05dd27d 100644 --- a/test/testsuite/Aqua.jl +++ b/test/testsuite/Aqua.jl @@ -1,3 +1,3 @@ using HierarchicalGaussianFiltering using Aqua -Aqua.test_all(HierarchicalGaussianFiltering, ambiguities = false) +Aqua.test_all(HierarchicalGaussianFiltering, ambiguities = false, persistent_tasks = false) diff --git a/test/testsuite/test_canonical.jl b/test/testsuite/test_canonical.jl index e97ac49..ab57c81 100644 --- a/test/testsuite/test_canonical.jl +++ b/test/testsuite/test_canonical.jl @@ -92,8 +92,6 @@ using Plots #Set parameters test_parameters = Dict( - ("u", "category_means") => [0.0, 1.0], - ("u", "input_precision") => Inf, ("xprob", "volatility") => -2.5, ("xvol", "volatility") => -6.0, ("xbin", "xprob", "coupling_strength") => 1.0, diff --git a/test/testsuite/test_fit_model.jl b/test/testsuite/test_fit_model.jl index ee5f716..259210c 100644 --- a/test/testsuite/test_fit_model.jl +++ b/test/testsuite/test_fit_model.jl @@ -75,8 +75,6 @@ using Turing #Set fixed parameters and priors test_fixed_parameters = Dict( - ("u", "category_means") => Real[0.0, 1.0], - ("u", "input_precision") => Inf, ("xprob", "initial_mean") => 3.0, ("xprob", "initial_precision") => exp(2.306), ("xvol", "initial_mean") => 3.2189,