remove cropping of images smaller than the crop size

This commit is contained in:
Bryson Jones
2025-12-15 22:20:20 -08:00
parent 25ecd16b67
commit 8a2f5aa6cb
2 changed files with 24 additions and 10 deletions

View File

@@ -157,7 +157,7 @@ available_datasets = sorted(
)
# lists all available policies from `lerobot/policies`
available_policies = ["act", "multi_task_dit", "diffusion", "tdmpc", "vqbet"]
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
# lists all available robots from `lerobot/robots`
available_robots = [

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
@@ -144,9 +145,12 @@ class MultiTaskDiTConfig(PreTrainedConfig):
or self.image_crop_shape[1] > self.image_resize_shape[1]
)
):
raise ValueError(
f"image_crop_shape {self.image_crop_shape} must be <= image_resize_shape {self.image_resize_shape}"
logging.warning(
"image_crop_shape %s must be <= image_resize_shape %s; disabling cropping.",
self.image_crop_shape,
self.image_resize_shape,
)
self.image_crop_shape = None
# Text encoder validation
if "clip" not in self.text_encoder_name.lower():
@@ -202,16 +206,26 @@ class MultiTaskDiTConfig(PreTrainedConfig):
def validate_features(self) -> None:
"""Validate that required input features are present and properly configured."""
# If the configured crop doesn't fit, disable cropping instead of erroring.
# Note: if image_resize_shape is set, cropping is applied *after* resizing.
if self.image_crop_shape is not None:
for key, image_ft in self.image_features.items():
if (
self.image_crop_shape[0] > image_ft.shape[1]
or self.image_crop_shape[1] > image_ft.shape[2]
):
raise ValueError(
f"image_crop_shape {self.image_crop_shape} doesn't fit within image shape {image_ft.shape} "
f"for '{key}'"
# image_ft.shape is (C, H, W)
effective_h, effective_w = (
self.image_resize_shape
if self.image_resize_shape is not None
else (image_ft.shape[1], image_ft.shape[2])
)
if self.image_crop_shape[0] > effective_h or self.image_crop_shape[1] > effective_w:
logging.warning(
"image_crop_shape %s doesn't fit within effective image shape (%s, %s) for '%s'; disabling cropping.",
self.image_crop_shape,
effective_h,
effective_w,
key,
)
self.image_crop_shape = None
break
if len(self.image_features) > 0:
first_key, first_ft = next(iter(self.image_features.items()))