Source code for spine.data.out.fragment

"""Module with a data class objects which represent output fragments."""

from dataclasses import dataclass, field

import numpy as np

from spine.constants import SHAPE_LABELS, TRACK_SHP, ParticleShape
from spine.data.decorator import stored_property
from spine.data.field import FieldMetadata
from spine.data.larcv.particle import Particle

from .base import OutBase, RecoBase, TruthBase

__all__ = ["RecoFragment", "TruthFragment"]


@dataclass(eq=False, repr=False)
class FragmentBase(OutBase):
    """Base fragment-specific information.

    Attributes
    ----------
    particle_id : int
        Index of the particle this fragment belongs to
    interaction_id : int
        Index of the interaction this fragment belongs to
    shape : int
        Semantic type (shower (0), track (1), Michel (2), delta (3),
        low energy scatter (4)) of this particle
    is_primary : bool
        Whether this fragment was the first in the particle group
    length : float
        Length of the particle (only assigned to track objects)
    start_point : np.ndarray
        (3) Fragment start point
    end_point : np.ndarray
        (3) Fragment end point (only assigned to track objects)
    """

    # Scalar attributes
    particle_id: int = -1
    interaction_id: int = -1

    is_primary: bool = False

    length: float = field(default=np.nan, metadata=FieldMetadata(units="instance"))

    # Enumerated attributes
    shape: int = field(default=-1, metadata=FieldMetadata(enum=ParticleShape))

    # Vector attributes
    start_point: np.ndarray = field(
        default_factory=lambda: np.full(3, np.nan, dtype=np.float32),
        metadata=FieldMetadata(
            length=3,
            dtype=np.float32,
            position=True,
            units="instance",
        ),
    )
    end_point: np.ndarray = field(
        default_factory=lambda: np.full(3, np.nan, dtype=np.float32),
        metadata=FieldMetadata(
            length=3,
            dtype=np.float32,
            position=True,
            units="instance",
        ),
    )

    def __str__(self):
        """Human-readable string representation of the fragment object.

        Results
        -------
        str
            Basic information about the fragment properties
        """
        shape_label = SHAPE_LABELS[self.shape]
        match = self.match_ids[0] if len(self.match_ids) > 0 else -1
        return (
            f"Fragment(ID: {self.id:<3} | ParticleShape: {shape_label:<11} "
            f"| Primary: {self.is_primary:<2} "
            f"| Size: {self.size:<5} | Match: {match:<3})"
        )


[docs] @dataclass(eq=False, repr=False) class RecoFragment(FragmentBase, RecoBase): """Reconstructed fragment information. Attributes ---------- primary_scores : np.ndarray (2) Array of softmax scores associated with secondary and primary start_dir : np.ndarray (3) Fragment direction w.r.t. the start point end_dir : np.ndarray (3) Fragment direction w.r.t. the end point (only assigned to track objects) """ # Vector attributes start_dir: np.ndarray = field( default_factory=lambda: np.full(3, np.nan, dtype=np.float32), metadata=FieldMetadata(length=3, dtype=np.float32, vector=True), ) end_dir: np.ndarray = field( default_factory=lambda: np.full(3, np.nan, dtype=np.float32), metadata=FieldMetadata(length=3, dtype=np.float32, vector=True), ) primary_scores: np.ndarray = field( default_factory=lambda: np.full(2, np.nan, dtype=np.float32), metadata=FieldMetadata(length=2, dtype=np.float32), ) def __str__(self): """Human-readable string representation of the fragment object. Results ------- str Basic information about the fragment properties """ return "Reco" + super().__str__()
[docs] @dataclass(eq=False, repr=False) class TruthFragment(Particle, FragmentBase, TruthBase): """Truth fragment information. This inherits all of the attributes of :class:`Particle`, which contains the G4 truth information for the fragment. Attributes ---------- orig_interaction_id : int Unaltered index of the interaction in the original MC particle list orig_parent_id : int Unaltered index of the particle parent in the original MC particle list orig_group_id : int Unaltered index of the particle group in the original MC particle list orig_children_id : np.ndarray Unaltered list of the particle children in the original MC particle list children_counts : np.ndarray (P) Number of truth child fragment of each shape reco_length : float Reconstructed length of the fragment (only assigned to track objects) reco_start_dir : np.ndarray (3) Particle direction estimate w.r.t. the start point reco_end_dir : np.ndarray (3) Particle direction estimate w.r.t. the end point (only assigned to track objects) """ # Scalar attributes orig_interaction_id: int = -1 orig_parent_id: int = -1 orig_group_id: int = -1 # Vector attributes orig_children_id: np.ndarray = field( default_factory=lambda: np.empty(0, dtype=np.int32), metadata=FieldMetadata(dtype=np.int32), ) children_counts: np.ndarray = field( default_factory=lambda: np.empty(0, dtype=np.int32), metadata=FieldMetadata(dtype=np.int32), ) reco_start_dir: np.ndarray = field( default_factory=lambda: np.full(3, np.nan, dtype=np.float32), metadata=FieldMetadata(length=3, dtype=np.float32, vector=True), ) reco_end_dir: np.ndarray = field( default_factory=lambda: np.full(3, np.nan, dtype=np.float32), metadata=FieldMetadata(length=3, dtype=np.float32, vector=True), ) def __str__(self): """Human-readable string representation of the fragment object. Results ------- str Basic information about the fragment properties """ return "Truth" + super().__str__() @property @stored_property(length=3, dtype=np.float32, vector=True) def start_dir(self) -> np.ndarray: """Converts the initial momentum to a direction vector. Returns ------- np.ndarray (3) Start direction vector """ norm = np.linalg.norm(self.momentum) if norm > 0.0: return self.momentum / norm return np.full(3, np.nan, dtype=np.float32) @property @stored_property(length=3, dtype=np.float32, vector=True) def end_dir(self) -> np.ndarray: """Converts the final momentum to a direction vector. Note that if a particle stops, this is unreliable as an estimate of the direction of the particle before it stops. Returns ------- np.ndarray (3) End direction vector """ if self.shape == TRACK_SHP: norm = np.linalg.norm(self.end_momentum) if norm > 0.0: return self.end_momentum / norm return np.full(3, np.nan, dtype=np.float32)