mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 10:21:24 +00:00
adding instructions for different embodiement + fixing some tests
This commit is contained in:
@@ -66,15 +66,16 @@ All checkpoints use `Qwen/Qwen3-VL-2B-Instruct` as the language backbone.
|
|||||||
|
|
||||||
Key parameters in `VLAJEPAConfig`:
|
Key parameters in `VLAJEPAConfig`:
|
||||||
|
|
||||||
| Parameter | Default | Description |
|
| Parameter | Default | Description |
|
||||||
| ------------------------- | ------- | -------------------------------------------------------------- |
|
| ------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `chunk_size` | 7 | Number of actions predicted per inference call |
|
| `chunk_size` | 7 | Number of actions predicted per inference call |
|
||||||
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
|
| `n_action_steps` | 7 | Steps executed from the predicted chunk before re-planning |
|
||||||
| `num_video_frames` | 8 | Video clip length fed to the world model |
|
| `num_video_frames` | 8 | Video clip length fed to the world model |
|
||||||
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
| `enable_world_model` | `True` | Whether to load and train the V-JEPA2 predictor |
|
||||||
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
| `world_model_loss_weight` | 0.1 | Weight of the JEPA prediction loss relative to the action loss |
|
||||||
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
|
| `num_inference_timesteps` | 4 | Euler integration steps for action denoising |
|
||||||
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
| `freeze_qwen` | `False` | Freeze the Qwen3-VL backbone and only train the action head |
|
||||||
|
| `reinit_modules` | `None` | Key prefixes allowed to be randomly re-initialised on load (for cross-embodiment transfer, see [Fine-tuning on a different embodiment](#fine-tuning-on-a-different-embodiment)) |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -110,6 +111,29 @@ lerobot-train \
|
|||||||
--dataset.repo_id=your_org/your_dataset
|
--dataset.repo_id=your_org/your_dataset
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Fine-tuning on a different embodiment
|
||||||
|
|
||||||
|
When the target robot has a different action or state dimensionality than the pretrained checkpoint, the input/output projection layers of the action head will have mismatched shapes and cannot be loaded directly. `reinit_modules` lets you list the key prefixes that are allowed to mismatch — those layers are randomly re-initialised while every other weight is reused from the checkpoint. Any shape mismatch outside the listed prefixes raises an error.
|
||||||
|
|
||||||
|
The layers that depend on `action_dim` and `state_dim` are:
|
||||||
|
|
||||||
|
| Layer | Key prefix |
|
||||||
|
| ----------------------------------------- | ----------------------------------- |
|
||||||
|
| Action encoder (action_dim → inner_dim) | `model.action_model.action_encoder` |
|
||||||
|
| Action decoder (hidden_size → action_dim) | `model.action_model.action_decoder` |
|
||||||
|
| State encoder (state_dim → inner_dim) | `model.action_model.state_encoder` |
|
||||||
|
|
||||||
|
```bash
|
||||||
|
lerobot-train \
|
||||||
|
--policy.path=lerobot/VLA-JEPA-Pretrain \
|
||||||
|
--policy.repo_id=your_org/your_repo \
|
||||||
|
--policy.freeze_qwen=true \
|
||||||
|
--policy.reinit_modules='["model.action_model.action_encoder", "model.action_model.action_decoder", "model.action_model.state_encoder"]' \
|
||||||
|
--dataset.repo_id=your_org/your_dataset
|
||||||
|
```
|
||||||
|
|
||||||
|
If your robot has no proprioceptive state, omit `model.action_model.state_encoder` from the list.
|
||||||
|
|
||||||
### Reproducing the LIBERO results
|
### Reproducing the LIBERO results
|
||||||
|
|
||||||
**Training on LIBERO:**
|
**Training on LIBERO:**
|
||||||
@@ -132,7 +156,7 @@ lerobot-eval \
|
|||||||
--env.type=libero \
|
--env.type=libero \
|
||||||
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
--env.task=libero_spatial,libero_object,libero_goal,libero_10 \
|
||||||
--eval.n_episodes=10 \
|
--eval.n_episodes=10 \
|
||||||
--eval.batch_size=5 \
|
--eval.batch_size=5
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -145,9 +169,19 @@ lerobot-eval \
|
|||||||
--env.task=libero_10 \
|
--env.task=libero_10 \
|
||||||
--env.task_ids='[0,1,2]' \
|
--env.task_ids='[0,1,2]' \
|
||||||
--eval.n_episodes=10 \
|
--eval.n_episodes=10 \
|
||||||
--eval.batch_size=5 \
|
--eval.batch_size=5
|
||||||
```
|
```
|
||||||
|
|
||||||
|
**Expected results:**
|
||||||
|
|
||||||
|
| Suite | Episodes | Successes | Success Rate |
|
||||||
|
| -------------- | -------- | --------- | ------------ |
|
||||||
|
| libero_spatial | 100 | 93 | **95.0%** |
|
||||||
|
| libero_object | 100 | 100 | **100.0%** |
|
||||||
|
| libero_goal | 100 | 98 | **98.0%** |
|
||||||
|
| libero_10 | 100 | 96 | **93.0%** |
|
||||||
|
| **Overall** | **400** | **387** | **96.5%** |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Fine-tuning on single-camera datasets
|
## Fine-tuning on single-camera datasets
|
||||||
|
|||||||
@@ -27,7 +27,12 @@ class VLAJEPAConfig(PreTrainedConfig):
|
|||||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||||
freeze_qwen: bool = False
|
freeze_qwen: bool = False
|
||||||
enable_world_model: bool = True
|
enable_world_model: bool = True
|
||||||
reinit_action_head: bool = False
|
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
|
||||||
|
# different action or state dimensionality, the input/output projection layers must be
|
||||||
|
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
|
||||||
|
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
|
||||||
|
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
|
||||||
|
reinit_modules: list[str] | None = None
|
||||||
|
|
||||||
tokenizer_padding_side: str = "left"
|
tokenizer_padding_side: str = "left"
|
||||||
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
||||||
|
|||||||
@@ -219,14 +219,9 @@ class VLAJEPAModel(nn.Module):
|
|||||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||||
|
|
||||||
video_pixels = []
|
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
|
||||||
for i in range(b * v):
|
"pixel_values_videos"
|
||||||
video_pixels.append(
|
].to(self.video_encoder.device) # [B*V, T, C, H, W]
|
||||||
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
|
|
||||||
"pixel_values_videos"
|
|
||||||
].to(self.video_encoder.device)
|
|
||||||
)
|
|
||||||
video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||||
@@ -572,11 +567,8 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||||
"""
|
reinit_prefixes = model.config.reinit_modules
|
||||||
Custom loading to enable opt reinit of action head
|
if not reinit_prefixes:
|
||||||
when loading pretrained weights with mismatched action head shapes.
|
|
||||||
"""
|
|
||||||
if not model.config.reinit_action_head:
|
|
||||||
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||||
|
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
@@ -584,20 +576,25 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
|||||||
state_dict = load_file(model_file, device=map_location)
|
state_dict = load_file(model_file, device=map_location)
|
||||||
current = model.state_dict()
|
current = model.state_dict()
|
||||||
|
|
||||||
mismatched: list[str] = []
|
reinitialized: list[str] = []
|
||||||
filtered: dict = {}
|
filtered: dict = {}
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
if key in current and value.shape != current[key].shape:
|
if key in current and value.shape != current[key].shape:
|
||||||
mismatched.append(
|
if not any(key.startswith(p) for p in reinit_prefixes):
|
||||||
f"{key}: checkpoint {tuple(value.shape)} vs model {tuple(current[key].shape)}"
|
raise ValueError(
|
||||||
|
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
|
||||||
|
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
|
||||||
|
)
|
||||||
|
reinitialized.append(
|
||||||
|
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
filtered[key] = value
|
filtered[key] = value
|
||||||
|
|
||||||
if mismatched:
|
if reinitialized:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"reinit_action_head=True: skipping {len(mismatched)} tensor(s) with mismatched shapes "
|
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
|
||||||
f"(randomly re-initialised):\n " + "\n ".join(mismatched)
|
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
|
||||||
)
|
)
|
||||||
|
|
||||||
from lerobot.policies.utils import log_model_loading_keys
|
from lerobot.policies.utils import log_model_loading_keys
|
||||||
|
|||||||
@@ -169,6 +169,7 @@ class _FakeQwenBackbone(nn.Module):
|
|||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
).view(batch_size, seq_len, hidden_size)
|
).view(batch_size, seq_len, hidden_size)
|
||||||
hidden = values / values.numel() + self.weight
|
hidden = values / values.numel() + self.weight
|
||||||
|
self.model(input_ids) # call through so the forward hook on layers[-1] fires
|
||||||
return SimpleNamespace(hidden_states=[hidden])
|
return SimpleNamespace(hidden_states=[hidden])
|
||||||
|
|
||||||
|
|
||||||
@@ -241,9 +242,13 @@ class _FakeVideoEncoder(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class _FakeVideoProcessor:
|
class _FakeVideoProcessor:
|
||||||
def __call__(self, videos: np.ndarray, return_tensors: str) -> dict[str, Tensor]:
|
def __call__(self, videos, return_tensors: str) -> dict[str, Tensor]:
|
||||||
assert return_tensors == "pt"
|
assert return_tensors == "pt"
|
||||||
return {"pixel_values_videos": torch.as_tensor(videos).unsqueeze(0)}
|
if isinstance(videos, list):
|
||||||
|
pixel_values = torch.stack([torch.as_tensor(v) for v in videos])
|
||||||
|
else:
|
||||||
|
pixel_values = torch.as_tensor(videos).unsqueeze(0)
|
||||||
|
return {"pixel_values_videos": pixel_values}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user