fix(policies): remove action from batch for offline evaluation (#1609)

* fix(policies): remove action from batch for offline evaluation in diffusion, tdmpc, and vqbet policies

* style(diffusion): correct comment capitalization for clarity in modeling_diffusion.py
This commit is contained in:
Adil Zouitine
2025-07-28 13:10:34 +02:00
committed by GitHub
parent 664e069c3f
commit c3d5e494c0
3 changed files with 15 additions and 3 deletions

View File

@@ -133,11 +133,15 @@ class DiffusionPolicy(PreTrainedPolicy):
"horizon" may not the best name to describe what the variable actually means, because this period is
actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past.
"""
# NOTE: for offline evaluation, we have action in the batch, so we need to pop it out
if ACTION in batch:
batch.pop(ACTION)
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
# Note: It's important that this happens after stacking the images into a single key.
# NOTE: It's important that this happens after stacking the images into a single key.
self._queues = populate_queues(self._queues, batch)
if len(self._queues[ACTION]) == 0: