From 8efa5cabe9cc00b6d112e0238373733183a7c056 Mon Sep 17 00:00:00 2001 From: Maxime Ellerbach Date: Fri, 22 May 2026 07:59:47 +0000 Subject: [PATCH] trying to close success rate gap --- .../vla_jepa/configuration_vla_jepa.py | 3 +- .../vla_jepa/convert_vla_jepa_checkpoints.py | 167 ++++++++++++++++-- .../policies/vla_jepa/modeling_vla_jepa.py | 45 ++--- .../policies/vla_jepa/processor_vla_jepa.py | 26 +++ .../policies/vla_jepa/qwen_interface.py | 2 +- 5 files changed, 193 insertions(+), 50 deletions(-) diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py index 9e1dd0ffe..1794a5c46 100644 --- a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -18,7 +18,7 @@ class VLAJEPAConfig(PreTrainedConfig): normalization_mapping: dict[str, NormalizationMode] = field( default_factory=lambda: { "VISUAL": NormalizationMode.IDENTITY, - "STATE": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, "ACTION": NormalizationMode.MIN_MAX, } ) @@ -66,6 +66,7 @@ class VLAJEPAConfig(PreTrainedConfig): resize_images_to: tuple[int, int] | None = None binarize_gripper_action: bool = True + pre_snap_gripper_action: bool = True clip_normalized_actions: bool = True torch_dtype: str = "bfloat16" diff --git a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py index 90120c9bf..591ce5db9 100644 --- a/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py +++ b/src/lerobot/policies/vla_jepa/convert_vla_jepa_checkpoints.py @@ -76,25 +76,141 @@ def _map_checkpoint_key(raw_key: str) -> str | None: return None -def _fetch_action_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict | None: - """Download dataset_statistics.json and return the action stats dict.""" +def _fetch_dataset_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict | None: + """Download dataset_statistics.json and return {action: {...}, state: {...}} stats dict.""" import json stats_file = f"{subfolder}/dataset_statistics.json" try: local = api.hf_hub_download(source_repo_id, stats_file) data = json.loads(Path(local).read_text()) - # Original repo nests stats under a robot key, e.g. {"franka": {"action": {...}}} + # Original repo nests stats under a robot key, e.g. {"franka": {"action": {...}, "state": {...}}} for robot_key in data: - if isinstance(data[robot_key], dict) and "action" in data[robot_key]: - log.info(" Loaded action stats from %s (robot key: %s)", stats_file, robot_key) - return data[robot_key]["action"] - log.warning(" %s found but no 'action' key under any robot key — skipping action stats.", stats_file) + robot_data = data[robot_key] + if isinstance(robot_data, dict) and "action" in robot_data: + log.info(" Loaded dataset stats from %s (robot key: %s)", stats_file, robot_key) + result = {"action": robot_data["action"]} + if "state" in robot_data: + result["observation.state"] = robot_data["state"] + log.info(" Also loaded state stats.") + return result + log.warning(" %s found but no 'action' key under any robot key — skipping stats.", stats_file) except Exception as exc: # noqa: BLE001 log.warning(" Could not fetch %s: %s — postprocessor will have no unnorm stats.", stats_file, exc) return None +def _set_if_present(d: dict, key: str, value) -> None: + if value is not None: + d[key] = value + + +def _deep_get(mapping: dict, path: tuple, default=None): + current = mapping + for key in path: + if not isinstance(current, dict) or key not in current: + return default + current = current[key] + return current + + +def _fetch_source_config(api: HfApi, source_repo_id: str, subfolder: str) -> dict: + """Download config.yaml from the source HF repo for a given variant subfolder.""" + try: + import yaml + except ImportError: + log.warning("PyYAML not installed — cannot apply source config.yaml overrides.") + return {} + config_file = f"{subfolder}/config.yaml" + try: + local = api.hf_hub_download(source_repo_id, config_file) + data = yaml.safe_load(Path(local).read_text()) or {} + if isinstance(data, dict): + log.info(" Loaded source config from %s", config_file) + return data + except Exception as exc: # noqa: BLE001 + log.warning(" Could not fetch %s: %s — using hardcoded defaults.", config_file, exc) + return {} + + +def _apply_source_config(kwargs: dict, source_config: dict) -> None: + """Apply ginwind/VLA-JEPA config.yaml values to kwargs, mirroring todo_converter.py logic.""" + if not source_config: + return + + data_cfg = _deep_get(source_config, ("datasets", "vla_data"), {}) + action_cfg = _deep_get(source_config, ("framework", "action_model"), {}) + diffusion_cfg = _deep_get(source_config, ("framework", "action_model", "diffusion_model_cfg"), {}) + video_cfg = _deep_get(source_config, ("framework", "vj2_model"), {}) + trainer_cfg = source_config.get("trainer", {}) + + prompt_template = data_cfg.get("CoT_prompt") + if prompt_template: + kwargs["prompt_template"] = str(prompt_template) + + action_horizon = action_cfg.get("action_horizon") + if action_horizon is not None: + kwargs["chunk_size"] = int(action_horizon) + kwargs["n_action_steps"] = int(action_horizon) + + _set_if_present( + kwargs, + "num_action_tokens_per_timestep", + video_cfg.get("num_action_tokens_per_timestep", action_cfg.get("num_action_tokens_per_timestep")), + ) + _set_if_present( + kwargs, + "num_embodied_action_tokens_per_instruction", + video_cfg.get( + "num_embodied_action_tokens_per_instruction", + action_cfg.get("num_embodied_action_tokens_per_instruction"), + ), + ) + _set_if_present(kwargs, "num_inference_timesteps", action_cfg.get("num_inference_timesteps")) + _set_if_present(kwargs, "special_action_token", video_cfg.get("special_action_token")) + _set_if_present(kwargs, "embodied_action_token", video_cfg.get("embodied_action_token")) + _set_if_present( + kwargs, "action_hidden_size", action_cfg.get("action_hidden_dim", action_cfg.get("hidden_size")) + ) + _set_if_present(kwargs, "action_model_type", action_cfg.get("action_model_type")) + _set_if_present(kwargs, "action_noise_beta_alpha", action_cfg.get("noise_beta_alpha")) + _set_if_present(kwargs, "action_noise_beta_beta", action_cfg.get("noise_beta_beta")) + _set_if_present(kwargs, "action_noise_s", action_cfg.get("noise_s")) + _set_if_present(kwargs, "action_num_timestep_buckets", action_cfg.get("num_timestep_buckets")) + _set_if_present(kwargs, "repeated_diffusion_steps", action_cfg.get("repeated_diffusion_steps")) + _set_if_present(kwargs, "action_num_layers", diffusion_cfg.get("num_layers")) + _set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout")) + + _set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames")) + _set_if_present( + kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth")) + ) + _set_if_present( + kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads")) + ) + _set_if_present(kwargs, "predictor_mlp_ratio", video_cfg.get("predictor_mlp_ratio")) + + _set_if_present(kwargs, "optimizer_grad_clip_norm", trainer_cfg.get("max_grad_norm")) + learning_rate = trainer_cfg.get("learning_rate", {}) + if isinstance(learning_rate, dict): + _set_if_present(kwargs, "optimizer_lr", learning_rate.get("action_model")) + optimizer_cfg = trainer_cfg.get("optimizer", {}) + if isinstance(optimizer_cfg, dict): + _set_if_present(kwargs, "optimizer_eps", optimizer_cfg.get("eps")) + _set_if_present(kwargs, "optimizer_weight_decay", optimizer_cfg.get("weight_decay")) + betas = optimizer_cfg.get("betas") + if betas is not None: + kwargs["optimizer_betas"] = tuple(betas) + scheduler = trainer_cfg.get("scheduler", {}) + if isinstance(scheduler, dict): + _set_if_present(kwargs, "scheduler_warmup_steps", scheduler.get("warmup_steps")) + _set_if_present(kwargs, "scheduler_decay_lr", scheduler.get("min_lr")) + _set_if_present(kwargs, "scheduler_warmup_steps", trainer_cfg.get("num_warmup_steps")) + scheduler_kwargs = trainer_cfg.get("scheduler_specific_kwargs", {}) + if isinstance(scheduler_kwargs, dict): + _set_if_present(kwargs, "scheduler_decay_lr", scheduler_kwargs.get("min_lr")) + + # --------------------------------------------------------------------------- # Architecture — identical across all 4 variants (from config.json) # --------------------------------------------------------------------------- @@ -129,7 +245,7 @@ _ARCH = { # LIBERO — confirmed from lerobot/libero_10 meta/info.json _LIBERO_CAMS = [ "observation.images.image", # agentview camera - "observation.images.wrist_image", # eye-in-hand camera + "observation.images.image2", # eye-in-hand camera ] # DROID pretrain — 2 views match the predictor embed_dim=2 × 1024=2048 in checkpoint @@ -153,23 +269,41 @@ def _build_config( camera_keys: list[str], with_state: bool, enable_world_model: bool = True, + source_config: dict | None = None, ): from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig - input_features = {k: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)) for k in camera_keys} + kwargs = dict(_ARCH) + _apply_source_config(kwargs, source_config or {}) + + # Image resolution: prefer source config, fall back to 224 + data_cfg = _deep_get(source_config or {}, ("datasets", "vla_data"), {}) + raw_res = data_cfg.get("resolution_size") + resolution_size = int(raw_res) if raw_res is not None else 224 + image_shape = (3, resolution_size, resolution_size) + # Always set resize_images_to so the policy resizes env images to the training resolution, + # regardless of what resolution the eval env renders at. + kwargs["resize_images_to"] = (resolution_size, resolution_size) + + # State / action dims: prefer source config + action_cfg = _deep_get(source_config or {}, ("framework", "action_model"), {}) + state_dim = int(action_cfg["state_dim"]) if "state_dim" in action_cfg else 8 + action_dim = int(action_cfg["action_dim"]) if "action_dim" in action_cfg else 7 + + input_features = {k: PolicyFeature(type=FeatureType.VISUAL, shape=image_shape) for k in camera_keys} if with_state: - input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(8,)) + input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(state_dim,)) cfg = VLAJEPAConfig( input_features=input_features, output_features={ - "action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)), + "action": PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)), }, enable_world_model=enable_world_model, binarize_gripper_action=True, clip_normalized_actions=True, - **_ARCH, + **kwargs, ) cfg.validate_features() return cfg @@ -276,13 +410,12 @@ def main() -> None: log.info(" Skipped sample: %s", skipped_keys[:5]) log.info(" First 5 mapped keys: %s", list(mapped_sd)[:5]) - # 3. Fetch action stats (min/max per dim) needed by the postprocessor unnormalizer - action_stats_raw = _fetch_action_stats(api, SOURCE_REPO_ID, subfolder) - # Wrap as {"action": {...}} for UnnormalizerProcessorStep - dataset_stats = {"action": action_stats_raw} if action_stats_raw is not None else None + # 3. Fetch action + state stats needed by the pre/postprocessor unnormalizers + dataset_stats = _fetch_dataset_stats(api, SOURCE_REPO_ID, subfolder) # 4. Build config (no policy instantiation — avoids loading backbone from Hub) - config = _build_config(camera_keys, with_state, enable_world_model) + source_config = _fetch_source_config(api, SOURCE_REPO_ID, subfolder) + config = _build_config(camera_keys, with_state, enable_world_model, source_config) # 5. Save everything to a temp dir and upload in one shot api.create_repo(target_repo_id, repo_type="model", exist_ok=True) diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 88b2fbbdd..adf8b7540 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -117,38 +117,21 @@ class VLAJEPAModel(nn.Module): ) def _qwen_last_decoder_hidden(self, qwen_inputs: dict[str, torch.Tensor]) -> torch.Tensor: - """Return Qwen's final decoder-layer output before the final RMSNorm. + """Return Qwen's last decoder hidden state matching training behaviour. - starVLA trained its downstream heads on the legacy transformers-4.57 - `hidden_states[-1]` value, which is the last decoder layer output before - Qwen's final RMSNorm. Newer transformers versions expose `hidden_states[-1]` - as the post-norm last hidden state, so capture the layer output directly. + The original starVLA uses `output_hidden_states=True` and takes `hidden_states[-1]`. + In transformers 5.x the `@capture_outputs` decorator explicitly replaces + `hidden_states[-1]` with `last_hidden_state` (post-RMSNorm), so this call + consistently returns the post-norm output regardless of transformers version, + matching what the model was trained with. """ - captured: dict[str, torch.Tensor] = {} - language_model = self.qwen.model.model.language_model - - def capture_last_layer_output( - _module: nn.Module, - _inputs: tuple[torch.Tensor, ...], - output: torch.Tensor | tuple[torch.Tensor, ...], - ) -> None: - captured["last_hidden"] = output[0] if isinstance(output, tuple) else output - return None - - handle = language_model.layers[-1].register_forward_hook(capture_last_layer_output) - try: - self.qwen.model.model( - **qwen_inputs, - output_hidden_states=False, - output_attentions=False, - return_dict=True, - ) - finally: - handle.remove() - - if "last_hidden" not in captured: - raise RuntimeError("Failed to capture Qwen last decoder hidden states.") - return captured["last_hidden"] + outputs = self.qwen.model( + **qwen_inputs, + output_hidden_states=True, + output_attentions=False, + return_dict=True, + ) + return outputs.hidden_states[-1] # ---- Native VLA-JEPA forward (follows original VLA_JEPA.py) ---- @@ -469,7 +452,7 @@ class VLAJEPAPolicy(PreTrainedPolicy): # Clamp to [0, 255] if t_np.max() <= 1.0: t_np = t_np * 255.0 - t_np = t_np.clip(0, 255).astype(np.uint8) + t_np = np.rint(t_np.clip(0, 255)).astype(np.uint8) sample_views.append(t_np) # Stack views: [V, T, H, W, 3] videos_per_sample.append(np.stack(sample_views, axis=0)) diff --git a/src/lerobot/policies/vla_jepa/processor_vla_jepa.py b/src/lerobot/policies/vla_jepa/processor_vla_jepa.py index a455737e6..e661dde8f 100644 --- a/src/lerobot/policies/vla_jepa/processor_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/processor_vla_jepa.py @@ -38,6 +38,30 @@ class ClipActionsProcessorStep(ProcessorStep): return features +@ProcessorStepRegistry.register(name="vla_jepa_pre_snap_gripper") +class PreSnapGripperProcessorStep(ProcessorStep): + """Snaps gripper dim (index 6) to {0, 1} BEFORE unnormalization. + + Mirrors the original starVLA LIBERO eval: + normalized[:, 6] = np.where(normalized[:, 6] < 0.5, 0, 1) + This ensures the unnormalizer receives an exact binary value, which is + required when the model was trained with gripper in identity (mask=False) + space where 0=open and 1=close. + """ + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition.get(TransitionKey.ACTION) + if action is not None and action.shape[-1] >= 7: + transition = dict(transition) + a = action.clone() + a[..., 6] = (a[..., 6] >= 0.5).float() + transition[TransitionKey.ACTION] = a + return transition + + def transform_features(self, features): + return features + + @ProcessorStepRegistry.register(name="vla_jepa_binarize_gripper") class BinarizeGripperProcessorStep(ProcessorStep): """Binarizes gripper dim (index 6) after unnormalization. @@ -80,6 +104,8 @@ def make_vla_jepa_pre_post_processors( output_steps: list[ProcessorStep] = [] if config.clip_normalized_actions: output_steps.append(ClipActionsProcessorStep()) + if config.pre_snap_gripper_action: + output_steps.append(PreSnapGripperProcessorStep()) output_steps.append( UnnormalizerProcessorStep( features=features, diff --git a/src/lerobot/policies/vla_jepa/qwen_interface.py b/src/lerobot/policies/vla_jepa/qwen_interface.py index 1031f837b..2df74f51c 100644 --- a/src/lerobot/policies/vla_jepa/qwen_interface.py +++ b/src/lerobot/policies/vla_jepa/qwen_interface.py @@ -97,7 +97,7 @@ class Qwen3VLInterface(torch.nn.Module): image = image.float() if image.max() <= 1.0: image = image * 255.0 - image = image.clamp(0, 255).to(torch.uint8).numpy() + image = image.clamp(0, 255).round().to(torch.uint8).numpy() if image.shape[-1] == 1: image = np.repeat(image, 3, axis=-1) return Image.fromarray(image)