diff --git a/src/lerobot/__init__.py b/src/lerobot/__init__.py index 63d6a44f4..eec574296 100644 --- a/src/lerobot/__init__.py +++ b/src/lerobot/__init__.py @@ -157,7 +157,7 @@ available_datasets = sorted( ) # lists all available policies from `lerobot/policies` -available_policies = ["act", "multi_task_dit", "diffusion", "tdmpc", "vqbet"] +available_policies = ["act", "diffusion", "tdmpc", "vqbet"] # lists all available robots from `lerobot/robots` available_robots = [ diff --git a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py index 8286ab8e6..061230687 100644 --- a/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/configuration_multi_task_dit.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from dataclasses import dataclass, field from lerobot.configs.policies import PreTrainedConfig @@ -144,9 +145,12 @@ class MultiTaskDiTConfig(PreTrainedConfig): or self.image_crop_shape[1] > self.image_resize_shape[1] ) ): - raise ValueError( - f"image_crop_shape {self.image_crop_shape} must be <= image_resize_shape {self.image_resize_shape}" + logging.warning( + "image_crop_shape %s must be <= image_resize_shape %s; disabling cropping.", + self.image_crop_shape, + self.image_resize_shape, ) + self.image_crop_shape = None # Text encoder validation if "clip" not in self.text_encoder_name.lower(): @@ -202,16 +206,26 @@ class MultiTaskDiTConfig(PreTrainedConfig): def validate_features(self) -> None: """Validate that required input features are present and properly configured.""" + # If the configured crop doesn't fit, disable cropping instead of erroring. + # Note: if image_resize_shape is set, cropping is applied *after* resizing. if self.image_crop_shape is not None: for key, image_ft in self.image_features.items(): - if ( - self.image_crop_shape[0] > image_ft.shape[1] - or self.image_crop_shape[1] > image_ft.shape[2] - ): - raise ValueError( - f"image_crop_shape {self.image_crop_shape} doesn't fit within image shape {image_ft.shape} " - f"for '{key}'" + # image_ft.shape is (C, H, W) + effective_h, effective_w = ( + self.image_resize_shape + if self.image_resize_shape is not None + else (image_ft.shape[1], image_ft.shape[2]) + ) + if self.image_crop_shape[0] > effective_h or self.image_crop_shape[1] > effective_w: + logging.warning( + "image_crop_shape %s doesn't fit within effective image shape (%s, %s) for '%s'; disabling cropping.", + self.image_crop_shape, + effective_h, + effective_w, + key, ) + self.image_crop_shape = None + break if len(self.image_features) > 0: first_key, first_ft = next(iter(self.image_features.items()))