fix(rebase): remove residual normalization layer:

This commit is contained in:
Adil Zouitine
2025-07-28 13:16:37 +02:00
committed by Steven Palma
parent fbe9009db2
commit 1feb7b5d88
3 changed files with 0 additions and 5 deletions

View File

@@ -383,8 +383,6 @@ class SmolVLAPolicy(PreTrainedPolicy):
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
actions = self.unnormalize_outputs({ACTION: actions})[ACTION]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)

View File

@@ -138,8 +138,6 @@ class TDMPCPolicy(PreTrainedPolicy):
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))]

View File

@@ -131,7 +131,6 @@ class VQBeTPolicy(PreTrainedPolicy):
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
# NOTE: It's important that this happens after stacking the images into a single key.
batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)