mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
Merge branch 'feat/add_pi' into feat/validate_pi_libero
This commit is contained in:
@@ -1033,7 +1033,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
}
|
||||
|
||||
def _preprocess_images(
|
||||
self, batch: dict[str, Tensor], *, train: bool = False
|
||||
self, batch: dict[str, Tensor]
|
||||
) -> tuple[list[Tensor], list[Tensor]]: # see lerobot pi0 `prepare_images`
|
||||
"""Preprocess images for the model.
|
||||
|
||||
@@ -1080,105 +1080,8 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
if img.shape[1:3] != self.config.image_resolution:
|
||||
img = resize_with_pad_torch(img, *self.config.image_resolution)
|
||||
|
||||
# from openpi preprocess_observation_pytorch: Training augmentations
|
||||
if train:
|
||||
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
|
||||
img = img / 2.0 + 0.5
|
||||
|
||||
# Apply PyTorch-based augmentations
|
||||
if "wrist" not in key:
|
||||
# Geometric augmentations for non-wrist cameras
|
||||
height, width = img.shape[1:3]
|
||||
|
||||
# Random crop and resize
|
||||
crop_height = int(height * 0.95)
|
||||
crop_width = int(width * 0.95)
|
||||
|
||||
# Random crop
|
||||
max_h = height - crop_height
|
||||
max_w = width - crop_width
|
||||
if max_h > 0 and max_w > 0:
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
start_h = torch.randint(0, max_h + 1, (1,), device=img.device)
|
||||
start_w = torch.randint(0, max_w + 1, (1,), device=img.device)
|
||||
img = img[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
|
||||
|
||||
# Resize back to original size
|
||||
img = torch.nn.functional.interpolate(
|
||||
img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
size=(height, width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Random rotation (small angles)
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
angle = torch.rand(1, device=img.device) * 10 - 5 # Random angle between -5 and 5 degrees
|
||||
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
|
||||
# Convert to radians
|
||||
angle_rad = angle * torch.pi / 180.0
|
||||
|
||||
# Create rotation matrix
|
||||
cos_a = torch.cos(angle_rad)
|
||||
sin_a = torch.sin(angle_rad)
|
||||
|
||||
# Apply rotation using grid_sample
|
||||
grid_x = torch.linspace(-1, 1, width, device=img.device)
|
||||
grid_y = torch.linspace(-1, 1, height, device=img.device)
|
||||
|
||||
# Create meshgrid
|
||||
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
|
||||
|
||||
# Expand to batch dimension
|
||||
grid_x = grid_x.unsqueeze(0).expand(img.shape[0], -1, -1)
|
||||
grid_y = grid_y.unsqueeze(0).expand(img.shape[0], -1, -1)
|
||||
|
||||
# Apply rotation transformation
|
||||
grid_x_rot = grid_x * cos_a - grid_y * sin_a
|
||||
grid_y_rot = grid_x * sin_a + grid_y * cos_a
|
||||
|
||||
# Stack and reshape for grid_sample
|
||||
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
|
||||
|
||||
img = torch.nn.functional.grid_sample(
|
||||
img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
grid,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Color augmentations for all cameras
|
||||
# Random brightness
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
brightness_factor = (
|
||||
0.7 + torch.rand(1, device=img.device) * 0.6
|
||||
) # Random factor between 0.7 and 1.3
|
||||
img = img * brightness_factor
|
||||
|
||||
# Random contrast
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
contrast_factor = (
|
||||
0.6 + torch.rand(1, device=img.device) * 0.8
|
||||
) # Random factor between 0.6 and 1.4
|
||||
mean = img.mean(dim=[1, 2, 3], keepdim=True)
|
||||
img = (img - mean) * contrast_factor + mean
|
||||
|
||||
# Random saturation (convert to HSV, modify S, convert back)
|
||||
# For simplicity, we'll just apply a random scaling to the color channels
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
saturation_factor = (
|
||||
0.5 + torch.rand(1, device=img.device) * 1.0
|
||||
) # Random factor between 0.5 and 1.5
|
||||
gray = img.mean(dim=-1, keepdim=True)
|
||||
img = gray + (img - gray) * saturation_factor
|
||||
|
||||
# Clamp values to [0, 1]
|
||||
img = torch.clamp(img, 0, 1)
|
||||
|
||||
else:
|
||||
# from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
# from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
# from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
|
||||
if is_channels_first:
|
||||
@@ -1265,7 +1168,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch, train=False)
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
||||
state = self.prepare_state(batch)
|
||||
|
||||
@@ -1285,7 +1188,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch, train=True)
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
||||
|
||||
state = self.prepare_state(batch)
|
||||
|
||||
@@ -1048,7 +1048,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
}
|
||||
|
||||
def _preprocess_images(
|
||||
self, batch: dict[str, Tensor], *, train: bool = False
|
||||
self, batch: dict[str, Tensor]
|
||||
) -> tuple[list[Tensor], list[Tensor]]: # see lerobot pi0 `prepare_images`
|
||||
"""Preprocess images for the model.
|
||||
|
||||
@@ -1095,105 +1095,8 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
if img.shape[1:3] != self.config.image_resolution:
|
||||
img = resize_with_pad_torch(img, *self.config.image_resolution)
|
||||
|
||||
# from openpi preprocess_observation_pytorch: Training augmentations
|
||||
if train:
|
||||
# Convert from [-1, 1] to [0, 1] for PyTorch augmentations
|
||||
img = img / 2.0 + 0.5
|
||||
|
||||
# Apply PyTorch-based augmentations
|
||||
if "wrist" not in key:
|
||||
# Geometric augmentations for non-wrist cameras
|
||||
height, width = img.shape[1:3]
|
||||
|
||||
# Random crop and resize
|
||||
crop_height = int(height * 0.95)
|
||||
crop_width = int(width * 0.95)
|
||||
|
||||
# Random crop
|
||||
max_h = height - crop_height
|
||||
max_w = width - crop_width
|
||||
if max_h > 0 and max_w > 0:
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
start_h = torch.randint(0, max_h + 1, (1,), device=img.device)
|
||||
start_w = torch.randint(0, max_w + 1, (1,), device=img.device)
|
||||
img = img[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :]
|
||||
|
||||
# Resize back to original size
|
||||
img = torch.nn.functional.interpolate(
|
||||
img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
size=(height, width),
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Random rotation (small angles)
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
angle = torch.rand(1, device=img.device) * 10 - 5 # Random angle between -5 and 5 degrees
|
||||
if torch.abs(angle) > 0.1: # Only rotate if angle is significant
|
||||
# Convert to radians
|
||||
angle_rad = angle * torch.pi / 180.0
|
||||
|
||||
# Create rotation matrix
|
||||
cos_a = torch.cos(angle_rad)
|
||||
sin_a = torch.sin(angle_rad)
|
||||
|
||||
# Apply rotation using grid_sample
|
||||
grid_x = torch.linspace(-1, 1, width, device=img.device)
|
||||
grid_y = torch.linspace(-1, 1, height, device=img.device)
|
||||
|
||||
# Create meshgrid
|
||||
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij")
|
||||
|
||||
# Expand to batch dimension
|
||||
grid_x = grid_x.unsqueeze(0).expand(img.shape[0], -1, -1)
|
||||
grid_y = grid_y.unsqueeze(0).expand(img.shape[0], -1, -1)
|
||||
|
||||
# Apply rotation transformation
|
||||
grid_x_rot = grid_x * cos_a - grid_y * sin_a
|
||||
grid_y_rot = grid_x * sin_a + grid_y * cos_a
|
||||
|
||||
# Stack and reshape for grid_sample
|
||||
grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1)
|
||||
|
||||
img = torch.nn.functional.grid_sample(
|
||||
img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w]
|
||||
grid,
|
||||
mode="bilinear",
|
||||
padding_mode="zeros",
|
||||
align_corners=False,
|
||||
).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c]
|
||||
|
||||
# Color augmentations for all cameras
|
||||
# Random brightness
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
brightness_factor = (
|
||||
0.7 + torch.rand(1, device=img.device) * 0.6
|
||||
) # Random factor between 0.7 and 1.3
|
||||
img = img * brightness_factor
|
||||
|
||||
# Random contrast
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
contrast_factor = (
|
||||
0.6 + torch.rand(1, device=img.device) * 0.8
|
||||
) # Random factor between 0.6 and 1.4
|
||||
mean = img.mean(dim=[1, 2, 3], keepdim=True)
|
||||
img = (img - mean) * contrast_factor + mean
|
||||
|
||||
# Random saturation (convert to HSV, modify S, convert back)
|
||||
# For simplicity, we'll just apply a random scaling to the color channels
|
||||
# Use tensor operations instead of .item() for torch.compile compatibility
|
||||
saturation_factor = (
|
||||
0.5 + torch.rand(1, device=img.device) * 1.0
|
||||
) # Random factor between 0.5 and 1.5
|
||||
gray = img.mean(dim=-1, keepdim=True)
|
||||
img = gray + (img - gray) * saturation_factor
|
||||
|
||||
# Clamp values to [0, 1]
|
||||
img = torch.clamp(img, 0, 1)
|
||||
|
||||
else:
|
||||
# from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
# from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip
|
||||
img = img * 2.0 - 1.0
|
||||
|
||||
# from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first
|
||||
if is_channels_first:
|
||||
@@ -1280,7 +1183,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_inputs(batch)
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch, train=False)
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
||||
state = self.prepare_state(batch)
|
||||
|
||||
@@ -1300,7 +1203,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
|
||||
batch = self.normalize_targets(batch)
|
||||
|
||||
# Prepare inputs
|
||||
images, img_masks = self._preprocess_images(batch, train=True)
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
lang_tokens, lang_masks = self._tokenize_language(batch)
|
||||
state = self.prepare_state(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
Reference in New Issue
Block a user