annotate: ruff lint + format pass

Quality-gate fixes after the main merge:
  * UP037: drop redundant quotes from PlanConfig forward-ref annotations
    (action_records / task_aug_axes) — safe under 'from __future__ import
    annotations'.
  * ruff format applied to config.py, executor.py, general_vqa.py,
    plan_subtasks_memory.py, validator.py, lerobot_annotate.py.

No behavior change.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-06-02 17:38:18 +02:00
parent 3662c41b85
commit 53c7b4c69a
6 changed files with 54 additions and 50 deletions

View File

@@ -116,7 +116,7 @@ class PlanConfig:
# that record back to canonical subtask text — reducing the VLM's
# "creative" surface to just the perception step. See
# ``ActionRecordsConfig`` for details. Off by default (back-compat).
action_records: "ActionRecordsConfig" = field(default_factory=lambda: ActionRecordsConfig())
action_records: ActionRecordsConfig = field(default_factory=lambda: ActionRecordsConfig())
# Structured 5-axis augmentation taxonomy for the t=0 task variants
# (replaces the free-form ``n_task_rephrasings`` flow when enabled).
@@ -124,7 +124,7 @@ class PlanConfig:
# free-form rephrasings, the VLM produces variants along named
# axes (synonym / omit_arm / omit_orientation / omit_grasp_method /
# combined). Off by default (back-compat).
task_aug_axes: "TaskAugAxesConfig" = field(default_factory=lambda: TaskAugAxesConfig())
task_aug_axes: TaskAugAxesConfig = field(default_factory=lambda: TaskAugAxesConfig())
@dataclass
@@ -136,12 +136,12 @@ class ActionRecordsConfig:
subtask to extract a typed record::
{
"verb": "pick" | "place" | "press" | ..., # closed vocabulary
"object": "<canonical_object_name>",
"arm": "left" | "right" | "both" | null,
"grasp_type": "pinch" | "wrap" | "hook" | ... | null,
"destination": "<canonical_destination>" | null,
"mistake": "<short text>" | null,
"verb": "pick" | "place" | "press" | ..., # closed vocabulary
"object": "<canonical_object_name>",
"arm": "left" | "right" | "both" | null,
"grasp_type": "pinch" | "wrap" | "hook" | ... | null,
"destination": "<canonical_destination>" | null,
"mistake": "<short text>" | null,
}
The record is emitted as a separate row with ``style="action_record"``
@@ -176,16 +176,34 @@ class ActionRecordsConfig:
# exactly one. Override per-dataset (e.g. ``["pick", "place", "open",
# "close"]`` for door-only manipulation) for tighter constraint.
verb_vocabulary: tuple[str, ...] = (
"pick", "place", "push", "pull", "open", "close", "turn",
"press", "lift", "insert", "pour", "move", "reach", "grasp",
"release", "wipe", "dump",
"pick",
"place",
"push",
"pull",
"open",
"close",
"turn",
"press",
"lift",
"insert",
"pour",
"move",
"reach",
"grasp",
"release",
"wipe",
"dump",
)
# Closed grasp-type vocabulary. ``null`` is always allowed (no
# contact / unclear). Adjust per-hardware (e.g. drop ``hook`` /
# ``key`` for parallel-jaw grippers).
grasp_vocabulary: tuple[str, ...] = (
"pinch", "wrap", "hook", "key", "lateral",
"pinch",
"wrap",
"hook",
"key",
"lateral",
)

View File

@@ -238,9 +238,7 @@ class Executor:
prompt path is reused.
"""
if not self.plan.enabled or not self.interjections.enabled:
return PhaseResult(
name="plan_update", episodes_processed=0, episodes_skipped=len(records)
)
return PhaseResult(name="plan_update", episodes_processed=0, episodes_skipped=len(records))
processed = 0
for record in records:
staging = EpisodeStaging(staging_dir, record.episode_index)

View File

@@ -206,9 +206,7 @@ class GeneralVqaModule:
episode_task=record.episode_task,
question_type=question_type,
)
images = self.frame_provider.frames_at(
record, [frame_timestamp], camera_key=camera_key
)
images = self.frame_provider.frames_at(record, [frame_timestamp], camera_key=camera_key)
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
return [{"role": "user", "content": content}]

View File

@@ -172,9 +172,7 @@ class PlanSubtasksMemoryModule:
# "what's still left" at inference time.
for span in subtask_spans:
boundary_t = snap_to_frame(span["start"], record.frame_timestamps)
plan_text = self._generate_plan(
record, subtask_spans, refresh_t=boundary_t, task=effective_task
)
plan_text = self._generate_plan(record, subtask_spans, refresh_t=boundary_t, task=effective_task)
if plan_text is not None:
rows.append(
{
@@ -336,7 +334,9 @@ class PlanSubtasksMemoryModule:
if not frames:
logger.debug(
"action_record: no frames at span %.2f-%.2f for ep %s; skipping",
start_t, end_t, record.episode_index,
start_t,
end_t,
record.episode_index,
)
return None
@@ -811,12 +811,15 @@ class PlanSubtasksMemoryModule:
import json # noqa: PLC0415
subtasks_json = json.dumps(
{"subtasks": [{"text": s["text"], "start": round(s["start"], 3), "end": round(s["end"], 3)} for s in spans]},
{
"subtasks": [
{"text": s["text"], "start": round(s["start"], 3), "end": round(s["end"], 3)}
for s in spans
]
},
indent=2,
)
prompt = load_prompt("module_1_subtask_verify").format(
episode_task=task, subtasks_json=subtasks_json
)
prompt = load_prompt("module_1_subtask_verify").format(episode_task=task, subtasks_json=subtasks_json)
kept_raw = self._vlm_field(self._video_message(record, prompt, window=window), "subtasks")
# Windowed verify: the video is sampled from the absolute window
# ``[w0, w1]`` but the model perceives it as a clip starting at 0,
@@ -824,9 +827,7 @@ class PlanSubtasksMemoryModule:
# Clamp to that relative range and skip the absolute frame-snap
# dedupe (done once later on the merged absolute-time set).
clamp = (0.0, float(window[1] - window[0])) if window is not None else None
kept = self._clean_spans(
kept_raw, record, bounds=clamp, dedupe=window is None
)
kept = self._clean_spans(kept_raw, record, bounds=clamp, dedupe=window is None)
if not kept:
logger.info(
"episode %d: verify pass returned nothing — keeping the %d "
@@ -927,17 +928,13 @@ class PlanSubtasksMemoryModule:
if not subtask_spans:
return None
remaining = [
s
for s in subtask_spans
if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
s for s in subtask_spans if refresh_t is None or float(s.get("start", 0.0)) >= float(refresh_t)
]
if not remaining:
# Past the last subtask boundary on a late refresh — nothing
# left to plan; emit None so the caller skips the row.
return None
return "\n".join(
f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1)
)
return "\n".join(f"{i}. {span.get('text', '').strip()}" for i, span in enumerate(remaining, start=1))
def _generate_memory(
self,

View File

@@ -137,9 +137,7 @@ class StagingValidator:
persistent: list[dict[str, Any]] = []
for row in all_rows:
self._check_column_routing(row, report, record.episode_index)
self._check_camera_field(
row, report, record.episode_index, self.dataset_camera_keys
)
self._check_camera_field(row, report, record.episode_index, self.dataset_camera_keys)
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
persistent.append(row)
else:
@@ -166,15 +164,9 @@ class StagingValidator:
try:
validate_camera_field(style, camera)
except ValueError as exc:
report.add_error(
f"ep={episode_index} module={row.get('_module')}: {exc}"
)
report.add_error(f"ep={episode_index} module={row.get('_module')}: {exc}")
return
if (
is_view_dependent_style(style)
and dataset_camera_keys
and camera not in dataset_camera_keys
):
if is_view_dependent_style(style) and dataset_camera_keys and camera not in dataset_camera_keys:
report.add_error(
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"

View File

@@ -64,9 +64,7 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
logger.info("annotate: root=%s", root)
vlm = make_vlm_client(cfg.vlm)
frame_provider = make_frame_provider(
root, camera_key=cfg.vlm.camera_key, video_backend=cfg.video_backend
)
frame_provider = make_frame_provider(root, camera_key=cfg.vlm.camera_key, video_backend=cfg.video_backend)
# Surface the resolved cameras up front so a silent vqa-module no-op
# is obvious in job output rather than discovered post-hoc by counting
# parquet rows.
@@ -168,7 +166,10 @@ def _push_to_hub(root: Path, cfg: AnnotationPipelineConfig) -> None:
if isinstance(ds_version, str) and ds_version.startswith("v"):
version_tag = ds_version
except Exception as exc: # noqa: BLE001
print(f"[lerobot-annotate] could not read codebase_version from info.json ({exc}); falling back to {version_tag}", flush=True)
print(
f"[lerobot-annotate] could not read codebase_version from info.json ({exc}); falling back to {version_tag}",
flush=True,
)
revision = getattr(commit_info, "oid", None)
tag_kwargs = {
"repo_id": repo_id,