mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
Compare commits
3 Commits
feat/add-m
...
feat/add-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6986b46fa8 | ||
|
|
b1d72ac29c | ||
|
|
cffd545527 |
@@ -11,6 +11,7 @@ import os
|
||||
import signal
|
||||
import statistics
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -19,18 +20,18 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
class TimeoutException:
|
||||
class TimeoutExceptionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timeout(seconds):
|
||||
def signal_handler(signum, frame):
|
||||
raise TimeoutException(f"Timed out after {seconds} seconds")
|
||||
raise TimeoutExceptionError(f"Timed out after {seconds} seconds")
|
||||
|
||||
# On Windows, signal is not available, so we can't use this timeout mechanism
|
||||
if not hasattr(signal, "SIGALRM"):
|
||||
@@ -84,12 +85,12 @@ def generate_dummy_observation(input_features: dict, device: str = "cpu") -> dic
|
||||
# Default: random normal for unknown types
|
||||
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
|
||||
|
||||
# Add batch dimension
|
||||
for key in dummy_obs:
|
||||
dummy_obs[key] = dummy_obs[key].unsqueeze(0)
|
||||
# # Add batch dimension
|
||||
# for key in dummy_obs:
|
||||
# dummy_obs[key] = dummy_obs[key].unsqueeze(0)
|
||||
|
||||
# Add task string for language-conditioned policies
|
||||
dummy_obs["task"] = ""
|
||||
dummy_obs["task"] = " this is a dummy task"
|
||||
|
||||
return dummy_obs
|
||||
|
||||
@@ -151,7 +152,9 @@ def main():
|
||||
policy_class = get_policy_class(args.policy_type)
|
||||
policy: PreTrainedPolicy = policy_class.from_pretrained(args.policy_id)
|
||||
policy.eval()
|
||||
policy.to(device)
|
||||
policy.to(device, torch.float32)
|
||||
policy.config.device = device
|
||||
preprocessor, postprocessor = make_pre_post_processors(policy.config)
|
||||
|
||||
print(f"Policy loaded on {device}")
|
||||
print(f"Input features: {list(policy.config.input_features.keys())}")
|
||||
@@ -159,7 +162,7 @@ def main():
|
||||
|
||||
# Generate dummy observation based on policy input features
|
||||
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
|
||||
dummy_observation["task"] = ""
|
||||
dummy_observation["task"] = "this is a dummy task"
|
||||
|
||||
# Helper to sync for fair timings
|
||||
def _sync(dev_=device):
|
||||
@@ -175,8 +178,15 @@ def main():
|
||||
print("Warming up...")
|
||||
with torch.no_grad():
|
||||
policy.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
orginal_dummy_observation = deepcopy(dummy_observation)
|
||||
for _ in range(args.warmup):
|
||||
_ = policy.select_action(dummy_observation)
|
||||
dummy_observation_model = deepcopy(orginal_dummy_observation)
|
||||
dummy_observation_model = preprocessor(dummy_observation_model)
|
||||
action_model = policy.select_action(dummy_observation_model)
|
||||
_ = postprocessor(action_model)
|
||||
policy.reset()
|
||||
_sync()
|
||||
|
||||
# Memory footprint before timing
|
||||
@@ -193,20 +203,25 @@ def main():
|
||||
start_events = []
|
||||
end_events = []
|
||||
timeout_count = 0
|
||||
orginal_dummy_observation = deepcopy(dummy_observation)
|
||||
|
||||
with torch.no_grad():
|
||||
for forward in tqdm(range(args.num_samples), desc="Trials"):
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
try:
|
||||
dummy_observation_model = deepcopy(orginal_dummy_observation)
|
||||
dummy_observation_model = preprocessor(dummy_observation)
|
||||
with timeout(args.timeout):
|
||||
start_event.record()
|
||||
_ = policy.select_action(dummy_observation)
|
||||
action_model = policy.select_action(dummy_observation_model)
|
||||
end_event.record()
|
||||
_ = postprocessor(action_model)
|
||||
policy.reset()
|
||||
|
||||
start_events.append(start_event)
|
||||
end_events.append(end_event)
|
||||
except TimeoutException:
|
||||
except TimeoutExceptionError:
|
||||
timeout_count += 1
|
||||
# Add placeholder for timeout
|
||||
start_events.append(None)
|
||||
@@ -219,7 +234,8 @@ def main():
|
||||
per_forward_ms = []
|
||||
for start_event, end_event in zip(start_events, end_events, strict=True):
|
||||
if start_event is None:
|
||||
per_forward_ms.append(args.timeout * 1000)
|
||||
# per_forward_ms.append(args.timeout * 1000)
|
||||
continue
|
||||
else:
|
||||
per_forward_ms.append(start_event.elapsed_time(end_event))
|
||||
|
||||
@@ -236,20 +252,25 @@ def main():
|
||||
with torch.no_grad():
|
||||
for sample in tqdm(range(args.num_samples), desc="Samples"):
|
||||
try:
|
||||
dummy_observation_model = deepcopy(orginal_dummy_observation)
|
||||
dummy_observation_model = preprocessor(dummy_observation_model)
|
||||
with timeout(args.timeout):
|
||||
start_time = time.perf_counter()
|
||||
_ = policy.select_action(dummy_observation)
|
||||
action_model = policy.select_action(dummy_observation_model)
|
||||
end_time = time.perf_counter()
|
||||
policy.reset()
|
||||
|
||||
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
|
||||
except TimeoutException:
|
||||
except TimeoutExceptionError:
|
||||
timeout_count += 1
|
||||
per_forward_ms.append(args.timeout * 1000)
|
||||
# per_forward_ms.append(args.timeout * 1000)
|
||||
|
||||
print(f"\n[!] Timeout on sample {sample + 1}")
|
||||
continue
|
||||
|
||||
if timeout_count > 0:
|
||||
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
|
||||
print(f"Timeout percentage: {timeout_count / args.num_samples * 100:.1f}%")
|
||||
|
||||
# Memory footprint after timing
|
||||
rss_after = process.memory_info().rss
|
||||
@@ -355,6 +376,7 @@ Benchmark completed successfully at {datetime.now().strftime("%Y-%m-%d %H:%M:%S"
|
||||
print(f"Device: {device}")
|
||||
print(f"Samples: {args.num_samples} | Warmup: {args.warmup}")
|
||||
print(f"Model params: {num_params:,}")
|
||||
print(f"Timeout percentage: {timeout_count / args.num_samples * 100:.1f}%")
|
||||
|
||||
print("\nLatency per forward (ms):")
|
||||
print(f" mean: {mean_ms:.3f} std: {std_ms:.3f}")
|
||||
|
||||
@@ -493,7 +493,7 @@ class PI0FlowMatching(nn.Module):
|
||||
img_mask,
|
||||
) in zip(images, img_masks, strict=False):
|
||||
img_emb = self.paligemma_with_expert.embed_image(img)
|
||||
img_emb = img_emb.to(dtype=torch.bfloat16)
|
||||
img_emb = img_emb.to(dtype=torch.float32)
|
||||
|
||||
# Normalize image embeddings
|
||||
img_emb_dim = img_emb.shape[-1]
|
||||
@@ -536,7 +536,7 @@ class PI0FlowMatching(nn.Module):
|
||||
|
||||
# Embed state
|
||||
state_emb = self.state_proj(state)
|
||||
state_emb = state_emb.to(dtype=torch.bfloat16)
|
||||
state_emb = state_emb.to(dtype=torch.float32)
|
||||
embs.append(state_emb[:, None, :])
|
||||
bsize = state_emb.shape[0]
|
||||
dtype = state_emb.dtype
|
||||
|
||||
@@ -202,7 +202,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
self.paligemma.eval()
|
||||
|
||||
def to_bfloat16_like_physical_intelligence(self):
|
||||
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
|
||||
self.paligemma = self.paligemma.to(dtype=torch.float32)
|
||||
|
||||
params_to_change_dtype = [
|
||||
"language_model.model.layers",
|
||||
@@ -212,7 +212,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
]
|
||||
for name, param in self.named_parameters():
|
||||
if any(selector in name for selector in params_to_change_dtype):
|
||||
param.data = param.data.to(dtype=torch.bfloat16)
|
||||
param.data = param.data.to(dtype=torch.float32)
|
||||
|
||||
def embed_image(self, image: torch.Tensor):
|
||||
# Handle different transformers versions
|
||||
@@ -262,7 +262,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
hidden_states = hidden_states.to(dtype=torch.bfloat16)
|
||||
hidden_states = hidden_states.to(dtype=torch.float32)
|
||||
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
|
||||
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
|
||||
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
|
||||
@@ -303,7 +303,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
|
||||
att_output = attention_interface(
|
||||
attention_mask, batch_size, head_dim, query_states, key_states, value_states
|
||||
)
|
||||
att_output = att_output.to(dtype=torch.bfloat16)
|
||||
att_output = att_output.to(dtype=torch.float32)
|
||||
|
||||
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
|
||||
outputs_embeds = []
|
||||
|
||||
Reference in New Issue
Block a user