more styling fixes:

This commit is contained in:
Jade Choghari
2025-11-18 16:01:41 +01:00
parent 769eb27c87
commit a068618faf

View File

@@ -290,29 +290,25 @@ class LiberoProcessorStep(ObservationProcessorStep):
"""
if not isinstance(quat, torch.Tensor):
raise TypeError(
f"_quat2axisangle expected a torch.Tensor, got {type(quat)}"
)
raise TypeError(f"_quat2axisangle expected a torch.Tensor, got {type(quat)}")
if quat.ndim != 2 or quat.shape[1] != 4:
raise ValueError(
f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}"
)
raise ValueError(f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}")
quat = quat.to(dtype=torch.float32)
device = quat.device
B = quat.shape[0]
batch_size = quat.shape[0]
w = quat[:, 3].clamp(-1.0, 1.0)
den = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0))
result = torch.zeros((B, 3), device=device)
result = torch.zeros((batch_size, 3), device=device)
mask = den > 1e-10
if mask.any():
angle = 2.0 * torch.acos(w[mask]) # (M,)
angle = 2.0 * torch.acos(w[mask]) # (M,)
axis = quat[mask, :3] / den[mask].unsqueeze(1)
result[mask] = axis * angle.unsqueeze(1)