mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 11:51:25 +00:00
add tests/fixes
This commit is contained in:
2
src/lerobot/data_processing/data_annotations/__init__.py
Normal file
2
src/lerobot/data_processing/data_annotations/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
# Data annotations for subtasks and VLM-based labeling.
|
||||
@@ -5,13 +5,12 @@ from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from lerobot.datasets.dataset_tools import add_features
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
create_subtasks_dataframe,
|
||||
create_subtask_index_array,
|
||||
create_subtasks_dataframe,
|
||||
save_subtasks,
|
||||
)
|
||||
|
||||
@@ -57,6 +56,7 @@ class EpisodeSkills:
|
||||
|
||||
# Video Extraction Utilities
|
||||
|
||||
|
||||
class VideoExtractor:
|
||||
"""Utilities for extracting and processing video segments from LeRobot datasets."""
|
||||
|
||||
@@ -82,9 +82,8 @@ class VideoExtractor:
|
||||
Returns:
|
||||
Path to the extracted temporary video file
|
||||
"""
|
||||
tmp_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
||||
tmp_path = Path(tmp_file.name)
|
||||
tmp_file.close()
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
|
||||
tmp_path = Path(tmp_file.name)
|
||||
|
||||
duration = end_timestamp - start_timestamp
|
||||
|
||||
@@ -115,8 +114,8 @@ class VideoExtractor:
|
||||
subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"FFmpeg failed: {e}") from e
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError("FFmpeg not found. Please install ffmpeg.")
|
||||
except FileNotFoundError as e:
|
||||
raise RuntimeError("FFmpeg not found. Please install ffmpeg.") from e
|
||||
|
||||
if not tmp_path.exists() or tmp_path.stat().st_size < 1024:
|
||||
if tmp_path.exists():
|
||||
@@ -131,9 +130,8 @@ class VideoExtractor:
|
||||
Used so the VLM can read the timestamp from the image instead of relying on file metadata.
|
||||
Draws a black box with white text at top-right. Writes to a new temporary file and returns its path.
|
||||
"""
|
||||
out_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
||||
out_path = Path(out_file.name)
|
||||
out_file.close()
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as out_file:
|
||||
out_path = Path(out_file.name)
|
||||
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
@@ -271,7 +269,7 @@ class SkillAnnotator:
|
||||
# Check if skills list exists and is not empty
|
||||
if "skills" in episode_data and episode_data["skills"]:
|
||||
existing_episode_indices.add(idx)
|
||||
|
||||
|
||||
original_count = len(episode_indices)
|
||||
episode_indices = [ep for ep in episode_indices if ep not in existing_episode_indices]
|
||||
skipped_count = original_count - len(episode_indices)
|
||||
@@ -288,14 +286,16 @@ class SkillAnnotator:
|
||||
for batch_start in range(0, len(episode_indices), self.batch_size):
|
||||
batch_end = min(batch_start + self.batch_size, len(episode_indices))
|
||||
batch_episodes = episode_indices[batch_start:batch_end]
|
||||
|
||||
print(f"Processing batch {batch_start//self.batch_size + 1}/{(len(episode_indices) + self.batch_size - 1)//self.batch_size} (episodes {batch_episodes[0]} to {batch_episodes[-1]})...")
|
||||
|
||||
print(
|
||||
f"Processing batch {batch_start // self.batch_size + 1}/{(len(episode_indices) + self.batch_size - 1) // self.batch_size} (episodes {batch_episodes[0]} to {batch_episodes[-1]})..."
|
||||
)
|
||||
|
||||
try:
|
||||
batch_annotations = self._annotate_episodes_batch(
|
||||
dataset, batch_episodes, video_key, coarse_goal, subtask_labels
|
||||
)
|
||||
|
||||
|
||||
for ep_idx in batch_episodes:
|
||||
if ep_idx in batch_annotations and batch_annotations[ep_idx]:
|
||||
skills = batch_annotations[ep_idx]
|
||||
@@ -337,9 +337,7 @@ class SkillAnnotator:
|
||||
for ep_idx, error_msg in list(failed_episodes.items()):
|
||||
print(f"Retry attempt for episode {ep_idx} (previous error: {error_msg})")
|
||||
try:
|
||||
skills = self._annotate_episode(
|
||||
dataset, ep_idx, video_key, coarse_goal, subtask_labels
|
||||
)
|
||||
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal, subtask_labels)
|
||||
if skills:
|
||||
annotations[ep_idx] = EpisodeSkills(
|
||||
episode_index=ep_idx,
|
||||
@@ -354,10 +352,10 @@ class SkillAnnotator:
|
||||
except Exception as retry_error:
|
||||
failed_episodes[ep_idx] = str(retry_error)
|
||||
print(f"✗ Episode {ep_idx} (retry) failed: {retry_error}")
|
||||
|
||||
|
||||
if retry_count > 0:
|
||||
print(f"Successfully recovered {retry_count} episodes on retry")
|
||||
|
||||
|
||||
if failed_episodes:
|
||||
print(f"\n⚠ Warning: {len(failed_episodes)} episodes still failed after retry:")
|
||||
for ep_idx, error_msg in failed_episodes.items():
|
||||
@@ -391,27 +389,27 @@ class SkillAnnotator:
|
||||
paths_for_vlm = []
|
||||
durations = []
|
||||
valid_episode_indices = []
|
||||
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
try:
|
||||
# Get video path and timestamps
|
||||
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
|
||||
|
||||
|
||||
if not video_path.exists():
|
||||
print(f"Warning: Video not found for episode {ep_idx}")
|
||||
continue
|
||||
|
||||
|
||||
# Get episode timestamps from metadata
|
||||
ep = dataset.meta.episodes[ep_idx]
|
||||
start_ts = float(ep[f"videos/{video_key}/from_timestamp"])
|
||||
end_ts = float(ep[f"videos/{video_key}/to_timestamp"])
|
||||
duration = end_ts - start_ts
|
||||
|
||||
|
||||
# Extract episode segment to temporary file
|
||||
extracted_path = self.video_extractor.extract_episode_video(
|
||||
video_path, start_ts, end_ts, target_fps=dataset.meta.fps
|
||||
)
|
||||
|
||||
|
||||
if self.add_timer_overlay:
|
||||
video_for_vlm = self.video_extractor.add_timer_overlay(extracted_path)
|
||||
extracted_paths.append(extracted_path)
|
||||
@@ -424,27 +422,25 @@ class SkillAnnotator:
|
||||
paths_for_vlm.append(video_for_vlm)
|
||||
durations.append(duration)
|
||||
valid_episode_indices.append(ep_idx)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to extract video for episode {ep_idx}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
if not paths_for_vlm:
|
||||
return {}
|
||||
|
||||
|
||||
try:
|
||||
# Run VLM skill segmentation in batch
|
||||
all_skills = self.vlm.segment_skills_batch(
|
||||
paths_for_vlm, durations, coarse_goal, subtask_labels
|
||||
)
|
||||
|
||||
all_skills = self.vlm.segment_skills_batch(paths_for_vlm, durations, coarse_goal, subtask_labels)
|
||||
|
||||
# Map results back to episode indices
|
||||
results = {}
|
||||
for ep_idx, skills in zip(valid_episode_indices, all_skills):
|
||||
for ep_idx, skills in zip(valid_episode_indices, all_skills, strict=True):
|
||||
results[ep_idx] = skills
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
finally:
|
||||
# Clean up all temporary files (extracted and timer-overlay)
|
||||
for path in extracted_paths:
|
||||
@@ -486,9 +482,7 @@ class SkillAnnotator:
|
||||
|
||||
try:
|
||||
# Run VLM skill segmentation
|
||||
skills = self.vlm.segment_skills(
|
||||
video_for_vlm, duration, coarse_goal, subtask_labels
|
||||
)
|
||||
skills = self.vlm.segment_skills(video_for_vlm, duration, coarse_goal, subtask_labels)
|
||||
return skills
|
||||
finally:
|
||||
# Clean up temporary files (extracted and optionally timer-overlay)
|
||||
@@ -552,7 +546,7 @@ def save_skill_annotations(
|
||||
# Step 1: Create subtasks DataFrame
|
||||
print("Creating subtasks DataFrame...")
|
||||
subtasks_df, skill_to_subtask_idx = create_subtasks_dataframe(annotations)
|
||||
|
||||
|
||||
# Step 2: Create subtask_index array for all frames
|
||||
print("Creating subtask_index array...")
|
||||
subtask_indices = create_subtask_index_array(dataset, annotations, skill_to_subtask_idx)
|
||||
@@ -563,41 +557,47 @@ def save_skill_annotations(
|
||||
# Step 4: Save the raw skill annotations as JSON for reference
|
||||
skills_path = dataset.root / "meta" / "skills.json"
|
||||
skills_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Load existing skills data if it exists and is not empty
|
||||
existing_skills_data = None
|
||||
if skills_path.exists():
|
||||
try:
|
||||
with open(skills_path, "r") as f:
|
||||
with open(skills_path) as f:
|
||||
existing_skills_data = json.load(f)
|
||||
if existing_skills_data and len(existing_skills_data.get("episodes", {})) > 0:
|
||||
print(f"Found existing skills.json with {len(existing_skills_data.get('episodes', {}))} episodes, merging...")
|
||||
except (json.JSONDecodeError, IOError):
|
||||
print(
|
||||
f"Found existing skills.json with {len(existing_skills_data.get('episodes', {}))} episodes, merging..."
|
||||
)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
print("Warning: Could not load existing skills.json, will create new file")
|
||||
existing_skills_data = None
|
||||
|
||||
|
||||
# Prepare new annotations
|
||||
new_episodes = {str(ep_idx): ann.to_dict() for ep_idx, ann in annotations.items()}
|
||||
|
||||
|
||||
# Merge with existing data if available
|
||||
if existing_skills_data:
|
||||
# Preserve existing episodes that are not being updated
|
||||
merged_episodes = existing_skills_data.get("episodes", {}).copy()
|
||||
merged_episodes.update(new_episodes)
|
||||
|
||||
|
||||
# Merge skill_to_subtask_index mappings
|
||||
merged_skill_to_subtask = existing_skills_data.get("skill_to_subtask_index", {}).copy()
|
||||
merged_skill_to_subtask.update(skill_to_subtask_idx)
|
||||
|
||||
|
||||
# Use existing coarse_description if available, otherwise use new one
|
||||
coarse_desc = existing_skills_data.get("coarse_description", annotations[next(iter(annotations))].description)
|
||||
|
||||
coarse_desc = existing_skills_data.get(
|
||||
"coarse_description", annotations[next(iter(annotations))].description
|
||||
)
|
||||
|
||||
skills_data = {
|
||||
"coarse_description": coarse_desc,
|
||||
"skill_to_subtask_index": merged_skill_to_subtask,
|
||||
"episodes": merged_episodes,
|
||||
}
|
||||
print(f"Updated {len(new_episodes)} episode(s), total episodes in skills.json: {len(merged_episodes)}")
|
||||
print(
|
||||
f"Updated {len(new_episodes)} episode(s), total episodes in skills.json: {len(merged_episodes)}"
|
||||
)
|
||||
else:
|
||||
# No existing data, create new
|
||||
skills_data = {
|
||||
@@ -613,16 +613,13 @@ def save_skill_annotations(
|
||||
|
||||
# Step 5: Add subtask_index feature to dataset using add_features
|
||||
print("Adding subtask_index feature to dataset...")
|
||||
|
||||
|
||||
# Determine output directory and repo_id
|
||||
if output_dir is None:
|
||||
output_dir = dataset.root.parent / f"{dataset.root.name}"
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
output_dir = dataset.root.parent / f"{dataset.root.name}" if output_dir is None else Path(output_dir)
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}"
|
||||
|
||||
|
||||
# Add feature using dataset_tools
|
||||
feature_info = {
|
||||
"dtype": "int64",
|
||||
@@ -637,22 +634,17 @@ def save_skill_annotations(
|
||||
output_dir=output_dir,
|
||||
repo_id=repo_id,
|
||||
)
|
||||
|
||||
|
||||
# Copy subtasks.parquet to new output directory
|
||||
import shutil
|
||||
shutil.copy(
|
||||
dataset.root / "meta" / "subtasks.parquet",
|
||||
output_dir / "meta" / "subtasks.parquet"
|
||||
)
|
||||
shutil.copy(
|
||||
dataset.root / "meta" / "skills.json",
|
||||
output_dir / "meta" / "skills.json"
|
||||
)
|
||||
|
||||
|
||||
shutil.copy(dataset.root / "meta" / "subtasks.parquet", output_dir / "meta" / "subtasks.parquet")
|
||||
shutil.copy(dataset.root / "meta" / "skills.json", output_dir / "meta" / "skills.json")
|
||||
|
||||
print(" Successfully added subtask_index feature!")
|
||||
print(f" New dataset saved to: {new_dataset.root}")
|
||||
print(f" Total subtasks: {len(subtasks_df)}")
|
||||
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
|
||||
@@ -7,13 +7,12 @@ from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.data_processing.data_annotations.subtask_annotations import Skill
|
||||
from lerobot.utils.constants import (
|
||||
SKILL_SEGMENTATION_PROMPT_TEMPLATE,
|
||||
format_subtask_labels_section,
|
||||
)
|
||||
|
||||
from lerobot.data_processing.data_annotations.subtask_annotations import Skill
|
||||
|
||||
|
||||
class BaseVLM(ABC):
|
||||
"""
|
||||
@@ -85,9 +84,7 @@ def create_skill_segmentation_prompt(
|
||||
if duration_seconds is None:
|
||||
raise ValueError("duration_seconds is required for skill segmentation prompt")
|
||||
goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else ""
|
||||
subtask_labels_section = (
|
||||
format_subtask_labels_section(subtask_labels) if subtask_labels else ""
|
||||
)
|
||||
subtask_labels_section = format_subtask_labels_section(subtask_labels) if subtask_labels else ""
|
||||
video_duration_mm_ss = f"{int(duration_seconds // 60):02d}:{int(duration_seconds % 60):02d}"
|
||||
return SKILL_SEGMENTATION_PROMPT_TEMPLATE.format(
|
||||
goal_context=goal_context,
|
||||
@@ -99,6 +96,7 @@ def create_skill_segmentation_prompt(
|
||||
|
||||
# Qwen2-VL Implementation
|
||||
|
||||
|
||||
class Qwen2VL(BaseVLM):
|
||||
"""Qwen2-VL model for skill segmentation."""
|
||||
|
||||
@@ -157,10 +155,12 @@ class Qwen2VL(BaseVLM):
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
||||
generated_ids = self.model.generate(
|
||||
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||
)
|
||||
|
||||
response = self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||
skip_special_tokens=True,
|
||||
)[0].strip()
|
||||
|
||||
@@ -176,10 +176,8 @@ class Qwen2VL(BaseVLM):
|
||||
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
|
||||
# Create messages for each video (prompt includes duration so each gets correct length)
|
||||
all_messages = []
|
||||
for video_path, duration in zip(video_paths, episode_durations):
|
||||
prompt = create_skill_segmentation_prompt(
|
||||
coarse_goal, subtask_labels, duration_seconds=duration
|
||||
)
|
||||
for video_path, duration in zip(video_paths, episode_durations, strict=True):
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
|
||||
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
@@ -195,19 +193,19 @@ class Qwen2VL(BaseVLM):
|
||||
},
|
||||
]
|
||||
all_messages.append(messages)
|
||||
|
||||
|
||||
# Process all videos in batch
|
||||
all_texts = []
|
||||
all_image_inputs = []
|
||||
all_video_inputs = []
|
||||
|
||||
|
||||
for messages in all_messages:
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
all_texts.append(text)
|
||||
all_image_inputs.extend(image_inputs or [])
|
||||
all_video_inputs.extend(video_inputs or [])
|
||||
|
||||
|
||||
inputs = self.processor(
|
||||
text=all_texts,
|
||||
images=all_image_inputs if all_image_inputs else None,
|
||||
@@ -217,13 +215,15 @@ class Qwen2VL(BaseVLM):
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
||||
generated_ids = self.model.generate(
|
||||
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||
)
|
||||
|
||||
responses = self.processor.batch_decode(
|
||||
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
|
||||
# Parse each response
|
||||
all_skills = []
|
||||
for idx, response in enumerate(responses):
|
||||
@@ -235,7 +235,7 @@ class Qwen2VL(BaseVLM):
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse response for video {idx}: {e}")
|
||||
all_skills.append([])
|
||||
|
||||
|
||||
return all_skills
|
||||
|
||||
def _parse_skills_response(self, response: str) -> list[Skill]:
|
||||
@@ -321,10 +321,12 @@ class Qwen3VL(BaseVLM):
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
||||
generated_ids = self.model.generate(
|
||||
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||
)
|
||||
|
||||
response = self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||
skip_special_tokens=True,
|
||||
)[0].strip()
|
||||
|
||||
@@ -340,10 +342,8 @@ class Qwen3VL(BaseVLM):
|
||||
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
|
||||
# Create messages for each video (prompt includes duration so each gets correct length)
|
||||
all_messages = []
|
||||
for video_path, duration in zip(video_paths, episode_durations):
|
||||
prompt = create_skill_segmentation_prompt(
|
||||
coarse_goal, subtask_labels, duration_seconds=duration
|
||||
)
|
||||
for video_path, duration in zip(video_paths, episode_durations, strict=True):
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
|
||||
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
@@ -359,19 +359,19 @@ class Qwen3VL(BaseVLM):
|
||||
},
|
||||
]
|
||||
all_messages.append(messages)
|
||||
|
||||
|
||||
# Process all videos in batch
|
||||
all_texts = []
|
||||
all_image_inputs = []
|
||||
all_video_inputs = []
|
||||
|
||||
|
||||
for messages in all_messages:
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
all_texts.append(text)
|
||||
all_image_inputs.extend(image_inputs or [])
|
||||
all_video_inputs.extend(video_inputs or [])
|
||||
|
||||
|
||||
inputs = self.processor(
|
||||
text=all_texts,
|
||||
images=all_image_inputs if all_image_inputs else None,
|
||||
@@ -381,13 +381,15 @@ class Qwen3VL(BaseVLM):
|
||||
).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
||||
generated_ids = self.model.generate(
|
||||
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
|
||||
)
|
||||
|
||||
responses = self.processor.batch_decode(
|
||||
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
|
||||
# Parse each response
|
||||
all_skills = []
|
||||
for idx, response in enumerate(responses):
|
||||
@@ -399,7 +401,7 @@ class Qwen3VL(BaseVLM):
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to parse response for video {idx}: {e}")
|
||||
all_skills.append([])
|
||||
|
||||
|
||||
return all_skills
|
||||
|
||||
def _parse_skills_response(self, response: str) -> list[Skill]:
|
||||
@@ -420,14 +422,14 @@ class Qwen3VL(BaseVLM):
|
||||
data = json.loads(match.group())
|
||||
skills_data = data.get("skills", [])
|
||||
return [Skill.from_dict(s) for s in skills_data]
|
||||
|
||||
|
||||
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
|
||||
|
||||
|
||||
# Qwen3.5-VL Implementation (Qwen3_5ForConditionalGeneration)
|
||||
|
||||
|
||||
class Qwen3_5VL(BaseVLM):
|
||||
class Qwen35VL(BaseVLM):
|
||||
"""Qwen3.5-VL model for skill segmentation (Qwen3_5ForConditionalGeneration)."""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
|
||||
@@ -486,7 +488,7 @@ class Qwen3_5VL(BaseVLM):
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
|
||||
|
||||
response = self.processor.batch_decode(
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)[0].strip()
|
||||
@@ -502,10 +504,8 @@ class Qwen3_5VL(BaseVLM):
|
||||
) -> list[list[Skill]]:
|
||||
"""Segment multiple videos into skills using Qwen3.5-VL in a batch."""
|
||||
all_messages = []
|
||||
for video_path, duration in zip(video_paths, episode_durations):
|
||||
prompt = create_skill_segmentation_prompt(
|
||||
coarse_goal, subtask_labels, duration_seconds=duration
|
||||
)
|
||||
for video_path, duration in zip(video_paths, episode_durations, strict=True):
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
|
||||
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
|
||||
messages = [
|
||||
{"role": "system", "content": [{"type": "text", "text": prompt}]},
|
||||
@@ -527,7 +527,9 @@ class Qwen3_5VL(BaseVLM):
|
||||
all_video_inputs = []
|
||||
|
||||
for messages in all_messages:
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
|
||||
)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
all_texts.append(text)
|
||||
all_image_inputs.extend(image_inputs or [])
|
||||
@@ -545,7 +547,7 @@ class Qwen3_5VL(BaseVLM):
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
|
||||
|
||||
responses = self.processor.batch_decode(
|
||||
[out[len(inp):] for inp, out in zip(inputs.input_ids, generated_ids)],
|
||||
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
@@ -584,6 +586,7 @@ class Qwen3_5VL(BaseVLM):
|
||||
|
||||
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
|
||||
|
||||
|
||||
# VLM Registry - Add new VLMs here
|
||||
|
||||
VLM_REGISTRY: dict[str, type[BaseVLM]] = {
|
||||
@@ -594,8 +597,8 @@ VLM_REGISTRY: dict[str, type[BaseVLM]] = {
|
||||
# Qwen3-VL variants (MoE)
|
||||
"Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL,
|
||||
# Qwen3.5-VL (Qwen3_5ForConditionalGeneration)
|
||||
"Qwen/Qwen3.5-27B": Qwen3_5VL,
|
||||
"Qwen/Qwen3-VL-8B-Instruct": Qwen3_5VL,
|
||||
"Qwen/Qwen3.5-27B": Qwen35VL,
|
||||
"Qwen/Qwen3-VL-8B-Instruct": Qwen35VL,
|
||||
}
|
||||
|
||||
|
||||
@@ -621,7 +624,7 @@ def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = to
|
||||
# Check for partial matches (e.g., "qwen2" in model name)
|
||||
model_lower = model_name.lower()
|
||||
if "qwen3.5" in model_lower:
|
||||
return Qwen3_5VL(model_name, device, torch_dtype)
|
||||
return Qwen35VL(model_name, device, torch_dtype)
|
||||
if "qwen3" in model_lower:
|
||||
return Qwen3VL(model_name, device, torch_dtype)
|
||||
elif "qwen2" in model_lower or "qwen-vl" in model_lower:
|
||||
|
||||
Reference in New Issue
Block a user