mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-31 19:01:28 +00:00
- Added JointMaskingActionSpace wrapper in gym_manipulator in order to select which joints will be controlled. For example, we can disable the gripper actions for some tasks.
- Added Nan detection mechanisms in the actor, learner and gym_manipulator for the case where we encounter nans in the loop. - changed the non-blocking in the `.to(device)` functions to only work for the case of cuda because they were causing nans when running the policy on mps - Added some joint clipping and limits in the env, robot and policy configs. TODO clean this part and make the limits in one config file only. Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
@@ -101,10 +101,14 @@ class ActorServiceServicer(hilserl_pb2_grpc.ActorServiceServicer):
|
||||
message = message_queue.get(block=True)
|
||||
|
||||
if message.transition is not None:
|
||||
transition_to_send_to_learner = [
|
||||
move_transition_to_device(T, device="cpu") for T in message.transition
|
||||
transition_to_send_to_learner: list[Transition] = [
|
||||
move_transition_to_device(transition=T, device="cpu") for T in message.transition
|
||||
]
|
||||
|
||||
# Check for NaNs in transitions before sending to learner
|
||||
for transition in transition_to_send_to_learner:
|
||||
for key, value in transition["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
buf = io.BytesIO()
|
||||
torch.save(transition_to_send_to_learner, buf)
|
||||
transition_bytes = buf.getvalue()
|
||||
@@ -226,7 +230,7 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
with TimerManager(
|
||||
elapsed_time_list=list_policy_time, label="Policy inference time", log=False
|
||||
) as timer: # noqa: F841
|
||||
action = policy.select_action(batch=obs) * 0.0
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
@@ -238,7 +242,9 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
next_obs, reward, done, truncated, info = online_env.step(action)
|
||||
|
||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||
action = torch.from_numpy(action[0]).to(device, non_blocking=True).unsqueeze(dim=0)
|
||||
action = (
|
||||
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
||||
)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
|
||||
@@ -247,6 +253,11 @@ def act_with_policy(cfg: DictConfig, robot: Robot, reward_classifier: nn.Module)
|
||||
# TODO: Check the shape
|
||||
action = info["action_intervention"]
|
||||
|
||||
# Check for NaN values in observations
|
||||
for key, tensor in obs.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=obs,
|
||||
|
||||
Reference in New Issue
Block a user