mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 04:11:24 +00:00
757 lines
27 KiB
Python
757 lines
27 KiB
Python
|
|
#!/usr/bin/env python
|
|||
|
|
|
|||
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|||
|
|
#
|
|||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|||
|
|
# you may not use this file except in compliance with the License.
|
|||
|
|
# You may obtain a copy of the License at
|
|||
|
|
#
|
|||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|||
|
|
#
|
|||
|
|
# Unless required by applicable law or agreed to in writing, software
|
|||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
|
|
# See the License for the specific language governing permissions and
|
|||
|
|
# limitations under the License.
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
Synthetic Data Generation for Hi-Robot Style Hierarchical Policy Training.
|
|||
|
|
|
|||
|
|
This script generates synthetic user prompts (ℓ_t) and robot utterances (u_t) for
|
|||
|
|
hierarchical policy training using Qwen VLM as the generator model (pgen).
|
|||
|
|
|
|||
|
|
The pipeline:
|
|||
|
|
1. Loads a LeRobot dataset with skill annotations (from annotate.py)
|
|||
|
|
2. For each frame, generates synthetic dialogue based on:
|
|||
|
|
- Visual context (images at time t)
|
|||
|
|
- Current skill being performed
|
|||
|
|
- History of previous skills
|
|||
|
|
- High-level task description
|
|||
|
|
3. Saves results as high-level tasks and updates dataset with task_index_high_level
|
|||
|
|
|
|||
|
|
Usage:
|
|||
|
|
```bash
|
|||
|
|
python examples/dataset/annotate_pgen.py \
|
|||
|
|
--repo-id lerobot/svla_so101_pickplace \
|
|||
|
|
--model Qwen/Qwen2-VL-7B-Instruct \
|
|||
|
|
--output-dir /path/to/output \
|
|||
|
|
--batch-size 1
|
|||
|
|
```
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import argparse
|
|||
|
|
import json
|
|||
|
|
import re
|
|||
|
|
import textwrap
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
import pandas as pd
|
|||
|
|
import torch
|
|||
|
|
from PIL import Image
|
|||
|
|
from rich.console import Console
|
|||
|
|
from rich.progress import Progress, SpinnerColumn, TextColumn
|
|||
|
|
from tqdm import tqdm
|
|||
|
|
|
|||
|
|
from lerobot.datasets.dataset_tools import add_features
|
|||
|
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# Prompt Template for pgen
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
PGEN_PROMPT_TEMPLATE = textwrap.dedent("""\
|
|||
|
|
# Role
|
|||
|
|
You are a robot-assistant dialogue generator for hierarchical robot policies.
|
|||
|
|
|
|||
|
|
# Task
|
|||
|
|
You will receive:
|
|||
|
|
- A list of images showing the current robot scene at time t
|
|||
|
|
- The high-level task: {task_description}
|
|||
|
|
- Previous skill steps completed: {skill_history}
|
|||
|
|
- The next skill to be performed by the robot: {skill_current}
|
|||
|
|
|
|||
|
|
# Your Goal
|
|||
|
|
Generate two things that create a natural human-robot interaction:
|
|||
|
|
1. **user_prompt**: A natural-sounding user request that logically leads to the robot
|
|||
|
|
performing the skill "{skill_current}" given the task context and history.
|
|||
|
|
2. **robot_utterance**: A natural robot reply acknowledging or clarifying the request.
|
|||
|
|
|
|||
|
|
# Guidelines
|
|||
|
|
- The user prompt should be grounded in the visual scene and task context
|
|||
|
|
- Vary interaction types: direct commands, implicit requests, corrections, constraints
|
|||
|
|
- Examples of user prompt styles:
|
|||
|
|
* Direct: "Can you pick up the red brick?"
|
|||
|
|
* Implicit: "I need something red for the tower"
|
|||
|
|
* Negative: "Don't pick up the blue one"
|
|||
|
|
* Constraint: "Make sure to handle it gently"
|
|||
|
|
* Correction: "Actually, move to the other box instead"
|
|||
|
|
- Robot responses should be appropriate: confirmations, clarifications, or error handling
|
|||
|
|
- Use the skill history to ensure continuity (don't repeat past actions)
|
|||
|
|
- Consider world knowledge (dietary preferences, object properties, etc.)
|
|||
|
|
|
|||
|
|
# Scenario Types (choose one that fits):
|
|||
|
|
- **specific_object**: User specifies exact object/action
|
|||
|
|
- **negative_task**: User says what NOT to do
|
|||
|
|
- **situated_correction**: User adjusts based on current state
|
|||
|
|
- **implicit_request**: User implies need without direct command
|
|||
|
|
- **constraint_based**: User adds specific constraints
|
|||
|
|
|
|||
|
|
# Response Types (choose one that fits):
|
|||
|
|
- **confirmation**: Simple "OK, I'll do X"
|
|||
|
|
- **clarification**: "Just to confirm, you want me to..."
|
|||
|
|
- **acknowledgment**: "Got it, [doing action]"
|
|||
|
|
- **constraint_acknowledgment**: "Sure, I'll [action] while [constraint]"
|
|||
|
|
|
|||
|
|
# Output Format
|
|||
|
|
Respond ONLY with valid JSON:
|
|||
|
|
{{
|
|||
|
|
"scenario_type": "one of the types above",
|
|||
|
|
"response_type": "one of the types above",
|
|||
|
|
"user_prompt": "natural user request here",
|
|||
|
|
"robot_utterance": "natural robot response here"
|
|||
|
|
}}
|
|||
|
|
|
|||
|
|
The responses must be grounded in the visual scene, the task, and the skill history.
|
|||
|
|
Make it sound like a real human-robot interaction.
|
|||
|
|
""")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def construct_prompt(
|
|||
|
|
task_description: str,
|
|||
|
|
skill_history: list[str],
|
|||
|
|
skill_current: str,
|
|||
|
|
) -> str:
|
|||
|
|
"""
|
|||
|
|
Construct the text prompt for pgen.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
task_description: High-level task description
|
|||
|
|
skill_history: List of previously completed skills
|
|||
|
|
skill_current: Current skill to be performed
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Formatted prompt string
|
|||
|
|
"""
|
|||
|
|
# Format skill history nicely
|
|||
|
|
if skill_history:
|
|||
|
|
history_str = ", ".join(f'"{s}"' for s in skill_history[-5:]) # Last 5 for context
|
|||
|
|
if len(skill_history) > 5:
|
|||
|
|
history_str = f"... {history_str}"
|
|||
|
|
else:
|
|||
|
|
history_str = "None (starting the task)"
|
|||
|
|
|
|||
|
|
return PGEN_PROMPT_TEMPLATE.format(
|
|||
|
|
task_description=task_description,
|
|||
|
|
skill_history=history_str,
|
|||
|
|
skill_current=skill_current,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# Qwen VLM Interface
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
class QwenPgen:
|
|||
|
|
"""Qwen VLM wrapper for synthetic dialogue generation."""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
model_name: str,
|
|||
|
|
device: str = "cuda",
|
|||
|
|
torch_dtype: torch.dtype = torch.bfloat16,
|
|||
|
|
temperature: float = 0.7,
|
|||
|
|
):
|
|||
|
|
from qwen_vl_utils import process_vision_info
|
|||
|
|
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
|||
|
|
|
|||
|
|
self.console = Console()
|
|||
|
|
self.device = device
|
|||
|
|
self.model_name = model_name
|
|||
|
|
self.temperature = temperature
|
|||
|
|
self.process_vision_info = process_vision_info
|
|||
|
|
|
|||
|
|
self.console.print(f"[cyan]Loading Qwen model: {model_name}...[/cyan]")
|
|||
|
|
|
|||
|
|
# Load model based on name
|
|||
|
|
if "qwen3" in model_name.lower():
|
|||
|
|
from transformers import Qwen3VLMoeForConditionalGeneration
|
|||
|
|
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
|
|||
|
|
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|||
|
|
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
|||
|
|
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
|||
|
|
|
|||
|
|
def call_qwen(
|
|||
|
|
self,
|
|||
|
|
images: list[Image.Image | str],
|
|||
|
|
prompt: str,
|
|||
|
|
) -> dict[str, str]:
|
|||
|
|
"""
|
|||
|
|
Call Qwen VLM to generate synthetic dialogue.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
images: List of PIL Images or image paths
|
|||
|
|
prompt: Text prompt for generation
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Dictionary with keys: scenario_type, response_type, user_prompt, robot_utterance
|
|||
|
|
"""
|
|||
|
|
# Build messages with images and text
|
|||
|
|
content = []
|
|||
|
|
for img in images:
|
|||
|
|
if isinstance(img, str):
|
|||
|
|
content.append({"type": "image", "image": img})
|
|||
|
|
else:
|
|||
|
|
# PIL Image - need to save temporarily or convert
|
|||
|
|
content.append({"type": "image", "image": img})
|
|||
|
|
|
|||
|
|
content.append({"type": "text", "text": prompt})
|
|||
|
|
|
|||
|
|
messages = [
|
|||
|
|
{
|
|||
|
|
"role": "user",
|
|||
|
|
"content": content,
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
# Process inputs
|
|||
|
|
text = self.processor.apply_chat_template(
|
|||
|
|
messages, tokenize=False, add_generation_prompt=True
|
|||
|
|
)
|
|||
|
|
image_inputs, video_inputs = self.process_vision_info(messages)
|
|||
|
|
|
|||
|
|
inputs = self.processor(
|
|||
|
|
text=[text],
|
|||
|
|
images=image_inputs,
|
|||
|
|
videos=video_inputs,
|
|||
|
|
padding=True,
|
|||
|
|
return_tensors="pt",
|
|||
|
|
).to(self.device)
|
|||
|
|
|
|||
|
|
# Generate
|
|||
|
|
with torch.no_grad():
|
|||
|
|
generated_ids = self.model.generate(
|
|||
|
|
**inputs,
|
|||
|
|
max_new_tokens=512,
|
|||
|
|
do_sample=True,
|
|||
|
|
temperature=self.temperature,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Decode response
|
|||
|
|
response = self.processor.batch_decode(
|
|||
|
|
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
|||
|
|
skip_special_tokens=True,
|
|||
|
|
)[0].strip()
|
|||
|
|
|
|||
|
|
return self._parse_response(response)
|
|||
|
|
|
|||
|
|
def _parse_response(self, response: str) -> dict[str, str]:
|
|||
|
|
"""Parse JSON response from model."""
|
|||
|
|
# Extract JSON from response
|
|||
|
|
if "```json" in response:
|
|||
|
|
response = response.split("```json")[1].split("```")[0]
|
|||
|
|
elif "```" in response:
|
|||
|
|
response = response.split("```")[1].split("```")[0]
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
data = json.loads(response)
|
|||
|
|
return {
|
|||
|
|
"scenario_type": data.get("scenario_type", "specific_object"),
|
|||
|
|
"response_type": data.get("response_type", "confirmation"),
|
|||
|
|
"user_prompt": data.get("user_prompt", ""),
|
|||
|
|
"robot_utterance": data.get("robot_utterance", ""),
|
|||
|
|
}
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
# Try to find JSON object in response
|
|||
|
|
match = re.search(r"\{.*\}", response, re.DOTALL)
|
|||
|
|
if match:
|
|||
|
|
data = json.loads(match.group())
|
|||
|
|
return {
|
|||
|
|
"scenario_type": data.get("scenario_type", "specific_object"),
|
|||
|
|
"response_type": data.get("response_type", "confirmation"),
|
|||
|
|
"user_prompt": data.get("user_prompt", ""),
|
|||
|
|
"robot_utterance": data.get("robot_utterance", ""),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
raise ValueError(f"Could not parse response: {response[:200]}...")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# Annotation Pipeline
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
def load_skills_metadata(dataset_root: Path) -> dict | None:
|
|||
|
|
"""Load skills.json metadata from annotated dataset."""
|
|||
|
|
skills_path = dataset_root / "meta" / "skills.json"
|
|||
|
|
if skills_path.exists():
|
|||
|
|
with open(skills_path) as f:
|
|||
|
|
return json.load(f)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_skill_at_timestamp(skills: list[dict], timestamp: float) -> str | None:
|
|||
|
|
"""Find which skill covers a given timestamp."""
|
|||
|
|
for skill in skills:
|
|||
|
|
if skill["start"] <= timestamp < skill["end"]:
|
|||
|
|
return skill["name"]
|
|||
|
|
# Handle last frame
|
|||
|
|
if timestamp >= skill["end"] and skill == skills[-1]:
|
|||
|
|
return skill["name"]
|
|||
|
|
return skills[-1]["name"] if skills else None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def annotate_sample(
|
|||
|
|
pgen: QwenPgen,
|
|||
|
|
images: list[Image.Image | str],
|
|||
|
|
task_description: str,
|
|||
|
|
skill_history: list[str],
|
|||
|
|
skill_current: str,
|
|||
|
|
) -> dict[str, str]:
|
|||
|
|
"""
|
|||
|
|
Generate synthetic dialogue for a single sample.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
pgen: Qwen model wrapper
|
|||
|
|
images: List of images at current timestep
|
|||
|
|
task_description: High-level task description
|
|||
|
|
skill_history: Previous skills completed
|
|||
|
|
skill_current: Current skill being performed
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Dictionary with generated dialogue
|
|||
|
|
"""
|
|||
|
|
prompt = construct_prompt(task_description, skill_history, skill_current)
|
|||
|
|
result = pgen.call_qwen(images, prompt)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_synthetic_data(
|
|||
|
|
dataset: LeRobotDataset,
|
|||
|
|
pgen: QwenPgen,
|
|||
|
|
skills_metadata: dict,
|
|||
|
|
image_keys: list[str],
|
|||
|
|
sample_interval_seconds: float = 1.0,
|
|||
|
|
console: Console | None = None,
|
|||
|
|
) -> tuple[pd.DataFrame, np.ndarray, list[dict]]:
|
|||
|
|
"""
|
|||
|
|
Generate synthetic dialogue data for entire dataset.
|
|||
|
|
|
|||
|
|
This function processes ALL frames in the dataset, but only calls the VLM
|
|||
|
|
at specified intervals (sample_interval_seconds). Frames between samples
|
|||
|
|
inherit the task_index from the most recent sample.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
dataset: LeRobot dataset with skill annotations
|
|||
|
|
pgen: Qwen model wrapper
|
|||
|
|
skills_metadata: Loaded skills.json metadata
|
|||
|
|
image_keys: List of image observation keys to use
|
|||
|
|
sample_interval_seconds: Generate dialogue every N seconds (default: 1.0)
|
|||
|
|
console: Rich console for logging
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Tuple of (tasks_df, task_indices_array, debug_outputs)
|
|||
|
|
- tasks_df: DataFrame with high-level tasks (user_prompt, robot_utterance, etc.)
|
|||
|
|
- task_indices_array: Array of task indices for each frame (full dataset length)
|
|||
|
|
- debug_outputs: List of debug dictionaries (only for sampled frames)
|
|||
|
|
"""
|
|||
|
|
if console is None:
|
|||
|
|
console = Console()
|
|||
|
|
|
|||
|
|
# Extract metadata
|
|||
|
|
coarse_description = skills_metadata.get("coarse_description", "Complete the task")
|
|||
|
|
episodes = skills_metadata.get("episodes", {})
|
|||
|
|
|
|||
|
|
# Track unique high-level tasks
|
|||
|
|
high_level_tasks = {} # (user_prompt, robot_utterance, skill) -> task_index
|
|||
|
|
task_index_counter = 0 # Start at 0
|
|||
|
|
|
|||
|
|
# Array to store task index for each frame - MUST match full dataset length
|
|||
|
|
full_dataset_length = len(dataset)
|
|||
|
|
task_indices = np.zeros(full_dataset_length, dtype=np.int64)
|
|||
|
|
|
|||
|
|
# For debugging - save to JSONL
|
|||
|
|
debug_outputs = []
|
|||
|
|
|
|||
|
|
# Track sampling
|
|||
|
|
last_sample_timestamp = {} # episode_idx -> last sampled timestamp
|
|||
|
|
last_task_index = {} # episode_idx -> last generated task_index
|
|||
|
|
frames_sampled = 0
|
|||
|
|
|
|||
|
|
console.print(f"[cyan]Processing all {full_dataset_length} frames from {dataset.meta.total_episodes} episodes...[/cyan]")
|
|||
|
|
console.print(f"[cyan]Sampling interval: {sample_interval_seconds}s (fps: {dataset.meta.fps})[/cyan]")
|
|||
|
|
|
|||
|
|
# Process each frame in the FULL dataset
|
|||
|
|
for frame_idx in tqdm(range(full_dataset_length), desc="Generating synthetic dialogue"):
|
|||
|
|
try:
|
|||
|
|
# Get frame data
|
|||
|
|
frame = dataset[frame_idx]
|
|||
|
|
episode_idx = frame["episode_index"].item()
|
|||
|
|
timestamp = frame["timestamp"].item()
|
|||
|
|
|
|||
|
|
# Get episode skills
|
|||
|
|
episode_key = str(episode_idx)
|
|||
|
|
if episode_key not in episodes:
|
|||
|
|
console.print(f"[yellow]Warning: Episode {episode_idx} not in skills metadata[/yellow]")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
episode_data = episodes[episode_key]
|
|||
|
|
skills = episode_data.get("skills", [])
|
|||
|
|
description = episode_data.get("description", coarse_description)
|
|||
|
|
|
|||
|
|
# Find current skill
|
|||
|
|
current_skill = get_skill_at_timestamp(skills, timestamp)
|
|||
|
|
if current_skill is None:
|
|||
|
|
console.print(f"[yellow]Warning: No skill found for timestamp {timestamp}[/yellow]")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# Determine if we should sample this frame
|
|||
|
|
should_sample = False
|
|||
|
|
|
|||
|
|
# Always sample first frame of an episode
|
|||
|
|
if episode_idx not in last_sample_timestamp:
|
|||
|
|
should_sample = True
|
|||
|
|
last_sample_timestamp[episode_idx] = timestamp
|
|||
|
|
else:
|
|||
|
|
# Sample if enough time has passed
|
|||
|
|
time_since_last = timestamp - last_sample_timestamp[episode_idx]
|
|||
|
|
if time_since_last >= sample_interval_seconds:
|
|||
|
|
should_sample = True
|
|||
|
|
last_sample_timestamp[episode_idx] = timestamp
|
|||
|
|
|
|||
|
|
# If not sampling, reuse last task index for this episode
|
|||
|
|
if not should_sample:
|
|||
|
|
if episode_idx in last_task_index:
|
|||
|
|
task_indices[frame_idx] = last_task_index[episode_idx]
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# Sample this frame - generate synthetic dialogue
|
|||
|
|
frames_sampled += 1
|
|||
|
|
|
|||
|
|
# Build skill history (all skills before current timestamp)
|
|||
|
|
skill_history = []
|
|||
|
|
for skill in skills:
|
|||
|
|
if skill["end"] <= timestamp:
|
|||
|
|
skill_history.append(skill["name"])
|
|||
|
|
|
|||
|
|
# Load images
|
|||
|
|
images = []
|
|||
|
|
for img_key in image_keys:
|
|||
|
|
if img_key in frame:
|
|||
|
|
# Frame images are tensors (C, H, W) in [0, 1]
|
|||
|
|
img_tensor = frame[img_key]
|
|||
|
|
if len(img_tensor.shape) == 4: # (T, C, H, W)
|
|||
|
|
img_tensor = img_tensor[-1] # Take last frame
|
|||
|
|
|
|||
|
|
# Convert to PIL Image
|
|||
|
|
img_array = (img_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
|||
|
|
img_pil = Image.fromarray(img_array)
|
|||
|
|
images.append(img_pil)
|
|||
|
|
|
|||
|
|
if not images:
|
|||
|
|
console.print(f"[yellow]Warning: No images found for frame {frame_idx}[/yellow]")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# Generate synthetic dialogue
|
|||
|
|
result = annotate_sample(
|
|||
|
|
pgen=pgen,
|
|||
|
|
images=images,
|
|||
|
|
task_description=description,
|
|||
|
|
skill_history=skill_history,
|
|||
|
|
skill_current=current_skill,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Create unique task key
|
|||
|
|
task_key = (
|
|||
|
|
result["user_prompt"],
|
|||
|
|
result["robot_utterance"],
|
|||
|
|
current_skill,
|
|||
|
|
result["scenario_type"],
|
|||
|
|
result["response_type"],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Assign or create task index
|
|||
|
|
if task_key not in high_level_tasks:
|
|||
|
|
high_level_tasks[task_key] = task_index_counter
|
|||
|
|
task_index_counter += 1
|
|||
|
|
|
|||
|
|
current_task_idx = high_level_tasks[task_key]
|
|||
|
|
task_indices[frame_idx] = current_task_idx
|
|||
|
|
last_task_index[episode_idx] = current_task_idx
|
|||
|
|
|
|||
|
|
# Save for debugging
|
|||
|
|
debug_outputs.append({
|
|||
|
|
"episode_id": int(episode_idx),
|
|||
|
|
"frame_index": frame_idx,
|
|||
|
|
"timestamp": float(timestamp),
|
|||
|
|
"skill_current": current_skill,
|
|||
|
|
"skill_history": skill_history,
|
|||
|
|
"task_description": description,
|
|||
|
|
"sampled": True,
|
|||
|
|
**result,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
console.print(f"[red]Error processing frame {frame_idx}: {e}[/red]")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
console.print(f"[green]✓ Sampled {frames_sampled} frames out of {full_dataset_length} total ({frames_sampled/full_dataset_length*100:.1f}%)[/green]")
|
|||
|
|
|
|||
|
|
# Create tasks DataFrame
|
|||
|
|
tasks_data = []
|
|||
|
|
for task_key, task_idx in sorted(high_level_tasks.items(), key=lambda x: x[1]):
|
|||
|
|
user_prompt, robot_utterance, skill, scenario_type, response_type = task_key
|
|||
|
|
tasks_data.append({
|
|||
|
|
"task": f"{user_prompt} | {robot_utterance}",
|
|||
|
|
"task_index": task_idx,
|
|||
|
|
"user_prompt": user_prompt,
|
|||
|
|
"robot_utterance": robot_utterance,
|
|||
|
|
"skill": skill,
|
|||
|
|
"scenario_type": scenario_type,
|
|||
|
|
"response_type": response_type,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
tasks_df = pd.DataFrame(tasks_data).set_index("task")
|
|||
|
|
|
|||
|
|
console.print(f"[green]✓ Generated {len(high_level_tasks)} unique high-level tasks[/green]")
|
|||
|
|
|
|||
|
|
return tasks_df, task_indices, debug_outputs
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_high_level_tasks(
|
|||
|
|
tasks_df: pd.DataFrame,
|
|||
|
|
dataset_root: Path,
|
|||
|
|
console: Console | None = None,
|
|||
|
|
) -> None:
|
|||
|
|
"""Save high-level tasks to tasks_high_level.parquet."""
|
|||
|
|
if console is None:
|
|||
|
|
console = Console()
|
|||
|
|
|
|||
|
|
output_path = dataset_root / "meta" / "tasks_high_level.parquet"
|
|||
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
tasks_df.to_parquet(output_path, engine="pyarrow", compression="snappy")
|
|||
|
|
console.print(f"[green]✓ Saved high-level tasks to {output_path}[/green]")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_debug_outputs(
|
|||
|
|
debug_outputs: list[dict],
|
|||
|
|
dataset_root: Path,
|
|||
|
|
console: Console | None = None,
|
|||
|
|
) -> None:
|
|||
|
|
"""Save debug outputs to JSONL file."""
|
|||
|
|
if console is None:
|
|||
|
|
console = Console()
|
|||
|
|
|
|||
|
|
output_path = dataset_root / "meta" / "syn_annotations.jsonl"
|
|||
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
with open(output_path, "w") as f:
|
|||
|
|
for item in debug_outputs:
|
|||
|
|
f.write(json.dumps(item) + "\n")
|
|||
|
|
|
|||
|
|
console.print(f"[green]✓ Saved debug annotations to {output_path}[/green]")
|
|||
|
|
|
|||
|
|
|
|||
|
|
# =============================================================================
|
|||
|
|
# Main Entry Point
|
|||
|
|
# =============================================================================
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
"""Main entry point for synthetic data generation."""
|
|||
|
|
parser = argparse.ArgumentParser(
|
|||
|
|
description="Generate synthetic dialogue data for hierarchical robot policies",
|
|||
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|||
|
|
epilog=textwrap.dedent("""\
|
|||
|
|
Examples:
|
|||
|
|
# Generate synthetic data for a dataset
|
|||
|
|
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
|
|||
|
|
--model Qwen/Qwen2-VL-7B-Instruct \\
|
|||
|
|
--output-dir ./output
|
|||
|
|
|
|||
|
|
# Use Qwen3 model with custom parameters
|
|||
|
|
python annotate_pgen.py --repo-id lerobot/svla_so101_pickplace \\
|
|||
|
|
--model Qwen/Qwen3-VL-30B-A3B-Instruct \\
|
|||
|
|
--temperature 0.8 \\
|
|||
|
|
--batch-size 1
|
|||
|
|
"""),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Data source
|
|||
|
|
data_group = parser.add_mutually_exclusive_group(required=True)
|
|||
|
|
data_group.add_argument("--data-dir", type=str, help="Path to local LeRobot dataset")
|
|||
|
|
data_group.add_argument("--repo-id", type=str, help="HuggingFace Hub dataset repository ID")
|
|||
|
|
|
|||
|
|
# Model configuration
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--model",
|
|||
|
|
type=str,
|
|||
|
|
default="Qwen/Qwen2-VL-7B-Instruct",
|
|||
|
|
help="VLM model to use (default: Qwen/Qwen2-VL-7B-Instruct)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--device",
|
|||
|
|
type=str,
|
|||
|
|
default="cuda",
|
|||
|
|
help="Device to run model on (default: cuda)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--dtype",
|
|||
|
|
type=str,
|
|||
|
|
default="bfloat16",
|
|||
|
|
choices=["bfloat16", "float16", "float32"],
|
|||
|
|
help="Model dtype (default: bfloat16)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--temperature",
|
|||
|
|
type=float,
|
|||
|
|
default=0.7,
|
|||
|
|
help="Sampling temperature (default: 0.7)",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Processing options
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--batch-size",
|
|||
|
|
type=int,
|
|||
|
|
default=1,
|
|||
|
|
help="Batch size for processing (default: 1) [currently unused]",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--num-image-views-per-sample",
|
|||
|
|
type=int,
|
|||
|
|
default=1,
|
|||
|
|
help="Number of camera views to use per sample (default: 1)",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--sample-interval",
|
|||
|
|
type=float,
|
|||
|
|
default=1.0,
|
|||
|
|
help="Generate dialogue every N seconds (default: 1.0). Frames between samples reuse the last generated dialogue. "
|
|||
|
|
"Use larger intervals (e.g., 2.0 or 5.0) for faster processing during testing.",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Output options
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--output-dir",
|
|||
|
|
type=str,
|
|||
|
|
default=None,
|
|||
|
|
help="Output directory for modified dataset",
|
|||
|
|
)
|
|||
|
|
parser.add_argument(
|
|||
|
|
"--push-to-hub",
|
|||
|
|
action="store_true",
|
|||
|
|
help="Push modified dataset to HuggingFace Hub",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
console = Console()
|
|||
|
|
|
|||
|
|
# Load dataset
|
|||
|
|
console.print("[cyan]Loading dataset...[/cyan]")
|
|||
|
|
if args.data_dir:
|
|||
|
|
dataset = LeRobotDataset(repo_id="local/dataset", root=args.data_dir)
|
|||
|
|
dataset_root = Path(args.data_dir)
|
|||
|
|
else:
|
|||
|
|
dataset = LeRobotDataset(repo_id=args.repo_id)
|
|||
|
|
dataset_root = dataset.root
|
|||
|
|
|
|||
|
|
console.print(f"[green]✓ Loaded dataset with {len(dataset)} frames[/green]")
|
|||
|
|
|
|||
|
|
# Load skills metadata
|
|||
|
|
console.print("[cyan]Loading skills metadata...[/cyan]")
|
|||
|
|
skills_metadata = load_skills_metadata(dataset_root)
|
|||
|
|
if skills_metadata is None:
|
|||
|
|
console.print("[red]Error: No skills.json found. Run annotate.py first![/red]")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
console.print(f"[green]✓ Loaded skills for {len(skills_metadata.get('episodes', {}))} episodes[/green]")
|
|||
|
|
|
|||
|
|
# Initialize model
|
|||
|
|
dtype_map = {
|
|||
|
|
"bfloat16": torch.bfloat16,
|
|||
|
|
"float16": torch.float16,
|
|||
|
|
"float32": torch.float32,
|
|||
|
|
}
|
|||
|
|
torch_dtype = dtype_map[args.dtype]
|
|||
|
|
|
|||
|
|
console.print(f"[cyan]Initializing {args.model}...[/cyan]")
|
|||
|
|
pgen = QwenPgen(
|
|||
|
|
model_name=args.model,
|
|||
|
|
device=args.device,
|
|||
|
|
torch_dtype=torch_dtype,
|
|||
|
|
temperature=args.temperature,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Get image keys
|
|||
|
|
image_keys = dataset.meta.camera_keys[:args.num_image_views_per_sample]
|
|||
|
|
console.print(f"[cyan]Using image keys: {image_keys}[/cyan]")
|
|||
|
|
|
|||
|
|
# Generate synthetic data
|
|||
|
|
tasks_df, task_indices, debug_outputs = generate_synthetic_data(
|
|||
|
|
dataset=dataset,
|
|||
|
|
pgen=pgen,
|
|||
|
|
skills_metadata=skills_metadata,
|
|||
|
|
image_keys=image_keys,
|
|||
|
|
sample_interval_seconds=args.sample_interval,
|
|||
|
|
console=console,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Save high-level tasks
|
|||
|
|
save_high_level_tasks(tasks_df, dataset_root, console)
|
|||
|
|
save_debug_outputs(debug_outputs, dataset_root, console)
|
|||
|
|
|
|||
|
|
# Add task_index_high_level feature to dataset
|
|||
|
|
console.print("[cyan]Adding task_index_high_level feature to dataset...[/cyan]")
|
|||
|
|
|
|||
|
|
# Determine output directory
|
|||
|
|
if args.output_dir:
|
|||
|
|
output_dir = Path(args.output_dir)
|
|||
|
|
repo_id = f"{dataset.repo_id}_with_high_level_tasks"
|
|||
|
|
else:
|
|||
|
|
output_dir = None
|
|||
|
|
repo_id = f"{dataset.repo_id}_with_high_level_tasks"
|
|||
|
|
|
|||
|
|
# Add feature using dataset_tools
|
|||
|
|
feature_info = {
|
|||
|
|
"dtype": "int64",
|
|||
|
|
"shape": (1,),
|
|||
|
|
"names": None,
|
|||
|
|
}
|
|||
|
|
breakpoint()
|
|||
|
|
new_dataset = add_features(
|
|||
|
|
dataset=dataset,
|
|||
|
|
features={
|
|||
|
|
"task_index_high_level": (task_indices, feature_info),
|
|||
|
|
},
|
|||
|
|
output_dir=output_dir,
|
|||
|
|
repo_id=repo_id,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
console.print(f"[bold green]✓ Successfully added task_index_high_level feature![/bold green]")
|
|||
|
|
console.print(f" New dataset saved to: {new_dataset.root}")
|
|||
|
|
console.print(f" Total high-level tasks: {len(tasks_df)}")
|
|||
|
|
|
|||
|
|
# Push to hub if requested
|
|||
|
|
if args.push_to_hub:
|
|||
|
|
if args.data_dir:
|
|||
|
|
console.print("[yellow]Warning: --push-to-hub requires --repo-id, skipping...[/yellow]")
|
|||
|
|
else:
|
|||
|
|
console.print("[cyan]Pushing to HuggingFace Hub...[/cyan]")
|
|||
|
|
try:
|
|||
|
|
new_dataset.push_to_hub(push_videos=False)
|
|||
|
|
console.print(f"[green]✓ Pushed to {repo_id}[/green]")
|
|||
|
|
except Exception as e:
|
|||
|
|
console.print(f"[red]Push failed: {e}[/red]")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|
|||
|
|
|