mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-03 12:21:27 +00:00
final refactor/fix
This commit is contained in:
@@ -62,6 +62,7 @@ import einops
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from termcolor import colored
|
||||
from torch import Tensor, nn
|
||||
from tqdm import trange
|
||||
|
||||
@@ -73,6 +74,7 @@ from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
@@ -146,8 +148,7 @@ def rollout(
|
||||
check_env_attributes_and_types(env)
|
||||
while not np.all(done) and step < max_steps:
|
||||
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
|
||||
# observation = preprocess_observation(observation)
|
||||
observation = preprocess_observation(observation, cfg=policy.config)
|
||||
observation = preprocess_observation(observation)
|
||||
if return_observations:
|
||||
all_observations.append(deepcopy(observation))
|
||||
|
||||
@@ -459,24 +460,8 @@ def _compile_episode_data(
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
def set_global_seed(seed):
|
||||
"""Set seed for reproducibility."""
|
||||
import random
|
||||
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info("Output dir:" + f" {out_dir}")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def eval(cfg: EvalPipelineConfig):
|
||||
def eval_main(cfg: EvalPipelineConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# Check device is available
|
||||
@@ -484,9 +469,9 @@ def eval(cfg: EvalPipelineConfig):
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
set_global_seed(cfg.seed)
|
||||
set_seed(cfg.seed)
|
||||
|
||||
log_output_dir(cfg.output_dir)
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||
@@ -494,11 +479,9 @@ def eval(cfg: EvalPipelineConfig):
|
||||
logging.info("Making policy.")
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
# device=device,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
if cfg.env.multitask_eval:
|
||||
info = eval_policy_multitask(
|
||||
@@ -663,4 +646,4 @@ def eval_policy_multitask(
|
||||
|
||||
if __name__ == "__main__":
|
||||
init_logging()
|
||||
eval()
|
||||
eval_main()
|
||||
|
||||
@@ -186,7 +186,6 @@ def train(cfg: TrainPipelineConfig):
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
@@ -263,15 +262,14 @@ def train(cfg: TrainPipelineConfig):
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
)
|
||||
aggregated = eval_info["overall"]["aggregated"]
|
||||
# Print per-suite stats
|
||||
# Print per-suite stats, log?
|
||||
for task_group, task_group_info in eval_info.items():
|
||||
if task_group == "overall":
|
||||
continue # Skip the overall stats since we already printed it
|
||||
print(f"\nAggregated Metrics for {task_group}:")
|
||||
print(task_group_info["aggregated"])
|
||||
breakpoint()
|
||||
breakpoint()
|
||||
else:
|
||||
print("START EVAL")
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
@@ -280,9 +278,8 @@ def train(cfg: TrainPipelineConfig):
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
breakpoint()
|
||||
aggregated = eval_info["aggregated"]
|
||||
print("END EVAL")
|
||||
breakpoint()
|
||||
|
||||
eval_metrics = {
|
||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||
|
||||
Reference in New Issue
Block a user