removing conversion script

This commit is contained in:
Maximellerbach
2026-05-28 11:47:00 +02:00
parent c2088fd760
commit d26d2da138
2 changed files with 1 additions and 456 deletions

View File

@@ -1,455 +0,0 @@
#!/usr/bin/env python
"""
Convert all VLA-JEPA .pt checkpoints (ginwind/VLA-JEPA) to LeRobot safetensors
format and upload them to maximellerbach org inside a HF collection.
Usage:
uv run python convert_vla_jepa_checkpoints.py
For each variant the script:
1. Downloads the .pt checkpoint.
2. Extracts the state dict.
3. Instantiates VLAJEPAPolicy with the variant's confirmed config.
4. Loads the state dict (strict=False — mismatches printed to stdout).
5. push_to_hub → writes model.safetensors + config.json in LeRobot format.
6. Adds the new repo to a shared HF collection.
Config sources
--------------
Numeric hyper-params : ginwind/VLA-JEPA/<variant>/config.json
Image keys LIBERO : lerobot/libero_10 meta/info.json
Image keys Pretrain : lerobot/droid_1.0.1 meta/info.json
Image keys SimplerEnv: OXE Bridge/RT1 single-camera (view x2)
"""
from __future__ import annotations
import logging
import tempfile
from pathlib import Path
import torch
from huggingface_hub import HfApi
from safetensors.torch import save_file as save_safetensors
from lerobot.policies.vla_jepa.processor_vla_jepa import make_vla_jepa_pre_post_processors
# ---------------------------------------------------------------------------
# Top-level settings
# ---------------------------------------------------------------------------
SOURCE_REPO_ID = "ginwind/VLA-JEPA"
TARGET_ORG = "maximellerbach"
COLLECTION_TITLE = "VLA-JEPA"
COLLECTION_DESCRIPTION = (
"VLA-JEPA model checkpoints (LIBERO, Pretrain, SimplerEnv) converted from .pt to safetensors via LeRobot."
)
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Key mapping — mirrors todo_converter.py map_key() so both converters
# produce identical safetensors layouts that match the LeRobot action_head code.
# ---------------------------------------------------------------------------
def _normalize_source_key(key: str) -> str:
return key[len("module.") :] if key.startswith("module.") else key
def _map_checkpoint_key(raw_key: str) -> str | None:
"""Map original VLA-JEPA state-dict keys to LeRobot vla_jepa layout."""
key = _normalize_source_key(raw_key)
if key.startswith("qwen_vl_interface."):
return "model.qwen." + key[len("qwen_vl_interface.") :]
if key.startswith("vj_encoder."):
return "model.video_encoder." + key[len("vj_encoder.") :]
if key.startswith("vj_predictor."):
return "model.video_predictor." + key[len("vj_predictor.") :]
if key.startswith("action_model."):
# LeRobot code uses the same sub-key names as the source checkpoint,
# so only the top-level "model." prefix needs to be added.
return "model." + key
return None
def _fetch_dataset_stats(api: HfApi, source_repo_id: str, subfolder: str) -> dict | None:
"""Download dataset_statistics.json and return {action: {...}, state: {...}} stats dict."""
import json
stats_file = f"{subfolder}/dataset_statistics.json"
try:
local = api.hf_hub_download(source_repo_id, stats_file)
data = json.loads(Path(local).read_text())
# Original repo nests stats under a robot key, e.g. {"franka": {"action": {...}, "state": {...}}}
for robot_key in data:
robot_data = data[robot_key]
if isinstance(robot_data, dict) and "action" in robot_data:
log.info(" Loaded dataset stats from %s (robot key: %s)", stats_file, robot_key)
result = {"action": robot_data["action"]}
if "state" in robot_data:
result["observation.state"] = robot_data["state"]
log.info(" Also loaded state stats.")
return result
log.warning(" %s found but no 'action' key under any robot key — skipping stats.", stats_file)
except Exception as exc: # noqa: BLE001
log.warning(" Could not fetch %s: %s — postprocessor will have no unnorm stats.", stats_file, exc)
return None
def _set_if_present(d: dict, key: str, value) -> None:
if value is not None:
d[key] = value
def _deep_get(mapping: dict, path: tuple, default=None):
current = mapping
for key in path:
if not isinstance(current, dict) or key not in current:
return default
current = current[key]
return current
def _fetch_source_config(api: HfApi, source_repo_id: str, subfolder: str) -> dict:
"""Download config.yaml from the source HF repo for a given variant subfolder."""
try:
import yaml
except ImportError:
log.warning("PyYAML not installed — cannot apply source config.yaml overrides.")
return {}
config_file = f"{subfolder}/config.yaml"
try:
local = api.hf_hub_download(source_repo_id, config_file)
data = yaml.safe_load(Path(local).read_text()) or {}
if isinstance(data, dict):
log.info(" Loaded source config from %s", config_file)
return data
except Exception as exc: # noqa: BLE001
log.warning(" Could not fetch %s: %s — using hardcoded defaults.", config_file, exc)
return {}
def _apply_source_config(kwargs: dict, source_config: dict) -> None:
"""Apply ginwind/VLA-JEPA config.yaml values to kwargs, mirroring todo_converter.py logic."""
if not source_config:
return
data_cfg = _deep_get(source_config, ("datasets", "vla_data"), {})
action_cfg = _deep_get(source_config, ("framework", "action_model"), {})
diffusion_cfg = _deep_get(source_config, ("framework", "action_model", "diffusion_model_cfg"), {})
video_cfg = _deep_get(source_config, ("framework", "vj2_model"), {})
trainer_cfg = source_config.get("trainer", {})
prompt_template = data_cfg.get("CoT_prompt")
if prompt_template:
kwargs["prompt_template"] = str(prompt_template)
action_horizon = action_cfg.get("action_horizon")
if action_horizon is not None:
kwargs["chunk_size"] = int(action_horizon)
kwargs["n_action_steps"] = int(action_horizon)
_set_if_present(
kwargs,
"num_action_tokens_per_timestep",
video_cfg.get("num_action_tokens_per_timestep", action_cfg.get("num_action_tokens_per_timestep")),
)
_set_if_present(
kwargs,
"num_embodied_action_tokens_per_instruction",
video_cfg.get(
"num_embodied_action_tokens_per_instruction",
action_cfg.get("num_embodied_action_tokens_per_instruction"),
),
)
_set_if_present(kwargs, "num_inference_timesteps", action_cfg.get("num_inference_timesteps"))
_set_if_present(kwargs, "special_action_token", video_cfg.get("special_action_token"))
_set_if_present(kwargs, "embodied_action_token", video_cfg.get("embodied_action_token"))
_set_if_present(
kwargs, "action_hidden_size", action_cfg.get("action_hidden_dim", action_cfg.get("hidden_size"))
)
_set_if_present(kwargs, "action_model_type", action_cfg.get("action_model_type"))
_set_if_present(kwargs, "action_noise_beta_alpha", action_cfg.get("noise_beta_alpha"))
_set_if_present(kwargs, "action_noise_beta_beta", action_cfg.get("noise_beta_beta"))
_set_if_present(kwargs, "action_noise_s", action_cfg.get("noise_s"))
_set_if_present(kwargs, "action_num_timestep_buckets", action_cfg.get("num_timestep_buckets"))
_set_if_present(kwargs, "repeated_diffusion_steps", action_cfg.get("repeated_diffusion_steps"))
_set_if_present(kwargs, "action_num_layers", diffusion_cfg.get("num_layers"))
_set_if_present(kwargs, "action_dropout", diffusion_cfg.get("dropout"))
_set_if_present(kwargs, "num_video_frames", video_cfg.get("num_frames"))
_set_if_present(kwargs, "predictor_depth", video_cfg.get("predictor_depth", video_cfg.get("depth")))
_set_if_present(
kwargs, "predictor_num_heads", video_cfg.get("predictor_num_heads", video_cfg.get("num_heads"))
)
_set_if_present(kwargs, "predictor_mlp_ratio", video_cfg.get("predictor_mlp_ratio"))
_set_if_present(kwargs, "optimizer_grad_clip_norm", trainer_cfg.get("max_grad_norm"))
learning_rate = trainer_cfg.get("learning_rate", {})
if isinstance(learning_rate, dict):
_set_if_present(kwargs, "optimizer_lr", learning_rate.get("action_model"))
optimizer_cfg = trainer_cfg.get("optimizer", {})
if isinstance(optimizer_cfg, dict):
_set_if_present(kwargs, "optimizer_eps", optimizer_cfg.get("eps"))
_set_if_present(kwargs, "optimizer_weight_decay", optimizer_cfg.get("weight_decay"))
betas = optimizer_cfg.get("betas")
if betas is not None:
kwargs["optimizer_betas"] = tuple(betas)
scheduler = trainer_cfg.get("scheduler", {})
if isinstance(scheduler, dict):
_set_if_present(kwargs, "scheduler_warmup_steps", scheduler.get("warmup_steps"))
_set_if_present(kwargs, "scheduler_decay_lr", scheduler.get("min_lr"))
_set_if_present(kwargs, "scheduler_warmup_steps", trainer_cfg.get("num_warmup_steps"))
scheduler_kwargs = trainer_cfg.get("scheduler_specific_kwargs", {})
if isinstance(scheduler_kwargs, dict):
_set_if_present(kwargs, "scheduler_decay_lr", scheduler_kwargs.get("min_lr"))
# ---------------------------------------------------------------------------
# Architecture — identical across all 4 variants (from config.json)
# ---------------------------------------------------------------------------
_ARCH = {
"qwen_model_name": "Qwen/Qwen3-VL-2B-Instruct", # 2B, NOT the default 4B
"chunk_size": 7,
"n_action_steps": 7,
"num_video_frames": 8,
"jepa_tubelet_size": 2,
"num_action_tokens_per_timestep": 8,
"num_embodied_action_tokens_per_instruction": 32,
"num_inference_timesteps": 4,
"action_hidden_size": 1024,
"action_model_type": "DiT-B",
# Explicit dims matching DiT-B preset and ginwind checkpoint shape
"action_num_heads": 12,
"action_attention_head_dim": 64,
"action_num_layers": 16,
"action_dropout": 0.2,
"repeated_diffusion_steps": 8,
"action_noise_beta_alpha": 1.5,
"action_noise_beta_beta": 1.0,
"action_noise_s": 0.999,
"action_num_timestep_buckets": 1000,
# World model predictor (12 blocks, confirmed from checkpoint)
"predictor_depth": 12,
}
# ---------------------------------------------------------------------------
# Image-key sets (confirmed sources in module docstring)
# ---------------------------------------------------------------------------
# LIBERO — confirmed from lerobot/libero_10 meta/info.json
_LIBERO_CAMS = [
"observation.images.image", # agentview camera
"observation.images.image2", # eye-in-hand camera
]
# DROID pretrain — 2 views match the predictor embed_dim=2 × 1024=2048 in checkpoint
_DROID_CAMS = [
"observation.images.exterior_1_left",
"observation.images.exterior_2_left",
]
# OXE Bridge + RT1 — single-camera; view is duplicated at runtime to match the 2-view world model
_OXE_CAMS = [
"observation.images.image",
]
# ---------------------------------------------------------------------------
# Config factories
# ---------------------------------------------------------------------------
def _build_config(
camera_keys: list[str],
with_state: bool,
enable_world_model: bool = True,
source_config: dict | None = None,
):
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.vla_jepa.configuration_vla_jepa import VLAJEPAConfig
kwargs = dict(_ARCH)
_apply_source_config(kwargs, source_config or {})
# Image resolution: prefer source config, fall back to 224
data_cfg = _deep_get(source_config or {}, ("datasets", "vla_data"), {})
raw_res = data_cfg.get("resolution_size")
resolution_size = int(raw_res) if raw_res is not None else 224
image_shape = (3, resolution_size, resolution_size)
# Always set resize_images_to so the policy resizes env images to the training resolution,
# regardless of what resolution the eval env renders at.
kwargs["resize_images_to"] = (resolution_size, resolution_size)
# State / action dims: prefer source config
action_cfg = _deep_get(source_config or {}, ("framework", "action_model"), {})
state_dim = int(action_cfg["state_dim"]) if "state_dim" in action_cfg else 8
action_dim = int(action_cfg["action_dim"]) if "action_dim" in action_cfg else 7
input_features = {k: PolicyFeature(type=FeatureType.VISUAL, shape=image_shape) for k in camera_keys}
if with_state:
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))
cfg = VLAJEPAConfig(
input_features=input_features,
output_features={
"action": PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
},
enable_world_model=enable_world_model,
binarize_gripper_action=True,
clip_normalized_actions=True,
**kwargs,
)
cfg.validate_features()
return cfg
# Maps each subfolder in SOURCE_REPO_ID to (camera_keys, with_state, enable_world_model, repo_suffix)
VARIANTS: dict[str, tuple] = {
"LIBERO": (_LIBERO_CAMS, True, True, "LIBERO"),
"Pretrain": (_DROID_CAMS, False, True, "Pretrain"),
# SimplerEnv uses a single camera; the single view is duplicated at runtime to produce
# the 2-view input the world model expects (embed_dim=2048).
"SimplerEnv": (_OXE_CAMS, False, True, "SimplerEnv"),
}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def extract_state_dict(ckpt: object) -> dict[str, torch.Tensor]:
if isinstance(ckpt, dict):
sd = ckpt.get("state_dict") or ckpt.get("model_state_dict") or ckpt.get("model")
if sd is None:
sd = ckpt
else:
sd = ckpt
return {k: v for k, v in sd.items() if isinstance(v, torch.Tensor)}
def subfolder_of(pt_path: str) -> str | None:
for part in Path(pt_path).parts:
if part in VARIANTS:
return part
return None
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
api = HfApi()
log.info("Listing .pt files in %s", SOURCE_REPO_ID)
pt_files = [f for f in api.list_repo_files(SOURCE_REPO_ID) if f.endswith(".pt")]
if not pt_files:
log.error("No .pt files found.")
return
for f in pt_files:
log.info(" %s", f)
# Create / reuse the collection once
collection = api.create_collection(
title=COLLECTION_TITLE,
description=COLLECTION_DESCRIPTION,
namespace=TARGET_ORG,
exists_ok=True,
)
log.info("Collection: %s", collection.url)
for pt_filename in pt_files:
log.info("\n=== %s ===", pt_filename)
subfolder = subfolder_of(pt_filename)
if subfolder is None:
log.warning(" No variant entry for '%s' — skipping.", pt_filename)
continue
camera_keys, with_state, enable_world_model, repo_suffix = VARIANTS[subfolder]
target_repo_id = f"{TARGET_ORG}/VLA-JEPA-{repo_suffix}"
log.info(
" cameras=%d with_state=%s wm=%s%s",
len(camera_keys),
with_state,
enable_world_model,
target_repo_id,
)
# 1. Download
local_pt = api.hf_hub_download(SOURCE_REPO_ID, pt_filename)
log.info(" Downloaded → %s", local_pt)
# 2. Load checkpoint
try:
ckpt = torch.load(local_pt, map_location="cpu", mmap=True, weights_only=False) # nosec B614
except TypeError:
ckpt = torch.load(local_pt, map_location="cpu") # nosec B614
sd = extract_state_dict(ckpt)
# Map source key names → LeRobot layout (handles layer1→w1, transformer_blocks→blocks, etc.)
mapped_sd: dict[str, torch.Tensor] = {}
skipped_keys: list[str] = []
for raw_key, value in sd.items():
target_key = _map_checkpoint_key(raw_key)
if target_key is None:
skipped_keys.append(raw_key)
else:
mapped_sd[target_key] = value
log.info(" %d tensors mapped, %d skipped", len(mapped_sd), len(skipped_keys))
if skipped_keys:
log.info(" Skipped sample: %s", skipped_keys[:5])
log.info(" First 5 mapped keys: %s", list(mapped_sd)[:5])
# 3. Fetch action + state stats needed by the pre/postprocessor unnormalizers
dataset_stats = _fetch_dataset_stats(api, SOURCE_REPO_ID, subfolder)
# 4. Build config (no policy instantiation — avoids loading backbone from Hub)
source_config = _fetch_source_config(api, SOURCE_REPO_ID, subfolder)
config = _build_config(camera_keys, with_state, enable_world_model, source_config)
# 5. Save everything to a temp dir and upload in one shot
api.create_repo(target_repo_id, repo_type="model", exist_ok=True)
with tempfile.TemporaryDirectory() as tmp:
save_dir = Path(tmp)
log.info(" Saving model.safetensors …")
save_safetensors(mapped_sd, save_dir / "model.safetensors")
preprocessor, postprocessor = make_vla_jepa_pre_post_processors(config, dataset_stats)
preprocessor.save_pretrained(save_dir) # writes policy_preprocessor.json
postprocessor.save_pretrained(save_dir) # writes policy_postprocessor.json
config.device = None # don't bake in the conversion machine's device
config._save_pretrained(save_dir) # writes config.json via draccus
log.info(" Uploading …")
commit_url = api.upload_folder(
folder_path=save_dir,
repo_id=target_repo_id,
repo_type="model",
commit_message=f"Convert {Path(pt_filename).name} to safetensors",
)
log.info(" Uploaded → %s", commit_url)
# 6. Add to collection
api.add_collection_item(
collection_slug=collection.slug,
item_id=target_repo_id,
item_type="model",
exists_ok=True,
)
log.info(" Added to collection.")
log.info("\nAll done. Collection: %s", collection.url)
if __name__ == "__main__":
main()

View File

@@ -287,7 +287,7 @@ def _make_hub_inference_batch(policy: VLAJEPAPolicy, batch_size: int = 1) -> dic
return batch
_CP_ROOT = "lerobot" # TODO: upload converted checkpoints
_CP_ROOT = "lerobot"
# Each tuple: (repo_id, enable_world_model)
_HUB_VARIANTS = [