[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:41:27 +00:00
committed by AdilZouitine
parent 2945bbb221
commit 7c05755823
123 changed files with 1161 additions and 3425 deletions

View File

@@ -14,18 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from statistics import mean, quantiles
import time
from functools import lru_cache
from lerobot.scripts.server.utils import setup_process_handlers
from queue import Empty
from statistics import mean, quantiles
# from lerobot.scripts.eval import eval_policy
import grpc
import hydra
import torch
from omegaconf import DictConfig
from torch import nn
import time
from torch.multiprocessing import Event, Queue
# TODO: Remove the import of maniskill
# from lerobot.common.envs.factory import make_maniskill_env
@@ -34,34 +34,28 @@ from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import (
TimerManager,
get_safe_torch_device,
init_logging,
set_global_seed,
)
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
from lerobot.scripts.server.buffer import (
Transition,
bytes_to_state_dict,
move_state_dict_to_device,
move_transition_to_device,
python_object_to_bytes,
transitions_to_bytes,
bytes_to_state_dict,
)
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
from lerobot.scripts.server.network_utils import (
receive_bytes_in_chunks,
send_bytes_in_chunks,
)
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
from lerobot.scripts.server import learner_service
from lerobot.common.robot_devices.utils import busy_wait
from torch.multiprocessing import Queue, Event
from queue import Empty
from lerobot.common.utils.utils import init_logging
from lerobot.scripts.server.utils import get_last_item_from_queue
from lerobot.scripts.server.utils import get_last_item_from_queue, setup_process_handlers
ACTOR_SHUTDOWN_TIMEOUT = 30
@@ -102,9 +96,7 @@ def receive_policy(
logging.info("[ACTOR] Received policy loop stopped")
def transitions_stream(
shutdown_event: Event, transitions_queue: Queue
) -> hilserl_pb2.Empty:
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty:
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=5)
@@ -169,9 +161,7 @@ def send_transitions(
)
try:
learner_client.SendTransitions(
transitions_stream(shutdown_event, transitions_queue)
)
learner_client.SendTransitions(transitions_stream(shutdown_event, transitions_queue))
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
@@ -211,9 +201,7 @@ def send_interactions(
)
try:
learner_client.SendInteractions(
interactions_stream(shutdown_event, interactions_queue)
)
learner_client.SendInteractions(interactions_stream(shutdown_event, interactions_queue))
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
@@ -301,9 +289,7 @@ def act_with_policy(
logging.info("make_env online")
online_env = make_robot_env(
robot=robot, reward_classifier=reward_classifier, cfg=cfg
)
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
set_global_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
@@ -355,13 +341,9 @@ def act_with_policy(
action = policy.select_action(batch=obs)
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
log_policy_frequency_issue(
policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step
)
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
next_obs, reward, done, truncated, info = online_env.step(
action.squeeze(dim=0).cpu().numpy()
)
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
else:
# TODO (azouitine): Make a custom space for torch tensor
action = online_env.action_space.sample()
@@ -369,9 +351,7 @@ def act_with_policy(
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
action = (
torch.from_numpy(action[0])
.to(device, non_blocking=device.type == "cuda")
.unsqueeze(dim=0)
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
)
sum_reward_episode += float(reward)
@@ -391,9 +371,7 @@ def act_with_policy(
# Check for NaN values in observations
for key, tensor in obs.items():
if torch.isnan(tensor).any():
logging.error(
f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}"
)
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
list_transition_to_send_to_learner.append(
Transition(
@@ -413,13 +391,9 @@ def act_with_policy(
# Because we are using a single environment we can index at zero
if done or truncated:
# TODO: Handle logging for episode information
logging.info(
f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}"
)
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(
policy=policy.actor, parameters_queue=parameters_queue, device=device
)
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
@@ -495,9 +469,7 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
return stats
def log_policy_frequency_issue(
policy_fps: float, cfg: DictConfig, interaction_step: int
):
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
if policy_fps < cfg.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"