fixed dtype bugs

This commit is contained in:
Geoffrey19
2025-12-04 17:15:22 +08:00
parent b4a7586b27
commit 56d20caa1e
2 changed files with 70 additions and 29 deletions

View File

@@ -0,0 +1,35 @@
# WALL-OSS
This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction.
---
## Model Overview
| Feature | Description |
| -------------------- | ------------------------------------------------------------------------ |
| Base Model | Qwen2.5-VL (Vision-Language Model) |
| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) |
| Architecture | Mixture of Experts (MoE) with action-specific routing | |
| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception |
---
## Citation
If you use this work, please cite:
```bibtex
@article{zhai2025igniting,
title = {Igniting VLMs Toward the Embodied Space},
author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach},
journal = {arXiv preprint arXiv:2509.11766},
year = {2025}
}
```
---
## License
This port follows the **Apache 2.0 License**.

View File

@@ -140,11 +140,11 @@ class ActionHead(nn.Module):
# Proprioception projection
self.propri_proj = nn.Linear(self.propri_dim * 2, self.hidden_size, bias=False)
def sample_time(self, batch_size, device, dtype):
"""Sample timesteps using Beta distribution."""
def sample_time(self, batch_size, device):
"""Sample timesteps using Beta distribution (always in float32 for numerical stability)."""
beta_dist = Beta(
torch.tensor(self.beta_alpha, dtype=dtype, device=device),
torch.tensor(self.beta_beta, dtype=dtype, device=device)
torch.tensor(self.beta_alpha, dtype=torch.float32, device=device),
torch.tensor(self.beta_beta, dtype=torch.float32, device=device)
)
sample = beta_dist.sample([batch_size])
time = (1 - sample) * self.s
@@ -163,30 +163,30 @@ class ActionHead(nn.Module):
"""
batch_size = action_chunk.shape[0]
device = action_chunk.device
dtype = action_chunk.dtype
weight_dtype = self.w1.weight.dtype
# Add noise using flow matching
noise = torch.randn_like(action_chunk)
time = self.sample_time(batch_size, device, dtype)
# Sample time outside of autocast (Beta distribution needs float32)
time = self.sample_time(batch_size, device)
t = time.unsqueeze(-1).unsqueeze(-1)
# Linear interpolation
noisy_action = (1 - t) * noise + t * action_chunk
flow = action_chunk - noise
# Generate time embeddings
time_embed = self.time_embed(time)
# Noise and flow computation in float32
noise = torch.randn_like(action_chunk, dtype=torch.float32)
action_chunk_f32 = action_chunk.to(torch.float32)
noisy_action = (1 - t) * noise + t * action_chunk_f32
flow = action_chunk_f32 - noise
# Project noisy actions
if dof_mask is not None:
noisy_action = torch.cat([noisy_action, dof_mask], dim=-1)
noisy_action = torch.cat([noisy_action, dof_mask.to(torch.float32)], dim=-1)
noisy_action = noisy_action.to(dtype=self.w1.weight.dtype)
# Convert to weight dtype for linear layers
noisy_action = noisy_action.to(dtype=weight_dtype)
action_embed = self.w1(noisy_action)
# Combine with time embeddings
# Generate time embeddings and combine
time_embed = self.time_embed(time)
time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1)
time_embed = time_embed.to(dtype=self.w2.weight.dtype)
time_embed = time_embed.to(dtype=weight_dtype)
concat_embed = torch.cat([action_embed, time_embed], dim=-1)
concat_embed = self.w2(concat_embed)
@@ -196,14 +196,17 @@ class ActionHead(nn.Module):
def step(self, timestep, noisy_action, dof_mask=None):
"""Single denoising step for inference."""
weight_dtype = self.w1.weight.dtype
if dof_mask is not None:
noisy_action = torch.cat([noisy_action, dof_mask], dim=-1)
noisy_action = noisy_action.to(dtype=weight_dtype)
time_embed = self.time_embed(timestep)
action_embed = self.w1(noisy_action)
time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1)
time_embed = time_embed.to(device=noisy_action.device, dtype=noisy_action.dtype)
time_embed = time_embed.to(device=noisy_action.device, dtype=weight_dtype)
concat_embed = torch.cat([action_embed, time_embed], dim=-1)
concat_embed = self.w2(concat_embed)
@@ -212,12 +215,16 @@ class ActionHead(nn.Module):
return embed
def flow_loss(self, action_hidden_states, flow, dof_mask=None):
"""Compute flow matching loss."""
"""Compute flow matching loss (all computations in float32 for stability)."""
# Ensure all inputs are float32
action_hidden_states = action_hidden_states.to(torch.float32)
flow = flow.to(torch.float32)
action_pred = self.action_proj_back(action_hidden_states)
loss = F.mse_loss(action_pred, flow, reduction="none")
if dof_mask is not None:
dof_mask = dof_mask.reshape(-1, dof_mask.shape[-1])
dof_mask = dof_mask.reshape(-1, dof_mask.shape[-1]).to(torch.float32)
loss = loss * dof_mask
return loss
@@ -960,10 +967,10 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
# Compute losses if labels are provided
if labels is not None:
loss = 0
loss = torch.tensor(0.0, device=hidden_states.device, dtype=torch.float32)
# Compute standard cross-entropy loss for language modeling
shift_logits = logits[..., :-1, :].contiguous()
shift_logits = logits[..., :-1, :].contiguous().to(torch.float32)
shift_labels = labels[..., 1:].contiguous()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
@@ -975,12 +982,12 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
cross_entropy_loss = (
_cross_entropy_loss[non_ignored_mask].mean()
if non_ignored_mask.any()
else torch.tensor(0.0, device=shift_logits.device)
else torch.tensor(0.0, device=shift_logits.device, dtype=torch.float32)
)
# Add cross-entropy loss to total loss if valid
if not torch.isnan(cross_entropy_loss):
loss += cross_entropy_loss
loss = loss + cross_entropy_loss.to(torch.float32)
else:
with torch.no_grad():
cross_entropy_loss.detach()
@@ -989,16 +996,16 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration):
action_mask = input_ids == self.action_token_id_set["action_token_id"]
if action_mask.any():
action_hidden_states = hidden_states[action_mask].to(torch.float32)
flow = flow.reshape(-1, flow.shape[-1])
flow = flow.reshape(-1, flow.shape[-1]).to(torch.float32)
_flow_loss = self.action_preprocessor.flow_loss(
action_hidden_states, flow, dof_mask
)
if isinstance(_flow_loss, torch.Tensor):
flow_loss = _flow_loss.mean()
if loss is not None:
loss += self.flow_loss_weight * flow_loss
loss = loss + self.flow_loss_weight * flow_loss.to(torch.float32)
else:
loss = self.flow_loss_weight * flow_loss
loss = self.flow_loss_weight * flow_loss.to(torch.float32)
_flow_loss = _flow_loss.view(
dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2]
)
@@ -1828,7 +1835,6 @@ class WallXPolicy(PreTrainedPolicy):
# Initialize the wall-x model
self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path)
self.model.to(config.device)
# Convert to bfloat16 for Flash Attention compatibility
self.model.to_bfloat16_for_selected_params()
self.reset()