mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 04:41:24 +00:00
pi052: hierarchical select_action + RoboCasa eval video overlay
- modeling_pi052: per-env low-level subtask generation in select_action so hierarchical inference is correct for eval.batch_size > 1 - render_messages_processor: always emit a fallback low-level prompt so observation.language.tokens are produced when recipe annotations are absent - lerobot_eval: overlay high-level task + predicted subtask onto recorded rollout videos (render path only; does not affect policy observations) Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -46,7 +46,7 @@ import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
from ..pi05.configuration_pi05 import PI05Config
|
||||
from ..pi05.modeling_pi05 import PI05Policy
|
||||
@@ -461,6 +461,32 @@ class PI052Policy(PI05Policy):
|
||||
"gradients are blocked in attention."
|
||||
)
|
||||
|
||||
# Per-env hierarchical-inference state. Sized lazily on the first
|
||||
# select_action() call once the batch size (number of parallel envs)
|
||||
# is known. ``last_subtasks[i]`` is the subtask currently conditioning
|
||||
# env ``i``'s action expert; scalar ``last_subtask`` mirrors env 0 for
|
||||
# back-compat (e.g. the eval video overlay).
|
||||
self.last_subtasks: list[str] | None = None
|
||||
self.last_subtasks_raw: list[str] | None = None
|
||||
self.last_subtasks_source: list[str] | None = None
|
||||
self._last_good_subtasks: list[str | None] | None = None
|
||||
self.last_subtask: str | None = None
|
||||
self.last_subtask_raw: str | None = None
|
||||
self.last_subtask_source: str = "unset"
|
||||
self.last_subtask_debug: str = ""
|
||||
|
||||
def reset(self):
|
||||
"""Reset action and high-level inference state."""
|
||||
super().reset()
|
||||
self.last_subtasks = None
|
||||
self.last_subtasks_raw = None
|
||||
self.last_subtasks_source = None
|
||||
self._last_good_subtasks = None
|
||||
self.last_subtask = None
|
||||
self.last_subtask_raw = None
|
||||
self.last_subtask_source = "unset"
|
||||
self.last_subtask_debug = ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Head unfreeze helper
|
||||
# ------------------------------------------------------------------
|
||||
@@ -1167,6 +1193,226 @@ class PI052Policy(PI05Policy):
|
||||
self._last_select_message_debug = ""
|
||||
return decoded
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select an action via PI052's high-level → low-level inference path.
|
||||
|
||||
At action-chunk boundaries, first generate a low-level subtask from
|
||||
the high-level task prompt. Then retokenize that subtask as the
|
||||
low-level action prompt before sampling the action chunk. This keeps
|
||||
the public policy API identical to PI05 (`Tensor` action out), while
|
||||
matching the PI052 training/runtime conditioning more closely.
|
||||
"""
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
|
||||
self.eval()
|
||||
|
||||
if len(self._action_queue) == 0:
|
||||
action_batch = self._with_low_level_subtask_prompt(batch)
|
||||
actions = self.predict_action_chunk(action_batch)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def _with_low_level_subtask_prompt(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
from .inference.steps import _build_text_batch # noqa: PLC0415
|
||||
|
||||
n = self._batch_size_from_observation(batch)
|
||||
self._ensure_subtask_state(n)
|
||||
tasks = self._tasks_from_batch(batch, n)
|
||||
|
||||
# Generate one subtask per parallel env, each conditioned on that env's
|
||||
# own task + observation, then stack the per-env prompts into a single
|
||||
# (n, L) batch for the action expert. This keeps batch_size > 1 correct
|
||||
# (env i is conditioned on env i's subtask, not a broadcast of env 0).
|
||||
rows: list[tuple[Tensor, Tensor | None]] = []
|
||||
tokenizer = None
|
||||
for i in range(n):
|
||||
obs_i = self._slice_observation(batch, i)
|
||||
subtask = self._generate_low_level_subtask(obs_i, tasks[i], i)
|
||||
text_batch = _build_text_batch(
|
||||
self,
|
||||
[{"role": "user", "content": subtask}],
|
||||
add_generation_prompt=False,
|
||||
)
|
||||
rows.append((text_batch["lang_tokens"], text_batch["lang_masks"]))
|
||||
tokenizer = text_batch["tokenizer"]
|
||||
|
||||
tokens, masks = self._stack_token_rows(rows, tokenizer)
|
||||
|
||||
# Scalar aliases mirror env 0 for back-compat / single-env overlays.
|
||||
self.last_subtask = self.last_subtasks[0] if self.last_subtasks else None
|
||||
self.last_subtask_raw = self.last_subtasks_raw[0] if self.last_subtasks_raw else None
|
||||
self.last_subtask_source = (
|
||||
self.last_subtasks_source[0] if self.last_subtasks_source else "unset"
|
||||
)
|
||||
|
||||
out = dict(batch)
|
||||
out[OBS_LANGUAGE_TOKENS] = tokens
|
||||
out[OBS_LANGUAGE_ATTENTION_MASK] = masks
|
||||
return out
|
||||
|
||||
def _generate_low_level_subtask(self, obs_i: dict[str, Tensor], task: str, i: int) -> str:
|
||||
from .inference.steps import _generate_with_policy, _looks_like_gibberish # noqa: PLC0415
|
||||
|
||||
msg = ""
|
||||
if task:
|
||||
msg = _generate_with_policy(
|
||||
self,
|
||||
[{"role": "user", "content": task}],
|
||||
observation=obs_i,
|
||||
label=f"eval subtask gen[{i}]",
|
||||
suppress_loc_tokens=True,
|
||||
)
|
||||
self.last_subtasks_raw[i] = msg or ""
|
||||
|
||||
# Faithful hierarchical inference: condition the action expert on the
|
||||
# model's own generated subtask verbatim (this is exactly what the
|
||||
# ``low_level_execution`` recipe did at training — ``user: ${subtask}``).
|
||||
if msg and not _looks_like_gibberish(msg):
|
||||
subtask = " ".join(msg.strip().split())
|
||||
self._last_good_subtasks[i] = subtask
|
||||
self.last_subtasks[i] = subtask
|
||||
self.last_subtasks_source[i] = "generated"
|
||||
logger.info("PI052 eval subtask[%d]: %r (task=%r)", i, subtask, task)
|
||||
return subtask
|
||||
|
||||
# Generation unusable (empty / gibberish). Training never fed such a
|
||||
# prompt to the action expert, so the least-OOD choice is to reuse this
|
||||
# env's last accepted subtask; on the first chunk (none yet) derive one
|
||||
# from the task so the action expert still gets an imperative command
|
||||
# rather than the raw high-level instruction.
|
||||
debug = getattr(self, "_last_select_message_debug", "") or ""
|
||||
if not task:
|
||||
reason = "No task string was available in the batch."
|
||||
elif msg:
|
||||
reason = f"Rejected generated subtask: {msg!r}"
|
||||
else:
|
||||
reason = f"Empty generated subtask. {debug}".strip()
|
||||
if self._last_good_subtasks[i]:
|
||||
subtask = self._last_good_subtasks[i]
|
||||
source = "reuse_last"
|
||||
else:
|
||||
subtask = self._fallback_subtask_from_task(task)
|
||||
source = "fallback_task"
|
||||
self.last_subtasks[i] = subtask
|
||||
self.last_subtasks_source[i] = source
|
||||
logger.info(
|
||||
"PI052 eval subtask[%d] fallback (%s): %s | final=%r task=%r",
|
||||
i,
|
||||
source,
|
||||
reason,
|
||||
subtask,
|
||||
task,
|
||||
)
|
||||
return subtask
|
||||
|
||||
def _ensure_subtask_state(self, n: int) -> None:
|
||||
"""(Re)allocate per-env subtask buffers when the env count is first seen."""
|
||||
if self.last_subtasks is not None and len(self.last_subtasks) == n:
|
||||
return
|
||||
self.last_subtasks = ["" for _ in range(n)]
|
||||
self.last_subtasks_raw = ["" for _ in range(n)]
|
||||
self.last_subtasks_source = ["unset" for _ in range(n)]
|
||||
self._last_good_subtasks = [None for _ in range(n)]
|
||||
|
||||
@staticmethod
|
||||
def _slice_observation(batch: dict[str, Tensor], i: int) -> dict[str, Tensor]:
|
||||
"""Slice the per-env observation tensors for env ``i`` (images/state).
|
||||
|
||||
Language keys are excluded so high-level generation uses the freshly
|
||||
tokenized task prompt, not the preprocessor's low-level fallback tokens.
|
||||
"""
|
||||
out: dict[str, Tensor] = {}
|
||||
for k, v in batch.items():
|
||||
if not (isinstance(k, str) and k.startswith("observation.")):
|
||||
continue
|
||||
if k.startswith("observation.language"):
|
||||
continue
|
||||
if torch.is_tensor(v):
|
||||
out[k] = v[i : i + 1]
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _stack_token_rows(
|
||||
rows: list[tuple[Tensor, Tensor | None]], tokenizer: Any
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
"""Right-pad per-env ``(1, L_i)`` token/mask rows and stack to ``(n, L)``.
|
||||
|
||||
Right-padding with a False attention mask matches the training-time
|
||||
tokenizer (``padding_side="right"``), so the action expert treats pad
|
||||
positions as masked.
|
||||
"""
|
||||
max_len = max(t.shape[1] for t, _ in rows)
|
||||
pad_id = getattr(tokenizer, "pad_token_id", None) or 0
|
||||
tok_rows: list[Tensor] = []
|
||||
mask_rows: list[Tensor] = []
|
||||
for tokens, masks in rows:
|
||||
length = tokens.shape[1]
|
||||
if masks is None:
|
||||
masks = torch.ones((1, length), dtype=torch.bool, device=tokens.device)
|
||||
if length < max_len:
|
||||
pad = max_len - length
|
||||
tokens = torch.cat(
|
||||
[tokens, torch.full((1, pad), pad_id, dtype=tokens.dtype, device=tokens.device)],
|
||||
dim=1,
|
||||
)
|
||||
masks = torch.cat(
|
||||
[masks, torch.zeros((1, pad), dtype=masks.dtype, device=masks.device)],
|
||||
dim=1,
|
||||
)
|
||||
tok_rows.append(tokens)
|
||||
mask_rows.append(masks)
|
||||
return torch.cat(tok_rows, dim=0), torch.cat(mask_rows, dim=0)
|
||||
|
||||
@staticmethod
|
||||
def _fallback_subtask_from_task(task: str) -> str:
|
||||
target = PI052Policy._navigation_target_from_task(task)
|
||||
if target:
|
||||
return f"go to {target}"
|
||||
if task.lower().startswith("open the stand mixer head"):
|
||||
return "pull stand mixer head"
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def _navigation_target_from_task(task: str) -> str:
|
||||
prefix = "navigate to "
|
||||
lower = task.lower().strip()
|
||||
if not lower.startswith(prefix):
|
||||
return ""
|
||||
return lower[len(prefix) :].strip().rstrip(".")
|
||||
|
||||
@staticmethod
|
||||
def _tasks_from_batch(batch: dict[str, Any], n: int) -> list[str]:
|
||||
"""Return one task string per env, padded/truncated to ``n``."""
|
||||
task = batch.get("task")
|
||||
if isinstance(task, list):
|
||||
raw = list(task)
|
||||
elif task is None:
|
||||
raw = []
|
||||
else:
|
||||
raw = [task]
|
||||
tasks: list[str] = []
|
||||
for t in raw:
|
||||
if hasattr(t, "item"):
|
||||
t = t.item()
|
||||
tasks.append(t if isinstance(t, str) else "")
|
||||
if len(tasks) < n:
|
||||
tasks += [tasks[-1] if tasks else ""] * (n - len(tasks))
|
||||
return tasks[:n]
|
||||
|
||||
@staticmethod
|
||||
def _batch_size_from_observation(batch: dict[str, Any]) -> int:
|
||||
state = batch.get("observation.state")
|
||||
if torch.is_tensor(state) and state.ndim > 0:
|
||||
return int(state.shape[0])
|
||||
for key, value in batch.items():
|
||||
if isinstance(key, str) and key.startswith("observation.images.") and torch.is_tensor(value):
|
||||
return int(value.shape[0])
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def _sample_next_token(logits: Tensor, temperature: float, top_p: float) -> Tensor:
|
||||
if temperature <= 0.0:
|
||||
|
||||
@@ -50,7 +50,14 @@ class RenderMessagesStep(ProcessorStep):
|
||||
events = complementary_data.get(LANGUAGE_EVENTS) or []
|
||||
|
||||
if not persistent and not events:
|
||||
return transition
|
||||
rendered = _fallback_low_level_render(complementary_data.get("task"))
|
||||
if rendered is None:
|
||||
return transition
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
new_complementary_data.update(rendered)
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return new_transition
|
||||
|
||||
if _is_batched_language(persistent) or _is_batched_language(events):
|
||||
return self._call_batch(transition, complementary_data, persistent, events)
|
||||
@@ -191,6 +198,22 @@ def _fallback_low_level_render(task: Any) -> dict[str, Any] | None:
|
||||
"""Keep action-only samples trainable when no recipe branch matches."""
|
||||
if hasattr(task, "item"):
|
||||
task = task.item()
|
||||
if isinstance(task, list):
|
||||
messages = []
|
||||
message_streams = []
|
||||
target_message_indices = []
|
||||
for t in task:
|
||||
rendered = _fallback_low_level_render(t)
|
||||
if rendered is None:
|
||||
return None
|
||||
messages.append(rendered["messages"])
|
||||
message_streams.append(rendered["message_streams"])
|
||||
target_message_indices.append(rendered["target_message_indices"])
|
||||
return {
|
||||
"messages": messages,
|
||||
"message_streams": message_streams,
|
||||
"target_message_indices": target_message_indices,
|
||||
}
|
||||
if not isinstance(task, str) or not task:
|
||||
return None
|
||||
return {
|
||||
|
||||
@@ -95,6 +95,67 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _wrap_text_to_width(text: str, cv2, font, scale: int, thickness: int, max_width: int) -> list[str]:
|
||||
"""Greedy word-wrap using measured pixel width so text fits the frame."""
|
||||
words = text.split()
|
||||
lines: list[str] = []
|
||||
current = ""
|
||||
for word in words:
|
||||
candidate = f"{current} {word}".strip()
|
||||
(w, _), _ = cv2.getTextSize(candidate, font, scale, thickness)
|
||||
if w > max_width and current:
|
||||
lines.append(current)
|
||||
current = word
|
||||
else:
|
||||
current = candidate
|
||||
if current:
|
||||
lines.append(current)
|
||||
return lines or [""]
|
||||
|
||||
|
||||
def _annotate_eval_frames(
|
||||
frames: np.ndarray, task: str | None, subtask: str | None
|
||||
) -> np.ndarray:
|
||||
"""Overlay the high-level task and predicted subtask onto rendered frames.
|
||||
|
||||
``frames`` is ``(n_envs, H, W, C)`` uint8. Best-effort: if OpenCV isn't
|
||||
available the frames are returned unchanged so eval never fails over a
|
||||
visualization concern.
|
||||
"""
|
||||
if frames.ndim != 4 or frames.shape[-1] != 3:
|
||||
return frames
|
||||
try:
|
||||
import cv2 # noqa: PLC0415
|
||||
except ImportError:
|
||||
return frames
|
||||
|
||||
width = frames.shape[2]
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
scale = 0.5
|
||||
margin = 6
|
||||
max_width = width - 2 * margin
|
||||
|
||||
lines: list[str] = []
|
||||
if task:
|
||||
lines += _wrap_text_to_width(f"Task: {task}", cv2, font, scale, 1, max_width)
|
||||
if subtask:
|
||||
lines += _wrap_text_to_width(f"Subtask: {subtask}", cv2, font, scale, 1, max_width)
|
||||
if not lines:
|
||||
return frames
|
||||
|
||||
out = frames.copy()
|
||||
for i in range(out.shape[0]):
|
||||
img = np.ascontiguousarray(out[i])
|
||||
y = 18
|
||||
for line in lines:
|
||||
# Black outline then white fill so text stays legible on any scene.
|
||||
cv2.putText(img, line, (margin, y), font, scale, (0, 0, 0), 3, cv2.LINE_AA)
|
||||
cv2.putText(img, line, (margin, y), font, scale, (255, 255, 255), 1, cv2.LINE_AA)
|
||||
y += 20
|
||||
out[i] = img
|
||||
return out
|
||||
|
||||
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
@@ -325,11 +386,42 @@ def eval_policy(
|
||||
return
|
||||
n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs)
|
||||
if isinstance(env, gym.vector.SyncVectorEnv):
|
||||
ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023
|
||||
frames = np.stack([env.envs[i].render() for i in range(n_to_render_now)]) # noqa: B023
|
||||
elif hasattr(env, "call"):
|
||||
# Here we must render all frames and discard any we don't need.
|
||||
# Covers AsyncVectorEnv and _LazyAsyncVectorEnv (which wraps one).
|
||||
ep_frames.append(np.stack(env.call("render")[:n_to_render_now]))
|
||||
frames = np.stack(env.call("render")[:n_to_render_now])
|
||||
else:
|
||||
return
|
||||
|
||||
# Overlay the high-level task and (for hierarchical policies like
|
||||
# pi052) the predicted low-level subtask onto each frame. Both are
|
||||
# best-effort: missing values just skip that line.
|
||||
try:
|
||||
tasks = list(env.call("task_description"))
|
||||
except (AttributeError, NotImplementedError):
|
||||
try:
|
||||
tasks = list(env.call("task"))
|
||||
except (AttributeError, NotImplementedError):
|
||||
tasks = None
|
||||
# Per-env subtasks when available (batched hierarchical policies);
|
||||
# fall back to the scalar last_subtask for single-env / other policies.
|
||||
subtasks = getattr(policy, "last_subtasks", None)
|
||||
subtask_scalar = getattr(policy, "last_subtask", None)
|
||||
annotated = []
|
||||
for i in range(frames.shape[0]):
|
||||
if subtasks is not None and i < len(subtasks):
|
||||
subtask_i = subtasks[i]
|
||||
else:
|
||||
subtask_i = subtask_scalar
|
||||
annotated.append(
|
||||
_annotate_eval_frames(
|
||||
frames[i : i + 1],
|
||||
tasks[i] if tasks is not None and i < len(tasks) else None,
|
||||
subtask_i,
|
||||
)[0]
|
||||
)
|
||||
ep_frames.append(np.stack(annotated))
|
||||
|
||||
if max_episodes_rendered > 0:
|
||||
video_paths: list[str] = []
|
||||
|
||||
Reference in New Issue
Block a user