Skip to content

[Issue-122] Pass shortest arg for interp; optionally enforce non-negative scalar … #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions spatialmath/base/quaternions.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def r2q(
check: Optional[bool] = False,
tol: float = 20,
order: Optional[str] = "sxyz",
shortest: bool = False,
) -> UnitQuaternionArray:
"""
Convert SO(3) rotation matrix to unit-quaternion
Expand All @@ -562,6 +563,8 @@ def r2q(
:param order: the order of the returned quaternion elements. Must be 'sxyz' or
'xyzs'. Defaults to 'sxyz'.
:type order: str
:param shortest: ensures the quaternion has non-negative scalar part.
:type shortest: bool, default to False
:return: unit-quaternion as Euler parameters
:rtype: ndarray(4)
:raises ValueError: for non SO(3) argument
Expand Down Expand Up @@ -633,6 +636,9 @@ def r2q(
e[1] = math.copysign(e[1], R[0, 2] + R[2, 0])
e[2] = math.copysign(e[2], R[2, 1] + R[1, 2])

if shortest and e[0] < 0:
e = -e

if order == "sxyz":
return e
elif order == "xyzs":
Expand Down
12 changes: 9 additions & 3 deletions spatialmath/base/transforms2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,16 +853,16 @@ def tr2jac2(T: SE2Array) -> R3x3:


@overload
def trinterp2(start: Optional[SO2Array], end: SO2Array, s: float) -> SO2Array:
def trinterp2(start: Optional[SO2Array], end: SO2Array, s: float, shortest: bool = True) -> SO2Array:
...


@overload
def trinterp2(start: Optional[SE2Array], end: SE2Array, s: float) -> SE2Array:
def trinterp2(start: Optional[SE2Array], end: SE2Array, s: float, shortest: bool = True) -> SE2Array:
...


def trinterp2(start, end, s):
def trinterp2(start, end, s, shortest: bool = True):
"""
Interpolate SE(2) or SO(2) matrices

Expand All @@ -872,6 +872,8 @@ def trinterp2(start, end, s):
:type end: ndarray(3,3) or ndarray(2,2)
:param s: interpolation coefficient, range 0 to 1
:type s: float
:param shortest: take the shortest path along the great circle for the rotation
:type shortest: bool, default to True
:return: interpolated SE(2) or SO(2) matrix value
:rtype: ndarray(3,3) or ndarray(2,2)
:raises ValueError: bad arguments
Expand Down Expand Up @@ -917,6 +919,8 @@ def trinterp2(start, end, s):

th0 = math.atan2(start[1, 0], start[0, 0])
th1 = math.atan2(end[1, 0], end[0, 0])
if shortest:
th1 = th0 + smb.wrap_mpi_pi(th1 - th0)

th = th0 * (1 - s) + s * th1

Expand All @@ -937,6 +941,8 @@ def trinterp2(start, end, s):

th0 = math.atan2(start[1, 0], start[0, 0])
th1 = math.atan2(end[1, 0], end[0, 0])
if shortest:
th1 = th0 + smb.wrap_mpi_pi(th1 - th0)

p0 = transl2(start)
p1 = transl2(end)
Expand Down
16 changes: 9 additions & 7 deletions spatialmath/base/transforms3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -1605,16 +1605,16 @@ def trnorm(T: SE3Array) -> SE3Array:


@overload
def trinterp(start: Optional[SO3Array], end: SO3Array, s: float) -> SO3Array:
def trinterp(start: Optional[SO3Array], end: SO3Array, s: float, shortest: bool = True) -> SO3Array:
...


@overload
def trinterp(start: Optional[SE3Array], end: SE3Array, s: float) -> SE3Array:
def trinterp(start: Optional[SE3Array], end: SE3Array, s: float, shortest: bool = True) -> SE3Array:
...


def trinterp(start, end, s):
def trinterp(start, end, s, shortest=True):
"""
Interpolate SE(3) matrices

Expand All @@ -1624,6 +1624,8 @@ def trinterp(start, end, s):
:type end: ndarray(4,4) or ndarray(3,3)
:param s: interpolation coefficient, range 0 to 1
:type s: float
:param shortest: take the shortest path along the great circle for the rotation
:type shortest: bool, default to True
:return: interpolated SE(3) or SO(3) matrix value
:rtype: ndarray(4,4) or ndarray(3,3)
:raises ValueError: bad arguments
Expand Down Expand Up @@ -1663,12 +1665,12 @@ def trinterp(start, end, s):
if start is None:
# TRINTERP(T, s)
q0 = r2q(end)
qr = qslerp(qeye(), q0, s)
qr = qslerp(qeye(), q0, s, shortest=shortest)
else:
# TRINTERP(T0, T1, s)
q0 = r2q(start)
q1 = r2q(end)
qr = qslerp(q0, q1, s)
qr = qslerp(q0, q1, s, shortest=shortest)

return q2r(qr)

Expand All @@ -1679,7 +1681,7 @@ def trinterp(start, end, s):
q0 = r2q(t2r(end))
p0 = transl(end)

qr = qslerp(qeye(), q0, s)
qr = qslerp(qeye(), q0, s, shortest=shortest)
pr = s * p0
else:
# TRINTERP(T0, T1, s)
Expand All @@ -1689,7 +1691,7 @@ def trinterp(start, end, s):
p0 = transl(start)
p1 = transl(end)

qr = qslerp(q0, q1, s)
qr = qslerp(q0, q1, s, shortest=shortest)
pr = p0 * (1 - s) + s * p1

return rt2tr(q2r(qr), pr)
Expand Down
8 changes: 5 additions & 3 deletions spatialmath/baseposematrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,14 +377,16 @@ def log(self, twist: Optional[bool] = False) -> Union[NDArray, List[NDArray]]:
else:
return log

def interp(self, end: Optional[bool] = None, s: Union[int, float] = None) -> Self:
def interp(self, end: Optional[bool] = None, s: Union[int, float] = None, shortest: bool = True) -> Self:
"""
Interpolate between poses (superclass method)

:param end: final pose
:type end: same as ``self``
:param s: interpolation coefficient, range 0 to 1, or number of steps
:type s: array_like or int
:param shortest: take the shortest path along the great circle for the rotation
:type shortest: bool, default to True
:return: interpolated pose
:rtype: same as ``self``

Expand Down Expand Up @@ -432,13 +434,13 @@ def interp(self, end: Optional[bool] = None, s: Union[int, float] = None) -> Sel
if self.N == 2:
# SO(2) or SE(2)
return self.__class__(
[smb.trinterp2(start=self.A, end=end, s=_s) for _s in s]
[smb.trinterp2(start=self.A, end=end, s=_s, shortest=shortest) for _s in s]
)

elif self.N == 3:
# SO(3) or SE(3)
return self.__class__(
[smb.trinterp(start=self.A, end=end, s=_s) for _s in s]
[smb.trinterp(start=self.A, end=end, s=_s, shortest=shortest) for _s in s]
)

def interp1(self, s: float = None) -> Self:
Expand Down
7 changes: 7 additions & 0 deletions tests/base/test_quaternions.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ def test_rotation(self):
)
nt.assert_array_almost_equal(qvmul([0, 1, 0, 0], [0, 0, 1]), np.r_[0, 0, -1])

large_rotation = math.pi + 0.01
q1 = r2q(tr.rotx(large_rotation), shortest=False)
q2 = r2q(tr.rotx(large_rotation), shortest=True)
self.assertLess(q1[0], 0)
self.assertGreater(q2[0], 0)
self.assertTrue(qisequal(q1=q1, q2=q2, unitq=True))

def test_slerp(self):
q1 = np.r_[0, 1, 0, 0]
q2 = np.r_[0, 0, 1, 0]
Expand Down
10 changes: 10 additions & 0 deletions tests/test_pose2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,16 @@ def test_interp(self):
array_compare(I.interp(TT, s=1), TT)
array_compare(I.interp(TT, s=0.5), SE2(1, -2, 0.3))

R1 = SO2(math.pi - 0.1)
R2 = SO2(-math.pi + 0.2)
array_compare(R1.interp(R2, s=0.5, shortest=False), SO2(0.05))
array_compare(R1.interp(R2, s=0.5, shortest=True), SO2(-math.pi + 0.05))

T1 = SE2(0, 0, math.pi - 0.1)
T2 = SE2(0, 0, -math.pi + 0.2)
array_compare(T1.interp(T2, s=0.5, shortest=False), SE2(0, 0, 0.05))
array_compare(T1.interp(T2, s=0.5, shortest=True), SE2(0, 0, -math.pi + 0.05))

def test_miscellany(self):
TT = SE2(1, 2, 0.3)

Expand Down
Loading