|
2 | 2 | # MIT Licence, see details in top-level file: LICENCE
|
3 | 3 |
|
4 | 4 | """
|
5 |
| -Classes for parameterizing a trajectory in SE3 with B-splines. |
6 |
| -
|
7 |
| -Copies parts of the API from scipy's B-spline class. |
| 5 | +Classes for parameterizing a trajectory in SE3 with splines. |
8 | 6 | """
|
9 | 7 |
|
10 | 8 | from typing import Any, Dict, List, Optional
|
|
14 | 12 | import matplotlib.pyplot as plt
|
15 | 13 | from spatialmath.base.transforms3d import tranimate, trplot
|
16 | 14 |
|
| 15 | +from typing import Any, List |
| 16 | + |
| 17 | +import matplotlib.pyplot as plt |
| 18 | +import numpy as np |
| 19 | +import numpy.typing as npt |
| 20 | +from scipy.interpolate import CubicSpline |
| 21 | +from scipy.spatial.transform import Rotation, RotationSpline |
| 22 | +from spatialmath import SE3, SO3, Twist3 |
| 23 | +from spatialmath.base.transforms3d import tranimate |
| 24 | + |
| 25 | + |
| 26 | +class InterpSplineSE3: |
| 27 | + """Class for an interpolated trajectory in SE3 through waypoints with a cubic spline. |
| 28 | +
|
| 29 | + A combination of scipy.interpolate.CubicSpline and scipy.spatial.transform.RotationSpline (itself also cubic) |
| 30 | + under the hood. |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + timestamps: list[float] | npt.NDArray, |
| 36 | + waypoints: list[SE3], |
| 37 | + *, |
| 38 | + normalize_time: bool = True, |
| 39 | + bc_type: str | tuple = "not-a-knot", # not-a-knot is scipy default; None is invalid |
| 40 | + ) -> None: |
| 41 | + """Construct a InterpSplineSE3 object |
| 42 | +
|
| 43 | + Extends the scipy CubicSpline object |
| 44 | + https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.CubicSpline.html#cubicspline |
| 45 | +
|
| 46 | + Args : |
| 47 | + timestamps : list of times corresponding to provided poses |
| 48 | + waypoints : list of SE3 objects that govern the shape of the spline. |
| 49 | + normalize_time : flag to map times into the range [0, 1] |
| 50 | + bc_type : boundary condition provided to scipy CubicSpline backend. |
| 51 | + string options: ["not-a-knot" (default), "clamped", "natural", "periodic"]. |
| 52 | + For tuple options and details see the scipy docs link above. |
| 53 | + """ |
| 54 | + |
| 55 | + self.waypoints = waypoints |
| 56 | + self.timestamps = np.array(timestamps) |
| 57 | + |
| 58 | + if normalize_time: |
| 59 | + self.timestamps = self.timestamps - self.timestamps[0] |
| 60 | + self.timestamps = self.timestamps / self.timestamps[-1] |
| 61 | + |
| 62 | + self.xyz_data = np.array([pose.t for pose in waypoints]) |
| 63 | + self.so3_data = Rotation.from_matrix(np.array([(pose.R) for pose in waypoints])) |
| 64 | + |
| 65 | + self.spline_xyz = CubicSpline(self.timestamps, self.xyz_data, bc_type=bc_type) |
| 66 | + self.spline_so3 = RotationSpline(self.timestamps, self.so3_data) |
| 67 | + |
| 68 | + self.interpolation_indices = list(range(len(waypoints))) |
| 69 | + |
| 70 | + def __call__(self, t: float) -> Any: |
| 71 | + |
| 72 | + return SE3.Rt(t=self.spline_xyz(t), R=self.spline_so3(t).as_matrix()) |
| 73 | + |
| 74 | + def derivative(self, t: float) -> Twist3: |
| 75 | + linear_vel = self.spline_xyz.derivative()(t) |
| 76 | + angular_vel = self.spline_so3(t, 1) |
| 77 | + return Twist3(linear_vel, angular_vel) |
| 78 | + |
| 79 | + def max_angular_error(self) -> float: |
| 80 | + return np.max(self.angular_errors()) |
| 81 | + |
| 82 | + def angular_errors(self) -> list[float]: |
| 83 | + return [ |
| 84 | + SO3(pose).angdist(SO3(self.spline_so3(timestamp).as_matrix())) |
| 85 | + for pose, timestamp in zip(self.waypoints, self.timestamps, strict=True) |
| 86 | + ] |
| 87 | + |
| 88 | + def max_euclidean_error(self) -> float: |
| 89 | + return np.max(self.euclidean_errors()) |
| 90 | + |
| 91 | + def euclidean_errors(self) -> List[float]: |
| 92 | + return [ |
| 93 | + np.linalg.norm(pose.t - self.spline_xyz(timestamp)) |
| 94 | + for pose, timestamp in zip(self.waypoints, self.timestamps, strict=True) |
| 95 | + ] |
| 96 | + |
| 97 | + def downsample(self, epsilon_xyz: float = 1e-3, epsilon_angle: float = 1e-1) -> int: |
| 98 | + chosen_indices: set[int] = set() |
| 99 | + interpolation_indices = self.interpolation_indices.copy() |
| 100 | + |
| 101 | + for _ in range(len(self.timestamps) - 2): # you must have at least 2 indices |
| 102 | + choices = list(set(interpolation_indices).difference(chosen_indices)) |
| 103 | + |
| 104 | + index = np.random.choice(choices) |
| 105 | + |
| 106 | + chosen_indices.add(index) |
| 107 | + interpolation_indices.remove(index) |
| 108 | + |
| 109 | + self.spline_xyz = CubicSpline(self.timestamps[interpolation_indices], self.xyz_data[interpolation_indices]) |
| 110 | + self.spline_so3 = RotationSpline( |
| 111 | + self.timestamps[interpolation_indices], self.so3_data[interpolation_indices] |
| 112 | + ) |
| 113 | + |
| 114 | + time = self.timestamps[index] |
| 115 | + angular_error = SO3(self.waypoints[index]).angdist(SO3(self.spline_so3(time).as_matrix())) |
| 116 | + euclidean_error = np.linalg.norm(self.waypoints[index].t - self.spline_xyz(time)) |
| 117 | + if angular_error > epsilon_angle or euclidean_error > epsilon_xyz: |
| 118 | + interpolation_indices.insert(int(np.searchsorted(interpolation_indices, index, side="right")), index) |
| 119 | + |
| 120 | + self.interpolation_indices = interpolation_indices |
| 121 | + return len(self.waypoints) - len(interpolation_indices) |
| 122 | + |
| 123 | + def visualize( |
| 124 | + self, |
| 125 | + num_samples: int, |
| 126 | + pose_marker_length: float = 0.2, |
| 127 | + animate: bool = False, |
| 128 | + ax: plt.Axes | None = None, |
| 129 | + ) -> None: |
| 130 | + """Displays an animation of the trajectory with the control poses.""" |
| 131 | + if ax is None: |
| 132 | + fig = plt.figure(figsize=(10, 10)) |
| 133 | + ax = fig.add_subplot(projection="3d") |
| 134 | + |
| 135 | + samples = [self(t) for t in np.linspace(0, 1, num_samples)] |
| 136 | + if not animate: |
| 137 | + x = [pose.x for pose in samples] |
| 138 | + y = [pose.y for pose in samples] |
| 139 | + z = [pose.z for pose in samples] |
| 140 | + ax.plot(x, y, z, "c", linewidth=1.0) # plot spline fit |
| 141 | + |
| 142 | + x = [pose.x for pose in self.waypoints] |
| 143 | + y = [pose.y for pose in self.waypoints] |
| 144 | + z = [pose.z for pose in self.waypoints] |
| 145 | + ax.plot(x, y, z, "r*") # plot source data |
| 146 | + |
| 147 | + x = [self.waypoints[i].x for i in self.interpolation_indices] |
| 148 | + y = [self.waypoints[i].y for i in self.interpolation_indices] |
| 149 | + z = [self.waypoints[i].z for i in self.interpolation_indices] |
| 150 | + ax.plot(x, y, z, "go", fillstyle="none") # plot interpolation indices |
| 151 | + |
| 152 | + if animate: |
| 153 | + tranimate(samples, repeat=True, length=pose_marker_length, wait=True) # animate pose along trajectory |
| 154 | + else: |
| 155 | + plt.show() |
| 156 | + |
| 157 | + def to_numpy(self) -> dict[str, npt.NDArray]: |
| 158 | + """Export spline parameters as dictionary of numpy arrays.""" |
| 159 | + return {"timestamps": self.timestamps, "twists": np.vstack([1.0 * pose.twist().A for pose in self.waypoints])} |
| 160 | + |
| 161 | + def from_numpy(self, data: dict[str, npt.NDArray]) -> None: |
| 162 | + """Reconstruct spline from 'to_numpy' parameters.""" |
| 163 | + self.timestamps = data["timestamps"] |
| 164 | + self.waypoints = [SE3.Exp(twist) for twist in data["twists"]] |
| 165 | + |
| 166 | + |
| 167 | +class SplineFit: |
| 168 | + |
| 169 | + pass |
| 170 | + |
17 | 171 |
|
18 | 172 | class BSplineSE3:
|
19 | 173 | """A class to parameterize a trajectory in SE3 with a 6-dimensional B-spline.
|
|
0 commit comments