mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
adding instructions for different embodiement + fixing some tests
This commit is contained in:
@@ -27,7 +27,12 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
jepa_encoder_name: str = "facebook/vjepa2-vitl-fpc64-256"
|
||||
freeze_qwen: bool = False
|
||||
enable_world_model: bool = True
|
||||
reinit_action_head: bool = False
|
||||
# Enables cross-embodiment transfer: when fine-tuning a pretrained model on a robot with a
|
||||
# different action or state dimensionality, the input/output projection layers must be
|
||||
# re-initialised from scratch while the rest of the network keeps its pretrained weights.
|
||||
# List the key prefixes that are allowed to have shape mismatches; anything else raises an error.
|
||||
# e.g. ["model.action_model.action_encoder", "model.action_model.state_encoder"]
|
||||
reinit_modules: list[str] | None = None
|
||||
|
||||
tokenizer_padding_side: str = "left"
|
||||
prompt_template: str = "Your task is {instruction}. Infer the temporal dynamics from frames {actions} and produce the corresponding policy actions {e_actions}."
|
||||
|
||||
@@ -219,14 +219,9 @@ class VLAJEPAModel(nn.Module):
|
||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
|
||||
video_pixels = []
|
||||
for i in range(b * v):
|
||||
video_pixels.append(
|
||||
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device)
|
||||
)
|
||||
video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W]
|
||||
video_pixels = self.video_processor(videos=list(batch_videos_flat), return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device) # [B*V, T, C, H, W]
|
||||
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
@@ -572,11 +567,8 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
"""
|
||||
Custom loading to enable opt reinit of action head
|
||||
when loading pretrained weights with mismatched action head shapes.
|
||||
"""
|
||||
if not model.config.reinit_action_head:
|
||||
reinit_prefixes = model.config.reinit_modules
|
||||
if not reinit_prefixes:
|
||||
return super()._load_as_safetensor(model, model_file, map_location, strict)
|
||||
|
||||
from safetensors.torch import load_file
|
||||
@@ -584,20 +576,25 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
current = model.state_dict()
|
||||
|
||||
mismatched: list[str] = []
|
||||
reinitialized: list[str] = []
|
||||
filtered: dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key in current and value.shape != current[key].shape:
|
||||
mismatched.append(
|
||||
f"{key}: checkpoint {tuple(value.shape)} vs model {tuple(current[key].shape)}"
|
||||
if not any(key.startswith(p) for p in reinit_prefixes):
|
||||
raise ValueError(
|
||||
f"Shape mismatch for '{key}' (checkpoint {tuple(value.shape)} vs model "
|
||||
f"{tuple(current[key].shape)}) and its prefix is not in `reinit_modules`."
|
||||
)
|
||||
reinitialized.append(
|
||||
f"{key}: checkpoint {tuple(value.shape)} → model {tuple(current[key].shape)}"
|
||||
)
|
||||
else:
|
||||
filtered[key] = value
|
||||
|
||||
if mismatched:
|
||||
if reinitialized:
|
||||
logging.warning(
|
||||
f"reinit_action_head=True: skipping {len(mismatched)} tensor(s) with mismatched shapes "
|
||||
f"(randomly re-initialised):\n " + "\n ".join(mismatched)
|
||||
f"reinit_modules: skipping {len(reinitialized)} tensor(s) with mismatched shapes "
|
||||
f"(randomly re-initialised):\n " + "\n ".join(reinitialized)
|
||||
)
|
||||
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
|
||||
Reference in New Issue
Block a user