final refactor/fix

This commit is contained in:
Jade Choghari (jchoghar)
2025-08-25 06:25:02 -04:00
parent afad90ffaa
commit 8d2c66abd2
7 changed files with 47 additions and 75 deletions

View File

@@ -62,6 +62,7 @@ import einops
import gymnasium as gym
import numpy as np
import torch
from termcolor import colored
from torch import Tensor, nn
from tqdm import trange
@@ -73,6 +74,7 @@ from lerobot.policies.factory import make_policy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.utils.io_utils import write_video
from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import (
get_safe_torch_device,
init_logging,
@@ -146,8 +148,7 @@ def rollout(
check_env_attributes_and_types(env)
while not np.all(done) and step < max_steps:
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
# observation = preprocess_observation(observation)
observation = preprocess_observation(observation, cfg=policy.config)
observation = preprocess_observation(observation)
if return_observations:
all_observations.append(deepcopy(observation))
@@ -459,24 +460,8 @@ def _compile_episode_data(
return data_dict
def set_global_seed(seed):
"""Set seed for reproducibility."""
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def log_output_dir(out_dir):
logging.info("Output dir:" + f" {out_dir}")
@parser.wrap()
def eval(cfg: EvalPipelineConfig):
def eval_main(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg)))
# Check device is available
@@ -484,9 +469,9 @@ def eval(cfg: EvalPipelineConfig):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(cfg.seed)
set_seed(cfg.seed)
log_output_dir(cfg.output_dir)
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
logging.info("Making environment.")
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
@@ -494,11 +479,9 @@ def eval(cfg: EvalPipelineConfig):
logging.info("Making policy.")
policy = make_policy(
cfg=cfg.policy,
# device=device,
env_cfg=cfg.env,
)
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
if cfg.env.multitask_eval:
info = eval_policy_multitask(
@@ -663,4 +646,4 @@ def eval_policy_multitask(
if __name__ == "__main__":
init_logging()
eval()
eval_main()

View File

@@ -186,7 +186,6 @@ def train(cfg: TrainPipelineConfig):
dl_iter = cycle(dataloader)
policy.train()
train_metrics = {
"loss": AverageMeter("loss", ":.3f"),
"grad_norm": AverageMeter("grdn", ":.3f"),
@@ -263,15 +262,14 @@ def train(cfg: TrainPipelineConfig):
max_parallel_tasks=cfg.env.max_parallel_tasks,
)
aggregated = eval_info["overall"]["aggregated"]
# Print per-suite stats
# Print per-suite stats, log?
for task_group, task_group_info in eval_info.items():
if task_group == "overall":
continue # Skip the overall stats since we already printed it
print(f"\nAggregated Metrics for {task_group}:")
print(task_group_info["aggregated"])
breakpoint()
breakpoint()
else:
print("START EVAL")
eval_info = eval_policy(
eval_env,
policy,
@@ -280,9 +278,8 @@ def train(cfg: TrainPipelineConfig):
max_episodes_rendered=4,
start_seed=cfg.seed,
)
breakpoint()
aggregated = eval_info["aggregated"]
print("END EVAL")
breakpoint()
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),