mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 11:21:27 +00:00
some more fixes to be closer to the original implem
This commit is contained in:
@@ -53,7 +53,7 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
action_noise_s: float = 0.999
|
||||
|
||||
# total video frames loaded per sample
|
||||
num_video_frames: int = 4
|
||||
num_video_frames: int = 16
|
||||
predictor_depth: int = 6
|
||||
predictor_num_heads: int = 8
|
||||
predictor_mlp_ratio: float = 4.0
|
||||
|
||||
@@ -188,27 +188,24 @@ class VLAJEPAModel(nn.Module):
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
t_enc_ctx = self.config.n_obs_steps // tubelet_size
|
||||
t_enc_fut = self.config.n_future_frames // tubelet_size
|
||||
device_wm = video_embeddings.device
|
||||
# num_video_frames raw frames → t_enc_total temporal positions after tubelet compression
|
||||
t_enc_total = self.config.num_video_frames // tubelet_size
|
||||
|
||||
if t_enc_ctx < 1 or t_enc_fut < 1:
|
||||
# not enough frames for one tubelet -> skip world model loss
|
||||
if t_enc_total < 2:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
t_enc_total = t_enc_ctx + t_enc_fut
|
||||
# Shift-by-one JEPA split (matches original VLA_JEPA.py lines 231-232):
|
||||
# input_states: positions 0..T-2, gt_states: positions 1..T-1
|
||||
t_enc_ctx = t_enc_total - 1
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||
|
||||
# context: encoded frames [t-(n_obs-1)..t+0]
|
||||
# future: encoded frames [t+1..t+n_future]
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame * t_enc_ctx :, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
d_emb = input_states.shape[-1]
|
||||
|
||||
# [B, t_enc_ctx*tokens, D] → [B, t_enc_ctx, tokens, D]
|
||||
input_states_4d = input_states.view(b, t_enc_ctx, tokens_per_frame, d_emb)
|
||||
|
||||
# Action tokens conditioning: one group per context step
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
@@ -241,10 +238,10 @@ class VLAJEPAModel(nn.Module):
|
||||
|
||||
# repeated_diffusion_steps: draw R independent noise samples per batch item (CogACT-style).
|
||||
# Effectively multiplies data efficiency of the action head by R with no extra Qwen/JEPA cost.
|
||||
R = self.config.repeated_diffusion_steps
|
||||
embodied_rep = embodied_action_tokens.float().repeat(R, 1, 1)
|
||||
actions_rep = actions_target.repeat(R, 1, 1)
|
||||
state_rep = state_tensor.repeat(R, 1, 1) if state_tensor is not None else None
|
||||
num_repeated = self.config.repeated_diffusion_steps
|
||||
embodied_rep = embodied_action_tokens.float().repeat(num_repeated, 1, 1)
|
||||
actions_rep = actions_target.repeat(num_repeated, 1, 1)
|
||||
state_rep = state_tensor.repeat(num_repeated, 1, 1) if state_tensor is not None else None
|
||||
|
||||
action_loss = self.action_model(embodied_rep, actions_rep, state_rep)
|
||||
|
||||
@@ -371,9 +368,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
for key in image_keys:
|
||||
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
||||
if tensor.ndim == 5:
|
||||
# observation_delta_indices = [-(n_obs_steps-1), ..., 0, ..., n_future_frames]
|
||||
# frame at index n_obs_steps-1 is t=0 (current observation)
|
||||
tensor = tensor[:, self.config.n_obs_steps - 1]
|
||||
# observation_delta_indices = [0, 1, ..., num_video_frames-1]
|
||||
# index 0 is the current observation (delta=0)
|
||||
tensor = tensor[:, 0]
|
||||
for b in range(batch_size):
|
||||
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
||||
|
||||
@@ -417,16 +414,16 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
|
||||
# ---- Collect actions (training only) ----
|
||||
actions_list = None
|
||||
if ACTION in batch:
|
||||
actions_tensor = batch[ACTION] # [B, chunk_size, action_dim]
|
||||
actions_tensor = batch.get(ACTION)
|
||||
if actions_tensor is not None:
|
||||
if actions_tensor.ndim == 2:
|
||||
actions_tensor = actions_tensor.unsqueeze(1)
|
||||
actions_list = [actions_tensor[b].detach().cpu().float().numpy() for b in range(batch_size)]
|
||||
|
||||
# ---- Collect state ----
|
||||
state_list = None
|
||||
if OBS_STATE in batch:
|
||||
state_tensor = batch[OBS_STATE] # [B, state_dim] or [B, T, state_dim]
|
||||
state_tensor = batch.get(OBS_STATE)
|
||||
if state_tensor is not None:
|
||||
if state_tensor.ndim > 2:
|
||||
state_tensor = state_tensor[:, -1, :]
|
||||
if state_tensor.ndim == 2:
|
||||
|
||||
Reference in New Issue
Block a user