Files
lerobot-clone/src/lerobot/data_processing/data_annotations/vlm_annotations.py

678 lines
26 KiB
Python
Raw Normal View History

2026-03-11 23:14:22 +00:00
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2026-03-11 19:51:48 +00:00
# VLM Interface (Abstract Base Class for Modularity)
import json
import re
from abc import ABC, abstractmethod
from pathlib import Path
import torch
2026-03-11 22:49:06 +00:00
from lerobot.data_processing.data_annotations.subtask_annotations import Skill
2026-03-11 19:51:48 +00:00
from lerobot.utils.constants import (
SKILL_SEGMENTATION_PROMPT_TEMPLATE,
format_subtask_labels_section,
)
class BaseVLM(ABC):
"""
Abstract base class for Vision-Language Models.
To add a new VLM:
1. Create a subclass of BaseVLM
2. Implement the `__init__`, `segment_skills`, and `segment_skills_batch` methods
3. Register it in the VLM_REGISTRY dictionary
"""
@abstractmethod
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
"""Initialize the VLM with model name, device, and dtype."""
pass
@abstractmethod
def segment_skills(
self,
video_path: Path,
episode_duration: float,
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
"""
Segment a video into atomic skills.
Args:
video_path: Path to the video file
episode_duration: Total duration of the episode in seconds
coarse_goal: Optional high-level task description
subtask_labels: If provided, model must choose only from these labels (closed vocabulary)
Returns:
List of Skill objects representing atomic manipulation skills
"""
pass
@abstractmethod
def segment_skills_batch(
self,
video_paths: list[Path],
episode_durations: list[float],
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[list[Skill]]:
"""
Segment multiple videos into atomic skills in a single batch.
Args:
video_paths: List of paths to video files
episode_durations: List of episode durations in seconds
coarse_goal: Optional high-level task description
Returns:
List of skill lists, one for each video
"""
pass
def _unpack_video_inputs(
video_inputs: list | None,
) -> tuple[list | None, list[dict] | None]:
"""Unpack (tensor, metadata) tuples returned by process_vision_info with return_video_metadata=True."""
if not video_inputs:
return None, None
videos = [v[0] for v in video_inputs]
metadata = [v[1] for v in video_inputs]
return videos, metadata
2026-03-11 19:51:48 +00:00
def create_skill_segmentation_prompt(
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
duration_seconds: float | None = None,
) -> str:
"""Create the prompt for skill segmentation using the template from constants.
duration_seconds is required. When subtask_labels is provided, uses closed-vocabulary section.
"""
if duration_seconds is None:
raise ValueError("duration_seconds is required for skill segmentation prompt")
goal_context = f'The overall goal is: "{coarse_goal}"\n\n' if coarse_goal else ""
2026-03-11 22:49:06 +00:00
subtask_labels_section = format_subtask_labels_section(subtask_labels) if subtask_labels else ""
2026-03-11 19:51:48 +00:00
video_duration_mm_ss = f"{int(duration_seconds // 60):02d}:{int(duration_seconds % 60):02d}"
return SKILL_SEGMENTATION_PROMPT_TEMPLATE.format(
goal_context=goal_context,
subtask_labels_section=subtask_labels_section,
video_duration_seconds=duration_seconds,
video_duration_mm_ss=video_duration_mm_ss,
)
# Qwen2-VL Implementation
2026-03-11 22:49:06 +00:00
2026-03-11 19:51:48 +00:00
class Qwen2VL(BaseVLM):
"""Qwen2-VL model for skill segmentation."""
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
self.device = device
self.model_name = model_name
self.process_vision_info = process_vision_info
print(f"Loading Qwen2-VL model: {model_name}...")
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
print(f" Model loaded successfully on {device}")
def segment_skills(
self,
video_path: Path,
episode_duration: float,
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
"""Segment video into skills using Qwen2-VL."""
prompt = create_skill_segmentation_prompt(
coarse_goal, subtask_labels, duration_seconds=episode_duration
)
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
2026-03-11 19:51:48 +00:00
{
"type": "text",
"text": f"Video duration: {duration_str} (exactly {episode_duration:.1f} seconds). Segment into atomic skills. Last skill must end at {episode_duration:.1f}.",
},
],
},
]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
2026-03-11 19:51:48 +00:00
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
2026-03-11 19:51:48 +00:00
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
2026-03-11 22:49:06 +00:00
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
2026-03-11 19:51:48 +00:00
response = self.processor.batch_decode(
2026-03-11 22:49:06 +00:00
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
2026-03-11 19:51:48 +00:00
skip_special_tokens=True,
)[0].strip()
return self._parse_skills_response(response)
def segment_skills_batch(
self,
video_paths: list[Path],
episode_durations: list[float],
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[list[Skill]]:
"""Segment multiple videos into skills using Qwen2-VL in a batch."""
all_messages = []
2026-03-11 22:49:06 +00:00
for video_path, duration in zip(video_paths, episode_durations, strict=True):
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
2026-03-11 19:51:48 +00:00
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
2026-03-11 19:51:48 +00:00
{
"type": "text",
"text": f"Video duration: {duration_str} (exactly {duration:.1f} seconds). Segment into atomic skills. Last skill must end at {duration:.1f}.",
},
],
},
]
all_messages.append(messages)
2026-03-11 22:49:06 +00:00
2026-03-11 19:51:48 +00:00
all_texts = []
all_video_inputs = []
2026-03-11 22:49:06 +00:00
2026-03-11 19:51:48 +00:00
for messages in all_messages:
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages)
2026-03-11 19:51:48 +00:00
all_texts.append(text)
all_video_inputs.extend(video_inputs or [])
2026-03-11 22:49:06 +00:00
2026-03-11 19:51:48 +00:00
inputs = self.processor(
text=all_texts,
videos=all_video_inputs or None,
2026-03-11 19:51:48 +00:00
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
2026-03-11 22:49:06 +00:00
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
2026-03-11 19:51:48 +00:00
responses = self.processor.batch_decode(
2026-03-11 22:49:06 +00:00
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
2026-03-11 19:51:48 +00:00
skip_special_tokens=True,
)
2026-03-11 22:49:06 +00:00
2026-03-11 19:51:48 +00:00
# Parse each response
all_skills = []
for idx, response in enumerate(responses):
try:
skills = self._parse_skills_response(response.strip())
if not skills:
print(f"Warning: No skills parsed from response for video {idx}")
all_skills.append(skills)
except Exception as e:
print(f"Warning: Failed to parse response for video {idx}: {e}")
all_skills.append([])
2026-03-11 22:49:06 +00:00
2026-03-11 19:51:48 +00:00
return all_skills
def _parse_skills_response(self, response: str) -> list[Skill]:
"""Parse the VLM response into Skill objects."""
# Extract JSON from response
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
data = json.loads(response)
skills_data = data.get("skills", data)
if isinstance(skills_data, list):
return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError:
# Try to find JSON object in response
match = re.search(r"\{.*\}", response, re.DOTALL)
if match:
try:
data = json.loads(match.group())
skills_data = data.get("skills", [])
return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError as e:
excerpt = response[:200]
raise ValueError(
f"Could not parse JSON from VLM response (fallback failed): {excerpt}..."
) from e
2026-03-11 19:51:48 +00:00
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
# Qwen3-VL Implementation (MoE variant)
class Qwen3VL(BaseVLM):
"""Qwen3-VL MoE model for skill segmentation."""
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
self.device = device
self.model_name = model_name
self.process_vision_info = process_vision_info
print(f"Loading Qwen3-VL model: {model_name}...")
self.model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
print(f" Model loaded successfully on {device}")
def segment_skills(
self,
video_path: Path,
episode_duration: float,
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
"""Segment video into skills using Qwen3-VL."""
prompt = create_skill_segmentation_prompt(
coarse_goal, subtask_labels, duration_seconds=episode_duration
)
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
2026-03-11 19:51:48 +00:00
{
"type": "text",
"text": f"Video duration: {duration_str} (exactly {episode_duration:.1f} seconds). Segment into atomic skills. Last skill must end at {episode_duration:.1f}.",
},
],
},
]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
videos, video_metadata = _unpack_video_inputs(video_inputs)
2026-03-11 19:51:48 +00:00
inputs = self.processor(
text=[text],
images=image_inputs,
videos=videos,
videos_kwargs={
"video_metadata": video_metadata,
"do_sample_frames": False,
},
2026-03-11 19:51:48 +00:00
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
2026-03-11 22:49:06 +00:00
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
2026-03-11 19:51:48 +00:00
response = self.processor.batch_decode(
2026-03-11 22:49:06 +00:00
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
2026-03-11 19:51:48 +00:00
skip_special_tokens=True,
)[0].strip()
return self._parse_skills_response(response)
def segment_skills_batch(
self,
video_paths: list[Path],
episode_durations: list[float],
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[list[Skill]]:
"""Segment multiple videos into skills using Qwen3-VL in a batch."""
all_messages = []
2026-03-11 22:49:06 +00:00
for video_path, duration in zip(video_paths, episode_durations, strict=True):
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
2026-03-11 19:51:48 +00:00
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
2026-03-11 19:51:48 +00:00
{
"type": "text",
"text": f"Video duration: {duration_str} (exactly {duration:.1f} seconds). Segment into atomic skills. Last skill must end at {duration:.1f}.",
},
],
},
]
all_messages.append(messages)
2026-03-11 22:49:06 +00:00
2026-03-11 19:51:48 +00:00
all_texts = []
all_video_tuples = []
2026-03-11 22:49:06 +00:00
2026-03-11 19:51:48 +00:00
for messages in all_messages:
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
2026-03-11 19:51:48 +00:00
all_texts.append(text)
all_video_tuples.extend(video_inputs or [])
2026-03-11 22:49:06 +00:00
videos, video_metadata = _unpack_video_inputs(all_video_tuples or None)
2026-03-11 19:51:48 +00:00
inputs = self.processor(
text=all_texts,
videos=videos,
videos_kwargs={
"video_metadata": video_metadata,
"do_sample_frames": False,
},
2026-03-11 19:51:48 +00:00
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
2026-03-11 22:49:06 +00:00
generated_ids = self.model.generate(
**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7
)
2026-03-11 19:51:48 +00:00
responses = self.processor.batch_decode(
2026-03-11 22:49:06 +00:00
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
2026-03-11 19:51:48 +00:00
skip_special_tokens=True,
)
2026-03-11 22:49:06 +00:00
2026-03-11 19:51:48 +00:00
# Parse each response
all_skills = []
for idx, response in enumerate(responses):
try:
skills = self._parse_skills_response(response.strip())
if not skills:
print(f"Warning: No skills parsed from response for video {idx}")
all_skills.append(skills)
except Exception as e:
print(f"Warning: Failed to parse response for video {idx}: {e}")
all_skills.append([])
2026-03-11 22:49:06 +00:00
2026-03-11 19:51:48 +00:00
return all_skills
def _parse_skills_response(self, response: str) -> list[Skill]:
"""Parse the VLM response into Skill objects."""
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
data = json.loads(response)
skills_data = data.get("skills", data)
if isinstance(skills_data, list):
return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError:
match = re.search(r"\{.*\}", response, re.DOTALL)
if match:
data = json.loads(match.group())
skills_data = data.get("skills", [])
return [Skill.from_dict(s) for s in skills_data]
2026-03-11 22:49:06 +00:00
2026-03-11 19:51:48 +00:00
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
# Qwen3.5-VL Implementation (Qwen3_5ForConditionalGeneration)
2026-03-11 22:49:06 +00:00
class Qwen35VL(BaseVLM):
2026-03-11 19:51:48 +00:00
"""Qwen3.5-VL model for skill segmentation (Qwen3_5ForConditionalGeneration)."""
def __init__(self, model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16):
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen3_5ForConditionalGeneration
self.device = device
self.model_name = model_name
self.process_vision_info = process_vision_info
print(f"Loading Qwen3.5-VL model: {model_name}...")
self.model = Qwen3_5ForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
self.processor.tokenizer.padding_side = "left"
print(f" Model loaded successfully on {device}")
def segment_skills(
self,
video_path: Path,
episode_duration: float,
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[Skill]:
"""Segment video into skills using Qwen3.5-VL."""
prompt = create_skill_segmentation_prompt(
coarse_goal, subtask_labels, duration_seconds=episode_duration
)
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
2026-03-11 19:51:48 +00:00
{
"type": "text",
"text": f"Video duration: {duration_str} (exactly {episode_duration:.1f} seconds). Segment into atomic skills. Last skill must end at {episode_duration:.1f}.",
},
],
},
]
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
videos, video_metadata = _unpack_video_inputs(video_inputs)
2026-03-11 19:51:48 +00:00
inputs = self.processor(
text=[text],
images=image_inputs,
videos=videos,
videos_kwargs={
"video_metadata": video_metadata,
"do_sample_frames": False,
},
2026-03-11 19:51:48 +00:00
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
response = self.processor.batch_decode(
2026-03-11 22:49:06 +00:00
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
2026-03-11 19:51:48 +00:00
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0].strip()
return self._parse_skills_response(response)
def segment_skills_batch(
self,
video_paths: list[Path],
episode_durations: list[float],
coarse_goal: str | None = None,
subtask_labels: list[str] | None = None,
) -> list[list[Skill]]:
"""Segment multiple videos into skills using Qwen3.5-VL in a batch."""
all_messages = []
2026-03-11 22:49:06 +00:00
for video_path, duration in zip(video_paths, episode_durations, strict=True):
prompt = create_skill_segmentation_prompt(coarse_goal, subtask_labels, duration_seconds=duration)
2026-03-11 19:51:48 +00:00
duration_str = f"{int(duration // 60):02d}:{int(duration % 60):02d}"
messages = [
{"role": "system", "content": [{"type": "text", "text": prompt}]},
{
"role": "user",
"content": [
{"type": "video", "video": str(video_path), "fps": 1.0},
2026-03-11 19:51:48 +00:00
{
"type": "text",
"text": f"Video duration: {duration_str} (exactly {duration:.1f} seconds). Segment into atomic skills. Last skill must end at {duration:.1f}.",
},
],
},
]
all_messages.append(messages)
all_texts = []
all_video_tuples = []
2026-03-11 19:51:48 +00:00
for messages in all_messages:
2026-03-11 22:49:06 +00:00
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
)
image_inputs, video_inputs = self.process_vision_info(messages, return_video_metadata=True)
2026-03-11 19:51:48 +00:00
all_texts.append(text)
all_video_tuples.extend(video_inputs or [])
2026-03-11 19:51:48 +00:00
videos, video_metadata = _unpack_video_inputs(all_video_tuples or None)
2026-03-11 19:51:48 +00:00
inputs = self.processor(
text=all_texts,
videos=videos,
videos_kwargs={
"video_metadata": video_metadata,
"do_sample_frames": False,
},
2026-03-11 19:51:48 +00:00
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
responses = self.processor.batch_decode(
2026-03-11 22:49:06 +00:00
[out[len(inp) :] for inp, out in zip(inputs.input_ids, generated_ids, strict=True)],
2026-03-11 19:51:48 +00:00
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
all_skills = []
for idx, response in enumerate(responses):
try:
skills = self._parse_skills_response(response.strip())
if not skills:
print(f"Warning: No skills parsed from response for video {idx}")
all_skills.append(skills)
except Exception as e:
print(f"Warning: Failed to parse response for video {idx}: {e}")
all_skills.append([])
return all_skills
def _parse_skills_response(self, response: str) -> list[Skill]:
"""Parse the VLM response into Skill objects."""
if "```json" in response:
response = response.split("```json")[1].split("```")[0]
elif "```" in response:
response = response.split("```")[1].split("```")[0]
try:
data = json.loads(response)
skills_data = data.get("skills", data)
if isinstance(skills_data, list):
return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError:
match = re.search(r"\{.*\}", response, re.DOTALL)
if match:
data = json.loads(match.group())
skills_data = data.get("skills", [])
return [Skill.from_dict(s) for s in skills_data]
raise ValueError(f"Could not parse skills from response: {response[:200]}...")
2026-03-11 22:49:06 +00:00
2026-03-11 19:51:48 +00:00
# VLM Registry - Add new VLMs here
VLM_REGISTRY: dict[str, type[BaseVLM]] = {
# Qwen2-VL variants
"Qwen/Qwen2-VL-2B-Instruct": Qwen2VL,
"Qwen/Qwen2-VL-7B-Instruct": Qwen2VL,
"Qwen/Qwen2-VL-72B-Instruct": Qwen2VL,
# Qwen3-VL variants (MoE)
"Qwen/Qwen3-VL-30B-A3B-Instruct": Qwen3VL,
# Qwen3.5-VL (Qwen3_5ForConditionalGeneration)
2026-03-11 22:49:06 +00:00
"Qwen/Qwen3.5-27B": Qwen35VL,
"Qwen/Qwen3-VL-8B-Instruct": Qwen35VL,
2026-03-11 19:51:48 +00:00
}
def get_vlm(model_name: str, device: str = "cuda", torch_dtype: torch.dtype = torch.bfloat16) -> BaseVLM:
"""
Factory function to get the appropriate VLM based on model name.
Args:
model_name: HuggingFace model identifier
device: Device to load model on
torch_dtype: Data type for model weights
Returns:
Initialized VLM instance
Raises:
ValueError: If model is not in registry
"""
# Check exact match first
if model_name in VLM_REGISTRY:
return VLM_REGISTRY[model_name](model_name, device, torch_dtype)
# Check for partial matches (e.g., "qwen2" in model name)
model_lower = model_name.lower()
if "qwen3.5" in model_lower:
2026-03-11 22:49:06 +00:00
return Qwen35VL(model_name, device, torch_dtype)
2026-03-11 19:51:48 +00:00
if "qwen3" in model_lower:
return Qwen3VL(model_name, device, torch_dtype)
elif "qwen2" in model_lower or "qwen-vl" in model_lower:
return Qwen2VL(model_name, device, torch_dtype)
raise ValueError(
f"Unknown model: {model_name}. "
f"Supported models: {list(VLM_REGISTRY.keys())}. "
"Or implement a new VLM class inheriting from BaseVLM."
)