"""Module to draw 3D arrows."""
from __future__ import annotations
import time
from typing import Any
import numpy as np
from plotly import graph_objs as go
from .point import scatter_points_3d
from .utils import (
ColorInput,
HoverTextInput,
is_scalar_sequence,
select_scalar_or_sequence,
)
__all__ = ["scatter_arrows"]
[docs]
def scatter_arrows(
points: np.ndarray,
directions: np.ndarray,
length: float = 10.0,
tip_ratio: float = 0.25,
color: ColorInput = None,
hovertext: HoverTextInput = None,
line: dict[str, Any] | None = None,
linewidth: float = 5,
name: str | None = None,
) -> list[go.Scatter3d | go.Cone]:
"""Converts a list of points and directions into a set of arrows.
Parameters
----------
points : np.ndarray
(N, 3) Array of point coordinates
directions : np.ndarray
(N, 3) Array of arrow direction vectors
length : float, default 5.0
Length of the arrows
tip_ratio : float, defautl 0.05
Relative arrow tip size w.r.t. its full length
color : Union[str, int, float, Sequence], optional
Color of the arrows, either as one shared scalar value or one value
per arrow.
hovertext : Union[int, float, str, Sequence], optional
Text associated with the arrows, either as one shared label or one
label per arrow.
line : dict, optional
Arrow trunk line property dictionary
linewidth : float, default 2
Width of the arrow trunk lines
name : name
Name of the traces
"""
# Process color and hovertext information for the arrows
color_trunks, hovertext_trunks = color, hovertext
if is_scalar_sequence(color):
if len(color) != len(points):
raise ValueError(
"If providing a list of colors for the arrows, "
"its length must match the number of points."
)
color_trunks = np.repeat(np.asarray(color), 3)
hovertext_arrows = []
for i, direction in enumerate(directions):
vx, vy, vz = direction
ht = f"vx: {vx:0.3f}<br>vy: {vy:0.3f}<br>vz: {vz:0.3f}"
if hovertext is not None:
if not is_scalar_sequence(hovertext):
ht += f"<br>{hovertext}"
else:
ht += f"<br>{select_scalar_or_sequence(hovertext, i)}"
hovertext_arrows.append(ht)
hovertext_trunks = np.repeat(np.asarray(hovertext_arrows), 3)
legendgroup = "group_" + str(time.time())
# Initialize the arrow trunks
vertices = np.empty((0, 3), dtype=points.dtype)
if len(points) > 0:
vertices = []
for point, direction in zip(points, directions):
vertices.extend([point, point + length * direction, [None, None, None]])
vertices = np.vstack(vertices)
traces = scatter_points_3d(
vertices,
color=color_trunks,
hovertext=hovertext_trunks,
line=line,
linewidth=linewidth,
mode="lines",
hovertemplate="%{text}",
name=name,
legendgroup=legendgroup,
)
# Process color information for the arrow tips
colorscale = None
if color is not None and isinstance(color, str):
colorscale = [(0, color), (1, color)]
else:
colorscale = [(0, "black"), (1, "black")]
# Intitialize the arrow tips
ends = points + (1 - tip_ratio / 2) * length * directions
directions = tip_ratio * length * directions
traces += [
go.Cone(
x=ends[:, 0],
y=ends[:, 1],
z=ends[:, 2],
u=directions[:, 0],
v=directions[:, 1],
w=directions[:, 2],
showscale=False,
showlegend=False,
sizemode="raw",
colorscale=colorscale,
hovertext=hovertext_arrows,
hovertemplate="%{hovertext}",
name=name,
legendgroup=legendgroup,
)
]
# Return
return traces