adjusting obs steps, tublets size to match original implementation

This commit is contained in:
Maxime Ellerbach
2026-05-13 12:38:42 +00:00
committed by Maximellerbach
parent ea535ad98d
commit 596c72bfc6
2 changed files with 108 additions and 78 deletions

View File

@@ -52,6 +52,7 @@ class VLAJEPAConfig(PreTrainedConfig):
action_noise_beta_beta: float = 1.0
action_noise_s: float = 0.999
# total video frames loaded per sample
num_video_frames: int = 4
predictor_depth: int = 6
predictor_num_heads: int = 8
@@ -59,6 +60,8 @@ class VLAJEPAConfig(PreTrainedConfig):
predictor_dropout: float = 0.0
world_model_loss_weight: float = 0.1
enable_world_model: bool = True
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
repeated_diffusion_steps: int = 4 # independent noise draws per batch item (CogACT-style)
resize_images_to: tuple[int, int] | None = None
torch_dtype: str = "bfloat16"
@@ -78,8 +81,11 @@ class VLAJEPAConfig(PreTrainedConfig):
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
if self.future_action_window_size + 1 > self.chunk_size:
raise ValueError("`chunk_size` must cover the predicted action horizon.")
if self.num_video_frames < 2:
raise ValueError("`num_video_frames` must be >= 2 for JEPA prediction.")
if self.num_video_frames < 2 * self.jepa_tubelet_size:
raise ValueError(
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
)
def validate_features(self) -> None:
if not self.image_features:
@@ -109,7 +115,9 @@ class VLAJEPAConfig(PreTrainedConfig):
@property
def observation_delta_indices(self) -> list[int]:
return [0]
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
# matches original repo's observation_indices=list(range(video_horizon))
return list(range(self.num_video_frames))
@property
def action_delta_indices(self) -> list[int]:

View File

@@ -65,25 +65,34 @@ class VLAJEPAModel(nn.Module):
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
# JEPA world model components
self.video_encoder = AutoModel.from_pretrained(
config.jepa_encoder_name,
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
)
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
self.video_predictor = ActionConditionedVideoPredictor(
embed_dim=self.video_encoder.config.hidden_size,
action_embed_dim=self.qwen.model.config.hidden_size,
predictor_embed_dim=self.video_encoder.config.hidden_size,
depth=config.predictor_depth,
num_heads=config.predictor_num_heads,
mlp_ratio=config.predictor_mlp_ratio,
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
)
if config.enable_world_model:
self.video_encoder = AutoModel.from_pretrained(
config.jepa_encoder_name,
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
)
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
num_views = max(len(config.image_features), 1)
self.video_predictor = ActionConditionedVideoPredictor(
embed_dim=num_views * self.video_encoder.config.hidden_size,
action_embed_dim=self.qwen.model.config.hidden_size,
predictor_embed_dim=self.video_encoder.config.hidden_size,
depth=config.predictor_depth,
num_heads=config.predictor_num_heads,
mlp_ratio=config.predictor_mlp_ratio,
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
)
else:
self.video_encoder = None
self.video_processor = None
self.video_predictor = None
# Build prompt placeholders (same as original)
# Build prompt placeholders.
# Original uses num_frames // tubelet_size - 1 action token groups for the world model predictor.
# This matches the number of context temporal positions after tubelet compression.
n_wm_action_groups = max(1, self.config.num_video_frames // self.config.jepa_tubelet_size - 1)
self.replace_prompt = "".join(
token * self.config.num_action_tokens_per_timestep
for token in self.action_tokens[: self.config.num_video_frames - 1]
for token in self.action_tokens[:n_wm_action_groups]
)
self.embodied_replace_prompt = (
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
@@ -127,16 +136,18 @@ class VLAJEPAModel(nn.Module):
embodied_prompt=self.embodied_replace_prompt,
)
# Locate action and embodied-action tokens in the tokenized sequence
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
)
action_indices = action_mask.nonzero(as_tuple=True)
# Locate embodied-action tokens (always needed for action head)
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
embodied_indices = embodied_mask.nonzero(as_tuple=True)
# Locate action tokens (only needed for world model predictor)
if self.config.enable_world_model:
action_mask = torch.isin(
qwen_inputs["input_ids"],
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
)
action_indices = action_mask.nonzero(as_tuple=True)
device_type = next(self.parameters()).device.type
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
@@ -149,68 +160,72 @@ class VLAJEPAModel(nn.Module):
last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H]
b, _, h = last_hidden.shape
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
if self.config.enable_world_model:
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
# ---- Step 2: JEPA Encoder (same as original) ----
b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
video_pixels = []
for i in range(b * v):
video_pixels.append(
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device)
)
video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W]
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
# ---- Step 3: JEPA Predictor (same as original) ----
tubelet_size = self.video_encoder.config.tubelet_size
t_enc = t_frames // tubelet_size
device_wm = video_embeddings.device
if t_enc < 2:
# Not enough frames for JEPA prediction (need at least 2 encoded frames)
# ---- Step 2+3: JEPA Encoder + Predictor ----
device_wm = last_hidden.device
if not self.config.enable_world_model:
wm_loss = torch.tensor(0.0, device=device_wm)
else:
tokens_per_frame = video_embeddings.shape[1] // t_enc
b, v, t_frames, c, h_img, w_img = batch_videos.shape
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
# input_states: frames 0..T-2 [B, (T-1)*tokens_per_frame, D]
# gt_states: frames 1..T-1 [B, (T-1)*tokens_per_frame, D]
input_states = video_embeddings[:, : tokens_per_frame * (t_enc - 1), :]
gt_states = video_embeddings[:, tokens_per_frame:, :]
d_emb = input_states.shape[-1]
video_pixels = []
for i in range(b * v):
video_pixels.append(
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
"pixel_values_videos"
].to(self.video_encoder.device)
)
video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W]
# Reshape to 4D for ActionConditionedVideoPredictor:
# [B, (T-1)*tokens, D] → [B, T-1, tokens, D]
input_states_4d = input_states.view(b, t_enc - 1, tokens_per_frame, d_emb)
with torch.no_grad():
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
# Reshape action tokens: [B, total_acts, D] → [B, T-1, per_step, D]
expected_actions = (t_enc - 1) * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
act_4d = action_tokens[:, :expected_actions].view(
b, t_enc - 1, self.config.num_action_tokens_per_timestep, -1
)
tubelet_size = self.video_encoder.config.tubelet_size
t_enc_ctx = self.config.n_obs_steps // tubelet_size
t_enc_fut = self.config.n_future_frames // tubelet_size
device_wm = video_embeddings.device
# Cast to float32 for predictor (Linear layers are float32)
pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float())
predicted_states = pred_4d.reshape(b, -1, d_emb)
if t_enc_ctx < 1 or t_enc_fut < 1:
# not enough frames for one tubelet -> skip world model loss
wm_loss = torch.tensor(0.0, device=device_wm)
else:
t_enc_total = t_enc_ctx + t_enc_fut
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
# context: encoded frames [t-(n_obs-1)..t+0]
# future: encoded frames [t+1..t+n_future]
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
gt_states = video_embeddings[:, tokens_per_frame * t_enc_ctx :, :]
d_emb = input_states.shape[-1]
# [B, t_enc_ctx*tokens, D] → [B, t_enc_ctx, tokens, D]
input_states_4d = input_states.view(b, t_enc_ctx, tokens_per_frame, d_emb)
# Action tokens conditioning: one group per context step
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
if action_tokens.shape[1] < expected_actions:
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
action_tokens = torch.cat([action_tokens, pad], dim=1)
act_4d = action_tokens[:, :expected_actions].view(
b, t_enc_ctx, self.config.num_action_tokens_per_timestep, -1
)
pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float())
predicted_states = pred_4d.reshape(b, -1, d_emb)
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
if not has_action:
return {"wm_loss": wm_loss}
# ---- Step 4: Action Head (same as original) ----
# ---- Step 4: Action Head ----
with torch.autocast(device_type=device_type, dtype=torch.float32):
actions_tensor = torch.tensor(
np.array(actions), device=last_hidden.device, dtype=torch.float32
@@ -224,8 +239,14 @@ class VLAJEPAModel(nn.Module):
np.array(state), device=last_hidden.device, dtype=torch.float32
) # [B, 1, state_dim]
# Cast embodied tokens to float32 for action model compatibility
action_loss = self.action_model(embodied_action_tokens.float(), actions_target, state_tensor)
# repeated_diffusion_steps: draw R independent noise samples per batch item (CogACT-style).
# Effectively multiplies data efficiency of the action head by R with no extra Qwen/JEPA cost.
R = self.config.repeated_diffusion_steps
embodied_rep = embodied_action_tokens.float().repeat(R, 1, 1)
actions_rep = actions_target.repeat(R, 1, 1)
state_rep = state_tensor.repeat(R, 1, 1) if state_tensor is not None else None
action_loss = self.action_model(embodied_rep, actions_rep, state_rep)
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
@@ -350,8 +371,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
for key in image_keys:
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
if tensor.ndim == 5:
# Multi-frame: take the last frame as the "current" image
tensor = tensor[:, -1]
# observation_delta_indices = [-(n_obs_steps-1), ..., 0, ..., n_future_frames]
# frame at index n_obs_steps-1 is t=0 (current observation)
tensor = tensor[:, self.config.n_obs_steps - 1]
for b in range(batch_size):
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))