Skip to content

Commit

Permalink
Merge pull request #137 from ilabcode/dev
Browse files Browse the repository at this point in the history
Version 0.5.1
  • Loading branch information
PTWaade authored Apr 17, 2024
2 parents f00c50c + 4901bd6 commit 3cbac46
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 34 deletions.
3 changes: 1 addition & 2 deletions src/create_hgf/create_premade_hgf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ function premade_hgf(model_name::String, config::Dict = Dict(); verbose = true)
"binary_3level" => premade_binary_3level, #The standard binary input 3 level HGF
"JGET" => premade_JGET, #The JGET model
"categorical_3level" => premade_categorical_3level, #The standard categorical input 3 level HGF
"categorical_3level_state_transitions" =>
premade_categorical_3level_state_transitions, #Categorical 3 level HGF for learning state transitions
"categorical_state_transitions" => premade_categorical_state_transitions, #Categorical 3 level HGF for learning state transitions
)

#Check that the specified model is in the list of keys
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

"""
premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true)
premade_categorical_state_transitions(config::Dict; verbose::Bool = true)
The categorical state transition 3 level HGF model, learns state transition probabilities between a set of n categorical states.
It has one categorical input node u, with a categorical value parent xcat_n for each of the n categories, representing which category was transitioned from.
Expand All @@ -27,11 +27,13 @@ This HGF has five shared parameters:
- ("xvol", "initial_mean"): 0
- ("xvol", "initial_precision"): 1
"""
function premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true)
function premade_categorical_state_transitions(config::Dict; verbose::Bool = true)

#Defaults
defaults = Dict(
"n_categories" => 4,
"n_categories_from" => 4,
"n_categories_to" => 4,
"include_volatility_parent" => true,
("xprob", "volatility") => -2,
("xprob", "drift") => 0,
("xprob", "autoconnection_strength") => 1,
Expand Down Expand Up @@ -74,12 +76,12 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo
grouped_parameters_xprob_xvol_coupling_strength = []

#Go through each category that the transition may have been from
for category_from = 1:config["n_categories"]
for category_from = 1:config["n_categories_from"]
#One input node and its state node parent for each
push!(input_node_names, "u" * string(category_from))
push!(observation_parent_names, "xcat_" * string(category_from))
#Go through each category that the transition may have been to
for category_to = 1:config["n_categories"]
for category_to = 1:config["n_categories_to"]
#Each categorical state node has a binary parent for each
push!(
category_parent_names,
Expand Down Expand Up @@ -139,19 +141,23 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo
)
end

#If volatility parent is included
if config["include_volatility_parent"]
#Add the shared volatility parent of the continuous nodes
push!(
nodes,
ContinuousState(
name = "xvol",
volatility = config[("xvol", "volatility")],
drift = config[("xvol", "drift")],
autoconnection_strength = config[("xvol", "autoconnection_strength")],
initial_mean = config[("xvol", "initial_mean")],
initial_precision = config[("xvol", "initial_precision")],
),
)

end

#Add the shared volatility parent of the continuous nodes
push!(
nodes,
ContinuousState(
name = "xvol",
volatility = config[("xvol", "volatility")],
drift = config[("xvol", "drift")],
autoconnection_strength = config[("xvol", "autoconnection_strength")],
initial_mean = config[("xvol", "initial_mean")],
initial_precision = config[("xvol", "initial_precision")],
),
)

##Create edges
#Initialize list
Expand Down Expand Up @@ -193,20 +199,27 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo
edges[(category_parent_name, probability_parent_name)] =
ProbabilityCoupling(config[("xbin", "xprob", "coupling_strength")])

#Connect the probability parents to the shared volatility parent
edges[(probability_parent_name, "xvol")] =
VolatilityCoupling(config[("xprob", "xvol", "coupling_strength")])

#If there is a volatility parent
if config["include_volatility_parent"]
#Connect the probability parents to the shared volatility parent
edges[(probability_parent_name, "xvol")] =
VolatilityCoupling(config[("xprob", "xvol", "coupling_strength")])
end

#Add the parameters as grouped parameters for shared parameters
push!(
grouped_parameters_xbin_xprob_coupling_strength,
(category_parent_name, probability_parent_name, "coupling_strength"),
)
push!(
grouped_parameters_xprob_xvol_coupling_strength,
(probability_parent_name, "xvol", "coupling_strength"),
)

#If volatility parent is included
if config["include_volatility_parent"]
push!(
grouped_parameters_xprob_xvol_coupling_strength,
(probability_parent_name, "xvol", "coupling_strength"),
)
end

end

#Create dictionary with shared parameter information
Expand Down Expand Up @@ -242,13 +255,20 @@ function premade_categorical_3level_state_transitions(config::Dict; verbose::Boo
grouped_parameters_xbin_xprob_coupling_strength,
config[("xbin", "xprob", "coupling_strength")],
),
ParameterGroup(
"xprob_xvol_coupling_strength",
grouped_parameters_xprob_xvol_coupling_strength,
config[("xprob", "xvol", "coupling_strength")],
),
]

#If volatility parent is included
if config["include_volatility_parent"]
push!(
parameter_groups,
ParameterGroup(
"xprob_xvol_coupling_strength",
grouped_parameters_xprob_xvol_coupling_strength,
config[("xprob", "xvol", "coupling_strength")],
),
)
end

#Initialize the HGF
init_hgf(
nodes = nodes,
Expand Down
4 changes: 2 additions & 2 deletions test/testsuite/test_premade_hgf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ using Test
give_inputs!(HGF_test, test_inputs)
end

@testset "Categorical 3 level state transition HGF" begin
@testset "Categorical state transition HGF" begin

#Set up test inputs
test_inputs = [
Expand All @@ -89,7 +89,7 @@ using Test
]

#Initialize HGF
HGF_test = premade_hgf("categorical_3level_state_transitions", verbose = false)
HGF_test = premade_hgf("categorical_state_transitions", verbose = false)

#Give inputs
give_inputs!(HGF_test, test_inputs)
Expand Down

0 comments on commit 3cbac46

Please sign in to comment.