mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 20:31:25 +00:00
Fixes @torch.no_grad() usage (#1455)
* fix: decorator calls with parentheses * fix no grad for normalize too Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> --------- Signed-off-by: Francesco Capuano <74058581+fracapuano@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
aec1b29d23
commit
a5e0aae13a
@@ -99,7 +99,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
# stack n latest observations from the queue
|
||||
@@ -111,7 +111,7 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
|
||||
return actions
|
||||
|
||||
@torch.no_grad
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user