trying to close success rate gap

This commit is contained in:
Maxime Ellerbach
2026-05-22 07:59:47 +00:00
committed by Maximellerbach
parent 7e23859c55
commit 8efa5cabe9
5 changed files with 193 additions and 50 deletions

View File

@@ -18,7 +18,7 @@ class VLAJEPAConfig(PreTrainedConfig):
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MIN_MAX,
}
)
@@ -66,6 +66,7 @@ class VLAJEPAConfig(PreTrainedConfig):
resize_images_to: tuple[int, int] | None = None
binarize_gripper_action: bool = True
pre_snap_gripper_action: bool = True
clip_normalized_actions: bool = True
torch_dtype: str = "bfloat16"

View File

@@ -76,25 +76,141 @@ def _map_checkpoint_key(raw_key: str) -> str | None:
return None
def _fetch_action_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict | None:
"""Download dataset_statistics.json and return the action stats dict."""
def _fetch_dataset_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict | None:
"""Download dataset_statistics.json and return {action: {...}, state: {...}} stats dict."""
import json
stats_file = f"{subfolder}/dataset_statistics.json"
try:
local = api.hf_hub_download(source_repo_id, stats_file)
data = json.loads(Path(local).read_text())
# Original repo nests stats under a robot key, e.g. {"franka": {"action": {...}}}
# Original repo nests stats under a robot key, e.g. {"franka": {"action": {...}, "state": {...}}}
for robot_key in data:
if isinstance(data[robot_key], dict) and "action" in data[robot_key]:
log.info(" Loaded action stats from %s (robot key: %s)", stats_file, robot_key)
return data[robot_key]["action"]
log.warning(" %s found but no 'action' key under any robot key — skipping action stats.", stats_file)
robot_data = data[robot_key]
if isinstance(robot_data, dict) and "action" in robot_data:
log.info(" Loaded dataset stats from %s (robot key: %s)", stats_file, robot_key)
result = {"action": robot_data["action"]}
if "state" in robot_data:
result["observation.state"] = robot_data["state"]
log.info(" Also loaded state stats.")
return result
log.warning(" %s found but no 'action' key under any robot key — skipping stats.", stats_file)
except Exception as exc: # noqa: BLE001
log.warning(" Could not fetch %s: %s — postprocessor will have no unnorm stats.", stats_file, exc)
return None
def _set_if_present(d: dict, key: str, value) -> None:
if value is not None:
d[key] = value
def _deep_get(mapping: dict, path: tuple, default=None):
current = mapping
for key in path:
if not isinstance(current, dict) or key not in current:
return default
current = current[key]
return current
def _fetch_source_config(api: HfApi, source_repo_id: str, subfolder: str) -> dict:
"""Download config.yaml from the source HF repo for a given variant subfolder."""
try:
import yaml
except ImportError:
log.warning("PyYAML not installed — cannot apply source config.yaml overrides.")
return {}
config_file = f"{subfolder}/config.yaml"
try:
local = api.hf_hub_download(source_repo_id, config_file)
data = yaml.safe_load(Path(local).read_text()) or {}
if isinstance(data, dict):
log.info(" Loaded source config from %s", config_file)
return data
except Exception as exc: # noqa: BLE001
log.warning(" Could not fetch %s: %s — using hardcoded defaults.", config_file, exc)
return {}
def _apply_source_config(kwargs: dict, source_config: dict) -> None:
"""Apply ginwind/VLA-JEPA config.yaml values to kwargs, mirroring todo_converter.py logic."""
if not source_config:
return
data_cfg = _deep_get(source_config, ("datasets", "vla_data"), {})
action_cfg = _deep_get(source_config, ("framework", "action_model"), {})
diffusion_cfg = _deep_get(source_config, ("framework", "action_model", "diffusion_model_cfg"), {})
video_cfg = _deep_get(source_config, ("framework", "vj2_model"), {})
trainer_cfg = source_config.get("trainer", {})
prompt_template = data_cfg.get("CoT_prompt")
if prompt_template:
kwargs["prompt_template"] = str(prompt_template)
action_horizon = action_cfg.get("action_horizon")
if action_horizon is not None:
kwargs["chunk_size"] = int(action_horizon)
kwargs["n_action_steps"] = int(action_horizon)
_set_if_present(
kwargs,
"num_action_tokens_per_timestep",
video_cfg.get("num_action_tokens_per_timestep", action_cfg.get("num_action_tokens_per_timestep")),
)
_set_if_present(
kwargs,
"num_embodied_action_tokens_per_instruction",
video_cfg.get(
"num_embodied_action_tokens_per_instruction",
action_cfg.get("num_embodied_action_tokens_per_instruction"),
),
)
_set_if_present(kwargs, "num_inference_timesteps", action_cfg.get("num_inference_timesteps"))
_set_if_present(kwargs, "special_action_token", video_cfg.get("special_action_token"))
_set_if_present(kwargs, "embodied_action_token", video_cfg.get("embodied_action_token"))
_set_if_present(
kwargs, "action_hidden_size", action_cfg.get("action_hidden_dim", action_cfg.get("hidden_size"))
)
_set_if_present(kwargs, "action_model_type", action_cfg.get("action_model_type"))
_set_if_present(kwargs, "action_noise_beta_alpha", action_cfg.get("noise_beta_alpha"))
_set_if_present(kwargs, "action_noise_beta_beta", action_cfg.get("noise_beta_beta"))
_set_if_present(kwargs, "action_noise_s", action_cfg.get("noise_s"))
_set_if_present(kwargs, "action_num_timestep_buckets", action_cfg.get("num_timestep_buckets"))
_set_if_present(kwargs, "repeated_diffusion_steps", action_cfg.get("repeated_diffusion_steps"))
_set_if_present(kwargs, "action_num_layers", diffusion_cfg.get("num_layers"))
_set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout"))
_set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames"))
_set_if_present(
kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth"))
)
_set_if_present(
kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads"))
)
_set_if_present(kwargs, "predictor_mlp_ratio", video_cfg.get("predictor_mlp_ratio"))
_set_if_present(kwargs, "optimizer_grad_clip_norm", trainer_cfg.get("max_grad_norm"))
learning_rate = trainer_cfg.get("learning_rate", {})
if isinstance(learning_rate, dict):
_set_if_present(kwargs, "optimizer_lr", learning_rate.get("action_model"))
optimizer_cfg = trainer_cfg.get("optimizer", {})
if isinstance(optimizer_cfg, dict):
_set_if_present(kwargs, "optimizer_eps", optimizer_cfg.get("eps"))
_set_if_present(kwargs, "optimizer_weight_decay", optimizer_cfg.get("weight_decay"))
betas = optimizer_cfg.get("betas")
if betas is not None:
kwargs["optimizer_betas"] = tuple(betas)
scheduler = trainer_cfg.get("scheduler", {})
if isinstance(scheduler, dict):
_set_if_present(kwargs, "scheduler_warmup_steps", scheduler.get("warmup_steps"))
_set_if_present(kwargs, "scheduler_decay_lr", scheduler.get("min_lr"))
_set_if_present(kwargs, "scheduler_warmup_steps", trainer_cfg.get("num_warmup_steps"))
scheduler_kwargs = trainer_cfg.get("scheduler_specific_kwargs", {})
if isinstance(scheduler_kwargs, dict):
_set_if_present(kwargs, "scheduler_decay_lr", scheduler_kwargs.get("min_lr"))
# ---------------------------------------------------------------------------
# Architecture — identical across all 4 variants (from config.json)
# ---------------------------------------------------------------------------
@@ -129,7 +245,7 @@ _ARCH = {
# LIBERO — confirmed from lerobot/libero_10 meta/info.json
_LIBERO_CAMS = [
"observation.images.image", # agentview camera
"observation.images.wrist_image", # eye-in-hand camera
"observation.images.image2", # eye-in-hand camera
]
# DROID pretrain — 2 views match the predictor embed_dim=2 × 1024=2048 in checkpoint
@@ -153,23 +269,41 @@ def _build_config(
camera_keys: list[str],
with_state: bool,
enable_world_model: bool = True,
source_config: dict | None = None,
):
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
input_features = {k: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)) for k in camera_keys}
kwargs = dict(_ARCH)
_apply_source_config(kwargs, source_config or {})
# Image resolution: prefer source config, fall back to 224
data_cfg = _deep_get(source_config or {}, ("datasets", "vla_data"), {})
raw_res = data_cfg.get("resolution_size")
resolution_size = int(raw_res) if raw_res is not None else 224
image_shape = (3, resolution_size, resolution_size)
# Always set resize_images_to so the policy resizes env images to the training resolution,
# regardless of what resolution the eval env renders at.
kwargs["resize_images_to"] = (resolution_size, resolution_size)
# State / action dims: prefer source config
action_cfg = _deep_get(source_config or {}, ("framework", "action_model"), {})
state_dim = int(action_cfg["state_dim"]) if "state_dim" in action_cfg else 8
action_dim = int(action_cfg["action_dim"]) if "action_dim" in action_cfg else 7
input_features = {k: PolicyFeature(type=FeatureType.VISUAL, shape=image_shape) for k in camera_keys}
if with_state:
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))
cfg = VLAJEPAConfig(
input_features=input_features,
output_features={
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
},
enable_world_model=enable_world_model,
binarize_gripper_action=True,
clip_normalized_actions=True,
**_ARCH,
**kwargs,
)
cfg.validate_features()
return cfg
@@ -276,13 +410,12 @@ def main() -> None:
log.info(" Skipped sample: %s", skipped_keys[:5])
log.info(" First 5 mapped keys: %s", list(mapped_sd)[:5])
# 3. Fetch action stats (min/max per dim) needed by the postprocessor unnormalizer
action_stats_raw = _fetch_action_stats(api, SOURCE_REPO_ID, subfolder)
# Wrap as {"action": {...}} for UnnormalizerProcessorStep
dataset_stats = {"action": action_stats_raw} if action_stats_raw is not None else None
# 3. Fetch action + state stats needed by the pre/postprocessor unnormalizers
dataset_stats = _fetch_dataset_stats(api, SOURCE_REPO_ID, subfolder)
# 4. Build config (no policy instantiation — avoids loading backbone from Hub)
config = _build_config(camera_keys, with_state, enable_world_model)
source_config = _fetch_source_config(api, SOURCE_REPO_ID, subfolder)
config = _build_config(camera_keys, with_state, enable_world_model, source_config)
# 5. Save everything to a temp dir and upload in one shot
api.create_repo(target_repo_id, repo_type="model", exist_ok=True)

View File

@@ -117,38 +117,21 @@ class VLAJEPAModel(nn.Module):
)
def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor:
"""Return Qwen's final decoder-layer output before the final RMSNorm.
"""Return Qwen's last decoder hidden state matching training behaviour.
starVLA trained its downstream heads on the legacy transformers-4.57
`hidden_states[-1]` value, which is the last decoder layer output before
Qwen's final RMSNorm. Newer transformers versions expose `hidden_states[-1]`
as the post-norm last hidden state, so capture the layer output directly.
The original starVLA uses `output_hidden_states=True` and takes `hidden_states[-1]`.
In transformers 5.x the `@capture_outputs` decorator explicitly replaces
`hidden_states[-1]` with `last_hidden_state` (post-RMSNorm), so this call
consistently returns the post-norm output regardless of transformers version,
matching what the model was trained with.
"""
captured: dict[str, torch.Tensor] = {}
language_model = self.qwen.model.model.language_model
def capture_last_layer_output(
_module: nn.Module,
_inputs: tuple[torch.Tensor, ...],
output: torch.Tensor | tuple[torch.Tensor, ...],
) -> None:
captured["last_hidden"] = output[0] if isinstance(output, tuple) else output
return None
handle = language_model.layers[-1].register_forward_hook(capture_last_layer_output)
try:
self.qwen.model.model(
**qwen_inputs,
output_hidden_states=False,
output_attentions=False,
return_dict=True,
)
finally:
handle.remove()
if "last_hidden" not in captured:
raise RuntimeError("Failed to capture Qwen last decoder hidden states.")
return captured["last_hidden"]
outputs = self.qwen.model(
**qwen_inputs,
output_hidden_states=True,
output_attentions=False,
return_dict=True,
)
return outputs.hidden_states[-1]
# ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ----
@@ -469,7 +452,7 @@ class VLAJEPAPolicy(PreTrainedPolicy):
# Clamp to [0, 255]
if t_np.max() <= 1.0:
t_np = t_np * 255.0
t_np = t_np.clip(0, 255).astype(np.uint8)
t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8)
sample_views.append(t_np)
# Stack views: [V, T, H, W, 3]
videos_per_sample.append(np.stack(sample_views, axis=0))

View File

@@ -38,6 +38,30 @@ class ClipActionsProcessorStep(ProcessorStep):
return features
@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper")
class PreSnapGripperProcessorStep(ProcessorStep):
"""Snaps gripper dim (index 6) to {0, 1} BEFORE unnormalization.
Mirrors the original starVLA LIBERO eval:
normalized[:, 6] = np.where(normalized[:, 6] < 0.5, 0, 1)
This ensures the unnormalizer receives an exact binary value, which is
required when the model was trained with gripper in identity (mask=False)
space where 0=open and 1=close.
"""
def __call__(self, transition: EnvTransition) -> EnvTransition:
action = transition.get(TransitionKey.ACTION)
if action is not None and action.shape[-1] >= 7:
transition = dict(transition)
a = action.clone()
a[..., 6] = (a[..., 6] >= 0.5).float()
transition[TransitionKey.ACTION] = a
return transition
def transform_features(self, features):
return features
@ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper")
class BinarizeGripperProcessorStep(ProcessorStep):
"""Binarizes gripper dim (index 6) after unnormalization.
@@ -80,6 +104,8 @@ def make_vla_jepa_pre_post_processors(
output_steps: list[ProcessorStep] = []
if config.clip_normalized_actions:
output_steps.append(ClipActionsProcessorStep())
if config.pre_snap_gripper_action:
output_steps.append(PreSnapGripperProcessorStep())
output_steps.append(
UnnormalizerProcessorStep(
features=features,

View File

@@ -97,7 +97,7 @@ class Qwen3VLInterface(torch.nn.Module):
image = image.float()
if image.max() <= 1.0:
image = image * 255.0
image = image.clamp(0, 255).to(torch.uint8).numpy()
image = image.clamp(0, 255).round().to(torch.uint8).numpy()
if image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
return Image.fromarray(image)