From 012bde51cb962eb4070aedca5c21af50f00fd1e9 Mon Sep 17 00:00:00 2001 From: Martino Russi Date: Wed, 25 Feb 2026 23:09:33 +0100 Subject: [PATCH] trim episodes --- src/lerobot/datasets/dataset_tools.py | 294 ++++++++++++++++++++ src/lerobot/scripts/lerobot_edit_dataset.py | 133 +++++++++ 2 files changed, 427 insertions(+) diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index b62d7d959..249aa00bb 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -47,6 +47,7 @@ from lerobot.datasets.utils import ( DEFAULT_EPISODES_PATH, get_parquet_file_size_in_mb, load_episodes, + load_info, update_chunk_file_indices, write_info, write_stats, @@ -1774,3 +1775,296 @@ def convert_image_to_video_dataset( # Return new dataset return LeRobotDataset(repo_id=repo_id, root=output_dir) + + +def trim_episodes_by_frames( + dataset: LeRobotDataset, + episode_frames_to_keep: dict[int, list[int]], + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Trim multiple episodes to keep only specific frames. + + This function creates a new dataset where the specified episodes contain only + the frames at the given indices. All other episodes are copied as-is. + + Args: + dataset: The source LeRobotDataset. + episode_frames_to_keep: Dict mapping episode indices to lists of global frame indices to keep. + output_dir: Directory to save the new dataset. If None, uses default location. + repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original. + + Returns: + A new LeRobotDataset with the trimmed episodes. + """ + if not episode_frames_to_keep: + raise ValueError("No episodes to trim") + + for ep_idx in episode_frames_to_keep: + if ep_idx >= dataset.meta.total_episodes: + raise ValueError(f"Episode {ep_idx} does not exist") + if not episode_frames_to_keep[ep_idx]: + raise ValueError(f"No frames to keep for episode {ep_idx}") + + if repo_id is None: + repo_id = f"{dataset.repo_id}_trimmed" + output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id + + total_trimmed = sum(len(frames) for frames in episode_frames_to_keep.values()) + logging.info(f"Trimming {len(episode_frames_to_keep)} episodes, keeping {total_trimmed} frames total") + + # Create new metadata + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=dataset.meta.features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=len(dataset.meta.video_keys) > 0, + ) + + # Build set of all frames to keep (for episodes being trimmed) + # and compute new frame counts per episode + all_keep_frames: set[int] = set() + trimmed_frame_counts: dict[int, int] = {} + for ep_idx, frames in episode_frames_to_keep.items(): + all_keep_frames.update(frames) + trimmed_frame_counts[ep_idx] = len(frames) + + # Copy and filter data + _copy_and_reindex_data_with_multi_frame_filter( + dataset, new_meta, episode_frames_to_keep, all_keep_frames + ) + + # Handle videos if present + if dataset.meta.video_keys: + _copy_and_reindex_videos_with_multi_frame_filter( + dataset, new_meta, episode_frames_to_keep + ) + + # Copy episode metadata + _copy_and_reindex_episodes_metadata_for_multi_trim( + dataset, new_meta, trimmed_frame_counts + ) + + logging.info(f"Created trimmed dataset with {new_meta.total_frames} frames at {output_dir}") + + # Return the metadata instead of trying to load as LeRobotDataset + # This avoids Hub validation issues when the repo doesn't exist yet + return new_meta + + +# Keep old function for backward compatibility +def trim_episode_by_frames( + dataset: LeRobotDataset, + episode_index: int, + keep_frame_indices: list[int], + output_dir: str | Path | None = None, + repo_id: str | None = None, +) -> LeRobotDataset: + """Trim a single episode. Wrapper around trim_episodes_by_frames.""" + return trim_episodes_by_frames( + dataset, + episode_frames_to_keep={episode_index: keep_frame_indices}, + output_dir=output_dir, + repo_id=repo_id, + ) + + +def _copy_and_reindex_data_with_multi_frame_filter( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_frames_to_keep: dict[int, list[int]], + all_keep_frames: set[int], +) -> None: + """Copy data files with frame-level filtering for multiple episodes.""" + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) + + # Copy tasks + if dst_meta.tasks is None and src_dataset.meta.tasks is not None: + # Tasks are stored with task string as index + dst_meta.save_episode_tasks(list(src_dataset.meta.tasks.index)) + + # Get all parquet files + data_dir = src_dataset.root / "data" + parquet_files = sorted(data_dir.glob("chunk-*/file-*.parquet")) + + trim_episode_set = set(episode_frames_to_keep.keys()) + global_index = 0 + + for parquet_path in tqdm(parquet_files, desc="Processing data files"): + df = pd.read_parquet(parquet_path) + + # Filter: keep all frames from non-trimmed episodes, + # and only specified frames from trimmed episodes + mask = (~df["episode_index"].isin(trim_episode_set)) | (df["index"].isin(all_keep_frames)) + df = df[mask].copy().reset_index(drop=True) + + if len(df) == 0: + continue + + # Reindex + df["index"] = range(global_index, global_index + len(df)) + + # Recalculate frame_index within each episode + for ep_idx in df["episode_index"].unique(): + ep_mask = df["episode_index"] == ep_idx + df.loc[ep_mask, "frame_index"] = range(ep_mask.sum()) + + # Recalculate timestamps based on frame_index and fps + df["timestamp"] = df["frame_index"] / src_dataset.meta.fps + + # Determine output path (keep same structure) + rel_path = parquet_path.relative_to(src_dataset.root) + dst_path = dst_meta.root / rel_path + dst_path.parent.mkdir(parents=True, exist_ok=True) + + _write_parquet(df, dst_path, dst_meta) + global_index += len(df) + + +def _copy_and_reindex_videos_with_multi_frame_filter( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_frames_to_keep: dict[int, list[int]], +) -> None: + """Copy video files for trimmed dataset. + + In v3.0 datasets, multiple episodes are concatenated into single video files. + Each episode has from_timestamp/to_timestamp indicating its portion of the video. + + For trimming, we copy the original video files as-is and update the metadata + timestamps in _copy_and_reindex_episodes_metadata_for_multi_trim. + """ + for video_key in src_dataset.meta.video_keys: + video_dir = src_dataset.root / "videos" / video_key + dst_video_dir = dst_meta.root / "videos" / video_key + + if not video_dir.exists(): + logging.warning(f"Video directory not found: {video_dir}") + continue + + # Copy all video files (they contain concatenated episodes) + # The metadata timestamps will handle which portions to use + copied_files = set() + for chunk_dir in video_dir.glob("chunk-*"): + dst_chunk_dir = dst_video_dir / chunk_dir.name + dst_chunk_dir.mkdir(parents=True, exist_ok=True) + + for video_file in chunk_dir.glob("*.mp4"): + if video_file.name not in copied_files: + dst_path = dst_chunk_dir / video_file.name + if not dst_path.exists(): + shutil.copy(video_file, dst_path) + copied_files.add(video_file.name) + + logging.info(f"Copied {len(copied_files)} video files for {video_key}") + + +def _trim_video_frames( + src_path: Path, + dst_path: Path, + keep_frame_indices: list[int], + fps: float, + episode_start_idx: int, +) -> None: + """Trim a video to keep only specific frames using ffmpeg.""" + import subprocess + + # Convert global indices to local indices within the episode + local_indices = sorted([idx - episode_start_idx for idx in keep_frame_indices]) + + if not local_indices: + logging.warning(f"No frames to keep for video {src_path}") + return + + # Calculate start and end times + start_frame = local_indices[0] + end_frame = local_indices[-1] + + start_time = start_frame / fps + duration = (end_frame - start_frame + 1) / fps + + # Use ffmpeg to trim + cmd = [ + "ffmpeg", "-y", + "-ss", str(start_time), + "-i", str(src_path), + "-t", str(duration), + "-c", "copy", # Fast copy without re-encoding + str(dst_path) + ] + + try: + subprocess.run(cmd, check=True, capture_output=True) + except subprocess.CalledProcessError as e: + logging.error(f"Failed to trim video: {e.stderr.decode()}") + # Fallback: copy the whole video + shutil.copy(src_path, dst_path) + + +def _copy_and_reindex_episodes_metadata_for_multi_trim( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + trimmed_frame_counts: dict[int, int], +) -> None: + """Copy and update episode metadata for trimmed dataset.""" + if src_dataset.meta.episodes is None: + src_dataset.meta.episodes = load_episodes(src_dataset.meta.root) + + # Calculate new frame counts and indices + episodes_data = [] + global_idx = 0 + + for old_ep_idx in range(src_dataset.meta.total_episodes): + src_ep = src_dataset.meta.episodes[old_ep_idx] + + if old_ep_idx in trimmed_frame_counts: + ep_length = trimmed_frame_counts[old_ep_idx] + else: + ep_length = src_ep["length"] + + ep_data = { + "episode_index": old_ep_idx, + "tasks": src_ep["tasks"], + "length": ep_length, + "data/chunk_index": src_ep["data/chunk_index"], + "data/file_index": src_ep["data/file_index"], + "dataset_from_index": global_idx, + "dataset_to_index": global_idx + ep_length, + } + + # Copy video metadata - preserve timestamps for concatenated videos + for video_key in src_dataset.meta.video_keys: + ep_data[f"videos/{video_key}/chunk_index"] = src_ep[f"videos/{video_key}/chunk_index"] + ep_data[f"videos/{video_key}/file_index"] = src_ep[f"videos/{video_key}/file_index"] + + # Keep original from_timestamp (start position in concatenated video) + orig_from_ts = src_ep[f"videos/{video_key}/from_timestamp"] + ep_data[f"videos/{video_key}/from_timestamp"] = orig_from_ts + + # For trimmed episodes, update to_timestamp based on new length + # For non-trimmed episodes, keep original to_timestamp + if old_ep_idx in trimmed_frame_counts: + ep_data[f"videos/{video_key}/to_timestamp"] = orig_from_ts + (ep_length / src_dataset.meta.fps) + else: + ep_data[f"videos/{video_key}/to_timestamp"] = src_ep[f"videos/{video_key}/to_timestamp"] + + ep_data["meta/episodes/chunk_index"] = 0 + ep_data["meta/episodes/file_index"] = 0 + + episodes_data.append(ep_data) + global_idx += ep_length + + # Save episodes metadata + df = pd.DataFrame(episodes_data) + episodes_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0) + episodes_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(episodes_path) + + # Update info.json + info = load_info(src_dataset.root) + info["total_episodes"] = len(episodes_data) + info["total_frames"] = global_idx + write_info(info, dst_meta.root) diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index afdc95efd..775c2622f 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -104,6 +104,28 @@ Convert image dataset to video format and push to hub: --operation.type convert_image_to_video \ --push_to_hub true +Trim single episode to keep only frames within timestamp range: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --new_repo_id lerobot/pusht_trimmed \ + --operation.type trim_episode \ + --operation.episode_index 0 \ + --operation.start_timestamp 10.0 \ + --operation.end_timestamp 30.0 + +Trim multiple episodes at once (use null for no limit): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type trim_episode \ + --operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, null], "3": [null, 20.0]}' + +Trim and re-upload to same repo (overwrites original): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type trim_episode \ + --operation.episode_index 0 \ + --operation.start_timestamp 10.0 \ + --push_to_hub true Show dataset information: lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ @@ -204,9 +226,32 @@ class InfoConfig(OperationConfig): show_features: bool = False +@dataclass +class TrimEpisodeConfig: + """Trim episodes to keep only frames within timestamp ranges. + + Supports multiple episodes via episode_trims dict: + --operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, 20.0]}' + + Or single episode via legacy parameters: + --operation.episode_index 0 --operation.start_timestamp 10.0 --operation.end_timestamp 30.0 + """ + type: str = "trim_episode" + # Multi-episode support: dict mapping episode_index -> [start_timestamp, end_timestamp] + # Use null for no limit, e.g. {"0": [10.0, null], "2": [null, 30.0]} + episode_trims: dict[str, list[float | None]] | None = None + # Legacy single-episode parameters (used if episode_trims is None) + episode_index: int | None = None + start_timestamp: float | None = None # Keep frames from this timestamp (inclusive) + end_timestamp: float | None = None # Keep frames until this timestamp (inclusive) + + @dataclass class EditDatasetConfig: repo_id: str + operation: ( + DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig | TrimEpisodeConfig + ) operation: OperationConfig root: str | None = None new_repo_id: str | None = None @@ -351,6 +396,92 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None: LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() +def handle_trim_episode(cfg: EditDatasetConfig) -> None: + """Trim episodes to keep only frames within timestamp ranges.""" + if not isinstance(cfg.operation, TrimEpisodeConfig): + raise ValueError("Operation config must be TrimEpisodeConfig") + + # Parse episode trims - support both multi-episode dict and legacy single episode + episode_trims: dict[int, tuple[float | None, float | None]] = {} + + if cfg.operation.episode_trims is not None: + # Multi-episode mode + for ep_str, ts_range in cfg.operation.episode_trims.items(): + ep_idx = int(ep_str) + start_ts = ts_range[0] if len(ts_range) > 0 else None + end_ts = ts_range[1] if len(ts_range) > 1 else None + episode_trims[ep_idx] = (start_ts, end_ts) + elif cfg.operation.episode_index is not None: + # Legacy single-episode mode + if cfg.operation.start_timestamp is None and cfg.operation.end_timestamp is None: + raise ValueError("At least one of start_timestamp or end_timestamp must be specified") + episode_trims[cfg.operation.episode_index] = ( + cfg.operation.start_timestamp, + cfg.operation.end_timestamp, + ) + else: + raise ValueError("Either episode_trims or episode_index must be specified") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + output_repo_id, output_dir = get_output_path( + cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None + ) + + if cfg.new_repo_id is None: + dataset.root = Path(str(dataset.root) + "_old") + + logging.info(f"Trimming {len(episode_trims)} episode(s) from {cfg.repo_id}") + + # Get episode boundaries and find frames to keep for each episode + episodes_info = dataset.meta.episodes + all_frames_to_keep: dict[int, list[int]] = {} + + for ep_idx, (start_ts, end_ts) in episode_trims.items(): + if ep_idx >= len(episodes_info["episode_index"]): + raise ValueError(f"Episode {ep_idx} does not exist (dataset has {len(episodes_info['episode_index'])} episodes)") + + from_frame = episodes_info["dataset_from_index"][ep_idx] + to_frame = episodes_info["dataset_to_index"][ep_idx] + + logging.info(f"Episode {ep_idx}: trimming to [{start_ts}, {end_ts}]") + logging.info(f" Original frames: {from_frame} to {to_frame} ({to_frame - from_frame} frames)") + + # Find frames within timestamp range + frames_to_keep = [] + for frame_idx in range(from_frame, to_frame): + frame = dataset.hf_dataset[frame_idx] + ts = frame["timestamp"] + + in_range = True + if start_ts is not None and ts < start_ts: + in_range = False + if end_ts is not None and ts > end_ts: + in_range = False + + if in_range: + frames_to_keep.append(frame_idx) + + if not frames_to_keep: + raise ValueError(f"Episode {ep_idx}: No frames found in timestamp range [{start_ts}, {end_ts}]") + + logging.info(f" Keeping {len(frames_to_keep)} frames (indices {frames_to_keep[0]} to {frames_to_keep[-1]})") + all_frames_to_keep[ep_idx] = frames_to_keep + + from lerobot.datasets.dataset_tools import trim_episodes_by_frames + + new_dataset = trim_episodes_by_frames( + dataset, + episode_frames_to_keep=all_frames_to_keep, + output_dir=output_dir, + repo_id=output_repo_id, + ) + + logging.info(f"Dataset saved to {output_dir}") + logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {output_repo_id}") + LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() def handle_modify_tasks(cfg: EditDatasetConfig) -> None: if not isinstance(cfg.operation, ModifyTasksConfig): raise ValueError("Operation config must be ModifyTasksConfig") @@ -515,6 +646,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: handle_modify_tasks(cfg) elif operation_type == "convert_image_to_video": handle_convert_image_to_video(cfg) + elif operation_type == "trim_episode": + handle_trim_episode(cfg) elif operation_type == "info": handle_info(cfg) else: