Fixes @torch.no_grad() usage (#1455)

* fix: decorator calls with parentheses

* fix no grad for normalize too

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>

---------

Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
Francesco Capuano
2025-07-08 13:08:32 +02:00
committed by GitHub
parent aec1b29d23
commit a5e0aae13a
9 changed files with 16 additions and 15 deletions

View File

@@ -99,7 +99,7 @@ class DiffusionPolicy(PreTrainedPolicy):
if self.config.env_state_feature:
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
@torch.no_grad
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
@@ -111,7 +111,7 @@ class DiffusionPolicy(PreTrainedPolicy):
return actions
@torch.no_grad
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.