mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 10:51:35 +00:00
Compare commits
9 Commits
feat/trim_
...
v0.4.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8fff0fde7c | ||
|
|
04de496547 | ||
|
|
baf9b50365 | ||
|
|
a0fdbf037a | ||
|
|
c085531b17 | ||
|
|
c7c6205332 | ||
|
|
4e54be1334 | ||
|
|
fde9d08281 | ||
|
|
46044fed75 |
@@ -57,7 +57,7 @@ class DatasetReplayConfig:
|
||||
repo_id: str
|
||||
# Episode to replay.
|
||||
episode: int
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int = 30
|
||||
|
||||
@@ -49,23 +49,18 @@ import torch
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so_follower,
|
||||
koch_follower,
|
||||
from lerobot.robots import (
|
||||
RobotConfig, # noqa: F401
|
||||
make_robot_from_config,
|
||||
omx_follower,
|
||||
so_follower,
|
||||
)
|
||||
from lerobot.transport import (
|
||||
services_pb2, # type: ignore
|
||||
services_pb2_grpc, # type: ignore
|
||||
)
|
||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
from .configs import RobotClientConfig
|
||||
from .constants import SUPPORTED_ROBOTS
|
||||
from .helpers import (
|
||||
Action,
|
||||
FPSTracker,
|
||||
@@ -485,8 +480,9 @@ class RobotClient:
|
||||
def async_client(cfg: RobotClientConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
# TODO: Assert if checking robot support is still needed with the plugin system
|
||||
# if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||
|
||||
client = RobotClient(cfg)
|
||||
|
||||
@@ -512,4 +508,5 @@ def async_client(cfg: RobotClientConfig):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
register_third_party_plugins()
|
||||
async_client() # run the client
|
||||
|
||||
@@ -27,7 +27,7 @@ class DatasetConfig:
|
||||
# "dataset_index" into the returned item. The index mapping is made according to the order in which the
|
||||
# datasets are provided.
|
||||
repo_id: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | None = None
|
||||
episodes: list[int] | None = None
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
|
||||
@@ -47,7 +47,6 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
load_info,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -1775,296 +1774,3 @@ def convert_image_to_video_dataset(
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
|
||||
def trim_episodes_by_frames(
|
||||
dataset: LeRobotDataset,
|
||||
episode_frames_to_keep: dict[int, list[int]],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Trim multiple episodes to keep only specific frames.
|
||||
|
||||
This function creates a new dataset where the specified episodes contain only
|
||||
the frames at the given indices. All other episodes are copied as-is.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
episode_frames_to_keep: Dict mapping episode indices to lists of global frame indices to keep.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original.
|
||||
|
||||
Returns:
|
||||
A new LeRobotDataset with the trimmed episodes.
|
||||
"""
|
||||
if not episode_frames_to_keep:
|
||||
raise ValueError("No episodes to trim")
|
||||
|
||||
for ep_idx in episode_frames_to_keep:
|
||||
if ep_idx >= dataset.meta.total_episodes:
|
||||
raise ValueError(f"Episode {ep_idx} does not exist")
|
||||
if not episode_frames_to_keep[ep_idx]:
|
||||
raise ValueError(f"No frames to keep for episode {ep_idx}")
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_trimmed"
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
total_trimmed = sum(len(frames) for frames in episode_frames_to_keep.values())
|
||||
logging.info(f"Trimming {len(episode_frames_to_keep)} episodes, keeping {total_trimmed} frames total")
|
||||
|
||||
# Create new metadata
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=dataset.meta.features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Build set of all frames to keep (for episodes being trimmed)
|
||||
# and compute new frame counts per episode
|
||||
all_keep_frames: set[int] = set()
|
||||
trimmed_frame_counts: dict[int, int] = {}
|
||||
for ep_idx, frames in episode_frames_to_keep.items():
|
||||
all_keep_frames.update(frames)
|
||||
trimmed_frame_counts[ep_idx] = len(frames)
|
||||
|
||||
# Copy and filter data
|
||||
_copy_and_reindex_data_with_multi_frame_filter(
|
||||
dataset, new_meta, episode_frames_to_keep, all_keep_frames
|
||||
)
|
||||
|
||||
# Handle videos if present
|
||||
if dataset.meta.video_keys:
|
||||
_copy_and_reindex_videos_with_multi_frame_filter(
|
||||
dataset, new_meta, episode_frames_to_keep
|
||||
)
|
||||
|
||||
# Copy episode metadata
|
||||
_copy_and_reindex_episodes_metadata_for_multi_trim(
|
||||
dataset, new_meta, trimmed_frame_counts
|
||||
)
|
||||
|
||||
logging.info(f"Created trimmed dataset with {new_meta.total_frames} frames at {output_dir}")
|
||||
|
||||
# Return the metadata instead of trying to load as LeRobotDataset
|
||||
# This avoids Hub validation issues when the repo doesn't exist yet
|
||||
return new_meta
|
||||
|
||||
|
||||
# Keep old function for backward compatibility
|
||||
def trim_episode_by_frames(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
keep_frame_indices: list[int],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Trim a single episode. Wrapper around trim_episodes_by_frames."""
|
||||
return trim_episodes_by_frames(
|
||||
dataset,
|
||||
episode_frames_to_keep={episode_index: keep_frame_indices},
|
||||
output_dir=output_dir,
|
||||
repo_id=repo_id,
|
||||
)
|
||||
|
||||
|
||||
def _copy_and_reindex_data_with_multi_frame_filter(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_frames_to_keep: dict[int, list[int]],
|
||||
all_keep_frames: set[int],
|
||||
) -> None:
|
||||
"""Copy data files with frame-level filtering for multiple episodes."""
|
||||
if src_dataset.meta.episodes is None:
|
||||
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
|
||||
|
||||
# Copy tasks
|
||||
if dst_meta.tasks is None and src_dataset.meta.tasks is not None:
|
||||
# Tasks are stored with task string as index
|
||||
dst_meta.save_episode_tasks(list(src_dataset.meta.tasks.index))
|
||||
|
||||
# Get all parquet files
|
||||
data_dir = src_dataset.root / "data"
|
||||
parquet_files = sorted(data_dir.glob("chunk-*/file-*.parquet"))
|
||||
|
||||
trim_episode_set = set(episode_frames_to_keep.keys())
|
||||
global_index = 0
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Filter: keep all frames from non-trimmed episodes,
|
||||
# and only specified frames from trimmed episodes
|
||||
mask = (~df["episode_index"].isin(trim_episode_set)) | (df["index"].isin(all_keep_frames))
|
||||
df = df[mask].copy().reset_index(drop=True)
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Reindex
|
||||
df["index"] = range(global_index, global_index + len(df))
|
||||
|
||||
# Recalculate frame_index within each episode
|
||||
for ep_idx in df["episode_index"].unique():
|
||||
ep_mask = df["episode_index"] == ep_idx
|
||||
df.loc[ep_mask, "frame_index"] = range(ep_mask.sum())
|
||||
|
||||
# Recalculate timestamps based on frame_index and fps
|
||||
df["timestamp"] = df["frame_index"] / src_dataset.meta.fps
|
||||
|
||||
# Determine output path (keep same structure)
|
||||
rel_path = parquet_path.relative_to(src_dataset.root)
|
||||
dst_path = dst_meta.root / rel_path
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_write_parquet(df, dst_path, dst_meta)
|
||||
global_index += len(df)
|
||||
|
||||
|
||||
def _copy_and_reindex_videos_with_multi_frame_filter(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_frames_to_keep: dict[int, list[int]],
|
||||
) -> None:
|
||||
"""Copy video files for trimmed dataset.
|
||||
|
||||
In v3.0 datasets, multiple episodes are concatenated into single video files.
|
||||
Each episode has from_timestamp/to_timestamp indicating its portion of the video.
|
||||
|
||||
For trimming, we copy the original video files as-is and update the metadata
|
||||
timestamps in _copy_and_reindex_episodes_metadata_for_multi_trim.
|
||||
"""
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
video_dir = src_dataset.root / "videos" / video_key
|
||||
dst_video_dir = dst_meta.root / "videos" / video_key
|
||||
|
||||
if not video_dir.exists():
|
||||
logging.warning(f"Video directory not found: {video_dir}")
|
||||
continue
|
||||
|
||||
# Copy all video files (they contain concatenated episodes)
|
||||
# The metadata timestamps will handle which portions to use
|
||||
copied_files = set()
|
||||
for chunk_dir in video_dir.glob("chunk-*"):
|
||||
dst_chunk_dir = dst_video_dir / chunk_dir.name
|
||||
dst_chunk_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for video_file in chunk_dir.glob("*.mp4"):
|
||||
if video_file.name not in copied_files:
|
||||
dst_path = dst_chunk_dir / video_file.name
|
||||
if not dst_path.exists():
|
||||
shutil.copy(video_file, dst_path)
|
||||
copied_files.add(video_file.name)
|
||||
|
||||
logging.info(f"Copied {len(copied_files)} video files for {video_key}")
|
||||
|
||||
|
||||
def _trim_video_frames(
|
||||
src_path: Path,
|
||||
dst_path: Path,
|
||||
keep_frame_indices: list[int],
|
||||
fps: float,
|
||||
episode_start_idx: int,
|
||||
) -> None:
|
||||
"""Trim a video to keep only specific frames using ffmpeg."""
|
||||
import subprocess
|
||||
|
||||
# Convert global indices to local indices within the episode
|
||||
local_indices = sorted([idx - episode_start_idx for idx in keep_frame_indices])
|
||||
|
||||
if not local_indices:
|
||||
logging.warning(f"No frames to keep for video {src_path}")
|
||||
return
|
||||
|
||||
# Calculate start and end times
|
||||
start_frame = local_indices[0]
|
||||
end_frame = local_indices[-1]
|
||||
|
||||
start_time = start_frame / fps
|
||||
duration = (end_frame - start_frame + 1) / fps
|
||||
|
||||
# Use ffmpeg to trim
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-ss", str(start_time),
|
||||
"-i", str(src_path),
|
||||
"-t", str(duration),
|
||||
"-c", "copy", # Fast copy without re-encoding
|
||||
str(dst_path)
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"Failed to trim video: {e.stderr.decode()}")
|
||||
# Fallback: copy the whole video
|
||||
shutil.copy(src_path, dst_path)
|
||||
|
||||
|
||||
def _copy_and_reindex_episodes_metadata_for_multi_trim(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
trimmed_frame_counts: dict[int, int],
|
||||
) -> None:
|
||||
"""Copy and update episode metadata for trimmed dataset."""
|
||||
if src_dataset.meta.episodes is None:
|
||||
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
|
||||
|
||||
# Calculate new frame counts and indices
|
||||
episodes_data = []
|
||||
global_idx = 0
|
||||
|
||||
for old_ep_idx in range(src_dataset.meta.total_episodes):
|
||||
src_ep = src_dataset.meta.episodes[old_ep_idx]
|
||||
|
||||
if old_ep_idx in trimmed_frame_counts:
|
||||
ep_length = trimmed_frame_counts[old_ep_idx]
|
||||
else:
|
||||
ep_length = src_ep["length"]
|
||||
|
||||
ep_data = {
|
||||
"episode_index": old_ep_idx,
|
||||
"tasks": src_ep["tasks"],
|
||||
"length": ep_length,
|
||||
"data/chunk_index": src_ep["data/chunk_index"],
|
||||
"data/file_index": src_ep["data/file_index"],
|
||||
"dataset_from_index": global_idx,
|
||||
"dataset_to_index": global_idx + ep_length,
|
||||
}
|
||||
|
||||
# Copy video metadata - preserve timestamps for concatenated videos
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
ep_data[f"videos/{video_key}/chunk_index"] = src_ep[f"videos/{video_key}/chunk_index"]
|
||||
ep_data[f"videos/{video_key}/file_index"] = src_ep[f"videos/{video_key}/file_index"]
|
||||
|
||||
# Keep original from_timestamp (start position in concatenated video)
|
||||
orig_from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
|
||||
ep_data[f"videos/{video_key}/from_timestamp"] = orig_from_ts
|
||||
|
||||
# For trimmed episodes, update to_timestamp based on new length
|
||||
# For non-trimmed episodes, keep original to_timestamp
|
||||
if old_ep_idx in trimmed_frame_counts:
|
||||
ep_data[f"videos/{video_key}/to_timestamp"] = orig_from_ts + (ep_length / src_dataset.meta.fps)
|
||||
else:
|
||||
ep_data[f"videos/{video_key}/to_timestamp"] = src_ep[f"videos/{video_key}/to_timestamp"]
|
||||
|
||||
ep_data["meta/episodes/chunk_index"] = 0
|
||||
ep_data["meta/episodes/file_index"] = 0
|
||||
|
||||
episodes_data.append(ep_data)
|
||||
global_idx += ep_length
|
||||
|
||||
# Save episodes metadata
|
||||
df = pd.DataFrame(episodes_data)
|
||||
episodes_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(episodes_path)
|
||||
|
||||
# Update info.json
|
||||
info = load_info(src_dataset.root)
|
||||
info["total_episodes"] = len(episodes_data)
|
||||
info["total_frames"] = global_idx
|
||||
write_info(info, dst_meta.root)
|
||||
|
||||
@@ -664,11 +664,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for the README).
|
||||
|
||||
Args:
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset
|
||||
will be stored under root/repo_id.
|
||||
root (Path | None, optional): Local directory to use for downloading/writing files. You can also
|
||||
set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to
|
||||
'~/.cache/huggingface/lerobot'.
|
||||
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
||||
root (Path | None, optional): Local directory where the dataset will be downloaded and
|
||||
stored. If set, all dataset files will be stored directly under this path. If not set, the
|
||||
dataset files will be stored under $HF_LEROBOT_HOME/repo_id (configurable via the
|
||||
HF_LEROBOT_HOME environment variable).
|
||||
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
||||
their episode_index in this list. Defaults to None.
|
||||
image_transforms (Callable | None, optional): You can pass standard v2 image transforms from
|
||||
@@ -1771,11 +1771,12 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||
extra_keys = set(ds.features).difference(intersection_features)
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
if extra_keys:
|
||||
logging.warning(
|
||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||
"other datasets."
|
||||
)
|
||||
self.disabled_features.update(extra_keys)
|
||||
|
||||
self.image_transforms = image_transforms
|
||||
self.delta_timestamps = delta_timestamps
|
||||
|
||||
@@ -55,10 +55,16 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to
|
||||
a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX)
|
||||
vision_backbone: Name of the torchvision resnet backbone to use for encoding images.
|
||||
crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit
|
||||
within the image size. If None, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center crop in eval
|
||||
mode).
|
||||
resize_shape: (H, W) shape to resize images to as a preprocessing step for the vision
|
||||
backbone. If None, no resizing is done and the original image resolution is used.
|
||||
crop_ratio: Ratio in (0, 1] used to derive the crop size from resize_shape
|
||||
(crop_h = int(resize_shape[0] * crop_ratio), likewise for width).
|
||||
Set to 1.0 to disable cropping. Only takes effect when resize_shape is not None.
|
||||
crop_shape: (H, W) shape to crop images to. When resize_shape is set and crop_ratio < 1.0,
|
||||
this is computed automatically. Can also be set directly for legacy configs that use
|
||||
crop-only (without resize). If None and no derivation applies, no cropping is done.
|
||||
crop_is_random: Whether the crop should be random at training time (it's always a center
|
||||
crop in eval mode).
|
||||
pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone.
|
||||
`None` means no pretrained weights.
|
||||
use_group_norm: Whether to replace batch normalization with group normalization in the backbone.
|
||||
@@ -114,7 +120,9 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Architecture / modeling.
|
||||
# Vision backbone.
|
||||
vision_backbone: str = "resnet18"
|
||||
crop_shape: tuple[int, int] | None = (84, 84)
|
||||
resize_shape: tuple[int, int] | None = None
|
||||
crop_ratio: float = 1.0
|
||||
crop_shape: tuple[int, int] | None = None
|
||||
crop_is_random: bool = True
|
||||
pretrained_backbone_weights: str | None = None
|
||||
use_group_norm: bool = True
|
||||
@@ -175,6 +183,25 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
f"Got {self.noise_scheduler_type}."
|
||||
)
|
||||
|
||||
if self.resize_shape is not None and (
|
||||
len(self.resize_shape) != 2 or any(d <= 0 for d in self.resize_shape)
|
||||
):
|
||||
raise ValueError(f"`resize_shape` must be a pair of positive integers. Got {self.resize_shape}.")
|
||||
if not (0 < self.crop_ratio <= 1.0):
|
||||
raise ValueError(f"`crop_ratio` must be in (0, 1]. Got {self.crop_ratio}.")
|
||||
|
||||
if self.resize_shape is not None:
|
||||
if self.crop_ratio < 1.0:
|
||||
self.crop_shape = (
|
||||
int(self.resize_shape[0] * self.crop_ratio),
|
||||
int(self.resize_shape[1] * self.crop_ratio),
|
||||
)
|
||||
else:
|
||||
# Explicitly disable cropping for resize+ratio path when crop_ratio == 1.0.
|
||||
self.crop_shape = None
|
||||
if self.crop_shape is not None and (self.crop_shape[0] <= 0 or self.crop_shape[1] <= 0):
|
||||
raise ValueError(f"`crop_shape` must have positive dimensions. Got {self.crop_shape}.")
|
||||
|
||||
# Check that the horizon size and U-Net downsampling is compatible.
|
||||
# U-Net downsamples by 2 with each stage.
|
||||
downsampling_factor = 2 ** len(self.down_dims)
|
||||
@@ -202,13 +229,12 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
|
||||
if self.crop_shape is not None:
|
||||
if self.resize_shape is None and self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
f"`{key}`."
|
||||
f"`crop_shape` should fit within the image shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for `{key}`."
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
|
||||
@@ -454,12 +454,18 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
super().__init__()
|
||||
# Set up optional preprocessing.
|
||||
if config.crop_shape is not None:
|
||||
if config.resize_shape is not None:
|
||||
self.resize = torchvision.transforms.Resize(config.resize_shape)
|
||||
else:
|
||||
self.resize = None
|
||||
|
||||
crop_shape = config.crop_shape
|
||||
if crop_shape is not None:
|
||||
self.do_crop = True
|
||||
# Always use center crop for eval
|
||||
self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape)
|
||||
self.center_crop = torchvision.transforms.CenterCrop(crop_shape)
|
||||
if config.crop_is_random:
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape)
|
||||
self.maybe_random_crop = torchvision.transforms.RandomCrop(crop_shape)
|
||||
else:
|
||||
self.maybe_random_crop = self.center_crop
|
||||
else:
|
||||
@@ -485,13 +491,16 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Set up pooling and final layers.
|
||||
# Use a dry run to get the feature map shape.
|
||||
# The dummy input should take the number of image channels from `config.image_features` and it should
|
||||
# use the height and width from `config.crop_shape` if it is provided, otherwise it should use the
|
||||
# height and width from `config.image_features`.
|
||||
# The dummy shape mirrors the runtime preprocessing order: resize -> crop.
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
if config.crop_shape is not None:
|
||||
dummy_shape_h_w = config.crop_shape
|
||||
elif config.resize_shape is not None:
|
||||
dummy_shape_h_w = config.resize_shape
|
||||
else:
|
||||
dummy_shape_h_w = images_shape[1:]
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
@@ -507,7 +516,10 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
Returns:
|
||||
(B, D) image feature.
|
||||
"""
|
||||
# Preprocess: maybe crop (if it was set up in the __init__).
|
||||
# Preprocess: resize if configured, then crop if configured.
|
||||
|
||||
if self.resize is not None:
|
||||
x = self.resize(x)
|
||||
if self.do_crop:
|
||||
if self.training: # noqa: SIM108
|
||||
x = self.maybe_random_crop(x)
|
||||
|
||||
@@ -106,6 +106,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@@ -593,6 +593,12 @@ class VLAFlowMatching(nn.Module):
|
||||
self.prefix_length = self.config.prefix_length
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
# Compile model if requested
|
||||
if config.compile_model:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
|
||||
@@ -77,7 +77,6 @@ class SmolVLMWithExpertModel(nn.Module):
|
||||
print(f"Loading {model_id} weights ...")
|
||||
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
||||
model_id,
|
||||
device_map=device,
|
||||
torch_dtype="bfloat16",
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
@@ -56,6 +56,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
)
|
||||
|
||||
@@ -104,28 +104,6 @@ Convert image dataset to video format and push to hub:
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Trim single episode to keep only frames within timestamp range:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_trimmed \
|
||||
--operation.type trim_episode \
|
||||
--operation.episode_index 0 \
|
||||
--operation.start_timestamp 10.0 \
|
||||
--operation.end_timestamp 30.0
|
||||
|
||||
Trim multiple episodes at once (use null for no limit):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type trim_episode \
|
||||
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, null], "3": [null, 20.0]}'
|
||||
|
||||
Trim and re-upload to same repo (overwrites original):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type trim_episode \
|
||||
--operation.episode_index 0 \
|
||||
--operation.start_timestamp 10.0 \
|
||||
--push_to_hub true
|
||||
Show dataset information:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -226,32 +204,9 @@ class InfoConfig(OperationConfig):
|
||||
show_features: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrimEpisodeConfig:
|
||||
"""Trim episodes to keep only frames within timestamp ranges.
|
||||
|
||||
Supports multiple episodes via episode_trims dict:
|
||||
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, 20.0]}'
|
||||
|
||||
Or single episode via legacy parameters:
|
||||
--operation.episode_index 0 --operation.start_timestamp 10.0 --operation.end_timestamp 30.0
|
||||
"""
|
||||
type: str = "trim_episode"
|
||||
# Multi-episode support: dict mapping episode_index -> [start_timestamp, end_timestamp]
|
||||
# Use null for no limit, e.g. {"0": [10.0, null], "2": [null, 30.0]}
|
||||
episode_trims: dict[str, list[float | None]] | None = None
|
||||
# Legacy single-episode parameters (used if episode_trims is None)
|
||||
episode_index: int | None = None
|
||||
start_timestamp: float | None = None # Keep frames from this timestamp (inclusive)
|
||||
end_timestamp: float | None = None # Keep frames until this timestamp (inclusive)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: (
|
||||
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig | TrimEpisodeConfig
|
||||
)
|
||||
operation: OperationConfig
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
@@ -396,92 +351,6 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def handle_trim_episode(cfg: EditDatasetConfig) -> None:
|
||||
"""Trim episodes to keep only frames within timestamp ranges."""
|
||||
if not isinstance(cfg.operation, TrimEpisodeConfig):
|
||||
raise ValueError("Operation config must be TrimEpisodeConfig")
|
||||
|
||||
# Parse episode trims - support both multi-episode dict and legacy single episode
|
||||
episode_trims: dict[int, tuple[float | None, float | None]] = {}
|
||||
|
||||
if cfg.operation.episode_trims is not None:
|
||||
# Multi-episode mode
|
||||
for ep_str, ts_range in cfg.operation.episode_trims.items():
|
||||
ep_idx = int(ep_str)
|
||||
start_ts = ts_range[0] if len(ts_range) > 0 else None
|
||||
end_ts = ts_range[1] if len(ts_range) > 1 else None
|
||||
episode_trims[ep_idx] = (start_ts, end_ts)
|
||||
elif cfg.operation.episode_index is not None:
|
||||
# Legacy single-episode mode
|
||||
if cfg.operation.start_timestamp is None and cfg.operation.end_timestamp is None:
|
||||
raise ValueError("At least one of start_timestamp or end_timestamp must be specified")
|
||||
episode_trims[cfg.operation.episode_index] = (
|
||||
cfg.operation.start_timestamp,
|
||||
cfg.operation.end_timestamp,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either episode_trims or episode_index must be specified")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
output_repo_id, output_dir = get_output_path(
|
||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||
)
|
||||
|
||||
if cfg.new_repo_id is None:
|
||||
dataset.root = Path(str(dataset.root) + "_old")
|
||||
|
||||
logging.info(f"Trimming {len(episode_trims)} episode(s) from {cfg.repo_id}")
|
||||
|
||||
# Get episode boundaries and find frames to keep for each episode
|
||||
episodes_info = dataset.meta.episodes
|
||||
all_frames_to_keep: dict[int, list[int]] = {}
|
||||
|
||||
for ep_idx, (start_ts, end_ts) in episode_trims.items():
|
||||
if ep_idx >= len(episodes_info["episode_index"]):
|
||||
raise ValueError(f"Episode {ep_idx} does not exist (dataset has {len(episodes_info['episode_index'])} episodes)")
|
||||
|
||||
from_frame = episodes_info["dataset_from_index"][ep_idx]
|
||||
to_frame = episodes_info["dataset_to_index"][ep_idx]
|
||||
|
||||
logging.info(f"Episode {ep_idx}: trimming to [{start_ts}, {end_ts}]")
|
||||
logging.info(f" Original frames: {from_frame} to {to_frame} ({to_frame - from_frame} frames)")
|
||||
|
||||
# Find frames within timestamp range
|
||||
frames_to_keep = []
|
||||
for frame_idx in range(from_frame, to_frame):
|
||||
frame = dataset.hf_dataset[frame_idx]
|
||||
ts = frame["timestamp"]
|
||||
|
||||
in_range = True
|
||||
if start_ts is not None and ts < start_ts:
|
||||
in_range = False
|
||||
if end_ts is not None and ts > end_ts:
|
||||
in_range = False
|
||||
|
||||
if in_range:
|
||||
frames_to_keep.append(frame_idx)
|
||||
|
||||
if not frames_to_keep:
|
||||
raise ValueError(f"Episode {ep_idx}: No frames found in timestamp range [{start_ts}, {end_ts}]")
|
||||
|
||||
logging.info(f" Keeping {len(frames_to_keep)} frames (indices {frames_to_keep[0]} to {frames_to_keep[-1]})")
|
||||
all_frames_to_keep[ep_idx] = frames_to_keep
|
||||
|
||||
from lerobot.datasets.dataset_tools import trim_episodes_by_frames
|
||||
|
||||
new_dataset = trim_episodes_by_frames(
|
||||
dataset,
|
||||
episode_frames_to_keep=all_frames_to_keep,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
)
|
||||
|
||||
logging.info(f"Dataset saved to {output_dir}")
|
||||
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {output_repo_id}")
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
if not isinstance(cfg.operation, ModifyTasksConfig):
|
||||
raise ValueError("Operation config must be ModifyTasksConfig")
|
||||
@@ -646,8 +515,6 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
elif operation_type == "trim_episode":
|
||||
handle_trim_episode(cfg)
|
||||
elif operation_type == "info":
|
||||
handle_info(cfg)
|
||||
else:
|
||||
|
||||
@@ -61,6 +61,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
@@ -125,6 +125,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
@@ -154,7 +155,7 @@ class DatasetRecordConfig:
|
||||
repo_id: str
|
||||
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
||||
single_task: str
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second.
|
||||
fps: int = 30
|
||||
@@ -333,6 +334,7 @@ def record_loop(
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
no_action_count = 0
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
while timestamp < control_time_s:
|
||||
@@ -380,11 +382,13 @@ def record_loop(
|
||||
act = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
act_processed_teleop = teleop_action_processor((act, obs))
|
||||
else:
|
||||
logging.info(
|
||||
"No policy or teleoperator provided, skipping action generation."
|
||||
"This is likely to happen when resetting the environment without a teleop device."
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
no_action_count += 1
|
||||
if no_action_count == 1 or no_action_count % 10 == 0:
|
||||
logging.warning(
|
||||
"No policy or teleoperator provided, skipping action generation. "
|
||||
"This is likely to happen when resetting the environment without a teleop device. "
|
||||
"The robot won't be at its rest position at the start of the next episode."
|
||||
)
|
||||
continue
|
||||
|
||||
# Applies a pipeline to the action, default is IdentityProcessor
|
||||
|
||||
@@ -80,7 +80,7 @@ class DatasetReplayConfig:
|
||||
repo_id: str
|
||||
# Episode to replay.
|
||||
episode: int
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
||||
# Root directory where the dataset will be stored (e.g. 'dataset/path'). If None, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
root: str | Path | None = None
|
||||
# Limit the frames per second. By default, uses the policy fps.
|
||||
fps: int = 30
|
||||
|
||||
@@ -94,6 +94,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_leader,
|
||||
openarm_mini,
|
||||
reachy2_teleoperator,
|
||||
so_leader,
|
||||
unitree_g1,
|
||||
|
||||
@@ -380,10 +380,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
"dataloading_s": AverageMeter("data_s", ":.3f"),
|
||||
}
|
||||
|
||||
# Use effective batch size for proper epoch calculation in distributed training
|
||||
# Keep global batch size for logging; MetricsTracker handles world size internally.
|
||||
effective_batch_size = cfg.batch_size * accelerator.num_processes
|
||||
train_tracker = MetricsTracker(
|
||||
effective_batch_size,
|
||||
cfg.batch_size,
|
||||
dataset.num_frames,
|
||||
dataset.num_episodes,
|
||||
train_metrics,
|
||||
|
||||
@@ -104,9 +104,10 @@ class MetricsTracker:
|
||||
self.metrics = metrics
|
||||
|
||||
self.steps = initial_step
|
||||
world_size = accelerator.num_processes if accelerator else 1
|
||||
# A sample is an (observation,action) pair, where observation and action
|
||||
# can be on multiple timestamps. In a batch, we have `batch_size` number of samples.
|
||||
self.samples = self.steps * self._batch_size
|
||||
self.samples = self.steps * self._batch_size * world_size
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
self.accelerator = accelerator
|
||||
@@ -132,7 +133,8 @@ class MetricsTracker:
|
||||
Updates metrics that depend on 'step' for one step.
|
||||
"""
|
||||
self.steps += 1
|
||||
self.samples += self._batch_size * (self.accelerator.num_processes if self.accelerator else 1)
|
||||
world_size = self.accelerator.num_processes if self.accelerator else 1
|
||||
self.samples += self._batch_size * world_size
|
||||
self.episodes = self.samples / self._avg_samples_per_ep
|
||||
self.epochs = self.samples / self._num_frames
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -24,6 +24,11 @@ def mock_metrics():
|
||||
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
||||
|
||||
|
||||
class MockAccelerator:
|
||||
def __init__(self, num_processes: int):
|
||||
self.num_processes = num_processes
|
||||
|
||||
|
||||
def test_average_meter_initialization():
|
||||
meter = AverageMeter("loss", ":.2f")
|
||||
assert meter.name == "loss"
|
||||
@@ -82,6 +87,37 @@ def test_metrics_tracker_step(mock_metrics):
|
||||
assert tracker.epochs == tracker.samples / 1000
|
||||
|
||||
|
||||
def test_metrics_tracker_initialization_with_accelerator(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=mock_metrics,
|
||||
initial_step=10,
|
||||
accelerator=MockAccelerator(num_processes=2),
|
||||
)
|
||||
assert tracker.steps == 10
|
||||
assert tracker.samples == 10 * 32 * 2
|
||||
assert tracker.episodes == tracker.samples / (1000 / 50)
|
||||
assert tracker.epochs == tracker.samples / 1000
|
||||
|
||||
|
||||
def test_metrics_tracker_step_with_accelerator(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32,
|
||||
num_frames=1000,
|
||||
num_episodes=50,
|
||||
metrics=mock_metrics,
|
||||
initial_step=5,
|
||||
accelerator=MockAccelerator(num_processes=2),
|
||||
)
|
||||
tracker.step()
|
||||
assert tracker.steps == 6
|
||||
assert tracker.samples == (5 * 32 * 2) + (32 * 2)
|
||||
assert tracker.episodes == tracker.samples / (1000 / 50)
|
||||
assert tracker.epochs == tracker.samples / 1000
|
||||
|
||||
|
||||
def test_metrics_tracker_getattr(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
assert tracker.loss == mock_metrics["loss"]
|
||||
|
||||
Reference in New Issue
Block a user