diff --git a/spine/data/neutrino.py b/spine/data/neutrino.py index 59f0a075..8aea01da 100644 --- a/spine/data/neutrino.py +++ b/spine/data/neutrino.py @@ -34,7 +34,7 @@ class Neutrino(PosDataBase): pdg_code : int PDG code of the neutrino lepton_pdg_code : int - PDF code of the outgoing lepton + PDG code of the outgoing lepton current_type : int Enumerated current type of the neutrino interaction interaction_mode : int diff --git a/spine/data/optical.py b/spine/data/optical.py index 3222cfc7..1ce5fbc2 100644 --- a/spine/data/optical.py +++ b/spine/data/optical.py @@ -21,7 +21,7 @@ class Flash(PosDataBase): id : int Index of the flash in the list volume_id : int - Index of the optical volume in which the flahs was recorded + Index of the optical volume in which the flash was recorded time : float Time with respect to the trigger in microseconds time_width : float @@ -111,3 +111,66 @@ def from_larcv(cls, flash): time_abs=flash.absTime(), time_width=flash.timeWidth(), total_pe=flash.TotalPE(), pe_per_ch=pe_per_ch, center=center, width=width) + + def merge(self, other): + """Merge another flash into this one. + + The merging strategy proceeds as follows: + - The earlier flash takes precedence over the later flash as far as + all timing-related information is concerned (flash time, etc.) + - The combined flash centroid is produced by taking the weighted average + of the two existing flash centroids + - The PE values in each light collection system are added together, so + is the total PE value of the combined flash + + Parameters + ---------- + other : Flash + Flash to merge into this one + """ + # Check that the position units are the same + assert self.units == other.units, ( + "The units of the flash to be merged do not match.") + + # Determine the flash window end points (to merge time widths later) + end_i, end_j = self.time + self.time_width, other.time + other.time_width + + # If the other flash happened first, update the timing information + if self.time > other.time: + self.time = other.time + self.time_abs = other.time_abs + self.on_beam_time = other.on_beam_time + self.frame = other.frame + self.in_beam_frame = other.in_beam_frame + + # Take the union of the two time widths as the new combined width + self.time_width = max(end_i, end_j) - self.time + + # Take the weighted average of the centroids to compute the new one + valid_mask = (self.width > 0.) & (other.width > 0.) + + w_i, w_j = 1./self.width**2, 1./other.width**2 + self.center = (w_i*self.center + w_j*other.center)/(w_i + w_j) + + self.width = 1./np.sqrt(w_i + w_j) + self.width[~valid_mask] = -1. + + # Compute the new total PE and fast light component to total ratio + t_i, t_j = self.total_pe, other.total_pe + self.total_pe = t_i + t_j + + r_i, r_j = self.fast_to_total, other.fast_to_total + self.fast_to_total = (r_i*t_i + r_j*t_j)/(t_i + t_j) + + # Merge the PE count in each PMT + pe_per_ch = np.zeros( + max(len(self.pe_per_ch), len(other.pe_per_ch)), + dtype=self.pe_per_ch.dtype) + pe_per_ch[:len(self.pe_per_ch)] += self.pe_per_ch + pe_per_ch[:len(other.pe_per_ch)] += other.pe_per_ch + + self.pe_per_ch = pe_per_ch + + # The new volume ID is invalid if the two original volumes differ + if self.volume_id != other.volume_id: + self.volume_id = -1 diff --git a/spine/data/out/base.py b/spine/data/out/base.py index 7a7688ba..bec968c9 100644 --- a/spine/data/out/base.py +++ b/spine/data/out/base.py @@ -89,6 +89,11 @@ class OutBase(PosDataBase): # Attributes that must not be stored to file when storing lite files _lite_skip_attrs = ('index',) + def reset_match(self): + """Resets the reco/truth matching information for the object.""" + self.is_matched = False + self.match_ids = np.empty(0, dtype=np.int64) + @property def size(self): """Total number of voxels that make up the object. diff --git a/spine/data/out/interaction.py b/spine/data/out/interaction.py index d498e91a..5ec8cbd6 100644 --- a/spine/data/out/interaction.py +++ b/spine/data/out/interaction.py @@ -120,6 +120,28 @@ def __str__(self): return info + def reset_flash_match(self, typed=True): + """Reset all the flash matching attributes. + + Parameters + ---------- + typed : bool, default True + If `True`, the underlying arrays are reset to typed empty arrays + """ + self.is_flash_matched = False + self.flash_total_pe = -1. + self.flash_type_pe = -1. + if typed: + self.flash_ids = np.empty(0, dtype=np.int32) + self.flash_volume_ids = np.empty(0, dtype=np.int32) + self.flash_times = np.empty(0, dtype=np.float32) + self.flash_scores = np.empty(0, dtype=np.float32) + else: + self.flash_ids = [] + self.flash_volume_ids = [] + self.flash_times = [] + self.flash_scores = [] + @property def primary_particles(self): """List of primary particles associated with this interaction. diff --git a/spine/data/out/particle.py b/spine/data/out/particle.py index b81fd786..bee912ea 100644 --- a/spine/data/out/particle.py +++ b/spine/data/out/particle.py @@ -196,13 +196,6 @@ def p(self): def p(self, p): pass - def unmatch(self): - """ - Unmatch the particle from its reco or truth particle match. - """ - self.match_ids = [] - self.is_matched = False - @dataclass(eq=False) @inherit_docstring(RecoBase, ParticleBase) @@ -224,6 +217,8 @@ class RecoParticle(ParticleBase, RecoBase): interaction vertex position in cm start_dedx : float dE/dx around a user-defined neighborhood of the start point in MeV/cm + end_dedx : float + dE/dx around a user-defined neighborhood of the end point in MeV/cm start_straightness : float Explained variance ratio of the beginning of the particle directional_spread : float @@ -238,6 +233,7 @@ class RecoParticle(ParticleBase, RecoBase): ppn_points: np.ndarray = None vertex_distance: float = -1. start_dedx: float = -1. + end_dedx: float = -1. start_straightness: float = -1. directional_spread: float = -1. axial_spread: float = -np.inf diff --git a/spine/io/collate.py b/spine/io/collate.py index e18b2138..76421105 100644 --- a/spine/io/collate.py +++ b/spine/io/collate.py @@ -253,7 +253,7 @@ def stack_feat_tensors(self, batch, key): # Dispatch if not self.split or sources is None: tensor = np.concatenate([sample[key].features for sample in batch]) - counts = [len(sample[key]) for sample in batch] + counts = [len(sample[key].features) for sample in batch] else: batch_size = len(batch) diff --git a/spine/io/parse/misc.py b/spine/io/parse/misc.py index c12fac6a..23e4634d 100644 --- a/spine/io/parse/misc.py +++ b/spine/io/parse/misc.py @@ -4,7 +4,6 @@ - :class:`Meta2DParser` - :class:`Meta3DParser` - :class:`RunInfoParser` -- :class:`OpFlashParser` - :class:`CRTHitParser` - :class:`TriggerParser` """ @@ -13,6 +12,7 @@ from spine.data import Meta, RunInfo, Flash, CRTHit, Trigger from spine.utils.conditional import larcv +from spine.utils.optical import FlashMerger from .base import ParserBase from .data import ParserObjectList @@ -67,7 +67,7 @@ def __init__(self, projection_id=None, **kwargs): projection_id : int, optional Projection ID to get the 2D image from (if fetching from 2D) **kwargs : dict, optional - Data product arguments to be passed to the `process` function + data product arguments to be passed to the `process` function """ # Initialize the parent class super().__init__(**kwargs) @@ -186,6 +186,24 @@ class FlashParser(ParserBase): # Type of object(s) returned by the parser returns = 'object_list' + def __init__(self, merge=None, **kwargs): + """Initialize the flash parser. + + Parameters + ---------- + merge : dict, optional + Flash merging configuration + **kwargs : dict, optional + data product arguments to be passed to the `process` function + """ + # Initialize the parent class + super().__init__(**kwargs) + + # Initialize the flash merging class, if needed + self.merger = None + if merge is not None: + self.merger = FlashMerger(**merge) + def __call__(self, trees): """Parse one entry. @@ -238,6 +256,10 @@ def process(self, flash_event=None, flash_event_list=None): flashes.append(flash) idx += 1 + # If requested, merge flashes which match in time + if self.merger is not None: + flashes, _ = merger(flashes) + return ParserObjectList(flashes, Flash()) diff --git a/spine/io/parse/sparse.py b/spine/io/parse/sparse.py index 2a809009..74e3a828 100644 --- a/spine/io/parse/sparse.py +++ b/spine/io/parse/sparse.py @@ -13,7 +13,7 @@ from spine.data import Meta from spine.utils.globals import GHOST_SHP, SHAPE_PREC -from spine.utils.ghost import compute_rescaled_charge +from spine.utils.ghost import ChargeRescaler from spine.utils.conditional import larcv from .base import ParserBase @@ -457,9 +457,8 @@ def __init__(self, dtype, collection_only=False, collection_id=2, **kwargs): # Initialize the parent class super().__init__(dtype, **kwargs) - # Store the revelant attributes - self.collection_only = collection_only - self.collection_id = collection_id + # Initialize the charge rescaler + self.rescaler = ChargeRescaler(collection_only, collection_id) def __call__(self, trees): """Parse one entry. @@ -497,10 +496,7 @@ def process_rescale(self, sparse_event_list): # Use individual hit informations to compute a rescaled charge deghost_mask = np.where(tensor.features[:, -1] < GHOST_SHP)[0] - charges = compute_rescaled_charge( - tensor.features[deghost_mask, :-1], - collection_only=self.collection_only, - collection_id=self.collection_id) + charges = self.rescaler.process_single(tensor.features[deghost_mask, :-1]) tensor.features = charges[:, None] diff --git a/spine/model/full_chain.py b/spine/model/full_chain.py index 9c3c5407..f84dff30 100644 --- a/spine/model/full_chain.py +++ b/spine/model/full_chain.py @@ -17,13 +17,14 @@ # TODO: raname it something more generic like ParticleClusterImageClassifier? from spine.data import TensorBatch, IndexBatch, RunInfo + from spine.utils.logger import logger from spine.utils.globals import ( COORD_COLS, VALUE_COL, CLUST_COL, SHAPE_COL, SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP, GHOST_SHP) +from spine.utils.ghost import ChargeRescaler from spine.utils.calib import CalibrationManager from spine.utils.ppn import ParticlePointPredictor -from spine.utils.ghost import compute_rescaled_charge_batch from spine.utils.cluster.label import ClusterLabelAdapter from spine.utils.gnn.cluster import ( form_clusters_batch, get_cluster_label_batch) @@ -164,6 +165,11 @@ def __init__(self, chain, uresnet_deghost=None, uresnet=None, "`uresnet_deghost` configuration block.") self.uresnet_deghost = UResNetSegmentation(uresnet_deghost) + # Initialize the charge rescaling process (adapt to ghost predictions) + if self.charge_rescaling is not None: + self.charge_rescaler = ChargeRescaler( + collection_only=self.charge_rescaling == 'collection') + # Initialize the semantic segmentation model (+ point proposal) if self.segmentation is not None and self.segmentation == 'uresnet': assert (uresnet is not None) ^ (uresnet_ppn is not None), ( @@ -363,8 +369,7 @@ def run_deghosting(self, data, sources=None, seg_label=None, # Rescale the charge, if requested if self.charge_rescaling is not None: - charges = compute_rescaled_charge_batch( - data_adapt, self.charge_rescaling == 'collection') + charges = self.charge_rescaler(data_adapt) tensor_deghost = data_adapt.tensor[:, :-6] tensor_deghost[:, VALUE_COL] = charges data_adapt.data = tensor_deghost diff --git a/spine/post/optical/flash_matching.py b/spine/post/optical/flash_matching.py index ed12a491..5b385998 100644 --- a/spine/post/optical/flash_matching.py +++ b/spine/post/optical/flash_matching.py @@ -3,12 +3,14 @@ import numpy as np from warnings import warn +import copy from spine.post.base import PostBase from spine.data.out.base import OutBase from spine.utils.geo import Geometry +from spine.utils.optical import FlashMerger from .barycenter import BarycenterFlashMatcher from .likelihood import LikelihoodFlashMatcher @@ -31,7 +33,8 @@ class FlashMatchProcessor(PostBase): def __init__(self, flash_key, volume, ref_volume_id=None, method='likelihood', detector=None, geometry_file=None, run_mode='reco', truth_point_mode='points', - truth_dep_mode='depositions', parent_path=None, **kwargs): + truth_dep_mode='depositions', parent_path=None, merge=None, + update_flashes=False, **kwargs): """Initialize the flash matching algorithm. Parameters @@ -53,6 +56,11 @@ def __init__(self, flash_key, volume, ref_volume_id=None, parent_path : str, optional Path to the parent directory of the main analysis configuration. This allows for the use of relative paths in the post-processors. + merge : dict, optional + Flash merging configuration + update_flashes : bool, default False + If `True` and merging flashes, replaces the original list of + flashes in place with the list of merged flashes **kwargs : dict Keyword arguments to pass to specific flash matching algorithms """ @@ -85,6 +93,12 @@ def __init__(self, flash_key, volume, ref_volume_id=None, else: raise ValueError(f'Flash matching method not recognized: {method}') + # Initialize the flash merging class, if needed + self.merger = None + if merge is not None: + self.merger = FlashMerger(**merge) + self.update_flashes = update_flashes + def process(self, data): """Find [interaction, flash] pairs. @@ -105,6 +119,8 @@ def process(self, data): The flash optical volume IDs in the flash list - interaction.flash_times: np.ndarray The flash time(s) in microseconds + - interaction.flash_scores: np.ndarray + The flash scores(s) (larger is better) - interaction.flash_total_pe: float Total number of PEs associated with the matched flash(es) - interaction.flash_hypo_pe: float, optional @@ -113,7 +129,31 @@ def process(self, data): # Fetch the optical volume each flash belongs to flashes = data[self.flash_key] volume_ids = np.asarray([f.volume_id for f in flashes]) - + + # Resize the PE vectors to match the optical geometry + # TODO: ideally this should not happen in the flash matcher... + for flash in flashes: + # Reshape the flash based on geometry + pe_per_ch = np.zeros( + self.geo.optical.num_detectors_per_volume, + dtype=flash.pe_per_ch.dtype) + if (self.ref_volume_id is not None and + len(flash.pe_per_ch) > len(pe_per_ch)): + # If the flash spans > 1 optical volume, reshape + lower = flash.volume_id*len(pe_per_ch) + upper = (flash.volume_id + 1)*len(pe_per_ch) + pe_per_ch = flash.pe_per_ch[lower:upper] + + else: + # Otherwise, just pad if it does not fill the full length + pe_per_ch[:len(flash.pe_per_ch)] = flash.pe_per_ch + + flash.pe_per_ch = pe_per_ch + + # Merge flashes based on timing, if requested + if self.merger is not None: + flashes, orig_ids = self.merger(flashes) + # Loop over the optical volumes, run flash matching for k in self.interaction_keys: # Fetch interactions, nothing to do if there are not any @@ -126,41 +166,15 @@ def process(self, data): # Clear previous flash matching information for inter in interactions: - inter.flash_ids = [] - inter.flash_volume_ids = [] - inter.flash_times = [] - inter.flash_scores = [] - if inter.is_flash_matched: - inter.is_flash_matched = False - inter.flash_total_pe = -1. - inter.flash_hypo_pe = -1. + inter.reset_flash_match(typed=False) # Loop over the optical volumes for volume_id in np.unique(volume_ids): # Get the list of flashes associated with this optical volume flashes_v = [] for flash in flashes: - # Skip if the flash is not associated with the right volume - if flash.volume_id != volume_id: - continue - - # Reshape the flash based on geometry - pe_per_ch = np.zeros( - self.geo.optical.num_detectors_per_volume, - dtype=flash.pe_per_ch.dtype) - if (self.ref_volume_id is not None and - len(flash.pe_per_ch) > len(pe_per_ch)): - # If the flash spans > 1 optical volume, reshape - lower = flash.volume_id*len(pe_per_ch) - upper = (flash.volume_id + 1)*len(pe_per_ch) - pe_per_ch = flash.pe_per_ch[lower:upper] - - else: - # Otherwise, just pad if it does not fill the full length - pe_per_ch[:len(flash.pe_per_ch)] = flash.pe_per_ch - - flash.pe_per_ch = pe_per_ch - flashes_v.append(flash) + if flash.volume_id == volume_id: + flashes_v.append(flash) # Crop interactions to only include depositions in the optical volume interactions_v = [] @@ -207,19 +221,26 @@ def process(self, data): if hasattr(match, 'score'): score = float(match.score) - # Append - inter.flash_ids.append(int(flash.id)) - inter.flash_volume_ids.append(int(flash.volume_id)) - inter.flash_times.append(float(flash.time)) - inter.flash_scores.append(score) - if inter.is_flash_matched: + # Update + if not inter.is_flash_matched: + inter.is_flash_matched = True + inter.flash_total_pe = float(flash.total_pe) + inter.flash_hypo_pe = hypo_pe + else: inter.flash_total_pe += float(flash.total_pe) inter.flash_hypo_pe += hypo_pe + if self.merger is not None and not self.update_flashes: + orig_flashes = [data[self.flash_key][i] for i in orig_ids[flash.id]] + inter.flash_ids.extend([f.id for f in orig_flashes]) + inter.flash_volume_ids.extend([f.volume_id for f in orig_flashes]) + inter.flash_times.extend([f.time for f in orig_flashes]) + inter.flash_scores.extend([score for _ in orig_flashes]) else: - inter.is_flash_matched = True - inter.flash_total_pe = float(flash.total_pe) - inter.flash_hypo_pe = hypo_pe + inter.flash_ids.append(int(flash.id)) + inter.flash_volume_ids.append(int(flash.volume_id)) + inter.flash_times.append(float(flash.time)) + inter.flash_scores.append(score) # Cast list attributes to numpy arrays for inter in interactions: @@ -227,3 +248,7 @@ def process(self, data): inter.flash_volume_ids = np.asarray(inter.flash_volume_ids, dtype=np.int32) inter.flash_times = np.asarray(inter.flash_times, dtype=np.float32) inter.flash_scores = np.asarray(inter.flash_scores, dtype=np.float32) + + # Return an updated flash list, if requested + if self.update_flashes: + return {self.flash_key: flashes} diff --git a/spine/post/optical/likelihood.py b/spine/post/optical/likelihood.py index 3a6ba672..29559e32 100644 --- a/spine/post/optical/likelihood.py +++ b/spine/post/optical/likelihood.py @@ -12,8 +12,7 @@ class LikelihoodFlashMatcher: See https://github.com/drinkingkazu/OpT0Finder for more details about it. """ - def __init__(self, cfg, detector, parent_path=None, - reflash_merging_window=None, scaling=1., alpha=0.21, + def __init__(self, cfg, detector, parent_path=None, scaling=1., alpha=0.21, recombination_mip=0.65, legacy=False): """Initialize the likelihood-based flash matching algorithm. @@ -25,8 +24,6 @@ def __init__(self, cfg, detector, parent_path=None, Detector to get the geometry from parent_path : str, optional Path to the parent configuration file (allows for relative paths) - reflash_merging_window : float, optional - Maximum time between successive flashes to be considered a reflash scaling : Union[float, str], default 1. Global scaling factor for the depositions (can be an expression) alpha : float, default 0.21 @@ -40,7 +37,6 @@ def __init__(self, cfg, detector, parent_path=None, self.initialize_backend(cfg, detector, parent_path) # Get the external parameters - self.reflash_merging_window = reflash_merging_window self.scaling = scaling if isinstance(self.scaling, str): self.scaling = eval(self.scaling) @@ -222,23 +218,6 @@ def make_flash_list(self, flashes): List[Flash_t] List of flashmatch::Flash_t objects """ - # If requested, merge flashes that are compatible in time - if self.reflash_merging_window is not None: - times = [f.time for f in flashes] - perm = np.argsort(times) - new_flashes = [flashes[perm[0]]] - for i in range(1, len(perm)): - prev, curr = perm[i-1], perm[i] - if ((flashes[curr].time - flashes[prev].time) - < self.reflash_merging_window): - # If compatible, simply add up the PEs - pe_v = flashes[prev].pe_per_ch + flashes[curr].pe_per_ch - new_flashes[-1].pe_per_ch = pe_v - else: - new_flashes.append(flashes[curr]) - - flashes = new_flashes - # Loop over the optical flashes from flashmatch import flashmatch flash_v = [] @@ -343,7 +322,6 @@ def get_flash(self, idx, array=False): raise Exception('Flash {idx} does not exist in self.flash_v') - def get_match(self, idx): """Fetch a match for a given TPC interaction ID. diff --git a/spine/post/reco/cathode_cross.py b/spine/post/reco/cathode_cross.py index 4536ec89..fc0bb79c 100644 --- a/spine/post/reco/cathode_cross.py +++ b/spine/post/reco/cathode_cross.py @@ -1,11 +1,11 @@ """Cathode crossing identification + merging module.""" import numpy as np +from scipy.spatial.distance import cdist from spine.data import RecoInteraction, TruthInteraction -from spine.math.distance import cdist, farthest_pair -from scipy.spatial.distance import cdist as scipy_cdist +from spine.math.distance import farthest_pair from spine.utils.globals import TRACK_SHP from spine.utils.geo import Geometry @@ -82,7 +82,6 @@ def __init__(self, crossing_point_tolerance, offset_tolerance, keys['points'] = True if run_mode != 'reco': keys[truth_point_mode] = True - keys['meta'] = True #Needed to find shift in the cathode self.update_keys(keys) def process(self, data): @@ -93,8 +92,11 @@ def process(self, data): data : dict Dictionary of data products """ - #Get the drift pixel resolution - dx_res = data['meta'].size[0] + # Reset all particle/interaction matches, they are broken by merging + for obj_key in self.obj_keys: + for obj in data[obj_key]: + obj.reset_match() + # Loop over particle types update_dict = {} for part_key in self.particle_keys: @@ -135,7 +137,7 @@ def process(self, data): if (part.is_cathode_crosser and self.adjust_crossers and len(tpcs) == 2): # Adjust positions - self.adjust_positions(data, i,dx_res) + self.adjust_positions(data, i) # If we do not want to merge broken crossers, our job here is done if not self.merge_crossers: @@ -159,14 +161,13 @@ def process(self, data): continue # Get the cathode position, drift axis and cathode plane axes - daxis = self.geo.tpc[modules_i[0]].drift_axis cpos = self.geo.tpc[modules_i[0]].cathode_pos + daxis = self.geo.tpc[modules_i[0]].drift_axis caxes = np.array([i for i in range(3) if i != daxis]) # Store the distance of the particle to the cathode tpc_offset = self.geo.get_min_volume_offset( end_points_i, modules_i[0], tpcs_i[0])[daxis] - cdists = end_points_i[:, daxis] - tpc_offset - cpos # Loop over other tracks j = i + 1 @@ -191,8 +192,7 @@ def process(self, data): # Check if the two particles stop at roughly the same # position in the plane of the cathode compat = True - dist_mat = scipy_cdist( - end_points_i[:, caxes], end_points_j[:, caxes]) + dist_mat = cdist(end_points_i[:, caxes], end_points_j[:, caxes]) argmin = np.argmin(dist_mat) pair_i, pair_j = np.unravel_index(argmin, (2, 2)) compat &= ( @@ -213,7 +213,7 @@ def process(self, data): # If compatible, merge if compat: # Merge particle and adjust positions - self.adjust_positions(data, ci,dx_res, cj, truth=pi.is_truth) + self.adjust_positions(data, ci, cj, truth=pi.is_truth) # Update the candidate list to remove matched particle candidate_ids[j:-1] = candidate_ids[j+1:] - 1 @@ -246,7 +246,7 @@ def process(self, data): return update_dict - def adjust_positions(self, data, idx_i,dx_res, idx_j=None, truth=False): + def adjust_positions(self, data, idx_i, idx_j=None, truth=False): """Given a cathode crosser (either in one or two pieces), apply the necessary position offsets to match it at the cathode. @@ -256,8 +256,6 @@ def adjust_positions(self, data, idx_i,dx_res, idx_j=None, truth=False): Dictionary of data products idx_i : int Index of a cathode crosser (or a cathode crosser fragment) - dx_res : float - Drift pixel resolution [cm]. Offset the drift position by this amount. idx_j : int, optional Index of a matched cathode crosser fragment truth : bool, default False @@ -274,10 +272,6 @@ def adjust_positions(self, data, idx_i,dx_res, idx_j=None, truth=False): points_key = 'points' if not truth else self.truth_point_key particles = data[part_key] if idx_j is not None: - # Unmatch the particles from their interactions - particles[idx_i].unmatch() - particles[idx_j].unmatch() - # Merge particles int_id_i = particles[idx_i].interaction_id int_id_j = particles[idx_j].interaction_id @@ -310,14 +304,12 @@ def adjust_positions(self, data, idx_i,dx_res, idx_j=None, truth=False): int_id = particle.interaction_id sisters = [p for p in particles if p.interaction_id == int_id] - # Get the cathode position + # Get the drift axis m = modules[0] daxis = self.geo.tpc[m].drift_axis - cpos = self.geo.tpc[m].cathode_pos # Loop over contributing TPCs, shift the points in each independently - offsets, global_offset = self.get_cathode_offsets( - particle, m, tpcs) + offsets, global_offset = self.get_cathode_offsets(particle, m, tpcs) for i, t in enumerate(tpcs): # Move each of the sister particles by the same amount for sister in sisters: @@ -329,18 +321,18 @@ def adjust_positions(self, data, idx_i,dx_res, idx_j=None, truth=False): continue # Update the sister position and the main position tensor - self.get_points(sister)[tpc_index, daxis] -= offsets[i] + dx_res - data[points_key][index, daxis] -= offsets[i] + dx_res + self.get_points(sister)[tpc_index, daxis] -= offsets[i] + data[points_key][index, daxis] -= offsets[i] # Update the start/end points appropriately if sister.id == idx_i: for attr, closest_tpc in closest_tpcs.items(): if closest_tpc == t: - getattr(sister, attr)[daxis] -= offsets[i] + dx_res + getattr(sister, attr)[daxis] -= offsets[i] else: - sister.start_point[daxis] -= offsets[i] + dx_res - sister.end_point[daxis] -= offsets[i] + dx_res + sister.start_point[daxis] -= offsets[i] + sister.end_point[daxis] -= offsets[i] # Store crosser information particle.is_cathode_crosser = True @@ -400,10 +392,8 @@ def get_cathode_offsets(self, particle, module, tpcs): float General offset for this particle (proxy of out-of-time displacement) """ - # Get the cathode position + # Get the drift axis daxis = self.geo.tpc[module].drift_axis - cpos = self.geo.tpc[module].cathode_pos - dvector = (np.arange(3) == daxis).astype(float) # Check which side of the cathode each TPC lives flip = (-1) ** ( @@ -411,7 +401,7 @@ def get_cathode_offsets(self, particle, module, tpcs): > self.geo.tpc[module, tpcs[1]].boundaries[daxis].mean()) # Loop over the contributing TPCs - closest_points = np.empty((2, 3)) + # closest_points = np.empty((2, 3)) offsets = np.empty(2) for i, t in enumerate(tpcs): # Get the end points of the track segment @@ -424,14 +414,17 @@ def get_cathode_offsets(self, particle, module, tpcs): # Find the point closest to the cathode tpc_offset = self.geo.get_min_volume_offset( end_points, module, t)[daxis] + cpos = self.geo.tpc[module][t].cathode_pos cdists = end_points[:, daxis] - tpc_offset - cpos argmin = np.argmin(np.abs(cdists)) - closest_points[i] = end_points[argmin] + # closest_points[i] = end_points[argmin] # Compute the offset to bring it to the cathode offsets[i] = cdists[argmin] + tpc_offset # Now optimize the offsets based on angular matching + # cpos = self.geo.tpc[module].cathode_pos + # dvector = (np.arange(3) == daxis).astype(float) # xing_point = np.mean(closest_points, axis=0) # xing_point[daxis] = cpos # for i, t in enumerate(tpcs): @@ -443,7 +436,10 @@ def get_cathode_offsets(self, particle, module, tpcs): # disp = np.dot(dplane, vplane)/np.dot(vplane, vplane) # offsets[i] = [disp, offsets[i]][np.argmin(np.abs([disp, offsets[i]]))] - # Take the average offset as the value to use + # Align the offsets to match the smallest of the two + offsets = np.sign(offsets) * np.min(np.abs(offsets)) + + # Take the smallest of the two offsets (avoid moving into the cathode) global_offset = flip * (offsets[1] - offsets[0])/2. return offsets, global_offset diff --git a/spine/post/reco/topology.py b/spine/post/reco/topology.py index 31cf2947..66d8b82c 100644 --- a/spine/post/reco/topology.py +++ b/spine/post/reco/topology.py @@ -4,22 +4,25 @@ from scipy.stats import pearsonr from sklearn.decomposition import PCA -from spine.utils.globals import PHOT_PID, ELEC_PID +from spine.utils.globals import SHOWR_SHP, PHOT_PID, ELEC_PID from spine.utils.gnn.cluster import cluster_dedx, cluster_dedx_dir from spine.post.base import PostBase -__all__ = ['ParticleStartDEDXProcessor', 'ParticleStartStraightnessProcessor', +__all__ = ['ParticleDEDXProcessor', 'ParticleStartStraightnessProcessor', 'ParticleSpreadProcessor'] -class ParticleStartDEDXProcessor(PostBase): +class ParticleDEDXProcessor(PostBase): """Compute the dE/dx of the particle start by summing the energy depositions along the particle start and dividing by the total length of the start. """ # Name of the post-processor (as specified in the configuration) - name = 'start_dedx' + name = 'local_dedx' + + # Aliases for the post-processor + aliases = ('start_dedx', 'end_dedx') # List of recognized dE/dx computation modes _modes = ('default', 'direction') @@ -79,19 +82,31 @@ def process(self, data): if part.pid not in self.include_pids: continue - # Compute the particle start dE/dx - if self.mode == 'default': - # Use all depositions within a radius of the particle start - part.start_dedx = cluster_dedx( - part.points, part.depositions, part.start_point, - max_dist=self.radius, anchor=self.anchor) - - else: - # Use the particle direction estimate as a guide - part.start_dedx = cluster_dedx_dir( - part.points, part.depositions, part.start_point, - part.start_dir, max_dist=self.radius, - anchor=self.anchor)[0] + # Loop over the two sides of the particle + for side in ('start', 'end'): + # Showers have no end points, skip + if side == 'end' and part.shape == SHOWR_SHP: + continue + + # Fetch the point + ref_point = getattr(part, f'{side}_point') + + # Compute the particle end dE/dx + if self.mode == 'default': + # Use all depositions within a radius of the particle point + dedx = cluster_dedx( + part.points, part.depositions, ref_point, + max_dist=self.radius, anchor=self.anchor) + + else: + # Use the particle direction estimate as a guide + ref_dir = getattr(part, f'{side}_dir')*(-1)**(end == 'end') + dedx = cluster_dedx_dir( + part.points, part.depositions, ref_point, ref_dir, + max_dist=self.radius, anchor=self.anchor)[0] + + # Store the dE/dx + setattr(part, f'{side}_dedx', dedx) class ParticleStartStraightnessProcessor(PostBase): diff --git a/spine/utils/geo/detector/tpc.py b/spine/utils/geo/detector/tpc.py index 5d87212f..21fc384f 100644 --- a/spine/utils/geo/detector/tpc.py +++ b/spine/utils/geo/detector/tpc.py @@ -195,6 +195,17 @@ def cathode_pos(self): """ return np.mean([c.cathode_pos for c in self.chambers]) + @property + def cathode_thickness(self): + """Thickness of the cathode. + + Returns + ------- + float + Thickness of the cathode + """ + return abs(self.chambers[1].cathode_pos - self.chambers[0].cathode_pos) + def __len__(self): """Returns the number of TPCs in the module. diff --git a/spine/utils/geo/source/icarus_geometry.yaml b/spine/utils/geo/source/icarus_geometry.yaml index 1af96b93..e306d414 100644 --- a/spine/utils/geo/source/icarus_geometry.yaml +++ b/spine/utils/geo/source/icarus_geometry.yaml @@ -1,5 +1,5 @@ tpc: - dimensions: [148.2, 316.82, 1789.901] + dimensions: [148.2, 316.82, 1789.902] module_ids: [0, 0, 1, 1] det_ids: [0, 0, 1, 1] positions: diff --git a/spine/utils/geo/source/sbnd_geometry.yaml b/spine/utils/geo/source/sbnd_geometry.yaml index 5df9531a..460654a1 100644 --- a/spine/utils/geo/source/sbnd_geometry.yaml +++ b/spine/utils/geo/source/sbnd_geometry.yaml @@ -1,9 +1,9 @@ tpc: - dimensions: [201.3, 400.016, 499.51562] + dimensions: [201.1, 400.016, 499.51562] module_ids: [0, 0] positions: - - [-100.65, 0.0, 254.70019] - - [100.65, 0.0, 254.70019] + - [-100.75, 0.0, 254.70019] + - [100.75, 0.0, 254.70019] optical: volume: module shape: [box, ellipsoid] diff --git a/spine/utils/ghost.py b/spine/utils/ghost.py index 12795ac5..b8c3cca3 100644 --- a/spine/utils/ghost.py +++ b/spine/utils/ghost.py @@ -8,82 +8,98 @@ from .globals import SHOWR_SHP, TRACK_SHP, MICHL_SHP, DELTA_SHP -def compute_rescaled_charge_batch(data, collection_only=False, collection_id=2): - """Batched version of :func:`compute_rescaled_charge`. - - Parameters - ---------- - data : TensorBatch - (N, 1 + D + N_f + 6) tensor of voxel/value pairs - collection_only : bool, default False - Only use the collection plane to estimate the rescaled charge - collection_id : int, default 2 - Index of the collection plane - - Returns - ------- - Union[np.ndarray, torch.Tensor] - (N) Rescaled charge values - """ - charges = data._empty(len(data.tensor)) - for b in range(data.batch_size): - lower, upper = data.edges[b], data.edges[b+1] - charges[lower:upper] = compute_rescaled_charge( - data[b], collection_only, collection_id) - - return charges - - -def compute_rescaled_charge(data, collection_only=False, collection_id=2): - """Computes rescaled charge after deghosting. - - The last 6 columns of the input tensor *MUST* contain: - - charge in each of the projective planes (3) - - index of the hit in each 2D projection (3) - - Notes - ----- - This function should work on numpy arrays or Torch tensors. - - Parameters - ---------- - data : Union[np.ndarray, torch.Tensor] - (N, 1 + D + N_f + 6) tensor of voxel/value pairs - collection_only : bool, default False - Only use the collection plane to estimate the rescaled charge - collection_id : int, default 2 - Index of the collection plane - - Returns - ------- - data : Union[np.ndarray, torch.Tensor] - (N) Rescaled charge values +class ChargeRescaler: + """Rescales the space point charge based on the deghosting output. + + It ensures that the amount of charge carried by each hit that makes up at + least one space point is not duplicated by distributing said hit charge + across all the space points formed with it. """ - # Define operations on the basis of the input type - if torch.is_tensor(data): - unique = torch.unique - empty = lambda shape: torch.empty(shape, dtype=torch.long, - device=data.device) - sum = lambda x: torch.sum(x, dim=1) - else: - unique = np.unique - empty = np.empty - sum = lambda x: np.sum(x, axis=1) - - # Count how many times each wire hit is used to form a space point - hit_ids = data[:, -3:] - _, inverse, counts = unique( - hit_ids, return_inverse=True, return_counts=True) - multiplicity = counts[inverse].reshape(-1, 3) - - # Rescale the charge on the basis of hit multiplicity - hit_charges = data[:, -6:-3] - if not collection_only: - # Take the average of the charge estimates from each active plane - pmask = hit_ids > -1 - charges = sum((hit_charges*pmask)/multiplicity)/sum(pmask) - else: - # Only use the collection plane measurement - charges = hit_charges[:, collection_id]/multiplicity[:, collection_id] - - return charges + + def __init__(self, collection_only=False, collection_id=2): + """Initialize the charge rescaler. + + Parameters + ---------- + collection_only : bool, default False + If `True`, only use the collection plane to estimate the rescaled charge + collection_id : int, default 2 + Index of the collection plane + """ + # Save the parameters + self.collection_only = collection_only + self.collection_id = collection_id + + def __call__(self, data): + """Rescale the charge of one batch of deghosted data. + + Parameters + ---------- + data : TensorBatch + (N, 1 + D + N_f + 6) tensor of voxel/value pairs + + Returns + ------- + data : Union[np.ndarray, torch.Tensor] + (N) Rescaled charge values + """ + charges = data._empty(len(data.tensor)) + for b in range(data.batch_size): + lower, upper = data.edges[b], data.edges[b+1] + charges[lower:upper] = self.process_single(data[b]) + + return charges + + def process_single(self, data): + """Rescale the charge of one event. + + The last 6 columns of the input tensor *MUST* contain: + - charge in each of the projection planes (3) + - unique index of the hit in each 2D projection (3) + + Notes + ----- + This function should work on numpy arrays or Torch tensors. + + Parameters + ---------- + data : Union[np.ndarray, torch.Tensor] + (N, 1 + D + N_f + 6) tensor of voxel/value pairs + + Returns + ------- + data : Union[np.ndarray, torch.Tensor] + (N) Rescaled charge values + """ + # Define operations on the basis of the input type + if torch.is_tensor(data): + unique, where = torch.unique, torch.where + sum = lambda x: torch.sum(x, dim=1) + else: + unique, where = np.unique, np.where + sum = lambda x: np.sum(x, axis=1) + + # Count how many times each wire hit is used to form a space point + hit_ids = data[:, -3:] + _, inverse, counts = unique( + hit_ids, return_inverse=True, return_counts=True) + mult = counts[inverse].reshape(-1, 3) + + # Rescale the charge on the basis of hit multiplicity + hit_charges = data[:, -6:-3] + if not self.collection_only: + # Take the average of the charge estimates from each active plane + pmask = hit_ids > -1 + charges = sum((hit_charges*pmask)/mult)/sum(pmask) + else: + # Only use the collection plane measurement, when available + charges = hit_charges[:, self.collection_id]/mult[:, self.collection_id] + + # Fallback on the average if there is no collection hit + bad_index = where(hit_ids[:, self.collection_id] < 0)[0] + if len(bad_index) > 0: + pmask = hit_ids[bad_index] > -1 + charges[bad_index] = sum( + (hit_charges[bad_index]*pmask)/mult[bad_index])/sum(pmask) + + return charges diff --git a/spine/utils/optical.py b/spine/utils/optical.py new file mode 100644 index 00000000..2359e3ef --- /dev/null +++ b/spine/utils/optical.py @@ -0,0 +1,123 @@ +"""Defines objects and methods related to optical information.""" + +from copy import deepcopy + +import numpy as np + +from .geo import Geometry + +__all__ = ['FlashMerger'] + + +class FlashMerger: + """Class which takes care of merging flashes together.""" + + def __init__(self, threshold=1.0, window=None, combine_volumes=True): + """Initialize the flash merging class. + + Parameters + ---------- + threshold : float, default 1.0 + Maximum time difference (in us) between two successive flashes for + them to be merged into one combined flash + window : List[float], optional + Time window (in us) within which to merge flashes. If flash times + are outside of this window they are not considered for merging. + combine_volumes : bool, default True + If `True`, merge flashes from different optical volumes + """ + # Store merging parameters + self.threshold = threshold + self.window = window + self.combine_volumes = combine_volumes + + # Check on the merging time window formatting + assert window is None or (hasattr(window, '__len__') and len(window) == 2), ( + "The `window` parameter should be a list/tuple of two numbers.") + + def __call__(self, flashes): + """Combine flashes if they are compatible in time. + + Parameters + ---------- + flashes : List[Flash] + List of flash objects + + Returns + ------- + List[Flash] + (M) List of merged flashes + List[List[int]] + (M) List of original flash indexes which make up the merged flashes + """ + # If there is less than two flashes, nothing to do + if len(flashes) < 2: + return flashes, np.arange(len(flashes)) + + # Dispatch + if not self.combine_volumes: + # Only merge flashes when they belong to the same volume + volume_ids = np.array([f.volume_id for f in flashes]) + new_flashes, orig_ids = [], [] + for volume_id in np.unique(volume_ids): + index = np.where(volume_ids == volume_id)[0] + flashes_i, orig_ids_i = self.merge([flashes[i] for i in index]) + new_flashes.extend(flashes_i) + for ids in orig_ids_i: + orig_ids.append(index[ids]) + + return new_flashes, orig_ids + + else: + # Merge flashes regardless of their optical volume + return self.merge(flashes) + + def merge(self, flashes): + """Merge flashes if they are compatible in time. + + Parameters + ---------- + flashes : List[Flash] + List of flash objects + + Returns + ------- + List[Flash] + (M) List of merged flashes + List[List[int]] + (M) List of original flash indexes which make up the merged flashes + """ + # Order the flashes in time, merge them if they are compatible + times = [f.time for f in flashes] + perm = np.argsort(times) + new_flashes = [deepcopy(flashes[perm[0]])] + new_flashes[-1].id = 0 + orig_ids = [[perm[0]]] + in_window = True + for i in range(1, len(perm)): + # Check the both flashes to be merged are in the merging window + prev, curr = flashes[perm[i-1]], flashes[perm[i]] + if self.window is not None: + in_window = (prev.time > self.window[0] and + curr.time < self.window[1]) + + # Check that the two consecutive flashes are compatible in time + if in_window and (curr.time - prev.time) < self.threshold: + # Merge the successive flashes if they are comptible in time + new_flashes[-1].merge(curr) + orig_ids[-1].append(perm[i]) + + else: + # If the two flashes are not compatible, add a new one + new_flashes.append(deepcopy(curr)) + orig_ids.append([perm[i]]) + + # Reset the flash index to match the new list + new_flashes[-1].id = len(new_flashes) - 1 + + # Reset the volume IDs, if necessary + if self.combine_volumes: + for flash in new_flashes: + flash.volume_id = 0 + + return new_flashes, orig_ids diff --git a/spine/utils/tracking.py b/spine/utils/tracking.py index 817bb914..6db23d62 100644 --- a/spine/utils/tracking.py +++ b/spine/utils/tracking.py @@ -195,9 +195,9 @@ def get_track_deposition_gradient(coordinates: nb.float32[:,:], seg_dedxs = seg_dedxs[valid_index] seg_rrs = seg_rrs[valid_index] - # Compute the dE/dx gradient - gradient = np.cov(seg_rrs, seg_dedxs)[0,1]/np.std(seg_rrs)**2 \ - if np.std(seg_rrs) > 0. else 0. + # Compute the dE/dx gradient = Cov(x,y) / Var(x) + cov = np.cov(seg_rrs, seg_dedxs) + gradient = cov[0,1]/cov[0,0] if cov[0,0] > 0. else 0. return gradient, seg_dedxs, seg_rrs, seg_lengths diff --git a/spine/version.py b/spine/version.py index 7a581a0f..3ae95a15 100644 --- a/spine/version.py +++ b/spine/version.py @@ -1,3 +1,3 @@ """Module which stores the current software version.""" -__version__ = '0.5.0' +__version__ = '0.5.1' diff --git a/spine/vis/geo.py b/spine/vis/geo.py index eab826ff..62eaeee3 100644 --- a/spine/vis/geo.py +++ b/spine/vis/geo.py @@ -90,8 +90,9 @@ def tpc_traces(self, meta=None, draw_faces=False, shared_legend=True, return detectors def optical_traces(self, meta=None, shared_legend=True, legendgroup=None, - name='Optical', color='rgba(0,0,255,0.25)', cmin=None, - cmax=None, zero_supress=False, volume_id=None, **kwargs): + name='Optical', color='rgba(0,0,255,0.25)', + hovertext=None, cmin=None, cmax=None, zero_supress=False, + volume_id=None, **kwargs): """Function which produces a list of traces which represent the optical detectors in a 3D event display. @@ -108,6 +109,8 @@ def optical_traces(self, meta=None, shared_legend=True, legendgroup=None, Name(s) of the detector volumes color : Union[int, str, np.ndarray] Color of optical detectors or list of color of optical detectors + hovertext : Union[str, List[str]], optional + Label or list of labels associated with each optical detector cmin : float, optional Minimum value along the color scale cmax : float, optional @@ -157,6 +160,21 @@ def optical_traces(self, meta=None, shared_legend=True, legendgroup=None, assert len(color) == len(positions), ( "Must provide one value for each optical detector.") + # Build the hovertext vectors + if hovertext is not None: + if np.isscalar(hovertext): + hovertext = [hovertext]*len(positions) + elif len(hovertext) != len(positions): + raise ValueError( + "The `hovertext` attribute should be provided as a scalar, " + "one value per point or one value per optical detector.") + + else: + hovertext = [f'PD ID: {i}' for i in range(len(positions))] + if color is not None and not np.isscalar(color): + for i, hc in enumerate(hovertext): + hovertext[i] = hc + f'
Value: {color[i]:.3f}' + # If cmin/cmax are not provided, must build them so that all optical # detectors share the same colorscale range (not guaranteed otherwise) if color is not None and not np.isscalar(color) and len(color) > 0: @@ -176,6 +194,7 @@ def optical_traces(self, meta=None, shared_legend=True, legendgroup=None, if shape_ids is None: pos = positions col = color + ht = hovertext else: index = np.where(np.asarray(shape_ids) == i)[0] pos = positions[index] @@ -183,6 +202,7 @@ def optical_traces(self, meta=None, shared_legend=True, legendgroup=None, col = color[index] else: col = color + ht = [hovertext[i] for i in index] # If zero-supression is requested, only draw the optical detectors # which record a non-zero signal @@ -190,6 +210,7 @@ def optical_traces(self, meta=None, shared_legend=True, legendgroup=None, index = np.where(np.asarray(col) != 0)[0] pos = pos[index] col = col[index] + ht = [ht[i] for i in index] # Determine wheter to show legends or not showlegend = not shared_legend or i == 0 @@ -204,7 +225,8 @@ def optical_traces(self, meta=None, shared_legend=True, legendgroup=None, traces += box_traces( lower, upper, shared_legend=shared_legend, name=name, color=col, cmin=cmin, cmax=cmax, draw_faces=True, - legendgroup=legendgroup, showlegend=showlegend, **kwargs) + hovertext=ht, legendgroup=legendgroup, + showlegend=showlegend, **kwargs) else: # Convert the optical detector dimensions to a covariance matrix @@ -214,7 +236,8 @@ def optical_traces(self, meta=None, shared_legend=True, legendgroup=None, traces += ellipsoid_traces( pos, covmat, shared_legend=shared_legend, name=name, color=col, cmin=cmin, cmax=cmax, - legendgroup=legendgroup, showlegend=showlegend, **kwargs) + hovertext=ht, legendgroup=legendgroup, + showlegend=showlegend, **kwargs) return traces