woking on qwen

This commit is contained in:
Jade Choghari
2025-12-08 14:03:47 +00:00
parent a811945336
commit 3568df8a35
3 changed files with 853 additions and 42 deletions

View File

@@ -173,10 +173,7 @@ def create_skill_segmentation_prompt(coarse_goal: str | None = None) -> str:
4. **Natural Language**: Use clear, descriptive names for each skill
5. **Timestamps**: Use seconds (float) for all timestamps
# Analysis Steps
1. First, describe what you observe in the video chronologically
2. Identify distinct motion phases and state changes
3. Determine precise boundaries based on visual state transitions
# Output Format
After your analysis, output ONLY valid JSON with this exact structure:
@@ -393,8 +390,10 @@ class SmolVLM(BaseVLM):
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
self.model = AutoModelForVision2Seq.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
)
model_name,
torch_dtype=torch_dtype,
# _attn_implementation="flash_attention_2" if device == "cuda" else "eager",
).to(device)
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
@@ -413,15 +412,21 @@ class SmolVLM(BaseVLM):
prompt = create_skill_segmentation_prompt(coarse_goal)
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
# Create message with sampled frames
content = [{"type": "text", "text": prompt}]
# Add frames as images (sample up to 8 frames to avoid context overflow)
# Sample frames (up to 8 frames to avoid context overflow)
frame_indices = self._select_frame_indices(len(frames), max_frames=8)
for idx in frame_indices:
frame = frames[idx]
pil_image = PIL.Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
content.append({"type": "image", "image": pil_image})
# Convert frames to PIL images
pil_images = [
PIL.Image.fromarray(cv2.cvtColor(frames[idx], cv2.COLOR_BGR2RGB))
for idx in frame_indices
]
# Create message content with image placeholders
content = [{"type": "text", "text": prompt}]
# Add image placeholders (one for each frame)
for _ in frame_indices:
content.append({"type": "image"})
content.append(
{
@@ -432,17 +437,18 @@ class SmolVLM(BaseVLM):
messages = [{"role": "user", "content": content}]
inputs = self.processor(
text=self.processor.apply_chat_template(messages, add_generation_prompt=True),
images=[PIL.Image.fromarray(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)) for i in frame_indices],
return_tensors="pt",
).to(self.device)
# Apply chat template to get the prompt
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
# Process inputs with both text and images
inputs = self.processor(text=prompt, images=pil_images, return_tensors="pt")
inputs = inputs.to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return self._parse_skills_response(response, episode_duration)
def _extract_frames(self, video_path: Path, target_fps: int = 1) -> list:
@@ -481,6 +487,7 @@ class SmolVLM(BaseVLM):
try:
data = json.loads(response)
skills_data = data.get("skills", data)
breakpoint()
if isinstance(skills_data, list):
return [Skill.from_dict(s) for s in skills_data]
except json.JSONDecodeError:
@@ -683,30 +690,32 @@ class SkillAnnotator:
# Get coarse task description if available
coarse_goal = self._get_coarse_goal(dataset)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=self.console,
) as progress:
task = progress.add_task(f"Annotating {len(episode_indices)} episodes...", total=len(episode_indices))
# with Progress(
# SpinnerColumn(),
# TextColumn("[progress.description]{task.description}"),
# console=self.console,
# ) as progress:
# task = progress.add_task(f"Annotating {len(episode_indices)} episodes...", total=len(episode_indices))
print(f"Annotating {len(episode_indices)} episodes...")
for ep_idx in episode_indices:
progress.update(task, description=f"Processing episode {ep_idx}...")
for ep_idx in episode_indices:
# progress.update(task, description=f"Processing episode {ep_idx}...")
print(f"Processing episode {ep_idx}...")
try:
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
annotations[ep_idx] = EpisodeSkills(
episode_index=ep_idx,
description=coarse_goal,
skills=skills,
)
self.console.print(
f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]"
)
except Exception as e:
self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]")
try:
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
annotations[ep_idx] = EpisodeSkills(
episode_index=ep_idx,
description=coarse_goal,
skills=skills,
)
self.console.print(
f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]"
)
except Exception as e:
self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]")
progress.advance(task)
# progress.advance(task)
return annotations