support wallx

This commit is contained in:
vincentchen
2025-11-05 11:57:48 +08:00
committed by Michel Aractingi
parent 08d2ed8015
commit d3846b0beb
5 changed files with 860 additions and 2 deletions

View File

@@ -22,6 +22,7 @@ from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
from .wall_x.configuration_wall_x import WallXConfig as WallXConfig
__all__ = [
"ACTConfig",
@@ -33,4 +34,5 @@ __all__ = [
"VQBeTConfig",
"GrootConfig",
"XVLAConfig",
"WallXConfig",
]

View File

@@ -42,6 +42,7 @@ from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
batch_to_transition,
@@ -61,7 +62,7 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
Args:
name: The name of the policy. Supported names are "tdmpc", "diffusion", "act",
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla".
"vqbet", "pi0", "pi05", "sac", "reward_classifier", "smolvla", "wall_x".
Returns:
The policy class corresponding to the given name.
@@ -113,6 +114,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
return XVLAPolicy
elif name == "wall_x":
from lerobot.policies.wall_x.modeling_wall_x import WallXPolicy
return WallXPolicy
else:
try:
return _get_policy_cls_from_policy_name(name=name)
@@ -130,7 +135,7 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
Args:
policy_type: The type of the policy. Supported types include "tdmpc",
"diffusion", "act", "vqbet", "pi0", "pi05", "sac", "smolvla",
"reward_classifier".
"reward_classifier", "wall_x".
**kwargs: Keyword arguments to be passed to the configuration class constructor.
Returns:
@@ -161,6 +166,8 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return GrootConfig(**kwargs)
elif policy_type == "xvla":
return XVLAConfig(**kwargs)
elif policy_type == "wall_x":
return WallXConfig(**kwargs)
else:
try:
config_cls = PreTrainedConfig.get_choice_class(policy_type)
@@ -344,6 +351,7 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, XVLAConfig):
from lerobot.policies.xvla.processor_xvla import (
make_xvla_pre_post_processors,
@@ -353,6 +361,14 @@ def make_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, WallXConfig):
from lerobot.policies.wall_x.processor_wall_x import make_wall_x_pre_post_processors
processors = make_wall_x_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
else:
try:

View File

@@ -0,0 +1,193 @@
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import OBS_IMAGES
@PreTrainedConfig.register_subclass("wall_x")
@dataclass
class WallXConfig(PreTrainedConfig):
"""
Configuration class for Wall-X policy.
Wall-X is based on Qwen2.5-VL with action prediction capabilities using flow matching.
It supports cross-embodiment robotic control through unified action representations.
"""
# Input / output structure
n_obs_steps: int = 1
chunk_size: int = 32 # action_horizon in wall-x
n_action_steps: int = 32
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Action dimension - wall-x uses hardcoded 20
max_action_dim: int = 20
max_state_dim: int = 20 # For proprioception
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] | None = None # wall-x uses Qwen processor
# Tokenizer
tokenizer_max_length: int = 256
# Model architecture
vlm_model_name: str = "Qwen/Qwen2.5-VL-3B-Instruct"
load_vlm_weights: bool = True
# Vision config
vision_config: dict = field(default_factory=lambda: {
"depth": 32,
"hidden_size": 3584,
"hidden_act": "silu",
"intermediate_size": 3420,
"num_heads": 16,
"patch_size": 14,
"spatial_merge_size": 2,
"temporal_patch_size": 2,
"window_size": 112,
"out_hidden_size": 3584,
})
# Language model config
hidden_size: int = 3584 # 8192 for 7B model
intermediate_size: int = 18944 # 29568 for 7B model
num_hidden_layers: int = 36 # 80 for 7B model
num_attention_heads: int = 28 # 64 for 7B model
num_key_value_heads: int = 4 # 8 for 7B model
vocab_size: int = 152064
# Action prediction mode: "flow" or "fast"
prediction_mode: str = "flow"
# Flow matching parameters
noise_scheduler: dict = field(default_factory=lambda: {
"beta_alpha": 1.5, # Beta distribution concentration1
"beta_beta": 1.0, # Beta distribution concentration0
"s": 0.999, # Scaling factor for time
})
# Decoding parameters
num_inference_timesteps: int = 10 # Number of ODE solver steps
ode_solver_method: str = "euler" # ODE solver method
# Degrees of freedom configuration - example for bimanual robot
dof_config: dict = field(default_factory=lambda: {
"left_ee_pos": 3,
"left_ee_rot": 3,
"left_gripper": 1,
"right_ee_pos": 3,
"right_ee_rot": 3,
"right_gripper": 1,
})
# Proprioception configuration (mirrors dof_config)
agent_pos_config: dict = field(default_factory=lambda: {
"left_ee_pos": 3,
"left_ee_rot": 3,
"left_gripper": 1,
"right_ee_pos": 3,
"right_ee_rot": 3,
"right_gripper": 1,
})
# MoE configuration
num_experts: int = 4
attention_moe: bool = False
mlp_moe: bool = False
# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = False # wall-x trains more components
train_action_head: bool = True
# Cache
use_cache: bool = True
# Training presets
optimizer_lr: float = 2e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 0.01
optimizer_grad_clip_norm: float = 1.0
scheduler_warmup_steps: int = 1000
scheduler_decay_steps: int = 100000
scheduler_decay_lr: float = 1e-6
# Dataset-specific normalization statistics
# Maps dataset names to {min, delta} for action normalization
action_statistics: dict = field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
"""Input validation"""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.prediction_mode not in ["flow", "fast"]:
raise ValueError(
f"prediction_mode must be 'flow' or 'fast', got {self.prediction_mode}"
)
# Validate dof_config total doesn't exceed max_action_dim
total_dof = sum(self.dof_config.values())
if total_dof > self.max_action_dim:
raise ValueError(
f"Total DOF ({total_dof}) exceeds max_action_dim ({self.max_action_dim})"
)
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=self.optimizer_grad_clip_norm,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> list:
return [0]
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None

View File

@@ -0,0 +1,466 @@
#!/usr/bin/env python
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Wall-X: Cross-embodiment robotic control using Qwen2.5-VL with flow matching.
[Paper](https://github.com/x2-robot/wall-x)
Install wall-x extra dependencies:
```bash
pip install -e ".[wall_x]"
```
Example of finetuning a wall-x model:
```bash
lerobot-train \
--policy.type=wall_x \
--dataset.repo_id=your/dataset \
--batch_size=32 \
--steps=100000
```
"""
import math
import sys
from collections import deque
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Beta
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
from lerobot.policies.utils import populate_queues
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
from lerobot.utils.utils import get_safe_dtype
# Add wall-x repo to path if available
WALL_X_PATH = Path("/x2robot_v2/vincent/workspace/lerobot_opensource/wall-x")
if WALL_X_PATH.exists():
sys.path.insert(0, str(WALL_X_PATH))
class SinusoidalPosEmb(nn.Module):
"""Sinusoidal positional embedding for diffusion timesteps."""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
class ActionHead(nn.Module):
"""
Action prediction head with flow matching.
Implements Beta-distributed noise scheduling and temporal embeddings
for action sequence prediction.
"""
def __init__(self, config: WallXConfig):
super().__init__()
self.config = config
self.action_dim = sum(config.dof_config.values())
self.propri_dim = sum(config.agent_pos_config.values())
self.hidden_size = config.hidden_size
# Beta distribution for noise scheduling
noise_config = config.noise_scheduler
self.beta_alpha = noise_config.get("beta_alpha", 1.5)
self.beta_beta = noise_config.get("beta_beta", 1.0)
self.s = noise_config.get("s", 0.999)
# Sinusoidal timestep embedding
self.time_embed = SinusoidalPosEmb(config.hidden_size)
# Action embedding network
# *2 for action + DOF mask concatenation
self.w1 = nn.Linear(self.action_dim * 2, self.hidden_size, bias=False)
self.w2 = nn.Linear(self.hidden_size * 2, self.hidden_size, bias=False) # *2 for action + time
self.w3 = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.act_fn = nn.SiLU()
# Project back to action space
self.action_proj_back = nn.Linear(self.hidden_size, self.action_dim, bias=False)
# Proprioception projection
self.propri_proj = nn.Linear(self.propri_dim * 2, self.hidden_size, bias=False)
def sample_time(self, batch_size, device, dtype):
"""Sample timesteps using Beta distribution."""
beta_dist = Beta(
torch.tensor(self.beta_alpha, dtype=dtype, device=device),
torch.tensor(self.beta_beta, dtype=dtype, device=device)
)
sample = beta_dist.sample([batch_size])
time = (1 - sample) / self.s
return time
def forward(self, action_chunk, dof_mask=None):
"""
Process action sequences with noise injection for training.
Args:
action_chunk: Action sequences [batch, seq_len, action_dim]
dof_mask: DOF mask [batch, seq_len, action_dim]
Returns:
tuple: (action_embeddings, flow_target)
"""
batch_size = action_chunk.shape[0]
device = action_chunk.device
dtype = action_chunk.dtype
# Add noise using flow matching
noise = torch.randn_like(action_chunk)
time = self.sample_time(batch_size, device, dtype)
t = time.unsqueeze(-1).unsqueeze(-1)
# Linear interpolation
noisy_action = (1 - t) * noise + t * action_chunk
flow = action_chunk - noise
# Generate time embeddings
time_embed = self.time_embed(time)
# Project noisy actions
if dof_mask is not None:
noisy_action = torch.cat([noisy_action, dof_mask], dim=-1)
noisy_action = noisy_action.to(dtype=self.w1.weight.dtype)
action_embed = self.w1(noisy_action)
# Combine with time embeddings
time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1)
time_embed = time_embed.to(dtype=self.w2.weight.dtype)
concat_embed = torch.cat([action_embed, time_embed], dim=-1)
concat_embed = self.w2(concat_embed)
embed = self.w3(self.act_fn(concat_embed))
return embed, flow
def step(self, timestep, noisy_action, dof_mask=None):
"""Single denoising step for inference."""
if dof_mask is not None:
noisy_action = torch.cat([noisy_action, dof_mask], dim=-1)
time_embed = self.time_embed(timestep)
action_embed = self.w1(noisy_action)
time_embed = time_embed.unsqueeze(1).repeat(1, action_embed.shape[1], 1)
time_embed = time_embed.to(device=noisy_action.device, dtype=noisy_action.dtype)
concat_embed = torch.cat([action_embed, time_embed], dim=-1)
concat_embed = self.w2(concat_embed)
embed = self.w3(self.act_fn(concat_embed))
return embed
def flow_loss(self, action_hidden_states, flow, dof_mask=None):
"""Compute flow matching loss."""
action_pred = self.action_proj_back(action_hidden_states)
loss = F.mse_loss(action_pred, flow, reduction="none")
if dof_mask is not None:
dof_mask = dof_mask.reshape(-1, dof_mask.shape[-1])
loss = loss * dof_mask
return loss
def project_proprioception(self, proprioception, dof_mask=None):
"""Project proprioceptive data to hidden space."""
proprioception = proprioception.to(
device=self.propri_proj.weight.device,
dtype=self.propri_proj.weight.dtype
)
if dof_mask is not None:
proprioception = torch.cat([proprioception, dof_mask], dim=-1)
return self.propri_proj(proprioception)
class WallXVLMWrapper(nn.Module):
"""
Wrapper around Qwen2.5-VL model from wall-x.
This class attempts to load the wall-x model if available,
otherwise provides a placeholder implementation.
"""
def __init__(self, config: WallXConfig):
super().__init__()
self.config = config
# Try to import wall-x model
try:
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
config.vlm_model_name,
torch_dtype=torch.bfloat16 if config.device != "cpu" else torch.float32,
device_map=config.device if config.device != "cpu" else None,
)
self.processor = AutoProcessor.from_pretrained(config.vlm_model_name)
self.process_vision_info = process_vision_info
self.available = True
# Freeze vision encoder if requested
if config.freeze_vision_encoder:
for param in self.model.visual.parameters():
param.requires_grad = False
except ImportError:
print("Warning: Could not import wall-x dependencies. Using placeholder.")
self.available = False
self.model = None
self.processor = None
def forward(self, **kwargs):
"""Forward pass through VLM."""
if not self.available:
raise RuntimeError("Wall-X VLM not available. Install required dependencies.")
return self.model(**kwargs)
class WallXPolicy(PreTrainedPolicy):
"""
Wall-X policy for cross-embodiment robotic control.
Integrates Qwen2.5-VL vision-language model with action prediction
using flow matching for continuous action spaces.
"""
config_class = WallXConfig
name = "wall_x"
def __init__(self, config: WallXConfig):
super().__init__(config)
config.validate_features()
self.config = config
# Initialize VLM wrapper
self.vlm = WallXVLMWrapper(config)
# Initialize action head
self.action_head = ActionHead(config)
self.reset()
def reset(self):
"""Reset action queue."""
self._queues = {
ACTION: deque(maxlen=self.config.n_action_steps),
}
def get_optim_params(self):
"""Get parameters for optimization."""
params = []
if self.vlm.available:
# Add VLM parameters
if not self.config.train_expert_only:
params.extend(self.vlm.model.parameters())
# Always add action head parameters
if self.config.train_action_head:
params.extend(self.action_head.parameters())
return params
def prepare_images(self, batch):
"""Prepare images for VLM processing."""
images = []
present_img_keys = [key for key in self.config.image_features if key in batch]
if len(present_img_keys) == 0:
raise ValueError("No image features found in batch")
for key in present_img_keys:
img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key]
images.append(img)
return images
def prepare_state(self, batch):
"""Prepare proprioceptive state."""
state = batch[OBS_STATE][:, -1, :] if batch[OBS_STATE].ndim > 2 else batch[OBS_STATE]
# Pad to expected dimension
if state.shape[-1] < self.config.max_state_dim:
padding = torch.zeros(
*state.shape[:-1],
self.config.max_state_dim - state.shape[-1],
device=state.device,
dtype=state.dtype
)
state = torch.cat([state, padding], dim=-1)
return state
def prepare_action(self, batch):
"""Prepare action chunk."""
actions = batch[ACTION]
# Pad to expected dimension
if actions.shape[-1] < self.config.max_action_dim:
padding = torch.zeros(
*actions.shape[:-1],
self.config.max_action_dim - actions.shape[-1],
device=actions.device,
dtype=actions.dtype
)
actions = torch.cat([actions, padding], dim=-1)
return actions
def _create_dof_mask(self, batch_size, device, dtype):
"""Create DOF mask for action dimensions."""
# Create mask showing which dimensions are active
mask = torch.ones(
batch_size,
self.config.chunk_size,
sum(self.config.dof_config.values()),
device=device,
dtype=dtype
)
return mask
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
"""
Training forward pass.
Args:
batch: Dictionary containing observations and actions
Returns:
tuple: (loss, loss_dict)
"""
# Prepare inputs
images = self.prepare_images(batch)
state = self.prepare_state(batch)
actions = self.prepare_action(batch)
batch_size = actions.shape[0]
device = actions.device
dtype = actions.dtype
# Create DOF mask
dof_mask = self._create_dof_mask(batch_size, device, dtype)
# Process actions through action head (adds noise, gets embeddings)
action_embeds, flow_target = self.action_head(actions, dof_mask)
# For now, use simplified loss computation
# In full implementation, would pass through VLM transformer
loss_dict = {}
# Compute flow matching loss
# Note: In full wall-x, action_embeds would go through VLM transformer first
flow_loss = self.action_head.flow_loss(action_embeds, flow_target, dof_mask)
loss = flow_loss.mean()
loss_dict["loss"] = loss.item()
loss_dict["flow_loss"] = loss.item()
return loss, loss_dict
def _sample_actions_flow(self, batch: dict[str, Tensor]) -> Tensor:
"""
Sample actions using flow matching / diffusion.
Args:
batch: Dictionary containing observations
Returns:
Predicted actions [batch, chunk_size, action_dim]
"""
batch_size = 1 # Typically inference is single sample
device = self.config.device
dtype = torch.float32
# Initialize with noise
noisy_action = torch.randn(
batch_size,
self.config.chunk_size,
sum(self.config.dof_config.values()),
device=device,
dtype=dtype
)
# Create DOF mask
dof_mask = self._create_dof_mask(batch_size, device, dtype)
# ODE integration for denoising
num_steps = self.config.num_inference_timesteps
dt = 1.0 / num_steps
for step_idx in range(num_steps):
t = torch.tensor(step_idx * dt, device=device, dtype=dtype)
timestep = t.unsqueeze(0).repeat(batch_size)
# Single denoising step
action_embeds = self.action_head.step(timestep, noisy_action, dof_mask)
# Predict flow (in full implementation, would go through VLM)
flow_pred = self.action_head.action_proj_back(action_embeds)
# Euler integration step
noisy_action = noisy_action + dt * flow_pred
return noisy_action
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict action chunk for evaluation."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
if self.config.prediction_mode == "flow":
actions = self._sample_actions_flow(batch)
else:
raise NotImplementedError(f"Prediction mode {self.config.prediction_mode} not implemented")
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
return actions
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select single action for environment execution."""
self.eval()
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
# Use action queue
if len(self._queues[ACTION]) == 0:
actions = self.predict_action_chunk(batch)
self._queues[ACTION].extend(actions.transpose(0, 1)[:self.config.n_action_steps])
return self._queues[ACTION].popleft()

View File

@@ -0,0 +1,181 @@
#!/usr/bin/env python
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.wall_x.configuration_wall_x import WallXConfig
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
def make_wall_x_pre_post_processors(
config: WallXConfig,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the Wall-X policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations
2. Adding a batch dimension
3. Tokenizing language task descriptions
4. Normalizing input and output features based on dataset statistics
5. Moving all data to the specified device
The post-processing pipeline handles the model's output by:
1. Unnormalizing the output actions to their original scale
2. Moving data to the CPU
Args:
config: The configuration object for the Wall-X policy
dataset_stats: A dictionary of statistics for normalization
Returns:
A tuple containing the configured pre-processor and post-processor pipelines
"""
# Try to use Qwen processor if available
try:
from transformers import AutoProcessor
tokenizer_name = config.vlm_model_name
qwen_available = True
except ImportError:
tokenizer_name = "Qwen/Qwen2-VL-2B-Instruct" # Fallback
qwen_available = False
input_steps = [
RenameObservationsProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
WallXTaskProcessor(), # Process task description
TokenizerProcessorStep(
tokenizer_name=tokenizer_name,
padding="max_length",
padding_side="right",
max_length=config.tokenizer_max_length,
),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeviceProcessorStep(device=config.device),
]
output_steps = [
UnnormalizerProcessorStep(
features=config.output_features,
norm_map=config.normalization_mapping,
stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@ProcessorStepRegistry.register(name="wall_x_task_processor")
class WallXTaskProcessor(ComplementaryDataProcessorStep):
"""
A processor step that ensures the task description is properly formatted for Wall-X.
This step handles task preprocessing similar to Qwen-VL requirements.
"""
def complementary_data(self, complementary_data):
if "task" not in complementary_data:
return complementary_data
task = complementary_data["task"]
if task is None:
# Provide default task if none specified
complementary_data["task"] = "Execute the robot action."
return complementary_data
new_complementary_data = dict(complementary_data)
# Handle both string and list of strings
if isinstance(task, str):
# Single string: ensure proper formatting
if not task.endswith("."):
new_complementary_data["task"] = f"{task}."
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: format each
new_complementary_data["task"] = [
t if t.endswith(".") else f"{t}." for t in task
]
return new_complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register(name="wall_x_image_processor")
class WallXImageProcessor(ComplementaryDataProcessorStep):
"""
Image processor for Wall-X using Qwen-VL vision processing.
This handles image formatting according to Qwen-VL requirements.
"""
def __init__(self):
super().__init__()
try:
from transformers import AutoProcessor
self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
self.available = True
except ImportError:
self.available = False
def complementary_data(self, complementary_data):
# Image processing is handled by the VLM processor
return complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features