diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py index 03760a9de..f803432c6 100644 --- a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -52,6 +52,7 @@ class VLAJEPAConfig(PreTrainedConfig): action_noise_beta_beta: float = 1.0 action_noise_s: float = 0.999 + # total video frames loaded per sample num_video_frames: int = 4 predictor_depth: int = 6 predictor_num_heads: int = 8 @@ -59,6 +60,8 @@ class VLAJEPAConfig(PreTrainedConfig): predictor_dropout: float = 0.0 world_model_loss_weight: float = 0.1 enable_world_model: bool = True + jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256) + repeated_diffusion_steps: int = 4 # independent noise draws per batch item (CogACT-style) resize_images_to: tuple[int, int] | None = None torch_dtype: str = "bfloat16" @@ -78,8 +81,11 @@ class VLAJEPAConfig(PreTrainedConfig): raise ValueError("`n_action_steps` must be <= `chunk_size`.") if self.future_action_window_size + 1 > self.chunk_size: raise ValueError("`chunk_size` must cover the predicted action horizon.") - if self.num_video_frames < 2: - raise ValueError("`num_video_frames` must be >= 2 for JEPA prediction.") + if self.num_video_frames < 2 * self.jepa_tubelet_size: + raise ValueError( + f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` " + f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position." + ) def validate_features(self) -> None: if not self.image_features: @@ -109,7 +115,9 @@ class VLAJEPAConfig(PreTrainedConfig): @property def observation_delta_indices(self) -> list[int]: - return [0] + # load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1] + # matches original repo's observation_indices=list(range(video_horizon)) + return list(range(self.num_video_frames)) @property def action_delta_indices(self) -> list[int]: diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index 448fc34f8..ded318992 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -65,25 +65,34 @@ class VLAJEPAModel(nn.Module): self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size) # JEPA world model components - self.video_encoder = AutoModel.from_pretrained( - config.jepa_encoder_name, - torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype), - ) - self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name) - self.video_predictor = ActionConditionedVideoPredictor( - embed_dim=self.video_encoder.config.hidden_size, - action_embed_dim=self.qwen.model.config.hidden_size, - predictor_embed_dim=self.video_encoder.config.hidden_size, - depth=config.predictor_depth, - num_heads=config.predictor_num_heads, - mlp_ratio=config.predictor_mlp_ratio, - num_action_tokens_per_step=config.num_action_tokens_per_timestep, - ) + if config.enable_world_model: + self.video_encoder = AutoModel.from_pretrained( + config.jepa_encoder_name, + torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype), + ) + self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name) + num_views = max(len(config.image_features), 1) + self.video_predictor = ActionConditionedVideoPredictor( + embed_dim=num_views * self.video_encoder.config.hidden_size, + action_embed_dim=self.qwen.model.config.hidden_size, + predictor_embed_dim=self.video_encoder.config.hidden_size, + depth=config.predictor_depth, + num_heads=config.predictor_num_heads, + mlp_ratio=config.predictor_mlp_ratio, + num_action_tokens_per_step=config.num_action_tokens_per_timestep, + ) + else: + self.video_encoder = None + self.video_processor = None + self.video_predictor = None - # Build prompt placeholders (same as original) + # Build prompt placeholders. + # Original uses num_frames // tubelet_size - 1 action token groups for the world model predictor. + # This matches the number of context temporal positions after tubelet compression. + n_wm_action_groups = max(1, self.config.num_video_frames // self.config.jepa_tubelet_size - 1) self.replace_prompt = "".join( token * self.config.num_action_tokens_per_timestep - for token in self.action_tokens[: self.config.num_video_frames - 1] + for token in self.action_tokens[:n_wm_action_groups] ) self.embodied_replace_prompt = ( self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction @@ -127,16 +136,18 @@ class VLAJEPAModel(nn.Module): embodied_prompt=self.embodied_replace_prompt, ) - # Locate action and embodied-action tokens in the tokenized sequence - action_mask = torch.isin( - qwen_inputs["input_ids"], - torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device), - ) - action_indices = action_mask.nonzero(as_tuple=True) - + # Locate embodied-action tokens (always needed for action head) embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id embodied_indices = embodied_mask.nonzero(as_tuple=True) + # Locate action tokens (only needed for world model predictor) + if self.config.enable_world_model: + action_mask = torch.isin( + qwen_inputs["input_ids"], + torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device), + ) + action_indices = action_mask.nonzero(as_tuple=True) + device_type = next(self.parameters()).device.type with torch.autocast(device_type=device_type, dtype=torch.bfloat16): @@ -149,68 +160,72 @@ class VLAJEPAModel(nn.Module): last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H] b, _, h = last_hidden.shape - action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h) + if self.config.enable_world_model: + action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h) embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h) - # ---- Step 2: JEPA Encoder (same as original) ---- - b, v, t_frames, c, h_img, w_img = batch_videos.shape - batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img) - - video_pixels = [] - for i in range(b * v): - video_pixels.append( - self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[ - "pixel_values_videos" - ].to(self.video_encoder.device) - ) - video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W] - - with torch.no_grad(): - video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels) - # Merge views: [B*V, ...] -> [B, ..., V*embed_dim] - video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2) - - # ---- Step 3: JEPA Predictor (same as original) ---- - tubelet_size = self.video_encoder.config.tubelet_size - t_enc = t_frames // tubelet_size - device_wm = video_embeddings.device - - if t_enc < 2: - # Not enough frames for JEPA prediction (need at least 2 encoded frames) + # ---- Step 2+3: JEPA Encoder + Predictor ---- + device_wm = last_hidden.device + if not self.config.enable_world_model: wm_loss = torch.tensor(0.0, device=device_wm) else: - tokens_per_frame = video_embeddings.shape[1] // t_enc + b, v, t_frames, c, h_img, w_img = batch_videos.shape + batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img) - # input_states: frames 0..T-2 [B, (T-1)*tokens_per_frame, D] - # gt_states: frames 1..T-1 [B, (T-1)*tokens_per_frame, D] - input_states = video_embeddings[:, : tokens_per_frame * (t_enc - 1), :] - gt_states = video_embeddings[:, tokens_per_frame:, :] - d_emb = input_states.shape[-1] + video_pixels = [] + for i in range(b * v): + video_pixels.append( + self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[ + "pixel_values_videos" + ].to(self.video_encoder.device) + ) + video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W] - # Reshape to 4D for ActionConditionedVideoPredictor: - # [B, (T-1)*tokens, D] → [B, T-1, tokens, D] - input_states_4d = input_states.view(b, t_enc - 1, tokens_per_frame, d_emb) + with torch.no_grad(): + video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels) + # Merge views: [B*V, ...] -> [B, ..., V*embed_dim] + video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2) - # Reshape action tokens: [B, total_acts, D] → [B, T-1, per_step, D] - expected_actions = (t_enc - 1) * self.config.num_action_tokens_per_timestep - if action_tokens.shape[1] < expected_actions: - pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1) - action_tokens = torch.cat([action_tokens, pad], dim=1) - act_4d = action_tokens[:, :expected_actions].view( - b, t_enc - 1, self.config.num_action_tokens_per_timestep, -1 - ) + tubelet_size = self.video_encoder.config.tubelet_size + t_enc_ctx = self.config.n_obs_steps // tubelet_size + t_enc_fut = self.config.n_future_frames // tubelet_size + device_wm = video_embeddings.device - # Cast to float32 for predictor (Linear layers are float32) - pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float()) - predicted_states = pred_4d.reshape(b, -1, d_emb) + if t_enc_ctx < 1 or t_enc_fut < 1: + # not enough frames for one tubelet -> skip world model loss + wm_loss = torch.tensor(0.0, device=device_wm) + else: + t_enc_total = t_enc_ctx + t_enc_fut + tokens_per_frame = video_embeddings.shape[1] // t_enc_total - wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean") + # context: encoded frames [t-(n_obs-1)..t+0] + # future: encoded frames [t+1..t+n_future] + input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :] + gt_states = video_embeddings[:, tokens_per_frame * t_enc_ctx :, :] + d_emb = input_states.shape[-1] + + # [B, t_enc_ctx*tokens, D] → [B, t_enc_ctx, tokens, D] + input_states_4d = input_states.view(b, t_enc_ctx, tokens_per_frame, d_emb) + + # Action tokens conditioning: one group per context step + expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep + if action_tokens.shape[1] < expected_actions: + pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1) + action_tokens = torch.cat([action_tokens, pad], dim=1) + act_4d = action_tokens[:, :expected_actions].view( + b, t_enc_ctx, self.config.num_action_tokens_per_timestep, -1 + ) + + pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float()) + predicted_states = pred_4d.reshape(b, -1, d_emb) + + wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean") if not has_action: return {"wm_loss": wm_loss} - # ---- Step 4: Action Head (same as original) ---- + # ---- Step 4: Action Head ---- with torch.autocast(device_type=device_type, dtype=torch.float32): actions_tensor = torch.tensor( np.array(actions), device=last_hidden.device, dtype=torch.float32 @@ -224,8 +239,14 @@ class VLAJEPAModel(nn.Module): np.array(state), device=last_hidden.device, dtype=torch.float32 ) # [B, 1, state_dim] - # Cast embodied tokens to float32 for action model compatibility - action_loss = self.action_model(embodied_action_tokens.float(), actions_target, state_tensor) + # repeated_diffusion_steps: draw R independent noise samples per batch item (CogACT-style). + # Effectively multiplies data efficiency of the action head by R with no extra Qwen/JEPA cost. + R = self.config.repeated_diffusion_steps + embodied_rep = embodied_action_tokens.float().repeat(R, 1, 1) + actions_rep = actions_target.repeat(R, 1, 1) + state_rep = state_tensor.repeat(R, 1, 1) if state_tensor is not None else None + + action_loss = self.action_model(embodied_rep, actions_rep, state_rep) return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight} @@ -350,8 +371,9 @@ class VLAJEPAPolicy(PreTrainedPolicy): for key in image_keys: tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W] if tensor.ndim == 5: - # Multi-frame: take the last frame as the "current" image - tensor = tensor[:, -1] + # observation_delta_indices = [-(n_obs_steps-1), ..., 0, ..., n_future_frames] + # frame at index n_obs_steps-1 is t=0 (current observation) + tensor = tensor[:, self.config.n_obs_steps - 1] for b in range(batch_size): images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))