mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
support wallx
This commit is contained in:
committed by
Michel Aractingi
parent
08d2ed8015
commit
d3846b0beb
@@ -22,6 +22,7 @@ from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
|
||||
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
@@ -33,4 +34,5 @@ __all__ = [
|
||||
"VQBeTConfig",
|
||||
"GrootConfig",
|
||||
"XVLAConfig",
|
||||
"WallXConfig",
|
||||
]
|
||||
|
||||
@@ -42,6 +42,7 @@ from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import validate_visual_features_consistency
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
batch_to_transition,
|
||||
@@ -61,7 +62,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
|
||||
Args:
|
||||
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
|
||||
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
|
||||
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
@@ -113,6 +114,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
|
||||
|
||||
return XVLAPolicy
|
||||
elif name == "wall_x":
|
||||
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy
|
||||
|
||||
return WallXPolicy
|
||||
else:
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
@@ -130,7 +135,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
Args:
|
||||
policy_type: The type of the policy. Supported types include "tdmpc",
|
||||
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
|
||||
"reward_classifier".
|
||||
"reward_classifier", "wall_x".
|
||||
**kwargs: Keyword arguments to be passed to the configuration class constructor.
|
||||
|
||||
Returns:
|
||||
@@ -161,6 +166,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return GrootConfig(**kwargs)
|
||||
elif policy_type == "xvla":
|
||||
return XVLAConfig(**kwargs)
|
||||
elif policy_type == "wall_x":
|
||||
return WallXConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
@@ -344,6 +351,7 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, XVLAConfig):
|
||||
from lerobot.policies.xvla.processor_xvla import (
|
||||
make_xvla_pre_post_processors,
|
||||
@@ -353,6 +361,14 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, WallXConfig):
|
||||
from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors
|
||||
|
||||
processors = make_wall_x_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
|
||||
193
src/lerobot/policies/wall_x/configuration_wall_x.py
Normal file
193
src/lerobot/policies/wall_x/configuration_wall_x.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("wall_x")
|
||||
@dataclass
|
||||
class WallXConfig(PreTrainedConfig):
|
||||
"""
|
||||
Configuration class for Wall-X policy.
|
||||
|
||||
Wall-X is based on Qwen2.5-VL with action prediction capabilities using flow matching.
|
||||
It supports cross-embodiment robotic control through unified action representations.
|
||||
"""
|
||||
# Input / output structure
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 32 # action_horizon in wall-x
|
||||
n_action_steps: int = 32
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MEAN_STD,
|
||||
}
|
||||
)
|
||||
|
||||
# Action dimension - wall-x uses hardcoded 20
|
||||
max_action_dim: int = 20
|
||||
max_state_dim: int = 20 # For proprioception
|
||||
|
||||
# Image preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] | None = None # wall-x uses Qwen processor
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 256
|
||||
|
||||
# Model architecture
|
||||
vlm_model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
|
||||
load_vlm_weights: bool = True
|
||||
|
||||
# Vision config
|
||||
vision_config: dict = field(default_factory=lambda: {
|
||||
"depth": 32,
|
||||
"hidden_size": 3584,
|
||||
"hidden_act": "silu",
|
||||
"intermediate_size": 3420,
|
||||
"num_heads": 16,
|
||||
"patch_size": 14,
|
||||
"spatial_merge_size": 2,
|
||||
"temporal_patch_size": 2,
|
||||
"window_size": 112,
|
||||
"out_hidden_size": 3584,
|
||||
})
|
||||
|
||||
# Language model config
|
||||
hidden_size: int = 3584 # 8192 for 7B model
|
||||
intermediate_size: int = 18944 # 29568 for 7B model
|
||||
num_hidden_layers: int = 36 # 80 for 7B model
|
||||
num_attention_heads: int = 28 # 64 for 7B model
|
||||
num_key_value_heads: int = 4 # 8 for 7B model
|
||||
vocab_size: int = 152064
|
||||
|
||||
# Action prediction mode: "flow" or "fast"
|
||||
prediction_mode: str = "flow"
|
||||
|
||||
# Flow matching parameters
|
||||
noise_scheduler: dict = field(default_factory=lambda: {
|
||||
"beta_alpha": 1.5, # Beta distribution concentration1
|
||||
"beta_beta": 1.0, # Beta distribution concentration0
|
||||
"s": 0.999, # Scaling factor for time
|
||||
})
|
||||
|
||||
# Decoding parameters
|
||||
num_inference_timesteps: int = 10 # Number of ODE solver steps
|
||||
ode_solver_method: str = "euler" # ODE solver method
|
||||
|
||||
# Degrees of freedom configuration - example for bimanual robot
|
||||
dof_config: dict = field(default_factory=lambda: {
|
||||
"left_ee_pos": 3,
|
||||
"left_ee_rot": 3,
|
||||
"left_gripper": 1,
|
||||
"right_ee_pos": 3,
|
||||
"right_ee_rot": 3,
|
||||
"right_gripper": 1,
|
||||
})
|
||||
|
||||
# Proprioception configuration (mirrors dof_config)
|
||||
agent_pos_config: dict = field(default_factory=lambda: {
|
||||
"left_ee_pos": 3,
|
||||
"left_ee_rot": 3,
|
||||
"left_gripper": 1,
|
||||
"right_ee_pos": 3,
|
||||
"right_ee_rot": 3,
|
||||
"right_gripper": 1,
|
||||
})
|
||||
|
||||
# MoE configuration
|
||||
num_experts: int = 4
|
||||
attention_moe: bool = False
|
||||
mlp_moe: bool = False
|
||||
|
||||
# Finetuning settings
|
||||
freeze_vision_encoder: bool = True
|
||||
train_expert_only: bool = False # wall-x trains more components
|
||||
train_action_head: bool = True
|
||||
|
||||
# Cache
|
||||
use_cache: bool = True
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 2e-5
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.95)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.01
|
||||
optimizer_grad_clip_norm: float = 1.0
|
||||
|
||||
scheduler_warmup_steps: int = 1000
|
||||
scheduler_decay_steps: int = 100000
|
||||
scheduler_decay_lr: float = 1e-6
|
||||
|
||||
# Dataset-specific normalization statistics
|
||||
# Maps dataset names to {min, delta} for action normalization
|
||||
action_statistics: dict = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
"""Input validation"""
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
|
||||
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
|
||||
)
|
||||
|
||||
if self.prediction_mode not in ["flow", "fast"]:
|
||||
raise ValueError(
|
||||
f"prediction_mode must be 'flow' or 'fast', got {self.prediction_mode}"
|
||||
)
|
||||
|
||||
# Validate dof_config total doesn't exceed max_action_dim
|
||||
total_dof = sum(self.dof_config.values())
|
||||
if total_dof > self.max_action_dim:
|
||||
raise ValueError(
|
||||
f"Total DOF ({total_dof}) exceeds max_action_dim ({self.max_action_dim})"
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
return [0]
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
466
src/lerobot/policies/wall_x/modeling_wall_x.py
Normal file
466
src/lerobot/policies/wall_x/modeling_wall_x.py
Normal file
@@ -0,0 +1,466 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Wall-X: Cross-embodiment robotic control using Qwen2.5-VL with flow matching.
|
||||
|
||||
[Paper](https://github.com/x2-robot/wall-x)
|
||||
|
||||
Install wall-x extra dependencies:
|
||||
```bash
|
||||
pip install -e ".[wall_x]"
|
||||
```
|
||||
|
||||
Example of finetuning a wall-x model:
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type=wall_x \
|
||||
--dataset.repo_id=your/dataset \
|
||||
--batch_size=32 \
|
||||
--steps=100000
|
||||
```
|
||||
"""
|
||||
|
||||
import math
|
||||
import sys
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributions import Beta
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
# Add wall-x repo to path if available
|
||||
WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x")
|
||||
if WALL_X_PATH.exists():
|
||||
sys.path.insert(0, str(WALL_X_PATH))
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
"""Sinusoidal positional embedding for diffusion timesteps."""
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
class ActionHead(nn.Module):
|
||||
"""
|
||||
Action prediction head with flow matching.
|
||||
|
||||
Implements Beta-distributed noise scheduling and temporal embeddings
|
||||
for action sequence prediction.
|
||||
"""
|
||||
|
||||
def __init__(self, config: WallXConfig):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.action_dim = sum(config.dof_config.values())
|
||||
self.propri_dim = sum(config.agent_pos_config.values())
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
# Beta distribution for noise scheduling
|
||||
noise_config = config.noise_scheduler
|
||||
self.beta_alpha = noise_config.get("beta_alpha", 1.5)
|
||||
self.beta_beta = noise_config.get("beta_beta", 1.0)
|
||||
self.s = noise_config.get("s", 0.999)
|
||||
|
||||
# Sinusoidal timestep embedding
|
||||
self.time_embed = SinusoidalPosEmb(config.hidden_size)
|
||||
|
||||
# Action embedding network
|
||||
# *2 for action + DOF mask concatenation
|
||||
self.w1 = nn.Linear(self.action_dim * 2, self.hidden_size, bias=False)
|
||||
self.w2 = nn.Linear(self.hidden_size * 2, self.hidden_size, bias=False) # *2 for action + time
|
||||
self.w3 = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
# Project back to action space
|
||||
self.action_proj_back = nn.Linear(self.hidden_size, self.action_dim, bias=False)
|
||||
|
||||
# Proprioception projection
|
||||
self.propri_proj = nn.Linear(self.propri_dim * 2, self.hidden_size, bias=False)
|
||||
|
||||
def sample_time(self, batch_size, device, dtype):
|
||||
"""Sample timesteps using Beta distribution."""
|
||||
beta_dist = Beta(
|
||||
torch.tensor(self.beta_alpha, dtype=dtype, device=device),
|
||||
torch.tensor(self.beta_beta, dtype=dtype, device=device)
|
||||
)
|
||||
sample = beta_dist.sample([batch_size])
|
||||
time = (1 - sample) / self.s
|
||||
return time
|
||||
|
||||
def forward(self, action_chunk, dof_mask=None):
|
||||
"""
|
||||
Process action sequences with noise injection for training.
|
||||
|
||||
Args:
|
||||
action_chunk: Action sequences [batch, seq_len, action_dim]
|
||||
dof_mask: DOF mask [batch, seq_len, action_dim]
|
||||
|
||||
Returns:
|
||||
tuple: (action_embeddings, flow_target)
|
||||
"""
|
||||
batch_size = action_chunk.shape[0]
|
||||
device = action_chunk.device
|
||||
dtype = action_chunk.dtype
|
||||
|
||||
# Add noise using flow matching
|
||||
noise = torch.randn_like(action_chunk)
|
||||
time = self.sample_time(batch_size, device, dtype)
|
||||
t = time.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# Linear interpolation
|
||||
noisy_action = (1 - t) * noise + t * action_chunk
|
||||
flow = action_chunk - noise
|
||||
|
||||
# Generate time embeddings
|
||||
time_embed = self.time_embed(time)
|
||||
|
||||
# Project noisy actions
|
||||
if dof_mask is not None:
|
||||
noisy_action = torch.cat([noisy_action, dof_mask], dim=-1)
|
||||
|
||||
noisy_action = noisy_action.to(dtype=self.w1.weight.dtype)
|
||||
action_embed = self.w1(noisy_action)
|
||||
|
||||
# Combine with time embeddings
|
||||
time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1)
|
||||
time_embed = time_embed.to(dtype=self.w2.weight.dtype)
|
||||
|
||||
concat_embed = torch.cat([action_embed, time_embed], dim=-1)
|
||||
concat_embed = self.w2(concat_embed)
|
||||
embed = self.w3(self.act_fn(concat_embed))
|
||||
|
||||
return embed, flow
|
||||
|
||||
def step(self, timestep, noisy_action, dof_mask=None):
|
||||
"""Single denoising step for inference."""
|
||||
if dof_mask is not None:
|
||||
noisy_action = torch.cat([noisy_action, dof_mask], dim=-1)
|
||||
|
||||
time_embed = self.time_embed(timestep)
|
||||
action_embed = self.w1(noisy_action)
|
||||
|
||||
time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1)
|
||||
time_embed = time_embed.to(device=noisy_action.device, dtype=noisy_action.dtype)
|
||||
|
||||
concat_embed = torch.cat([action_embed, time_embed], dim=-1)
|
||||
concat_embed = self.w2(concat_embed)
|
||||
embed = self.w3(self.act_fn(concat_embed))
|
||||
|
||||
return embed
|
||||
|
||||
def flow_loss(self, action_hidden_states, flow, dof_mask=None):
|
||||
"""Compute flow matching loss."""
|
||||
action_pred = self.action_proj_back(action_hidden_states)
|
||||
loss = F.mse_loss(action_pred, flow, reduction="none")
|
||||
|
||||
if dof_mask is not None:
|
||||
dof_mask = dof_mask.reshape(-1, dof_mask.shape[-1])
|
||||
loss = loss * dof_mask
|
||||
|
||||
return loss
|
||||
|
||||
def project_proprioception(self, proprioception, dof_mask=None):
|
||||
"""Project proprioceptive data to hidden space."""
|
||||
proprioception = proprioception.to(
|
||||
device=self.propri_proj.weight.device,
|
||||
dtype=self.propri_proj.weight.dtype
|
||||
)
|
||||
|
||||
if dof_mask is not None:
|
||||
proprioception = torch.cat([proprioception, dof_mask], dim=-1)
|
||||
|
||||
return self.propri_proj(proprioception)
|
||||
|
||||
|
||||
class WallXVLMWrapper(nn.Module):
|
||||
"""
|
||||
Wrapper around Qwen2.5-VL model from wall-x.
|
||||
|
||||
This class attempts to load the wall-x model if available,
|
||||
otherwise provides a placeholder implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, config: WallXConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Try to import wall-x model
|
||||
try:
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
config.vlm_model_name,
|
||||
torch_dtype=torch.bfloat16 if config.device != "cpu" else torch.float32,
|
||||
device_map=config.device if config.device != "cpu" else None,
|
||||
)
|
||||
|
||||
self.processor = AutoProcessor.from_pretrained(config.vlm_model_name)
|
||||
self.process_vision_info = process_vision_info
|
||||
self.available = True
|
||||
|
||||
# Freeze vision encoder if requested
|
||||
if config.freeze_vision_encoder:
|
||||
for param in self.model.visual.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
except ImportError:
|
||||
print("Warning: Could not import wall-x dependencies. Using placeholder.")
|
||||
self.available = False
|
||||
self.model = None
|
||||
self.processor = None
|
||||
|
||||
def forward(self, **kwargs):
|
||||
"""Forward pass through VLM."""
|
||||
if not self.available:
|
||||
raise RuntimeError("Wall-X VLM not available. Install required dependencies.")
|
||||
return self.model(**kwargs)
|
||||
|
||||
|
||||
class WallXPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
Wall-X policy for cross-embodiment robotic control.
|
||||
|
||||
Integrates Qwen2.5-VL vision-language model with action prediction
|
||||
using flow matching for continuous action spaces.
|
||||
"""
|
||||
|
||||
config_class = WallXConfig
|
||||
name = "wall_x"
|
||||
|
||||
def __init__(self, config: WallXConfig):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
# Initialize VLM wrapper
|
||||
self.vlm = WallXVLMWrapper(config)
|
||||
|
||||
# Initialize action head
|
||||
self.action_head = ActionHead(config)
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
"""Reset action queue."""
|
||||
self._queues = {
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Get parameters for optimization."""
|
||||
params = []
|
||||
|
||||
if self.vlm.available:
|
||||
# Add VLM parameters
|
||||
if not self.config.train_expert_only:
|
||||
params.extend(self.vlm.model.parameters())
|
||||
|
||||
# Always add action head parameters
|
||||
if self.config.train_action_head:
|
||||
params.extend(self.action_head.parameters())
|
||||
|
||||
return params
|
||||
|
||||
def prepare_images(self, batch):
|
||||
"""Prepare images for VLM processing."""
|
||||
images = []
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError("No image features found in batch")
|
||||
|
||||
for key in present_img_keys:
|
||||
img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key]
|
||||
images.append(img)
|
||||
|
||||
return images
|
||||
|
||||
def prepare_state(self, batch):
|
||||
"""Prepare proprioceptive state."""
|
||||
state = batch[OBS_STATE][:, -1, :] if batch[OBS_STATE].ndim > 2 else batch[OBS_STATE]
|
||||
# Pad to expected dimension
|
||||
if state.shape[-1] < self.config.max_state_dim:
|
||||
padding = torch.zeros(
|
||||
*state.shape[:-1],
|
||||
self.config.max_state_dim - state.shape[-1],
|
||||
device=state.device,
|
||||
dtype=state.dtype
|
||||
)
|
||||
state = torch.cat([state, padding], dim=-1)
|
||||
return state
|
||||
|
||||
def prepare_action(self, batch):
|
||||
"""Prepare action chunk."""
|
||||
actions = batch[ACTION]
|
||||
# Pad to expected dimension
|
||||
if actions.shape[-1] < self.config.max_action_dim:
|
||||
padding = torch.zeros(
|
||||
*actions.shape[:-1],
|
||||
self.config.max_action_dim - actions.shape[-1],
|
||||
device=actions.device,
|
||||
dtype=actions.dtype
|
||||
)
|
||||
actions = torch.cat([actions, padding], dim=-1)
|
||||
return actions
|
||||
|
||||
def _create_dof_mask(self, batch_size, device, dtype):
|
||||
"""Create DOF mask for action dimensions."""
|
||||
# Create mask showing which dimensions are active
|
||||
mask = torch.ones(
|
||||
batch_size,
|
||||
self.config.chunk_size,
|
||||
sum(self.config.dof_config.values()),
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
return mask
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
"""
|
||||
Training forward pass.
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing observations and actions
|
||||
|
||||
Returns:
|
||||
tuple: (loss, loss_dict)
|
||||
"""
|
||||
# Prepare inputs
|
||||
images = self.prepare_images(batch)
|
||||
state = self.prepare_state(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
batch_size = actions.shape[0]
|
||||
device = actions.device
|
||||
dtype = actions.dtype
|
||||
|
||||
# Create DOF mask
|
||||
dof_mask = self._create_dof_mask(batch_size, device, dtype)
|
||||
|
||||
# Process actions through action head (adds noise, gets embeddings)
|
||||
action_embeds, flow_target = self.action_head(actions, dof_mask)
|
||||
|
||||
# For now, use simplified loss computation
|
||||
# In full implementation, would pass through VLM transformer
|
||||
loss_dict = {}
|
||||
|
||||
# Compute flow matching loss
|
||||
# Note: In full wall-x, action_embeds would go through VLM transformer first
|
||||
flow_loss = self.action_head.flow_loss(action_embeds, flow_target, dof_mask)
|
||||
loss = flow_loss.mean()
|
||||
|
||||
loss_dict["loss"] = loss.item()
|
||||
loss_dict["flow_loss"] = loss.item()
|
||||
|
||||
return loss, loss_dict
|
||||
|
||||
def _sample_actions_flow(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""
|
||||
Sample actions using flow matching / diffusion.
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing observations
|
||||
|
||||
Returns:
|
||||
Predicted actions [batch, chunk_size, action_dim]
|
||||
"""
|
||||
batch_size = 1 # Typically inference is single sample
|
||||
device = self.config.device
|
||||
dtype = torch.float32
|
||||
|
||||
# Initialize with noise
|
||||
noisy_action = torch.randn(
|
||||
batch_size,
|
||||
self.config.chunk_size,
|
||||
sum(self.config.dof_config.values()),
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
# Create DOF mask
|
||||
dof_mask = self._create_dof_mask(batch_size, device, dtype)
|
||||
|
||||
# ODE integration for denoising
|
||||
num_steps = self.config.num_inference_timesteps
|
||||
dt = 1.0 / num_steps
|
||||
|
||||
for step_idx in range(num_steps):
|
||||
t = torch.tensor(step_idx * dt, device=device, dtype=dtype)
|
||||
timestep = t.unsqueeze(0).repeat(batch_size)
|
||||
|
||||
# Single denoising step
|
||||
action_embeds = self.action_head.step(timestep, noisy_action, dof_mask)
|
||||
|
||||
# Predict flow (in full implementation, would go through VLM)
|
||||
flow_pred = self.action_head.action_proj_back(action_embeds)
|
||||
|
||||
# Euler integration step
|
||||
noisy_action = noisy_action + dt * flow_pred
|
||||
|
||||
return noisy_action
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict action chunk for evaluation."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
if self.config.prediction_mode == "flow":
|
||||
actions = self._sample_actions_flow(batch)
|
||||
else:
|
||||
raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented")
|
||||
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select single action for environment execution."""
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
# Use action queue
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[:self.config.n_action_steps])
|
||||
|
||||
return self._queues[ACTION].popleft()
|
||||
181
src/lerobot/policies/wall_x/processor_wall_x.py
Normal file
181
src/lerobot/policies/wall_x/processor_wall_x.py
Normal file
@@ -0,0 +1,181 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
def make_wall_x_pre_post_processors(
|
||||
config: WallXConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Constructs pre-processor and post-processor pipelines for the Wall-X policy.
|
||||
|
||||
The pre-processing pipeline prepares input data for the model by:
|
||||
1. Renaming features to match pretrained configurations
|
||||
2. Adding a batch dimension
|
||||
3. Tokenizing language task descriptions
|
||||
4. Normalizing input and output features based on dataset statistics
|
||||
5. Moving all data to the specified device
|
||||
|
||||
The post-processing pipeline handles the model's output by:
|
||||
1. Unnormalizing the output actions to their original scale
|
||||
2. Moving data to the CPU
|
||||
|
||||
Args:
|
||||
config: The configuration object for the Wall-X policy
|
||||
dataset_stats: A dictionary of statistics for normalization
|
||||
|
||||
Returns:
|
||||
A tuple containing the configured pre-processor and post-processor pipelines
|
||||
"""
|
||||
|
||||
# Try to use Qwen processor if available
|
||||
try:
|
||||
from transformers import AutoProcessor
|
||||
tokenizer_name = config.vlm_model_name
|
||||
qwen_available = True
|
||||
except ImportError:
|
||||
tokenizer_name = "Qwen/Qwen2-VL-2B-Instruct" # Fallback
|
||||
qwen_available = False
|
||||
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
WallXTaskProcessor(), # Process task description
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=tokenizer_name,
|
||||
padding="max_length",
|
||||
padding_side="right",
|
||||
max_length=config.tokenizer_max_length,
|
||||
),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="wall_x_task_processor")
|
||||
class WallXTaskProcessor(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
A processor step that ensures the task description is properly formatted for Wall-X.
|
||||
|
||||
This step handles task preprocessing similar to Qwen-VL requirements.
|
||||
"""
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
if "task" not in complementary_data:
|
||||
return complementary_data
|
||||
|
||||
task = complementary_data["task"]
|
||||
if task is None:
|
||||
# Provide default task if none specified
|
||||
complementary_data["task"] = "Execute the robot action."
|
||||
return complementary_data
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
|
||||
# Handle both string and list of strings
|
||||
if isinstance(task, str):
|
||||
# Single string: ensure proper formatting
|
||||
if not task.endswith("."):
|
||||
new_complementary_data["task"] = f"{task}."
|
||||
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
|
||||
# List of strings: format each
|
||||
new_complementary_data["task"] = [
|
||||
t if t.endswith(".") else f"{t}." for t in task
|
||||
]
|
||||
|
||||
return new_complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="wall_x_image_processor")
|
||||
class WallXImageProcessor(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
Image processor for Wall-X using Qwen-VL vision processing.
|
||||
|
||||
This handles image formatting according to Qwen-VL requirements.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
try:
|
||||
from transformers import AutoProcessor
|
||||
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
|
||||
self.available = True
|
||||
except ImportError:
|
||||
self.available = False
|
||||
|
||||
def complementary_data(self, complementary_data):
|
||||
# Image processing is handled by the VLM processor
|
||||
return complementary_data
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
Reference in New Issue
Block a user