diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index d9d4b22d0..d9cc28b30 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -78,6 +78,7 @@ from lerobot.datasets.video_utils import ( from lerobot.utils.constants import HF_LEROBOT_HOME CODEBASE_VERSION = "v3.0" +VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"} class LeRobotDatasetMetadata: @@ -540,11 +541,13 @@ class LeRobotDatasetMetadata: return obj -def _encode_video_worker(video_key: str, episode_index: int, root: Path, fps: int) -> Path: +def _encode_video_worker( + video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1" +) -> Path: temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4" fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0) img_dir = (root / fpath).parent - encode_video_frames(img_dir, temp_path, fps, overwrite=True) + encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True) shutil.rmtree(img_dir) return temp_path @@ -563,6 +566,7 @@ class LeRobotDataset(torch.utils.data.Dataset): download_videos: bool = True, video_backend: str | None = None, batch_encoding_size: int = 1, + vcodec: str = "libsvtav1", ): """ 2 modes are available for instantiating this class, depending on 2 different use cases: @@ -675,8 +679,13 @@ class LeRobotDataset(torch.utils.data.Dataset): You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos. Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1. + vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc', + 'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1 + encoding is CPU-heavy. """ super().__init__() + if vcodec not in VALID_VIDEO_CODECS: + raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}") self.repo_id = repo_id self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id self.image_transforms = image_transforms @@ -688,6 +697,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.delta_indices = None self.batch_encoding_size = batch_encoding_size self.episodes_since_last_encoding = 0 + self.vcodec = vcodec # Unused attributes self.image_writer = None @@ -1211,6 +1221,7 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_index, self.root, self.fps, + self.vcodec, ): video_key for video_key in self.meta.video_keys } @@ -1526,7 +1537,7 @@ class LeRobotDataset(torch.utils.data.Dataset): Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, since video encoding with ffmpeg is already using multithreading. """ - return _encode_video_worker(video_key, episode_index, self.root, self.fps) + return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec) @classmethod def create( @@ -1542,8 +1553,11 @@ class LeRobotDataset(torch.utils.data.Dataset): image_writer_threads: int = 0, video_backend: str | None = None, batch_encoding_size: int = 1, + vcodec: str = "libsvtav1", ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" + if vcodec not in VALID_VIDEO_CODECS: + raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}") obj = cls.__new__(cls) obj.meta = LeRobotDatasetMetadata.create( repo_id=repo_id, @@ -1560,6 +1574,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.image_writer = None obj.batch_encoding_size = batch_encoding_size obj.episodes_since_last_encoding = 0 + obj.vcodec = vcodec if image_writer_processes or image_writer_threads: obj.start_image_writer(image_writer_processes, image_writer_threads) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 5d2945e67..a81a5d54e 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -27,6 +27,8 @@ lerobot-record \ --dataset.num_episodes=2 \ --dataset.single_task="Grab the cube" \ --display_data=true + # <- Optional: specify video codec (h264, hevc, libsvtav1). Default is libsvtav1. \ + # --dataset.vcodec=h264 \ # <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \ # --teleop.type=so100_leader \ # --teleop.port=/dev/tty.usbmodem58760431551 \ @@ -165,6 +167,9 @@ class DatasetRecordConfig: # Number of episodes to record before batch encoding videos # Set to 1 for immediate encoding (default behavior), or higher for batched encoding video_encoding_batch_size: int = 1 + # Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1'. + # Use 'h264' for faster encoding on systems where AV1 encoding is CPU-heavy. + vcodec: str = "libsvtav1" # Rename map for the observation to override the image and state keys rename_map: dict[str, str] = field(default_factory=dict) @@ -427,6 +432,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: cfg.dataset.repo_id, root=cfg.dataset.root, batch_encoding_size=cfg.dataset.video_encoding_batch_size, + vcodec=cfg.dataset.vcodec, ) if hasattr(robot, "cameras") and len(robot.cameras) > 0: @@ -448,6 +454,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: image_writer_processes=cfg.dataset.num_image_writer_processes, image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), batch_encoding_size=cfg.dataset.video_encoding_batch_size, + vcodec=cfg.dataset.vcodec, ) # Load pretrained policy diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 38fdc358d..4c91c55c0 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -31,8 +31,10 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.datasets.image_writer import image_array_to_pil_image from lerobot.datasets.lerobot_dataset import ( + VALID_VIDEO_CODECS, LeRobotDataset, MultiLeRobotDataset, + _encode_video_worker, ) from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, @@ -1292,3 +1294,101 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact frame = loaded_dataset[idx] expected_ep = idx // frames_per_episode assert frame["episode_index"].item() == expected_ep + + +def test_encode_video_worker_forwards_vcodec(tmp_path): + """Test that _encode_video_worker correctly forwards the vcodec parameter to encode_video_frames.""" + from unittest.mock import patch + + from lerobot.datasets.utils import DEFAULT_IMAGE_PATH + + # Create the expected directory structure + video_key = "observation.images.laptop" + episode_index = 0 + frame_index = 0 + + fpath = DEFAULT_IMAGE_PATH.format( + image_key=video_key, episode_index=episode_index, frame_index=frame_index + ) + img_dir = tmp_path / Path(fpath).parent + img_dir.mkdir(parents=True, exist_ok=True) + + # Create a dummy image file + dummy_img = Image.new("RGB", (64, 64), color="red") + dummy_img.save(img_dir / "frame-000000.png") + + # Track what vcodec was passed to encode_video_frames + captured_kwargs = {} + + def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs): + captured_kwargs.update(kwargs) + # Create a dummy output file so the worker doesn't fail + Path(video_path).parent.mkdir(parents=True, exist_ok=True) + Path(video_path).touch() + + with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames): + # Test with h264 codec + _encode_video_worker(video_key, episode_index, tmp_path, fps=30, vcodec="h264") + + assert "vcodec" in captured_kwargs + assert captured_kwargs["vcodec"] == "h264" + + +def test_encode_video_worker_default_vcodec(tmp_path): + """Test that _encode_video_worker uses libsvtav1 as the default codec.""" + from unittest.mock import patch + + from lerobot.datasets.utils import DEFAULT_IMAGE_PATH + + # Create the expected directory structure + video_key = "observation.images.laptop" + episode_index = 0 + frame_index = 0 + + fpath = DEFAULT_IMAGE_PATH.format( + image_key=video_key, episode_index=episode_index, frame_index=frame_index + ) + img_dir = tmp_path / Path(fpath).parent + img_dir.mkdir(parents=True, exist_ok=True) + + # Create a dummy image file + dummy_img = Image.new("RGB", (64, 64), color="red") + dummy_img.save(img_dir / "frame-000000.png") + + # Track what vcodec was passed to encode_video_frames + captured_kwargs = {} + + def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs): + captured_kwargs.update(kwargs) + # Create a dummy output file so the worker doesn't fail + Path(video_path).parent.mkdir(parents=True, exist_ok=True) + Path(video_path).touch() + + with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames): + # Test with default codec (no vcodec specified) + _encode_video_worker(video_key, episode_index, tmp_path, fps=30) + + assert "vcodec" in captured_kwargs + assert captured_kwargs["vcodec"] == "libsvtav1" + + +def test_lerobot_dataset_vcodec_validation(): + """Test that LeRobotDataset validates the vcodec parameter.""" + # Test that invalid vcodec raises ValueError + with pytest.raises(ValueError, match="Invalid vcodec"): + LeRobotDataset.__new__(LeRobotDataset) # bypass __init__ to test validation directly + # Actually test via create since it's easier + LeRobotDataset.create( + repo_id="test/invalid_codec", + fps=30, + features={"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}}, + vcodec="invalid_codec", + ) + + +def test_valid_video_codecs_constant(): + """Test that VALID_VIDEO_CODECS contains the expected codecs.""" + assert "h264" in VALID_VIDEO_CODECS + assert "hevc" in VALID_VIDEO_CODECS + assert "libsvtav1" in VALID_VIDEO_CODECS + assert len(VALID_VIDEO_CODECS) == 3