|
58 | 58 | from psyneulink.core.components.functions.selectionfunctions import OneHot
|
59 | 59 | from psyneulink.core.components.functions.statefulfunctions.integratorfunctions import SimpleIntegrator
|
60 | 60 | from psyneulink.core.components.shellclasses import Projection
|
61 |
| -from psyneulink.core.globals.keywords import \ |
62 |
| - ADDITIVE_PARAM, ALL, BIAS, EXPONENTIAL_FUNCTION, \ |
63 |
| - GAIN, GAUSSIAN_DISTORT_FUNCTION, GAUSSIAN_FUNCTION, HAS_INITIALIZERS, HOLLOW_MATRIX, \ |
64 |
| - IDENTITY_FUNCTION, IDENTITY_MATRIX, INTERCEPT, LEAK, LINEAR_FUNCTION, LINEAR_MATRIX_FUNCTION, LOGISTIC_FUNCTION, \ |
65 |
| - TANH_FUNCTION, MATRIX_KEYWORD_NAMES, MATRIX, MATRIX_KEYWORD_VALUES, MAX_INDICATOR, MAX_VAL, MULTIPLICATIVE_PARAM, \ |
66 |
| - OFF, OFFSET, ON, PER_ITEM, PROB, PRODUCT, OUTPUT_TYPE, PROB_INDICATOR, \ |
67 |
| - RATE, RECEIVER, RELU_FUNCTION, SCALE, SLOPE, SOFTMAX_FUNCTION, STANDARD_DEVIATION, SUM,\ |
68 |
| - TRANSFER_FUNCTION_TYPE, TRANSFER_WITH_COSTS_FUNCTION, VARIANCE, VARIABLE, X_0, PREFERENCE_SET_NAME |
| 61 | +from psyneulink.core.globals.keywords import ADDITIVE_PARAM, ALL, BIAS, EXPONENTIAL_FUNCTION, GAIN, GAUSSIAN_DISTORT_FUNCTION, GAUSSIAN_FUNCTION, HAS_INITIALIZERS, HOLLOW_MATRIX, IDENTITY_FUNCTION, IDENTITY_MATRIX, INTERCEPT, LEAK, LINEAR_FUNCTION, LINEAR_MATRIX_FUNCTION, LOGISTIC_FUNCTION, LSTM_FUNCTION, MATRIX, MATRIX_KEYWORD_NAMES, MATRIX_KEYWORD_VALUES, MAX_INDICATOR, MAX_VAL, MULTIPLICATIVE_PARAM, OFF, OFFSET, ON, OUTPUT_TYPE, PER_ITEM, PREFERENCE_SET_NAME, PROB, PROB_INDICATOR, PRODUCT, RANDOM_CONNECTIVITY_MATRIX, RATE, RECEIVER, RELU_FUNCTION, SCALE, SLOPE, SOFTMAX_FUNCTION, STANDARD_DEVIATION, SUM, TANH_FUNCTION, TRANSFER_FUNCTION_TYPE, TRANSFER_WITH_COSTS_FUNCTION, VARIABLE, VARIANCE, X_0 |
69 | 62 | from psyneulink.core.globals.parameters import \
|
70 | 63 | Parameter, get_validator_by_function
|
71 | 64 | from psyneulink.core.globals.utilities import parameter_spec, get_global_seed, safe_len
|
@@ -2530,7 +2523,221 @@ def derivative(self, output, input=None, context=None):
|
2530 | 2523 |
|
2531 | 2524 | return derivative
|
2532 | 2525 |
|
| 2526 | +# ********************************************************************************************************************** |
| 2527 | +# SoftMax |
| 2528 | +# ********************************************************************************************************************** |
| 2529 | + |
| 2530 | +class LSTM(TransferFunction): |
| 2531 | + componentName = LSTM_FUNCTION |
| 2532 | + |
| 2533 | + def __init__(self, |
| 2534 | + default_variable=None, |
| 2535 | + params=None, |
| 2536 | + owner=None, |
| 2537 | + prefs: tc.optional(is_pref_set) = None): |
| 2538 | + |
| 2539 | + super().__init__( |
| 2540 | + default_variable=default_variable, |
| 2541 | + params=params, |
| 2542 | + owner=owner, |
| 2543 | + prefs=prefs) |
| 2544 | + |
| 2545 | + class Parameters(TransferFunction.Parameters): |
| 2546 | + i_input_matrix = Parameter(modulable=True) |
| 2547 | + i_hidden_matrix = Parameter(modulable=True) |
| 2548 | + i_gate_func = Parameter(default_value=Logistic()) |
| 2549 | + |
| 2550 | + f_input_matrix = Parameter(modulable=True) |
| 2551 | + f_hidden_matrix = Parameter(modulable=True) |
| 2552 | + f_gate_func = Parameter(default_value=Logistic()) |
2533 | 2553 |
|
| 2554 | + g_input_matrix = Parameter(modulable=True) |
| 2555 | + g_hidden_matrix = Parameter(modulable=True) |
| 2556 | + g_gate_func = Parameter(default_value=Tanh()) |
| 2557 | + |
| 2558 | + o_input_matrix = Parameter(modulable=True) |
| 2559 | + o_hidden_matrix = Parameter(modulable=True) |
| 2560 | + o_gate_func = Parameter(default_value=Logistic()) |
| 2561 | + |
| 2562 | + h_gate_func = Parameter(default_value=Tanh()) |
| 2563 | + |
| 2564 | + |
| 2565 | + def _instantiate_attributes_before_function(self, function=None, context=None): |
| 2566 | + input_size = len(self.variable[0]) |
| 2567 | + hidden_size = len(self.variable[1]) |
| 2568 | + |
| 2569 | + # Instatiate input matrices |
| 2570 | + for param_id in ["i_input_matrix", "f_input_matrix", "g_input_matrix", "o_input_matrix"]: |
| 2571 | + param_val = getattr(self, param_id, None) |
| 2572 | + if param_val is None: |
| 2573 | + param_val = RANDOM_CONNECTIVITY_MATRIX |
| 2574 | + |
| 2575 | + setattr(self, param_id, get_matrix(param_val, hidden_size, input_size, context=context)) |
| 2576 | + |
| 2577 | + # Instantiate hidden matrices |
| 2578 | + for param_id in ["i_hidden_matrix", "f_hidden_matrix", "g_hidden_matrix", "o_hidden_matrix"]: |
| 2579 | + param_val = getattr(self, param_id, None) |
| 2580 | + if param_val is None: |
| 2581 | + param_val = RANDOM_CONNECTIVITY_MATRIX |
| 2582 | + |
| 2583 | + setattr(self, param_id, get_matrix(param_val, hidden_size, hidden_size, context=context)) |
| 2584 | + |
| 2585 | + # Instantiate function default variables |
| 2586 | + for param_id in ["i_gate_func", "f_gate_func","g_gate_func", "o_gate_func", "h_gate_func"]: |
| 2587 | + param_val = getattr(self, param_id) |
| 2588 | + param_val.default_variable = np.zeros(hidden_size) |
| 2589 | + param_val.defaults.variable = np.zeros(hidden_size) |
| 2590 | + param_val.variable = np.zeros(hidden_size) |
| 2591 | + param_val.default_value = np.zeros(hidden_size) |
| 2592 | + param_val.defaults.value = np.zeros(hidden_size) |
| 2593 | + param_val.value = np.zeros(hidden_size) |
| 2594 | + |
| 2595 | + def _function(self, |
| 2596 | + variable=None, |
| 2597 | + context=None, |
| 2598 | + params=None, |
| 2599 | + ): |
| 2600 | + |
| 2601 | + x_t = variable[0] |
| 2602 | + h_prev = variable[1] |
| 2603 | + c_prev = variable[2] |
| 2604 | + |
| 2605 | + # Calculate input |
| 2606 | + i_input_matrix = self._get_current_function_param("i_input_matrix", context=context) |
| 2607 | + i_hidden_matrix = self._get_current_function_param("i_hidden_matrix", context=context) |
| 2608 | + i_gate_func = self._get_current_function_param("i_gate_func", context=context) |
| 2609 | + i_t = i_gate_func(np.matmul(i_input_matrix, x_t) + np.matmul(i_hidden_matrix, h_prev)) |
| 2610 | + |
| 2611 | + # Calculate forget gate |
| 2612 | + f_input_matrix = self._get_current_function_param("f_input_matrix", context=context) |
| 2613 | + f_hidden_matrix = self._get_current_function_param("f_hidden_matrix", context=context) |
| 2614 | + f_gate_func = self._get_current_function_param("f_gate_func", context=context) |
| 2615 | + f_t = f_gate_func(np.matmul(f_input_matrix, x_t) + np.matmul(f_hidden_matrix, h_prev)) |
| 2616 | + |
| 2617 | + # Update cell state |
| 2618 | + g_input_matrix = self._get_current_function_param("g_input_matrix", context=context) |
| 2619 | + g_hidden_matrix = self._get_current_function_param("g_hidden_matrix", context=context) |
| 2620 | + g_gate_func = self._get_current_function_param("g_gate_func", context=context) |
| 2621 | + g_t = g_gate_func(np.matmul(g_input_matrix, x_t) + np.matmul(g_hidden_matrix, h_prev)) |
| 2622 | + c_t = np.multiply(f_t, c_prev) + np.multiply(i_t, g_t) |
| 2623 | + |
| 2624 | + # Calculate output gate |
| 2625 | + o_input_matrix = self._get_current_function_param("o_input_matrix", context=context) |
| 2626 | + o_hidden_matrix = self._get_current_function_param("o_hidden_matrix", context=context) |
| 2627 | + o_gate_func = self._get_current_function_param("o_gate_func", context=context) |
| 2628 | + o_t = o_gate_func(np.matmul(o_input_matrix, x_t) + np.matmul(o_hidden_matrix, h_prev)) |
| 2629 | + |
| 2630 | + # Update hidden state |
| 2631 | + h_gate_func = self._get_current_function_param("h_gate_func", context=context) |
| 2632 | + h_t = np.multiply(o_t, h_gate_func(c_t)) |
| 2633 | + value = [h_t, c_t] |
| 2634 | + |
| 2635 | + return value |
| 2636 | + |
| 2637 | + def _gen_llvm_function_body(self, ctx, builder, params, state, arg_in, arg_out, *, tags:frozenset): |
| 2638 | + matmul = ctx.import_llvm_function("__pnl_builtin_mxm") |
| 2639 | + vecadd = ctx.import_llvm_function("__pnl_builtin_vec_add") |
| 2640 | + vechadamard = ctx.import_llvm_function("__pnl_builtin_vec_hadamard") |
| 2641 | + |
| 2642 | + x_t = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(0)]) |
| 2643 | + h_prev = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(1)]) |
| 2644 | + c_prev = builder.gep(arg_in, [ctx.int32_ty(0), ctx.int32_ty(2)]) |
| 2645 | + |
| 2646 | + def _mxv(m, v): |
| 2647 | + tmp = builder.alloca(h_prev.type.pointee) |
| 2648 | + tmp_ptr = builder.gep(tmp, [ctx.int32_ty(0), |
| 2649 | + ctx.int32_ty(0)]) |
| 2650 | + dim_x = len(m.type.pointee) |
| 2651 | + dim_y = len(m.type.pointee.elements[0]) |
| 2652 | + m_ptr = builder.gep(m, [ctx.int32_ty(0), |
| 2653 | + ctx.int32_ty(0), |
| 2654 | + ctx.int32_ty(0)]) |
| 2655 | + v_ptr = builder.gep(v, [ctx.int32_ty(0), |
| 2656 | + ctx.int32_ty(0)]) |
| 2657 | + |
| 2658 | + builder.call(matmul, [m_ptr, |
| 2659 | + v_ptr, |
| 2660 | + ctx.int32_ty(dim_x), |
| 2661 | + ctx.int32_ty(dim_y), |
| 2662 | + ctx.int32_ty(1), |
| 2663 | + tmp_ptr]) |
| 2664 | + |
| 2665 | + return tmp |
| 2666 | + |
| 2667 | + def _vxv(v1, v2): |
| 2668 | + tmp = builder.alloca(h_prev.type.pointee) |
| 2669 | + tmp_ptr = builder.gep(tmp, [ctx.int32_ty(0), |
| 2670 | + ctx.int32_ty(0)]) |
| 2671 | + dim_x = len(v1.type.pointee) |
| 2672 | + v1_ptr = builder.gep(v1, [ctx.int32_ty(0), |
| 2673 | + ctx.int32_ty(0)]) |
| 2674 | + v2_ptr = builder.gep(v2, [ctx.int32_ty(0), |
| 2675 | + ctx.int32_ty(0)]) |
| 2676 | + |
| 2677 | + builder.call(vechadamard, [v1_ptr, |
| 2678 | + v2_ptr, |
| 2679 | + ctx.int32_ty(dim_x), |
| 2680 | + tmp_ptr]) |
| 2681 | + return tmp |
| 2682 | + |
| 2683 | + def _mac(m1, v1, m2, v2, mul_op=_mxv): |
| 2684 | + val1 = mul_op(m1, v1) |
| 2685 | + val2 = mul_op(m2, v2) |
| 2686 | + val1_ptr = builder.gep(val1, [ctx.int32_ty(0), |
| 2687 | + ctx.int32_ty(0)]) |
| 2688 | + val2_ptr = builder.gep(val2, [ctx.int32_ty(0), |
| 2689 | + ctx.int32_ty(0)]) |
| 2690 | + builder.call(vecadd, [val1_ptr, |
| 2691 | + val2_ptr, |
| 2692 | + ctx.int32_ty(len(m1.type.pointee)), |
| 2693 | + val1_ptr]) |
| 2694 | + return val1 |
| 2695 | + |
| 2696 | + def _call_func(func_id, in_vec, out_vec): |
| 2697 | + param_ptr = pnlvm.helpers.get_param_ptr(builder, self, params, func_id) |
| 2698 | + state_ptr = pnlvm.helpers.get_state_ptr(builder, self, state, func_id) |
| 2699 | + |
| 2700 | + llvm_func = ctx.import_llvm_function(getattr(self, func_id), tags=tags) |
| 2701 | + builder.call(llvm_func, [param_ptr, state_ptr, in_vec, out_vec]) |
| 2702 | + |
| 2703 | + # Calculate input |
| 2704 | + i_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'i_input_matrix') |
| 2705 | + i_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'i_hidden_matrix') |
| 2706 | + i_t = _mac(i_input_matrix, x_t, i_hidden_matrix, h_prev) |
| 2707 | + _call_func("i_gate_func", i_t, i_t) |
| 2708 | + |
| 2709 | + # Calculate forget gate |
| 2710 | + f_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'f_input_matrix') |
| 2711 | + f_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'f_hidden_matrix') |
| 2712 | + f_t = _mac(f_input_matrix, x_t, f_hidden_matrix, h_prev) |
| 2713 | + _call_func("f_gate_func", f_t, f_t) |
| 2714 | + |
| 2715 | + # Update cell state |
| 2716 | + g_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'g_input_matrix') |
| 2717 | + g_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'g_hidden_matrix') |
| 2718 | + g_t = _mac(g_input_matrix, x_t, g_hidden_matrix, h_prev) |
| 2719 | + _call_func("g_gate_func", g_t, g_t) |
| 2720 | + |
| 2721 | + c_t = _mac(f_t, c_prev, i_t, g_t, mul_op=_vxv) |
| 2722 | + |
| 2723 | + # Calculate output gate |
| 2724 | + o_input_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'o_input_matrix') |
| 2725 | + o_hidden_matrix = pnlvm.helpers.get_param_ptr(builder, self, params, 'o_hidden_matrix') |
| 2726 | + o_t = _mac(o_input_matrix, x_t, o_hidden_matrix, h_prev) |
| 2727 | + _call_func("o_gate_func", o_t, o_t) |
| 2728 | + |
| 2729 | + # Update hidden state |
| 2730 | + h_t = builder.alloca(h_prev.type.pointee) |
| 2731 | + _call_func("h_gate_func", c_t, h_t) |
| 2732 | + h_t = _vxv(o_t, h_t) |
| 2733 | + |
| 2734 | + # Writeback into value struct |
| 2735 | + builder.store(builder.load(h_t), builder.gep(arg_out, [ctx.int32_ty(0), |
| 2736 | + ctx.int32_ty(0)])) |
| 2737 | + builder.store(builder.load(c_t), builder.gep(arg_out, [ctx.int32_ty(0), |
| 2738 | + ctx.int32_ty(1)])) |
| 2739 | + |
| 2740 | + return builder |
2534 | 2741 | # **********************************************************************************************************************
|
2535 | 2742 | # LinearMatrix
|
2536 | 2743 | # **********************************************************************************************************************
|
|
0 commit comments