mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-02 20:01:25 +00:00
final refactor/fix
This commit is contained in:
@@ -186,7 +186,6 @@ def train(cfg: TrainPipelineConfig):
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
@@ -263,15 +262,14 @@ def train(cfg: TrainPipelineConfig):
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
)
|
||||
aggregated = eval_info["overall"]["aggregated"]
|
||||
# Print per-suite stats
|
||||
# Print per-suite stats, log?
|
||||
for task_group, task_group_info in eval_info.items():
|
||||
if task_group == "overall":
|
||||
continue # Skip the overall stats since we already printed it
|
||||
print(f"\nAggregated Metrics for {task_group}:")
|
||||
print(task_group_info["aggregated"])
|
||||
breakpoint()
|
||||
breakpoint()
|
||||
else:
|
||||
print("START EVAL")
|
||||
eval_info = eval_policy(
|
||||
eval_env,
|
||||
policy,
|
||||
@@ -280,9 +278,8 @@ def train(cfg: TrainPipelineConfig):
|
||||
max_episodes_rendered=4,
|
||||
start_seed=cfg.seed,
|
||||
)
|
||||
breakpoint()
|
||||
aggregated = eval_info["aggregated"]
|
||||
print("END EVAL")
|
||||
breakpoint()
|
||||
|
||||
eval_metrics = {
|
||||
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
|
||||
|
||||
Reference in New Issue
Block a user