[WIP] Update SAC configuration and environment settings

- Reduced frame rate in `ManiskillEnvConfig` from 400 to 200.
- Enhanced `SACConfig` with new dataclasses for actor, learner, and network configurations.
- Improved input and output feature management in `SACConfig`.
- Refactored `actor_server` and `learner_server` to access configuration properties directly.
- Updated training pipeline to validate configurations and handle dataset repo IDs more robustly.
This commit is contained in:
AdilZouitine
2025-03-27 08:13:20 +00:00
parent 626e5dd35c
commit 052a4acfc2
7 changed files with 183 additions and 126 deletions

View File

@@ -73,8 +73,8 @@ def receive_policy(
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
@@ -85,6 +85,7 @@ def receive_policy(
shutdown_event,
log_prefix="[ACTOR] parameters",
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
@@ -153,8 +154,8 @@ def send_transitions(
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
@@ -193,8 +194,8 @@ def send_interactions(
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
@@ -286,10 +287,10 @@ def act_with_policy(
logging.info("make_env online")
online_env = make_robot_env(robot=robot, reward_classifier=reward_classifier, cfg=cfg)
online_env = make_robot_env( cfg=cfg.env)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.device, log=True)
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
@@ -302,11 +303,7 @@ def act_with_policy(
# TODO: At some point we should just need make sac policy
policy: SACPolicy = make_policy(
cfg=cfg.policy,
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online training, we do not need dataset_stats
dataset_stats=None,
# TODO: Handle resume training
device=device,
env_cfg=cfg.env,
)
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
@@ -322,13 +319,13 @@ def act_with_policy(
episode_intervention_steps = 0
episode_total_steps = 0
for interaction_step in range(cfg.training.online_steps):
for interaction_step in range(cfg.policy.online_steps):
start_time = time.perf_counter()
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down act_with_policy")
return
if interaction_step >= cfg.training.online_step_before_learning:
if interaction_step >= cfg.policy.online_step_before_learning:
# Time policy inference and check if it meets FPS requirement
with TimerManager(
elapsed_time_list=list_policy_time,
@@ -426,9 +423,9 @@ def act_with_policy(
episode_total_steps = 0
obs, info = online_env.reset()
if cfg.fps is not None:
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
busy_wait(1 / cfg.fps - dt_time)
busy_wait(1 / cfg.env.fps - dt_time)
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
@@ -467,9 +464,9 @@ def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]:
def log_policy_frequency_issue(policy_fps: float, cfg: TrainPipelineConfig, interaction_step: int):
if policy_fps < cfg.fps:
if policy_fps < cfg.env.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}"
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
)
@@ -495,7 +492,7 @@ def establish_learner_connection(
def use_threads(cfg: TrainPipelineConfig) -> bool:
return cfg.policy.concurrency["actor"] == "threads"
return cfg.policy.concurrency.actor == "threads"
@parser.wrap()
@@ -511,8 +508,8 @@ def actor_cli(cfg: TrainPipelineConfig):
shutdown_event = setup_process_handlers(use_threads(cfg))
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config["learner_host"],
port=cfg.policy.actor_learner_config["learner_port"],
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
logging.info("[ACTOR] Establishing connection with Learner")