Skip to content

Commit

Permalink
gaussian cost for delta_R as well #168
Browse files Browse the repository at this point in the history
  • Loading branch information
jacanchaplais committed Nov 23, 2023
1 parent cfe914f commit 8ca9b1e
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 11 deletions.
27 changes: 24 additions & 3 deletions graphicle/calculate.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,27 @@ def aggregate_momenta(
return momentum_class(list(it.chain.from_iterable(pmu_sums)))


@nb.vectorize("float64(float64, float64)")
def _pt_distance(pt_1: float, pt_2: float) -> float:
return -math.expm1(-0.5 * pow((pt_1 - pt_2) / min(pt_1, pt_2), 2))
@nb.njit(
"float64[:, :](float64[:], float64[:], complex128[:], complex128[:])",
parallel=True,
)
def _assignment_cost(
rapidity_1: base.DoubleVector,
rapidity_2: base.DoubleVector,
xy_pol_1: base.ComplexVector,
xy_pol_2: base.ComplexVector,
) -> base.DoubleVector:
dist_matrix = _delta_R(rapidity_1, rapidity_2, xy_pol_1, xy_pol_2)
num_partons = dist_matrix.shape[0]
pt_2_cache = np.abs(xy_pol_2)
var_pt_recip = 1.0 / np.var(pt_2_cache)
for parton_idx in nb.prange(num_partons):
row = dist_matrix[parton_idx, :]
pt_1 = abs(xy_pol_1[parton_idx])
var_dR_recip = 1.0 / np.var(row)
for jet_idx, (dR_val, pt_2) in enumerate(zip(row, pt_2_cache)):
dpt = pt_1 - pt_2
dR_cost = math.expm1(-0.5 * var_dR_recip * dR_val * dR_val)
pt_cost = math.expm1(-0.5 * var_pt_recip * dpt * dpt)
row[jet_idx] = math.hypot(dR_cost, pt_cost)
return dist_matrix
22 changes: 14 additions & 8 deletions graphicle/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,9 @@ def clusters(


def arg_closest(
focus: gcl.MomentumArray, candidate: gcl.MomentumArray
focus: gcl.MomentumArray,
candidate: gcl.MomentumArray,
num_threads: int = 1,
) -> ty.List[int]:
"""Assigns four-momenta elements in ``candidate`` to the nearest
four-momenta elements in ``focus``. Elements in ``candidate`` are
Expand All @@ -1183,6 +1185,9 @@ def arg_closest(
candidate : MomentumArray
Four-momenta of candidate objects to draw from until ``focus``
objects have each received an assignment.
num_threads : int
Number of threads to parallelise the cost matrix computation
over. Default is 1.
Returns
-------
Expand Down Expand Up @@ -1216,13 +1221,14 @@ def arg_closest(
Systems*, 52(4):1679-1696, August 2016,
:doi:`10.1109/TAES.2016.140952`
"""
dist_matrix = focus.delta_R(candidate, pseudo=False)
pt_dist = gcl.calculate._pt_distance.outer(focus.pt, candidate.pt)
np.divide(
dist_matrix, dist_matrix.max(axis=1, keepdims=True), out=dist_matrix
)
np.hypot(dist_matrix, pt_dist, out=dist_matrix)
_, idxs = opt.linear_sum_assignment(dist_matrix)
with gcl.calculate._thread_scope(num_threads):
cost_matrix = gcl.calculate._assignment_cost(
focus.rapidity,
candidate.rapidity,
focus._xy_pol,
candidate._xy_pol,
)
_, idxs = opt.linear_sum_assignment(cost_matrix)
return idxs.tolist()


Expand Down

0 comments on commit 8ca9b1e

Please sign in to comment.