mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
393 lines
14 KiB
Python
393 lines
14 KiB
Python
#!/usr/bin/env python
|
||
|
||
"""
|
||
Standalone evaluation script for RLearN models.
|
||
|
||
This script evaluates RLearN reward models on episodes from a dataset,
|
||
generating comparison plots between ground truth rewards and model predictions.
|
||
|
||
Usage:
|
||
python src/lerobot/policies/rlearn/eval_script.py --model MODEL_NAME --dataset DATASET_REPO --episodes N
|
||
|
||
Example:
|
||
python src/lerobot/policies/rlearn/eval_script.py --model pepijn223/rlearn_18 --dataset pepijn223/phone_pipeline_pickup1 --episodes 2
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# Add src to path for imports
|
||
sys.path.append(str(Path(__file__).parent.parent.parent.parent))
|
||
|
||
import warnings
|
||
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
import torch
|
||
from scipy.stats import spearmanr
|
||
from tqdm import tqdm
|
||
|
||
warnings.filterwarnings("ignore")
|
||
|
||
# LeRobot imports
|
||
from lerobot.constants import OBS_IMAGE, OBS_IMAGES, OBS_LANGUAGE
|
||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||
from lerobot.policies.rlearn.modeling_rlearn import RLearNPolicy
|
||
|
||
|
||
def _to_chw_float01(img):
|
||
"""Ensure CHW float in [0,1]."""
|
||
if isinstance(img, np.ndarray):
|
||
img = torch.from_numpy(img)
|
||
# HWC -> CHW if needed
|
||
if len(img.shape) == 3 and img.shape[-1] in (1, 3, 4):
|
||
img = img.permute(2, 0, 1)
|
||
if img.dtype == torch.uint8:
|
||
img = img.float() / 255.0
|
||
else:
|
||
img = img.float()
|
||
return torch.clamp(img, 0.0, 1.0)
|
||
|
||
|
||
def _get_language(frame_data):
|
||
lang = None
|
||
if OBS_LANGUAGE in frame_data:
|
||
lang = frame_data[OBS_LANGUAGE]
|
||
if isinstance(lang, list) and len(lang) > 0:
|
||
lang = lang[0]
|
||
elif "task" in frame_data:
|
||
lang = frame_data["task"]
|
||
return lang if isinstance(lang, str) else "No language provided"
|
||
|
||
|
||
def _get_ground_truth_reward(frame_data):
|
||
"""Try common keys for ground-truth reward. Return None if unavailable."""
|
||
for key in ("reward", "rewards", "gt_reward", "progress"):
|
||
if key in frame_data:
|
||
r = frame_data[key]
|
||
# unwrap single-element lists/arrays
|
||
if isinstance(r, (list, np.ndarray)) and np.array(r).size == 1:
|
||
r = float(np.array(r).reshape(-1)[0])
|
||
try:
|
||
return float(r)
|
||
except Exception:
|
||
pass
|
||
return None
|
||
|
||
|
||
def extract_episode_frames_and_gt(dataset, episode_idx):
|
||
"""Load a full episode: frames (T, C, H, W), language (str), gt_rewards (np.ndarray or None)."""
|
||
ep_start = dataset.episode_data_index["from"][episode_idx].item()
|
||
ep_end = dataset.episode_data_index["to"][episode_idx].item()
|
||
T = ep_end - ep_start
|
||
|
||
frames = []
|
||
gt_rewards = []
|
||
language = None
|
||
|
||
for t in range(T):
|
||
item = dataset[ep_start + t]
|
||
|
||
# image(s)
|
||
if OBS_IMAGES in item:
|
||
img = item[OBS_IMAGES]
|
||
elif OBS_IMAGE in item:
|
||
img = item[OBS_IMAGE]
|
||
else:
|
||
# try to find an image-like key
|
||
img_keys = [k for k in item.keys() if "image" in k.lower()]
|
||
if not img_keys:
|
||
continue
|
||
img = item[img_keys[0]]
|
||
|
||
frames.append(_to_chw_float01(img))
|
||
|
||
# language once
|
||
if language is None:
|
||
language = _get_language(item)
|
||
|
||
# ground-truth reward (optional)
|
||
r = _get_ground_truth_reward(item)
|
||
gt_rewards.append(r)
|
||
|
||
if not frames:
|
||
return None, None, None
|
||
|
||
frames = torch.stack(frames) # (T, C, H, W)
|
||
|
||
# If all GT entries are None, treat as missing
|
||
if all(r is None for r in gt_rewards):
|
||
gt_rewards = None
|
||
else:
|
||
# Replace None by forward filling
|
||
arr = np.array([np.nan if r is None else float(r) for r in gt_rewards], dtype=float)
|
||
# forward/back fill
|
||
if np.isnan(arr[0]):
|
||
first_valid = np.flatnonzero(~np.isnan(arr))
|
||
if len(first_valid) > 0:
|
||
arr[0] = arr[first_valid[0]]
|
||
else:
|
||
arr[0] = 0.0
|
||
for i in range(1, len(arr)):
|
||
if np.isnan(arr[i]):
|
||
arr[i] = arr[i - 1]
|
||
gt_rewards = arr
|
||
|
||
return frames, language or "No language provided", gt_rewards
|
||
|
||
|
||
@torch.no_grad()
|
||
def predict_rewards_sliding(model, frames, language, max_seq_len=16, batch_size=64, device="cuda", temporal_stride: int | None = None):
|
||
"""
|
||
Sliding-window prediction: for each frame i, create a window [max(0, i-L+1) .. i],
|
||
left-pad by repeating the first frame to length L (<= 16), and take the prediction
|
||
corresponding to the current frame's position in the window.
|
||
Returns np.ndarray of shape (T,).
|
||
"""
|
||
T = frames.shape[0]
|
||
cfg = getattr(model, "config", object())
|
||
L = int(getattr(cfg, "max_seq_len", max_seq_len))
|
||
L = min(L, max_seq_len) # hard-cap at 16
|
||
# Use the same temporal stride as training (skip s-1 frames, take 1)
|
||
if temporal_stride is None:
|
||
temporal_stride = int(getattr(cfg, "temporal_sampling_stride", 1))
|
||
temporal_stride = max(1, int(temporal_stride))
|
||
|
||
# Preprocessed tensor on device
|
||
frames = frames.to(device)
|
||
|
||
windows = []
|
||
frame_positions = [] # Track which temporal position each frame should use
|
||
left_pad_counts = [] # Number of left-pad (OOB) frames per window
|
||
|
||
for i in range(T):
|
||
# Build indices with stride s: [..., i-3, i] etc., left-padded by clamping to 0
|
||
idxs = [i - (L - 1 - j) * temporal_stride for j in range(L)]
|
||
pad_needed = sum(1 for k in idxs if k < 0)
|
||
clamped = [0 if k < 0 else (T - 1 if k >= T else k) for k in idxs]
|
||
window = frames[clamped] # (L, C, H, W)
|
||
|
||
# Use the last temporal position (current frame) for reading model output
|
||
frame_pos = L - 1
|
||
|
||
windows.append(window)
|
||
frame_positions.append(frame_pos)
|
||
left_pad_counts.append(pad_needed)
|
||
|
||
preds = np.zeros(T, dtype=float)
|
||
|
||
for s in range(0, T, batch_size):
|
||
e = min(s + batch_size, T)
|
||
batch_windows = torch.stack(windows[s:e]) # (B, L, C, H, W)
|
||
batch_positions = frame_positions[s:e]
|
||
|
||
batch = {OBS_IMAGES: batch_windows, OBS_LANGUAGE: [language] * (e - s)} # expects (B, L, C, H, W)
|
||
|
||
# Model returns (B, L) predictions for each temporal position
|
||
values = model.predict_rewards(batch) # torch.Tensor (B, L)
|
||
|
||
# Apply eval-time padding rule: predictions for left-padded (OOB) frames are zero
|
||
if values.dim() == 2 and len(left_pad_counts) >= (e - s):
|
||
for b_idx in range(e - s):
|
||
pad_n = left_pad_counts[s + b_idx]
|
||
if pad_n > 0:
|
||
values[b_idx, :pad_n] = 0.0
|
||
|
||
# Debug output removed - issue was identified and fixed
|
||
|
||
if values.dim() == 2:
|
||
# Extract the prediction corresponding to each frame's position in its window
|
||
batch_preds = []
|
||
for b_idx, pos in enumerate(batch_positions):
|
||
batch_preds.append(values[b_idx, pos].item())
|
||
preds[s:e] = np.array(batch_preds)
|
||
else:
|
||
# Fallback: if model returns (B,), use as is
|
||
preds[s:e] = values.detach().float().cpu().numpy()
|
||
|
||
return preds
|
||
|
||
|
||
def plot_episode_eval(episode_idx, gt, pred, language, save_path=None, show=False, title_prefix="RLearN Eval"):
|
||
"""Plot GT vs Predicted over time. Saves PNG if save_path is provided."""
|
||
T = len(pred)
|
||
x = np.arange(T)
|
||
|
||
plt.figure(figsize=(14, 8))
|
||
plt.plot(x, pred, linewidth=2.5, marker="o", markersize=3, label="Predicted Reward", color="blue")
|
||
|
||
if gt is not None:
|
||
plt.plot(x, gt, linestyle="--", linewidth=2.5, label="Ground-Truth Reward", color="orange")
|
||
# Correlation between GT and Pred
|
||
corr, p = spearmanr(gt, pred)
|
||
corr_str = f"ρ(GT, Pred) = {0.0 if np.isnan(corr) else corr:.3f} (p={0.0 if np.isnan(p) else p:.3f})"
|
||
else:
|
||
expected = np.linspace(0, 1, T)
|
||
plt.plot(x, expected, linestyle="--", linewidth=2.5, label="Expected Progress (0→1)", color="orange")
|
||
corr, p = spearmanr(x, pred)
|
||
corr_str = f"VOC-S ρ(t, Pred) = {0.0 if np.isnan(corr) else corr:.3f} (p={0.0 if np.isnan(p) else p:.3f})"
|
||
|
||
plt.title(f"{title_prefix} — Episode {episode_idx}\n{language}\n{corr_str}", fontsize=14)
|
||
plt.xlabel("Frame Index", fontsize=12)
|
||
plt.ylabel("Reward / Progress", fontsize=12)
|
||
plt.legend(fontsize=11)
|
||
plt.grid(True, alpha=0.3)
|
||
plt.tight_layout()
|
||
|
||
if save_path is not None:
|
||
plt.savefig(save_path, dpi=200, bbox_inches="tight")
|
||
print(f"Saved eval image to: {save_path}")
|
||
|
||
if show:
|
||
plt.show()
|
||
else:
|
||
plt.close()
|
||
|
||
|
||
def eval_episode_sliding(
|
||
episode_idx, dataset, model, save_dir=".", device="cuda", max_seq_len=16, batch_size=64, title_prefix="RLearN Eval"
|
||
):
|
||
"""End-to-end: load episode, predict with sliding 16-frame windows, and save PNG."""
|
||
frames, language, gt = extract_episode_frames_and_gt(dataset, episode_idx)
|
||
if frames is None:
|
||
print(f"[Episode {episode_idx}] No frames found.")
|
||
return None
|
||
|
||
model.eval()
|
||
|
||
pred = predict_rewards_sliding(
|
||
model=model, frames=frames, language=language, max_seq_len=max_seq_len, batch_size=batch_size, device=device
|
||
)
|
||
|
||
# Basic stats
|
||
print(f"Episode {episode_idx}: T={len(pred)}, pred∈[{pred.min():.3f},{pred.max():.3f}]")
|
||
if gt is not None:
|
||
print(f"GT available: gt∈[{np.nanmin(gt):.3f},{np.nanmax(gt):.3f}]")
|
||
|
||
save_path = f"{save_dir}/episode_{episode_idx:04d}_eval.png"
|
||
plot_episode_eval(
|
||
episode_idx=episode_idx, gt=gt, pred=pred, language=language, save_path=save_path, show=False, title_prefix=title_prefix
|
||
)
|
||
return save_path
|
||
|
||
|
||
def main():
|
||
"""Main evaluation script for RLearN models."""
|
||
# Parse command line arguments
|
||
parser = argparse.ArgumentParser(description="Evaluate RLearN model on episodes with GT vs Predicted rewards")
|
||
parser.add_argument("--model", type=str, required=True, help="Model name/path (e.g., pepijn223/rlearn_mse5)")
|
||
parser.add_argument("--dataset", type=str, required=True, help="Dataset repo (e.g., pepijn223/phone_pipeline_pickup1)")
|
||
parser.add_argument("--episodes", type=int, default=5, help="Number of episodes to evaluate")
|
||
parser.add_argument("--output", type=str, default="./eval_results", help="Output directory for images")
|
||
parser.add_argument(
|
||
"--device",
|
||
type=str,
|
||
default="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu",
|
||
help="Device to use",
|
||
)
|
||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for sliding window evaluation")
|
||
|
||
args = parser.parse_args()
|
||
|
||
# Create output directory
|
||
output_dir = Path(args.output)
|
||
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print("🎯 RLearN Model Evaluation")
|
||
print("=" * 60)
|
||
print(f"Model: {args.model}")
|
||
print(f"Dataset: {args.dataset}")
|
||
print(f"Episodes: {args.episodes}")
|
||
print(f"Device: {args.device}")
|
||
print(f"Output: {output_dir}")
|
||
print("=" * 60)
|
||
|
||
try:
|
||
# Load dataset
|
||
print("📁 Loading dataset...")
|
||
|
||
dataset = LeRobotDataset(
|
||
repo_id=args.dataset,
|
||
episodes=list(range(min(args.episodes, 50))), # Load enough episodes
|
||
download_videos=True,
|
||
)
|
||
|
||
print(f"✅ Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||
print(f" Features: {list(dataset.features.keys())}")
|
||
print(f" FPS: {dataset.fps}")
|
||
|
||
# Load model
|
||
print("\n🤖 Loading model...")
|
||
|
||
model = RLearNPolicy.from_pretrained(args.model)
|
||
model = model.to(args.device)
|
||
model.eval()
|
||
|
||
print(f"✅ Model loaded on {args.device}")
|
||
print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")
|
||
print(f" Trainable: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
|
||
print(f" Max sequence length: {model.config.max_seq_len}")
|
||
|
||
# Select episodes to evaluate
|
||
total_available = min(dataset.num_episodes, args.episodes)
|
||
episode_indices = list(range(total_available))
|
||
|
||
print(f"\n📊 Evaluating {len(episode_indices)} episodes...")
|
||
print("=" * 60)
|
||
|
||
# Run sliding window evaluation on each episode
|
||
saved_paths = []
|
||
for i, ep_idx in enumerate(episode_indices):
|
||
print(f"\n[{i+1}/{len(episode_indices)}] Processing Episode {ep_idx}")
|
||
print("-" * 40)
|
||
|
||
try:
|
||
save_path = eval_episode_sliding(
|
||
episode_idx=ep_idx,
|
||
dataset=dataset,
|
||
model=model,
|
||
save_dir=str(output_dir),
|
||
device=args.device,
|
||
batch_size=args.batch_size,
|
||
title_prefix="RLearN Ground Truth vs Predicted",
|
||
)
|
||
|
||
if save_path:
|
||
saved_paths.append(save_path)
|
||
|
||
except Exception as e:
|
||
print(f"❌ Error processing episode {ep_idx}: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
continue
|
||
|
||
# Summary
|
||
print("\n" + "=" * 60)
|
||
print("✅ EVALUATION COMPLETE")
|
||
print(f"📈 Generated {len(saved_paths)} evaluation plots")
|
||
print(f"📁 Results saved to: {output_dir}")
|
||
print("\nGenerated files:")
|
||
for path in saved_paths:
|
||
print(f" • {path}")
|
||
|
||
if saved_paths:
|
||
print(f"\n💡 View the plots to compare ground truth vs predicted rewards!")
|
||
print(f" Each plot shows the model's sliding 16-frame window predictions")
|
||
print(f" against available ground truth rewards over the episode timeline.")
|
||
|
||
return 0
|
||
|
||
except Exception as e:
|
||
print(f"❌ Error during evaluation: {e}")
|
||
import traceback
|
||
|
||
traceback.print_exc()
|
||
return 1
|
||
|
||
|
||
if __name__ == "__main__":
|
||
exit(main())
|