Skip to content

Commit

Permalink
EpisodicMemoryMechanism: make memory (_memory_init) a FunctionParameter
Browse files Browse the repository at this point in the history
Shared with its function initializer. Changes conflict behavior to be
consistent with other SharedParameters (function value favored over
owner value). For discussion on this, see
#2600
  • Loading branch information
kmantel committed Feb 9, 2023
1 parent aa6c6dc commit 9defdba
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@
"""
import copy
import warnings
from typing import Optional, Union

Expand All @@ -416,7 +417,7 @@
from psyneulink.core.components.mechanisms.processing.processingmechanism import ProcessingMechanism_Base
from psyneulink.core.components.ports.inputport import InputPort
from psyneulink.core.globals.keywords import EPISODIC_MEMORY_MECHANISM, INITIALIZER, NAME, OWNER_VALUE, VARIABLE
from psyneulink.core.globals.parameters import Parameter, check_user_specified
from psyneulink.core.globals.parameters import FunctionParameter, Parameter, check_user_specified
from psyneulink.core.globals.preferences.basepreferenceset import is_pref_set
from psyneulink.core.globals.utilities import deprecation_warning, convert_to_np_array, convert_all_elements_to_np_array

Expand Down Expand Up @@ -508,6 +509,13 @@ class Parameters(ProcessingMechanism_Base.Parameters):
"""
variable = Parameter([[0,0]], pnl_internal=True, constructor_argument='default_variable')
function = Parameter(ContentAddressableMemory, stateful=False, loggable=False)
memory = FunctionParameter(None, function_parameter_name='initializer')

def _parse_memory(self, memory):
if memory is None:
return memory

return ContentAddressableMemory._enforce_memory_shape(memory)

@check_user_specified
def __init__(self,
Expand Down Expand Up @@ -538,15 +546,14 @@ def __init__(self,
size += kwargs['assoc_size']
kwargs.pop('assoc_size')

self._memory_init = memory

super().__init__(
default_variable=default_variable,
size=size,
function=function,
params=params,
name=name,
prefs=prefs,
memory=memory,
**kwargs
)

Expand All @@ -564,18 +571,15 @@ def _handle_default_variable(self, default_variable=None, size=None, input_ports
variable_shape = convert_all_elements_to_np_array(default_variable).shape \
if default_variable is not None else None
function_instance = self.function if isinstance(self.function, Function) else None
function_type = self.function if isinstance(self.function, type) else self.function.__class__

# **memory** arg is specified in constructor, so use that to initialize or validate default_variable
if self._memory_init:
try:
self._memory_init = function_type._enforce_memory_shape(self._memory_init)
except:
pass
if self.parameters.memory._user_specified:
memory = self.defaults.memory

if default_variable is None:
default_variable = self._memory_init[0]
default_variable = copy.deepcopy(memory[0])
else:
entry_shape = convert_all_elements_to_np_array(self._memory_init[0]).shape
entry_shape = convert_all_elements_to_np_array(memory[0]).shape
if entry_shape != variable_shape:
raise EpisodicMemoryMechanismError(f"Shape of 'variable' for {self.name} ({variable_shape}) "
f"does not match the shape of entries ({entry_shape}) in "
Expand Down Expand Up @@ -610,14 +614,9 @@ def _instantiate_input_ports(self, context=None):

def _instantiate_function(self, function, function_params, context):
"""Assign memory to function if specified in Mechanism's constructor"""
if self._memory_init is not None:
if isinstance(function, type):
function_params.update({INITIALIZER:self._memory_init})
else:
if len(function.memory):
warnings.warn(f"The 'memory' argument specified for {self.name} will override the specification "
f"for the {repr(INITIALIZER)} argument of its function ({self.function.name}).")
function.reset(self._memory_init)
memory = self.parameters.memory._get(context)
if memory is not None:
function.reset(memory)
super()._instantiate_function(function, function_params, context)

def _instantiate_output_ports(self, context=None):
Expand Down
7 changes: 5 additions & 2 deletions tests/mechanisms/test_episodic_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,11 @@ def test_with_contentaddressablememory(name, func, func_params, mech_params, tes
def test_contentaddressable_memory_warnings_and_errors():

# both memory arg of Mechanism and initializer for its function are specified
text = "The 'memory' argument specified for EpisodicMemoryMechanism-0 will override the specification " \
"for the 'initializer' argument of its function"
text = (
r"Specification of the \"memory\" parameter[.\S\s]*The value"
+ r" specified on \(ContentAddressableMemory ContentAddressableMemory"
+ r" Function-\d\) will be used\."
)
with pytest.warns(UserWarning, match=text):
em = EpisodicMemoryMechanism(
memory = [[[1,2,3],[4,5,6]]],
Expand Down

0 comments on commit 9defdba

Please sign in to comment.