diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py index 5b6da2f91..63681112f 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py @@ -1033,7 +1033,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): } def _preprocess_images( - self, batch: dict[str, Tensor], *, train: bool = False + self, batch: dict[str, Tensor] ) -> tuple[list[Tensor], list[Tensor]]: # see lerobot pi0 `prepare_images` """Preprocess images for the model. @@ -1080,105 +1080,8 @@ class PI05OpenPIPolicy(PreTrainedPolicy): if img.shape[1:3] != self.config.image_resolution: img = resize_with_pad_torch(img, *self.config.image_resolution) - # from openpi preprocess_observation_pytorch: Training augmentations - if train: - # Convert from [-1, 1] to [0, 1] for PyTorch augmentations - img = img / 2.0 + 0.5 - - # Apply PyTorch-based augmentations - if "wrist" not in key: - # Geometric augmentations for non-wrist cameras - height, width = img.shape[1:3] - - # Random crop and resize - crop_height = int(height * 0.95) - crop_width = int(width * 0.95) - - # Random crop - max_h = height - crop_height - max_w = width - crop_width - if max_h > 0 and max_w > 0: - # Use tensor operations instead of .item() for torch.compile compatibility - start_h = torch.randint(0, max_h + 1, (1,), device=img.device) - start_w = torch.randint(0, max_w + 1, (1,), device=img.device) - img = img[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :] - - # Resize back to original size - img = torch.nn.functional.interpolate( - img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] - size=(height, width), - mode="bilinear", - align_corners=False, - ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - - # Random rotation (small angles) - # Use tensor operations instead of .item() for torch.compile compatibility - angle = torch.rand(1, device=img.device) * 10 - 5 # Random angle between -5 and 5 degrees - if torch.abs(angle) > 0.1: # Only rotate if angle is significant - # Convert to radians - angle_rad = angle * torch.pi / 180.0 - - # Create rotation matrix - cos_a = torch.cos(angle_rad) - sin_a = torch.sin(angle_rad) - - # Apply rotation using grid_sample - grid_x = torch.linspace(-1, 1, width, device=img.device) - grid_y = torch.linspace(-1, 1, height, device=img.device) - - # Create meshgrid - grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij") - - # Expand to batch dimension - grid_x = grid_x.unsqueeze(0).expand(img.shape[0], -1, -1) - grid_y = grid_y.unsqueeze(0).expand(img.shape[0], -1, -1) - - # Apply rotation transformation - grid_x_rot = grid_x * cos_a - grid_y * sin_a - grid_y_rot = grid_x * sin_a + grid_y * cos_a - - # Stack and reshape for grid_sample - grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) - - img = torch.nn.functional.grid_sample( - img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] - grid, - mode="bilinear", - padding_mode="zeros", - align_corners=False, - ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - - # Color augmentations for all cameras - # Random brightness - # Use tensor operations instead of .item() for torch.compile compatibility - brightness_factor = ( - 0.7 + torch.rand(1, device=img.device) * 0.6 - ) # Random factor between 0.7 and 1.3 - img = img * brightness_factor - - # Random contrast - # Use tensor operations instead of .item() for torch.compile compatibility - contrast_factor = ( - 0.6 + torch.rand(1, device=img.device) * 0.8 - ) # Random factor between 0.6 and 1.4 - mean = img.mean(dim=[1, 2, 3], keepdim=True) - img = (img - mean) * contrast_factor + mean - - # Random saturation (convert to HSV, modify S, convert back) - # For simplicity, we'll just apply a random scaling to the color channels - # Use tensor operations instead of .item() for torch.compile compatibility - saturation_factor = ( - 0.5 + torch.rand(1, device=img.device) * 1.0 - ) # Random factor between 0.5 and 1.5 - gray = img.mean(dim=-1, keepdim=True) - img = gray + (img - gray) * saturation_factor - - # Clamp values to [0, 1] - img = torch.clamp(img, 0, 1) - - else: - # from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip - img = img * 2.0 - 1.0 + # from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first if is_channels_first: @@ -1265,7 +1168,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) # Prepare inputs - images, img_masks = self._preprocess_images(batch, train=False) + images, img_masks = self._preprocess_images(batch) lang_tokens, lang_masks = self._tokenize_language(batch) state = self.prepare_state(batch) @@ -1285,7 +1188,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): batch = self.normalize_targets(batch) # Prepare inputs - images, img_masks = self._preprocess_images(batch, train=True) + images, img_masks = self._preprocess_images(batch) lang_tokens, lang_masks = self._tokenize_language(batch) state = self.prepare_state(batch) diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py index ef737010e..cf873d843 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py @@ -1048,7 +1048,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): } def _preprocess_images( - self, batch: dict[str, Tensor], *, train: bool = False + self, batch: dict[str, Tensor] ) -> tuple[list[Tensor], list[Tensor]]: # see lerobot pi0 `prepare_images` """Preprocess images for the model. @@ -1095,105 +1095,8 @@ class PI0OpenPIPolicy(PreTrainedPolicy): if img.shape[1:3] != self.config.image_resolution: img = resize_with_pad_torch(img, *self.config.image_resolution) - # from openpi preprocess_observation_pytorch: Training augmentations - if train: - # Convert from [-1, 1] to [0, 1] for PyTorch augmentations - img = img / 2.0 + 0.5 - - # Apply PyTorch-based augmentations - if "wrist" not in key: - # Geometric augmentations for non-wrist cameras - height, width = img.shape[1:3] - - # Random crop and resize - crop_height = int(height * 0.95) - crop_width = int(width * 0.95) - - # Random crop - max_h = height - crop_height - max_w = width - crop_width - if max_h > 0 and max_w > 0: - # Use tensor operations instead of .item() for torch.compile compatibility - start_h = torch.randint(0, max_h + 1, (1,), device=img.device) - start_w = torch.randint(0, max_w + 1, (1,), device=img.device) - img = img[:, start_h : start_h + crop_height, start_w : start_w + crop_width, :] - - # Resize back to original size - img = torch.nn.functional.interpolate( - img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] - size=(height, width), - mode="bilinear", - align_corners=False, - ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - - # Random rotation (small angles) - # Use tensor operations instead of .item() for torch.compile compatibility - angle = torch.rand(1, device=img.device) * 10 - 5 # Random angle between -5 and 5 degrees - if torch.abs(angle) > 0.1: # Only rotate if angle is significant - # Convert to radians - angle_rad = angle * torch.pi / 180.0 - - # Create rotation matrix - cos_a = torch.cos(angle_rad) - sin_a = torch.sin(angle_rad) - - # Apply rotation using grid_sample - grid_x = torch.linspace(-1, 1, width, device=img.device) - grid_y = torch.linspace(-1, 1, height, device=img.device) - - # Create meshgrid - grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing="ij") - - # Expand to batch dimension - grid_x = grid_x.unsqueeze(0).expand(img.shape[0], -1, -1) - grid_y = grid_y.unsqueeze(0).expand(img.shape[0], -1, -1) - - # Apply rotation transformation - grid_x_rot = grid_x * cos_a - grid_y * sin_a - grid_y_rot = grid_x * sin_a + grid_y * cos_a - - # Stack and reshape for grid_sample - grid = torch.stack([grid_x_rot, grid_y_rot], dim=-1) - - img = torch.nn.functional.grid_sample( - img.permute(0, 3, 1, 2), # [b, h, w, c] -> [b, c, h, w] - grid, - mode="bilinear", - padding_mode="zeros", - align_corners=False, - ).permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] - - # Color augmentations for all cameras - # Random brightness - # Use tensor operations instead of .item() for torch.compile compatibility - brightness_factor = ( - 0.7 + torch.rand(1, device=img.device) * 0.6 - ) # Random factor between 0.7 and 1.3 - img = img * brightness_factor - - # Random contrast - # Use tensor operations instead of .item() for torch.compile compatibility - contrast_factor = ( - 0.6 + torch.rand(1, device=img.device) * 0.8 - ) # Random factor between 0.6 and 1.4 - mean = img.mean(dim=[1, 2, 3], keepdim=True) - img = (img - mean) * contrast_factor + mean - - # Random saturation (convert to HSV, modify S, convert back) - # For simplicity, we'll just apply a random scaling to the color channels - # Use tensor operations instead of .item() for torch.compile compatibility - saturation_factor = ( - 0.5 + torch.rand(1, device=img.device) * 1.0 - ) # Random factor between 0.5 and 1.5 - gray = img.mean(dim=-1, keepdim=True) - img = gray + (img - gray) * saturation_factor - - # Clamp values to [0, 1] - img = torch.clamp(img, 0, 1) - - else: - # from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip - img = img * 2.0 - 1.0 + # from lerobot pi0: Normalize from [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first if is_channels_first: @@ -1280,7 +1183,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) # Prepare inputs - images, img_masks = self._preprocess_images(batch, train=False) + images, img_masks = self._preprocess_images(batch) lang_tokens, lang_masks = self._tokenize_language(batch) state = self.prepare_state(batch) @@ -1300,7 +1203,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): batch = self.normalize_targets(batch) # Prepare inputs - images, img_masks = self._preprocess_images(batch, train=True) + images, img_masks = self._preprocess_images(batch) lang_tokens, lang_masks = self._tokenize_language(batch) state = self.prepare_state(batch) actions = self.prepare_action(batch)