Skip to content

Commit

Permalink
Added nonlinear transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
PTWaade committed Apr 15, 2024
1 parent 20792df commit 0eea8bc
Show file tree
Hide file tree
Showing 10 changed files with 346 additions and 174 deletions.
93 changes: 66 additions & 27 deletions src/ActionModels_variations/utils/get_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,12 @@ function ActionModels.get_parameters(hgf::HGF, target_param::Tuple{String,String
return param
end

##For coupling strengths
##For coupling strengths and coupling transforms
function ActionModels.get_parameters(hgf::HGF, target_param::Tuple{String,String,String})

#Unpack node name, parent name and param name
(node_name, parent_name, param_name) = target_param

#If the specified parameter is not a coupling strength
if !(param_name == "coupling_strength")
throw(
ArgumentError(
"the parameter $target_param is specified as three strings, but is not a coupling strength",
),
)
end

#If the node does not exist
if !(node_name in keys(hgf.all_nodes))
#Throw an error
Expand All @@ -67,29 +58,59 @@ function ActionModels.get_parameters(hgf::HGF, target_param::Tuple{String,String
#Get out the node
node = hgf.all_nodes[node_name]

#Get out the dictionary of coupling strengths
coupling_strengths = getproperty(node.parameters, :coupling_strengths)
#If the parameter is a coupling strength
if param_name == "coupling_strength"

#If the specified parent is not in the dictionary
if !(parent_name in keys(coupling_strengths))
#Throw an error
throw(
ArgumentError(
"The node $node_name does not have a coupling strength parameter to a parent called $parent_name",
),
)
end
#Get out the dictionary of coupling strengths
coupling_strengths = getproperty(node.parameters, :coupling_strengths)

#If the specified parent is not in the dictionary
if !(parent_name in keys(coupling_strengths))
#Throw an error
throw(
ArgumentError(
"The node $node_name does not have a coupling strength parameter to a parent called $parent_name",
),
)
end

#Get the coupling strength for that given parent
param = coupling_strengths[parent_name]

#Get the coupling strength for that given parent
param = coupling_strengths[parent_name]
else

#Get out the coupling transforms
coupling_transforms = getproperty(node.parameters, :coupling_transforms)

#If the specified parent is not in the dictionary
if !(parent_name in keys(coupling_transforms))
#Throw an error
throw(
ArgumentError(
"The node $node_name does not have a coupling transformation to a parent called $parent_name",
),
)
end

#If the specified parameter does not exist for the transform
if !(param_name in keys(coupling_transforms.parameters))
throw(
ArgumentError(
"There is no parameter called $param_name for the transformation function between $node_name and its parent $parent_name",
),
)
end

#Extract the parameter
param = coupling_transforms.parameters[param_name]
end

return param
end



### For getting all parameters of a specific node ###

### For getting a single-string parameter (a parameter group), or all parameters of a node ###
function ActionModels.get_parameters(hgf::HGF, target_parameter::String)

#If the target parameter is a parameter group
Expand Down Expand Up @@ -196,13 +217,31 @@ function ActionModels.get_parameters(node::AbstractNode)
coupling_strengths = node.parameters.coupling_strengths

#Go through each parent
for parent_name in keys(coupling_strengths)
for (parent_name, coupling_strength) in coupling_strengths

#Add the coupling strength to the ouput dict
parameters[(node.name, parent_name, "coupling_strength")] =
coupling_strengths[parent_name]
coupling_strength

end

#If the parameter is a coupling transform
elseif param_key == :coupling_transforms

#Go through each parent and corresponding transform
for (parent_name, coupling_transform) in node.parameters.coupling_transforms

#Go through each parameter for the transform
for (coupling_parameter, parameter_value) in coupling_transform.parameters

#Add the coupling strength to the ouput dict
parameters[(node.name, parent_name, coupling_parameter)] =
parameter_value

end
end

#For other nodes
else
#And add their values to the dictionary
parameters[(node.name, String(param_key))] =
Expand Down
71 changes: 46 additions & 25 deletions src/ActionModels_variations/utils/set_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function ActionModels.set_parameters!() end

### For setting a single parameter ###

##For parameters other than coupling strengths
##For parameters other than coupling strengths and transforms
function ActionModels.set_parameters!(
hgf::HGF,
target_param::Tuple{String,String},
Expand Down Expand Up @@ -61,15 +61,6 @@ function ActionModels.set_parameters!(
#Unpack node name, parent name and parameter name
(node_name, parent_name, param_name) = target_param

#If the specified parameter is not a coupling strength
if !(param_name == "coupling_strength")
throw(
ArgumentError(
"the parameter $target_param is specified as three strings, but is not a coupling strength",
),
)
end

#If the node does not exist
if !(node_name in keys(hgf.all_nodes))
#Throw an error
Expand All @@ -79,26 +70,56 @@ function ActionModels.set_parameters!(
#Get the child node
node = hgf.all_nodes[node_name]

#Get coupling_strengths
coupling_strengths = node.parameters.coupling_strengths
#If it is a coupling strength
if param_name == "coupling_strength"

#Get coupling_strengths
coupling_strengths = node.parameters.coupling_strengths

#If the specified parent is not in the dictionary
if !(parent_name in keys(coupling_strengths))
#Throw an error
throw(
ArgumentError(
"The node $node_name does not have a coupling strength parameter to a parent called $parent_name",
),
)
end

#Set the coupling strength to the specified parent to the specified value
coupling_strengths[parent_name] = param_value

else

#Get out the coupling transforms
coupling_transforms = getproperty(node.parameters, :coupling_transforms)

#If the specified parent is not in the dictionary
if !(parent_name in keys(coupling_transforms))
#Throw an error
throw(
ArgumentError(
"The node $node_name does not have a coupling transformation to a parent called $parent_name",
),
)
end

#If the specified parameter does not exist for the transform
if !(param_name in keys(coupling_transforms.parameters))
throw(
ArgumentError(
"There is no parameter called $param_name for the transformation function between $node_name and its parent $parent_name",
),
)
end

#If the specified parent is not in the dictionary
if !(parent_name in keys(coupling_strengths))
#Throw an error
throw(
ArgumentError(
"The node $node_name does not have a coupling strength parameter to a parent called $parent_name",
),
)
#Set the parameter
coupling_transforms.parameters[param_name] = param_value
end

#Set the coupling strength to the specified parent to the specified value
coupling_strengths[parent_name] = param_value

end

### For setting a single parameter ###
function ActionModels.set_parameters!(hgf::HGF, target_param::String, param_value::Any)
function ActionModels.set_parameters!(hgf::HGF, target_param::String, param_value::Real)
#If the target parameter is not in the shared parameters
if !(target_param in keys(hgf.parameter_groups))
throw(
Expand Down
8 changes: 5 additions & 3 deletions src/HierarchicalGaussianFiltering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ export DriftCoupling,
CategoryCoupling,
ProbabilityCoupling,
VolatilityCoupling,
NoiseCoupling
NoiseCoupling,
LinearTransform,
NonlinearTransform

#Add premade agents to shared dict at initialization
function __init__()
Expand All @@ -47,6 +49,7 @@ include("ActionModels_variations/utils/set_save_history.jl")

#Functions for updating the HGF
include("update_hgf/update_hgf.jl")
include("update_hgf/nonlinear_transforms.jl")
include("update_hgf/node_updates/continuous_input_node.jl")
include("update_hgf/node_updates/continuous_state_node.jl")
include("update_hgf/node_updates/binary_input_node.jl")
Expand All @@ -57,10 +60,9 @@ include("update_hgf/node_updates/categorical_state_node.jl")
#Functions for creating HGFs
include("create_hgf/check_hgf.jl")
include("create_hgf/init_hgf.jl")
include("create_hgf/init_node_edge.jl")
include("create_hgf/create_premade_hgf.jl")

#Plotting functions

#Functions for premade agents
include("premade_models/premade_agents/premade_gaussian.jl")
include("premade_models/premade_agents/premade_predict_category.jl")
Expand Down
19 changes: 18 additions & 1 deletion src/create_hgf/hgf_structs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ struct EnhancedUpdate <: HGFUpdateType end
######## Coupling types ########
################################

#Types for specifying nonlinear transformations
abstract type CouplingTransform end

Base.@kwdef mutable struct LinearTransform <: CouplingTransform end

Base.@kwdef mutable struct NonlinearTransform <: CouplingTransform
base_function::Function
first_derivation::Function
second_derivation::Function
parameters::Dict = Dict()
end

#Supertypes for coupling types
abstract type CouplingType end
abstract type ValueCoupling <: CouplingType end
Expand All @@ -45,6 +57,7 @@ abstract type PrecisionCoupling <: CouplingType end
#Concrete value coupling types
Base.@kwdef mutable struct DriftCoupling <: ValueCoupling
strength::Union{Nothing,Real} = nothing
transform::CouplingTransform = LinearTransform()
end
Base.@kwdef mutable struct ProbabilityCoupling <: ValueCoupling
strength::Union{Nothing,Real} = nothing
Expand Down Expand Up @@ -159,9 +172,10 @@ Base.@kwdef mutable struct ContinuousStateNodeParameters
volatility::Real = 0
drift::Real = 0
autoconnection_strength::Real = 1
coupling_strengths::Dict{String,Real} = Dict{String,Real}()
initial_mean::Real = 0
initial_precision::Real = 0
coupling_strengths::Dict{String,Real} = Dict{String,Real}()
coupling_transforms::Dict{String,CouplingTransform} = Dict{String,Real}()
end

"""
Expand Down Expand Up @@ -210,6 +224,8 @@ Base.@kwdef mutable struct ContinuousInputNodeEdges
observation_parents::Vector{<:AbstractContinuousStateNode} =
Vector{ContinuousStateNode}()
noise_parents::Vector{<:AbstractContinuousStateNode} = Vector{ContinuousStateNode}()


end

"""
Expand All @@ -219,6 +235,7 @@ Base.@kwdef mutable struct ContinuousInputNodeParameters
input_noise::Real = 0
bias::Real = 0
coupling_strengths::Dict{String,Real} = Dict{String,Real}()
coupling_transforms::Dict{String,CouplingTransform} = Dict{String,Real}()
end

"""
Expand Down
Loading

0 comments on commit 0eea8bc

Please sign in to comment.