Skip to content

Commit 7047302

Browse files
authored
Refactor/autodiff/track pnl (#3030)
• composition.py, autodiffcomposition.py and relevant subordinate methods: - implement synch and track parameter dictionaries that are passed to relevant methods - add/rename attributes: - PytorchCompositionWrapper: - retained_outputs - retained_targets - retained_losses - _nodes_to_execute_after_gradient_calc - PytorchMechanismWrapper: - value -> output - input - add methods: - synch_with_psyneulink(): centralize copying of params and values to pnl using methods below - copy_node_variables_to_psyneulink(): centralize updating of node (mech & comp) variables in PNL - copy_node_values_to_psyneulink(): centralize updating of node (mech & comp) values in PNL - copy_results_to_psyneulink(): centralize updating of autodiffcomposition.results - retain_in_psyneulink(): centralize tracking of pytorch results in PNL using methods below - retain_torch_outputs: keeps record of targets and copies to AutodiffComposition.pytorch_targets at end of call to learn() - retain_torch_targets: keeps record of targets and copies to AutodiffComposition.pytorch_targets at end of call to learn() - retain_torch_losses: keeps record of losses and copies to AutodiffComposition.pytorch_losses at end of call to learn() • compositionrunner.py, autodiffcomposition.py, pytorchwrappers.py: - move loss tracking from parameter on autodiff to attribute on its pytorch_rep - batch_inputs(): add calls to synch_with_psyneulink() and retain_in_psyneulink() - batch_function_inputs(): - needs calls to synch_with_psyneulink() and retain_in_psyneulink() • composition.py: - run(): add _update_results() as helper method than can be overidden (e.g., by autodiffcomposition) for less frequent updating * • autodiffcomposition.py - restrict calls to copy_weights_to_psyneulink based on copy_parameters_to_psyneulink_after arg/attribute - implement handling of optimizations_per_minibatch and copy_parameters_to_psyneulink as attributes and args to learn - autodiff_training(): fix bug in call to pytorch_rep.forward() - implement synch and track Parameters - _manage_synch_and_retain_args() - run(): support specification of synch and retain args when called directly - autodiff._update_learning_parameters -> do_optimzation(): - calculates loss for current trial - calls autodiff_backward() to calculate gradients and update parameters - updates tracked_loss over triasl - autodiff_backward() -> new method that is called from do_optimization that calculates and updates the gradients - self.loss -> self.loss_function - _update_results() - overriden to call pytoch_rep.retain_for_psyneulink(RUN:trial_output) - learn(): - move tracked_loss for each minibatch from parameter on autodiff to attribute on its pytorch_rep (since that is already context dependent, and avoids calls to pnl.parameters._set on every call to forward() - synch_with_pnl_options: implement as dict to consolidate synch_projection_matrices_with_torch, synch_node_values_with_torch and synch_node_values_with_torch options passed to learning methods - retain_in_pnl_options implement as dict to consolidate retain_torch_outputs_in_results, retain_torch_targets and retain_torch_losses passed to learning methods • pytorchwrappers.py - sublcass PytorchCompositionWrapper from torch.jit.ScriptModule - retain_for_psyneulink(): implemented - stores outputs, targets, and losses from Pytorch execution for copying to PsyNeuLink at end of learn(). - PytorchMechanismWrapper: - .value -> .output - add .input - add/rename attributes: - PytorchCompositionWrapper: - retained_outputs - retained_targets - retained_losses - _nodes_to_execute_after_gradient_calc - PytorchMechanismWrapper: - value -> output - input - add methods: - synch_with_psyneulink(): centralize copying of params and values to pnl using methods below - copy_node_variables_to_psyneulink(): centralize updating of node (mech & comp) variables in PNL - copy_node_values_to_psyneulink(): centralize updating of node (mech & comp) values in PNL - copy_results_to_psyneulink(): centralize updating of autodiffcomposition.results - retain_in_psyneulink(): centralize tracking of pytorch results in PNL using methods below - retain_torch_outputs: keeps record of targets and copies to AutodiffComposition.pytorch_targets at end of call to learn() - retain_torch_targets: keeps record of targets and copies to AutodiffComposition.pytorch_targets at end of call to learn() - retain_torch_losses: keeps record of losses and copies to AutodiffComposition.pytorch_losses at end of call to learn() • pytorchEMcompositionwrapper.py - store_memory(): - implement single call to linalg over memory - only execute storage_node after last optimization_rep • keywords.py - implement LearningScale keywords class • AutoAssociativeProjection: make dependent on MaskedMappingProjection in prep for allowing lcamechanism to modulate auto/hetero parameters * fix Literals import • Factorize scripts into: - ScriptControl.py - TestParams.py - [MODEL].py --------- Co-authored-by: jdcpni <pniintel55>
1 parent 310afb1 commit 7047302

34 files changed

+2033
-2298
lines changed

Scripts/Models (Under Development)/EGO/EGO Model (sim 2) - CSW using EMComposition (BACKUP).py

-433
This file was deleted.

Scripts/Models (Under Development)/EGO/EGO Model (sim 2) - CSW using EMComposition.py

-433
This file was deleted.

Scripts/Models (Under Development)/EGO/EGO Model (sim 2) - CSW with Integrator and Learning.py

-406
This file was deleted.

Scripts/Models (Under Development)/EGO/EGO Model - MDP OLD.py

-500
This file was deleted.

Scripts/Models (Under Development)/EGO/Tutorial/Declan's EGO Tutorial.ipynb

+399
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
DECLAN Params: **************************************************************************
3+
√ episodic_lr = 1 # learning rate for the episodic pathway
4+
√ temperature = 0.1 # temperature for EM retrieval (lower is more argmax-like)
5+
√ n_optimization_steps = 10 # number of update steps
6+
sim_thresh = 0.8 # threshold for discarding bad seeds -- can probably ignore this for now
7+
Filter runs whose context representations are too uniform (i.e. not similar to "checkerboard" foil)
8+
9+
May need to pad the context reps because there will be 999 reps
10+
def filter_run(run_em, thresh=0.8):
11+
foil = np.zeros([4,4])
12+
foil[::2, ::2] = 1
13+
foil[1::2, 1::2] = 1
14+
run_em = run_em.reshape(200, 5, 11).mean(axis=1)
15+
mat = cosine_similarity(run_em, run_em)
16+
vec = mat[:160, :160].reshape(4, 40, 4, 40).mean(axis=(1, 3)).ravel()
17+
return cosine_similarity(foil.reshape(1, -1), vec.reshape(1, -1))[0][0]
18+
19+
# Stack the model predictions (should be 999x11), pad with zeros, and reshape into trials for averaging.
20+
em_preds = np.vstack([em_preds, np.zeros([1,11])]).reshape(-1,5,11)
21+
22+
# Stack the ground truth states (should be 999x11), pad with zeros, and reshape into trials for averaging.
23+
ys = np.vstack([data_loader.dataset.ys.cpu().numpy(), np.zeros([1,11])]).reshape(-1,5,11)
24+
25+
# compute the probability as a performance metric
26+
def calc_prob(em_preds, test_ys):
27+
em_preds, test_ys = em_preds[:, 2:-1, :], test_ys[:, 2:-1, :]
28+
em_probability = (em_preds*test_ys).sum(-1).mean(-1)
29+
trial_probs = (em_preds*test_ys)
30+
return em_probability, trial_probs
31+
32+
Calculate the retrieval probability of the correct response as a performance metric (probs)
33+
probs, trial_probs = calc_prob(em_preds, test_ys)
34+
"""
35+
from psyneulink.core.llvm import ExecutionMode
36+
from psyneulink.core.globals.keywords import ALL, ADAPTIVE, CONTROL, CPU, Loss, MPS, OPTIMIZATION_STEP, RUN, TRIAL
37+
38+
model_params = dict(
39+
40+
# Names:
41+
name = "EGO Model CSW",
42+
state_input_layer_name = "STATE",
43+
previous_state_layer_name = "PREVIOUS STATE",
44+
context_layer_name = 'CONTEXT',
45+
em_name = "EM",
46+
prediction_layer_name = "PREDICTION",
47+
48+
# Structural
49+
state_d = 11, # length of state vector
50+
previous_state_d = 11, # length of state vector
51+
context_d = 11, # length of context vector
52+
memory_capacity = ALL, # number of entries in EM memory; ALL=> match to number of stims
53+
memory_init = (0,.0001), # Initialize memory with random values in interval
54+
# memory_init = None, # Initialize with zeros
55+
concatenate_keys = False,
56+
# concatenate_keys = True,
57+
58+
# environment
59+
# curriculum_type = 'Interleaved',
60+
curriculum_type = 'Blocked',
61+
# num_stims = 100, # Integer or ALL
62+
num_stims = ALL, # Integer or ALL
63+
64+
# Processing
65+
integration_rate = .69, # rate at which state is integrated into new context
66+
# state_weight = 1, # weight of the state used during memory retrieval
67+
# context_weight = 1, # weight of the context used during memory retrieval
68+
state_weight = .5, # weight of the state used during memory retrieval
69+
context_weight = .5, # weight of the context used during memory retrieval
70+
normalize_field_weights = False, # whether to normalize the field weights during memory retrieval
71+
# normalize_field_weights = True, # whether to normalize the field weights during memory retrieval
72+
# softmax_temperature = None, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
73+
softmax_temperature = .1, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
74+
# softmax_temperature = ADAPTIVE, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
75+
# softmax_temperature = CONTROL, # temperature of the softmax used during memory retrieval (smaller means more argmax-like
76+
# softmax_threshold = None, # threshold used to mask out small values in softmax
77+
softmax_threshold = .001, # threshold used to mask out small values in softmax
78+
enable_learning=[True, False, False], # Enable learning for PREDICTION (STATE) but not CONTEXT or PREVIOUS STATE
79+
learn_field_weights = False,
80+
loss_spec = Loss.BINARY_CROSS_ENTROPY,
81+
# loss_spec = Loss.MSE,
82+
learning_rate = .5,
83+
# num_optimization_steps = 1,
84+
num_optimization_steps = 10,
85+
synch_weights = RUN,
86+
synch_values = RUN,
87+
synch_results = RUN,
88+
# execution_mode = ExecutionMode.Python,
89+
execution_mode = ExecutionMode.PyTorch,
90+
device = CPU,
91+
# device = MPS,
92+
)
93+
#endregion
+3-3
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@
147147
MEMORY_CAPACITY = 5
148148
CONSTRUCT_MODEL = True # THIS MUST BE SET TO True to run the script
149149
DISPLAY_MODEL = ( # Only one of the following can be uncommented:
150-
None # suppress display of model
151-
# {} # show simple visual display of model
150+
# None # suppress display of model
151+
{} # show simple visual display of model
152152
# {'show_node_structure': True} # show detailed view of node structures and projections
153153
)
154154
RUN_MODEL = True # True => run the model
@@ -404,7 +404,7 @@ def construct_model(model_name:str=MODEL_NAME,
404404
model = construct_model()
405405
assert 'DEBUGGING BREAK POINT'
406406
# print(model.scheduler.consideration_queue)
407-
# gs.output_graph_image(model.scheduler.graph, 'EGO_comp-scheduler.png')
407+
# gs.output_graph_image(model.scheduler.graph, 'show_graph OUTPUT/EGO_comp-scheduler.png')
408408

409409
if DISPLAY_MODEL is not None:
410410
if model:

0 commit comments

Comments
 (0)