From af36b42e90cd97c6ae0b007236f9dfa6fdc8762a Mon Sep 17 00:00:00 2001 From: Maximellerbach Date: Wed, 13 May 2026 15:51:55 +0200 Subject: [PATCH] some more fixes to be closer to the original implem --- .../vla_jepa/configuration_vla_jepa.py | 2 +- .../policies/vla_jepa/modeling_vla_jepa.py | 39 +++++++++---------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py index f803432c6..65070f62b 100644 --- a/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/configuration_vla_jepa.py @@ -53,7 +53,7 @@ class VLAJEPAConfig(PreTrainedConfig): action_noise_s: float = 0.999 # total video frames loaded per sample - num_video_frames: int = 4 + num_video_frames: int = 16 predictor_depth: int = 6 predictor_num_heads: int = 8 predictor_mlp_ratio: float = 4.0 diff --git a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py index ded318992..ecdaef978 100644 --- a/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py +++ b/src/lerobot/policies/vla_jepa/modeling_vla_jepa.py @@ -188,27 +188,24 @@ class VLAJEPAModel(nn.Module): video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2) tubelet_size = self.video_encoder.config.tubelet_size - t_enc_ctx = self.config.n_obs_steps // tubelet_size - t_enc_fut = self.config.n_future_frames // tubelet_size device_wm = video_embeddings.device + # num_video_frames raw frames → t_enc_total temporal positions after tubelet compression + t_enc_total = self.config.num_video_frames // tubelet_size - if t_enc_ctx < 1 or t_enc_fut < 1: - # not enough frames for one tubelet -> skip world model loss + if t_enc_total < 2: wm_loss = torch.tensor(0.0, device=device_wm) else: - t_enc_total = t_enc_ctx + t_enc_fut + # Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232): + # input_states: positions 0..T-2, gt_states: positions 1..T-1 + t_enc_ctx = t_enc_total - 1 tokens_per_frame = video_embeddings.shape[1] // t_enc_total - # context: encoded frames [t-(n_obs-1)..t+0] - # future: encoded frames [t+1..t+n_future] input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :] - gt_states = video_embeddings[:, tokens_per_frame * t_enc_ctx :, :] + gt_states = video_embeddings[:, tokens_per_frame:, :] d_emb = input_states.shape[-1] - # [B, t_enc_ctx*tokens, D] → [B, t_enc_ctx, tokens, D] input_states_4d = input_states.view(b, t_enc_ctx, tokens_per_frame, d_emb) - # Action tokens conditioning: one group per context step expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep if action_tokens.shape[1] < expected_actions: pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1) @@ -241,10 +238,10 @@ class VLAJEPAModel(nn.Module): # repeated_diffusion_steps: draw R independent noise samples per batch item (CogACT-style). # Effectively multiplies data efficiency of the action head by R with no extra Qwen/JEPA cost. - R = self.config.repeated_diffusion_steps - embodied_rep = embodied_action_tokens.float().repeat(R, 1, 1) - actions_rep = actions_target.repeat(R, 1, 1) - state_rep = state_tensor.repeat(R, 1, 1) if state_tensor is not None else None + num_repeated = self.config.repeated_diffusion_steps + embodied_rep = embodied_action_tokens.float().repeat(num_repeated, 1, 1) + actions_rep = actions_target.repeat(num_repeated, 1, 1) + state_rep = state_tensor.repeat(num_repeated, 1, 1) if state_tensor is not None else None action_loss = self.action_model(embodied_rep, actions_rep, state_rep) @@ -371,9 +368,9 @@ class VLAJEPAPolicy(PreTrainedPolicy): for key in image_keys: tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W] if tensor.ndim == 5: - # observation_delta_indices = [-(n_obs_steps-1), ..., 0, ..., n_future_frames] - # frame at index n_obs_steps-1 is t=0 (current observation) - tensor = tensor[:, self.config.n_obs_steps - 1] + # observation_delta_indices = [0, 1, ..., num_video_frames-1] + # index 0 is the current observation (delta=0) + tensor = tensor[:, 0] for b in range(batch_size): images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b])) @@ -417,16 +414,16 @@ class VLAJEPAPolicy(PreTrainedPolicy): # ---- Collect actions (training only) ---- actions_list = None - if ACTION in batch: - actions_tensor = batch[ACTION] # [B, chunk_size, action_dim] + actions_tensor = batch.get(ACTION) + if actions_tensor is not None: if actions_tensor.ndim == 2: actions_tensor = actions_tensor.unsqueeze(1) actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)] # ---- Collect state ---- state_list = None - if OBS_STATE in batch: - state_tensor = batch[OBS_STATE] # [B, state_dim] or [B, T, state_dim] + state_tensor = batch.get(OBS_STATE) + if state_tensor is not None: if state_tensor.ndim > 2: state_tensor = state_tensor[:, -1, :] if state_tensor.ndim == 2: