mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 18:31:25 +00:00
trying to close success rate gap
This commit is contained in:
committed by
Maximellerbach
parent
7e23859c55
commit
8efa5cabe9
@@ -18,7 +18,7 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
@@ -66,6 +66,7 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
|
||||
resize_images_to: tuple[int, int] | None = None
|
||||
binarize_gripper_action: bool = True
|
||||
pre_snap_gripper_action: bool = True
|
||||
clip_normalized_actions: bool = True
|
||||
torch_dtype: str = "bfloat16"
|
||||
|
||||
|
||||
@@ -76,25 +76,141 @@ def _map_checkpoint_key(raw_key: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_action_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict | None:
|
||||
"""Download dataset_statistics.json and return the action stats dict."""
|
||||
def _fetch_dataset_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict | None:
|
||||
"""Download dataset_statistics.json and return {action: {...}, state: {...}} stats dict."""
|
||||
import json
|
||||
|
||||
stats_file = f"{subfolder}/dataset_statistics.json"
|
||||
try:
|
||||
local = api.hf_hub_download(source_repo_id, stats_file)
|
||||
data = json.loads(Path(local).read_text())
|
||||
# Original repo nests stats under a robot key, e.g. {"franka": {"action": {...}}}
|
||||
# Original repo nests stats under a robot key, e.g. {"franka": {"action": {...}, "state": {...}}}
|
||||
for robot_key in data:
|
||||
if isinstance(data[robot_key], dict) and "action" in data[robot_key]:
|
||||
log.info(" Loaded action stats from %s (robot key: %s)", stats_file, robot_key)
|
||||
return data[robot_key]["action"]
|
||||
log.warning(" %s found but no 'action' key under any robot key — skipping action stats.", stats_file)
|
||||
robot_data = data[robot_key]
|
||||
if isinstance(robot_data, dict) and "action" in robot_data:
|
||||
log.info(" Loaded dataset stats from %s (robot key: %s)", stats_file, robot_key)
|
||||
result = {"action": robot_data["action"]}
|
||||
if "state" in robot_data:
|
||||
result["observation.state"] = robot_data["state"]
|
||||
log.info(" Also loaded state stats.")
|
||||
return result
|
||||
log.warning(" %s found but no 'action' key under any robot key — skipping stats.", stats_file)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(" Could not fetch %s: %s — postprocessor will have no unnorm stats.", stats_file, exc)
|
||||
return None
|
||||
|
||||
|
||||
def _set_if_present(d: dict, key: str, value) -> None:
|
||||
if value is not None:
|
||||
d[key] = value
|
||||
|
||||
|
||||
def _deep_get(mapping: dict, path: tuple, default=None):
|
||||
current = mapping
|
||||
for key in path:
|
||||
if not isinstance(current, dict) or key not in current:
|
||||
return default
|
||||
current = current[key]
|
||||
return current
|
||||
|
||||
|
||||
def _fetch_source_config(api: HfApi, source_repo_id: str, subfolder: str) -> dict:
|
||||
"""Download config.yaml from the source HF repo for a given variant subfolder."""
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
log.warning("PyYAML not installed — cannot apply source config.yaml overrides.")
|
||||
return {}
|
||||
config_file = f"{subfolder}/config.yaml"
|
||||
try:
|
||||
local = api.hf_hub_download(source_repo_id, config_file)
|
||||
data = yaml.safe_load(Path(local).read_text()) or {}
|
||||
if isinstance(data, dict):
|
||||
log.info(" Loaded source config from %s", config_file)
|
||||
return data
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(" Could not fetch %s: %s — using hardcoded defaults.", config_file, exc)
|
||||
return {}
|
||||
|
||||
|
||||
def _apply_source_config(kwargs: dict, source_config: dict) -> None:
|
||||
"""Apply ginwind/VLA-JEPA config.yaml values to kwargs, mirroring todo_converter.py logic."""
|
||||
if not source_config:
|
||||
return
|
||||
|
||||
data_cfg = _deep_get(source_config, ("datasets", "vla_data"), {})
|
||||
action_cfg = _deep_get(source_config, ("framework", "action_model"), {})
|
||||
diffusion_cfg = _deep_get(source_config, ("framework", "action_model", "diffusion_model_cfg"), {})
|
||||
video_cfg = _deep_get(source_config, ("framework", "vj2_model"), {})
|
||||
trainer_cfg = source_config.get("trainer", {})
|
||||
|
||||
prompt_template = data_cfg.get("CoT_prompt")
|
||||
if prompt_template:
|
||||
kwargs["prompt_template"] = str(prompt_template)
|
||||
|
||||
action_horizon = action_cfg.get("action_horizon")
|
||||
if action_horizon is not None:
|
||||
kwargs["chunk_size"] = int(action_horizon)
|
||||
kwargs["n_action_steps"] = int(action_horizon)
|
||||
|
||||
_set_if_present(
|
||||
kwargs,
|
||||
"num_action_tokens_per_timestep",
|
||||
video_cfg.get("num_action_tokens_per_timestep", action_cfg.get("num_action_tokens_per_timestep")),
|
||||
)
|
||||
_set_if_present(
|
||||
kwargs,
|
||||
"num_embodied_action_tokens_per_instruction",
|
||||
video_cfg.get(
|
||||
"num_embodied_action_tokens_per_instruction",
|
||||
action_cfg.get("num_embodied_action_tokens_per_instruction"),
|
||||
),
|
||||
)
|
||||
_set_if_present(kwargs, "num_inference_timesteps", action_cfg.get("num_inference_timesteps"))
|
||||
_set_if_present(kwargs, "special_action_token", video_cfg.get("special_action_token"))
|
||||
_set_if_present(kwargs, "embodied_action_token", video_cfg.get("embodied_action_token"))
|
||||
_set_if_present(
|
||||
kwargs, "action_hidden_size", action_cfg.get("action_hidden_dim", action_cfg.get("hidden_size"))
|
||||
)
|
||||
_set_if_present(kwargs, "action_model_type", action_cfg.get("action_model_type"))
|
||||
_set_if_present(kwargs, "action_noise_beta_alpha", action_cfg.get("noise_beta_alpha"))
|
||||
_set_if_present(kwargs, "action_noise_beta_beta", action_cfg.get("noise_beta_beta"))
|
||||
_set_if_present(kwargs, "action_noise_s", action_cfg.get("noise_s"))
|
||||
_set_if_present(kwargs, "action_num_timestep_buckets", action_cfg.get("num_timestep_buckets"))
|
||||
_set_if_present(kwargs, "repeated_diffusion_steps", action_cfg.get("repeated_diffusion_steps"))
|
||||
_set_if_present(kwargs, "action_num_layers", diffusion_cfg.get("num_layers"))
|
||||
_set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout"))
|
||||
|
||||
_set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames"))
|
||||
_set_if_present(
|
||||
kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth"))
|
||||
)
|
||||
_set_if_present(
|
||||
kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads"))
|
||||
)
|
||||
_set_if_present(kwargs, "predictor_mlp_ratio", video_cfg.get("predictor_mlp_ratio"))
|
||||
|
||||
_set_if_present(kwargs, "optimizer_grad_clip_norm", trainer_cfg.get("max_grad_norm"))
|
||||
learning_rate = trainer_cfg.get("learning_rate", {})
|
||||
if isinstance(learning_rate, dict):
|
||||
_set_if_present(kwargs, "optimizer_lr", learning_rate.get("action_model"))
|
||||
optimizer_cfg = trainer_cfg.get("optimizer", {})
|
||||
if isinstance(optimizer_cfg, dict):
|
||||
_set_if_present(kwargs, "optimizer_eps", optimizer_cfg.get("eps"))
|
||||
_set_if_present(kwargs, "optimizer_weight_decay", optimizer_cfg.get("weight_decay"))
|
||||
betas = optimizer_cfg.get("betas")
|
||||
if betas is not None:
|
||||
kwargs["optimizer_betas"] = tuple(betas)
|
||||
scheduler = trainer_cfg.get("scheduler", {})
|
||||
if isinstance(scheduler, dict):
|
||||
_set_if_present(kwargs, "scheduler_warmup_steps", scheduler.get("warmup_steps"))
|
||||
_set_if_present(kwargs, "scheduler_decay_lr", scheduler.get("min_lr"))
|
||||
_set_if_present(kwargs, "scheduler_warmup_steps", trainer_cfg.get("num_warmup_steps"))
|
||||
scheduler_kwargs = trainer_cfg.get("scheduler_specific_kwargs", {})
|
||||
if isinstance(scheduler_kwargs, dict):
|
||||
_set_if_present(kwargs, "scheduler_decay_lr", scheduler_kwargs.get("min_lr"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Architecture — identical across all 4 variants (from config.json)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -129,7 +245,7 @@ _ARCH = {
|
||||
# LIBERO — confirmed from lerobot/libero_10 meta/info.json
|
||||
_LIBERO_CAMS = [
|
||||
"observation.images.image", # agentview camera
|
||||
"observation.images.wrist_image", # eye-in-hand camera
|
||||
"observation.images.image2", # eye-in-hand camera
|
||||
]
|
||||
|
||||
# DROID pretrain — 2 views match the predictor embed_dim=2 × 1024=2048 in checkpoint
|
||||
@@ -153,23 +269,41 @@ def _build_config(
|
||||
camera_keys: list[str],
|
||||
with_state: bool,
|
||||
enable_world_model: bool = True,
|
||||
source_config: dict | None = None,
|
||||
):
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
|
||||
|
||||
input_features = {k: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)) for k in camera_keys}
|
||||
kwargs = dict(_ARCH)
|
||||
_apply_source_config(kwargs, source_config or {})
|
||||
|
||||
# Image resolution: prefer source config, fall back to 224
|
||||
data_cfg = _deep_get(source_config or {}, ("datasets", "vla_data"), {})
|
||||
raw_res = data_cfg.get("resolution_size")
|
||||
resolution_size = int(raw_res) if raw_res is not None else 224
|
||||
image_shape = (3, resolution_size, resolution_size)
|
||||
# Always set resize_images_to so the policy resizes env images to the training resolution,
|
||||
# regardless of what resolution the eval env renders at.
|
||||
kwargs["resize_images_to"] = (resolution_size, resolution_size)
|
||||
|
||||
# State / action dims: prefer source config
|
||||
action_cfg = _deep_get(source_config or {}, ("framework", "action_model"), {})
|
||||
state_dim = int(action_cfg["state_dim"]) if "state_dim" in action_cfg else 8
|
||||
action_dim = int(action_cfg["action_dim"]) if "action_dim" in action_cfg else 7
|
||||
|
||||
input_features = {k: PolicyFeature(type=FeatureType.VISUAL, shape=image_shape) for k in camera_keys}
|
||||
if with_state:
|
||||
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
|
||||
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))
|
||||
|
||||
cfg = VLAJEPAConfig(
|
||||
input_features=input_features,
|
||||
output_features={
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
||||
},
|
||||
enable_world_model=enable_world_model,
|
||||
binarize_gripper_action=True,
|
||||
clip_normalized_actions=True,
|
||||
**_ARCH,
|
||||
**kwargs,
|
||||
)
|
||||
cfg.validate_features()
|
||||
return cfg
|
||||
@@ -276,13 +410,12 @@ def main() -> None:
|
||||
log.info(" Skipped sample: %s", skipped_keys[:5])
|
||||
log.info(" First 5 mapped keys: %s", list(mapped_sd)[:5])
|
||||
|
||||
# 3. Fetch action stats (min/max per dim) needed by the postprocessor unnormalizer
|
||||
action_stats_raw = _fetch_action_stats(api, SOURCE_REPO_ID, subfolder)
|
||||
# Wrap as {"action": {...}} for UnnormalizerProcessorStep
|
||||
dataset_stats = {"action": action_stats_raw} if action_stats_raw is not None else None
|
||||
# 3. Fetch action + state stats needed by the pre/postprocessor unnormalizers
|
||||
dataset_stats = _fetch_dataset_stats(api, SOURCE_REPO_ID, subfolder)
|
||||
|
||||
# 4. Build config (no policy instantiation — avoids loading backbone from Hub)
|
||||
config = _build_config(camera_keys, with_state, enable_world_model)
|
||||
source_config = _fetch_source_config(api, SOURCE_REPO_ID, subfolder)
|
||||
config = _build_config(camera_keys, with_state, enable_world_model, source_config)
|
||||
|
||||
# 5. Save everything to a temp dir and upload in one shot
|
||||
api.create_repo(target_repo_id, repo_type="model", exist_ok=True)
|
||||
|
||||
@@ -117,38 +117,21 @@ class VLAJEPAModel(nn.Module):
|
||||
)
|
||||
|
||||
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
|
||||
"""Return Qwen's final decoder-layer output before the final RMSNorm.
|
||||
"""Return Qwen's last decoder hidden state matching training behaviour.
|
||||
|
||||
starVLA trained its downstream heads on the legacy transformers-4.57
|
||||
`hidden_states[-1]` value, which is the last decoder layer output before
|
||||
Qwen's final RMSNorm. Newer transformers versions expose `hidden_states[-1]`
|
||||
as the post-norm last hidden state, so capture the layer output directly.
|
||||
The original starVLA uses `output_hidden_states=True` and takes `hidden_states[-1]`.
|
||||
In transformers 5.x the `@capture_outputs` decorator explicitly replaces
|
||||
`hidden_states[-1]` with `last_hidden_state` (post-RMSNorm), so this call
|
||||
consistently returns the post-norm output regardless of transformers version,
|
||||
matching what the model was trained with.
|
||||
"""
|
||||
captured: dict[str, torch.Tensor] = {}
|
||||
language_model = self.qwen.model.model.language_model
|
||||
|
||||
def capture_last_layer_output(
|
||||
_module: nn.Module,
|
||||
_inputs: tuple[torch.Tensor, ...],
|
||||
output: torch.Tensor | tuple[torch.Tensor, ...],
|
||||
) -> None:
|
||||
captured["last_hidden"] = output[0] if isinstance(output, tuple) else output
|
||||
return None
|
||||
|
||||
handle = language_model.layers[-1].register_forward_hook(capture_last_layer_output)
|
||||
try:
|
||||
self.qwen.model.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
finally:
|
||||
handle.remove()
|
||||
|
||||
if "last_hidden" not in captured:
|
||||
raise RuntimeError("Failed to capture Qwen last decoder hidden states.")
|
||||
return captured["last_hidden"]
|
||||
outputs = self.qwen.model(
|
||||
**qwen_inputs,
|
||||
output_hidden_states=True,
|
||||
output_attentions=False,
|
||||
return_dict=True,
|
||||
)
|
||||
return outputs.hidden_states[-1]
|
||||
|
||||
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
|
||||
|
||||
@@ -469,7 +452,7 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
# Clamp to [0, 255]
|
||||
if t_np.max() <= 1.0:
|
||||
t_np = t_np * 255.0
|
||||
t_np = t_np.clip(0, 255).astype(np.uint8)
|
||||
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
|
||||
sample_views.append(t_np)
|
||||
# Stack views: [V, T, H, W, 3]
|
||||
videos_per_sample.append(np.stack(sample_views, axis=0))
|
||||
|
||||
@@ -38,6 +38,30 @@ class ClipActionsProcessorStep(ProcessorStep):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
|
||||
class PreSnapGripperProcessorStep(ProcessorStep):
|
||||
"""Snaps gripper dim (index 6) to {0, 1} BEFORE unnormalization.
|
||||
|
||||
Mirrors the original starVLA LIBERO eval:
|
||||
normalized[:, 6] = np.where(normalized[:, 6] < 0.5, 0, 1)
|
||||
This ensures the unnormalizer receives an exact binary value, which is
|
||||
required when the model was trained with gripper in identity (mask=False)
|
||||
space where 0=open and 1=close.
|
||||
"""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and action.shape[-1] >= 7:
|
||||
transition = dict(transition)
|
||||
a = action.clone()
|
||||
a[..., 6] = (a[..., 6] >= 0.5).float()
|
||||
transition[TransitionKey.ACTION] = a
|
||||
return transition
|
||||
|
||||
def transform_features(self, features):
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
|
||||
class BinarizeGripperProcessorStep(ProcessorStep):
|
||||
"""Binarizes gripper dim (index 6) after unnormalization.
|
||||
@@ -80,6 +104,8 @@ def make_vla_jepa_pre_post_processors(
|
||||
output_steps: list[ProcessorStep] = []
|
||||
if config.clip_normalized_actions:
|
||||
output_steps.append(ClipActionsProcessorStep())
|
||||
if config.pre_snap_gripper_action:
|
||||
output_steps.append(PreSnapGripperProcessorStep())
|
||||
output_steps.append(
|
||||
UnnormalizerProcessorStep(
|
||||
features=features,
|
||||
|
||||
@@ -97,7 +97,7 @@ class Qwen3VLInterface(torch.nn.Module):
|
||||
image = image.float()
|
||||
if image.max() <= 1.0:
|
||||
image = image * 255.0
|
||||
image = image.clamp(0, 255).to(torch.uint8).numpy()
|
||||
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
|
||||
if image.shape[-1] == 1:
|
||||
image = np.repeat(image, 3, axis=-1)
|
||||
return Image.fromarray(image)
|
||||
|
||||
Reference in New Issue
Block a user