mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
fix: pass video_kwargs from process_vision_info to Qwen processor
The Qwen processor needs fps metadata (via return_video_kwargs=True) to compute correct temporal position embeddings. Without it, the processor defaults to fps=24 regardless of the actual video fps, causing shape mismatches between expected and actual video tokens. Made-with: Cursor
This commit is contained in:
@@ -15,15 +15,12 @@
|
||||
# VLM Interface (Abstract Base Class for Modularity)
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
os.environ.setdefault("FORCE_QWENVL_VIDEO_READER", "torchvision")
|
||||
|
||||
from lerobot.data_processing.data_annotations.subtask_annotations import Skill
|
||||
from lerobot.utils.constants import (
|
||||
SKILL_SEGMENTATION_PROMPT_TEMPLATE,
|
||||
@@ -162,11 +159,14 @@ class Qwen2VL(BaseVLM):
|
||||
]
|
||||
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
image_inputs, video_inputs, video_kwargs = self.process_vision_info(
|
||||
messages, return_video_kwargs=True
|
||||
)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
**video_kwargs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
@@ -191,7 +191,6 @@ class Qwen2VL(BaseVLM):
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[list[Skill]]:
|
||||
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
|
||||
# Create messages for each video (prompt includes duration so each gets correct length)
|
||||
all_messages = []
|
||||
for video_path, duration in zip(video_paths, episode_durations, strict=True):
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
|
||||
@@ -211,22 +210,26 @@ class Qwen2VL(BaseVLM):
|
||||
]
|
||||
all_messages.append(messages)
|
||||
|
||||
# Process all videos in batch
|
||||
all_texts = []
|
||||
all_image_inputs = []
|
||||
all_video_inputs = []
|
||||
all_video_kwargs: dict = {"do_sample_frames": False, "fps": []}
|
||||
|
||||
for messages in all_messages:
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
image_inputs, video_inputs, video_kwargs = self.process_vision_info(
|
||||
messages, return_video_kwargs=True
|
||||
)
|
||||
all_texts.append(text)
|
||||
all_image_inputs.extend(image_inputs or [])
|
||||
all_video_inputs.extend(video_inputs or [])
|
||||
all_video_kwargs["fps"].extend(video_kwargs.get("fps", []))
|
||||
|
||||
inputs = self.processor(
|
||||
text=all_texts,
|
||||
images=all_image_inputs if all_image_inputs else None,
|
||||
videos=all_video_inputs if all_video_inputs else None,
|
||||
**all_video_kwargs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
@@ -335,11 +338,14 @@ class Qwen3VL(BaseVLM):
|
||||
]
|
||||
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
image_inputs, video_inputs, video_kwargs = self.process_vision_info(
|
||||
messages, return_video_kwargs=True
|
||||
)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
**video_kwargs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
@@ -363,7 +369,6 @@ class Qwen3VL(BaseVLM):
|
||||
subtask_labels: list[str] | None = None,
|
||||
) -> list[list[Skill]]:
|
||||
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
|
||||
# Create messages for each video (prompt includes duration so each gets correct length)
|
||||
all_messages = []
|
||||
for video_path, duration in zip(video_paths, episode_durations, strict=True):
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
|
||||
@@ -383,22 +388,26 @@ class Qwen3VL(BaseVLM):
|
||||
]
|
||||
all_messages.append(messages)
|
||||
|
||||
# Process all videos in batch
|
||||
all_texts = []
|
||||
all_image_inputs = []
|
||||
all_video_inputs = []
|
||||
all_video_kwargs: dict = {"do_sample_frames": False, "fps": []}
|
||||
|
||||
for messages in all_messages:
|
||||
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
image_inputs, video_inputs, video_kwargs = self.process_vision_info(
|
||||
messages, return_video_kwargs=True
|
||||
)
|
||||
all_texts.append(text)
|
||||
all_image_inputs.extend(image_inputs or [])
|
||||
all_video_inputs.extend(video_inputs or [])
|
||||
all_video_kwargs["fps"].extend(video_kwargs.get("fps", []))
|
||||
|
||||
inputs = self.processor(
|
||||
text=all_texts,
|
||||
images=all_image_inputs if all_image_inputs else None,
|
||||
videos=all_video_inputs if all_video_inputs else None,
|
||||
**all_video_kwargs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
@@ -501,11 +510,14 @@ class Qwen35VL(BaseVLM):
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
|
||||
)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
image_inputs, video_inputs, video_kwargs = self.process_vision_info(
|
||||
messages, return_video_kwargs=True
|
||||
)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
**video_kwargs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
@@ -550,20 +562,25 @@ class Qwen35VL(BaseVLM):
|
||||
all_texts = []
|
||||
all_image_inputs = []
|
||||
all_video_inputs = []
|
||||
all_video_kwargs: dict = {"do_sample_frames": False, "fps": []}
|
||||
|
||||
for messages in all_messages:
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
|
||||
)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
image_inputs, video_inputs, video_kwargs = self.process_vision_info(
|
||||
messages, return_video_kwargs=True
|
||||
)
|
||||
all_texts.append(text)
|
||||
all_image_inputs.extend(image_inputs or [])
|
||||
all_video_inputs.extend(video_inputs or [])
|
||||
all_video_kwargs["fps"].extend(video_kwargs.get("fps", []))
|
||||
|
||||
inputs = self.processor(
|
||||
text=all_texts,
|
||||
images=all_image_inputs if all_image_inputs else None,
|
||||
videos=all_video_inputs if all_video_inputs else None,
|
||||
**all_video_kwargs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
|
||||
Reference in New Issue
Block a user