Add script to generate embedding for dataset (#2138)

* Add generate and validate script

* fix precommit

* Improve generate embeddings function by using dataset tools (#2206)

---------

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Pepijn
2025-11-18 17:13:55 +01:00
committed by GitHub
parent 52b080fd8c
commit 9bd69bb236
5 changed files with 855 additions and 3 deletions

View File

@@ -1003,10 +1003,18 @@ def _copy_data_with_feature_changes(
df[feature_name] = feature_values
else:
feature_slice = values[frame_idx:end_idx]
if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
df[feature_name] = feature_slice.flatten()
else:
if len(feature_slice.shape) == 1:
# 1D array - can assign directly
df[feature_name] = feature_slice
elif len(feature_slice.shape) == 2 and feature_slice.shape[1] == 1:
# 2D array with single column - flatten it
df[feature_name] = feature_slice.flatten()
elif len(feature_slice.shape) == 2:
# 2D array with multiple columns (e.g., embeddings) - convert to list of lists
df[feature_name] = feature_slice.tolist()
else:
# Higher dimensional - convert to list
df[feature_name] = [row.tolist() for row in feature_slice]
frame_idx = end_idx
# Write using the preserved chunk_idx and file_idx from source

View File

@@ -0,0 +1,146 @@
# LeRobot Embedding Generation Script
Generate embeddings for LeRobot datasets to make them more lightweight and efficient for training.
## Overview
This script processes v3.0 LeRobot datasets and adds pre-computed embeddings for:
- **Task embeddings**: Language command embeddings using MiniLM
- **Image embeddings**: Frame embeddings using DinoV2
The resulting dataset can be used more efficiently during training by loading pre-computed embeddings instead of running encoders on-the-fly.
## Supported Encoders
### Image Encoders (DinoV2)
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings:
- **`dinov2_vits14`**: ViT-S/14 (384-dim) - Fastest, smaller model
- **`dinov2_vitb14`**: ViT-B/14 (768-dim) - **Recommended** - Good balance
- **`dinov2_vitl14`**: ViT-L/14 (1024-dim) - Best quality, slower
### Language Encoders (MiniLM)
MiniLM is a lightweight sentence transformer model:
- **`minilm-l6`**: MiniLM-L6-v2 (384-dim) - Faster
- **`minilm-l12`**: MiniLM-L12-v2 (384-dim) - **Recommended** - Better quality
## Usage
### Basic Command
```bash
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--push-to-hub
```
### Lightweight Version (No Videos)
Removes video files to significantly reduce storage:
```bash
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--remove-videos \
--push-to-hub
```
## Output
The script adds new features to your dataset:
### New Features
1. **`task_embedding`**: Language embedding for each frame
- Shape: `[384]` (MiniLM)
- One embedding per frame based on its task
2. **`{camera_key}_embedding`**: Image embedding for each camera view
- Shape: `[384]`, `[768]`, or `[1024]` depending on DinoV2 model
- Examples: `observation.images.top_embedding`, `observation.images.wrist_embedding`
### Using Embeddings in Training
```python
from lerobot.datasets.lerobot_dataset import LeRobotDataset
# Load dataset with embeddings
dataset = LeRobotDataset("your-username/utokyo_xarm_bimanual_embeddings")
# Access embeddings
item = dataset[0]
task_emb = item["task_embedding"] # Shape: [384]
img_emb = item["observation.images.top_embedding"] # Shape: [768]
# Use in your policy
# Instead of running encoders during training, use pre-computed embeddings
```
## Extending with New Encoders
The script is designed to be easily extensible. To add a new encoder:
### 1. Create Encoder Class
```python
class MyCustomImageEncoder(ImageEncoder):
"""Your custom image encoder."""
def __init__(self, device: str = "cuda"):
super().__init__(device)
# Load your model
self.model = load_my_model()
self.model = self.model.to(self.device)
self.model.eval()
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
# Your encoding logic here
embeddings = []
for img in images:
emb = self.model(img)
embeddings.append(emb)
return np.array(embeddings)
@property
def embedding_dim(self) -> int:
"""Return embedding dimension."""
return 512 # Your embedding dimension
```
### 2. Add to Factory Function
```python
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
encoders = {
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
# Add your encoder
"my_custom": lambda: MyCustomImageEncoder(device=device),
}
# ... rest of function
```
## Validating Embeddings
After generating embeddings, you can validate them using `validate_embeddings.py`:
```bash
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
--original-repo-id lerobot/utokyo_xarm_bimanual \
--embeddings-repo-id pepijn223/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--num-samples 20
```

View File

@@ -0,0 +1,147 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
import torch
from PIL import Image
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ImageEncoder:
"""Base class for image encoders."""
def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
raise NotImplementedError
class DinoV2Encoder(ImageEncoder):
"""DinoV2 image encoder.
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings.
Supports multiple model sizes (ViT-S/14, ViT-B/14, ViT-L/14).
"""
def __init__(self, model_name: str = "dinov2_vitb14", device: str = "cuda", batch_size: int = 32):
super().__init__(device)
self.batch_size = batch_size
self.model_name = model_name
logger.info(f"Loading DinoV2 model: {model_name}")
self.model = torch.hub.load("facebookresearch/dinov2", model_name) # nosec B614
self.model = self.model.to(self.device)
self.model.eval()
# DinoV2 preprocessing
from torchvision import transforms
self.transform = transforms.Compose(
[
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def encode(self, images: list[np.ndarray]) -> np.ndarray:
"""Encode a batch of images."""
embeddings = []
with torch.inference_mode():
for i in range(0, len(images), self.batch_size):
batch_images = images[i : i + self.batch_size]
# Convert numpy arrays to PIL Images and apply transforms
pil_images = [Image.fromarray(img.astype(np.uint8)) for img in batch_images]
tensors = torch.stack([self.transform(img) for img in pil_images]).to(self.device)
# Get embeddings
batch_embeddings = self.model(tensors).cpu().numpy()
embeddings.append(batch_embeddings)
return np.concatenate(embeddings, axis=0)
@property
def embedding_dim(self) -> int:
"""Return the embedding dimension based on model size."""
if "vits14" in self.model_name:
return 384 # DinoV2 ViT-S/14
elif "vitb14" in self.model_name:
return 768 # DinoV2 ViT-B/14
elif "vitl14" in self.model_name:
return 1024 # DinoV2 ViT-L/14
else:
return 768 # Default to ViT-B/14
class LanguageEncoder:
"""Base class for language encoders."""
def __init__(self, device: str = "cuda"):
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
def encode(self, texts: list[str]) -> np.ndarray:
"""Encode a batch of texts."""
raise NotImplementedError
class MiniLMEncoder(LanguageEncoder):
"""MiniLM language encoder.
MiniLM is a lightweight sentence transformer model that produces high-quality text embeddings.
Supports L6 and L12 model sizes.
"""
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2", device: str = "cuda"):
super().__init__(device)
self.model_name = model_name
logger.info(f"Loading MiniLM model: {model_name}")
from transformers import AutoModel, AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
def _mean_pooling(self, model_output, attention_mask):
"""Mean pooling to get sentence embeddings."""
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def encode(self, texts: list[str]) -> np.ndarray:
"""Encode a batch of texts."""
with torch.inference_mode():
encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
model_output = self.model(**encoded_input)
embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"])
return embeddings.cpu().numpy()
@property
def embedding_dim(self) -> int:
"""Return the embedding dimension."""
return 384 # Both MiniLM-L6 and L12 output 384-dim embeddings

View File

@@ -0,0 +1,329 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generate embeddings for LeRobot datasets to make them more lightweight and efficient.
This script:
1. Loads a v3.0 LeRobot dataset from the hub
2. Computes embeddings for tasks (language commands) and frames (images)
3. Stores embeddings as new features in the dataset
4. Optionally removes video files to reduce size
5. Pushes the converted dataset to the hub
Current supported encoders:
- Image: DinoV2 (dinov2_vits14, dinov2_vitb14, dinov2_vitl14)
- Language: MiniLM (minilm-l6, minilm-l12)
The architecture is extensible - you can add more encoders by:
1. Creating a new encoder class inheriting from ImageEncoder or LanguageEncoder
2. Implementing the encode() method and embedding_dim property
3. Adding it to the get_image_encoder() or get_language_encoder() factory function
Usage example:
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
--repo-id lerobot/utokyo_xarm_bimanual \
--output-repo-id lerobot/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--remove-videos \
--push-to-hub
"""
import argparse
import shutil
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.generating_embeddings.encoders import (
DinoV2Encoder,
ImageEncoder,
LanguageEncoder,
MiniLMEncoder,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
"""Factory function to get image encoder.
To add a new encoder:
1. Create a new class inheriting from ImageEncoder
2. Implement encode() and embedding_dim property
3. Add it to the encoders dictionary below
"""
encoders = {
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
}
if encoder_name not in encoders:
raise ValueError(f"Unknown image encoder: {encoder_name}. Available options: {list(encoders.keys())}")
return encoders[encoder_name]()
def get_language_encoder(encoder_name: str, device: str = "cuda") -> LanguageEncoder:
"""Factory function to get language encoder.
To add a new encoder:
1. Create a new class inheriting from LanguageEncoder
2. Implement encode() and embedding_dim property
3. Add it to the encoders dictionary below
"""
encoders = {
"minilm-l6": lambda: MiniLMEncoder(
model_name="sentence-transformers/all-MiniLM-L6-v2", device=device
),
"minilm-l12": lambda: MiniLMEncoder(
model_name="sentence-transformers/all-MiniLM-L12-v2", device=device
),
}
if encoder_name not in encoders:
raise ValueError(
f"Unknown language encoder: {encoder_name}. Available options: {list(encoders.keys())}"
)
return encoders[encoder_name]()
def generate_embeddings_for_dataset(
repo_id: str,
output_repo_id: str,
image_encoder: ImageEncoder,
language_encoder: LanguageEncoder,
remove_videos: bool = False,
local_dir: Path | None = None,
output_local_dir: Path | None = None,
push_to_hub: bool = False,
):
"""Generate embeddings for a LeRobot dataset.
Args:
repo_id: Source dataset repository ID
output_repo_id: Output dataset repository ID
image_encoder: Image encoder instance
language_encoder: Language encoder instance
remove_videos: Whether to remove video files
local_dir: Local directory for source dataset
output_local_dir: Local directory for output dataset
push_to_hub: Whether to push to hub after conversion
"""
from lerobot.datasets.dataset_tools import modify_features
print(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(repo_id, root=local_dir, download_videos=True)
print(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
print("Computing task embeddings...")
unique_tasks = dataset.meta.tasks.index.tolist()
task_embeddings = {}
for task in tqdm(unique_tasks, desc="Encoding tasks"):
# Clean up task text
task_clean = task.strip().capitalize().strip(" .,!?-_")
embedding = language_encoder.encode([task_clean])[0]
task_embeddings[task] = embedding
print(f"Computed {len(task_embeddings)} task embeddings")
print("Processing frames and computing embeddings...")
all_task_embeddings = []
all_image_embeddings_dict = {cam_key: [] for cam_key in dataset.meta.camera_keys}
for frame_idx in tqdm(range(dataset.num_frames), desc="Processing frames"):
item = dataset.hf_dataset[frame_idx]
ep_idx = item["episode_index"].item()
task = dataset.meta.tasks.iloc[item["task_index"].item()].name
task_emb = task_embeddings[task]
all_task_embeddings.append(task_emb)
for cam_key in dataset.meta.camera_keys:
if cam_key in dataset.meta.video_keys:
current_ts = item["timestamp"].item()
video_frames = dataset._query_videos({cam_key: [current_ts]}, ep_idx)
img = video_frames[cam_key]
if isinstance(img, torch.Tensor):
if img.ndim == 4:
img = img[0] # (T, C, H, W) -> (C, H, W)
elif img.ndim != 3:
raise ValueError(f"Unexpected video frame shape {img.shape} for camera {cam_key}")
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
img_np = np.array(img)
else:
img = item[cam_key]
if isinstance(img, torch.Tensor):
if img.ndim == 3:
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
raise ValueError(f"Unexpected image shape {img.shape} for camera {cam_key}")
else:
img_np = np.array(img)
all_image_embeddings_dict[cam_key].append(img_np)
print("Computing image embeddings...")
image_embeddings_dict = {}
for cam_key, images in all_image_embeddings_dict.items():
print(f" {cam_key}: {len(images)} images")
embeddings = image_encoder.encode(images)
image_embeddings_dict[cam_key] = embeddings
all_task_embeddings = np.array(all_task_embeddings)
for cam_key in dataset.meta.camera_keys:
image_embeddings_dict[cam_key] = np.array(image_embeddings_dict[cam_key])
img_emb_dim = image_encoder.embedding_dim
lang_emb_dim = language_encoder.embedding_dim
add_features_dict = {
"task_embedding": (
all_task_embeddings,
{"dtype": "float32", "shape": [lang_emb_dim], "names": None},
),
}
for cam_key in dataset.meta.camera_keys:
add_features_dict[f"{cam_key}_embedding"] = (
image_embeddings_dict[cam_key],
{"dtype": "float32", "shape": [img_emb_dim], "names": None},
)
print("Adding embeddings to dataset...")
remove_features_list = None
if remove_videos:
remove_features_list = dataset.meta.video_keys
output_dataset = modify_features(
dataset=dataset,
add_features=add_features_dict,
remove_features=remove_features_list,
output_dir=output_local_dir,
repo_id=output_repo_id,
)
if remove_videos:
print("Removing video files...")
videos_dir = output_dataset.root / "videos"
if videos_dir.exists():
shutil.rmtree(videos_dir)
print(f"Saved to: {output_dataset.root}")
if push_to_hub:
print(f"Pushing to hub: {output_repo_id}")
output_dataset.push_to_hub(push_videos=not remove_videos)
print("Done!")
def main():
parser = argparse.ArgumentParser(
description="Generate embeddings for LeRobot datasets",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Basic usage with default encoders (DinoV2 ViT-B/14 + MiniLM-L12)
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
--repo-id lerobot/utokyo_xarm_bimanual \\
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--push-to-hub
# Generate embeddings and remove videos
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
--repo-id lerobot/utokyo_xarm_bimanual \\
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--remove-videos \\
--push-to-hub
Available image encoders:
- dinov2_vits14: DinoV2 ViT-S/14 (384-dim, faster)
- dinov2_vitb14: DinoV2 ViT-B/14 (768-dim, recommended)
- dinov2_vitl14: DinoV2 ViT-L/14 (1024-dim, best quality)
Available language encoders:
- minilm-l6: MiniLM-L6-v2 (384-dim, faster)
- minilm-l12: MiniLM-L12-v2 (384-dim, recommended)
""",
)
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repository ID")
parser.add_argument("--output-repo-id", type=str, required=True, help="Output dataset repository ID")
parser.add_argument(
"--image-encoder",
type=str,
default="dinov2_vitb14",
help="Image encoder to use (default: dinov2_vitb14)",
)
parser.add_argument(
"--language-encoder",
type=str,
default="minilm-l12",
help="Language encoder to use (default: minilm-l12)",
)
parser.add_argument(
"--remove-videos",
action="store_true",
help="Remove video files after generating embeddings",
)
parser.add_argument("--local-dir", type=str, default=None, help="Local directory for source dataset")
parser.add_argument(
"--output-local-dir", type=str, default=None, help="Local directory for output dataset"
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push the converted dataset to the hub",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for encoding (default: cuda)",
)
args = parser.parse_args()
# Load encoders
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
# Generate embeddings
generate_embeddings_for_dataset(
repo_id=args.repo_id,
output_repo_id=args.output_repo_id,
image_encoder=image_encoder,
language_encoder=language_encoder,
remove_videos=args.remove_videos,
local_dir=Path(args.local_dir) if args.local_dir else None,
output_local_dir=Path(args.output_local_dir) if args.output_local_dir else None,
push_to_hub=args.push_to_hub,
)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,222 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Validate pre-computed embeddings against on-the-fly computed embeddings.
Usage:
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
--original-repo-id lerobot/utokyo_xarm_bimanual \
--embeddings-repo-id <your_username>/utokyo_xarm_bimanual_embeddings \
--image-encoder dinov2_vitb14 \
--language-encoder minilm-l12 \
--num-samples 10
"""
import argparse
import numpy as np
import torch
from tqdm import tqdm
from lerobot.datasets.generating_embeddings.encoders import ImageEncoder, LanguageEncoder
from lerobot.datasets.generating_embeddings.generate_embeddings import (
get_image_encoder,
get_language_encoder,
)
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Compute cosine similarity between two vectors."""
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def validate_embeddings(
original_repo_id: str,
embeddings_repo_id: str,
image_encoder: ImageEncoder,
language_encoder: LanguageEncoder,
num_samples: int = 10,
device: str = "cuda",
):
"""Validate pre-computed embeddings against on-the-fly embeddings.
Args:
original_repo_id: Original dataset repository ID
embeddings_repo_id: Dataset with pre-computed embeddings repository ID
image_encoder: Image encoder instance
language_encoder: Language encoder instance
num_samples: Number of samples to validate
device: Device to use for encoding
"""
# Load both datasets
print("Loading datasets...")
original_dataset = LeRobotDataset(original_repo_id, download_videos=True)
embeddings_dataset = LeRobotDataset(embeddings_repo_id, download_videos=False)
# Verify both datasets have the same number of frames
assert original_dataset.num_frames == embeddings_dataset.num_frames, (
f"Frame count mismatch: original={original_dataset.num_frames}, "
f"embeddings={embeddings_dataset.num_frames}"
)
camera_keys = original_dataset.meta.camera_keys
# Check embedding features exist
expected_features = ["task_embedding"] + [f"{cam}_embedding" for cam in camera_keys]
for feat in expected_features:
if feat not in embeddings_dataset.features:
raise ValueError(f"Embedding feature not found: {feat}")
# Select random sample indices
sample_indices = np.random.choice(
original_dataset.num_frames, size=min(num_samples, original_dataset.num_frames), replace=False
)
print(f"Validating {len(sample_indices)} samples...")
# Track statistics
task_similarities = []
image_similarities = {cam: [] for cam in camera_keys}
for idx in tqdm(sample_indices, desc="Validating"):
idx = int(idx)
embeddings_item = embeddings_dataset[idx]
precomputed_task_emb = embeddings_item["task_embedding"].numpy()
precomputed_image_embs = {cam: embeddings_item[f"{cam}_embedding"].numpy() for cam in camera_keys}
original_item = original_dataset[idx]
# Get task and compute embedding
task = original_item["task"]
# Clean up task text (same as in generate_embeddings.py)
task_clean = task.strip().capitalize().strip(" .,!?-_")
onthefly_task_emb = language_encoder.encode([task_clean])[0]
# Get images and compute embeddings
onthefly_image_embs = {}
for cam in camera_keys:
img = original_item[cam]
# Convert to numpy if needed
if isinstance(img, torch.Tensor):
if img.ndim == 3: # (C, H, W)
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
else:
raise ValueError(f"Unexpected image shape: {img.shape}")
else:
img_np = np.array(img)
onthefly_image_embs[cam] = image_encoder.encode([img_np])[0]
# Task embedding comparison
task_sim = cosine_similarity(precomputed_task_emb, onthefly_task_emb)
task_similarities.append(task_sim)
# Image embedding comparison
for cam in camera_keys:
img_sim = cosine_similarity(precomputed_image_embs[cam], onthefly_image_embs[cam])
image_similarities[cam].append(img_sim)
# Results
print("\nResults:")
task_sim_threshold = 0.99
img_sim_threshold = 0.99
task_mean_sim = np.mean(task_similarities)
task_pass = task_mean_sim >= task_sim_threshold
print(f" Task: {task_mean_sim:.4f} {'' if task_pass else ''}")
for cam in camera_keys:
cam_mean_sim = np.mean(image_similarities[cam])
cam_pass = cam_mean_sim >= img_sim_threshold
print(f" {cam}: {cam_mean_sim:.4f} {'' if cam_pass else ''}")
image_pass = all(np.mean(image_similarities[cam]) >= img_sim_threshold for cam in camera_keys)
print()
if task_pass and image_pass:
print("✓ PASSED")
else:
print("✗ FAILED")
def main():
parser = argparse.ArgumentParser(
description="Validate and compare pre-computed embeddings with on-the-fly embeddings",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Example:
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \\
--original-repo-id lerobot/utokyo_xarm_bimanual \\
--embeddings-repo-id lerobot/utokyo_xarm_bimanual_embeddings \\
--image-encoder dinov2_vitb14 \\
--language-encoder minilm-l12 \\
--num-samples 20
""",
)
parser.add_argument("--original-repo-id", type=str, required=True, help="Original dataset repository ID")
parser.add_argument(
"--embeddings-repo-id",
type=str,
required=True,
help="Dataset with pre-computed embeddings repository ID",
)
parser.add_argument(
"--image-encoder",
type=str,
default="dinov2_vitb14",
help="Image encoder to use (default: dinov2_vitb14)",
)
parser.add_argument(
"--language-encoder",
type=str,
default="minilm-l12",
help="Language encoder to use (default: minilm-l12)",
)
parser.add_argument(
"--num-samples",
type=int,
default=10,
help="Number of samples to validate (default: 10)",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device to use for encoding (default: cuda)",
)
args = parser.parse_args()
# Load encoders
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
# Validate embeddings
validate_embeddings(
original_repo_id=args.original_repo_id,
embeddings_repo_id=args.embeddings_repo_id,
image_encoder=image_encoder,
language_encoder=language_encoder,
num_samples=args.num_samples,
device=args.device,
)
if __name__ == "__main__":
main()