diff --git a/src/lerobot/policies/wall_x/README.md b/src/lerobot/policies/wall_x/README.md new file mode 100644 index 000000000..b43be409e --- /dev/null +++ b/src/lerobot/policies/wall_x/README.md @@ -0,0 +1,35 @@ +# WALL-OSS + +This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction. + +--- + +## Model Overview + +| Feature | Description | +| -------------------- | ------------------------------------------------------------------------ | +| Base Model | Qwen2.5-VL (Vision-Language Model) | +| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) | +| Architecture | Mixture of Experts (MoE) with action-specific routing | | +| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception | +--- + +## Citation + +If you use this work, please cite: + +```bibtex +@article{zhai2025igniting, + title = {Igniting VLMs Toward the Embodied Space}, + author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach}, + journal = {arXiv preprint arXiv:2509.11766}, + year = {2025} +} +``` + +--- + +## License + +This port follows the **Apache 2.0 License**. + diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py index dbd147872..d0c192fb4 100644 --- a/src/lerobot/policies/wall_x/modeling_wall_x.py +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -140,11 +140,11 @@ class ActionHead(nn.Module): # Proprioception projection self.propri_proj = nn.Linear(self.propri_dim * 2, self.hidden_size, bias=False) - def sample_time(self, batch_size, device, dtype): - """Sample timesteps using Beta distribution.""" + def sample_time(self, batch_size, device): + """Sample timesteps using Beta distribution (always in float32 for numerical stability).""" beta_dist = Beta( - torch.tensor(self.beta_alpha, dtype=dtype, device=device), - torch.tensor(self.beta_beta, dtype=dtype, device=device) + torch.tensor(self.beta_alpha, dtype=torch.float32, device=device), + torch.tensor(self.beta_beta, dtype=torch.float32, device=device) ) sample = beta_dist.sample([batch_size]) time = (1 - sample) * self.s @@ -163,30 +163,30 @@ class ActionHead(nn.Module): """ batch_size = action_chunk.shape[0] device = action_chunk.device - dtype = action_chunk.dtype + weight_dtype = self.w1.weight.dtype - # Add noise using flow matching - noise = torch.randn_like(action_chunk) - time = self.sample_time(batch_size, device, dtype) + # Sample time outside of autocast (Beta distribution needs float32) + time = self.sample_time(batch_size, device) t = time.unsqueeze(-1).unsqueeze(-1) - # Linear interpolation - noisy_action = (1 - t) * noise + t * action_chunk - flow = action_chunk - noise - - # Generate time embeddings - time_embed = self.time_embed(time) + # Noise and flow computation in float32 + noise = torch.randn_like(action_chunk, dtype=torch.float32) + action_chunk_f32 = action_chunk.to(torch.float32) + noisy_action = (1 - t) * noise + t * action_chunk_f32 + flow = action_chunk_f32 - noise # Project noisy actions if dof_mask is not None: - noisy_action = torch.cat([noisy_action, dof_mask], dim=-1) + noisy_action = torch.cat([noisy_action, dof_mask.to(torch.float32)], dim=-1) - noisy_action = noisy_action.to(dtype=self.w1.weight.dtype) + # Convert to weight dtype for linear layers + noisy_action = noisy_action.to(dtype=weight_dtype) action_embed = self.w1(noisy_action) - # Combine with time embeddings + # Generate time embeddings and combine + time_embed = self.time_embed(time) time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1) - time_embed = time_embed.to(dtype=self.w2.weight.dtype) + time_embed = time_embed.to(dtype=weight_dtype) concat_embed = torch.cat([action_embed, time_embed], dim=-1) concat_embed = self.w2(concat_embed) @@ -196,14 +196,17 @@ class ActionHead(nn.Module): def step(self, timestep, noisy_action, dof_mask=None): """Single denoising step for inference.""" + weight_dtype = self.w1.weight.dtype + if dof_mask is not None: noisy_action = torch.cat([noisy_action, dof_mask], dim=-1) + noisy_action = noisy_action.to(dtype=weight_dtype) time_embed = self.time_embed(timestep) action_embed = self.w1(noisy_action) time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1) - time_embed = time_embed.to(device=noisy_action.device, dtype=noisy_action.dtype) + time_embed = time_embed.to(device=noisy_action.device, dtype=weight_dtype) concat_embed = torch.cat([action_embed, time_embed], dim=-1) concat_embed = self.w2(concat_embed) @@ -212,12 +215,16 @@ class ActionHead(nn.Module): return embed def flow_loss(self, action_hidden_states, flow, dof_mask=None): - """Compute flow matching loss.""" + """Compute flow matching loss (all computations in float32 for stability).""" + # Ensure all inputs are float32 + action_hidden_states = action_hidden_states.to(torch.float32) + flow = flow.to(torch.float32) + action_pred = self.action_proj_back(action_hidden_states) loss = F.mse_loss(action_pred, flow, reduction="none") if dof_mask is not None: - dof_mask = dof_mask.reshape(-1, dof_mask.shape[-1]) + dof_mask = dof_mask.reshape(-1, dof_mask.shape[-1]).to(torch.float32) loss = loss * dof_mask return loss @@ -960,10 +967,10 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): # Compute losses if labels are provided if labels is not None: - loss = 0 + loss = torch.tensor(0.0, device=hidden_states.device, dtype=torch.float32) # Compute standard cross-entropy loss for language modeling - shift_logits = logits[..., :-1, :].contiguous() + shift_logits = logits[..., :-1, :].contiguous().to(torch.float32) shift_labels = labels[..., 1:].contiguous() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) @@ -975,12 +982,12 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): cross_entropy_loss = ( _cross_entropy_loss[non_ignored_mask].mean() if non_ignored_mask.any() - else torch.tensor(0.0, device=shift_logits.device) + else torch.tensor(0.0, device=shift_logits.device, dtype=torch.float32) ) # Add cross-entropy loss to total loss if valid if not torch.isnan(cross_entropy_loss): - loss += cross_entropy_loss + loss = loss + cross_entropy_loss.to(torch.float32) else: with torch.no_grad(): cross_entropy_loss.detach() @@ -989,16 +996,16 @@ class Qwen2_5_VLMoEForAction(Qwen2_5_VLForConditionalGeneration): action_mask = input_ids == self.action_token_id_set["action_token_id"] if action_mask.any(): action_hidden_states = hidden_states[action_mask].to(torch.float32) - flow = flow.reshape(-1, flow.shape[-1]) + flow = flow.reshape(-1, flow.shape[-1]).to(torch.float32) _flow_loss = self.action_preprocessor.flow_loss( action_hidden_states, flow, dof_mask ) if isinstance(_flow_loss, torch.Tensor): flow_loss = _flow_loss.mean() if loss is not None: - loss += self.flow_loss_weight * flow_loss + loss = loss + self.flow_loss_weight * flow_loss.to(torch.float32) else: - loss = self.flow_loss_weight * flow_loss + loss = self.flow_loss_weight * flow_loss.to(torch.float32) _flow_loss = _flow_loss.view( dof_mask.shape[0], dof_mask.shape[1], dof_mask.shape[2] ) @@ -1828,7 +1835,6 @@ class WallXPolicy(PreTrainedPolicy): # Initialize the wall-x model self.model = Qwen2_5_VLMoEForAction.from_pretrained(config.pretrained_name_or_path) self.model.to(config.device) - # Convert to bfloat16 for Flash Attention compatibility self.model.to_bfloat16_for_selected_params() self.reset()