Skip to content

Commit b176980

Browse files
committed
lstm: Add LSTMMechanism + compiled support
1 parent 4d4c9bc commit b176980

File tree

6 files changed

+657
-10
lines changed

6 files changed

+657
-10
lines changed

psyneulink/core/components/functions/transferfunctions.py

+215-8
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,7 @@
5858
from psyneulink.core.components.functions.selectionfunctions import OneHot
5959
from psyneulink.core.components.functions.statefulfunctions.integratorfunctions import SimpleIntegrator
6060
from psyneulink.core.components.shellclasses import Projection
61-
from psyneulink.core.globals.keywords import \
62-
ADDITIVE_PARAM, ALL, BIAS, EXPONENTIAL_FUNCTION, \
63-
GAIN, GAUSSIAN_DISTORT_FUNCTION, GAUSSIAN_FUNCTION, HAS_INITIALIZERS, HOLLOW_MATRIX, \
64-
IDENTITY_FUNCTION, IDENTITY_MATRIX, INTERCEPT, LEAK, LINEAR_FUNCTION, LINEAR_MATRIX_FUNCTION, LOGISTIC_FUNCTION, \
65-
TANH_FUNCTION, MATRIX_KEYWORD_NAMES, MATRIX, MATRIX_KEYWORD_VALUES, MAX_INDICATOR, MAX_VAL, MULTIPLICATIVE_PARAM, \
66-
OFF, OFFSET, ON, PER_ITEM, PROB, PRODUCT, OUTPUT_TYPE, PROB_INDICATOR, \
67-
RATE, RECEIVER, RELU_FUNCTION, SCALE, SLOPE, SOFTMAX_FUNCTION, STANDARD_DEVIATION, SUM,\
68-
TRANSFER_FUNCTION_TYPE, TRANSFER_WITH_COSTS_FUNCTION, VARIANCE, VARIABLE, X_0, PREFERENCE_SET_NAME
61+
from psyneulink.core.globals.keywords import ADDITIVE_PARAM, ALL, BIAS, EXPONENTIAL_FUNCTION, GAIN, GAUSSIAN_DISTORT_FUNCTION, GAUSSIAN_FUNCTION, HAS_INITIALIZERS, HOLLOW_MATRIX, IDENTITY_FUNCTION, IDENTITY_MATRIX, INTERCEPT, LEAK, LINEAR_FUNCTION, LINEAR_MATRIX_FUNCTION, LOGISTIC_FUNCTION, LSTM_FUNCTION, MATRIX, MATRIX_KEYWORD_NAMES, MATRIX_KEYWORD_VALUES, MAX_INDICATOR, MAX_VAL, MULTIPLICATIVE_PARAM, OFF, OFFSET, ON, OUTPUT_TYPE, PER_ITEM, PREFERENCE_SET_NAME, PROB, PROB_INDICATOR, PRODUCT, RANDOM_CONNECTIVITY_MATRIX, RATE, RECEIVER, RELU_FUNCTION, SCALE, SLOPE, SOFTMAX_FUNCTION, STANDARD_DEVIATION, SUM, TANH_FUNCTION, TRANSFER_FUNCTION_TYPE, TRANSFER_WITH_COSTS_FUNCTION, VARIABLE, VARIANCE, X_0
6962
from psyneulink.core.globals.parameters import \
7063
Parameter, get_validator_by_function
7164
from psyneulink.core.globals.utilities import parameter_spec, get_global_seed, safe_len
@@ -2530,7 +2523,221 @@ def derivative(self, output, input=None, context=None):
25302523

25312524
return derivative
25322525

2526+
# **********************************************************************************************************************
2527+
# SoftMax
2528+
# **********************************************************************************************************************
2529+
2530+
class LSTM(TransferFunction):
2531+
componentName = LSTM_FUNCTION
2532+
2533+
def __init__(self,
2534+
default_variable=None,
2535+
params=None,
2536+
owner=None,
2537+
prefs: tc.optional(is_pref_set) = None):
2538+
2539+
super().__init__(
2540+
default_variable=default_variable,
2541+
params=params,
2542+
owner=owner,
2543+
prefs=prefs)
2544+
2545+
class Parameters(TransferFunction.Parameters):
2546+
i_input_matrix = Parameter(modulable=True)
2547+
i_hidden_matrix = Parameter(modulable=True)
2548+
i_gate_func = Parameter(default_value=Logistic())
2549+
2550+
f_input_matrix = Parameter(modulable=True)
2551+
f_hidden_matrix = Parameter(modulable=True)
2552+
f_gate_func = Parameter(default_value=Logistic())
25332553

2554+
g_input_matrix = Parameter(modulable=True)
2555+
g_hidden_matrix = Parameter(modulable=True)
2556+
g_gate_func = Parameter(default_value=Tanh())
2557+
2558+
o_input_matrix = Parameter(modulable=True)
2559+
o_hidden_matrix = Parameter(modulable=True)
2560+
o_gate_func = Parameter(default_value=Logistic())
2561+
2562+
h_gate_func = Parameter(default_value=Tanh())
2563+
2564+
2565+
def _instantiate_attributes_before_function(self, function=None, context=None):
2566+
input_size = len(self.variable[0])
2567+
hidden_size = len(self.variable[1])
2568+
2569+
# Instatiate input matrices
2570+
for param_id in ["i_input_matrix", "f_input_matrix", "g_input_matrix", "o_input_matrix"]:
2571+
param_val = getattr(self, param_id, None)
2572+
if param_val is None:
2573+
param_val = RANDOM_CONNECTIVITY_MATRIX
2574+
2575+
setattr(self, param_id, get_matrix(param_val, hidden_size, input_size, context=context))
2576+
2577+
# Instantiate hidden matrices
2578+
for param_id in ["i_hidden_matrix", "f_hidden_matrix", "g_hidden_matrix", "o_hidden_matrix"]:
2579+
param_val = getattr(self, param_id, None)
2580+
if param_val is None:
2581+
param_val = RANDOM_CONNECTIVITY_MATRIX
2582+
2583+
setattr(self, param_id, get_matrix(param_val, hidden_size, hidden_size, context=context))
2584+
2585+
# Instantiate function default variables
2586+
for param_id in ["i_gate_func", "f_gate_func","g_gate_func", "o_gate_func", "h_gate_func"]:
2587+
param_val = getattr(self, param_id)
2588+
param_val.default_variable = np.zeros(hidden_size)
2589+
param_val.defaults.variable = np.zeros(hidden_size)
2590+
param_val.variable = np.zeros(hidden_size)
2591+
param_val.default_value = np.zeros(hidden_size)
2592+
param_val.defaults.value = np.zeros(hidden_size)
2593+
param_val.value = np.zeros(hidden_size)
2594+
2595+
def _function(self,
2596+
variable=None,
2597+
context=None,
2598+
params=None,
2599+
):
2600+
2601+
x_t = variable[0]
2602+
h_prev = variable[1]
2603+
c_prev = variable[2]
2604+
2605+
# Calculate input
2606+
i_input_matrix = self._get_current_function_param("i_input_matrix", context=context)
2607+
i_hidden_matrix = self._get_current_function_param("i_hidden_matrix", context=context)
2608+
i_gate_func = self._get_current_function_param("i_gate_func", context=context)
2609+
i_t = i_gate_func(np.matmul(i_input_matrix, x_t) + np.matmul(i_hidden_matrix, h_prev))
2610+
2611+
# Calculate forget gate
2612+
f_input_matrix = self._get_current_function_param("f_input_matrix", context=context)
2613+
f_hidden_matrix = self._get_current_function_param("f_hidden_matrix", context=context)
2614+
f_gate_func = self._get_current_function_param("f_gate_func", context=context)
2615+
f_t = f_gate_func(np.matmul(f_input_matrix, x_t) + np.matmul(f_hidden_matrix, h_prev))
2616+
2617+
# Update cell state
2618+
g_input_matrix = self._get_current_function_param("g_input_matrix", context=context)
2619+
g_hidden_matrix = self._get_current_function_param("g_hidden_matrix", context=context)
2620+
g_gate_func = self._get_current_function_param("g_gate_func", context=context)
2621+
g_t = g_gate_func(np.matmul(g_input_matrix, x_t) + np.matmul(g_hidden_matrix, h_prev))
2622+
c_t = np.multiply(f_t, c_prev) + np.multiply(i_t, g_t)
2623+
2624+
# Calculate output gate
2625+
o_input_matrix = self._get_current_function_param("o_input_matrix", context=context)
2626+
o_hidden_matrix = self._get_current_function_param("o_hidden_matrix", context=context)
2627+
o_gate_func = self._get_current_function_param("o_gate_func", context=context)
2628+
o_t = o_gate_func(np.matmul(o_input_matrix, x_t) + np.matmul(o_hidden_matrix, h_prev))
2629+
2630+
# Update hidden state
2631+
h_gate_func = self._get_current_function_param("h_gate_func", context=context)
2632+
h_t = np.multiply(o_t, h_gate_func(c_t))
2633+
value = [h_t, c_t]
2634+
2635+
return value
2636+
2637+
def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset):
2638+
matmul = ctx.import_llvm_function("__pnl_builtin_mxm")
2639+
vecadd = ctx.import_llvm_function("__pnl_builtin_vec_add")
2640+
vechadamard = ctx.import_llvm_function("__pnl_builtin_vec_hadamard")
2641+
2642+
x_t = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(0)])
2643+
h_prev = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(1)])
2644+
c_prev = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(2)])
2645+
2646+
def _mxv(m, v):
2647+
tmp = builder.alloca(h_prev.type.pointee)
2648+
tmp_ptr = builder.gep(tmp, [ctx.int32_ty(0),
2649+
ctx.int32_ty(0)])
2650+
dim_x = len(m.type.pointee)
2651+
dim_y = len(m.type.pointee.elements[0])
2652+
m_ptr = builder.gep(m, [ctx.int32_ty(0),
2653+
ctx.int32_ty(0),
2654+
ctx.int32_ty(0)])
2655+
v_ptr = builder.gep(v, [ctx.int32_ty(0),
2656+
ctx.int32_ty(0)])
2657+
2658+
builder.call(matmul, [m_ptr,
2659+
v_ptr,
2660+
ctx.int32_ty(dim_x),
2661+
ctx.int32_ty(dim_y),
2662+
ctx.int32_ty(1),
2663+
tmp_ptr])
2664+
2665+
return tmp
2666+
2667+
def _vxv(v1, v2):
2668+
tmp = builder.alloca(h_prev.type.pointee)
2669+
tmp_ptr = builder.gep(tmp, [ctx.int32_ty(0),
2670+
ctx.int32_ty(0)])
2671+
dim_x = len(v1.type.pointee)
2672+
v1_ptr = builder.gep(v1, [ctx.int32_ty(0),
2673+
ctx.int32_ty(0)])
2674+
v2_ptr = builder.gep(v2, [ctx.int32_ty(0),
2675+
ctx.int32_ty(0)])
2676+
2677+
builder.call(vechadamard, [v1_ptr,
2678+
v2_ptr,
2679+
ctx.int32_ty(dim_x),
2680+
tmp_ptr])
2681+
return tmp
2682+
2683+
def _mac(m1, v1, m2, v2, mul_op=_mxv):
2684+
val1 = mul_op(m1, v1)
2685+
val2 = mul_op(m2, v2)
2686+
val1_ptr = builder.gep(val1, [ctx.int32_ty(0),
2687+
ctx.int32_ty(0)])
2688+
val2_ptr = builder.gep(val2, [ctx.int32_ty(0),
2689+
ctx.int32_ty(0)])
2690+
builder.call(vecadd, [val1_ptr,
2691+
val2_ptr,
2692+
ctx.int32_ty(len(m1.type.pointee)),
2693+
val1_ptr])
2694+
return val1
2695+
2696+
def _call_func(func_id, in_vec, out_vec):
2697+
param_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, func_id)
2698+
state_ptr = pnlvm.helpers.get_state_ptr(builder, self, state, func_id)
2699+
2700+
llvm_func = ctx.import_llvm_function(getattr(self, func_id), tags=tags)
2701+
builder.call(llvm_func, [param_ptr, state_ptr, in_vec, out_vec])
2702+
2703+
# Calculate input
2704+
i_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'i_input_matrix')
2705+
i_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'i_hidden_matrix')
2706+
i_t = _mac(i_input_matrix, x_t, i_hidden_matrix, h_prev)
2707+
_call_func("i_gate_func", i_t, i_t)
2708+
2709+
# Calculate forget gate
2710+
f_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'f_input_matrix')
2711+
f_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'f_hidden_matrix')
2712+
f_t = _mac(f_input_matrix, x_t, f_hidden_matrix, h_prev)
2713+
_call_func("f_gate_func", f_t, f_t)
2714+
2715+
# Update cell state
2716+
g_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'g_input_matrix')
2717+
g_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'g_hidden_matrix')
2718+
g_t = _mac(g_input_matrix, x_t, g_hidden_matrix, h_prev)
2719+
_call_func("g_gate_func", g_t, g_t)
2720+
2721+
c_t = _mac(f_t, c_prev, i_t, g_t, mul_op=_vxv)
2722+
2723+
# Calculate output gate
2724+
o_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'o_input_matrix')
2725+
o_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'o_hidden_matrix')
2726+
o_t = _mac(o_input_matrix, x_t, o_hidden_matrix, h_prev)
2727+
_call_func("o_gate_func", o_t, o_t)
2728+
2729+
# Update hidden state
2730+
h_t = builder.alloca(h_prev.type.pointee)
2731+
_call_func("h_gate_func", c_t, h_t)
2732+
h_t = _vxv(o_t, h_t)
2733+
2734+
# Writeback into value struct
2735+
builder.store(builder.load(h_t), builder.gep(arg_out, [ctx.int32_ty(0),
2736+
ctx.int32_ty(0)]))
2737+
builder.store(builder.load(c_t), builder.gep(arg_out, [ctx.int32_ty(0),
2738+
ctx.int32_ty(1)]))
2739+
2740+
return builder
25342741
# **********************************************************************************************************************
25352742
# LinearMatrix
25362743
# **********************************************************************************************************************

psyneulink/core/compositions/composition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -6747,7 +6747,7 @@ def bfs(start):
67476747
pathways.append(p)
67486748
continue
67496749
for projection, efferent_node in [(p, p.receiver.owner) for p in curr_node.efferents]:
6750-
if (not hasattr(projection,'learnable')) or (projection.learnable is False) or efferent_node in prev:
6750+
if getattr(projection, 'learnable', False) is False or efferent_node in prev:
67516751
continue
67526752
prev[efferent_node] = projection
67536753
prev[projection] = curr_node

psyneulink/core/globals/keywords.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
'LEARNING_PATHWAY', 'LEARNING_PROJECTION', 'LEARNING_PROJECTION_PARAMS', 'LEARNING_RATE', 'LEARNING_SIGNAL',
6969
'LEARNING_SIGNAL_SPECS', 'LEARNING_SIGNALS',
7070
'LESS_THAN', 'LESS_THAN_OR_EQUAL', 'LINEAR', 'LINEAR_COMBINATION_FUNCTION', 'LINEAR_FUNCTION',
71-
'LINEAR_MATRIX_FUNCTION', 'LOG_ENTRIES', 'LOGISTIC_FUNCTION', 'LOW', 'LVOC_CONTROL_MECHANISM', 'L0', 'L1',
71+
'LINEAR_MATRIX_FUNCTION', 'LOG_ENTRIES', 'LOGISTIC_FUNCTION', 'LOW', 'LSTM_FUNCTION', 'LVOC_CONTROL_MECHANISM', 'L0', 'L1',
7272
'MAPPING_PROJECTION', 'MAPPING_PROJECTION_PARAMS', 'MASKED_MAPPING_PROJECTION',
7373
'MATRIX', 'MATRIX_KEYWORD_NAMES', 'MATRIX_KEYWORD_SET', 'MATRIX_KEYWORD_VALUES', 'MATRIX_KEYWORDS','MatrixKeywords',
7474
'MAX_ABS_DIFF', 'MAX_ABS_INDICATOR', 'MAX_ONE_HOT', 'MAX_ABS_ONE_HOT', 'MAX_ABS_VAL',
@@ -526,6 +526,7 @@ def _is_metric(metric):
526526
TRANSFER_MECHANISM = "TransferMechanism"
527527
LEABRA_MECHANISM = "LeabraMechanism"
528528
RECURRENT_TRANSFER_MECHANISM = "RecurrentTransferMechanism"
529+
LSTM_MECHANISM = "LSTMMechanism"
529530
CONTRASTIVE_HEBBIAN_MECHANISM = "ContrastiveHebbianMechanism"
530531
LCA_MECHANISM = "LCAMechanism"
531532
KOHONEN_MECHANISM = 'KohonenMechanism'
@@ -557,6 +558,7 @@ def _is_metric(metric):
557558
GAUSSIAN_FUNCTION = "Gaussian Function"
558559
GAUSSIAN_DISTORT_FUNCTION = "GaussianDistort Function"
559560
SOFTMAX_FUNCTION = 'SoftMax Function'
561+
LSTM_FUNCTION = 'LSTM Function'
560562
LINEAR_MATRIX_FUNCTION = "LinearMatrix Function"
561563
TRANSFER_WITH_COSTS_FUNCTION = "TransferWithCosts Function"
562564

0 commit comments

Comments
 (0)