mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Compare commits
23 Commits
feat/sarm_
...
feat/unitr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b320530482 | ||
|
|
4bdd2475b0 | ||
|
|
5d9266b024 | ||
|
|
c7834c3db8 | ||
|
|
c65866ddd8 | ||
|
|
bebf9b8480 | ||
|
|
3be342a00d | ||
|
|
e6c16a60b1 | ||
|
|
786f4df529 | ||
|
|
58739f4b7a | ||
|
|
b6e606c28d | ||
|
|
c477c54e3c | ||
|
|
f30e15d411 | ||
|
|
8f06c02c17 | ||
|
|
9a052566a3 | ||
|
|
e5cae6be64 | ||
|
|
56b66b9542 | ||
|
|
30c3bbef7b | ||
|
|
c4bf27772c | ||
|
|
3422f2cb01 | ||
|
|
c365bcd0a5 | ||
|
|
2cbd6649f2 | ||
|
|
c749fba0f5 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -173,7 +173,3 @@ outputs/
|
||||
|
||||
# Dev folders
|
||||
.cache/*
|
||||
*.stl
|
||||
*.urdf
|
||||
*.xml
|
||||
*.part
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,525 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Visualize SARM Subtask Annotations
|
||||
|
||||
This script creates visualizations of the subtask annotations generated by subtask_annotation.py.
|
||||
For each episode, it shows:
|
||||
- A timeline with dashed vertical lines at subtask boundaries
|
||||
- Sample frames from the episode at key points (start, middle, end of each subtask)
|
||||
- Color-coded subtask segments
|
||||
|
||||
Usage:
|
||||
python visualize_subtask_annotations.py --repo-id pepijn223/mydataset --video-key observation.images.top --num-episodes 5
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib.lines import Line2D
|
||||
from rich.console import Console
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
|
||||
|
||||
|
||||
def timestamp_to_seconds(timestamp: str) -> float:
|
||||
"""Convert MM:SS or SS timestamp to seconds"""
|
||||
parts = timestamp.split(":")
|
||||
if len(parts) == 2:
|
||||
return int(parts[0]) * 60 + int(parts[1])
|
||||
else:
|
||||
return int(parts[0])
|
||||
|
||||
|
||||
def load_annotations_from_dataset(dataset_path: Path) -> dict[int, SubtaskAnnotation]:
|
||||
"""
|
||||
Load annotations from LeRobot dataset parquet files.
|
||||
|
||||
Reads subtask annotations from the episodes metadata parquet files.
|
||||
"""
|
||||
episodes_dataset = load_episodes(dataset_path)
|
||||
|
||||
if episodes_dataset is None or len(episodes_dataset) == 0:
|
||||
return {}
|
||||
|
||||
# Check if subtask columns exist
|
||||
if "subtask_names" not in episodes_dataset.column_names:
|
||||
return {}
|
||||
|
||||
# Convert to pandas DataFrame for easier access
|
||||
episodes_df = episodes_dataset.to_pandas()
|
||||
|
||||
annotations = {}
|
||||
|
||||
for ep_idx in episodes_df.index:
|
||||
subtask_names = episodes_df.loc[ep_idx, "subtask_names"]
|
||||
|
||||
# Skip episodes without annotations
|
||||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||
continue
|
||||
|
||||
start_times = episodes_df.loc[ep_idx, "subtask_start_times"]
|
||||
end_times = episodes_df.loc[ep_idx, "subtask_end_times"]
|
||||
|
||||
# Reconstruct SubtaskAnnotation from stored data
|
||||
subtasks = []
|
||||
for i, name in enumerate(subtask_names):
|
||||
# Convert seconds back to MM:SS format
|
||||
start_sec = int(start_times[i])
|
||||
end_sec = int(end_times[i])
|
||||
start_str = f"{start_sec // 60:02d}:{start_sec % 60:02d}"
|
||||
end_str = f"{end_sec // 60:02d}:{end_sec % 60:02d}"
|
||||
|
||||
subtasks.append(
|
||||
Subtask(
|
||||
name=name,
|
||||
timestamps=Timestamp(start=start_str, end=end_str)
|
||||
)
|
||||
)
|
||||
|
||||
annotations[int(ep_idx)] = SubtaskAnnotation(subtasks=subtasks)
|
||||
|
||||
return annotations
|
||||
|
||||
|
||||
# Color palette for subtasks (colorblind-friendly)
|
||||
SUBTASK_COLORS = [
|
||||
"#E69F00", # Orange
|
||||
"#56B4E9", # Sky blue
|
||||
"#009E73", # Bluish green
|
||||
"#F0E442", # Yellow
|
||||
"#0072B2", # Blue
|
||||
"#D55E00", # Vermillion
|
||||
"#CC79A7", # Reddish purple
|
||||
"#999999", # Gray
|
||||
]
|
||||
|
||||
|
||||
def extract_frame_from_video(video_path: Path, timestamp: float) -> np.ndarray | None:
|
||||
"""Extract a single frame from video at given timestamp."""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
return None
|
||||
|
||||
# Set position to timestamp
|
||||
cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
|
||||
if ret:
|
||||
# Convert BGR to RGB
|
||||
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
return None
|
||||
|
||||
|
||||
def visualize_episode(
|
||||
episode_idx: int,
|
||||
annotation,
|
||||
video_path: Path,
|
||||
video_start_timestamp: float,
|
||||
video_end_timestamp: float,
|
||||
fps: int,
|
||||
output_path: Path,
|
||||
video_key: str,
|
||||
):
|
||||
"""
|
||||
Create visualization for a single episode.
|
||||
|
||||
Shows:
|
||||
- Top row: Sample frames from the episode (one per subtask)
|
||||
- Bottom: Timeline with subtask segments and boundary lines
|
||||
"""
|
||||
subtasks = annotation.subtasks
|
||||
num_subtasks = len(subtasks)
|
||||
|
||||
if num_subtasks == 0:
|
||||
print(f"No subtasks found for episode {episode_idx}")
|
||||
return
|
||||
|
||||
# Calculate episode duration
|
||||
episode_duration = video_end_timestamp - video_start_timestamp
|
||||
|
||||
# Extract sample frames - get frame from middle of each subtask
|
||||
sample_frames = []
|
||||
frame_timestamps = []
|
||||
|
||||
for subtask in subtasks:
|
||||
start_sec = timestamp_to_seconds(subtask.timestamps.start)
|
||||
end_sec = timestamp_to_seconds(subtask.timestamps.end)
|
||||
mid_sec = (start_sec + end_sec) / 2
|
||||
|
||||
# Convert to video timestamp (add video_start_timestamp offset)
|
||||
video_timestamp = video_start_timestamp + mid_sec
|
||||
frame_timestamps.append(mid_sec)
|
||||
|
||||
frame = extract_frame_from_video(video_path, video_timestamp)
|
||||
sample_frames.append(frame)
|
||||
|
||||
# Create figure
|
||||
fig = plt.figure(figsize=(16, 10))
|
||||
|
||||
# Use a dark background for better contrast
|
||||
fig.patch.set_facecolor('#1a1a2e')
|
||||
|
||||
# Calculate grid layout
|
||||
# Top section: frames (variable number of columns based on subtasks)
|
||||
# Bottom section: timeline
|
||||
|
||||
# Create gridspec
|
||||
gs = fig.add_gridspec(
|
||||
2, max(num_subtasks, 1),
|
||||
height_ratios=[2, 1],
|
||||
hspace=0.3,
|
||||
wspace=0.1,
|
||||
left=0.05, right=0.95,
|
||||
top=0.88, bottom=0.1
|
||||
)
|
||||
|
||||
# Add title
|
||||
fig.suptitle(
|
||||
f"Episode {episode_idx} - Subtask Annotations",
|
||||
fontsize=18,
|
||||
fontweight='bold',
|
||||
color='white',
|
||||
y=0.96
|
||||
)
|
||||
|
||||
# Add subtitle with video info
|
||||
fig.text(
|
||||
0.5, 0.91,
|
||||
f"Camera: {video_key} | Duration: {episode_duration:.1f}s | {num_subtasks} subtasks",
|
||||
ha='center',
|
||||
fontsize=11,
|
||||
color='#888888'
|
||||
)
|
||||
|
||||
# Plot sample frames
|
||||
for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks)):
|
||||
ax = fig.add_subplot(gs[0, i])
|
||||
ax.set_facecolor('#16213e')
|
||||
|
||||
if frame is not None:
|
||||
ax.imshow(frame)
|
||||
else:
|
||||
ax.text(0.5, 0.5, "Frame\nN/A", ha='center', va='center',
|
||||
fontsize=12, color='white', transform=ax.transAxes)
|
||||
|
||||
ax.set_title(
|
||||
f"{subtask.name}",
|
||||
fontsize=10,
|
||||
fontweight='bold',
|
||||
color=SUBTASK_COLORS[i % len(SUBTASK_COLORS)],
|
||||
pad=8
|
||||
)
|
||||
ax.axis('off')
|
||||
|
||||
# Add frame timestamp below
|
||||
ax.text(
|
||||
0.5, -0.08,
|
||||
f"t={frame_timestamps[i]:.1f}s",
|
||||
ha='center',
|
||||
fontsize=9,
|
||||
color='#888888',
|
||||
transform=ax.transAxes
|
||||
)
|
||||
|
||||
# Create timeline subplot spanning all columns
|
||||
ax_timeline = fig.add_subplot(gs[1, :])
|
||||
ax_timeline.set_facecolor('#16213e')
|
||||
|
||||
# Get total duration from last subtask end time
|
||||
total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end)
|
||||
|
||||
# Draw subtask segments as colored bars
|
||||
bar_height = 0.6
|
||||
bar_y = 0.5
|
||||
|
||||
for i, subtask in enumerate(subtasks):
|
||||
start_sec = timestamp_to_seconds(subtask.timestamps.start)
|
||||
end_sec = timestamp_to_seconds(subtask.timestamps.end)
|
||||
color = SUBTASK_COLORS[i % len(SUBTASK_COLORS)]
|
||||
|
||||
# Draw segment bar
|
||||
rect = mpatches.FancyBboxPatch(
|
||||
(start_sec, bar_y - bar_height/2),
|
||||
end_sec - start_sec,
|
||||
bar_height,
|
||||
boxstyle="round,pad=0.02,rounding_size=0.1",
|
||||
facecolor=color,
|
||||
edgecolor='white',
|
||||
linewidth=1.5,
|
||||
alpha=0.85
|
||||
)
|
||||
ax_timeline.add_patch(rect)
|
||||
|
||||
# Add subtask label inside bar
|
||||
mid_x = (start_sec + end_sec) / 2
|
||||
duration = end_sec - start_sec
|
||||
|
||||
# Only add text if segment is wide enough
|
||||
if duration > total_duration * 0.08:
|
||||
ax_timeline.text(
|
||||
mid_x, bar_y,
|
||||
subtask.name,
|
||||
ha='center', va='center',
|
||||
fontsize=9,
|
||||
fontweight='bold',
|
||||
color='black' if i in [3] else 'white', # Yellow needs dark text
|
||||
rotation=0 if duration > total_duration * 0.15 else 45
|
||||
)
|
||||
|
||||
# Draw boundary lines (dashed vertical lines between subtasks)
|
||||
boundary_times = []
|
||||
for i, subtask in enumerate(subtasks):
|
||||
start_sec = timestamp_to_seconds(subtask.timestamps.start)
|
||||
end_sec = timestamp_to_seconds(subtask.timestamps.end)
|
||||
|
||||
# Add start boundary (except for first subtask at t=0)
|
||||
if i == 0 and start_sec > 0:
|
||||
boundary_times.append(start_sec)
|
||||
elif i > 0:
|
||||
boundary_times.append(start_sec)
|
||||
|
||||
# Add end boundary for last subtask
|
||||
if i == len(subtasks) - 1:
|
||||
boundary_times.append(end_sec)
|
||||
|
||||
# Draw dashed lines at boundaries
|
||||
for t in boundary_times:
|
||||
ax_timeline.axvline(
|
||||
x=t,
|
||||
ymin=0.1, ymax=0.9,
|
||||
color='white',
|
||||
linestyle='--',
|
||||
linewidth=2,
|
||||
alpha=0.9
|
||||
)
|
||||
|
||||
# Add time label below line
|
||||
ax_timeline.text(
|
||||
t, 0.0,
|
||||
f"{int(t//60):02d}:{int(t%60):02d}",
|
||||
ha='center', va='top',
|
||||
fontsize=8,
|
||||
color='#cccccc'
|
||||
)
|
||||
|
||||
# Add start line at t=0
|
||||
ax_timeline.axvline(x=0, ymin=0.1, ymax=0.9, color='#00ff00', linestyle='-', linewidth=2.5, alpha=0.9)
|
||||
ax_timeline.text(0, 0.0, "00:00", ha='center', va='top', fontsize=8, color='#00ff00', fontweight='bold')
|
||||
|
||||
# Configure timeline axes
|
||||
ax_timeline.set_xlim(-total_duration * 0.02, total_duration * 1.02)
|
||||
ax_timeline.set_ylim(-0.3, 1.2)
|
||||
ax_timeline.set_xlabel("Time (seconds)", fontsize=11, color='white', labelpad=10)
|
||||
ax_timeline.set_ylabel("")
|
||||
|
||||
# Style the axes
|
||||
ax_timeline.spines['top'].set_visible(False)
|
||||
ax_timeline.spines['right'].set_visible(False)
|
||||
ax_timeline.spines['left'].set_visible(False)
|
||||
ax_timeline.spines['bottom'].set_color('#444444')
|
||||
ax_timeline.tick_params(axis='x', colors='#888888', labelsize=9)
|
||||
ax_timeline.tick_params(axis='y', left=False, labelleft=False)
|
||||
|
||||
# Add x-axis ticks at regular intervals
|
||||
tick_interval = max(1, int(total_duration / 10))
|
||||
ax_timeline.set_xticks(np.arange(0, total_duration + tick_interval, tick_interval))
|
||||
|
||||
# Add legend explaining line styles
|
||||
legend_elements = [
|
||||
Line2D([0], [0], color='#00ff00', linewidth=2.5, linestyle='-', label='Start'),
|
||||
Line2D([0], [0], color='white', linewidth=2, linestyle='--', label='Subtask boundary'),
|
||||
]
|
||||
ax_timeline.legend(
|
||||
handles=legend_elements,
|
||||
loc='upper right',
|
||||
framealpha=0.3,
|
||||
facecolor='#16213e',
|
||||
edgecolor='#444444',
|
||||
fontsize=9,
|
||||
labelcolor='white'
|
||||
)
|
||||
|
||||
# Save figure
|
||||
plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Visualize SARM subtask annotations",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-episodes",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of random episodes to visualize (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Specific episode indices to visualize (overrides --num-episodes)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Camera/video key to use. If not specified, uses first available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./subtask_viz",
|
||||
help="Output directory for visualizations (default: ./subtask_viz)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Random seed for reproducibility",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
console = Console()
|
||||
|
||||
# Set random seed if specified
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
|
||||
console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]")
|
||||
dataset = LeRobotDataset(args.repo_id, download_videos=True)
|
||||
fps = dataset.fps
|
||||
|
||||
# Get video key
|
||||
if args.video_key:
|
||||
if args.video_key not in dataset.meta.video_keys:
|
||||
console.print(f"[red]Error: Video key '{args.video_key}' not found[/red]")
|
||||
console.print(f"[yellow]Available: {', '.join(dataset.meta.video_keys)}[/yellow]")
|
||||
return
|
||||
video_key = args.video_key
|
||||
else:
|
||||
video_key = dataset.meta.video_keys[0]
|
||||
|
||||
console.print(f"[cyan]Using camera: {video_key}[/cyan]")
|
||||
console.print(f"[cyan]FPS: {fps}[/cyan]")
|
||||
|
||||
# Load annotations
|
||||
console.print(f"\n[cyan]Loading annotations...[/cyan]")
|
||||
annotations = load_annotations_from_dataset(dataset.root)
|
||||
|
||||
if not annotations:
|
||||
console.print("[red]Error: No annotations found in dataset[/red]")
|
||||
console.print("[yellow]Run subtask_annotation.py first to generate annotations[/yellow]")
|
||||
return
|
||||
|
||||
console.print(f"[green]Found {len(annotations)} annotated episodes[/green]")
|
||||
|
||||
# Determine which episodes to visualize
|
||||
if args.episodes:
|
||||
episode_indices = args.episodes
|
||||
# Validate episodes exist
|
||||
for ep in episode_indices:
|
||||
if ep not in annotations:
|
||||
console.print(f"[yellow]Warning: Episode {ep} has no annotation, skipping[/yellow]")
|
||||
episode_indices = [ep for ep in episode_indices if ep in annotations]
|
||||
else:
|
||||
# Random selection
|
||||
available_episodes = list(annotations.keys())
|
||||
num_to_select = min(args.num_episodes, len(available_episodes))
|
||||
episode_indices = random.sample(available_episodes, num_to_select)
|
||||
episode_indices.sort()
|
||||
|
||||
if not episode_indices:
|
||||
console.print("[red]Error: No valid episodes to visualize[/red]")
|
||||
return
|
||||
|
||||
console.print(f"[cyan]Visualizing episodes: {episode_indices}[/cyan]")
|
||||
|
||||
# Create output directory
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate visualizations
|
||||
for ep_idx in episode_indices:
|
||||
console.print(f"\n[cyan]Processing episode {ep_idx}...[/cyan]")
|
||||
|
||||
annotation = annotations[ep_idx]
|
||||
|
||||
# Get video path and timestamps
|
||||
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
|
||||
|
||||
if not video_path.exists():
|
||||
console.print(f"[red]Video not found: {video_path}[/red]")
|
||||
continue
|
||||
|
||||
# Get episode-specific timestamps within the video file
|
||||
video_path_key = f"videos/{video_key}/from_timestamp"
|
||||
video_path_key_to = f"videos/{video_key}/to_timestamp"
|
||||
|
||||
video_start_timestamp = float(dataset.meta.episodes[video_path_key][ep_idx])
|
||||
video_end_timestamp = float(dataset.meta.episodes[video_path_key_to][ep_idx])
|
||||
|
||||
# Create visualization
|
||||
output_path = output_dir / f"episode_{ep_idx:04d}_subtasks.png"
|
||||
|
||||
try:
|
||||
visualize_episode(
|
||||
episode_idx=ep_idx,
|
||||
annotation=annotation,
|
||||
video_path=video_path,
|
||||
video_start_timestamp=video_start_timestamp,
|
||||
video_end_timestamp=video_end_timestamp,
|
||||
fps=fps,
|
||||
output_path=output_path,
|
||||
video_key=video_key,
|
||||
)
|
||||
console.print(f"[green]✓ Saved: {output_path}[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Failed to visualize episode {ep_idx}: {e}[/red]")
|
||||
|
||||
# Print summary
|
||||
console.print(f"\n[bold green]{'=' * 50}[/bold green]")
|
||||
console.print(f"[bold green]Visualization Complete![/bold green]")
|
||||
console.print(f"[bold green]{'=' * 50}[/bold green]")
|
||||
console.print(f"Output directory: {output_dir.absolute()}")
|
||||
console.print(f"Episodes visualized: {len(episode_indices)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
107
get_calibration.py
Normal file
107
get_calibration.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import json
|
||||
import time
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
# ---- key → (section, name, id)
|
||||
MAP = {
|
||||
# LEFT
|
||||
"kLeftShoulderPitch.pos": ("left", "shoulder_pitch", 0),
|
||||
"kLeftShoulderYaw.pos": ("left", "shoulder_yaw", 1),
|
||||
"kLeftShoulderRoll.pos": ("left", "shoulder_roll", 2),
|
||||
"kLeftElbow.pos": ("left", "elbow_flex", 3),
|
||||
"kLeftWristRoll.pos": ("left", "wrist_roll", 4),
|
||||
"kLeftWristYaw.pos": ("left", "wrist_yaw", 5),
|
||||
"kLeftWristyaw.pos": ("left", "wrist_yaw", 5), # tolerate casing variant
|
||||
"kLeftWristPitch.pos": ("left", "wrist_pitch", 6),
|
||||
|
||||
# RIGHT
|
||||
"kRightShoulderPitch.pos": ("right", "shoulder_pitch", 0),
|
||||
"kRightShoulderYaw.pos": ("right", "shoulder_yaw", 1),
|
||||
"kRightShoulderRoll.pos": ("right", "shoulder_roll", 2),
|
||||
"kRightElbow.pos": ("right", "elbow_flex", 3),
|
||||
"kRightWristRoll.pos": ("right", "wrist_roll", 4),
|
||||
"kRightWristYaw.pos": ("right", "wrist_yaw", 5),
|
||||
"kRightWristPitch.pos": ("right", "wrist_pitch", 6),
|
||||
}
|
||||
|
||||
# Output
|
||||
CALIB_PATH = Path("calibration.json")
|
||||
ROUND_TO_INT = False # set True if you want int ranges
|
||||
|
||||
# Init tracker: tracker["left"]["shoulder_pitch"] = {...}
|
||||
tracker = {"left": {}, "right": {}}
|
||||
for sec, name, idx in MAP.values():
|
||||
if name not in tracker[sec]:
|
||||
tracker[sec][name] = {
|
||||
"id": idx,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": math.inf,
|
||||
"range_max": -math.inf,
|
||||
}
|
||||
|
||||
def _to_float(x):
|
||||
# unwrap numpy / torch scalars if present
|
||||
if hasattr(x, "item"):
|
||||
try:
|
||||
x = x.item()
|
||||
except Exception:
|
||||
pass
|
||||
return float(x)
|
||||
|
||||
def update_tracker(obs: dict):
|
||||
for k, v in obs.items():
|
||||
if k not in MAP:
|
||||
continue
|
||||
sec, name, _ = MAP[k]
|
||||
try:
|
||||
x = _to_float(v)
|
||||
except Exception:
|
||||
continue
|
||||
t = tracker[sec][name]
|
||||
if x < t["range_min"]:
|
||||
t["range_min"] = x
|
||||
if x > t["range_max"]:
|
||||
t["range_max"] = x
|
||||
|
||||
def dump_calibration(path: Path):
|
||||
out = {"left": {}, "right": {}}
|
||||
for sec in ("left", "right"):
|
||||
for name, d in tracker[sec].items():
|
||||
mn, mx = d["range_min"], d["range_max"]
|
||||
if ROUND_TO_INT:
|
||||
mn = None if mn is math.inf else int(round(mn))
|
||||
mx = None if mx is -math.inf else int(round(mx))
|
||||
else:
|
||||
mn = None if mn is math.inf else mn
|
||||
mx = None if mx is -math.inf else mx
|
||||
out[sec][name] = {
|
||||
"id": d["id"],
|
||||
"drive_mode": d["drive_mode"],
|
||||
"homing_offset": d["homing_offset"],
|
||||
"range_min": mn,
|
||||
"range_max": mx,
|
||||
}
|
||||
path.write_text(json.dumps(out, indent=4))
|
||||
print(f"Saved calibration to {path.resolve()}")
|
||||
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1, G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
import time
|
||||
config = UnitreeG1Config(
|
||||
motion_mode=False,
|
||||
simulation_mode=False
|
||||
)
|
||||
|
||||
robot = UnitreeG1(config)
|
||||
try:
|
||||
while True:
|
||||
observation = robot.get_observation()
|
||||
update_tracker(observation)
|
||||
robot.send_action(observation) # mirror, if desired
|
||||
time.sleep(0.01)
|
||||
except KeyboardInterrupt:
|
||||
dump_calibration(CALIB_PATH)
|
||||
BIN
screenshot.png
Normal file
BIN
screenshot.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.5 MiB |
@@ -1,761 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Inference script for SARM (Stage-Aware Reward Model).
|
||||
|
||||
This script loads a trained SARM model and runs inference on a dataset episode,
|
||||
generating visualizations of the predicted task stages and progress over time.
|
||||
|
||||
Example usage:
|
||||
python scripts/visualize_sarm_predictions.py \
|
||||
--model-id username/sarm-model \
|
||||
--dataset-repo lerobot/aloha_sim_insertion_human \
|
||||
--episode-index 0 \
|
||||
--output-dir outputs/sarm_viz \
|
||||
--task-description "insert the peg into the socket"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.gridspec as gridspec
|
||||
import matplotlib.patches as mpatches
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
pad_state_to_max_dim,
|
||||
compute_tau,
|
||||
compute_cumulative_progress_batch,
|
||||
)
|
||||
from lerobot.datasets.utils import load_stats
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run SARM inference and visualize predictions")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace model ID or local path to trained SARM model"
|
||||
)
|
||||
|
||||
# Dataset arguments
|
||||
parser.add_argument(
|
||||
"--dataset-repo",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episode-index",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Index of the episode to visualize (default: 0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-description",
|
||||
type=str,
|
||||
default="perform the task",
|
||||
help="Task description for the reward model (default: 'perform the task')"
|
||||
)
|
||||
|
||||
# Output arguments
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="outputs/sarm_inference",
|
||||
help="Directory to save visualization outputs (default: outputs/sarm_inference)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Key for images in dataset (e.g., observation.images.image). If not specified, uses model config's image_key"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--state-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Key for joint states in dataset. If None, auto-detects from dataset"
|
||||
)
|
||||
|
||||
# Visualization options
|
||||
parser.add_argument(
|
||||
"--show-frames",
|
||||
action="store_true",
|
||||
help="Include sample frames in the visualization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-sample-frames",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of sample frames to show (default: 8)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--figsize",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[14, 8],
|
||||
help="Figure size as width height (default: 14 8)"
|
||||
)
|
||||
|
||||
# Device
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Device to run inference on (cuda/cpu, default: auto-detect)"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_episode_data(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
image_key: str,
|
||||
state_key: str | None = None
|
||||
) -> tuple[np.ndarray, np.ndarray, int, int, str]:
|
||||
"""
|
||||
Load all frames and states from a specific episode.
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset instance
|
||||
episode_index: Index of the episode to load
|
||||
image_key: Key for accessing images in the dataset
|
||||
state_key: Key for accessing joint states (auto-detected if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (frames, states, start_index, end_index, task_description)
|
||||
"""
|
||||
# Get episode boundaries
|
||||
episode_data = dataset.meta.episodes
|
||||
start_idx = episode_data["dataset_from_index"][episode_index]
|
||||
end_idx = episode_data["dataset_to_index"][episode_index]
|
||||
|
||||
logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)")
|
||||
|
||||
# Auto-detect state key if not provided
|
||||
if state_key is None:
|
||||
first_item = dataset[start_idx]
|
||||
state_keys = [k for k in first_item.keys() if 'state' in k.lower() or 'qpos' in k.lower()]
|
||||
if state_keys:
|
||||
state_key = state_keys[0]
|
||||
logger.info(f"Auto-detected state key: {state_key}")
|
||||
|
||||
# Get task description from the dataset if available
|
||||
task_description = None
|
||||
first_item = dataset[start_idx]
|
||||
if "task" in first_item:
|
||||
task_description = first_item["task"]
|
||||
logger.info(f"✓ Extracted task from episode {episode_index}: '{task_description}'")
|
||||
|
||||
# Load all frames and states from the episode
|
||||
frames = []
|
||||
states = []
|
||||
for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"):
|
||||
item = dataset[idx]
|
||||
|
||||
# Get image
|
||||
img = item[image_key]
|
||||
|
||||
# Convert to numpy if needed
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy()
|
||||
|
||||
# Handle different image formats (C, H, W) or (H, W, C)
|
||||
if img.shape[0] in [1, 3]: # Channel first
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
# Convert to uint8 if needed
|
||||
if img.dtype != np.uint8:
|
||||
if img.max() <= 1.0:
|
||||
img = (img * 255).astype(np.uint8)
|
||||
else:
|
||||
img = img.astype(np.uint8)
|
||||
|
||||
frames.append(img)
|
||||
|
||||
# Get state if available
|
||||
if state_key and state_key in item:
|
||||
state = item[state_key]
|
||||
if isinstance(state, torch.Tensor):
|
||||
state = state.cpu().numpy()
|
||||
states.append(state)
|
||||
|
||||
frames = np.array(frames)
|
||||
states = np.array(states) if states else None
|
||||
logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}")
|
||||
if states is not None:
|
||||
logger.info(f"Loaded states with shape {states.shape}")
|
||||
|
||||
return frames, states, start_idx, end_idx, task_description
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_inference(
|
||||
model: SARMRewardModel,
|
||||
frames: np.ndarray,
|
||||
states: Optional[np.ndarray],
|
||||
task_description: str,
|
||||
dataset_stats: dict | None = None,
|
||||
state_key: str = "observation.state",
|
||||
batch_size: int = 32
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Run SARM inference on video frames and joint states.
|
||||
|
||||
(per SARM paper Section A.4):
|
||||
- Frame 0: Initial frame of the episode (frame 0)
|
||||
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame t
|
||||
Pattern: [frame_0, t-(7*gap), t-(6*gap), ..., t-gap, t]
|
||||
|
||||
Args:
|
||||
model: SARM model
|
||||
frames: Video frames (num_frames, H, W, C) - all frames from ONE episode
|
||||
states: Joint states (num_frames, state_dim)
|
||||
task_description: Task description text
|
||||
dataset_stats: Dataset statistics for state normalization (same as training)
|
||||
state_key: Key for state in dataset_stats
|
||||
batch_size: Batch size for processing slices
|
||||
|
||||
Returns:
|
||||
Tuple of (progress_predictions, stage_predictions)
|
||||
- progress_predictions: (num_frames,)
|
||||
- stage_predictions: (num_frames, num_stages)
|
||||
"""
|
||||
logger.info("Encoding video frames with CLIP...")
|
||||
video_embeddings = model.encode_images(frames)
|
||||
|
||||
logger.info("Encoding task description with CLIP...")
|
||||
text_embedding = model.encode_text(task_description)
|
||||
|
||||
# Get config values
|
||||
num_frames_model = model.config.num_frames # 9
|
||||
frame_gap = model.config.frame_gap # 30
|
||||
|
||||
logger.info("Creating video slices (SARM paper: initial frame + 8 consecutive)...")
|
||||
|
||||
# Convert to tensors
|
||||
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
|
||||
if states is not None:
|
||||
state_embeddings = torch.tensor(states, dtype=torch.float32)
|
||||
|
||||
# Normalize states using dataset stats (same as training processor)
|
||||
if dataset_stats is not None and state_key in dataset_stats:
|
||||
mean = torch.tensor(dataset_stats[state_key]["mean"], dtype=torch.float32)
|
||||
std = torch.tensor(dataset_stats[state_key]["std"], dtype=torch.float32)
|
||||
state_embeddings = (state_embeddings - mean) / (std + 1e-8)
|
||||
logger.info(f"✓ Applied MEAN_STD normalization to states using {state_key}")
|
||||
else:
|
||||
logger.warning("⚠ No dataset_stats provided - states not normalized (may differ from training)")
|
||||
else:
|
||||
state_embeddings = None
|
||||
|
||||
video_slices = []
|
||||
state_slices = []
|
||||
|
||||
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
|
||||
# Compute frame indices using symmetric bidirectional pattern:
|
||||
# [initial (0), t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
# Boundary handling: clamp to [0, last_valid]
|
||||
deltas = model.config.observation_delta_indices
|
||||
last_valid = len(video_embeddings) - 1
|
||||
|
||||
frame_indices = []
|
||||
for delta in deltas:
|
||||
idx = current_frame + delta
|
||||
idx = max(0, min(idx, last_valid)) # Clamp to valid range
|
||||
frame_indices.append(idx)
|
||||
|
||||
video_slice = video_embeddings[frame_indices]
|
||||
video_slices.append(video_slice)
|
||||
|
||||
if state_embeddings is not None:
|
||||
state_slice = state_embeddings[frame_indices]
|
||||
state_slices.append(state_slice)
|
||||
|
||||
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
|
||||
if state_embeddings is not None:
|
||||
state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim)
|
||||
# Pad states to max_state_dim (same as training processor)
|
||||
state_slices = pad_state_to_max_dim(state_slices, model.config.max_state_dim)
|
||||
else:
|
||||
state_slices = None
|
||||
|
||||
logger.info("Running SARM inference on all slices...")
|
||||
# Process in batches
|
||||
all_progress = []
|
||||
all_stages = []
|
||||
|
||||
for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"):
|
||||
batch_video = video_slices[i:i + batch_size].to(model.device)
|
||||
batch_states = state_slices[i:i + batch_size].to(model.device) if state_slices is not None else None
|
||||
batch_size_actual = batch_video.shape[0]
|
||||
|
||||
# Replicate text embedding for batch
|
||||
batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device)
|
||||
|
||||
# Get predictions
|
||||
stage_logits, stage_probs, progress_preds = model.sarm_transformer(
|
||||
batch_video, batch_text, batch_states
|
||||
)
|
||||
|
||||
# Extract predictions at the "current frame" position
|
||||
# With symmetric pattern [initial, t-4g, t-3g, t-2g, t-g, t, t+g, t+2g, t+3g],
|
||||
# the current frame is at position 5 (0-indexed)
|
||||
current_frame_idx = 5
|
||||
batch_progress = progress_preds[:, current_frame_idx, 0].cpu().numpy()
|
||||
batch_stages = stage_probs[:, current_frame_idx, :].cpu().numpy()
|
||||
|
||||
all_progress.extend(batch_progress)
|
||||
all_stages.extend(batch_stages)
|
||||
|
||||
return np.array(all_progress), np.array(all_stages)
|
||||
|
||||
|
||||
def compute_ground_truth_progress(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
temporal_proportions: dict[str, float],
|
||||
subtask_names_ordered: list[str],
|
||||
) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
|
||||
"""
|
||||
Compute ground truth progress and stage labels for an episode using annotations.
|
||||
|
||||
Uses SARM Paper Formula (2):
|
||||
y_t = P_{k-1} + ᾱ_k × τ_t
|
||||
|
||||
where:
|
||||
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
|
||||
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
|
||||
- ᾱ_k is the temporal proportion for subtask k
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset instance
|
||||
episode_index: Index of the episode
|
||||
temporal_proportions: Dict mapping subtask name to proportion
|
||||
subtask_names_ordered: Ordered list of subtask names (for consistent stage indexing)
|
||||
|
||||
Returns:
|
||||
Tuple of (ground_truth_progress, ground_truth_stages) arrays, or (None, None) if no annotations
|
||||
"""
|
||||
# Load episode metadata
|
||||
episodes_df = dataset.meta.episodes.to_pandas()
|
||||
|
||||
# Check if annotations exist
|
||||
if "subtask_names" not in episodes_df.columns:
|
||||
logger.warning("No subtask_names column found in episodes metadata")
|
||||
return None, None
|
||||
|
||||
ep_subtask_names = episodes_df.loc[episode_index, "subtask_names"]
|
||||
if ep_subtask_names is None or (isinstance(ep_subtask_names, float) and pd.isna(ep_subtask_names)):
|
||||
logger.warning(f"No annotations found for episode {episode_index}")
|
||||
return None, None
|
||||
|
||||
subtask_start_frames = episodes_df.loc[episode_index, "subtask_start_frames"]
|
||||
subtask_end_frames = episodes_df.loc[episode_index, "subtask_end_frames"]
|
||||
|
||||
# Get episode boundaries
|
||||
ep_start = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
ep_end = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
num_frames = ep_end - ep_start
|
||||
|
||||
# Get temporal proportions as ordered list
|
||||
temporal_proportions_list = [
|
||||
temporal_proportions.get(name, 0.0) for name in subtask_names_ordered
|
||||
]
|
||||
|
||||
logger.info(f"Computing ground truth for {num_frames} frames using {len(ep_subtask_names)} annotated subtasks")
|
||||
logger.info(f"Subtask names in episode: {ep_subtask_names}")
|
||||
logger.info(f"Subtask start frames: {subtask_start_frames}")
|
||||
logger.info(f"Subtask end frames: {subtask_end_frames}")
|
||||
logger.info(f"Temporal proportions (ordered): {dict(zip(subtask_names_ordered, temporal_proportions_list))}")
|
||||
|
||||
# Compute ground truth for each frame
|
||||
gt_progress = np.zeros(num_frames)
|
||||
gt_stages = np.zeros(num_frames, dtype=np.int32)
|
||||
|
||||
for frame_rel in range(num_frames):
|
||||
# Find which subtask this frame belongs to
|
||||
found = False
|
||||
for j, (name, start_frame, end_frame) in enumerate(zip(ep_subtask_names, subtask_start_frames, subtask_end_frames)):
|
||||
if frame_rel >= start_frame and frame_rel <= end_frame:
|
||||
# Found the subtask - get its global index
|
||||
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else 0
|
||||
|
||||
# Compute τ_t using utility function
|
||||
tau = compute_tau(frame_rel, start_frame, end_frame)
|
||||
|
||||
# Compute cumulative progress using utility function
|
||||
progress = compute_cumulative_progress_batch(tau, stage_idx, temporal_proportions_list)
|
||||
|
||||
gt_progress[frame_rel] = progress
|
||||
gt_stages[frame_rel] = stage_idx
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
# Handle frames outside annotated subtasks
|
||||
if frame_rel < subtask_start_frames[0]:
|
||||
gt_progress[frame_rel] = 0.0
|
||||
gt_stages[frame_rel] = 0
|
||||
elif frame_rel > subtask_end_frames[-1]:
|
||||
gt_progress[frame_rel] = 1.0
|
||||
gt_stages[frame_rel] = len(subtask_names_ordered) - 1
|
||||
else:
|
||||
# Between subtasks - find previous subtask
|
||||
for j in range(len(ep_subtask_names) - 1):
|
||||
if frame_rel > subtask_end_frames[j] and frame_rel < subtask_start_frames[j + 1]:
|
||||
name = ep_subtask_names[j]
|
||||
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else j
|
||||
progress = compute_cumulative_progress_batch(1.0, stage_idx, temporal_proportions_list)
|
||||
gt_progress[frame_rel] = progress
|
||||
gt_stages[frame_rel] = stage_idx
|
||||
break
|
||||
|
||||
logger.info(f"✓ Ground truth computed: final={gt_progress[-1]:.3f}, max={gt_progress.max():.3f}")
|
||||
return gt_progress, gt_stages
|
||||
|
||||
|
||||
def visualize_predictions(
|
||||
frames: np.ndarray,
|
||||
progress_predictions: np.ndarray,
|
||||
stage_predictions: np.ndarray,
|
||||
task_description: str,
|
||||
output_path: Path,
|
||||
num_sample_frames: int = 8,
|
||||
figsize: tuple = (14, 8),
|
||||
subtask_names: list[str] | None = None,
|
||||
temporal_proportions: dict[str, float] | None = None,
|
||||
ground_truth_progress: np.ndarray | None = None,
|
||||
ground_truth_stages: np.ndarray | None = None,
|
||||
):
|
||||
"""
|
||||
Create visualization of SARM predictions with optional ground truth comparison.
|
||||
|
||||
Args:
|
||||
frames: Video frames (num_frames, H, W, C)
|
||||
progress_predictions: Progress predictions (num_frames,)
|
||||
stage_predictions: Stage probabilities (num_frames, num_stages)
|
||||
task_description: Task description
|
||||
output_path: Path to save the figure
|
||||
num_sample_frames: Number of frames to show
|
||||
figsize: Figure size (width, height)
|
||||
subtask_names: Optional list of subtask names for labeling
|
||||
temporal_proportions: Optional dict of temporal proportions for each subtask
|
||||
ground_truth_progress: Optional ground truth progress array (num_frames,)
|
||||
ground_truth_stages: Optional ground truth stage indices array (num_frames,)
|
||||
"""
|
||||
num_stages = stage_predictions.shape[1]
|
||||
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
|
||||
|
||||
# Use subtask names if available, otherwise use generic labels
|
||||
if subtask_names is not None and len(subtask_names) == num_stages:
|
||||
stage_labels = subtask_names
|
||||
else:
|
||||
stage_labels = [f'Stage {i+1}' for i in range(num_stages)]
|
||||
|
||||
# Create figure with progress plot, stage plot, and sample frames
|
||||
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
|
||||
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
|
||||
|
||||
ax_progress = fig.add_subplot(gs[0])
|
||||
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
|
||||
ax_frames = fig.add_subplot(gs[2])
|
||||
|
||||
frame_indices = np.arange(len(progress_predictions))
|
||||
|
||||
# Plot 1: Progress over time
|
||||
ax_progress.plot(frame_indices, progress_predictions, linewidth=2, color='#2E86AB', label='Predicted Progress')
|
||||
ax_progress.fill_between(frame_indices, 0, progress_predictions, alpha=0.3, color='#2E86AB')
|
||||
|
||||
# Plot ground truth if available
|
||||
if ground_truth_progress is not None:
|
||||
ax_progress.plot(frame_indices, ground_truth_progress, linewidth=2, color='#28A745',
|
||||
linestyle='--', label='Ground Truth Progress')
|
||||
ax_progress.fill_between(frame_indices, 0, ground_truth_progress, alpha=0.15, color='#28A745')
|
||||
|
||||
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
|
||||
ax_progress.set_ylabel('Task Progress', fontsize=12)
|
||||
ax_progress.set_title(f'Task: "{task_description}"', fontsize=14, fontweight='bold')
|
||||
ax_progress.grid(True, alpha=0.3)
|
||||
ax_progress.set_ylim(-0.05, 1.1)
|
||||
ax_progress.legend(loc='upper left')
|
||||
|
||||
# Add statistics box
|
||||
stats_text = (
|
||||
f'Frames: {len(progress_predictions)}\n'
|
||||
f'Final Progress: {progress_predictions[-1]:.3f}\n'
|
||||
f'Max Progress: {progress_predictions.max():.3f}\n'
|
||||
f'Mean Progress: {progress_predictions.mean():.3f}'
|
||||
)
|
||||
if ground_truth_progress is not None:
|
||||
mse = np.mean((progress_predictions - ground_truth_progress) ** 2)
|
||||
stats_text += f'\nMSE vs GT: {mse:.4f}'
|
||||
stats_text += f'\nGT Final: {ground_truth_progress[-1]:.3f}'
|
||||
|
||||
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
|
||||
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
|
||||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||||
|
||||
# Plot 2: Stage predictions (stacked area plot)
|
||||
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
|
||||
colors=stage_colors, alpha=0.8, labels=stage_labels)
|
||||
|
||||
# Plot ground truth stage as vertical bands or markers
|
||||
if ground_truth_stages is not None:
|
||||
# Find stage transition points in ground truth
|
||||
stage_changes = np.where(np.diff(ground_truth_stages) != 0)[0] + 1
|
||||
for change_idx in stage_changes:
|
||||
ax_stages.axvline(x=change_idx, color='black', linestyle='-', alpha=0.7, linewidth=1.5)
|
||||
ax_progress.axvline(x=change_idx, color='black', linestyle='-', alpha=0.3, linewidth=1)
|
||||
|
||||
# Add small markers at bottom showing GT stage
|
||||
gt_stage_normalized = ground_truth_stages / max(num_stages - 1, 1)
|
||||
ax_stages.scatter(frame_indices[::30], np.zeros(len(frame_indices[::30])) + 0.02,
|
||||
c=[stage_colors[s] for s in ground_truth_stages[::30]],
|
||||
s=20, marker='|', alpha=0.8, label='GT Stage Markers')
|
||||
|
||||
ax_stages.set_xlabel('Frame Index', fontsize=12)
|
||||
ax_stages.set_ylabel('Stage Probability', fontsize=12)
|
||||
ax_stages.set_ylim(0, 1)
|
||||
ax_stages.grid(True, alpha=0.3)
|
||||
|
||||
# Adjust legend based on number of stages and label lengths
|
||||
if num_stages <= 5:
|
||||
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
|
||||
else:
|
||||
ax_stages.legend(loc='upper left', ncol=3, fontsize=7)
|
||||
|
||||
# Add vertical lines and labels for expected stage transitions (if temporal proportions available)
|
||||
if temporal_proportions is not None and subtask_names is not None:
|
||||
cumulative_progress = 0.0
|
||||
for i, name in enumerate(stage_labels):
|
||||
if name in temporal_proportions:
|
||||
# Find approximate frame where this stage should end
|
||||
stage_end_progress = cumulative_progress + temporal_proportions[name]
|
||||
|
||||
# Find frame index closest to this progress
|
||||
progress_diffs = np.abs(progress_predictions - stage_end_progress)
|
||||
stage_end_frame = np.argmin(progress_diffs)
|
||||
|
||||
# Draw vertical line
|
||||
ax_progress.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
|
||||
ax_stages.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
|
||||
|
||||
cumulative_progress = stage_end_progress
|
||||
|
||||
# Plot 3: Sample frames (if requested)
|
||||
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
|
||||
|
||||
ax_frames.axis('off')
|
||||
|
||||
# Create grid for frames
|
||||
frame_height = frames[0].shape[0]
|
||||
frame_width = frames[0].shape[1]
|
||||
|
||||
combined_width = frame_width * num_sample_frames
|
||||
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
|
||||
|
||||
for i, frame_idx in enumerate(frame_indices_to_show):
|
||||
frame = frames[frame_idx]
|
||||
if frame.shape[-1] == 1:
|
||||
frame = np.repeat(frame, 3, axis=-1)
|
||||
|
||||
# Add frame to combined image
|
||||
x_start = i * frame_width
|
||||
x_end = (i + 1) * frame_width
|
||||
combined_image[:, x_start:x_end] = frame
|
||||
|
||||
# Add frame number, progress, and stage
|
||||
progress_val = progress_predictions[frame_idx]
|
||||
stage_idx = np.argmax(stage_predictions[frame_idx])
|
||||
stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
|
||||
|
||||
# Truncate long stage names for display
|
||||
if len(stage_name) > 15:
|
||||
stage_name = stage_name[:12] + '...'
|
||||
|
||||
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
|
||||
|
||||
# Draw label on image
|
||||
ax_frames.text(x_start + frame_width / 2, -10, label,
|
||||
ha='center', va='top', fontsize=7,
|
||||
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
|
||||
|
||||
ax_frames.imshow(combined_image)
|
||||
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
|
||||
|
||||
plt.tight_layout()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
logger.info(f"Saved visualization to {output_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Setup device
|
||||
if args.device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
device = args.device
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Load model
|
||||
logger.info(f"Loading SARM model from {args.model_id}...")
|
||||
model = SARMRewardModel.from_pretrained(args.model_id)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset {args.dataset_repo}...")
|
||||
dataset = LeRobotDataset(args.dataset_repo)
|
||||
logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames")
|
||||
|
||||
# Validate episode index
|
||||
if args.episode_index >= len(dataset.meta.episodes):
|
||||
raise ValueError(
|
||||
f"Episode index {args.episode_index} out of range. "
|
||||
f"Dataset has {len(dataset.meta.episodes)} episodes."
|
||||
)
|
||||
|
||||
image_key = args.image_key if args.image_key is not None else model.config.image_key
|
||||
state_key = args.state_key if args.state_key is not None else model.config.state_key
|
||||
logger.info(f"Using image key: {image_key}")
|
||||
logger.info(f"Using state key: {state_key}")
|
||||
|
||||
# Load dataset stats for state normalization (same as training)
|
||||
dataset_stats = load_stats(dataset.root)
|
||||
if dataset_stats:
|
||||
logger.info(f"✓ Loaded dataset stats from {dataset.root}")
|
||||
else:
|
||||
logger.warning("⚠ Could not load dataset stats - states will not be normalized")
|
||||
|
||||
# Load episode data
|
||||
frames, states, start_idx, end_idx, dataset_task = load_episode_data(
|
||||
dataset, args.episode_index, image_key, state_key
|
||||
)
|
||||
|
||||
# Use task description from dataset if available, otherwise use command-line argument
|
||||
task_description = dataset_task if dataset_task is not None else args.task_description
|
||||
logger.info(f"Using task description: '{task_description}'")
|
||||
|
||||
# Run inference
|
||||
progress_predictions, stage_predictions = run_inference(
|
||||
model, frames, states, task_description,
|
||||
dataset_stats=dataset_stats, state_key=state_key
|
||||
)
|
||||
|
||||
# Extract subtask names and temporal proportions from model config if available
|
||||
subtask_names = None
|
||||
temporal_proportions = None
|
||||
|
||||
if hasattr(model.config, 'subtask_names') and model.config.subtask_names is not None:
|
||||
subtask_names = model.config.subtask_names
|
||||
logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}")
|
||||
|
||||
# Try to load temporal proportions from model config
|
||||
if hasattr(model.config, 'temporal_proportions') and model.config.temporal_proportions is not None:
|
||||
temporal_proportions = {
|
||||
name: prop for name, prop in zip(model.config.subtask_names, model.config.temporal_proportions)
|
||||
}
|
||||
logger.info(f"✓ Loaded temporal proportions from model config: {temporal_proportions}")
|
||||
|
||||
# Fallback: try to load from dataset meta
|
||||
if temporal_proportions is None:
|
||||
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
|
||||
if proportions_path.exists():
|
||||
with open(proportions_path, 'r') as f:
|
||||
temporal_proportions = json.load(f)
|
||||
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
|
||||
|
||||
# Also extract subtask names from proportions if not already set
|
||||
if subtask_names is None:
|
||||
subtask_names = sorted(temporal_proportions.keys())
|
||||
logger.info(f"✓ Extracted subtask names from proportions: {subtask_names}")
|
||||
|
||||
# Compute ground truth progress if annotations are available
|
||||
ground_truth_progress = None
|
||||
ground_truth_stages = None
|
||||
|
||||
if temporal_proportions is not None and subtask_names is not None:
|
||||
logger.info("Attempting to compute ground truth progress from annotations...")
|
||||
ground_truth_progress, ground_truth_stages = compute_ground_truth_progress(
|
||||
dataset,
|
||||
args.episode_index,
|
||||
temporal_proportions,
|
||||
subtask_names
|
||||
)
|
||||
if ground_truth_progress is None:
|
||||
logger.warning("⚠ Ground truth not available - annotations may be missing for this episode")
|
||||
else:
|
||||
logger.warning("⚠ Cannot compute ground truth - temporal_proportions or subtask_names not available")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
|
||||
|
||||
visualize_predictions(
|
||||
frames,
|
||||
progress_predictions,
|
||||
stage_predictions,
|
||||
task_description,
|
||||
output_path,
|
||||
num_sample_frames=args.num_sample_frames,
|
||||
figsize=tuple(args.figsize),
|
||||
subtask_names=subtask_names,
|
||||
temporal_proportions=temporal_proportions,
|
||||
ground_truth_progress=ground_truth_progress,
|
||||
ground_truth_stages=ground_truth_stages,
|
||||
)
|
||||
|
||||
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz"
|
||||
save_dict = {
|
||||
'progress': progress_predictions,
|
||||
'stages': stage_predictions
|
||||
}
|
||||
if ground_truth_progress is not None:
|
||||
save_dict['gt_progress'] = ground_truth_progress
|
||||
save_dict['gt_stages'] = ground_truth_stages
|
||||
np.savez(predictions_path, **save_dict)
|
||||
logger.info(f"Saved predictions to {predictions_path}")
|
||||
logger.info(f"\nVisualization: {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -43,6 +43,11 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s
|
||||
|
||||
cameras[key] = Reachy2Camera(cfg)
|
||||
|
||||
elif cfg.type == "zmq":
|
||||
from .zmq import ZMQCamera
|
||||
|
||||
cameras[key] = ZMQCamera(cfg)
|
||||
|
||||
else:
|
||||
try:
|
||||
cameras[key] = cast(Camera, make_device_from_device_class(cfg))
|
||||
|
||||
16
src/lerobot/cameras/zmq/__init__.py
Normal file
16
src/lerobot/cameras/zmq/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .camera_zmq import ZMQCamera
|
||||
from .configuration_zmq import ZMQCameraConfig
|
||||
623
src/lerobot/cameras/zmq/camera_zmq.py
Normal file
623
src/lerobot/cameras/zmq/camera_zmq.py
Normal file
@@ -0,0 +1,623 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Provides the ZMQCamera class for capturing frames from remote cameras via ZeroMQ.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
import base64
|
||||
import cv2
|
||||
import numpy as np
|
||||
import zmq
|
||||
from numpy.typing import NDArray
|
||||
import base64
|
||||
import msgpack
|
||||
import msgpack_numpy as m
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..camera import Camera
|
||||
from ..configs import ColorMode
|
||||
from .configuration_zmq import ZMQCameraConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZMQCamera(Camera):
|
||||
"""
|
||||
Manages camera interactions using ZeroMQ for remote frame streaming.
|
||||
|
||||
This class provides a high-level interface to connect to remote cameras
|
||||
that stream JPEG-encoded images over ZeroMQ PUB/SUB sockets. It supports
|
||||
both synchronous and asynchronous frame reading.
|
||||
|
||||
The camera server must be running and publishing JPEG images on the specified
|
||||
address and port. Use the provided utility script to find available ZMQ cameras:
|
||||
```bash
|
||||
lerobot-find-cameras zmq
|
||||
```
|
||||
|
||||
Example:
|
||||
```python
|
||||
from lerobot.cameras.zmq import ZMQCamera
|
||||
from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig, ColorMode
|
||||
|
||||
# Basic usage
|
||||
config = ZMQCameraConfig(
|
||||
server_address="192.168.123.164",
|
||||
port=5554,
|
||||
camera_name="remote_cam"
|
||||
)
|
||||
camera = ZMQCamera(config)
|
||||
camera.connect()
|
||||
|
||||
# Read 1 frame synchronously
|
||||
color_image = camera.read()
|
||||
print(color_image.shape)
|
||||
|
||||
# Read 1 frame asynchronously
|
||||
async_image = camera.async_read()
|
||||
|
||||
# When done, properly disconnect the camera
|
||||
camera.disconnect()
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, config: ZMQCameraConfig):
|
||||
"""
|
||||
Initializes the ZMQCamera instance.
|
||||
|
||||
Args:
|
||||
config: The configuration settings for the ZMQ camera.
|
||||
"""
|
||||
super().__init__(config)
|
||||
|
||||
self.config = config
|
||||
self.server_address = config.server_address
|
||||
self.port = config.port
|
||||
self.camera_name = config.camera_name
|
||||
self.color_mode = config.color_mode
|
||||
self.timeout_ms = config.timeout_ms
|
||||
|
||||
self.context: zmq.Context | None = None
|
||||
self.socket: zmq.Socket | None = None
|
||||
self._connected = False
|
||||
|
||||
self.thread: Thread | None = None
|
||||
self.stop_event: Event | None = None
|
||||
self.frame_lock: Lock = Lock()
|
||||
self.latest_frame: NDArray[Any] | None = None
|
||||
self.new_frame_event: Event = Event()
|
||||
|
||||
# Format type detected during connection (msgpack, json, or raw_jpeg)
|
||||
self._format_type: str | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.camera_name}@{self.server_address}:{self.port})"
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Checks if the camera is currently connected."""
|
||||
return self._connected and self.context is not None and self.socket is not None
|
||||
|
||||
def connect(self, warmup: bool = True) -> None:
|
||||
"""
|
||||
Connects to the ZMQ camera server and configures settings.
|
||||
|
||||
Args:
|
||||
warmup: If True (default), captures a warmup frame before returning.
|
||||
|
||||
Raises:
|
||||
DeviceAlreadyConnectedError: If the camera is already connected.
|
||||
RuntimeError: If connection to the ZMQ server fails.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} is already connected.")
|
||||
|
||||
logger.info(f"Connecting to {self}...")
|
||||
|
||||
try:
|
||||
self.context = zmq.Context()
|
||||
self.socket = self.context.socket(zmq.SUB)
|
||||
self.socket.connect(f"tcp://{self.server_address}:{self.port}")
|
||||
self.socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
|
||||
# Set receive timeout
|
||||
self.socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms)
|
||||
|
||||
self._connected = True
|
||||
|
||||
# Try to receive one frame to validate connection and detect format
|
||||
try:
|
||||
# Try each format until one works
|
||||
test_frame = None
|
||||
for format_type in ["msgpack", "json", "raw_jpeg"]:
|
||||
try:
|
||||
test_frame = self.read(format=format_type)
|
||||
self._format_type = format_type
|
||||
logger.info(f"{self} detected format: {format_type}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"{self} format '{format_type}' failed: {e}")
|
||||
continue
|
||||
|
||||
if test_frame is None:
|
||||
raise RuntimeError("Failed to decode frame with any supported format (msgpack, json, raw_jpeg)")
|
||||
|
||||
# Auto-detect resolution if not specified
|
||||
if self.width is None or self.height is None:
|
||||
h, w = test_frame.shape[:2]
|
||||
self.height = h
|
||||
self.width = w
|
||||
logger.info(f"{self} auto-detected resolution: {w}x{h}")
|
||||
|
||||
logger.info(f"{self} connected successfully.")
|
||||
|
||||
if warmup:
|
||||
logger.debug(f"Warming up {self}...")
|
||||
time.sleep(0.1) # Brief warmup period
|
||||
|
||||
except Exception as e:
|
||||
self._connected = False
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
if self.context:
|
||||
self.context.term()
|
||||
self.socket = None
|
||||
self.context = None
|
||||
raise RuntimeError(f"Failed to receive initial frame from {self}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
self._connected = False
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
if self.context:
|
||||
self.context.term()
|
||||
self.socket = None
|
||||
self.context = None
|
||||
raise RuntimeError(f"Failed to connect to {self}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def find_cameras(
|
||||
subnet: str | None = None,
|
||||
ports: list[int] | None = None,
|
||||
timeout_ms: int = 200,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Scans the local network for ZMQ cameras (fast parallel scan).
|
||||
|
||||
Uses threading to scan multiple hosts simultaneously. Without parallelization,
|
||||
scanning 254 hosts would take 6+ minutes. With threads, takes ~10-15 seconds.
|
||||
|
||||
Args:
|
||||
subnet: Network subnet to scan (e.g., "192.168.1.0/24"). If None, auto-detects.
|
||||
ports: List of ports to scan. Defaults to [5554, 5555, 5556].
|
||||
timeout_ms: Connection timeout per host in milliseconds. Default: 200ms.
|
||||
|
||||
Returns:
|
||||
List of dicts containing camera info (address, port, format, resolution).
|
||||
|
||||
Example:
|
||||
>>> cameras = ZMQCamera.find_cameras()
|
||||
>>> # Or specify: cameras = ZMQCamera.find_cameras(subnet="10.0.0.0/24", ports=[5554])
|
||||
"""
|
||||
import socket
|
||||
import ipaddress
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
if ports is None:
|
||||
ports = [5554, 5555, 5556]
|
||||
|
||||
# Auto-detect local subnet
|
||||
if subnet is None:
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
s.connect(("8.8.8.8", 80))
|
||||
local_ip = s.getsockname()[0]
|
||||
s.close()
|
||||
subnet = ".".join(local_ip.split(".")[:-1]) + ".0/24"
|
||||
logger.info(f"Auto-detected subnet: {subnet}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to auto-detect subnet: {e}")
|
||||
return []
|
||||
|
||||
# Parse subnet
|
||||
try:
|
||||
network = ipaddress.ip_network(subnet, strict=False)
|
||||
hosts = list(network.hosts())
|
||||
# Always include localhost (for MuJoCo sim, local servers)
|
||||
hosts.insert(0, ipaddress.IPv4Address("127.0.0.1"))
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid subnet '{subnet}': {e}")
|
||||
return []
|
||||
|
||||
total = len(hosts) * len(ports)
|
||||
logger.info(f"Scanning {len(hosts)} hosts × {len(ports)} ports = {total} targets (this takes ~10-15s)...")
|
||||
|
||||
def test_target(host_ip: str, port: int) -> dict | None:
|
||||
"""Test one host:port for ZMQ camera."""
|
||||
ctx = zmq.Context()
|
||||
sock = ctx.socket(zmq.SUB)
|
||||
sock.connect(f"tcp://{host_ip}:{port}")
|
||||
sock.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
sock.setsockopt(zmq.RCVTIMEO, timeout_ms)
|
||||
|
||||
# Wait for subscription to establish (ZMQ "slow joiner" problem)
|
||||
time.sleep(0.1)
|
||||
|
||||
# Try receiving a few times
|
||||
msg = None
|
||||
for _ in range(3):
|
||||
try:
|
||||
msg = sock.recv()
|
||||
break
|
||||
except zmq.Again:
|
||||
time.sleep(0.05)
|
||||
|
||||
if msg is None:
|
||||
sock.close()
|
||||
ctx.term()
|
||||
return None
|
||||
|
||||
# Try formats: msgpack → json → raw_jpeg
|
||||
frame = fmt = None
|
||||
|
||||
# Msgpack
|
||||
try:
|
||||
d = msgpack.unpackb(msg, object_hook=m.decode)
|
||||
if isinstance(d, dict) and "images" in d and len(d["images"]) > 0:
|
||||
img = next(iter(d["images"].values()))
|
||||
if isinstance(img, str):
|
||||
frame = cv2.imdecode(np.frombuffer(base64.b64decode(img), np.uint8), cv2.IMREAD_COLOR)
|
||||
elif isinstance(img, np.ndarray):
|
||||
frame = img
|
||||
if frame is not None:
|
||||
fmt = "msgpack"
|
||||
except:
|
||||
pass
|
||||
|
||||
# JSON
|
||||
if frame is None:
|
||||
try:
|
||||
d = json.loads(msg.decode('utf-8'))
|
||||
if isinstance(d, dict):
|
||||
for v in d.values():
|
||||
if isinstance(v, str) and len(v) > 100:
|
||||
try:
|
||||
frame = cv2.imdecode(np.frombuffer(base64.b64decode(v), np.uint8), cv2.IMREAD_COLOR)
|
||||
if frame is not None:
|
||||
fmt = "json"
|
||||
break
|
||||
except:
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
|
||||
# Raw JPEG
|
||||
if frame is None:
|
||||
try:
|
||||
frame = cv2.imdecode(np.frombuffer(msg, np.uint8), cv2.IMREAD_COLOR)
|
||||
if frame is not None:
|
||||
fmt = "raw_jpeg"
|
||||
except:
|
||||
pass
|
||||
|
||||
sock.close()
|
||||
ctx.term()
|
||||
|
||||
if frame is not None:
|
||||
h, w = frame.shape[:2]
|
||||
return {
|
||||
"name": f"ZMQ @ {host_ip}:{port}",
|
||||
"type": "ZMQ",
|
||||
"id": f"{host_ip}:{port}",
|
||||
"server_address": host_ip,
|
||||
"port": port,
|
||||
"camera_name": f"cam_{host_ip.replace('.', '_')}_{port}",
|
||||
"format": fmt,
|
||||
"default_stream_profile": {"width": w, "height": h, "format": fmt.upper()},
|
||||
}
|
||||
return None
|
||||
|
||||
# Parallel scan with thread pool
|
||||
found = []
|
||||
with ThreadPoolExecutor(max_workers=100) as ex:
|
||||
futures = [ex.submit(test_target, str(h), p) for h in hosts for p in ports]
|
||||
for i, fut in enumerate(as_completed(futures), 1):
|
||||
if i % 100 == 0:
|
||||
logger.info(f" Progress: {i}/{total} ({100*i//total}%)")
|
||||
res = fut.result()
|
||||
if res:
|
||||
found.append(res)
|
||||
logger.info(f" ✓ {res['server_address']}:{res['port']} ({res['format']})")
|
||||
|
||||
logger.info(f"Scan complete! Found {len(found)} camera(s).")
|
||||
return found
|
||||
|
||||
def read(self, color_mode: ColorMode | None = None, format: str | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Reads a single frame synchronously from the ZMQ camera.
|
||||
|
||||
Supports three message formats:
|
||||
1. "msgpack": Msgpack with base64 JPEGs: {"timestamps": {...}, "images": {camera_name: "b64"}}
|
||||
(used by MuJoCo sim)
|
||||
2. "json": JSON with base64 JPEGs: {"state": 0.0, "camera_name": "b64jpeg"}
|
||||
(used by LeKiwi-style servers)
|
||||
3. "raw_jpeg": Raw JPEG bytes (used by Unitree G1 head camera)
|
||||
|
||||
Args:
|
||||
color_mode: Target color mode (RGB or BGR). If None, uses self.color_mode.
|
||||
format: Message format to use. If None, uses auto-detected format from connect().
|
||||
One of: "msgpack", "json", "raw_jpeg"
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decoded frame in shape (height, width, 3)
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If camera is not connected
|
||||
TimeoutError: If no frame received within timeout_ms
|
||||
RuntimeError: If frame decoding fails
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.socket is None:
|
||||
raise DeviceNotConnectedError(f"{self} socket is not initialized")
|
||||
|
||||
# Use detected format if not specified
|
||||
if format is None:
|
||||
format = self._format_type
|
||||
|
||||
if format is None:
|
||||
raise RuntimeError(f"{self} format not specified and not auto-detected during connect()")
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
message = self.socket.recv()
|
||||
except zmq.Again:
|
||||
raise TimeoutError(f"{self} timeout waiting for frame after {self.timeout_ms}ms")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"{self} read failed: {e}")
|
||||
|
||||
frame = None
|
||||
|
||||
# Decode based on format
|
||||
if format == "msgpack":
|
||||
data = msgpack.unpackb(message, object_hook=m.decode)
|
||||
if not isinstance(data, dict) or "images" not in data:
|
||||
raise RuntimeError(f"{self} invalid msgpack format: expected dict with 'images' key")
|
||||
|
||||
images_dict = data["images"]
|
||||
|
||||
# Prefer named camera if present
|
||||
if self.camera_name in images_dict:
|
||||
img_data = images_dict[self.camera_name]
|
||||
elif len(images_dict) > 0:
|
||||
# Fallback: first available camera
|
||||
img_data = next(iter(images_dict.values()))
|
||||
else:
|
||||
raise RuntimeError(f"{self} no images found in msgpack message")
|
||||
|
||||
# Decode the image data
|
||||
if isinstance(img_data, str):
|
||||
color_bytes = base64.b64decode(img_data)
|
||||
np_img = np.frombuffer(color_bytes, dtype=np.uint8)
|
||||
frame = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
|
||||
elif isinstance(img_data, np.ndarray):
|
||||
frame = img_data
|
||||
else:
|
||||
raise RuntimeError(f"{self} unknown image payload type: {type(img_data)}")
|
||||
|
||||
elif format == "json":
|
||||
data = json.loads(message.decode('utf-8'))
|
||||
if not isinstance(data, dict) or self.camera_name not in data:
|
||||
raise RuntimeError(f"{self} invalid JSON format: expected dict with '{self.camera_name}' key")
|
||||
|
||||
img_b64 = data[self.camera_name]
|
||||
if not isinstance(img_b64, str):
|
||||
raise RuntimeError(f"{self} expected base64 string in JSON, got {type(img_b64)}")
|
||||
|
||||
color_bytes = base64.b64decode(img_b64)
|
||||
np_img = np.frombuffer(color_bytes, dtype=np.uint8)
|
||||
frame = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
|
||||
|
||||
elif format == "raw_jpeg":
|
||||
np_img = np.frombuffer(message, dtype=np.uint8)
|
||||
frame = cv2.imdecode(np_img, cv2.IMREAD_COLOR)
|
||||
|
||||
else:
|
||||
raise ValueError(f"{self} unsupported format: {format}. Use 'msgpack', 'json', or 'raw_jpeg'")
|
||||
|
||||
if frame is None or not isinstance(frame, np.ndarray):
|
||||
raise RuntimeError(f"{self} failed to decode image using format '{format}'")
|
||||
|
||||
processed_frame = self._postprocess_image(frame, color_mode)
|
||||
|
||||
read_duration_ms = (time.perf_counter() - start_time) * 1e3
|
||||
logger.debug(f"{self} read took: {read_duration_ms:.1f}ms")
|
||||
|
||||
return processed_frame
|
||||
|
||||
def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]:
|
||||
"""
|
||||
Applies color conversion to a raw frame.
|
||||
|
||||
Args:
|
||||
image: The raw image frame (BGR format from cv2.imdecode).
|
||||
color_mode: The target color mode (RGB or BGR). If None, uses self.color_mode.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The processed image frame.
|
||||
|
||||
Raises:
|
||||
ValueError: If the requested color_mode is invalid.
|
||||
RuntimeError: If the frame dimensions don't match expectations.
|
||||
"""
|
||||
requested_color_mode = self.color_mode if color_mode is None else color_mode
|
||||
|
||||
if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}."
|
||||
)
|
||||
|
||||
h, w, c = image.shape
|
||||
|
||||
# Validate dimensions if they were specified
|
||||
if self.height is not None and self.width is not None:
|
||||
if h != self.height or w != self.width:
|
||||
logger.warning(
|
||||
f"{self} frame dimensions ({w}x{h}) don't match configured ({self.width}x{self.height}). "
|
||||
"This might be expected if the server sends different resolutions."
|
||||
)
|
||||
|
||||
if c != 3:
|
||||
raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).")
|
||||
|
||||
processed_image = image
|
||||
if requested_color_mode == ColorMode.RGB:
|
||||
processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
return processed_image
|
||||
|
||||
def _read_loop(self) -> None:
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
|
||||
On each iteration:
|
||||
1. Reads a frame from ZMQ
|
||||
2. Stores result in latest_frame (thread-safe)
|
||||
3. Sets new_frame_event to notify listeners
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
while not self.stop_event.is_set():
|
||||
try:
|
||||
frame = self.read()
|
||||
|
||||
with self.frame_lock:
|
||||
self.latest_frame = frame
|
||||
self.new_frame_event.set()
|
||||
|
||||
except DeviceNotConnectedError:
|
||||
break
|
||||
except TimeoutError:
|
||||
# Timeout is expected occasionally, just continue
|
||||
logger.debug(f"{self} read timeout in background thread")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
def _start_read_thread(self) -> None:
|
||||
"""Starts or restarts the background read thread if it's not running."""
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=0.1)
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
self.stop_event = Event()
|
||||
self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop")
|
||||
self.thread.daemon = True
|
||||
self.thread.start()
|
||||
|
||||
def _stop_read_thread(self) -> None:
|
||||
"""Signals the background read thread to stop and waits for it to join."""
|
||||
if self.stop_event is not None:
|
||||
self.stop_event.set()
|
||||
|
||||
if self.thread is not None and self.thread.is_alive():
|
||||
self.thread.join(timeout=2.0)
|
||||
|
||||
self.thread = None
|
||||
self.stop_event = None
|
||||
|
||||
def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]:
|
||||
"""
|
||||
Reads the latest available frame asynchronously.
|
||||
|
||||
This method retrieves the most recent frame captured by the background
|
||||
read thread. It does not block waiting for ZMQ directly, but may wait
|
||||
up to timeout_ms for the background thread to provide a frame.
|
||||
|
||||
Args:
|
||||
timeout_ms: Maximum time in milliseconds to wait for a frame
|
||||
to become available. Defaults to 2000ms.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The latest captured frame as a NumPy array in the format
|
||||
(height, width, channels), processed according to configuration.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is not connected.
|
||||
TimeoutError: If no frame becomes available within the specified timeout.
|
||||
RuntimeError: If an unexpected error occurs.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.thread is None or not self.thread.is_alive():
|
||||
self._start_read_thread()
|
||||
|
||||
if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0):
|
||||
thread_alive = self.thread is not None and self.thread.is_alive()
|
||||
raise TimeoutError(
|
||||
f"Timed out waiting for frame from {self} after {timeout_ms} ms. "
|
||||
f"Read thread alive: {thread_alive}."
|
||||
)
|
||||
|
||||
with self.frame_lock:
|
||||
frame = self.latest_frame
|
||||
self.new_frame_event.clear()
|
||||
|
||||
if frame is None:
|
||||
raise RuntimeError(f"Internal error: Event set but no frame available for {self}.")
|
||||
|
||||
return frame
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""
|
||||
Disconnects from the ZMQ camera and cleans up resources.
|
||||
|
||||
Stops the background read thread (if running) and closes the ZMQ socket.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If the camera is already disconnected.
|
||||
"""
|
||||
if not self.is_connected and self.thread is None:
|
||||
raise DeviceNotConnectedError(f"{self} not connected.")
|
||||
|
||||
if self.thread is not None:
|
||||
self._stop_read_thread()
|
||||
|
||||
if self.socket is not None:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
if self.context is not None:
|
||||
self.context.term()
|
||||
self.context = None
|
||||
|
||||
self._connected = False
|
||||
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
78
src/lerobot/cameras/zmq/configuration_zmq.py
Normal file
78
src/lerobot/cameras/zmq/configuration_zmq.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..configs import CameraConfig, ColorMode
|
||||
|
||||
__all__ = ["ZMQCameraConfig", "ColorMode"]
|
||||
|
||||
|
||||
@CameraConfig.register_subclass("zmq")
|
||||
@dataclass
|
||||
class ZMQCameraConfig(CameraConfig):
|
||||
"""Configuration class for ZMQ-based remote camera streams.
|
||||
|
||||
This class provides configuration options for cameras accessed through ZeroMQ (ZMQ),
|
||||
supporting remote camera streams over the network. The server must be running and
|
||||
streaming JPEG-encoded images over a ZMQ PUB socket.
|
||||
|
||||
Example configurations:
|
||||
```python
|
||||
# Basic configuration
|
||||
ZMQCameraConfig(
|
||||
server_address="192.168.123.164",
|
||||
port=5554,
|
||||
camera_name="remote_cam_1"
|
||||
)
|
||||
|
||||
# With custom resolution
|
||||
ZMQCameraConfig(
|
||||
server_address="10.0.0.100",
|
||||
port=5555,
|
||||
camera_name="lab_cam",
|
||||
width=1280,
|
||||
height=480,
|
||||
fps=30
|
||||
)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
server_address: IP address or hostname of the ZMQ image server.
|
||||
port: Port number where the ZMQ server is publishing images.
|
||||
camera_name: Identifier name for this camera (for logging/debugging).
|
||||
color_mode: Color mode for image output (RGB or BGR). Defaults to RGB.
|
||||
timeout_ms: Timeout in milliseconds for receiving frames. Defaults to 1000ms.
|
||||
"""
|
||||
|
||||
server_address: str
|
||||
port: int = 5554
|
||||
camera_name: str = "zmq_camera"
|
||||
color_mode: ColorMode = ColorMode.RGB
|
||||
timeout_ms: int = 5000
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.color_mode not in (ColorMode.RGB, ColorMode.BGR):
|
||||
raise ValueError(
|
||||
f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided."
|
||||
)
|
||||
|
||||
if self.timeout_ms <= 0:
|
||||
raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.")
|
||||
|
||||
if not self.server_address:
|
||||
raise ValueError("`server_address` cannot be empty.")
|
||||
|
||||
if self.port <= 0 or self.port > 65535:
|
||||
raise ValueError(f"`port` must be between 1 and 65535, but {self.port} is provided.")
|
||||
@@ -64,26 +64,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
reward_model_path: str | None = None # Path to pre-trained reward model (e.g., SARM)
|
||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
||||
rabc_update_freq: int = 1 # Compute rewards every N batches (1 = every batch)
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
|
||||
def validate(self):
|
||||
# Validate RA-BC configuration
|
||||
if self.use_rabc and not self.reward_model_path:
|
||||
raise ValueError(
|
||||
"RA-BC is enabled (use_rabc=True) but no reward_model_path provided. "
|
||||
"Please specify a pre-trained reward model (e.g., SARM) path."
|
||||
)
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
|
||||
@@ -999,18 +999,10 @@ def _copy_data_with_feature_changes(
|
||||
df[feature_name] = feature_values
|
||||
else:
|
||||
feature_slice = values[frame_idx:end_idx]
|
||||
if len(feature_slice.shape) == 1:
|
||||
# 1D array - can assign directly
|
||||
df[feature_name] = feature_slice
|
||||
elif len(feature_slice.shape) == 2 and feature_slice.shape[1] == 1:
|
||||
# 2D array with single column - flatten it
|
||||
if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
|
||||
df[feature_name] = feature_slice.flatten()
|
||||
elif len(feature_slice.shape) == 2:
|
||||
# 2D array with multiple columns (e.g., embeddings) - convert to list of lists
|
||||
df[feature_name] = feature_slice.tolist()
|
||||
else:
|
||||
# Higher dimensional - convert to list
|
||||
df[feature_name] = [row.tolist() for row in feature_slice]
|
||||
df[feature_name] = feature_slice
|
||||
frame_idx = end_idx
|
||||
|
||||
# Write using the same chunk/file structure as source
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
# LeRobot Embedding Generation Script
|
||||
|
||||
Generate embeddings for LeRobot datasets to make them more lightweight and efficient for training.
|
||||
|
||||
## Overview
|
||||
|
||||
This script processes v3.0 LeRobot datasets and adds pre-computed embeddings for:
|
||||
|
||||
- **Task embeddings**: Language command embeddings using MiniLM
|
||||
- **Image embeddings**: Frame embeddings using DinoV2
|
||||
|
||||
The resulting dataset can be used more efficiently during training by loading pre-computed embeddings instead of running encoders on-the-fly.
|
||||
|
||||
## Supported Encoders
|
||||
|
||||
### Image Encoders (DinoV2)
|
||||
|
||||
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings:
|
||||
|
||||
- **`dinov2_vits14`**: ViT-S/14 (384-dim) - Fastest, smaller model
|
||||
- **`dinov2_vitb14`**: ViT-B/14 (768-dim) - **Recommended** - Good balance
|
||||
- **`dinov2_vitl14`**: ViT-L/14 (1024-dim) - Best quality, slower
|
||||
|
||||
### Language Encoders (MiniLM)
|
||||
|
||||
MiniLM is a lightweight sentence transformer model:
|
||||
|
||||
- **`minilm-l6`**: MiniLM-L6-v2 (384-dim) - Faster
|
||||
- **`minilm-l12`**: MiniLM-L12-v2 (384-dim) - **Recommended** - Better quality
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Command
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
|
||||
--repo-id lerobot/utokyo_xarm_bimanual \
|
||||
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \
|
||||
--image-encoder dinov2_vitb14 \
|
||||
--language-encoder minilm-l12 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
### Lightweight Version (No Videos)
|
||||
|
||||
Removes video files to significantly reduce storage:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
|
||||
--repo-id lerobot/utokyo_xarm_bimanual \
|
||||
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \
|
||||
--image-encoder dinov2_vitb14 \
|
||||
--language-encoder minilm-l12 \
|
||||
--remove-videos \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
The script adds new features to your dataset:
|
||||
|
||||
### New Features
|
||||
|
||||
1. **`task_embedding`**: Language embedding for each frame
|
||||
- Shape: `[384]` (MiniLM)
|
||||
- One embedding per frame based on its task
|
||||
|
||||
2. **`{camera_key}_embedding`**: Image embedding for each camera view
|
||||
- Shape: `[384]`, `[768]`, or `[1024]` depending on DinoV2 model
|
||||
- Examples: `observation.images.top_embedding`, `observation.images.wrist_embedding`
|
||||
|
||||
### Using Embeddings in Training
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load dataset with embeddings
|
||||
dataset = LeRobotDataset("your-username/utokyo_xarm_bimanual_embeddings")
|
||||
|
||||
# Access embeddings
|
||||
item = dataset[0]
|
||||
task_emb = item["task_embedding"] # Shape: [384]
|
||||
img_emb = item["observation.images.top_embedding"] # Shape: [768]
|
||||
|
||||
# Use in your policy
|
||||
# Instead of running encoders during training, use pre-computed embeddings
|
||||
```
|
||||
|
||||
## Extending with New Encoders
|
||||
|
||||
The script is designed to be easily extensible. To add a new encoder:
|
||||
|
||||
### 1. Create Encoder Class
|
||||
|
||||
```python
|
||||
class MyCustomImageEncoder(ImageEncoder):
|
||||
"""Your custom image encoder."""
|
||||
|
||||
def __init__(self, device: str = "cuda"):
|
||||
super().__init__(device)
|
||||
# Load your model
|
||||
self.model = load_my_model()
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def encode(self, images: list[np.ndarray]) -> np.ndarray:
|
||||
"""Encode a batch of images."""
|
||||
# Your encoding logic here
|
||||
embeddings = []
|
||||
for img in images:
|
||||
emb = self.model(img)
|
||||
embeddings.append(emb)
|
||||
return np.array(embeddings)
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return embedding dimension."""
|
||||
return 512 # Your embedding dimension
|
||||
```
|
||||
|
||||
### 2. Add to Factory Function
|
||||
|
||||
```python
|
||||
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
|
||||
encoders = {
|
||||
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
|
||||
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
|
||||
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
|
||||
# Add your encoder
|
||||
"my_custom": lambda: MyCustomImageEncoder(device=device),
|
||||
}
|
||||
# ... rest of function
|
||||
```
|
||||
|
||||
## Validating Embeddings
|
||||
|
||||
After generating embeddings, you can validate them using `validate_embeddings.py`:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
|
||||
--original-repo-id lerobot/utokyo_xarm_bimanual \
|
||||
--embeddings-repo-id pepijn223/utokyo_xarm_bimanual_embeddings \
|
||||
--image-encoder dinov2_vitb14 \
|
||||
--language-encoder minilm-l12 \
|
||||
--num-samples 20
|
||||
```
|
||||
@@ -1,147 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageEncoder:
|
||||
"""Base class for image encoders."""
|
||||
|
||||
def __init__(self, device: str = "cuda"):
|
||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def encode(self, images: list[np.ndarray]) -> np.ndarray:
|
||||
"""Encode a batch of images."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DinoV2Encoder(ImageEncoder):
|
||||
"""DinoV2 image encoder.
|
||||
|
||||
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings.
|
||||
Supports multiple model sizes (ViT-S/14, ViT-B/14, ViT-L/14).
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "dinov2_vitb14", device: str = "cuda", batch_size: int = 32):
|
||||
super().__init__(device)
|
||||
self.batch_size = batch_size
|
||||
self.model_name = model_name
|
||||
logger.info(f"Loading DinoV2 model: {model_name}")
|
||||
self.model = torch.hub.load("facebookresearch/dinov2", model_name) # nosec B614
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# DinoV2 preprocessing
|
||||
from torchvision import transforms
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
|
||||
def encode(self, images: list[np.ndarray]) -> np.ndarray:
|
||||
"""Encode a batch of images."""
|
||||
embeddings = []
|
||||
|
||||
with torch.inference_mode():
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
batch_images = images[i : i + self.batch_size]
|
||||
# Convert numpy arrays to PIL Images and apply transforms
|
||||
pil_images = [Image.fromarray(img.astype(np.uint8)) for img in batch_images]
|
||||
tensors = torch.stack([self.transform(img) for img in pil_images]).to(self.device)
|
||||
|
||||
# Get embeddings
|
||||
batch_embeddings = self.model(tensors).cpu().numpy()
|
||||
embeddings.append(batch_embeddings)
|
||||
|
||||
return np.concatenate(embeddings, axis=0)
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return the embedding dimension based on model size."""
|
||||
if "vits14" in self.model_name:
|
||||
return 384 # DinoV2 ViT-S/14
|
||||
elif "vitb14" in self.model_name:
|
||||
return 768 # DinoV2 ViT-B/14
|
||||
elif "vitl14" in self.model_name:
|
||||
return 1024 # DinoV2 ViT-L/14
|
||||
else:
|
||||
return 768 # Default to ViT-B/14
|
||||
|
||||
|
||||
class LanguageEncoder:
|
||||
"""Base class for language encoders."""
|
||||
|
||||
def __init__(self, device: str = "cuda"):
|
||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def encode(self, texts: list[str]) -> np.ndarray:
|
||||
"""Encode a batch of texts."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MiniLMEncoder(LanguageEncoder):
|
||||
"""MiniLM language encoder.
|
||||
|
||||
MiniLM is a lightweight sentence transformer model that produces high-quality text embeddings.
|
||||
Supports L6 and L12 model sizes.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2", device: str = "cuda"):
|
||||
super().__init__(device)
|
||||
self.model_name = model_name
|
||||
logger.info(f"Loading MiniLM model: {model_name}")
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModel.from_pretrained(model_name).to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def _mean_pooling(self, model_output, attention_mask):
|
||||
"""Mean pooling to get sentence embeddings."""
|
||||
token_embeddings = model_output[0]
|
||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
||||
input_mask_expanded.sum(1), min=1e-9
|
||||
)
|
||||
|
||||
def encode(self, texts: list[str]) -> np.ndarray:
|
||||
"""Encode a batch of texts."""
|
||||
with torch.inference_mode():
|
||||
encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
||||
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
|
||||
|
||||
model_output = self.model(**encoded_input)
|
||||
embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"])
|
||||
|
||||
return embeddings.cpu().numpy()
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return the embedding dimension."""
|
||||
return 384 # Both MiniLM-L6 and L12 output 384-dim embeddings
|
||||
@@ -1,329 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Generate embeddings for LeRobot datasets to make them more lightweight and efficient.
|
||||
|
||||
This script:
|
||||
1. Loads a v3.0 LeRobot dataset from the hub
|
||||
2. Computes embeddings for tasks (language commands) and frames (images)
|
||||
3. Stores embeddings as new features in the dataset
|
||||
4. Optionally removes video files to reduce size
|
||||
5. Pushes the converted dataset to the hub
|
||||
|
||||
Current supported encoders:
|
||||
- Image: DinoV2 (dinov2_vits14, dinov2_vitb14, dinov2_vitl14)
|
||||
- Language: MiniLM (minilm-l6, minilm-l12)
|
||||
|
||||
The architecture is extensible - you can add more encoders by:
|
||||
1. Creating a new encoder class inheriting from ImageEncoder or LanguageEncoder
|
||||
2. Implementing the encode() method and embedding_dim property
|
||||
3. Adding it to the get_image_encoder() or get_language_encoder() factory function
|
||||
|
||||
Usage example:
|
||||
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
|
||||
--repo-id lerobot/utokyo_xarm_bimanual \
|
||||
--output-repo-id lerobot/utokyo_xarm_bimanual_embeddings \
|
||||
--image-encoder dinov2_vitb14 \
|
||||
--language-encoder minilm-l12 \
|
||||
--remove-videos \
|
||||
--push-to-hub
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.generating_embeddings.encoders import (
|
||||
DinoV2Encoder,
|
||||
ImageEncoder,
|
||||
LanguageEncoder,
|
||||
MiniLMEncoder,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
|
||||
"""Factory function to get image encoder.
|
||||
|
||||
To add a new encoder:
|
||||
1. Create a new class inheriting from ImageEncoder
|
||||
2. Implement encode() and embedding_dim property
|
||||
3. Add it to the encoders dictionary below
|
||||
"""
|
||||
encoders = {
|
||||
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
|
||||
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
|
||||
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
|
||||
}
|
||||
|
||||
if encoder_name not in encoders:
|
||||
raise ValueError(f"Unknown image encoder: {encoder_name}. Available options: {list(encoders.keys())}")
|
||||
|
||||
return encoders[encoder_name]()
|
||||
|
||||
|
||||
def get_language_encoder(encoder_name: str, device: str = "cuda") -> LanguageEncoder:
|
||||
"""Factory function to get language encoder.
|
||||
|
||||
To add a new encoder:
|
||||
1. Create a new class inheriting from LanguageEncoder
|
||||
2. Implement encode() and embedding_dim property
|
||||
3. Add it to the encoders dictionary below
|
||||
"""
|
||||
encoders = {
|
||||
"minilm-l6": lambda: MiniLMEncoder(
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2", device=device
|
||||
),
|
||||
"minilm-l12": lambda: MiniLMEncoder(
|
||||
model_name="sentence-transformers/all-MiniLM-L12-v2", device=device
|
||||
),
|
||||
}
|
||||
|
||||
if encoder_name not in encoders:
|
||||
raise ValueError(
|
||||
f"Unknown language encoder: {encoder_name}. Available options: {list(encoders.keys())}"
|
||||
)
|
||||
|
||||
return encoders[encoder_name]()
|
||||
|
||||
|
||||
def generate_embeddings_for_dataset(
|
||||
repo_id: str,
|
||||
output_repo_id: str,
|
||||
image_encoder: ImageEncoder,
|
||||
language_encoder: LanguageEncoder,
|
||||
remove_videos: bool = False,
|
||||
local_dir: Path | None = None,
|
||||
output_local_dir: Path | None = None,
|
||||
push_to_hub: bool = False,
|
||||
):
|
||||
"""Generate embeddings for a LeRobot dataset.
|
||||
|
||||
Args:
|
||||
repo_id: Source dataset repository ID
|
||||
output_repo_id: Output dataset repository ID
|
||||
image_encoder: Image encoder instance
|
||||
language_encoder: Language encoder instance
|
||||
remove_videos: Whether to remove video files
|
||||
local_dir: Local directory for source dataset
|
||||
output_local_dir: Local directory for output dataset
|
||||
push_to_hub: Whether to push to hub after conversion
|
||||
"""
|
||||
from lerobot.datasets.dataset_tools import modify_features
|
||||
|
||||
print(f"Loading dataset: {repo_id}")
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=local_dir, download_videos=True)
|
||||
print(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
print("Computing task embeddings...")
|
||||
unique_tasks = dataset.meta.tasks.index.tolist()
|
||||
task_embeddings = {}
|
||||
|
||||
for task in tqdm(unique_tasks, desc="Encoding tasks"):
|
||||
# Clean up task text
|
||||
task_clean = task.strip().capitalize().strip(" .,!?-_")
|
||||
embedding = language_encoder.encode([task_clean])[0]
|
||||
task_embeddings[task] = embedding
|
||||
|
||||
print(f"Computed {len(task_embeddings)} task embeddings")
|
||||
|
||||
print("Processing frames and computing embeddings...")
|
||||
all_task_embeddings = []
|
||||
all_image_embeddings_dict = {cam_key: [] for cam_key in dataset.meta.camera_keys}
|
||||
|
||||
for frame_idx in tqdm(range(dataset.num_frames), desc="Processing frames"):
|
||||
item = dataset.hf_dataset[frame_idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
task = dataset.meta.tasks.iloc[item["task_index"].item()].name
|
||||
task_emb = task_embeddings[task]
|
||||
all_task_embeddings.append(task_emb)
|
||||
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
if cam_key in dataset.meta.video_keys:
|
||||
current_ts = item["timestamp"].item()
|
||||
video_frames = dataset._query_videos({cam_key: [current_ts]}, ep_idx)
|
||||
img = video_frames[cam_key]
|
||||
|
||||
if isinstance(img, torch.Tensor):
|
||||
if img.ndim == 4:
|
||||
img = img[0] # (T, C, H, W) -> (C, H, W)
|
||||
elif img.ndim != 3:
|
||||
raise ValueError(f"Unexpected video frame shape {img.shape} for camera {cam_key}")
|
||||
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
||||
else:
|
||||
img_np = np.array(img)
|
||||
else:
|
||||
img = item[cam_key]
|
||||
if isinstance(img, torch.Tensor):
|
||||
if img.ndim == 3:
|
||||
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
||||
else:
|
||||
raise ValueError(f"Unexpected image shape {img.shape} for camera {cam_key}")
|
||||
else:
|
||||
img_np = np.array(img)
|
||||
|
||||
all_image_embeddings_dict[cam_key].append(img_np)
|
||||
|
||||
print("Computing image embeddings...")
|
||||
image_embeddings_dict = {}
|
||||
for cam_key, images in all_image_embeddings_dict.items():
|
||||
print(f" {cam_key}: {len(images)} images")
|
||||
embeddings = image_encoder.encode(images)
|
||||
image_embeddings_dict[cam_key] = embeddings
|
||||
|
||||
all_task_embeddings = np.array(all_task_embeddings)
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
image_embeddings_dict[cam_key] = np.array(image_embeddings_dict[cam_key])
|
||||
|
||||
img_emb_dim = image_encoder.embedding_dim
|
||||
lang_emb_dim = language_encoder.embedding_dim
|
||||
|
||||
add_features_dict = {
|
||||
"task_embedding": (
|
||||
all_task_embeddings,
|
||||
{"dtype": "float32", "shape": [lang_emb_dim], "names": None},
|
||||
),
|
||||
}
|
||||
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
add_features_dict[f"{cam_key}_embedding"] = (
|
||||
image_embeddings_dict[cam_key],
|
||||
{"dtype": "float32", "shape": [img_emb_dim], "names": None},
|
||||
)
|
||||
|
||||
print("Adding embeddings to dataset...")
|
||||
remove_features_list = None
|
||||
if remove_videos:
|
||||
remove_features_list = dataset.meta.video_keys
|
||||
|
||||
output_dataset = modify_features(
|
||||
dataset=dataset,
|
||||
add_features=add_features_dict,
|
||||
remove_features=remove_features_list,
|
||||
output_dir=output_local_dir,
|
||||
repo_id=output_repo_id,
|
||||
)
|
||||
|
||||
if remove_videos:
|
||||
print("Removing video files...")
|
||||
videos_dir = output_dataset.root / "videos"
|
||||
if videos_dir.exists():
|
||||
shutil.rmtree(videos_dir)
|
||||
|
||||
print(f"Saved to: {output_dataset.root}")
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing to hub: {output_repo_id}")
|
||||
output_dataset.push_to_hub(push_videos=not remove_videos)
|
||||
print("Done!")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate embeddings for LeRobot datasets",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Basic usage with default encoders (DinoV2 ViT-B/14 + MiniLM-L12)
|
||||
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
|
||||
--repo-id lerobot/utokyo_xarm_bimanual \\
|
||||
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \\
|
||||
--image-encoder dinov2_vitb14 \\
|
||||
--language-encoder minilm-l12 \\
|
||||
--push-to-hub
|
||||
|
||||
# Generate embeddings and remove videos
|
||||
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
|
||||
--repo-id lerobot/utokyo_xarm_bimanual \\
|
||||
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \\
|
||||
--image-encoder dinov2_vitb14 \\
|
||||
--language-encoder minilm-l12 \\
|
||||
--remove-videos \\
|
||||
--push-to-hub
|
||||
|
||||
Available image encoders:
|
||||
- dinov2_vits14: DinoV2 ViT-S/14 (384-dim, faster)
|
||||
- dinov2_vitb14: DinoV2 ViT-B/14 (768-dim, recommended)
|
||||
- dinov2_vitl14: DinoV2 ViT-L/14 (1024-dim, best quality)
|
||||
|
||||
Available language encoders:
|
||||
- minilm-l6: MiniLM-L6-v2 (384-dim, faster)
|
||||
- minilm-l12: MiniLM-L12-v2 (384-dim, recommended)
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repository ID")
|
||||
parser.add_argument("--output-repo-id", type=str, required=True, help="Output dataset repository ID")
|
||||
parser.add_argument(
|
||||
"--image-encoder",
|
||||
type=str,
|
||||
default="dinov2_vitb14",
|
||||
help="Image encoder to use (default: dinov2_vitb14)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language-encoder",
|
||||
type=str,
|
||||
default="minilm-l12",
|
||||
help="Language encoder to use (default: minilm-l12)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remove-videos",
|
||||
action="store_true",
|
||||
help="Remove video files after generating embeddings",
|
||||
)
|
||||
parser.add_argument("--local-dir", type=str, default=None, help="Local directory for source dataset")
|
||||
parser.add_argument(
|
||||
"--output-local-dir", type=str, default=None, help="Local directory for output dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push the converted dataset to the hub",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="Device to use for encoding (default: cuda)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load encoders
|
||||
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
|
||||
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
|
||||
|
||||
# Generate embeddings
|
||||
generate_embeddings_for_dataset(
|
||||
repo_id=args.repo_id,
|
||||
output_repo_id=args.output_repo_id,
|
||||
image_encoder=image_encoder,
|
||||
language_encoder=language_encoder,
|
||||
remove_videos=args.remove_videos,
|
||||
local_dir=Path(args.local_dir) if args.local_dir else None,
|
||||
output_local_dir=Path(args.output_local_dir) if args.output_local_dir else None,
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,222 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Validate pre-computed embeddings against on-the-fly computed embeddings.
|
||||
|
||||
Usage:
|
||||
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
|
||||
--original-repo-id lerobot/utokyo_xarm_bimanual \
|
||||
--embeddings-repo-id <your_username>/utokyo_xarm_bimanual_embeddings \
|
||||
--image-encoder dinov2_vitb14 \
|
||||
--language-encoder minilm-l12 \
|
||||
--num-samples 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.generating_embeddings.encoders import ImageEncoder, LanguageEncoder
|
||||
from lerobot.datasets.generating_embeddings.generate_embeddings import (
|
||||
get_image_encoder,
|
||||
get_language_encoder,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
|
||||
|
||||
def validate_embeddings(
|
||||
original_repo_id: str,
|
||||
embeddings_repo_id: str,
|
||||
image_encoder: ImageEncoder,
|
||||
language_encoder: LanguageEncoder,
|
||||
num_samples: int = 10,
|
||||
device: str = "cuda",
|
||||
):
|
||||
"""Validate pre-computed embeddings against on-the-fly embeddings.
|
||||
|
||||
Args:
|
||||
original_repo_id: Original dataset repository ID
|
||||
embeddings_repo_id: Dataset with pre-computed embeddings repository ID
|
||||
image_encoder: Image encoder instance
|
||||
language_encoder: Language encoder instance
|
||||
num_samples: Number of samples to validate
|
||||
device: Device to use for encoding
|
||||
"""
|
||||
# Load both datasets
|
||||
print("Loading datasets...")
|
||||
original_dataset = LeRobotDataset(original_repo_id, download_videos=True)
|
||||
embeddings_dataset = LeRobotDataset(embeddings_repo_id, download_videos=False)
|
||||
|
||||
# Verify both datasets have the same number of frames
|
||||
assert original_dataset.num_frames == embeddings_dataset.num_frames, (
|
||||
f"Frame count mismatch: original={original_dataset.num_frames}, "
|
||||
f"embeddings={embeddings_dataset.num_frames}"
|
||||
)
|
||||
|
||||
camera_keys = original_dataset.meta.camera_keys
|
||||
|
||||
# Check embedding features exist
|
||||
expected_features = ["task_embedding"] + [f"{cam}_embedding" for cam in camera_keys]
|
||||
for feat in expected_features:
|
||||
if feat not in embeddings_dataset.features:
|
||||
raise ValueError(f"Embedding feature not found: {feat}")
|
||||
|
||||
# Select random sample indices
|
||||
sample_indices = np.random.choice(
|
||||
original_dataset.num_frames, size=min(num_samples, original_dataset.num_frames), replace=False
|
||||
)
|
||||
print(f"Validating {len(sample_indices)} samples...")
|
||||
|
||||
# Track statistics
|
||||
task_similarities = []
|
||||
image_similarities = {cam: [] for cam in camera_keys}
|
||||
|
||||
for idx in tqdm(sample_indices, desc="Validating"):
|
||||
idx = int(idx)
|
||||
|
||||
embeddings_item = embeddings_dataset[idx]
|
||||
precomputed_task_emb = embeddings_item["task_embedding"].numpy()
|
||||
precomputed_image_embs = {cam: embeddings_item[f"{cam}_embedding"].numpy() for cam in camera_keys}
|
||||
|
||||
original_item = original_dataset[idx]
|
||||
|
||||
# Get task and compute embedding
|
||||
task = original_item["task"]
|
||||
# Clean up task text (same as in generate_embeddings.py)
|
||||
task_clean = task.strip().capitalize().strip(" .,!?-_")
|
||||
onthefly_task_emb = language_encoder.encode([task_clean])[0]
|
||||
|
||||
# Get images and compute embeddings
|
||||
onthefly_image_embs = {}
|
||||
for cam in camera_keys:
|
||||
img = original_item[cam]
|
||||
# Convert to numpy if needed
|
||||
if isinstance(img, torch.Tensor):
|
||||
if img.ndim == 3: # (C, H, W)
|
||||
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
||||
else:
|
||||
raise ValueError(f"Unexpected image shape: {img.shape}")
|
||||
else:
|
||||
img_np = np.array(img)
|
||||
|
||||
onthefly_image_embs[cam] = image_encoder.encode([img_np])[0]
|
||||
|
||||
# Task embedding comparison
|
||||
task_sim = cosine_similarity(precomputed_task_emb, onthefly_task_emb)
|
||||
task_similarities.append(task_sim)
|
||||
|
||||
# Image embedding comparison
|
||||
for cam in camera_keys:
|
||||
img_sim = cosine_similarity(precomputed_image_embs[cam], onthefly_image_embs[cam])
|
||||
image_similarities[cam].append(img_sim)
|
||||
|
||||
# Results
|
||||
print("\nResults:")
|
||||
task_sim_threshold = 0.99
|
||||
img_sim_threshold = 0.99
|
||||
|
||||
task_mean_sim = np.mean(task_similarities)
|
||||
task_pass = task_mean_sim >= task_sim_threshold
|
||||
|
||||
print(f" Task: {task_mean_sim:.4f} {'✓' if task_pass else '✗'}")
|
||||
|
||||
for cam in camera_keys:
|
||||
cam_mean_sim = np.mean(image_similarities[cam])
|
||||
cam_pass = cam_mean_sim >= img_sim_threshold
|
||||
print(f" {cam}: {cam_mean_sim:.4f} {'✓' if cam_pass else '✗'}")
|
||||
|
||||
image_pass = all(np.mean(image_similarities[cam]) >= img_sim_threshold for cam in camera_keys)
|
||||
|
||||
print()
|
||||
if task_pass and image_pass:
|
||||
print("✓ PASSED")
|
||||
else:
|
||||
print("✗ FAILED")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Validate and compare pre-computed embeddings with on-the-fly embeddings",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Example:
|
||||
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \\
|
||||
--original-repo-id lerobot/utokyo_xarm_bimanual \\
|
||||
--embeddings-repo-id lerobot/utokyo_xarm_bimanual_embeddings \\
|
||||
--image-encoder dinov2_vitb14 \\
|
||||
--language-encoder minilm-l12 \\
|
||||
--num-samples 20
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--original-repo-id", type=str, required=True, help="Original dataset repository ID")
|
||||
parser.add_argument(
|
||||
"--embeddings-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Dataset with pre-computed embeddings repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-encoder",
|
||||
type=str,
|
||||
default="dinov2_vitb14",
|
||||
help="Image encoder to use (default: dinov2_vitb14)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language-encoder",
|
||||
type=str,
|
||||
default="minilm-l12",
|
||||
help="Language encoder to use (default: minilm-l12)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of samples to validate (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="Device to use for encoding (default: cuda)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load encoders
|
||||
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
|
||||
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
|
||||
|
||||
# Validate embeddings
|
||||
validate_embeddings(
|
||||
original_repo_id=args.original_repo_id,
|
||||
embeddings_repo_id=args.embeddings_repo_id,
|
||||
image_encoder=image_encoder,
|
||||
language_encoder=language_encoder,
|
||||
num_samples=args.num_samples,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,151 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SARM Temporal Sampler for reward model training.
|
||||
|
||||
Samples frames uniformly from episodes for SARM's 9-frame symmetric pattern:
|
||||
- 1 initial frame + 4 frames before + current + 3 frames after
|
||||
|
||||
Boundary handling: clamp to first/last frame when indices go out of bounds.
|
||||
This enables truly uniform sampling across entire episodes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Iterator, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Sampler
|
||||
import random
|
||||
|
||||
|
||||
class SARMTemporalSampler(Sampler):
|
||||
"""
|
||||
Temporal sampler for SARM reward model training with symmetric/bidirectional sampling.
|
||||
|
||||
SARM uses 9 frames per sample:
|
||||
- Frame 0: Initial frame of the episode (always frame 0)
|
||||
- Frames 1-8: Symmetric context around current frame
|
||||
Pattern: [t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
|
||||
Boundary handling:
|
||||
- Early frames: backward indices clamp to 0 (e.g., [0,0,0,5,35,65,95,125])
|
||||
- Late frames: forward indices clamp to last frame (e.g., [850,880,910,940,970,1000,1000,1000])
|
||||
|
||||
This enables truly uniform sampling across entire episodes.
|
||||
|
||||
Args:
|
||||
dataset_from_index: Start indices of episodes (global dataset indices)
|
||||
dataset_to_index: End indices of episodes (global dataset indices)
|
||||
frame_gap: Gap between consecutive frames (default: 30 = 1 second at 30fps)
|
||||
shuffle: Whether to shuffle sampling order
|
||||
seed: Random seed for reproducibility
|
||||
samples_per_epoch: Number of samples per epoch (default: 6400)
|
||||
min_episode_length: Minimum episode length to include (default: 1)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_index: np.ndarray,
|
||||
dataset_to_index: np.ndarray,
|
||||
frame_gap: int = 30,
|
||||
shuffle: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
samples_per_epoch: int = 6400,
|
||||
min_episode_length: int = 1,
|
||||
):
|
||||
self.dataset_from_index = np.array(dataset_from_index)
|
||||
self.dataset_to_index = np.array(dataset_to_index)
|
||||
self.frame_gap = frame_gap
|
||||
self.shuffle = shuffle
|
||||
self.samples_per_epoch = samples_per_epoch
|
||||
self.min_episode_length = min_episode_length
|
||||
|
||||
if seed is not None:
|
||||
self.seed = seed
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
self.generator = torch.Generator().manual_seed(seed)
|
||||
else:
|
||||
self.generator = torch.Generator()
|
||||
|
||||
# Compute valid episodes and sampling positions (ALL frames for uniform sampling)
|
||||
self._compute_valid_positions()
|
||||
|
||||
logging.info(
|
||||
f"SARMTemporalSampler: {len(self.valid_episodes)} valid episodes, "
|
||||
f"{len(self.all_valid_positions)} positions (uniform sampling), "
|
||||
f"{self.samples_per_epoch} samples per epoch, "
|
||||
f"frame_gap={frame_gap}, symmetric bidirectional pattern"
|
||||
)
|
||||
|
||||
def _compute_valid_positions(self):
|
||||
"""Compute valid episodes and ALL sampling positions for uniform sampling.
|
||||
|
||||
With symmetric bidirectional sampling, we can sample from ANY frame:
|
||||
- Early frames: backward indices clamp to first frame
|
||||
- Late frames: forward indices clamp to last frame
|
||||
"""
|
||||
self.valid_episodes = []
|
||||
self.all_valid_positions = []
|
||||
|
||||
for ep_idx in range(len(self.dataset_from_index)):
|
||||
ep_start = self.dataset_from_index[ep_idx]
|
||||
ep_end = self.dataset_to_index[ep_idx]
|
||||
episode_length = ep_end - ep_start
|
||||
|
||||
# Include all episodes with at least min_episode_length frames
|
||||
if episode_length >= self.min_episode_length:
|
||||
self.valid_episodes.append((ep_idx, ep_start, ep_end))
|
||||
|
||||
# Include ALL positions in the episode (truly uniform sampling)
|
||||
for pos in range(ep_start, ep_end):
|
||||
self.all_valid_positions.append(pos)
|
||||
|
||||
self.valid_episodes = np.array(self.valid_episodes)
|
||||
self.all_valid_positions = np.array(self.all_valid_positions)
|
||||
|
||||
if len(self.all_valid_positions) == 0:
|
||||
raise ValueError(
|
||||
f"No valid sampling positions found! "
|
||||
f"Check that episodes have at least {self.min_episode_length} frames."
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.samples_per_epoch
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
"""
|
||||
Yields global dataset indices for uniform sampling across episodes.
|
||||
|
||||
Each yielded index represents the "current frame" position.
|
||||
The dataset's observation_delta_indices then handles loading:
|
||||
- Frame 0: Episode initial frame (via large negative delta clamping)
|
||||
- Frames 1-8: Symmetric context around current frame (with boundary clamping)
|
||||
|
||||
For early frames: backward indices clamp to first frame (progress ~0%)
|
||||
For late frames: forward indices clamp to last frame (progress ~100%)
|
||||
"""
|
||||
if self.shuffle:
|
||||
# Randomly sample from all valid positions
|
||||
for _ in range(self.samples_per_epoch):
|
||||
idx = np.random.randint(0, len(self.all_valid_positions))
|
||||
yield int(self.all_valid_positions[idx])
|
||||
else:
|
||||
# Sequential sampling with wrap-around
|
||||
for i in range(self.samples_per_epoch):
|
||||
idx = i % len(self.all_valid_positions)
|
||||
yield int(self.all_valid_positions[idx])
|
||||
@@ -111,7 +111,6 @@ def make_env(
|
||||
|
||||
# import and surface clear import errors
|
||||
module = _import_hub_module(local_file, repo_id)
|
||||
|
||||
# call the hub-provided make_env
|
||||
raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs)
|
||||
|
||||
|
||||
@@ -221,7 +221,22 @@ def _load_module_from_path(path: str, module_name: str | None = None):
|
||||
if spec is None:
|
||||
raise ImportError(f"Could not load module spec for {module_name} from {path}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
|
||||
# Add the module's directory to sys.path so it can import local modules
|
||||
import sys
|
||||
module_dir = os.path.dirname(os.path.abspath(path))
|
||||
sys_path_modified = False
|
||||
if module_dir not in sys.path:
|
||||
sys.path.insert(0, module_dir)
|
||||
sys_path_modified = True
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
finally:
|
||||
# Clean up sys.path after import
|
||||
if sys_path_modified:
|
||||
sys.path.remove(module_dir)
|
||||
|
||||
return module
|
||||
|
||||
|
||||
|
||||
@@ -35,7 +35,6 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
@@ -104,10 +103,6 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "sarm":
|
||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "groot":
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
|
||||
@@ -327,14 +322,6 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SARMConfig):
|
||||
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
processors = make_sarm_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
@@ -418,13 +405,6 @@ def make_policy(
|
||||
if not cfg.input_features:
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
||||
if ds_meta is not None and hasattr(ds_meta, 'stats'):
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
|
||||
if ds_meta is not None:
|
||||
kwargs["dataset_meta"] = ds_meta
|
||||
|
||||
if cfg.pretrained_path:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import PolicyFeature, FeatureType, NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sarm")
|
||||
@dataclass
|
||||
class SARMConfig(PreTrainedConfig):
|
||||
"""Configuration class for SARM (Stage-Aware Reward Modeling)"""
|
||||
|
||||
# CLIP params
|
||||
image_dim: int = 512
|
||||
text_dim: int = 512
|
||||
num_frames: int = 9 # 1 initial + 8 consecutive frames
|
||||
frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
|
||||
|
||||
# Architecture params
|
||||
hidden_dim: int = 768
|
||||
num_heads: int = 12
|
||||
num_layers: int = 8
|
||||
max_state_dim: int = 32
|
||||
num_stages: int = 5 # Number of task stages (auto-updated from annotations if available)
|
||||
subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
|
||||
temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations)
|
||||
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
|
||||
use_temporal_sampler: bool = True # Always enable temporal sequence loading
|
||||
|
||||
# Training params
|
||||
batch_size: int = 64
|
||||
clip_batch_size: int = 64 # Batch size for CLIP encoding
|
||||
dropout: float = 0.1
|
||||
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
|
||||
|
||||
pretrained_model_path: str | None = None
|
||||
device: str | None = None
|
||||
|
||||
# Processor settings
|
||||
image_key: str = "observation.images.top" # Key for image used from the dataset
|
||||
|
||||
# State key in the dataset (for normalization)
|
||||
state_key: str = "observation.state"
|
||||
|
||||
# Populated by the processor (video_features, state_features, text_features)
|
||||
input_features: dict = field(default_factory=lambda: {})
|
||||
|
||||
# Output features
|
||||
output_features: dict = field(default_factory=lambda: {
|
||||
"stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
|
||||
"progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
|
||||
})
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"LANGUAGE": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Add the image_key as VISUAL
|
||||
if self.image_key:
|
||||
self.input_features[self.image_key] = PolicyFeature(
|
||||
shape=(480, 640, 3),
|
||||
type=FeatureType.VISUAL
|
||||
)
|
||||
|
||||
# Add state_key as STATE
|
||||
self.input_features[self.state_key] = PolicyFeature(
|
||||
shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence
|
||||
type=FeatureType.STATE
|
||||
)
|
||||
|
||||
# Update output features with actual dimensions
|
||||
self.output_features["stage"] = PolicyFeature(
|
||||
shape=(self.num_frames, self.num_stages),
|
||||
type=FeatureType.REWARD
|
||||
)
|
||||
self.output_features["progress"] = PolicyFeature(
|
||||
shape=(self.num_frames, 1),
|
||||
type=FeatureType.REWARD
|
||||
)
|
||||
|
||||
# Validate configuration
|
||||
if self.hidden_dim % self.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"hidden_dim ({self.hidden_dim}) must be divisible by num_heads ({self.num_heads})"
|
||||
)
|
||||
|
||||
if self.max_length != self.num_frames:
|
||||
raise ValueError(
|
||||
f"max_length ({self.max_length}) must equal num_frames ({self.num_frames})"
|
||||
)
|
||||
|
||||
if self.num_stages < 2:
|
||||
raise ValueError(f"num_stages must be at least 2, got {self.num_stages}")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
"""Get default optimizer configuration for SARM training."""
|
||||
return AdamWConfig(
|
||||
lr=5e-5,
|
||||
weight_decay=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
"""Get default learning rate scheduler configuration."""
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=5e-5,
|
||||
decay_lr=5e-6,
|
||||
num_warmup_steps=500,
|
||||
num_decay_steps=50000,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate input and output features."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
"""Load frames for SARM temporal sampling with SYMMETRIC/BIDIRECTIONAL pattern.
|
||||
|
||||
The model uses 9 frames with symmetric context around current frame:
|
||||
- Frame 0: Initial frame of the episode (clamped via large negative delta)
|
||||
- Frames 1-8: Symmetric context: 4 before + current + 3 after
|
||||
|
||||
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
|
||||
Boundary handling (done by dataset loader):
|
||||
- Early frames: backward indices clamp to 0 (first frame)
|
||||
- Late frames: forward indices clamp to episode end (last frame)
|
||||
|
||||
This enables truly uniform sampling across entire episodes.
|
||||
|
||||
Returns:
|
||||
9 delta indices: [-1_000_000, -4*gap, -3*gap, -2*gap, -gap, 0, gap, 2*gap, 3*gap]
|
||||
"""
|
||||
initial_frame_delta = -1_000_000
|
||||
|
||||
# Symmetric pattern: 4 frames before, current (0), 3 frames after = 8 context frames
|
||||
symmetric_deltas = [
|
||||
-4 * self.frame_gap,
|
||||
-3 * self.frame_gap,
|
||||
-2 * self.frame_gap,
|
||||
-1 * self.frame_gap,
|
||||
0, # current frame
|
||||
1 * self.frame_gap,
|
||||
2 * self.frame_gap,
|
||||
3 * self.frame_gap,
|
||||
]
|
||||
|
||||
return [initial_frame_delta] + symmetric_deltas
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
"""SARM is a reward model, not an action policy."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
"""SARM doesn't use delta rewards."""
|
||||
return None
|
||||
|
||||
@@ -1,650 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Union, Optional
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
class SARMTransformer(nn.Module):
|
||||
"""
|
||||
SARM Transformer model for stage-aware reward prediction.
|
||||
|
||||
This model has a dual-head architecture:
|
||||
1. Stage estimator: Predicts the high-level task stage (classification)
|
||||
2. Subtask estimator: Predicts fine-grained progress within the stage (regression)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_dim: int = 512,
|
||||
text_dim: int = 512,
|
||||
max_state_dim: int = 32,
|
||||
hidden_dim: int = 768,
|
||||
num_heads: int = 12,
|
||||
num_layers: int = 8,
|
||||
num_stages: int = 5,
|
||||
max_length: int = 9,
|
||||
dropout: float = 0.1,
|
||||
temporal_proportions: list[float] | None = None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.max_length = max_length
|
||||
self.num_stages = num_stages
|
||||
self.max_state_dim = max_state_dim
|
||||
|
||||
if temporal_proportions is None:
|
||||
raise ValueError(
|
||||
"temporal_proportions is required for SARM. "
|
||||
"Provide subtask annotations in your dataset or set temporal_proportions in config."
|
||||
)
|
||||
|
||||
# ᾱ_k: proportion for each stage
|
||||
alpha = torch.tensor(temporal_proportions, dtype=torch.float32)
|
||||
|
||||
# P_k: cumulative proportion up to stage k (P_0 = 0)
|
||||
cumulative = torch.zeros(num_stages + 1, dtype=torch.float32)
|
||||
cumulative[1:] = torch.cumsum(alpha, dim=0)
|
||||
self.register_buffer('alpha', alpha)
|
||||
self.register_buffer('cumulative_prior', cumulative)
|
||||
|
||||
self.video_proj = nn.Linear(video_dim, hidden_dim)
|
||||
self.text_proj = nn.Linear(text_dim, hidden_dim)
|
||||
self.state_proj = nn.Linear(max_state_dim, hidden_dim)
|
||||
|
||||
# Position embedding only for the first frame
|
||||
self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim))
|
||||
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=hidden_dim * 4,
|
||||
dropout=dropout,
|
||||
batch_first=True
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
|
||||
# Stage estimator head (classification)
|
||||
self.stage_head = nn.Sequential(
|
||||
nn.Linear(hidden_dim, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(512, num_stages)
|
||||
)
|
||||
|
||||
# Subtask estimator head (regression)
|
||||
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
|
||||
subtask_input_dim = hidden_dim + hidden_dim // 4
|
||||
self.subtask_head = nn.Sequential(
|
||||
nn.Linear(subtask_input_dim, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(512, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Attention mask
|
||||
self.register_buffer("attention_mask", None, persistent=False)
|
||||
|
||||
def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor:
|
||||
"""Generate or retrieve cached causal attention mask."""
|
||||
if self.attention_mask is None or self.attention_mask.shape[0] != seq_length:
|
||||
# Create causal mask
|
||||
mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
|
||||
self.attention_mask = mask
|
||||
return self.attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_frames: torch.Tensor,
|
||||
text_embed: torch.Tensor,
|
||||
state_features: Optional[torch.Tensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass through the SARM transformer.
|
||||
|
||||
Args:
|
||||
video_frames: Video frame embeddings (batch_size, seq_len, video_dim)
|
||||
text_embed: Text embeddings (batch_size, text_dim)
|
||||
state_features: Joint state features (batch_size, seq_len, state_dim)
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Stage logits for each frame (batch_size, seq_len, num_stages)
|
||||
- Stage probabilities (batch_size, seq_len, num_stages)
|
||||
- Progress predictions for each frame (batch_size, seq_len, 1)
|
||||
"""
|
||||
# Project inputs to common dimension
|
||||
video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim]
|
||||
text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim]
|
||||
|
||||
# Pad state features to max_state_dim before projection
|
||||
state_features_padded = pad_state_to_max_dim(state_features, self.max_state_dim)
|
||||
|
||||
state_embed = self.state_proj(state_features_padded) # [batch_size, seq_len, hidden_dim]
|
||||
|
||||
# Fuse video and state features
|
||||
video_embed = video_embed + state_embed
|
||||
|
||||
# Add positional embedding to first video frame
|
||||
video_embed[:, 0] += self.first_pos_embed
|
||||
|
||||
# Combine sequence: [text, video_frames]
|
||||
sequence = torch.cat([text_embed, video_embed], dim=1)
|
||||
|
||||
# Get causal attention mask
|
||||
seq_length = sequence.shape[1]
|
||||
attention_mask = self._get_attention_mask(seq_length, sequence.device)
|
||||
|
||||
# Pass through transformer with causal masking
|
||||
transformed = self.transformer(sequence, mask=attention_mask, is_causal=True)
|
||||
|
||||
# Get frame features
|
||||
frame_features = transformed[:, 1:] # [batch_size, seq_len, hidden_dim]
|
||||
|
||||
# Stage estimation
|
||||
stage_logits = self.stage_head(frame_features) # [batch_size, seq_len, num_stages]
|
||||
stage_probs = F.softmax(stage_logits, dim=-1) # [batch_size, seq_len, num_stages]
|
||||
|
||||
# Get predicted stage indices
|
||||
stage_indices = torch.argmax(stage_probs, dim=-1) # [batch_size, seq_len]
|
||||
|
||||
# Get stage embeddings for conditioning
|
||||
stage_embeds = self.stage_embedding(stage_indices)
|
||||
|
||||
# Concatenate frame features with stage embeddings
|
||||
conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1)
|
||||
|
||||
# Subtask progress estimation (conditioned on stage)
|
||||
# τ̂ = within-subtask progress (0-1)
|
||||
tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1]
|
||||
|
||||
# Convert τ̂ to cumulative progress ŷ using Paper Formula (2):
|
||||
# ŷ = P_{k-1} + ᾱ_k × τ̂
|
||||
progress_preds = compute_cumulative_progress_batch(
|
||||
tau_preds, stage_indices, self.alpha, self.cumulative_prior
|
||||
)
|
||||
|
||||
return stage_logits, stage_probs, progress_preds
|
||||
|
||||
|
||||
class SARMRewardModel(PreTrainedPolicy):
|
||||
"""
|
||||
SARM Reward Model for stage-aware task completion rewards.
|
||||
|
||||
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
|
||||
to process both RGB image sequences and task descriptions."
|
||||
|
||||
This model combines:
|
||||
- CLIP for encoding video frames AND text descriptions
|
||||
- SARMTransformer for predicting task stage and progress
|
||||
- Optional RA-BC (Reward-Aligned Behavior Cloning) for weighted training
|
||||
"""
|
||||
|
||||
name = "sarm"
|
||||
config_class = SARMConfig
|
||||
|
||||
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.dataset_stats = dataset_stats
|
||||
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Load temporal proportions from dataset
|
||||
if config.temporal_proportions is None and dataset_meta is not None:
|
||||
self._load_temporal_proportions(dataset_meta)
|
||||
|
||||
logging.info("Loading CLIP encoder")
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||
self.clip_model.to(self.device)
|
||||
self.clip_model.eval()
|
||||
|
||||
self.sarm_transformer = SARMTransformer(
|
||||
video_dim=config.image_dim,
|
||||
text_dim=config.text_dim,
|
||||
max_state_dim=config.max_state_dim,
|
||||
hidden_dim=config.hidden_dim,
|
||||
num_heads=config.num_heads,
|
||||
num_layers=config.num_layers,
|
||||
num_stages=config.num_stages,
|
||||
max_length=config.max_length,
|
||||
dropout=config.dropout,
|
||||
temporal_proportions=config.temporal_proportions
|
||||
)
|
||||
self.sarm_transformer.to(self.device)
|
||||
logging.info(f"SARM initialized on {self.device}")
|
||||
|
||||
def _load_temporal_proportions(self, dataset_meta) -> None:
|
||||
"""
|
||||
Load pre-computed temporal proportions from dataset metadata JSON file.
|
||||
|
||||
The temporal proportions are computed during dataset annotation using SARM Paper Formula (1):
|
||||
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
"""
|
||||
import json
|
||||
|
||||
proportions_path = dataset_meta.root / "meta" / "temporal_proportions.json"
|
||||
|
||||
if not proportions_path.exists():
|
||||
raise ValueError(
|
||||
f"Temporal proportions not found at {proportions_path}. "
|
||||
"Run the subtask annotation tool first to compute and save temporal proportions."
|
||||
)
|
||||
|
||||
with open(proportions_path, "r") as f:
|
||||
temporal_proportions_dict = json.load(f)
|
||||
|
||||
# Sort subtask names for consistent ordering
|
||||
subtask_names = sorted(temporal_proportions_dict.keys())
|
||||
|
||||
self.config.num_stages = len(subtask_names)
|
||||
self.config.subtask_names = subtask_names
|
||||
self.config.temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
|
||||
|
||||
logging.info(f"Loaded {len(subtask_names)} subtasks: {subtask_names}")
|
||||
logging.info(f"Temporal proportions: {temporal_proportions_dict}")
|
||||
|
||||
def to(self, device):
|
||||
"""Override to method to ensure all components move together."""
|
||||
super().to(device)
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
self.clip_model.to(device)
|
||||
self.sarm_transformer.to(device)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_images(self, images: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Encode video frames using CLIP.
|
||||
|
||||
Args:
|
||||
images: Video frames with shape (num_videos, num_frames, H, W, C) in uint8.
|
||||
Can also be (num_frames, H, W, C) for a single video.
|
||||
|
||||
Returns:
|
||||
Encoded image features (num_videos, num_frames, 512) or (num_frames, 512).
|
||||
"""
|
||||
# Handle single video case
|
||||
single_video = False
|
||||
if len(images.shape) == 4:
|
||||
images = images[np.newaxis, ...]
|
||||
single_video = True
|
||||
|
||||
assert len(images.shape) == 5, f"Expected 5D input (num_videos, num_frames, H, W, C), got {images.shape}"
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
for video in images:
|
||||
video_embeddings = []
|
||||
|
||||
# Convert frames to PIL images for CLIP processor
|
||||
frames = []
|
||||
for frame in video:
|
||||
if frame.shape[0] == 3: # Channel first
|
||||
frame = frame.transpose(1, 2, 0)
|
||||
if frame.dtype != np.uint8:
|
||||
frame = (frame * 255).astype(np.uint8) if frame.max() <= 1.0 else frame.astype(np.uint8)
|
||||
frames.append(Image.fromarray(frame))
|
||||
|
||||
# Batch process frames with CLIP
|
||||
for i in range(0, len(frames), self.config.clip_batch_size):
|
||||
batch = frames[i:i + self.config.clip_batch_size]
|
||||
inputs = self.clip_processor(images=batch, return_tensors="pt")
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get image embeddings from CLIP
|
||||
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
|
||||
|
||||
# Handle single frame case
|
||||
if embeddings.dim() == 1:
|
||||
embeddings = embeddings.unsqueeze(0)
|
||||
|
||||
video_embeddings.append(embeddings)
|
||||
|
||||
video_embeddings = torch.cat(video_embeddings)
|
||||
all_embeddings.append(video_embeddings)
|
||||
|
||||
result = torch.stack(all_embeddings).numpy()
|
||||
|
||||
if single_video:
|
||||
result = result[0]
|
||||
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
||||
"""
|
||||
Encode text using CLIP text encoder (per SARM paper A.4).
|
||||
|
||||
Args:
|
||||
text: Text string or list of text strings.
|
||||
|
||||
Returns:
|
||||
Encoded text features (batch_size, 512) or (512,) for single text.
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
single_text = True
|
||||
else:
|
||||
single_text = False
|
||||
|
||||
# Use CLIP's tokenizer directly (avoids image processor validation issues)
|
||||
tokenizer = self.clip_processor.tokenizer
|
||||
|
||||
# Process in batches
|
||||
all_embeddings = []
|
||||
for i in range(0, len(text), self.config.batch_size):
|
||||
batch_text = text[i:i + self.config.batch_size]
|
||||
|
||||
inputs = tokenizer(batch_text, return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
text_embeddings = self.clip_model.get_text_features(**inputs)
|
||||
all_embeddings.append(text_embeddings.cpu())
|
||||
|
||||
result = torch.cat(all_embeddings).numpy()
|
||||
|
||||
if single_text:
|
||||
result = result[0]
|
||||
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_rewards(
|
||||
self,
|
||||
text_embeddings: Union[np.ndarray, torch.Tensor],
|
||||
video_embeddings: Union[np.ndarray, torch.Tensor],
|
||||
state_features: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
return_all_frames: bool = False,
|
||||
return_stages: bool = False
|
||||
) -> Union[np.ndarray, tuple]:
|
||||
"""
|
||||
Calculate rewards for given text, video, and state representations.
|
||||
|
||||
Args:
|
||||
text_embeddings: Encoded text representations (batch_size, 512)
|
||||
video_embeddings: Encoded video representations (batch_size, num_frames, 512)
|
||||
state_features: Joint state features (batch_size, num_frames, state_dim)
|
||||
return_all_frames: If True, return rewards for all frames
|
||||
return_stages: If True, also return stage predictions
|
||||
|
||||
Returns:
|
||||
If return_stages=False:
|
||||
Reward values (batch_size,) or (batch_size, num_frames)
|
||||
If return_stages=True:
|
||||
Tuple of (rewards, stage_probs)
|
||||
"""
|
||||
if isinstance(text_embeddings, np.ndarray):
|
||||
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
|
||||
if isinstance(video_embeddings, np.ndarray):
|
||||
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||
if state_features is not None and isinstance(state_features, np.ndarray):
|
||||
state_features = torch.tensor(state_features, dtype=torch.float32)
|
||||
|
||||
# Handle single sample case
|
||||
if text_embeddings.dim() == 1:
|
||||
text_embeddings = text_embeddings.unsqueeze(0)
|
||||
video_embeddings = video_embeddings.unsqueeze(0)
|
||||
if state_features is not None:
|
||||
state_features = state_features.unsqueeze(0)
|
||||
single_sample = True
|
||||
else:
|
||||
single_sample = False
|
||||
|
||||
# Process in batches
|
||||
all_rewards = []
|
||||
all_stage_probs = []
|
||||
|
||||
for i in range(0, len(video_embeddings), self.config.batch_size):
|
||||
batch_texts = text_embeddings[i:i + self.config.batch_size].to(self.device)
|
||||
batch_videos = video_embeddings[i:i + self.config.batch_size].to(self.device)
|
||||
batch_states = None
|
||||
if state_features is not None:
|
||||
batch_states = state_features[i:i + self.config.batch_size].to(self.device)
|
||||
|
||||
# Get predictions
|
||||
stage_logits, stage_probs, progress_preds = self.sarm_transformer(
|
||||
batch_videos.float(), batch_texts.float(), batch_states.float() if batch_states is not None else None
|
||||
)
|
||||
|
||||
if return_all_frames:
|
||||
all_rewards.append(progress_preds.squeeze(-1).cpu())
|
||||
else:
|
||||
# Return only last frame reward
|
||||
all_rewards.append(progress_preds[:, -1, 0].cpu())
|
||||
|
||||
if return_stages:
|
||||
all_stage_probs.append(stage_probs.cpu())
|
||||
|
||||
rewards = torch.cat(all_rewards).numpy()
|
||||
|
||||
if single_sample:
|
||||
rewards = rewards[0] if not return_all_frames else rewards[0]
|
||||
|
||||
if return_stages:
|
||||
stage_probs = torch.cat(all_stage_probs).numpy()
|
||||
if single_sample:
|
||||
stage_probs = stage_probs[0]
|
||||
return rewards, stage_probs
|
||||
|
||||
return rewards
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Overwrite train method to ensure CLIP encoder stays frozen during training"""
|
||||
super().train(mode)
|
||||
self.clip_model.eval()
|
||||
self.sarm_transformer.train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
"""Overwrite eval method to ensure CLIP encoder stays frozen during evaluation"""
|
||||
return self.train(False)
|
||||
|
||||
def parameters(self):
|
||||
"""Override to return trainable parameters (only SARM transformer, not CLIP encoder)."""
|
||||
return self.sarm_transformer.parameters()
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Override to return optimizer parameters (only SARM transformer, not CLIP encoder)."""
|
||||
return self.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
pass
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
raise NotImplementedError("SARM model does not predict action chunks")
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for SARM."""
|
||||
raise NotImplementedError("SARM model does not select actions")
|
||||
|
||||
def _apply_temporal_augmentation(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
progress: torch.Tensor,
|
||||
state: torch.Tensor | None,
|
||||
max_length: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
"""Apply temporal augmentation by appending reversed frames (SARM paper A.4).
|
||||
|
||||
This helps the model learn to handle non-monotonic progress (failures, recoveries).
|
||||
Appends 1-4 reversed frames to simulate going backwards in task progress.
|
||||
"""
|
||||
num_reverse = random.randint(1, min(4, max_length - 1))
|
||||
|
||||
# Reverse and take frames (skip first which is last of original)
|
||||
reversed_video = video.flip(0)[1:num_reverse + 1]
|
||||
reversed_progress = progress.flip(0)[1:num_reverse + 1]
|
||||
|
||||
# Concatenate and trim
|
||||
video = torch.cat([video, reversed_video], dim=0)[:max_length]
|
||||
progress = torch.cat([progress, reversed_progress], dim=0)[:max_length]
|
||||
|
||||
if state is not None:
|
||||
reversed_state = state.flip(0)[1:num_reverse + 1]
|
||||
state = torch.cat([state, reversed_state], dim=0)[:max_length]
|
||||
|
||||
return video, progress, state
|
||||
|
||||
def _ensure_sequence_length(self, tensor: torch.Tensor, target_len: int) -> torch.Tensor:
|
||||
"""Pad or trim tensor to target length."""
|
||||
current_len = tensor.shape[0]
|
||||
if current_len == target_len:
|
||||
return tensor
|
||||
if current_len < target_len:
|
||||
padding = target_len - current_len
|
||||
return torch.cat([tensor, tensor[-1:].expand(padding, *tensor.shape[1:])])
|
||||
return tensor[:target_len]
|
||||
|
||||
def forward(self, batch):
|
||||
"""
|
||||
Forward pass for SARM reward model training.
|
||||
|
||||
Uses annotation-based progress targets following SARM paper Eq. 2:
|
||||
yt = Pk-1 + α̅k × τt
|
||||
where:
|
||||
- τt = (t - sk) / (ek - sk) is within-subtask normalized time
|
||||
- Pk-1 is cumulative prior (sum of previous subtask proportions)
|
||||
- α̅k is the temporal proportion for subtask k
|
||||
|
||||
Args:
|
||||
batch: Dictionary with 'observation' containing:
|
||||
- 'video_features': (B, T, 512) pre-encoded video features
|
||||
- 'text_features': (B, 512) pre-encoded text features (CLIP)
|
||||
- 'state_features': (B, T, state_dim) joint state features
|
||||
- 'stage_labels': (B, T) stage labels from annotations
|
||||
- 'progress_targets': (B, T, 1) progress targets from annotations
|
||||
|
||||
Returns:
|
||||
Tuple of (total_loss, output_dict with loss components)
|
||||
"""
|
||||
observation = batch.get('observation', batch)
|
||||
|
||||
# Extract required features
|
||||
video_features = observation['video_features'].to(self.device)
|
||||
text_features = observation['text_features'].to(self.device)
|
||||
state_features = observation.get('state_features').to(self.device)
|
||||
|
||||
batch_size = video_features.shape[0]
|
||||
max_length = self.config.num_frames
|
||||
|
||||
# Ensure 3D video features (B, T, D)
|
||||
if video_features.dim() == 2:
|
||||
video_features = video_features.unsqueeze(1).expand(-1, max_length, -1)
|
||||
if state_features is not None and state_features.dim() == 2:
|
||||
state_features = state_features.unsqueeze(1).expand(-1, max_length, -1)
|
||||
|
||||
# Get annotation-based progress targets (required for SARM paper formula)
|
||||
progress_from_annotations = observation.get('progress_targets')
|
||||
if progress_from_annotations is None:
|
||||
raise ValueError("progress_targets from annotations is required for SARM training")
|
||||
|
||||
progress_from_annotations = progress_from_annotations.to(self.device)
|
||||
if progress_from_annotations.dim() == 2:
|
||||
progress_from_annotations = progress_from_annotations.unsqueeze(-1)
|
||||
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
|
||||
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
|
||||
|
||||
# Process each sample: apply temporal REWIND augmentation
|
||||
processed_videos = []
|
||||
processed_states = []
|
||||
progress_targets = []
|
||||
|
||||
for i in range(batch_size):
|
||||
video = video_features[i]
|
||||
state = state_features[i] if state_features is not None else None
|
||||
progress = progress_from_annotations[i].squeeze(-1) # (T,)
|
||||
|
||||
# Apply temporal REWIND augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
|
||||
if random.random() < 0.5:
|
||||
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
|
||||
|
||||
# Ensure correct sequence length
|
||||
video = self._ensure_sequence_length(video, max_length)
|
||||
progress = self._ensure_sequence_length(progress.unsqueeze(-1), max_length).squeeze(-1)
|
||||
if state is not None:
|
||||
state = self._ensure_sequence_length(state, max_length)
|
||||
|
||||
processed_videos.append(video)
|
||||
progress_targets.append(progress)
|
||||
if state is not None:
|
||||
processed_states.append(state)
|
||||
|
||||
# Stack into batches
|
||||
processed_videos = torch.stack(processed_videos)
|
||||
progress_targets = torch.stack(progress_targets).unsqueeze(-1) # (B, T, 1)
|
||||
processed_states = torch.stack(processed_states) if processed_states else None
|
||||
|
||||
# Get model predictions
|
||||
stage_logits, stage_probs, progress_preds = self.sarm_transformer(
|
||||
processed_videos, text_features, processed_states
|
||||
)
|
||||
|
||||
# Compute progress loss (MSE)
|
||||
progress_loss = F.mse_loss(progress_preds, progress_targets)
|
||||
output_dict = {'progress_loss': progress_loss.item()}
|
||||
total_loss = progress_loss
|
||||
|
||||
# Compute stage loss (cross-entropy)
|
||||
stage_labels = observation.get('stage_labels')
|
||||
if stage_labels is None:
|
||||
raise ValueError("stage_labels from annotations is required for SARM training")
|
||||
|
||||
stage_labels = stage_labels.to(self.device)
|
||||
if stage_labels.dim() == 1:
|
||||
stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1)
|
||||
stage_loss = compute_stage_loss(stage_logits, stage_labels)
|
||||
total_loss = total_loss + self.config.stage_loss_weight * stage_loss
|
||||
output_dict['stage_loss'] = stage_loss.item()
|
||||
|
||||
# Misaligned loss: 20% probability
|
||||
if random.random() < 0.2:
|
||||
shuffle_idx = torch.randperm(batch_size, device=self.device)
|
||||
_, _, misaligned_preds = self.sarm_transformer(
|
||||
processed_videos, text_features[shuffle_idx], processed_states
|
||||
)
|
||||
misaligned_loss = F.mse_loss(misaligned_preds, torch.zeros_like(misaligned_preds))
|
||||
total_loss = total_loss + misaligned_loss
|
||||
output_dict['misaligned_loss'] = misaligned_loss.item()
|
||||
|
||||
output_dict['total_loss'] = total_loss.item()
|
||||
return total_loss, output_dict
|
||||
|
||||
def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor:
|
||||
_, _, num_stages = stage_logits.shape
|
||||
stage_logits_flat = stage_logits.reshape(-1, num_stages)
|
||||
target_stages_flat = target_stages.reshape(-1)
|
||||
|
||||
loss = F.cross_entropy(stage_logits_flat, target_stages_flat)
|
||||
return loss
|
||||
@@ -1,644 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from lerobot.processor.core import TransitionKey
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||
from lerobot.processor import (
|
||||
ProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
PolicyAction,
|
||||
DeviceProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
from_tensor_to_numpy,
|
||||
)
|
||||
from lerobot.processor.pipeline import PipelineFeatureType
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.configs.types import PolicyFeature, FeatureType
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
class SARMEncodingProcessorStep(ProcessorStep):
|
||||
"""ProcessorStep that encodes images and text with CLIP."""
|
||||
def __init__(
|
||||
self,
|
||||
config: SARMConfig,
|
||||
image_key: str | None = None,
|
||||
dataset_meta = None,
|
||||
dataset_stats: dict | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.image_key = image_key or config.image_key
|
||||
self.dataset_meta = dataset_meta
|
||||
self.dataset_stats = dataset_stats
|
||||
self.temporal_proportions = {name: prop for name, prop in zip(self.config.subtask_names, self.config.temporal_proportions)}
|
||||
self.subtask_names = self.config.subtask_names
|
||||
|
||||
self.device = torch.device(
|
||||
self.config.device if self.config.device
|
||||
else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||
self.clip_model.to(self.device)
|
||||
self.clip_model.eval()
|
||||
|
||||
def _find_episode_for_frame(self, frame_idx: int) -> int:
|
||||
"""Find the episode index for a given frame index."""
|
||||
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
if ep_start <= frame_idx < ep_end:
|
||||
return ep_idx
|
||||
return 0
|
||||
|
||||
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
|
||||
"""Get episode indices for each frame index."""
|
||||
if episode_index is None:
|
||||
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||
|
||||
episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
|
||||
|
||||
# If single episode but multiple frames, compute episode for each frame
|
||||
if len(episode_indices) == 1 and len(frame_indices) > 1:
|
||||
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||
|
||||
return episode_indices
|
||||
|
||||
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, ep_end: int, num_frames: int) -> torch.Tensor:
|
||||
"""Compute absolute frame indices for symmetric bidirectional pattern.
|
||||
|
||||
Pattern: [ep_start, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
|
||||
Boundary handling:
|
||||
- Backward indices clamp to ep_start (first frame)
|
||||
- Forward indices clamp to ep_end - 1 (last frame)
|
||||
"""
|
||||
indices = []
|
||||
indices.append(ep_start) # Initial frame is always episode start
|
||||
|
||||
# Symmetric context: 4 before, current, 3 after
|
||||
num_before = 4
|
||||
num_after = 3
|
||||
last_valid_frame = ep_end - 1
|
||||
|
||||
# Frames before current (clamp to first frame)
|
||||
for i in range(num_before, 0, -1):
|
||||
idx = max(ep_start, frame_idx - i * self.config.frame_gap)
|
||||
indices.append(idx)
|
||||
|
||||
# Current frame
|
||||
indices.append(frame_idx)
|
||||
|
||||
# Frames after current (clamp to last frame)
|
||||
for i in range(1, num_after + 1):
|
||||
idx = min(last_valid_frame, frame_idx + i * self.config.frame_gap)
|
||||
indices.append(idx)
|
||||
|
||||
return torch.tensor(indices)
|
||||
|
||||
def _compute_episode_metadata(
|
||||
self,
|
||||
frame_indices: np.ndarray,
|
||||
episode_indices: np.ndarray,
|
||||
num_frames: int,
|
||||
) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute episode metadata for all samples.
|
||||
|
||||
Returns:
|
||||
Tuple of (absolute_frame_indices, remaining_lengths, episode_lengths)
|
||||
"""
|
||||
absolute_indices_list = []
|
||||
remaining_lengths = []
|
||||
episode_lengths = []
|
||||
|
||||
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
|
||||
ep_idx, frame_idx = int(ep_idx), int(frame_idx)
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
|
||||
episode_lengths.append(ep_end - ep_start)
|
||||
abs_indices = self._compute_absolute_indices(frame_idx, ep_start, ep_end, num_frames)
|
||||
absolute_indices_list.append(abs_indices)
|
||||
remaining_lengths.append(ep_end - abs_indices[0].item())
|
||||
|
||||
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
|
||||
|
||||
def _compute_stage_and_progress_for_frame(
|
||||
self,
|
||||
current_frame: int,
|
||||
subtask_names: list,
|
||||
subtask_start_frames: list,
|
||||
subtask_end_frames: list,
|
||||
transition_smoothing_frames: int = 15,
|
||||
) -> tuple[int, float, dict[int, float] | None]:
|
||||
"""Compute stage index, cumulative progress, and soft stage labels for a single frame.
|
||||
|
||||
Implements SARM Paper Formula (2):
|
||||
y_t = P_{k-1} + ᾱ_k × τ_t
|
||||
|
||||
where:
|
||||
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
|
||||
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
|
||||
- ᾱ_k is the temporal proportion for subtask k
|
||||
|
||||
Additionally computes soft stage labels near transitions to mitigate discrete jumps
|
||||
in the stage classifier. Near stage boundaries, labels are blended between adjacent
|
||||
stages to encourage smoother predictions.
|
||||
|
||||
Args:
|
||||
current_frame: Frame index relative to episode start
|
||||
subtask_names: List of subtask names for this episode
|
||||
subtask_start_frames: List of subtask start frames
|
||||
subtask_end_frames: List of subtask end frames
|
||||
transition_smoothing_frames: Number of frames over which to smooth labels near transitions
|
||||
|
||||
Returns:
|
||||
Tuple of (stage_idx, cumulative_progress, soft_stage_labels)
|
||||
- stage_idx: Hard stage index (for compatibility)
|
||||
- cumulative_progress: Progress value in [0, 1]
|
||||
- soft_stage_labels: Dict mapping stage_idx -> probability, or None if not near transition
|
||||
"""
|
||||
# Get temporal proportions as list for compute_cumulative_progress
|
||||
temporal_proportions_list = [
|
||||
self.temporal_proportions.get(name, 0.0) for name in self.subtask_names
|
||||
]
|
||||
num_stages = len(self.subtask_names)
|
||||
|
||||
# Find which subtask this frame belongs to
|
||||
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
|
||||
if current_frame >= start_frame and current_frame <= end_frame:
|
||||
# Found the subtask, get its global index
|
||||
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
|
||||
|
||||
# Compute τ_t using utility function (Paper Formula 2)
|
||||
tau = compute_tau(current_frame, start_frame, end_frame)
|
||||
|
||||
# Compute cumulative progress using utility function (Paper Formula 2)
|
||||
cumulative_progress = compute_cumulative_progress_batch(
|
||||
tau, stage_idx, temporal_proportions_list
|
||||
)
|
||||
|
||||
# Compute soft stage labels near transitions
|
||||
soft_stage_labels = None
|
||||
frames_from_start = current_frame - start_frame
|
||||
frames_to_end = end_frame - current_frame
|
||||
|
||||
if frames_from_start < transition_smoothing_frames and j > 0:
|
||||
# Near start of stage - blend with previous stage
|
||||
blend = frames_from_start / transition_smoothing_frames
|
||||
prev_name = subtask_names[j - 1]
|
||||
prev_stage_idx = self.subtask_names.index(prev_name) if prev_name in self.subtask_names else max(0, stage_idx - 1)
|
||||
soft_stage_labels = {prev_stage_idx: 1.0 - blend, stage_idx: blend}
|
||||
|
||||
elif frames_to_end < transition_smoothing_frames and j < len(subtask_names) - 1:
|
||||
# Near end of stage - blend with next stage
|
||||
blend = frames_to_end / transition_smoothing_frames
|
||||
next_name = subtask_names[j + 1]
|
||||
next_stage_idx = self.subtask_names.index(next_name) if next_name in self.subtask_names else min(num_stages - 1, stage_idx + 1)
|
||||
soft_stage_labels = {stage_idx: blend, next_stage_idx: 1.0 - blend}
|
||||
|
||||
return stage_idx, cumulative_progress, soft_stage_labels
|
||||
|
||||
# No matching subtask found
|
||||
if current_frame < subtask_start_frames[0]:
|
||||
return 0, 0.0, None
|
||||
elif current_frame > subtask_end_frames[-1]:
|
||||
return len(self.subtask_names) - 1, 1.0, None
|
||||
else:
|
||||
# Between subtasks - use previous subtask's end state (tau = 1.0)
|
||||
for j in range(len(subtask_names) - 1):
|
||||
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
|
||||
name = subtask_names[j]
|
||||
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
|
||||
|
||||
# Completed subtask, so tau = 1.0
|
||||
cumulative_progress = compute_cumulative_progress_batch(
|
||||
1.0, stage_idx, temporal_proportions_list
|
||||
)
|
||||
return stage_idx, cumulative_progress, None
|
||||
|
||||
return 0, 0.0, None
|
||||
|
||||
def _compute_labels_for_sample(
|
||||
self,
|
||||
frame_idx: int,
|
||||
ep_idx: int,
|
||||
seq_len: int,
|
||||
episodes_df: pd.DataFrame,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | tuple[None, None, None]:
|
||||
"""Compute stage labels, progress targets, and soft stage labels for symmetric bidirectional pattern.
|
||||
|
||||
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
|
||||
Boundary handling:
|
||||
- Before episode start: clamp to frame 0 (progress ~0%)
|
||||
- After episode end: clamp to last frame (progress ~100%)
|
||||
|
||||
Soft stage labels are computed near stage transitions to mitigate discrete jumps.
|
||||
|
||||
Args:
|
||||
frame_idx: The frame index for this sample
|
||||
ep_idx: The episode index
|
||||
seq_len: Number of frames in the sequence
|
||||
episodes_df: DataFrame with episode metadata
|
||||
|
||||
Returns:
|
||||
Tuple of (stage_labels, progress_targets, soft_stage_labels):
|
||||
- stage_labels: (T,) hard stage indices
|
||||
- progress_targets: (T, 1) progress values
|
||||
- soft_stage_labels: (T, num_stages) soft probability labels, or None if no transitions nearby
|
||||
"""
|
||||
# Check if episode has valid annotations
|
||||
if ep_idx >= len(episodes_df):
|
||||
return None, None, None
|
||||
|
||||
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
|
||||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||
return None, None, None
|
||||
|
||||
subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
|
||||
subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
ep_length = ep_end - ep_start
|
||||
last_valid_frame = ep_length - 1
|
||||
|
||||
num_stages = len(self.subtask_names)
|
||||
|
||||
# Generate labels for each frame in the sequence
|
||||
stage_labels = []
|
||||
progress_targets = []
|
||||
soft_labels_list = [] # List of soft label dicts (or None)
|
||||
has_any_soft_labels = False
|
||||
|
||||
# Symmetric pattern: initial + 4 before + current + 3 after = 9 frames
|
||||
num_before = 4
|
||||
num_after = 3
|
||||
|
||||
for i in range(seq_len):
|
||||
if i == 0:
|
||||
# Position 0: Initial frame of the episode
|
||||
current_frame = 0 # Relative to episode start
|
||||
elif i <= num_before:
|
||||
# Positions 1-4: frames before current (with clamping to first frame)
|
||||
offset = -(num_before - i + 1) * self.config.frame_gap
|
||||
current_frame = max(0, frame_idx + offset - ep_start)
|
||||
elif i == num_before + 1:
|
||||
# Position 5: current frame
|
||||
current_frame = frame_idx - ep_start
|
||||
else:
|
||||
# Positions 6-8: frames after current (with clamping to last frame)
|
||||
offset = (i - num_before - 1) * self.config.frame_gap
|
||||
current_frame = min(last_valid_frame, frame_idx + offset - ep_start)
|
||||
|
||||
stage_idx, cumulative_progress, soft_stage_labels = self._compute_stage_and_progress_for_frame(
|
||||
current_frame, subtask_names, subtask_start_frames, subtask_end_frames
|
||||
)
|
||||
|
||||
stage_labels.append(stage_idx)
|
||||
progress_targets.append(cumulative_progress)
|
||||
soft_labels_list.append(soft_stage_labels)
|
||||
if soft_stage_labels is not None:
|
||||
has_any_soft_labels = True
|
||||
|
||||
stage_labels = torch.tensor(stage_labels, dtype=torch.long)
|
||||
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1)
|
||||
|
||||
# Convert soft labels to tensor if any exist
|
||||
soft_stage_labels_tensor = None
|
||||
if has_any_soft_labels:
|
||||
soft_stage_labels_tensor = torch.zeros(seq_len, num_stages, dtype=torch.float32)
|
||||
for i, soft_dict in enumerate(soft_labels_list):
|
||||
if soft_dict is not None:
|
||||
for stage_idx, prob in soft_dict.items():
|
||||
soft_stage_labels_tensor[i, stage_idx] = prob
|
||||
else:
|
||||
# Use hard one-hot label
|
||||
soft_stage_labels_tensor[i, stage_labels[i]] = 1.0
|
||||
|
||||
return stage_labels, progress_targets, soft_stage_labels_tensor
|
||||
|
||||
def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features):
|
||||
"""Generate stage labels, progress targets, and soft stage labels from subtask annotations.
|
||||
|
||||
Args:
|
||||
frame_index: Current frame index or tensor of indices
|
||||
episode_index: Episode index or tensor of indices
|
||||
video_features: Video features tensor to determine sequence length
|
||||
|
||||
Returns:
|
||||
Tuple of (stage_labels, progress_targets, soft_stage_labels) or (None, None, None) if no annotations.
|
||||
- stage_labels: (B, T) hard stage indices
|
||||
- progress_targets: (B, T, 1) progress values
|
||||
- soft_stage_labels: (B, T, num_stages) soft probability labels, or None
|
||||
"""
|
||||
if self.temporal_proportions is None or episode_index is None:
|
||||
return None, None, None
|
||||
|
||||
# Normalize inputs to numpy arrays
|
||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||
|
||||
# Determine sequence length
|
||||
if video_features is not None and video_features.dim() >= 2:
|
||||
seq_len = video_features.shape[1]
|
||||
else:
|
||||
seq_len = 1
|
||||
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
num_stages = len(self.subtask_names)
|
||||
|
||||
all_stage_labels = []
|
||||
all_progress_targets = []
|
||||
all_soft_stage_labels = []
|
||||
has_any_soft_labels = False
|
||||
|
||||
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
|
||||
stage_labels, progress_targets, soft_labels = self._compute_labels_for_sample(
|
||||
int(frame_idx), int(ep_idx), seq_len, episodes_df
|
||||
)
|
||||
|
||||
if stage_labels is None:
|
||||
all_stage_labels.append(torch.zeros(seq_len, dtype=torch.long))
|
||||
all_progress_targets.append(torch.zeros(seq_len, 1, dtype=torch.float32))
|
||||
all_soft_stage_labels.append(None)
|
||||
else:
|
||||
all_stage_labels.append(stage_labels)
|
||||
all_progress_targets.append(progress_targets)
|
||||
all_soft_stage_labels.append(soft_labels)
|
||||
if soft_labels is not None:
|
||||
has_any_soft_labels = True
|
||||
|
||||
stacked_stage_labels = torch.stack(all_stage_labels, dim=0)
|
||||
stacked_progress_targets = torch.stack(all_progress_targets, dim=0)
|
||||
|
||||
# Stack soft labels if any exist
|
||||
stacked_soft_labels = None
|
||||
if has_any_soft_labels:
|
||||
soft_labels_tensors = []
|
||||
for i, soft_labels in enumerate(all_soft_stage_labels):
|
||||
if soft_labels is not None:
|
||||
soft_labels_tensors.append(soft_labels)
|
||||
else:
|
||||
# Create one-hot from hard labels
|
||||
one_hot = torch.zeros(seq_len, num_stages, dtype=torch.float32)
|
||||
for t in range(seq_len):
|
||||
one_hot[t, all_stage_labels[i][t]] = 1.0
|
||||
soft_labels_tensors.append(one_hot)
|
||||
stacked_soft_labels = torch.stack(soft_labels_tensors, dim=0)
|
||||
|
||||
return stacked_stage_labels, stacked_progress_targets, stacked_soft_labels
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Encode images, text, and normalize states in the transition."""
|
||||
|
||||
new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
image = observation.get(self.image_key)
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
video_features = self._encode_images_batch(image)
|
||||
observation['video_features'] = video_features
|
||||
|
||||
# Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep)
|
||||
state_key = self.config.state_key
|
||||
state_data = observation.get(state_key)
|
||||
|
||||
if isinstance(state_data, torch.Tensor):
|
||||
state_tensor = state_data.float()
|
||||
else:
|
||||
state_tensor = torch.tensor(state_data, dtype=torch.float32)
|
||||
|
||||
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
|
||||
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
# Get task description from dataset (complementary_data["task"])
|
||||
task = comp_data.get('task')
|
||||
if isinstance(task, list):
|
||||
# If batch, take first task (assuming same task for all items in batch)
|
||||
task = task[0] if task else ""
|
||||
|
||||
# Encode text with CLIP
|
||||
batch_size = video_features.shape[0]
|
||||
observation['text_features'] = self._encode_text_clip(task, batch_size)
|
||||
|
||||
frame_index = comp_data.get('index')
|
||||
episode_index = comp_data.get('episode_index')
|
||||
|
||||
if frame_index is None:
|
||||
raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA")
|
||||
if episode_index is None:
|
||||
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
|
||||
|
||||
# Compute episode metadata if dataset_meta is available
|
||||
if self.dataset_meta is not None:
|
||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||
|
||||
# Determine number of frames from video features
|
||||
if video_features.dim() >= 2:
|
||||
num_frames = video_features.shape[1]
|
||||
else:
|
||||
num_frames = 1
|
||||
|
||||
abs_indices, remaining, ep_lengths = self._compute_episode_metadata(
|
||||
frame_indices, episode_indices, num_frames
|
||||
)
|
||||
observation['absolute_frame_indices'] = abs_indices
|
||||
observation['remaining_length'] = remaining
|
||||
observation['episode_length'] = ep_lengths
|
||||
|
||||
# Generate stage labels, progress targets, and soft stage labels from subtask annotations
|
||||
if self.temporal_proportions is not None and self.dataset_meta is not None:
|
||||
stage_labels, progress_targets, soft_stage_labels = self._generate_stage_and_progress_labels(
|
||||
frame_index, episode_index, video_features
|
||||
)
|
||||
if stage_labels is not None:
|
||||
observation['stage_labels'] = stage_labels
|
||||
observation['progress_targets'] = progress_targets
|
||||
if soft_stage_labels is not None:
|
||||
observation['soft_stage_labels'] = soft_stage_labels
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
return new_transition
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor:
|
||||
"""Encode a batch of images using CLIP.
|
||||
|
||||
Args:
|
||||
images: Batched images with shape: (B, T, C, H, W)
|
||||
|
||||
Returns:
|
||||
Encoded feature vectors with shape (B, T, 512)
|
||||
"""
|
||||
|
||||
batch_size, seq_length = images.shape[0], images.shape[1]
|
||||
images = images.reshape(batch_size * seq_length, *images.shape[2:])
|
||||
|
||||
# Convert to list of PIL images
|
||||
num_frames = images.shape[0]
|
||||
images_list = []
|
||||
for i in range(num_frames):
|
||||
img = images[i]
|
||||
if img.shape[0] in [1, 3]: # Channel first (C, H, W)
|
||||
img = img.transpose(1, 2, 0)
|
||||
|
||||
# Handle single channel
|
||||
if img.shape[-1] == 1:
|
||||
img = np.repeat(img, 3, axis=-1)
|
||||
|
||||
# Convert to uint8
|
||||
if img.dtype != np.uint8:
|
||||
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
|
||||
|
||||
images_list.append(Image.fromarray(img))
|
||||
|
||||
# Encode each batch
|
||||
all_embeddings = []
|
||||
for i in range(0, num_frames, self.config.clip_batch_size):
|
||||
batch_imgs = images_list[i:i + self.config.clip_batch_size]
|
||||
|
||||
# Process with CLIP
|
||||
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get image embeddings
|
||||
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
|
||||
|
||||
# Handle single frame case
|
||||
if embeddings.dim() == 1:
|
||||
embeddings = embeddings.unsqueeze(0)
|
||||
|
||||
all_embeddings.append(embeddings)
|
||||
|
||||
# Concatenate all embeddings
|
||||
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
|
||||
|
||||
# Reshape back
|
||||
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor:
|
||||
"""Encode text using CLIP text encoder (per SARM paper A.4).
|
||||
|
||||
Args:
|
||||
text: Task description text to encode
|
||||
batch_size: Batch size to replicate for
|
||||
|
||||
Returns:
|
||||
Encoded text features with shape (B, 512)
|
||||
"""
|
||||
# Use CLIP's tokenizer directly for text
|
||||
tokenizer = self.clip_processor.tokenizer
|
||||
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get text features from CLIP
|
||||
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
|
||||
|
||||
# Replicate for batch (B, 512)
|
||||
text_embedding = text_embedding.expand(batch_size, -1)
|
||||
|
||||
return text_embedding
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Add encoded features to the observation features."""
|
||||
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(self.config.num_frames, self.config.image_dim)
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]['text_features'] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE,
|
||||
shape=(self.config.text_dim,)
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.config.num_frames, self.config.max_state_dim)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
def make_sarm_pre_post_processors(
|
||||
config: SARMConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
dataset_meta = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Create pre-processor and post-processor pipelines for SARM.
|
||||
|
||||
The pre-processing pipeline:
|
||||
1. Adds batch dimension
|
||||
2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD)
|
||||
3. SARMEncodingProcessorStep:
|
||||
- Encodes images with CLIP
|
||||
- Pads states to max_state_dim
|
||||
- Encodes text with CLIP
|
||||
4. Moves data to device
|
||||
|
||||
The post-processing pipeline:
|
||||
1. Moves data to CPU
|
||||
"""
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
SARMEncodingProcessorStep(
|
||||
config=config,
|
||||
dataset_meta=dataset_meta,
|
||||
dataset_stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=[DeviceProcessorStep(device="cpu")],
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -1,257 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Sequence, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Pydantic Models for SARM-style Annotation
|
||||
class Timestamp(BaseModel):
|
||||
"""Timestamp in MM:SS or SS format"""
|
||||
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
|
||||
end: str = Field(description="End timestamp (MM:SS or just seconds)")
|
||||
|
||||
|
||||
class Subtask(BaseModel):
|
||||
"""Individual subtask/stage - must use EXACT names from provided list"""
|
||||
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
|
||||
timestamps: Timestamp
|
||||
|
||||
|
||||
class SubtaskAnnotation(BaseModel):
|
||||
"""Complete annotation for a robot manipulation episode"""
|
||||
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
|
||||
|
||||
|
||||
def compute_temporal_proportions(annotations: dict[int, Any], fps: int = 30) -> dict[str, float]:
|
||||
"""
|
||||
Compute dataset-level temporal proportions (priors) for each subtask.
|
||||
|
||||
Implements SARM Paper Formula (1):
|
||||
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
|
||||
where:
|
||||
- M is the number of trajectories (episodes)
|
||||
- L_{i,k} is the duration of subtask k in trajectory i
|
||||
- T_i is the total duration of trajectory i
|
||||
|
||||
This averages the PROPORTION of each subtask within each trajectory,
|
||||
giving equal weight to all trajectories regardless of their absolute length.
|
||||
|
||||
Args:
|
||||
annotations: Dict mapping episode index to SubtaskAnnotation object.
|
||||
Each annotation has a .subtasks list where each subtask has:
|
||||
- .name: subtask name
|
||||
- .timestamps.start: start time as "MM:SS" string
|
||||
- .timestamps.end: end time as "MM:SS" string
|
||||
fps: Frames per second (unused, kept for API compatibility)
|
||||
|
||||
Returns:
|
||||
Dict mapping subtask name to its temporal proportion (ᾱ_k).
|
||||
Proportions are normalized to sum to 1.0.
|
||||
"""
|
||||
subtask_proportions: dict[str, list[float]] = {}
|
||||
|
||||
for annotation in annotations.values():
|
||||
total_duration = 0
|
||||
durations: dict[str, int] = {}
|
||||
|
||||
for subtask in annotation.subtasks:
|
||||
start_parts = subtask.timestamps.start.split(":")
|
||||
end_parts = subtask.timestamps.end.split(":")
|
||||
|
||||
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1]) if len(start_parts) == 2 else int(start_parts[0])
|
||||
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0])
|
||||
|
||||
duration = end_seconds - start_seconds
|
||||
durations[subtask.name] = duration
|
||||
total_duration += duration
|
||||
|
||||
# Calculate L_{i,k} / T_i for each subtask in this trajectory
|
||||
if total_duration > 0:
|
||||
for name, duration in durations.items():
|
||||
if name not in subtask_proportions:
|
||||
subtask_proportions[name] = []
|
||||
subtask_proportions[name].append(duration / total_duration)
|
||||
|
||||
if not subtask_proportions:
|
||||
return {}
|
||||
|
||||
# Average across trajectories: (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
avg_proportions = {
|
||||
name: sum(props) / len(props)
|
||||
for name, props in subtask_proportions.items()
|
||||
}
|
||||
|
||||
# Normalize to ensure sum = 1
|
||||
total = sum(avg_proportions.values())
|
||||
if total > 0:
|
||||
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
|
||||
|
||||
return avg_proportions
|
||||
|
||||
|
||||
def compute_tau(
|
||||
current_frame: int | float,
|
||||
subtask_start: int | float,
|
||||
subtask_end: int | float,
|
||||
) -> float:
|
||||
"""
|
||||
Compute within-subtask normalized time τ_t.
|
||||
|
||||
Implements part of SARM Paper Formula (2):
|
||||
τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
|
||||
|
||||
where:
|
||||
- t is the current frame
|
||||
- s_k is the start frame of subtask k
|
||||
- e_k is the end frame of subtask k
|
||||
|
||||
Args:
|
||||
current_frame: Current frame index (t)
|
||||
subtask_start: Start frame of the subtask (s_k)
|
||||
subtask_end: End frame of the subtask (e_k)
|
||||
|
||||
Returns:
|
||||
Within-subtask progress τ_t ∈ [0, 1]
|
||||
"""
|
||||
subtask_duration = subtask_end - subtask_start
|
||||
|
||||
if subtask_duration <= 0:
|
||||
return 1.0
|
||||
|
||||
tau = (current_frame - subtask_start) / subtask_duration
|
||||
|
||||
return float(np.clip(tau, 0.0, 1.0))
|
||||
|
||||
|
||||
def compute_cumulative_progress_batch(
|
||||
tau: torch.Tensor | float,
|
||||
stage_indices: torch.Tensor | int,
|
||||
alpha: torch.Tensor | Sequence[float],
|
||||
cumulative_prior: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | float:
|
||||
"""
|
||||
Compute cumulative normalized progress from within-subtask progress.
|
||||
|
||||
This function implements the core formula used in SARM for both:
|
||||
|
||||
**Formula 2 (Training labels):**
|
||||
y_t = P_{k-1} + ᾱ_k × τ_t ∈ [0, 1]
|
||||
|
||||
Used to compute ground-truth progress labels from subtask annotations.
|
||||
- τ_t comes from annotated frame position: τ_t = (t - s_k) / (e_k - s_k)
|
||||
- k is the known subtask from annotations
|
||||
|
||||
**Formula 4 (Inference predictions):**
|
||||
ŷ_{1:N} = P̂_{k-1, 1:N} + ᾱ_{k, 1:N} × τ̂_{1:N} ∈ [0, 1]
|
||||
|
||||
Used to convert model outputs to cumulative progress during inference.
|
||||
- τ̂ comes from the subtask MLP head (conditioned on predicted stage)
|
||||
- k = Ŝ is the predicted stage from Formula 3: Ŝ = argmax(softmax(Ψ))
|
||||
|
||||
The formulas are mathematically identical; only the source of inputs differs:
|
||||
- Training: τ and k from annotations → ground-truth labels
|
||||
- Inference: τ̂ and Ŝ from model → predicted progress
|
||||
|
||||
where:
|
||||
- P_{k-1} = Σ_{j=1}^{k-1} ᾱ_j is the cumulative prior (sum of previous proportions)
|
||||
- ᾱ_k is the temporal proportion for subtask k (from Formula 1)
|
||||
- τ is within-subtask progress ∈ [0, 1]
|
||||
|
||||
This ensures:
|
||||
- y at start of subtask k = P_{k-1}
|
||||
- y at end of subtask k = P_k
|
||||
|
||||
Supports both scalar and batched tensor inputs:
|
||||
- Scalar: tau (float), stage_indices (int), alpha (list/sequence)
|
||||
- Batch: tau (Tensor), stage_indices (Tensor), alpha (Tensor), cumulative_prior (Tensor)
|
||||
|
||||
Args:
|
||||
tau: Within-subtask progress τ ∈ [0, 1].
|
||||
For training: computed from frame position in annotated subtask.
|
||||
For inference: predicted by subtask MLP head.
|
||||
Scalar float or Tensor with shape (..., 1)
|
||||
stage_indices: Index of current subtask k (0-indexed).
|
||||
For training: known from annotations.
|
||||
For inference: predicted via argmax(stage_probs) (Formula 3).
|
||||
Scalar int or Tensor with shape (...)
|
||||
alpha: Temporal proportions ᾱ with shape (num_stages,) or Sequence[float].
|
||||
Computed from dataset annotations using Formula 1.
|
||||
cumulative_prior: Optional. Cumulative priors P with shape (num_stages + 1,)
|
||||
where cumulative_prior[k] = P_k = Σ_{j=1}^{k} ᾱ_j.
|
||||
If None, will be computed from alpha.
|
||||
|
||||
Returns:
|
||||
Cumulative progress y ∈ [0, 1].
|
||||
Scalar float if inputs are scalar, otherwise Tensor with shape (..., 1)
|
||||
"""
|
||||
if not isinstance(tau, torch.Tensor):
|
||||
if not alpha:
|
||||
raise ValueError("alpha (temporal_proportions) cannot be empty")
|
||||
|
||||
if isinstance(alpha, torch.Tensor):
|
||||
alpha_list = alpha.tolist()
|
||||
else:
|
||||
alpha_list = list(alpha)
|
||||
|
||||
if stage_indices < 0 or stage_indices >= len(alpha_list):
|
||||
raise ValueError(
|
||||
f"stage_indices {stage_indices} out of range "
|
||||
f"for {len(alpha_list)} subtasks"
|
||||
)
|
||||
|
||||
# P_{k-1} = sum of proportions for subtasks 0 to k-1
|
||||
P_k_minus_1 = sum(alpha_list[:stage_indices])
|
||||
|
||||
# ᾱ_k = proportion for current subtask
|
||||
alpha_k = alpha_list[stage_indices]
|
||||
|
||||
# y_t = P_{k-1} + ᾱ_k × τ_t
|
||||
y_t = P_k_minus_1 + alpha_k * tau
|
||||
|
||||
return float(np.clip(y_t, 0.0, 1.0))
|
||||
|
||||
if not isinstance(alpha, torch.Tensor):
|
||||
alpha = torch.tensor(alpha, dtype=torch.float32)
|
||||
|
||||
# Compute cumulative_prior if not provided
|
||||
if cumulative_prior is None:
|
||||
cumulative_prior = torch.zeros(len(alpha) + 1, dtype=alpha.dtype, device=alpha.device)
|
||||
cumulative_prior[1:] = torch.cumsum(alpha, dim=0)
|
||||
|
||||
# P_{k-1} for each predicted stage
|
||||
P_k_minus_1 = cumulative_prior[stage_indices]
|
||||
|
||||
# ᾱ_k for each predicted stage
|
||||
alpha_k = alpha[stage_indices]
|
||||
|
||||
# ŷ = P_{k-1} + ᾱ_k × τ̂
|
||||
progress = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau
|
||||
|
||||
return progress
|
||||
|
||||
def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
|
||||
"""Pad the state tensor's last dimension to max_state_dim with zeros."""
|
||||
current_dim = state.shape[-1]
|
||||
if current_dim >= max_state_dim:
|
||||
return state[..., :max_state_dim] # Truncate if larger
|
||||
|
||||
# Pad with zeros on the right
|
||||
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
|
||||
return F.pad(state, padding, mode='constant', value=0)
|
||||
@@ -230,10 +230,6 @@ def validate_visual_features_consistency(
|
||||
) -> None:
|
||||
"""
|
||||
Validates visual feature consistency between a policy config and provided dataset/environment features.
|
||||
|
||||
Validation passes if EITHER:
|
||||
- Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more)
|
||||
- Dataset's provided visuals are a subset of policy (policy declares extras for flexibility)
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
|
||||
@@ -241,11 +237,5 @@ def validate_visual_features_consistency(
|
||||
"""
|
||||
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
|
||||
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
|
||||
|
||||
# Accept if either direction is a subset
|
||||
policy_subset_of_dataset = expected_visuals.issubset(provided_visuals)
|
||||
dataset_subset_of_policy = provided_visuals.issubset(expected_visuals)
|
||||
|
||||
if not (policy_subset_of_dataset or dataset_subset_of_policy):
|
||||
if not provided_visuals.issubset(expected_visuals):
|
||||
raise_feature_mismatch_error(provided_visuals, expected_visuals)
|
||||
|
||||
|
||||
@@ -170,9 +170,8 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -14,21 +14,5 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.modeling_sarm import (
|
||||
SARMRewardModel,
|
||||
SARMTransformer,
|
||||
)
|
||||
from lerobot.policies.sarm.processor_sarm import (
|
||||
SARMEncodingProcessorStep,
|
||||
make_sarm_pre_post_processors,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SARMConfig",
|
||||
"SARMRewardModel",
|
||||
"SARMTransformer",
|
||||
"SARMEncodingProcessorStep",
|
||||
"make_sarm_pre_post_processors",
|
||||
]
|
||||
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
from .unitree_g1 import UnitreeG1
|
||||
108
src/lerobot/robots/unitree_g1/arm_calibration.json
Normal file
108
src/lerobot/robots/unitree_g1/arm_calibration.json
Normal file
@@ -0,0 +1,108 @@
|
||||
{
|
||||
"kLeftShoulderPitch.pos": {
|
||||
"id": 0,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -3,
|
||||
"range_max": 1
|
||||
},
|
||||
"kLeftShoulderYaw.pos": {
|
||||
"id": 1,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -2.6,
|
||||
"range_max": 2.6
|
||||
},
|
||||
"kLeftShoulderRoll.pos": {
|
||||
"id": 2,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -0.1,
|
||||
"range_max": 2.2
|
||||
},
|
||||
"kLeftElbow.pos": {
|
||||
"id": 3,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -1,
|
||||
"range_max": 1
|
||||
},
|
||||
"kLeftWristRoll.pos": {
|
||||
"id": 4,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -1.9,
|
||||
"range_max": 1.9
|
||||
},
|
||||
"kLeftWristYaw.pos": {
|
||||
"id": 5,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": 0.0,
|
||||
"range_max": 0.0
|
||||
},
|
||||
"kLeftWristyaw.pos": {
|
||||
"id": 5,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": 0.0,
|
||||
"range_max": 0.0
|
||||
},
|
||||
"kLeftWristPitch.pos": {
|
||||
"id": 6,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": 0.0,
|
||||
"range_max": 0.0
|
||||
},
|
||||
|
||||
"kRightShoulderPitch.pos": {
|
||||
"id": 0,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -3.0,
|
||||
"range_max": 1
|
||||
},
|
||||
"kRightShoulderYaw.pos": {
|
||||
"id": 1,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -2.6,
|
||||
"range_max": 2.6
|
||||
},
|
||||
"kRightShoulderRoll.pos": {
|
||||
"id": 2,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -2.2,
|
||||
"range_max": 0.5
|
||||
},
|
||||
"kRightElbow.pos": {
|
||||
"id": 3,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -1,
|
||||
"range_max": 1
|
||||
},
|
||||
"kRightWristRoll.pos": {
|
||||
"id": 4,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": -1.9,
|
||||
"range_max": 1.9
|
||||
},
|
||||
"kRightWristYaw.pos": {
|
||||
"id": 5,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": 0.0,
|
||||
"range_max": 0.0
|
||||
},
|
||||
"kRightWristPitch.pos": {
|
||||
"id": 6,
|
||||
"drive_mode": 0,
|
||||
"homing_offset": 0,
|
||||
"range_min": 0.0,
|
||||
"range_max": 0.0
|
||||
}
|
||||
}
|
||||
2
src/lerobot/robots/unitree_g1/assets/g1/.gitignore
vendored
Normal file
2
src/lerobot/robots/unitree_g1/assets/g1/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
*.gv
|
||||
*.pdf
|
||||
33
src/lerobot/robots/unitree_g1/assets/g1/README.md
Normal file
33
src/lerobot/robots/unitree_g1/assets/g1/README.md
Normal file
@@ -0,0 +1,33 @@
|
||||
# Unitree G1 Description (URDF & MJCF)
|
||||
|
||||
## Overview
|
||||
|
||||
This package includes a universal humanoid robot description (URDF & MJCF) for the [Unitree G1](https://www.unitree.com/g1/), developed by [Unitree Robotics](https://www.unitree.com/).
|
||||
|
||||
MJCF/URDF for the G1 robot:
|
||||
|
||||
| MJCF/URDF file name | `mode_machine` | Hip roll reduction ratio | Update status | dof#leg | dof#waist | dof#arm | dof#hand |
|
||||
| ----------------------------- | :------------: | :----------------------: | ------------- | :-----: | :-------: | :-----: | :------: |
|
||||
| `g1_23dof` | 1 | 14.5 | Beta | 6*2 | 1 | 5*2 | 0 |
|
||||
| `g1_29dof` | 2 | 14.5 | Beta | 6*2 | 3 | 7*2 | 0 |
|
||||
| `g1_29dof_with_hand` | 2 | 14.5 | Beta | 6*2 | 3 | 7*2 | 7*2 |
|
||||
| `g1_29dof_lock_waist` | 3 | 14.5 | Beta | 6*2 | 1 | 7*2 | 0 |
|
||||
| `g1_23dof_rev_1_0` | 4 | 22.5 | Up-to-date | 6*2 | 1 | 5*2 | 0 |
|
||||
| `g1_29dof_rev_1_0` | 5 | 22.5 | Up-to-date | 6*2 | 3 | 7*2 | 0 |
|
||||
| `g1_29dof_with_hand_rev_1_0` | 5 | 22.5 | Up-to-date | 6*2 | 3 | 7*2 | 7*2 |
|
||||
| `g1_29dof_lock_waist_rev_1_0` | 6 | 22.5 | Up-to-date | 6*2 | 1 | 7*2 | 0 |
|
||||
| `g1_dual_arm` | 9 | null | Up-to-date | 0 | 0 | 7*2 | 0 |
|
||||
|
||||
## Visulization with [MuJoCo](https://github.com/google-deepmind/mujoco)
|
||||
|
||||
1. Open MuJoCo Viewer
|
||||
|
||||
```bash
|
||||
pip install mujoco
|
||||
python -m mujoco.viewer
|
||||
```
|
||||
|
||||
2. Drag and drop the MJCF/URDF model file (`g1_XXX.xml`/`g1_XXX.urdf`) to the MuJoCo Viewer.
|
||||
|
||||
## Note for teleoperate
|
||||
g1_body29_hand14 is modified from [g1_29dof_with_hand_rev_1_0](https://github.com/unitreerobotics/unitree_ros/blob/master/robots/g1_description/g1_29dof_with_hand_rev_1_0.urdf)
|
||||
903
src/lerobot/robots/unitree_g1/assets/g1/g1_body23.urdf
Normal file
903
src/lerobot/robots/unitree_g1/assets/g1/g1_body23.urdf
Normal file
@@ -0,0 +1,903 @@
|
||||
<robot name="g1_23dof">
|
||||
<mujoco>
|
||||
<compiler meshdir="meshes" discardvisual="false"/>
|
||||
</mujoco>
|
||||
|
||||
<!-- [CAUTION] uncomment when convert to mujoco -->
|
||||
<!-- <link name="world"></link>
|
||||
<joint name="floating_base_joint" type="floating">
|
||||
<parent link="world"/>
|
||||
<child link="pelvis"/>
|
||||
</joint> -->
|
||||
|
||||
<link name="pelvis">
|
||||
<inertial>
|
||||
<origin xyz="0 0 -0.07605" rpy="0 0 0"/>
|
||||
<mass value="3.813"/>
|
||||
<inertia ixx="0.010549" ixy="0" ixz="2.1E-06" iyy="0.0093089" iyz="0" izz="0.0079184"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/pelvis.STL"/>
|
||||
</geometry>
|
||||
<material name="dark">
|
||||
<color rgba="0.2 0.2 0.2 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
</link>
|
||||
<link name="pelvis_contour_link">
|
||||
<inertial>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<mass value="0.001"/>
|
||||
<inertia ixx="1e-7" ixy="0" ixz="0" iyy="1e-7" iyz="0" izz="1e-7"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/pelvis_contour_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/pelvis_contour_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="pelvis_contour_joint" type="fixed">
|
||||
<parent link="pelvis"/>
|
||||
<child link="pelvis_contour_link"/>
|
||||
</joint>
|
||||
|
||||
<!-- Legs -->
|
||||
<link name="left_hip_pitch_link">
|
||||
<inertial>
|
||||
<origin xyz="0.002741 0.047791 -0.02606" rpy="0 0 0"/>
|
||||
<mass value="1.35"/>
|
||||
<inertia ixx="0.001811" ixy="3.68E-05" ixz="-3.44E-05" iyy="0.0014193" iyz="0.000171" izz="0.0012812"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_hip_pitch_link.STL"/>
|
||||
</geometry>
|
||||
<material name="dark">
|
||||
<color rgba="0.2 0.2 0.2 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_hip_pitch_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="left_hip_pitch_joint" type="revolute">
|
||||
<origin xyz="0 0.064452 -0.1027" rpy="0 0 0"/>
|
||||
<parent link="pelvis"/>
|
||||
<child link="left_hip_pitch_link"/>
|
||||
<axis xyz="0 1 0"/>
|
||||
<limit lower="-2.5307" upper="2.8798" effort="88" velocity="32"/>
|
||||
</joint>
|
||||
<link name="left_hip_roll_link">
|
||||
<inertial>
|
||||
<origin xyz="0.029812 -0.001045 -0.087934" rpy="0 0 0"/>
|
||||
<mass value="1.52"/>
|
||||
<inertia ixx="0.0023773" ixy="-3.8E-06" ixz="-0.0003908" iyy="0.0024123" iyz="1.84E-05" izz="0.0016595"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_hip_roll_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_hip_roll_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="left_hip_roll_joint" type="revolute">
|
||||
<origin xyz="0 0.052 -0.030465" rpy="0 -0.1749 0"/>
|
||||
<parent link="left_hip_pitch_link"/>
|
||||
<child link="left_hip_roll_link"/>
|
||||
<axis xyz="1 0 0"/>
|
||||
<limit lower="-0.5236" upper="2.9671" effort="88" velocity="32"/>
|
||||
</joint>
|
||||
<link name="left_hip_yaw_link">
|
||||
<inertial>
|
||||
<origin xyz="-0.057709 -0.010981 -0.15078" rpy="0 0 0"/>
|
||||
<mass value="1.702"/>
|
||||
<inertia ixx="0.0057774" ixy="-0.0005411" ixz="-0.0023948" iyy="0.0076124" iyz="-0.0007072" izz="0.003149"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_hip_yaw_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_hip_yaw_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="left_hip_yaw_joint" type="revolute">
|
||||
<origin xyz="0.025001 0 -0.12412" rpy="0 0 0"/>
|
||||
<parent link="left_hip_roll_link"/>
|
||||
<child link="left_hip_yaw_link"/>
|
||||
<axis xyz="0 0 1"/>
|
||||
<limit lower="-2.7576" upper="2.7576" effort="88" velocity="32"/>
|
||||
</joint>
|
||||
<link name="left_knee_link">
|
||||
<inertial>
|
||||
<origin xyz="0.005457 0.003964 -0.12074" rpy="0 0 0"/>
|
||||
<mass value="1.932"/>
|
||||
<inertia ixx="0.011329" ixy="4.82E-05" ixz="-4.49E-05" iyy="0.011277" iyz="-0.0007146" izz="0.0015168"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_knee_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_knee_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="left_knee_joint" type="revolute">
|
||||
<origin xyz="-0.078273 0.0021489 -0.17734" rpy="0 0.1749 0"/>
|
||||
<parent link="left_hip_yaw_link"/>
|
||||
<child link="left_knee_link"/>
|
||||
<axis xyz="0 1 0"/>
|
||||
<limit lower="-0.087267" upper="2.8798" effort="139" velocity="20"/>
|
||||
</joint>
|
||||
<link name="left_ankle_pitch_link">
|
||||
<inertial>
|
||||
<origin xyz="-0.007269 0 0.011137" rpy="0 0 0"/>
|
||||
<mass value="0.074"/>
|
||||
<inertia ixx="8.4E-06" ixy="0" ixz="-2.9E-06" iyy="1.89E-05" iyz="0" izz="1.26E-05"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_ankle_pitch_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_ankle_pitch_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="left_ankle_pitch_joint" type="revolute">
|
||||
<origin xyz="0 -9.4445E-05 -0.30001" rpy="0 0 0"/>
|
||||
<parent link="left_knee_link"/>
|
||||
<child link="left_ankle_pitch_link"/>
|
||||
<axis xyz="0 1 0"/>
|
||||
<limit lower="-0.87267" upper="0.5236" effort="50" velocity="37"/>
|
||||
</joint>
|
||||
<link name="left_ankle_roll_link">
|
||||
<inertial>
|
||||
<origin xyz="0.026505 0 -0.016425" rpy="0 0 0"/>
|
||||
<mass value="0.608"/>
|
||||
<inertia ixx="0.0002231" ixy="2E-07" ixz="8.91E-05" iyy="0.0016161" iyz="-1E-07" izz="0.0016667"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_ankle_roll_link.STL"/>
|
||||
</geometry>
|
||||
<material name="dark">
|
||||
<color rgba="0.2 0.2 0.2 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="-0.05 0.025 -0.03" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<sphere radius="0.005"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
<collision>
|
||||
<origin xyz="-0.05 -0.025 -0.03" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<sphere radius="0.005"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
<collision>
|
||||
<origin xyz="0.12 0.03 -0.03" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<sphere radius="0.005"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
<collision>
|
||||
<origin xyz="0.12 -0.03 -0.03" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<sphere radius="0.005"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="left_ankle_roll_joint" type="revolute">
|
||||
<origin xyz="0 0 -0.017558" rpy="0 0 0"/>
|
||||
<parent link="left_ankle_pitch_link"/>
|
||||
<child link="left_ankle_roll_link"/>
|
||||
<axis xyz="1 0 0"/>
|
||||
<limit lower="-0.2618" upper="0.2618" effort="50" velocity="37"/>
|
||||
</joint>
|
||||
<link name="right_hip_pitch_link">
|
||||
<inertial>
|
||||
<origin xyz="0.002741 -0.047791 -0.02606" rpy="0 0 0"/>
|
||||
<mass value="1.35"/>
|
||||
<inertia ixx="0.001811" ixy="-3.68E-05" ixz="-3.44E-05" iyy="0.0014193" iyz="-0.000171" izz="0.0012812"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_hip_pitch_link.STL"/>
|
||||
</geometry>
|
||||
<material name="dark">
|
||||
<color rgba="0.2 0.2 0.2 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_hip_pitch_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="right_hip_pitch_joint" type="revolute">
|
||||
<origin xyz="0 -0.064452 -0.1027" rpy="0 0 0"/>
|
||||
<parent link="pelvis"/>
|
||||
<child link="right_hip_pitch_link"/>
|
||||
<axis xyz="0 1 0"/>
|
||||
<limit lower="-2.5307" upper="2.8798" effort="88" velocity="32"/>
|
||||
</joint>
|
||||
<link name="right_hip_roll_link">
|
||||
<inertial>
|
||||
<origin xyz="0.029812 0.001045 -0.087934" rpy="0 0 0"/>
|
||||
<mass value="1.52"/>
|
||||
<inertia ixx="0.0023773" ixy="3.8E-06" ixz="-0.0003908" iyy="0.0024123" iyz="-1.84E-05" izz="0.0016595"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_hip_roll_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_hip_roll_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="right_hip_roll_joint" type="revolute">
|
||||
<origin xyz="0 -0.052 -0.030465" rpy="0 -0.1749 0"/>
|
||||
<parent link="right_hip_pitch_link"/>
|
||||
<child link="right_hip_roll_link"/>
|
||||
<axis xyz="1 0 0"/>
|
||||
<limit lower="-2.9671" upper="0.5236" effort="88" velocity="32"/>
|
||||
</joint>
|
||||
<link name="right_hip_yaw_link">
|
||||
<inertial>
|
||||
<origin xyz="-0.057709 0.010981 -0.15078" rpy="0 0 0"/>
|
||||
<mass value="1.702"/>
|
||||
<inertia ixx="0.0057774" ixy="0.0005411" ixz="-0.0023948" iyy="0.0076124" iyz="0.0007072" izz="0.003149"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_hip_yaw_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_hip_yaw_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="right_hip_yaw_joint" type="revolute">
|
||||
<origin xyz="0.025001 0 -0.12412" rpy="0 0 0"/>
|
||||
<parent link="right_hip_roll_link"/>
|
||||
<child link="right_hip_yaw_link"/>
|
||||
<axis xyz="0 0 1"/>
|
||||
<limit lower="-2.7576" upper="2.7576" effort="88" velocity="32"/>
|
||||
</joint>
|
||||
<link name="right_knee_link">
|
||||
<inertial>
|
||||
<origin xyz="0.005457 -0.003964 -0.12074" rpy="0 0 0"/>
|
||||
<mass value="1.932"/>
|
||||
<inertia ixx="0.011329" ixy="-4.82E-05" ixz="4.49E-05" iyy="0.011277" iyz="0.0007146" izz="0.0015168"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_knee_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_knee_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="right_knee_joint" type="revolute">
|
||||
<origin xyz="-0.078273 -0.0021489 -0.17734" rpy="0 0.1749 0"/>
|
||||
<parent link="right_hip_yaw_link"/>
|
||||
<child link="right_knee_link"/>
|
||||
<axis xyz="0 1 0"/>
|
||||
<limit lower="-0.087267" upper="2.8798" effort="139" velocity="20"/>
|
||||
</joint>
|
||||
<link name="right_ankle_pitch_link">
|
||||
<inertial>
|
||||
<origin xyz="-0.007269 0 0.011137" rpy="0 0 0"/>
|
||||
<mass value="0.074"/>
|
||||
<inertia ixx="8.4E-06" ixy="0" ixz="-2.9E-06" iyy="1.89E-05" iyz="0" izz="1.26E-05"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_ankle_pitch_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_ankle_pitch_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="right_ankle_pitch_joint" type="revolute">
|
||||
<origin xyz="0 9.4445E-05 -0.30001" rpy="0 0 0"/>
|
||||
<parent link="right_knee_link"/>
|
||||
<child link="right_ankle_pitch_link"/>
|
||||
<axis xyz="0 1 0"/>
|
||||
<limit lower="-0.87267" upper="0.5236" effort="50" velocity="37"/>
|
||||
</joint>
|
||||
<link name="right_ankle_roll_link">
|
||||
<inertial>
|
||||
<origin xyz="0.026505 0 -0.016425" rpy="0 0 0"/>
|
||||
<mass value="0.608"/>
|
||||
<inertia ixx="0.0002231" ixy="-2E-07" ixz="8.91E-05" iyy="0.0016161" iyz="1E-07" izz="0.0016667"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_ankle_roll_link.STL"/>
|
||||
</geometry>
|
||||
<material name="dark">
|
||||
<color rgba="0.2 0.2 0.2 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="-0.05 0.025 -0.03" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<sphere radius="0.005"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
<collision>
|
||||
<origin xyz="-0.05 -0.025 -0.03" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<sphere radius="0.005"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
<collision>
|
||||
<origin xyz="0.12 0.03 -0.03" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<sphere radius="0.005"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
<collision>
|
||||
<origin xyz="0.12 -0.03 -0.03" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<sphere radius="0.005"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="right_ankle_roll_joint" type="revolute">
|
||||
<origin xyz="0 0 -0.017558" rpy="0 0 0"/>
|
||||
<parent link="right_ankle_pitch_link"/>
|
||||
<child link="right_ankle_roll_link"/>
|
||||
<axis xyz="1 0 0"/>
|
||||
<limit lower="-0.2618" upper="0.2618" effort="50" velocity="37"/>
|
||||
</joint>
|
||||
|
||||
<!-- Torso -->
|
||||
<link name="waist_yaw_fixed_link">
|
||||
<inertial>
|
||||
<origin xyz="0.003964 0 0.018769" rpy="0 0 0"/>
|
||||
<mass value="0.244"/>
|
||||
<inertia ixx="9.9587E-05" ixy="-1.833E-06" ixz="-1.2617E-05" iyy="0.00012411" iyz="-1.18E-07" izz="0.00015586"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/waist_yaw_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
</link>
|
||||
<joint name="waist_yaw_fixed_joint" type="fixed">
|
||||
<origin xyz="0.0039635 0 -0.054" rpy="0 0 0"/>
|
||||
<parent link="torso_link"/>
|
||||
<child link="waist_yaw_fixed_link"/>
|
||||
</joint>
|
||||
<joint name="waist_yaw_joint" type="revolute">
|
||||
<origin xyz="-0.0039635 0 0.054" rpy="0 0 0"/>
|
||||
<parent link="pelvis"/>
|
||||
<child link="torso_link"/>
|
||||
<axis xyz="0 0 1"/>
|
||||
<limit lower="-2.618" upper="2.618" effort="88" velocity="32"/>
|
||||
</joint>
|
||||
<link name="torso_link">
|
||||
<inertial>
|
||||
<origin xyz="0.002601 0.000257 0.153719" rpy="0 0 0"/>
|
||||
<mass value="8.562"/>
|
||||
<inertia ixx="0.065674966" ixy="-8.597E-05" ixz="-0.001737252" iyy="0.053535188" iyz="8.6899E-05" izz="0.030808125"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/torso_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/torso_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
|
||||
<!-- LOGO -->
|
||||
<joint name="logo_joint" type="fixed">
|
||||
<origin xyz="0.0039635 0 -0.054" rpy="0 0 0"/>
|
||||
<parent link="torso_link"/>
|
||||
<child link="logo_link"/>
|
||||
</joint>
|
||||
<link name="logo_link">
|
||||
<inertial>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<mass value="0.001"/>
|
||||
<inertia ixx="1e-7" ixy="0" ixz="0" iyy="1e-7" iyz="0" izz="1e-7"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/logo_link.STL"/>
|
||||
</geometry>
|
||||
<material name="dark">
|
||||
<color rgba="0.2 0.2 0.2 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/logo_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
|
||||
<!-- Head -->
|
||||
<link name="head_link">
|
||||
<inertial>
|
||||
<origin xyz="0.005267 0.000299 0.449869" rpy="0 0 0"/>
|
||||
<mass value="1.036"/>
|
||||
<inertia ixx="0.004085051" ixy="-2.543E-06" ixz="-6.9455E-05" iyy="0.004185212" iyz="-3.726E-06" izz="0.001807911"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/head_link.STL"/>
|
||||
</geometry>
|
||||
<material name="dark">
|
||||
<color rgba="0.2 0.2 0.2 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/head_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="head_joint" type="fixed">
|
||||
<origin xyz="0.0039635 0 -0.054" rpy="0 0 0"/>
|
||||
<parent link="torso_link"/>
|
||||
<child link="head_link"/>
|
||||
</joint>
|
||||
|
||||
<!-- Waist Support -->
|
||||
<link name="waist_support_link">
|
||||
<inertial>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<mass value="0.001"/>
|
||||
<inertia ixx="1e-7" ixy="0" ixz="0" iyy="1e-7" iyz="0" izz="1e-7"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/waist_support_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/waist_support_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="waist_support_joint" type="fixed">
|
||||
<origin xyz="0.0039635 0 -0.054" rpy="0 0 0"/>
|
||||
<parent link="torso_link"/>
|
||||
<child link="waist_support_link"/>
|
||||
</joint>
|
||||
|
||||
<!-- IMU -->
|
||||
<link name="imu_in_torso"></link>
|
||||
<joint name="imu_in_torso_joint" type="fixed">
|
||||
<origin xyz="-0.03959 -0.00224 0.13792" rpy="0 0 0"/>
|
||||
<parent link="torso_link"/>
|
||||
<child link="imu_in_torso"/>
|
||||
</joint>
|
||||
|
||||
<link name="imu_in_pelvis"></link>
|
||||
<joint name="imu_in_pelvis_joint" type="fixed">
|
||||
<origin xyz="0.04525 0 -0.08339" rpy="0 0 0"/>
|
||||
<parent link="pelvis"/>
|
||||
<child link="imu_in_pelvis"/>
|
||||
</joint>
|
||||
|
||||
<!-- d435 -->
|
||||
<link name="d435_link"></link>
|
||||
<joint name="d435_joint" type="fixed">
|
||||
<origin xyz="0.0576235 0.01753 0.41987" rpy="0 0.8307767239493009 0"/>
|
||||
<parent link="torso_link"/>
|
||||
<child link="d435_link"/>
|
||||
</joint>
|
||||
|
||||
<!-- mid360 -->
|
||||
<link name="mid360_link"></link>
|
||||
<joint name="mid360_joint" type="fixed">
|
||||
<origin xyz="0.0002835 0.00003 0.40618" rpy="0 0.04014257279586953 0"/>
|
||||
<parent link="torso_link"/>
|
||||
<child link="mid360_link"/>
|
||||
</joint>
|
||||
|
||||
<!-- Arm -->
|
||||
<link name="left_shoulder_pitch_link">
|
||||
<inertial>
|
||||
<origin xyz="0 0.035892 -0.011628" rpy="0 0 0"/>
|
||||
<mass value="0.718"/>
|
||||
<inertia ixx="0.0004291" ixy="-9.2E-06" ixz="6.4E-06" iyy="0.000453" iyz="2.26E-05" izz="0.000423"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_shoulder_pitch_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0.04 -0.01" rpy="0 1.5707963267948966 0"/>
|
||||
<geometry>
|
||||
<cylinder radius="0.03" length="0.05"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="left_shoulder_pitch_joint" type="revolute">
|
||||
<origin xyz="0.0039563 0.10022 0.23778" rpy="0.27931 5.4949E-05 -0.00019159"/>
|
||||
<parent link="torso_link"/>
|
||||
<child link="left_shoulder_pitch_link"/>
|
||||
<axis xyz="0 1 0"/>
|
||||
<limit lower="-3.0892" upper="2.6704" effort="25" velocity="37"/>
|
||||
</joint>
|
||||
<link name="left_shoulder_roll_link">
|
||||
<inertial>
|
||||
<origin xyz="-0.000227 0.00727 -0.063243" rpy="0 0 0"/>
|
||||
<mass value="0.643"/>
|
||||
<inertia ixx="0.0006177" ixy="-1E-06" ixz="8.7E-06" iyy="0.0006912" iyz="-5.3E-06" izz="0.0003894"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_shoulder_roll_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="-0.004 0.006 -0.053" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<cylinder radius="0.03" length="0.03"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="left_shoulder_roll_joint" type="revolute">
|
||||
<origin xyz="0 0.038 -0.013831" rpy="-0.27925 0 0"/>
|
||||
<parent link="left_shoulder_pitch_link"/>
|
||||
<child link="left_shoulder_roll_link"/>
|
||||
<axis xyz="1 0 0"/>
|
||||
<limit lower="-1.5882" upper="2.2515" effort="25" velocity="37"/>
|
||||
</joint>
|
||||
<link name="left_shoulder_yaw_link">
|
||||
<inertial>
|
||||
<origin xyz="0.010773 -0.002949 -0.072009" rpy="0 0 0"/>
|
||||
<mass value="0.734"/>
|
||||
<inertia ixx="0.0009988" ixy="7.9E-06" ixz="0.0001412" iyy="0.0010605" iyz="-2.86E-05" izz="0.0004354"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_shoulder_yaw_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_shoulder_yaw_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="left_shoulder_yaw_joint" type="revolute">
|
||||
<origin xyz="0 0.00624 -0.1032" rpy="0 0 0"/>
|
||||
<parent link="left_shoulder_roll_link"/>
|
||||
<child link="left_shoulder_yaw_link"/>
|
||||
<axis xyz="0 0 1"/>
|
||||
<limit lower="-2.618" upper="2.618" effort="25" velocity="37"/>
|
||||
</joint>
|
||||
<link name="left_elbow_link">
|
||||
<inertial>
|
||||
<origin xyz="0.064956 0.004454 -0.010062" rpy="0 0 0"/>
|
||||
<mass value="0.6"/>
|
||||
<inertia ixx="0.0002891" ixy="6.53E-05" ixz="1.72E-05" iyy="0.0004152" iyz="-5.6E-06" izz="0.0004197"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_elbow_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_elbow_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="left_elbow_joint" type="revolute">
|
||||
<origin xyz="0.015783 0 -0.080518" rpy="0 0 0"/>
|
||||
<parent link="left_shoulder_yaw_link"/>
|
||||
<child link="left_elbow_link"/>
|
||||
<axis xyz="0 1 0"/>
|
||||
<limit lower="-1.0472" upper="2.0944" effort="25" velocity="37"/>
|
||||
</joint>
|
||||
<joint name="left_wrist_roll_joint" type="revolute">
|
||||
<origin xyz="0.100 0.00188791 -0.010" rpy="0 0 0"/>
|
||||
<axis xyz="1 0 0"/>
|
||||
<parent link="left_elbow_link"/>
|
||||
<child link="left_wrist_roll_rubber_hand"/>
|
||||
<limit effort="25" velocity="37" lower="-1.972222054" upper="1.972222054"/>
|
||||
</joint>
|
||||
<link name="left_wrist_roll_rubber_hand">
|
||||
<inertial>
|
||||
<origin xyz="0.10794656650 0.00163511945 0.00202244863" rpy="0 0 0"/>
|
||||
<mass value="0.35692864"/>
|
||||
<inertia ixx="0.00019613494735" ixy="-0.00000419816908" ixz="-0.00003950860580" iyy="0.00200280358206" iyz="0.00000249774203" izz="0.00194181412808"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_wrist_roll_rubber_hand.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/left_wrist_roll_rubber_hand.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<link name="right_shoulder_pitch_link">
|
||||
<inertial>
|
||||
<origin xyz="0 -0.035892 -0.011628" rpy="0 0 0"/>
|
||||
<mass value="0.718"/>
|
||||
<inertia ixx="0.0004291" ixy="9.2E-06" ixz="6.4E-06" iyy="0.000453" iyz="-2.26E-05" izz="0.000423"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_shoulder_pitch_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 -0.04 -0.01" rpy="0 1.5707963267948966 0"/>
|
||||
<geometry>
|
||||
<cylinder radius="0.03" length="0.05"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="right_shoulder_pitch_joint" type="revolute">
|
||||
<origin xyz="0.0039563 -0.10021 0.23778" rpy="-0.27931 5.4949E-05 0.00019159"/>
|
||||
<parent link="torso_link"/>
|
||||
<child link="right_shoulder_pitch_link"/>
|
||||
<axis xyz="0 1 0"/>
|
||||
<limit lower="-3.0892" upper="2.6704" effort="25" velocity="37"/>
|
||||
</joint>
|
||||
<link name="right_shoulder_roll_link">
|
||||
<inertial>
|
||||
<origin xyz="-0.000227 -0.00727 -0.063243" rpy="0 0 0"/>
|
||||
<mass value="0.643"/>
|
||||
<inertia ixx="0.0006177" ixy="1E-06" ixz="8.7E-06" iyy="0.0006912" iyz="5.3E-06" izz="0.0003894"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_shoulder_roll_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="-0.004 -0.006 -0.053" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<cylinder radius="0.03" length="0.03"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="right_shoulder_roll_joint" type="revolute">
|
||||
<origin xyz="0 -0.038 -0.013831" rpy="0.27925 0 0"/>
|
||||
<parent link="right_shoulder_pitch_link"/>
|
||||
<child link="right_shoulder_roll_link"/>
|
||||
<axis xyz="1 0 0"/>
|
||||
<limit lower="-2.2515" upper="1.5882" effort="25" velocity="37"/>
|
||||
</joint>
|
||||
<link name="right_shoulder_yaw_link">
|
||||
<inertial>
|
||||
<origin xyz="0.010773 0.002949 -0.072009" rpy="0 0 0"/>
|
||||
<mass value="0.734"/>
|
||||
<inertia ixx="0.0009988" ixy="-7.9E-06" ixz="0.0001412" iyy="0.0010605" iyz="2.86E-05" izz="0.0004354"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_shoulder_yaw_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_shoulder_yaw_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="right_shoulder_yaw_joint" type="revolute">
|
||||
<origin xyz="0 -0.00624 -0.1032" rpy="0 0 0"/>
|
||||
<parent link="right_shoulder_roll_link"/>
|
||||
<child link="right_shoulder_yaw_link"/>
|
||||
<axis xyz="0 0 1"/>
|
||||
<limit lower="-2.618" upper="2.618" effort="25" velocity="37"/>
|
||||
</joint>
|
||||
<link name="right_elbow_link">
|
||||
<inertial>
|
||||
<origin xyz="0.064956 -0.004454 -0.010062" rpy="0 0 0"/>
|
||||
<mass value="0.6"/>
|
||||
<inertia ixx="0.0002891" ixy="-6.53E-05" ixz="1.72E-05" iyy="0.0004152" iyz="5.6E-06" izz="0.0004197"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_elbow_link.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_elbow_link.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
<joint name="right_elbow_joint" type="revolute">
|
||||
<origin xyz="0.015783 0 -0.080518" rpy="0 0 0"/>
|
||||
<parent link="right_shoulder_yaw_link"/>
|
||||
<child link="right_elbow_link"/>
|
||||
<axis xyz="0 1 0"/>
|
||||
<limit lower="-1.0472" upper="2.0944" effort="25" velocity="37"/>
|
||||
</joint>
|
||||
<joint name="right_wrist_roll_joint" type="revolute">
|
||||
<origin xyz="0.100 -0.00188791 -0.010" rpy="0 0 0"/>
|
||||
<axis xyz="1 0 0"/>
|
||||
<parent link="right_elbow_link"/>
|
||||
<child link="right_wrist_roll_rubber_hand"/>
|
||||
<limit effort="25" velocity="37" lower="-1.972222054" upper="1.972222054"/>
|
||||
</joint>
|
||||
<link name="right_wrist_roll_rubber_hand">
|
||||
<inertial>
|
||||
<origin xyz="0.10794656650 -0.00163511945 0.00202244863" rpy="0 0 0"/>
|
||||
<mass value="0.35692864"/>
|
||||
<inertia ixx="0.00019613494735" ixy="0.00000419816908" ixz="-0.00003950860580" iyy="0.00200280358206" iyz="-0.00000249774203" izz="0.00194181412808"/>
|
||||
</inertial>
|
||||
<visual>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_wrist_roll_rubber_hand.STL"/>
|
||||
</geometry>
|
||||
<material name="white">
|
||||
<color rgba="0.7 0.7 0.7 1"/>
|
||||
</material>
|
||||
</visual>
|
||||
<collision>
|
||||
<origin xyz="0 0 0" rpy="0 0 0"/>
|
||||
<geometry>
|
||||
<mesh filename="meshes/right_wrist_roll_rubber_hand.STL"/>
|
||||
</geometry>
|
||||
</collision>
|
||||
</link>
|
||||
</robot>
|
||||
1476
src/lerobot/robots/unitree_g1/assets/g1/g1_body29_hand14.urdf
Normal file
1476
src/lerobot/robots/unitree_g1/assets/g1/g1_body29_hand14.urdf
Normal file
File diff suppressed because it is too large
Load Diff
408
src/lerobot/robots/unitree_g1/assets/g1/g1_body29_hand14.xml
Normal file
408
src/lerobot/robots/unitree_g1/assets/g1/g1_body29_hand14.xml
Normal file
@@ -0,0 +1,408 @@
|
||||
<mujoco model="g1">
|
||||
<compiler angle="radian" meshdir="meshes"/>
|
||||
|
||||
<asset>
|
||||
<mesh name="pelvis" file="pelvis.STL"/>
|
||||
<mesh name="pelvis_contour_link" file="pelvis_contour_link.STL"/>
|
||||
<mesh name="left_hip_pitch_link" file="left_hip_pitch_link.STL"/>
|
||||
<mesh name="left_hip_roll_link" file="left_hip_roll_link.STL"/>
|
||||
<mesh name="left_hip_yaw_link" file="left_hip_yaw_link.STL"/>
|
||||
<mesh name="left_knee_link" file="left_knee_link.STL"/>
|
||||
<mesh name="left_ankle_pitch_link" file="left_ankle_pitch_link.STL"/>
|
||||
<mesh name="left_ankle_roll_link" file="left_ankle_roll_link.STL"/>
|
||||
<mesh name="right_hip_pitch_link" file="right_hip_pitch_link.STL"/>
|
||||
<mesh name="right_hip_roll_link" file="right_hip_roll_link.STL"/>
|
||||
<mesh name="right_hip_yaw_link" file="right_hip_yaw_link.STL"/>
|
||||
<mesh name="right_knee_link" file="right_knee_link.STL"/>
|
||||
<mesh name="right_ankle_pitch_link" file="right_ankle_pitch_link.STL"/>
|
||||
<mesh name="right_ankle_roll_link" file="right_ankle_roll_link.STL"/>
|
||||
<mesh name="waist_yaw_link" file="waist_yaw_link_rev_1_0.STL"/>
|
||||
<mesh name="waist_roll_link" file="waist_roll_link_rev_1_0.STL"/>
|
||||
<mesh name="torso_link" file="torso_link_rev_1_0.STL"/>
|
||||
<mesh name="logo_link" file="logo_link.STL"/>
|
||||
<mesh name="head_link" file="head_link.STL"/>
|
||||
<mesh name="left_shoulder_pitch_link" file="left_shoulder_pitch_link.STL"/>
|
||||
<mesh name="left_shoulder_roll_link" file="left_shoulder_roll_link.STL"/>
|
||||
<mesh name="left_shoulder_yaw_link" file="left_shoulder_yaw_link.STL"/>
|
||||
<mesh name="left_elbow_link" file="left_elbow_link.STL"/>
|
||||
<mesh name="left_wrist_roll_link" file="left_wrist_roll_link.STL"/>
|
||||
<mesh name="left_wrist_pitch_link" file="left_wrist_pitch_link.STL"/>
|
||||
<mesh name="left_wrist_yaw_link" file="left_wrist_yaw_link.STL"/>
|
||||
<mesh name="left_hand_palm_link" file="left_hand_palm_link.STL"/>
|
||||
<mesh name="left_hand_thumb_0_link" file="left_hand_thumb_0_link.STL"/>
|
||||
<mesh name="left_hand_thumb_1_link" file="left_hand_thumb_1_link.STL"/>
|
||||
<mesh name="left_hand_thumb_2_link" file="left_hand_thumb_2_link.STL"/>
|
||||
<mesh name="left_hand_middle_0_link" file="left_hand_middle_0_link.STL"/>
|
||||
<mesh name="left_hand_middle_1_link" file="left_hand_middle_1_link.STL"/>
|
||||
<mesh name="left_hand_index_0_link" file="left_hand_index_0_link.STL"/>
|
||||
<mesh name="left_hand_index_1_link" file="left_hand_index_1_link.STL"/>
|
||||
<mesh name="right_shoulder_pitch_link" file="right_shoulder_pitch_link.STL"/>
|
||||
<mesh name="right_shoulder_roll_link" file="right_shoulder_roll_link.STL"/>
|
||||
<mesh name="right_shoulder_yaw_link" file="right_shoulder_yaw_link.STL"/>
|
||||
<mesh name="right_elbow_link" file="right_elbow_link.STL"/>
|
||||
<mesh name="right_wrist_roll_link" file="right_wrist_roll_link.STL"/>
|
||||
<mesh name="right_wrist_pitch_link" file="right_wrist_pitch_link.STL"/>
|
||||
<mesh name="right_wrist_yaw_link" file="right_wrist_yaw_link.STL"/>
|
||||
<mesh name="right_hand_palm_link" file="right_hand_palm_link.STL"/>
|
||||
<mesh name="right_hand_thumb_0_link" file="right_hand_thumb_0_link.STL"/>
|
||||
<mesh name="right_hand_thumb_1_link" file="right_hand_thumb_1_link.STL"/>
|
||||
<mesh name="right_hand_thumb_2_link" file="right_hand_thumb_2_link.STL"/>
|
||||
<mesh name="right_hand_middle_0_link" file="right_hand_middle_0_link.STL"/>
|
||||
<mesh name="right_hand_middle_1_link" file="right_hand_middle_1_link.STL"/>
|
||||
<mesh name="right_hand_index_0_link" file="right_hand_index_0_link.STL"/>
|
||||
<mesh name="right_hand_index_1_link" file="right_hand_index_1_link.STL"/>
|
||||
</asset>
|
||||
|
||||
<worldbody>
|
||||
<body name="pelvis" pos="0 0 0.793">
|
||||
<inertial pos="0 0 -0.07605" quat="1 0 -0.000399148 0" mass="3.813" diaginertia="0.010549 0.0093089 0.0079184"/>
|
||||
<joint name="floating_base_joint" type="free" limited="false" actuatorfrclimited="false"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.2 0.2 0.2 1" mesh="pelvis"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="pelvis_contour_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="pelvis_contour_link"/>
|
||||
<site name="imu_in_pelvis" size="0.01" pos="0.04525 0 -0.08339"/>
|
||||
<body name="left_hip_pitch_link" pos="0 0.064452 -0.1027">
|
||||
<inertial pos="0.002741 0.047791 -0.02606" quat="0.954862 0.293964 0.0302556 0.030122" mass="1.35" diaginertia="0.00181517 0.00153422 0.00116212"/>
|
||||
<joint name="left_hip_pitch_joint" pos="0 0 0" axis="0 1 0" range="-2.5307 2.8798" actuatorfrcrange="-88 88"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.2 0.2 0.2 1" mesh="left_hip_pitch_link"/>
|
||||
<geom type="mesh" rgba="0.2 0.2 0.2 1" mesh="left_hip_pitch_link"/>
|
||||
<body name="left_hip_roll_link" pos="0 0.052 -0.030465" quat="0.996179 0 -0.0873386 0">
|
||||
<inertial pos="0.029812 -0.001045 -0.087934" quat="0.977808 -1.97119e-05 0.205576 -0.0403793" mass="1.52" diaginertia="0.00254986 0.00241169 0.00148755"/>
|
||||
<joint name="left_hip_roll_joint" pos="0 0 0" axis="1 0 0" range="-0.5236 2.9671" actuatorfrcrange="-139 139"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_hip_roll_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_hip_roll_link"/>
|
||||
<body name="left_hip_yaw_link" pos="0.025001 0 -0.12412">
|
||||
<inertial pos="-0.057709 -0.010981 -0.15078" quat="0.600598 0.15832 0.223482 0.751181" mass="1.702" diaginertia="0.00776166 0.00717575 0.00160139"/>
|
||||
<joint name="left_hip_yaw_joint" pos="0 0 0" axis="0 0 1" range="-2.7576 2.7576" actuatorfrcrange="-88 88"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_hip_yaw_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_hip_yaw_link"/>
|
||||
<body name="left_knee_link" pos="-0.078273 0.0021489 -0.17734" quat="0.996179 0 0.0873386 0">
|
||||
<inertial pos="0.005457 0.003964 -0.12074" quat="0.923418 -0.0327699 0.0158246 0.382067" mass="1.932" diaginertia="0.0113804 0.0112778 0.00146458"/>
|
||||
<joint name="left_knee_joint" pos="0 0 0" axis="0 1 0" range="-0.087267 2.8798" actuatorfrcrange="-139 139"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_knee_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_knee_link"/>
|
||||
<body name="left_ankle_pitch_link" pos="0 -9.4445e-05 -0.30001">
|
||||
<inertial pos="-0.007269 0 0.011137" quat="0.603053 0.369225 0.369225 0.603053" mass="0.074" diaginertia="1.89e-05 1.40805e-05 6.9195e-06"/>
|
||||
<joint name="left_ankle_pitch_joint" pos="0 0 0" axis="0 1 0" range="-0.87267 0.5236" actuatorfrcrange="-50 50"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_ankle_pitch_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_ankle_pitch_link"/>
|
||||
<body name="left_ankle_roll_link" pos="0 0 -0.017558">
|
||||
<inertial pos="0.026505 0 -0.016425" quat="-0.000481092 0.728482 -0.000618967 0.685065" mass="0.608" diaginertia="0.00167218 0.0016161 0.000217621"/>
|
||||
<joint name="left_ankle_roll_joint" pos="0 0 0" axis="1 0 0" range="-0.2618 0.2618" actuatorfrcrange="-50 50"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.2 0.2 0.2 1" mesh="left_ankle_roll_link"/>
|
||||
<geom size="0.005" pos="-0.05 0.025 -0.03" rgba="0.2 0.2 0.2 1"/>
|
||||
<geom size="0.005" pos="-0.05 -0.025 -0.03" rgba="0.2 0.2 0.2 1"/>
|
||||
<geom size="0.005" pos="0.12 0.03 -0.03" rgba="0.2 0.2 0.2 1"/>
|
||||
<geom size="0.005" pos="0.12 -0.03 -0.03" rgba="0.2 0.2 0.2 1"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
<body name="right_hip_pitch_link" pos="0 -0.064452 -0.1027">
|
||||
<inertial pos="0.002741 -0.047791 -0.02606" quat="0.954862 -0.293964 0.0302556 -0.030122" mass="1.35" diaginertia="0.00181517 0.00153422 0.00116212"/>
|
||||
<joint name="right_hip_pitch_joint" pos="0 0 0" axis="0 1 0" range="-2.5307 2.8798" actuatorfrcrange="-88 88"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.2 0.2 0.2 1" mesh="right_hip_pitch_link"/>
|
||||
<geom type="mesh" rgba="0.2 0.2 0.2 1" mesh="right_hip_pitch_link"/>
|
||||
<body name="right_hip_roll_link" pos="0 -0.052 -0.030465" quat="0.996179 0 -0.0873386 0">
|
||||
<inertial pos="0.029812 0.001045 -0.087934" quat="0.977808 1.97119e-05 0.205576 0.0403793" mass="1.52" diaginertia="0.00254986 0.00241169 0.00148755"/>
|
||||
<joint name="right_hip_roll_joint" pos="0 0 0" axis="1 0 0" range="-2.9671 0.5236" actuatorfrcrange="-139 139"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_hip_roll_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_hip_roll_link"/>
|
||||
<body name="right_hip_yaw_link" pos="0.025001 0 -0.12412">
|
||||
<inertial pos="-0.057709 0.010981 -0.15078" quat="0.751181 0.223482 0.15832 0.600598" mass="1.702" diaginertia="0.00776166 0.00717575 0.00160139"/>
|
||||
<joint name="right_hip_yaw_joint" pos="0 0 0" axis="0 0 1" range="-2.7576 2.7576" actuatorfrcrange="-88 88"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_hip_yaw_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_hip_yaw_link"/>
|
||||
<body name="right_knee_link" pos="-0.078273 -0.0021489 -0.17734" quat="0.996179 0 0.0873386 0">
|
||||
<inertial pos="0.005457 -0.003964 -0.12074" quat="0.923439 0.0345276 0.0116333 -0.382012" mass="1.932" diaginertia="0.011374 0.0112843 0.00146452"/>
|
||||
<joint name="right_knee_joint" pos="0 0 0" axis="0 1 0" range="-0.087267 2.8798" actuatorfrcrange="-139 139"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_knee_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_knee_link"/>
|
||||
<body name="right_ankle_pitch_link" pos="0 9.4445e-05 -0.30001">
|
||||
<inertial pos="-0.007269 0 0.011137" quat="0.603053 0.369225 0.369225 0.603053" mass="0.074" diaginertia="1.89e-05 1.40805e-05 6.9195e-06"/>
|
||||
<joint name="right_ankle_pitch_joint" pos="0 0 0" axis="0 1 0" range="-0.87267 0.5236" actuatorfrcrange="-50 50"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_ankle_pitch_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_ankle_pitch_link"/>
|
||||
<body name="right_ankle_roll_link" pos="0 0 -0.017558">
|
||||
<inertial pos="0.026505 0 -0.016425" quat="0.000481092 0.728482 0.000618967 0.685065" mass="0.608" diaginertia="0.00167218 0.0016161 0.000217621"/>
|
||||
<joint name="right_ankle_roll_joint" pos="0 0 0" axis="1 0 0" range="-0.2618 0.2618" actuatorfrcrange="-50 50"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.2 0.2 0.2 1" mesh="right_ankle_roll_link"/>
|
||||
<geom size="0.005" pos="-0.05 0.025 -0.03" rgba="0.2 0.2 0.2 1"/>
|
||||
<geom size="0.005" pos="-0.05 -0.025 -0.03" rgba="0.2 0.2 0.2 1"/>
|
||||
<geom size="0.005" pos="0.12 0.03 -0.03" rgba="0.2 0.2 0.2 1"/>
|
||||
<geom size="0.005" pos="0.12 -0.03 -0.03" rgba="0.2 0.2 0.2 1"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
<body name="waist_yaw_link">
|
||||
<inertial pos="0.003494 0.000233 0.018034" quat="0.289697 0.591001 -0.337795 0.672821" mass="0.214" diaginertia="0.000163531 0.000107714 0.000102205"/>
|
||||
<joint name="waist_yaw_joint" pos="0 0 0" axis="0 0 1" range="-2.618 2.618" actuatorfrcrange="-88 88"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="waist_yaw_link"/>
|
||||
<body name="waist_roll_link" pos="-0.0039635 0 0.044">
|
||||
<inertial pos="0 2.3e-05 0" quat="0.5 0.5 -0.5 0.5" mass="0.086" diaginertia="8.245e-06 7.079e-06 6.339e-06"/>
|
||||
<joint name="waist_roll_joint" pos="0 0 0" axis="1 0 0" range="-0.52 0.52" actuatorfrcrange="-50 50"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="waist_roll_link"/>
|
||||
<body name="torso_link">
|
||||
<inertial pos="0.00203158 0.000339683 0.184568" quat="0.999803 -6.03319e-05 0.0198256 0.00131986" mass="7.818" diaginertia="0.121847 0.109825 0.0273735"/>
|
||||
<joint name="waist_pitch_joint" pos="0 0 0" axis="0 1 0" range="-0.52 0.52" actuatorfrcrange="-50 50"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="torso_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="torso_link"/>
|
||||
<geom pos="0.0039635 0 -0.044" quat="1 0 0 0" type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.2 0.2 0.2 1" mesh="logo_link"/>
|
||||
<geom pos="0.0039635 0 -0.044" quat="1 0 0 0" type="mesh" rgba="0.2 0.2 0.2 1" mesh="logo_link"/>
|
||||
<geom pos="0.0039635 0 -0.044" type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.2 0.2 0.2 1" mesh="head_link"/>
|
||||
<geom pos="0.0039635 0 -0.044" type="mesh" rgba="0.2 0.2 0.2 1" mesh="head_link"/>
|
||||
<site name="imu_in_torso" size="0.01" pos="-0.03959 -0.00224 0.14792"/>
|
||||
<body name="left_shoulder_pitch_link" pos="0.0039563 0.10022 0.24778" quat="0.990264 0.139201 1.38722e-05 -9.86868e-05">
|
||||
<inertial pos="0 0.035892 -0.011628" quat="0.654152 0.0130458 -0.326267 0.68225" mass="0.718" diaginertia="0.000465864 0.000432842 0.000406394"/>
|
||||
<joint name="left_shoulder_pitch_joint" pos="0 0 0" axis="0 1 0" range="-3.0892 2.6704" actuatorfrcrange="-25 25"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_shoulder_pitch_link"/>
|
||||
<geom size="0.03 0.025" pos="0 0.04 -0.01" quat="0.707107 0 0.707107 0" type="cylinder" rgba="0.7 0.7 0.7 1"/>
|
||||
<body name="left_shoulder_roll_link" pos="0 0.038 -0.013831" quat="0.990268 -0.139172 0 0">
|
||||
<inertial pos="-0.000227 0.00727 -0.063243" quat="0.701256 -0.0196223 -0.00710317 0.712604" mass="0.643" diaginertia="0.000691311 0.000618011 0.000388977"/>
|
||||
<joint name="left_shoulder_roll_joint" pos="0 0 0" axis="1 0 0" range="-1.5882 2.2515" actuatorfrcrange="-25 25"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_shoulder_roll_link"/>
|
||||
<geom size="0.03 0.015" pos="-0.004 0.006 -0.053" type="cylinder" rgba="0.7 0.7 0.7 1"/>
|
||||
<body name="left_shoulder_yaw_link" pos="0 0.00624 -0.1032">
|
||||
<inertial pos="0.010773 -0.002949 -0.072009" quat="0.716879 -0.0964829 -0.0679942 0.687134" mass="0.734" diaginertia="0.00106187 0.00103217 0.000400661"/>
|
||||
<joint name="left_shoulder_yaw_joint" pos="0 0 0" axis="0 0 1" range="-2.618 2.618" actuatorfrcrange="-25 25"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_shoulder_yaw_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_shoulder_yaw_link"/>
|
||||
<body name="left_elbow_link" pos="0.015783 0 -0.080518">
|
||||
<inertial pos="0.064956 0.004454 -0.010062" quat="0.541765 0.636132 0.388821 0.388129" mass="0.6" diaginertia="0.000443035 0.000421612 0.000259353"/>
|
||||
<joint name="left_elbow_joint" pos="0 0 0" axis="0 1 0" range="-1.0472 2.0944" actuatorfrcrange="-25 25"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_elbow_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_elbow_link"/>
|
||||
<body name="left_wrist_roll_link" pos="0.1 0.00188791 -0.01">
|
||||
<inertial pos="0.0171394 0.000537591 4.8864e-07" quat="0.575338 0.411667 -0.574906 0.411094" mass="0.085445" diaginertia="5.48211e-05 4.96646e-05 3.57798e-05"/>
|
||||
<joint name="left_wrist_roll_joint" pos="0 0 0" axis="1 0 0" range="-1.97222 1.97222" actuatorfrcrange="-25 25"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_wrist_roll_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_wrist_roll_link"/>
|
||||
<body name="left_wrist_pitch_link" pos="0.038 0 0">
|
||||
<inertial pos="0.0229999 -0.00111685 -0.00111658" quat="0.249998 0.661363 0.293036 0.643608" mass="0.48405" diaginertia="0.000430353 0.000429873 0.000164648"/>
|
||||
<joint name="left_wrist_pitch_joint" pos="0 0 0" axis="0 1 0" range="-1.61443 1.61443" actuatorfrcrange="-5 5"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_wrist_pitch_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_wrist_pitch_link"/>
|
||||
<body name="left_wrist_yaw_link" pos="0.046 0 0">
|
||||
<inertial pos="0.0885506 0.00212216 -0.000374562" quat="0.487149 0.493844 0.513241 0.505358" mass="0.457415" diaginertia="0.00105989 0.000895419 0.000323842"/>
|
||||
<joint name="left_wrist_yaw_joint" pos="0 0 0" axis="0 0 1" range="-1.61443 1.61443" actuatorfrcrange="-5 5"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_wrist_yaw_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_wrist_yaw_link"/>
|
||||
<geom pos="0.0415 0.003 0" quat="1 0 0 0" type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_hand_palm_link"/>
|
||||
<geom pos="0.0415 0.003 0" quat="1 0 0 0" type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_hand_palm_link"/>
|
||||
<body name="left_hand_thumb_0_link" pos="0.067 0.003 0">
|
||||
<inertial pos="-0.000884246 -0.00863407 0.000944293" quat="0.462991 0.643965 -0.460173 0.398986" mass="0.0862366" diaginertia="1.6546e-05 1.60058e-05 1.43741e-05"/>
|
||||
<joint name="left_hand_thumb_0_joint" pos="0 0 0" axis="0 1 0" range="-1.0472 1.0472" actuatorfrcrange="-2.45 2.45"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_hand_thumb_0_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_hand_thumb_0_link"/>
|
||||
<body name="left_hand_thumb_1_link" pos="-0.0025 -0.0193 0">
|
||||
<inertial pos="-0.000827888 -0.0354744 -0.0003809" quat="0.685598 0.705471 -0.15207 0.0956069" mass="0.0588507" diaginertia="1.28514e-05 1.22902e-05 5.9666e-06"/>
|
||||
<joint name="left_hand_thumb_1_joint" pos="0 0 0" axis="0 0 1" range="-0.724312 1.0472" actuatorfrcrange="-1.4 1.4"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_hand_thumb_1_link"/>
|
||||
<geom size="0.01 0.015 0.01" pos="-0.001 -0.032 0" type="box" rgba="0.7 0.7 0.7 1"/>
|
||||
<body name="left_hand_thumb_2_link" pos="0 -0.0458 0">
|
||||
<inertial pos="-0.00171735 -0.0262819 0.000107789" quat="0.703174 0.710977 -0.00017564 -0.00766553" mass="0.0203063" diaginertia="4.61314e-06 3.86645e-06 1.53495e-06"/>
|
||||
<joint name="left_hand_thumb_2_joint" pos="0 0 0" axis="0 0 1" range="0 1.74533" actuatorfrcrange="-1.4 1.4"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_hand_thumb_2_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_hand_thumb_2_link"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
<body name="left_hand_middle_0_link" pos="0.1192 0.0046 -0.0285">
|
||||
<inertial pos="0.0354744 0.000827888 0.0003809" quat="0.391313 0.552395 0.417187 0.606373" mass="0.0588507" diaginertia="1.28514e-05 1.22902e-05 5.9666e-06"/>
|
||||
<joint name="left_hand_middle_0_joint" pos="0 0 0" axis="0 0 1" range="-1.5708 0" actuatorfrcrange="-1.4 1.4"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_hand_middle_0_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_hand_middle_0_link"/>
|
||||
<body name="left_hand_middle_1_link" pos="0.0458 0 0">
|
||||
<inertial pos="0.0262819 0.00171735 -0.000107789" quat="0.502612 0.491799 0.502639 0.502861" mass="0.0203063" diaginertia="4.61314e-06 3.86645e-06 1.53495e-06"/>
|
||||
<joint name="left_hand_middle_1_joint" pos="0 0 0" axis="0 0 1" range="-1.74533 0" actuatorfrcrange="-1.4 1.4"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_hand_middle_1_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_hand_middle_1_link"/>
|
||||
</body>
|
||||
</body>
|
||||
<body name="left_hand_index_0_link" pos="0.1192 0.0046 0.0285">
|
||||
<inertial pos="0.0354744 0.000827888 0.0003809" quat="0.391313 0.552395 0.417187 0.606373" mass="0.0588507" diaginertia="1.28514e-05 1.22902e-05 5.9666e-06"/>
|
||||
<joint name="left_hand_index_0_joint" pos="0 0 0" axis="0 0 1" range="-1.5708 0" actuatorfrcrange="-1.4 1.4"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_hand_index_0_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_hand_index_0_link"/>
|
||||
<body name="left_hand_index_1_link" pos="0.0458 0 0">
|
||||
<inertial pos="0.0262819 0.00171735 -0.000107789" quat="0.502612 0.491799 0.502639 0.502861" mass="0.0203063" diaginertia="4.61314e-06 3.86645e-06 1.53495e-06"/>
|
||||
<joint name="left_hand_index_1_joint" pos="0 0 0" axis="0 0 1" range="-1.74533 0" actuatorfrcrange="-1.4 1.4"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="left_hand_index_1_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="left_hand_index_1_link"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
<body name="right_shoulder_pitch_link" pos="0.0039563 -0.10021 0.24778" quat="0.990264 -0.139201 1.38722e-05 9.86868e-05">
|
||||
<inertial pos="0 -0.035892 -0.011628" quat="0.68225 -0.326267 0.0130458 0.654152" mass="0.718" diaginertia="0.000465864 0.000432842 0.000406394"/>
|
||||
<joint name="right_shoulder_pitch_joint" pos="0 0 0" axis="0 1 0" range="-3.0892 2.6704" actuatorfrcrange="-25 25"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_shoulder_pitch_link"/>
|
||||
<geom size="0.03 0.025" pos="0 -0.04 -0.01" quat="0.707107 0 0.707107 0" type="cylinder" rgba="0.7 0.7 0.7 1"/>
|
||||
<body name="right_shoulder_roll_link" pos="0 -0.038 -0.013831" quat="0.990268 0.139172 0 0">
|
||||
<inertial pos="-0.000227 -0.00727 -0.063243" quat="0.712604 -0.00710317 -0.0196223 0.701256" mass="0.643" diaginertia="0.000691311 0.000618011 0.000388977"/>
|
||||
<joint name="right_shoulder_roll_joint" pos="0 0 0" axis="1 0 0" range="-2.2515 1.5882" actuatorfrcrange="-25 25"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_shoulder_roll_link"/>
|
||||
<geom size="0.03 0.015" pos="-0.004 -0.006 -0.053" type="cylinder" rgba="0.7 0.7 0.7 1"/>
|
||||
<body name="right_shoulder_yaw_link" pos="0 -0.00624 -0.1032">
|
||||
<inertial pos="0.010773 0.002949 -0.072009" quat="0.687134 -0.0679942 -0.0964829 0.716879" mass="0.734" diaginertia="0.00106187 0.00103217 0.000400661"/>
|
||||
<joint name="right_shoulder_yaw_joint" pos="0 0 0" axis="0 0 1" range="-2.618 2.618" actuatorfrcrange="-25 25"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_shoulder_yaw_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_shoulder_yaw_link"/>
|
||||
<body name="right_elbow_link" pos="0.015783 0 -0.080518">
|
||||
<inertial pos="0.064956 -0.004454 -0.010062" quat="0.388129 0.388821 0.636132 0.541765" mass="0.6" diaginertia="0.000443035 0.000421612 0.000259353"/>
|
||||
<joint name="right_elbow_joint" pos="0 0 0" axis="0 1 0" range="-1.0472 2.0944" actuatorfrcrange="-25 25"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_elbow_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_elbow_link"/>
|
||||
<body name="right_wrist_roll_link" pos="0.1 -0.00188791 -0.01">
|
||||
<inertial pos="0.0171394 -0.000537591 4.8864e-07" quat="0.411667 0.575338 -0.411094 0.574906" mass="0.085445" diaginertia="5.48211e-05 4.96646e-05 3.57798e-05"/>
|
||||
<joint name="right_wrist_roll_joint" pos="0 0 0" axis="1 0 0" range="-1.97222 1.97222" actuatorfrcrange="-25 25"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_wrist_roll_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_wrist_roll_link"/>
|
||||
<body name="right_wrist_pitch_link" pos="0.038 0 0">
|
||||
<inertial pos="0.0229999 0.00111685 -0.00111658" quat="0.643608 0.293036 0.661363 0.249998" mass="0.48405" diaginertia="0.000430353 0.000429873 0.000164648"/>
|
||||
<joint name="right_wrist_pitch_joint" pos="0 0 0" axis="0 1 0" range="-1.61443 1.61443" actuatorfrcrange="-5 5"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_wrist_pitch_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_wrist_pitch_link"/>
|
||||
<body name="right_wrist_yaw_link" pos="0.046 0 0">
|
||||
<inertial pos="0.0885506 -0.00212216 -0.000374562" quat="0.505358 0.513241 0.493844 0.487149" mass="0.457415" diaginertia="0.00105989 0.000895419 0.000323842"/>
|
||||
<joint name="right_wrist_yaw_joint" pos="0 0 0" axis="0 0 1" range="-1.61443 1.61443" actuatorfrcrange="-5 5"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_wrist_yaw_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_wrist_yaw_link"/>
|
||||
<geom pos="0.0415 -0.003 0" quat="1 0 0 0" type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_hand_palm_link"/>
|
||||
<geom pos="0.0415 -0.003 0" quat="1 0 0 0" type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_hand_palm_link"/>
|
||||
<body name="right_hand_thumb_0_link" pos="0.067 -0.003 0">
|
||||
<inertial pos="-0.000884246 0.00863407 0.000944293" quat="0.643965 0.462991 -0.398986 0.460173" mass="0.0862366" diaginertia="1.6546e-05 1.60058e-05 1.43741e-05"/>
|
||||
<joint name="right_hand_thumb_0_joint" pos="0 0 0" axis="0 1 0" range="-1.0472 1.0472" actuatorfrcrange="-2.45 2.45"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_hand_thumb_0_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_hand_thumb_0_link"/>
|
||||
<body name="right_hand_thumb_1_link" pos="-0.0025 0.0193 0">
|
||||
<inertial pos="-0.000827888 0.0354744 -0.0003809" quat="0.705471 0.685598 -0.0956069 0.15207" mass="0.0588507" diaginertia="1.28514e-05 1.22902e-05 5.9666e-06"/>
|
||||
<joint name="right_hand_thumb_1_joint" pos="0 0 0" axis="0 0 1" range="-1.0472 0.724312" actuatorfrcrange="-1.4 1.4"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_hand_thumb_1_link"/>
|
||||
<geom size="0.01 0.015 0.01" pos="-0.001 0.032 0" type="box" rgba="0.7 0.7 0.7 1"/>
|
||||
<body name="right_hand_thumb_2_link" pos="0 0.0458 0">
|
||||
<inertial pos="-0.00171735 0.0262819 0.000107789" quat="0.710977 0.703174 0.00766553 0.00017564" mass="0.0203063" diaginertia="4.61314e-06 3.86645e-06 1.53495e-06"/>
|
||||
<joint name="right_hand_thumb_2_joint" pos="0 0 0" axis="0 0 1" range="-1.74533 0" actuatorfrcrange="-1.4 1.4"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_hand_thumb_2_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_hand_thumb_2_link"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
<body name="right_hand_middle_0_link" pos="0.1192 -0.0046 -0.0285">
|
||||
<inertial pos="0.0354744 -0.000827888 0.0003809" quat="0.606373 0.417187 0.552395 0.391313" mass="0.0588507" diaginertia="1.28514e-05 1.22902e-05 5.9666e-06"/>
|
||||
<joint name="right_hand_middle_0_joint" pos="0 0 0" axis="0 0 1" range="0 1.5708" actuatorfrcrange="-1.4 1.4"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_hand_middle_0_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_hand_middle_0_link"/>
|
||||
<body name="right_hand_middle_1_link" pos="0.0458 0 0">
|
||||
<inertial pos="0.0262819 -0.00171735 -0.000107789" quat="0.502861 0.502639 0.491799 0.502612" mass="0.0203063" diaginertia="4.61314e-06 3.86645e-06 1.53495e-06"/>
|
||||
<joint name="right_hand_middle_1_joint" pos="0 0 0" axis="0 0 1" range="0 1.74533" actuatorfrcrange="-1.4 1.4"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_hand_middle_1_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_hand_middle_1_link"/>
|
||||
</body>
|
||||
</body>
|
||||
<body name="right_hand_index_0_link" pos="0.1192 -0.0046 0.0285">
|
||||
<inertial pos="0.0354744 -0.000827888 0.0003809" quat="0.606373 0.417187 0.552395 0.391313" mass="0.0588507" diaginertia="1.28514e-05 1.22902e-05 5.9666e-06"/>
|
||||
<joint name="right_hand_index_0_joint" pos="0 0 0" axis="0 0 1" range="0 1.5708" actuatorfrcrange="-1.4 1.4"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_hand_index_0_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_hand_index_0_link"/>
|
||||
<body name="right_hand_index_1_link" pos="0.0458 0 0">
|
||||
<inertial pos="0.0262819 -0.00171735 -0.000107789" quat="0.502861 0.502639 0.491799 0.502612" mass="0.0203063" diaginertia="4.61314e-06 3.86645e-06 1.53495e-06"/>
|
||||
<joint name="right_hand_index_1_joint" pos="0 0 0" axis="0 0 1" range="0 1.74533" actuatorfrcrange="-1.4 1.4"/>
|
||||
<geom type="mesh" contype="0" conaffinity="0" group="1" density="0" rgba="0.7 0.7 0.7 1" mesh="right_hand_index_1_link"/>
|
||||
<geom type="mesh" rgba="0.7 0.7 0.7 1" mesh="right_hand_index_1_link"/>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</body>
|
||||
</worldbody>
|
||||
|
||||
<actuator>
|
||||
<motor name="left_hip_pitch_joint" joint="left_hip_pitch_joint"/>
|
||||
<motor name="left_hip_roll_joint" joint="left_hip_roll_joint"/>
|
||||
<motor name="left_hip_yaw_joint" joint="left_hip_yaw_joint"/>
|
||||
<motor name="left_knee_joint" joint="left_knee_joint"/>
|
||||
<motor name="left_ankle_pitch_joint" joint="left_ankle_pitch_joint"/>
|
||||
<motor name="left_ankle_roll_joint" joint="left_ankle_roll_joint"/>
|
||||
<motor name="right_hip_pitch_joint" joint="right_hip_pitch_joint"/>
|
||||
<motor name="right_hip_roll_joint" joint="right_hip_roll_joint"/>
|
||||
<motor name="right_hip_yaw_joint" joint="right_hip_yaw_joint"/>
|
||||
<motor name="right_knee_joint" joint="right_knee_joint"/>
|
||||
<motor name="right_ankle_pitch_joint" joint="right_ankle_pitch_joint"/>
|
||||
<motor name="right_ankle_roll_joint" joint="right_ankle_roll_joint"/>
|
||||
<motor name="waist_yaw_joint" joint="waist_yaw_joint"/>
|
||||
<motor name="waist_roll_joint" joint="waist_roll_joint"/>
|
||||
<motor name="waist_pitch_joint" joint="waist_pitch_joint"/>
|
||||
<motor name="left_shoulder_pitch_joint" joint="left_shoulder_pitch_joint"/>
|
||||
<motor name="left_shoulder_roll_joint" joint="left_shoulder_roll_joint"/>
|
||||
<motor name="left_shoulder_yaw_joint" joint="left_shoulder_yaw_joint"/>
|
||||
<motor name="left_elbow_joint" joint="left_elbow_joint"/>
|
||||
<motor name="left_wrist_roll_joint" joint="left_wrist_roll_joint"/>
|
||||
<motor name="left_wrist_pitch_joint" joint="left_wrist_pitch_joint"/>
|
||||
<motor name="left_wrist_yaw_joint" joint="left_wrist_yaw_joint"/>
|
||||
<motor name="left_hand_thumb_0_joint" joint="left_hand_thumb_0_joint"/>
|
||||
<motor name="left_hand_thumb_1_joint" joint="left_hand_thumb_1_joint"/>
|
||||
<motor name="left_hand_thumb_2_joint" joint="left_hand_thumb_2_joint"/>
|
||||
<motor name="left_hand_middle_0_joint" joint="left_hand_middle_0_joint"/>
|
||||
<motor name="left_hand_middle_1_joint" joint="left_hand_middle_1_joint"/>
|
||||
<motor name="left_hand_index_0_joint" joint="left_hand_index_0_joint"/>
|
||||
<motor name="left_hand_index_1_joint" joint="left_hand_index_1_joint"/>
|
||||
<motor name="right_shoulder_pitch_joint" joint="right_shoulder_pitch_joint"/>
|
||||
<motor name="right_shoulder_roll_joint" joint="right_shoulder_roll_joint"/>
|
||||
<motor name="right_shoulder_yaw_joint" joint="right_shoulder_yaw_joint"/>
|
||||
<motor name="right_elbow_joint" joint="right_elbow_joint"/>
|
||||
<motor name="right_wrist_roll_joint" joint="right_wrist_roll_joint"/>
|
||||
<motor name="right_wrist_pitch_joint" joint="right_wrist_pitch_joint"/>
|
||||
<motor name="right_wrist_yaw_joint" joint="right_wrist_yaw_joint"/>
|
||||
<motor name="right_hand_thumb_0_joint" joint="right_hand_thumb_0_joint"/>
|
||||
<motor name="right_hand_thumb_1_joint" joint="right_hand_thumb_1_joint"/>
|
||||
<motor name="right_hand_thumb_2_joint" joint="right_hand_thumb_2_joint"/>
|
||||
<motor name="right_hand_index_0_joint" joint="right_hand_index_0_joint"/>
|
||||
<motor name="right_hand_index_1_joint" joint="right_hand_index_1_joint"/>
|
||||
<motor name="right_hand_middle_0_joint" joint="right_hand_middle_0_joint"/>
|
||||
<motor name="right_hand_middle_1_joint" joint="right_hand_middle_1_joint"/>
|
||||
</actuator>
|
||||
|
||||
<sensor>
|
||||
<gyro name="imu-torso-angular-velocity" site="imu_in_torso" noise="5e-4" cutoff="34.9"/>
|
||||
<accelerometer name="imu-torso-linear-acceleration" site="imu_in_torso" noise="1e-2" cutoff="157"/>
|
||||
<gyro name="imu-pelvis-angular-velocity" site="imu_in_pelvis" noise="5e-4" cutoff="34.9"/>
|
||||
<accelerometer name="imu-pelvis-linear-acceleration" site="imu_in_pelvis" noise="1e-2" cutoff="157"/>
|
||||
</sensor>
|
||||
|
||||
|
||||
<!-- setup scene -->
|
||||
<statistic center="1.0 0.7 1.0" extent="0.8"/>
|
||||
<visual>
|
||||
<headlight diffuse="0.6 0.6 0.6" ambient="0.1 0.1 0.1" specular="0.9 0.9 0.9"/>
|
||||
<rgba haze="0.15 0.25 0.35 1"/>
|
||||
<global azimuth="-140" elevation="-20"/>
|
||||
</visual>
|
||||
<asset>
|
||||
<texture type="skybox" builtin="flat" rgb1="0 0 0" rgb2="0 0 0" width="512" height="3072"/>
|
||||
<texture type="2d" name="groundplane" builtin="checker" mark="edge" rgb1="0.2 0.3 0.4" rgb2="0.1 0.2 0.3" markrgb="0.8 0.8 0.8" width="300" height="300"/>
|
||||
<material name="groundplane" texture="groundplane" texuniform="true" texrepeat="5 5" reflectance="0.2"/>
|
||||
</asset>
|
||||
<worldbody>
|
||||
<light pos="1 0 3.5" dir="0 0 -1" directional="true"/>
|
||||
<geom name="floor" size="0 0 0.05" type="plane" material="groundplane"/>
|
||||
</worldbody>
|
||||
</mujoco>
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
src/lerobot/robots/unitree_g1/assets/g1/locomotion/motion.pt
Normal file
BIN
src/lerobot/robots/unitree_g1/assets/g1/locomotion/motion.pt
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
src/lerobot/robots/unitree_g1/assets/g1/locomotion/motion_500.pt
Normal file
BIN
src/lerobot/robots/unitree_g1/assets/g1/locomotion/motion_500.pt
Normal file
Binary file not shown.
Binary file not shown.
BIN
src/lerobot/robots/unitree_g1/assets/g1/meshes/head_link.STL
Normal file
BIN
src/lerobot/robots/unitree_g1/assets/g1/meshes/head_link.STL
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
src/lerobot/robots/unitree_g1/assets/g1/meshes/logo_link.STL
Normal file
BIN
src/lerobot/robots/unitree_g1/assets/g1/meshes/logo_link.STL
Normal file
Binary file not shown.
BIN
src/lerobot/robots/unitree_g1/assets/g1/meshes/pelvis.STL
Normal file
BIN
src/lerobot/robots/unitree_g1/assets/g1/meshes/pelvis.STL
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user