mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
fixed dtype bugs
This commit is contained in:
35
src/lerobot/policies/wall_x/README.md
Normal file
35
src/lerobot/policies/wall_x/README.md
Normal file
@@ -0,0 +1,35 @@
|
||||
# WALL-OSS
|
||||
|
||||
This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction.
|
||||
|
||||
---
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Feature | Description |
|
||||
| -------------------- | ------------------------------------------------------------------------ |
|
||||
| Base Model | Qwen2.5-VL (Vision-Language Model) |
|
||||
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
|
||||
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
|
||||
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this work, please cite:
|
||||
|
||||
```bibtex
|
||||
@article{zhai2025igniting,
|
||||
title = {Igniting VLMs Toward the Embodied Space},
|
||||
author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach},
|
||||
journal = {arXiv preprint arXiv:2509.11766},
|
||||
year = {2025}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This port follows the **Apache 2.0 License**.
|
||||
|
||||
@@ -140,11 +140,11 @@ class ActionHead(nn.Module):
|
||||
# Proprioception projection
|
||||
self.propri_proj = nn.Linear(self.propri_dim * 2, self.hidden_size, bias=False)
|
||||
|
||||
def sample_time(self, batch_size, device, dtype):
|
||||
"""Sample timesteps using Beta distribution."""
|
||||
def sample_time(self, batch_size, device):
|
||||
"""Sample timesteps using Beta distribution (always in float32 for numerical stability)."""
|
||||
beta_dist = Beta(
|
||||
torch.tensor(self.beta_alpha, dtype=dtype, device=device),
|
||||
torch.tensor(self.beta_beta, dtype=dtype, device=device)
|
||||
torch.tensor(self.beta_alpha, dtype=torch.float32, device=device),
|
||||
torch.tensor(self.beta_beta, dtype=torch.float32, device=device)
|
||||
)
|
||||
sample = beta_dist.sample([batch_size])
|
||||
time = (1 - sample) * self.s
|
||||
@@ -163,30 +163,30 @@ class ActionHead(nn.Module):
|
||||
"""
|
||||
batch_size = action_chunk.shape[0]
|
||||
device = action_chunk.device
|
||||
dtype = action_chunk.dtype
|
||||
weight_dtype = self.w1.weight.dtype
|
||||
|
||||
# Add noise using flow matching
|
||||
noise = torch.randn_like(action_chunk)
|
||||
time = self.sample_time(batch_size, device, dtype)
|
||||
# Sample time outside of autocast (Beta distribution needs float32)
|
||||
time = self.sample_time(batch_size, device)
|
||||
t = time.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# Linear interpolation
|
||||
noisy_action = (1 - t) * noise + t * action_chunk
|
||||
flow = action_chunk - noise
|
||||
|
||||
# Generate time embeddings
|
||||
time_embed = self.time_embed(time)
|
||||
# Noise and flow computation in float32
|
||||
noise = torch.randn_like(action_chunk, dtype=torch.float32)
|
||||
action_chunk_f32 = action_chunk.to(torch.float32)
|
||||
noisy_action = (1 - t) * noise + t * action_chunk_f32
|
||||
flow = action_chunk_f32 - noise
|
||||
|
||||
# Project noisy actions
|
||||
if dof_mask is not None:
|
||||
noisy_action = torch.cat([noisy_action, dof_mask], dim=-1)
|
||||
noisy_action = torch.cat([noisy_action, dof_mask.to(torch.float32)], dim=-1)
|
||||
|
||||
noisy_action = noisy_action.to(dtype=self.w1.weight.dtype)
|
||||
# Convert to weight dtype for linear layers
|
||||
noisy_action = noisy_action.to(dtype=weight_dtype)
|
||||
action_embed = self.w1(noisy_action)
|
||||
|
||||
# Combine with time embeddings
|
||||
# Generate time embeddings and combine
|
||||
time_embed = self.time_embed(time)
|
||||
time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1)
|
||||
time_embed = time_embed.to(dtype=self.w2.weight.dtype)
|
||||
time_embed = time_embed.to(dtype=weight_dtype)
|
||||
|
||||
concat_embed = torch.cat([action_embed, time_embed], dim=-1)
|
||||
concat_embed = self.w2(concat_embed)
|
||||
@@ -196,14 +196,17 @@ class ActionHead(nn.Module):
|
||||
|
||||
def step(self, timestep, noisy_action, dof_mask=None):
|
||||
"""Single denoising step for inference."""
|
||||
weight_dtype = self.w1.weight.dtype
|
||||
|
||||
if dof_mask is not None:
|
||||
noisy_action = torch.cat([noisy_action, dof_mask], dim=-1)
|
||||
noisy_action = noisy_action.to(dtype=weight_dtype)
|
||||
|
||||
time_embed = self.time_embed(timestep)
|
||||
action_embed = self.w1(noisy_action)
|
||||
|
||||
time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1)
|
||||
time_embed = time_embed.to(device=noisy_action.device, dtype=noisy_action.dtype)
|
||||
time_embed = time_embed.to(device=noisy_action.device, dtype=weight_dtype)
|
||||
|
||||
concat_embed = torch.cat([action_embed, time_embed], dim=-1)
|
||||
concat_embed = self.w2(concat_embed)
|
||||
@@ -212,12 +215,16 @@ class ActionHead(nn.Module):
|
||||
return embed
|
||||
|
||||
def flow_loss(self, action_hidden_states, flow, dof_mask=None):
|
||||
"""Compute flow matching loss."""
|
||||
"""Compute flow matching loss (all computations in float32 for stability)."""
|
||||
# Ensure all inputs are float32
|
||||
action_hidden_states = action_hidden_states.to(torch.float32)
|
||||
flow = flow.to(torch.float32)
|
||||
|
||||
action_pred = self.action_proj_back(action_hidden_states)
|
||||
loss = F.mse_loss(action_pred, flow, reduction="none")
|
||||
|
||||
if dof_mask is not None:
|
||||
dof_mask = dof_mask.reshape(-1, dof_mask.shape[-1])
|
||||
dof_mask = dof_mask.reshape(-1, dof_mask.shape[-1]).to(torch.float32)
|
||||
loss = loss * dof_mask
|
||||
|
||||
return loss
|
||||
@@ -960,10 +967,10 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
|
||||
# Compute losses if labels are provided
|
||||
if labels is not None:
|
||||
loss = 0
|
||||
loss = torch.tensor(0.0, device=hidden_states.device, dtype=torch.float32)
|
||||
|
||||
# Compute standard cross-entropy loss for language modeling
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_logits = logits[..., :-1, :].contiguous().to(torch.float32)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
@@ -975,12 +982,12 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
cross_entropy_loss = (
|
||||
_cross_entropy_loss[non_ignored_mask].mean()
|
||||
if non_ignored_mask.any()
|
||||
else torch.tensor(0.0, device=shift_logits.device)
|
||||
else torch.tensor(0.0, device=shift_logits.device, dtype=torch.float32)
|
||||
)
|
||||
|
||||
# Add cross-entropy loss to total loss if valid
|
||||
if not torch.isnan(cross_entropy_loss):
|
||||
loss += cross_entropy_loss
|
||||
loss = loss + cross_entropy_loss.to(torch.float32)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
cross_entropy_loss.detach()
|
||||
@@ -989,16 +996,16 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
|
||||
action_mask = input_ids == self.action_token_id_set["action_token_id"]
|
||||
if action_mask.any():
|
||||
action_hidden_states = hidden_states[action_mask].to(torch.float32)
|
||||
flow = flow.reshape(-1, flow.shape[-1])
|
||||
flow = flow.reshape(-1, flow.shape[-1]).to(torch.float32)
|
||||
_flow_loss = self.action_preprocessor.flow_loss(
|
||||
action_hidden_states, flow, dof_mask
|
||||
)
|
||||
if isinstance(_flow_loss, torch.Tensor):
|
||||
flow_loss = _flow_loss.mean()
|
||||
if loss is not None:
|
||||
loss += self.flow_loss_weight * flow_loss
|
||||
loss = loss + self.flow_loss_weight * flow_loss.to(torch.float32)
|
||||
else:
|
||||
loss = self.flow_loss_weight * flow_loss
|
||||
loss = self.flow_loss_weight * flow_loss.to(torch.float32)
|
||||
_flow_loss = _flow_loss.view(
|
||||
dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2]
|
||||
)
|
||||
@@ -1828,7 +1835,6 @@ class WallXPolicy(PreTrainedPolicy):
|
||||
# Initialize the wall-x model
|
||||
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path)
|
||||
self.model.to(config.device)
|
||||
# Convert to bfloat16 for Flash Attention compatibility
|
||||
self.model.to_bfloat16_for_selected_params()
|
||||
|
||||
self.reset()
|
||||
|
||||
Reference in New Issue
Block a user