diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index cb491b7ed..ce8c3abc6 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -46,7 +46,7 @@ import torch from torch import Tensor from torch.nn import functional as F -from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS +from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS from ..pi05.configuration_pi05 import PI05Config from ..pi05.modeling_pi05 import PI05Policy @@ -461,6 +461,32 @@ class PI052Policy(PI05Policy): "gradients are blocked in attention." ) + # Per-env hierarchical-inference state. Sized lazily on the first + # select_action() call once the batch size (number of parallel envs) + # is known. ``last_subtasks[i]`` is the subtask currently conditioning + # env ``i``'s action expert; scalar ``last_subtask`` mirrors env 0 for + # back-compat (e.g. the eval video overlay). + self.last_subtasks: list[str] | None = None + self.last_subtasks_raw: list[str] | None = None + self.last_subtasks_source: list[str] | None = None + self._last_good_subtasks: list[str | None] | None = None + self.last_subtask: str | None = None + self.last_subtask_raw: str | None = None + self.last_subtask_source: str = "unset" + self.last_subtask_debug: str = "" + + def reset(self): + """Reset action and high-level inference state.""" + super().reset() + self.last_subtasks = None + self.last_subtasks_raw = None + self.last_subtasks_source = None + self._last_good_subtasks = None + self.last_subtask = None + self.last_subtask_raw = None + self.last_subtask_source = "unset" + self.last_subtask_debug = "" + # ------------------------------------------------------------------ # Head unfreeze helper # ------------------------------------------------------------------ @@ -1167,6 +1193,226 @@ class PI052Policy(PI05Policy): self._last_select_message_debug = "" return decoded + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select an action via PI052's high-level → low-level inference path. + + At action-chunk boundaries, first generate a low-level subtask from + the high-level task prompt. Then retokenize that subtask as the + low-level action prompt before sampling the action chunk. This keeps + the public policy API identical to PI05 (`Tensor` action out), while + matching the PI052 training/runtime conditioning more closely. + """ + assert not self._rtc_enabled(), ( + "RTC is not supported for select_action, use it with predict_action_chunk" + ) + + self.eval() + + if len(self._action_queue) == 0: + action_batch = self._with_low_level_subtask_prompt(batch) + actions = self.predict_action_chunk(action_batch)[:, : self.config.n_action_steps] + self._action_queue.extend(actions.transpose(0, 1)) + + return self._action_queue.popleft() + + def _with_low_level_subtask_prompt(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + from .inference.steps import _build_text_batch # noqa: PLC0415 + + n = self._batch_size_from_observation(batch) + self._ensure_subtask_state(n) + tasks = self._tasks_from_batch(batch, n) + + # Generate one subtask per parallel env, each conditioned on that env's + # own task + observation, then stack the per-env prompts into a single + # (n, L) batch for the action expert. This keeps batch_size > 1 correct + # (env i is conditioned on env i's subtask, not a broadcast of env 0). + rows: list[tuple[Tensor, Tensor | None]] = [] + tokenizer = None + for i in range(n): + obs_i = self._slice_observation(batch, i) + subtask = self._generate_low_level_subtask(obs_i, tasks[i], i) + text_batch = _build_text_batch( + self, + [{"role": "user", "content": subtask}], + add_generation_prompt=False, + ) + rows.append((text_batch["lang_tokens"], text_batch["lang_masks"])) + tokenizer = text_batch["tokenizer"] + + tokens, masks = self._stack_token_rows(rows, tokenizer) + + # Scalar aliases mirror env 0 for back-compat / single-env overlays. + self.last_subtask = self.last_subtasks[0] if self.last_subtasks else None + self.last_subtask_raw = self.last_subtasks_raw[0] if self.last_subtasks_raw else None + self.last_subtask_source = ( + self.last_subtasks_source[0] if self.last_subtasks_source else "unset" + ) + + out = dict(batch) + out[OBS_LANGUAGE_TOKENS] = tokens + out[OBS_LANGUAGE_ATTENTION_MASK] = masks + return out + + def _generate_low_level_subtask(self, obs_i: dict[str, Tensor], task: str, i: int) -> str: + from .inference.steps import _generate_with_policy, _looks_like_gibberish # noqa: PLC0415 + + msg = "" + if task: + msg = _generate_with_policy( + self, + [{"role": "user", "content": task}], + observation=obs_i, + label=f"eval subtask gen[{i}]", + suppress_loc_tokens=True, + ) + self.last_subtasks_raw[i] = msg or "" + + # Faithful hierarchical inference: condition the action expert on the + # model's own generated subtask verbatim (this is exactly what the + # ``low_level_execution`` recipe did at training — ``user: ${subtask}``). + if msg and not _looks_like_gibberish(msg): + subtask = " ".join(msg.strip().split()) + self._last_good_subtasks[i] = subtask + self.last_subtasks[i] = subtask + self.last_subtasks_source[i] = "generated" + logger.info("PI052 eval subtask[%d]: %r (task=%r)", i, subtask, task) + return subtask + + # Generation unusable (empty / gibberish). Training never fed such a + # prompt to the action expert, so the least-OOD choice is to reuse this + # env's last accepted subtask; on the first chunk (none yet) derive one + # from the task so the action expert still gets an imperative command + # rather than the raw high-level instruction. + debug = getattr(self, "_last_select_message_debug", "") or "" + if not task: + reason = "No task string was available in the batch." + elif msg: + reason = f"Rejected generated subtask: {msg!r}" + else: + reason = f"Empty generated subtask. {debug}".strip() + if self._last_good_subtasks[i]: + subtask = self._last_good_subtasks[i] + source = "reuse_last" + else: + subtask = self._fallback_subtask_from_task(task) + source = "fallback_task" + self.last_subtasks[i] = subtask + self.last_subtasks_source[i] = source + logger.info( + "PI052 eval subtask[%d] fallback (%s): %s | final=%r task=%r", + i, + source, + reason, + subtask, + task, + ) + return subtask + + def _ensure_subtask_state(self, n: int) -> None: + """(Re)allocate per-env subtask buffers when the env count is first seen.""" + if self.last_subtasks is not None and len(self.last_subtasks) == n: + return + self.last_subtasks = ["" for _ in range(n)] + self.last_subtasks_raw = ["" for _ in range(n)] + self.last_subtasks_source = ["unset" for _ in range(n)] + self._last_good_subtasks = [None for _ in range(n)] + + @staticmethod + def _slice_observation(batch: dict[str, Tensor], i: int) -> dict[str, Tensor]: + """Slice the per-env observation tensors for env ``i`` (images/state). + + Language keys are excluded so high-level generation uses the freshly + tokenized task prompt, not the preprocessor's low-level fallback tokens. + """ + out: dict[str, Tensor] = {} + for k, v in batch.items(): + if not (isinstance(k, str) and k.startswith("observation.")): + continue + if k.startswith("observation.language"): + continue + if torch.is_tensor(v): + out[k] = v[i : i + 1] + return out + + @staticmethod + def _stack_token_rows( + rows: list[tuple[Tensor, Tensor | None]], tokenizer: Any + ) -> tuple[Tensor, Tensor]: + """Right-pad per-env ``(1, L_i)`` token/mask rows and stack to ``(n, L)``. + + Right-padding with a False attention mask matches the training-time + tokenizer (``padding_side="right"``), so the action expert treats pad + positions as masked. + """ + max_len = max(t.shape[1] for t, _ in rows) + pad_id = getattr(tokenizer, "pad_token_id", None) or 0 + tok_rows: list[Tensor] = [] + mask_rows: list[Tensor] = [] + for tokens, masks in rows: + length = tokens.shape[1] + if masks is None: + masks = torch.ones((1, length), dtype=torch.bool, device=tokens.device) + if length < max_len: + pad = max_len - length + tokens = torch.cat( + [tokens, torch.full((1, pad), pad_id, dtype=tokens.dtype, device=tokens.device)], + dim=1, + ) + masks = torch.cat( + [masks, torch.zeros((1, pad), dtype=masks.dtype, device=masks.device)], + dim=1, + ) + tok_rows.append(tokens) + mask_rows.append(masks) + return torch.cat(tok_rows, dim=0), torch.cat(mask_rows, dim=0) + + @staticmethod + def _fallback_subtask_from_task(task: str) -> str: + target = PI052Policy._navigation_target_from_task(task) + if target: + return f"go to {target}" + if task.lower().startswith("open the stand mixer head"): + return "pull stand mixer head" + return task + + @staticmethod + def _navigation_target_from_task(task: str) -> str: + prefix = "navigate to " + lower = task.lower().strip() + if not lower.startswith(prefix): + return "" + return lower[len(prefix) :].strip().rstrip(".") + + @staticmethod + def _tasks_from_batch(batch: dict[str, Any], n: int) -> list[str]: + """Return one task string per env, padded/truncated to ``n``.""" + task = batch.get("task") + if isinstance(task, list): + raw = list(task) + elif task is None: + raw = [] + else: + raw = [task] + tasks: list[str] = [] + for t in raw: + if hasattr(t, "item"): + t = t.item() + tasks.append(t if isinstance(t, str) else "") + if len(tasks) < n: + tasks += [tasks[-1] if tasks else ""] * (n - len(tasks)) + return tasks[:n] + + @staticmethod + def _batch_size_from_observation(batch: dict[str, Any]) -> int: + state = batch.get("observation.state") + if torch.is_tensor(state) and state.ndim > 0: + return int(state.shape[0]) + for key, value in batch.items(): + if isinstance(key, str) and key.startswith("observation.images.") and torch.is_tensor(value): + return int(value.shape[0]) + return 1 + @staticmethod def _sample_next_token(logits: Tensor, temperature: float, top_p: float) -> Tensor: if temperature <= 0.0: diff --git a/src/lerobot/processor/render_messages_processor.py b/src/lerobot/processor/render_messages_processor.py index e3dc6361a..0b5e4923f 100644 --- a/src/lerobot/processor/render_messages_processor.py +++ b/src/lerobot/processor/render_messages_processor.py @@ -50,7 +50,14 @@ class RenderMessagesStep(ProcessorStep): events = complementary_data.get(LANGUAGE_EVENTS) or [] if not persistent and not events: - return transition + rendered = _fallback_low_level_render(complementary_data.get("task")) + if rendered is None: + return transition + new_transition = transition.copy() + new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}) + new_complementary_data.update(rendered) + new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data + return new_transition if _is_batched_language(persistent) or _is_batched_language(events): return self._call_batch(transition, complementary_data, persistent, events) @@ -191,6 +198,22 @@ def _fallback_low_level_render(task: Any) -> dict[str, Any] | None: """Keep action-only samples trainable when no recipe branch matches.""" if hasattr(task, "item"): task = task.item() + if isinstance(task, list): + messages = [] + message_streams = [] + target_message_indices = [] + for t in task: + rendered = _fallback_low_level_render(t) + if rendered is None: + return None + messages.append(rendered["messages"]) + message_streams.append(rendered["message_streams"]) + target_message_indices.append(rendered["target_message_indices"]) + return { + "messages": messages, + "message_streams": message_streams, + "target_message_indices": target_message_indices, + } if not isinstance(task, str) or not task: return None return { diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index d45483d21..fc8885a91 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -95,6 +95,67 @@ from lerobot.utils.utils import ( ) +def _wrap_text_to_width(text: str, cv2, font, scale: int, thickness: int, max_width: int) -> list[str]: + """Greedy word-wrap using measured pixel width so text fits the frame.""" + words = text.split() + lines: list[str] = [] + current = "" + for word in words: + candidate = f"{current} {word}".strip() + (w, _), _ = cv2.getTextSize(candidate, font, scale, thickness) + if w > max_width and current: + lines.append(current) + current = word + else: + current = candidate + if current: + lines.append(current) + return lines or [""] + + +def _annotate_eval_frames( + frames: np.ndarray, task: str | None, subtask: str | None +) -> np.ndarray: + """Overlay the high-level task and predicted subtask onto rendered frames. + + ``frames`` is ``(n_envs, H, W, C)`` uint8. Best-effort: if OpenCV isn't + available the frames are returned unchanged so eval never fails over a + visualization concern. + """ + if frames.ndim != 4 or frames.shape[-1] != 3: + return frames + try: + import cv2 # noqa: PLC0415 + except ImportError: + return frames + + width = frames.shape[2] + font = cv2.FONT_HERSHEY_SIMPLEX + scale = 0.5 + margin = 6 + max_width = width - 2 * margin + + lines: list[str] = [] + if task: + lines += _wrap_text_to_width(f"Task: {task}", cv2, font, scale, 1, max_width) + if subtask: + lines += _wrap_text_to_width(f"Subtask: {subtask}", cv2, font, scale, 1, max_width) + if not lines: + return frames + + out = frames.copy() + for i in range(out.shape[0]): + img = np.ascontiguousarray(out[i]) + y = 18 + for line in lines: + # Black outline then white fill so text stays legible on any scene. + cv2.putText(img, line, (margin, y), font, scale, (0, 0, 0), 3, cv2.LINE_AA) + cv2.putText(img, line, (margin, y), font, scale, (255, 255, 255), 1, cv2.LINE_AA) + y += 20 + out[i] = img + return out + + def rollout( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, @@ -325,11 +386,42 @@ def eval_policy( return n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs) if isinstance(env, gym.vector.SyncVectorEnv): - ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023 + frames = np.stack([env.envs[i].render() for i in range(n_to_render_now)]) # noqa: B023 elif hasattr(env, "call"): # Here we must render all frames and discard any we don't need. # Covers AsyncVectorEnv and _LazyAsyncVectorEnv (which wraps one). - ep_frames.append(np.stack(env.call("render")[:n_to_render_now])) + frames = np.stack(env.call("render")[:n_to_render_now]) + else: + return + + # Overlay the high-level task and (for hierarchical policies like + # pi052) the predicted low-level subtask onto each frame. Both are + # best-effort: missing values just skip that line. + try: + tasks = list(env.call("task_description")) + except (AttributeError, NotImplementedError): + try: + tasks = list(env.call("task")) + except (AttributeError, NotImplementedError): + tasks = None + # Per-env subtasks when available (batched hierarchical policies); + # fall back to the scalar last_subtask for single-env / other policies. + subtasks = getattr(policy, "last_subtasks", None) + subtask_scalar = getattr(policy, "last_subtask", None) + annotated = [] + for i in range(frames.shape[0]): + if subtasks is not None and i < len(subtasks): + subtask_i = subtasks[i] + else: + subtask_i = subtask_scalar + annotated.append( + _annotate_eval_frames( + frames[i : i + 1], + tasks[i] if tasks is not None and i < len(tasks) else None, + subtask_i, + )[0] + ) + ep_frames.append(np.stack(annotated)) if max_episodes_rendered > 0: video_paths: list[str] = []