Skip to content

Commit

Permalink
aln: fix index of input from i to i-1
Browse files Browse the repository at this point in the history
  • Loading branch information
caglorithm committed Aug 30, 2023
1 parent 846c3f7 commit 9cf9480
Showing 1 changed file with 22 additions and 24 deletions.
46 changes: 22 additions & 24 deletions neurolib/models/aln/timeIntegration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def timeIntegration(params):
"""Sets up the parameters for time integration
Return:
rates_exc: N*L array : containing the exc. neuron rates in kHz time series of the N nodes
rates_inh: N*L array : containing the inh. neuron rates in kHz time series of the N nodes
Expand Down Expand Up @@ -397,12 +397,11 @@ def timeIntegration_njit_elementwise(
noise_exc,
noise_inh,
):

# squared Jee_max
sq_Jee_max = Jee_max ** 2
sq_Jei_max = Jei_max ** 2
sq_Jie_max = Jie_max ** 2
sq_Jii_max = Jii_max ** 2
sq_Jee_max = Jee_max**2
sq_Jei_max = Jei_max**2
sq_Jie_max = Jie_max**2
sq_Jii_max = Jii_max**2

# initialize so we don't get an error when returning
rd_exc_rhs = 0.0
Expand All @@ -416,7 +415,6 @@ def timeIntegration_njit_elementwise(

### integrate ODE system:
for i in range(startind, startind + len(t)):

if not distr_delay:
# Get the input from one node into another from the rates at time t - connection_delay - 1
# remark: assume Kie == Kee and Kei == Kii
Expand All @@ -430,13 +428,12 @@ def timeIntegration_njit_elementwise(

# loop through all the nodes
for no in range(N):

# To save memory, noise is saved in the rates array
noise_exc[no] = rates_exc[no, i]
noise_inh[no] = rates_inh[no, i]

mue = Jee_max * seem[no] + Jei_max * seim[no] + mue_ou[no] + ext_exc_current[no, i]
mui = Jie_max * siem[no] + Jii_max * siim[no] + mui_ou[no] + ext_inh_current[no, i]
mue = Jee_max * seem[no] + Jei_max * seim[no] + mue_ou[no] + ext_exc_current[no, i - 1]
mui = Jie_max * siem[no] + Jii_max * siim[no] + mui_ou[no] + ext_inh_current[no, i - 1]

# compute row sum of Cmat*rd_exc and Cmat**2*rd_exc
rowsum = 0
Expand All @@ -447,33 +444,35 @@ def timeIntegration_njit_elementwise(

# z1: weighted sum of delayed rates, weights=c*K
z1ee = (
cee * Ke * rd_exc[no, no] + c_gl * Ke_gl * rowsum + c_gl * Ke_gl * ext_exc_rate[no, i]
cee * Ke * rd_exc[no, no] + c_gl * Ke_gl * rowsum + c_gl * Ke_gl * ext_exc_rate[no, i - 1]
) # rate from other regions + exc_ext_rate
z1ei = cei * Ki * rd_inh[no]
z1ie = (
cie * Ke * rd_exc[no, no] + c_gl * Ke_gl * ext_inh_rate[no, i]
cie * Ke * rd_exc[no, no] + c_gl * Ke_gl * ext_inh_rate[no, i - 1]
) # first test of external rate input to inh. population
z1ii = cii * Ki * rd_inh[no]
# z2: weighted sum of delayed rates, weights=c^2*K (see thesis last ch.)
z2ee = (
cee ** 2 * Ke * rd_exc[no, no] + c_gl ** 2 * Ke_gl * rowsumsq + c_gl ** 2 * Ke_gl * ext_exc_rate[no, i]
cee**2 * Ke * rd_exc[no, no]
+ c_gl**2 * Ke_gl * rowsumsq
+ c_gl**2 * Ke_gl * ext_exc_rate[no, i - 1]
)
z2ei = cei ** 2 * Ki * rd_inh[no]
z2ei = cei**2 * Ki * rd_inh[no]
z2ie = (
cie ** 2 * Ke * rd_exc[no, no] + c_gl ** 2 * Ke_gl * ext_inh_rate[no, i]
cie**2 * Ke * rd_exc[no, no] + c_gl**2 * Ke_gl * ext_inh_rate[no, i - 1]
) # external rate input to inh. population
z2ii = cii ** 2 * Ki * rd_inh[no]
z2ii = cii**2 * Ki * rd_inh[no]

sigmae = np.sqrt(
2 * sq_Jee_max * seev[no] * tau_se * taum / ((1 + z1ee) * taum + tau_se)
+ 2 * sq_Jei_max * seiv[no] * tau_si * taum / ((1 + z1ei) * taum + tau_si)
+ sigmae_ext ** 2
+ sigmae_ext**2
) # mV/sqrt(ms)

sigmai = np.sqrt(
2 * sq_Jie_max * siev[no] * tau_se * taum / ((1 + z1ie) * taum + tau_se)
+ 2 * sq_Jii_max * siiv[no] * tau_si * taum / ((1 + z1ii) * taum + tau_si)
+ sigmai_ext ** 2
+ sigmai_ext**2
) # mV/sqrt(ms)

if not filter_sigma:
Expand Down Expand Up @@ -531,10 +530,10 @@ def timeIntegration_njit_elementwise(
seim_rhs = ((1 - seim[no]) * z1ei - seim[no]) / tau_si
siem_rhs = ((1 - siem[no]) * z1ie - siem[no]) / tau_se
siim_rhs = ((1 - siim[no]) * z1ii - siim[no]) / tau_si
seev_rhs = ((1 - seem[no]) ** 2 * z2ee + (z2ee - 2 * tau_se * (z1ee + 1)) * seev[no]) / tau_se ** 2
seiv_rhs = ((1 - seim[no]) ** 2 * z2ei + (z2ei - 2 * tau_si * (z1ei + 1)) * seiv[no]) / tau_si ** 2
siev_rhs = ((1 - siem[no]) ** 2 * z2ie + (z2ie - 2 * tau_se * (z1ie + 1)) * siev[no]) / tau_se ** 2
siiv_rhs = ((1 - siim[no]) ** 2 * z2ii + (z2ii - 2 * tau_si * (z1ii + 1)) * siiv[no]) / tau_si ** 2
seev_rhs = ((1 - seem[no]) ** 2 * z2ee + (z2ee - 2 * tau_se * (z1ee + 1)) * seev[no]) / tau_se**2
seiv_rhs = ((1 - seim[no]) ** 2 * z2ei + (z2ei - 2 * tau_si * (z1ei + 1)) * seiv[no]) / tau_si**2
siev_rhs = ((1 - siem[no]) ** 2 * z2ie + (z2ie - 2 * tau_se * (z1ie + 1)) * siev[no]) / tau_se**2
siiv_rhs = ((1 - siim[no]) ** 2 * z2ii + (z2ii - 2 * tau_si * (z1ii + 1)) * siiv[no]) / tau_si**2

# -------------- integration --------------

Expand Down Expand Up @@ -596,7 +595,6 @@ def interpolate_values(table, xid1, yid1, dxid, dyid):

@numba.njit(locals={"idxX": numba.int64, "idxY": numba.int64})
def lookup_no_interp(x, dx, xi, y, dy, yi):

"""
Return the indices for the closest values for a look-up table
Choose the closest point in the grid
Expand Down Expand Up @@ -636,9 +634,9 @@ def lookup_no_interp(x, dx, xi, y, dy, yi):

return idxX, idxY


@numba.njit(locals={"xid1": numba.int64, "yid1": numba.int64, "dxid": numba.float64, "dyid": numba.float64})
def fast_interp2_opt(x, dx, xi, y, dy, yi):

"""
Returns the values needed for interpolation:
- bilinear (2D) interpolation within ranges,
Expand Down

0 comments on commit 9cf9480

Please sign in to comment.