Skip to content

Commit

Permalink
corrected unsupported numpy inplace overwrite #168
Browse files Browse the repository at this point in the history
  • Loading branch information
jacanchaplais committed Nov 23, 2023
1 parent 0ef34c4 commit 892c3e4
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions graphicle/calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,16 +665,18 @@ def _assignment_cost(
xy_pol_2: base.ComplexVector,
) -> base.DoubleVector:
dist_matrix = _delta_R(rapidity_1, rapidity_2, xy_pol_1, xy_pol_2)
pt_2 = rapidity_2 # recycle memory buffer for transverse momenta
for pol_idx, pol_val in enumerate(xy_pol_2):
pt_2[pol_idx] = abs(pol_val)
var_pt_recip = 1.0 / np.var(pt_2)
num_partons = dist_matrix.shape[0]
pt_2_cache = np.abs(xy_pol_2, out=rapidity_2)
var_pt_recip = 1.0 / np.var(pt_2_cache)
for parton_idx in nb.prange(num_partons):
row = dist_matrix[parton_idx, :]
pt_1 = abs(xy_pol_1[parton_idx])
pt_1_val = abs(xy_pol_1[parton_idx])
var_dR_recip = 1.0 / np.var(row)
for jet_idx, (dR_val, pt_2) in enumerate(zip(row, pt_2_cache)):
dpt = pt_1 - pt_2
for jet_idx, (dR_val, pt_2_val) in enumerate(zip(row, pt_2)):
dpt = pt_1_val - pt_2_val
dR_cost = math.expm1(-0.5 * var_dR_recip * dR_val * dR_val)
pt_cost = math.expm1(-0.5 * var_pt_recip * dpt * dpt)
row[jet_idx] = math.hypot(dR_cost, pt_cost)
row[jet_idx] = -(dR_cost + pt_cost)
return dist_matrix

0 comments on commit 892c3e4

Please sign in to comment.