mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
Compare commits
29 Commits
feat/depth
...
feat/depth
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4445849b86 | ||
|
|
f43bf75f9b | ||
|
|
b540fa94a9 | ||
|
|
efad15f600 | ||
|
|
407d1882a2 | ||
|
|
0d6e4f3bad | ||
|
|
536b29d963 | ||
|
|
2744e26593 | ||
|
|
de64ad3f7e | ||
|
|
d777359662 | ||
|
|
5d0a20bd9c | ||
|
|
2c796d3352 | ||
|
|
df1648c102 | ||
|
|
3bd96a4346 | ||
|
|
016799dfa1 | ||
|
|
51b9038458 | ||
|
|
cc9a2e5c99 | ||
|
|
a2376389f9 | ||
|
|
57a619ab02 | ||
|
|
7f624adcc5 | ||
|
|
375cf1fdf3 | ||
|
|
b2c2bb7641 | ||
|
|
4a87ee1537 | ||
|
|
e44f86e516 | ||
|
|
a0e3acdb67 | ||
|
|
38ff579bcc | ||
|
|
479e444517 | ||
|
|
9787b8fa26 | ||
|
|
71f39f6912 |
@@ -39,6 +39,7 @@ from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import (
|
||||
VideoEncoderConfig,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
)
|
||||
@@ -251,10 +252,13 @@ def benchmark_encoding_decoding(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=encoding_cfg["vcodec"],
|
||||
pix_fmt=encoding_cfg["pix_fmt"],
|
||||
g=encoding_cfg.get("g"),
|
||||
crf=encoding_cfg.get("crf"),
|
||||
camera_encoder_config=VideoEncoderConfig(
|
||||
vcodec=encoding_cfg["vcodec"],
|
||||
pix_fmt=encoding_cfg["pix_fmt"],
|
||||
g=encoding_cfg.get("g"),
|
||||
crf=encoding_cfg.get("crf"),
|
||||
preset=encoding_cfg.get("preset"),
|
||||
),
|
||||
# fast_decode=encoding_cfg.get("fastdecode"),
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -90,6 +90,6 @@ lerobot-record \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
```
|
||||
|
||||
@@ -194,7 +194,7 @@ lerobot-record \
|
||||
--dataset.single_task="Navigate around obstacles" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ lerobot-record \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.reset_time_s=10
|
||||
|
||||
@@ -232,7 +232,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -278,6 +278,6 @@ lerobot-record \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
@@ -193,7 +193,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
@@ -43,7 +43,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -203,7 +203,7 @@ lerobot-record \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# --dataset.camera_encoder_config.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
|
||||
@@ -14,12 +14,22 @@ This makes `save_episode()` near-instant (the video is already encoded by the ti
|
||||
|
||||
## 2. Tuning Parameters
|
||||
|
||||
| Parameter | CLI Flag | Type | Default | Description |
|
||||
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
|
||||
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
|
||||
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
|
||||
All encoding parameters are grouped under `camera_encoder_config` (a `VideoEncoderConfig` dataclass), accessible from the CLI via `--dataset.camera_encoder_config.<field>`.
|
||||
|
||||
| Parameter | CLI Flag | Type | Default | Description |
|
||||
| ----------------------- | --------------------------------------------- | ------------- | ------------- | ------------------------------------------------------------------- |
|
||||
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
|
||||
| `vcodec` | `--dataset.camera_encoder_config.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `pix_fmt` | `--dataset.camera_encoder_config.pix_fmt` | `str` | `"yuv420p"` | Pixel format |
|
||||
| `g` | `--dataset.camera_encoder_config.g` | `int \| None` | `2` | GOP size (keyframe interval) |
|
||||
| `crf` | `--dataset.camera_encoder_config.crf` | `int \| None` | `30` | Quality level (mapped to codec-specific parameter) |
|
||||
| `preset` | `--dataset.camera_encoder_config.preset` | `int \| None` | `12` | Speed preset (libsvtav1 only, 0 = slowest … 13 = fastest) |
|
||||
| `fast_decode` | `--dataset.camera_encoder_config.fast_decode` | `int` | `0` | Fast-decode tuning level |
|
||||
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance (global). `None` lets the codec decide |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
|
||||
|
||||
> [!TIP]
|
||||
> Not all parameters apply to every codec. `VideoEncoderConfig` will warn at startup if you set a parameter that your chosen codec ignores (e.g. `preset` with `h264_nvenc`).
|
||||
|
||||
## 3. Performance Considerations
|
||||
|
||||
@@ -40,7 +50,7 @@ Streaming encoding means the CPU is encoding video **during** the capture loop,
|
||||
|
||||
### `encoder_threads` Tuning
|
||||
|
||||
This parameter controls how many threads each encoder instance uses internally:
|
||||
This parameter (`--dataset.encoder_threads`) controls how many threads each encoder instance uses internally:
|
||||
|
||||
- **Higher values** (e.g., 4-5): Faster encoding, but uses more CPU cores per camera. Good for high-end systems with many cores.
|
||||
- **Lower values** (e.g., 1-2): Less CPU per camera, freeing cores for capture and visualization. Good for low-res images and capable CPUs.
|
||||
@@ -82,15 +92,15 @@ Use HW encoding when:
|
||||
|
||||
### Available HW Encoders
|
||||
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ---------------------------------------------------------- |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder_config.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.camera_encoder_config.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder_config.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.camera_encoder_config.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.camera_encoder_config.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.camera_encoder_config.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.camera_encoder_config.vcodec=auto` |
|
||||
|
||||
> [!NOTE]
|
||||
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
|
||||
@@ -100,15 +110,15 @@ Use HW encoding when:
|
||||
|
||||
## 5. Troubleshooting
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.camera_encoder_config.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.camera_encoder_config.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.camera_encoder_config.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
|
||||
## 6. Recommended Configurations
|
||||
|
||||
@@ -146,10 +156,10 @@ On very constrained systems, streaming encoding may compete too heavily with the
|
||||
# 2camsx 640x480x3 @30fps: Requires some tuning.
|
||||
|
||||
# Use H.264, disable streaming, consider batching encoding
|
||||
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
lerobot-record --dataset.camera_encoder_config.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
```
|
||||
|
||||
## 7. Closing note
|
||||
|
||||
Performance ultimately depends on your exact setup — frames-per-second, resolution, CPU cores and load, available memory, episode length, and the encoder you choose. Always test with your target workload, be mindful about your CPU & system capabilities and tune `encoder_threads`, `encoder_queue_maxsize`, and
|
||||
`vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.
|
||||
`camera_encoder_config.vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.
|
||||
|
||||
@@ -117,10 +117,10 @@ lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_image_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
--operation.g 2 \
|
||||
--operation.crf 30
|
||||
--operation.camera_encoder_config.vcodec libsvtav1 \
|
||||
--operation.camera_encoder_config.pix_fmt yuv420p \
|
||||
--operation.camera_encoder_config.g 2 \
|
||||
--operation.camera_encoder_config.crf 30
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
@@ -147,11 +147,14 @@ lerobot-edit-dataset \
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
|
||||
- `fast_decode`: Fast decode tuning option (default: 0)
|
||||
- `camera_encoder_config`: Video encoder settings — all sub-fields accessible via `--operation.camera_encoder_config.<field>`:
|
||||
- `vcodec`: Video codec — `h264`, `hevc`, `libsvtav1`, `auto`, or hardware codecs (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format — `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: GOP size — lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Quality level — lower is better, 0 is lossless (default: 30)
|
||||
- `preset`: Speed preset, libsvtav1 only (default: 12)
|
||||
- `fast_decode`: Fast-decode tuning (default: 0)
|
||||
- `encoder_threads`: Threads per encoder instance — global setting, separate from `camera_encoder_config` (default: None)
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
|
||||
@@ -133,6 +133,9 @@ class RealSenseCamera(Camera):
|
||||
|
||||
self.rs_pipeline: rs.pipeline | None = None
|
||||
self.rs_profile: rs.pipeline_profile | None = None
|
||||
# Meters per uint16 unit on the depth stream. Queried from the device
|
||||
# at connect() time. Typical D-series value is 0.001 (= 1 mm/unit).
|
||||
self.depth_scale: float | None = None
|
||||
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
@@ -190,6 +193,17 @@ class RealSenseCamera(Camera):
|
||||
) from e
|
||||
|
||||
self._configure_capture_settings()
|
||||
|
||||
# Query depth scale (meters per uint16 unit) when depth is enabled so
|
||||
# consumers can convert the raw z16 stream to metric distances.
|
||||
if self.use_depth and self.rs_profile is not None:
|
||||
try:
|
||||
depth_sensor = self.rs_profile.get_device().first_depth_sensor()
|
||||
self.depth_scale = float(depth_sensor.get_depth_scale())
|
||||
except RuntimeError as e:
|
||||
logger.warning(f"{self}: failed to query depth scale ({e}); falling back to 0.001 m/unit.")
|
||||
self.depth_scale = 0.001
|
||||
|
||||
self._start_read_thread()
|
||||
|
||||
# NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise.
|
||||
@@ -532,7 +546,6 @@ class RealSenseCamera(Camera):
|
||||
self.latest_timestamp = None
|
||||
self.new_frame_event.clear()
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def async_read(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""
|
||||
@@ -575,7 +588,6 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
# NOTE(Steven): Missing implementation for depth for now
|
||||
@check_if_not_connected
|
||||
def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent (color) frame captured immediately (Peeking).
|
||||
@@ -611,6 +623,78 @@ class RealSenseCamera(Camera):
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
@check_if_not_connected
|
||||
def async_read_depth(self, timeout_ms: float = 200) -> NDArray[Any]:
|
||||
"""Read the latest depth frame asynchronously, in metric meters.
|
||||
|
||||
Mirrors :meth:`async_read` but returns the depth stream rather than the
|
||||
color stream. Output is ``np.uint16`` of shape ``(H, W)``.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
the background read thread is not running.
|
||||
TimeoutError: If no frame becomes available within ``timeout_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(
|
||||
f"{self}: cannot read depth — camera was configured with use_depth=False."
|
||||
)
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for depth frame from camera {self} after {timeout_ms} ms."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if depth_frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no depth frame available for {self}.")
|
||||
|
||||
return depth_frame
|
||||
|
||||
@check_if_not_connected
|
||||
def read_latest_depth(self, max_age_ms: int = 500) -> NDArray[Any]:
|
||||
"""Return the most recent depth frame in metric meters (peeking).
|
||||
|
||||
Non-blocking counterpart of :meth:`read_latest` for the depth stream.
|
||||
Output is ``np.float32`` of shape ``(H, W)`` in meters.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
RuntimeError: If ``use_depth`` is ``False`` for this camera, or if
|
||||
no depth frame has been captured yet.
|
||||
TimeoutError: If the latest depth frame is older than ``max_age_ms``.
|
||||
"""
|
||||
if not self.use_depth:
|
||||
raise RuntimeError(
|
||||
f"{self}: cannot read depth — camera was configured with use_depth=False."
|
||||
)
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
raise RuntimeError(f"{self} read thread is not running.")
|
||||
|
||||
with self.frame_lock:
|
||||
depth_frame = self.latest_depth_frame
|
||||
timestamp = self.latest_timestamp
|
||||
|
||||
if depth_frame is None or timestamp is None:
|
||||
raise RuntimeError(f"{self} has not captured any depth frames yet.")
|
||||
|
||||
age_ms = (time.perf_counter() - timestamp) * 1e3
|
||||
if age_ms > max_age_ms:
|
||||
raise TimeoutError(
|
||||
f"{self} latest depth frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)."
|
||||
)
|
||||
|
||||
return depth_frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the camera, stops the pipeline, and cleans up resources.
|
||||
@@ -634,6 +718,8 @@ class RealSenseCamera(Camera):
|
||||
self.rs_pipeline = None
|
||||
self.rs_profile = None
|
||||
|
||||
self.depth_scale = None
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_color_frame = None
|
||||
self.latest_depth_frame = None
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.transforms import ImageTransformsConfig
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,7 +34,7 @@ class DatasetConfig:
|
||||
image_transforms: ImageTransformsConfig = field(default_factory=ImageTransformsConfig)
|
||||
revision: str | None = None
|
||||
use_imagenet_stats: bool = True
|
||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||
video_backend: str = field(default_factory=get_safe_default_video_backend)
|
||||
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||
return_uint8: bool = False
|
||||
|
||||
@@ -40,10 +40,21 @@ from .io_utils import load_episodes, write_stats
|
||||
from .lerobot_dataset import LeRobotDataset
|
||||
from .multi_dataset import MultiLeRobotDataset
|
||||
from .pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from .pyav_utils import (
|
||||
check_video_encoder_config_pyav,
|
||||
detect_available_encoders_pyav,
|
||||
get_codec,
|
||||
)
|
||||
from .sampler import EpisodeAwareSampler
|
||||
from .streaming_dataset import StreamingLeRobotDataset
|
||||
from .utils import DEFAULT_EPISODES_PATH, create_lerobot_dataset_card
|
||||
from .video_utils import VideoEncodingManager
|
||||
from .video_utils import (
|
||||
DepthEncoderConfig,
|
||||
VideoEncoderConfig,
|
||||
VideoEncodingManager,
|
||||
camera_encoder_defaults,
|
||||
depth_encoder_defaults,
|
||||
)
|
||||
|
||||
# NOTE: Low-level I/O functions (cast_stats_to_numpy, get_parquet_file_size_in_mb, etc.)
|
||||
# and legacy migration constants are intentionally NOT re-exported here.
|
||||
@@ -58,15 +69,22 @@ __all__ = [
|
||||
"LeRobotDatasetMetadata",
|
||||
"MultiLeRobotDataset",
|
||||
"StreamingLeRobotDataset",
|
||||
"DepthEncoderConfig",
|
||||
"VideoEncoderConfig",
|
||||
"VideoEncodingManager",
|
||||
"camera_encoder_defaults",
|
||||
"depth_encoder_defaults",
|
||||
"add_features",
|
||||
"aggregate_datasets",
|
||||
"aggregate_pipeline_dataset_features",
|
||||
"aggregate_stats",
|
||||
"check_video_encoder_config_pyav",
|
||||
"convert_image_to_video_dataset",
|
||||
"create_initial_features",
|
||||
"create_lerobot_dataset_card",
|
||||
"delete_episodes",
|
||||
"detect_available_encoders_pyav",
|
||||
"get_codec",
|
||||
"get_feature_stats",
|
||||
"load_episodes",
|
||||
"make_dataset",
|
||||
|
||||
@@ -332,7 +332,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
videos_idx: Dictionary tracking video chunk and file indices.
|
||||
video_files_size_in_mb: Maximum size for video files in MB (defaults to DEFAULT_VIDEO_FILE_SIZE_IN_MB)
|
||||
chunk_size: Maximum number of files per chunk (defaults to DEFAULT_CHUNK_SIZE)
|
||||
|
||||
Returns:
|
||||
dict: Updated videos_idx with current chunk and file indices.
|
||||
"""
|
||||
@@ -417,6 +416,7 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
compatibility_check=True,
|
||||
)
|
||||
# Update duration of this destination file
|
||||
dst_file_durations[dst_key] = current_dst_duration + src_duration
|
||||
|
||||
@@ -48,7 +48,7 @@ from .utils import (
|
||||
is_valid_version,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import get_video_info
|
||||
from .video_utils import VideoEncoderConfig, get_video_info
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
|
||||
@@ -313,6 +313,20 @@ class LeRobotDatasetMetadata:
|
||||
"""Keys to access visual modalities stored as videos."""
|
||||
return [key for key, ft in self.features.items() if ft["dtype"] == "video"]
|
||||
|
||||
@property
|
||||
def depth_keys(self) -> list[str]:
|
||||
"""Keys to access depth-map modalities stored as videos.
|
||||
|
||||
A depth video key is a feature whose ``info`` dict carries
|
||||
``"video.is_depth_map": True`` (set either at creation time by the user
|
||||
or after the first encoded episode by :meth:`update_video_info`).
|
||||
"""
|
||||
return [
|
||||
key
|
||||
for key, ft in self.features.items()
|
||||
if ft["dtype"] == "video" and ft.get("info", {}).get("video.is_depth_map", False)
|
||||
]
|
||||
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||
@@ -510,19 +524,38 @@ class LeRobotDatasetMetadata:
|
||||
self.stats = aggregate_stats([self.stats, episode_stats]) if self.stats is not None else episode_stats
|
||||
write_stats(self.stats, self.root)
|
||||
|
||||
def update_video_info(self, video_key: str | None = None) -> None:
|
||||
"""
|
||||
def update_video_info(
|
||||
self,
|
||||
video_key: str | None = None,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
) -> None:
|
||||
"""Populate per-feature video info in ``info.json``.
|
||||
|
||||
Warning: this function writes info from first episode videos, implicitly assuming that all videos have
|
||||
been encoded the same way. Also, this means it assumes the first episode exists.
|
||||
|
||||
Args:
|
||||
video_key: If provided, only update this video key. Otherwise update
|
||||
all video keys in the dataset.
|
||||
camera_encoder_config: Encoder configuration used to produce the
|
||||
videos. When provided, its fields are recorded as
|
||||
``video.<field>`` entries alongside the stream-derived
|
||||
``video.*`` entries (see :func:`get_video_info`).
|
||||
"""
|
||||
if video_key is not None and video_key not in self.video_keys:
|
||||
raise ValueError(f"Video key {video_key} not found in dataset")
|
||||
|
||||
video_keys = [video_key] if video_key is not None else self.video_keys
|
||||
for key in video_keys:
|
||||
if not self.features[key].get("info", None):
|
||||
existing = self.features[key].get("info") or {}
|
||||
# Repopulate when codec metadata is missing — preserves user-provided
|
||||
# markers like ``video.is_depth_map`` while still recording stream
|
||||
# info on the first episode.
|
||||
if not existing or "video.codec" not in existing:
|
||||
video_path = self.root / self.video_path.format(video_key=key, chunk_index=0, file_index=0)
|
||||
self.info.features[key]["info"] = get_video_info(video_path)
|
||||
stream_info = get_video_info(video_path, camera_encoder_config=camera_encoder_config)
|
||||
merged = {**existing, **stream_info}
|
||||
self.info.features[key]["info"] = merged
|
||||
|
||||
def update_chunk_settings(
|
||||
self,
|
||||
|
||||
@@ -32,7 +32,13 @@ from .io_utils import (
|
||||
hf_transform_to_torch,
|
||||
load_nested_dataset,
|
||||
)
|
||||
from .video_utils import decode_video_frames
|
||||
from .video_utils import decode_depth_frames, decode_video_frames
|
||||
from .depth_utils import (
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
)
|
||||
|
||||
|
||||
class DatasetReader:
|
||||
@@ -237,17 +243,31 @@ class DatasetReader:
|
||||
"""
|
||||
ep = self._meta.episodes[ep_idx]
|
||||
|
||||
depth_keys = set(self._meta.depth_keys)
|
||||
|
||||
def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]:
|
||||
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
if vid_key in depth_keys:
|
||||
feature_info = self._meta.features[vid_key].get("info") or {}
|
||||
frames = decode_depth_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
depth_min=feature_info.get("video.depth_min", DEFAULT_DEPTH_MIN),
|
||||
depth_max=feature_info.get("video.depth_max", DEFAULT_DEPTH_MAX),
|
||||
shift=feature_info.get("video.shift", DEFAULT_DEPTH_SHIFT),
|
||||
use_log=feature_info.get("video.use_log", DEFAULT_DEPTH_USE_LOG),
|
||||
)
|
||||
else:
|
||||
frames = decode_video_frames(
|
||||
video_path,
|
||||
shifted_query_ts,
|
||||
self._tolerance_s,
|
||||
self._video_backend,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
return vid_key, frames.squeeze(0)
|
||||
|
||||
items = list(query_timestamps.items())
|
||||
|
||||
@@ -62,7 +62,7 @@ from .utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import encode_video_frames, get_video_info
|
||||
from .video_utils import VideoEncoderConfig, encode_video_frames, get_video_info
|
||||
|
||||
|
||||
def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict:
|
||||
@@ -92,6 +92,7 @@ def delete_episodes(
|
||||
episode_indices: list[int],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Delete episodes from a LeRobotDataset and create a new dataset.
|
||||
|
||||
@@ -100,6 +101,7 @@ def delete_episodes(
|
||||
episode_indices: List of episode indices to delete.
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
camera_encoder_config: Video encoder settings used when re-encoding video segments (default: :class:`VideoEncoderConfig()`).
|
||||
"""
|
||||
if not episode_indices:
|
||||
raise ValueError("No episodes to delete")
|
||||
@@ -132,7 +134,7 @@ def delete_episodes(
|
||||
|
||||
video_metadata = None
|
||||
if dataset.meta.video_keys:
|
||||
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping)
|
||||
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping, camera_encoder_config)
|
||||
|
||||
data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping)
|
||||
|
||||
@@ -154,6 +156,7 @@ def split_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
splits: dict[str, float | list[int]],
|
||||
output_dir: str | Path | None = None,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
) -> dict[str, LeRobotDataset]:
|
||||
"""Split a LeRobotDataset into multiple smaller datasets.
|
||||
|
||||
@@ -162,6 +165,7 @@ def split_dataset(
|
||||
splits: Either a dict mapping split names to episode indices, or a dict mapping
|
||||
split names to fractions (must sum to <= 1.0).
|
||||
output_dir: Root directory where the split datasets will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id.
|
||||
camera_encoder_config: Video encoder settings used when re-encoding video segments (default: :class:`VideoEncoderConfig()`).
|
||||
|
||||
Examples:
|
||||
Split by specific episodes
|
||||
@@ -222,7 +226,9 @@ def split_dataset(
|
||||
|
||||
video_metadata = None
|
||||
if dataset.meta.video_keys:
|
||||
video_metadata = _copy_and_reindex_videos(dataset, new_meta, episode_mapping)
|
||||
video_metadata = _copy_and_reindex_videos(
|
||||
dataset, new_meta, episode_mapping, camera_encoder_config
|
||||
)
|
||||
|
||||
data_metadata = _copy_and_reindex_data(dataset, new_meta, episode_mapping)
|
||||
|
||||
@@ -578,8 +584,7 @@ def _keep_episodes_from_video_with_av(
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
@@ -593,9 +598,10 @@ def _keep_episodes_from_video_with_av(
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
fps: Frame rate of the video.
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
camera_encoder_config: Video encoder settings (default: :class:`VideoEncoderConfig()`).
|
||||
"""
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
from fractions import Fraction
|
||||
|
||||
import av
|
||||
@@ -619,12 +625,12 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Convert fps to Fraction for PyAV compatibility.
|
||||
fps_fraction = Fraction(fps).limit_denominator(1000)
|
||||
v_out = out.add_stream(vcodec, rate=fps_fraction)
|
||||
v_out = out.add_stream(camera_encoder_config.vcodec, rate=fps_fraction)
|
||||
|
||||
# PyAV type stubs don't distinguish video streams from audio/subtitle streams.
|
||||
v_out.width = v_in.codec_context.width
|
||||
v_out.height = v_in.codec_context.height
|
||||
v_out.pix_fmt = pix_fmt
|
||||
v_out.pix_fmt = camera_encoder_config.pix_fmt
|
||||
|
||||
# Set time_base to match the frame rate for proper timestamp handling.
|
||||
v_out.time_base = Fraction(1, int(fps))
|
||||
@@ -687,8 +693,7 @@ def _copy_and_reindex_videos(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_mapping: dict[int, int],
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
) -> dict[int, dict]:
|
||||
"""Copy and filter video files, only re-encoding files with deleted episodes.
|
||||
|
||||
@@ -700,10 +705,13 @@ def _copy_and_reindex_videos(
|
||||
src_dataset: Source dataset to copy from
|
||||
dst_meta: Destination metadata object
|
||||
episode_mapping: Mapping from old episode indices to new indices
|
||||
camera_encoder_config: Video encoder settings used when re-encoding segments (default: :class:`VideoEncoderConfig()`).
|
||||
|
||||
Returns:
|
||||
dict mapping episode index to its video metadata (chunk_index, file_index, timestamps)
|
||||
"""
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
if src_dataset.meta.episodes is None:
|
||||
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
|
||||
|
||||
@@ -792,8 +800,7 @@ def _copy_and_reindex_videos(
|
||||
dst_video_path,
|
||||
episodes_to_keep_ranges,
|
||||
src_dataset.meta.fps,
|
||||
vcodec,
|
||||
pix_fmt,
|
||||
camera_encoder_config,
|
||||
)
|
||||
|
||||
cumulative_ts = 0.0
|
||||
@@ -1264,11 +1271,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: list[int],
|
||||
temp_dir: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
camera_encoder_config: VideoEncoderConfig,
|
||||
num_calibration_frames: int = 30,
|
||||
) -> float:
|
||||
"""Estimate MB per frame by encoding a small calibration sample.
|
||||
@@ -1282,11 +1285,7 @@ def _estimate_frame_size_via_calibration(
|
||||
episode_indices: List of episode indices being processed.
|
||||
temp_dir: Temporary directory for calibration files.
|
||||
fps: Frames per second for video encoding.
|
||||
vcodec: Video codec (libsvtav1, h264, hevc).
|
||||
pix_fmt: Pixel format (yuv420p, etc.).
|
||||
g: GOP size (group of pictures).
|
||||
crf: Constant Rate Factor (quality).
|
||||
fast_decode: Fast decode tuning parameter.
|
||||
camera_encoder_config: Video encoder settings used for calibration encoding.
|
||||
num_calibration_frames: Number of frames to use for calibration (default: 30).
|
||||
|
||||
Returns:
|
||||
@@ -1322,11 +1321,7 @@ def _estimate_frame_size_via_calibration(
|
||||
imgs_dir=calibration_dir,
|
||||
video_path=calibration_video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1644,11 +1639,7 @@ def convert_image_to_video_dataset(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
max_episodes_per_batch: int | None = None,
|
||||
@@ -1663,11 +1654,7 @@ def convert_image_to_video_dataset(
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Root directory where the edited dataset will be stored. If not specified, defaults to $HF_LEROBOT_HOME/repo_id. Equivalent to new_root in EditDatasetConfig.
|
||||
repo_id: Edited dataset identifier. Equivalent to new_repo_id in EditDatasetConfig.
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
camera_encoder_config: Video encoder settings (default: :class:`VideoEncoderConfig()`).
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit)
|
||||
@@ -1676,6 +1663,9 @@ def convert_image_to_video_dataset(
|
||||
Returns:
|
||||
New LeRobotDataset with images encoded as videos
|
||||
"""
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
@@ -1699,7 +1689,10 @@ def convert_image_to_video_dataset(
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
logging.info(
|
||||
f"Video codec: {camera_encoder_config.vcodec}, pixel format: {camera_encoder_config.pix_fmt}, "
|
||||
f"GOP: {camera_encoder_config.g}, CRF: {camera_encoder_config.crf}"
|
||||
)
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
@@ -1769,11 +1762,7 @@ def convert_image_to_video_dataset(
|
||||
episode_indices=episode_indices,
|
||||
temp_dir=temp_dir,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
)
|
||||
|
||||
logging.info(f"Processing camera: {img_key}")
|
||||
@@ -1815,11 +1804,7 @@ def convert_image_to_video_dataset(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
@@ -1865,7 +1850,9 @@ def convert_image_to_video_dataset(
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(video_path)
|
||||
new_meta.info.features[img_key]["info"] = get_video_info(
|
||||
video_path, camera_encoder_config=camera_encoder_config
|
||||
)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
|
||||
@@ -46,15 +46,19 @@ from .io_utils import (
|
||||
write_info,
|
||||
)
|
||||
from .utils import (
|
||||
DEFAULT_DEPTH_PATH,
|
||||
DEFAULT_EPISODES_PATH,
|
||||
DEFAULT_IMAGE_PATH,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
from .video_utils import (
|
||||
DepthEncoderConfig,
|
||||
StreamingVideoEncoder,
|
||||
VideoEncoderConfig,
|
||||
concatenate_video_files,
|
||||
encode_video_frames,
|
||||
get_video_duration_in_s,
|
||||
is_depth_feature,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -65,14 +69,19 @@ def _encode_video_worker(
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
|
||||
img_dir,
|
||||
temp_path,
|
||||
fps,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
encoder_threads=encoder_threads,
|
||||
overwrite=True,
|
||||
)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
@@ -89,33 +98,40 @@ class DatasetWriter:
|
||||
self,
|
||||
meta: LeRobotDatasetMetadata,
|
||||
root: Path,
|
||||
vcodec: str,
|
||||
camera_encoder_config: VideoEncoderConfig,
|
||||
encoder_threads: int | None,
|
||||
batch_encoding_size: int,
|
||||
streaming_encoder: StreamingVideoEncoder | None = None,
|
||||
initial_frames: int = 0,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
):
|
||||
"""Initialize the writer with metadata, codec, and encoding config.
|
||||
"""Initialize the writer with metadata, codec, and encoder config.
|
||||
|
||||
Args:
|
||||
meta: Dataset metadata instance (used for feature schema, chunk
|
||||
settings, and episode persistence).
|
||||
root: Local dataset root directory.
|
||||
vcodec: Video codec for encoding (e.g. ``'libsvtav1'``, ``'h264'``).
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
camera_encoder_config: Video encoder settings applied to all cameras.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos.
|
||||
streaming_encoder: Optional pre-built :class:`StreamingVideoEncoder`
|
||||
for real-time encoding. ``None`` disables streaming mode.
|
||||
initial_frames: Starting frame count (non-zero when resuming).
|
||||
depth_encoder_config: Optional depth-map encoder config used in
|
||||
place of ``camera_encoder_config`` for keys present in
|
||||
``meta.depth_keys``.
|
||||
"""
|
||||
self._meta = meta
|
||||
self._root = root
|
||||
self._vcodec = vcodec
|
||||
self._camera_encoder_config = camera_encoder_config
|
||||
self._depth_encoder_config = depth_encoder_config
|
||||
self._encoder_threads = encoder_threads
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._streaming_encoder = streaming_encoder
|
||||
|
||||
|
||||
# Writer state
|
||||
self.image_writer: AsyncImageWriter | None = None
|
||||
self.episode_buffer: dict = self._create_episode_buffer()
|
||||
@@ -135,8 +151,16 @@ class DatasetWriter:
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
|
||||
def _is_depth_image_key(self, image_key: str) -> bool:
|
||||
"""Whether *image_key* is a depth feature stored as per-frame images."""
|
||||
ft = self._meta.features.get(image_key)
|
||||
if ft is None or ft.get("dtype") != "image":
|
||||
return False
|
||||
return is_depth_feature(ft.get("info") or {})
|
||||
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
path_template = DEFAULT_DEPTH_PATH if self._is_depth_image_key(image_key) else DEFAULT_IMAGE_PATH
|
||||
fpath = path_template.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
)
|
||||
return self._root / fpath
|
||||
@@ -284,7 +308,7 @@ class DatasetWriter:
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._vcodec,
|
||||
self._camera_encoder_config,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self._meta.video_keys
|
||||
@@ -495,7 +519,13 @@ class DatasetWriter:
|
||||
|
||||
# Update video info (only needed when first episode is encoded)
|
||||
if episode_index == 0:
|
||||
self._meta.update_video_info(video_key)
|
||||
is_depth_key = video_key in set(self._meta.depth_keys)
|
||||
cfg_for_info = (
|
||||
self._depth_encoder_config
|
||||
if is_depth_key and self._depth_encoder_config is not None
|
||||
else self._camera_encoder_config
|
||||
)
|
||||
self._meta.update_video_info(video_key, camera_encoder_config=cfg_for_info)
|
||||
write_info(self._meta.info, self._meta.root)
|
||||
|
||||
metadata = {
|
||||
@@ -564,7 +594,12 @@ class DatasetWriter:
|
||||
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
|
||||
"""Use ffmpeg to convert frames stored as png into mp4 videos."""
|
||||
return _encode_video_worker(
|
||||
video_key, episode_index, self._root, self._meta.fps, self._vcodec, self._encoder_threads
|
||||
video_key,
|
||||
episode_index,
|
||||
self._root,
|
||||
self._meta.fps,
|
||||
self._camera_encoder_config,
|
||||
self._encoder_threads,
|
||||
)
|
||||
|
||||
def close_writer(self) -> None:
|
||||
|
||||
189
src/lerobot/datasets/depth_utils.py
Normal file
189
src/lerobot/datasets/depth_utils.py
Normal file
@@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Depth encoding/decoding helpers for :class:`VideoEncoderConfig`.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import NDArray
|
||||
|
||||
DEPTH_QUANT_BITS: int = 12
|
||||
DEPTH_QMAX: int = (1 << DEPTH_QUANT_BITS) - 1 # 4095
|
||||
_MM_PER_METRE: float = 1000.0
|
||||
_UINT16_MAX: int = 65535
|
||||
|
||||
DEFAULT_DEPTH_MIN: float = 0.01
|
||||
DEFAULT_DEPTH_MAX: float = 10.0
|
||||
DEFAULT_DEPTH_SHIFT: float = 3.5
|
||||
DEFAULT_DEPTH_USE_LOG: bool = True
|
||||
|
||||
|
||||
def _validate_log_quant_params(depth_min: float, shift: float) -> None:
|
||||
"""Ensure ``log(depth_min + shift)`` is finite."""
|
||||
if depth_min + shift <= 0:
|
||||
raise ValueError(
|
||||
f"depth_min + shift must be positive for logarithmic quantization, "
|
||||
f"got depth_min={depth_min} + shift={shift} = {depth_min + shift}"
|
||||
)
|
||||
|
||||
|
||||
def _depth_input_to_float32_and_unit(
|
||||
depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor,
|
||||
input_unit: Literal["auto", "m", "mm"],
|
||||
) -> tuple[NDArray[np.float32], Literal["m", "mm"]]:
|
||||
"""Depth as float32 in the chosen unit, plus the resolved unit."""
|
||||
if isinstance(depth, torch.Tensor):
|
||||
t = depth.detach().cpu()
|
||||
arr = t.numpy()
|
||||
is_floating = t.is_floating_point()
|
||||
else:
|
||||
arr = np.asarray(depth)
|
||||
is_floating = np.issubdtype(arr.dtype, np.floating)
|
||||
|
||||
resolved_unit: Literal["m", "mm"]
|
||||
if input_unit == "auto":
|
||||
resolved_unit = "m" if is_floating else "mm"
|
||||
else:
|
||||
resolved_unit = input_unit
|
||||
|
||||
# Convert to float32 to keep typing consistency
|
||||
return np.asarray(arr, dtype=np.float32, order="K"), resolved_unit
|
||||
|
||||
|
||||
def quantize_depth(
|
||||
depth: NDArray[np.uint16] | NDArray[np.floating] | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
*,
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> NDArray[np.uint16]:
|
||||
"""Quantize depth to 12-bit codes (``uint16``, values ``0…DEPTH_QMAX``).
|
||||
|
||||
Depth maps are packed into 12-bit integer frames so they fit in standard
|
||||
high-bit-depth pixel formats (e.g. ``yuv420p12le`` / ``gray12le``)
|
||||
and can be encoded by widely supported video codecs (HEVC Main 12, ffv1).
|
||||
Logarithmic quantization is the default because it allocates more quanta
|
||||
to near-range depth, which matches the (1/depth) error profile of typical
|
||||
depth sensors. Math is ported from BEHAVIOR-1K's ``obs_utils.py``.
|
||||
|
||||
**Input units**:
|
||||
|
||||
- ``input_unit="auto"`` (default): infer from dtype (floating = m, non-floating = mm).
|
||||
- ``input_unit="mm"``: interpret input values as millimetres.
|
||||
- ``input_unit="m"``: interpret input values as metres.
|
||||
|
||||
Quantization math runs in the **resolved input unit**.
|
||||
|
||||
``depth_min``, ``depth_max``, and ``shift`` are always in **metres**.
|
||||
|
||||
Args:
|
||||
depth: Depth map; ``torch.Tensor`` is moved to CPU for conversion.
|
||||
depth_min: Depth (metres) at quantum ``0``.
|
||||
depth_max: Depth (metres) at quantum :data:`DEPTH_QMAX`.
|
||||
shift: Depth shift (metres); used in log mode. Must satisfy ``depth_min + shift > 0``.
|
||||
use_log: If ``True`` (default), quantize in log space.
|
||||
input_unit: Input unit policy (``"auto"``, ``"mm"``, ``"m"``).
|
||||
|
||||
Returns:
|
||||
``numpy.ndarray``, ``dtype=uint16``, same shape as ``depth``, values in
|
||||
``[0, DEPTH_QMAX]``.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``input_unit`` is not ``"auto"``, ``"mm"``, or ``"m"``.
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
"""
|
||||
if input_unit not in ("auto", "m", "mm"):
|
||||
raise ValueError(f"input_unit must be 'auto', 'm', or 'mm', got {input_unit!r}")
|
||||
|
||||
depth_f, resolved_unit = _depth_input_to_float32_and_unit(depth, input_unit=input_unit)
|
||||
depth_min_u = np.float32(depth_min) if resolved_unit == "m" else np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_u = np.float32(depth_max) if resolved_unit == "m" else np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_u = np.float32(shift) if resolved_unit == "m" else np.float32(shift * _MM_PER_METRE)
|
||||
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_u + shift_u))
|
||||
log_max = math.log(float(depth_max_u + shift_u))
|
||||
norm = (np.log(depth_f + shift_u) - log_min) / (log_max - log_min)
|
||||
else:
|
||||
norm = (depth_f - depth_min_u) / (depth_max_u - depth_min_u)
|
||||
|
||||
out = np.rint(norm * DEPTH_QMAX).clip(0, DEPTH_QMAX)
|
||||
return out.astype(np.uint16, copy=False)
|
||||
|
||||
|
||||
def dequantize_depth(
|
||||
quantized: NDArray[np.uint16] | torch.Tensor,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
*,
|
||||
output_unit: Literal["m", "mm"] = "mm",
|
||||
) -> NDArray[np.uint16] | NDArray[np.float32]:
|
||||
"""Inverse of :func:`quantize_depth`.
|
||||
|
||||
Tuning arguments **must match** :func:`quantize_depth`.
|
||||
|
||||
Decoding inverts the same normalized code mapping as :func:`quantize_depth`
|
||||
using ``depth_min`` / ``depth_max`` / ``shift`` (in metres), then returns
|
||||
the requested output unit.
|
||||
|
||||
Args:
|
||||
quantized: 12-bit codes ``[0, DEPTH_QMAX]``, ``dtype=uint16``.
|
||||
depth_min, depth_max, shift, use_log: Same as :func:`quantize_depth` (metres).
|
||||
output_unit: ``\"mm\"`` returns ``uint16`` millimetres (``rint``, clip
|
||||
``[0, 65535]``). ``\"m\"`` returns ``float32`` metres in
|
||||
``[depth_min, depth_max]``.
|
||||
|
||||
Returns:
|
||||
Depth map in the requested unit and dtype.
|
||||
|
||||
Raises:
|
||||
ValueError: If ``use_log=True`` and ``depth_min + shift <= 0``.
|
||||
ValueError: If ``output_unit`` is not ``\"m\"`` or ``\"mm\"``.
|
||||
"""
|
||||
if output_unit not in ("m", "mm"):
|
||||
raise ValueError(f"output_unit must be 'm' or 'mm', got {output_unit!r}")
|
||||
|
||||
if isinstance(quantized, torch.Tensor):
|
||||
quantized = quantized.detach().cpu().numpy()
|
||||
q = np.asarray(quantized, dtype=np.uint16, order="K")
|
||||
norm = q.astype(np.float32, copy=False) / DEPTH_QMAX
|
||||
|
||||
depth_min_mm = np.float32(depth_min * _MM_PER_METRE)
|
||||
depth_max_mm = np.float32(depth_max * _MM_PER_METRE)
|
||||
shift_mm = np.float32(shift * _MM_PER_METRE)
|
||||
|
||||
if use_log:
|
||||
_validate_log_quant_params(depth_min, shift)
|
||||
log_min = math.log(float(depth_min_mm + shift_mm))
|
||||
log_max = math.log(float(depth_max_mm + shift_mm))
|
||||
depth_mm = np.exp(norm * (log_max - log_min) + log_min) - shift_mm
|
||||
else:
|
||||
depth_mm = norm * (depth_max_mm - depth_min_mm) + depth_min_mm
|
||||
|
||||
depth_mm = np.clip(depth_mm, depth_min_mm, depth_max_mm).astype(np.float32, copy=False)
|
||||
if output_unit == "m":
|
||||
return (depth_mm / np.float32(_MM_PER_METRE)).astype(np.float32, copy=False)
|
||||
mm = np.rint(depth_mm).clip(0, _UINT16_MAX)
|
||||
return mm.astype(np.uint16, copy=False)
|
||||
@@ -294,10 +294,20 @@ def validate_feature_image_or_video(
|
||||
# Note: The check of pixels range ([0,1] for float and [0,255] for uint8) is done by the image writer threads.
|
||||
error_message = ""
|
||||
if isinstance(value, np.ndarray):
|
||||
actual_shape = value.shape
|
||||
c, h, w = expected_shape
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
actual_shape = tuple(value.shape)
|
||||
expected = tuple(expected_shape)
|
||||
if len(expected) == 2:
|
||||
# Single-channel features (e.g. depth maps) — accept (H,W), (1,H,W), (H,W,1)
|
||||
h, w = expected
|
||||
valid = actual_shape in {(h, w), (1, h, w), (h, w, 1)}
|
||||
if not valid:
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(h, w)}', '{(1, h, w)}', or '{(h, w, 1)}'.\n"
|
||||
elif len(expected) == 3:
|
||||
c, h, w = expected
|
||||
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
|
||||
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
|
||||
else:
|
||||
error_message += f"The feature '{name}' has an unsupported expected_shape '{expected}'.\n"
|
||||
elif isinstance(value, PILImage.Image):
|
||||
pass
|
||||
else:
|
||||
|
||||
@@ -41,15 +41,56 @@ def safe_stop_image_writer(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
# TODO(aliberts): handle 1 channel and 4 for depth images
|
||||
if image_array.ndim != 3:
|
||||
raise ValueError(f"The array has {image_array.ndim} dimensions, but 3 is expected for an image.")
|
||||
# Single-channel dtypes that PIL natively maps to the matching mode
|
||||
# (``uint8`` → ``L``, ``uint16`` → ``I;16``, ``float32`` → ``F``).
|
||||
GRAYSCALE_DTYPES: tuple[np.dtype, ...] = (
|
||||
np.dtype("uint8"),
|
||||
np.dtype("uint16"),
|
||||
np.dtype("float32"),
|
||||
)
|
||||
|
||||
|
||||
def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True) -> PIL.Image.Image:
|
||||
"""Convert a NumPy array to a PIL Image, preserving precision for grayscale.
|
||||
|
||||
Behaviour by shape:
|
||||
|
||||
- ``(H, W)`` or ``(1, H, W)`` / ``(H, W, 1)``: single-channel grayscale.
|
||||
The native dtype is preserved using the matching PIL mode
|
||||
(``L`` / ``I;16`` / ``F``). This is the path used for raw depth maps (no rescaling, clamping, or downcasting)
|
||||
- ``(3, H, W)`` / ``(H, W, 3)``: RGB. Channels-first inputs are transposed
|
||||
to channels-last. Float inputs in ``[0, 1]`` are scaled to ``uint8``
|
||||
(existing behaviour, gated by ``range_check``).
|
||||
|
||||
Other shapes / channel counts raise ``NotImplementedError`` or
|
||||
``ValueError``.
|
||||
"""
|
||||
if image_array.ndim not in (2, 3):
|
||||
raise ValueError(
|
||||
f"The array has {image_array.ndim} dimensions, but 2 or 3 is expected for an image."
|
||||
)
|
||||
|
||||
# Squeeze 3D single-channel inputs to 2D so depth maps work whether the
|
||||
# caller emits (H, W), (1, H, W), or (H, W, 1).
|
||||
if image_array.ndim == 3:
|
||||
if image_array.shape[0] == 1:
|
||||
image_array = image_array[0]
|
||||
elif image_array.shape[-1] == 1:
|
||||
image_array = image_array[..., 0]
|
||||
|
||||
if image_array.ndim == 2:
|
||||
if image_array.dtype not in GRAYSCALE_DTYPES:
|
||||
raise ValueError(
|
||||
f"Unsupported single-channel image dtype: {image_array.dtype}. "
|
||||
f"Supported dtypes: {sorted(str(d) for d in GRAYSCALE_DTYPES)}."
|
||||
)
|
||||
|
||||
return PIL.Image.fromarray(np.ascontiguousarray(image_array))
|
||||
|
||||
# 3D path: must be RGB (3 channels), channels-first or channels-last.
|
||||
if image_array.shape[0] == 3:
|
||||
# Transpose from pytorch convention (C, H, W) to (H, W, C)
|
||||
image_array = image_array.transpose(1, 2, 0)
|
||||
|
||||
elif image_array.shape[-1] != 3:
|
||||
raise NotImplementedError(
|
||||
f"The image has {image_array.shape[-1]} channels, but 3 is required for now."
|
||||
@@ -71,13 +112,28 @@ def image_array_to_pil_image(image_array: np.ndarray, range_check: bool = True)
|
||||
return PIL.Image.fromarray(image_array)
|
||||
|
||||
|
||||
def save_kwargs_for_path(fpath: Path, compress_level: int) -> dict:
|
||||
"""Pick the right format-specific kwargs for :meth:`PIL.Image.Image.save`.
|
||||
|
||||
PNG uses ``compress_level`` (0–9, zlib). TIFF uses ``compression`` (raw) for lossless raw depth maps.
|
||||
"""
|
||||
suffix = Path(fpath).suffix.lower()
|
||||
if suffix == ".png":
|
||||
return {"compress_level": compress_level}
|
||||
if suffix in (".tif", ".tiff"):
|
||||
return {"compression": "raw"}
|
||||
return {}
|
||||
|
||||
|
||||
def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1):
|
||||
"""
|
||||
Saves a NumPy array or PIL Image to a file.
|
||||
|
||||
This function handles both NumPy arrays and PIL Image objects, converting
|
||||
the former to a PIL Image before saving. It includes error handling for
|
||||
the save operation.
|
||||
the save operation. The output format is inferred from the *fpath*
|
||||
extension: ``.png`` → PNG with ``compress_level``, ``.tiff`` / ``.tif``
|
||||
→ lossless raw depth maps (TIFF).
|
||||
|
||||
Args:
|
||||
image (np.ndarray | PIL.Image.Image): The image data to save.
|
||||
@@ -101,7 +157,7 @@ def write_image(image: np.ndarray | PIL.Image.Image, fpath: Path, compress_level
|
||||
img = image
|
||||
else:
|
||||
raise TypeError(f"Unsupported image type: {type(image)}")
|
||||
img.save(fpath, compress_level=compress_level)
|
||||
img.save(fpath, **save_kwargs_for_path(Path(fpath), compress_level))
|
||||
except Exception as e:
|
||||
logger.error("Error writing image %s: %s", fpath, e)
|
||||
|
||||
|
||||
@@ -35,9 +35,11 @@ from .utils import (
|
||||
is_valid_version,
|
||||
)
|
||||
from .video_utils import (
|
||||
DepthEncoderConfig,
|
||||
StreamingVideoEncoder,
|
||||
get_safe_default_codec,
|
||||
resolve_vcodec,
|
||||
VideoEncoderConfig,
|
||||
get_safe_default_video_backend,
|
||||
seed_depth_feature_info,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -58,10 +60,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
return_uint8: bool = False,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -177,16 +180,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision.
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
|
||||
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
|
||||
camera_encoder_config (VideoEncoderConfig | None, optional): Video encoder settings for cameras
|
||||
(codec, quality, etc.). Defaults to
|
||||
:class:`~lerobot.datasets.video_utils.VideoEncoderConfig` defaults when ``None``.
|
||||
encoder_threads (int | None, optional): Number of encoder threads (global). ``None`` lets the
|
||||
codec decide.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
|
||||
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
|
||||
streaming encoding. Defaults to 30 (~1s at 30fps).
|
||||
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
|
||||
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
|
||||
libsvtav1 and 'threads' for h264/hevc.
|
||||
|
||||
Note:
|
||||
Write-mode parameters (``streaming_encoding``, ``batch_encoding_size``) passed to
|
||||
@@ -202,10 +204,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episodes = episodes
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
self._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
self._return_uint8 = return_uint8
|
||||
self._batch_encoding_size = batch_encoding_size
|
||||
self._vcodec = resolve_vcodec(vcodec)
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
self._camera_encoder_config = camera_encoder_config
|
||||
self._depth_encoder_config = depth_encoder_config
|
||||
self._encoder_threads = encoder_threads
|
||||
|
||||
if self._requested_root is not None:
|
||||
@@ -248,16 +253,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
seed_depth_feature_info(self.meta.features, self._depth_encoder_config)
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(self.meta.video_keys) > 0:
|
||||
streaming_enc = self._build_streaming_encoder(
|
||||
self.meta.fps, self._vcodec, encoder_queue_maxsize, encoder_threads
|
||||
self.meta.fps,
|
||||
self._camera_encoder_config,
|
||||
self._encoder_threads,
|
||||
encoder_queue_maxsize,
|
||||
depth_encoder_config=self._depth_encoder_config,
|
||||
depth_keys=self.meta.depth_keys,
|
||||
)
|
||||
self.writer = DatasetWriter(
|
||||
meta=self.meta,
|
||||
root=self.root,
|
||||
vcodec=self._vcodec,
|
||||
encoder_threads=encoder_threads,
|
||||
camera_encoder_config=self._camera_encoder_config,
|
||||
depth_encoder_config=self._depth_encoder_config,
|
||||
encoder_threads=self._encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
initial_frames=self.meta.total_frames,
|
||||
@@ -298,19 +310,20 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
@staticmethod
|
||||
def _build_streaming_encoder(
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
encoder_queue_maxsize: int,
|
||||
camera_encoder_config: VideoEncoderConfig,
|
||||
encoder_threads: int | None,
|
||||
encoder_queue_maxsize: int,
|
||||
*,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
depth_keys: list[str] | None = None,
|
||||
) -> StreamingVideoEncoder:
|
||||
return StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
encoder_threads=encoder_threads,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
depth_keys=depth_keys,
|
||||
)
|
||||
|
||||
# ── Metadata properties ───────────────────────────────────────────
|
||||
@@ -625,7 +638,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_writer_threads: int = 0,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
@@ -656,20 +670,23 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: Video decoding backend (used when reading back).
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos. ``1`` means encode immediately.
|
||||
vcodec: Video codec for encoding. Options include ``'libsvtav1'``,
|
||||
``'h264'``, ``'hevc'``, ``'auto'``.
|
||||
camera_encoder_config: Video encoder settings for cameras; defaults
|
||||
match :class:`~lerobot.datasets.video_utils.VideoEncoderConfig`
|
||||
when ``None``.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
metadata_buffer_size: Number of episode metadata records to buffer
|
||||
before flushing to parquet.
|
||||
streaming_encoding: If ``True``, encode video frames in real-time
|
||||
during capture instead of writing images first.
|
||||
encoder_queue_maxsize: Max buffered frames per camera when using
|
||||
streaming encoding.
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
|
||||
Returns:
|
||||
A new :class:`LeRobotDataset` in write mode.
|
||||
"""
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -690,23 +707,32 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._camera_encoder_config = camera_encoder_config
|
||||
obj._depth_encoder_config = depth_encoder_config
|
||||
obj._encoder_threads = encoder_threads
|
||||
seed_depth_feature_info(obj.meta.features, depth_encoder_config)
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
# Create writer
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(fps, vcodec, encoder_queue_maxsize, encoder_threads)
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
fps,
|
||||
camera_encoder_config,
|
||||
encoder_threads,
|
||||
encoder_queue_maxsize,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
depth_keys=obj.meta.depth_keys,
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
vcodec=vcodec,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
@@ -729,12 +755,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
force_cache_sync: bool = False,
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
depth_encoder_config: DepthEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
image_writer_processes: int = 0,
|
||||
image_writer_threads: int = 0,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Resume recording on an existing dataset.
|
||||
|
||||
@@ -757,13 +784,16 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: Video decoding backend for reading back data.
|
||||
batch_encoding_size: Number of episodes to accumulate before
|
||||
batch-encoding videos.
|
||||
vcodec: Video codec for encoding.
|
||||
camera_encoder_config: Video encoder settings for cameras; defaults
|
||||
match :class:`~lerobot.datasets.video_utils.VideoEncoderConfig`
|
||||
when ``None``.
|
||||
encoder_threads: Number of encoder threads (global). ``None``
|
||||
lets the codec decide.
|
||||
image_writer_processes: Subprocesses for async image writing.
|
||||
image_writer_threads: Threads for async image writing.
|
||||
streaming_encoding: If ``True``, encode video in real-time during
|
||||
capture.
|
||||
encoder_queue_maxsize: Max buffered frames per camera for streaming.
|
||||
encoder_threads: Threads per encoder instance. ``None`` for auto.
|
||||
|
||||
Returns:
|
||||
A :class:`LeRobotDataset` in write mode, ready to append episodes.
|
||||
@@ -774,7 +804,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"Writing into the revision-safe Hub snapshot cache (used when root=None) would corrupt "
|
||||
"the shared cache. Please provide a local directory path."
|
||||
)
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
obj = cls.__new__(cls)
|
||||
obj.repo_id = repo_id
|
||||
obj._requested_root = Path(root)
|
||||
@@ -783,11 +812,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.episodes = None
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||
obj._video_backend = video_backend if video_backend else get_safe_default_video_backend()
|
||||
obj._return_uint8 = False
|
||||
obj._batch_encoding_size = batch_encoding_size
|
||||
obj._vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
if obj._requested_root is not None:
|
||||
obj._requested_root.mkdir(exist_ok=True, parents=True)
|
||||
@@ -796,21 +823,33 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.meta = LeRobotDatasetMetadata(
|
||||
obj.repo_id, obj._requested_root, obj.revision, force_cache_sync=force_cache_sync
|
||||
)
|
||||
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
obj._camera_encoder_config = camera_encoder_config
|
||||
obj._depth_encoder_config = depth_encoder_config
|
||||
obj._encoder_threads = encoder_threads
|
||||
obj.root = obj.meta.root
|
||||
seed_depth_feature_info(obj.meta.features, depth_encoder_config)
|
||||
|
||||
# Reader is lazily created on first access (write-only mode)
|
||||
obj.reader = None
|
||||
|
||||
# Create writer for appending
|
||||
streaming_enc = None
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
streaming_enc = cls._build_streaming_encoder(
|
||||
obj.meta.fps, vcodec, encoder_queue_maxsize, encoder_threads
|
||||
obj.meta.fps,
|
||||
camera_encoder_config,
|
||||
encoder_threads,
|
||||
encoder_queue_maxsize,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
depth_keys=obj.meta.depth_keys,
|
||||
)
|
||||
obj.writer = DatasetWriter(
|
||||
meta=obj.meta,
|
||||
root=obj.root,
|
||||
vcodec=vcodec,
|
||||
camera_encoder_config=camera_encoder_config,
|
||||
depth_encoder_config=depth_encoder_config,
|
||||
encoder_threads=encoder_threads,
|
||||
batch_encoding_size=batch_encoding_size,
|
||||
streaming_encoder=streaming_enc,
|
||||
|
||||
311
src/lerobot/datasets/pyav_utils.py
Normal file
311
src/lerobot/datasets/pyav_utils.py
Normal file
@@ -0,0 +1,311 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyAV-based compatibility checks for :class:`VideoEncoderConfig`.
|
||||
|
||||
Centralises all :mod:`av` introspection of the bundled FFmpeg build.
|
||||
Checks degrade to a no-op when the target codec isn't available locally.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.depth_utils import (
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
quantize_depth,
|
||||
dequantize_depth,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.datasets.video_utils import VideoEncoderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Pixel formats supported by the depth encode/decode helpers below. Both are
|
||||
# 16-bit-word formats that carry 12 significant bits per sample, matching the
|
||||
# ``DEPTH_QMAX = 4095`` quantization range.
|
||||
DEPTH_PIX_FMTS: tuple[str, ...] = ("yuv420p12le", "gray12le")
|
||||
|
||||
# Neutral chroma for 12-bit YUV (the midpoint of [0, 4095]). Filling the U/V
|
||||
# planes with this value keeps the encoder from spending bits on chroma noise
|
||||
# when only the Y plane carries information.
|
||||
_NEUTRAL_CHROMA_12BIT: int = 2048
|
||||
|
||||
FFMPEG_NUMERIC_OPTION_TYPES = ("INT", "INT64", "UINT64", "FLOAT", "DOUBLE")
|
||||
FFMPEG_INTEGER_OPTION_TYPES = ("INT", "INT64", "UINT64")
|
||||
|
||||
|
||||
def _write_u16_plane(plane: av.video.plane.VideoPlane, src: np.ndarray, fill_value: int | None = None) -> None:
|
||||
"""Copy ``src`` into a uint16 plane respecting FFmpeg line padding."""
|
||||
height, width = src.shape
|
||||
stride_u16 = plane.line_size // np.dtype(np.uint16).itemsize
|
||||
dst = np.frombuffer(plane, dtype=np.uint16).reshape(height, stride_u16)
|
||||
if fill_value is not None:
|
||||
dst.fill(fill_value)
|
||||
dst[:, :width] = src
|
||||
|
||||
|
||||
def encode_depth_frame_pyav(
|
||||
depth: np.ndarray | torch.Tensor,
|
||||
*,
|
||||
pix_fmt: str = "yuv420p12le",
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
input_unit: Literal["auto", "m", "mm"] = "auto",
|
||||
) -> av.VideoFrame:
|
||||
"""Quantize depth and pack it into a 12-bit PyAV video frame.
|
||||
|
||||
Args:
|
||||
depth: Depth frame to encode (H, W). Unit handling follows
|
||||
:func:`lerobot.datasets.depth_utils.quantize_depth`.
|
||||
pix_fmt: Target pixel format. Must be one of :data:`DEPTH_PIX_FMTS`.
|
||||
depth_min, depth_max, shift, use_log, input_unit: Forwarded to
|
||||
:func:`quantize_depth`.
|
||||
|
||||
Returns:
|
||||
An :class:`av.VideoFrame` in ``pix_fmt`` with quantized depth in the
|
||||
luminance plane.
|
||||
"""
|
||||
if pix_fmt not in DEPTH_PIX_FMTS:
|
||||
raise ValueError(f"Unsupported depth pix_fmt={pix_fmt!r}; expected one of {DEPTH_PIX_FMTS}")
|
||||
|
||||
quantized_depth = quantize_depth(
|
||||
depth,
|
||||
depth_min=depth_min,
|
||||
depth_max=depth_max,
|
||||
shift=shift,
|
||||
use_log=use_log,
|
||||
input_unit=input_unit,
|
||||
)
|
||||
if quantized_depth.ndim != 2:
|
||||
raise ValueError(f"depth must be a 2D frame; got shape {quantized_depth.shape}")
|
||||
|
||||
quantized_depth = np.ascontiguousarray(quantized_depth, dtype=np.uint16)
|
||||
height, width = quantized_depth.shape
|
||||
|
||||
if pix_fmt == "gray12le":
|
||||
frame = av.VideoFrame(width=width, height=height, format="gray12le")
|
||||
_write_u16_plane(frame.planes[0], quantized_depth)
|
||||
return frame
|
||||
|
||||
if height % 2 != 0 or width % 2 != 0:
|
||||
raise ValueError("yuv420p12le requires even H and W")
|
||||
|
||||
frame = av.VideoFrame(width=width, height=height, format="yuv420p12le")
|
||||
_write_u16_plane(frame.planes[0], quantized_depth)
|
||||
neutral_chroma = np.full((height // 2, width // 2), _NEUTRAL_CHROMA_12BIT, dtype=np.uint16)
|
||||
_write_u16_plane(frame.planes[1], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT)
|
||||
_write_u16_plane(frame.planes[2], neutral_chroma, fill_value=_NEUTRAL_CHROMA_12BIT)
|
||||
return frame
|
||||
|
||||
|
||||
def decode_depth_frame_pyav(
|
||||
frame: av.VideoFrame | list[av.VideoFrame],
|
||||
*,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
return_quantized: bool = False,
|
||||
output_unit: Literal["m", "mm"] = "m",
|
||||
) -> np.ndarray:
|
||||
"""Decode one or many depth video frames to quantized or metric depth.
|
||||
|
||||
Args:
|
||||
frame: A single depth frame or a list of depth frames.
|
||||
depth_min, depth_max, shift, use_log: Forwarded to
|
||||
:func:`dequantize_depth`.
|
||||
return_quantized: If ``True``, return raw 12-bit quanta as ``uint16``.
|
||||
output_unit: Unit for dequantized output (``"m"`` or ``"mm"``).
|
||||
|
||||
Returns:
|
||||
``(H, W)`` array for a single frame, or ``(N, H, W)`` for a list.
|
||||
"""
|
||||
frames = frame if isinstance(frame, list) else [frame]
|
||||
quantized = np.stack([f.reformat(format="gray12le").to_ndarray() for f in frames]).astype(np.uint16, copy=False)
|
||||
if return_quantized:
|
||||
return quantized[0] if len(frames) == 1 else quantized
|
||||
|
||||
decoded = dequantize_depth(
|
||||
quantized,
|
||||
depth_min=depth_min,
|
||||
depth_max=depth_max,
|
||||
shift=shift,
|
||||
use_log=use_log,
|
||||
output_unit=output_unit,
|
||||
)
|
||||
return decoded[0] if len(frames) == 1 else decoded
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_codec(vcodec: str) -> av.codec.Codec | None:
|
||||
"""PyAV write-mode ``Codec`` for *vcodec*, or ``None`` if unavailable."""
|
||||
try:
|
||||
return av.codec.Codec(vcodec, "w")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_video_formats(vcodec: str) -> dict[str, av.option.Option]:
|
||||
"""Private-option name → PyAV ``Option`` for *vcodec* (empty if unavailable)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
return {}
|
||||
return {opt.name: opt for opt in codec.descriptor.options}
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_codec_video_formats(vcodec: str) -> tuple[str, ...]:
|
||||
"""Pixel formats accepted by *vcodec* in PyAV's preferred order (empty if unknown)."""
|
||||
codec = get_codec(vcodec)
|
||||
if codec is None:
|
||||
return ()
|
||||
return tuple(fmt.name for fmt in (codec.video_formats or []))
|
||||
|
||||
|
||||
def detect_available_encoders_pyav(encoders: list[str] | str) -> list[str]:
|
||||
"""Return the subset of *encoders* available as video encoders in the local FFmpeg build.
|
||||
|
||||
Each name is probed directly via :func:`get_codec`; input order is preserved.
|
||||
"""
|
||||
if isinstance(encoders, str):
|
||||
encoders = [encoders]
|
||||
|
||||
available: list[str] = []
|
||||
for name in encoders:
|
||||
codec = get_codec(name)
|
||||
if codec is not None and codec.type == "video":
|
||||
available.append(name)
|
||||
else:
|
||||
logger.debug("encoder '%s' not available as video encoder", name)
|
||||
return available
|
||||
|
||||
|
||||
def _check_option_value(vcodec: str, label: str, value: Any, opt: av.option.Option) -> None:
|
||||
"""Range-check numeric *value* and choice-check string *value* against *opt*."""
|
||||
type_name = opt.type.name
|
||||
if type_name in FFMPEG_NUMERIC_OPTION_TYPES:
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
try:
|
||||
num_val = float(value)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
) from e
|
||||
elif isinstance(value, (float, int)):
|
||||
num_val = value
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{label}={value!r} is not numeric; codec {vcodec!r} expects a number for this option."
|
||||
)
|
||||
|
||||
# Check integer type compatibility
|
||||
if type_name in FFMPEG_INTEGER_OPTION_TYPES and not num_val.is_integer():
|
||||
raise ValueError(
|
||||
f"{label}={num_val!r} must be an integer for codec {vcodec!r} "
|
||||
f"(FFmpeg option {opt.name!r} is {type_name}); float values are not allowed."
|
||||
)
|
||||
|
||||
# Check numeric range compatibility
|
||||
lo, hi = float(opt.min), float(opt.max)
|
||||
if lo < hi and not (lo <= num_val <= hi):
|
||||
raise ValueError(
|
||||
f"{label}={num_val} is out of range for codec {vcodec!r}; must be in [{lo}, {hi}]"
|
||||
)
|
||||
|
||||
elif type_name == "STRING":
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(f"{label}={value!r} is not a valid string value for codec {vcodec!r}.")
|
||||
if isinstance(value, str):
|
||||
str_val = value
|
||||
elif isinstance(value, (int, float)):
|
||||
str_val = str(value)
|
||||
else:
|
||||
raise ValueError(f"{label}={value!r} has unsupported type for STRING option on codec {vcodec!r}")
|
||||
|
||||
# Check string choice compatibility
|
||||
choices = [c.name for c in (opt.choices or [])]
|
||||
if choices and str_val not in choices:
|
||||
raise ValueError(
|
||||
f"{label}={str_val!r} is not a supported choice for codec "
|
||||
f"{vcodec!r}; valid choices: {choices}"
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
|
||||
def _check_pixel_format(vcodec: str, pix_fmt: str) -> None:
|
||||
formats = _get_codec_video_formats(vcodec)
|
||||
if formats and pix_fmt not in formats:
|
||||
raise ValueError(
|
||||
f"pix_fmt={pix_fmt!r} is not supported by codec {vcodec!r}; "
|
||||
f"supported pixel formats: {list(formats)}"
|
||||
)
|
||||
|
||||
|
||||
def _check_codec_options(vcodec: str, codec_options: dict[str, Any], config: VideoEncoderConfig) -> None:
|
||||
"""Validate merged encoder options (typed) against the codec's published AVOptions."""
|
||||
supported_options = _get_codec_options_by_name(vcodec)
|
||||
for key, value in codec_options.items():
|
||||
# GOP size is not a codec-specific option, it has to be validated separately.
|
||||
if key == "g":
|
||||
if isinstance(value, bool) or not isinstance(value, int) or value < 1:
|
||||
raise ValueError(f"g={value!r} must be a positive integer for codec {vcodec!r}")
|
||||
continue
|
||||
if key not in supported_options:
|
||||
continue
|
||||
opt = supported_options[key]
|
||||
label = f"extra_options[{key!r}]" if key in config.extra_options else key
|
||||
_check_option_value(vcodec, label, value, opt)
|
||||
|
||||
|
||||
def check_video_encoder_config_pyav(config: VideoEncoderConfig) -> None:
|
||||
"""Verify *config* is compatible with the bundled FFmpeg build.
|
||||
|
||||
Checks pixel format, abstract tuning-field compatibility, and each merged
|
||||
encoder option from :meth:`~lerobot.datasets.video_utils.VideoEncoderConfig.get_codec_options`
|
||||
against PyAV (including numeric ``extra_options`` present in that dict).
|
||||
No-op when ``config.vcodec`` isn't in the local FFmpeg build.
|
||||
|
||||
Raises:
|
||||
ValueError: on the first incompatibility encountered.
|
||||
"""
|
||||
vcodec = config.vcodec
|
||||
options = _get_codec_options_by_name(vcodec)
|
||||
if not options:
|
||||
logger.warning(
|
||||
"Codec %r is not available in the bundled FFmpeg build; ",
|
||||
vcodec,
|
||||
)
|
||||
return
|
||||
_check_pixel_format(config.vcodec, config.pix_fmt)
|
||||
_check_codec_options(config.vcodec, config.get_codec_options(), config)
|
||||
@@ -93,6 +93,10 @@ DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet"
|
||||
DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4"
|
||||
DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.png"
|
||||
# Depth maps live alongside images on disk but use TIFF instead of PNG: PNG
|
||||
# cannot natively round-trip float32, and several common loaders silently
|
||||
# downcast 16-bit grayscale.
|
||||
DEFAULT_DEPTH_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{frame_index:06d}.tiff"
|
||||
|
||||
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
|
||||
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
|
||||
|
||||
@@ -17,12 +17,13 @@ import contextlib
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
@@ -37,7 +38,23 @@ import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
from lerobot.datasets.pyav_utils import (
|
||||
check_video_encoder_config_pyav,
|
||||
depth_to_video_frame,
|
||||
detect_available_encoders_pyav,
|
||||
decode_depth_frame,
|
||||
encode_depth_frame_pyav,
|
||||
decode_depth_frame_pyav,
|
||||
)
|
||||
from lerobot.datasets.depth_utils import (
|
||||
quantize_depth,
|
||||
dequantize_depth,
|
||||
DEFAULT_DEPTH_MIN,
|
||||
DEFAULT_DEPTH_MAX,
|
||||
DEFAULT_DEPTH_SHIFT,
|
||||
DEFAULT_DEPTH_USE_LOG,
|
||||
)
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,70 +69,226 @@ HW_ENCODERS = [
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "ffv1", "auto"} | set(HW_ENCODERS)
|
||||
|
||||
LIBSVTAV1_DEFAULT_PRESET: int = 12
|
||||
|
||||
|
||||
def _get_codec_options(
|
||||
vcodec: str,
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
) -> dict:
|
||||
"""Build codec-specific options dict for video encoding."""
|
||||
options = {}
|
||||
@dataclass
|
||||
class VideoEncoderConfig:
|
||||
"""Video encoder configuration.
|
||||
|
||||
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
|
||||
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
|
||||
options["g"] = str(g)
|
||||
Attributes:
|
||||
vcodec: FFmpeg encoder name. ``"auto"`` is resolved during
|
||||
construction (HW encoder if available, else ``libsvtav1``).
|
||||
pix_fmt: Pixel format (e.g. ``"yuv420p"``).
|
||||
g: GOP size (keyframe interval).
|
||||
crf: Quality level — mapped to the native quality parameter of the
|
||||
codec (``crf`` for software, ``qp`` for NVENC/VAAPI,
|
||||
``q:v`` for VideoToolbox, ``global_quality`` for QSV).
|
||||
preset: Speed/quality preset. Accepted type is per-codec.
|
||||
fast_decode: Fast-decode tuning. For ``libsvtav1`` this is a level (0-2)
|
||||
embedded in ``svtav1-params``. For ``h264`` and ``hevc`` non-zero values
|
||||
set ``tune=fastdecode``. Ignored for other codecs.
|
||||
video_backend: Python library driving FFmpeg for encoding. Only ``"pyav"``
|
||||
is currently supported.
|
||||
extra_options: Free-form dictionary of additional FFmpeg options
|
||||
(e.g. ``{"tune": "film", "profile:v": "high", "bf": 2}``).
|
||||
"""
|
||||
|
||||
# Quality control (codec-specific parameter names)
|
||||
if crf is not None:
|
||||
if vcodec in ("h264", "hevc", "libsvtav1"):
|
||||
options["crf"] = str(crf)
|
||||
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
quality = max(1, min(100, int(100 - crf * 2)))
|
||||
options["q:v"] = str(quality)
|
||||
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
options["rc"] = "constqp"
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_vaapi",):
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_qsv",):
|
||||
options["global_quality"] = str(crf)
|
||||
vcodec: str = "libsvtav1"
|
||||
pix_fmt: str = "yuv420p"
|
||||
g: int | None = 2
|
||||
crf: int | None = 30
|
||||
preset: int | str | None = None
|
||||
fast_decode: int = 0
|
||||
# TODO(CarolinePascal): add torchcodec support + find a way to unify the
|
||||
# two backends (encoding and decoding).
|
||||
video_backend: str = "pyav"
|
||||
extra_options: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Preset (only for libsvtav1)
|
||||
if vcodec == "libsvtav1":
|
||||
options["preset"] = str(preset) if preset is not None else "12"
|
||||
# Class-level marker persisted to ``info.json`` (via ``asdict``) so the
|
||||
# reader can tell depth datasets from RGB ones without a separate dispatch
|
||||
# path. ``init=False`` keeps it out of CLI/constructor surface; subclasses
|
||||
# flip the default (see :class:`DepthEncoderConfig`).
|
||||
is_depth_map: bool = field(default=False, init=False)
|
||||
|
||||
return options
|
||||
def __post_init__(self) -> None:
|
||||
self.resolve_vcodec()
|
||||
|
||||
# Empty-constructor ergonomics: ``VideoEncoderConfig()`` must "just work".
|
||||
if self.preset is None and self.vcodec == "libsvtav1":
|
||||
self.preset = LIBSVTAV1_DEFAULT_PRESET
|
||||
|
||||
self.validate()
|
||||
|
||||
def detect_available_encoders(self, encoders: list[str] | str) -> list[str]:
|
||||
"""Detect available encoders based on the video backend."""
|
||||
if self.video_backend == "pyav":
|
||||
return detect_available_encoders_pyav(encoders)
|
||||
else:
|
||||
return []
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Validate the video encoder config."""
|
||||
if self.video_backend == "pyav":
|
||||
check_video_encoder_config_pyav(self)
|
||||
|
||||
def resolve_vcodec(self) -> None:
|
||||
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1.
|
||||
|
||||
Any explicitly-requested codec that isn't in the local FFmpeg build is
|
||||
also silently rewritten to ``libsvtav1`` so encoding never hard-fails on
|
||||
a host missing the requested encoder.
|
||||
"""
|
||||
# Backward compatibility: older datasets persist ``vcodec="av1"`` in
|
||||
# ``info.json``. Rewrite to the canonical encoder name *before* the
|
||||
# validation check below so loading those datasets keeps working.
|
||||
if self.vcodec == "av1":
|
||||
self.vcodec = "libsvtav1"
|
||||
|
||||
if self.vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{self.vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if self.vcodec == "auto":
|
||||
available = self.detect_available_encoders(HW_ENCODERS)
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
self.vcodec = encoder
|
||||
return
|
||||
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
self.vcodec = "libsvtav1"
|
||||
|
||||
if self.detect_available_encoders(self.vcodec):
|
||||
logger.info(f"Using video codec: {self.vcodec}")
|
||||
self.vcodec = self.vcodec
|
||||
return
|
||||
raise ValueError(f"Unsupported video codec: {self.vcodec} with video backend {self.video_backend}")
|
||||
|
||||
def get_codec_options(
|
||||
self, encoder_threads: int | None = None, as_strings: bool = False
|
||||
) -> dict[str, str]:
|
||||
"""Translate the tuning fields to codec-specific FFmpeg options.
|
||||
|
||||
``VideoEncoderConfig.extra_options`` are merged last but never override a structured field.
|
||||
|
||||
Args:
|
||||
encoder_threads: Number of encoder threads set globally for all VideoEncoderConfigs.
|
||||
For libsvtav1, this is mapped to ``lp`` via ``svtav1-params``.
|
||||
For h264/hevc, this is mapped to ``threads``.
|
||||
Hardware encoders ignore this parameter.
|
||||
as_strings: If ``True``, casts values to strings.
|
||||
"""
|
||||
opts: dict[str, Any] = {}
|
||||
|
||||
def set_if(key: str, value: Any) -> None:
|
||||
if value is not None:
|
||||
opts[key] = value if not as_strings else str(value)
|
||||
|
||||
# GOP size is not a codec-specific option, so it is always set.
|
||||
set_if("g", self.g)
|
||||
|
||||
if self.vcodec == "libsvtav1":
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
svtav1_parts: list[str] = []
|
||||
if self.fast_decode is not None:
|
||||
svtav1_parts.append(f"fast-decode={max(0, min(2, self.fast_decode))}")
|
||||
if encoder_threads is not None:
|
||||
svtav1_parts.append(f"lp={encoder_threads}")
|
||||
if svtav1_parts:
|
||||
opts["svtav1-params"] = ":".join(svtav1_parts)
|
||||
elif self.vcodec in ("h264", "hevc"):
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
if self.fast_decode:
|
||||
opts["tune"] = "fastdecode"
|
||||
set_if("threads", encoder_threads)
|
||||
elif self.vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
if self.crf is not None:
|
||||
opts["q:v"] = max(1, min(100, 100 - self.crf * 2))
|
||||
elif self.vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
opts["rc"] = "constqp"
|
||||
set_if("qp", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "h264_vaapi":
|
||||
set_if("qp", self.crf)
|
||||
elif self.vcodec == "h264_qsv":
|
||||
set_if("global_quality", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
elif self.vcodec == "ffv1":
|
||||
# Lossless intra-frame codec. ``crf``/``preset``/``fast_decode``
|
||||
# are not meaningful.
|
||||
set_if("threads", encoder_threads)
|
||||
else:
|
||||
set_if("crf", self.crf)
|
||||
set_if("preset", self.preset)
|
||||
|
||||
# Extra options are merged last but never override structured fields (values are kept as given).
|
||||
for k, v in self.extra_options.items():
|
||||
if k not in opts:
|
||||
set_if(k, v)
|
||||
|
||||
return opts
|
||||
|
||||
|
||||
def detect_available_hw_encoders() -> list[str]:
|
||||
"""Probe PyAV/FFmpeg for available hardware video encoders."""
|
||||
available = []
|
||||
for codec_name in HW_ENCODERS:
|
||||
try:
|
||||
av.codec.Codec(codec_name, "w")
|
||||
available.append(codec_name)
|
||||
except Exception: # nosec B110
|
||||
logger.debug("HW encoder '%s' not available", codec_name) # nosec B110
|
||||
return available
|
||||
@dataclass
|
||||
class DepthEncoderConfig(VideoEncoderConfig):
|
||||
"""Encoder configuration for depth-map streams.
|
||||
|
||||
Inherits the full :class:`VideoEncoderConfig` surface (codec, GOP, CRF,
|
||||
preset, ``extra_options``…) and adds the four parameters of the depth
|
||||
quantization pipeline (:func:`quantize_depth`). Inheritance — rather
|
||||
than composition — keeps the CLI flat: ``--dataset.depth_encoder_config.<field>``
|
||||
works identically to its RGB counterpart.
|
||||
|
||||
Defaults flip ``vcodec`` to ``"hevc"`` (Main 12 profile) and ``pix_fmt``
|
||||
to ``"yuv420p12le"``, the most widely available 12-bit pixel format.
|
||||
For archive-grade lossless storage use ``vcodec="ffv1"`` together with
|
||||
``pix_fmt="gray12le"`` (and clear ``crf``/``preset`` to ``None`` since
|
||||
``ffv1`` doesn't expose those tuning knobs).
|
||||
|
||||
The :attr:`is_depth_map` marker is class-fixed to ``True`` (``init=False``,
|
||||
so it's hidden from CLI and constructor args) and is what the reader
|
||||
side keys on to tell depth datasets from RGB ones.
|
||||
|
||||
Attributes:
|
||||
depth_min: Minimum depth in physical units (e.g. metres) represented
|
||||
by quantum ``0``.
|
||||
depth_max: Maximum depth represented by quantum :data:`DEPTH_QMAX`.
|
||||
shift: Pre-log offset for numerical stability near zero.
|
||||
use_log: ``True`` for logarithmic quantization (default; matches
|
||||
sensor error profile), ``False`` for linear.
|
||||
"""
|
||||
|
||||
vcodec: str = "hevc"
|
||||
pix_fmt: str = "yuv420p12le"
|
||||
|
||||
depth_min: float = DEFAULT_DEPTH_MIN
|
||||
depth_max: float = DEFAULT_DEPTH_MAX
|
||||
shift: float = DEFAULT_DEPTH_SHIFT
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG
|
||||
|
||||
# Class invariant — kept out of ``__init__`` (and CLI) but persisted
|
||||
# via ``asdict`` into ``info.json`` for the reader to detect depth.
|
||||
is_depth_map: bool = field(default=True, init=False)
|
||||
|
||||
def quantize(self, depth: torch.Tensor | np.ndarray) -> torch.Tensor:
|
||||
"""Apply :func:`quantize_depth` bound to this config's parameters."""
|
||||
return quantize_depth(depth, self.depth_min, self.depth_max, self.shift, self.use_log)
|
||||
|
||||
def dequantize(self, quantized: torch.Tensor | np.ndarray) -> torch.Tensor:
|
||||
"""Apply :func:`dequantize_depth` bound to this config's parameters."""
|
||||
return dequantize_depth(quantized, self.depth_min, self.depth_max, self.shift, self.use_log)
|
||||
|
||||
|
||||
def resolve_vcodec(vcodec: str) -> str:
|
||||
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if vcodec != "auto":
|
||||
logger.info(f"Using video codec: {vcodec}")
|
||||
return vcodec
|
||||
available = detect_available_hw_encoders()
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logger.info(f"Auto-selected video codec: {encoder}")
|
||||
return encoder
|
||||
logger.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
return "libsvtav1"
|
||||
def depth_encoder_defaults() -> DepthEncoderConfig:
|
||||
"""Return a :class:`DepthEncoderConfig` with depth-camera defaults."""
|
||||
return DepthEncoderConfig()
|
||||
|
||||
def camera_encoder_defaults() -> VideoEncoderConfig:
|
||||
"""Return a :class:`VideoEncoderConfig` with RGB-camera defaults."""
|
||||
return VideoEncoderConfig()
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
@@ -142,7 +315,7 @@ def decode_video_frames(
|
||||
Currently supports torchcodec on cpu and pyav.
|
||||
"""
|
||||
if backend is None:
|
||||
backend = get_safe_default_codec()
|
||||
backend = get_safe_default_video_backend()
|
||||
if backend == "torchcodec":
|
||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||
elif backend in ["pyav", "video_reader"]:
|
||||
@@ -396,22 +569,136 @@ def decode_video_frames_torchcodec(
|
||||
return closest_frames
|
||||
|
||||
|
||||
def decode_depth_frames(
|
||||
video_path: Path | str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
*,
|
||||
depth_min: float = DEFAULT_DEPTH_MIN,
|
||||
depth_max: float = DEFAULT_DEPTH_MAX,
|
||||
shift: float = DEFAULT_DEPTH_SHIFT,
|
||||
use_log: bool = DEFAULT_DEPTH_USE_LOG,
|
||||
return_quantized: bool = False,
|
||||
log_loaded_timestamps: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Decode depth-map frames at the requested timestamps using PyAV.
|
||||
|
||||
Mirrors the timestamp-tolerance / closest-frame contract of
|
||||
:func:`decode_video_frames` but operates entirely through PyAV (the
|
||||
``torchvision`` and ``torchcodec`` backends don't currently round-trip
|
||||
12-bit pixel formats reliably).
|
||||
|
||||
Each decoded frame is reformatted to ``gray12le`` so the same path
|
||||
handles ``yuv420p12le`` (HEVC default) and ``gray12le`` (ffv1 archive)
|
||||
sources transparently.
|
||||
|
||||
Args:
|
||||
video_path: Path to a depth video produced with a
|
||||
:class:`DepthEncoderConfig`.
|
||||
timestamps: Frame timestamps to retrieve, in seconds.
|
||||
tolerance_s: Maximum allowed deviation between the queried and the
|
||||
actually-decoded timestamps.
|
||||
depth_min, depth_max, shift, use_log: Parameters used at quantization
|
||||
time. Should match :func:`info_to_depth_kwargs` extracted from
|
||||
``info.json`` for the source dataset.
|
||||
return_quantized: If ``True``, skip the dequantization step and
|
||||
return raw 12-bit ``uint16`` quanta.
|
||||
log_loaded_timestamps: Debug logging.
|
||||
|
||||
Returns:
|
||||
``torch.Tensor`` of shape ``(N, H, W)``:
|
||||
|
||||
* ``dtype=torch.float32`` (metric depth, default)
|
||||
* ``dtype=torch.uint16`` when ``return_quantized=True``.
|
||||
|
||||
Raises:
|
||||
FrameTimestampError: If a query timestamp can't be matched within
|
||||
*tolerance_s*, or if no frames are decoded.
|
||||
"""
|
||||
video_path_str = str(video_path)
|
||||
first_ts = min(timestamps)
|
||||
last_ts = max(timestamps)
|
||||
|
||||
loaded_frames: list[np.ndarray] = []
|
||||
loaded_ts: list[float] = []
|
||||
|
||||
av.logging.set_level(av.logging.WARNING)
|
||||
with av.open(video_path_str, "r") as container:
|
||||
try:
|
||||
stream = container.streams.video[0]
|
||||
except IndexError as e:
|
||||
raise FrameTimestampError(f"No video stream in {video_path_str}") from e
|
||||
|
||||
# Seek to the keyframe at-or-before first_ts (PyAV doesn't do
|
||||
# accurate seek, so we still iterate forward to the requested range).
|
||||
seek_pts = int(first_ts / stream.time_base)
|
||||
container.seek(seek_pts, stream=stream, any_frame=False, backward=True)
|
||||
|
||||
for frame in container.decode(stream):
|
||||
if frame.pts is None:
|
||||
continue
|
||||
current_ts = float(frame.pts * stream.time_base)
|
||||
if log_loaded_timestamps:
|
||||
logger.info(f"depth frame loaded at timestamp={current_ts:.4f}")
|
||||
loaded_frames.append(
|
||||
decode_depth_frame(
|
||||
frame,
|
||||
depth_min=depth_min,
|
||||
depth_max=depth_max,
|
||||
shift=shift,
|
||||
use_log=use_log,
|
||||
return_quantized=True,
|
||||
)
|
||||
)
|
||||
loaded_ts.append(current_ts)
|
||||
if current_ts >= last_ts:
|
||||
break
|
||||
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
if not loaded_frames:
|
||||
raise FrameTimestampError(
|
||||
f"No depth frames decoded from {video_path_str} for timestamps {timestamps}"
|
||||
)
|
||||
|
||||
query_ts = torch.tensor(timestamps)
|
||||
loaded_ts_t = torch.tensor(loaded_ts)
|
||||
dist = torch.cdist(query_ts[:, None], loaded_ts_t[:, None], p=1)
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps violate the tolerance "
|
||||
f"({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts_t}"
|
||||
f"\nvideo: {video_path_str}"
|
||||
)
|
||||
|
||||
closest = np.stack([loaded_frames[i] for i in argmin_]) # (N, H, W) uint16
|
||||
quantized = torch.from_numpy(closest)
|
||||
|
||||
if return_quantized:
|
||||
return quantized
|
||||
return dequantize_depth(quantized, depth_min, depth_max, shift, use_log)
|
||||
|
||||
|
||||
def encode_video_frames(
|
||||
imgs_dir: Path | str,
|
||||
video_path: Path | str,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
overwrite: bool = False,
|
||||
preset: int | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
if camera_encoder_config is None:
|
||||
camera_encoder_config = VideoEncoderConfig()
|
||||
vcodec = camera_encoder_config.vcodec
|
||||
pix_fmt = camera_encoder_config.pix_fmt
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -422,42 +709,18 @@ def encode_video_frames(
|
||||
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encoders/pixel formats incompatibility check
|
||||
if (vcodec == "libsvtav1" or vcodec == "hevc") and pix_fmt == "yuv444p":
|
||||
logger.warning(
|
||||
f"Incompatible pixel format 'yuv444p' for codec {vcodec}, auto-selecting format 'yuv420p'"
|
||||
)
|
||||
pix_fmt = "yuv420p"
|
||||
|
||||
# Get input frames
|
||||
template = "frame-" + ("[0-9]" * 6) + ".png"
|
||||
input_list = sorted(
|
||||
glob.glob(str(imgs_dir / template)), key=lambda x: int(x.split("-")[-1].split(".")[0])
|
||||
)
|
||||
|
||||
# Define video output frame size (assuming all input frames are the same size)
|
||||
if len(input_list) == 0:
|
||||
raise FileNotFoundError(f"No images found in {imgs_dir}.")
|
||||
with Image.open(input_list[0]) as dummy_image:
|
||||
width, height = dummy_image.size
|
||||
|
||||
# Define video codec options
|
||||
video_options = _get_codec_options(vcodec, g, crf, preset)
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
|
||||
if encoder_threads is not None:
|
||||
if vcodec == "libsvtav1":
|
||||
lp_param = f"lp={encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(encoder_threads)
|
||||
video_options = camera_encoder_config.get_codec_options(encoder_threads, as_strings=True)
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -494,7 +757,10 @@ def encode_video_frames(
|
||||
|
||||
|
||||
def concatenate_video_files(
|
||||
input_video_paths: list[Path | str], output_video_path: Path, overwrite: bool = True
|
||||
input_video_paths: list[Path | str],
|
||||
output_video_path: Path,
|
||||
overwrite: bool = True,
|
||||
compatibility_check: bool = False,
|
||||
):
|
||||
"""
|
||||
Concatenate multiple video files into a single video file using pyav.
|
||||
@@ -507,6 +773,7 @@ def concatenate_video_files(
|
||||
input_video_paths: Ordered list of input video file paths to concatenate.
|
||||
output_video_path: Path to the output video file.
|
||||
overwrite: Whether to overwrite the output video file if it already exists. Default is True.
|
||||
compatibility_check: Whether to check if the input videos are compatible. Default is False.
|
||||
|
||||
Note:
|
||||
- Creates a temporary directory for intermediate files that is cleaned up after use.
|
||||
@@ -525,6 +792,22 @@ def concatenate_video_files(
|
||||
if len(input_video_paths) == 0:
|
||||
raise FileNotFoundError("No input video paths provided.")
|
||||
|
||||
# This check may be skipped at recording time as videos are encoded with the same encoder config.
|
||||
if compatibility_check:
|
||||
reference_video_info = get_video_info(input_video_paths[0])
|
||||
for input_path in input_video_paths[1:]:
|
||||
video_info = get_video_info(input_path)
|
||||
if (
|
||||
video_info["video.height"] != reference_video_info["video.height"]
|
||||
or video_info["video.width"] != reference_video_info["video.width"]
|
||||
or video_info["video.fps"] != reference_video_info["video.fps"]
|
||||
or video_info["video.codec"] != reference_video_info["video.codec"]
|
||||
or video_info["video.pix_fmt"] != reference_video_info["video.pix_fmt"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Input video {input_path} is not compatible with the reference video {input_video_paths[0]}."
|
||||
)
|
||||
|
||||
# Create a temporary .ffconcat file to list the input video paths
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
|
||||
tmp_concatenate_file.write("ffconcat version 1.0\n")
|
||||
@@ -591,33 +874,31 @@ class _CameraEncoderThread(threading.Thread):
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int | None,
|
||||
crf: int | None,
|
||||
preset: int | None,
|
||||
codec_options: dict[str, str],
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
depth_encoder_config: "DepthEncoderConfig | None" = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.codec_options = codec_options
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
self.depth_encoder_config = depth_encoder_config
|
||||
|
||||
|
||||
def run(self) -> None:
|
||||
from .compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
|
||||
container = None
|
||||
output_stream = None
|
||||
stats_tracker = RunningQuantileStats()
|
||||
is_depth = self.depth_encoder_config is not None
|
||||
stats_tracker = RunningQuantileStats() if not is_depth else None
|
||||
frame_count = 0
|
||||
|
||||
try:
|
||||
@@ -635,51 +916,45 @@ class _CameraEncoderThread(threading.Thread):
|
||||
# Sentinel: flush and close
|
||||
break
|
||||
|
||||
# Ensure HWC uint8 numpy array
|
||||
# Ensure HWC (RGB or depth) uint8 (RGB only) numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if frame_data.dtype != np.uint8:
|
||||
if frame_data.dtype != np.uint8 and not is_depth:
|
||||
frame_data = (frame_data * 255).astype(np.uint8)
|
||||
|
||||
# Open container on first frame (to get width/height)
|
||||
if container is None:
|
||||
height, width = frame_data.shape[:2]
|
||||
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
|
||||
if self.encoder_threads is not None:
|
||||
if self.vcodec == "libsvtav1":
|
||||
lp_param = f"lp={self.encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(self.encoder_threads)
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=self.codec_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
output_stream.time_base = Fraction(1, self.fps)
|
||||
|
||||
# Encode frame with explicit timestamps
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
if is_depth:
|
||||
video_frame = encode_depth_frame_pyav(frame_data, pix_fmt=self.pix_fmt, depth_min=self.depth_encoder_config.depth_min, depth_max=self.depth_encoder_config.depth_max, shift=self.depth_encoder_config.shift, use_log=self.depth_encoder_config.use_log)
|
||||
else:
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
video_frame.pts = frame_count
|
||||
video_frame.time_base = Fraction(1, self.fps)
|
||||
packet = output_stream.encode(video_frame)
|
||||
if packet:
|
||||
container.mux(packet)
|
||||
|
||||
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
|
||||
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
|
||||
img_downsampled = auto_downsample_height_width(img_chw)
|
||||
# Reshape CHW to (H*W, C) for per-channel stats
|
||||
channels = img_downsampled.shape[0]
|
||||
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
|
||||
stats_tracker.update(img_for_stats)
|
||||
if not is_depth:
|
||||
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
|
||||
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
|
||||
img_downsampled = auto_downsample_height_width(img_chw)
|
||||
# Reshape CHW to (H*W, C) for per-channel stats
|
||||
channels = img_downsampled.shape[0]
|
||||
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
|
||||
stats_tracker.update(img_for_stats)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
@@ -694,8 +969,10 @@ class _CameraEncoderThread(threading.Thread):
|
||||
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
# Get stats and put on result queue
|
||||
if frame_count >= 2:
|
||||
# Get stats and put on result queue (depth streams skip stats)
|
||||
if is_depth:
|
||||
self.result_queue.put(("ok", None))
|
||||
elif frame_count >= 2:
|
||||
stats = stats_tracker.get_statistics()
|
||||
self.result_queue.put(("ok", stats))
|
||||
else:
|
||||
@@ -724,22 +1001,40 @@ class StreamingVideoEncoder:
|
||||
def __init__(
|
||||
self,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
camera_encoder_config: VideoEncoderConfig | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
*,
|
||||
queue_maxsize: int = 30,
|
||||
depth_encoder_config: "DepthEncoderConfig | None" = None,
|
||||
depth_keys: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
fps: Frames per second for the output videos.
|
||||
camera_encoder_config: Video encoder settings applied to all cameras.
|
||||
When ``None``, :class:`VideoEncoderConfig` defaults are used.
|
||||
encoder_threads: Number of encoder threads (global setting).
|
||||
``None`` lets the codec decide.
|
||||
queue_maxsize: Max frames to buffer per camera before
|
||||
back-pressure drops frames.
|
||||
depth_encoder_config: Optional depth encoder configuration applied
|
||||
to all depth video keys listed in ``depth_keys``.
|
||||
depth_keys: Video keys (matching the dataset feature names) that
|
||||
must be encoded as quantized depth maps using
|
||||
``depth_encoder_config``. Required when ``depth_encoder_config``
|
||||
is provided.
|
||||
"""
|
||||
self.fps = fps
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self._camera_encoder_config = camera_encoder_config or VideoEncoderConfig()
|
||||
self._encoder_threads = encoder_threads
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.encoder_threads = encoder_threads
|
||||
self._depth_encoder_config = depth_encoder_config
|
||||
self._depth_keys: set[str] = set(depth_keys or [])
|
||||
if self._depth_keys and self._depth_encoder_config is None:
|
||||
raise ValueError(
|
||||
"StreamingVideoEncoder received depth_keys without a depth_encoder_config; "
|
||||
"either pass a DepthEncoderConfig or remove depth_keys."
|
||||
)
|
||||
|
||||
self._frame_queues: dict[str, queue.Queue] = {}
|
||||
self._result_queues: dict[str, queue.Queue] = {}
|
||||
@@ -770,18 +1065,28 @@ class StreamingVideoEncoder:
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
is_depth_key = video_key in self._depth_keys
|
||||
encoder_cfg: VideoEncoderConfig
|
||||
depth_cfg = None
|
||||
if is_depth_key:
|
||||
assert self._depth_encoder_config is not None # guaranteed by __init__
|
||||
encoder_cfg = self._depth_encoder_config
|
||||
depth_cfg = self._depth_encoder_config
|
||||
else:
|
||||
encoder_cfg = self._camera_encoder_config
|
||||
|
||||
vcodec = encoder_cfg.vcodec
|
||||
codec_options = encoder_cfg.get_codec_options(self._encoder_threads)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt=self.pix_fmt,
|
||||
g=self.g,
|
||||
crf=self.crf,
|
||||
preset=self.preset,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=encoder_cfg.pix_fmt,
|
||||
codec_options=codec_options,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self.encoder_threads,
|
||||
depth_encoder_config=depth_cfg,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
@@ -986,8 +1291,18 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
return audio_info
|
||||
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
def get_video_info(
|
||||
video_path: Path | str,
|
||||
video_encoder_config: "VideoEncoderConfig | None" = None,
|
||||
) -> dict:
|
||||
"""Build the ``video.*`` / ``audio.*`` info dict persisted in ``info.json``.
|
||||
|
||||
Args:
|
||||
video_path: Path to the encoded video file to probe.
|
||||
video_encoder_config: If provided, record the exact encoder settings used to encode this
|
||||
video. Stream-derived values take precedence — encoder fields are only written for keys
|
||||
not already populated from the video file itself.
|
||||
"""
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
# Getting video stream information
|
||||
@@ -1004,7 +1319,6 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
video_info["video.width"] = video_stream.width
|
||||
video_info["video.codec"] = video_stream.codec.canonical_name
|
||||
video_info["video.pix_fmt"] = video_stream.pix_fmt
|
||||
video_info["video.is_depth_map"] = False
|
||||
|
||||
# Calculate fps from r_frame_rate
|
||||
video_info["video.fps"] = int(video_stream.base_rate)
|
||||
@@ -1018,9 +1332,67 @@ def get_video_info(video_path: Path | str) -> dict:
|
||||
# Adding audio stream information
|
||||
video_info.update(**get_audio_info(video_path))
|
||||
|
||||
# Add additional encoder configuration if provided (no override of stream-derived values)
|
||||
# Depth related fields flow naturally through this path.
|
||||
if video_encoder_config is not None:
|
||||
for field_name, field_value in asdict(video_encoder_config).items():
|
||||
video_info.setdefault(f"video.{field_name}", field_value)
|
||||
|
||||
# Fallback case where no encoder config is provided or the video is not a depth map.
|
||||
video_info.setdefault("video.is_depth_map", False)
|
||||
|
||||
return video_info
|
||||
|
||||
|
||||
# ─── Depth metadata helpers (reader side) ────────────────────────────
|
||||
|
||||
|
||||
_DEPTH_INFO_KEYS: tuple[str, ...] = (
|
||||
"video.depth_min",
|
||||
"video.depth_max",
|
||||
"video.shift",
|
||||
"video.use_log",
|
||||
)
|
||||
|
||||
|
||||
def seed_depth_feature_info(
|
||||
features: dict[str, dict],
|
||||
depth_encoder_config: "DepthEncoderConfig | None",
|
||||
) -> None:
|
||||
"""Pre-populate per-feature ``video.<field>`` entries from *depth_encoder_config*.
|
||||
|
||||
``update_video_info`` only runs after the first episode video is encoded,
|
||||
so without this seeding step ``features[key]["info"]`` carries no
|
||||
quantization range until then. Consumers that read the dataset feature
|
||||
spec mid-recording (e.g. the rerun visualizer pinning the depth colormap
|
||||
to ``video.depth_min`` / ``video.depth_max``) would otherwise see no
|
||||
range during episode 1 and re-normalize per frame.
|
||||
|
||||
Stream-derived values written later by :func:`get_video_info` /
|
||||
``update_video_info`` win over these seeds (the merge is
|
||||
``{**existing, **stream_info}``), so callers can safely re-run this on
|
||||
a partially-populated info dict.
|
||||
|
||||
No-op when ``depth_encoder_config`` is ``None`` or no feature is flagged
|
||||
as a depth map.
|
||||
"""
|
||||
if depth_encoder_config is None:
|
||||
return
|
||||
encoder_fields = {
|
||||
f"video.{name}": value for name, value in asdict(depth_encoder_config).items()
|
||||
}
|
||||
for ft in features.values():
|
||||
if ft.get("dtype") != "video":
|
||||
continue
|
||||
info = ft.get("info") or {}
|
||||
if not info.get("video.is_depth_map", False):
|
||||
continue
|
||||
# Only fill fields not already set, so explicit user-provided info is preserved.
|
||||
for k, v in encoder_fields.items():
|
||||
info.setdefault(k, v)
|
||||
ft["info"] = info
|
||||
|
||||
|
||||
def get_video_pixel_channels(pix_fmt: str) -> int:
|
||||
if "gray" in pix_fmt or "depth" in pix_fmt or "monochrome" in pix_fmt:
|
||||
return 1
|
||||
|
||||
@@ -68,9 +68,16 @@ class SOFollower(Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
features: dict[str, tuple] = {}
|
||||
for cam in self.cameras:
|
||||
cam_cfg = self.config.cameras[cam]
|
||||
features[cam] = (cam_cfg.height, cam_cfg.width, 3)
|
||||
# Cameras with a depth stream (e.g. RealSense with use_depth=True) also
|
||||
# emit a 2D depth feature; hw_to_dataset_features routes 2D shapes to
|
||||
# ``observation.depth.<bare>`` with the depth-map marker.
|
||||
if getattr(cam_cfg, "use_depth", False):
|
||||
features[f"{cam}_depth"] = (cam_cfg.height, cam_cfg.width)
|
||||
return features
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -190,6 +197,14 @@ class SOFollower(Robot):
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms")
|
||||
|
||||
# Cameras with a depth stream populate a sibling ``<cam>_depth`` key
|
||||
# (consumed by hw_to_dataset_features / build_dataset_frame).
|
||||
if getattr(self.config.cameras[cam_key], "use_depth", False):
|
||||
start = time.perf_counter()
|
||||
obs_dict[f"{cam_key}_depth"] = cam.read_latest_depth()
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read {cam_key} depth: {dt_ms:.1f}ms")
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
|
||||
@@ -49,6 +49,14 @@ Delete episodes and save to a new dataset at a specific path and with a new repo
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]"
|
||||
|
||||
Delete episodes and re-encode video segments with h264:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]" \
|
||||
--operation.camera_encoder_config.vcodec h264 \
|
||||
--operation.camera_encoder_config.crf 23
|
||||
|
||||
Split dataset by fractions (pusht_train, pusht_val):
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
@@ -74,6 +82,14 @@ Split into more than two splits:
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}'
|
||||
|
||||
Split dataset and re-encode video segments with h264:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type split \
|
||||
--operation.splits '{"train": 0.8, "val": 0.2}' \
|
||||
--operation.camera_encoder_config.vcodec h264 \
|
||||
--operation.camera_encoder_config.crf 23
|
||||
|
||||
Merge multiple datasets:
|
||||
lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_merged \
|
||||
@@ -187,7 +203,7 @@ import abc
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
@@ -195,6 +211,8 @@ import draccus
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets import (
|
||||
LeRobotDataset,
|
||||
VideoEncoderConfig,
|
||||
camera_encoder_defaults,
|
||||
convert_image_to_video_dataset,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
@@ -218,12 +236,14 @@ class OperationConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@dataclass
|
||||
class DeleteEpisodesConfig(OperationConfig):
|
||||
episode_indices: list[int] | None = None
|
||||
camera_encoder_config: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("split")
|
||||
@dataclass
|
||||
class SplitConfig(OperationConfig):
|
||||
splits: dict[str, float | list[int]] | None = None
|
||||
camera_encoder_config: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
|
||||
|
||||
@OperationConfig.register_subclass("merge")
|
||||
@@ -250,11 +270,7 @@ class ModifyTasksConfig(OperationConfig):
|
||||
@dataclass
|
||||
class ConvertImageToVideoConfig(OperationConfig):
|
||||
output_dir: str | None = None
|
||||
vcodec: str = "libsvtav1"
|
||||
pix_fmt: str = "yuv420p"
|
||||
g: int = 2
|
||||
crf: int = 30
|
||||
fast_decode: int = 0
|
||||
camera_encoder_config: VideoEncoderConfig = field(default_factory=camera_encoder_defaults)
|
||||
episode_indices: list[int] | None = None
|
||||
num_workers: int = 4
|
||||
max_episodes_per_batch: int | None = None
|
||||
@@ -356,6 +372,7 @@ def handle_delete_episodes(cfg: EditDatasetConfig) -> None:
|
||||
episode_indices=cfg.operation.episode_indices,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
camera_encoder_config=cfg.operation.camera_encoder_config,
|
||||
)
|
||||
|
||||
logging.info(f"Dataset saved to {output_dir}")
|
||||
@@ -387,6 +404,7 @@ def handle_split(cfg: EditDatasetConfig) -> None:
|
||||
dataset,
|
||||
splits=cfg.operation.splits,
|
||||
output_dir=cfg.new_root,
|
||||
camera_encoder_config=cfg.operation.camera_encoder_config,
|
||||
)
|
||||
|
||||
for split_name, split_ds in split_datasets.items():
|
||||
@@ -557,11 +575,8 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None:
|
||||
dataset=dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
vcodec=getattr(cfg.operation, "vcodec", "libsvtav1"),
|
||||
pix_fmt=getattr(cfg.operation, "pix_fmt", "yuv420p"),
|
||||
g=getattr(cfg.operation, "g", 2),
|
||||
crf=getattr(cfg.operation, "crf", 30),
|
||||
fast_decode=getattr(cfg.operation, "fast_decode", 0),
|
||||
camera_encoder_config=getattr(cfg.operation, "camera_encoder_config", None)
|
||||
or camera_encoder_defaults(),
|
||||
episode_indices=getattr(cfg.operation, "episode_indices", None),
|
||||
num_workers=getattr(cfg.operation, "num_workers", 4),
|
||||
max_episodes_per_batch=getattr(cfg.operation, "max_episodes_per_batch", None),
|
||||
|
||||
@@ -63,6 +63,27 @@ lerobot-record \\
|
||||
--dataset.streaming_encoding=true \\
|
||||
--dataset.encoder_threads=2
|
||||
```
|
||||
|
||||
Example recording with custom video encoding parameters:
|
||||
```shell
|
||||
lerobot-record \\
|
||||
--robot.type=so100_follower \\
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \\
|
||||
--robot.cameras="{laptop: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \\
|
||||
--robot.id=black \\
|
||||
--teleop.type=so100_leader \\
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \\
|
||||
--teleop.id=blue \\
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \\
|
||||
--dataset.num_episodes=2 \\
|
||||
--dataset.single_task="Grab the cube" \\
|
||||
--dataset.streaming_encoding=true \\
|
||||
--dataset.encoder_threads=2 \\
|
||||
--dataset.camera_encoder_config.vcodec=h264 \\
|
||||
--dataset.camera_encoder_config.preset=fast \\
|
||||
--dataset.camera_encoder_config.extra_options={"tune": "film", "profile:v": "high", "bf": 2} \\
|
||||
--display_data=true
|
||||
```
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -83,10 +104,12 @@ from lerobot.common.control_utils import (
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.datasets import (
|
||||
DepthEncoderConfig,
|
||||
LeRobotDataset,
|
||||
VideoEncodingManager,
|
||||
aggregate_pipeline_dataset_features,
|
||||
create_initial_features,
|
||||
depth_encoder_defaults,
|
||||
safe_stop_image_writer,
|
||||
)
|
||||
from lerobot.processor import (
|
||||
@@ -305,7 +328,10 @@ def record_loop(
|
||||
|
||||
if display_data:
|
||||
log_rerun_data(
|
||||
observation=obs_processed, action=action_values, compress_images=display_compressed_images
|
||||
observation=obs_processed,
|
||||
action=action_values,
|
||||
compress_images=display_compressed_images,
|
||||
features=dataset.features if dataset is not None else None,
|
||||
)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
@@ -377,10 +403,11 @@ def record(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
camera_encoder_config=cfg.dataset.camera_encoder_config,
|
||||
depth_encoder_config=cfg.dataset.depth_encoder_config,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes if num_cameras > 0 else 0,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * num_cameras
|
||||
if num_cameras > 0
|
||||
@@ -406,10 +433,11 @@ def record(
|
||||
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
camera_encoder_config=cfg.dataset.camera_encoder_config,
|
||||
depth_encoder_config=cfg.dataset.depth_encoder_config,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
@@ -420,7 +448,7 @@ def record(
|
||||
|
||||
if not cfg.dataset.streaming_encoding:
|
||||
logging.info(
|
||||
"Streaming encoding is disabled. If you have capable hardware, consider enabling it for way faster episode saving. --dataset.streaming_encoding=true --dataset.encoder_threads=2 # --dataset.vcodec=auto. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding"
|
||||
"Streaming encoding is disabled. If you have capable hardware, consider enabling it for way faster episode saving. --dataset.streaming_encoding=true --dataset.encoder_threads=2 # --dataset.camera_encoder_config.vcodec=auto. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding"
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
|
||||
@@ -86,11 +86,24 @@ def hw_to_dataset_features(
|
||||
}
|
||||
|
||||
for key, shape in cam_fts.items():
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
if len(shape) == 2:
|
||||
# Single-channel feature (e.g. depth map). The hardware-side key is
|
||||
# expected to use a "_depth" suffix to disambiguate from its color
|
||||
# counterpart; we strip it so the dataset feature is published as
|
||||
# ``{prefix}.depth.<bare>`` and aligned with ``observation.images.<bare>``.
|
||||
bare = key.removesuffix("_depth") if key.endswith("_depth") else key
|
||||
features[f"{prefix}.depth.{bare}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width"],
|
||||
"info": {"video.is_depth_map": True},
|
||||
}
|
||||
else:
|
||||
features[f"{prefix}.images.{key}"] = {
|
||||
"dtype": "video" if use_video else "image",
|
||||
"shape": shape,
|
||||
"names": ["height", "width", "channels"],
|
||||
}
|
||||
|
||||
_validate_feature_names(features)
|
||||
return features
|
||||
@@ -120,7 +133,14 @@ def build_dataset_frame(
|
||||
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
|
||||
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
|
||||
elif ft["dtype"] in ["image", "video"]:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
if key.startswith(f"{prefix}.depth."):
|
||||
bare = key.removeprefix(f"{prefix}.depth.")
|
||||
# Hardware emits depth values under "<bare>_depth" to disambiguate
|
||||
# from the color stream stored at "<bare>" — fall back to the bare
|
||||
# name when the producer already uses dataset-style keys.
|
||||
frame[key] = values.get(f"{bare}_depth", values.get(bare))
|
||||
else:
|
||||
frame[key] = values[key.removeprefix(f"{prefix}.images.")]
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ def is_package_available(
|
||||
return package_exists
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
def get_safe_default_video_backend():
|
||||
logger = logging.getLogger(__name__)
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
return "torchcodec"
|
||||
|
||||
@@ -63,10 +63,56 @@ def _is_scalar(x):
|
||||
)
|
||||
|
||||
|
||||
def _derive_depth_obs_ranges(
|
||||
features: dict[str, dict] | None,
|
||||
) -> dict[str, tuple[float, float] | None]:
|
||||
"""Map observation keys of depth features to their ``(depth_min, depth_max)`` range.
|
||||
|
||||
A feature is considered a depth map when its ``info`` dict carries
|
||||
``video.is_depth_map=True`` (the marker set by ``hw_to_dataset_features``
|
||||
and persisted in ``info.json``). For each such feature, we record both
|
||||
the fully-namespaced dataset key (e.g. ``observation.depth.front``) and
|
||||
the corresponding raw observation key forms the robot is likely to emit
|
||||
(``front`` and ``front_depth``) so a single membership check covers all
|
||||
call sites.
|
||||
|
||||
The mapped value is the ``(depth_min, depth_max)`` range stored on the
|
||||
feature (matching the quantization range used at encoding time), or
|
||||
``None`` when the metadata doesn't expose a range — in which case the
|
||||
caller should let Rerun auto-normalize. Anchoring the colormap to a
|
||||
fixed range avoids per-frame re-normalization, which otherwise looks
|
||||
like flicker on near-static scenes.
|
||||
"""
|
||||
ranges: dict[str, tuple[float, float] | None] = {}
|
||||
if not features:
|
||||
return ranges
|
||||
depth_prefix = f"{OBS_STR}.depth."
|
||||
for fk, fv in features.items():
|
||||
info = fv.get("info") if isinstance(fv, dict) else None
|
||||
if not isinstance(info, dict) or not info.get("video.is_depth_map", False):
|
||||
continue
|
||||
depth_min = info.get("video.depth_min")
|
||||
depth_max = info.get("video.depth_max")
|
||||
rng: tuple[float, float] | None = None
|
||||
if (
|
||||
isinstance(depth_min, (int, float))
|
||||
and isinstance(depth_max, (int, float))
|
||||
and depth_max > depth_min
|
||||
):
|
||||
rng = (float(depth_min), float(depth_max))
|
||||
ranges[fk] = rng
|
||||
if fk.startswith(depth_prefix):
|
||||
bare = fk[len(depth_prefix) :]
|
||||
ranges[bare] = rng
|
||||
ranges[f"{bare}_depth"] = rng
|
||||
return ranges
|
||||
|
||||
|
||||
def log_rerun_data(
|
||||
observation: RobotObservation | None = None,
|
||||
action: RobotAction | None = None,
|
||||
compress_images: bool = False,
|
||||
features: dict[str, dict] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Logs observation and action data to Rerun for real-time visualization.
|
||||
@@ -76,6 +122,13 @@ def log_rerun_data(
|
||||
- Scalars values (floats, ints) are logged as `rr.Scalars`.
|
||||
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
|
||||
from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`.
|
||||
- 2D NumPy arrays whose key matches a depth feature in ``features`` (i.e. carrying
|
||||
``video.is_depth_map=True``) are logged as ``rr.DepthImage`` with the Viridis
|
||||
colormap and ``meter=1.0`` (depth values are expected in metric meters). When
|
||||
the feature exposes ``video.depth_min`` / ``video.depth_max`` (the encoder
|
||||
quantization range, persisted in ``info.json``), the colormap is anchored to
|
||||
that range via ``depth_range`` to keep the visualization stable across frames.
|
||||
Depth images are never JPEG-compressed regardless of ``compress_images``.
|
||||
- 1D NumPy arrays are logged as a series of individual scalars, with each element indexed.
|
||||
- Other multi-dimensional arrays are flattened and logged as individual scalars.
|
||||
|
||||
@@ -85,11 +138,16 @@ def log_rerun_data(
|
||||
observation: An optional dictionary containing observation data to log.
|
||||
action: An optional dictionary containing action data to log.
|
||||
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
|
||||
features: Optional dataset feature spec (e.g. ``LeRobotDataset.features``). When
|
||||
provided, observation entries matching a depth-map feature are rendered with
|
||||
``rr.DepthImage`` instead of the generic ``rr.Image`` path.
|
||||
"""
|
||||
|
||||
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
||||
import rerun as rr
|
||||
|
||||
depth_obs_ranges = _derive_depth_obs_ranges(features)
|
||||
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
if v is None:
|
||||
@@ -100,6 +158,20 @@ def log_rerun_data(
|
||||
rr.log(key, rr.Scalars(float(v)))
|
||||
elif isinstance(v, np.ndarray):
|
||||
arr = v
|
||||
is_depth = bool(depth_obs_ranges) and (k in depth_obs_ranges or key in depth_obs_ranges)
|
||||
if is_depth and arr.ndim == 2:
|
||||
# Viridis-colormapped DepthImage; never JPEG-compress (lossy on float metric depth).
|
||||
# Anchor the colormap to the encoder range when available, so the
|
||||
# visualization doesn't flicker as per-frame min/max drift.
|
||||
depth_range = depth_obs_ranges.get(k) or depth_obs_ranges.get(key)
|
||||
depth_kwargs: dict = {
|
||||
"meter": 1.0,
|
||||
"colormap": rr.components.Colormap.Viridis,
|
||||
}
|
||||
if depth_range is not None:
|
||||
depth_kwargs["depth_range"] = depth_range
|
||||
rr.log(key, rr.DepthImage(arr, **depth_kwargs), static=True)
|
||||
continue
|
||||
# Convert CHW -> HWC when needed
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
|
||||
@@ -202,6 +202,31 @@ def test_read_latest_too_old():
|
||||
_ = camera.read_latest(max_age_ms=0) # immediately too old
|
||||
|
||||
|
||||
def test_async_read_depth_without_use_depth_raises():
|
||||
"""``async_read_depth`` must reject cameras configured without ``use_depth=True``."""
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0)
|
||||
with RealSenseCamera(config) as camera, pytest.raises(RuntimeError, match="use_depth=False"):
|
||||
_ = camera.async_read_depth()
|
||||
|
||||
|
||||
def test_read_latest_depth_without_use_depth_raises():
|
||||
"""``read_latest_depth`` must reject cameras configured without ``use_depth=True``."""
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0)
|
||||
with RealSenseCamera(config) as camera, pytest.raises(RuntimeError, match="use_depth=False"):
|
||||
_ = camera.read_latest_depth()
|
||||
|
||||
|
||||
def test_depth_to_meters_uses_depth_scale():
|
||||
"""``_depth_to_meters`` must scale uint16 raw depth into float32 metric meters."""
|
||||
config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0)
|
||||
camera = RealSenseCamera(config)
|
||||
camera.depth_scale = 0.001 # typical D-series scale (1 mm/unit)
|
||||
raw = np.array([[0, 1000, 2500], [4095, 65535, 0]], dtype=np.uint16)
|
||||
meters = camera._depth_to_meters(raw)
|
||||
assert meters.dtype == np.float32
|
||||
np.testing.assert_allclose(meters, raw.astype(np.float32) * 0.001)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rotation",
|
||||
[
|
||||
|
||||
@@ -142,6 +142,36 @@ def test_create_without_videos_has_no_video_path(tmp_path):
|
||||
assert meta.video_keys == []
|
||||
|
||||
|
||||
def test_depth_keys_property_filters_by_marker(tmp_path):
|
||||
"""``depth_keys`` selects only video features carrying ``video.is_depth_map=True``."""
|
||||
features = {
|
||||
**SIMPLE_FEATURES,
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
"info": None,
|
||||
},
|
||||
"observation.depth.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96),
|
||||
"names": ["height", "width"],
|
||||
"info": {"video.is_depth_map": True},
|
||||
},
|
||||
}
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/depth_keys", fps=DEFAULT_FPS, features=features, root=tmp_path / "depth_keys"
|
||||
)
|
||||
|
||||
assert set(meta.video_keys) == {"observation.images.cam", "observation.depth.cam"}
|
||||
assert meta.depth_keys == ["observation.depth.cam"]
|
||||
|
||||
def test_depth_keys_empty_when_no_marker(tmp_path):
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/no_depth", fps=DEFAULT_FPS, features=VIDEO_FEATURES, root=tmp_path / "no_depth"
|
||||
)
|
||||
assert meta.depth_keys == []
|
||||
|
||||
def test_create_raises_on_existing_directory(tmp_path):
|
||||
"""create() raises if root directory already exists."""
|
||||
root = tmp_path / "existing"
|
||||
|
||||
@@ -20,7 +20,7 @@ import pytest
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.datasets.dataset_reader import DatasetReader
|
||||
from lerobot.utils.import_utils import get_safe_default_codec
|
||||
from lerobot.utils.import_utils import get_safe_default_video_backend
|
||||
|
||||
# ── Loading ──────────────────────────────────────────────────────────
|
||||
|
||||
@@ -35,7 +35,7 @@ def test_try_load_returns_true_when_data_exists(tmp_path, lerobot_dataset_factor
|
||||
root=dataset.root,
|
||||
episodes=None,
|
||||
tolerance_s=1e-4,
|
||||
video_backend=get_safe_default_codec(),
|
||||
video_backend=get_safe_default_video_backend(),
|
||||
delta_timestamps=None,
|
||||
image_transforms=None,
|
||||
)
|
||||
@@ -58,7 +58,7 @@ def test_try_load_returns_false_when_no_data(tmp_path):
|
||||
root=meta.root,
|
||||
episodes=None,
|
||||
tolerance_s=1e-4,
|
||||
video_backend=get_safe_default_codec(),
|
||||
video_backend=get_safe_default_video_backend(),
|
||||
delta_timestamps=None,
|
||||
image_transforms=None,
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da
|
||||
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
add_features,
|
||||
convert_image_to_video_dataset,
|
||||
delete_episodes,
|
||||
merge_datasets,
|
||||
modify_features,
|
||||
@@ -32,7 +33,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset
|
||||
from lerobot.datasets.video_utils import VideoEncoderConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1246,10 +1247,12 @@ def test_convert_image_to_video_dataset(tmp_path):
|
||||
dataset=source_dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id="lerobot/pusht_video",
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
camera_encoder_config=VideoEncoderConfig(
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
),
|
||||
episode_indices=[0, 1],
|
||||
num_workers=2,
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ pytest.importorskip("datasets", reason="datasets is required (install lerobot[da
|
||||
from lerobot.datasets.dataset_writer import _encode_video_worker
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import DEFAULT_IMAGE_PATH
|
||||
from lerobot.datasets.video_utils import VideoEncoderConfig
|
||||
from tests.fixtures.constants import DEFAULT_FPS, DUMMY_REPO_ID
|
||||
|
||||
SIMPLE_FEATURES = {
|
||||
@@ -52,8 +53,8 @@ def _make_frame(features: dict, task: str = "Dummy task") -> dict:
|
||||
# ── Existing encode_video_worker tests ───────────────────────────────
|
||||
|
||||
|
||||
def test_encode_video_worker_forwards_vcodec(tmp_path):
|
||||
"""_encode_video_worker correctly forwards the vcodec parameter."""
|
||||
def test_encode_video_worker_forwards_camera_encoder_config(tmp_path):
|
||||
"""_encode_video_worker forwards camera_encoder_config to encode_video_frames."""
|
||||
video_key = "observation.images.laptop"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
@@ -68,13 +69,21 @@ def test_encode_video_worker_forwards_vcodec(tmp_path):
|
||||
Path(video_path).touch()
|
||||
|
||||
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
|
||||
_encode_video_worker(video_key, 0, tmp_path, fps=30, vcodec="h264")
|
||||
_encode_video_worker(
|
||||
video_key,
|
||||
0,
|
||||
tmp_path,
|
||||
fps=30,
|
||||
camera_encoder_config=VideoEncoderConfig(vcodec="h264", preset=None),
|
||||
encoder_threads=4,
|
||||
)
|
||||
|
||||
assert captured_kwargs["vcodec"] == "h264"
|
||||
assert captured_kwargs["camera_encoder_config"].vcodec == "h264"
|
||||
assert captured_kwargs["encoder_threads"] == 4
|
||||
|
||||
|
||||
def test_encode_video_worker_default_vcodec(tmp_path):
|
||||
"""_encode_video_worker uses libsvtav1 as the default codec."""
|
||||
def test_encode_video_worker_default_camera_encoder_config(tmp_path):
|
||||
"""_encode_video_worker passes None camera_encoder_config which encode_video_frames defaults."""
|
||||
video_key = "observation.images.laptop"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=0, frame_index=0)
|
||||
img_dir = tmp_path / Path(fpath).parent
|
||||
@@ -91,7 +100,8 @@ def test_encode_video_worker_default_vcodec(tmp_path):
|
||||
with patch("lerobot.datasets.dataset_writer.encode_video_frames", side_effect=mock_encode):
|
||||
_encode_video_worker(video_key, 0, tmp_path, fps=30)
|
||||
|
||||
assert captured_kwargs["vcodec"] == "libsvtav1"
|
||||
assert captured_kwargs["camera_encoder_config"] is None
|
||||
assert captured_kwargs["encoder_threads"] is None
|
||||
|
||||
|
||||
# ── add_frame contracts ──────────────────────────────────────────────
|
||||
|
||||
@@ -43,7 +43,7 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
create_branch,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VALID_VIDEO_CODECS
|
||||
from lerobot.datasets.video_utils import VALID_VIDEO_CODECS, VideoEncoderConfig
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
@@ -1470,17 +1470,9 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact
|
||||
|
||||
|
||||
def test_lerobot_dataset_vcodec_validation():
|
||||
"""Test that LeRobotDataset validates the vcodec parameter."""
|
||||
# Test that invalid vcodec raises ValueError
|
||||
"""Invalid vcodec in encoder config is rejected at construction time."""
|
||||
with pytest.raises(ValueError, match="Invalid vcodec"):
|
||||
LeRobotDataset.__new__(LeRobotDataset) # bypass __init__ to test validation directly
|
||||
# Actually test via create since it's easier
|
||||
LeRobotDataset.create(
|
||||
repo_id="test/invalid_codec",
|
||||
fps=30,
|
||||
features={"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}},
|
||||
vcodec="invalid_codec",
|
||||
)
|
||||
VideoEncoderConfig(vcodec="invalid_codec")
|
||||
|
||||
|
||||
def test_valid_video_codecs_constant():
|
||||
@@ -1491,7 +1483,8 @@ def test_valid_video_codecs_constant():
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 10
|
||||
assert "ffv1" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 11
|
||||
|
||||
|
||||
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
||||
@@ -93,9 +93,32 @@ def test_image_array_to_pil_image_pytorch_format(img_array_factory):
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_single_channel(img_array_factory):
|
||||
# Single-channel inputs are routed to grayscale mode for raw depth maps.
|
||||
img_array = img_array_factory(channels=1)
|
||||
with pytest.raises(NotImplementedError):
|
||||
image_array_to_pil_image(img_array)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "L"
|
||||
assert np.array_equal(np.array(result_image), img_array.squeeze(-1))
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_single_channel_uint16(img_array_factory):
|
||||
img_array = img_array_factory(channels=1, dtype=np.uint16)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "I;16"
|
||||
# Bit-perfect: no rescaling, no clipping.
|
||||
assert np.array_equal(np.array(result_image), img_array.squeeze(-1))
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_single_channel_float32(img_array_factory):
|
||||
img_array = img_array_factory(channels=1, dtype=np.float32)
|
||||
result_image = image_array_to_pil_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "F"
|
||||
assert np.array_equal(np.array(result_image), img_array.squeeze(-1))
|
||||
|
||||
|
||||
def test_image_array_to_pil_image_4_channels(img_array_factory):
|
||||
@@ -141,6 +164,28 @@ def test_write_image_image(tmp_path, img_factory):
|
||||
assert np.array_equal(image_pil, saved_image)
|
||||
|
||||
|
||||
def test_write_image_tiff_uint16_bitperfect(tmp_path):
|
||||
"""16-bit grayscale TIFF round-trips bit-perfectly (raw depth maps)."""
|
||||
image_array = np.random.randint(0, 65535, size=(32, 48), dtype=np.uint16)
|
||||
fpath = tmp_path / "depth.tiff"
|
||||
write_image(image_array, fpath)
|
||||
assert fpath.exists()
|
||||
saved = np.array(Image.open(fpath))
|
||||
assert saved.dtype == np.uint16
|
||||
assert np.array_equal(saved, image_array)
|
||||
|
||||
|
||||
def test_write_image_tiff_float32_bitperfect(tmp_path):
|
||||
"""Float32 TIFF round-trips bit-perfectly (metric depth in meters)."""
|
||||
image_array = np.random.uniform(0.05, 4.0, size=(32, 48)).astype(np.float32)
|
||||
fpath = tmp_path / "depth.tiff"
|
||||
write_image(image_array, fpath)
|
||||
assert fpath.exists()
|
||||
saved = np.array(Image.open(fpath))
|
||||
assert saved.dtype == np.float32
|
||||
assert np.array_equal(saved, image_array)
|
||||
|
||||
|
||||
def test_write_image_exception(tmp_path):
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
|
||||
@@ -436,6 +436,37 @@ def test_add_frame_works_in_write_mode(tmp_path):
|
||||
dataset.add_frame(_make_frame()) # should not raise
|
||||
|
||||
|
||||
# ── Depth-feature plumbing ───────────────────────────────────────────
|
||||
|
||||
|
||||
_DEPTH_FEATURES = {
|
||||
**SIMPLE_FEATURES,
|
||||
"observation.depth": {
|
||||
"dtype": "video",
|
||||
"shape": (32, 32),
|
||||
"names": ["height", "width"],
|
||||
"info": {"video.is_depth_map": True},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_create_with_depth_streaming_succeeds(tmp_path):
|
||||
"""A depth dataset with streaming_encoding=True is created in write mode."""
|
||||
from lerobot.datasets.video_utils import DepthEncoderConfig
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=_DEPTH_FEATURES,
|
||||
root=tmp_path / "depth_ds",
|
||||
depth_encoder_config=DepthEncoderConfig(),
|
||||
streaming_encoding=True,
|
||||
)
|
||||
assert isinstance(dataset.writer, DatasetWriter)
|
||||
assert dataset.meta.depth_keys == ["observation.depth"]
|
||||
assert dataset._depth_encoder_config is not None
|
||||
|
||||
|
||||
# ── Resume mode ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -14,11 +14,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for streaming video encoding and hardware-accelerated encoding."""
|
||||
"""Tests for streaming video encoding."""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -27,112 +26,20 @@ pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
|
||||
|
||||
import av # noqa: E402
|
||||
|
||||
from lerobot.datasets.pyav_utils import get_codec
|
||||
from lerobot.datasets.video_utils import (
|
||||
VALID_VIDEO_CODECS,
|
||||
StreamingVideoEncoder,
|
||||
VideoEncoderConfig,
|
||||
_CameraEncoderThread,
|
||||
_get_codec_options,
|
||||
detect_available_hw_encoders,
|
||||
resolve_vcodec,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
# ─── _get_codec_options tests ───
|
||||
|
||||
|
||||
class TestGetCodecOptions:
|
||||
def test_libsvtav1_defaults(self):
|
||||
opts = _get_codec_options("libsvtav1")
|
||||
assert opts["g"] == "2"
|
||||
assert opts["crf"] == "30"
|
||||
assert opts["preset"] == "12"
|
||||
|
||||
def test_libsvtav1_custom_preset(self):
|
||||
opts = _get_codec_options("libsvtav1", preset=8)
|
||||
assert opts["preset"] == "8"
|
||||
|
||||
def test_h264_options(self):
|
||||
opts = _get_codec_options("h264", g=10, crf=23)
|
||||
assert opts["g"] == "10"
|
||||
assert opts["crf"] == "23"
|
||||
assert "preset" not in opts
|
||||
|
||||
def test_videotoolbox_options(self):
|
||||
opts = _get_codec_options("h264_videotoolbox", g=2, crf=30)
|
||||
assert opts["g"] == "2"
|
||||
# CRF 30 maps to quality = max(1, min(100, 100 - 30*2)) = 40
|
||||
assert opts["q:v"] == "40"
|
||||
assert "crf" not in opts
|
||||
|
||||
def test_nvenc_options(self):
|
||||
opts = _get_codec_options("h264_nvenc", g=2, crf=25)
|
||||
assert opts["rc"] == "constqp"
|
||||
assert opts["qp"] == "25"
|
||||
assert "crf" not in opts
|
||||
# NVENC doesn't support g
|
||||
assert "g" not in opts
|
||||
|
||||
def test_vaapi_options(self):
|
||||
opts = _get_codec_options("h264_vaapi", crf=28)
|
||||
assert opts["qp"] == "28"
|
||||
|
||||
def test_qsv_options(self):
|
||||
opts = _get_codec_options("h264_qsv", crf=25)
|
||||
assert opts["global_quality"] == "25"
|
||||
|
||||
def test_no_g_no_crf(self):
|
||||
opts = _get_codec_options("h264", g=None, crf=None)
|
||||
assert "g" not in opts
|
||||
assert "crf" not in opts
|
||||
|
||||
|
||||
# ─── HW encoder detection tests ───
|
||||
|
||||
|
||||
class TestHWEncoderDetection:
|
||||
def test_detect_available_hw_encoders_returns_list(self):
|
||||
result = detect_available_hw_encoders()
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_detect_available_hw_encoders_only_valid(self):
|
||||
from lerobot.datasets.video_utils import HW_ENCODERS
|
||||
|
||||
result = detect_available_hw_encoders()
|
||||
for encoder in result:
|
||||
assert encoder in HW_ENCODERS
|
||||
|
||||
def test_resolve_vcodec_passthrough(self):
|
||||
assert resolve_vcodec("libsvtav1") == "libsvtav1"
|
||||
assert resolve_vcodec("h264") == "h264"
|
||||
|
||||
def test_resolve_vcodec_auto_fallback(self):
|
||||
"""When no HW encoders are available, auto should fall back to libsvtav1."""
|
||||
with patch("lerobot.datasets.video_utils.detect_available_hw_encoders", return_value=[]):
|
||||
assert resolve_vcodec("auto") == "libsvtav1"
|
||||
|
||||
def test_resolve_vcodec_auto_picks_hw(self):
|
||||
"""When a HW encoder is available, auto should pick it."""
|
||||
with patch(
|
||||
"lerobot.datasets.video_utils.detect_available_hw_encoders",
|
||||
return_value=["h264_videotoolbox"],
|
||||
):
|
||||
assert resolve_vcodec("auto") == "h264_videotoolbox"
|
||||
|
||||
def test_resolve_vcodec_auto_returns_valid(self):
|
||||
"""Test that resolve_vcodec('auto') returns a known valid codec."""
|
||||
result = resolve_vcodec("auto")
|
||||
assert result in VALID_VIDEO_CODECS
|
||||
|
||||
def test_hw_encoder_names_accepted_in_validation(self):
|
||||
"""Test that HW encoder names pass validation in VALID_VIDEO_CODECS."""
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
|
||||
def test_resolve_vcodec_invalid_raises(self):
|
||||
"""Test that resolve_vcodec raises ValueError for invalid codecs."""
|
||||
with pytest.raises(ValueError, match="Invalid vcodec"):
|
||||
resolve_vcodec("not_a_real_codec")
|
||||
# Cross-codec validation tests only fire when the target codec is present
|
||||
# in the local FFmpeg build; on other platforms validate() is a no-op.
|
||||
_has_videotoolbox = get_codec("h264_videotoolbox") is not None
|
||||
_videotoolbox_only = pytest.mark.skipif(
|
||||
not _has_videotoolbox, reason="h264_videotoolbox not in local FFmpeg build"
|
||||
)
|
||||
|
||||
|
||||
# ─── _CameraEncoderThread tests ───
|
||||
@@ -150,14 +57,13 @@ class TestCameraEncoderThread:
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
enc_cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=13,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(),
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -202,14 +108,13 @@ class TestCameraEncoderThread:
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
enc_cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=13,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(),
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -237,14 +142,13 @@ class TestCameraEncoderThread:
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
enc_cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=13,
|
||||
vcodec=enc_cfg.vcodec,
|
||||
pix_fmt=enc_cfg.pix_fmt,
|
||||
codec_options=enc_cfg.get_codec_options(),
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
@@ -266,11 +170,20 @@ class TestCameraEncoderThread:
|
||||
|
||||
|
||||
class TestStreamingVideoEncoder:
|
||||
def _make_encoder_config(self, **kwargs):
|
||||
"""Helper to build a VideoEncoderConfig."""
|
||||
return VideoEncoderConfig(**kwargs)
|
||||
|
||||
def test_single_camera_episode(self, tmp_path):
|
||||
"""Test encoding a single camera episode."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.laptop"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13
|
||||
),
|
||||
)
|
||||
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 20
|
||||
@@ -295,9 +208,13 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_multi_camera_episode(self, tmp_path):
|
||||
"""Test encoding multiple cameras simultaneously."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.laptop", f"{OBS_IMAGES}.phone"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30
|
||||
),
|
||||
)
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 15
|
||||
@@ -319,8 +236,13 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_sequential_episodes(self, tmp_path):
|
||||
"""Test that multiple sequential episodes work correctly."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30
|
||||
),
|
||||
)
|
||||
|
||||
for ep in range(3):
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
@@ -342,8 +264,13 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_cancel_episode(self, tmp_path):
|
||||
"""Test that canceling an episode cleans up properly."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30
|
||||
),
|
||||
)
|
||||
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
@@ -365,28 +292,33 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_feed_without_start_raises(self, tmp_path):
|
||||
"""Test that feeding frames without starting an episode raises."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
encoder = StreamingVideoEncoder(fps=30)
|
||||
with pytest.raises(RuntimeError, match="No active episode"):
|
||||
encoder.feed_frame("cam", np.zeros((64, 96, 3), dtype=np.uint8))
|
||||
encoder.close()
|
||||
|
||||
def test_finish_without_start_raises(self, tmp_path):
|
||||
"""Test that finishing without starting raises."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
encoder = StreamingVideoEncoder(fps=30)
|
||||
with pytest.raises(RuntimeError, match="No active episode"):
|
||||
encoder.finish_episode()
|
||||
encoder.close()
|
||||
|
||||
def test_close_is_idempotent(self, tmp_path):
|
||||
"""Test that close() can be called multiple times safely."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
encoder = StreamingVideoEncoder(fps=30)
|
||||
encoder.close()
|
||||
encoder.close() # Should not raise
|
||||
|
||||
def test_video_duration_matches_frame_count(self, tmp_path):
|
||||
"""Test that encoded video duration matches num_frames / fps."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13
|
||||
),
|
||||
)
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 90 # 3 seconds at 30fps
|
||||
@@ -417,9 +349,13 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_multi_camera_start_episode_called_once(self, tmp_path):
|
||||
"""Test that with multiple cameras, no frames are lost due to double start_episode."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.cam1", f"{OBS_IMAGES}.cam2"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30
|
||||
),
|
||||
)
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 30
|
||||
@@ -446,17 +382,24 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_encoder_threads_passed_to_thread(self, tmp_path):
|
||||
"""Test that encoder_threads is stored and passed through to encoder threads."""
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, encoder_threads=2
|
||||
)
|
||||
assert encoder.encoder_threads == 2
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
cfg = VideoEncoderConfig(
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
)
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=cfg,
|
||||
encoder_threads=2,
|
||||
)
|
||||
assert encoder._encoder_threads == 2
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
# Verify the thread received the encoder_threads value
|
||||
# Verify codec options include thread tuning for libsvtav1 (lp=…)
|
||||
thread = encoder._threads[f"{OBS_IMAGES}.cam"]
|
||||
assert thread.encoder_threads == 2
|
||||
assert "svtav1-params" in thread.codec_options or "threads" in thread.codec_options
|
||||
|
||||
# Feed some frames and finish to ensure it works end-to-end
|
||||
num_frames = 10
|
||||
@@ -478,16 +421,20 @@ class TestStreamingVideoEncoder:
|
||||
|
||||
def test_encoder_threads_none_by_default(self, tmp_path):
|
||||
"""Test that encoder_threads defaults to None (codec auto-detect)."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
assert encoder.encoder_threads is None
|
||||
encoder = StreamingVideoEncoder(fps=30)
|
||||
assert encoder._encoder_threads is None
|
||||
encoder.close()
|
||||
|
||||
def test_graceful_frame_dropping(self, tmp_path):
|
||||
"""Test that full queue drops frames instead of crashing."""
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13, queue_maxsize=1
|
||||
)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30,
|
||||
camera_encoder_config=self._make_encoder_config(
|
||||
vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13
|
||||
),
|
||||
queue_maxsize=1,
|
||||
)
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
# Feed many frames quickly - with queue_maxsize=1, some will be dropped
|
||||
|
||||
581
tests/datasets/test_video_encoding.py
Normal file
581
tests/datasets/test_video_encoding.py
Normal file
@@ -0,0 +1,581 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Unit tests for ``lerobot.datasets.video_utils`` encoding functions and ``VideoEncoderConfig`` config class."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("av", reason="av is required (install lerobot[dataset])")
|
||||
|
||||
import av # noqa: E402
|
||||
|
||||
from lerobot.datasets.image_writer import write_image
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pyav_utils import get_codec
|
||||
from lerobot.datasets.utils import INFO_PATH
|
||||
from lerobot.datasets.video_utils import (
|
||||
VALID_VIDEO_CODECS,
|
||||
VideoEncoderConfig,
|
||||
concatenate_video_files,
|
||||
encode_video_frames,
|
||||
get_video_info,
|
||||
)
|
||||
|
||||
|
||||
# Per-codec skip markers — validation tests only fire when the codec is available
|
||||
def _require_encoder(vcodec: str) -> pytest.MarkDecorator:
|
||||
"""Skip the test if ``vcodec`` is not available in the local FFmpeg build."""
|
||||
return pytest.mark.skipif(get_codec(vcodec) is None, reason=f"{vcodec!r} not in local FFmpeg build")
|
||||
|
||||
|
||||
require_libsvtav1 = _require_encoder("libsvtav1")
|
||||
require_h264 = _require_encoder("h264")
|
||||
require_videotoolbox = _require_encoder("h264_videotoolbox")
|
||||
require_nvenc = _require_encoder("h264_nvenc")
|
||||
require_vaapi = _require_encoder("h264_vaapi")
|
||||
require_qsv = _require_encoder("h264_qsv")
|
||||
|
||||
|
||||
# ─── VideoEncoderConfig / codec options ──────────────────────────────
|
||||
|
||||
|
||||
class TestCodecOptions:
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_defaults(self):
|
||||
cfg = VideoEncoderConfig()
|
||||
opts = cfg.get_codec_options()
|
||||
assert opts["g"] == 2
|
||||
assert opts["crf"] == 30
|
||||
assert opts["preset"] == 12
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_custom_preset(self):
|
||||
cfg = VideoEncoderConfig(preset=8)
|
||||
assert cfg.get_codec_options()["preset"] == 8
|
||||
|
||||
@require_h264
|
||||
def test_h264_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264", g=10, crf=23, preset=None)
|
||||
opts = cfg.get_codec_options()
|
||||
assert opts["g"] == 10
|
||||
assert opts["crf"] == 23
|
||||
assert "preset" not in opts
|
||||
|
||||
@require_videotoolbox
|
||||
def test_videotoolbox_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_videotoolbox", g=2, crf=30, preset=None)
|
||||
opts = cfg.get_codec_options()
|
||||
assert opts["g"] == 2
|
||||
assert opts["q:v"] == 40
|
||||
assert "crf" not in opts
|
||||
|
||||
@_require_encoder("h264_nvenc")
|
||||
def test_nvenc_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_nvenc", g=2, crf=25, preset=None)
|
||||
opts = cfg.get_codec_options()
|
||||
assert opts["rc"] == "constqp"
|
||||
assert opts["qp"] == 25
|
||||
assert "crf" not in opts
|
||||
assert "g" not in opts
|
||||
|
||||
@_require_encoder("h264_vaapi")
|
||||
def test_vaapi_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_vaapi", crf=28, preset=None)
|
||||
assert cfg.get_codec_options()["qp"] == 28
|
||||
|
||||
@_require_encoder("h264_qsv")
|
||||
def test_qsv_options(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264_qsv", crf=25, preset=None)
|
||||
assert cfg.get_codec_options()["global_quality"] == 25
|
||||
|
||||
@require_h264
|
||||
def test_no_g_no_crf(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264", g=None, crf=None, preset=None)
|
||||
opts = cfg.get_codec_options()
|
||||
assert "g" not in opts
|
||||
assert "crf" not in opts
|
||||
|
||||
@require_libsvtav1
|
||||
def test_encoder_threads_libsvtav1(self):
|
||||
cfg = VideoEncoderConfig(fast_decode=0)
|
||||
opts = cfg.get_codec_options(encoder_threads=4)
|
||||
assert "lp=4" in opts.get("svtav1-params", "")
|
||||
|
||||
@require_h264
|
||||
def test_encoder_threads_h264(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264", preset=None)
|
||||
assert cfg.get_codec_options(encoder_threads=2)["threads"] == 2
|
||||
|
||||
@require_libsvtav1
|
||||
def test_fast_decode_libsvtav1(self):
|
||||
cfg = VideoEncoderConfig(fast_decode=1)
|
||||
opts = cfg.get_codec_options()
|
||||
assert "fast-decode=1" in opts.get("svtav1-params", "")
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_fast_decode_clamped_to_svt_range(self):
|
||||
"""Out-of-range fast_decode is clamped to [0, 2] in svtav1-params (SVT-AV1 FastDecode)."""
|
||||
cfg = VideoEncoderConfig(fast_decode=100)
|
||||
assert "fast-decode=2" in cfg.get_codec_options().get("svtav1-params", "")
|
||||
cfg_neg = VideoEncoderConfig(fast_decode=-5)
|
||||
assert "fast-decode=0" in cfg_neg.get_codec_options().get("svtav1-params", "")
|
||||
|
||||
@require_h264
|
||||
def test_fast_decode_h264(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264", fast_decode=1, preset=None)
|
||||
assert cfg.get_codec_options()["tune"] == "fastdecode"
|
||||
|
||||
@require_libsvtav1
|
||||
def test_pix_fmt_unsupported_raises(self):
|
||||
"""Passing an unsupported pix_fmt is a hard error."""
|
||||
with pytest.raises(ValueError, match="pix_fmt"):
|
||||
VideoEncoderConfig(pix_fmt="yuv444p") # libsvtav1 only supports yuv420p variants
|
||||
|
||||
@require_libsvtav1
|
||||
@require_h264
|
||||
def test_preset_default_behaviour(self):
|
||||
"""Empty constructor picks preset=12 (libsvtav1 path); other codecs stay None."""
|
||||
assert VideoEncoderConfig().preset == 12
|
||||
assert VideoEncoderConfig(vcodec="libsvtav1").preset == 12
|
||||
assert VideoEncoderConfig(vcodec="h264").preset is None
|
||||
assert VideoEncoderConfig(vcodec="h264", preset=None).preset is None
|
||||
|
||||
@require_h264
|
||||
def test_preset_string_on_h264(self):
|
||||
"""h264 accepts string presets and forwards them to FFmpeg."""
|
||||
cfg = VideoEncoderConfig(vcodec="h264", preset="slow")
|
||||
assert cfg.get_codec_options()["preset"] == "slow"
|
||||
|
||||
@require_videotoolbox
|
||||
def test_preset_on_videotoolbox_not_set(self):
|
||||
"""videotoolbox has no preset option at all."""
|
||||
cfg = VideoEncoderConfig(vcodec="h264_videotoolbox", preset="slow")
|
||||
assert "preset" not in cfg.get_codec_options()
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_preset_out_of_range_raises(self):
|
||||
"""libsvtav1 preset must sit in [-2, 13] as exposed by PyAV."""
|
||||
with pytest.raises(ValueError, match="out of range"):
|
||||
VideoEncoderConfig(vcodec="libsvtav1", preset=100)
|
||||
with pytest.raises(ValueError, match="out of range"):
|
||||
VideoEncoderConfig(vcodec="libsvtav1", preset=-3)
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_crf_out_of_range_raises(self):
|
||||
"""libsvtav1 crf must sit in [0, 63]."""
|
||||
with pytest.raises(ValueError, match="crf.*out of range"):
|
||||
VideoEncoderConfig(vcodec="libsvtav1", crf=64)
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_crf_rejects_python_float(self):
|
||||
"""libsvtav1 exposes ``crf`` as an INT AVOption; Python float must not pass validation."""
|
||||
with pytest.raises(ValueError, match="float values are not allowed"):
|
||||
VideoEncoderConfig(vcodec="libsvtav1", crf=2.5)
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_extra_crf_rejects_fractional_string(self):
|
||||
"""INT options reject fractional values even when supplied only via ``extra_options``."""
|
||||
with pytest.raises(ValueError, match="float values are not allowed"):
|
||||
VideoEncoderConfig(
|
||||
vcodec="libsvtav1",
|
||||
crf=None,
|
||||
extra_options={"crf": "2.5"},
|
||||
)
|
||||
|
||||
@require_libsvtav1
|
||||
def test_libsvtav1_extra_crf_rejects_float(self):
|
||||
with pytest.raises(ValueError, match="float values are not allowed"):
|
||||
VideoEncoderConfig(
|
||||
vcodec="libsvtav1",
|
||||
crf=None,
|
||||
extra_options={"crf": 2.5},
|
||||
)
|
||||
|
||||
@require_h264
|
||||
def test_h264_crf_accepts_float_and_int(self):
|
||||
"""x264 exposes crf as a FLOAT option, so both int and float are accepted."""
|
||||
assert VideoEncoderConfig(vcodec="h264", crf=23).get_codec_options()["crf"] == 23
|
||||
assert VideoEncoderConfig(vcodec="h264", crf=23.5).get_codec_options()["crf"] == 23.5
|
||||
|
||||
@require_libsvtav1
|
||||
def test_validate_is_rerunnable(self):
|
||||
"""After mutating a field, validate() re-checks and surfaces new issues."""
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1")
|
||||
cfg.preset = 100 # now out of range
|
||||
with pytest.raises(ValueError, match="out of range"):
|
||||
cfg.validate()
|
||||
|
||||
|
||||
class TestExtraOptions:
|
||||
@require_libsvtav1
|
||||
def test_default_is_empty_dict(self):
|
||||
cfg = VideoEncoderConfig()
|
||||
assert cfg.extra_options == {}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_unknown_key_passes_through(self):
|
||||
"""Keys not published as AVOptions are forwarded to FFmpeg."""
|
||||
cfg = VideoEncoderConfig(extra_options={"totally_made_up_option": "value"})
|
||||
assert cfg.extra_options == {"totally_made_up_option": "value"}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_numeric_value_in_range_ok(self):
|
||||
"""libsvtav1 exposes ``qp`` as INT in [0, 63]."""
|
||||
cfg = VideoEncoderConfig(extra_options={"qp": 30})
|
||||
assert cfg.extra_options == {"qp": 30}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_numeric_out_of_range_raises(self):
|
||||
with pytest.raises(ValueError, match=r"extra_options\['qp'\].*out of range"):
|
||||
VideoEncoderConfig(extra_options={"qp": 999})
|
||||
|
||||
@require_libsvtav1
|
||||
def test_numeric_string_accepted_in_range(self):
|
||||
"""Numeric strings are accepted for numeric options (mirrors FFmpeg)."""
|
||||
cfg = VideoEncoderConfig(extra_options={"qp": "18"})
|
||||
assert cfg.extra_options == {"qp": "18"}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_numeric_string_out_of_range_raises(self):
|
||||
with pytest.raises(ValueError, match=r"extra_options\['qp'\].*out of range"):
|
||||
VideoEncoderConfig(extra_options={"qp": "999"})
|
||||
|
||||
@require_libsvtav1
|
||||
def test_non_numeric_string_on_numeric_option_raises(self):
|
||||
with pytest.raises(ValueError, match=r"extra_options\['qp'\].*not numeric"):
|
||||
VideoEncoderConfig(extra_options={"qp": "medium"})
|
||||
|
||||
@require_libsvtav1
|
||||
def test_bool_on_numeric_option_raises(self):
|
||||
"""``bool`` is explicitly rejected for numeric options."""
|
||||
with pytest.raises(ValueError, match=r"extra_options\['qp'\].*not numeric"):
|
||||
VideoEncoderConfig(extra_options={"qp": True})
|
||||
|
||||
@require_h264
|
||||
def test_string_option_passes_through_unchecked(self):
|
||||
"""String-typed AVOptions are NOT enum-checked (too many accept freeform)."""
|
||||
cfg = VideoEncoderConfig(vcodec="h264", preset=None, extra_options={"tune": "some-future-tune"})
|
||||
assert cfg.extra_options == {"tune": "some-future-tune"}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_merged_into_codec_options_and_stringified(self):
|
||||
"""Typed merge by default; ``as_strings=True`` matches FFmpeg option dict."""
|
||||
cfg = VideoEncoderConfig(extra_options={"qp": 20})
|
||||
opts = cfg.get_codec_options()
|
||||
assert opts["qp"] == 20
|
||||
assert isinstance(opts["qp"], int)
|
||||
assert cfg.get_codec_options(as_strings=True)["qp"] == "20"
|
||||
|
||||
@require_libsvtav1
|
||||
def test_structured_fields_win_on_collision(self):
|
||||
"""A colliding extra_options key is discarded; the structured field wins."""
|
||||
cfg = VideoEncoderConfig(crf=30, extra_options={"crf": 18})
|
||||
assert cfg.get_codec_options()["crf"] == 30
|
||||
|
||||
|
||||
class TestEncoderDetection:
|
||||
@require_h264
|
||||
def test_explicit_codec_kept_when_available(self):
|
||||
cfg = VideoEncoderConfig(vcodec="h264")
|
||||
assert cfg.vcodec == "h264"
|
||||
|
||||
@require_videotoolbox
|
||||
def test_auto_picks_videotoolbox_when_available(self):
|
||||
"""``h264_videotoolbox`` sits at the top of ``HW_ENCODERS`` so it wins when present."""
|
||||
cfg = VideoEncoderConfig(vcodec="auto")
|
||||
assert cfg.vcodec == "h264_videotoolbox"
|
||||
|
||||
def test_invalid_codec_raises(self):
|
||||
with pytest.raises(ValueError, match="Invalid vcodec"):
|
||||
VideoEncoderConfig(vcodec="not_a_real_codec")
|
||||
|
||||
def test_hw_encoder_names_listed_as_valid(self):
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
|
||||
def test_av1_alias_resolves_to_libsvtav1(self):
|
||||
"""Older datasets persist ``vcodec="av1"``; backward-compat alias must keep them loadable."""
|
||||
cfg = VideoEncoderConfig(vcodec="av1")
|
||||
assert cfg.vcodec == "libsvtav1"
|
||||
|
||||
def test_av1_alias_persisted_after_resolve(self):
|
||||
"""Repeated calls to ``resolve_vcodec`` should be idempotent (alias only fires once)."""
|
||||
cfg = VideoEncoderConfig(vcodec="av1")
|
||||
assert cfg.vcodec == "libsvtav1"
|
||||
cfg.resolve_vcodec()
|
||||
assert cfg.vcodec == "libsvtav1"
|
||||
|
||||
|
||||
ARTIFACTS = Path(__file__).parent.parent / "fixtures" / "artifacts" / "videos"
|
||||
|
||||
# Default video feature set used by persistence tests.
|
||||
VIDEO_FEATURES = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["a", "b"]},
|
||||
}
|
||||
VIDEO_KEY = "observation.images.cam"
|
||||
|
||||
|
||||
def _write_frames(imgs_dir: Path, num_frames: int = 4, height: int = 64, width: int = 96) -> None:
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i in range(num_frames):
|
||||
arr = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
||||
write_image(arr, imgs_dir / f"frame-{i:06d}.png")
|
||||
|
||||
|
||||
def _encode_video(
|
||||
path: Path, num_frames: int = 4, fps: int = 30, cfg: VideoEncoderConfig | None = None
|
||||
) -> Path:
|
||||
imgs_dir = path.parent / f"imgs_{path.stem}"
|
||||
_write_frames(imgs_dir, num_frames=num_frames)
|
||||
encode_video_frames(imgs_dir, path, fps=fps, camera_encoder_config=cfg, overwrite=True)
|
||||
return path
|
||||
|
||||
|
||||
def _read_feature_info(dataset: LeRobotDataset) -> dict:
|
||||
info = json.loads((dataset.root / INFO_PATH).read_text())
|
||||
return info["features"][VIDEO_KEY]["info"]
|
||||
|
||||
|
||||
def _add_frames(dataset: LeRobotDataset, num_frames: int) -> None:
|
||||
shape = dataset.meta.features[VIDEO_KEY]["shape"]
|
||||
for _ in range(num_frames):
|
||||
dataset.add_frame(
|
||||
{
|
||||
VIDEO_KEY: np.random.randint(0, 256, shape, dtype=np.uint8),
|
||||
"action": np.zeros(2, dtype=np.float32),
|
||||
"task": "test",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestGetVideoInfo:
|
||||
def test_returns_all_stream_fields(self):
|
||||
info = get_video_info(ARTIFACTS / "clip_4frames.mp4")
|
||||
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.channels"] == 3
|
||||
assert info["video.is_depth_map"] is False
|
||||
assert info["has_audio"] is False
|
||||
assert "video.g" not in info
|
||||
assert "video.crf" not in info
|
||||
assert "video.preset" not in info
|
||||
|
||||
@require_libsvtav1
|
||||
def test_merges_encoder_config_as_video_prefixed_entries(self):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
|
||||
|
||||
info = get_video_info(ARTIFACTS / "clip_4frames.mp4", camera_encoder_config=cfg)
|
||||
|
||||
assert info["video.g"] == 2
|
||||
assert info["video.crf"] == 30
|
||||
assert info["video.preset"] == 12
|
||||
assert info["video.fast_decode"] == 0
|
||||
assert info["video.video_backend"] == "pyav"
|
||||
assert info["video.extra_options"] == {}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_stream_derived_keys_take_precedence_over_config(self):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
|
||||
info = get_video_info(ARTIFACTS / "clip_4frames.mp4", camera_encoder_config=cfg)
|
||||
|
||||
assert info["video.codec"] # populated from stream, not from config's vcodec
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
|
||||
|
||||
class TestEncodeVideoFrames:
|
||||
@require_libsvtav1
|
||||
def test_produces_readable_mp4(self, tmp_path):
|
||||
video_path = _encode_video(tmp_path / "out.mp4")
|
||||
|
||||
assert video_path.exists()
|
||||
info = get_video_info(video_path)
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
|
||||
@require_libsvtav1
|
||||
def test_frame_count_and_duration_match_input(self, tmp_path):
|
||||
num_frames = 10
|
||||
fps = 30
|
||||
video_path = _encode_video(tmp_path / "out.mp4", num_frames=num_frames, fps=fps)
|
||||
|
||||
with av.open(str(video_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
actual_frames = sum(1 for _ in container.decode(stream))
|
||||
duration = (
|
||||
float(stream.duration * stream.time_base)
|
||||
if stream.duration is not None
|
||||
else float(container.duration / av.time_base)
|
||||
)
|
||||
|
||||
assert actual_frames == num_frames
|
||||
assert abs(duration - num_frames / fps) < 0.1
|
||||
|
||||
def test_overwrite_false_skips_existing_file(self, tmp_path):
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
sentinel = b"pre-existing content"
|
||||
video_path.write_bytes(sentinel)
|
||||
|
||||
encode_video_frames(imgs_dir, video_path, fps=30, overwrite=False)
|
||||
|
||||
assert video_path.read_bytes() == sentinel
|
||||
|
||||
@require_libsvtav1
|
||||
def test_overwrite_true_replaces_existing_file(self, tmp_path):
|
||||
imgs_dir = tmp_path / "imgs"
|
||||
_write_frames(imgs_dir)
|
||||
video_path = tmp_path / "out.mp4"
|
||||
video_path.write_bytes(b"stale content")
|
||||
|
||||
encode_video_frames(imgs_dir, video_path, fps=30, overwrite=True)
|
||||
|
||||
info = get_video_info(video_path)
|
||||
assert info["video.height"] == 64
|
||||
|
||||
@require_libsvtav1
|
||||
def test_custom_encoder_config_fields_stored_in_info(self, tmp_path):
|
||||
"""All stream-derived and encoder config fields are present after encoding."""
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=4, crf=25, preset=10)
|
||||
video_path = _encode_video(tmp_path / "out.mp4", num_frames=4, fps=30, cfg=cfg)
|
||||
|
||||
info = get_video_info(video_path, camera_encoder_config=cfg)
|
||||
|
||||
# Stream-derived
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.channels"] == 3
|
||||
assert info["video.codec"] == "av1"
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.is_depth_map"] is False
|
||||
assert info["has_audio"] is False
|
||||
# Encoder config
|
||||
assert info["video.g"] == 4
|
||||
assert info["video.crf"] == 25
|
||||
assert info["video.preset"] == 10
|
||||
assert info["video.fast_decode"] == 0
|
||||
assert info["video.video_backend"] == "pyav"
|
||||
assert info["video.extra_options"] == {}
|
||||
|
||||
|
||||
class TestConcatenateVideoFiles:
|
||||
def test_two_clips_frame_count(self, tmp_path):
|
||||
"""Output frame count equals the sum of the two input frame counts."""
|
||||
out = tmp_path / "out.mp4"
|
||||
concatenate_video_files([ARTIFACTS / "clip_6frames.mp4", ARTIFACTS / "clip_4frames.mp4"], out)
|
||||
|
||||
with av.open(str(out)) as container:
|
||||
total = sum(1 for _ in container.decode(video=0))
|
||||
assert total == 10
|
||||
|
||||
def test_three_clips_frame_count(self, tmp_path):
|
||||
out = tmp_path / "out.mp4"
|
||||
clip = ARTIFACTS / "clip_5frames.mp4"
|
||||
concatenate_video_files([clip, clip, clip], out)
|
||||
|
||||
with av.open(str(out)) as container:
|
||||
total = sum(1 for _ in container.decode(video=0))
|
||||
assert total == 15
|
||||
|
||||
@require_libsvtav1
|
||||
def test_geometry_preserved(self, tmp_path):
|
||||
"""Output resolution, fps, codec and pixel format must match the inputs."""
|
||||
out = tmp_path / "out.mp4"
|
||||
concatenate_video_files([ARTIFACTS / "clip_4frames.mp4", ARTIFACTS / "clip_4frames.mp4"], out)
|
||||
|
||||
info = get_video_info(out)
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.codec"] == "av1"
|
||||
assert info["video.pix_fmt"] == "yuv420p"
|
||||
|
||||
def test_compatibility_check_raises_on_different_codec(self, tmp_path):
|
||||
with pytest.raises(ValueError):
|
||||
concatenate_video_files(
|
||||
[ARTIFACTS / "clip_4frames.mp4", ARTIFACTS / "clip_h264.mp4"],
|
||||
tmp_path / "out.mp4",
|
||||
compatibility_check=True,
|
||||
)
|
||||
|
||||
def test_compatibility_check_raises_on_different_resolution(self, tmp_path):
|
||||
with pytest.raises(ValueError):
|
||||
concatenate_video_files(
|
||||
[ARTIFACTS / "clip_4frames.mp4", ARTIFACTS / "clip_32x48.mp4"],
|
||||
tmp_path / "out.mp4",
|
||||
compatibility_check=True,
|
||||
)
|
||||
|
||||
|
||||
class TestEncoderConfigPersistence:
|
||||
"""Encoder config must be stored as ``video.<field>`` entries in
|
||||
``info["features"][key]["info"]`` when the first episode is saved.
|
||||
"""
|
||||
|
||||
@require_libsvtav1
|
||||
def test_first_episode_save_persists_encoder_config(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", features=VIDEO_FEATURES, use_videos=True, camera_encoder_config=cfg
|
||||
)
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
info = _read_feature_info(dataset)
|
||||
|
||||
assert info["video.height"] == 64
|
||||
assert info["video.width"] == 96
|
||||
assert info["video.fps"] == 30
|
||||
assert info["video.g"] == 2
|
||||
assert info["video.crf"] == 30
|
||||
assert info["video.preset"] == 12
|
||||
assert info["video.fast_decode"] == 0
|
||||
assert info["video.video_backend"] == "pyav"
|
||||
assert info["video.extra_options"] == {}
|
||||
|
||||
@require_libsvtav1
|
||||
def test_second_episode_does_not_overwrite_encoder_fields(self, tmp_path, empty_lerobot_dataset_factory):
|
||||
cfg = VideoEncoderConfig(vcodec="libsvtav1", g=2, crf=30, preset=12)
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds", features=VIDEO_FEATURES, use_videos=True, camera_encoder_config=cfg
|
||||
)
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
first_info = dict(_read_feature_info(dataset))
|
||||
|
||||
_add_frames(dataset, num_frames=4)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
assert _read_feature_info(dataset) == first_info
|
||||
BIN
tests/fixtures/artifacts/videos/clip_32x48.mp4
LFS
vendored
Normal file
BIN
tests/fixtures/artifacts/videos/clip_32x48.mp4
LFS
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/artifacts/videos/clip_4frames.mp4
LFS
vendored
Normal file
BIN
tests/fixtures/artifacts/videos/clip_4frames.mp4
LFS
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/artifacts/videos/clip_5frames.mp4
LFS
vendored
Normal file
BIN
tests/fixtures/artifacts/videos/clip_5frames.mp4
LFS
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/artifacts/videos/clip_6frames.mp4
LFS
vendored
Normal file
BIN
tests/fixtures/artifacts/videos/clip_6frames.mp4
LFS
vendored
Normal file
Binary file not shown.
BIN
tests/fixtures/artifacts/videos/clip_h264.mp4
LFS
vendored
Normal file
BIN
tests/fixtures/artifacts/videos/clip_h264.mp4
LFS
vendored
Normal file
Binary file not shown.
77
tests/utils/test_feature_utils.py
Normal file
77
tests/utils/test_feature_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Unit tests for ``lerobot.utils.feature_utils``."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
|
||||
|
||||
|
||||
def test_hw_to_dataset_features_routes_3d_shape_to_images():
|
||||
hw = {"front": (480, 640, 3)}
|
||||
out = hw_to_dataset_features(hw, OBS_STR, use_video=True)
|
||||
|
||||
assert "observation.images.front" in out
|
||||
assert out["observation.images.front"]["dtype"] == "video"
|
||||
assert out["observation.images.front"]["shape"] == (480, 640, 3)
|
||||
assert out["observation.images.front"]["names"] == ["height", "width", "channels"]
|
||||
assert "info" not in out["observation.images.front"]
|
||||
|
||||
|
||||
def test_hw_to_dataset_features_routes_2d_shape_to_depth():
|
||||
hw = {"front_depth": (480, 640)}
|
||||
out = hw_to_dataset_features(hw, OBS_STR, use_video=True)
|
||||
|
||||
assert "observation.depth.front" in out, out
|
||||
feat = out["observation.depth.front"]
|
||||
assert feat["dtype"] == "video"
|
||||
assert feat["shape"] == (480, 640)
|
||||
assert feat["names"] == ["height", "width"]
|
||||
assert feat["info"] == {"video.is_depth_map": True}
|
||||
|
||||
|
||||
def test_hw_to_dataset_features_handles_paired_color_and_depth():
|
||||
"""A camera with use_depth=True is expected to emit both keys."""
|
||||
hw = {"front": (480, 640, 3), "front_depth": (480, 640)}
|
||||
out = hw_to_dataset_features(hw, OBS_STR, use_video=True)
|
||||
|
||||
assert set(out) == {"observation.images.front", "observation.depth.front"}
|
||||
assert out["observation.images.front"]["shape"] == (480, 640, 3)
|
||||
assert out["observation.depth.front"]["shape"] == (480, 640)
|
||||
|
||||
|
||||
def test_hw_to_dataset_features_keeps_bare_2d_key_when_no_suffix():
|
||||
"""If the producer didn't use a "_depth" suffix, the bare name flows through."""
|
||||
hw = {"top": (240, 320)}
|
||||
out = hw_to_dataset_features(hw, OBS_STR, use_video=True)
|
||||
|
||||
assert "observation.depth.top" in out
|
||||
|
||||
|
||||
def test_build_dataset_frame_routes_depth_values():
|
||||
ds_features = hw_to_dataset_features(
|
||||
{"front": (4, 6, 3), "front_depth": (4, 6)},
|
||||
OBS_STR,
|
||||
use_video=True,
|
||||
)
|
||||
rgb = np.zeros((4, 6, 3), dtype=np.uint8)
|
||||
depth = np.full((4, 6), 0.5, dtype=np.float32)
|
||||
values = {"front": rgb, "front_depth": depth}
|
||||
|
||||
frame = build_dataset_frame(ds_features, values, OBS_STR)
|
||||
assert frame["observation.images.front"] is rgb
|
||||
assert frame["observation.depth.front"] is depth
|
||||
@@ -43,18 +43,32 @@ def mock_rerun(monkeypatch):
|
||||
def __init__(self, arr):
|
||||
self.arr = arr
|
||||
|
||||
class DummyDepthImage:
|
||||
def __init__(self, arr, meter=None, colormap=None, **kwargs):
|
||||
self.arr = arr
|
||||
self.meter = meter
|
||||
self.colormap = colormap
|
||||
self.kwargs = kwargs
|
||||
|
||||
def dummy_log(key, obj=None, **kwargs):
|
||||
# Accept either positional `obj` or keyword `entity` and record remaining kwargs.
|
||||
if obj is None and "entity" in kwargs:
|
||||
obj = kwargs.pop("entity")
|
||||
calls.append((key, obj, kwargs))
|
||||
|
||||
class _Colormap:
|
||||
Viridis = "viridis"
|
||||
|
||||
dummy_components = SimpleNamespace(Colormap=_Colormap)
|
||||
|
||||
dummy_rr = SimpleNamespace(
|
||||
__name__="rerun",
|
||||
__package__="rerun",
|
||||
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
|
||||
Scalars=DummyScalar,
|
||||
Image=DummyImage,
|
||||
DepthImage=DummyDepthImage,
|
||||
components=dummy_components,
|
||||
log=dummy_log,
|
||||
init=lambda *a, **k: None,
|
||||
spawn=lambda *a, **k: None,
|
||||
@@ -232,3 +246,122 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
|
||||
a = _obj_for(calls, "action.a")
|
||||
assert type(a).__name__ == "DummyScalar"
|
||||
assert a.value == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_log_rerun_data_routes_depth_to_depth_image(mock_rerun):
|
||||
"""A 2D float depth obs whose key matches a depth feature must use ``rr.DepthImage``.
|
||||
|
||||
Without ``video.depth_min``/``video.depth_max`` in the feature info, the
|
||||
visualizer should not pass a ``depth_range`` (Rerun then auto-normalizes).
|
||||
"""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
features = {
|
||||
"observation.images.front": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640, 3),
|
||||
"info": {"video.is_depth_map": False},
|
||||
},
|
||||
"observation.depth.front": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640),
|
||||
"info": {"video.is_depth_map": True},
|
||||
},
|
||||
}
|
||||
obs = {
|
||||
"front": np.zeros((10, 12, 3), dtype=np.uint8),
|
||||
"front_depth": np.full((10, 12), 0.7, dtype=np.float32),
|
||||
}
|
||||
|
||||
vu.log_rerun_data(observation=obs, features=features)
|
||||
|
||||
rgb = _obj_for(calls, "observation.front")
|
||||
assert type(rgb).__name__ == "DummyImage"
|
||||
|
||||
depth = _obj_for(calls, "observation.front_depth")
|
||||
assert type(depth).__name__ == "DummyDepthImage"
|
||||
assert depth.arr.shape == (10, 12)
|
||||
assert depth.meter == pytest.approx(1.0)
|
||||
assert depth.colormap == "viridis"
|
||||
# No range available -> Rerun should auto-normalize; we must not pass `depth_range`.
|
||||
assert "depth_range" not in depth.kwargs
|
||||
assert _kwargs_for(calls, "observation.front_depth").get("static", False) is True
|
||||
|
||||
|
||||
def test_log_rerun_data_depth_range_anchored_from_info(mock_rerun):
|
||||
"""When ``video.depth_min``/``depth_max`` are present, ``depth_range`` is forwarded."""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
features = {
|
||||
"observation.depth.front": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640),
|
||||
"info": {
|
||||
"video.is_depth_map": True,
|
||||
"video.depth_min": 0.05,
|
||||
"video.depth_max": 4.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
obs = {"front_depth": np.full((10, 12), 0.5, dtype=np.float32)}
|
||||
|
||||
vu.log_rerun_data(observation=obs, features=features)
|
||||
|
||||
depth = _obj_for(calls, "observation.front_depth")
|
||||
assert type(depth).__name__ == "DummyDepthImage"
|
||||
assert depth.kwargs.get("depth_range") == (0.05, 4.0)
|
||||
|
||||
|
||||
def test_log_rerun_data_depth_range_ignored_when_invalid(mock_rerun):
|
||||
"""A degenerate range (max <= min, or non-numeric) must be discarded silently."""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
features = {
|
||||
"observation.depth.front": {
|
||||
"dtype": "video",
|
||||
"shape": (480, 640),
|
||||
"info": {
|
||||
"video.is_depth_map": True,
|
||||
"video.depth_min": 1.0,
|
||||
"video.depth_max": 1.0, # degenerate
|
||||
},
|
||||
},
|
||||
}
|
||||
obs = {"front_depth": np.full((10, 12), 0.5, dtype=np.float32)}
|
||||
|
||||
vu.log_rerun_data(observation=obs, features=features)
|
||||
|
||||
depth = _obj_for(calls, "observation.front_depth")
|
||||
assert type(depth).__name__ == "DummyDepthImage"
|
||||
assert "depth_range" not in depth.kwargs
|
||||
|
||||
|
||||
def test_log_rerun_data_depth_skips_compression(mock_rerun):
|
||||
"""Depth frames must never be JPEG-compressed even when ``compress_images=True``."""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
features = {
|
||||
"observation.depth.front": {
|
||||
"dtype": "video",
|
||||
"shape": (8, 8),
|
||||
"info": {"video.is_depth_map": True},
|
||||
},
|
||||
}
|
||||
obs = {"front_depth": np.full((8, 8), 0.5, dtype=np.float32)}
|
||||
|
||||
vu.log_rerun_data(observation=obs, features=features, compress_images=True)
|
||||
|
||||
depth = _obj_for(calls, "observation.front_depth")
|
||||
assert type(depth).__name__ == "DummyDepthImage"
|
||||
|
||||
|
||||
def test_log_rerun_data_no_features_falls_back_to_image(mock_rerun):
|
||||
"""Without ``features``, a 2D array still goes through the generic Image path (no depth detection)."""
|
||||
vu, calls = mock_rerun
|
||||
|
||||
obs = {"front_depth": np.zeros((8, 8), dtype=np.float32)}
|
||||
|
||||
vu.log_rerun_data(observation=obs)
|
||||
|
||||
obj = _obj_for(calls, "observation.front_depth")
|
||||
assert type(obj).__name__ == "DummyImage"
|
||||
|
||||
Reference in New Issue
Block a user