pi052: hierarchical select_action + RoboCasa eval video overlay

- modeling_pi052: per-env low-level subtask generation in select_action so
  hierarchical inference is correct for eval.batch_size > 1
- render_messages_processor: always emit a fallback low-level prompt so
  observation.language.tokens are produced when recipe annotations are absent
- lerobot_eval: overlay high-level task + predicted subtask onto recorded
  rollout videos (render path only; does not affect policy observations)

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn223
2026-06-01 14:35:13 +02:00
parent 1f1541243a
commit bb2c09965b
3 changed files with 365 additions and 4 deletions

View File

@@ -46,7 +46,7 @@ import torch
from torch import Tensor
from torch.nn import functional as F
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
from ..pi05.configuration_pi05 import PI05Config
from ..pi05.modeling_pi05 import PI05Policy
@@ -461,6 +461,32 @@ class PI052Policy(PI05Policy):
"gradients are blocked in attention."
)
# Per-env hierarchical-inference state. Sized lazily on the first
# select_action() call once the batch size (number of parallel envs)
# is known. ``last_subtasks[i]`` is the subtask currently conditioning
# env ``i``'s action expert; scalar ``last_subtask`` mirrors env 0 for
# back-compat (e.g. the eval video overlay).
self.last_subtasks: list[str] | None = None
self.last_subtasks_raw: list[str] | None = None
self.last_subtasks_source: list[str] | None = None
self._last_good_subtasks: list[str | None] | None = None
self.last_subtask: str | None = None
self.last_subtask_raw: str | None = None
self.last_subtask_source: str = "unset"
self.last_subtask_debug: str = ""
def reset(self):
"""Reset action and high-level inference state."""
super().reset()
self.last_subtasks = None
self.last_subtasks_raw = None
self.last_subtasks_source = None
self._last_good_subtasks = None
self.last_subtask = None
self.last_subtask_raw = None
self.last_subtask_source = "unset"
self.last_subtask_debug = ""
# ------------------------------------------------------------------
# Head unfreeze helper
# ------------------------------------------------------------------
@@ -1167,6 +1193,226 @@ class PI052Policy(PI05Policy):
self._last_select_message_debug = ""
return decoded
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select an action via PI052's high-level → low-level inference path.
At action-chunk boundaries, first generate a low-level subtask from
the high-level task prompt. Then retokenize that subtask as the
low-level action prompt before sampling the action chunk. This keeps
the public policy API identical to PI05 (`Tensor` action out), while
matching the PI052 training/runtime conditioning more closely.
"""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval()
if len(self._action_queue) == 0:
action_batch = self._with_low_level_subtask_prompt(batch)
actions = self.predict_action_chunk(action_batch)[:, : self.config.n_action_steps]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def _with_low_level_subtask_prompt(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
from .inference.steps import _build_text_batch # noqa: PLC0415
n = self._batch_size_from_observation(batch)
self._ensure_subtask_state(n)
tasks = self._tasks_from_batch(batch, n)
# Generate one subtask per parallel env, each conditioned on that env's
# own task + observation, then stack the per-env prompts into a single
# (n, L) batch for the action expert. This keeps batch_size > 1 correct
# (env i is conditioned on env i's subtask, not a broadcast of env 0).
rows: list[tuple[Tensor, Tensor | None]] = []
tokenizer = None
for i in range(n):
obs_i = self._slice_observation(batch, i)
subtask = self._generate_low_level_subtask(obs_i, tasks[i], i)
text_batch = _build_text_batch(
self,
[{"role": "user", "content": subtask}],
add_generation_prompt=False,
)
rows.append((text_batch["lang_tokens"], text_batch["lang_masks"]))
tokenizer = text_batch["tokenizer"]
tokens, masks = self._stack_token_rows(rows, tokenizer)
# Scalar aliases mirror env 0 for back-compat / single-env overlays.
self.last_subtask = self.last_subtasks[0] if self.last_subtasks else None
self.last_subtask_raw = self.last_subtasks_raw[0] if self.last_subtasks_raw else None
self.last_subtask_source = (
self.last_subtasks_source[0] if self.last_subtasks_source else "unset"
)
out = dict(batch)
out[OBS_LANGUAGE_TOKENS] = tokens
out[OBS_LANGUAGE_ATTENTION_MASK] = masks
return out
def _generate_low_level_subtask(self, obs_i: dict[str, Tensor], task: str, i: int) -> str:
from .inference.steps import _generate_with_policy, _looks_like_gibberish # noqa: PLC0415
msg = ""
if task:
msg = _generate_with_policy(
self,
[{"role": "user", "content": task}],
observation=obs_i,
label=f"eval subtask gen[{i}]",
suppress_loc_tokens=True,
)
self.last_subtasks_raw[i] = msg or ""
# Faithful hierarchical inference: condition the action expert on the
# model's own generated subtask verbatim (this is exactly what the
# ``low_level_execution`` recipe did at training — ``user: ${subtask}``).
if msg and not _looks_like_gibberish(msg):
subtask = " ".join(msg.strip().split())
self._last_good_subtasks[i] = subtask
self.last_subtasks[i] = subtask
self.last_subtasks_source[i] = "generated"
logger.info("PI052 eval subtask[%d]: %r (task=%r)", i, subtask, task)
return subtask
# Generation unusable (empty / gibberish). Training never fed such a
# prompt to the action expert, so the least-OOD choice is to reuse this
# env's last accepted subtask; on the first chunk (none yet) derive one
# from the task so the action expert still gets an imperative command
# rather than the raw high-level instruction.
debug = getattr(self, "_last_select_message_debug", "") or ""
if not task:
reason = "No task string was available in the batch."
elif msg:
reason = f"Rejected generated subtask: {msg!r}"
else:
reason = f"Empty generated subtask. {debug}".strip()
if self._last_good_subtasks[i]:
subtask = self._last_good_subtasks[i]
source = "reuse_last"
else:
subtask = self._fallback_subtask_from_task(task)
source = "fallback_task"
self.last_subtasks[i] = subtask
self.last_subtasks_source[i] = source
logger.info(
"PI052 eval subtask[%d] fallback (%s): %s | final=%r task=%r",
i,
source,
reason,
subtask,
task,
)
return subtask
def _ensure_subtask_state(self, n: int) -> None:
"""(Re)allocate per-env subtask buffers when the env count is first seen."""
if self.last_subtasks is not None and len(self.last_subtasks) == n:
return
self.last_subtasks = ["" for _ in range(n)]
self.last_subtasks_raw = ["" for _ in range(n)]
self.last_subtasks_source = ["unset" for _ in range(n)]
self._last_good_subtasks = [None for _ in range(n)]
@staticmethod
def _slice_observation(batch: dict[str, Tensor], i: int) -> dict[str, Tensor]:
"""Slice the per-env observation tensors for env ``i`` (images/state).
Language keys are excluded so high-level generation uses the freshly
tokenized task prompt, not the preprocessor's low-level fallback tokens.
"""
out: dict[str, Tensor] = {}
for k, v in batch.items():
if not (isinstance(k, str) and k.startswith("observation.")):
continue
if k.startswith("observation.language"):
continue
if torch.is_tensor(v):
out[k] = v[i : i + 1]
return out
@staticmethod
def _stack_token_rows(
rows: list[tuple[Tensor, Tensor | None]], tokenizer: Any
) -> tuple[Tensor, Tensor]:
"""Right-pad per-env ``(1, L_i)`` token/mask rows and stack to ``(n, L)``.
Right-padding with a False attention mask matches the training-time
tokenizer (``padding_side="right"``), so the action expert treats pad
positions as masked.
"""
max_len = max(t.shape[1] for t, _ in rows)
pad_id = getattr(tokenizer, "pad_token_id", None) or 0
tok_rows: list[Tensor] = []
mask_rows: list[Tensor] = []
for tokens, masks in rows:
length = tokens.shape[1]
if masks is None:
masks = torch.ones((1, length), dtype=torch.bool, device=tokens.device)
if length < max_len:
pad = max_len - length
tokens = torch.cat(
[tokens, torch.full((1, pad), pad_id, dtype=tokens.dtype, device=tokens.device)],
dim=1,
)
masks = torch.cat(
[masks, torch.zeros((1, pad), dtype=masks.dtype, device=masks.device)],
dim=1,
)
tok_rows.append(tokens)
mask_rows.append(masks)
return torch.cat(tok_rows, dim=0), torch.cat(mask_rows, dim=0)
@staticmethod
def _fallback_subtask_from_task(task: str) -> str:
target = PI052Policy._navigation_target_from_task(task)
if target:
return f"go to {target}"
if task.lower().startswith("open the stand mixer head"):
return "pull stand mixer head"
return task
@staticmethod
def _navigation_target_from_task(task: str) -> str:
prefix = "navigate to "
lower = task.lower().strip()
if not lower.startswith(prefix):
return ""
return lower[len(prefix) :].strip().rstrip(".")
@staticmethod
def _tasks_from_batch(batch: dict[str, Any], n: int) -> list[str]:
"""Return one task string per env, padded/truncated to ``n``."""
task = batch.get("task")
if isinstance(task, list):
raw = list(task)
elif task is None:
raw = []
else:
raw = [task]
tasks: list[str] = []
for t in raw:
if hasattr(t, "item"):
t = t.item()
tasks.append(t if isinstance(t, str) else "")
if len(tasks) < n:
tasks += [tasks[-1] if tasks else ""] * (n - len(tasks))
return tasks[:n]
@staticmethod
def _batch_size_from_observation(batch: dict[str, Any]) -> int:
state = batch.get("observation.state")
if torch.is_tensor(state) and state.ndim > 0:
return int(state.shape[0])
for key, value in batch.items():
if isinstance(key, str) and key.startswith("observation.images.") and torch.is_tensor(value):
return int(value.shape[0])
return 1
@staticmethod
def _sample_next_token(logits: Tensor, temperature: float, top_p: float) -> Tensor:
if temperature <= 0.0:

View File

@@ -50,7 +50,14 @@ class RenderMessagesStep(ProcessorStep):
events = complementary_data.get(LANGUAGE_EVENTS) or []
if not persistent and not events:
return transition
rendered = _fallback_low_level_render(complementary_data.get("task"))
if rendered is None:
return transition
new_transition = transition.copy()
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
new_complementary_data.update(rendered)
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
return new_transition
if _is_batched_language(persistent) or _is_batched_language(events):
return self._call_batch(transition, complementary_data, persistent, events)
@@ -191,6 +198,22 @@ def _fallback_low_level_render(task: Any) -> dict[str, Any] | None:
"""Keep action-only samples trainable when no recipe branch matches."""
if hasattr(task, "item"):
task = task.item()
if isinstance(task, list):
messages = []
message_streams = []
target_message_indices = []
for t in task:
rendered = _fallback_low_level_render(t)
if rendered is None:
return None
messages.append(rendered["messages"])
message_streams.append(rendered["message_streams"])
target_message_indices.append(rendered["target_message_indices"])
return {
"messages": messages,
"message_streams": message_streams,
"target_message_indices": target_message_indices,
}
if not isinstance(task, str) or not task:
return None
return {

View File

@@ -95,6 +95,67 @@ from lerobot.utils.utils import (
)
def _wrap_text_to_width(text: str, cv2, font, scale: int, thickness: int, max_width: int) -> list[str]:
"""Greedy word-wrap using measured pixel width so text fits the frame."""
words = text.split()
lines: list[str] = []
current = ""
for word in words:
candidate = f"{current} {word}".strip()
(w, _), _ = cv2.getTextSize(candidate, font, scale, thickness)
if w > max_width and current:
lines.append(current)
current = word
else:
current = candidate
if current:
lines.append(current)
return lines or [""]
def _annotate_eval_frames(
frames: np.ndarray, task: str | None, subtask: str | None
) -> np.ndarray:
"""Overlay the high-level task and predicted subtask onto rendered frames.
``frames`` is ``(n_envs, H, W, C)`` uint8. Best-effort: if OpenCV isn't
available the frames are returned unchanged so eval never fails over a
visualization concern.
"""
if frames.ndim != 4 or frames.shape[-1] != 3:
return frames
try:
import cv2 # noqa: PLC0415
except ImportError:
return frames
width = frames.shape[2]
font = cv2.FONT_HERSHEY_SIMPLEX
scale = 0.5
margin = 6
max_width = width - 2 * margin
lines: list[str] = []
if task:
lines += _wrap_text_to_width(f"Task: {task}", cv2, font, scale, 1, max_width)
if subtask:
lines += _wrap_text_to_width(f"Subtask: {subtask}", cv2, font, scale, 1, max_width)
if not lines:
return frames
out = frames.copy()
for i in range(out.shape[0]):
img = np.ascontiguousarray(out[i])
y = 18
for line in lines:
# Black outline then white fill so text stays legible on any scene.
cv2.putText(img, line, (margin, y), font, scale, (0, 0, 0), 3, cv2.LINE_AA)
cv2.putText(img, line, (margin, y), font, scale, (255, 255, 255), 1, cv2.LINE_AA)
y += 20
out[i] = img
return out
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
@@ -325,11 +386,42 @@ def eval_policy(
return
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
if isinstance(env, gym.vector.SyncVectorEnv):
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
frames = np.stack([env.envs[i].render() for i in range(n_to_render_now)]) # noqa: B023
elif hasattr(env, "call"):
# Here we must render all frames and discard any we don't need.
# Covers AsyncVectorEnv and _LazyAsyncVectorEnv (which wraps one).
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
frames = np.stack(env.call("render")[:n_to_render_now])
else:
return
# Overlay the high-level task and (for hierarchical policies like
# pi052) the predicted low-level subtask onto each frame. Both are
# best-effort: missing values just skip that line.
try:
tasks = list(env.call("task_description"))
except (AttributeError, NotImplementedError):
try:
tasks = list(env.call("task"))
except (AttributeError, NotImplementedError):
tasks = None
# Per-env subtasks when available (batched hierarchical policies);
# fall back to the scalar last_subtask for single-env / other policies.
subtasks = getattr(policy, "last_subtasks", None)
subtask_scalar = getattr(policy, "last_subtask", None)
annotated = []
for i in range(frames.shape[0]):
if subtasks is not None and i < len(subtasks):
subtask_i = subtasks[i]
else:
subtask_i = subtask_scalar
annotated.append(
_annotate_eval_frames(
frames[i : i + 1],
tasks[i] if tasks is not None and i < len(tasks) else None,
subtask_i,
)[0]
)
ep_frames.append(np.stack(annotated))
if max_episodes_rendered > 0:
video_paths: list[str] = []