Compare commits

...

3 Commits

Author SHA1 Message Date
AdilZouitine
6986b46fa8 refactor(inference): improve timeout handling and report timeout percentage
- Commented out the timeout handling logic to prevent appending timeout values to the results.
- Added a print statement to display the percentage of timeouts during inference.
2025-09-24 14:50:58 +02:00
AdilZouitine
b1d72ac29c refactor(model): change tensor data type from bfloat16 to float32
- Updated image and state embeddings to use float32 for improved compatibility.
- Adjusted model parameters and hidden states to ensure consistent data type usage.
2025-09-24 14:33:11 +02:00
AdilZouitine
cffd545527 refactor(inference): improve timeout handling and enhance dummy observation generation
- Renamed TimeoutException to TimeoutExceptionError for clarity.
- Updated dummy observation generation to include a task string.
- Integrated pre-processing and post-processing steps in the main function.
- Added deep copy of dummy observations to prevent mutation during processing.
- Enhanced timeout handling to provide percentage of timeouts during inference.
2025-09-24 14:32:47 +02:00
3 changed files with 44 additions and 22 deletions

View File

@@ -11,6 +11,7 @@ import os
import signal
import statistics
from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime
from pathlib import Path
@@ -19,18 +20,18 @@ import torch
from tqdm import tqdm
from lerobot.configs.types import FeatureType
from lerobot.policies.factory import get_policy_class
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.pretrained import PreTrainedPolicy
class TimeoutException:
class TimeoutExceptionError(Exception):
pass
@contextmanager
def timeout(seconds):
def signal_handler(signum, frame):
raise TimeoutException(f"Timed out after {seconds} seconds")
raise TimeoutExceptionError(f"Timed out after {seconds} seconds")
# On Windows, signal is not available, so we can't use this timeout mechanism
if not hasattr(signal, "SIGALRM"):
@@ -84,12 +85,12 @@ def generate_dummy_observation(input_features: dict, device: str = "cpu") -> dic
# Default: random normal for unknown types
dummy_obs[key] = torch.randn(shape, dtype=torch.float32, device=device)
# Add batch dimension
for key in dummy_obs:
dummy_obs[key] = dummy_obs[key].unsqueeze(0)
# # Add batch dimension
# for key in dummy_obs:
# dummy_obs[key] = dummy_obs[key].unsqueeze(0)
# Add task string for language-conditioned policies
dummy_obs["task"] = ""
dummy_obs["task"] = " this is a dummy task"
return dummy_obs
@@ -151,7 +152,9 @@ def main():
policy_class = get_policy_class(args.policy_type)
policy: PreTrainedPolicy = policy_class.from_pretrained(args.policy_id)
policy.eval()
policy.to(device)
policy.to(device, torch.float32)
policy.config.device = device
preprocessor, postprocessor = make_pre_post_processors(policy.config)
print(f"Policy loaded on {device}")
print(f"Input features: {list(policy.config.input_features.keys())}")
@@ -159,7 +162,7 @@ def main():
# Generate dummy observation based on policy input features
dummy_observation = generate_dummy_observation(policy.config.input_features, device)
dummy_observation["task"] = ""
dummy_observation["task"] = "this is a dummy task"
# Helper to sync for fair timings
def _sync(dev_=device):
@@ -175,8 +178,15 @@ def main():
print("Warming up...")
with torch.no_grad():
policy.reset()
preprocessor.reset()
postprocessor.reset()
orginal_dummy_observation = deepcopy(dummy_observation)
for _ in range(args.warmup):
_ = policy.select_action(dummy_observation)
dummy_observation_model = deepcopy(orginal_dummy_observation)
dummy_observation_model = preprocessor(dummy_observation_model)
action_model = policy.select_action(dummy_observation_model)
_ = postprocessor(action_model)
policy.reset()
_sync()
# Memory footprint before timing
@@ -193,20 +203,25 @@ def main():
start_events = []
end_events = []
timeout_count = 0
orginal_dummy_observation = deepcopy(dummy_observation)
with torch.no_grad():
for forward in tqdm(range(args.num_samples), desc="Trials"):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
try:
dummy_observation_model = deepcopy(orginal_dummy_observation)
dummy_observation_model = preprocessor(dummy_observation)
with timeout(args.timeout):
start_event.record()
_ = policy.select_action(dummy_observation)
action_model = policy.select_action(dummy_observation_model)
end_event.record()
_ = postprocessor(action_model)
policy.reset()
start_events.append(start_event)
end_events.append(end_event)
except TimeoutException:
except TimeoutExceptionError:
timeout_count += 1
# Add placeholder for timeout
start_events.append(None)
@@ -219,7 +234,8 @@ def main():
per_forward_ms = []
for start_event, end_event in zip(start_events, end_events, strict=True):
if start_event is None:
per_forward_ms.append(args.timeout * 1000)
# per_forward_ms.append(args.timeout * 1000)
continue
else:
per_forward_ms.append(start_event.elapsed_time(end_event))
@@ -236,20 +252,25 @@ def main():
with torch.no_grad():
for sample in tqdm(range(args.num_samples), desc="Samples"):
try:
dummy_observation_model = deepcopy(orginal_dummy_observation)
dummy_observation_model = preprocessor(dummy_observation_model)
with timeout(args.timeout):
start_time = time.perf_counter()
_ = policy.select_action(dummy_observation)
action_model = policy.select_action(dummy_observation_model)
end_time = time.perf_counter()
policy.reset()
per_forward_ms.append((end_time - start_time) * 1000) # Convert to ms
except TimeoutException:
except TimeoutExceptionError:
timeout_count += 1
per_forward_ms.append(args.timeout * 1000)
# per_forward_ms.append(args.timeout * 1000)
print(f"\n[!] Timeout on sample {sample + 1}")
continue
if timeout_count > 0:
print(f"[!] {timeout_count} inference passes timed out (>{args.timeout * 1000:.1f}ms)")
print(f"Timeout percentage: {timeout_count / args.num_samples * 100:.1f}%")
# Memory footprint after timing
rss_after = process.memory_info().rss
@@ -355,6 +376,7 @@ Benchmark completed successfully at {datetime.now().strftime("%Y-%m-%d %H:%M:%S"
print(f"Device: {device}")
print(f"Samples: {args.num_samples} | Warmup: {args.warmup}")
print(f"Model params: {num_params:,}")
print(f"Timeout percentage: {timeout_count / args.num_samples * 100:.1f}%")
print("\nLatency per forward (ms):")
print(f" mean: {mean_ms:.3f} std: {std_ms:.3f}")

View File

@@ -493,7 +493,7 @@ class PI0FlowMatching(nn.Module):
img_mask,
) in zip(images, img_masks, strict=False):
img_emb = self.paligemma_with_expert.embed_image(img)
img_emb = img_emb.to(dtype=torch.bfloat16)
img_emb = img_emb.to(dtype=torch.float32)
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
@@ -536,7 +536,7 @@ class PI0FlowMatching(nn.Module):
# Embed state
state_emb = self.state_proj(state)
state_emb = state_emb.to(dtype=torch.bfloat16)
state_emb = state_emb.to(dtype=torch.float32)
embs.append(state_emb[:, None, :])
bsize = state_emb.shape[0]
dtype = state_emb.dtype

View File

@@ -202,7 +202,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
self.paligemma.eval()
def to_bfloat16_like_physical_intelligence(self):
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
self.paligemma = self.paligemma.to(dtype=torch.float32)
params_to_change_dtype = [
"language_model.model.layers",
@@ -212,7 +212,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch.bfloat16)
param.data = param.data.to(dtype=torch.float32)
def embed_image(self, image: torch.Tensor):
# Handle different transformers versions
@@ -262,7 +262,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=torch.bfloat16)
hidden_states = hidden_states.to(dtype=torch.float32)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
@@ -303,7 +303,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel):
att_output = attention_interface(
attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_output = att_output.to(dtype=torch.bfloat16)
att_output = att_output.to(dtype=torch.float32)
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
outputs_embeds = []