mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
woking on qwen
This commit is contained in:
@@ -173,10 +173,7 @@ def create_skill_segmentation_prompt(coarse_goal: str | None = None) -> str:
|
||||
4. **Natural Language**: Use clear, descriptive names for each skill
|
||||
5. **Timestamps**: Use seconds (float) for all timestamps
|
||||
|
||||
# Analysis Steps
|
||||
1. First, describe what you observe in the video chronologically
|
||||
2. Identify distinct motion phases and state changes
|
||||
3. Determine precise boundaries based on visual state transitions
|
||||
|
||||
|
||||
# Output Format
|
||||
After your analysis, output ONLY valid JSON with this exact structure:
|
||||
@@ -393,8 +390,10 @@ class SmolVLM(BaseVLM):
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
self.model = AutoModelForVision2Seq.from_pretrained(
|
||||
model_name, torch_dtype=torch_dtype, device_map=device, trust_remote_code=True
|
||||
)
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
# _attn_implementation="flash_attention_2" if device == "cuda" else "eager",
|
||||
).to(device)
|
||||
|
||||
self.console.print(f"[green]✓ Model loaded successfully on {device}[/green]")
|
||||
|
||||
@@ -413,15 +412,21 @@ class SmolVLM(BaseVLM):
|
||||
prompt = create_skill_segmentation_prompt(coarse_goal)
|
||||
duration_str = f"{int(episode_duration // 60):02d}:{int(episode_duration % 60):02d}"
|
||||
|
||||
# Create message with sampled frames
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
|
||||
# Add frames as images (sample up to 8 frames to avoid context overflow)
|
||||
# Sample frames (up to 8 frames to avoid context overflow)
|
||||
frame_indices = self._select_frame_indices(len(frames), max_frames=8)
|
||||
for idx in frame_indices:
|
||||
frame = frames[idx]
|
||||
pil_image = PIL.Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
content.append({"type": "image", "image": pil_image})
|
||||
|
||||
# Convert frames to PIL images
|
||||
pil_images = [
|
||||
PIL.Image.fromarray(cv2.cvtColor(frames[idx], cv2.COLOR_BGR2RGB))
|
||||
for idx in frame_indices
|
||||
]
|
||||
|
||||
# Create message content with image placeholders
|
||||
content = [{"type": "text", "text": prompt}]
|
||||
|
||||
# Add image placeholders (one for each frame)
|
||||
for _ in frame_indices:
|
||||
content.append({"type": "image"})
|
||||
|
||||
content.append(
|
||||
{
|
||||
@@ -432,17 +437,18 @@ class SmolVLM(BaseVLM):
|
||||
|
||||
messages = [{"role": "user", "content": content}]
|
||||
|
||||
inputs = self.processor(
|
||||
text=self.processor.apply_chat_template(messages, add_generation_prompt=True),
|
||||
images=[PIL.Image.fromarray(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB)) for i in frame_indices],
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
# Apply chat template to get the prompt
|
||||
prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
|
||||
|
||||
# Process inputs with both text and images
|
||||
inputs = self.processor(text=prompt, images=pil_images, return_tensors="pt")
|
||||
inputs = inputs.to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=1024, do_sample=True, temperature=0.7)
|
||||
|
||||
response = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
||||
|
||||
|
||||
return self._parse_skills_response(response, episode_duration)
|
||||
|
||||
def _extract_frames(self, video_path: Path, target_fps: int = 1) -> list:
|
||||
@@ -481,6 +487,7 @@ class SmolVLM(BaseVLM):
|
||||
try:
|
||||
data = json.loads(response)
|
||||
skills_data = data.get("skills", data)
|
||||
breakpoint()
|
||||
if isinstance(skills_data, list):
|
||||
return [Skill.from_dict(s) for s in skills_data]
|
||||
except json.JSONDecodeError:
|
||||
@@ -683,30 +690,32 @@ class SkillAnnotator:
|
||||
# Get coarse task description if available
|
||||
coarse_goal = self._get_coarse_goal(dataset)
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
console=self.console,
|
||||
) as progress:
|
||||
task = progress.add_task(f"Annotating {len(episode_indices)} episodes...", total=len(episode_indices))
|
||||
# with Progress(
|
||||
# SpinnerColumn(),
|
||||
# TextColumn("[progress.description]{task.description}"),
|
||||
# console=self.console,
|
||||
# ) as progress:
|
||||
# task = progress.add_task(f"Annotating {len(episode_indices)} episodes...", total=len(episode_indices))
|
||||
print(f"Annotating {len(episode_indices)} episodes...")
|
||||
|
||||
for ep_idx in episode_indices:
|
||||
progress.update(task, description=f"Processing episode {ep_idx}...")
|
||||
for ep_idx in episode_indices:
|
||||
# progress.update(task, description=f"Processing episode {ep_idx}...")
|
||||
print(f"Processing episode {ep_idx}...")
|
||||
|
||||
try:
|
||||
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
|
||||
annotations[ep_idx] = EpisodeSkills(
|
||||
episode_index=ep_idx,
|
||||
description=coarse_goal,
|
||||
skills=skills,
|
||||
)
|
||||
self.console.print(
|
||||
f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]"
|
||||
)
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]")
|
||||
try:
|
||||
skills = self._annotate_episode(dataset, ep_idx, video_key, coarse_goal)
|
||||
annotations[ep_idx] = EpisodeSkills(
|
||||
episode_index=ep_idx,
|
||||
description=coarse_goal,
|
||||
skills=skills,
|
||||
)
|
||||
self.console.print(
|
||||
f"[green]✓ Episode {ep_idx}: {len(skills)} skills identified[/green]"
|
||||
)
|
||||
except Exception as e:
|
||||
self.console.print(f"[red]✗ Episode {ep_idx} failed: {e}[/red]")
|
||||
|
||||
progress.advance(task)
|
||||
# progress.advance(task)
|
||||
|
||||
return annotations
|
||||
|
||||
|
||||
Reference in New Issue
Block a user