Source code for spine.vis.trace.cylinder

"""Module to draw cylinders."""

from __future__ import annotations

import time
from typing import Any

import numpy as np
import plotly.graph_objs as go

from .utils import (
    ColorInput,
    HoverTextInput,
    IntensityInput,
    NumericOrSequence,
    is_scalar_sequence,
    require_matching_length,
    rotation_matrix_from_z,
    select_numeric_or_sequence,
    select_scalar_or_sequence,
)

__all__ = ["cylinder_trace", "cylinder_traces"]


[docs] def cylinder_trace( centroid: np.ndarray, axis: np.ndarray, height: float, diameter: float, num_samples: int = 10, color: ColorInput = None, intensity: IntensityInput = None, hovertext: HoverTextInput = None, showscale: bool = False, **kwargs: Any, ) -> go.Mesh3d: """Draw a cylinder centered at a given position. Parameters ---------- centroid : np.ndarray (3) Centroid of the cylinder axis : np.ndarray (3) Axis direction of the cylinder height : float Height of the cylinder diameter : float Diameter of the cylinder num_samples : int, default 10 Number of points sampled along theta and h in the cylindrical coordinate system of the cylinder. A larger number increases the resolution. color : Union[str, int, float, Sequence], optional Color of the cylinder. Can be a single Plotly color or numeric value. intensity : Union[int, float, Sequence], optional Color intensity of the cylinder along the colorscale axis. Can be a single numeric value or a per-vertex sequence. hovertext : Union[int, float, str, Sequence], optional Text associated with the cylinder. Can be a scalar label or a per-vertex sequence of labels. showscale : bool, default False If True, show the colorscale of the :class:`plotly.graph_objs.Mesh3d` **kwargs : dict, optional Additional parameters to pass to the underlying :class:`plotly.graph_objs.Mesh3d` object """ # Compute the points on a unit cylinder phi = np.linspace(0, 2 * np.pi, num=num_samples) h = np.linspace(-0.5, 0.5, num=num_samples) phi, h = np.meshgrid(phi, h) x = 0.5 * np.cos(phi) y = 0.5 * np.sin(phi) z = h unit_points = np.vstack((x.flatten(), y.flatten(), z.flatten())).T # Compute the rotation matrix which aligns the z axis to the cylinder axis rotmat = rotation_matrix_from_z(axis) # Compute the scaling vectors for radius and height scale = np.diag([diameter, diameter, height]) # Compute the cylinder points cyl_points = centroid + np.dot(unit_points.dot(scale), rotmat) # Convert the color provided to a set of intensities, if needed if color is not None and not isinstance(color, str): if intensity is not None: raise ValueError("Must not provide both `color` and `intensity`.") intensity = np.full(len(cyl_points), color) color = None # Update hovertemplate style hovertemplate = "x: %{x}<br>y: %{y}<br>z: %{z}" if hovertext is not None: if is_scalar_sequence(hovertext): hovertemplate += "<br>%{text}" else: hovertemplate += f"<br>{hovertext}" hovertext = None # Append Mesh3d object return go.Mesh3d( x=cyl_points[:, 0], y=cyl_points[:, 1], z=cyl_points[:, 2], color=color, intensity=intensity, alphahull=0, showscale=showscale, hovertext=hovertext, hovertemplate=hovertemplate, **kwargs, )
[docs] def cylinder_traces( centroids: np.ndarray, axis: np.ndarray, height: NumericOrSequence, diameter: NumericOrSequence, color: ColorInput = None, hovertext: HoverTextInput = None, cmin: float | None = None, cmax: float | None = None, shared_legend: bool = True, legendgroup: str | None = None, showlegend: bool = True, name: str | None = None, **kwargs: Any, ) -> list[go.Mesh3d]: """Function which produces a list of plotly traces of cylinders given a list of centroids and one shared or per-cylinder geometric description. Parameters ---------- centroids : np.ndarray (N, 3) Positions of each of the cylinder centroids axis : np.ndarray (3,) or (N, 3) Axis direction of the cylinders height : Union[int, float, Sequence] Height of the cylinders, either as one shared value or one value per cylinder. diameter : Union[int, float, Sequence] Diameter of the cylinders, either as one shared value or one value per cylinder. color : Union[str, int, float, Sequence], optional Color of the cylinders, either as one shared value or one value per cylinder. hovertext : Union[int, float, str, Sequence], optional Text associated with the cylinders, either as one shared label or one label per cylinder. cmin : float, optional Minimum value along the color scale cmax : float, optional Maximum value along the color scale shared_legend : bool, default True If True, the plotly legend of all ellipsoids is shared as one legendgroup : str, optional Legend group to be shared between all cylinders showlegend : bool, default `True` Whether to show legends on not name : str, optional Name of the trace(s) **kwargs : dict, optional List of additional arguments to pass to the underlying list of :class:`plotly.graph_objs.Mesh3D` Returns ------- Union[List[plotly.graph_objs.Mesh3D]] Cylinder traces """ # Check the parameters if axis.shape != (3,) and axis.shape != (len(centroids), 3): raise ValueError( "Specify one axis for all cylinders, or one axis per cylinder." ) require_matching_length( height, len(centroids), "Specify one height for all cylinders, or one height per cylinder.", ) require_matching_length( diameter, len(centroids), "Specify one diameter for all cylinders, or one diameter per cylinder.", ) require_matching_length( color, len(centroids), "Specify one color for all cylinders, or one color per cylinder.", ) require_matching_length( hovertext, len(centroids), "Specify one hovertext for all cylinders, or one hovertext per cylinder.", ) # If one color is provided per cylinder, give an associated hovertext if hovertext is None and is_scalar_sequence(color): hovertext = [f"Value: {v:0.3f}" for v in color] # If cmin/cmax are not provided, must build them so that all cylinders # share the same colorscale range (not guaranteed otherwise) if color is not None and is_scalar_sequence(color): if len(color) > 0: if cmin is None: cmin = np.min(np.asarray(color)) if cmax is None: cmax = np.max(np.asarray(color)) # If the legend is to be shared, make sure there is a common legend group if shared_legend and legendgroup is None: legendgroup = "group_" + str(time.time()) # Loop over the list of cylinder centroids traces = [] col, hov = color, hovertext for i, centroid in enumerate(centroids): # Fetch the right color/hovertext combination col = select_scalar_or_sequence(color, i) hov = select_scalar_or_sequence(hovertext, i) # If the legend is shared, only draw the legend of the first trace if shared_legend: showlegend = showlegend and i == 0 name_i = name else: name_i = f"{name} {i}" # If any of the axis, height or diameter are arrays, fetch the right one axis_i = axis if len(axis.shape) == 2: axis_i = axis[i] height_i = select_numeric_or_sequence(height, i) diameter_i = select_numeric_or_sequence(diameter, i) # Append list of traces traces.append( cylinder_trace( centroid=centroid, axis=axis_i, height=float(height_i), diameter=float(diameter_i), contour=None, color=col, hovertext=hov, cmin=cmin, cmax=cmax, legendgroup=legendgroup, showlegend=showlegend, name=name_i, **kwargs, ) ) return traces