diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 123d455c6..b62d7d959 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -567,20 +567,22 @@ def _copy_and_reindex_data( def _keep_episodes_from_video_with_av( input_path: Path, output_path: Path, - episodes_to_keep: list[tuple[float, float]], + episodes_to_keep: list[tuple[int, int]], fps: float, vcodec: str = "libsvtav1", pix_fmt: str = "yuv420p", ) -> None: """Keep only specified episodes from a video file using PyAV. - This function decodes frames from specified time ranges and re-encodes them with + This function decodes frames from specified frame ranges and re-encodes them with properly reset timestamps to ensure monotonic progression. Args: input_path: Source video file path. output_path: Destination video file path. - episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep. + episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep. + Ranges are half-open intervals: [start_frame, end_frame), where start_frame + is inclusive and end_frame is exclusive. fps: Frame rate of the video. vcodec: Video codec to use for encoding. pix_fmt: Pixel format for output video. @@ -622,9 +624,10 @@ def _keep_episodes_from_video_with_av( # Create set of (start, end) ranges for fast lookup. # Convert to a sorted list for efficient checking. - time_ranges = sorted(episodes_to_keep) + frame_ranges = sorted(episodes_to_keep) # Track frame index for setting PTS and current range being processed. + src_frame_count = 0 frame_count = 0 range_idx = 0 @@ -634,21 +637,20 @@ def _keep_episodes_from_video_with_av( if frame is None: continue - # Get frame timestamp. - frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0 - - # Check if frame is in any of our desired time ranges. + # Check if frame is in any of our desired frame ranges. # Skip ranges that have already passed. - while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]: + while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]: range_idx += 1 # If we've passed all ranges, stop processing. - if range_idx >= len(time_ranges): + if range_idx >= len(frame_ranges): break # Check if frame is in current range. - start_ts, end_ts = time_ranges[range_idx] - if frame_time < start_ts: + start_frame = frame_ranges[range_idx][0] + + if src_frame_count < start_frame: + src_frame_count += 1 continue # Frame is in range - create a new frame with reset timestamps. @@ -661,6 +663,7 @@ def _keep_episodes_from_video_with_av( for pkt in v_out.encode(new_frame): out.mux(pkt) + src_frame_count += 1 frame_count += 1 # Flush encoder. @@ -749,15 +752,17 @@ def _copy_and_reindex_videos( f"videos/{video_key}/to_timestamp" ] else: - # Build list of time ranges to keep, in sorted order. + # Build list of frame ranges to keep, in sorted order. sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x]) - episodes_to_keep_ranges: list[tuple[float, float]] = [] - + episodes_to_keep_ranges: list[tuple[int, int]] = [] for old_idx in sorted_keep_episodes: src_ep = src_dataset.meta.episodes[old_idx] - from_ts = src_ep[f"videos/{video_key}/from_timestamp"] - to_ts = src_ep[f"videos/{video_key}/to_timestamp"] - episodes_to_keep_ranges.append((from_ts, to_ts)) + from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps) + to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps) + assert src_ep["length"] == to_frame - from_frame, ( + f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}" + ) + episodes_to_keep_ranges.append((from_frame, to_frame)) # Use PyAV filters to efficiently re-encode only the desired segments. assert src_dataset.meta.video_path is not None