mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
adjusting obs steps, tublets size to match original implementation
This commit is contained in:
committed by
Maximellerbach
parent
ea535ad98d
commit
596c72bfc6
@@ -52,6 +52,7 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
action_noise_beta_beta: float = 1.0
|
||||
action_noise_s: float = 0.999
|
||||
|
||||
# total video frames loaded per sample
|
||||
num_video_frames: int = 4
|
||||
predictor_depth: int = 6
|
||||
predictor_num_heads: int = 8
|
||||
@@ -59,6 +60,8 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
predictor_dropout: float = 0.0
|
||||
world_model_loss_weight: float = 0.1
|
||||
enable_world_model: bool = True
|
||||
jepa_tubelet_size: int = 2 # must match the encoder (e.g. 2 for vjepa2-vitl-fpc64-256)
|
||||
repeated_diffusion_steps: int = 4 # independent noise draws per batch item (CogACT-style)
|
||||
|
||||
resize_images_to: tuple[int, int] | None = None
|
||||
torch_dtype: str = "bfloat16"
|
||||
@@ -78,8 +81,11 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
raise ValueError("`n_action_steps` must be <= `chunk_size`.")
|
||||
if self.future_action_window_size + 1 > self.chunk_size:
|
||||
raise ValueError("`chunk_size` must cover the predicted action horizon.")
|
||||
if self.num_video_frames < 2:
|
||||
raise ValueError("`num_video_frames` must be >= 2 for JEPA prediction.")
|
||||
if self.num_video_frames < 2 * self.jepa_tubelet_size:
|
||||
raise ValueError(
|
||||
f"`video_horizon` ({self.num_video_frames}) must be >= 2 * `jepa_tubelet_size` "
|
||||
f"({self.jepa_tubelet_size}) to have at least one context and one GT temporal position."
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features:
|
||||
@@ -109,7 +115,9 @@ class VLAJEPAConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
return [0]
|
||||
# load video_horizon frames starting from current timestep: [t, t+1, ..., t+video_horizon-1]
|
||||
# matches original repo's observation_indices=list(range(video_horizon))
|
||||
return list(range(self.num_video_frames))
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
|
||||
@@ -65,25 +65,34 @@ class VLAJEPAModel(nn.Module):
|
||||
self.action_model = VLAJEPAActionHead(config, cross_attention_dim=self.qwen.model.config.hidden_size)
|
||||
|
||||
# JEPA world model components
|
||||
self.video_encoder = AutoModel.from_pretrained(
|
||||
config.jepa_encoder_name,
|
||||
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
|
||||
self.video_predictor = ActionConditionedVideoPredictor(
|
||||
embed_dim=self.video_encoder.config.hidden_size,
|
||||
action_embed_dim=self.qwen.model.config.hidden_size,
|
||||
predictor_embed_dim=self.video_encoder.config.hidden_size,
|
||||
depth=config.predictor_depth,
|
||||
num_heads=config.predictor_num_heads,
|
||||
mlp_ratio=config.predictor_mlp_ratio,
|
||||
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
|
||||
)
|
||||
if config.enable_world_model:
|
||||
self.video_encoder = AutoModel.from_pretrained(
|
||||
config.jepa_encoder_name,
|
||||
torch_dtype=self.qwen._get_torch_dtype(config.torch_dtype),
|
||||
)
|
||||
self.video_processor = AutoVideoProcessor.from_pretrained(config.jepa_encoder_name)
|
||||
num_views = max(len(config.image_features), 1)
|
||||
self.video_predictor = ActionConditionedVideoPredictor(
|
||||
embed_dim=num_views * self.video_encoder.config.hidden_size,
|
||||
action_embed_dim=self.qwen.model.config.hidden_size,
|
||||
predictor_embed_dim=self.video_encoder.config.hidden_size,
|
||||
depth=config.predictor_depth,
|
||||
num_heads=config.predictor_num_heads,
|
||||
mlp_ratio=config.predictor_mlp_ratio,
|
||||
num_action_tokens_per_step=config.num_action_tokens_per_timestep,
|
||||
)
|
||||
else:
|
||||
self.video_encoder = None
|
||||
self.video_processor = None
|
||||
self.video_predictor = None
|
||||
|
||||
# Build prompt placeholders (same as original)
|
||||
# Build prompt placeholders.
|
||||
# Original uses num_frames // tubelet_size - 1 action token groups for the world model predictor.
|
||||
# This matches the number of context temporal positions after tubelet compression.
|
||||
n_wm_action_groups = max(1, self.config.num_video_frames // self.config.jepa_tubelet_size - 1)
|
||||
self.replace_prompt = "".join(
|
||||
token * self.config.num_action_tokens_per_timestep
|
||||
for token in self.action_tokens[: self.config.num_video_frames - 1]
|
||||
for token in self.action_tokens[:n_wm_action_groups]
|
||||
)
|
||||
self.embodied_replace_prompt = (
|
||||
self.config.embodied_action_token * self.config.num_embodied_action_tokens_per_instruction
|
||||
@@ -127,16 +136,18 @@ class VLAJEPAModel(nn.Module):
|
||||
embodied_prompt=self.embodied_replace_prompt,
|
||||
)
|
||||
|
||||
# Locate action and embodied-action tokens in the tokenized sequence
|
||||
action_mask = torch.isin(
|
||||
qwen_inputs["input_ids"],
|
||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||
)
|
||||
action_indices = action_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Locate embodied-action tokens (always needed for action head)
|
||||
embodied_mask = qwen_inputs["input_ids"] == self.embodied_action_token_id
|
||||
embodied_indices = embodied_mask.nonzero(as_tuple=True)
|
||||
|
||||
# Locate action tokens (only needed for world model predictor)
|
||||
if self.config.enable_world_model:
|
||||
action_mask = torch.isin(
|
||||
qwen_inputs["input_ids"],
|
||||
torch.tensor(self.action_token_ids, device=qwen_inputs["input_ids"].device),
|
||||
)
|
||||
action_indices = action_mask.nonzero(as_tuple=True)
|
||||
|
||||
device_type = next(self.parameters()).device.type
|
||||
|
||||
with torch.autocast(device_type=device_type, dtype=torch.bfloat16):
|
||||
@@ -149,68 +160,72 @@ class VLAJEPAModel(nn.Module):
|
||||
last_hidden = qwen_outputs.hidden_states[-1] # [B, seq_len, H]
|
||||
b, _, h = last_hidden.shape
|
||||
|
||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
||||
if self.config.enable_world_model:
|
||||
action_tokens = last_hidden[action_indices[0], action_indices[1], :].view(b, -1, h)
|
||||
|
||||
embodied_action_tokens = last_hidden[embodied_indices[0], embodied_indices[1], :].view(b, -1, h)
|
||||
|
||||
# ---- Step 2: JEPA Encoder (same as original) ----
|
||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
|
||||
video_pixels = []
|
||||
for i in range(b * v):
|
||||
video_pixels.append(
|
||||
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device)
|
||||
)
|
||||
video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W]
|
||||
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
# ---- Step 3: JEPA Predictor (same as original) ----
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
t_enc = t_frames // tubelet_size
|
||||
device_wm = video_embeddings.device
|
||||
|
||||
if t_enc < 2:
|
||||
# Not enough frames for JEPA prediction (need at least 2 encoded frames)
|
||||
# ---- Step 2+3: JEPA Encoder + Predictor ----
|
||||
device_wm = last_hidden.device
|
||||
if not self.config.enable_world_model:
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc
|
||||
b, v, t_frames, c, h_img, w_img = batch_videos.shape
|
||||
batch_videos_flat = batch_videos.reshape(b * v, t_frames, c, h_img, w_img)
|
||||
|
||||
# input_states: frames 0..T-2 [B, (T-1)*tokens_per_frame, D]
|
||||
# gt_states: frames 1..T-1 [B, (T-1)*tokens_per_frame, D]
|
||||
input_states = video_embeddings[:, : tokens_per_frame * (t_enc - 1), :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame:, :]
|
||||
d_emb = input_states.shape[-1]
|
||||
video_pixels = []
|
||||
for i in range(b * v):
|
||||
video_pixels.append(
|
||||
self.video_processor(videos=batch_videos_flat[i], return_tensors="pt")[
|
||||
"pixel_values_videos"
|
||||
].to(self.video_encoder.device)
|
||||
)
|
||||
video_pixels = torch.cat(video_pixels, dim=0) # [B*V, T, C, H, W]
|
||||
|
||||
# Reshape to 4D for ActionConditionedVideoPredictor:
|
||||
# [B, (T-1)*tokens, D] → [B, T-1, tokens, D]
|
||||
input_states_4d = input_states.view(b, t_enc - 1, tokens_per_frame, d_emb)
|
||||
with torch.no_grad():
|
||||
video_embeddings = self.video_encoder.get_vision_features(pixel_values_videos=video_pixels)
|
||||
# Merge views: [B*V, ...] -> [B, ..., V*embed_dim]
|
||||
video_embeddings = torch.cat(torch.chunk(video_embeddings, chunks=v, dim=0), dim=2)
|
||||
|
||||
# Reshape action tokens: [B, total_acts, D] → [B, T-1, per_step, D]
|
||||
expected_actions = (t_enc - 1) * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
act_4d = action_tokens[:, :expected_actions].view(
|
||||
b, t_enc - 1, self.config.num_action_tokens_per_timestep, -1
|
||||
)
|
||||
tubelet_size = self.video_encoder.config.tubelet_size
|
||||
t_enc_ctx = self.config.n_obs_steps // tubelet_size
|
||||
t_enc_fut = self.config.n_future_frames // tubelet_size
|
||||
device_wm = video_embeddings.device
|
||||
|
||||
# Cast to float32 for predictor (Linear layers are float32)
|
||||
pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float())
|
||||
predicted_states = pred_4d.reshape(b, -1, d_emb)
|
||||
if t_enc_ctx < 1 or t_enc_fut < 1:
|
||||
# not enough frames for one tubelet -> skip world model loss
|
||||
wm_loss = torch.tensor(0.0, device=device_wm)
|
||||
else:
|
||||
t_enc_total = t_enc_ctx + t_enc_fut
|
||||
tokens_per_frame = video_embeddings.shape[1] // t_enc_total
|
||||
|
||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
# context: encoded frames [t-(n_obs-1)..t+0]
|
||||
# future: encoded frames [t+1..t+n_future]
|
||||
input_states = video_embeddings[:, : tokens_per_frame * t_enc_ctx, :]
|
||||
gt_states = video_embeddings[:, tokens_per_frame * t_enc_ctx :, :]
|
||||
d_emb = input_states.shape[-1]
|
||||
|
||||
# [B, t_enc_ctx*tokens, D] → [B, t_enc_ctx, tokens, D]
|
||||
input_states_4d = input_states.view(b, t_enc_ctx, tokens_per_frame, d_emb)
|
||||
|
||||
# Action tokens conditioning: one group per context step
|
||||
expected_actions = t_enc_ctx * self.config.num_action_tokens_per_timestep
|
||||
if action_tokens.shape[1] < expected_actions:
|
||||
pad = action_tokens[:, -1:].repeat(1, expected_actions - action_tokens.shape[1], 1)
|
||||
action_tokens = torch.cat([action_tokens, pad], dim=1)
|
||||
act_4d = action_tokens[:, :expected_actions].view(
|
||||
b, t_enc_ctx, self.config.num_action_tokens_per_timestep, -1
|
||||
)
|
||||
|
||||
pred_4d = self.video_predictor(input_states_4d.float(), act_4d.float())
|
||||
predicted_states = pred_4d.reshape(b, -1, d_emb)
|
||||
|
||||
wm_loss = F.l1_loss(predicted_states, gt_states.float(), reduction="mean")
|
||||
|
||||
if not has_action:
|
||||
return {"wm_loss": wm_loss}
|
||||
|
||||
# ---- Step 4: Action Head (same as original) ----
|
||||
# ---- Step 4: Action Head ----
|
||||
with torch.autocast(device_type=device_type, dtype=torch.float32):
|
||||
actions_tensor = torch.tensor(
|
||||
np.array(actions), device=last_hidden.device, dtype=torch.float32
|
||||
@@ -224,8 +239,14 @@ class VLAJEPAModel(nn.Module):
|
||||
np.array(state), device=last_hidden.device, dtype=torch.float32
|
||||
) # [B, 1, state_dim]
|
||||
|
||||
# Cast embodied tokens to float32 for action model compatibility
|
||||
action_loss = self.action_model(embodied_action_tokens.float(), actions_target, state_tensor)
|
||||
# repeated_diffusion_steps: draw R independent noise samples per batch item (CogACT-style).
|
||||
# Effectively multiplies data efficiency of the action head by R with no extra Qwen/JEPA cost.
|
||||
R = self.config.repeated_diffusion_steps
|
||||
embodied_rep = embodied_action_tokens.float().repeat(R, 1, 1)
|
||||
actions_rep = actions_target.repeat(R, 1, 1)
|
||||
state_rep = state_tensor.repeat(R, 1, 1) if state_tensor is not None else None
|
||||
|
||||
action_loss = self.action_model(embodied_rep, actions_rep, state_rep)
|
||||
|
||||
return {"action_loss": action_loss, "wm_loss": wm_loss * self.config.world_model_loss_weight}
|
||||
|
||||
@@ -350,8 +371,9 @@ class VLAJEPAPolicy(PreTrainedPolicy):
|
||||
for key in image_keys:
|
||||
tensor = batch[key] # [B, C, H, W] or [B, T, C, H, W]
|
||||
if tensor.ndim == 5:
|
||||
# Multi-frame: take the last frame as the "current" image
|
||||
tensor = tensor[:, -1]
|
||||
# observation_delta_indices = [-(n_obs_steps-1), ..., 0, ..., n_future_frames]
|
||||
# frame at index n_obs_steps-1 is t=0 (current observation)
|
||||
tensor = tensor[:, self.config.n_obs_steps - 1]
|
||||
for b in range(batch_size):
|
||||
images_per_sample[b].append(self.model.qwen.tensor_to_pil(tensor[b]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user