mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
more styling fixes:
This commit is contained in:
@@ -290,29 +290,25 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
|
||||
if not isinstance(quat, torch.Tensor):
|
||||
raise TypeError(
|
||||
f"_quat2axisangle expected a torch.Tensor, got {type(quat)}"
|
||||
)
|
||||
raise TypeError(f"_quat2axisangle expected a torch.Tensor, got {type(quat)}")
|
||||
|
||||
if quat.ndim != 2 or quat.shape[1] != 4:
|
||||
raise ValueError(
|
||||
f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}"
|
||||
)
|
||||
raise ValueError(f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}")
|
||||
|
||||
quat = quat.to(dtype=torch.float32)
|
||||
device = quat.device
|
||||
B = quat.shape[0]
|
||||
batch_size = quat.shape[0]
|
||||
|
||||
w = quat[:, 3].clamp(-1.0, 1.0)
|
||||
|
||||
den = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0))
|
||||
|
||||
result = torch.zeros((B, 3), device=device)
|
||||
result = torch.zeros((batch_size, 3), device=device)
|
||||
|
||||
mask = den > 1e-10
|
||||
|
||||
if mask.any():
|
||||
angle = 2.0 * torch.acos(w[mask]) # (M,)
|
||||
angle = 2.0 * torch.acos(w[mask]) # (M,)
|
||||
axis = quat[mask, :3] / den[mask].unsqueeze(1)
|
||||
result[mask] = axis * angle.unsqueeze(1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user