add tests/fixes

This commit is contained in:
root
2026-03-11 22:49:06 +00:00
parent f0848c6887
commit 819c1b9710
8 changed files with 306 additions and 144 deletions

View File

@@ -7,13 +7,12 @@ from pathlib import Path
import torch
from lerobot.data_processing.data_annotations.subtask_annotations import Skill
from lerobot.utils.constants import (
SKILL_SEGMENTATION_PROMPT_TEMPLATE,
format_subtask_labels_section,
)
from lerobot.data_processing.data_annotations.subtask_annotations import Skill
class BaseVLM(ABC):
"""
@@ -85,9 +84,7 @@ def create_skill_segmentation_prompt(
if duration_seconds is None:
raise ValueError("duration_seconds is required for skill segmentation prompt")
goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else ""
subtask_labels_section = (
format_subtask_labels_section(subtask_labels) if subtask_labels else ""
)
subtask_labels_section = format_subtask_labels_section(subtask_labels) if subtask_labels else ""
video_duration_mm_ss = f"{int(duration_seconds // 60):02d}:{int(duration_seconds % 60):02d}"
return SKILL_SEGMENTATION_PROMPT_TEMPLATE.format(
goal_context=goal_context,
@@ -99,6 +96,7 @@ def create_skill_segmentation_prompt(
# Qwen2-VL Implementation
class Qwen2VL(BaseVLM):
"""Qwen2-VL model for skill segmentation."""
@@ -157,10 +155,12 @@ class Qwen2VL(BaseVLM):
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
response = self.processor.batch_decode(
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True,
)[0].strip()
@@ -176,10 +176,8 @@ class Qwen2VL(BaseVLM):
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
# Create messages for each video (prompt includes duration so each gets correct length)
all_messages = []
for video_path, duration in zip(video_paths, episode_durations):
prompt = create_skill_segmentation_prompt(
coarse_goal, subtask_labels, duration_seconds=duration
)
for video_path, duration in zip(video_paths, episode_durations, strict=True):
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
@@ -195,19 +193,19 @@ class Qwen2VL(BaseVLM):
},
]
all_messages.append(messages)
# Process all videos in batch
all_texts = []
all_image_inputs = []
all_video_inputs = []
for messages in all_messages:
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
all_texts.append(text)
all_image_inputs.extend(image_inputs or [])
all_video_inputs.extend(video_inputs or [])
inputs = self.processor(
text=all_texts,
images=all_image_inputs if all_image_inputs else None,
@@ -217,13 +215,15 @@ class Qwen2VL(BaseVLM):
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
responses = self.processor.batch_decode(
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True,
)
# Parse each response
all_skills = []
for idx, response in enumerate(responses):
@@ -235,7 +235,7 @@ class Qwen2VL(BaseVLM):
except Exception as e:
print(f"Warning: Failed to parse response for video {idx}: {e}")
all_skills.append([])
return all_skills
def _parse_skills_response(self, response: str) -> list[Skill]:
@@ -321,10 +321,12 @@ class Qwen3VL(BaseVLM):
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
response = self.processor.batch_decode(
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True,
)[0].strip()
@@ -340,10 +342,8 @@ class Qwen3VL(BaseVLM):
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
# Create messages for each video (prompt includes duration so each gets correct length)
all_messages = []
for video_path, duration in zip(video_paths, episode_durations):
prompt = create_skill_segmentation_prompt(
coarse_goal, subtask_labels, duration_seconds=duration
)
for video_path, duration in zip(video_paths, episode_durations, strict=True):
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
@@ -359,19 +359,19 @@ class Qwen3VL(BaseVLM):
},
]
all_messages.append(messages)
# Process all videos in batch
all_texts = []
all_image_inputs = []
all_video_inputs = []
for messages in all_messages:
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
all_texts.append(text)
all_image_inputs.extend(image_inputs or [])
all_video_inputs.extend(video_inputs or [])
inputs = self.processor(
text=all_texts,
images=all_image_inputs if all_image_inputs else None,
@@ -381,13 +381,15 @@ class Qwen3VL(BaseVLM):
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
responses = self.processor.batch_decode(
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True,
)
# Parse each response
all_skills = []
for idx, response in enumerate(responses):
@@ -399,7 +401,7 @@ class Qwen3VL(BaseVLM):
except Exception as e:
print(f"Warning: Failed to parse response for video {idx}: {e}")
all_skills.append([])
return all_skills
def _parse_skills_response(self, response: str) -> list[Skill]:
@@ -420,14 +422,14 @@ class Qwen3VL(BaseVLM):
data = json.loads(match.group())
skills_data = data.get("skills", [])
return [Skill.from_dict(s) for s in skills_data]
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
# Qwen3.5-VL Implementation (Qwen3_5ForConditionalGeneration)
class Qwen3_5VL(BaseVLM):
class Qwen35VL(BaseVLM):
"""Qwen3.5-VL model for skill segmentation (Qwen3_5ForConditionalGeneration)."""
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
@@ -486,7 +488,7 @@ class Qwen3_5VL(BaseVLM):
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
response = self.processor.batch_decode(
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0].strip()
@@ -502,10 +504,8 @@ class Qwen3_5VL(BaseVLM):
) -> list[list[Skill]]:
"""Segment multiple videos into skills using Qwen3.5-VL in a batch."""
all_messages = []
for video_path, duration in zip(video_paths, episode_durations):
prompt = create_skill_segmentation_prompt(
coarse_goal, subtask_labels, duration_seconds=duration
)
for video_path, duration in zip(video_paths, episode_durations, strict=True):
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
@@ -527,7 +527,9 @@ class Qwen3_5VL(BaseVLM):
all_video_inputs = []
for messages in all_messages:
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
image_inputs, video_inputs = self.process_vision_info(messages)
all_texts.append(text)
all_image_inputs.extend(image_inputs or [])
@@ -545,7 +547,7 @@ class Qwen3_5VL(BaseVLM):
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
responses = self.processor.batch_decode(
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
@@ -584,6 +586,7 @@ class Qwen3_5VL(BaseVLM):
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
# VLM Registry - Add new VLMs here
VLM_REGISTRY: dict[str, type[BaseVLM]] = {
@@ -594,8 +597,8 @@ VLM_REGISTRY: dict[str, type[BaseVLM]] = {
# Qwen3-VL variants (MoE)
"Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL,
# Qwen3.5-VL (Qwen3_5ForConditionalGeneration)
"Qwen/Qwen3.5-27B": Qwen3_5VL,
"Qwen/Qwen3-VL-8B-Instruct": Qwen3_5VL,
"Qwen/Qwen3.5-27B": Qwen35VL,
"Qwen/Qwen3-VL-8B-Instruct": Qwen35VL,
}
@@ -621,7 +624,7 @@ def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = to
# Check for partial matches (e.g., "qwen2" in model name)
model_lower = model_name.lower()
if "qwen3.5" in model_lower:
return Qwen3_5VL(model_name, device, torch_dtype)
return Qwen35VL(model_name, device, torch_dtype)
if "qwen3" in model_lower:
return Qwen3VL(model_name, device, torch_dtype)
elif "qwen2" in model_lower or "qwen-vl" in model_lower: