mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-04 12:51:27 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
AdilZouitine
parent
2945bbb221
commit
7c05755823
@@ -14,18 +14,18 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from statistics import mean, quantiles
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from lerobot.scripts.server.utils import setup_process_handlers
|
||||
from queue import Empty
|
||||
from statistics import mean, quantiles
|
||||
|
||||
# from lerobot.scripts.eval import eval_policy
|
||||
|
||||
import grpc
|
||||
import hydra
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from torch import nn
|
||||
import time
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
# TODO: Remove the import of maniskill
|
||||
# from lerobot.common.envs.factory import make_maniskill_env
|
||||
@@ -34,34 +34,28 @@ from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.common.robot_devices.robots.factory import make_robot
|
||||
from lerobot.common.robot_devices.robots.utils import Robot
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
from lerobot.common.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
set_global_seed,
|
||||
)
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc
|
||||
from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc, learner_service
|
||||
from lerobot.scripts.server.buffer import (
|
||||
Transition,
|
||||
bytes_to_state_dict,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
python_object_to_bytes,
|
||||
transitions_to_bytes,
|
||||
bytes_to_state_dict,
|
||||
)
|
||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||
from lerobot.scripts.server.network_utils import (
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
)
|
||||
from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env
|
||||
from lerobot.scripts.server import learner_service
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
|
||||
from torch.multiprocessing import Queue, Event
|
||||
from queue import Empty
|
||||
|
||||
from lerobot.common.utils.utils import init_logging
|
||||
|
||||
from lerobot.scripts.server.utils import get_last_item_from_queue
|
||||
from lerobot.scripts.server.utils import get_last_item_from_queue, setup_process_handlers
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
@@ -102,9 +96,7 @@ def receive_policy(
|
||||
logging.info("[ACTOR] Received policy loop stopped")
|
||||
|
||||
|
||||
def transitions_stream(
|
||||
shutdown_event: Event, transitions_queue: Queue
|
||||
) -> hilserl_pb2.Empty:
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue) -> hilserl_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=5)
|
||||
@@ -169,9 +161,7 @@ def send_transitions(
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendTransitions(
|
||||
transitions_stream(shutdown_event, transitions_queue)
|
||||
)
|
||||
learner_client.SendTransitions(transitions_stream(shutdown_event, transitions_queue))
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
@@ -211,9 +201,7 @@ def send_interactions(
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendInteractions(
|
||||
interactions_stream(shutdown_event, interactions_queue)
|
||||
)
|
||||
learner_client.SendInteractions(interactions_stream(shutdown_event, interactions_queue))
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
@@ -301,9 +289,7 @@ def act_with_policy(
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(
|
||||
robot=robot, reward_classifier=reward_classifier, cfg=cfg
|
||||
)
|
||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
||||
|
||||
set_global_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
@@ -355,13 +341,9 @@ def act_with_policy(
|
||||
action = policy.select_action(batch=obs)
|
||||
policy_fps = 1.0 / (list_policy_time[-1] + 1e-9)
|
||||
|
||||
log_policy_frequency_issue(
|
||||
policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step
|
||||
)
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
next_obs, reward, done, truncated, info = online_env.step(
|
||||
action.squeeze(dim=0).cpu().numpy()
|
||||
)
|
||||
next_obs, reward, done, truncated, info = online_env.step(action.squeeze(dim=0).cpu().numpy())
|
||||
else:
|
||||
# TODO (azouitine): Make a custom space for torch tensor
|
||||
action = online_env.action_space.sample()
|
||||
@@ -369,9 +351,7 @@ def act_with_policy(
|
||||
|
||||
# HACK: We have only one env but we want to batch it, it will be resolved with the torch box
|
||||
action = (
|
||||
torch.from_numpy(action[0])
|
||||
.to(device, non_blocking=device.type == "cuda")
|
||||
.unsqueeze(dim=0)
|
||||
torch.from_numpy(action[0]).to(device, non_blocking=device.type == "cuda").unsqueeze(dim=0)
|
||||
)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
@@ -391,9 +371,7 @@ def act_with_policy(
|
||||
# Check for NaN values in observations
|
||||
for key, tensor in obs.items():
|
||||
if torch.isnan(tensor).any():
|
||||
logging.error(
|
||||
f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}"
|
||||
)
|
||||
logging.error(f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}")
|
||||
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
@@ -413,13 +391,9 @@ def act_with_policy(
|
||||
# Because we are using a single environment we can index at zero
|
||||
if done or truncated:
|
||||
# TODO: Handle logging for episode information
|
||||
logging.info(
|
||||
f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}"
|
||||
)
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(
|
||||
policy=policy.actor, parameters_queue=parameters_queue, device=device
|
||||
)
|
||||
update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
push_transitions_to_transport_queue(
|
||||
@@ -495,9 +469,7 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||
return stats
|
||||
|
||||
|
||||
def log_policy_frequency_issue(
|
||||
policy_fps: float, cfg: DictConfig, interaction_step: int
|
||||
):
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: DictConfig, interaction_step: int):
|
||||
if policy_fps < cfg.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
||||
|
||||
Reference in New Issue
Block a user