diff --git a/src/create_hgf/create_premade_hgf.jl b/src/create_hgf/create_premade_hgf.jl index 96a042b..5d83f74 100644 --- a/src/create_hgf/create_premade_hgf.jl +++ b/src/create_hgf/create_premade_hgf.jl @@ -17,8 +17,7 @@ function premade_hgf(model_name::String, config::Dict = Dict(); verbose = true) "binary_3level" => premade_binary_3level, #The standard binary input 3 level HGF "JGET" => premade_JGET, #The JGET model "categorical_3level" => premade_categorical_3level, #The standard categorical input 3 level HGF - "categorical_3level_state_transitions" => - premade_categorical_3level_state_transitions, #Categorical 3 level HGF for learning state transitions + "categorical_state_transitions" => premade_categorical_state_transitions, #Categorical 3 level HGF for learning state transitions ) #Check that the specified model is in the list of keys diff --git a/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl index 8a77871..eecd12b 100644 --- a/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl +++ b/src/premade_models/premade_hgfs/premade_categorical_transitions_3level.jl @@ -1,6 +1,6 @@ """ - premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true) + premade_categorical_state_transitions(config::Dict; verbose::Bool = true) The categorical state transition 3 level HGF model, learns state transition probabilities between a set of n categorical states. It has one categorical input node u, with a categorical value parent xcat_n for each of the n categories, representing which category was transitioned from. @@ -27,11 +27,13 @@ This HGF has five shared parameters: - ("xvol", "initial_mean"): 0 - ("xvol", "initial_precision"): 1 """ -function premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true) +function premade_categorical_state_transitions(config::Dict; verbose::Bool = true) #Defaults defaults = Dict( - "n_categories" => 4, + "n_categories_from" => 4, + "n_categories_to" => 4, + "include_volatility_parent" => true, ("xprob", "volatility") => -2, ("xprob", "drift") => 0, ("xprob", "autoconnection_strength") => 1, @@ -74,12 +76,12 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo grouped_parameters_xprob_xvol_coupling_strength = [] #Go through each category that the transition may have been from - for category_from = 1:config["n_categories"] + for category_from = 1:config["n_categories_from"] #One input node and its state node parent for each push!(input_node_names, "u" * string(category_from)) push!(observation_parent_names, "xcat_" * string(category_from)) #Go through each category that the transition may have been to - for category_to = 1:config["n_categories"] + for category_to = 1:config["n_categories_to"] #Each categorical state node has a binary parent for each push!( category_parent_names, @@ -139,19 +141,23 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo ) end + #If volatility parent is included + if config["include_volatility_parent"] + #Add the shared volatility parent of the continuous nodes + push!( + nodes, + ContinuousState( + name = "xvol", + volatility = config[("xvol", "volatility")], + drift = config[("xvol", "drift")], + autoconnection_strength = config[("xvol", "autoconnection_strength")], + initial_mean = config[("xvol", "initial_mean")], + initial_precision = config[("xvol", "initial_precision")], + ), + ) + + end - #Add the shared volatility parent of the continuous nodes - push!( - nodes, - ContinuousState( - name = "xvol", - volatility = config[("xvol", "volatility")], - drift = config[("xvol", "drift")], - autoconnection_strength = config[("xvol", "autoconnection_strength")], - initial_mean = config[("xvol", "initial_mean")], - initial_precision = config[("xvol", "initial_precision")], - ), - ) ##Create edges #Initialize list @@ -193,20 +199,27 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo edges[(category_parent_name, probability_parent_name)] = ProbabilityCoupling(config[("xbin", "xprob", "coupling_strength")]) - #Connect the probability parents to the shared volatility parent - edges[(probability_parent_name, "xvol")] = - VolatilityCoupling(config[("xprob", "xvol", "coupling_strength")]) - + #If there is a volatility parent + if config["include_volatility_parent"] + #Connect the probability parents to the shared volatility parent + edges[(probability_parent_name, "xvol")] = + VolatilityCoupling(config[("xprob", "xvol", "coupling_strength")]) + end #Add the parameters as grouped parameters for shared parameters push!( grouped_parameters_xbin_xprob_coupling_strength, (category_parent_name, probability_parent_name, "coupling_strength"), ) - push!( - grouped_parameters_xprob_xvol_coupling_strength, - (probability_parent_name, "xvol", "coupling_strength"), - ) + + #If volatility parent is included + if config["include_volatility_parent"] + push!( + grouped_parameters_xprob_xvol_coupling_strength, + (probability_parent_name, "xvol", "coupling_strength"), + ) + end + end #Create dictionary with shared parameter information @@ -242,13 +255,20 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo grouped_parameters_xbin_xprob_coupling_strength, config[("xbin", "xprob", "coupling_strength")], ), - ParameterGroup( - "xprob_xvol_coupling_strength", - grouped_parameters_xprob_xvol_coupling_strength, - config[("xprob", "xvol", "coupling_strength")], - ), ] + #If volatility parent is included + if config["include_volatility_parent"] + push!( + parameter_groups, + ParameterGroup( + "xprob_xvol_coupling_strength", + grouped_parameters_xprob_xvol_coupling_strength, + config[("xprob", "xvol", "coupling_strength")], + ), + ) + end + #Initialize the HGF init_hgf( nodes = nodes, diff --git a/test/testsuite/test_premade_hgf.jl b/test/testsuite/test_premade_hgf.jl index 91f6e05..71deb52 100644 --- a/test/testsuite/test_premade_hgf.jl +++ b/test/testsuite/test_premade_hgf.jl @@ -77,7 +77,7 @@ using Test give_inputs!(HGF_test, test_inputs) end - @testset "Categorical 3 level state transition HGF" begin + @testset "Categorical state transition HGF" begin #Set up test inputs test_inputs = [ @@ -89,7 +89,7 @@ using Test ] #Initialize HGF - HGF_test = premade_hgf("categorical_3level_state_transitions", verbose = false) + HGF_test = premade_hgf("categorical_state_transitions", verbose = false) #Give inputs give_inputs!(HGF_test, test_inputs)