diff --git a/graphicle/calculate.py b/graphicle/calculate.py index f769c1a..6d22826 100644 --- a/graphicle/calculate.py +++ b/graphicle/calculate.py @@ -654,6 +654,27 @@ def aggregate_momenta( return momentum_class(list(it.chain.from_iterable(pmu_sums))) -@nb.vectorize("float64(float64, float64)") -def _pt_distance(pt_1: float, pt_2: float) -> float: - return -math.expm1(-0.5 * pow((pt_1 - pt_2) / min(pt_1, pt_2), 2)) +@nb.njit( + "float64[:, :](float64[:], float64[:], complex128[:], complex128[:])", + parallel=True, +) +def _assignment_cost( + rapidity_1: base.DoubleVector, + rapidity_2: base.DoubleVector, + xy_pol_1: base.ComplexVector, + xy_pol_2: base.ComplexVector, +) -> base.DoubleVector: + dist_matrix = _delta_R(rapidity_1, rapidity_2, xy_pol_1, xy_pol_2) + num_partons = dist_matrix.shape[0] + pt_2_cache = np.abs(xy_pol_2) + var_pt_recip = 1.0 / np.var(pt_2_cache) + for parton_idx in nb.prange(num_partons): + row = dist_matrix[parton_idx, :] + pt_1 = abs(xy_pol_1[parton_idx]) + var_dR_recip = 1.0 / np.var(row) + for jet_idx, (dR_val, pt_2) in enumerate(zip(row, pt_2_cache)): + dpt = pt_1 - pt_2 + dR_cost = math.expm1(-0.5 * var_dR_recip * dR_val * dR_val) + pt_cost = math.expm1(-0.5 * var_pt_recip * dpt * dpt) + row[jet_idx] = math.hypot(dR_cost, pt_cost) + return dist_matrix diff --git a/graphicle/select.py b/graphicle/select.py index 9f06bf7..61ce626 100644 --- a/graphicle/select.py +++ b/graphicle/select.py @@ -1162,7 +1162,9 @@ def clusters( def arg_closest( - focus: gcl.MomentumArray, candidate: gcl.MomentumArray + focus: gcl.MomentumArray, + candidate: gcl.MomentumArray, + num_threads: int = 1, ) -> ty.List[int]: """Assigns four-momenta elements in ``candidate`` to the nearest four-momenta elements in ``focus``. Elements in ``candidate`` are @@ -1183,6 +1185,9 @@ def arg_closest( candidate : MomentumArray Four-momenta of candidate objects to draw from until ``focus`` objects have each received an assignment. + num_threads : int + Number of threads to parallelise the cost matrix computation + over. Default is 1. Returns ------- @@ -1216,13 +1221,14 @@ def arg_closest( Systems*, 52(4):1679-1696, August 2016, :doi:`10.1109/TAES.2016.140952` """ - dist_matrix = focus.delta_R(candidate, pseudo=False) - pt_dist = gcl.calculate._pt_distance.outer(focus.pt, candidate.pt) - np.divide( - dist_matrix, dist_matrix.max(axis=1, keepdims=True), out=dist_matrix - ) - np.hypot(dist_matrix, pt_dist, out=dist_matrix) - _, idxs = opt.linear_sum_assignment(dist_matrix) + with gcl.calculate._thread_scope(num_threads): + cost_matrix = gcl.calculate._assignment_cost( + focus.rapidity, + candidate.rapidity, + focus._xy_pol, + candidate._xy_pol, + ) + _, idxs = opt.linear_sum_assignment(cost_matrix) return idxs.tolist()