mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 19:31:25 +00:00
modifications to gym_manipulator and buffer
This commit is contained in:
@@ -250,28 +250,18 @@ def act_with_policy(
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with TimerManager(
|
||||
elapsed_time_list=list_policy_time,
|
||||
label="Policy inference time",
|
||||
log=False,
|
||||
) as timer: # noqa: F841
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with TimerManager(
|
||||
elapsed_time_list=list_policy_time,
|
||||
label="Policy inference time",
|
||||
log=False,
|
||||
) as timer: # noqa: F841
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
||||
else:
|
||||
# TODO (azouitine): Make a custom space for torch tensor
|
||||
action = online_env.action_space.sample()
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||
action = (
|
||||
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
||||
)
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
# Increment total steps counter for intervention rate
|
||||
|
||||
Reference in New Issue
Block a user