add franka action

This commit is contained in:
Jade Choghari
2025-11-07 14:28:36 +01:00
parent 8a65623dec
commit 3cb14248a4
10 changed files with 508 additions and 458 deletions

View File

@@ -39,8 +39,8 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
from lerobot.policies.utils import validate_visual_features_consistency
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
from lerobot.processor.converters import (
batch_to_transition,

View File

@@ -15,18 +15,21 @@
# ------------------------------------------------------------------------------
from __future__ import annotations
from typing import Iterable, Tuple, Dict, Type
from collections.abc import Iterable
import torch
import torch.nn as nn
# =============================================================================
# Registry
# =============================================================================
ACTION_REGISTRY: Dict[str, Type["BaseActionSpace"]] = {}
ACTION_REGISTRY: dict[str, type[BaseActionSpace]] = {}
def register_action(name: str):
"""Decorator for registering a new action space."""
def _wrap(cls):
key = name.lower()
if key in ACTION_REGISTRY:
@@ -34,10 +37,11 @@ def register_action(name: str):
ACTION_REGISTRY[key] = cls
cls.name = key
return cls
return _wrap
def build_action_space(name: str, **kwargs) -> "BaseActionSpace":
def build_action_space(name: str, **kwargs) -> BaseActionSpace:
"""Instantiate a registered action space by name."""
key = name.lower()
if key not in ACTION_REGISTRY:
@@ -62,7 +66,7 @@ class BaseActionSpace(nn.Module):
name: str = "base"
dim_action: int = 0
gripper_idx: Tuple[int, ...] = ()
gripper_idx: tuple[int, ...] = ()
def __init__(self):
super().__init__()
@@ -70,10 +74,10 @@ class BaseActionSpace(nn.Module):
# ---------------------------------------------------------------------
# Core supervised loss
# ---------------------------------------------------------------------
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
raise NotImplementedError
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
"""Alias for compute_loss."""
return self.compute_loss(pred, target)
@@ -85,7 +89,7 @@ class BaseActionSpace(nn.Module):
proprio: torch.Tensor,
action: torch.Tensor,
mode: str = "train",
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""Default: return unchanged."""
return proprio, action
@@ -137,14 +141,14 @@ class EE6DActionSpace(BaseActionSpace):
# XYZ position
pos_loss = (
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
) * self.XYZ_SCALE
# Rotation 6D
rot_loss = (
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
) * self.ROT_SCALE
return {
@@ -236,14 +240,16 @@ class AGIBOTEE6DActionSpace(BaseActionSpace):
B, T, D = pred.shape
_ensure_indices_valid(D, self.gripper_idx, "gripper_idx")
gripper_loss = self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
gripper_loss = (
self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
)
pos_loss = (
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1]) +
self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
) * self.XYZ_SCALE
rot_loss = (
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1]) +
self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
) * self.ROT_SCALE
return {
@@ -261,6 +267,32 @@ class AGIBOTEE6DActionSpace(BaseActionSpace):
return action
@register_action("franka_joint7")
class FrankaJoint7ActionSpace(BaseActionSpace):
"""Franka Panda joint-space: 7 joints, no gripper."""
dim_action = 7
JOINTS_SCALE = 1.0
def __init__(self):
super().__init__()
self.mse = nn.MSELoss()
def compute_loss(self, pred, target):
assert pred.shape == target.shape, "pred/target shapes must match"
B, T, D = pred.shape
joints_loss = self.mse(pred, target) * self.JOINTS_SCALE
return {"joints_loss": joints_loss}
def preprocess(self, proprio, action, mode="train"):
"""No preprocessing needed for 7 joint actions."""
return proprio, action
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
"""Return directly (no sigmoid since no gripper)."""
return action
# =============================================================================
# Exports
# =============================================================================
@@ -271,5 +303,6 @@ __all__ = [
"EE6DActionSpace",
"JointActionSpace",
"AGIBOTEE6DActionSpace",
"FrankaJoint7ActionSpace",
"ACTION_REGISTRY",
]
]

View File

@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,20 +11,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
""" Florence-2 configuration"""
from typing import Optional
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class Florence2VisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
@@ -118,7 +117,6 @@ class Florence2VisionConfig(PretrainedConfig):
super().__init__(**kwargs)
class Florence2LanguageConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
@@ -269,10 +267,11 @@ class Florence2LanguageConfig(PretrainedConfig):
"The config can simply be saved and uploaded again to be fixed."
)
class Florence2Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
Florence-2 model according to the specified arguments, defining the model architecture.
Florence-2 model according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
@@ -281,7 +280,7 @@ class Florence2Config(PretrainedConfig):
vision_config (`Florence2VisionConfig`, *optional*):
Custom vision config or dict
text_config (`Union[AutoConfig, dict]`, *optional*):
The config object of the text backbone.
The config object of the text backbone.
ignore_index (`int`, *optional*, defaults to -100):
The ignore index for the loss function.
vocab_size (`int`, *optional*, defaults to 51289):
@@ -335,5 +334,4 @@ class Florence2Config(PretrainedConfig):
if text_config is not None:
self.text_config = Florence2LanguageConfig(**text_config)
super().__init__(**kwargs)

View File

@@ -124,43 +124,37 @@ class XVLAConfig(PreTrainedConfig):
# TODO: jadechoghari: provide default way, and do not hardcode
# Ensure vision_config and text_config are provided with defaults if not specified
config_dict = dict(self.florence_config)
if 'vision_config' not in config_dict or config_dict['vision_config'] is None:
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
# Provide default vision config
config_dict['vision_config'] = {
'model_type': 'davit',
'drop_path_rate': 0.1,
'patch_size': [14, 7, 7, 7],
'patch_stride': [4, 2, 2, 2],
'patch_padding': [3, 1, 1, 1],
'patch_prenorm': [False, True, True, True],
'dim_embed': [256, 512, 1024, 2048],
'num_heads': [8, 16, 32, 64],
'num_groups': [8, 16, 32, 64],
'depths': [1, 1, 9, 1],
'window_size': 12,
'projection_dim': 1024,
'visual_temporal_embedding': {
'type': 'COSINE',
'max_temporal_embeddings': 100
},
'image_pos_embed': {
'type': 'learned_abs_2d',
'max_pos_embeddings': 50
},
'image_feature_source': ['spatial_avg_pool', 'temporal_avg_pool']
config_dict["vision_config"] = {
"model_type": "davit",
"drop_path_rate": 0.1,
"patch_size": [14, 7, 7, 7],
"patch_stride": [4, 2, 2, 2],
"patch_padding": [3, 1, 1, 1],
"patch_prenorm": [False, True, True, True],
"dim_embed": [256, 512, 1024, 2048],
"num_heads": [8, 16, 32, 64],
"num_groups": [8, 16, 32, 64],
"depths": [1, 1, 9, 1],
"window_size": 12,
"projection_dim": 1024,
"visual_temporal_embedding": {"type": "COSINE", "max_temporal_embeddings": 100},
"image_pos_embed": {"type": "learned_abs_2d", "max_pos_embeddings": 50},
"image_feature_source": ["spatial_avg_pool", "temporal_avg_pool"],
}
if 'text_config' not in config_dict or config_dict['text_config'] is None:
if "text_config" not in config_dict or config_dict["text_config"] is None:
# Provide default text config
config_dict['text_config'] = {
'model_type': 'florence2_language',
'vocab_size': 51289,
'd_model': 1024,
'encoder_layers': 12,
'decoder_layers': 12,
'encoder_attention_heads': 16,
'decoder_attention_heads': 16,
'encoder_ffn_dim': 4096,
'decoder_ffn_dim': 4096,
config_dict["text_config"] = {
"model_type": "florence2_language",
"vocab_size": 51289,
"d_model": 1024,
"encoder_layers": 12,
"decoder_layers": 12,
"encoder_attention_heads": 16,
"decoder_attention_heads": 16,
"encoder_ffn_dim": 4096,
"decoder_ffn_dim": 4096,
}
self._florence_config_obj = Florence2Config(**config_dict)
return self._florence_config_obj

File diff suppressed because it is too large Load Diff

View File

@@ -19,7 +19,6 @@
from __future__ import annotations
from collections import deque
from typing import Dict
import torch
import torch.nn.functional as F # noqa: N812
@@ -87,7 +86,7 @@ class XVLAModel(nn.Module):
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
image_mask: torch.Tensor,
) -> Dict[str, torch.Tensor]:
) -> dict[str, torch.Tensor]:
"""
Encode text and multi-view images via Florence2 encoder.
"""
@@ -129,13 +128,14 @@ class XVLAModel(nn.Module):
domain_id: torch.LongTensor,
proprio: torch.Tensor,
action: torch.Tensor,
) -> Dict[str, torch.Tensor]:
) -> dict[str, torch.Tensor]:
enc = self.forward_vlm(input_ids, image_input, image_mask)
batch_size = input_ids.shape[0]
t = (torch.rand(1, device=input_ids.device) + torch.arange(batch_size, device=input_ids.device) / batch_size) % (
1 - 1e-5
)
t = (
torch.rand(1, device=input_ids.device)
+ torch.arange(batch_size, device=input_ids.device) / batch_size
) % (1 - 1e-5)
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
@@ -350,7 +350,9 @@ def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float
ratio = max(current_width / width, current_height / height)
resized_height = int(current_height / ratio)
resized_width = int(current_width / ratio)
resized_img = F.interpolate(img, size=(resized_height, resized_width), mode="bilinear", align_corners=False)
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, height - resized_height)
pad_width = max(0, width - resized_width)

View File

@@ -14,7 +14,7 @@
# limitations under the License.
# ------------------------------------------------------------------------------
from typing import Any, Dict, List, Optional, Union
from typing import Any
import torch
from transformers import ProcessorMixin
@@ -88,7 +88,7 @@ class XVLAProcessor(ProcessorMixin):
super().__init__(image_processor, tokenizer)
# ================== LANGUAGE ENCODING ==================
def encode_language(self, language_instruction: Union[str, List[str]]) -> Dict[str, torch.Tensor]:
def encode_language(self, language_instruction: str | list[str]) -> dict[str, torch.Tensor]:
"""
Tokenize one or more language instructions.
@@ -117,11 +117,7 @@ class XVLAProcessor(ProcessorMixin):
return {"input_ids": inputs["input_ids"]}
# ================== IMAGE ENCODING ==================
def encode_image(
self,
images: Union[List, List[List]],
**kwargs
) -> Dict[str, torch.Tensor]:
def encode_image(self, images: list | list[list], **kwargs) -> dict[str, torch.Tensor]:
"""
Preprocess one or more sets of multi-view images.
@@ -157,8 +153,7 @@ class XVLAProcessor(ProcessorMixin):
# Pad to self.num_views
if V_exist < self.num_views:
processed = torch.cat(
[processed,
processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
[processed, processed.new_zeros(self.num_views - V_exist, *processed.shape[1:])],
dim=0,
)
@@ -177,10 +172,10 @@ class XVLAProcessor(ProcessorMixin):
# ================== COMBINED CALL ==================
def __call__(
self,
images: Optional[Union[List, List[List]]] = None,
language_instruction: Optional[Union[str, List[str]]] = None,
**kwargs
) -> Dict[str, torch.Tensor]:
images: list | list[list] | None = None,
language_instruction: str | list[str] | None = None,
**kwargs,
) -> dict[str, torch.Tensor]:
"""
Combine image and text encoding into a unified multimodal input.
@@ -202,7 +197,7 @@ class XVLAProcessor(ProcessorMixin):
"image_mask": [B, num_views], optional
}
"""
outputs: Dict[str, Any] = {}
outputs: dict[str, Any] = {}
# Encode language if provided
if language_instruction is not None:
@@ -243,7 +238,9 @@ def make_xvla_pre_post_processors(
padding_side=config.tokenizer_padding_side,
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(features=features, norm_map=config.normalization_mapping, stats=dataset_stats),
NormalizerProcessorStep(
features=features, norm_map=config.normalization_mapping, stats=dataset_stats
),
]
output_steps = [
UnnormalizerProcessorStep(

View File

@@ -17,17 +17,18 @@
from __future__ import annotations
import math
from collections.abc import Iterable
from functools import partial
from typing import Final, Iterable, Tuple
from typing import Final
import torch
import torch.nn as nn
import torch.nn.functional as F
# ------------------------------- Small utils ----------------------------------
def _to_2tuple(x) -> Tuple:
def _to_2tuple(x) -> tuple:
"""Minimal replacement for timm.layers.to_2tuple."""
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
t = tuple(x)
@@ -42,6 +43,7 @@ def _has_sdp_attention() -> bool:
# ---------------------------------- MLP --------------------------------------
class Mlp(nn.Module):
"""
MLP used in ViT-style blocks.
@@ -55,8 +57,8 @@ class Mlp(nn.Module):
hidden_features: int | None = None,
out_features: int | None = None,
norm_layer: type[nn.Module] | None = None,
bias: bool | Tuple[bool, bool] = True,
drop: float | Tuple[float, float] = 0.0,
bias: bool | tuple[bool, bool] = True,
drop: float | tuple[float, float] = 0.0,
use_conv: bool = False,
) -> None:
super().__init__()
@@ -86,6 +88,7 @@ class Mlp(nn.Module):
# -------------------------------- Attention ----------------------------------
class Attention(nn.Module):
"""
Multi-Head Self-Attention with optional fused SDPA fallback.
@@ -110,7 +113,7 @@ class Attention(nn.Module):
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.scale = self.head_dim**-0.5
self.fused_attn = _has_sdp_attention()
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
@@ -143,17 +146,19 @@ class Attention(nn.Module):
if self.fused_attn:
x = F.scaled_dot_product_attention(
q, k, v,
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
) # [B, H, T, Dh]
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1) # [B, H, T, T]
attn = q @ k.transpose(-2, -1) # [B, H, T, T]
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v # [B, H, T, Dh]
x = attn @ v # [B, H, T, Dh]
x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
x = x.transpose(1, 2).reshape(B, T, C) # [B, T, C]
x = self.proj(x)
x = self.proj_drop(x)
return x
@@ -161,6 +166,7 @@ class Attention(nn.Module):
# ------------------------------- Utilities -----------------------------------
def basic_init(module: nn.Module) -> None:
"""
Apply a basic initialization scheme to Linear layers.
@@ -194,9 +200,7 @@ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torc
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=t.dtype, device=t.device)
/ half
-math.log(max_period) * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) / half
)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@@ -207,6 +211,7 @@ def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torc
# ------------------------------- Core Layers ----------------------------------
class DomainAwareLinear(nn.Module):
"""
Linear layer with domain-conditioned parameters (per-sample).
@@ -283,6 +288,7 @@ class TransformerBlock(nn.Module):
# --------------------------- Main Model ---------------------------------------
class SoftPromptedTransformer(nn.Module):
"""
Multi-modal, domain-aware Transformer with optional soft prompts.
@@ -318,7 +324,9 @@ class SoftPromptedTransformer(nn.Module):
if use_hetero_proj:
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
self.aux_visual_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
self.aux_visual_proj = DomainAwareLinear(
multi_modal_input_size, hidden_size, num_domains=num_domains
)
else:
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
@@ -367,16 +375,20 @@ class SoftPromptedTransformer(nn.Module):
B, num_actions = action_with_noise.shape[:2]
# Encode (action + proprio + time) → tokens
time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
time_emb = timestep_embedding(t, self.dim_time) # [B, dim_time]
time_tokens = time_emb.unsqueeze(1).expand(B, num_actions, self.dim_time)
proprio_tokens = proprio.unsqueeze(1).expand(B, num_actions, proprio.shape[-1])
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
x = self.action_encoder(action_tokens, domain_id) # [B, T_action, H]
# Project visual streams and concatenate
if self.use_hetero_proj:
x = torch.cat(
[x, self.vlm_proj(vlm_features, domain_id), self.aux_visual_proj(aux_visual_inputs, domain_id)],
[
x,
self.vlm_proj(vlm_features, domain_id),
self.aux_visual_proj(aux_visual_inputs, domain_id),
],
dim=1,
)
else:
@@ -385,9 +397,7 @@ class SoftPromptedTransformer(nn.Module):
# Add positional embeddings (truncate if needed)
seq_len = x.shape[1]
if seq_len > self.pos_emb.shape[1]:
raise ValueError(
f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}."
)
raise ValueError(f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}.")
x = x + self.pos_emb[:, :seq_len, :]
# Append soft prompts
@@ -400,4 +410,4 @@ class SoftPromptedTransformer(nn.Module):
x = block(x)
# Decode only the action segment
return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)
return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)

View File

@@ -1,6 +1,5 @@
from lerobot.policies.factory import make_policy
from lerobot.policies.factory import make_policy_config
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.factory import make_policy, make_policy_config
cfg = make_policy_config("xvla")
@@ -8,4 +7,4 @@ dataset_id = "lerobot/svla_so101_pickplace"
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
policy = make_policy(cfg=cfg, ds_meta=dataset_metadata)
print(policy)
print(policy)

View File

@@ -4,6 +4,7 @@ lerobot-train \
--output_dir=outputs/train/act_your_dataset \
--job_name=xvla_so101_pickplace \
--policy.device=cuda \
--policy.action_mode=franka_joint7 \
--wandb.enable=true \
--policy.repo_id=jadechoghari/xvla_policy \
--steps=10000