some more fixes to be closer to the original implem

This commit is contained in:
Maximellerbach
2026-05-13 15:51:55 +02:00
parent be9147b131
commit af36b42e90
2 changed files with 19 additions and 22 deletions

View File

@@ -53,7 +53,7 @@ class VLAJEPAConfig(PreTrainedConfig):
action_noise_s: float = 0.999
# total video frames loaded per sample
num_video_frames: int = 4
num_video_frames: int = 16
predictor_depth: int = 6
predictor_num_heads: int = 8
predictor_mlp_ratio: float = 4.0

View File

@@ -188,27 +188,24 @@ class VLAJEPAModel(nn.Module):
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
tubelet_size = self.video_encoder.config.tubelet_size
t_enc_ctx = self.config.n_obs_steps // tubelet_size
t_enc_fut = self.config.n_future_frames // tubelet_size
device_wm = video_embeddings.device
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
t_enc_total = self.config.num_video_frames // tubelet_size
if t_enc_ctx < 1 or t_enc_fut < 1:
# not enough frames for one tubelet -> skip world model loss
if t_enc_total < 2:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
t_enc_total = t_enc_ctx + t_enc_fut
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
# input_states: positions 0..T-2, gt_states: positions 1..T-1
t_enc_ctx = t_enc_total - 1
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
# context: encoded frames [t-(n_obs-1)..t+0]
# future: encoded frames [t+1..t+n_future]
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame * t_enc_ctx :, :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
d_emb = input_states.shape[-1]
# [B, t_enc_ctx*tokens, D] → [B, t_enc_ctx, tokens, D]
input_states_4d = input_states.view(b, t_enc_ctx, tokens_per_frame, d_emb)
# Action tokens conditioning: one group per context step
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
@@ -241,10 +238,10 @@ class VLAJEPAModel(nn.Module):
# repeated_diffusion_steps: draw R independent noise samples per batch item (CogACT-style).
# Effectively multiplies data efficiency of the action head by R with no extra Qwen/JEPA cost.
R = self.config.repeated_diffusion_steps
embodied_rep = embodied_action_tokens.float().repeat(R, 1, 1)
actions_rep = actions_target.repeat(R, 1, 1)
state_rep = state_tensor.repeat(R, 1, 1) if state_tensor is not None else None
num_repeated = self.config.repeated_diffusion_steps
embodied_rep = embodied_action_tokens.float().repeat(num_repeated, 1, 1)
actions_rep = actions_target.repeat(num_repeated, 1, 1)
state_rep = state_tensor.repeat(num_repeated, 1, 1) if state_tensor is not None else None
action_loss = self.action_model(embodied_rep, actions_rep, state_rep)
@@ -371,9 +368,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
for key in image_keys:
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
if tensor.ndim == 5:
# observation_delta_indices = [-(n_obs_steps-1), ..., 0, ..., n_future_frames]
# frame at index n_obs_steps-1 is t=0 (current observation)
tensor = tensor[:, self.config.n_obs_steps - 1]
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
# index 0 is the current observation (delta=0)
tensor = tensor[:, 0]
for b in range(batch_size):
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
@@ -417,16 +414,16 @@ class VLAJEPAPolicy(PreTrainedPolicy):
# ---- Collect actions (training only) ----
actions_list = None
if ACTION in batch:
actions_tensor = batch[ACTION] # [B, chunk_size, action_dim]
actions_tensor = batch.get(ACTION)
if actions_tensor is not None:
if actions_tensor.ndim == 2:
actions_tensor = actions_tensor.unsqueeze(1)
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
# ---- Collect state ----
state_list = None
if OBS_STATE in batch:
state_tensor = batch[OBS_STATE] # [B, state_dim] or [B, T, state_dim]
state_tensor = batch.get(OBS_STATE)
if state_tensor is not None:
if state_tensor.ndim > 2:
state_tensor = state_tensor[:, -1, :]
if state_tensor.ndim == 2: