chore: replace hard-coded action values with constants throughout all the source code (#2055)

* chore: replace hard-coded 'action' values with constants throughout all the source code

* chore(tests): replace hard-coded action values with constants throughout all the test code
This commit is contained in:
Steven Palma
2025-09-26 13:33:18 +02:00
committed by GitHub
parent 9627765ce2
commit d2782cf66b
47 changed files with 269 additions and 255 deletions

View File

@@ -82,7 +82,7 @@ class DiffusionPolicy(PreTrainedPolicy):
"""Clear observation and action queues. Should be called on `env.reset()`"""
self._queues = {
OBS_STATE: deque(maxlen=self.config.n_obs_steps),
"action": deque(maxlen=self.config.n_action_steps),
ACTION: deque(maxlen=self.config.n_action_steps),
}
if self.config.image_features:
self._queues[OBS_IMAGES] = deque(maxlen=self.config.n_obs_steps)
@@ -306,10 +306,10 @@ class DiffusionModel(nn.Module):
}
"""
# Input validation.
assert set(batch).issuperset({OBS_STATE, "action", "action_is_pad"})
assert set(batch).issuperset({OBS_STATE, ACTION, "action_is_pad"})
assert OBS_IMAGES in batch or OBS_ENV_STATE in batch
n_obs_steps = batch[OBS_STATE].shape[1]
horizon = batch["action"].shape[1]
horizon = batch[ACTION].shape[1]
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
@@ -317,7 +317,7 @@ class DiffusionModel(nn.Module):
global_cond = self._prepare_global_conditioning(batch) # (B, global_cond_dim)
# Forward diffusion.
trajectory = batch["action"]
trajectory = batch[ACTION]
# Sample noise to add to the trajectory.
eps = torch.randn(trajectory.shape, device=trajectory.device)
# Sample a random noising timestep for each item in the batch.
@@ -338,7 +338,7 @@ class DiffusionModel(nn.Module):
if self.config.prediction_type == "epsilon":
target = eps
elif self.config.prediction_type == "sample":
target = batch["action"]
target = batch[ACTION]
else:
raise ValueError(f"Unsupported prediction type {self.config.prediction_type}")