diff --git a/graphicle/select.py b/graphicle/select.py index 1d21292..3c7114d 100644 --- a/graphicle/select.py +++ b/graphicle/select.py @@ -1199,7 +1199,8 @@ def arg_closest( def monte_carlo_tag( particles: gcl.ParticleSet, - cluster_masks: ty.List[gcl.MaskArray], + cluster_masks: ty.Sequence[gcl.MaskArray], + clustered_pmu: ty.Optional[gcl.MomentumArray] = None, intermediate: bool = False, outgoing: bool = True, sign_sensitive: bool = False, @@ -1214,13 +1215,23 @@ def monte_carlo_tag( .. versionadded:: 0.2.14 + .. versionchanged:: 0.3.6 + Added ``clustered_pmu`` parameter, enabling tagging on clusters + with detector-level cuts. + Parameters ---------- particles : ParticleSet Monte-Carlo particle data record for the whole event. - cluster_masks : list[MaskArray] - List of boolean masks identifying which particles belong to each - of the clusterings. These are defined over the final particles. + cluster_masks : Sequence[MaskArray] + Boolean masks identifying which particles belong to each of the + clusterings. These are defined over the final particles, or + clustered_pmu, see below. + clustered_pmu : MomentumArray, optional + MomentumArray containing the data as passed to the clustering + algorithm. This is useful for when cuts have been applied before + clustering. If unset, the final state momentum for the whole + event will be assumed. intermediate : bool If ``True`` includes partons from the intermediate stage of the hard process. Default is ``False``. @@ -1250,6 +1261,10 @@ def monte_carlo_tag( If ``intermediate`` and ``outgoing`` are simultaneously set to ``False``, or ``blacklist`` and ``whitelist`` are simultaneously not ``None``. + ValueError + If cluster_masks is empty. Additionally, if the elements have a + size mismatch with either the number of particles in the final + state, or the length of clustered_pmu when passed. IndexError If after applying ``blacklist`` or ``whitelist``, no matching partons remain in the hard process. @@ -1295,6 +1310,8 @@ def monte_carlo_tag( ... np.sum(final_pmu[tagged_clusters], axis=0).mass array([163.33889956]) """ + if not cluster_masks: + raise ValueError("cluster_masks is an empty sequence.") portions = [] if outgoing: portions.append("outgoing") @@ -1319,10 +1336,17 @@ def monte_carlo_tag( raise IndexError("No partons matching filters found.") hard_pmu = hard_pmu[hard_mask] hard_pdg = hard_pdg[hard_mask] - cluster_pmu = gcl.calculate.aggregate_momenta( - particles.pmu[particles.final], cluster_masks - ) - idxs = arg_closest(hard_pmu, cluster_pmu) + ref_length = "clustered_pmu" + if clustered_pmu is None: + clustered_pmu = particles.pmu[particles.final] + ref_length = "particles.final" + if len(clustered_pmu) != len(cluster_masks[0]): + raise ValueError( + "shape mismatch: length of elements in cluster_masks must be the " + f"same as the length of {ref_length}." + ) + jets_pmu = gcl.calculate.aggregate_momenta(clustered_pmu, cluster_masks) + idxs = arg_closest(hard_pmu, jets_pmu) tagged_clusters = op.itemgetter(*idxs)(cluster_masks) if len(idxs) == 1: tagged_clusters = (tagged_clusters,)