mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-01 03:11:29 +00:00
[WIP] Update SAC configuration and environment settings
- Reduced frame rate in `ManiskillEnvConfig` from 400 to 200. - Enhanced `SACConfig` with new dataclasses for actor, learner, and network configurations. - Improved input and output feature management in `SACConfig`. - Refactored `actor_server` and `learner_server` to access configuration properties directly. - Updated training pipeline to validate configurations and handle dataset repo IDs more robustly.
This commit is contained in:
@@ -73,8 +73,8 @@ def receive_policy(
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config["learner_host"],
|
||||
port=cfg.policy.actor_learner_config["learner_port"],
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -85,6 +85,7 @@ def receive_policy(
|
||||
shutdown_event,
|
||||
log_prefix="[ACTOR] parameters",
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
@@ -153,8 +154,8 @@ def send_transitions(
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config["learner_host"],
|
||||
port=cfg.policy.actor_learner_config["learner_port"],
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -193,8 +194,8 @@ def send_interactions(
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config["learner_host"],
|
||||
port=cfg.policy.actor_learner_config["learner_port"],
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -286,10 +287,10 @@ def act_with_policy(
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
|
||||
online_env = make_robot_env( cfg=cfg.env)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.device, log=True)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
@@ -302,11 +303,7 @@ def act_with_policy(
|
||||
# TODO: At some point we should just need make sac policy
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online training, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
# TODO: Handle resume training
|
||||
device=device,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
@@ -322,13 +319,13 @@ def act_with_policy(
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
for interaction_step in range(cfg.training.online_steps):
|
||||
for interaction_step in range(cfg.policy.online_steps):
|
||||
start_time = time.perf_counter()
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
if interaction_step >= cfg.training.online_step_before_learning:
|
||||
if interaction_step >= cfg.policy.online_step_before_learning:
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with TimerManager(
|
||||
elapsed_time_list=list_policy_time,
|
||||
@@ -426,9 +423,9 @@ def act_with_policy(
|
||||
episode_total_steps = 0
|
||||
obs, info = online_env.reset()
|
||||
|
||||
if cfg.fps is not None:
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
busy_wait(1 / cfg.fps - dt_time)
|
||||
busy_wait(1 / cfg.env.fps - dt_time)
|
||||
|
||||
|
||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
@@ -467,9 +464,9 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
|
||||
|
||||
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
|
||||
if policy_fps < cfg.fps:
|
||||
if policy_fps < cfg.env.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
|
||||
)
|
||||
|
||||
|
||||
@@ -495,7 +492,7 @@ def establish_learner_connection(
|
||||
|
||||
|
||||
def use_threads(cfg: TrainPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency["actor"] == "threads"
|
||||
return cfg.policy.concurrency.actor == "threads"
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
@@ -511,8 +508,8 @@ def actor_cli(cfg: TrainPipelineConfig):
|
||||
shutdown_event = setup_process_handlers(use_threads(cfg))
|
||||
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config["learner_host"],
|
||||
port=cfg.policy.actor_learner_config["learner_port"],
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
logging.info("[ACTOR] Establishing connection with Learner")
|
||||
|
||||
Reference in New Issue
Block a user