mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 02:41:24 +00:00
remove cropping of images smaller than the crop size
This commit is contained in:
@@ -157,7 +157,7 @@ available_datasets = sorted(
|
||||
)
|
||||
|
||||
# lists all available policies from `lerobot/policies`
|
||||
available_policies = ["act", "multi_task_dit", "diffusion", "tdmpc", "vqbet"]
|
||||
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
|
||||
|
||||
# lists all available robots from `lerobot/robots`
|
||||
available_robots = [
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
@@ -144,9 +145,12 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
||||
or self.image_crop_shape[1] > self.image_resize_shape[1]
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"image_crop_shape {self.image_crop_shape} must be <= image_resize_shape {self.image_resize_shape}"
|
||||
logging.warning(
|
||||
"image_crop_shape %s must be <= image_resize_shape %s; disabling cropping.",
|
||||
self.image_crop_shape,
|
||||
self.image_resize_shape,
|
||||
)
|
||||
self.image_crop_shape = None
|
||||
|
||||
# Text encoder validation
|
||||
if "clip" not in self.text_encoder_name.lower():
|
||||
@@ -202,16 +206,26 @@ class MultiTaskDiTConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate that required input features are present and properly configured."""
|
||||
# If the configured crop doesn't fit, disable cropping instead of erroring.
|
||||
# Note: if image_resize_shape is set, cropping is applied *after* resizing.
|
||||
if self.image_crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if (
|
||||
self.image_crop_shape[0] > image_ft.shape[1]
|
||||
or self.image_crop_shape[1] > image_ft.shape[2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"image_crop_shape {self.image_crop_shape} doesn't fit within image shape {image_ft.shape} "
|
||||
f"for '{key}'"
|
||||
# image_ft.shape is (C, H, W)
|
||||
effective_h, effective_w = (
|
||||
self.image_resize_shape
|
||||
if self.image_resize_shape is not None
|
||||
else (image_ft.shape[1], image_ft.shape[2])
|
||||
)
|
||||
if self.image_crop_shape[0] > effective_h or self.image_crop_shape[1] > effective_w:
|
||||
logging.warning(
|
||||
"image_crop_shape %s doesn't fit within effective image shape (%s, %s) for '%s'; disabling cropping.",
|
||||
self.image_crop_shape,
|
||||
effective_h,
|
||||
effective_w,
|
||||
key,
|
||||
)
|
||||
self.image_crop_shape = None
|
||||
break
|
||||
|
||||
if len(self.image_features) > 0:
|
||||
first_key, first_ft = next(iter(self.image_features.items()))
|
||||
|
||||
Reference in New Issue
Block a user