mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
annotate: stitch subtasks to full-episode coverage
The verify pass prunes subtasks, which could leave the first subtask
starting after t0 or leave gaps between spans — so the subtask timeline
no longer tiled the episode and frames fell through with no active
subtask label.
New deterministic post-step (no VLM call), default on via
PlanConfig.subtask_full_coverage:
* first subtask start pulled back to the episode's first frame t0
(idle / approach before the first labelled action folds into it)
* each subtask end snapped to the next subtask start (gaps closed)
* last subtask end extended to the last frame t_last
Runs after segment + verify in _generate_subtasks. Starts other than
the first are left as the VLM/verify produced them (already frame-
snapped + distinct), so the cover is contiguous and non-overlapping.
Disable with --plan.subtask_full_coverage=false if a consumer wants
sparse subtasks.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -73,6 +73,16 @@ class PlanConfig:
|
||||
# nothing). +1 VLM call/ep.
|
||||
subtask_verify: bool = True
|
||||
|
||||
# ``subtask_full_coverage``: deterministic post-step (no VLM call)
|
||||
# that stitches the surviving subtask spans into a contiguous cover
|
||||
# of the whole episode — first subtask pulled back to t0, each span's
|
||||
# end snapped to the next span's start, last span extended to t_last.
|
||||
# Without it the verify pass (which prunes spans) can leave the
|
||||
# subtask timeline starting late or full of gaps, so frames fall
|
||||
# through with no active subtask. On by default; disable only if a
|
||||
# downstream consumer genuinely wants sparse (non-tiling) subtasks.
|
||||
subtask_full_coverage: bool = True
|
||||
|
||||
# When True (and backend supports it, e.g. ``openai``), the ``plan``
|
||||
# module sends a ``video_url`` block pointing at a per-episode mp4
|
||||
# subclip and lets the server sample frames at ``use_video_url_fps``.
|
||||
|
||||
@@ -571,9 +571,52 @@ class PlanSubtasksMemoryModule:
|
||||
# ---- Pass 3 (optional): verification / pruning ---------------
|
||||
if getattr(self.config, "subtask_verify", False):
|
||||
cleaned = self._verify_subtasks(record, effective_task, cleaned)
|
||||
if not cleaned:
|
||||
return []
|
||||
|
||||
# ---- Full-episode coverage stitch ----------------------------
|
||||
# The VLM (especially after the verify pass prunes spans) can
|
||||
# leave the first subtask starting after t0 or leave gaps between
|
||||
# spans, so the subtask timeline no longer tiles the whole
|
||||
# episode and frames fall through with no active subtask. Stitch
|
||||
# the surviving spans into a contiguous cover of [t0, t_last].
|
||||
if getattr(self.config, "subtask_full_coverage", True):
|
||||
cleaned = self._stitch_full_coverage(cleaned, record)
|
||||
|
||||
return cleaned
|
||||
|
||||
def _stitch_full_coverage(
|
||||
self, spans: list[dict[str, Any]], record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Make subtask spans tile the full episode with no gaps.
|
||||
|
||||
* The first subtask starts at the episode's first frame ``t0``
|
||||
(any idle / approach before the first labelled action is folded
|
||||
into it), so every early frame has an active subtask.
|
||||
* Each subtask's ``end`` is snapped to the next subtask's
|
||||
``start`` (gaps between spans are closed), and the final
|
||||
subtask's ``end`` extends to the last frame ``t_last``.
|
||||
|
||||
Starts are otherwise left as the (already frame-snapped, distinct)
|
||||
values the VLM + verify produced — only the FIRST start is pulled
|
||||
back to ``t0``, which can't collide with a later span because it
|
||||
was already the earliest. Purely deterministic; runs after the
|
||||
VLM passes.
|
||||
"""
|
||||
if not spans or not record.frame_timestamps:
|
||||
return spans
|
||||
t0 = float(record.frame_timestamps[0])
|
||||
t_last = float(record.frame_timestamps[-1])
|
||||
spans = sorted(spans, key=lambda s: float(s["start"]))
|
||||
spans[0]["start"] = t0
|
||||
for i in range(len(spans) - 1):
|
||||
spans[i]["end"] = float(spans[i + 1]["start"])
|
||||
spans[-1]["end"] = t_last
|
||||
for s in spans:
|
||||
if float(s["end"]) < float(s["start"]):
|
||||
s["end"] = float(s["start"])
|
||||
return spans
|
||||
|
||||
def _clean_spans(
|
||||
self, spans: Any, record: EpisodeRecord
|
||||
) -> list[dict[str, Any]]:
|
||||
|
||||
Reference in New Issue
Block a user